Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add default LoadJobConfig to Client #1526

Merged
merged 1 commit into from
Mar 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 71 additions & 50 deletions google/cloud/bigquery/client.py
chelsea-lin marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,9 @@ class Client(ClientWithProject):
default_query_job_config (Optional[google.cloud.bigquery.job.QueryJobConfig]):
Default ``QueryJobConfig``.
Will be merged into job configs passed into the ``query`` method.
default_load_job_config (Optional[google.cloud.bigquery.job.LoadJobConfig]):
Default ``LoadJobConfig``.
Will be merged into job configs passed into the ``load_table_*`` methods.
client_info (Optional[google.api_core.client_info.ClientInfo]):
The client info used to send a user-agent string along with API
requests. If ``None``, then default info will be used. Generally,
Expand All @@ -235,6 +238,7 @@ def __init__(
_http=None,
location=None,
default_query_job_config=None,
default_load_job_config=None,
client_info=None,
client_options=None,
) -> None:
Expand All @@ -260,6 +264,7 @@ def __init__(
self._connection = Connection(self, **kw_args)
self._location = location
self._default_query_job_config = copy.deepcopy(default_query_job_config)
self._default_load_job_config = copy.deepcopy(default_load_job_config)

@property
def location(self):
Expand All @@ -277,6 +282,17 @@ def default_query_job_config(self):
def default_query_job_config(self, value: QueryJobConfig):
self._default_query_job_config = copy.deepcopy(value)

@property
def default_load_job_config(self):
"""Default ``LoadJobConfig``.
Will be merged into job configs passed into the ``load_table_*`` methods.
"""
return self._default_load_job_config

@default_load_job_config.setter
def default_load_job_config(self, value: LoadJobConfig):
self._default_load_job_config = copy.deepcopy(value)

def close(self):
"""Close the underlying transport objects, releasing system resources.

Expand Down Expand Up @@ -2330,8 +2346,8 @@ def load_table_from_uri(

Raises:
TypeError:
If ``job_config`` is not an instance of :class:`~google.cloud.bigquery.job.LoadJobConfig`
class.
If ``job_config`` is not an instance of
:class:`~google.cloud.bigquery.job.LoadJobConfig` class.
"""
job_id = _make_job_id(job_id, job_id_prefix)

Expand All @@ -2348,11 +2364,14 @@ def load_table_from_uri(

destination = _table_arg_to_table_ref(destination, default_project=self.project)

if job_config:
job_config = copy.deepcopy(job_config)
_verify_job_config_type(job_config, google.cloud.bigquery.job.LoadJobConfig)
if job_config is not None:
_verify_job_config_type(job_config, LoadJobConfig)
else:
job_config = job.LoadJobConfig()

load_job = job.LoadJob(job_ref, source_uris, destination, self, job_config)
new_job_config = job_config._fill_from_default(self._default_load_job_config)

load_job = job.LoadJob(job_ref, source_uris, destination, self, new_job_config)
load_job._begin(retry=retry, timeout=timeout)

return load_job
Expand Down Expand Up @@ -2424,8 +2443,8 @@ def load_table_from_file(
mode.

TypeError:
If ``job_config`` is not an instance of :class:`~google.cloud.bigquery.job.LoadJobConfig`
class.
If ``job_config`` is not an instance of
:class:`~google.cloud.bigquery.job.LoadJobConfig` class.
"""
job_id = _make_job_id(job_id, job_id_prefix)

Expand All @@ -2437,10 +2456,15 @@ def load_table_from_file(

destination = _table_arg_to_table_ref(destination, default_project=self.project)
job_ref = job._JobReference(job_id, project=project, location=location)
if job_config:
job_config = copy.deepcopy(job_config)
_verify_job_config_type(job_config, google.cloud.bigquery.job.LoadJobConfig)
load_job = job.LoadJob(job_ref, None, destination, self, job_config)

if job_config is not None:
_verify_job_config_type(job_config, LoadJobConfig)
else:
job_config = job.LoadJobConfig()

new_job_config = job_config._fill_from_default(self._default_load_job_config)

load_job = job.LoadJob(job_ref, None, destination, self, new_job_config)
job_resource = load_job.to_api_repr()

if rewind:
Expand Down Expand Up @@ -2564,43 +2588,40 @@ def load_table_from_dataframe(
If a usable parquet engine cannot be found. This method
requires :mod:`pyarrow` to be installed.
TypeError:
If ``job_config`` is not an instance of :class:`~google.cloud.bigquery.job.LoadJobConfig`
class.
If ``job_config`` is not an instance of
:class:`~google.cloud.bigquery.job.LoadJobConfig` class.
"""
job_id = _make_job_id(job_id, job_id_prefix)

if job_config:
_verify_job_config_type(job_config, google.cloud.bigquery.job.LoadJobConfig)
# Make a copy so that the job config isn't modified in-place.
job_config_properties = copy.deepcopy(job_config._properties)
job_config = job.LoadJobConfig()
job_config._properties = job_config_properties

if job_config is not None:
_verify_job_config_type(job_config, LoadJobConfig)
else:
job_config = job.LoadJobConfig()

new_job_config = job_config._fill_from_default(self._default_load_job_config)

supported_formats = {job.SourceFormat.CSV, job.SourceFormat.PARQUET}
if job_config.source_format is None:
if new_job_config.source_format is None:
# default value
job_config.source_format = job.SourceFormat.PARQUET
new_job_config.source_format = job.SourceFormat.PARQUET

if (
job_config.source_format == job.SourceFormat.PARQUET
and job_config.parquet_options is None
new_job_config.source_format == job.SourceFormat.PARQUET
and new_job_config.parquet_options is None
):
parquet_options = ParquetOptions()
# default value
parquet_options.enable_list_inference = True
job_config.parquet_options = parquet_options
new_job_config.parquet_options = parquet_options

if job_config.source_format not in supported_formats:
if new_job_config.source_format not in supported_formats:
raise ValueError(
"Got unexpected source_format: '{}'. Currently, only PARQUET and CSV are supported".format(
job_config.source_format
new_job_config.source_format
)
)

if pyarrow is None and job_config.source_format == job.SourceFormat.PARQUET:
if pyarrow is None and new_job_config.source_format == job.SourceFormat.PARQUET:
# pyarrow is now the only supported parquet engine.
raise ValueError("This method requires pyarrow to be installed")

Expand All @@ -2611,8 +2632,8 @@ def load_table_from_dataframe(
# schema, and check if dataframe schema is compatible with it - except
# for WRITE_TRUNCATE jobs, the existing schema does not matter then.
if (
not job_config.schema
and job_config.write_disposition != job.WriteDisposition.WRITE_TRUNCATE
not new_job_config.schema
and new_job_config.write_disposition != job.WriteDisposition.WRITE_TRUNCATE
):
try:
table = self.get_table(destination)
Expand All @@ -2623,7 +2644,7 @@ def load_table_from_dataframe(
name
for name, _ in _pandas_helpers.list_columns_and_indexes(dataframe)
)
job_config.schema = [
new_job_config.schema = [
# Field description and policy tags are not needed to
# serialize a data frame.
SchemaField(
Expand All @@ -2637,11 +2658,11 @@ def load_table_from_dataframe(
if field.name in columns_and_indexes
]

job_config.schema = _pandas_helpers.dataframe_to_bq_schema(
dataframe, job_config.schema
new_job_config.schema = _pandas_helpers.dataframe_to_bq_schema(
dataframe, new_job_config.schema
)

if not job_config.schema:
if not new_job_config.schema:
# the schema could not be fully detected
warnings.warn(
"Schema could not be detected for all columns. Loading from a "
Expand All @@ -2652,13 +2673,13 @@ def load_table_from_dataframe(
)

tmpfd, tmppath = tempfile.mkstemp(
suffix="_job_{}.{}".format(job_id[:8], job_config.source_format.lower())
suffix="_job_{}.{}".format(job_id[:8], new_job_config.source_format.lower())
)
os.close(tmpfd)

try:

if job_config.source_format == job.SourceFormat.PARQUET:
if new_job_config.source_format == job.SourceFormat.PARQUET:
if _PYARROW_VERSION in _PYARROW_BAD_VERSIONS:
msg = (
"Loading dataframe data in PARQUET format with pyarrow "
Expand All @@ -2669,13 +2690,13 @@ def load_table_from_dataframe(
)
warnings.warn(msg, category=RuntimeWarning)

if job_config.schema:
if new_job_config.schema:
if parquet_compression == "snappy": # adjust the default value
parquet_compression = parquet_compression.upper()

_pandas_helpers.dataframe_to_parquet(
dataframe,
job_config.schema,
new_job_config.schema,
tmppath,
parquet_compression=parquet_compression,
parquet_use_compliant_nested_type=True,
Expand Down Expand Up @@ -2715,7 +2736,7 @@ def load_table_from_dataframe(
job_id_prefix=job_id_prefix,
location=location,
project=project,
job_config=job_config,
job_config=new_job_config,
timeout=timeout,
)

Expand Down Expand Up @@ -2791,22 +2812,22 @@ def load_table_from_json(

Raises:
TypeError:
If ``job_config`` is not an instance of :class:`~google.cloud.bigquery.job.LoadJobConfig`
class.
If ``job_config`` is not an instance of
:class:`~google.cloud.bigquery.job.LoadJobConfig` class.
"""
job_id = _make_job_id(job_id, job_id_prefix)

if job_config:
_verify_job_config_type(job_config, google.cloud.bigquery.job.LoadJobConfig)
# Make a copy so that the job config isn't modified in-place.
job_config = copy.deepcopy(job_config)
if job_config is not None:
_verify_job_config_type(job_config, LoadJobConfig)
else:
job_config = job.LoadJobConfig()

job_config.source_format = job.SourceFormat.NEWLINE_DELIMITED_JSON
new_job_config = job_config._fill_from_default(self._default_load_job_config)

new_job_config.source_format = job.SourceFormat.NEWLINE_DELIMITED_JSON

if job_config.schema is None:
job_config.autodetect = True
if new_job_config.schema is None:
new_job_config.autodetect = True

if project is None:
project = self.project
Expand All @@ -2828,7 +2849,7 @@ def load_table_from_json(
job_id_prefix=job_id_prefix,
location=location,
project=project,
job_config=job_config,
job_config=new_job_config,
timeout=timeout,
)

Expand Down
6 changes: 5 additions & 1 deletion google/cloud/bigquery/job/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def to_api_repr(self) -> dict:
"""
return copy.deepcopy(self._properties)

def _fill_from_default(self, default_job_config):
def _fill_from_default(self, default_job_config=None):
"""Merge this job config with a default job config.

The keys in this object take precedence over the keys in the default
Expand All @@ -283,6 +283,10 @@ def _fill_from_default(self, default_job_config):
Returns:
google.cloud.bigquery.job._JobConfig: A new (merged) job config.
"""
if not default_job_config:
new_job_config = copy.deepcopy(self)
return new_job_config

if self._job_type != default_job_config._job_type:
raise TypeError(
"attempted to merge two incompatible job types: "
Expand Down
8 changes: 4 additions & 4 deletions tests/system/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2319,7 +2319,7 @@ def _table_exists(t):
return False


def test_dbapi_create_view(dataset_id):
def test_dbapi_create_view(dataset_id: str):

query = f"""
CREATE VIEW {dataset_id}.dbapi_create_view
Expand All @@ -2332,7 +2332,7 @@ def test_dbapi_create_view(dataset_id):
assert Config.CURSOR.rowcount == 0, "expected 0 rows"


def test_parameterized_types_round_trip(dataset_id):
def test_parameterized_types_round_trip(dataset_id: str):
client = Config.CLIENT
table_id = f"{dataset_id}.test_parameterized_types_round_trip"
fields = (
Expand All @@ -2358,7 +2358,7 @@ def test_parameterized_types_round_trip(dataset_id):
assert tuple(s._key()[:2] for s in table2.schema) == fields


def test_table_snapshots(dataset_id):
def test_table_snapshots(dataset_id: str):
from google.cloud.bigquery import CopyJobConfig
from google.cloud.bigquery import OperationType

Expand Down Expand Up @@ -2429,7 +2429,7 @@ def test_table_snapshots(dataset_id):
assert rows == [(1, "one"), (2, "two")]


def test_table_clones(dataset_id):
def test_table_clones(dataset_id: str):
from google.cloud.bigquery import CopyJobConfig
from google.cloud.bigquery import OperationType

Expand Down
29 changes: 28 additions & 1 deletion tests/unit/job/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1104,7 +1104,7 @@ def test_ctor_with_unknown_property_raises_error(self):
config = self._make_one()
config.wrong_name = None

def test_fill_from_default(self):
def test_fill_query_job_config_from_default(self):
from google.cloud.bigquery import QueryJobConfig

job_config = QueryJobConfig()
Expand All @@ -1120,6 +1120,22 @@ def test_fill_from_default(self):
self.assertTrue(final_job_config.use_query_cache)
self.assertEqual(final_job_config.maximum_bytes_billed, 1000)

def test_fill_load_job_from_default(self):
from google.cloud.bigquery import LoadJobConfig

job_config = LoadJobConfig()
job_config.create_session = True
job_config.encoding = "UTF-8"

default_job_config = LoadJobConfig()
default_job_config.ignore_unknown_values = True
default_job_config.encoding = "ISO-8859-1"

final_job_config = job_config._fill_from_default(default_job_config)
self.assertTrue(final_job_config.create_session)
self.assertTrue(final_job_config.ignore_unknown_values)
self.assertEqual(final_job_config.encoding, "UTF-8")

def test_fill_from_default_conflict(self):
from google.cloud.bigquery import QueryJobConfig

Expand All @@ -1132,6 +1148,17 @@ def test_fill_from_default_conflict(self):
with self.assertRaises(TypeError):
basic_job_config._fill_from_default(conflicting_job_config)

def test_fill_from_empty_default_conflict(self):
from google.cloud.bigquery import QueryJobConfig

job_config = QueryJobConfig()
job_config.dry_run = True
job_config.maximum_bytes_billed = 1000

final_job_config = job_config._fill_from_default(default_job_config=None)
self.assertTrue(final_job_config.dry_run)
self.assertEqual(final_job_config.maximum_bytes_billed, 1000)

@mock.patch("google.cloud.bigquery._helpers._get_sub_prop")
def test__get_sub_prop_wo_default(self, _get_sub_prop):
job_config = self._make_one()
Expand Down
Loading