"""Initialize the distributed services"""# pylint: disable=line-too-longimportatexitimportgcimportmultiprocessingasmpimportosimportqueueimportsysimporttimeimporttracebackfromenumimportEnumfrom..importutilsfrom..baseimportdgl_warning,DGLErrorfrom.importrpcfrom.constantsimportMAX_QUEUE_SIZEfrom.kvstoreimportclose_kvstore,init_kvstorefrom.roleimportinit_rolefrom.rpc_clientimportconnect_to_serverSAMPLER_POOL=NoneNUM_SAMPLER_WORKERS=0INITIALIZED=Falsedefset_initialized(value=True):"""Set the initialized state of rpc"""globalINITIALIZEDINITIALIZED=valuedefget_sampler_pool():"""Return the sampler pool and num_workers"""returnSAMPLER_POOL,NUM_SAMPLER_WORKERSdef_init_rpc(ip_config,num_servers,max_queue_size,role,num_threads,group_id,):"""This init function is called in the worker processes."""try:utils.set_num_threads(num_threads)ifos.environ.get("DGL_DIST_MODE","standalone")!="standalone":connect_to_server(ip_config,num_servers,max_queue_size,group_id)init_role(role)init_kvstore(ip_config,num_servers,role)exceptExceptionase:print(e,flush=True)traceback.print_exc()raiseeclassMpCommand(Enum):"""Enum class for multiprocessing command"""INIT_RPC=0# Not used in the task queueSET_COLLATE_FN=1CALL_BARRIER=2DELETE_COLLATE_FN=3CALL_COLLATE_FN=4CALL_FN_ALL_WORKERS=5FINALIZE_POOL=6definit_process(rpc_config,mp_contexts):"""Work loop in the worker"""try:_init_rpc(*rpc_config)keep_polling=Truedata_queue,task_queue,barrier=mp_contextscollate_fn_dict={}whilekeep_polling:try:# Follow https://github.com/pytorch/pytorch/blob/d57ce8cf8989c0b737e636d8d7abe16c1f08f70b/torch/utils/data/_utils/worker.py#L260command,args=task_queue.get(timeout=5)exceptqueue.Empty:continueifcommand==MpCommand.SET_COLLATE_FN:dataloader_name,func=argscollate_fn_dict[dataloader_name]=funcelifcommand==MpCommand.CALL_BARRIER:barrier.wait()elifcommand==MpCommand.DELETE_COLLATE_FN:(dataloader_name,)=argsdelcollate_fn_dict[dataloader_name]elifcommand==MpCommand.CALL_COLLATE_FN:dataloader_name,collate_args=argsdata_queue.put((dataloader_name,collate_fn_dict[dataloader_name](collate_args),))elifcommand==MpCommand.CALL_FN_ALL_WORKERS:func,func_args=argsfunc(func_args)elifcommand==MpCommand.FINALIZE_POOL:_exit()keep_polling=Falseelse:raiseException("Unknown command")exceptExceptionase:traceback.print_exc()raiseeclassCustomPool:"""Customized worker pool"""def__init__(self,num_workers,rpc_config):""" Customized worker pool init function """ctx=mp.get_context("spawn")self.num_workers=num_workers# As pool could be used by any number of dataloaders, queues# should be able to take infinite elements to avoid dead lock.self.queue_size=0self.result_queue=ctx.Queue(self.queue_size)self.results={}# key is dataloader name, value is fetched batch.self.task_queues=[]self.process_list=[]self.current_proc_id=0self.cache_result_dict={}self.barrier=ctx.Barrier(num_workers)for_inrange(num_workers):task_queue=ctx.Queue(self.queue_size)self.task_queues.append(task_queue)proc=ctx.Process(target=init_process,args=(rpc_config,(self.result_queue,task_queue,self.barrier),),)proc.daemon=Trueproc.start()self.process_list.append(proc)defset_collate_fn(self,func,dataloader_name):"""Set collate function in subprocess"""foriinrange(self.num_workers):self.task_queues[i].put((MpCommand.SET_COLLATE_FN,(dataloader_name,func)))self.results[dataloader_name]=[]defsubmit_task(self,dataloader_name,args):"""Submit task to workers"""# Round robinself.task_queues[self.current_proc_id].put((MpCommand.CALL_COLLATE_FN,(dataloader_name,args)))self.current_proc_id=(self.current_proc_id+1)%self.num_workersdefsubmit_task_to_all_workers(self,func,args):"""Submit task to all workers"""foriinrange(self.num_workers):self.task_queues[i].put((MpCommand.CALL_FN_ALL_WORKERS,(func,args)))defget_result(self,dataloader_name,timeout=1800):"""Get result from result queue"""ifdataloader_namenotinself.results:raiseDGLError(f"Got result from an unknown dataloader {dataloader_name}.")whilelen(self.results[dataloader_name])==0:dl_name,data=self.result_queue.get(timeout=timeout)self.results[dl_name].append(data)returnself.results[dataloader_name].pop(0)defdelete_collate_fn(self,dataloader_name):"""Delete collate function"""foriinrange(self.num_workers):self.task_queues[i].put((MpCommand.DELETE_COLLATE_FN,(dataloader_name,)))delself.results[dataloader_name]defcall_barrier(self):"""Call barrier at all workers"""foriinrange(self.num_workers):self.task_queues[i].put((MpCommand.CALL_BARRIER,tuple()))defclose(self):"""Close worker pool"""foriinrange(self.num_workers):self.task_queues[i].put((MpCommand.FINALIZE_POOL,tuple()),block=False)time.sleep(0.5)# Fix for early python versiondefjoin(self):"""Join the close process of worker pool"""foriinrange(self.num_workers):self.process_list[i].join()
[docs]definitialize(ip_config,max_queue_size=MAX_QUEUE_SIZE,net_type=None,num_worker_threads=1,use_graphbolt=False,):"""Initialize DGL's distributed module This function initializes DGL's distributed module. It acts differently in server or client modes. In the server mode, it runs the server code and never returns. In the client mode, it builds connections with servers for communication and creates worker processes for distributed sampling. Parameters ---------- ip_config: str File path of ip_config file max_queue_size : int Maximal size (bytes) of client queue buffer (~20 GB on default). Note that the 20 GB is just an upper-bound and DGL uses zero-copy and it will not allocate 20GB memory at once. net_type : str, optional [Deprecated] Networking type, can be 'socket' only. num_worker_threads: int The number of OMP threads in each sampler process. use_graphbolt: bool, optional Whether to use GraphBolt for distributed train. Note ---- Users have to invoke this API before any DGL's distributed API and framework-specific distributed API. For example, when used with Pytorch, users have to invoke this function before Pytorch's `pytorch.distributed.init_process_group`. """print(f"Initialize the distributed services with graphbolt: {use_graphbolt}")ifnet_typeisnotNone:dgl_warning("net_type is deprecated and will be removed in future release.")ifos.environ.get("DGL_ROLE","client")=="server":from.dist_graphimportDistGraphServerassert(os.environ.get("DGL_SERVER_ID")isnotNone),"Please define DGL_SERVER_ID to run DistGraph server"assert(os.environ.get("DGL_IP_CONFIG")isnotNone),"Please define DGL_IP_CONFIG to run DistGraph server"assert(os.environ.get("DGL_NUM_SERVER")isnotNone),"Please define DGL_NUM_SERVER to run DistGraph server"assert(os.environ.get("DGL_NUM_CLIENT")isnotNone),"Please define DGL_NUM_CLIENT to run DistGraph server"assert(os.environ.get("DGL_CONF_PATH")isnotNone),"Please define DGL_CONF_PATH to run DistGraph server"formats=os.environ.get("DGL_GRAPH_FORMAT","csc").split(",")formats=[f.strip()forfinformats]rpc.reset()serv=DistGraphServer(int(os.environ.get("DGL_SERVER_ID")),os.environ.get("DGL_IP_CONFIG"),int(os.environ.get("DGL_NUM_SERVER")),int(os.environ.get("DGL_NUM_CLIENT")),os.environ.get("DGL_CONF_PATH"),graph_format=formats,use_graphbolt=use_graphbolt,)serv.start()sys.exit()else:num_workers=int(os.environ.get("DGL_NUM_SAMPLER",0))num_servers=int(os.environ.get("DGL_NUM_SERVER",1))group_id=int(os.environ.get("DGL_GROUP_ID",0))rpc.reset()globalSAMPLER_POOLglobalNUM_SAMPLER_WORKERSis_standalone=(os.environ.get("DGL_DIST_MODE","standalone")=="standalone")ifnum_workers>0andnotis_standalone:SAMPLER_POOL=CustomPool(num_workers,(ip_config,num_servers,max_queue_size,"sampler",num_worker_threads,group_id,),)else:SAMPLER_POOL=NoneNUM_SAMPLER_WORKERS=num_workersifnotis_standalone:assert(num_serversisnotNoneandnum_servers>0),"The number of servers per machine must be specified with a positive number."connect_to_server(ip_config,num_servers,max_queue_size,group_id=group_id,)init_role("default")init_kvstore(ip_config,num_servers,"default")
deffinalize_client():"""Release resources of this client."""ifos.environ.get("DGL_DIST_MODE","standalone")!="standalone":rpc.finalize_sender()rpc.finalize_receiver()def_exit():exit_client()time.sleep(1)deffinalize_worker():"""Finalize workers Python's multiprocessing pool will not call atexit function when close """globalSAMPLER_POOLifSAMPLER_POOLisnotNone:SAMPLER_POOL.close()defjoin_finalize_worker():"""join the worker close process"""globalSAMPLER_POOLifSAMPLER_POOLisnotNone:SAMPLER_POOL.join()SAMPLER_POOL=Nonedefis_initialized():"""Is RPC initialized?"""returnINITIALIZEDdef_shutdown_servers():set_initialized(False)# send ShutDownRequest to serversifrpc.get_rank()==0:# Only client_0 issue this commandreq=rpc.ShutDownRequest(rpc.get_rank())forserver_idinrange(rpc.get_num_server()):rpc.send_request(server_id,req)defexit_client():"""Trainer exits This function is called automatically when a Python process exits. Normally, the training script does not need to invoke this function at the end. In the case that the training script needs to initialize the distributed module multiple times (so far, this is needed in the unit tests), the training script needs to call `exit_client` before calling `initialize` again. """# Only client with rank_0 will send shutdown request to servers.print("Client[{}] in group[{}] is exiting...".format(rpc.get_rank(),rpc.get_group_id()))finalize_worker()# finalize workers should be earilier than barrier, and non-blocking# collect data such as DistTensor before exitgc.collect()ifos.environ.get("DGL_DIST_MODE","standalone")!="standalone":rpc.client_barrier()_shutdown_servers()finalize_client()join_finalize_worker()close_kvstore()atexit.unregister(exit_client)