Skip to content

Commit

Permalink
Check for broadcast op and format. Fixes #143.
Browse files Browse the repository at this point in the history
  • Loading branch information
chriselrod committed Apr 17, 2024
1 parent 0b2d48a commit ea9eb34
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 25 deletions.
5 changes: 3 additions & 2 deletions src/closure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,9 @@ function extractargs!(

startind = 1
if head === :call
if args[1] isa Symbol
startind = isdefined(mod, args[1]) ? 2 : 1
arg1 = args[1]
if arg1 isa Symbol && (first(string(arg1)) != '.')
startind = isdefined(mod, arg1) ? 2 : 1
else
startind = 2
end
Expand Down
76 changes: 53 additions & 23 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,21 +114,21 @@ end

function issue108!(y::Vector{T1}, x::Vector{T2}) where {T1,T2}
@batch for i in eachindex(y)
y[i] = sum(x[j] for j in 2i-oneunit(i):2i)
y[i] = sum(x[j] for j = 2i-oneunit(i):2i)
end
end

function issue108_comment!(data::Vector{T}, functions) where {T}
@batch for i in eachindex(data)
for f in functions
data[i] += f(data[i])
data[i] += f(data[i])
end
end
end

function issue116!(y::Vector{T}, x::Vector{T}) where {T}
@batch for i in 1:length(x)
y[i] = exp(x[i] + one(T))
@batch for i = 1:length(x)
y[i] = exp(x[i] + one(T))
end
end

Expand Down Expand Up @@ -280,15 +280,15 @@ end
x = collect(1:12)
y = zeros(6)
issue108!(y, x)
@test y == [sum(x[j] for j in 2i-oneunit(i):2i) for i in 1:6]
functions = [x -> n*x for n in 1:3]
@test y == [sum(x[j] for j = 2i-oneunit(i):2i) for i = 1:6]

functions = [x -> n * x for n = 1:3]
data = rand(100)
data1 = deepcopy(data)
issue108_comment!(data, functions)
for i in eachindex(data1)
for f in functions
data1[i] += f(data1[i])
data1[i] += f(data1[i])
end
end
@test data == data1
Expand Down Expand Up @@ -467,7 +467,7 @@ end
end
local7, local8 = let
red = 0
@batch minbatch = 100 stride = true reduction = (+,red) threadlocal = red for i = 0:9
@batch minbatch = 100 stride = true reduction = (+, red) threadlocal = red for i = 0:9
red += 1
threadlocal += 1
end
Expand All @@ -480,11 +480,19 @@ end
end
red
end
@test local1==local2==local3==local4==local5==local6==local7==local8==localsr
@test local1 ==
local2 ==
local3 ==
local4 ==
local5 ==
local6 ==
local7 ==
local8 ==
localsr
# check different operations
local9 = let
red = 1.0
@batch reduction = (*,red) for i = 1:100
@batch reduction = (*, red) for i = 1:100
red *= 4i^2 / (4i^2 - 1)
end
2red
Expand All @@ -495,7 +503,7 @@ end
red1 = 0
red2 = 0
red3 = 0
@batch reduction = ((+,red1), (+,red2), (+,red3)) for i = 0:9
@batch reduction = ((+, red1), (+, red2), (+, red3)) for i = 0:9
red1 += 1
red2 += 1
red3 -= 1
Expand All @@ -507,13 +515,19 @@ end
function f()
n = 1000
threadlocal = 0
@batch minbatch = 10 reduction = (+,threadlocal) for i = 1:n
@batch minbatch = 10 reduction = (+, threadlocal) for i = 1:n
threadlocal += 1
end
return threadlocal
end
allocated(f::F) where {F} = @allocated f()
inferred(f::F) where {F} = try @inferred f(); true catch; false end
inferred(f::F) where {F} =
try
@inferred f()
true
catch
false
end
allocated(f)
@test allocated(f) == 0
@test inferred(f) == true
Expand All @@ -524,16 +538,20 @@ end
red2 = false
red3 = typemax(eltype(arr))
red4 = typemin(eltype(arr))
@batch reduction = ((&,red1), (|,red2), (min,red3), (max,red4)) for x in arr
red1 &= x > 0.5
red2 |= x > 0.5
red3 = min(red3, x)
red4 = max(red4, x)
@batch reduction = ((&, red1), (|, red2), (min, red3), (max, red4)) for x in arr
red1 &= x > 0.5
red2 |= x > 0.5
red3 = min(red3, x)
red4 = max(red4, x)
end
red1, red2, red3, red4
end
@test (local13, local14, local15, local16) ==
(mapreduce(x->x>0.5, &, arr), mapreduce(x->x>0.5, |, arr), minimum(arr), maximum(arr))
@test (local13, local14, local15, local16) == (
mapreduce(x -> x > 0.5, &, arr),
mapreduce(x -> x > 0.5, |, arr),
minimum(arr),
maximum(arr),
)
end

@testset "locks and refvalues" begin
Expand Down Expand Up @@ -747,14 +765,26 @@ end
return any(find_call_to_nthreads, expr.args)
end

expr = @macroexpand @batch for i in 1:100
expr = @macroexpand @batch for i = 1:100
a[i] = i
end

@test find_call_to_nthreads(expr)
end


function dummy_broadcast!(x)
@batch for i = 1:2
a = (1,) .+ (1,)
x[i] = only(a)
end
end
let x = Vector{Float64}(undef, 2)
dummy_broadcast!(x)
@test x == fill(2.0, 2)
end

if VERSION v"1.6"
println("Package tests complete. Running `Aqua` checks.")
Aqua.test_all(Polyester; deps_compat = (check_extras=false,))
Aqua.test_all(Polyester; deps_compat = (check_extras = false,))
end

0 comments on commit ea9eb34

Please sign in to comment.