## Licensed to the Apache Software Foundation (ASF) under one or more# contributor license agreements. See the NOTICE file distributed with# this work for additional information regarding copyright ownership.# The ASF licenses this file to You under the Apache License, Version 2.0# (the "License"); you may not use this file except in compliance with# the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.#importjsonfromcontextlibimportcontextmanagerimportcollectionsimportloggingimportmathimportosimportrandomimportreimportshutilimportsubprocessimportsysimporttempfileimporttextwrapimporttimefromtypingimport(Union,Callable,List,Dict,Optional,Any,Tuple,Generator,Iterator,)frompysparkimportcloudpicklefrompyspark.resource.informationimportResourceInformationfrompyspark.sqlimportDataFrame,SparkSessionfrompyspark.taskcontextimportBarrierTaskContextfrompyspark.ml.torch.log_communicationimport(# type: ignoreLogStreamingClient,LogStreamingServer,)frompyspark.ml.dl_utilimportFunctionPicklerdef_get_resources(session:SparkSession)->Dict[str,ResourceInformation]:resources:Dict[str,ResourceInformation]={}try:resources=session.sparkContext.resourcesexceptException:resources=session._client._resources()# type: ignore[attr-defined]returnresourcesdef_get_conf(spark:SparkSession,key:str,default_value:str)->str:"""Get the conf "key" from the given spark session, or return the default value if the conf is not set. Parameters ---------- spark : :class:`SparkSession` The :class:`SparkSession` for the distributor. key : str string for conf name default_value : str default value for the conf value for the given key Returns ------- str Returns the string value that corresponds to the conf """value=spark.conf.get(key,default_value)assertvalueisnotNonereturnvalue# TODO(SPARK-41589): will move the functions and tests to an external file# once we are in agreement about which functions should be in utils.pydef_get_conf_boolean(spark:SparkSession,key:str,default_value:str)->bool:value=_get_conf(spark=spark,key=key,default_value=default_value)value=value.lower()assertvaluein["true","false"]returnvalue=="true"def_get_logger(name:str)->logging.Logger:""" Gets a logger by name, or creates and configures it for the first time. """logger=logging.getLogger(name)logger.setLevel(logging.INFO)# If the logger is configured, skip the configureifnotlogger.handlersandnotlogging.getLogger().handlers:handler=logging.StreamHandler(sys.stderr)logger.addHandler(handler)returnloggerdef_get_gpus_owned(context:Union[SparkSession,BarrierTaskContext])->List[str]:"""Gets the number of GPUs that Spark scheduled to the calling task. Parameters ---------- context : :class:`SparkSession` or :class:`BarrierTaskContext` The :class:`SparkSession` or :class:`BarrierTaskContext` that has GPUs available. Returns ------- list The correct mapping of addresses to workers. Raises ------ ValueError Raised if the input addresses were not found. """CUDA_VISIBLE_DEVICES="CUDA_VISIBLE_DEVICES"pattern=re.compile("^[1-9][0-9]*|0$")ifisinstance(context,BarrierTaskContext):addresses=context.resources()["gpu"].addresseselse:addresses=_get_resources(context)["gpu"].addressesifany(notpattern.match(address)foraddressinaddresses):raiseValueError(f"Found GPU addresses {addresses} which ""are not all in the correct format ""for CUDA_VISIBLE_DEVICES, which requires ""integers with no zero padding.")ifCUDA_VISIBLE_DEVICESinos.environ:gpu_indices=list(map(int,addresses))gpu_list=os.environ[CUDA_VISIBLE_DEVICES].split(",")gpu_owned=[gpu_list[i]foriingpu_indices]returngpu_ownedreturnaddressesSPARK_PARTITION_ARROW_DATA_FILE="SPARK_PARTITION_ARROW_DATA_FILE"SPARK_DATAFRAME_SCHEMA_FILE="SPARK_DATAFRAME_SCHEMA_FILE"classDistributor:""" The parent class for TorchDistributor. This class shouldn't be instantiated directly. """def__init__(self,num_processes:int=1,local_mode:bool=True,use_gpu:bool=True,ssl_conf:Optional[str]=None,):frompyspark.sql.utilsimportis_remoteself.is_remote=is_remote()self.spark=SparkSession.active()# indicate whether the server side is local modeself.is_spark_local_master=False# Refer to 'org.apache.spark.util.Utils#isLocalMaster'master=_get_conf(self.spark,"spark.master","")ifmaster=="local"ormaster.startswith("local["):self.is_spark_local_master=Trueself.logger=_get_logger(self.__class__.__name__)self.num_processes=num_processesself.local_mode=local_modeself.use_gpu=use_gpuself.num_tasks=self._get_num_tasks()self.ssl_conf=ssl_confdef_create_input_params(self)->Dict[str,Any]:input_params=self.__dict__.copy()forunneeded_paramin["spark","ssl_conf","logger","is_remote","is_spark_local_master",]:delinput_params[unneeded_param]returninput_paramsdef_get_num_tasks(self)->int:""" Returns the number of Spark tasks to use for distributed training Returns ------- int The number of Spark tasks to use for distributed training Raises ------ RuntimeError Raised when the SparkConf was misconfigured. """ifself.use_gpu:ifnotself.local_mode:key="spark.task.resource.gpu.amount"task_gpu_amount=int(_get_conf(self.spark,key,"0"))iftask_gpu_amount<1:raiseRuntimeError(f"'{key}' was unset, so gpu usage is unavailable.")# TODO(SPARK-41916): Address situation when spark.task.resource.gpu.amount > 1returnmath.ceil(self.num_processes/task_gpu_amount)else:key="spark.driver.resource.gpu.amount"if"gpu"notin_get_resources(self.spark):raiseRuntimeError("GPUs were unable to be found on the driver.")num_available_gpus=int(_get_conf(self.spark,key,"0"))ifnum_available_gpus==0:raiseRuntimeError("GPU resources were not configured properly on the driver.")ifself.num_processes>num_available_gpus:self.logger.warning("'num_processes' cannot be set to a value greater than the number of "f"available GPUs on the driver, which is {num_available_gpus}. ""'num_processes' was reset to be equal to the number of available GPUs.",)self.num_processes=num_available_gpusreturnself.num_processesdef_validate_input_params(self)->None:ifself.num_processes<=0:raiseValueError("num_proccesses has to be a positive integer")def_check_encryption(self)->None:"""Checks to see if the user requires encrpytion of data. If required, throw an exception since we don't support that. Raises ------ RuntimeError Thrown when the user requires ssl encryption or when the user initializes the Distributor parent class. """ifnothasattr(self,"ssl_conf"):raiseRuntimeError("Distributor doesn't have this functionality. Use TorchDistributor instead.")is_ssl_enabled=_get_conf_boolean(self.spark,"spark.ssl.enabled","false")ignore_ssl=_get_conf_boolean(self.spark,self.ssl_conf,"false")# type: ignoreifis_ssl_enabled:name=self.__class__.__name__ifignore_ssl:self.logger.warning(textwrap.dedent(f""" This cluster has TLS encryption enabled; however, {name} does not support data encryption in transit. The Spark configuration '{self.ssl_conf}' has been set to 'true' to override this configuration and use {name} anyway. Please note this will cause model parameters and possibly training data to be sent between nodes unencrypted. """,))returnraiseRuntimeError(textwrap.dedent(f""" This cluster has TLS encryption enabled; however, {name} does not support data encryption in transit. To override this configuration and use {name} anyway, you may set '{self.ssl_conf}' to 'true' in the Spark configuration. Please note this will cause model parameters and possibly training data to be sent between nodes unencrypted. """))
[docs]classTorchDistributor(Distributor):""" A class to support distributed training on PyTorch and PyTorch Lightning using PySpark. .. versionadded:: 3.4.0 .. versionchanged:: 3.5.0 Supports Spark Connect. Parameters ---------- num_processes : int, optional An integer that determines how many different concurrent tasks are allowed. We expect spark.task.gpus = 1 for GPU-enabled training. Default should be 1; we don't want to invoke multiple cores/gpus without explicit mention. local_mode : bool, optional A boolean that determines whether we are using the driver node for training. Default should be false; we don't want to invoke executors without explicit mention. use_gpu : bool, optional A boolean that indicates whether or not we are doing training on the GPU. Note that there are differences in how GPU-enabled code looks like and how CPU-specific code looks like. Examples -------- Run PyTorch Training locally on GPU (using a PyTorch native function) >>> def train(learning_rate): ... import torch.distributed ... torch.distributed.init_process_group(backend="nccl") ... # ... ... torch.destroy_process_group() ... return model # or anything else ... >>> distributor = TorchDistributor( ... num_processes=2, ... local_mode=True, ... use_gpu=True) >>> model = distributor.run(train, 1e-3) Run PyTorch Training on GPU (using a file with PyTorch code) >>> distributor = TorchDistributor( ... num_processes=2, ... local_mode=False, ... use_gpu=True) >>> distributor.run("/path/to/train.py", "--learning-rate=1e-3") Run PyTorch Lightning Training on GPU >>> num_proc = 2 >>> def train(): ... from pytorch_lightning import Trainer ... # ... ... # required to set devices = 1 and num_nodes = num_processes for multi node ... # required to set devices = num_processes and num_nodes = 1 for single node multi GPU ... trainer = Trainer(accelerator="gpu", devices=1, num_nodes=num_proc, strategy="ddp") ... trainer.fit() ... # ... ... return trainer ... >>> distributor = TorchDistributor( ... num_processes=num_proc, ... local_mode=True, ... use_gpu=True) >>> trainer = distributor.run(train) """_PICKLED_FUNC_FILE="func.pickle"_TRAIN_FILE="train.py"_PICKLED_OUTPUT_FILE="output.pickle"_TORCH_SSL_CONF="pytorch.spark.distributor.ignoreSsl"def__init__(self,num_processes:int=1,local_mode:bool=True,use_gpu:bool=True,_ssl_conf:str=_TORCH_SSL_CONF,):"""Initializes the distributor. Parameters ---------- num_processes : int, optional An integer that determines how many different concurrent tasks are allowed. We expect spark.task.gpus = 1 for GPU-enabled training. Default should be 1; we don't want to invoke multiple cores/gpus without explicit mention. local_mode : bool, optional A boolean that determines whether we are using the driver node for training. Default should be false; we don't want to invoke executors without explicit mention. use_gpu : bool, optional A boolean that indicates whether or not we are doing training on the GPU. Note that there are differences in how GPU-enabled code looks like and how CPU-specific code looks like. Raises ------ ValueError If any of the parameters are incorrect. RuntimeError If an active SparkSession is unavailable. """super().__init__(num_processes,local_mode,use_gpu,ssl_conf=_ssl_conf)self._validate_input_params()self.input_params=self._create_input_params()@staticmethoddef_get_torchrun_args(local_mode:bool,num_processes:int)->Tuple[List[Any],int]:""" Given the mode and the number of processes, create the arguments to be given to for torch Parameters --------- local_mode: bool Whether or not we are running training locally or in a distributed fashion num_processes: int The number of processes that we are going to use Returns ------ Tuple[List[Any], int] A tuple containing a list of arguments to pass as pytorch args, as well as the number of processes per node """iflocal_mode:torchrun_args=["--standalone","--nnodes=1"]processes_per_node=num_processesreturntorchrun_args,processes_per_nodemaster_addr=os.environ["MASTER_ADDR"]master_port=os.environ["MASTER_PORT"]node_rank=os.environ["RANK"]torchrun_args=[f"--nnodes={num_processes}",f"--node_rank={node_rank}",f"--rdzv_endpoint={master_addr}:{master_port}","--rdzv_id=0",# TODO: setup random ID that is gleaned from env variables]processes_per_node=1returntorchrun_args,processes_per_node@staticmethoddef_create_torchrun_command(input_params:Dict[str,Any],path_to_train_file:str,*args:Any)->List[str]:local_mode=input_params["local_mode"]num_processes=input_params["num_processes"]torchrun_args,processes_per_node=TorchDistributor._get_torchrun_args(local_mode=local_mode,num_processes=num_processes)args_string=list(map(str,args))# converting all args to stringsreturn[sys.executable,"-m","pyspark.ml.torch.torch_run_process_wrapper",*torchrun_args,f"--nproc_per_node={processes_per_node}",path_to_train_file,*args_string,]@staticmethoddef_execute_command(cmd:List[str],_prctl:bool=True,redirect_to_stdout:bool=True,log_streaming_client:Optional[LogStreamingClient]=None,)->None:_TAIL_LINES_TO_KEEP=100task=subprocess.Popen(cmd,stdout=subprocess.PIPE,stderr=subprocess.STDOUT,stdin=subprocess.PIPE,env=os.environ,)task.stdin.close()# type: ignoretail:collections.deque=collections.deque(maxlen=_TAIL_LINES_TO_KEEP)try:forlineintask.stdout:# type: ignoredecoded=line.decode()tail.append(decoded)ifredirect_to_stdout:if(log_streaming_clientandnotlog_streaming_client.failedand(log_streaming_client.sock.getsockname()[0]==log_streaming_client.sock.getpeername()[0])):# If log_streaming_client and log_stream_server are in the same# node (typical case is spark local mode),# server side will redirect the log to STDOUT,# to avoid STDOUT outputs duplication, skip redirecting# logs to STDOUT in client side.passelse:sys.stdout.write(decoded)iflog_streaming_client:log_streaming_client.send(decoded.rstrip())task.wait()finally:iftask.poll()isNone:try:task.terminate()# SIGTERMtime.sleep(0.5)iftask.poll()isNone:task.kill()# SIGKILLexceptOSError:passiftask.returncode!=os.EX_OK:iflen(tail)==_TAIL_LINES_TO_KEEP:last_n_msg=f"last {_TAIL_LINES_TO_KEEP} lines of the task output are"else:last_n_msg="task output is"task_output="".join(tail)raiseRuntimeError(f"Command {cmd} failed with return code {task.returncode}. "f"The {last_n_msg} included below: {task_output}")@staticmethoddef_get_output_from_framework_wrapper(framework_wrapper:Optional[Callable],input_params:Dict,train_object:Union[Callable,str],run_pytorch_file_fn:Optional[Callable],*args:Any,**kwargs:Any,)->Optional[Any]:""" This function is meant to get the output from framework wrapper function by passing in the correct arguments, depending on the type of train_object. Parameters ---------- framework_wrapper: Optional[Callable] Function pointer that will be invoked. Can either be the function that runs distributed training on files if train_object is a string. Otherwise, it will be the function that runs distributed training for functions if the train_object is a Callable input_params: Dict A dictionary that maps parameter to arguments for the command to be created. train_object: Union[Callable, str] This input comes from the user. If the user inputs a string, then this means it's a filepath. Otherwise, if the input is a function, then this means that the user wants to run this function in a distributed manner. run_pytorch_file_fn: Optional[Callable] The function that will be used to run distributed training of a file; mainly used for the distributed training using a function. *args: Any Extra arguments to be used by framework wrapper. **kwargs: Any Extra keyword args to be used. Not currently supported but kept for future improvement. Returns ------- Optional[Any] Returns the result of the framework_wrapper """ifnotframework_wrapper:raiseRuntimeError("`framework_wrapper` is not set. ...")# The object to train is a file path, so framework_wrapper is some# run_training_on_pytorch_file function.iftype(train_object)isstr:returnframework_wrapper(input_params,train_object,*args,**kwargs)else:# We are doing training with a function, will call run_training_on_pytorch_functionifnotrun_pytorch_file_fn:run_pytorch_file_fn=TorchDistributor._run_training_on_pytorch_filereturnframework_wrapper(input_params,train_object,run_pytorch_file_fn,*args,**kwargs)def_run_local_training(self,framework_wrapper_fn:Callable,train_object:Union[Callable,str],run_pytorch_file_fn:Optional[Callable],*args:Any,**kwargs:Any,)->Optional[Any]:CUDA_VISIBLE_DEVICES="CUDA_VISIBLE_DEVICES"cuda_state_was_set=CUDA_VISIBLE_DEVICESinos.environold_cuda_visible_devices=os.environ.get(CUDA_VISIBLE_DEVICES,"")try:# Only replace the GPUs with 'SparkContext.resources' in legacy mode.# In connect mode, this replacement is skipped since only GPUs on the client side# can be used.ifself.use_gpuandnotself.is_remote:gpus_owned=_get_gpus_owned(self.spark)random.seed(hash(train_object))selected_gpus=[str(e)foreinrandom.sample(gpus_owned,self.num_processes)]os.environ[CUDA_VISIBLE_DEVICES]=",".join(selected_gpus)self.logger.info(f"Started local training with {self.num_processes} processes")output=TorchDistributor._get_output_from_framework_wrapper(framework_wrapper_fn,self.input_params,train_object,run_pytorch_file_fn,*args,**kwargs,)self.logger.info(f"Finished local training with {self.num_processes} processes")finally:ifcuda_state_was_set:os.environ[CUDA_VISIBLE_DEVICES]=old_cuda_visible_deviceselse:ifCUDA_VISIBLE_DEVICESinos.environ:delos.environ[CUDA_VISIBLE_DEVICES]returnoutputdef_get_spark_task_function(self,framework_wrapper_fn:Optional[Callable],train_object:Union[Callable,str],run_pytorch_file_fn:Optional[Callable],input_dataframe:Optional["DataFrame"],*args:Any,**kwargs:Any,)->Callable:"""Creates a spark task function that is used inside `mapPartitions`. Parameters ---------- framework_wrapper_fn : Optional[Callable] The function that determines whether we are running training on a PyTorch file or a PyTorch function. train_object : Union[Callable, str] The actual train function/file. Returns ------- Callable The wrapped function ready for use with `mapPartitions` """num_processes=self.num_processesuse_gpu=self.use_gpuinput_params=self.input_paramsdriver_address=self.driver_addresslog_streaming_server_port=self.log_streaming_server_portis_spark_local_master=self.is_spark_local_masterdriver_owned_gpus:List[str]=[]ifis_spark_local_masteranduse_gpu:driver_owned_gpus=_get_gpus_owned(self.spark)ifinput_dataframeisnotNone:schema_json=input_dataframe.schema.jsonValue()else:schema_json=None# Spark task programdefwrapped_train_fn(iterator):# type: ignore[no-untyped-def]importosimportpandasaspdimportpyarrowfrompysparkimportBarrierTaskContextCUDA_VISIBLE_DEVICES="CUDA_VISIBLE_DEVICES"defget_free_port(address:str,context:"BarrierTaskContext")->int:port=""ifcontext.partitionId()==0:try:importsocketsock=socket.socket()sock.bind((address,0))port=sock.getsockname()[1]exceptsocket.error:passavailable_port=context.allGather(str(port))[0]ifnotavailable_port:raiseRuntimeError("Failed to find free port for distributed training.")returnint(available_port)defset_torch_config(context:"BarrierTaskContext")->None:addrs=[e.address.split(":")[0]foreincontext.getTaskInfos()]os.environ["MASTER_ADDR"]=str(addrs[0])os.environ["MASTER_PORT"]=str(get_free_port(addrs[0],context))os.environ["WORLD_SIZE"]=str(num_processes)os.environ["NODE_RANK"]=str(context.partitionId())os.environ["RANK"]=str(context.partitionId())ifcontext.partitionId()>=num_processes:raiseValueError("TorchDistributor._train_on_dataframe requires setting num_processes ""equal to input spark dataframe partition number.")ifis_spark_local_master:# distributed training on a local mode spark clusterdefset_gpus(context:"BarrierTaskContext")->None:ifCUDA_VISIBLE_DEVICESinos.environ:returngpu_owned=driver_owned_gpus[context.partitionId()]os.environ[CUDA_VISIBLE_DEVICES]=gpu_ownedelse:defset_gpus(context:"BarrierTaskContext")->None:ifCUDA_VISIBLE_DEVICESinos.environ:returngpus_owned=_get_gpus_owned(context)os.environ[CUDA_VISIBLE_DEVICES]=",".join(gpus_owned)context=BarrierTaskContext.get()ifuse_gpu:set_gpus(context)else:os.environ[CUDA_VISIBLE_DEVICES]=""set_torch_config(context)log_streaming_client=LogStreamingClient(driver_address,log_streaming_server_port)input_params["log_streaming_client"]=log_streaming_clienttry:withTorchDistributor._setup_spark_partition_data(iterator,schema_json):output=TorchDistributor._get_output_from_framework_wrapper(framework_wrapper_fn,input_params,train_object,run_pytorch_file_fn,*args,**kwargs,)finally:try:LogStreamingClient._destroy()exceptBaseException:passifcontext.partitionId()==0:output_bytes=cloudpickle.dumps(output)output_size=len(output_bytes)# In Spark Connect, DataFrame.collect stacks rows to size# 'spark.connect.grpc.arrow.maxBatchSize' (default 4MiB),# here use 4KiB for each chunk, which mean each arrow batch# may contain about 1000 chunks.chunks=[]chunk_size=4096index=0whileindex<output_size:chunks.append(output_bytes[index:index+chunk_size])index+=chunk_sizeyieldpyarrow.RecordBatch.from_pandas(pd.DataFrame(data={"chunk":chunks}))returnwrapped_train_fndef_run_distributed_training(self,framework_wrapper_fn:Callable,train_object:Union[Callable,str],run_pytorch_file_fn:Optional[Callable],spark_dataframe:Optional["DataFrame"],*args:Any,**kwargs:Any,)->Optional[Any]:log_streaming_server=LogStreamingServer()self.driver_address=_get_conf(self.spark,"spark.driver.host","")assertself.driver_address!=""try:log_streaming_server.start(spark_host_address=self.driver_address)time.sleep(1)# wait for the server to startself.log_streaming_server_port=log_streaming_server.portexceptExceptionase:# If starting log streaming server failed, we don't need to break# the distributor training but emit a warning instead.self.log_streaming_server_port=-1self.logger.warning("Start torch distributor log streaming server failed, ""You cannot receive logs sent from distributor workers, ",f"error: {repr(e)}.",)try:spark_task_function=self._get_spark_task_function(framework_wrapper_fn,train_object,run_pytorch_file_fn,spark_dataframe,*args,**kwargs,)self._check_encryption()self.logger.info(f"Started distributed training with {self.num_processes} executor processes")ifspark_dataframeisnotNone:input_df=spark_dataframeelse:input_df=self.spark.range(start=0,end=self.num_tasks,step=1,numPartitions=self.num_tasks)rows=input_df.mapInArrow(func=spark_task_function,schema="chunk binary",barrier=True).collect()output_bytes=b"".join([row.chunkforrowinrows])result=cloudpickle.loads(output_bytes)finally:log_streaming_server.shutdown()self.logger.info(f"Finished distributed training with {self.num_processes} executor processes")returnresult@staticmethoddef_run_training_on_pytorch_file(input_params:Dict[str,Any],train_path:str,*args:Any,**kwargs:Any)->None:ifkwargs:raiseValueError("Running pytorch file does not support key-word type arguments.")log_streaming_client=input_params.get("log_streaming_client",None)training_command=TorchDistributor._create_torchrun_command(input_params,train_path,*args)TorchDistributor._execute_command(training_command,log_streaming_client=log_streaming_client)@staticmethod@contextmanagerdef_setup_files(train_fn:Callable,*args:Any,**kwargs:Any)->Generator[Tuple[str,str],None,None]:save_dir=TorchDistributor._create_save_dir()pickle_file_path=FunctionPickler.pickle_fn_and_save(train_fn,"",save_dir,*args,**kwargs)output_file_path=os.path.join(save_dir,TorchDistributor._PICKLED_OUTPUT_FILE)script_path=os.path.join(save_dir,TorchDistributor._TRAIN_FILE)train_file_path=FunctionPickler.create_fn_run_script(pickle_file_path,output_file_path,script_path)try:yield(train_file_path,output_file_path)finally:TorchDistributor._cleanup_files(save_dir)@staticmethod@contextmanagerdef_setup_spark_partition_data(partition_data_iterator:Iterator[Any],input_schema_json:Dict[str,Any])->Iterator[Any]:frompyspark.sql.pandas.serializersimportArrowStreamSerializerfrompyspark.core.filesimportSparkFilesimportjsonifinput_schema_jsonisNone:yieldreturn# We need to temporarily write partition data into a temp dir,# partition data might be huge, so we need to write it under# configured `SPARK_LOCAL_DIRS`.save_dir=TorchDistributor._create_save_dir(root_dir=SparkFiles.getRootDirectory())try:serializer=ArrowStreamSerializer()arrow_file_path=os.path.join(save_dir,"data.arrow")withopen(arrow_file_path,"wb")asf:serializer.dump_stream(partition_data_iterator,f)iff.tell()==0:# Nothing is written to file, this partition is emptyraiseValueError("Empty Spark partition is not allowed in ""TorchDistributor.train_on_dataframe.")schema_file_path=os.path.join(save_dir,"schema.json")schema_json_string=json.dumps(input_schema_json)withopen(schema_file_path,"w")asf:f.write(schema_json_string)os.environ[SPARK_PARTITION_ARROW_DATA_FILE]=arrow_file_pathos.environ[SPARK_DATAFRAME_SCHEMA_FILE]=schema_file_pathyieldfinally:os.environ.pop(SPARK_PARTITION_ARROW_DATA_FILE)os.environ.pop(SPARK_DATAFRAME_SCHEMA_FILE)TorchDistributor._cleanup_files(save_dir)@staticmethoddef_run_training_on_pytorch_function(input_params:Dict[str,Any],train_fn:Callable,run_pytorch_file_fn:Optional[Callable],*args:Any,**kwargs:Any,)->Any:ifnotrun_pytorch_file_fn:run_pytorch_file_fn=TorchDistributor._run_training_on_pytorch_filewithTorchDistributor._setup_files(train_fn,*args,**kwargs)as(train_file_path,output_file_path,):run_pytorch_file_fn(input_params,train_file_path)ifnotos.path.exists(output_file_path):raiseRuntimeError("TorchDistributor failed during training.""View stdout logs for detailed error message.")try:output=FunctionPickler.get_fn_output(output_file_path)exceptExceptionase:raiseRuntimeError("TorchDistributor failed due to a pickling error. ""View stdout logs for detailed error message.")fromereturnoutput@staticmethoddef_create_save_dir(root_dir:Optional[str]=None)->str:# TODO: need to do this in a safe way to avoid issues during concurrent runsreturntempfile.mkdtemp(dir=root_dir)@staticmethoddef_cleanup_files(save_dir:str)->None:shutil.rmtree(save_dir,ignore_errors=True)
[docs]defrun(self,train_object:Union[Callable,str],*args:Any,**kwargs:Any)->Optional[Any]:"""Runs distributed training. Parameters ---------- train_object : callable object or str Either a PyTorch function, PyTorch Lightning function, or the path to a python file that launches distributed training. args : If train_object is a python function and not a path to a python file, args need to be the input parameters to that function. It would look like >>> model = distributor.run(train, 1e-3, 64) where train is a function and 1e-3 and 64 are regular numeric inputs to the function. If train_object is a python file, then args would be the command-line arguments for that python file which are all in the form of strings. An example would be >>> distributor.run("/path/to/train.py", "--learning-rate=1e-3", "--batch-size=64") where since the input is a path, all of the parameters are strings that can be handled by argparse in that python file. kwargs : If train_object is a python function and not a path to a python file, kwargs need to be the key-word input parameters to that function. It would look like >>> model = distributor.run(train, tol=1e-3, max_iter=64) where train is a function of 2 arguments `tol` and `max_iter`. If train_object is a python file, then you should not set kwargs arguments. Returns ------- Returns the output of train_object called with args inside spark rank 0 task if the train_object is a Callable with an expected output. Returns None if train_object is a file. """returnself._run(train_object,TorchDistributor._run_training_on_pytorch_file,*args,**kwargs)
def_run(self,train_object:Union[Callable,str],run_pytorch_file_fn:Callable,*args:Any,**kwargs:Any,)->Optional[Any]:ifisinstance(train_object,str):framework_wrapper_fn=run_pytorch_file_fnelse:framework_wrapper_fn=TorchDistributor._run_training_on_pytorch_functionifself.local_mode:output=self._run_local_training(framework_wrapper_fn,train_object,run_pytorch_file_fn,*args,**kwargs)else:output=self._run_distributed_training(framework_wrapper_fn,train_object,run_pytorch_file_fn,None,*args,**kwargs)returnoutputdef_train_on_dataframe(self,train_function:Callable,spark_dataframe:"DataFrame",*args:Any,**kwargs:Any,)->Any:""" Runs distributed training using provided Spark DataFrame as input data. You should ensure the input Spark DataFrame have evenly distributed partitions, and this method starts a barrier Spark job that each Spark task in the job process one partition of the input Spark DataFrame. Parameters ---------- train_function : Either a PyTorch function, PyTorch Lightning function that launches distributed training. Note that inside the function, you can call `pyspark.ml.torch.distributor.get_spark_partition_data_loader` API to get a torch data loader, the data loader loads data from the corresponding partition of the input Spark DataFrame. spark_dataframe : An input Spark DataFrame that can be used in PyTorch `train_function` function. See `train_function` argument doc for details. args : `args` need to be the input parameters to `train_function` function. It would look like >>> model = distributor.run(train, 1e-3, 64) where train is a function and 1e-3 and 64 are regular numeric inputs to the function. kwargs : `kwargs` need to be the key-word input parameters to `train_function` function. It would look like >>> model = distributor.run(train, tol=1e-3, max_iter=64) where train is a function of 2 arguments `tol` and `max_iter`. Returns ------- Returns the output of `train_function` called with args inside Spark rank 0 task. """ifself.local_mode:raiseValueError("TorchDistributor.train_on_dataframe requires setting ""TorchDistributor.local_mode to False.")returnself._run_distributed_training(TorchDistributor._run_training_on_pytorch_function,train_function,TorchDistributor._run_training_on_pytorch_file,spark_dataframe,*args,**kwargs,)
def_get_spark_partition_data_loader(num_samples:int,batch_size:int,num_workers:int=1,prefetch_factor:int=2)->Any:""" This function must be called inside the `train_function` where `train_function` is the input argument of `TorchDistributor.train_on_dataframe`. The function returns a pytorch data loader that loads data from the corresponding spark partition data. Parameters ---------- num_samples : Number of samples to generate per epoch. If `num_samples` is less than the number of rows in the spark partition, it generate the first `num_samples` rows of the spark partition, if `num_samples` is greater than the number of rows in the spark partition, then after the iterator loaded all rows from the partition, it wraps round back to the first row. batch_size: How many samples per batch to load. num_workers: How many subprocesses to use for data loading. 0 means that the data will be loaded in the main process. prefetch_factor: Number of batches loaded in advance by each worker """frompyspark.sql.typesimportStructTypefrompyspark.ml.torch.dataimport_SparkPartitionTorchDatasetfromtorch.utils.dataimportDataLoaderarrow_file=os.environ[SPARK_PARTITION_ARROW_DATA_FILE]schema_file=os.environ[SPARK_DATAFRAME_SCHEMA_FILE]withopen(schema_file,"r")asfp:schema=StructType.fromJson(json.load(fp))dataset=_SparkPartitionTorchDataset(arrow_file,schema,num_samples)ifnum_workers>0:returnDataLoader(dataset,batch_size,num_workers=num_workers,prefetch_factor=prefetch_factor)else:# if num_workers is zero, we cannot set `prefetch_factor` otherwise# torch will raise error.returnDataLoader(dataset,batch_size,num_workers=num_workers)