From 240daa975f5d9e20dd2e19e38713ed4b46e5ac38 Mon Sep 17 00:00:00 2001 From: Rafael Orozco Date: Mon, 28 Nov 2022 17:25:24 -0500 Subject: [PATCH] backward_inv for MAP on conditional glow --- .../conditional_layer_glow.jl | 38 ++- src/layers/invertible_layer_actnorm.jl | 2 +- src/layers/invertible_layer_basic.jl | 2 +- src/layers/invertible_layer_glow.jl | 13 +- src/layers/invertible_layer_hint.jl | 2 +- .../invertible_network_conditional_glow.jl | 67 ++++-- src/utils/invertible_network_sequential.jl | 15 +- .../test_conditional_glow_network.jl | 219 +++++++++++++++++- test/test_networks/test_glow.jl | 24 +- test/test_utils/test_sequential.jl | 5 +- 10 files changed, 340 insertions(+), 47 deletions(-) diff --git a/src/conditional_layers/conditional_layer_glow.jl b/src/conditional_layers/conditional_layer_glow.jl index fb7fea61..c8e17b38 100644 --- a/src/conditional_layers/conditional_layer_glow.jl +++ b/src/conditional_layers/conditional_layer_glow.jl @@ -68,13 +68,13 @@ end @Flux.functor ConditionalLayerGlow # Constructor from 1x1 convolution and residual block -function ConditionalLayerGlow(C::Conv1x1, RB::ResidualBlock; logdet=false, activation::ActivationFunction=SigmoidLayer()) +function ConditionalLayerGlow(C::Conv1x1, RB::ResidualBlock; logdet=true, activation::ActivationFunction=SigmoidLayer()) RB.fan == false && throw("Set ResidualBlock.fan == true") return ConditionalLayerGlow(C, RB, logdet, activation) end # Constructor from input dimensions -function ConditionalLayerGlow(n_in::Int64, n_cond::Int64, n_hidden::Int64; k1=3, k2=1, p1=1, p2=0, s1=1, s2=1, logdet=false, activation::ActivationFunction=SigmoidLayer(), rb_activation::ActivationFunction=RELUlayer(), ndims=2) +function ConditionalLayerGlow(n_in::Int64, n_cond::Int64, n_hidden::Int64; k1=3, k2=1, p1=1, p2=0, s1=1, s2=1, logdet=true, activation::ActivationFunction=SigmoidLayer(), rb_activation::ActivationFunction=RELUlayer(), ndims=2) # 1x1 Convolution and residual block for invertible layers C = Conv1x1(n_in) @@ -122,7 +122,8 @@ function inverse(Y::AbstractArray{T, N}, C::AbstractArray{T, N}, L::ConditionalL X_ = tensor_cat(X1, X2) X = L.C.inverse(X_) - save == true ? (return X, X1, X2, Sm) : (return X) + save && (return X, X1, X2, Sm,Tm) + L.logdet ? (return X, glow_logdet_forward(Sm)) : (return X) end # Backward pass: Input (ΔY, Y), Output (ΔX, X) @@ -151,3 +152,34 @@ function backward(ΔY::AbstractArray{T, N}, Y::AbstractArray{T, N}, C::AbstractA return ΔX, X, ΔC end + +function backward_inv(ΔX::AbstractArray{T, N}, X::AbstractArray{T, N},C::AbstractArray{T, N}, L::ConditionalLayerGlow; ) where {T, N} + ΔX, X = L.C.forward((ΔX, X)) + X1, X2 = tensor_split(X) + ΔX1, ΔX2 = tensor_split(ΔX) + + # Recompute forward state + rb_input = tensor_cat(X2,C) + logS_T = L.RB.forward(rb_input) + logSm, Tm = tensor_split(logS_T) + Sm = L.activation.forward(logSm) + Y1 = Sm.*X1 + Tm + + # Backpropagate residual + ΔT = -ΔX1 ./ Sm + ΔS = X1 .* ΔT + if L.logdet == true + ΔS += coupling_logdet_backward(Sm) + end + + ΔY2_ΔC = L.RB.backward(tensor_cat(L.activation.backward(ΔS, Sm), ΔT), rb_input) + ΔY2, ΔC = tensor_split(ΔY2_ΔC; split_index=Int(size(ΔX)[N-1]/2)) + ΔY2 += ΔX2 + + ΔY1 = -ΔT + + ΔY = tensor_cat(ΔY1, ΔY2) + Y = tensor_cat(Y1, X2) + + return ΔY, Y, ΔC +end \ No newline at end of file diff --git a/src/layers/invertible_layer_actnorm.jl b/src/layers/invertible_layer_actnorm.jl index c1c348b8..4d460414 100644 --- a/src/layers/invertible_layer_actnorm.jl +++ b/src/layers/invertible_layer_actnorm.jl @@ -78,7 +78,7 @@ end # 2-3D Inverse pass: Input Y, Output X function inverse(Y::AbstractArray{T, N}, AN::ActNorm; logdet=nothing) where {T, N} - isnothing(logdet) ? logdet = (AN.logdet && AN.is_reversed) : logdet = logdet + isnothing(logdet) ? logdet = (AN.logdet && ~AN.is_reversed) : logdet = logdet inds = [i!=(N-1) ? 1 : Colon() for i=1:N] dims = collect(1:N-1); dims[end] +=1 diff --git a/src/layers/invertible_layer_basic.jl b/src/layers/invertible_layer_basic.jl index 469e86af..f0105c8d 100644 --- a/src/layers/invertible_layer_basic.jl +++ b/src/layers/invertible_layer_basic.jl @@ -104,7 +104,7 @@ end # 2D/3D Inverse pass: Input Y, Output X function inverse(Y1::AbstractArray{T, N}, Y2::AbstractArray{T, N}, L::CouplingLayerBasic; save::Bool=false, logdet=nothing) where {T, N} - isnothing(logdet) ? logdet = (L.logdet && L.is_reversed) : logdet = logdet + isnothing(logdet) ? logdet = (L.logdet && ~L.is_reversed) : logdet = logdet # Inverse layer logS_T1, logS_T2 = tensor_split(L.RB.forward(Y1)) diff --git a/src/layers/invertible_layer_glow.jl b/src/layers/invertible_layer_glow.jl index 39d8e1cf..6de565cf 100644 --- a/src/layers/invertible_layer_glow.jl +++ b/src/layers/invertible_layer_glow.jl @@ -162,7 +162,6 @@ end ## Jacobian-related functions - function jacobian(ΔX::AbstractArray{T, N}, Δθ::Array{Parameter, 1}, X, L::CouplingLayerGlow) where {T,N} # Get dimensions @@ -175,17 +174,19 @@ function jacobian(ΔX::AbstractArray{T, N}, Δθ::Array{Parameter, 1}, X, L::Cou Y2 = copy(X2) ΔY2 = copy(ΔX2) ΔlogS_T, logS_T = L.RB.jacobian(ΔX2, Δθ[4:end], X2) - Sm = L.activation.forward(logS_T[:,:,1:k,:]) - ΔS = L.activation.backward(ΔlogS_T[:,:,1:k,:], nothing;x=logS_T[:,:,1:k,:]) - Tm = logS_T[:, :, k+1:end, :] - ΔT = ΔlogS_T[:, :, k+1:end, :] + logS, logT = tensor_split(logS_T) + ΔlogS, ΔlogT = tensor_split(ΔlogS_T) + Sm = L.activation.forward(logS) + ΔS = L.activation.backward(ΔlogS, nothing;x=logS) + Tm = logT + ΔT = ΔlogT Y1 = Sm.*X1 + Tm ΔY1 = ΔS.*X1 + Sm.*ΔX1 + ΔT Y = tensor_cat(Y1, Y2) ΔY = tensor_cat(ΔY1, ΔY2) # Gauss-Newton approximation of logdet terms - JΔθ = L.RB.jacobian(cuzeros(ΔX2, size(ΔX2)), Δθ[4:end], X2)[1][:, :, 1:k, :] + JΔθ = tensor_split(L.RB.jacobian(cuzeros(ΔX2, size(ΔX2)), Δθ[4:end], X2)[1])[1] GNΔθ = cat(0f0*Δθ[1:3], -L.RB.adjointJacobian(tensor_cat(L.activation.backward(JΔθ, Sm), zeros(Float32, size(Sm))), X2)[2]; dims=1) L.logdet ? (return ΔY, Y, glow_logdet_forward(Sm), GNΔθ) : (return ΔY, Y) diff --git a/src/layers/invertible_layer_hint.jl b/src/layers/invertible_layer_hint.jl index ca21a45a..20a63a4c 100644 --- a/src/layers/invertible_layer_hint.jl +++ b/src/layers/invertible_layer_hint.jl @@ -155,7 +155,7 @@ end # Input is tensor Y function inverse(Y::AbstractArray{T, N} , H::CouplingLayerHINT; scale=1, permute=nothing, logdet=nothing) where {T, N} - isnothing(logdet) ? logdet = (H.logdet && H.is_reversed) : logdet = logdet + isnothing(logdet) ? logdet = (H.logdet && ~H.is_reversed) : logdet = logdet isnothing(permute) ? permute = H.permute : permute = permute # Permutation diff --git a/src/networks/invertible_network_conditional_glow.jl b/src/networks/invertible_network_conditional_glow.jl index 253b329a..0c37c5e4 100644 --- a/src/networks/invertible_network_conditional_glow.jl +++ b/src/networks/invertible_network_conditional_glow.jl @@ -10,11 +10,13 @@ export NetworkConditionalGlow, NetworkConditionalGlow3D G = NetworkGlow3D(n_in, n_hidden, L, K; k1=3, k2=1, p1=1, p2=0, s1=1, s2=1) - Create an invertible network based on the Glow architecture. Each flow step in the inner loop + Create an conditional invertible network based on the Glow architecture. Each flow step in the inner loop consists of an activation normalization layer, followed by an invertible coupling layer with 1x1 convolutions and a residual block. The outer loop performs a squeezing operation prior to the inner loop, and a splitting operation afterwards. + NOTE: NEED TO GIVE OUTPUT Zc (Zx, Zc = G.forward(X,C)) AS INPUT TO inverss/backwards LAYERS G.backward(X,C) + *Input*: - 'n_in': number of input channels @@ -68,12 +70,13 @@ struct NetworkConditionalGlow <: InvertibleNetwork K::Int64 squeezer::Squeezer split_scales::Bool + logdet::Bool end @Flux.functor NetworkConditionalGlow # Constructor -function NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K; split_scales=false, rb_activation::ActivationFunction=ReLUlayer(), k1=3, k2=1, p1=1, p2=0, s1=1, s2=1, ndims=2, squeezer::Squeezer=ShuffleLayer(), activation::ActivationFunction=SigmoidLayer()) +function NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;logdet=true, split_scales=false, rb_activation::ActivationFunction=ReLUlayer(), k1=3, k2=1, p1=1, p2=0, s1=1, s2=1, ndims=2, squeezer::Squeezer=ShuffleLayer(), activation::ActivationFunction=SigmoidLayer()) AN = Array{ActNorm}(undef, L, K) # activation normalization AN_C = ActNorm(n_cond; logdet=false) # activation normalization for condition CL = Array{ConditionalLayerGlow}(undef, L, K) # coupling layers w/ 1x1 convolution and residual block @@ -90,13 +93,13 @@ function NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K; split_scales=false n_in *= channel_factor # squeeze if split_scales is turned on n_cond *= channel_factor # squeeze if split_scales is turned on for j=1:K - AN[i, j] = ActNorm(n_in; logdet=true) - CL[i, j] = ConditionalLayerGlow(n_in, n_cond, n_hidden; rb_activation=rb_activation, k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, logdet=true, activation=activation, ndims=ndims) + AN[i, j] = ActNorm(n_in; logdet=logdet) + CL[i, j] = ConditionalLayerGlow(n_in, n_cond, n_hidden; rb_activation=rb_activation, k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, logdet=logdet, activation=activation, ndims=ndims) end (i < L && split_scales) && (n_in = Int64(n_in/2)) # split end - return NetworkConditionalGlow(AN, AN_C, CL, Z_dims, L, K, squeezer, split_scales) + return NetworkConditionalGlow(AN, AN_C, CL, Z_dims, L, K, squeezer, split_scales,logdet) end NetworkConditionalGlow3D(args; kw...) = NetworkConditionalGlow(args...; kw..., ndims=3) @@ -111,12 +114,13 @@ function forward(X::AbstractArray{T, N}, C::AbstractArray{T, N}, G::NetworkCondi logdet = 0 for i=1:G.L + #println("L = $(i)") (G.split_scales) && (X = G.squeezer.forward(X)) (G.split_scales) && (C = G.squeezer.forward(C)) - for j=1:G.K - X, logdet1 = G.AN[i, j].forward(X) - X, logdet2 = G.CL[i, j].forward(X, C) - logdet += (logdet1 + logdet2) + for j=1:G.K + G.logdet ? (X, logdet1) = G.AN[i, j].forward(X) : X = G.AN[i, j].forward(X) + G.logdet ? (X, logdet2) = G.CL[i, j].forward(X, C) : X = G.CL[i, j].forward(X, C) + G.logdet && (logdet += (logdet1 + logdet2)) end if G.split_scales && i < G.L # don't split after last iteration X, Z = tensor_split(X) @@ -125,27 +129,28 @@ function forward(X::AbstractArray{T, N}, C::AbstractArray{T, N}, G::NetworkCondi end end G.split_scales && (X = reshape(cat_states(Z_save, X),orig_shape)) - return X, C, logdet + G.logdet ? (return X, C, logdet) : (return X, C) end # Inverse pass function inverse(X::AbstractArray{T, N}, C::AbstractArray{T, N}, G::NetworkConditionalGlow) where {T, N} G.split_scales && ((Z_save, X) = split_states(X[:], G.Z_dims)) + + logdet = 0 for i=G.L:-1:1 if G.split_scales && i < G.L X = tensor_cat(X, Z_save[i]) end for j=G.K:-1:1 - X = G.CL[i, j].inverse(X,C) - X = G.AN[i, j].inverse(X) + G.logdet ? (X, logdet1) = G.CL[i, j].inverse(X,C) : X = G.CL[i, j].inverse(X,C) + G.logdet ? (X, logdet2) = G.AN[i, j].inverse(X) : X = G.AN[i, j].inverse(X) + G.logdet && (logdet += (logdet1 + logdet2)) end - (G.split_scales) && (X = G.squeezer.inverse(X)) (G.split_scales) && (C = G.squeezer.inverse(C)) end - return X + G.logdet ? (return X, C, logdet) : (return X, C) end - # Backward pass and compute gradients function backward(ΔX::AbstractArray{T, N}, X::AbstractArray{T, N}, C::AbstractArray{T, N}, G::NetworkConditionalGlow) where {T, N} @@ -180,3 +185,35 @@ function backward(ΔX::AbstractArray{T, N}, X::AbstractArray{T, N}, C::AbstractA return ΔX, X end + +# Backward pass and compute gradients +function backward_inv(ΔX::AbstractArray{T, N}, X::AbstractArray{T, N}, C::AbstractArray{T, N}, G::NetworkConditionalGlow; C_save=nothing) where {T, N} + G.split_scales && (X_save = array_of_array(X, G.L-1)) + G.split_scales && (ΔX_save = array_of_array(ΔX, G.L-1)) + orig_shape = size(X) + + for i=1:G.L + G.split_scales && (ΔX = G.squeezer.forward(ΔX)) + G.split_scales && (X = G.squeezer.forward(X)) + G.split_scales && (C = G.squeezer.forward(C)) + for j=1:G.K + ΔX_, X_ = backward_inv(ΔX, X, G.AN[i, j]) + ΔX, X, ΔC = backward_inv(ΔX_, X_, C, G.CL[i, j]) + end + + if G.split_scales && i < G.L # don't split after last iteration + X, Z = tensor_split(X) + ΔX, ΔZx = tensor_split(ΔX) + + X_save[i] = Z + ΔX_save[i] = ΔZx + + G.Z_dims[i] = collect(size(X)) + end + end + + G.split_scales && (X = reshape(cat_states(X_save, X), orig_shape)) + G.split_scales && (ΔX = reshape(cat_states(ΔX_save, ΔX), orig_shape)) + return ΔX, X +end + diff --git a/src/utils/invertible_network_sequential.jl b/src/utils/invertible_network_sequential.jl index 3d8e9cab..415f1855 100644 --- a/src/utils/invertible_network_sequential.jl +++ b/src/utils/invertible_network_sequential.jl @@ -86,11 +86,22 @@ function forward(X::AbstractArray{T, N1}, N::ComposedInvertibleNetwork) where {T N.logdet ? (return X, logdet) : (return X) end +#Y = N.forward(X)[1] +#@test isapprox(X, N.inverse(N.forward(X)[1])[1]; rtol=1f-3) + function inverse(Y::AbstractArray{T, N1}, N::ComposedInvertibleNetwork) where {T, N1} + N.logdet && (logdet = 0) for i = length(N):-1:1 - Y = N.layers[i].inverse(Y) + println(i) + if N.logdet_array[i] + Y, logdet_ = N.layers[i].inverse(Y) + logdet += logdet_ + else + Y = N.layers[i].inverse(Y) + end + end - return Y + N.logdet ? (return Y, logdet) : (return Y) end function backward(ΔY::AbstractArray{T, N1}, Y::AbstractArray{T, N1}, N::ComposedInvertibleNetwork; set_grad::Bool = true) where {T, N1} diff --git a/test/test_networks/test_conditional_glow_network.jl b/test/test_networks/test_conditional_glow_network.jl index a42d4657..83a6b342 100644 --- a/test/test_networks/test_conditional_glow_network.jl +++ b/test/test_networks/test_conditional_glow_network.jl @@ -17,6 +17,217 @@ batchsize = 2 L = 2 K = 2 +########################################### Test with split_scales = true ######################### +# Invertibility + +# Network and input +G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=true) +X = rand(Float32, nx, ny, n_in, batchsize) +Cond = rand(Float32, nx, ny, n_cond, batchsize) + +G.CL[1,1].forward(rand(Float32, nx, ny, 8, batchsize),rand(Float32, nx, ny, 8, batchsize)) + +Y, Cond_z = G.forward(X,Cond) +X_,Cond_ = G.inverse(Y,Cond_z) + +@test isapprox(norm(X - X_)/norm(X), 0f0; atol=1f-5) + +################################################################################################### +# Test gradients are set and cleared +G.backward(Y, Y, Cond_z) + +P = get_params(G) +gsum = 0 +for p in P + ~isnothing(p.grad) && (global gsum += 1) +end +@test isequal(gsum, L*K*10+L) + +clear_grad!(G) +gsum = 0 +for p in P + ~isnothing(p.grad) && (global gsum += 1) +end +@test isequal(gsum, 0) + + +################################################################################################### +# Gradient test + +function loss(G, X, Cond) + Y, ZC, logdet = G.forward(X, Cond) + f = -log_likelihood(Y) - logdet + ΔY = -∇log_likelihood(Y) + ΔX, X_ = G.backward(ΔY, Y, ZC) + return f, ΔX, G.CL[1,1].RB.W1.grad, G.CL[1,1].C.v1.grad +end + +# Gradient test w.r.t. input +G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=true) +X = rand(Float32, nx, ny, n_in, batchsize) +Cond = rand(Float32, nx, ny, n_cond, batchsize) +X0 = rand(Float32, nx, ny, n_in, batchsize) +Cond0 = rand(Float32, nx, ny, n_cond, batchsize) + +dX = X - X0 + +f0, ΔX = loss(G, X0, Cond0)[1:2] +h = 0.1f0 +maxiter = 4 +err1 = zeros(Float32, maxiter) +err2 = zeros(Float32, maxiter) + +print("\nGradient test glow: input\n") +for j=1:maxiter + f = loss(G, X0 + h*dX, Cond0)[1] + err1[j] = abs(f - f0) + err2[j] = abs(f - f0 - h*dot(dX, ΔX)) + print(err1[j], "; ", err2[j], "\n") + global h = h/2f0 +end + +@test isapprox(err1[end] / (err1[1]/2^(maxiter-1)), 1f0; atol=1f0) +@test isapprox(err2[end] / (err2[1]/4^(maxiter-1)), 1f0; atol=1f0) + + +# Gradient test w.r.t. parameters +X = rand(Float32, nx, ny, n_in, batchsize) +G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=true) +G0 = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=true) +Gini = deepcopy(G0) + +# Test one parameter from residual block and 1x1 conv +dW = G.CL[1,1].RB.W1.data - G0.CL[1,1].RB.W1.data +dv = G.CL[1,1].C.v1.data - G0.CL[1,1].C.v1.data + +f0, ΔX, ΔW, Δv = loss(G0, X, Cond) +h = 0.1f0 +maxiter = 4 +err3 = zeros(Float32, maxiter) +err4 = zeros(Float32, maxiter) + +print("\nGradient test glow: input\n") +for j=1:maxiter + G0.CL[1,1].RB.W1.data = Gini.CL[1,1].RB.W1.data + h*dW + G0.CL[1,1].C.v1.data = Gini.CL[1,1].C.v1.data + h*dv + + f = loss(G0, X, Cond)[1] + err3[j] = abs(f - f0) + err4[j] = abs(f - f0 - h*dot(dW, ΔW) - h*dot(dv, Δv)) + print(err3[j], "; ", err4[j], "\n") + global h = h/2f0 +end + +@test isapprox(err3[end] / (err3[1]/2^(maxiter-1)), 1f0; atol=1f0) +@test isapprox(err4[end] / (err4[1]/4^(maxiter-1)), 1f0; atol=1f0) + + +########################################### Test with split_scales = true and REV ######################### +# Invertibility + +# Network and input +G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K; logdet=false, split_scales=true) +G = reverse(G) +X = rand(Float32, nx, ny, n_in, batchsize) +Cond = rand(Float32, nx, ny, n_cond, batchsize) + + +Y_, Cond_ = G.inverse(X,Cond) +X_, Cond_z = G.forward(Y_,Cond_) + +@test isapprox(norm(X - X_)/norm(X), 0f0; atol=1f-5) + +################################################################################################### +# Test gradients are set and cleared +G.backward(X_, X_, Cond_z) + +P = get_params(G) +gsum = 0 +for p in P + ~isnothing(p.grad) && (global gsum += 1) +end +@test isequal(gsum, L*K*10) #maybe we need to set act? Definitely. not in split_scales + +clear_grad!(G) +gsum = 0 +for p in P + ~isnothing(p.grad) && (global gsum += 1) +end +@test isequal(gsum, 0) + + +################################################################################################### +# Gradient test + +function loss(G, X, Cond) + Y, ZC = G.forward(X, Cond) + f = -log_likelihood(Y) + ΔY = -∇log_likelihood(Y) + ΔX, X_ = G.backward(ΔY, Y, ZC) + return f, ΔX, G.CL[1,1].RB.W1.grad, G.CL[1,1].C.v1.grad +end + +# Gradient test w.r.t. input +G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K; logdet=false, split_scales=true) +X = rand(Float32, nx, ny, n_in, batchsize) +Cond = rand(Float32, nx, ny, n_cond, batchsize) +X0 = rand(Float32, nx, ny, n_in, batchsize) +Cond0 = rand(Float32, nx, ny, n_cond, batchsize) + +dX = X - X0 + +f0, ΔX = loss(G, X0, Cond0)[1:2] +h = 0.1f0 +maxiter = 4 +err1 = zeros(Float32, maxiter) +err2 = zeros(Float32, maxiter) + +print("\nGradient test glow: input\n") +for j=1:maxiter + f = loss(G, X0 + h*dX, Cond0)[1] + err1[j] = abs(f - f0) + err2[j] = abs(f - f0 - h*dot(dX, ΔX)) + print(err1[j], "; ", err2[j], "\n") + global h = h/2f0 +end + +@test isapprox(err1[end] / (err1[1]/2^(maxiter-1)), 1f0; atol=1f0) +@test isapprox(err2[end] / (err2[1]/4^(maxiter-1)), 1f0; atol=1f0) + + +# Gradient test w.r.t. parameters +X = rand(Float32, nx, ny, n_in, batchsize) +G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K; logdet=false, split_scales=true) +G0 = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K; logdet=false, split_scales=true) +Gini = deepcopy(G0) + +# Test one parameter from residual block and 1x1 conv +dW = G.CL[1,1].RB.W1.data - G0.CL[1,1].RB.W1.data +dv = G.CL[1,1].C.v1.data - G0.CL[1,1].C.v1.data + +f0, ΔX, ΔW, Δv = loss(G0, X, Cond) +h = 0.1f0 +maxiter = 4 +err3 = zeros(Float32, maxiter) +err4 = zeros(Float32, maxiter) + +print("\nGradient test glow: input\n") +for j=1:maxiter + G0.CL[1,1].RB.W1.data = Gini.CL[1,1].RB.W1.data + h*dW + G0.CL[1,1].C.v1.data = Gini.CL[1,1].C.v1.data + h*dv + + f = loss(G0, X, Cond)[1] + err3[j] = abs(f - f0) + err4[j] = abs(f - f0 - h*dot(dW, ΔW) - h*dot(dv, Δv)) + print(err3[j], "; ", err4[j], "\n") + global h = h/2f0 +end + +@test isapprox(err3[end] / (err3[1]/2^(maxiter-1)), 1f0; atol=1f0) +@test isapprox(err4[end] / (err4[1]/4^(maxiter-1)), 1f0; atol=1f0) + + + ########################################### Test with split_scales = false ######################### # Invertibility @@ -25,21 +236,21 @@ G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K) X = rand(Float32, nx, ny, n_in, batchsize) Cond = rand(Float32, nx, ny, n_cond, batchsize) -Y, Cond = G.forward(X,Cond) -X_ = G.inverse(Y,Cond) # saving the cond is important in split scales because of reshapes +ZX, ZCond = G.forward(X,Cond) +X_ = G.inverse(ZX,ZCond)[1] @test isapprox(norm(X - X_)/norm(X), 0f0; atol=1f-5) ################################################################################################### # Test gradients are set and cleared -G.backward(Y, Y, Cond) +G.backward(ZX, ZX, ZCond) P = get_params(G) gsum = 0 for p in P ~isnothing(p.grad) && (global gsum += 1) end -@test isequal(gsum, L*K*10+2) +@test isequal(gsum, L*K*10+L) clear_grad!(G) gsum = 0 diff --git a/test/test_networks/test_glow.jl b/test/test_networks/test_glow.jl index 9bbd0542..e5ba18dd 100644 --- a/test/test_networks/test_glow.jl +++ b/test/test_networks/test_glow.jl @@ -19,7 +19,7 @@ K = 2 for split_scales = [true,false] for N in [(nx, ny), (nx, ny, nz)] - ###########################################Test with split_scales = false ######################### + println("Testing Glow with split_scales = $(split_scales) and dims = $(N)") # Invertibility # Network and input @@ -62,9 +62,9 @@ for split_scales = [true,false] end # Gradient test w.r.t. input - G = NetworkGlow(n_in, n_hidden, L, K) - X = rand(Float32, nx, ny, n_in, batchsize) - X0 = rand(Float32, nx, ny, n_in, batchsize) + G = NetworkGlow(n_in, n_hidden, L, K; split_scales=split_scales, ndims=length(N)) + X = rand(Float32, N..., n_in, batchsize) + X0 = rand(Float32, N..., n_in, batchsize) dX = X - X0 f0, ΔX = loss(G, X0)[1:2] @@ -86,9 +86,9 @@ for split_scales = [true,false] @test isapprox(err2[end] / (err2[1]/4^(maxiter-1)), 1f0; atol=1f1) # Gradient test w.r.t. parameters - X = rand(Float32, nx, ny, n_in, batchsize) - G = NetworkGlow(n_in, n_hidden, L, K) - G0 = NetworkGlow(n_in, n_hidden, L, K) + X = rand(Float32, N..., n_in, batchsize) + #G = NetworkGlow(n_in, n_hidden, L, K) + G0 = NetworkGlow(n_in, n_hidden, L, K; split_scales=split_scales, ndims=length(N)) Gini = deepcopy(G0) # Test one parameter from residual block and 1x1 conv @@ -122,16 +122,18 @@ for split_scales = [true,false] # Gradient test # Initialization - G = NetworkGlow(n_in, n_hidden, L, K); G.forward(randn(Float32, nx, ny, n_in, batchsize)) + G = NetworkGlow(n_in, n_hidden, L, K; split_scales=split_scales, ndims=length(N)) + G.forward(randn(Float32, N..., n_in, batchsize)) θ = deepcopy(get_params(G)) - G0 = NetworkGlow(n_in, n_hidden, L, K); G0.forward(randn(Float32, nx, ny, n_in, batchsize)) + G0 = NetworkGlow(n_in, n_hidden, L, K; split_scales=split_scales, ndims=length(N)) + G0.forward(randn(Float32, N..., n_in, batchsize)) θ0 = deepcopy(get_params(G0)) - X = randn(Float32, nx, ny, n_in, batchsize) + X = randn(Float32, N..., n_in, batchsize) # Perturbation (normalized) dθ = θ-θ0 dθ .*= norm.(θ0)./(norm.(dθ).+1f-10) - dX = randn(Float32, nx, ny, n_in, batchsize); dX *= norm(X)/norm(dX) + dX = randn(Float32, N..., n_in, batchsize); dX *= norm(X)/norm(dX) # Jacobian eval dY, Y, _, _ = G.jacobian(dX, dθ, X) diff --git a/test/test_utils/test_sequential.jl b/test/test_utils/test_sequential.jl index db45fb7b..49336af7 100644 --- a/test/test_utils/test_sequential.jl +++ b/test/test_utils/test_sequential.jl @@ -56,9 +56,8 @@ Y_, l_ = N_.forward(X) ############################################################################### # Test invertibility - -@test isapprox(X, N.inverse(N.forward(X)[1]); rtol=1f-3) -@test isapprox(X, N.forward(N.inverse(X))[1]; rtol=1f-3) +@test isapprox(X, N.inverse(N.forward(X)[1])[1]; rtol=1f-3) +@test isapprox(X, N.forward(N.inverse(X)[1])[1]; rtol=1f-3) ###############################################################################