From bfbd396c399de7ee8067dc5218e02705f9caa4bf Mon Sep 17 00:00:00 2001 From: Paco Nathan Date: Wed, 10 Jan 2024 16:43:09 -0800 Subject: [PATCH] unit test coverage for deser --- docs/index.md | 2 +- pkg_doc.cfg | 3 ++- tests/test_load.py | 62 +++++++++++++++++++++++++++++++++++++++++++++ textgraphs/graph.py | 55 ++++++++++++++++++++++++++++++++++++++++ 4 files changed, 120 insertions(+), 2 deletions(-) create mode 100644 tests/test_load.py diff --git a/docs/index.md b/docs/index.md index 09c4dd7..5bad87d 100644 --- a/docs/index.md +++ b/docs/index.md @@ -18,7 +18,7 @@ expensive process and poor results, relative to other methods. That said, are there other ways transformers might help augment natural language workflows? -This project results from an ongoing pursuit of that line of inquiry. +This project results from an ongoing pursuit along these lines of inquiry. With sufficiently narrowed task focus and ample software engineering, transformers can be used to augment specific _components_ of natural language workflows. diff --git a/pkg_doc.cfg b/pkg_doc.cfg index 214e46d..494b821 100644 --- a/pkg_doc.cfg +++ b/pkg_doc.cfg @@ -7,8 +7,9 @@ "TextGraphs", "SimpleGraph", "Node", - "NodeEnum", "Edge", + "EnumBase", + "NodeEnum", "RelEnum", "PipelineFactory", "Pipeline", diff --git a/tests/test_load.py b/tests/test_load.py new file mode 100644 index 0000000..68ab4da --- /dev/null +++ b/tests/test_load.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +""" +unit tests: + + * serialization and deserialization + +see copyright/license https://huggingface.co/spaces/DerwenAI/textgraphs/blob/main/README.md +""" + +from os.path import abspath, dirname +import json +import pathlib +import sys + +import deepdiff # pylint: disable=E0401 + +sys.path.insert(0, str(pathlib.Path(dirname(dirname(abspath(__file__)))))) +import textgraphs # pylint: disable=C0413 + + +def test_load_minimal ( + ) -> None: + """ +Construct a _lemma graph_ from a minimal example, then compare +serialized and deserialized data to ensure no fields get corrupted +in the conversions. + """ + text: str = """ +See Spot run. + """ + + tg: textgraphs.TextGraphs = textgraphs.TextGraphs() # pylint: disable=C0103 + pipe: textgraphs.Pipeline = tg.create_pipeline(text.strip()) + + # serialize into node-link format + tg.collect_graph_elements(pipe) + tg.construct_lemma_graph() + tg.calc_phrase_ranks() + + json_str: str = tg.dump_lemma_graph() + exp_graph = json.loads(json_str) + + # deserialize from node-link format + tg = textgraphs.TextGraphs() # pylint: disable=C0103 + tg.load_lemma_graph(json_str) + tg.construct_lemma_graph() + + obs_graph: dict = json.loads(tg.dump_lemma_graph()) + + # compare + diff: deepdiff.diff.DeepDiff = deepdiff.DeepDiff(exp_graph, obs_graph) + + if len(diff) > 0: + print(json.dumps(json.loads(diff.to_json()), indent = 2)) + + assert len(diff) == 0 + + +if __name__ == "__main__": + test_load_minimal() diff --git a/textgraphs/graph.py b/textgraphs/graph.py index 705935e..4ec2588 100644 --- a/textgraphs/graph.py +++ b/textgraphs/graph.py @@ -281,3 +281,58 @@ def dump_lemma_graph ( indent = 2, separators = ( ",", ":" ), ) + + + def load_lemma_graph ( + self, + json_str: str, + ) -> None: + """ +Load from a JSON string in +a JSON representation of the exported _lemma graph_ in +[_node-link_](https://networkx.org/documentation/stable/reference/readwrite/json_graph.html) +format + """ + dat: dict = json.loads(json_str) + tokens: typing.List[ Node ] = [] + + # deserialize the nodes + for nx_node in dat.get("nodes"): # type: ignore + label: typing.Optional[ str ] = None + kind: NodeEnum = NodeEnum.decode(nx_node["kind"]) # type: ignore + + if kind in [ NodeEnum.ENT ]: + label = nx_node["label"] + + node: Node = self.make_node( + tokens, + nx_node["lemma"], + None, + kind, + 0, + 0, + 0, + label = label, + length = nx_node["length"], + ) + + node.text = nx_node["name"] + node.pos = nx_node["pos"] + node.loc = eval(nx_node["loc"]) # pylint: disable=W0123 + node.count = int(nx_node["count"]) + node.neighbors = int(nx_node["hood"]) + + # deserialize the edges + node_list: typing.List[ Node ] = list(self.nodes.values()) + + for nx_edge in dat.get("links"): # type: ignore + edge: Edge = self.make_edge( # type: ignore + node_list[nx_edge["source"]], + node_list[nx_edge["target"]], + RelEnum.decode(nx_edge["kind"]), # type: ignore + nx_edge["title"], + float(nx_edge["prob"]), + key = nx_edge["lemma"], + ) + + edge.count = int(nx_edge["count"])