Skip to content

Commit

Permalink
support generic type oveload method
Browse files Browse the repository at this point in the history
  • Loading branch information
visualfc committed Jan 28, 2024
1 parent 49986c4 commit 1dae57f
Show file tree
Hide file tree
Showing 5 changed files with 177 additions and 11 deletions.
16 changes: 12 additions & 4 deletions codebuild.go
Original file line number Diff line number Diff line change
Expand Up @@ -1598,7 +1598,7 @@ retry:
return kind
}
}
if kind := p.method(t, name, aliasName, flag, arg, srcExpr); kind != MemberInvalid {
if kind := p.method(t, name, aliasName, flag, arg, srcExpr, t.TypeArgs() != nil); kind != MemberInvalid {
return kind
}
if fstruc {
Expand All @@ -1614,7 +1614,7 @@ retry:
}
case *types.Named:
named, typ = o, p.getUnderlying(o) // may cause to loadNamed (delay-loaded)
if kind := p.method(o, name, aliasName, flag, arg, srcExpr); kind != MemberInvalid {
if kind := p.method(o, name, aliasName, flag, arg, srcExpr, o.TypeArgs() != nil); kind != MemberInvalid {
return kind
}
if _, ok := typ.(*types.Struct); ok {
Expand All @@ -1630,7 +1630,7 @@ retry:
}
case *types.Interface:
o.Complete()
if kind := p.method(o, name, aliasName, flag, arg, srcExpr); kind != MemberInvalid {
if kind := p.method(o, name, aliasName, flag, arg, srcExpr, false); kind != MemberInvalid {
return kind
}
case *types.Basic, *types.Slice, *types.Map, *types.Chan:
Expand All @@ -1640,6 +1640,7 @@ retry:
}

type methodList interface {
types.Type
NumMethods() int
Method(i int) *types.Func
}
Expand All @@ -1666,7 +1667,7 @@ func (p *CodeBuilder) allowAccess(pkg *types.Package, name string) bool {
}

func (p *CodeBuilder) method(
o methodList, name, aliasName string, flag MemberFlag, arg *Element, src ast.Node) (kind MemberKind) {
o methodList, name, aliasName string, flag MemberFlag, arg *Element, src ast.Node, namedHasTypeArgs bool) (kind MemberKind) {
var found *types.Func
var exact bool
for i, n := 0, o.NumMethods(); i < n; i++ {
Expand All @@ -1691,6 +1692,13 @@ func (p *CodeBuilder) method(
if autoprop && !methodHasAutoProperty(typ, 0) {
return memberBad
}
if namedHasTypeArgs {
if t, ok := CheckFuncEx(typ.(*types.Signature)); ok {
if m, ok := t.(*TyOverloadMethod); ok && m.IsGeneric() {
typ = m.Instantiate(o.(*types.Named))
}
}
}
sel := selector(arg, found.Name())
p.stk.Ret(1, &internal.Elem{
Val: sel,
Expand Down
46 changes: 43 additions & 3 deletions func_ext.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,9 @@ func CheckOverloadFunc(sig *types.Signature) (funcs []types.Object, ok bool) {

// TyOverloadMethod: overload function type
type TyOverloadMethod struct {
Methods []types.Object
Methods []types.Object
indexs []int // func object indexs
instance map[*types.Named]*types.Signature // cache type signature for named
}

func (p *TyOverloadMethod) At(i int) types.Object { return p.Methods[i] }
Expand All @@ -127,8 +129,46 @@ func (p *TyOverloadMethod) Underlying() types.Type { return p }
func (p *TyOverloadMethod) String() string { return "TyOverloadMethod" }
func (p *TyOverloadMethod) funcEx() {}

func NewOverloadMethod(typ *types.Named, pos token.Pos, pkg *types.Package, name string, methods ...types.Object) *types.Func {
return newMethodEx(typ, pos, pkg, name, &TyOverloadMethod{methods})
func NewOverloadMethod(typ *types.Named, pos token.Pos, pkg *types.Package, name string, objectIndex map[types.Object]int, methods ...types.Object) *types.Func {
t := &TyOverloadMethod{Methods: methods}
if typ.TypeParams() != nil {
t.indexs = make([]int, len(methods))
for i, obj := range methods {
t.indexs[i] = objectIndex[obj]
}
t.instance = make(map[*types.Named]*types.Signature)
}
return newMethodEx(typ, pos, pkg, name, t)
}

func (m *TyOverloadMethod) IsGeneric() bool {
return len(m.indexs) != 0
}

func (m *TyOverloadMethod) Instantiate(named *types.Named) *types.Signature {
sig, ok := m.instance[named]
if !ok {
sig = newOverloadMethodType(named, m)
m.instance[named] = sig
}
return sig
}

func newOverloadMethodType(named *types.Named, m *TyOverloadMethod) *types.Signature {
var list methodList
switch t := named.Underlying().(type) {
case *types.Interface:
list = t
default:
list = named
}
pkg := named.Obj().Pkg()
recv := types.NewVar(token.NoPos, pkg, "", named)
methods := make([]types.Object, len(m.indexs))
for i, index := range m.indexs {
methods[i] = list.Method(index)
}
return sigFuncEx(pkg, recv, &TyOverloadMethod{Methods: methods})
}

func CheckOverloadMethod(sig *types.Signature) (methods []types.Object, ok bool) {
Expand Down
10 changes: 6 additions & 4 deletions import.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ func initThisGopPkg(pkg *types.Package) {
}
gopos := make([]string, 0, 4)
overloads := make(map[omthd][]types.Object)
mobjectIndexs := make(map[types.Object]int)
onameds := make(map[string][]*types.Named)
names := scope.Names()
for _, name := range names {
Expand All @@ -133,6 +134,7 @@ func initThisGopPkg(pkg *types.Package) {
mthd := mName[:len(mName)-3]
key := omthd{named, mthd}
overloads[key] = append(overloads[key], m)
mobjectIndexs[m] = i
}
}
if isOverload(name) { // overload named
Expand Down Expand Up @@ -160,14 +162,14 @@ func initThisGopPkg(pkg *types.Package) {
}
fns[i] = lookupFunc(scope, name, tname)
}
newOverload(pkg, scope, m, fns)
newOverload(pkg, scope, m, fns, nil)
delete(overloads, m)
}
}
for key, items := range overloads {
off := len(key.name) + 2
fns := overloadFuncs(off, items)
newOverload(pkg, scope, key, fns)
newOverload(pkg, scope, key, fns, mobjectIndexs)
}
for name, items := range onameds {
off := len(name) + 2
Expand Down Expand Up @@ -290,7 +292,7 @@ func checkOverloads(scope *types.Scope, gopoName string) (ret []string, exists b
return
}

func newOverload(pkg *types.Package, scope *types.Scope, m omthd, fns []types.Object) {
func newOverload(pkg *types.Package, scope *types.Scope, m omthd, fns []types.Object, mobjectIndexs map[types.Object]int) {
if m.typ == nil {
if debugImport {
log.Println("==> NewOverloadFunc", m.name)
Expand All @@ -302,7 +304,7 @@ func newOverload(pkg *types.Package, scope *types.Scope, m omthd, fns []types.Ob
if debugImport {
log.Println("==> NewOverloadMethod", m.typ.Obj().Name(), m.name)
}
NewOverloadMethod(m.typ, token.NoPos, pkg, m.name, fns...)
NewOverloadMethod(m.typ, token.NoPos, pkg, m.name, mobjectIndexs, fns...)
}
}

Expand Down
32 changes: 32 additions & 0 deletions internal/foo/foo.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,4 +116,36 @@ type NodeSeter interface {
Attr__1(k, v string) (ret NodeSeter)
}

type Data[T any] struct {
data []T
}

func (p *Data[T]) Size() int {
return len(p.data)
}

func (p *Data[T]) Add__0(v ...T) {
p.data = append(p.data, v...)
}

func (p *Data[T]) Add__1(v Data[T]) {
p.data = append(p.data, v.data...)
}

func (p *Data[T]) IndexOf__0(v T) int {
return -1
}

func (p *Data[T]) IndexOf__1(pos int, v T) int {
return -1
}

type DataInterface[T any] interface {
Size() int
Add__0(v ...T)
Add__1(v DataInterface[T])
IndexOf__0(v T) int
IndexOf__1(pos int, v T) int
}

// -----------------------------------------------------------------------------
84 changes: 84 additions & 0 deletions typeparams_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -962,3 +962,87 @@ func main() {
}
`)
}

func TestGenericTypeOverloadMethod(t *testing.T) {
pkg := newMainPackage()
foo := pkg.Import("github.com/goplus/gox/internal/foo")
tyDataT := foo.Ref("Data").Type()
tyInt := types.Typ[types.Int]
tyData, _ := types.Instantiate(nil, tyDataT, []types.Type{tyInt}, true)
v := pkg.NewParam(token.NoPos, "v", tyData)
pkg.NewFunc(nil, "bar", types.NewTuple(v), nil, false).BodyStart(pkg).
DefineVarStart(token.NoPos, "n").Val(v).
Debug(func(cb *gox.CodeBuilder) {
cb.Member("size", gox.MemberFlagMethodAlias)
}).
Call(0).EndInit(1).EndStmt().
Val(v).
Debug(func(cb *gox.CodeBuilder) {
cb.Member("add", gox.MemberFlagMethodAlias)
}).
Val(0).Val(1).Call(2).EndStmt().
Val(v).
Debug(func(cb *gox.CodeBuilder) {
cb.Member("add", gox.MemberFlagMethodAlias)
}).
Val(v).Call(1).EndStmt().
DefineVarStart(token.NoPos, "i").Val(v).
Debug(func(cb *gox.CodeBuilder) {
cb.Member("indexOf", gox.MemberFlagMethodAlias)
}).
Val(0).Val(1).Call(2).EndInit(1).EndStmt().
End()
domTest(t, pkg, `package main
import "github.com/goplus/gox/internal/foo"
func bar(v foo.Data[int]) {
n := v.Size()
v.Add__0(0, 1)
v.Add__1(v)
i := v.IndexOf__1(0, 1)
}
`)
}

func TestGenericInterfaceOverloadMethod(t *testing.T) {
pkg := newMainPackage()
foo := pkg.Import("github.com/goplus/gox/internal/foo")
tyDataT := foo.Ref("DataInterface").Type()
tyInt := types.Typ[types.Int]
tyData, _ := types.Instantiate(nil, tyDataT, []types.Type{tyInt}, true)
v := pkg.NewParam(token.NoPos, "v", tyData)
pkg.NewFunc(nil, "bar", types.NewTuple(v), nil, false).BodyStart(pkg).
DefineVarStart(token.NoPos, "n").Val(v).
Debug(func(cb *gox.CodeBuilder) {
cb.Member("size", gox.MemberFlagMethodAlias)
}).
Call(0).EndInit(1).EndStmt().
Val(v).
Debug(func(cb *gox.CodeBuilder) {
cb.Member("add", gox.MemberFlagMethodAlias)
}).
Val(0).Val(1).Call(2).EndStmt().
Val(v).
Debug(func(cb *gox.CodeBuilder) {
cb.Member("add", gox.MemberFlagMethodAlias)
}).
Val(v).Call(1).EndStmt().
DefineVarStart(token.NoPos, "i").Val(v).
Debug(func(cb *gox.CodeBuilder) {
cb.Member("indexOf", gox.MemberFlagMethodAlias)
}).
Val(0).Val(1).Call(2).EndInit(1).EndStmt().
End()
domTest(t, pkg, `package main
import "github.com/goplus/gox/internal/foo"
func bar(v foo.DataInterface[int]) {
n := v.Size()
v.Add__0(0, 1)
v.Add__1(v)
i := v.IndexOf__1(0, 1)
}
`)
}

0 comments on commit 1dae57f

Please sign in to comment.