diff --git a/bigquery/integration_test.go b/bigquery/integration_test.go index 81c1b123444c..9ad8b9584b07 100644 --- a/bigquery/integration_test.go +++ b/bigquery/integration_test.go @@ -1421,7 +1421,88 @@ func TestIntegration_Load(t *testing.T) { t.Fatal(err) } checkReadAndTotalRows(t, "reader load", table.Read(ctx), wantRows) +} + +func TestIntegration_LoadWithSessionSupport(t *testing.T) { + if client == nil { + t.Skip("Integration tests skipped") + } + + ctx := context.Background() + sessionDataset := client.Dataset("_SESSION") + sessionTable := sessionDataset.Table("test_temp_destination_table") + + schema := Schema{ + {Name: "username", Type: StringFieldType, Required: false}, + {Name: "tweet", Type: StringFieldType, Required: false}, + {Name: "timestamp", Type: StringFieldType, Required: false}, + {Name: "likes", Type: IntegerFieldType, Required: false}, + } + sourceURIs := []string{ + "gs://cloud-samples-data/bigquery/federated-formats-reference-file-schema/a-twitter.parquet", + } + + source := NewGCSReference(sourceURIs...) + source.SourceFormat = Parquet + source.Schema = schema + loader := sessionTable.LoaderFrom(source) + loader.CreateSession = true + loader.CreateDisposition = CreateIfNeeded + + job, err := loader.Run(ctx) + if err != nil { + t.Fatalf("loader.Run: %v", err) + } + err = wait(ctx, job) + if err != nil { + t.Fatalf("wait: %v", err) + } + sessionInfo := job.lastStatus.Statistics.SessionInfo + if sessionInfo == nil { + t.Fatalf("empty job.lastStatus.Statistics.SessionInfo: %v", sessionInfo) + } + + sessionID := sessionInfo.SessionID + loaderWithSession := sessionTable.LoaderFrom(source) + loaderWithSession.CreateDisposition = CreateIfNeeded + loaderWithSession.ConnectionProperties = []*ConnectionProperty{ + { + Key: "session_id", + Value: sessionID, + }, + } + jobWithSession, err := loaderWithSession.Run(ctx) + if err != nil { + t.Fatalf("loaderWithSession.Run: %v", err) + } + err = wait(ctx, jobWithSession) + if err != nil { + t.Fatalf("wait: %v", err) + } + + sessionJobInfo := jobWithSession.lastStatus.Statistics.SessionInfo + if sessionJobInfo == nil { + t.Fatalf("empty jobWithSession.lastStatus.Statistics.SessionInfo: %v", sessionJobInfo) + } + + if sessionID != sessionJobInfo.SessionID { + t.Fatalf("expected session ID %q, but found %q", sessionID, sessionJobInfo.SessionID) + } + + sql := "SELECT * FROM _SESSION.test_temp_destination_table;" + q := client.Query(sql) + q.ConnectionProperties = []*ConnectionProperty{ + { + Key: "session_id", + Value: sessionID, + }, + } + sessionQueryJob, err := q.Run(ctx) + err = wait(ctx, sessionQueryJob) + if err != nil { + t.Fatalf("wait: %v", err) + } } func TestIntegration_LoadWithReferenceSchemaFile(t *testing.T) { diff --git a/bigquery/load.go b/bigquery/load.go index cf515beefe1e..8af650d7a869 100644 --- a/bigquery/load.go +++ b/bigquery/load.go @@ -92,6 +92,15 @@ type LoadConfig struct { // When loading a table with external data, the user can provide a reference file with the table schema. // This is enabled for the following formats: AVRO, PARQUET, ORC. ReferenceFileSchemaURI string + + // If true, creates a new session, where session id will + // be a server generated random id. If false, runs query with an + // existing session_id passed in ConnectionProperty, otherwise runs the + // load job in non-session mode. + CreateSession bool + + // ConnectionProperties are optional key-values settings. + ConnectionProperties []*ConnectionProperty } func (l *LoadConfig) toBQ() (*bq.JobConfiguration, io.Reader) { @@ -110,12 +119,16 @@ func (l *LoadConfig) toBQ() (*bq.JobConfiguration, io.Reader) { ProjectionFields: l.ProjectionFields, HivePartitioningOptions: l.HivePartitioningOptions.toBQ(), ReferenceFileSchemaUri: l.ReferenceFileSchemaURI, + CreateSession: l.CreateSession, }, JobTimeoutMs: l.JobTimeout.Milliseconds(), } for _, v := range l.DecimalTargetTypes { config.Load.DecimalTargetTypes = append(config.Load.DecimalTargetTypes, string(v)) } + for _, v := range l.ConnectionProperties { + config.Load.ConnectionProperties = append(config.Load.ConnectionProperties, v.toBQ()) + } media := l.Src.populateLoadConfig(config.Load) return config, media } @@ -135,6 +148,7 @@ func bqToLoadConfig(q *bq.JobConfiguration, c *Client) *LoadConfig { ProjectionFields: q.Load.ProjectionFields, HivePartitioningOptions: bqToHivePartitioningOptions(q.Load.HivePartitioningOptions), ReferenceFileSchemaURI: q.Load.ReferenceFileSchemaUri, + CreateSession: q.Load.CreateSession, } if q.JobTimeoutMs > 0 { lc.JobTimeout = time.Duration(q.JobTimeoutMs) * time.Millisecond @@ -142,6 +156,9 @@ func bqToLoadConfig(q *bq.JobConfiguration, c *Client) *LoadConfig { for _, v := range q.Load.DecimalTargetTypes { lc.DecimalTargetTypes = append(lc.DecimalTargetTypes, DecimalTargetType(v)) } + for _, v := range q.Load.ConnectionProperties { + lc.ConnectionProperties = append(lc.ConnectionProperties, bqToConnectionProperty(v)) + } var fc *FileConfig if len(q.Load.SourceUris) == 0 { s := NewReaderSource(nil) diff --git a/bigquery/load_test.go b/bigquery/load_test.go index 7ef95310fadf..650ed066886f 100644 --- a/bigquery/load_test.go +++ b/bigquery/load_test.go @@ -405,6 +405,33 @@ func TestLoad(t *testing.T) { return j }(), }, + { + dst: c.Dataset("dataset-id").Table("table-id"), + src: func() *GCSReference { + g := NewGCSReference("uri") + return g + }(), + config: LoadConfig{ + CreateSession: true, + ConnectionProperties: []*ConnectionProperty{ + { + Key: "session_id", + Value: "session_id_1234567890", + }, + }, + }, + want: func() *bq.Job { + j := defaultLoadJob() + j.Configuration.Load.CreateSession = true + j.Configuration.Load.ConnectionProperties = []*bq.ConnectionProperty{ + { + Key: "session_id", + Value: "session_id_1234567890", + }, + } + return j + }(), + }, } for i, tc := range testCases { diff --git a/bigquery/query.go b/bigquery/query.go index 7ac35f88af3d..7030cfd50c4f 100644 --- a/bigquery/query.go +++ b/bigquery/query.go @@ -487,7 +487,7 @@ func (q *Query) probeFastPath() (*bq.QueryRequest, error) { return qRequest, nil } -// ConnectionProperty represents a single key and value pair that can be sent alongside a query request. +// ConnectionProperty represents a single key and value pair that can be sent alongside a query request or load job. type ConnectionProperty struct { // Name of the connection property to set. Key string