forked from NVIDIA-AI-Blueprints/vulnerability-analysis
-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathinput.py
313 lines (226 loc) · 12.9 KB
/
input.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
import os
import typing
import aiohttp
import appdirs
import json5
import mrc
from pydantic import BaseModel
from morpheus.config import Config
from morpheus.pipeline.linear_pipeline import LinearPipeline
from morpheus.pipeline.stage_decorator import source
from morpheus.pipeline.stage_decorator import stage
from ..data_models.config import FileInputConfig
from ..data_models.config import HttpInputConfig
from ..data_models.config import ManualInputConfig
from ..data_models.config import PluginInputConfig
from ..data_models.config import RunConfig
from ..data_models.cve_intel import CveIntel
from ..data_models.dependencies import VulnerableDependencies
from ..data_models.dependencies import VulnerableSBOMPackage
from ..data_models.info import AgentMorpheusInfo
from ..data_models.input import AgentMorpheusEngineInput
from ..data_models.input import AgentMorpheusInput
from ..data_models.input import FileSBOMInfoInput
from ..data_models.input import ManualSBOMInfoInput
from ..data_models.input import SBOMPackage
from ..data_models.plugin import InputPluginSchema
from ..stages.build_vdb_stage import BuildSourceCodeVdbStage
from ..stages.pydantic_http_stage import PydanticHttpStage
from ..utils.document_embedding import DocumentEmbedding
from ..utils.embedding_loader import EmbeddingLoader
from ..utils.intel_retriever import IntelRetriever
from ..utils.vulnerable_dependency_checker import VulnerableDependencyChecker
logger = logging.getLogger(__name__)
def build_http_input(pipe: LinearPipeline, config: Config, input: HttpInputConfig,
input_schema: type[BaseModel]) -> logging.Logger:
# Create a new logger handler for the HTTP server source. This allows saving the HTTP server logs to a file
# in addition to writing them to the console.
http_server_logger = logging.getLogger(logger.name + ".http_server")
# Log to a file in the user's log directory next to morpheus logs
log_file = os.path.join(appdirs.user_log_dir(appauthor="NVIDIA", appname="morpheus"), "http_server.log")
# Add a file handler
file_handler = logging.FileHandler(log_file)
# Set the format
file_handler.setFormatter(logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s"))
http_server_logger.addHandler(file_handler)
http_server_logger.setLevel(logging.INFO)
http_server_logger.info(f"Http Server Started for {input.http_method.name} "
f"""at {input.address}:{input.port}
{input.endpoint}.""")
pipe.set_source(
PydanticHttpStage(config,
bind_address=input.address,
port=input.port,
endpoint=input.endpoint,
method=input.http_method,
stop_after=input.stop_after,
input_schema=input_schema))
@stage
def print_payload(payload: typing.Any) -> typing.Any:
assert isinstance(payload, BaseModel)
serialized_str = payload.model_dump_json(indent=2)
http_server_logger.info("======= Got Request =======\n%s\n===========================", serialized_str)
return payload
pipe.add_stage(print_payload(config))
return http_server_logger
def build_input(pipe: LinearPipeline, config: Config, run_config: RunConfig):
# Create the base directories for the VDB and source code if they don't already exist
for file_dir in (run_config.general.base_vdb_dir, run_config.general.base_git_dir):
os.makedirs(file_dir, exist_ok=True)
if (run_config.input.type == ManualInputConfig.static_type()):
@source
def emit_input_object(subscription: mrc.Subscription) -> typing.Generator[AgentMorpheusInput, None, None]:
assert isinstance(run_config.input, ManualInputConfig)
for repead_idx in range(run_config.input.repeat_count):
if not subscription.is_subscribed():
return
yield run_config.input.message.model_copy(deep=True)
pipe.set_source(emit_input_object(config))
elif (run_config.input.type == FileInputConfig.static_type()):
assert isinstance(run_config.input, FileInputConfig)
# Read the input data from the file
with open(run_config.input.file, "r") as f:
json_dict = json5.load(fp=f)
# Create the input object
input_obj = AgentMorpheusInput.model_validate(json_dict)
repeat_count = run_config.input.repeat_count
@source
def emit_input_object(subscription: mrc.Subscription) -> typing.Generator[AgentMorpheusInput, None, None]:
for _ in range(repeat_count):
if not subscription.is_subscribed():
return
yield input_obj.model_copy(deep=True)
pipe.set_source(emit_input_object(config))
elif (run_config.input.type == HttpInputConfig.static_type()):
assert isinstance(run_config.input, HttpInputConfig)
build_http_input(pipe, config, run_config.input, AgentMorpheusInput)
elif (run_config.input.type == PluginInputConfig.static_type()):
# Set source based on plugins
plugin = InputPluginSchema.locate(run_config.input.plugin_name)
plugin.build_input(pipe, config, run_config)
else:
raise ValueError(f"Invalid source type: {run_config.input.type}")
# Load embedding model to be used throughout the pipeline
embedding = EmbeddingLoader.create(run_config.engine.rag_embedding.type,
**run_config.engine.rag_embedding.model_dump(exclude={"type"}))
embedder = DocumentEmbedding(embedding=embedding,
vdb_directory=run_config.general.base_vdb_dir,
git_directory=run_config.general.base_git_dir)
build_vdb_stage = BuildSourceCodeVdbStage(config,
build_vdb_fn=embedder.build_vdbs,
ignore_errors=run_config.general.ignore_build_vdb_errors,
ignore_code_embedding=run_config.general.code_search_tool)
pipe.add_stage(build_vdb_stage)
@stage
def fetch_intel(message: AgentMorpheusEngineInput) -> AgentMorpheusEngineInput:
async def _inner():
async with aiohttp.ClientSession() as session:
intel_retriever = IntelRetriever(session=session)
intel_coros = [intel_retriever.retrieve(vuln_id=cve.vuln_id) for cve in message.input.scan.vulns]
intel_objs = await asyncio.gather(*intel_coros)
return intel_objs
result = asyncio.run(_inner())
message.info.intel = result
return message
pipe.add_stage(fetch_intel(config))
@stage
def process_sbom(message: AgentMorpheusEngineInput) -> AgentMorpheusEngineInput:
if (message.input.image.sbom_info.type == ManualSBOMInfoInput.static_type()):
assert isinstance(message.input.image.sbom_info, ManualSBOMInfoInput)
# Create the SBOM object
message.info.sbom = AgentMorpheusInfo.SBOMInfo(packages=message.input.image.sbom_info.packages)
elif (message.input.image.sbom_info.type == FileSBOMInfoInput.static_type()):
assert isinstance(message.input.image.sbom_info, FileSBOMInfoInput)
# Read the file to an object
with open(message.input.image.sbom_info.file_path, "r") as f:
sbom_lines = f.readlines()
# Extract the packages
packages: list[SBOMPackage] = []
# Skip the first header line
for line in sbom_lines[1:]:
split = line.split()
if len(split) < 3:
continue
p_version = split[1]
p_name = split[0]
p_sys = split[2]
packages.append(SBOMPackage(name=p_name, version=p_version, system=p_sys))
# Create the SBOM object
message.info.sbom = AgentMorpheusInfo.SBOMInfo(packages=packages)
return message
pipe.add_stage(process_sbom(config))
@stage
def check_vulnerable_dependencies(message: AgentMorpheusEngineInput) -> AgentMorpheusEngineInput:
"""Check for vulnerable packages in the dependency graph and update the message object."""
sbom = message.info.sbom.packages
image = f"{message.input.image.name}:{message.input.image.tag}"
async def _inner():
async with aiohttp.ClientSession() as session:
vuln_dep_checker = VulnerableDependencyChecker(session=session, image=image, sbom_list=sbom)
await vuln_dep_checker.load_dependencies()
async def _calc_dep(cve_intel: CveIntel):
vuln_id = cve_intel.vuln_id
vuln_deps = []
vuln_package_intel_sources = []
# Check vulnerabilities from GHSA
if (cve_intel.ghsa is not None and cve_intel.ghsa.vulnerabilities):
vuln_package_intel_sources.append("ghsa")
vuln_deps.extend(await vuln_dep_checker.run_ghsa(cve_intel.ghsa.vulnerabilities))
# Check vulnerabilities from NVD
if (cve_intel.nvd is not None and cve_intel.nvd.configurations):
vuln_package_intel_sources.append("nvd")
vuln_deps.extend(await vuln_dep_checker.run_nvd(cve_intel.nvd.configurations))
if not len(vuln_package_intel_sources):
logger.warning("No vulnerabilities were found in either GHSA or NVD intel for %s.", vuln_id)
# Check vulnerabilities from Ubuntu notices
if (cve_intel.ubuntu is not None and hasattr(cve_intel.ubuntu, 'notices')
and cve_intel.ubuntu.notices):
vuln_package_intel_sources.append("ubuntu")
vuln_deps.extend(await vuln_dep_checker.run_ubuntu(cve_intel.ubuntu.notices))
# Check vulnerabilities from RHSA
if (cve_intel.rhsa is not None and hasattr(cve_intel.rhsa, 'affected_release')
and cve_intel.rhsa.affected_release):
vuln_package_intel_sources.append("rhsa")
vuln_deps.extend(await vuln_dep_checker.run_rhsa(cve_intel.rhsa.affected_release))
# Create list of vulnerable SBOM packages and the related vulnerable dependency for the CVE
vulnerable_sbom_packages: list[VulnerableSBOMPackage] = []
for vuln_dep in vuln_deps:
for (sbom_pkg_name, sbom_pkg_version), vuln_dep_pkg in vuln_dep.items():
vulnerable_sbom_packages.append(
VulnerableSBOMPackage(name=sbom_pkg_name,
version=sbom_pkg_version,
vulnerable_dependency_package=vuln_dep_pkg))
if len(vulnerable_sbom_packages) > 0:
logger.info("Found vulnerable dependencies for %s.", vuln_id)
else:
logger.info("No vulnerable dependencies found for %s.", vuln_id)
# Add the vulnerable dependencies for this CVE to the overall list
return VulnerableDependencies(vuln_id=vuln_id,
vuln_package_intel_sources=vuln_package_intel_sources,
vulnerable_sbom_packages=vulnerable_sbom_packages)
# Check vulnerable dependencies for each CVE
vulnerable_dependencies: list[VulnerableDependencies] = await asyncio.gather(
*[_calc_dep(cve_intel) for cve_intel in message.info.intel])
return vulnerable_dependencies
# Update the message info with the vulnerable dependencies list
message.info.vulnerable_dependencies = asyncio.run(_inner())
return message
pipe.add_stage(check_vulnerable_dependencies(config))
# Return embedding model to be used later in the pipeline
return embedding