diff --git a/core/database_arango.py b/core/database_arango.py index 30f9a54d9..ebb2ae889 100644 --- a/core/database_arango.py +++ b/core/database_arango.py @@ -28,6 +28,8 @@ from .interfaces import AbstractYetiConnector +CODE_DB_VERSION = 2 + LINK_TYPE_TO_GRAPH = { "tagged": "tags", "stix": "stix", @@ -58,6 +60,7 @@ def connect( username: str = None, password: str = None, database: str = None, + check_db_sync: bool = False, ): host = host or yeti_config.get("arangodb", "host") port = port or yeti_config.get("arangodb", "port") @@ -88,6 +91,8 @@ def connect( sys_db.create_database(database) self.db = client.db(database, username=username, password=password) + if check_db_sync: + self.check_database_version() self.create_edge_definition( self.graph("tags"), @@ -120,6 +125,17 @@ def connect( self.create_analyzers() self.create_views() + def check_database_version(self, skip_if_testing: bool = True): + if TESTING and skip_if_testing: + return + system = list(self.db.collection("system").all()) + if not system: + raise RuntimeError("Database version not found, please run migrations.") + if system[0]["db_version"] != CODE_DB_VERSION: + raise RuntimeError( + f"Database version mismatch. Expected {CODE_DB_VERSION}, got {system[0]['db_version']}" + ) + def create_analyzers(self): self.db.create_analyzer( name="norm", diff --git a/core/migrations/__init__.py b/core/migrations/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/core/migrations/arangodb.py b/core/migrations/arangodb.py new file mode 100644 index 000000000..23c2fce6b --- /dev/null +++ b/core/migrations/arangodb.py @@ -0,0 +1,63 @@ +import time + +from core.database_arango import ASYNC_JOB_WAIT_TIME, ArangoDatabase +from core.migrations import migration + + +class ArangoMigrationManager(migration.MigrationManager): + DB_TYPE = "arangodb" + + def connect_to_db(self): + self.db = ArangoDatabase() + self.db.connect(check_db_sync=False) + + system_coll = self.db.collection("system") + job = system_coll.all() + while job.status() != "done": + time.sleep(ASYNC_JOB_WAIT_TIME) + migrations = list(job.result()) + if not migrations: + job = system_coll.insert( + {"db_version": 0, "db_type": self.DB_TYPE}, + ) + while job.status() != "done": + time.sleep(ASYNC_JOB_WAIT_TIME) + + job = system_coll.all() + while job.status() != "done": + time.sleep(ASYNC_JOB_WAIT_TIME) + migrations = list(job.result()) + + db_version = migrations[0]["db_version"] + db_type = migrations[0]["db_type"] + + self.db_version = db_version + self.db_type = db_type + + def update_db_version(self, version: int): + job = self.db.collection("system").update_match( + {"db_version": self.db_version, "db_type": self.DB_TYPE}, + {"db_version": version}, + ) + while job.status() != "done": + time.sleep(ASYNC_JOB_WAIT_TIME) + self.db_version = version + + +def migration_0(): + pass + + +def migration_1(): + from core.schemas import observable + + for obs in observable.Observable.list(): + obs.save() + + +ArangoMigrationManager.register_migration(migration_0) +ArangoMigrationManager.register_migration(migration_1) + +if __name__ == "__main__": + migration_manager = ArangoMigrationManager() + migration_manager.migrate_to_latest() diff --git a/core/migrations/migration.py b/core/migrations/migration.py new file mode 100644 index 000000000..cf018146b --- /dev/null +++ b/core/migrations/migration.py @@ -0,0 +1,30 @@ +from typing import Callable + + +class MigrationManager: + MIGRATIONS: list[Callable] = [] + + def __init__(self): + self.connect_to_db() + + def connect_to_db(self): + raise NotImplementedError + + def update_db_version(self, version: int): + raise NotImplementedError + + def migrate_to_latest(self, stop_at: int | None = None): + for idx, migration in enumerate(self.MIGRATIONS): + if stop_at is not None and idx >= stop_at: + print(f"Stopping at migration {idx}") + elif idx >= self.db_version and (stop_at is None or idx < stop_at): + print(f"Running migration {idx} -> {idx + 1}") + migration() + self.update_db_version(idx + 1) + else: + print(f"Skipping migration {idx}, current version is {self.db_version}") + continue + + @classmethod + def register_migration(cls, migration): + cls.MIGRATIONS.append(migration) diff --git a/extras/docker/docker-entrypoint.sh b/extras/docker/docker-entrypoint.sh index 023826ce2..c5862c54d 100755 --- a/extras/docker/docker-entrypoint.sh +++ b/extras/docker/docker-entrypoint.sh @@ -17,6 +17,8 @@ elif [[ "$1" = 'toggle-user' ]]; then poetry run python yetictl/cli.py toggle-user "${@:2}" elif [[ "$1" = 'toggle-admin' ]]; then poetry run python yetictl/cli.py toggle-admin "${@:2}" +elif [[ "$1" = 'migrate-arangodb' ]]; then + poetry run python yetictl/cli.py migrate-arangodb "${@:2}" elif [[ "$1" = 'envshell' ]]; then poetry shell else diff --git a/tests/migration.py b/tests/migration.py new file mode 100644 index 000000000..68698bd57 --- /dev/null +++ b/tests/migration.py @@ -0,0 +1,48 @@ +import time +import unittest + +from core.migrations import arangodb + + +class ArangoMigrationTest(unittest.TestCase): + def setUp(self): + self.migration_manager = arangodb.ArangoMigrationManager() + self.migration_manager.update_db_version(0) + + def test_migration_init(self): + self.assertEqual(self.migration_manager.db_version, 0) + + def test_migration_0(self): + self.migration_manager.migrate_to_latest(stop_at=1) + self.assertEqual(self.migration_manager.db_version, 1) + + def test_migration_1(self): + observable_col = self.migration_manager.db.collection("observables") + observable_col.truncate() + observable_col.insert( + { + "value": "test.com", + "type": "hostname", + "root_type": "observable", + "created": "2024-11-14T11:58:49.757379Z", + } + ) + observable_col.insert( + { + "value": "test.com123", + "type": "hostname", + "root_type": "observable", + "created": "2024-11-14T11:58:49.757379Z", + } + ) + self.migration_manager.migrate_to_latest(stop_at=2) + self.assertEqual(self.migration_manager.db_version, 2) + job = observable_col.all() + while job.status() != "done": + time.sleep(0.1) + obs = list(job.result()) + self.assertEqual(len(obs), 2) + self.assertEqual(obs[0]["value"], "test.com") + self.assertEqual(obs[0]["is_valid"], True) + self.assertEqual(obs[1]["value"], "test.com123") + self.assertEqual(obs[1]["is_valid"], False) diff --git a/yetictl/cli.py b/yetictl/cli.py index 9f4a5d01d..c85691fc4 100644 --- a/yetictl/cli.py +++ b/yetictl/cli.py @@ -125,7 +125,7 @@ def list_tasks(task_type="") -> None: @cli.command() @click.argument("task_name") @click.argument("task_params", required=False) -def run_task(task_name: str, task_params: dict = None) -> None: +def run_task(task_name: str, task_params: dict | None = None) -> None: """Runs a task.""" # Load all tasks. Take into account new tasks that have not been registered logging.getLogger().setLevel(logging.INFO) @@ -149,5 +149,15 @@ def run_task(task_name: str, task_params: dict = None) -> None: click.echo(traceback.format_exc()) +@cli.command() +@click.argument("stop_at", required=False) +def migrate_arangodb(stop_at: int | None = None) -> None: + """Runs the database migrations.""" + from core.migrations.arangodb import ArangoMigrationManager + + migration_manager = ArangoMigrationManager() + migration_manager.migrate_to_latest(stop_at=stop_at) + + if __name__ == "__main__": cli()