Skip to content

Commit

Permalink
add serial running example on MNIST
Browse files Browse the repository at this point in the history
  • Loading branch information
Zilinghan committed Mar 13, 2024
1 parent 85091b9 commit abb00d4
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 3 deletions.
2 changes: 1 addition & 1 deletion examples/run_client_1.py → examples/grpc/run_client_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
init_global_model = client_comm.get_global_model(init_model=True)
client_agent.load_parameters(init_global_model)

# Send the number of load data to the server
# Send the number of local data to the server
sample_size = client_agent.get_sample_size()
print(f"Sample size: {sample_size}")
client_comm.invoke_custom_action(action='set_sample_size', sample_size=sample_size)
Expand Down
2 changes: 1 addition & 1 deletion examples/run_client_2.py → examples/grpc/run_client_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
init_global_model = client_comm.get_global_model(init_model=True)
client_agent.load_parameters(init_global_model)

# Send the number of load data to the server
# Send the number of local data to the server
sample_size = client_agent.get_sample_size()
print(f"Sample size: {sample_size}")
client_comm.invoke_custom_action(action='set_sample_size', sample_size=sample_size)
Expand Down
File renamed without changes.
84 changes: 84 additions & 0 deletions examples/serial/run_serial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
"""
Serial simulation of Federated learning.
It should be noted that only synchronous FL can be simulated in this way.
"""
import argparse
from omegaconf import OmegaConf
from appfl.agent import APPFLClientAgent, APPFLServerAgent

argparser = argparse.ArgumentParser()
argparser.add_argument(
"--server_config",
type=str,
default="config/server_fedavg.yaml",
)
argparser.add_argument(
"--client_config",
type=str,
default="config/client_1.yaml",
)
argparser.add_argument(
"--num_clients",
type=int,
default=10,
)
args = argparser.parse_args()

# Load server agent configurations and set the number of clients
server_agent_config = OmegaConf.load(args.server_config)
server_agent_config.server_configs.scheduler_kwargs.num_clients = args.num_clients
if hasattr(server_agent_config.server_configs.aggregator_kwargs, "num_clients"):
server_agent_config.server_configs.aggregator_kwargs.num_clients = args.num_clients

# Create server agent
server_agent = APPFLServerAgent(server_agent_config=server_agent_config)

# Load base client configurations and set corresponding fields for different clients
client_agent_configs = [OmegaConf.load(args.client_config) for _ in range(args.num_clients)]
for i in range(args.num_clients):
client_agent_configs[i].train_configs.logging_id = f'Client{i+1}'
client_agent_configs[i].data_configs.dataset_kwargs.num_clients = args.num_clients
client_agent_configs[i].data_configs.dataset_kwargs.client_id = i
client_agent_configs[i].data_configs.dataset_kwargs.visualization = True if i == 0 else False

# Load client agents
client_agents = [
APPFLClientAgent(client_agent_config=client_agent_configs[i])
for i in range(args.num_clients)
]

# Get additional client configurations from the server
client_config_from_server = server_agent.get_client_configs()
for client_agent in client_agents:
client_agent.load_config(client_config_from_server)

# Load initial global model from the server
init_global_model = server_agent.get_parameters(serial_run=True)
for client_agent in client_agents:
client_agent.load_parameters(init_global_model)

# [Optional] Set number of local data to the server
for i in range(args.num_clients):
sample_size = client_agents[i].get_sample_size()
server_agent.set_sample_size(
client_id=client_agents[i].get_id(),
sample_size=sample_size
)

for i in range(5):
new_global_models = []
for client_agent in client_agents:
# Client local training
client_agent.train()
local_model = client_agent.get_parameters()
# "Send" local model to server and get a Future object for the new global model
# The Future object will be resolved when the server receives local models from all clients
new_global_model_future = server_agent.global_update(
client_id=client_agent.get_id(),
local_model=local_model,
blocking=False,
)
new_global_models.append(new_global_model_future)
# Load the new global model from the server
for client_agent, new_global_model_future in zip(client_agents, new_global_models):
client_agent.load_parameters(new_global_model_future.result())
6 changes: 5 additions & 1 deletion src/appfl/scheduler/base_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,11 @@ def get_parameters(self, **kwargs) -> Union[Future, Dict, OrderedDict, Tuple[Uni
:params `kwargs['init_model']` (default is `True`): whether to get the initial global model or not
:return the global model or a `Future` object for the global model
"""
if kwargs.get("init_model", True) and self.scheduler_configs.get("same_init_model", True):
if (
kwargs.get("init_model", True)
and self.scheduler_configs.get("same_init_model", True)
and (not kwargs.get("serial_run", False))
):
if not hasattr(self, "init_model_requests"):
self.init_model_requests = 0
self.init_model_futures = []
Expand Down

0 comments on commit abb00d4

Please sign in to comment.