-
Notifications
You must be signed in to change notification settings - Fork 28
/
Copy pathkmeans.py
37 lines (23 loc) · 908 Bytes
/
kmeans.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch
import numpy as np
from kmeans_pytorch.pairwise import pairwise_distance
def forgy(X, n_clusters):
_len = len(X)
indices = np.random.choice(_len, n_clusters)
initial_state = X[indices]
return initial_state
def lloyd(X, n_clusters, device=0, tol=1e-4):
X = torch.from_numpy(X).float().cuda(device)
initial_state = forgy(X, n_clusters)
while True:
dis = pairwise_distance(X, initial_state)
choice_cluster = torch.argmin(dis, dim=1)
initial_state_pre = initial_state.clone()
for index in range(n_clusters):
selected = torch.nonzero(choice_cluster==index).squeeze()
selected = torch.index_select(X, 0, selected)
initial_state[index] = selected.mean(dim=0)
center_shift = torch.sum(torch.sqrt(torch.sum((initial_state - initial_state_pre) ** 2, dim=1)))
if center_shift ** 2 < tol:
break
return choice_cluster.cpu().numpy(), initial_state.cpu().numpy()