Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

compare and convert system types properly #2700

Merged
merged 2 commits into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions enginetest/queries/script_queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -7499,6 +7499,22 @@ where
},
},
},
{
Name: "coalesce with system types",
SetUpScript: []string{
"create table t as select @@admin_port as port1, @@port as port2, COALESCE(@@admin_port, @@port) as\n port3;",
},
Assertions: []ScriptTestAssertion{
{
Query: "describe t;",
Expected: []sql.Row{
{"port1", "bigint", "NO", "", nil, ""},
{"port2", "bigint", "NO", "", nil, ""},
{"port3", "bigint", "NO", "", nil, ""},
},
},
},
},
}

var SpatialScriptTests = []ScriptTest{
Expand Down
4 changes: 4 additions & 0 deletions sql/analyzer/resolve_create_select.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ func resolveCreateSelect(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.
for i, col := range mergedSchema {
tempCol := *col
tempCol.Source = ct.Name()
// replace system variable types with their underlying types
if sysType, isSysTyp := tempCol.Type.(sql.SystemVariableType); isSysTyp {
tempCol.Type = sysType.UnderlyingType()
}
newSch[i] = &tempCol
}

Expand Down
103 changes: 58 additions & 45 deletions sql/expression/function/coalesce.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@ import (
"fmt"
"strings"

"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/expression"

"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/types"
)

Expand Down Expand Up @@ -53,61 +54,73 @@ func (c *Coalesce) Description() string {
// Type implements the sql.Expression interface.
// The return type of Type() is the aggregated type of the argument types.
func (c *Coalesce) Type() sql.Type {
typ := types.Null
for _, arg := range c.args {
retType := types.Null
for i, arg := range c.args {
if arg == nil {
continue
}
t := arg.Type()
argType := arg.Type()
if sysVarType, ok := argType.(sql.SystemVariableType); ok {
argType = sysVarType.UnderlyingType()
}
if i == 0 {
retType = argType
continue
}
if argType == nil || argType == types.Null {
continue
}
if retType.Equals(argType) {
continue
}

// special case for signed and unsigned integers
if (types.IsSigned(typ) && types.IsUnsigned(t)) || (types.IsUnsigned(typ) && types.IsSigned(t)) {
typ = types.MustCreateDecimalType(20, 0)
if (types.IsSigned(retType) && types.IsUnsigned(argType)) || (types.IsUnsigned(retType) && types.IsSigned(argType)) {
retType = types.MustCreateDecimalType(20, 0)
continue
}

if t != nil && t != types.Null {
convType := expression.GetConvertToType(typ, t)
switch convType {
case expression.ConvertToChar:
// special case for float64s
if (t == types.Float64 || typ == types.Float64) && !types.IsText(t) && !types.IsText(typ) {
typ = types.Float64
continue
}
// Can't get any larger than this
return types.LongText
case expression.ConvertToDecimal:
if typ == types.Float64 || t == types.Float64 {
typ = types.Float64
} else if types.IsDecimal(t) {
typ = t
} else if !types.IsDecimal(typ) {
typ = types.MustCreateDecimalType(10, 0)
}
case expression.ConvertToUnsigned:
if typ == types.Uint64 || t == types.Uint64 {
typ = types.Uint64
} else {
typ = types.Uint32
}
case expression.ConvertToSigned:
if typ == types.Int64 || t == types.Int64 {
typ = types.Int64
} else {
typ = types.Int32
}
case expression.ConvertToFloat:
if typ == types.Float64 || t == types.Float64 {
typ = types.Float64
} else {
typ = types.Float32
}
default:
convType := expression.GetConvertToType(retType, argType)
switch convType {
case expression.ConvertToChar:
// special case for float64s
if (argType == types.Float64 || retType == types.Float64) && !types.IsText(argType) && !types.IsText(retType) {
retType = types.Float64
continue
}
// Can't get any larger than this
return types.LongText
case expression.ConvertToDecimal:
if retType == types.Float64 || argType == types.Float64 {
retType = types.Float64
} else if types.IsDecimal(argType) {
retType = argType
} else if !types.IsDecimal(retType) {
retType = types.MustCreateDecimalType(10, 0)
}
case expression.ConvertToUnsigned:
if retType == types.Uint64 || argType == types.Uint64 {
retType = types.Uint64
} else {
retType = types.Uint32
}
case expression.ConvertToSigned:
if retType == types.Int64 || argType == types.Int64 {
retType = types.Int64
} else {
retType = types.Int32
}
case expression.ConvertToFloat:
if retType == types.Float64 || argType == types.Float64 {
retType = types.Float64
} else {
retType = types.Float32
}
default:
}
}

return typ
return retType
}

// CollationCoercibility implements the interface sql.CollationCoercible.
Expand Down
83 changes: 76 additions & 7 deletions sql/expression/function/coalesce_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,17 +152,86 @@ func TestCoalesce(t *testing.T) {
typ: types.Float64,
nullable: false,
},
{
name: "coalesce(sysInt, sysInt)",
input: []sql.Expression{
expression.NewLiteral(1, types.NewSystemIntType("int1", 0, 10, false)),
expression.NewLiteral(2, types.NewSystemIntType("int2", 0, 10, false)),
},
expected: 1,
typ: types.Int64,
nullable: false,
},
{
name: "coalesce(sysInt, sysUint)",
input: []sql.Expression{
expression.NewLiteral(1, types.NewSystemIntType("int1", 0, 10, false)),
expression.NewLiteral(2, types.NewSystemUintType("int2", 0, 10)),
},
expected: 1,
typ: types.MustCreateDecimalType(20, 0),
nullable: false,
},
{
name: "coalesce(sysUint, sysUint)",
input: []sql.Expression{
expression.NewLiteral(1, types.NewSystemUintType("int1", 0, 10)),
expression.NewLiteral(2, types.NewSystemUintType("int2", 0, 10)),
},
expected: 1,
typ: types.Uint64,
nullable: false,
},
{
name: "coalesce(sysDouble, sysDouble)",
input: []sql.Expression{
expression.NewLiteral(1.0, types.NewSystemDoubleType("dbl1", 0.0, 10.0)),
expression.NewLiteral(2.0, types.NewSystemDoubleType("dbl2", 0.0, 10.0)),
},
expected: 1.0,
typ: types.Float64,
nullable: false,
},
{
name: "coalesce(sysText)",
input: []sql.Expression{
expression.NewLiteral("abc", types.NewSystemStringType("str1")),
},
expected: "abc",
typ: types.LongText,
nullable: false,
},
{
name: "coalesce(sysEnum)",
input: []sql.Expression{
expression.NewLiteral("abc", types.NewSystemEnumType("str1")),
},
expected: "abc",
typ: types.EnumType{},
nullable: false,
},
{
name: "coalesce(sysSet)",
input: []sql.Expression{
expression.NewLiteral("abc", types.NewSystemSetType("str1", "abc")),
},
expected: "abc",
typ: types.MustCreateSetType([]string{"abc"}, sql.Collation_Default),
nullable: false,
},
}

for _, tt := range testCases {
c, err := NewCoalesce(tt.input...)
require.NoError(t, err)
t.Run(tt.name, func(t *testing.T) {
c, err := NewCoalesce(tt.input...)
require.NoError(t, err)

require.Equal(t, tt.typ, c.Type())
require.Equal(t, tt.nullable, c.IsNullable())
v, err := c.Eval(sql.NewEmptyContext(), nil)
require.NoError(t, err)
require.Equal(t, tt.expected, v)
require.Equal(t, tt.typ, c.Type())
require.Equal(t, tt.nullable, c.IsNullable())
v, err := c.Eval(sql.NewEmptyContext(), nil)
require.NoError(t, err)
require.Equal(t, tt.expected, v)
})
}
}

Expand Down
47 changes: 31 additions & 16 deletions sql/types/typecheck.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,10 @@ func IsNull(ex sql.Expression) bool {

// IsNumber checks if t is a number type
func IsNumber(t sql.Type) bool {
switch t.(type) {
case NumberTypeImpl_, DecimalType_, BitType_, YearType_, SystemBoolType:
switch typ := t.(type) {
case sql.SystemVariableType:
return IsNumber(typ.UnderlyingType())
case NumberTypeImpl_, DecimalType_, BitType_, YearType_:
return true
default:
return false
Expand All @@ -101,23 +103,25 @@ func IsNumber(t sql.Type) bool {

// IsSigned checks if t is a signed type.
func IsSigned(t sql.Type) bool {
// systemBoolType is Int8
if _, ok := t.(SystemBoolType); ok {
return true
if svt, ok := t.(sql.SystemVariableType); ok {
t = svt.UnderlyingType()
}
return t == Int8 || t == Int16 || t == Int24 || t == Int32 || t == Int64 || t == Boolean
}

// IsText checks if t is a CHAR, VARCHAR, TEXT, BINARY, VARBINARY, or BLOB (including TEXT and BLOB variants).
func IsText(t sql.Type) bool {
if _, ok := t.(StringType); ok {
return ok
}
if extendedType, ok := t.(ExtendedType); ok {
_, isString := extendedType.Zero().(string)
switch typ := t.(type) {
case sql.SystemVariableType:
return IsText(typ.UnderlyingType())
case StringType:
return true
case ExtendedType:
_, isString := typ.Zero().(string)
return isString
default:
return false
}
return false
}

// IsTextBlob checks if t is one of the TEXTs or BLOBs.
Expand Down Expand Up @@ -178,14 +182,26 @@ func IsTimestampType(t sql.Type) bool {

// IsEnum checks if t is a enum
func IsEnum(t sql.Type) bool {
_, ok := t.(EnumType)
return ok
switch typ := t.(type) {
case sql.SystemVariableType:
return IsEnum(typ.UnderlyingType())
case EnumType:
return true
default:
return false
}
}

// IsSet checks if t is a set
func IsSet(t sql.Type) bool {
_, ok := t.(SetType)
return ok
switch typ := t.(type) {
case sql.SystemVariableType:
return IsSet(typ.UnderlyingType())
case SetType:
return true
default:
return false
}
}

// IsTuple checks if t is a tuple type.
Expand All @@ -201,7 +217,6 @@ func IsUnsigned(t sql.Type) bool {
if svt, ok := t.(sql.SystemVariableType); ok {
t = svt.UnderlyingType()
}

return t == Uint8 || t == Uint16 || t == Uint24 || t == Uint32 || t == Uint64
}

Expand Down
Loading
Loading