Skip to content

Commit

Permalink
Fix errors in Databricks SQL operator introduced when refactoring
Browse files Browse the repository at this point in the history
When SQLExecuteQueryOperator has been introduced in apache#25717, it
introduced some errors in the Databricks SQL operator:

* The schema (description) parameter has been passed as _process_output
  parameter from Hook's output
* The run() method of DatabricksHook was not conforming to other
  run methods of the Hook - it was returning Tuple of the
  result/description
* The _process_output type was not specified - if scalar was used
  it returned different output than without it and it was not
  specified in the DBApiHook.

This PR fixes it by:

* the Databricks Hook is now conformant to the other DBAPIHooks in
  terms of value returned by Hook (backwards incompatible so we
  need to bump major version of the provider)
* the DBApiHook now has "last_description" field which on one hand
  makes it stateless, on the other, the state reflects the
  description of the last run method and is not a problem to keep.
  This implies 1.4 version of common-sql provider as this is a new
  feature for the provider
* the DBApiHook now has "scalar_return_last" field that indicates
  if scalar output was specified.
* Python dbapi's "description" is properly named now - previously it was
  "schema" which clashed with the "schema" name passed to hook
  initialisation - the actual database schema
  • Loading branch information
potiuk committed Nov 23, 2022
1 parent a343bba commit d581af8
Show file tree
Hide file tree
Showing 8 changed files with 78 additions and 36 deletions.
11 changes: 7 additions & 4 deletions airflow/providers/common/sql/hooks/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from contextlib import closing
from datetime import datetime
from typing import Any, Callable, Iterable, Mapping, cast
from typing import Any, Callable, Iterable, Mapping, Sequence, cast

import sqlparse
from packaging.version import Version
Expand Down Expand Up @@ -111,9 +111,11 @@ def __init__(self, *args, schema: str | None = None, log_sql: bool = True, **kwa
# We should not make schema available in deriving hooks for backwards compatibility
# If a hook deriving from DBApiHook has a need to access schema, then it should retrieve it
# from kwargs and store it on its own. We do not run "pop" here as we want to give the
# Hook deriving from the DBApiHook to still have access to the field in it's constructor
# Hook deriving from the DBApiHook to still have access to the field in its constructor
self.__schema = schema
self.log_sql = log_sql
self.scalar_return_last = False
self.last_description: Sequence[Sequence] | None = None

def get_conn(self):
"""Returns a connection object"""
Expand Down Expand Up @@ -244,7 +246,7 @@ def run(
:param return_last: Whether to return result for only last statement or for all after split
:return: return only result of the ALL SQL expressions if handler was provided.
"""
scalar_return_last = isinstance(sql, str) and return_last
self.scalar_return_last = isinstance(sql, str) and return_last
if isinstance(sql, str):
if split_statements:
sql = self.split_sql_string(sql)
Expand All @@ -268,14 +270,15 @@ def run(
if handler is not None:
result = handler(cur)
results.append(result)
self.last_description = cur.description

# If autocommit was set to False or db does not support autocommit, we do a manual commit.
if not self.get_autocommit(conn):
conn.commit()

if handler is None:
return None
elif scalar_return_last:
elif self.scalar_return_last:
return results[-1]
else:
return results
Expand Down
36 changes: 30 additions & 6 deletions airflow/providers/common/sql/operators/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@

import ast
import re
from typing import TYPE_CHECKING, Any, Callable, Iterable, Mapping, NoReturn, Sequence, SupportsAbs
from typing import TYPE_CHECKING, Any, Callable, Iterable, Mapping, NoReturn, Sequence, SupportsAbs, overload

from airflow.compat.functools import cached_property
from airflow.exceptions import AirflowException, AirflowFailException
from airflow.hooks.base import BaseHook
from airflow.models import BaseOperator, SkipMixin
from airflow.providers.common.sql.hooks.sql import DbApiHook, fetch_all_handler
from airflow.typing_compat import Literal

if TYPE_CHECKING:
from airflow.utils.context import Context
Expand Down Expand Up @@ -224,6 +225,33 @@ def __init__(
self.split_statements = split_statements
self.return_last = return_last

@overload
def _process_output(
self, results: Any, description: Sequence[Sequence] | None, scalar_results: Literal[True]
) -> Any:
pass

@overload
def _process_output(
self, results: list[Any], description: Sequence[Sequence] | None, scalar_results: Literal[False]
) -> Any:
pass

def _process_output(
self, results: Any | list[Any], description: Sequence[Sequence] | None, scalar_results: bool
) -> Any:
"""
Can be overridden by the subclass in case some extra processing is needed.
The "process_output" method can override the returned output - augmenting or processing the
output as needed - the output returned will be returned as execute return value and if
do_xcom_push is set to True, it will be set as XCom returned
:param results: results in the form of list of rows.
:param description: as returned by ``cur.description`` in the Python DBAPI
:param scalar_results: True if result is single scalar value rather than list of rows
"""
return results

def execute(self, context):
self.log.info("Executing: %s", self.sql)
hook = self.get_db_hook()
Expand All @@ -244,11 +272,7 @@ def execute(self, context):
split_statements=self.split_statements,
)

if hasattr(self, "_process_output"):
for out in output:
self._process_output(*out)

return output
return self._process_output(output, hook.last_description, hook.scalar_return_last)

def prepare_template(self) -> None:
"""Parse template file for attribute parameters."""
Expand Down
10 changes: 5 additions & 5 deletions airflow/providers/databricks/hooks/databricks_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def run(
handler: Callable | None = None,
split_statements: bool = True,
return_last: bool = True,
) -> tuple[str, Any] | list[tuple[str, Any]] | None:
) -> Any | list[Any] | None:
"""
Runs a command or a list of commands. Pass a list of sql
statements to the sql parameter to get them to execute
Expand All @@ -163,7 +163,7 @@ def run(
:param return_last: Whether to return result for only last statement or for all after split
:return: return only result of the LAST SQL expression if handler was provided.
"""
scalar_return_last = isinstance(sql, str) and return_last
self.scalar_return_last = isinstance(sql, str) and return_last
if isinstance(sql, str):
if split_statements:
sql = self.split_sql_string(sql)
Expand All @@ -186,14 +186,14 @@ def run(

if handler is not None:
result = handler(cur)
schema = cur.description
results.append((schema, result))
results.append(result)
self.last_description = cur.description

self._sql_conn = None

if handler is None:
return None
elif scalar_return_last:
elif self.scalar_return_last:
return results[-1]
else:
return results
Expand Down
22 changes: 16 additions & 6 deletions airflow/providers/databricks/operators/databricks_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,12 +120,21 @@ def get_db_hook(self) -> DatabricksSqlHook:
}
return DatabricksSqlHook(self.databricks_conn_id, **hook_params)

def _process_output(self, schema, results):
def _process_output(
self, results: Any | list[Any], description: Sequence[Sequence] | None, scalar_results: bool
) -> Any:
if not self._output_path:
return
return description, results
if not self._output_format:
raise AirflowException("Output format should be specified!")
field_names = [field[0] for field in schema]
if description is None:
self.log.warning("Description of the cursor is missing. Will not process the output")
return description, results
field_names = [field[0] for field in description]
if scalar_results:
list_results: list[Any] = [results]
else:
list_results = results
if self._output_format.lower() == "csv":
with open(self._output_path, "w", newline="") as file:
if self._csv_params:
Expand All @@ -138,18 +147,19 @@ def _process_output(self, schema, results):
writer = csv.DictWriter(file, fieldnames=field_names, **csv_params)
if write_header:
writer.writeheader()
for row in results:
for row in list_results:
writer.writerow(row.asDict())
elif self._output_format.lower() == "json":
with open(self._output_path, "w") as file:
file.write(json.dumps([row.asDict() for row in results]))
file.write(json.dumps([row.asDict() for row in list_results]))
elif self._output_format.lower() == "jsonl":
with open(self._output_path, "w") as file:
for row in results:
for row in list_results:
file.write(json.dumps(row.asDict()))
file.write("\n")
else:
raise AirflowException(f"Unsupported output format: '{self._output_format}'")
return description, results


COPY_INTO_APPROVED_FORMATS = ["CSV", "JSON", "AVRO", "ORC", "PARQUET", "TEXT", "BINARYFILE"]
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/exasol/hooks/exasol.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def run(
:param return_last: Whether to return result for only last statement or for all after split
:return: return only result of the LAST SQL expression if handler was provided.
"""
scalar_return_last = isinstance(sql, str) and return_last
self.scalar_return_last = isinstance(sql, str) and return_last
if isinstance(sql, str):
if split_statements:
sql = self.split_sql_string(sql)
Expand Down Expand Up @@ -187,7 +187,7 @@ def run(

if handler is None:
return None
elif scalar_return_last:
elif self.scalar_return_last:
return results[-1]
else:
return results
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/snowflake/hooks/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ def run(
"""
self.query_ids = []

scalar_return_last = isinstance(sql, str) and return_last
self.scalar_return_last = isinstance(sql, str) and return_last
if isinstance(sql, str):
if split_statements:
split_statements_tuple = util_text.split_statements(StringIO(sql))
Expand Down Expand Up @@ -387,7 +387,7 @@ def run(

if handler is None:
return None
elif scalar_return_last:
elif self.scalar_return_last:
return results[-1]
else:
return results
11 changes: 6 additions & 5 deletions tests/providers/databricks/hooks/test_databricks_sql.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
Expand All @@ -15,6 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

#
from __future__ import annotations

import unittest
Expand Down Expand Up @@ -70,17 +71,17 @@ def test_query(self, mock_requests, mock_conn):
type(mock_requests.get.return_value).status_code = status_code_mock

test_fields = ["id", "value"]
test_schema = [(field,) for field in test_fields]
test_description = [(field,) for field in test_fields]

conn = mock_conn.return_value
cur = mock.MagicMock(rowcount=0, description=test_schema)
cur = mock.MagicMock(rowcount=0, description=test_description)
cur.fetchall.return_value = []
conn.cursor.return_value = cur

query = "select * from test.test;"
schema, results = self.hook.run(sql=query, handler=fetch_all_handler)
results = self.hook.run(sql=query, handler=fetch_all_handler)

assert schema == test_schema
assert self.hook.last_description == test_description
assert results == []

cur.execute.assert_has_calls([mock.call(q) for q in [query]])
Expand Down
16 changes: 10 additions & 6 deletions tests/providers/databricks/operators/test_databricks_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,15 @@ def test_exec_success(self, db_mock_class):
sql = "select * from dummy"
op = DatabricksSqlOperator(task_id=TASK_ID, sql=sql, do_xcom_push=True)
db_mock = db_mock_class.return_value
mock_schema = [("id",), ("value",)]
mock_description = [("id",), ("value",)]
mock_results = [Row(id=1, value="value1")]
db_mock.run.return_value = [(mock_schema, mock_results)]
db_mock.run.return_value = mock_results
db_mock.last_description = mock_description
db_mock.scalar_return_last = False

results = op.execute(None)
execute_results = op.execute(None)

assert results[0][1] == mock_results
assert execute_results == (mock_description, mock_results)
db_mock_class.assert_called_once_with(
DEFAULT_CONN_ID,
http_path=None,
Expand Down Expand Up @@ -82,9 +84,11 @@ def test_exec_write_file(self, db_mock_class):
tempfile_path = tempfile.mkstemp()[1]
op = DatabricksSqlOperator(task_id=TASK_ID, sql=sql, output_path=tempfile_path)
db_mock = db_mock_class.return_value
mock_schema = [("id",), ("value",)]
mock_description = [("id",), ("value",)]
mock_results = [Row(id=1, value="value1")]
db_mock.run.return_value = [(mock_schema, mock_results)]
db_mock.run.return_value = mock_results
db_mock.last_description = mock_description
db_mock.scalar_return_last = False

try:
op.execute(None)
Expand Down

0 comments on commit d581af8

Please sign in to comment.