From 77c70d05bb59ed770eb826298e5a406c076c0cb2 Mon Sep 17 00:00:00 2001 From: Chris Elrod Date: Thu, 9 Jun 2022 18:50:42 -0400 Subject: [PATCH 1/2] Broadcasted bias should not broadcast against duals. --- src/forwarddiff_matmul.jl | 82 ++++++++++++++++++++++++++------------ test/matmul_tests.jl | 16 ++++---- test/runtests.jl | 84 ++++++++++++++------------------------- 3 files changed, 93 insertions(+), 89 deletions(-) diff --git a/src/forwarddiff_matmul.jl b/src/forwarddiff_matmul.jl index e6c09b0..aec7b8e 100644 --- a/src/forwarddiff_matmul.jl +++ b/src/forwarddiff_matmul.jl @@ -108,11 +108,11 @@ end end @inline reinterpret_dual(A::AbstractArray{ForwardDiff.Dual{T,V,N}}) where {T,V,N} = - reinterpret(V, A) + zero_offsets(reinterpret(V, A)) @inline function reinterpret_reshape_dual( A::AbstractArray{ForwardDiff.Dual{T,V,N}}, ) where {T,V,N} - reinterpret(reshape, V, A) + zero_offsets(reinterpret(reshape, V, A)) end function contract_loops( DC::Vector{Int}, @@ -120,6 +120,7 @@ function contract_loops( DB::Vector{Int}, lao::Bool, update::Bool, + skipbroadcast::Vector{Int} ) contract_dims = Int[] Cref = Expr(:ref, :C) @@ -170,6 +171,16 @@ function contract_loops( q = if lao Arefo = copy(Aref) Arefo.args[end] = :K + # lao will not broadcast along excluded dims... + cond = Arefo + for d in skipbroadcast + i = findfirst(==(d), DC)::Int + c = :($(Cinds[i]) == $(ArrayInterface.offsets)(C, $(StaticInt{i}()))) + cond = cond === Arefo ? c : :($cond & $c) + end + if cond !== Arefo # found + Arefo = :(ifelse($cond, $Arefo, zero(Caccum))) + end quote Caccum = $Cinit $q @@ -201,7 +212,7 @@ function contract_loops( q = :(@turbo $q) if lao cd::Int = findfirst(==(first(contract_dims)), DA) - Expr(:block, :(K = lastindex(A, StaticInt{$cd}())), q) + Expr(:block, :(K = $(ArrayInterface.static_last)($axes(A, StaticInt{$cd}()))), q) else q end @@ -216,7 +227,8 @@ end ::Val{DB}, ::Val{LAO}, ::Val{U}, -) where {TC<:NativeTypes,TA<:NativeTypes,TB<:NativeTypes,DC,DA,DB,LAO,U} + ::Val{SB} +) where {TC<:NativeTypes,TA<:NativeTypes,TB<:NativeTypes,DC,DA,DB,LAO,U,SB} dc = Vector{Int}(undef, length(DC)) for i in eachindex(DC) dc[i] = DC[i] @@ -229,7 +241,11 @@ end for i in eachindex(DB) db[i] = DB[i] end - contract_loops(dc, da, db, LAO, U) + sb = LAO ? Vector{Int}(undef, length(SB)) : dc + for i in eachindex(SB) + sb[i] = SB[i] + end + contract_loops(dc, da, db, LAO, U, sb) end @generated function contract!( C::PtrArray{<:Any,DDC,TC}, @@ -240,6 +256,7 @@ end ::Val{DB}, ::Val{LAO}, ::Val{U}, + ::Val{SB} ) where { T, P, @@ -253,6 +270,7 @@ end DB, LAO, U, + SB } if (DDC[1] & DDA[1]) & (DC[1] == DA[1]) & Bool(is_column_major(A)) & Bool(is_column_major(C)) r = reinterpret_dual @@ -266,7 +284,7 @@ end DAN = (new_dim, DA...) end quote - contract!($r(C), $r(A), B, Val{$DCN}(), Val{$DAN}(), Val{$DB}(), Val{$LAO}(), Val{$U}()) + contract!($r(C), $r(A), B, Val{$DCN}(), Val{$DAN}(), Val{$DB}(), Val{$LAO}(), Val{$U}(), Val{$SB}()) end end @generated function contract!( @@ -278,6 +296,7 @@ end ::Val{DB}, ::Val{LAO}, ::Val{U}, + ::Val{SB} ) where { T, P, @@ -291,34 +310,39 @@ end DB, LAO, U, + SB } - r = reinterpret_reshape_dual dimC::Int = Int(length(DC)) new_dim = ((Int(length(DA))::Int + Int(length(DB))::Int - dimC) >>> 1) + dimC DCN = (new_dim, DC...) DBN = (new_dim, DB...) + if LAO + SBN = (new_dim, SB...) + else + SBN = () + end quote - contract!($r(C), A, $r(B), Val{$DCN}(), Val{$DA}(), Val{$DBN}(), Val{$LAO}(), Val{$U}()) + contract!($r(C), A, $r(B), Val{$DCN}(), Val{$DA}(), Val{$DBN}(), Val{$LAO}(), Val{$U}(), Val{$SBN}()) end end -function view_d1_first(A::AbstractArray{<:Any,N}) where {N} - view(A, firstindex(A, static(1)), ntuple(_ -> (:), Val(N - 1))...) +@inline function view_d1_first(A::AbstractArray{<:Any,N}) where {N} + zero_offsets(view(A, firstindex(A, static(1)), ntuple(_ -> (:), Val(N - 1))...)) end -_increment_first(r::CloseOpen) = CloseOpen(r.lower + static(1), r.upper) +_increment_first(r::CloseOpen) = CloseOpen(getfield(r,:start) + static(1), getfield(r,:upper)) _increment_first(r::AbstractUnitRange) = first(r)+static(1):last(r) -function view_d1_notfirst(A::AbstractArray{<:Any,N}) where {N} +@inline function view_d1_notfirst(A::AbstractArray{<:Any,N}) where {N} r = _increment_first(axes(A, static(1))) - view(A, r, ntuple(_ -> (:), Val(N - 1))...) + zero_offsets(view(A, r, ntuple(_ -> (:), Val(N - 1))...)) end -_decrement_last(r::CloseOpen) = CloseOpen(r.lower, r.upper - static(1)) +_decrement_last(r::CloseOpen) = CloseOpen(getfield(r,:start), getfield(r,:upper) - static(1)) _decrement_last(r::AbstractUnitRange) = CloseOpen(first(r), last(r)) -function view_dlast_front(A::AbstractArray{<:Any,N}) where {N} +@inline function view_dlast_front(A::AbstractArray{<:Any,N}) where {N} r = _decrement_last(axes(A, static(N))) - view(A, ntuple(_ -> (:), Val(N - 1))..., r) + zero_offsets(view(A, ntuple(_ -> (:), Val(N - 1))..., r)) end @generated function contract!( @@ -330,6 +354,7 @@ end ::Val{DB}, ::Val{LAO}, ::Val{U}, + ::Val{SB} ) where { T, P, @@ -344,6 +369,7 @@ end DB, LAO, U, + SB } rr = reinterpret_reshape_dual @@ -373,6 +399,7 @@ end Val{$DB}(), Val{$LAO}(), Val{$U}(), + Val{$SB}() )), ) end @@ -396,6 +423,7 @@ end Val{$DBN}(), Val{false}(), Val{true}(), + Val{()}() ), ), ) @@ -410,7 +438,7 @@ function matmul!( B::PtrVector, ::True, ) where {D<:ForwardDiff.Dual} - contract!(C, A, B, Val{(0,)}(), Val{(0, 1)}(), Val{(1,)}(), Val{true}(), Val{false}()) + contract!(zero_offsets(C), zero_offsets(A), zero_offsets(B), Val{(0,)}(), Val{(0, 1)}(), Val{(1,)}(), Val{true}(), Val{false}(), Val{()}()) end function matmul!( C::PtrMatrix{<:Any,<:Any,D}, @@ -418,7 +446,7 @@ function matmul!( B::PtrMatrix, ::True, ) where {D<:ForwardDiff.Dual} - contract!(C, A, B, Val{(0, 1)}(), Val{(0, 2)}(), Val{(2, 1)}(), Val{true}(), Val{false}()) + contract!(zero_offsets(C), zero_offsets(A), zero_offsets(B), Val{(0, 1)}(), Val{(0, 2)}(), Val{(2, 1)}(), Val{true}(), Val{false}(), Val{()}()) end function matmul!( C::PtrVector{<:Any,<:Any,D}, @@ -426,7 +454,7 @@ function matmul!( B::PtrVector, ::False, ) where {D<:ForwardDiff.Dual} - contract!(C, A, B, Val{(0,)}(), Val{(0, 1)}(), Val{(1,)}(), Val{false}(), Val{false}()) + contract!(zero_offsets(C), zero_offsets(A), zero_offsets(B), Val{(0,)}(), Val{(0, 1)}(), Val{(1,)}(), Val{false}(), Val{false}(), Val{()}()) end function matmul!( C::PtrMatrix{<:Any,<:Any,D}, @@ -435,14 +463,15 @@ function matmul!( ::False, ) where {D<:ForwardDiff.Dual} contract!( - C, - A, - B, + zero_offsets(C), + zero_offsets(A), + zero_offsets(B), Val{(0, 1)}(), Val{(0, 2)}(), Val{(2, 1)}(), Val{false}(), Val{false}(), + Val{()}(), ) end @@ -452,7 +481,7 @@ function matmul!( B::PtrVector, bias::StaticBool, ) where {D<:ForwardDiff.Dual} - matmul!(vec(C), A, B, bias) + matmul!(zero_offsets(vec(C)), zero_offsets(A), zero_offsets(B), bias) end function matmul!( C::PtrVector{<:Any,<:Any,D}, @@ -460,11 +489,11 @@ function matmul!( B::PtrMatrix, bias::StaticBool, ) where {D<:ForwardDiff.Dual} - matmul!(C, A, vec(B), bias) + matmul!(zero_offsets(C), zero_offsets(A), zero_offsets(vec(B)), bias) end function matmul!(C, A, B, bias::StaticBool) - GC.@preserve C A B matmul!(PtrArray(C), PtrArray(A), PtrArray(B), bias) + GC.@preserve C A B matmul!(zero_offsets(PtrArray(C)), zero_offsets(PtrArray(A)), zero_offsets(PtrArray(B)), bias) end function dense!( @@ -475,6 +504,7 @@ function dense!( ::BT, ::FF, ) where {F,BT<:StaticBool,FF,T,P,D<:ForwardDiff.Dual{<:Any,T,P}} - matmul!(Cdual, A, B, BT()) + + matmul!(zero_offsets(Cdual), zero_offsets(A), zero_offsets(B), BT()) dualeval!(f, Cdual) end diff --git a/test/matmul_tests.jl b/test/matmul_tests.jl index 31ef043..666b425 100644 --- a/test/matmul_tests.jl +++ b/test/matmul_tests.jl @@ -16,9 +16,9 @@ for bias in (true, false) M = 16 K = 20 N = 17 - A = rand(M, K + bias) - B = rand(K, N) - bm = rand(K, 1) + A = rand(M, K + bias); + B = rand(K, N); + bm = rand(K, 1); for fa1 in (identity, dual4x3), fa2 in (identity, dual4x3), @@ -44,15 +44,15 @@ for bias in (true, false) end SimpleChains.matmul!(C, A, B, static(bias)) - @test C ≈ AB + @test reinterpret(Float64,C) ≈ reinterpret(Float64,AB) SimpleChains.matmul!(c, A, b, static(bias)) - @test c ≈ Ab + @test reinterpret(Float64,c) ≈ reinterpret(Float64,Ab) SimpleChains.matmul!(c, A, bm, static(bias)) - @test c ≈ Ab + @test reinterpret(Float64,c) ≈ reinterpret(Float64,Ab) SimpleChains.matmul!(cm, A, b, static(bias)) - @test vec(cm) ≈ Ab + @test reinterpret(Float64,vec(cm)) ≈ reinterpret(Float64,Ab) SimpleChains.matmul!(cm, A, bm, static(bias)) - @test vec(cm) ≈ Ab + @test reinterpret(Float64,vec(cm)) ≈ reinterpret(Float64,Ab) end end end diff --git a/test/runtests.jl b/test/runtests.jl index d3589a6..0e8ee28 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,8 +4,10 @@ using Test, Aqua, ForwardDiff, Zygote function countallocations!(g, sc, x, p) @allocated valgrad!(g, sc, x, p) end -dual(x) = ForwardDiff.Dual(x, randn(), randn(), randn()) -dual(x::ForwardDiff.Dual) = ForwardDiff.Dual(x, dual(randn()), dual(randn())) +dual(x::T) where {T} = ForwardDiff.Dual(x, 4randn(T), 4randn(T), 4randn(T)) +function dual(x::ForwardDiff.Dual{<:Any,T}) where {T} + ForwardDiff.Dual(x, dual(4randn(T)), dual(4randn(T))) +end import InteractiveUtils InteractiveUtils.versioninfo(verbose=true) @@ -252,65 +254,37 @@ InteractiveUtils.versioninfo(verbose=true) undef, first(SimpleChains.layer_output_size(Val(eltype(xdd)), td, size(x))), ) - - A = reshape(view(p, 1:8*24), (8, 24)) - b = view(p, 1+8*24:8*25) - Ad = reshape(view(pd, 1:8*24), (8, 24)) - bd = view(pd, 1+8*24:8*25) + dim = size(x,1); + A = reshape(view(p, 1:8dim), (8, dim)) + b = view(p, 1+8dim:8*25) + Ad = reshape(view(pd, 1:8dim), (8, dim)) + bd = view(pd, 1+8dim:8*25) ld = tanh.(Ad * x .+ bd) l_d = tanh.(A * xd .+ b) ld_d = tanh.(Ad * xd .+ bd) - Add = reshape(view(pdd, 1:8*24), (8, 24)) - bdd = view(pdd, 1+8*24:8*25) + Add = reshape(view(pdd, 1:8dim), (8, dim)) + bdd = view(pdd, 1+8dim:8*25) ldd = tanh.(Add * x .+ bdd) ldd_dd = tanh.(Add * xdd .+ bdd) - if T === Float64 - GC.@preserve pd pu begin - @test reinterpret(T, ld) ≈ reinterpret(T, td(x, pointer(pd), pointer(pu))[1]) - @test reinterpret(T, ld) ≈ - reinterpret(T, td(permutedims(x)', pointer(pd), pointer(pu))[1]) - @test reinterpret(T, l_d) ≈ reinterpret(T, td(xd, pointer(p), pointer(pu))[1]) - @test reinterpret(T, l_d) ≈ - reinterpret(T, td(permutedims(xd)', pointer(p), pointer(pu))[1]) - @test reinterpret(T, ld_d) ≈ reinterpret(T, td(xd, pointer(pd), pointer(pu))[1]) - @test reinterpret(T, ld_d) ≈ - reinterpret(T, td(permutedims(xd)', pointer(pd), pointer(pu))[1]) - - @test reinterpret(T, ldd) ≈ reinterpret(T, td(x, pointer(pdd), pointer(pu))[1]) - @test reinterpret(T, ldd_dd) ≈ - reinterpret(T, td(xdd, pointer(pdd), pointer(pu))[1]) - @test reinterpret(T, ldd) ≈ - reinterpret(T, td(permutedims(x)', pointer(pdd), pointer(pu))[1]) - @test reinterpret(T, ldd_dd) ≈ - reinterpret(T, td(permutedims(xdd)', pointer(pdd), pointer(pu))[1]) - end - else - GC.@preserve pd pu begin - @test_broken reinterpret(T, ld) ≈ - reinterpret(T, td(x, pointer(pd), pointer(pu))[1]) - @test_broken reinterpret(T, l_d) ≈ - reinterpret(T, td(xd, pointer(p), pointer(pu))[1]) - @test_broken reinterpret(T, ld_d) ≈ - reinterpret(T, td(xd, pointer(pd), pointer(pu))[1]) - - @test_broken reinterpret(T, ldd) ≈ - reinterpret(T, td(x, pointer(pdd), pointer(pu))[1]) - @test_broken reinterpret(T, ldd_dd) ≈ - reinterpret(T, td(xdd, pointer(pdd), pointer(pu))[1]) - - @test_broken reinterpret(T, ld) ≈ - reinterpret(T, td(permutedims(x)', pointer(pd), pointer(pu))[1]) - @test_broken reinterpret(T, l_d) ≈ - reinterpret(T, td(permutedims(xd)', pointer(p), pointer(pu))[1]) - @test_broken reinterpret(T, ld_d) ≈ - reinterpret(T, td(permutedims(xd)', pointer(pd), pointer(pu))[1]) - - @test_broken reinterpret(T, ldd) ≈ - reinterpret(T, td(permutedims(x)', pointer(pdd), pointer(pu))[1]) - @test_broken reinterpret(T, ldd_dd) ≈ - reinterpret(T, td(permutedims(xdd)', pointer(pdd), pointer(pu))[1]) - end + GC.@preserve pd pu begin + @test reinterpret(T, ld) ≈ reinterpret(T, td(x, pointer(pd), pointer(pu))[1]) + @test reinterpret(T, ld) ≈ + reinterpret(T, td(permutedims(x)', pointer(pd), pointer(pu))[1]) + @test reinterpret(T, l_d) ≈ reinterpret(T, td(xd, pointer(p), pointer(pu))[1]) + @test reinterpret(T, l_d) ≈ + reinterpret(T, td(permutedims(xd)', pointer(p), pointer(pu))[1]) + @test reinterpret(T, ld_d) ≈ reinterpret(T, td(xd, pointer(pd), pointer(pu))[1]) + @test reinterpret(T, ld_d) ≈ + reinterpret(T, td(permutedims(xd)', pointer(pd), pointer(pu))[1]) + + @test reinterpret(T, ldd) ≈ reinterpret(T, td(x, pointer(pdd), pointer(pu))[1]) + @test reinterpret(T, ldd_dd) ≈ + reinterpret(T, td(xdd, pointer(pdd), pointer(pu))[1]) + @test reinterpret(T, ldd) ≈ + reinterpret(T, td(permutedims(x)', pointer(pdd), pointer(pu))[1]) + @test reinterpret(T, ldd_dd) ≈ + reinterpret(T, td(permutedims(xdd)', pointer(pdd), pointer(pu))[1]) end @testset "training" begin p .= randn.() .* 100 From 4580517eaf83f21f819b3851199ad2031394f004 Mon Sep 17 00:00:00 2001 From: Chris Elrod Date: Thu, 9 Jun 2022 18:52:47 -0400 Subject: [PATCH 2/2] Bump version. --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 46e0fa2..8d7d77b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SimpleChains" uuid = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5" authors = ["Chris Elrod and contributors"] -version = "0.2.9" +version = "0.2.10" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"