Föderiertes Training des Donut-Modells zur Document Recognition mit dem Simulator
Allgemeines
Folgendes Beispiel zeigt ein verteiltes Training mit dem Nvidia Flare Simulator (https://nvflare-22-docs-update.readthedocs.io/en/latest/user_guide/fl_simulator.html) für das Donut-Modell mit Code, Konfiguration und Ausgaben. Es wurden verschiedene Trainings mit 2,5 und 10 Clients durchgeführt.
## Configuration ... json # config_fed_client.json { "format_version": 2, "executors": [ { "tasks": [ "train", "validate" ], "executor": { "name": "LearnerExecutor", "path": "nvflare.app_common.executors.model_learner_executor.ModelLearnerExecutor", "args": { "learner_id": "donut_learner" } } } ], "task_result_filters": [ ], "task_data_filters": [ ], "components": [ { "id": "donut_learner", "path": "donut_learner.DonutLearner", "args": { "lr" : 3e-4, "aggregation_epochs" : 2, "split_dict" : { "site-1" : "[0%:10%]", "site-2" : "[10%:20%]" "site-3" : "[20%:30%]" "site-4" : "[30%:40%]" "site-5" : "[40%:50%]" "site-6" : "[50%:60%]" "site-7" : "[60%:70%]" "site-8" : "[70%:80%]" "site-9" : "[80%:90%]" "site-10" : "[90%:100%]" } } } ] }
Hier wird unsere Client Task für ein Training und Validierung konfiguriert. Dabei werden die Trainingsdaten gesplittet, sodass Client Site-1 die ersten 10%, `Site-2` die zweiten 10%, und so weiter bis Site-10 verwendet.
... json # config_fed_server.json { format_version = 2 task_data_filters = [] task_result_filters = [] model_class_path = "donut.DonutModelPLModule" workflows = [ { id = "scatter_and_gather" path = "nvflare.app_common.workflows.scatter_and_gather.ScatterAndGather" args { min_clients = 10 num_rounds = 5 start_round = 0 wait_time_after_min_received = 0 aggregator_id = "aggregator" persistor_id = "persistor" shareable_generator_id = "shareable_generator" train_task_name = "train" train_timeout = 0 } } ] components = [ { id = "persistor" path = "nvflare.app_opt.pt.file_model_persistor.PTFileModelPersistor" args { model { path = "{model_class_path}" } } } { id = "shareable_generator" path = "nvflare.app_common.shareablegenerators.full_model_shareable_generator.FullModelShareableGenerator" args {} } { id = "aggregator" path = "nvflare.app_common.aggregators.intime_accumulate_model_aggregator.InTimeAccumulateWeightedAggregator" args { expected_data_kind = "WEIGHT_DIFF" aggregation_weights = { "site-1" : 1.0, "site-2" : 1.0, "site-3" : 1.0, "site-4" : 1.0, "site-5" : 1.0, "site-6" : 1.0, "site-7" : 1.0, "site-8" : 1.0, "site-9" : 1.0, "site-10" : 1.0, } } } { id = "model_selector" path = "nvflare.app_common.widgets.intime_model_selector.IntimeModelSelector" args { key_metric = "val_edit_distance", negate_key_metric=True } } { id = "receiver" path = "nvflare.app_opt.tracking.tb.tb_receiver.TBAnalyticsReceiver" args { events = [ "fed.analytix_log_stats" ] } } ] }
## Site initialization ... python class DonutLearner(ModelLearner): def __init__( self, aggregation_epochs=1, lr=1e-2, batch_size=1, split_dict={"site-1": "[:50%]", "site-2": "[50%:]"}, ): super().__init__() self.aggregation_epochs = aggregation_epochs self.lr = lr self.batch_size = batch_size self.split_dict = split_dict self.split = None self.edit_distance = 0.0 self.model = DonutModelPLModule(lr=self.lr) self.default_train_conf = None self.persistence_manager = None self.n_iterations = None self.train_dataset = None self.validation_dataset = None self.local_model_file = None self.best_local_model_file = None self.trainer = pl.Trainer( devices=1, max_epochs=self.aggregation_epochs, precision="16-mixed", num_sanity_val_steps=0, accumulate_grad_batches=1, ) def initialize(self): # get the current data split self.info(f"Initializing {self.site_name} for training") self.split = self.split_dict[self.site_name] self.info(f"Training on split {self.split}") # persistence manager setup self.default_train_conf = {"train": {"model": type(self.model).__name__}} self.persistence_manager = PTModelPersistenceFormatManager( data=self.model.state_dict(), default_train_conf=self.default_train_conf ) ...
## Training python def train(self, model: FLModel) -> Union[str, FLModel]: self.info(f"Current/Total Round: {self.current_round + 1}/{self.total_rounds}") self.info(f"Client identity: {self.site_name}") self._create_datasets() global_weights = model.params # Before loading weights, tensors might need to be reshaped to support HE for secure aggregation. local_var_dict = self.model.state_dict() model_keys = global_weights.keys() for var_name in local_var_dict.keys(): if var_name in model_keys: weights = global_weights[var_name] try: # reshape global weights to compute difference later on global_weights[var_name] = np.reshape(weights, local_var_dict[var_name].shape) # update the local dict local_var_dict[var_name] = torch.as_tensor(global_weights[var_name]) except Exception as e: raise ValueError("Convert weight from {} failed with error: {}".format(var_name, str(e))) self.model.load_state_dict(local_var_dict) # keep a frozen copy of global model model_global = deepcopy(self.model) for param in model_global.parameters(): param.requires_grad = False # create dataloaders train_dataloader = DataLoader(self.train_dataset, shuffle=True, num_workers=0, batch_size=1) validation_dataloader = DataLoader(self.validation_dataset, shuffle=False, num_workers=0, batch_size=1) # train and validate self.trainer.fit(self.model, train_dataloader) edit_distance = self.trainer.validate(self.model, validation_dataloader)[0]["val_edit_distance"] self.n_iterations = len(train_dataloader) if edit_distance < self.edit_distance: self.edit_distance = edit_distance self.trainer.save_checkpoint("best_model.ckpt") else: self.trainer.save_checkpoint("model.ckpt") local_weights = self.model.state_dict() diff = {} for name in global_weights: if name not in global_weights: continue diff[name] = np.subtract(local_weights[name].cpu().numpy(), global_weights[name], dtype=np.float32) if np.any(np.isnan(diff[name])): self.stop_task(f"{name} weights became NaN...") return ReturnCode.EXECUTION_EXCEPTION # return an FLModel containing the model differences fl_model = FLModel(params_type=ParamsType.DIFF, params=diff) FLModelUtils.set_meta_prop(fl_model, FLMetaKey.NUM_STEPS_CURRENT_ROUND, self.n_iterations) self.info("Local epochs finished. Returning FLModel") return fl_model ...
## Validation ... python def validate(self, model: FLModel) -> Union[str, FLModel]: """Typical validation task pipeline Get global model weights (potentially with HE) Validation on local data Return validation Edit Distance """ self._create_datasets() # get validation information self.info(f"Client identity: {self.site_name}") # update local model weights with received weights global_weights = model.params # Before loading weights, tensors might need to be reshaped to support HE for secure aggregation. local_var_dict = self.model.state_dict() model_keys = global_weights.keys() n_loaded = 0 for var_name in local_var_dict: if var_name in model_keys: weights = torch.as_tensor(global_weights[var_name]) try: # update the local dict local_var_dict[var_name] = torch.as_tensor(torch.reshape(weights, local_var_dict[var_name].shape)) n_loaded += 1 except Exception as e: raise ValueError("Convert weight from {} failed with error: {}".format(var_name, str(e))) self.model.load_state_dict(local_var_dict) if n_loaded == 0: raise ValueError(f"No weights loaded for validation! Weights were searched in dict {local_var_dict}") # before_train_validate only, can extend to other validate types validate_type = FLModelUtils.get_meta_prop(model, FLMetaKey.VALIDATE_TYPE, ValidateType.MODEL_VALIDATE) model_owner = self.get_shareable_header(AppConstants.MODEL_OWNER) # get split for validation dataset validation_dataloader = DataLoader(self.validation_dataset, shuffle=False, num_workers=0, batch_size=1) edit_distance = self.trainer.validate(self.model, validation_dataloader)[0]["val_edit_distance"] return FLModel(metrics={"val_edit_distance": edit_distance}) ...
... # Server logs during initialization 2024-04-19 09:40:25,539 - ClientManager - INFO - Client: New client site-1@10.10.2.11 joined. Sent token: e356583b-82af-4386-92f9-cdec7e195c3a. Total clients: 1 2024-04-19 09:40:25,539 - FederatedClient - INFO - Successfully registered client:site-1 for project simulator_server. Token:e356583b-82af-4386-92f9-cdec7e195c3a SSID: 2024-04-19 09:40:25,540 - ClientManager - INFO - Client: New client site-2@10.10.2.11 joined. Sent token: 267c7798-af41-4494-b383-6076cdee22f7. Total clients: 2 ... Token:649ce318-351f-45ea-9456-364fdefc10eb SSID: 2024-04-19 09:40:25,552 - ClientManager - INFO - Client: New client site-10@10.10.2.11 joined. Sent token: 59120408-6ba9-4e51-8125-c9f08efd3ae1. Total clients: 10 ...
... 2024-04-19 09:40:46,105 - ServerRunner - INFO - [identity=simulator_server, run=simulate_job, wf=scatter_and_gather, peer=site-1, peer_run=simulate_job, task_name=train, task_id=31257a7c-f977-4be1-9a07-e06cdd34752c]: assigned task to client site-1: name=train, id=31257a7c-f977-4be1-9a07-e06cdd34752c ...
... 2024-04-19 09:40:47,290 - DonutLearner - INFO - [identity=site-1, run=simulate_job, peer=simulator_server, peer_run=simulate_job, task_name=train, task_id=31257a7c-f977-4be1-9a07-e06cdd34752c]: Training on split [0%:10%] ... 2024-04-19 09:41:26,860 - DonutLearner - INFO - [identity=site-1, run=simulate_job, peer=simulator_server, peer_run=simulate_job, task_name=train, task_id=31257a7c-f977-4be1-9a07-e06cdd34752c]: Current/Total Round: 1/5 ...
... 2024-04-19 09:44:08,222 - ScatterAndGather - INFO - [identity=simulator_server, run=simulate_job, wf=scatter_and_gather, peer=site-1, peer_run=simulate_job, peer_rc=OK, task_name=train, task_id=31257a7c-f977-4be1-9a07-e06cdd34752c]: Contribution from site-1 ACCEPTED by the aggregator at round 0. ...
... 2024-04-23 20:56:39,751 - DonutLearner - INFO - [identity=site-1, run=simulate_job, peer=simulator_server, peer_run=simulate_job, task_name=train, task_id=7bdc014b-72c3-4cd7-b482-326813231e9d]: Edit Distance: 0.9060616493225098 ...
... 2024-04-23 21:34:37,165 - IntimeModelSelector - INFO - [identity=simulator_server, run=simulate_job, wf=scatter_and_gather, peer=site-1, peer_run=simulate_job, peer_rc=OK, task_name=train, task_id=0e35ebab-17b0-413c-85ea-45cd15d9870e]: validation metric -0.9060616493225098 from client site-1 ...
... 2024-04-23 22:12:47,170 - DXOAggregator - INFO - [identity=simulator_server, run=simulate_job, wf=scatter_and_gather]: aggregating 10 update(s) at round 1 2024-04-23 22:12:47,586 - ScatterAndGather - INFO - [identity=simulator_server, run=simulate_job, wf=scatter_and_gather]: End aggregation. ... 2024-04-23 22:12:51,282 - ScatterAndGather - INFO - [identity=simulator_server, run=simulate_job, wf=scatter_and_gather]: Round 1 finished. 2024-04-23 22:12:51,282 - ScatterAndGather - INFO - [identity=simulator_server, run=simulate_job, wf=scatter_and_gather]: Round 2 started. ...
... 2024-04-23 23:43:03,965 - ScatterAndGather - INFO - [identity=simulator_server, run=simulate_job, wf=scatter_and_gather]: Round 4 finished. 637 2024-04-23 23:43:03,965 - ScatterAndGather - INFO - [identity=simulator_server, run=simulate_job, wf=scatter_and_gather]: Finished ScatterAndGather Training. 638 2024-04-23 23:43:03,967 - ServerRunner - INFO - [identity=simulator_server, run= simulate_job, wf=scatter_and_gather]: Workflow: scatter_and_gather finalizing .. . 639 2024-04-23 23:43:04,463 - ServerRunner - INFO - [identity=simulator_server, run= simulate_job, wf=scatter_and_gather]: ABOUT_TO_END_RUN fired 640 2024-04-23 23:43:04,463 - ServerRunner - INFO - [identity=simulator_server, run=simulate_job, wf=scatter_and_gather]: Firing CHECK_END_RUN_READINESS ... 641 2024-04-23 23:43:04,464 - ServerRunner - INFO - [identity=simulator_server, run=simulate_job, wf=scatter_and_gather]: END_RUN fired ...
Ergebnisse des Trainings
Die folgenden Ergebnisse zeigen, dass sich die Edit Distance mit jeder Epoche immer weiter einem Wert nähert.
Zudem sehen wir, dass wir bei der Verteilung unserer Trainingsdaten auf immer mehr Clients mehr Trainings-Epochen benötigen, um ein stabiles Ergebnis zu erreichen. Das liegt daran, dass bei mehr Clients jeder Einzelne weniger Trainingsdaten zur Verfügung hat.

Generell zeigen unsere Tests mit dem Simulator, dass dies ein sehr hilfreiches Tool ist, um das Verhalten eines geplanten Trainings mit einer größeren Anzahl von Clients abzuschätzen, ohne dass ein viel aufwändigeres physisch verteiltes Training erforderlich ist. Dies ist allerdings nur möglich, wenn wir zumindest eine kleine Untermenge der späteren verteilten Trainingsdaten für den Simulatortest zentral zur Verfügung haben. Weitere Tests, die in der Realität zu berücksichtigen sind, können auch auf diesem Weg durchgeführt werden. Dazu gehören z. B. die Fälle, wenn einzelne Clients korrupte Daten haben oder zeitweise Ausfälle oder Nicht-Erreichbarkeit von Clients vorliegen. Welche Auswirkungen ergeben sich dadurch auf das Training und was ist eine geeignete Politik, dem entgegenzuwirken.