Skip to content

Commit

Permalink
Merge pull request #87 from PumasAI/fixderivatives
Browse files Browse the repository at this point in the history
Broadcasted bias should not broadcast against duals.
  • Loading branch information
korsbo authored Jun 10, 2022
2 parents d1ffecc + 4580517 commit 80696ec
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 90 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SimpleChains"
uuid = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5"
authors = ["Chris Elrod <[email protected]> and contributors"]
version = "0.2.9"
version = "0.2.10"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down
82 changes: 56 additions & 26 deletions src/forwarddiff_matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,18 +108,19 @@ end
end

@inline reinterpret_dual(A::AbstractArray{ForwardDiff.Dual{T,V,N}}) where {T,V,N} =
reinterpret(V, A)
zero_offsets(reinterpret(V, A))
@inline function reinterpret_reshape_dual(
A::AbstractArray{ForwardDiff.Dual{T,V,N}},
) where {T,V,N}
reinterpret(reshape, V, A)
zero_offsets(reinterpret(reshape, V, A))
end
function contract_loops(
DC::Vector{Int},
DA::Vector{Int},
DB::Vector{Int},
lao::Bool,
update::Bool,
skipbroadcast::Vector{Int}
)
contract_dims = Int[]
Cref = Expr(:ref, :C)
Expand Down Expand Up @@ -170,6 +171,16 @@ function contract_loops(
q = if lao
Arefo = copy(Aref)
Arefo.args[end] = :K
# lao will not broadcast along excluded dims...
cond = Arefo
for d in skipbroadcast
i = findfirst(==(d), DC)::Int
c = :($(Cinds[i]) == $(ArrayInterface.offsets)(C, $(StaticInt{i}())))
cond = cond === Arefo ? c : :($cond & $c)
end
if cond !== Arefo # found
Arefo = :(ifelse($cond, $Arefo, zero(Caccum)))
end
quote
Caccum = $Cinit
$q
Expand Down Expand Up @@ -201,7 +212,7 @@ function contract_loops(
q = :(@turbo $q)
if lao
cd::Int = findfirst(==(first(contract_dims)), DA)
Expr(:block, :(K = lastindex(A, StaticInt{$cd}())), q)
Expr(:block, :(K = $(ArrayInterface.static_last)($axes(A, StaticInt{$cd}()))), q)
else
q
end
Expand All @@ -216,7 +227,8 @@ end
::Val{DB},
::Val{LAO},
::Val{U},
) where {TC<:NativeTypes,TA<:NativeTypes,TB<:NativeTypes,DC,DA,DB,LAO,U}
::Val{SB}
) where {TC<:NativeTypes,TA<:NativeTypes,TB<:NativeTypes,DC,DA,DB,LAO,U,SB}
dc = Vector{Int}(undef, length(DC))
for i in eachindex(DC)
dc[i] = DC[i]
Expand All @@ -229,7 +241,11 @@ end
for i in eachindex(DB)
db[i] = DB[i]
end
contract_loops(dc, da, db, LAO, U)
sb = LAO ? Vector{Int}(undef, length(SB)) : dc
for i in eachindex(SB)
sb[i] = SB[i]
end
contract_loops(dc, da, db, LAO, U, sb)
end
@generated function contract!(
C::PtrArray{<:Any,DDC,TC},
Expand All @@ -240,6 +256,7 @@ end
::Val{DB},
::Val{LAO},
::Val{U},
::Val{SB}
) where {
T,
P,
Expand All @@ -253,6 +270,7 @@ end
DB,
LAO,
U,
SB
}
if (DDC[1] & DDA[1]) & (DC[1] == DA[1]) & Bool(is_column_major(A)) & Bool(is_column_major(C))
r = reinterpret_dual
Expand All @@ -266,7 +284,7 @@ end
DAN = (new_dim, DA...)
end
quote
contract!($r(C), $r(A), B, Val{$DCN}(), Val{$DAN}(), Val{$DB}(), Val{$LAO}(), Val{$U}())
contract!($r(C), $r(A), B, Val{$DCN}(), Val{$DAN}(), Val{$DB}(), Val{$LAO}(), Val{$U}(), Val{$SB}())
end
end
@generated function contract!(
Expand All @@ -278,6 +296,7 @@ end
::Val{DB},
::Val{LAO},
::Val{U},
::Val{SB}
) where {
T,
P,
Expand All @@ -291,34 +310,39 @@ end
DB,
LAO,
U,
SB
}

r = reinterpret_reshape_dual
dimC::Int = Int(length(DC))
new_dim = ((Int(length(DA))::Int + Int(length(DB))::Int - dimC) >>> 1) + dimC
DCN = (new_dim, DC...)
DBN = (new_dim, DB...)
if LAO
SBN = (new_dim, SB...)
else
SBN = ()
end
quote
contract!($r(C), A, $r(B), Val{$DCN}(), Val{$DA}(), Val{$DBN}(), Val{$LAO}(), Val{$U}())
contract!($r(C), A, $r(B), Val{$DCN}(), Val{$DA}(), Val{$DBN}(), Val{$LAO}(), Val{$U}(), Val{$SBN}())
end
end


function view_d1_first(A::AbstractArray{<:Any,N}) where {N}
view(A, firstindex(A, static(1)), ntuple(_ -> (:), Val(N - 1))...)
@inline function view_d1_first(A::AbstractArray{<:Any,N}) where {N}
zero_offsets(view(A, firstindex(A, static(1)), ntuple(_ -> (:), Val(N - 1))...))
end
_increment_first(r::CloseOpen) = CloseOpen(r.lower + static(1), r.upper)
_increment_first(r::CloseOpen) = CloseOpen(getfield(r,:start) + static(1), getfield(r,:upper))
_increment_first(r::AbstractUnitRange) = first(r)+static(1):last(r)

function view_d1_notfirst(A::AbstractArray{<:Any,N}) where {N}
@inline function view_d1_notfirst(A::AbstractArray{<:Any,N}) where {N}
r = _increment_first(axes(A, static(1)))
view(A, r, ntuple(_ -> (:), Val(N - 1))...)
zero_offsets(view(A, r, ntuple(_ -> (:), Val(N - 1))...))
end
_decrement_last(r::CloseOpen) = CloseOpen(r.lower, r.upper - static(1))
_decrement_last(r::CloseOpen) = CloseOpen(getfield(r,:start), getfield(r,:upper) - static(1))
_decrement_last(r::AbstractUnitRange) = CloseOpen(first(r), last(r))
function view_dlast_front(A::AbstractArray{<:Any,N}) where {N}
@inline function view_dlast_front(A::AbstractArray{<:Any,N}) where {N}
r = _decrement_last(axes(A, static(N)))
view(A, ntuple(_ -> (:), Val(N - 1))..., r)
zero_offsets(view(A, ntuple(_ -> (:), Val(N - 1))..., r))
end

@generated function contract!(
Expand All @@ -330,6 +354,7 @@ end
::Val{DB},
::Val{LAO},
::Val{U},
::Val{SB}
) where {
T,
P,
Expand All @@ -344,6 +369,7 @@ end
DB,
LAO,
U,
SB
}

rr = reinterpret_reshape_dual
Expand Down Expand Up @@ -373,6 +399,7 @@ end
Val{$DB}(),
Val{$LAO}(),
Val{$U}(),
Val{$SB}()
)),
)
end
Expand All @@ -396,6 +423,7 @@ end
Val{$DBN}(),
Val{false}(),
Val{true}(),
Val{()}()
),
),
)
Expand All @@ -410,23 +438,23 @@ function matmul!(
B::PtrVector,
::True,
) where {D<:ForwardDiff.Dual}
contract!(C, A, B, Val{(0,)}(), Val{(0, 1)}(), Val{(1,)}(), Val{true}(), Val{false}())
contract!(zero_offsets(C), zero_offsets(A), zero_offsets(B), Val{(0,)}(), Val{(0, 1)}(), Val{(1,)}(), Val{true}(), Val{false}(), Val{()}())
end
function matmul!(
C::PtrMatrix{<:Any,<:Any,D},
A::PtrMatrix,
B::PtrMatrix,
::True,
) where {D<:ForwardDiff.Dual}
contract!(C, A, B, Val{(0, 1)}(), Val{(0, 2)}(), Val{(2, 1)}(), Val{true}(), Val{false}())
contract!(zero_offsets(C), zero_offsets(A), zero_offsets(B), Val{(0, 1)}(), Val{(0, 2)}(), Val{(2, 1)}(), Val{true}(), Val{false}(), Val{()}())
end
function matmul!(
C::PtrVector{<:Any,<:Any,D},
A::PtrMatrix{},
B::PtrVector,
::False,
) where {D<:ForwardDiff.Dual}
contract!(C, A, B, Val{(0,)}(), Val{(0, 1)}(), Val{(1,)}(), Val{false}(), Val{false}())
contract!(zero_offsets(C), zero_offsets(A), zero_offsets(B), Val{(0,)}(), Val{(0, 1)}(), Val{(1,)}(), Val{false}(), Val{false}(), Val{()}())
end
function matmul!(
C::PtrMatrix{<:Any,<:Any,D},
Expand All @@ -435,14 +463,15 @@ function matmul!(
::False,
) where {D<:ForwardDiff.Dual}
contract!(
C,
A,
B,
zero_offsets(C),
zero_offsets(A),
zero_offsets(B),
Val{(0, 1)}(),
Val{(0, 2)}(),
Val{(2, 1)}(),
Val{false}(),
Val{false}(),
Val{()}(),
)
end

Expand All @@ -452,19 +481,19 @@ function matmul!(
B::PtrVector,
bias::StaticBool,
) where {D<:ForwardDiff.Dual}
matmul!(vec(C), A, B, bias)
matmul!(zero_offsets(vec(C)), zero_offsets(A), zero_offsets(B), bias)
end
function matmul!(
C::PtrVector{<:Any,<:Any,D},
A::PtrMatrix,
B::PtrMatrix,
bias::StaticBool,
) where {D<:ForwardDiff.Dual}
matmul!(C, A, vec(B), bias)
matmul!(zero_offsets(C), zero_offsets(A), zero_offsets(vec(B)), bias)
end

function matmul!(C, A, B, bias::StaticBool)
GC.@preserve C A B matmul!(PtrArray(C), PtrArray(A), PtrArray(B), bias)
GC.@preserve C A B matmul!(zero_offsets(PtrArray(C)), zero_offsets(PtrArray(A)), zero_offsets(PtrArray(B)), bias)
end

function dense!(
Expand All @@ -475,6 +504,7 @@ function dense!(
::BT,
::FF,
) where {F,BT<:StaticBool,FF,T,P,D<:ForwardDiff.Dual{<:Any,T,P}}
matmul!(Cdual, A, B, BT())

matmul!(zero_offsets(Cdual), zero_offsets(A), zero_offsets(B), BT())
dualeval!(f, Cdual)
end
16 changes: 8 additions & 8 deletions test/matmul_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ for bias in (true, false)
M = 16
K = 20
N = 17
A = rand(M, K + bias)
B = rand(K, N)
bm = rand(K, 1)
A = rand(M, K + bias);
B = rand(K, N);
bm = rand(K, 1);

for fa1 in (identity, dual4x3),
fa2 in (identity, dual4x3),
Expand All @@ -44,15 +44,15 @@ for bias in (true, false)
end

SimpleChains.matmul!(C, A, B, static(bias))
@test C AB
@test reinterpret(Float64,C) reinterpret(Float64,AB)
SimpleChains.matmul!(c, A, b, static(bias))
@test c Ab
@test reinterpret(Float64,c) reinterpret(Float64,Ab)
SimpleChains.matmul!(c, A, bm, static(bias))
@test c Ab
@test reinterpret(Float64,c) reinterpret(Float64,Ab)
SimpleChains.matmul!(cm, A, b, static(bias))
@test vec(cm) Ab
@test reinterpret(Float64,vec(cm)) reinterpret(Float64,Ab)
SimpleChains.matmul!(cm, A, bm, static(bias))
@test vec(cm) Ab
@test reinterpret(Float64,vec(cm)) reinterpret(Float64,Ab)
end
end
end
Expand Down
Loading

2 comments on commit 80696ec

@korsbo
Copy link
Member Author

@korsbo korsbo commented on 80696ec Jun 10, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/62102

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.2.10 -m "<description of version>" 80696ecb025e102d9470f159abb57238b76077fb
git push origin v0.2.10

Please sign in to comment.