Föderiertes Training des Donut-Modells zur Document Recognition mit dem Simulator

Allgemeines

Folgendes Beispiel zeigt ein verteiltes Training mit dem Nvidia Flare Simulator (https://nvflare-22-docs-update.readthedocs.io/en/latest/user_guide/fl_simulator.html) für das Donut-Modell mit Code, Konfiguration und Ausgaben. Es wurden verschiedene Trainings mit 2,5 und 10 Clients durchgeführt.

Nach oben

Konfiguration der Clients

Unsere Clients wurden wie folgt konfiguriert:

## Configuration
...
json

# config_fed_client.json
{
   "format_version": 2,
   "executors": [
      {
         "tasks": [
            "train",
            "validate"
         ],
         "executor": {
            "name": "LearnerExecutor",
            "path": "nvflare.app_common.executors.model_learner_executor.ModelLearnerExecutor",
            "args": {
               "learner_id": "donut_learner"
            }
         }
      }
   ],
   "task_result_filters": [
   ],
   "task_data_filters": [
   ],
   "components": [
      {
         "id": "donut_learner",
         "path": "donut_learner.DonutLearner",
         "args": {
          "lr" : 3e-4,
          "aggregation_epochs" : 2, 
          "split_dict" : {
            "site-1" : "[0%:10%]",
            "site-2" : "[10%:20%]"
            "site-3" : "[20%:30%]"
            "site-4" : "[30%:40%]"
            "site-5" : "[40%:50%]"
            "site-6" : "[50%:60%]"
            "site-7" : "[60%:70%]"
            "site-8" : "[70%:80%]"
            "site-9" : "[80%:90%]"
            "site-10" : "[90%:100%]"
          }
         }
      }
   ]
}

Hier wird unsere Client Task für ein Training und Validierung konfiguriert. Dabei werden die Trainingsdaten gesplittet, sodass Client Site-1 die ersten 10%, `Site-2` die zweiten 10%, und so weiter bis Site-10 verwendet.

...
json
# config_fed_server.json
{
  format_version = 2
  task_data_filters = []
  task_result_filters = []
  model_class_path = "donut.DonutModelPLModule"
  workflows = [
    {
      id = "scatter_and_gather"
      path = "nvflare.app_common.workflows.scatter_and_gather.ScatterAndGather"
      args {
        min_clients = 10
        num_rounds = 5
        start_round = 0
        wait_time_after_min_received = 0
        aggregator_id = "aggregator"
        persistor_id = "persistor"
        shareable_generator_id = "shareable_generator"
        train_task_name = "train"
        train_timeout = 0
      }
    }
  ]
  components = [
    {
      id = "persistor"
      path = "nvflare.app_opt.pt.file_model_persistor.PTFileModelPersistor"
      args {
        model {
          path = "{model_class_path}"
        }
      }
    }
    {
      id = "shareable_generator"
      path = "nvflare.app_common.shareablegenerators.full_model_shareable_generator.FullModelShareableGenerator"
      args {}
    }
    {
        id = "aggregator"
        path = "nvflare.app_common.aggregators.intime_accumulate_model_aggregator.InTimeAccumulateWeightedAggregator"
        args {
            expected_data_kind = "WEIGHT_DIFF"
            aggregation_weights = {
                "site-1" : 1.0,
                "site-2" : 1.0,
                "site-3" : 1.0,
                "site-4" : 1.0,
                "site-5" : 1.0,
                "site-6" : 1.0,
                "site-7" : 1.0,
                "site-8" : 1.0,
                "site-9" : 1.0,
                "site-10" : 1.0,
            }
        }
    }
    {
      id = "model_selector"
      path = "nvflare.app_common.widgets.intime_model_selector.IntimeModelSelector"
      args {
        key_metric = "val_edit_distance",
        negate_key_metric=True
      }
    }
    {
      id = "receiver"
      path = "nvflare.app_opt.tracking.tb.tb_receiver.TBAnalyticsReceiver"
      args {
        events = [
          "fed.analytix_log_stats"
        ]
      }
    }
  ]
}

Konfiguration des Servers

Hiermit wird der zentrale Server für das verteilte Training konfiguriert. Dabei haben wir 10 Clients, 5 Epochen und das beste Modell ist jenes mit der kleinsten Edit Distance (https://en.wikipedia.org/wiki/Edit_distance):

## Site initialization
...
python
class DonutLearner(ModelLearner):
    def __init__(
        self,
        aggregation_epochs=1,
        lr=1e-2,
        batch_size=1,
        split_dict={"site-1": "[:50%]", "site-2": "[50%:]"},
    ):
        super().__init__()
        self.aggregation_epochs = aggregation_epochs
        self.lr = lr
        self.batch_size = batch_size
        self.split_dict = split_dict
        self.split = None
        self.edit_distance = 0.0
        self.model = DonutModelPLModule(lr=self.lr)
        self.default_train_conf = None
        self.persistence_manager = None
        self.n_iterations = None
        self.train_dataset = None
        self.validation_dataset = None
        self.local_model_file = None
        self.best_local_model_file = None
        self.trainer = pl.Trainer(
            devices=1,
            max_epochs=self.aggregation_epochs,
            precision="16-mixed",
            num_sanity_val_steps=0,
            accumulate_grad_batches=1,
        )
def initialize(self):
    # get the current data split
    self.info(f"Initializing {self.site_name} for training")
    self.split = self.split_dict[self.site_name]
    self.info(f"Training on split {self.split}")
    # persistence manager setup
    self.default_train_conf = {"train": {"model": type(self.model).__name__}}
    self.persistence_manager = PTModelPersistenceFormatManager(
        data=self.model.state_dict(),
        default_train_conf=self.default_train_conf
    )
...

Server bereitet Clients für das Training vor

Im nächsten Schritt bereitet sich der ModelLearner für das Training vor. Da alle Clients den gleichen Code ausführen, erfolgt dies, indem der Server jedem Client mitteilt, welche Trainings-Site er ist und mit welchem Split er trainieren soll:

## Training

python 
def train(self, model: FLModel) -> Union[str, FLModel]:
        self.info(f"Current/Total Round: {self.current_round + 1}/{self.total_rounds}")
        self.info(f"Client identity: {self.site_name}")
        self._create_datasets()
        global_weights = model.params
        # Before loading weights, tensors might need to be reshaped to support HE for secure aggregation.
        local_var_dict = self.model.state_dict()
        model_keys = global_weights.keys()
        for var_name in local_var_dict.keys():
            if var_name in model_keys:
                weights = global_weights[var_name]
                try:
                    # reshape global weights to compute difference later on
                    global_weights[var_name] = np.reshape(weights, local_var_dict[var_name].shape)
                    # update the local dict
                    local_var_dict[var_name] = torch.as_tensor(global_weights[var_name])
                except Exception as e:
                    raise ValueError("Convert weight from {} failed with error: {}".format(var_name, str(e)))
        self.model.load_state_dict(local_var_dict)
        # keep a frozen copy of global model
        model_global = deepcopy(self.model)
        for param in model_global.parameters():
            param.requires_grad = False
        # create dataloaders
        train_dataloader = DataLoader(self.train_dataset, shuffle=True, num_workers=0, batch_size=1)
        validation_dataloader = DataLoader(self.validation_dataset, shuffle=False, num_workers=0, batch_size=1)
        # train and validate
        self.trainer.fit(self.model, train_dataloader)
        edit_distance = self.trainer.validate(self.model, validation_dataloader)[0]["val_edit_distance"]
        self.n_iterations = len(train_dataloader)
        if edit_distance < self.edit_distance:
            self.edit_distance = edit_distance
            self.trainer.save_checkpoint("best_model.ckpt")
        else:
            self.trainer.save_checkpoint("model.ckpt")
        local_weights = self.model.state_dict()
        diff = {}
        for name in global_weights:
            if name not in global_weights:
                continue
            diff[name] = np.subtract(local_weights[name].cpu().numpy(), global_weights[name], dtype=np.float32)
            if np.any(np.isnan(diff[name])):
                self.stop_task(f"{name} weights became NaN...")
                return ReturnCode.EXECUTION_EXCEPTION
        # return an FLModel containing the model differences
        fl_model = FLModel(params_type=ParamsType.DIFF, params=diff)
        FLModelUtils.set_meta_prop(fl_model, FLMetaKey.NUM_STEPS_CURRENT_ROUND, self.n_iterations)
        self.info("Local epochs finished. Returning FLModel")
        return fl_model
...

Clients laden Model

Im nächsten Schritt lädt der Client das Modell vom Server, trainiert es mit seinen lokalen Daten (in unserem Fall mit dem zugewiesenen Split), und sendet anschließend das Modell zurück zum Server zur Aggregation:

## Validation
...
python
def validate(self, model: FLModel) -> Union[str, FLModel]:
        """Typical validation task pipeline
        Get global model weights (potentially with HE)
        Validation on local data
        Return validation Edit Distance
        """
        self._create_datasets()
        # get validation information
        self.info(f"Client identity: {self.site_name}")
        # update local model weights with received weights
        global_weights = model.params
        # Before loading weights, tensors might need to be reshaped to support HE for secure aggregation.
        local_var_dict = self.model.state_dict()
        model_keys = global_weights.keys()
        n_loaded = 0
        for var_name in local_var_dict:
            if var_name in model_keys:
                weights = torch.as_tensor(global_weights[var_name])
                try:
                    # update the local dict
                    local_var_dict[var_name] = torch.as_tensor(torch.reshape(weights, local_var_dict[var_name].shape))
                    n_loaded += 1
                except Exception as e:
                    raise ValueError("Convert weight from {} failed with error: {}".format(var_name, str(e)))
        self.model.load_state_dict(local_var_dict)
        if n_loaded == 0:
            raise ValueError(f"No weights loaded for validation! Weights were searched in dict {local_var_dict}")
        # before_train_validate only, can extend to other validate types
        validate_type = FLModelUtils.get_meta_prop(model, FLMetaKey.VALIDATE_TYPE, ValidateType.MODEL_VALIDATE)
        model_owner = self.get_shareable_header(AppConstants.MODEL_OWNER)
        # get split for validation dataset
        validation_dataloader = DataLoader(self.validation_dataset, shuffle=False, num_workers=0, batch_size=1)
        edit_distance = self.trainer.validate(self.model, validation_dataloader)[0]["val_edit_distance"]
        return FLModel(metrics={"val_edit_distance": edit_distance})
...

Ablauf eines Trainings mit dem Simulator

Als Erstes verbindet der Simulator die Clients:

...
# Server logs during initialization
2024-04-19 09:40:25,539 - ClientManager - INFO - Client: New client site-1@10.10.2.11 joined. Sent token: e356583b-82af-4386-92f9-cdec7e195c3a.  Total clients: 1
2024-04-19 09:40:25,539 - FederatedClient - INFO - Successfully registered client:site-1 for project simulator_server. Token:e356583b-82af-4386-92f9-cdec7e195c3a SSID:
2024-04-19 09:40:25,540 - ClientManager - INFO - Client: New client site-2@10.10.2.11 joined. Sent token: 267c7798-af41-4494-b383-6076cdee22f7.  Total clients: 2
...
Token:649ce318-351f-45ea-9456-364fdefc10eb SSID:
2024-04-19 09:40:25,552 - ClientManager - INFO - Client: New client site-10@10.10.2.11 joined. Sent token: 59120408-6ba9-4e51-8125-c9f08efd3ae1.  Total clients: 10
...
Dann initialisiert der Server das Training:

...
2024-04-19 09:40:46,105 - ServerRunner - INFO - [identity=simulator_server, run=simulate_job, wf=scatter_and_gather, peer=site-1, peer_run=simulate_job, task_name=train, task_id=31257a7c-f977-4be1-9a07-e06cdd34752c]: assigned task to client site-1: name=train, id=31257a7c-f977-4be1-9a07-e06cdd34752c
...
Dann empfängt der Client den Start zum Training und er läuft eine Epoche:

...
2024-04-19 09:40:47,290 - DonutLearner - INFO - [identity=site-1, run=simulate_job, peer=simulator_server, peer_run=simulate_job, task_name=train, task_id=31257a7c-f977-4be1-9a07-e06cdd34752c]: Training on split [0%:10%]
...
2024-04-19 09:41:26,860 - DonutLearner - INFO - [identity=site-1, run=simulate_job, peer=simulator_server, peer_run=simulate_job, task_name=train, task_id=31257a7c-f977-4be1-9a07-e06cdd34752c]: Current/Total Round: 1/5
...
Nach jeder lokalen Trainings-Epoche empfängt der Server das Modell:

...
2024-04-19 09:44:08,222 - ScatterAndGather - INFO - [identity=simulator_server, run=simulate_job, wf=scatter_and_gather, peer=site-1, peer_run=simulate_job, peer_rc=OK, task_name=train, task_id=31257a7c-f977-4be1-9a07-e06cdd34752c]: Contribution from site-1 ACCEPTED by the aggregator at round 0.
...
Dann wird das Modell auf dem Client validiert und die resultierende Metrik wird zum Server gesendet:

...
2024-04-23 20:56:39,751 - DonutLearner - INFO - [identity=site-1, run=simulate_job, peer=simulator_server, peer_run=simulate_job, task_name=train, task_id=7bdc014b-72c3-4cd7-b482-326813231e9d]: Edit Distance: 0.9060616493225098
...
Dann wird die Metrik auf der Serverseite invertiert, denn wir nutzen eine Metrik, bei der kleiner besser als größer ist.

...
2024-04-23 21:34:37,165 - IntimeModelSelector - INFO - [identity=simulator_server, run=simulate_job, wf=scatter_and_gather, peer=site-1, peer_run=simulate_job,     peer_rc=OK, task_name=train, task_id=0e35ebab-17b0-413c-85ea-45cd15d9870e]: validation metric -0.9060616493225098 from client site-1
...
Dieser Zyklus wird über alle Clients wiederholt, der Server aggregiert alle Zulieferungen der Clients und beginnt die nächste Epoche:

...
2024-04-23 22:12:47,170 - DXOAggregator - INFO - [identity=simulator_server, run=simulate_job, wf=scatter_and_gather]: aggregating 10 update(s) at round 1
2024-04-23 22:12:47,586 - ScatterAndGather - INFO - [identity=simulator_server, run=simulate_job, wf=scatter_and_gather]: End aggregation.
...
2024-04-23 22:12:51,282 - ScatterAndGather - INFO - [identity=simulator_server, run=simulate_job, wf=scatter_and_gather]: Round 1 finished.
2024-04-23 22:12:51,282 - ScatterAndGather - INFO - [identity=simulator_server, run=simulate_job, wf=scatter_and_gather]: Round 2 started.
...
Diese Epochen werden wiederholt bis zur definierten Anzahl und das finale Modell wird gespeichert.

...
2024-04-23 23:43:03,965 - ScatterAndGather - INFO - [identity=simulator_server, run=simulate_job, wf=scatter_and_gather]: Round 4 finished.
637 2024-04-23 23:43:03,965 - ScatterAndGather - INFO - [identity=simulator_server, run=simulate_job, wf=scatter_and_gather]: Finished ScatterAndGather Training.
638 2024-04-23 23:43:03,967 - ServerRunner - INFO - [identity=simulator_server, run= simulate_job, wf=scatter_and_gather]: Workflow: scatter_and_gather finalizing ..    .
639 2024-04-23 23:43:04,463 - ServerRunner - INFO - [identity=simulator_server, run= simulate_job, wf=scatter_and_gather]: ABOUT_TO_END_RUN fired 640 2024-04-23 23:43:04,463 - ServerRunner - INFO - [identity=simulator_server, run=simulate_job, wf=scatter_and_gather]: Firing CHECK_END_RUN_READINESS ...
641 2024-04-23 23:43:04,464 - ServerRunner - INFO - [identity=simulator_server, run=simulate_job, wf=scatter_and_gather]: END_RUN fired
...

Ergebnisse des Trainings

Die folgenden Ergebnisse zeigen, dass sich die Edit Distance mit jeder Epoche immer weiter einem Wert nähert.

Zudem sehen wir, dass wir bei der Verteilung unserer Trainingsdaten auf immer mehr Clients mehr Trainings-Epochen benötigen, um ein stabiles Ergebnis zu erreichen. Das liegt daran, dass bei mehr Clients jeder Einzelne weniger Trainingsdaten zur Verfügung hat.

Generell zeigen unsere Tests mit dem Simulator, dass dies ein sehr hilfreiches Tool ist, um das Verhalten eines geplanten Trainings mit einer größeren Anzahl von Clients abzuschätzen, ohne dass ein viel aufwändigeres physisch verteiltes Training erforderlich ist. Dies ist allerdings nur möglich, wenn wir zumindest eine kleine Untermenge der späteren verteilten Trainingsdaten für den Simulatortest zentral zur Verfügung haben. Weitere Tests, die in der Realität zu berücksichtigen sind, können auch auf diesem Weg durchgeführt werden. Dazu gehören z. B. die Fälle, wenn einzelne Clients korrupte Daten haben oder zeitweise Ausfälle oder Nicht-Erreichbarkeit von Clients vorliegen. Welche Auswirkungen ergeben sich dadurch auf das Training und was ist eine geeignete Politik, dem entgegenzuwirken.