Skip to content

Commit

Permalink
Merge pull request #356 from xushiwei/op
Browse files Browse the repository at this point in the history
refactor cb.BinaryOp
  • Loading branch information
xushiwei authored Jan 28, 2024
2 parents d79c312 + a2480c4 commit 49986c4
Show file tree
Hide file tree
Showing 8 changed files with 197 additions and 93 deletions.
7 changes: 4 additions & 3 deletions ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,7 @@ func getParam1st(sig *types.Signature) int {
return 0
}

// TODO: check if fn.recv != nil
func matchFuncCall(pkg *Package, fn *internal.Elem, args []*internal.Elem, flags InstrFlags) (ret *internal.Elem, err error) {
fnType := fn.Type
if debugMatch {
Expand Down Expand Up @@ -633,10 +634,10 @@ retry:
backup := backupArgs(args)
for _, o := range ft.Methods {
mfn := *fn
mfn.Val.(*ast.SelectorExpr).Sel = ident(o.Name())
if (flags & instrFlagOpFunc) != 0 { // from callOpFunc
mfn.Type = o.Type()
if (flags & instrFlagBinaryOp) != 0 { // from cb.BinaryOp
mfn.Type = methodToFuncSig(pkg, o, &mfn)
} else {
mfn.Val.(*ast.SelectorExpr).Sel = ident(o.Name())
mfn.Type = methodCallSig(o.Type())
}
if ret, err = matchFuncCall(pkg, &mfn, args, flags); err == nil {
Expand Down
25 changes: 17 additions & 8 deletions builtin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,26 @@ func getConf() *Config {
return &Config{Fset: fset, Importer: imp}
}

func TestErrMethodSigOf(t *testing.T) {
func TestCheckNamed(t *testing.T) {
foo := types.NewPackage("github.com/bar/foo", "foo")
tn := types.NewTypeName(0, foo, "t", nil)
typ := types.NewNamed(tn, types.Typ[types.Int], nil)
if v, ok := checkNamed(types.NewPointer(typ)); !ok || v != typ {
t.Fatal("TestCheckNamed failed:", v, ok)
}
}

func TestErrMethodSig(t *testing.T) {
pkg := NewPackage("", "foo", nil)
foo := types.NewPackage("github.com/bar/foo", "foo")
tn := types.NewTypeName(0, foo, "t", nil)
recv := types.NewNamed(tn, types.Typ[types.Int], nil)
t.Run("Go+ extended method", func(t *testing.T) {
defer func() {
if e := recover(); e != "can't call methodToFunc to Go+ extended method\n" {
t.Fatal("TestErrMethodSigOf:", e)
}
}()
methodSigOf(NewOverloadFunc(0, foo, "foo").Type(), memberFlagMethodToFunc, nil, nil)
t.Run("methodToFuncSig global func", func(t *testing.T) {
fnt := types.NewSignatureType(nil, nil, nil, nil, nil, false)
fn := types.NewFunc(0, foo, "bar", fnt)
if methodToFuncSig(pkg, fn, &internal.Elem{}) != fnt {
t.Fatal("methodToFuncSig failed")
}
})
t.Run("recv not pointer", func(t *testing.T) {
defer func() {
Expand Down
176 changes: 115 additions & 61 deletions codebuild.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"reflect"
"strconv"
"strings"
"syscall"

"github.com/goplus/gox/internal"
"golang.org/x/tools/go/types/typeutil"
Expand Down Expand Up @@ -1798,19 +1799,44 @@ func (p *CodeBuilder) field(
return p.embeddedField(o, name, aliasName, flag, arg, src)
}

func toFuncSig(sig *types.Signature, recv *types.Var) *types.Signature {
sp := sig.Params()
spLen := sp.Len()
vars := make([]*types.Var, spLen+1)
vars[0] = recv
for i := 0; i < spLen; i++ {
vars[i+1] = sp.At(i)
}
return types.NewSignatureType(nil, nil, nil, types.NewTuple(vars...), sig.Results(), sig.Variadic())
}

func methodToFuncSig(pkg *Package, o types.Object, fn *Element) *types.Signature {
sig := o.Type().(*types.Signature)
recv := sig.Recv()
if recv == nil {
fn.Val = toObjectExpr(pkg, o)
return sig
}

sel := fn.Val.(*ast.SelectorExpr)
sel.Sel = ident(o.Name())
sel.X = &ast.ParenExpr{X: sel.X}
return toFuncSig(sig, recv)
}

func methodSigOf(typ types.Type, flag MemberFlag, arg *Element, sel *ast.SelectorExpr) types.Type {
if flag != memberFlagMethodToFunc {
return methodCallSig(typ)
}

sig := typ.(*types.Signature)
if _, ok := CheckFuncEx(sig); ok {
log.Panicln("can't call methodToFunc to Go+ extended method")
return typ
}

at := arg.Type.(*TypeType).typ
recv := sig.Recv().Type()
_, isPtr := recv.(*types.Pointer) // recv is a pointer
at := arg.Type.(*TypeType).typ
if t, ok := at.(*types.Pointer); ok {
if !isPtr {
if _, ok := recv.Underlying().(*types.Interface); !ok { // and recv isn't a interface
Expand All @@ -1823,14 +1849,7 @@ func methodSigOf(typ types.Type, flag MemberFlag, arg *Element, sel *ast.Selecto
}
sel.X = &ast.ParenExpr{X: sel.X}

sp := sig.Params()
spLen := sp.Len()
vars := make([]*types.Var, spLen+1)
vars[0] = types.NewVar(token.NoPos, nil, "", at)
for i := 0; i < spLen; i++ {
vars[i+1] = sp.At(i)
}
return types.NewSignatureType(nil, nil, nil, types.NewTuple(vars...), sig.Results(), sig.Variadic())
return toFuncSig(sig, types.NewVar(token.NoPos, nil, "", at))
}

func methodCallSig(typ types.Type) types.Type {
Expand Down Expand Up @@ -2080,8 +2099,8 @@ func lookupMethod(t *types.Named, name string) types.Object {
return nil
}

func callOpFunc(cb *CodeBuilder, op token.Token, tokenOps []string, args []*internal.Elem, flags InstrFlags) (ret *internal.Elem, err error) {
name := goxPrefix + tokenOps[op]
func doUnaryOp(cb *CodeBuilder, op token.Token, args []*internal.Elem, flags InstrFlags) (ret *internal.Elem, err error) {
name := goxPrefix + unaryOps[op]
pkg := cb.pkg
typ := args[0].Type
retry:
Expand All @@ -2099,49 +2118,112 @@ retry:
typ = t.Elem()
goto retry
}
if op == token.QUO {
checkDivisionByZero(cb, args[0], args[1])
}
if op == token.EQL || op == token.NEQ {
if !ComparableTo(pkg, args[0], args[1]) {
return nil, errors.New("mismatched types")
}
ret = &internal.Elem{
Val: &ast.BinaryExpr{
X: checkParenExpr(args[0].Val), Op: op,
Y: checkParenExpr(args[1].Val),
},
Type: types.Typ[types.UntypedBool],
CVal: binaryOp(cb, op, args),
}
return
}
lm := pkg.builtin.Ref(name)
return matchFuncCall(pkg, toObject(pkg, lm, nil), args, flags)
}

// UnaryOp:
// - cb.UnaryOp(op token.Token)
// - cb.UnaryOp(op token.Token, twoValue bool)
// - cb.UnaryOp(op token.Token, twoValue bool, src ast.Node)
func (p *CodeBuilder) UnaryOp(op token.Token, params ...interface{}) *CodeBuilder {
var src ast.Node
var flags InstrFlags
switch len(params) {
case 2:
src, _ = params[1].(ast.Node)
fallthrough
case 1:
if params[0].(bool) {
flags = InstrFlagTwoValue
}
}
if debugInstr {
log.Println("UnaryOp", op, "flags:", flags)
}
ret, err := doUnaryOp(p, op, p.stk.GetArgs(1), flags)
if err != nil {
panic(err)
}
ret.Src = src
p.stk.Ret(1, ret)
return p
}

// BinaryOp func
func (p *CodeBuilder) BinaryOp(op token.Token, src ...ast.Node) *CodeBuilder {
if debugInstr {
log.Println("BinaryOp", op)
}
expr := getSrc(src)
pkg := p.pkg
name := goxPrefix + binaryOps[op]
args := p.stk.GetArgs(2)

var ret *internal.Elem
var err error
if ret, err = callOpFunc(p, op, binaryOps[:], args, 0); err != nil {
var err error = syscall.ENOENT
isUserDef := false
arg0 := args[0].Type
named0, ok0 := checkNamed(arg0)
if ok0 {
if fn, e := pkg.MethodToFunc(arg0, name, src...); e == nil {
ret, err = matchFuncCall(pkg, fn, args, instrFlagBinaryOp)
isUserDef = true
}
}
if err != nil {
arg1 := args[1].Type
if named1, ok1 := checkNamed(arg1); ok1 && named0 != named1 {
if fn, e := pkg.MethodToFunc(arg1, name, src...); e == nil {
ret, err = matchFuncCall(pkg, fn, args, instrFlagBinaryOp)
isUserDef = true
}
}
}
if err != nil && !isUserDef {
if op == token.QUO {
checkDivisionByZero(p, args[0], args[1])
}
if op == token.EQL || op == token.NEQ {
if !ComparableTo(pkg, args[0], args[1]) {
err = errors.New("mismatched types")
} else {
ret, err = &internal.Elem{
Val: &ast.BinaryExpr{
X: checkParenExpr(args[0].Val), Op: op,
Y: checkParenExpr(args[1].Val),
},
Type: types.Typ[types.UntypedBool],
CVal: binaryOp(p, op, args),
}, nil
}
} else {
lm := pkg.builtin.Ref(name)
ret, err = matchFuncCall(pkg, toObject(pkg, lm, nil), args, 0)
}
}

expr := getSrc(src)
if err != nil {
src, pos := p.loadExpr(expr)
if src == "" {
src = op.String()
}
p.panicCodeErrorf(
pos, "invalid operation: %s (mismatched types %v and %v)", src, args[0].Type, args[1].Type)
pos, "invalid operation: %s (mismatched types %v and %v)", src, arg0, args[1].Type)
}
ret.Src = expr
p.stk.Ret(2, ret)
return p
}

func checkNamed(typ types.Type) (ret *types.Named, ok bool) {
if t, ok := typ.(*types.Pointer); ok {
typ = t.Elem()
}
ret, ok = typ.(*types.Named)
return
}

var (
unaryOps = [...]string{
token.SUB: "Neg",
Expand Down Expand Up @@ -2182,34 +2264,6 @@ func (p *CodeBuilder) CompareNil(op token.Token, src ...ast.Node) *CodeBuilder {
return p.Val(nil).BinaryOp(op)
}

// UnaryOp:
// - cb.UnaryOp(op token.Token)
// - cb.UnaryOp(op token.Token, twoValue bool)
// - cb.UnaryOp(op token.Token, twoValue bool, src ast.Node)
func (p *CodeBuilder) UnaryOp(op token.Token, params ...interface{}) *CodeBuilder {
var src ast.Node
var flags InstrFlags
switch len(params) {
case 2:
src, _ = params[1].(ast.Node)
fallthrough
case 1:
if params[0].(bool) {
flags = InstrFlagTwoValue
}
}
if debugInstr {
log.Println("UnaryOp", op, "flags:", flags)
}
ret, err := callOpFunc(p, op, unaryOps[:], p.stk.GetArgs(1), flags)
if err != nil {
panic(err)
}
ret.Src = src
p.stk.Ret(1, ret)
return p
}

// Send func
func (p *CodeBuilder) Send() *CodeBuilder {
if debugInstr {
Expand Down
3 changes: 2 additions & 1 deletion func.go
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,9 @@ const (
InstrFlagTwoValue

instrFlagApproxType // restricts to all types whose underlying type is T
instrFlagOpFunc // from callOpFunc
instrFlagGopxFunc // call Gopx_xxx functions
instrFlagOpFunc // from callOpFunc
instrFlagBinaryOp // from cb.BinaryOp
)

type Instruction interface {
Expand Down
28 changes: 22 additions & 6 deletions gop_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,23 @@ func TestBigInt(t *testing.T) {
import "github.com/goplus/gox/internal/builtin"
var a, b builtin.Gop_bigint
var c builtin.Gop_bigint = a.Gop_Add(b)
var c builtin.Gop_bigint = (builtin.Gop_bigint).Gop_Add(a, b)
`)
}

func TestBigInt2(t *testing.T) {
pkg := newGopMainPackage()
big := pkg.Import("github.com/goplus/gox/internal/builtin")
typ := types.NewPointer(big.Ref("Gop_bigint").Type())
pkg.CB().NewVar(typ, "a", "b")
pkg.CB().NewVarStart(typ, "c").
VarVal("a").VarVal("b").BinaryOp(token.AND_NOT).EndInit(1)
domTest(t, pkg, `package main
import "github.com/goplus/gox/internal/builtin"
var a, b *builtin.Gop_bigint
var c *builtin.Gop_bigint = (*builtin.Gop_bigint).Gop_AndNot__0(a, b)
`)
}

Expand All @@ -151,7 +167,7 @@ func TestBigRat(t *testing.T) {
import "github.com/goplus/gox/internal/builtin"
var a, b builtin.Gop_bigrat
var c builtin.Gop_bigrat = a.Gop_Quo(b)
var c builtin.Gop_bigrat = (builtin.Gop_bigrat).Gop_Quo(a, b)
var d builtin.Gop_bigrat = a.Gop_Neg()
var e builtin.Gop_bigrat = builtin.Gop_bigrat_Cast__5()
var f builtin.Gop_bigrat = builtin.Gop_bigrat_Cast__3(1, 2)
Expand Down Expand Up @@ -567,7 +583,7 @@ import (
)
var a builtin.Gop_bigrat
var b = a.Gop_Add(builtin.Gop_bigrat_Init__2(big.NewRat(1, 6)))
var b = (builtin.Gop_bigrat).Gop_Add(a, builtin.Gop_bigrat_Init__2(big.NewRat(1, 6)))
`)
}

Expand All @@ -585,7 +601,7 @@ func TestUntypedBigRatAdd5(t *testing.T) {
import "github.com/goplus/gox/internal/builtin"
var a builtin.Gop_bigrat
var b = a.Gop_Add(builtin.Gop_bigrat_Init__0(100))
var b = (builtin.Gop_bigrat).Gop_Add(a, builtin.Gop_bigrat_Init__0(100))
`)
}

Expand All @@ -603,7 +619,7 @@ func TestUntypedBigRatAdd6(t *testing.T) {
import "github.com/goplus/gox/internal/builtin"
var a builtin.Gop_bigrat
var b = builtin.Gop_bigrat_Init__0(100) + a
var b = (builtin.Gop_bigrat).Gop_Add(builtin.Gop_bigrat_Init__0(100), a)
`)
}

Expand Down Expand Up @@ -642,7 +658,7 @@ import (
)
var a builtin.Gop_bigrat
var b = a.Gop_Sub__0(builtin.Gop_bigrat_Init__2(big.NewRat(1, 6)))
var b = (builtin.Gop_bigrat).Gop_Sub__0(a, builtin.Gop_bigrat_Init__2(big.NewRat(1, 6)))
`)
}

Expand Down
4 changes: 2 additions & 2 deletions internal/builtin/big.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,8 @@ func (a Gop_bigint) Gop_And(b Gop_bigint) Gop_bigint {
}

// Gop_AndNot: func (a bigint) &^ (b bigint) bigint
func (a Gop_bigint) Gop_AndNot(b Gop_bigint) Gop_bigint {
return Gop_bigint{tmpint(a, b).AndNot(a.Int, b.Int)}
func (a *Gop_bigint) Gop_AndNot__0(b *Gop_bigint) *Gop_bigint {
return a
}

// Gop_Lsh: func (a bigint) << (n untyped_uint) bigint
Expand Down
Loading

0 comments on commit 49986c4

Please sign in to comment.