Skip to content

Commit

Permalink
Fix templating fields and do_xcom_push in DatabricksSQLOperator (#27868)
Browse files Browse the repository at this point in the history
When SQLExecuteQueryOperator has been introduced in #25717, it
introduced some errors in the Databricks SQL operator:

* The templated "schema" field has not been set as field in the
  operator.
* The do_xcom_push parameter was ignored

This PR fixes it by:

* storing schema as field and using it via self reference
* do_xcom_push is removed (and BaseOperator's one is used).
  • Loading branch information
potiuk authored Nov 23, 2022
1 parent 68ee3bf commit a343bba
Showing 1 changed file with 19 additions and 17 deletions.
36 changes: 19 additions & 17 deletions airflow/providers/databricks/operators/databricks_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ class DatabricksSqlOperator(SQLExecuteQueryOperator):
:param output_format: format of output data if ``output_path` is specified.
Possible values are ``csv``, ``json``, ``jsonl``. Default is ``csv``.
:param csv_params: parameters that will be passed to the ``csv.DictWriter`` class used to write CSV data.
:param do_xcom_push: If True, then the result of SQL executed will be pushed to an XCom.
"""

template_fields: Sequence[str] = (
Expand All @@ -87,7 +86,6 @@ def __init__(
http_headers: list[tuple[str, str]] | None = None,
catalog: str | None = None,
schema: str | None = None,
do_xcom_push: bool = False,
output_path: str | None = None,
output_format: str = "csv",
csv_params: dict[str, Any] | None = None,
Expand All @@ -99,24 +97,28 @@ def __init__(
self._output_path = output_path
self._output_format = output_format
self._csv_params = csv_params
self.http_path = http_path
self.sql_endpoint_name = sql_endpoint_name
self.session_configuration = session_configuration
self.client_parameters = {} if client_parameters is None else client_parameters
self.hook_params = kwargs.pop("hook_params", {})
self.http_headers = http_headers
self.catalog = catalog
self.schema = schema

client_parameters = {} if client_parameters is None else client_parameters
hook_params = kwargs.pop("hook_params", {})

self.hook_params = {
"http_path": http_path,
"session_configuration": session_configuration,
"sql_endpoint_name": sql_endpoint_name,
"http_headers": http_headers,
"catalog": catalog,
"schema": schema,
def get_db_hook(self) -> DatabricksSqlHook:
hook_params = {
"http_path": self.http_path,
"session_configuration": self.session_configuration,
"sql_endpoint_name": self.sql_endpoint_name,
"http_headers": self.http_headers,
"catalog": self.catalog,
"schema": self.schema,
"caller": "DatabricksSqlOperator",
**client_parameters,
**hook_params,
**self.client_parameters,
**self.hook_params,
}

def get_db_hook(self) -> DatabricksSqlHook:
return DatabricksSqlHook(self.databricks_conn_id, **self.hook_params)
return DatabricksSqlHook(self.databricks_conn_id, **hook_params)

def _process_output(self, schema, results):
if not self._output_path:
Expand Down

0 comments on commit a343bba

Please sign in to comment.