Skip to content

Commit

Permalink
Remove templates from database, move to filesystem (#1141)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomchop authored Sep 30, 2024
1 parent a389eee commit 235f703
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 107 deletions.
43 changes: 31 additions & 12 deletions core/schemas/template.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,22 @@
import os
from typing import TYPE_CHECKING, ClassVar
from pathlib import Path
from typing import TYPE_CHECKING, Optional

import jinja2
from pydantic import BaseModel

from core import database_arango
from core.schemas.model import YetiModel
from core.config.config import yeti_config

if TYPE_CHECKING:
from core.schemas.observable import Observable

# TODO: Import Jinja functions to render templates


class Template(YetiModel, database_arango.ArangoYetiConnector):
class Template(BaseModel):
"""A template for exporting data to an external system."""

_collection_name: ClassVar[str] = "templates"

name: str
template: str

@classmethod
def load(cls, object: dict) -> "Template":
return cls(**object)

def render(self, data: list["Observable"], output_file: str | None) -> None | str:
"""Renders the template with the given data to the output file."""

Expand All @@ -37,3 +30,29 @@ def render(self, data: list["Observable"], output_file: str | None) -> None | st
return None
else:
return result

def save(self) -> "Template":
directory = Path(
yeti_config.get("system", "template_dir", "/opt/yeti/templates")
)
Path.mkdir(directory, parents=True, exist_ok=True)
file = directory / f"{self.name}.jinja2"
file.write_text(self.template)
return self

def delete(self) -> None:
directory = Path(
yeti_config.get("system", "template_dir", "/opt/yeti/templates")
)
file = directory / f"{self.name}.jinja2"
file.unlink()

@classmethod
def find(cls, name: str) -> Optional["Template"]:
directory = Path(
yeti_config.get("system", "template_dir", "/opt/yeti/templates")
)
file = directory / f"{name}.jinja2"
if file.exists():
return Template(name=name, template=file.read_text())
return None
96 changes: 43 additions & 53 deletions core/web/apiv2/templates.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import logging
from pathlib import Path

from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, ConfigDict

from core.config.config import yeti_config
from core.schemas.observable import Observable
from core.schemas.template import Template

Expand All @@ -10,8 +14,7 @@
class TemplateSearchRequest(BaseModel):
model_config = ConfigDict(extra="forbid")

query: dict[str, str | int | list] = {}
sorting: list[tuple[str, bool]] = []
name: str = ""
count: int = 50
page: int = 0

Expand All @@ -23,71 +26,62 @@ class TemplateSearchResponse(BaseModel):
total: int


class PatchTemplateRequest(BaseModel):
model_config = ConfigDict(extra="forbid")
# class PatchTemplateRequest(BaseModel):
# model_config = ConfigDict(extra="forbid")

template: Template
# template: Template


class RenderExportRequest(BaseModel):
class RenderTemplateRequest(BaseModel):
model_config = ConfigDict(extra="forbid")

template_id: str
observable_ids: list[str] | None = None
template_name: str
observable_ids: list[str] = []
search_query: str | None = None


# API endpoints
router = APIRouter()


@router.post("/")
async def new(request: PatchTemplateRequest) -> Template:
"""Creates a new template."""
# TODO: Validate template
return request.template.save()


@router.patch("/{template_id}")
async def update(template_id: str, request: PatchTemplateRequest) -> Template:
"""Updates a template."""
db_template = Template.get(template_id)
if not db_template:
raise HTTPException(
status_code=404, detail=f"Template {template_id} not found."
)
update_data = request.template.model_dump(exclude_unset=True)
updated_template = db_template.model_copy(update=update_data)
new = updated_template.save()
return new


@router.post("/search")
async def search(request: TemplateSearchRequest) -> TemplateSearchResponse:
"""Searches for observables."""
query = request.query
templates, total = Template.filter(
query,
offset=request.page * request.count,
count=request.count,
sorting=request.sorting,
)
glob = "*"
if request.name:
glob = f"*{request.name}*"

template_dir = yeti_config.get("system", "template_dir", "/opt/yeti/templates")
files = []
total = 0
for file in Path(template_dir).rglob(f"{glob}.jinja2"):
total += 1
files.append(file)

files = sorted(files)
templates = []
for file in files[
(request.page * request.count) : ((request.page + 1) * request.count)
]:
template = Template(name=file.stem, template=file.read_text())
templates.append(template)

return TemplateSearchResponse(templates=templates, total=total)


@router.post("/render")
async def render(request: RenderExportRequest) -> StreamingResponse:
async def render(request: RenderTemplateRequest) -> StreamingResponse:
"""Renders a template."""
if not request.search_query and not request.observable_ids:
raise HTTPException(
status_code=400,
detail="Must specify either search_query or observable_ids.",
)

template = Template.get(request.template_id)
template = Template.find(name=request.template_name)
if not template:
raise HTTPException(
status_code=404, detail=f"Template {request.template_id} not found."
status_code=404, detail=f"Template {request.template_name} not found."
)

if request.search_query:
Expand All @@ -97,9 +91,16 @@ async def render(request: RenderExportRequest) -> StreamingResponse:
status_code=404, detail="No observables found for search query."
)
else:
observables = [
Observable.get(observable_id) for observable_id in request.observable_ids
]
observables = []
for observable_id in request.observable_ids:
db_obs = Observable.get(observable_id)
if not db_obs:
logging.warning(
f"Observable with id {observable_id} not found, skipping..."
)
continue
observables.append(db_obs)

data = template.render(observables, None)

def _stream():
Expand All @@ -111,14 +112,3 @@ def _stream():
media_type="text/plain",
headers={"Content-Disposition": f"attachment; filename={template.name}.txt"},
)


@router.delete("/{template_id}")
async def delete(template_id: str):
"""Deletes a template from the database."""
template = Template.get(template_id)
if not template:
raise HTTPException(
status_code=404, detail=f"Template {template_id} not found."
)
template.delete()
1 change: 1 addition & 0 deletions tests/apiv2/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,3 +183,4 @@ def test_delete_export(self):

def tearDown(self) -> None:
database_arango.db.clear()
self.template.delete()
85 changes: 44 additions & 41 deletions tests/apiv2/templates.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import json
import logging
import sys
import unittest
from pathlib import Path

from fastapi.testclient import TestClient

Expand Down Expand Up @@ -30,70 +32,61 @@ def setUp(self) -> None:
"/api/v2/auth/api-token", headers={"x-yeti-apikey": user.api_key}
).json()
client.headers = {"Authorization": "Bearer " + token_data["access_token"]}
self.template = Template(name="FakeTemplate", template=TEST_TEMPLATE).save()
temp_path = Path("/opt/yeti/templates")
temp_path.mkdir(parents=True, exist_ok=True)
self.temp_template_path = temp_path

Template(name="FakeTemplate", template=TEST_TEMPLATE).save()
for i in range(0, 100):
Template(name=f"template_blah_{i:02}", template=f"fake_template_{i}").save()

def tearDown(self) -> None:
for file in Path(self.temp_template_path).rglob("*.jinja2"):
file.unlink()
database_arango.db.clear()

def test_search_template(self):
response = client.post("/api/v2/templates/search", json={"query": {"name": ""}})
response = client.post("/api/v2/templates/search", json={"name": "Fake"})
data = response.json()
self.assertEqual(response.status_code, 200, data)
self.assertEqual(data["templates"][0]["name"], "FakeTemplate")
self.assertEqual(data["total"], 1)

def test_delete_template(self):
response = client.delete(f"/api/v2/templates/{self.template.id}")
self.assertEqual(response.status_code, 200, response.json())
self.assertEqual(Template.get(self.template.id), None)

def test_create_template(self):
response = client.post(
"/api/v2/templates/",
json={"template": {"name": "FakeTemplate2", "template": "<BLAH>"}},
)
def test_pagination(self):
response = client.post("/api/v2/templates/search", json={"name": "blah"})
data = response.json()

self.assertEqual(response.status_code, 200, data)
self.assertEqual(data["name"], "FakeTemplate2")
self.assertEqual(data["template"], "<BLAH>")
self.assertEqual(data["id"], Template.find(name="FakeTemplate2").id)
self.assertEqual(len(data["templates"]), 50)
self.assertEqual(data["templates"][0]["name"], "template_blah_00")
self.assertEqual(data["templates"][49]["name"], "template_blah_49")
self.assertEqual(data["total"], 100)

def test_update_template(self):
response = client.patch(
f"/api/v2/templates/{self.template.id}",
json={
"template": {
"name": "FakeTemplateFoo",
"template": "<FOO>",
}
},
response = client.post(
"/api/v2/templates/search", json={"name": "blah", "page": 3, "count": 5}
)
data = response.json()
self.assertEqual(response.status_code, 200, data)
self.assertEqual(data["name"], "FakeTemplateFoo")
self.assertEqual(data["template"], "<FOO>")
self.assertEqual(data["id"], self.template.id)
db_template = Template.get(self.template.id)
self.assertEqual(db_template.template, "<FOO>")
self.assertEqual(db_template.name, "FakeTemplateFoo")
self.assertEqual(db_template.id, data["id"])
self.assertEqual(len(data["templates"]), 5)
self.assertEqual(data["templates"][0]["name"], "template_blah_15")
self.assertEqual(data["templates"][4]["name"], "template_blah_19")

def test_render_template_by_id(self):
def test_render_template_by_obs_ids(self):
ipv4.IPv4(value="1.1.1.1").save()
ipv4.IPv4(value="2.2.2.2").save()
ipv4.IPv4(value="3.3.3.3").save()
response = client.post(
"/api/v2/templates/render",
json={
"template_id": self.template.id,
"template_name": "FakeTemplate",
"observable_ids": [o.id for o in Observable.list()],
},
)
data = response.text
response.headers["Content-Disposition"] = (
"attachment; filename=FakeTemplate.txt"
)
self.assertEqual(response.status_code, 200, data)
self.assertEqual(
response.headers["Content-Disposition"],
"attachment; filename=FakeTemplate.txt",
)
self.assertEqual(data, "<blah>\n1.1.1.1\n2.2.2.2\n3.3.3.3\n\n</blah>\n")

def test_render_template_by_search(self):
Expand All @@ -103,11 +96,21 @@ def test_render_template_by_search(self):
hostname.Hostname(value="hacker.com").save()
response = client.post(
"/api/v2/templates/render",
json={"template_id": self.template.id, "search_query": "yeti"},
json={"template_name": "FakeTemplate", "search_query": "yeti"},
)
data = response.text
response.headers["Content-Disposition"] = (
"attachment; filename=FakeTemplate.txt"
)
self.assertEqual(response.status_code, 200, data)
self.assertEqual(
response.headers["Content-Disposition"],
"attachment; filename=FakeTemplate.txt",
)
self.assertEqual(data, "<blah>\nyeti1.com\nyeti2.com\nyeti3.com\n\n</blah>\n")

def test_render_nonexistent(self):
response = client.post(
"/api/v2/templates/render",
json={"template_name": "NotExist", "search_query": "yeti"},
)
data = response.text
self.assertEqual(response.status_code, 404, data)
self.assertEqual(json.loads(data), {"detail": "Template NotExist not found."})
2 changes: 1 addition & 1 deletion tests/schemas/fixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def setUp(self) -> None:
database_arango.db.connect(database="yeti_test")
database_arango.db.clear()

def test_something(self):
def general_fixture_test(self):
user = UserSensitive(username="yeti", admin=True, enabled=True)
user.set_password("yeti")
user.save()
Expand Down

0 comments on commit 235f703

Please sign in to comment.