-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathtrain_model.py
142 lines (129 loc) · 4.85 KB
/
train_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import pandas as pd
import numpy as np
import torch
from tqdm import tqdm
from load_dataset import Dataset
import os
class MNIST_PAQ:
def __init__(self, filename="saved_models", number_of_clients=1, aggregate_epochs=10, local_epochs=5, precision=7, r=1.0):
self.model = None
self.criterion = torch.nn.CrossEntropyLoss()
self.optimizer = None
self.number_of_clients = number_of_clients
self.aggregate_epochs = aggregate_epochs
self.local_epochs = local_epochs
self.precision = precision
self.r = r
self.filename = filename
def define_model(self):
self.model = torch.nn.Sequential(
torch.nn.Conv2d(1, 2, kernel_size=5),
torch.nn.ReLU(),
torch.nn.Conv2d(2, 4, kernel_size=7),
torch.nn.ReLU(),
torch.nn.Flatten(),
torch.nn.Linear(1296, 512),
torch.nn.ReLU(),
torch.nn.Linear(512, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, 32),
torch.nn.ReLU(),
torch.nn.Linear(32, 10),
torch.nn.Softmax(dim=1),
)
def get_weights(self, dtype=np.float32):
precision = self.precision
weights = []
for layer in self.model:
try:
weights.append([np.around(layer.weight.detach().numpy().astype(dtype), decimals=precision), np.around(layer.bias.detach().numpy().astype(dtype), decimals=precision)])
except:
continue
return np.array(weights)
def set_weights(self, weights):
index = 0
for layer_no, layer in enumerate(self.model):
try:
_ = self.model[layer_no].weight
self.model[layer_no].weight = torch.nn.Parameter(weights[index][0])
self.model[layer_no].bias = torch.nn.Parameter(weights[index][1])
index += 1
except:
continue
def average_weights(self, all_weights):
all_weights = np.array(all_weights)
all_weights = np.mean(all_weights, axis=0)
all_weights = [[torch.from_numpy(i[0].astype(np.float32)), torch.from_numpy(i[1].astype(np.float32))] for i in all_weights]
return all_weights
def client_generator(self, train_x, train_y):
number_of_clients = self.number_of_clients
size = train_y.shape[0]//number_of_clients
train_x, train_y = train_x.numpy(), train_y.numpy()
train_x = np.array([train_x[i:i+size] for i in range(0, len(train_x)-len(train_x)%size, size)])
train_y = np.array([train_y[i:i+size] for i in range(0, len(train_y)-len(train_y)%size, size)])
train_x = torch.from_numpy(train_x)
train_y = torch.from_numpy(train_y)
return train_x, train_y
def single_client(self, dataset, weights, E):
self.define_model()
if weights is not None:
self.set_weights(weights)
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)
for epoch in range(E):
running_loss = 0
for batch_x, target in zip(dataset['x'], dataset['y']):
output = self.model(batch_x)
loss = self.criterion(output, target)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
running_loss += loss.item()
running_loss /= len(dataset['y'])
weights = self.get_weights()
return weights, running_loss
def test_aggregated_model(self, test_x, test_y, epoch):
acc = 0
with torch.no_grad():
for batch_x, batch_y in zip(test_x, test_y):
y_pred = self.model(batch_x)
y_pred = torch.argmax(y_pred, dim=1)
acc += torch.sum(y_pred == batch_y)/y_pred.shape[0]
torch.save(self.model, "./"+self.filename+"/model_epoch_"+str(epoch+1)+".pt")
return (acc/test_x.shape[0])
def train_aggregator(self, datasets, datasets_test):
local_epochs = self.local_epochs
aggregate_epochs = self.aggregate_epochs
os.system('mkdir '+self.filename)
E = local_epochs
aggregate_weights = None
for epoch in range(aggregate_epochs):
all_weights = []
client = 0
running_loss = 0
selections = np.arange(datasets['x'].shape[0])
np.random.shuffle(selections)
selections = selections[:int(self.r*datasets['x'].shape[0])]
clients = tqdm(zip(datasets['x'][selections], datasets['y'][selections]), total=selections.shape[0])
for dataset_x, dataset_y in clients:
dataset = {'x':dataset_x, 'y':dataset_y}
weights, loss = self.single_client(dataset, aggregate_weights, E)
running_loss += loss
all_weights.append(weights)
client += 1
clients.set_description(str({"Epoch":epoch+1,"Loss": round(running_loss/client, 5)}))
clients.refresh()
aggregate_weights = self.average_weights(all_weights)
self.set_weights(aggregate_weights)
test_acc = self.test_aggregated_model(datasets_test['x'], datasets_test['y'], epoch)
print("Test Accuracy:", round(test_acc.item(), 5))
clients.close()
if __name__ == '__main__':
number_of_clients = 328
aggregate_epochs = 10
local_epochs = 3
r = 0.5
filename = "saved_models"
train_x, train_y, test_x, test_y = Dataset().load_csv()
m = MNIST_PAQ(filename=filename, r=r, number_of_clients=number_of_clients, aggregate_epochs=aggregate_epochs, local_epochs=local_epochs)
train_x, train_y = m.client_generator(train_x, train_y)
m.train_aggregator({'x':train_x, 'y':train_y}, {'x':test_x, 'y':test_y})