-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathBCG_lib.py
144 lines (113 loc) · 4.33 KB
/
BCG_lib.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import torch
import numpy as np
class BCG():
def __init__(self, A, b, prior_mean, prior_cov, eps, m_max, batch_directions=False):
self.A = A
self.b = b
self.x0 = prior_mean
self.sigma0 = prior_cov
self.eps = eps
self.max = m_max
self.batch_directions = batch_directions
def bcg(self, x_true):
#sigmaF = [] #np.concatenate?
A = self.A
b = self.b
x0 = self.x0
sigma0 = self.sigma0
eps = self.eps
m_max = self.max
batch_directions = self.batch_directions
r_m = b - A.dot(x0)
r_m_dot_r_m = r_m.T.dot(r_m)
s_m = r_m
x_m = x0
sigmaF = np.zeros(sigma0.shape)
search_directions = np.zeros(sigma0.shape)
A_sigma_A_search_directions = np.zeros(sigma0.shape)
nu_m = 0
m = 0
d = b.shape[0]
rel_error = np.zeros(m_max)
rel_trace = np.zeros(m_max)
search_normalisations = np.zeros(m_max)
while True:
sigma_At_s = np.dot(sigma0, np.dot(A.T, s_m))
A_sigma_A_s = np.dot(A, sigma_At_s)
E_2 = np.dot(s_m.T, A_sigma_A_s)
alpha_m = r_m_dot_r_m / E_2
x_m += alpha_m * sigma_At_s
r_m -= alpha_m * A_sigma_A_s
if batch_directions:
search_directions[:, m] = s_m.reshape(100,)
A_sigma_A_search_directions[:, m] = A_sigma_A_s.reshape(100,)
search_normalisations[m] = 1./np.sqrt(E_2)
rel_error[m] = np.linalg.norm(x_true-x_m)/np.linalg.norm(x_true)
nu_m += r_m_dot_r_m * r_m_dot_r_m / E_2
sigma_m = np.sqrt((d - 1 - m) * nu_m / (m + 1)) ##??
prev_r_m_dot_r_m = r_m_dot_r_m
r_m_dot_r_m = np.dot(r_m.T, r_m)
E = np.sqrt(E_2)
sigmaF[:, m] = (sigma_At_s/E).reshape(100,)
m +=1
Sigma_m = sigma0 - sigmaF[:, :m].dot(sigmaF[:, :m].T)
rel_trace[m-1] = np.trace(Sigma_m)/np.trace(sigma0)
if batch_directions:
s_m = r_m
for i in range(m):
coeff = r_m.T.dot(A_sigma_A_search_directions[:, i].reshape(100,1))*search_normalisations[i]
s_m += coeff*search_directions[:, i].reshape(100,1)
else:
beta_m = r_m_dot_r_m / prev_r_m_dot_r_m
s_m = r_m + beta_m *s_m
#add minimal no of iterations
if sigma_m < eps:
break
'''else sqrt(r_m_dot_r_m) < eps: - traditional residual-minimising strategy
break'''
if m == m_max or m == d:
break
return x_m, sigmaF, nu_m/m, rel_error, rel_trace
def conjugate_grad(A, b, maxiter, x_true):
n = len(b)
x = np.zeros([n, 1])
sigmaF = np.zeros(A.shape)
rel_error = np.zeros(n)
trace_error = np.zeros(n)
r = np.dot(A, x) - b
s = - r
r_k_norm = np.dot(r.T, r)
for i in range(maxiter):
As = np.dot(A, s)
alpha = r_k_norm / np.dot(s.T, As)
x += alpha * s
rel_error[i] = np.linalg.norm(x_true-x)/np.linalg.norm(x_true)
sigmaF[:, i] = s.reshape(100,)/np.sqrt(np.dot(s.T, As))
Sigma_m = np.eye(100) - sigmaF[:, :i+1].dot(sigmaF[:, :i+1].T)
trace_error[i] = np.trace(Sigma_m)/np.trace(np.eye(100))
r += alpha * As
r_kplus1_norm = np.dot(r.T, r)
beta = r_kplus1_norm / r_k_norm
r_k_norm = r_kplus1_norm
s = beta * s - r
return x, rel_error, trace_error
def ichol(A):
n = A.shape[0]
L = np.zeros(A.shape)
for i in range(n):
sqrt_diag = A[i,i]
for k in range(i):
tmp = L[i,k]
sqrt_diag -= tmp*tmp
sqrt_diag = np.sqrt(sqrt_diag)
L[i,i] = sqrt_diag
sqrt_diag = 1./sqrt_diag
for j in range(i+1, n):
tmp = A[j,i]
if tmp == 0:
continue
for k in range(i):
tmp -= L[i,k]*L[j,k]
tmp *= sqrt_diag
L[j,i] = tmp
return L