-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbaseline.py
117 lines (103 loc) · 4.18 KB
/
baseline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
"""Main module for running FEMNIST experiments."""
import pathlib
from functools import partial
from typing import Type, Union
import flwr as fl
import hydra
import pandas as pd
import torch
from flwr.server.strategy import FedAvg
from omegaconf import DictConfig
from client import create_client
from dataset.dataset import (
create_federated_dataloaders,
)
from strategy import FedAvgSameClients
from utils import setup_seed, weighted_average
# pylint: disable=too-many-locals
@hydra.main(config_path="conf", version_base=None)
def main(cfg: DictConfig):
"""Main function for running FEMNIST experiments."""
# Ensure reproducibility
setup_seed(cfg.random_seed)
# Specify PyTorch device
# pylint: disable=no-member
device = torch.device(cfg.device)
# Create datasets for federated learning
trainloaders, valloaders, testloaders, _ = create_federated_dataloaders(
cfg.dataset.distribution_type,
cfg.dataset.dataset_fraction,
cfg.dataset.batch_size,
cfg.dataset.train_fraction,
cfg.dataset.validation_fraction,
cfg.dataset.test_fraction,
cfg.random_seed,
)
# The total number of clients created produced from sampling differs (on different random seeds)
total_n_clients = len(trainloaders)
print(f"Total number of clients: {total_n_clients}")
client_fnc = partial(
create_client,
trainloaders=trainloaders,
valloaders=valloaders,
testloaders=testloaders,
device=device,
num_epochs=cfg.training.epochs_per_round,
learning_rate=cfg.training.learning_rate,
# There exist other variants of the NIST dataset with different # of classes
dataset=cfg.dataset.dataset,
num_classes=cfg.dataset.num_classes,
num_batches=cfg.training.batches_per_round,
)
flwr_strategy: Union[Type[FedAvg], Type[FedAvgSameClients]]
if cfg.training.same_train_test_clients:
# Assign reference to a class
flwr_strategy = FedAvgSameClients
else:
flwr_strategy = FedAvg
strategy = flwr_strategy(
min_available_clients=total_n_clients,
# min number of clients to sample from for fit and evaluate
# Keep fraction fit low (not zero for consistency reasons with fraction_evaluate)
# and determine number of clients by the min_fit_clients
# (it's max of 1. fraction_fit * available clients 2. min_fit_clients)
fraction_fit=0.001,
min_fit_clients=cfg.training.num_clients_per_round,
fraction_evaluate=0.001,
min_evaluate_clients=cfg.training.num_clients_per_round,
# evaluate_fn=None, # Leave empty since it's responsible for the centralized evaluation
fit_metrics_aggregation_fn=weighted_average,
evaluate_metrics_aggregation_fn=weighted_average,
)
client_resources = None
if device.type == "cuda":
client_resources = {"num_gpus": 2.0}
# Start simulation
history = fl.simulation.start_simulation(
client_fn=client_fnc, # type: ignore
num_clients=total_n_clients, # total number of clients in a simulation
config=fl.server.ServerConfig(num_rounds=cfg.training.num_rounds),
strategy=strategy,
client_resources=client_resources,
)
# Save the results
results_dir_path = pathlib.Path(cfg.training.results_dir_path)
if not results_dir_path.exists():
results_dir_path.mkdir(parents=True)
distributed_history_dict = {}
for metric, round_value_tuple_list in history.metrics_distributed.items():
distributed_history_dict["distributed_test_" + metric] = [
val for _, val in round_value_tuple_list
]
for metric, round_value_tuple_list in history.metrics_distributed_fit.items(): # type: ignore
distributed_history_dict["distributed_" + metric] = [
val for _, val in round_value_tuple_list
]
distributed_history_dict["distributed_test_loss"] = [
val for _, val in history.losses_distributed
]
results_df = pd.DataFrame.from_dict(distributed_history_dict)
results_df.to_csv(results_dir_path / "history.csv")
if __name__ == "__main__":
# pylint: disable=no-value-for-parameter
main()