Föderiertes Training des Donut-Modells zur Document Recognition mit dem Simulator
Allgemeines
Folgendes Beispiel zeigt ein verteiltes Training mit dem Nvidia Flare Simulator (https://nvflare-22-docs-update.readthedocs.io/en/latest/user_guide/fl_simulator.html) für das Donut-Modell mit Code, Konfiguration und Ausgaben. Es wurden verschiedene Trainings mit 2,5 und 10 Clients durchgeführt.
## Configuration
...
json
# config_fed_client.json
{
"format_version": 2,
"executors": [
{
"tasks": [
"train",
"validate"
],
"executor": {
"name": "LearnerExecutor",
"path": "nvflare.app_common.executors.model_learner_executor.ModelLearnerExecutor",
"args": {
"learner_id": "donut_learner"
}
}
}
],
"task_result_filters": [
],
"task_data_filters": [
],
"components": [
{
"id": "donut_learner",
"path": "donut_learner.DonutLearner",
"args": {
"lr" : 3e-4,
"aggregation_epochs" : 2,
"split_dict" : {
"site-1" : "[0%:10%]",
"site-2" : "[10%:20%]"
"site-3" : "[20%:30%]"
"site-4" : "[30%:40%]"
"site-5" : "[40%:50%]"
"site-6" : "[50%:60%]"
"site-7" : "[60%:70%]"
"site-8" : "[70%:80%]"
"site-9" : "[80%:90%]"
"site-10" : "[90%:100%]"
}
}
}
]
}
Hier wird unsere Client Task für ein Training und Validierung konfiguriert. Dabei werden die Trainingsdaten gesplittet, sodass Client Site-1 die ersten 10%, `Site-2` die zweiten 10%, und so weiter bis Site-10 verwendet.
...
json
# config_fed_server.json
{
format_version = 2
task_data_filters = []
task_result_filters = []
model_class_path = "donut.DonutModelPLModule"
workflows = [
{
id = "scatter_and_gather"
path = "nvflare.app_common.workflows.scatter_and_gather.ScatterAndGather"
args {
min_clients = 10
num_rounds = 5
start_round = 0
wait_time_after_min_received = 0
aggregator_id = "aggregator"
persistor_id = "persistor"
shareable_generator_id = "shareable_generator"
train_task_name = "train"
train_timeout = 0
}
}
]
components = [
{
id = "persistor"
path = "nvflare.app_opt.pt.file_model_persistor.PTFileModelPersistor"
args {
model {
path = "{model_class_path}"
}
}
}
{
id = "shareable_generator"
path = "nvflare.app_common.shareablegenerators.full_model_shareable_generator.FullModelShareableGenerator"
args {}
}
{
id = "aggregator"
path = "nvflare.app_common.aggregators.intime_accumulate_model_aggregator.InTimeAccumulateWeightedAggregator"
args {
expected_data_kind = "WEIGHT_DIFF"
aggregation_weights = {
"site-1" : 1.0,
"site-2" : 1.0,
"site-3" : 1.0,
"site-4" : 1.0,
"site-5" : 1.0,
"site-6" : 1.0,
"site-7" : 1.0,
"site-8" : 1.0,
"site-9" : 1.0,
"site-10" : 1.0,
}
}
}
{
id = "model_selector"
path = "nvflare.app_common.widgets.intime_model_selector.IntimeModelSelector"
args {
key_metric = "val_edit_distance",
negate_key_metric=True
}
}
{
id = "receiver"
path = "nvflare.app_opt.tracking.tb.tb_receiver.TBAnalyticsReceiver"
args {
events = [
"fed.analytix_log_stats"
]
}
}
]
}
## Site initialization
...
python
class DonutLearner(ModelLearner):
def __init__(
self,
aggregation_epochs=1,
lr=1e-2,
batch_size=1,
split_dict={"site-1": "[:50%]", "site-2": "[50%:]"},
):
super().__init__()
self.aggregation_epochs = aggregation_epochs
self.lr = lr
self.batch_size = batch_size
self.split_dict = split_dict
self.split = None
self.edit_distance = 0.0
self.model = DonutModelPLModule(lr=self.lr)
self.default_train_conf = None
self.persistence_manager = None
self.n_iterations = None
self.train_dataset = None
self.validation_dataset = None
self.local_model_file = None
self.best_local_model_file = None
self.trainer = pl.Trainer(
devices=1,
max_epochs=self.aggregation_epochs,
precision="16-mixed",
num_sanity_val_steps=0,
accumulate_grad_batches=1,
)
def initialize(self):
# get the current data split
self.info(f"Initializing {self.site_name} for training")
self.split = self.split_dict[self.site_name]
self.info(f"Training on split {self.split}")
# persistence manager setup
self.default_train_conf = {"train": {"model": type(self.model).__name__}}
self.persistence_manager = PTModelPersistenceFormatManager(
data=self.model.state_dict(),
default_train_conf=self.default_train_conf
)
...
## Training
python
def train(self, model: FLModel) -> Union[str, FLModel]:
self.info(f"Current/Total Round: {self.current_round + 1}/{self.total_rounds}")
self.info(f"Client identity: {self.site_name}")
self._create_datasets()
global_weights = model.params
# Before loading weights, tensors might need to be reshaped to support HE for secure aggregation.
local_var_dict = self.model.state_dict()
model_keys = global_weights.keys()
for var_name in local_var_dict.keys():
if var_name in model_keys:
weights = global_weights[var_name]
try:
# reshape global weights to compute difference later on
global_weights[var_name] = np.reshape(weights, local_var_dict[var_name].shape)
# update the local dict
local_var_dict[var_name] = torch.as_tensor(global_weights[var_name])
except Exception as e:
raise ValueError("Convert weight from {} failed with error: {}".format(var_name, str(e)))
self.model.load_state_dict(local_var_dict)
# keep a frozen copy of global model
model_global = deepcopy(self.model)
for param in model_global.parameters():
param.requires_grad = False
# create dataloaders
train_dataloader = DataLoader(self.train_dataset, shuffle=True, num_workers=0, batch_size=1)
validation_dataloader = DataLoader(self.validation_dataset, shuffle=False, num_workers=0, batch_size=1)
# train and validate
self.trainer.fit(self.model, train_dataloader)
edit_distance = self.trainer.validate(self.model, validation_dataloader)[0]["val_edit_distance"]
self.n_iterations = len(train_dataloader)
if edit_distance < self.edit_distance:
self.edit_distance = edit_distance
self.trainer.save_checkpoint("best_model.ckpt")
else:
self.trainer.save_checkpoint("model.ckpt")
local_weights = self.model.state_dict()
diff = {}
for name in global_weights:
if name not in global_weights:
continue
diff[name] = np.subtract(local_weights[name].cpu().numpy(), global_weights[name], dtype=np.float32)
if np.any(np.isnan(diff[name])):
self.stop_task(f"{name} weights became NaN...")
return ReturnCode.EXECUTION_EXCEPTION
# return an FLModel containing the model differences
fl_model = FLModel(params_type=ParamsType.DIFF, params=diff)
FLModelUtils.set_meta_prop(fl_model, FLMetaKey.NUM_STEPS_CURRENT_ROUND, self.n_iterations)
self.info("Local epochs finished. Returning FLModel")
return fl_model
...
## Validation
...
python
def validate(self, model: FLModel) -> Union[str, FLModel]:
"""Typical validation task pipeline
Get global model weights (potentially with HE)
Validation on local data
Return validation Edit Distance
"""
self._create_datasets()
# get validation information
self.info(f"Client identity: {self.site_name}")
# update local model weights with received weights
global_weights = model.params
# Before loading weights, tensors might need to be reshaped to support HE for secure aggregation.
local_var_dict = self.model.state_dict()
model_keys = global_weights.keys()
n_loaded = 0
for var_name in local_var_dict:
if var_name in model_keys:
weights = torch.as_tensor(global_weights[var_name])
try:
# update the local dict
local_var_dict[var_name] = torch.as_tensor(torch.reshape(weights, local_var_dict[var_name].shape))
n_loaded += 1
except Exception as e:
raise ValueError("Convert weight from {} failed with error: {}".format(var_name, str(e)))
self.model.load_state_dict(local_var_dict)
if n_loaded == 0:
raise ValueError(f"No weights loaded for validation! Weights were searched in dict {local_var_dict}")
# before_train_validate only, can extend to other validate types
validate_type = FLModelUtils.get_meta_prop(model, FLMetaKey.VALIDATE_TYPE, ValidateType.MODEL_VALIDATE)
model_owner = self.get_shareable_header(AppConstants.MODEL_OWNER)
# get split for validation dataset
validation_dataloader = DataLoader(self.validation_dataset, shuffle=False, num_workers=0, batch_size=1)
edit_distance = self.trainer.validate(self.model, validation_dataloader)[0]["val_edit_distance"]
return FLModel(metrics={"val_edit_distance": edit_distance})
...
... # Server logs during initialization 2024-04-19 09:40:25,539 - ClientManager - INFO - Client: New client site-1@10.10.2.11 joined. Sent token: e356583b-82af-4386-92f9-cdec7e195c3a. Total clients: 1 2024-04-19 09:40:25,539 - FederatedClient - INFO - Successfully registered client:site-1 for project simulator_server. Token:e356583b-82af-4386-92f9-cdec7e195c3a SSID: 2024-04-19 09:40:25,540 - ClientManager - INFO - Client: New client site-2@10.10.2.11 joined. Sent token: 267c7798-af41-4494-b383-6076cdee22f7. Total clients: 2 ... Token:649ce318-351f-45ea-9456-364fdefc10eb SSID: 2024-04-19 09:40:25,552 - ClientManager - INFO - Client: New client site-10@10.10.2.11 joined. Sent token: 59120408-6ba9-4e51-8125-c9f08efd3ae1. Total clients: 10 ...
... 2024-04-19 09:40:46,105 - ServerRunner - INFO - [identity=simulator_server, run=simulate_job, wf=scatter_and_gather, peer=site-1, peer_run=simulate_job, task_name=train, task_id=31257a7c-f977-4be1-9a07-e06cdd34752c]: assigned task to client site-1: name=train, id=31257a7c-f977-4be1-9a07-e06cdd34752c ...
... 2024-04-19 09:40:47,290 - DonutLearner - INFO - [identity=site-1, run=simulate_job, peer=simulator_server, peer_run=simulate_job, task_name=train, task_id=31257a7c-f977-4be1-9a07-e06cdd34752c]: Training on split [0%:10%] ... 2024-04-19 09:41:26,860 - DonutLearner - INFO - [identity=site-1, run=simulate_job, peer=simulator_server, peer_run=simulate_job, task_name=train, task_id=31257a7c-f977-4be1-9a07-e06cdd34752c]: Current/Total Round: 1/5 ...
... 2024-04-19 09:44:08,222 - ScatterAndGather - INFO - [identity=simulator_server, run=simulate_job, wf=scatter_and_gather, peer=site-1, peer_run=simulate_job, peer_rc=OK, task_name=train, task_id=31257a7c-f977-4be1-9a07-e06cdd34752c]: Contribution from site-1 ACCEPTED by the aggregator at round 0. ...
... 2024-04-23 20:56:39,751 - DonutLearner - INFO - [identity=site-1, run=simulate_job, peer=simulator_server, peer_run=simulate_job, task_name=train, task_id=7bdc014b-72c3-4cd7-b482-326813231e9d]: Edit Distance: 0.9060616493225098 ...
... 2024-04-23 21:34:37,165 - IntimeModelSelector - INFO - [identity=simulator_server, run=simulate_job, wf=scatter_and_gather, peer=site-1, peer_run=simulate_job, peer_rc=OK, task_name=train, task_id=0e35ebab-17b0-413c-85ea-45cd15d9870e]: validation metric -0.9060616493225098 from client site-1 ...
... 2024-04-23 22:12:47,170 - DXOAggregator - INFO - [identity=simulator_server, run=simulate_job, wf=scatter_and_gather]: aggregating 10 update(s) at round 1 2024-04-23 22:12:47,586 - ScatterAndGather - INFO - [identity=simulator_server, run=simulate_job, wf=scatter_and_gather]: End aggregation. ... 2024-04-23 22:12:51,282 - ScatterAndGather - INFO - [identity=simulator_server, run=simulate_job, wf=scatter_and_gather]: Round 1 finished. 2024-04-23 22:12:51,282 - ScatterAndGather - INFO - [identity=simulator_server, run=simulate_job, wf=scatter_and_gather]: Round 2 started. ...
... 2024-04-23 23:43:03,965 - ScatterAndGather - INFO - [identity=simulator_server, run=simulate_job, wf=scatter_and_gather]: Round 4 finished. 637 2024-04-23 23:43:03,965 - ScatterAndGather - INFO - [identity=simulator_server, run=simulate_job, wf=scatter_and_gather]: Finished ScatterAndGather Training. 638 2024-04-23 23:43:03,967 - ServerRunner - INFO - [identity=simulator_server, run= simulate_job, wf=scatter_and_gather]: Workflow: scatter_and_gather finalizing .. . 639 2024-04-23 23:43:04,463 - ServerRunner - INFO - [identity=simulator_server, run= simulate_job, wf=scatter_and_gather]: ABOUT_TO_END_RUN fired 640 2024-04-23 23:43:04,463 - ServerRunner - INFO - [identity=simulator_server, run=simulate_job, wf=scatter_and_gather]: Firing CHECK_END_RUN_READINESS ... 641 2024-04-23 23:43:04,464 - ServerRunner - INFO - [identity=simulator_server, run=simulate_job, wf=scatter_and_gather]: END_RUN fired ...
Ergebnisse des Trainings
Die folgenden Ergebnisse zeigen, dass sich die Edit Distance mit jeder Epoche immer weiter einem Wert nähert.
Zudem sehen wir, dass wir bei der Verteilung unserer Trainingsdaten auf immer mehr Clients mehr Trainings-Epochen benötigen, um ein stabiles Ergebnis zu erreichen. Das liegt daran, dass bei mehr Clients jeder Einzelne weniger Trainingsdaten zur Verfügung hat.
Generell zeigen unsere Tests mit dem Simulator, dass dies ein sehr hilfreiches Tool ist, um das Verhalten eines geplanten Trainings mit einer größeren Anzahl von Clients abzuschätzen, ohne dass ein viel aufwändigeres physisch verteiltes Training erforderlich ist. Dies ist allerdings nur möglich, wenn wir zumindest eine kleine Untermenge der späteren verteilten Trainingsdaten für den Simulatortest zentral zur Verfügung haben. Weitere Tests, die in der Realität zu berücksichtigen sind, können auch auf diesem Weg durchgeführt werden. Dazu gehören z. B. die Fälle, wenn einzelne Clients korrupte Daten haben oder zeitweise Ausfälle oder Nicht-Erreichbarkeit von Clients vorliegen. Welche Auswirkungen ergeben sich dadurch auf das Training und was ist eine geeignete Politik, dem entgegenzuwirken.