Skip to content

Commit

Permalink
Fix context patcher (#526)
Browse files Browse the repository at this point in the history
  • Loading branch information
antonmedv authored Jan 14, 2024
1 parent 523a091 commit e213a66
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 15 deletions.
10 changes: 0 additions & 10 deletions builtin/builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,6 @@ import (
"github.com/expr-lang/expr/vm/runtime"
)

type Function struct {
Name string
Func func(args ...any) (any, error)
Fast func(arg any) any
ValidateArgs func(args ...any) (any, error)
Types []reflect.Type
Validate func(args []reflect.Type) (reflect.Type, error)
Predicate bool
}

var (
Index map[string]int
Names []string
Expand Down
22 changes: 22 additions & 0 deletions builtin/function.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package builtin

import (
"reflect"
)

type Function struct {
Name string
Func func(args ...any) (any, error)
Fast func(arg any) any
ValidateArgs func(args ...any) (any, error)
Types []reflect.Type
Validate func(args []reflect.Type) (reflect.Type, error)
Predicate bool
}

func (f *Function) Type() reflect.Type {
if len(f.Types) > 0 {
return f.Types[0]
}
return reflect.TypeOf(f.Func)
}
6 changes: 3 additions & 3 deletions checker/checker.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,10 +169,10 @@ func (v *checker) ident(node ast.Node, name string, strict, builtins bool) (refl
}
if builtins {
if fn, ok := v.config.Functions[name]; ok {
return functionType, info{fn: fn}
return fn.Type(), info{fn: fn}
}
if fn, ok := v.config.Builtins[name]; ok {
return functionType, info{fn: fn}
return fn.Type(), info{fn: fn}
}
}
if v.config.Strict && strict {
Expand Down Expand Up @@ -833,7 +833,7 @@ func (v *checker) checkFunction(f *builtin.Function, node ast.Node, arguments []
}
return t, info{}
} else if len(f.Types) == 0 {
t, err := v.checkArguments(f.Name, functionType, false, arguments, node)
t, err := v.checkArguments(f.Name, f.Type(), false, arguments, node)
if err != nil {
if v.err == nil {
v.err = err
Expand Down
2 changes: 1 addition & 1 deletion checker/checker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -994,7 +994,7 @@ func TestCheck_builtin_without_call(t *testing.T) {
err string
}{
{`len + 1`, "invalid operation: + (mismatched types func(...interface {}) (interface {}, error) and int) (1:5)\n | len + 1\n | ....^"},
{`string.A`, "type func(...interface {}) (interface {}, error)[string] is undefined (1:8)\n | string.A\n | .......^"},
{`string.A`, "type func(interface {}) string[string] is undefined (1:8)\n | string.A\n | .......^"},
}

for _, test := range tests {
Expand Down
1 change: 0 additions & 1 deletion checker/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ var (
anyType = reflect.TypeOf(new(any)).Elem()
timeType = reflect.TypeOf(time.Time{})
durationType = reflect.TypeOf(time.Duration(0))
functionType = reflect.TypeOf(new(func(...any) (any, error))).Elem()
)

func combined(a, b reflect.Type) reflect.Type {
Expand Down
31 changes: 31 additions & 0 deletions patcher/with_context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,34 @@ func TestWithContext(t *testing.T) {
require.NoError(t, err)
require.Equal(t, 42, output)
}

func TestWithContext_with_env_Function(t *testing.T) {
env := map[string]any{
"ctx": context.TODO(),
}

fn := expr.Function("fn",
func(params ...any) (any, error) {
ctx := params[0].(context.Context)
a := params[1].(int)

return ctx.Value("value").(int) + a, nil
},
new(func(context.Context, int) int),
)

program, err := expr.Compile(
`fn(40)`,
expr.Env(env),
expr.WithContext("ctx"),
fn,
)
require.NoError(t, err)

ctx := context.WithValue(context.Background(), "value", 2)
env["ctx"] = ctx

output, err := expr.Run(program, env)
require.NoError(t, err)
require.Equal(t, 42, output)
}

0 comments on commit e213a66

Please sign in to comment.