Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhanced schema fixes #34

Merged
merged 7 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,18 @@

## Next

### Added

- Optional parameter to specify embedding dimension in `Neo4jVector`, avoiding the need to query the embedding model.

### Changed

- Made the `source` parameter of `GraphDocument` optional and updated related methods to support this.
- Suppressed AggregationSkippedNull warnings raised by the Neo4j driver in the Neo4jGraph class when fetching the enhanced_schema.

### Fixed

- Disabled warnings from the Neo4j driver for the Neo4jGraph class.
- Resolved syntax errors in GraphCypherQAChain by ensuring node labels with spaces are correctly quoted in Cypher queries.

## 0.2.0

Expand Down
107 changes: 56 additions & 51 deletions libs/neo4j/langchain_neo4j/graphs/neo4j_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,6 @@ class Neo4jGraph(GraphStore):
enhanced_schema (bool): A flag whether to scan the database for
example values and use them in the graph schema. Default is False.
driver_config (Dict): Configuration passed to Neo4j Driver.
Defaults to {"notifications_min_severity", "OFF"} if not set.

*Security note*: Make sure that the database connection uses credentials
that are narrowly-scoped to only include necessary permissions.
Expand Down Expand Up @@ -366,10 +365,9 @@ def __init__(
{"database": database}, "database", "NEO4J_DATABASE", "neo4j"
)

if driver_config is None:
driver_config = {}
driver_config.setdefault("notifications_min_severity", "OFF")
self._driver = neo4j.GraphDatabase.driver(url, auth=auth, **driver_config)
self._driver = neo4j.GraphDatabase.driver(
url, auth=auth, **(driver_config or {})
)
self._database = database
self.timeout = timeout
self.sanitize = sanitize
Expand All @@ -379,20 +377,11 @@ def __init__(
# Verify connection
try:
self._driver.verify_connectivity()
except neo4j.exceptions.ConfigurationError as e:
# If notification filtering is not supported
if "Notification filtering is not supported" in str(e):
# Retry without notifications_min_severity
driver_config.pop("notifications_min_severity", None)
self._driver = neo4j.GraphDatabase.driver(
url, auth=auth, **driver_config
)
self._driver.verify_connectivity()
else:
raise ValueError(
"Could not connect to Neo4j database. "
"Please ensure that the driver config is correct"
)
except neo4j.exceptions.ConfigurationError:
raise ValueError(
"Could not connect to Neo4j database. "
"Please ensure that the driver config is correct"
)
except neo4j.exceptions.ServiceUnavailable:
raise ValueError(
"Could not connect to Neo4j database. "
Expand Down Expand Up @@ -442,12 +431,15 @@ def query(
self,
query: str,
params: dict = {},
session_params: dict = {},
) -> List[Dict[str, Any]]:
"""Query Neo4j database.

Args:
query (str): The Cypher query to execute.
params (dict): The parameters to pass to the query.
session_params (dict): Parameters to pass to the session used for executing
the query.

Returns:
List[Dict[str, Any]]: The list of dictionaries containing the query results.
Expand All @@ -459,39 +451,42 @@ def query(
from neo4j import Query
from neo4j.exceptions import Neo4jError

try:
data, _, _ = self._driver.execute_query(
Query(text=query, timeout=self.timeout),
database_=self._database,
parameters_=params,
)
json_data = [r.data() for r in data]
if self.sanitize:
json_data = [value_sanitize(el) for el in json_data]
return json_data
except Neo4jError as e:
if not (
(
( # isCallInTransactionError
e.code == "Neo.DatabaseError.Statement.ExecutionFailed"
or e.code
== "Neo.DatabaseError.Transaction.TransactionStartFailed"
)
and e.message is not None
and "in an implicit transaction" in e.message
if not session_params:
try:
data, _, _ = self._driver.execute_query(
Query(text=query, timeout=self.timeout),
database_=self._database,
parameters_=params,
)
or ( # isPeriodicCommitError
e.code == "Neo.ClientError.Statement.SemanticError"
and e.message is not None
and (
"in an open transaction is not possible" in e.message
or "tried to execute in an explicit transaction" in e.message
json_data = [r.data() for r in data]
if self.sanitize:
json_data = [value_sanitize(el) for el in json_data]
return json_data
except Neo4jError as e:
if not (
(
( # isCallInTransactionError
e.code == "Neo.DatabaseError.Statement.ExecutionFailed"
or e.code
== "Neo.DatabaseError.Transaction.TransactionStartFailed"
)
and e.message is not None
and "in an implicit transaction" in e.message
)
)
):
raise
or ( # isPeriodicCommitError
e.code == "Neo.ClientError.Statement.SemanticError"
and e.message is not None
and (
"in an open transaction is not possible" in e.message
or "tried to execute in an explicit transaction"
in e.message
)
)
):
raise
# fallback to allow implicit transactions
with self._driver.session(database=self._database) as session:
session_params.setdefault("database", self._database)
with self._driver.session(**session_params) as session:
result = session.run(Query(text=query, timeout=self.timeout), params)
json_data = [r.data() for r in result]
if self.sanitize:
Expand Down Expand Up @@ -551,7 +546,8 @@ def refresh_schema(self) -> None:
}
if self._enhanced_schema:
schema_counts = self.query(
"CALL apoc.meta.graphSample() YIELD nodes, relationships "
"CALL apoc.meta.graph({sample: 1000, maxRels: 100}) "
"YIELD nodes, relationships "
"RETURN nodes, [rel in relationships | {name:apoc.any.property"
"(rel, 'type'), count: apoc.any.property(rel, 'count')}]"
" AS relationships"
Expand All @@ -569,7 +565,16 @@ def refresh_schema(self) -> None:
)
# Due to schema-flexible nature of neo4j errors can happen
try:
enhanced_info = self.query(enhanced_cypher)[0]["output"]
enhanced_info = self.query(
enhanced_cypher,
# Disable the
# Neo.ClientNotification.Statement.AggregationSkippedNull
# notifications raised by the use of collect in the enhanced
# schema query
session_params={
"notifications_disabled_categories": ["UNRECOGNIZED"]
},
)[0]["output"]
for prop in node_props:
if prop["property"] in enhanced_info:
prop.update(enhanced_info[prop["property"]])
Expand Down
30 changes: 1 addition & 29 deletions libs/neo4j/tests/unit_tests/graphs/test_neo4j_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,35 +202,7 @@ def test_neo4j_graph_init_with_empty_credentials() -> None:
Neo4jGraph(
url="bolt://localhost:7687", username="", password="", refresh_schema=False
)
mock_driver.assert_called_with(
"bolt://localhost:7687", auth=None, notifications_min_severity="OFF"
)


def test_neo4j_graph_init_notification_filtering_err() -> None:
"""Test the __init__ method when notification filtering is disabled."""
with patch("neo4j.GraphDatabase.driver", autospec=True) as mock_driver:
mock_driver_instance = MagicMock()
mock_driver.return_value = mock_driver_instance
err = ConfigurationError("Notification filtering is not supported")
mock_driver_instance.verify_connectivity.side_effect = [err, None]
Neo4jGraph(
url="bolt://localhost:7687",
username="username",
password="password",
refresh_schema=False,
)
mock_driver.assert_any_call(
"bolt://localhost:7687",
auth=("username", "password"),
notifications_min_severity="OFF",
)
# The first call verify_connectivity should fail causing the driver to be
# recreated without the notifications_min_severity parameter
mock_driver.assert_any_call(
"bolt://localhost:7687",
auth=("username", "password"),
)
mock_driver.assert_called_with("bolt://localhost:7687", auth=None)


def test_neo4j_graph_init_driver_config_err() -> None:
Expand Down
Loading