diff --git a/setup.py b/setup.py index f9e514cd..9d367469 100644 --- a/setup.py +++ b/setup.py @@ -1,5 +1,8 @@ from setuptools import setup +with open("README.md", "r", encoding="utf-8") as fh: + long_description = fh.read() + setup( name="torch_struct", version="0.5", @@ -9,9 +12,12 @@ "torch_struct", "torch_struct.semirings", ], + long_description=long_description, package_data={"torch_struct": []}, - url="https://github.com/harvardnlp/pytorch_struct", + long_description_content_type="text/markdown", + url="https://github.com/harvardnlp/pytorch-struct", install_requires=["torch"], setup_requires=["pytest-runner"], tests_require=["pytest"], + python_requires='>=3.6', ) diff --git a/torch_struct/semirings/checkpoint.py b/torch_struct/semirings/checkpoint.py index b2dacba5..c4e10c4f 100644 --- a/torch_struct/semirings/checkpoint.py +++ b/torch_struct/semirings/checkpoint.py @@ -1,8 +1,10 @@ import torch +has_genbmm = False try: import genbmm from genbmm import BandedMatrix + has_genbmm = True except ImportError: pass @@ -52,7 +54,7 @@ def backward(ctx, grad_output): class _CheckpointSemiring(cls): @staticmethod def matmul(a, b): - if isinstance(a, genbmm.BandedMatrix): + if has_genbmm and isinstance(a, genbmm.BandedMatrix): lu = a.lu + b.lu ld = a.ld + b.ld c = _CheckBand.apply(a.data, a.lu, a.ld, b.data, b.lu, b.ld)