Skip to content

Commit

Permalink
fix: to_gbq allows strings for DATE and floats for NUMERIC with `ap…
Browse files Browse the repository at this point in the history
…i_method="load_parquet"` (#423)

deps: require pandas 0.24+ and db-dtypes for TIME/DATE extension dtypes (#423)
  • Loading branch information
tswast authored Nov 22, 2021
1 parent 3e70975 commit 2180836
Show file tree
Hide file tree
Showing 12 changed files with 279 additions and 80 deletions.
2 changes: 1 addition & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ jobs:
- image: continuumio/miniconda3
environment:
PYTHON: "3.7"
PANDAS: "0.23.2"
PANDAS: "0.24.2"
steps:
- checkout
- run: ci/config_auth.sh
Expand Down
2 changes: 1 addition & 1 deletion .coveragerc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ omit =
google/cloud/__init__.py

[report]
fail_under = 86
fail_under = 88
show_missing = True
exclude_lines =
# Re-enable the standard pragma
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
codecov
coverage
db-dtypes==0.3.0
fastavro
flake8
numpy==1.16.6
Expand Down
1 change: 1 addition & 0 deletions ci/requirements-3.9-NIGHTLY.conda
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
db-dtypes
pydata-google-auth
google-cloud-bigquery
google-cloud-bigquery-storage
Expand Down
8 changes: 2 additions & 6 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,11 +146,7 @@ def system(session):
# Install all test dependencies, then install this package into the
# virtualenv's dist-packages.
session.install("mock", "pytest", "google-cloud-testutils", "-c", constraints_path)
if session.python == "3.9":
extras = "[tqdm,db-dtypes]"
else:
extras = "[tqdm]"
session.install("-e", f".{extras}", "-c", constraints_path)
session.install("-e", ".[tqdm]", "-c", constraints_path)

# Run py.test against the system tests.
if system_test_exists:
Expand Down Expand Up @@ -179,7 +175,7 @@ def cover(session):
test runs (not system test runs), and then erases coverage data.
"""
session.install("coverage", "pytest-cov")
session.run("coverage", "report", "--show-missing", "--fail-under=86")
session.run("coverage", "report", "--show-missing", "--fail-under=88")

session.run("coverage", "erase")

Expand Down
6 changes: 1 addition & 5 deletions owlbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,12 @@
# ----------------------------------------------------------------------------

extras = ["tqdm"]
extras_by_python = {
"3.9": ["tqdm", "db-dtypes"],
}
templated_files = common.py_library(
unit_test_python_versions=["3.7", "3.8", "3.9", "3.10"],
system_test_python_versions=["3.7", "3.8", "3.9", "3.10"],
cov_level=86,
cov_level=88,
unit_test_extras=extras,
system_test_extras=extras,
system_test_extras_by_python=extras_by_python,
intersphinx_dependencies={
"pandas": "https://pandas.pydata.org/pandas-docs/stable/",
"pydata-google-auth": "https://pydata-google-auth.readthedocs.io/en/latest/",
Expand Down
52 changes: 52 additions & 0 deletions pandas_gbq/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@

"""Helper methods for loading data into BigQuery"""

import decimal
import io
from typing import Any, Callable, Dict, List, Optional

import db_dtypes
import pandas
import pyarrow.lib
from google.cloud import bigquery
Expand Down Expand Up @@ -56,6 +58,55 @@ def split_dataframe(dataframe, chunksize=None):
yield remaining_rows, chunk


def cast_dataframe_for_parquet(
dataframe: pandas.DataFrame, schema: Optional[Dict[str, Any]],
) -> pandas.DataFrame:
"""Cast columns to needed dtype when writing parquet files.
See: https://github.com/googleapis/python-bigquery-pandas/issues/421
"""

columns = schema.get("fields", [])

# Protect against an explicit None in the dictionary.
columns = columns if columns is not None else []

for column in columns:
# Schema can be a superset of the columns in the dataframe, so ignore
# columns that aren't present.
column_name = column.get("name")
if column_name not in dataframe.columns:
continue

# Skip array columns for now. Potentially casting the elements of the
# array would be possible, but not worth the effort until there is
# demand for it.
if column.get("mode", "NULLABLE").upper() == "REPEATED":
continue

column_type = column.get("type", "").upper()
if (
column_type == "DATE"
# Use extension dtype first so that it uses the correct equality operator.
and db_dtypes.DateDtype() != dataframe[column_name].dtype
):
# Construct converted column manually, because I can't use
# .astype() with DateDtype. With .astype(), I get the error:
#
# TypeError: Cannot interpret '<db_dtypes.DateDtype ...>' as a data type
cast_column = pandas.Series(
dataframe[column_name], dtype=db_dtypes.DateDtype()
)
elif column_type in {"NUMERIC", "DECIMAL", "BIGNUMERIC", "BIGDECIMAL"}:
cast_column = dataframe[column_name].map(decimal.Decimal)
else:
cast_column = None

if cast_column is not None:
dataframe = dataframe.assign(**{column_name: cast_column})
return dataframe


def load_parquet(
client: bigquery.Client,
dataframe: pandas.DataFrame,
Expand All @@ -70,6 +121,7 @@ def load_parquet(
if schema is not None:
schema = pandas_gbq.schema.remove_policy_tags(schema)
job_config.schema = pandas_gbq.schema.to_google_cloud_bigquery(schema)
dataframe = cast_dataframe_for_parquet(dataframe, schema)

try:
client.load_table_from_dataframe(
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@
release_status = "Development Status :: 4 - Beta"
dependencies = [
"setuptools",
"db-dtypes >=0.3.0,<2.0.0",
"numpy>=1.16.6",
"pandas>=0.23.2",
"pandas>=0.24.2",
"pyarrow >=3.0.0, <7.0dev",
"pydata-google-auth",
"google-auth",
Expand All @@ -35,7 +36,6 @@
]
extras = {
"tqdm": "tqdm>=4.23.0",
"db-dtypes": "db-dtypes >=0.3.0,<2.0.0",
}

# Setup boilerplate below this line.
Expand Down
3 changes: 2 additions & 1 deletion testing/constraints-3.7.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
#
# e.g., if setup.py has "foo >= 1.14.0, < 2.0.0dev",
# Then this file should have foo==1.14.0
db-dtypes==0.3.0
google-auth==1.4.1
google-auth-oauthlib==0.0.1
google-cloud-bigquery==1.11.1
google-cloud-bigquery-storage==1.1.0
numpy==1.16.6
pandas==0.23.2
pandas==0.24.2
pyarrow==3.0.0
pydata-google-auth==0.1.2
tqdm==4.23.0
21 changes: 3 additions & 18 deletions tests/system/test_gbq.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@

TABLE_ID = "new_test"
PANDAS_VERSION = pkg_resources.parse_version(pandas.__version__)
NULLABLE_INT_PANDAS_VERSION = pkg_resources.parse_version("0.24.0")
NULLABLE_INT_MESSAGE = "Require pandas 0.24+ in order to use nullable integer type."


def test_imports():
Expand Down Expand Up @@ -173,9 +171,6 @@ def test_should_properly_handle_valid_integers(self, project_id):
tm.assert_frame_equal(df, DataFrame({"valid_integer": [3]}))

def test_should_properly_handle_nullable_integers(self, project_id):
if PANDAS_VERSION < NULLABLE_INT_PANDAS_VERSION:
pytest.skip(msg=NULLABLE_INT_MESSAGE)

query = """SELECT * FROM
UNNEST([1, NULL]) AS nullable_integer
"""
Expand All @@ -188,9 +183,7 @@ def test_should_properly_handle_nullable_integers(self, project_id):
)
tm.assert_frame_equal(
df,
DataFrame(
{"nullable_integer": pandas.Series([1, pandas.NA], dtype="Int64")}
),
DataFrame({"nullable_integer": pandas.Series([1, None], dtype="Int64")}),
)

def test_should_properly_handle_valid_longs(self, project_id):
Expand All @@ -204,9 +197,6 @@ def test_should_properly_handle_valid_longs(self, project_id):
tm.assert_frame_equal(df, DataFrame({"valid_long": [1 << 62]}))

def test_should_properly_handle_nullable_longs(self, project_id):
if PANDAS_VERSION < NULLABLE_INT_PANDAS_VERSION:
pytest.skip(msg=NULLABLE_INT_MESSAGE)

query = """SELECT * FROM
UNNEST([1 << 62, NULL]) AS nullable_long
"""
Expand All @@ -219,15 +209,10 @@ def test_should_properly_handle_nullable_longs(self, project_id):
)
tm.assert_frame_equal(
df,
DataFrame(
{"nullable_long": pandas.Series([1 << 62, pandas.NA], dtype="Int64")}
),
DataFrame({"nullable_long": pandas.Series([1 << 62, None], dtype="Int64")}),
)

def test_should_properly_handle_null_integers(self, project_id):
if PANDAS_VERSION < NULLABLE_INT_PANDAS_VERSION:
pytest.skip(msg=NULLABLE_INT_MESSAGE)

query = "SELECT CAST(NULL AS INT64) AS null_integer"
df = gbq.read_gbq(
query,
Expand All @@ -237,7 +222,7 @@ def test_should_properly_handle_null_integers(self, project_id):
dtypes={"null_integer": "Int64"},
)
tm.assert_frame_equal(
df, DataFrame({"null_integer": pandas.Series([pandas.NA], dtype="Int64")}),
df, DataFrame({"null_integer": pandas.Series([None], dtype="Int64")}),
)

def test_should_properly_handle_valid_floats(self, project_id):
Expand Down
Loading

0 comments on commit 2180836

Please sign in to comment.