Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix add wishart wishartfast prod #217

Merged
merged 5 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/distributions/wishart.jl
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,12 @@ function BayesBase.prod(::PreserveTypeProd{Distribution}, left::WishartFast, rig
return WishartFast(df, invV)
end

BayesBase.default_prod_rule(::Type{<:Wishart}, ::Type{<:WishartFast}) = PreserveTypeProd(Distribution)

function BayesBase.prod(::PreserveTypeProd{Distribution}, left::Wishart, right::WishartFast)
return prod(PreserveTypeProd(Distribution), convert(WishartFast, left), right)
end

function BayesBase.insupport(ef::ExponentialFamilyDistribution{WishartFast}, x::Matrix)
return size(getindex(unpack_parameters(ef), 2)) == size(x) && isposdef(x)
end
Expand Down
12 changes: 12 additions & 0 deletions src/distributions/wishart_inverse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,18 @@ function BayesBase.prod(::PreserveTypeProd{Distribution}, left::InverseWishartFa
return InverseWishartFast(df, V)
end

BayesBase.default_prod_rule(::Type{<:InverseWishart}, ::Type{<:InverseWishartFast}) = PreserveTypeProd(Distribution)

function BayesBase.prod(::PreserveTypeProd{Distribution}, left::InverseWishart, right::InverseWishartFast)
return prod(PreserveTypeProd(Distribution), convert(InverseWishartFast, left), right)
end

BayesBase.default_prod_rule(::Type{<:InverseWishart}, ::Type{<:InverseWishart}) = PreserveTypeProd(Distribution)

function BayesBase.prod(::PreserveTypeProd{Distribution}, left::InverseWishart, right::InverseWishart)
return prod(PreserveTypeProd(Distribution), convert(InverseWishartFast, left), convert(InverseWishartFast, right))
end

function BayesBase.insupport(ef::ExponentialFamilyDistribution{InverseWishartFast}, x::Matrix)
return size(getindex(unpack_parameters(ef), 2)) == size(x) && isposdef(x)
end
Expand Down
34 changes: 34 additions & 0 deletions test/distributions/wishart_inverse_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -228,3 +228,37 @@ end
end
end
end

@testitem "InverseWishart: prod between InverseWishart and InverseWishartFast" begin
include("distributions_setuptests.jl")

import ExponentialFamily: InverseWishartFast
import Distributions: InverseWishart

for Sleft in rand(InverseWishart(10, Array(Eye(2))), 2), Sright in rand(InverseWishart(10, Array(Eye(2))), 2), νright in (6, 7), νleft in (4, 5)
let left = InverseWishart(νleft, Sleft), right = InverseWishart(νleft, Sleft), right_fast = convert(InverseWishartFast, right)
# Test commutativity of the product
prod_result1 = prod(PreserveTypeProd(Distribution), left, right_fast)
prod_result2 = prod(PreserveTypeProd(Distribution), right_fast, left)

@test prod_result1.ν ≈ prod_result2.ν
@test prod_result1.S ≈ prod_result2.S

# Test that the product preserves type
@test prod_result1 isa InverseWishartFast
@test prod_result2 isa InverseWishartFast

# prod stays if we convert fisrt and then do product
left_fast = convert(InverseWishartFast, left)
prod_fast = prod(ClosedProd(), left_fast, right_fast)

@test prod_fast.ν ≈ prod_result1.ν
@test prod_fast.S ≈ prod_result2.S

# prod for Inverse Wishart is defenied
prod_result_not_fast = prod(PreserveTypeProd(Distribution), left, right)
@test prod_result_not_fast.ν ≈ prod_result1.ν
@test prod_result_not_fast.S ≈ prod_result1.S
end
end
end
30 changes: 30 additions & 0 deletions test/distributions/wishart_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,33 @@ end
end
end
end

@testitem "Wishart: prod between Wishart and WishartFast" begin
include("distributions_setuptests.jl")

import ExponentialFamily: WishartFast
import Distributions: Wishart

for Sleft in rand(Wishart(10, Array(Eye(2))), 2), Sright in rand(Wishart(10, Array(Eye(2))), 2), νright in (6, 7), νleft in (4, 5)
let left = Wishart(νleft, Sleft), right = WishartFast(νright, Sright)
# Test commutativity of the product
prod_result1 = prod(PreserveTypeProd(Distribution), left, right)
prod_result2 = prod(PreserveTypeProd(Distribution), right, left)

@test prod_result1.ν ≈ prod_result2.ν
@test prod_result1.invS ≈ prod_result2.invS

# Test that the product preserves type
@test prod_result1 isa WishartFast
@test prod_result2 isa WishartFast

# prod stays the same if we convert fisrt and then do product
left_fast = convert(WishartFast, left)
prod_fast = prod(ClosedProd(), left_fast, right)

@test prod_fast.ν ≈ prod_result1.ν
@test prod_fast.invS ≈ prod_result2.invS
end
end
end

Loading