-
Notifications
You must be signed in to change notification settings - Fork 293
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
170 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import time | ||
|
||
from core.database_arango import ASYNC_JOB_WAIT_TIME, ArangoDatabase | ||
from core.migrations import migration | ||
|
||
|
||
class ArangoMigrationManager(migration.MigrationManager): | ||
DB_TYPE = "arangodb" | ||
|
||
def connect_to_db(self): | ||
self.db = ArangoDatabase() | ||
self.db.connect(check_db_sync=False) | ||
|
||
system_coll = self.db.collection("system") | ||
job = system_coll.all() | ||
while job.status() != "done": | ||
time.sleep(ASYNC_JOB_WAIT_TIME) | ||
migrations = list(job.result()) | ||
if not migrations: | ||
job = system_coll.insert( | ||
{"db_version": 0, "db_type": self.DB_TYPE}, | ||
) | ||
while job.status() != "done": | ||
time.sleep(ASYNC_JOB_WAIT_TIME) | ||
|
||
job = system_coll.all() | ||
while job.status() != "done": | ||
time.sleep(ASYNC_JOB_WAIT_TIME) | ||
migrations = list(job.result()) | ||
|
||
db_version = migrations[0]["db_version"] | ||
db_type = migrations[0]["db_type"] | ||
|
||
self.db_version = db_version | ||
self.db_type = db_type | ||
|
||
def update_db_version(self, version: int): | ||
job = self.db.collection("system").update_match( | ||
{"db_version": self.db_version, "db_type": self.DB_TYPE}, | ||
{"db_version": version}, | ||
) | ||
while job.status() != "done": | ||
time.sleep(ASYNC_JOB_WAIT_TIME) | ||
self.db_version = version | ||
|
||
|
||
def migration_0(): | ||
pass | ||
|
||
|
||
def migration_1(): | ||
from core.schemas import observable | ||
|
||
for obs in observable.Observable.list(): | ||
obs.save() | ||
|
||
|
||
ArangoMigrationManager.register_migration(migration_0) | ||
ArangoMigrationManager.register_migration(migration_1) | ||
|
||
if __name__ == "__main__": | ||
migration_manager = ArangoMigrationManager() | ||
migration_manager.migrate_to_latest() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
from typing import Callable | ||
|
||
|
||
class MigrationManager: | ||
MIGRATIONS: list[Callable] = [] | ||
|
||
def __init__(self): | ||
self.connect_to_db() | ||
|
||
def connect_to_db(self): | ||
raise NotImplementedError | ||
|
||
def update_db_version(self, version: int): | ||
raise NotImplementedError | ||
|
||
def migrate_to_latest(self, stop_at: int | None = None): | ||
for idx, migration in enumerate(self.MIGRATIONS): | ||
if stop_at is not None and idx >= stop_at: | ||
print(f"Stopping at migration {idx}") | ||
elif idx >= self.db_version and (stop_at is None or idx < stop_at): | ||
print(f"Running migration {idx} -> {idx + 1}") | ||
migration() | ||
self.update_db_version(idx + 1) | ||
else: | ||
print(f"Skipping migration {idx}, current version is {self.db_version}") | ||
continue | ||
|
||
@classmethod | ||
def register_migration(cls, migration): | ||
cls.MIGRATIONS.append(migration) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import time | ||
import unittest | ||
|
||
from core.migrations import arangodb | ||
|
||
|
||
class ArangoMigrationTest(unittest.TestCase): | ||
def setUp(self): | ||
self.migration_manager = arangodb.ArangoMigrationManager() | ||
self.migration_manager.update_db_version(0) | ||
|
||
def test_migration_init(self): | ||
self.assertEqual(self.migration_manager.db_version, 0) | ||
|
||
def test_migration_0(self): | ||
self.migration_manager.migrate_to_latest(stop_at=1) | ||
self.assertEqual(self.migration_manager.db_version, 1) | ||
|
||
def test_migration_1(self): | ||
observable_col = self.migration_manager.db.collection("observables") | ||
observable_col.truncate() | ||
observable_col.insert( | ||
{ | ||
"value": "test.com", | ||
"type": "hostname", | ||
"root_type": "observable", | ||
"created": "2024-11-14T11:58:49.757379Z", | ||
} | ||
) | ||
observable_col.insert( | ||
{ | ||
"value": "test.com123", | ||
"type": "hostname", | ||
"root_type": "observable", | ||
"created": "2024-11-14T11:58:49.757379Z", | ||
} | ||
) | ||
self.migration_manager.migrate_to_latest(stop_at=2) | ||
self.assertEqual(self.migration_manager.db_version, 2) | ||
job = observable_col.all() | ||
while job.status() != "done": | ||
time.sleep(0.1) | ||
obs = list(job.result()) | ||
self.assertEqual(len(obs), 2) | ||
self.assertEqual(obs[0]["value"], "test.com") | ||
self.assertEqual(obs[0]["is_valid"], True) | ||
self.assertEqual(obs[1]["value"], "test.com123") | ||
self.assertEqual(obs[1]["is_valid"], False) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters