Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

backward_inv for MAP on conditional glow #69

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 35 additions & 3 deletions src/conditional_layers/conditional_layer_glow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,13 @@ end
@Flux.functor ConditionalLayerGlow

# Constructor from 1x1 convolution and residual block
function ConditionalLayerGlow(C::Conv1x1, RB::ResidualBlock; logdet=false, activation::ActivationFunction=SigmoidLayer())
function ConditionalLayerGlow(C::Conv1x1, RB::ResidualBlock; logdet=true, activation::ActivationFunction=SigmoidLayer())
RB.fan == false && throw("Set ResidualBlock.fan == true")
return ConditionalLayerGlow(C, RB, logdet, activation)
end

# Constructor from input dimensions
function ConditionalLayerGlow(n_in::Int64, n_cond::Int64, n_hidden::Int64; k1=3, k2=1, p1=1, p2=0, s1=1, s2=1, logdet=false, activation::ActivationFunction=SigmoidLayer(), rb_activation::ActivationFunction=RELUlayer(), ndims=2)
function ConditionalLayerGlow(n_in::Int64, n_cond::Int64, n_hidden::Int64; k1=3, k2=1, p1=1, p2=0, s1=1, s2=1, logdet=true, activation::ActivationFunction=SigmoidLayer(), rb_activation::ActivationFunction=RELUlayer(), ndims=2)

# 1x1 Convolution and residual block for invertible layers
C = Conv1x1(n_in)
Expand Down Expand Up @@ -122,7 +122,8 @@ function inverse(Y::AbstractArray{T, N}, C::AbstractArray{T, N}, L::ConditionalL
X_ = tensor_cat(X1, X2)
X = L.C.inverse(X_)

save == true ? (return X, X1, X2, Sm) : (return X)
save && (return X, X1, X2, Sm,Tm)
L.logdet ? (return X, glow_logdet_forward(Sm)) : (return X)
end

# Backward pass: Input (ΔY, Y), Output (ΔX, X)
Expand Down Expand Up @@ -151,3 +152,34 @@ function backward(ΔY::AbstractArray{T, N}, Y::AbstractArray{T, N}, C::AbstractA

return ΔX, X, ΔC
end

function backward_inv(ΔX::AbstractArray{T, N}, X::AbstractArray{T, N},C::AbstractArray{T, N}, L::ConditionalLayerGlow; ) where {T, N}
ΔX, X = L.C.forward((ΔX, X))
X1, X2 = tensor_split(X)
ΔX1, ΔX2 = tensor_split(ΔX)

# Recompute forward state
rb_input = tensor_cat(X2,C)
logS_T = L.RB.forward(rb_input)
logSm, Tm = tensor_split(logS_T)
Sm = L.activation.forward(logSm)
Y1 = Sm.*X1 + Tm

# Backpropagate residual
ΔT = -ΔX1 ./ Sm
ΔS = X1 .* ΔT
if L.logdet == true
ΔS += coupling_logdet_backward(Sm)
end

ΔY2_ΔC = L.RB.backward(tensor_cat(L.activation.backward(ΔS, Sm), ΔT), rb_input)
ΔY2, ΔC = tensor_split(ΔY2_ΔC; split_index=Int(size(ΔX)[N-1]/2))
ΔY2 += ΔX2

ΔY1 = -ΔT

ΔY = tensor_cat(ΔY1, ΔY2)
Y = tensor_cat(Y1, X2)

return ΔY, Y, ΔC
end
2 changes: 1 addition & 1 deletion src/layers/invertible_layer_actnorm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ end

# 2-3D Inverse pass: Input Y, Output X
function inverse(Y::AbstractArray{T, N}, AN::ActNorm; logdet=nothing) where {T, N}
isnothing(logdet) ? logdet = (AN.logdet && AN.is_reversed) : logdet = logdet
isnothing(logdet) ? logdet = (AN.logdet && ~AN.is_reversed) : logdet = logdet
inds = [i!=(N-1) ? 1 : Colon() for i=1:N]
dims = collect(1:N-1); dims[end] +=1

Expand Down
2 changes: 1 addition & 1 deletion src/layers/invertible_layer_basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ end

# 2D/3D Inverse pass: Input Y, Output X
function inverse(Y1::AbstractArray{T, N}, Y2::AbstractArray{T, N}, L::CouplingLayerBasic; save::Bool=false, logdet=nothing) where {T, N}
isnothing(logdet) ? logdet = (L.logdet && L.is_reversed) : logdet = logdet
isnothing(logdet) ? logdet = (L.logdet && ~L.is_reversed) : logdet = logdet

# Inverse layer
logS_T1, logS_T2 = tensor_split(L.RB.forward(Y1))
Expand Down
13 changes: 7 additions & 6 deletions src/layers/invertible_layer_glow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,6 @@ end


## Jacobian-related functions

function jacobian(ΔX::AbstractArray{T, N}, Δθ::Array{Parameter, 1}, X, L::CouplingLayerGlow) where {T,N}

# Get dimensions
Expand All @@ -175,17 +174,19 @@ function jacobian(ΔX::AbstractArray{T, N}, Δθ::Array{Parameter, 1}, X, L::Cou
Y2 = copy(X2)
ΔY2 = copy(ΔX2)
ΔlogS_T, logS_T = L.RB.jacobian(ΔX2, Δθ[4:end], X2)
Sm = L.activation.forward(logS_T[:,:,1:k,:])
ΔS = L.activation.backward(ΔlogS_T[:,:,1:k,:], nothing;x=logS_T[:,:,1:k,:])
Tm = logS_T[:, :, k+1:end, :]
ΔT = ΔlogS_T[:, :, k+1:end, :]
logS, logT = tensor_split(logS_T)
ΔlogS, ΔlogT = tensor_split(ΔlogS_T)
Sm = L.activation.forward(logS)
ΔS = L.activation.backward(ΔlogS, nothing;x=logS)
Tm = logT
ΔT = ΔlogT
Y1 = Sm.*X1 + Tm
ΔY1 = ΔS.*X1 + Sm.*ΔX1 + ΔT
Y = tensor_cat(Y1, Y2)
ΔY = tensor_cat(ΔY1, ΔY2)

# Gauss-Newton approximation of logdet terms
JΔθ = L.RB.jacobian(cuzeros(ΔX2, size(ΔX2)), Δθ[4:end], X2)[1][:, :, 1:k, :]
JΔθ = tensor_split(L.RB.jacobian(cuzeros(ΔX2, size(ΔX2)), Δθ[4:end], X2)[1])[1]
GNΔθ = cat(0f0*Δθ[1:3], -L.RB.adjointJacobian(tensor_cat(L.activation.backward(JΔθ, Sm), zeros(Float32, size(Sm))), X2)[2]; dims=1)

L.logdet ? (return ΔY, Y, glow_logdet_forward(Sm), GNΔθ) : (return ΔY, Y)
Expand Down
2 changes: 1 addition & 1 deletion src/layers/invertible_layer_hint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ end

# Input is tensor Y
function inverse(Y::AbstractArray{T, N} , H::CouplingLayerHINT; scale=1, permute=nothing, logdet=nothing) where {T, N}
isnothing(logdet) ? logdet = (H.logdet && H.is_reversed) : logdet = logdet
isnothing(logdet) ? logdet = (H.logdet && ~H.is_reversed) : logdet = logdet
isnothing(permute) ? permute = H.permute : permute = permute

# Permutation
Expand Down
67 changes: 52 additions & 15 deletions src/networks/invertible_network_conditional_glow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@ export NetworkConditionalGlow, NetworkConditionalGlow3D

G = NetworkGlow3D(n_in, n_hidden, L, K; k1=3, k2=1, p1=1, p2=0, s1=1, s2=1)

Create an invertible network based on the Glow architecture. Each flow step in the inner loop
Create an conditional invertible network based on the Glow architecture. Each flow step in the inner loop
consists of an activation normalization layer, followed by an invertible coupling layer with
1x1 convolutions and a residual block. The outer loop performs a squeezing operation prior
to the inner loop, and a splitting operation afterwards.

NOTE: NEED TO GIVE OUTPUT Zc (Zx, Zc = G.forward(X,C)) AS INPUT TO inverss/backwards LAYERS G.backward(X,C)

*Input*:

- 'n_in': number of input channels
Expand Down Expand Up @@ -68,12 +70,13 @@ struct NetworkConditionalGlow <: InvertibleNetwork
K::Int64
squeezer::Squeezer
split_scales::Bool
logdet::Bool
end

@Flux.functor NetworkConditionalGlow

# Constructor
function NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K; split_scales=false, rb_activation::ActivationFunction=ReLUlayer(), k1=3, k2=1, p1=1, p2=0, s1=1, s2=1, ndims=2, squeezer::Squeezer=ShuffleLayer(), activation::ActivationFunction=SigmoidLayer())
function NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;logdet=true, split_scales=false, rb_activation::ActivationFunction=ReLUlayer(), k1=3, k2=1, p1=1, p2=0, s1=1, s2=1, ndims=2, squeezer::Squeezer=ShuffleLayer(), activation::ActivationFunction=SigmoidLayer())
AN = Array{ActNorm}(undef, L, K) # activation normalization
AN_C = ActNorm(n_cond; logdet=false) # activation normalization for condition
CL = Array{ConditionalLayerGlow}(undef, L, K) # coupling layers w/ 1x1 convolution and residual block
Expand All @@ -90,13 +93,13 @@ function NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K; split_scales=false
n_in *= channel_factor # squeeze if split_scales is turned on
n_cond *= channel_factor # squeeze if split_scales is turned on
for j=1:K
AN[i, j] = ActNorm(n_in; logdet=true)
CL[i, j] = ConditionalLayerGlow(n_in, n_cond, n_hidden; rb_activation=rb_activation, k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, logdet=true, activation=activation, ndims=ndims)
AN[i, j] = ActNorm(n_in; logdet=logdet)
CL[i, j] = ConditionalLayerGlow(n_in, n_cond, n_hidden; rb_activation=rb_activation, k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, logdet=logdet, activation=activation, ndims=ndims)
end
(i < L && split_scales) && (n_in = Int64(n_in/2)) # split
end

return NetworkConditionalGlow(AN, AN_C, CL, Z_dims, L, K, squeezer, split_scales)
return NetworkConditionalGlow(AN, AN_C, CL, Z_dims, L, K, squeezer, split_scales,logdet)
end

NetworkConditionalGlow3D(args; kw...) = NetworkConditionalGlow(args...; kw..., ndims=3)
Expand All @@ -111,12 +114,13 @@ function forward(X::AbstractArray{T, N}, C::AbstractArray{T, N}, G::NetworkCondi

logdet = 0
for i=1:G.L
#println("L = $(i)")
(G.split_scales) && (X = G.squeezer.forward(X))
(G.split_scales) && (C = G.squeezer.forward(C))
for j=1:G.K
X, logdet1 = G.AN[i, j].forward(X)
X, logdet2 = G.CL[i, j].forward(X, C)
logdet += (logdet1 + logdet2)
for j=1:G.K
G.logdet ? (X, logdet1) = G.AN[i, j].forward(X) : X = G.AN[i, j].forward(X)
G.logdet ? (X, logdet2) = G.CL[i, j].forward(X, C) : X = G.CL[i, j].forward(X, C)
G.logdet && (logdet += (logdet1 + logdet2))
end
if G.split_scales && i < G.L # don't split after last iteration
X, Z = tensor_split(X)
Expand All @@ -125,27 +129,28 @@ function forward(X::AbstractArray{T, N}, C::AbstractArray{T, N}, G::NetworkCondi
end
end
G.split_scales && (X = reshape(cat_states(Z_save, X),orig_shape))
return X, C, logdet
G.logdet ? (return X, C, logdet) : (return X, C)
end

# Inverse pass
function inverse(X::AbstractArray{T, N}, C::AbstractArray{T, N}, G::NetworkConditionalGlow) where {T, N}
G.split_scales && ((Z_save, X) = split_states(X[:], G.Z_dims))

logdet = 0
for i=G.L:-1:1
if G.split_scales && i < G.L
X = tensor_cat(X, Z_save[i])
end
for j=G.K:-1:1
X = G.CL[i, j].inverse(X,C)
X = G.AN[i, j].inverse(X)
G.logdet ? (X, logdet1) = G.CL[i, j].inverse(X,C) : X = G.CL[i, j].inverse(X,C)
G.logdet ? (X, logdet2) = G.AN[i, j].inverse(X) : X = G.AN[i, j].inverse(X)
G.logdet && (logdet += (logdet1 + logdet2))
end

(G.split_scales) && (X = G.squeezer.inverse(X))
(G.split_scales) && (C = G.squeezer.inverse(C))
end
return X
G.logdet ? (return X, C, logdet) : (return X, C)
end

# Backward pass and compute gradients
function backward(ΔX::AbstractArray{T, N}, X::AbstractArray{T, N}, C::AbstractArray{T, N}, G::NetworkConditionalGlow) where {T, N}

Expand Down Expand Up @@ -180,3 +185,35 @@ function backward(ΔX::AbstractArray{T, N}, X::AbstractArray{T, N}, C::AbstractA

return ΔX, X
end

# Backward pass and compute gradients
function backward_inv(ΔX::AbstractArray{T, N}, X::AbstractArray{T, N}, C::AbstractArray{T, N}, G::NetworkConditionalGlow; C_save=nothing) where {T, N}
G.split_scales && (X_save = array_of_array(X, G.L-1))
G.split_scales && (ΔX_save = array_of_array(ΔX, G.L-1))
orig_shape = size(X)

for i=1:G.L
G.split_scales && (ΔX = G.squeezer.forward(ΔX))
G.split_scales && (X = G.squeezer.forward(X))
G.split_scales && (C = G.squeezer.forward(C))
for j=1:G.K
ΔX_, X_ = backward_inv(ΔX, X, G.AN[i, j])
ΔX, X, ΔC = backward_inv(ΔX_, X_, C, G.CL[i, j])
end

if G.split_scales && i < G.L # don't split after last iteration
X, Z = tensor_split(X)
ΔX, ΔZx = tensor_split(ΔX)

X_save[i] = Z
ΔX_save[i] = ΔZx

G.Z_dims[i] = collect(size(X))
end
end

G.split_scales && (X = reshape(cat_states(X_save, X), orig_shape))
G.split_scales && (ΔX = reshape(cat_states(ΔX_save, ΔX), orig_shape))
return ΔX, X
end

15 changes: 13 additions & 2 deletions src/utils/invertible_network_sequential.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,22 @@ function forward(X::AbstractArray{T, N1}, N::ComposedInvertibleNetwork) where {T
N.logdet ? (return X, logdet) : (return X)
end

#Y = N.forward(X)[1]
#@test isapprox(X, N.inverse(N.forward(X)[1])[1]; rtol=1f-3)

function inverse(Y::AbstractArray{T, N1}, N::ComposedInvertibleNetwork) where {T, N1}
N.logdet && (logdet = 0)
for i = length(N):-1:1
Y = N.layers[i].inverse(Y)
println(i)
if N.logdet_array[i]
Y, logdet_ = N.layers[i].inverse(Y)
logdet += logdet_
else
Y = N.layers[i].inverse(Y)
end

end
return Y
N.logdet ? (return Y, logdet) : (return Y)
end

function backward(ΔY::AbstractArray{T, N1}, Y::AbstractArray{T, N1}, N::ComposedInvertibleNetwork; set_grad::Bool = true) where {T, N1}
Expand Down
Loading