Skip to content

Commit

Permalink
Schemas validation and easy creation (#1159)
Browse files Browse the repository at this point in the history
Co-authored-by: Thomas Chopitea <[email protected]>
  • Loading branch information
udgover and tomchop authored Nov 7, 2024
1 parent 5950b20 commit c8bd5c5
Show file tree
Hide file tree
Showing 41 changed files with 1,644 additions and 795 deletions.
28 changes: 27 additions & 1 deletion core/schemas/entity.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import datetime
from enum import Enum
from typing import ClassVar, Literal
from typing import ClassVar, List, Literal

from pydantic import Field, computed_field

Expand Down Expand Up @@ -78,3 +78,29 @@ def add_context(
context["source"] = source
self.context.append(context)
return self.save()


def create(*, name: str, type: str, **kwargs) -> "EntityTypes":
"""
Create an entity of the given type without saving it to the database.
type is a string representing the type of entity to create.
If the type is not valid, a ValueError is raised.
kwargs must contain "name" fields and will be handled by
pydantic.
"""
if type not in TYPE_MAPPING:
raise ValueError(f"{type} is not a valid entity type")
return TYPE_MAPPING[type](name=name, **kwargs)


def save(*, name: str, type: str, tags: List[str] = None, **kwargs):
indicator_obj = create(name=name, type=type, **kwargs).save()
if tags:
indicator_obj.tag(tags)
return indicator_obj


def find(*, name: str, **kwargs) -> "EntityTypes":
return Entity.find(name=name, **kwargs)
40 changes: 39 additions & 1 deletion core/schemas/indicator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import datetime
import logging
from enum import Enum
from typing import ClassVar, Literal
from typing import ClassVar, List, Literal

from pydantic import BaseModel, Field, computed_field

Expand Down Expand Up @@ -88,3 +88,41 @@ def search(cls, observables: list[str]) -> list[tuple[str, "Indicator"]]:
logging.error(
f"Indicator type {indicator.type} has not implemented match(): {error}"
)


def create(
*, name: str, type: str, pattern: str, diamond: DiamondModel, **kwargs
) -> "IndicatorTypes":
"""
Create an indicator of the given type without saving it to the database.
type is a string representing the type of indicator to create.
If the type is not valid, a ValueError is raised.
kwargs must contain "name" and "diamond" fields and will be handled by
pydantic.
"""
if type not in TYPE_MAPPING:
raise ValueError(f"{type} is not a valid indicator type")
return TYPE_MAPPING[type](name=name, pattern=pattern, diamond=diamond, **kwargs)


def save(
*,
name: str,
type: str,
pattern: str,
diamond: DiamondModel,
tags: List[str] = None,
**kwargs,
):
indicator_obj = create(
name=name, type=type, pattern=pattern, diamond=diamond, **kwargs
).save()
if tags:
indicator_obj.tag(tags)
return indicator_obj


def find(*, name: str, **kwargs) -> "IndicatorTypes":
return Indicator.find(name=name, **kwargs)
216 changes: 180 additions & 36 deletions core/schemas/observable.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
# TODO Observable value normalization

import datetime
import re
import io
import os
import tempfile

# Data Schema
# Dynamically register all observable types
from enum import Enum

# from enum import Enum, EnumMeta
from typing import Any, ClassVar, Literal
from typing import IO, ClassVar, List, Literal, Tuple

import requests
from bs4 import BeautifulSoup
from pydantic import Field, computed_field

from core import database_arango
Expand All @@ -24,6 +28,7 @@ class ObservableType(str, Enum): ...

ObservableTypes = ()
TYPE_MAPPING = {}
FileLikeObject = str | os.PathLike | IO | tempfile.SpooledTemporaryFile


class Observable(YetiTagModel, database_arango.ArangoYetiConnector):
Expand All @@ -47,37 +52,6 @@ def load(cls, object: dict) -> "ObservableTypes": # noqa: F821
return TYPE_MAPPING[object["type"]](**object)
raise ValueError("Attempted to instantiate an undefined observable type.")

@staticmethod
def is_valid(value: Any) -> bool:
return False

@classmethod
def add_text(cls, text: str, tags: list[str] = []) -> "ObservableTypes": # noqa: F821
"""Adds and returns an observable for a given string.
Args:
text: the text that will be used to add an Observable from.
tags: a list of tags to add to the Observable.
Returns:
A saved Observable instance.
"""
refanged = refang(text)
observable_type = find_type(refanged)
if not observable_type:
raise ValueError(f"Invalid type for observable '{text}'")

observable = Observable.find(value=refanged)
if not observable:
observable = TYPE_MAPPING[observable_type](
value=refanged,
created=datetime.datetime.now(datetime.timezone.utc),
).save()
observable.get_tags()
if tags:
observable = observable.tag(tags)
return observable

def add_context(
self,
source: str,
Expand Down Expand Up @@ -128,8 +102,178 @@ def delete_context(
return self.save()


def find_type(value: str) -> ObservableType | None:
def guess_type(value: str) -> str | None:
"""
Guess the type of an observable based on its value.
Returns the type if it can be guessed, otherwise None.
"""
for obs_type, obj in TYPE_MAPPING.items():
if obj.is_valid(value):
return obs_type
if not hasattr(obj, "validate_value"):
continue
try:
if obj.validate_value(value):
return obs_type
except ValueError:
continue
return None


def create(*, value: str, type: str | None = None, **kwargs) -> ObservableTypes:
"""
Create an observable object without saving it to the database.
value argument representing the value of the observable.
if kwargs does not contain a "type" field, type will be automatically
determined based on the value. If the type is not recognized, a ValueError
will be raised.
"""
if not type or type == "guess":
type = guess_type(value)
if not type:
raise ValueError(f"Invalid type for observable '{value}'")
elif type not in TYPE_MAPPING:
raise ValueError(f"{type} is not a valid observable type")
return TYPE_MAPPING[type](value=value, **kwargs)


def save(
*, value: str, type: str | None = None, tags: List[str] = None, **kwargs
) -> ObservableTypes:
"""
Save an observable object. If the object is already in the database, it will be updated.
kwargs must contain a "value" field representing the of the observable.
if kwargs does not contain a "type" field, type will be automatically
determined based on the value. If the type is not recognized, a ValueError will be raised.
tags is an optional list of tags to add to the observable.
"""
observable_obj = create(value=value, type=type, **kwargs).save()
if tags:
observable_obj.tag(tags)
return observable_obj


def find(*, value, **kwargs) -> ObservableTypes:
return Observable.find(value=refang(value), **kwargs)


def create_from_text(text: str) -> Tuple[List["ObservableTypes"], List[str]]:
"""
Create a list of observables from a block of text.
The text is split into lines and each line is used to create an observable.
"""
unknown = list()
observables = list()
for line in text.split("\n"):
line = line.strip()
if not line:
continue
try:
obs = create(value=line)
observables.append(obs)
except ValueError:
unknown.append(line)
return observables, unknown


def save_from_text(
text: str, tags: List[str] = None
) -> Tuple[List["ObservableTypes"], List[str]]:
"""
Save a list of observables from a block of text.
The text is split into lines and each line is used to create and save an observable.
"""
saved_observables = []
observables, unknown = create_from_text(text)
for obs in observables:
obs = obs.save()
if tags:
obs.tag(tags)
saved_observables.append(obs)
return saved_observables, unknown


def create_from_file(file: FileLikeObject) -> Tuple[List["ObservableTypes"], List[str]]:
"""
Create a list of observables from a block of text.
The text is split into lines and each line is used to create an observable.
"""
opened = False
if isinstance(file, (str, bytes, os.PathLike)):
f = open(file, "r", encoding="utf-8")
opened = True
elif isinstance(file, (io.IOBase, tempfile.SpooledTemporaryFile)):
f = file
else:
raise ValueError("Invalid file type")
observables = list()
unknown = list()
for line in f.readlines():
if isinstance(line, bytes):
line = line.decode("utf-8")
line = line.strip()
if not line:
continue
try:
obs = create(value=line)
observables.append(obs)
except ValueError:
unknown.append(line)
if opened:
f.close()
return observables, unknown


def save_from_file(
file: FileLikeObject, tags: List[str] = None
) -> Tuple[List["ObservableTypes"], List[str]]:
"""
Save a list of observables from a block of text.
The text is split into lines and each line is used to create and save an observable.
"""
observables, unknown = create_from_file(file)
saved_observables = list()
for obs in observables:
obs = obs.save()
if tags:
obs.tag(tags)
saved_observables.append(obs)
return saved_observables, unknown


def create_from_url(url: str) -> Tuple[List["ObservableTypes"], List[str]]:
"""
Create a list of observables from a URL.
The URL is fetched and the content is split into lines. Each line is used to create an observable.
"""
response = requests.get(url)
response.raise_for_status()
soup = BeautifulSoup(response.text, "html.parser")
return create_from_text(soup.get_text())


def save_from_url(
url: str, tags: List[str] = None
) -> Tuple[List["ObservableTypes"], List[str]]:
"""
Save a list of observables from a URL.
The URL is fetched and the content is split into lines. Each line is used to create and save an observable.
"""
saved_observables = []
observables, unknown = create_from_url(url)
for obs in observables:
obs = obs.save()
if tags:
obs.tag(tags)
saved_observables.append(obs)
return saved_observables, unknown
11 changes: 8 additions & 3 deletions core/schemas/observables/bic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import re
from typing import Literal

from pydantic import field_validator

from core.schemas import observable

BIC_MATCHER_REGEX = re.compile("^[A-Z]{6}[A-Z0-9]{2}[A-Z0-9]{3}?")
Expand All @@ -9,6 +11,9 @@
class BIC(observable.Observable):
type: Literal["bic"] = "bic"

@staticmethod
def is_valid(value: str) -> bool:
return BIC_MATCHER_REGEX.match(value)
@field_validator("value")
@classmethod
def validate_value(cls, value: str) -> str:
if not BIC_MATCHER_REGEX.match(value):
raise ValueError("Invalid BIC")
return value
11 changes: 8 additions & 3 deletions core/schemas/observables/email.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
from typing import Literal

import validators
from pydantic import field_validator

from core.schemas import observable


class Email(observable.Observable):
type: Literal["email"] = "email"

@staticmethod
def is_valid(value: str) -> bool:
return validators.email(value)
@field_validator("value")
@classmethod
def validate_value(cls, value: str) -> str:
value = observable.refang(value)
if not validators.email(value):
raise ValueError("Invalid email address")
return value
Loading

0 comments on commit c8bd5c5

Please sign in to comment.