Skip to content

Commit

Permalink
Use EnvConfigValue for passing env-configured arguments to services (#…
Browse files Browse the repository at this point in the history
…1704)

Contributes to #1701

## By Submitting this PR I confirm:
- I am familiar with the [Contributing Guidelines](https://github.com/nv-morpheus/Morpheus/blob/main/docs/source/developer_guide/contributing.md).
- When the PR is ready for review, new or existing tests cover these changes.
- When the PR is ready for review, the documentation is up to date with these changes.

Authors:
  - Christopher Harris (https://github.com/cwharris)

Approvers:
  - Michael Demoret (https://github.com/mdemoret-nv)

URL: #1704
  • Loading branch information
cwharris authored May 21, 2024
1 parent 2fe4dd3 commit 6c722c7
Show file tree
Hide file tree
Showing 5 changed files with 283 additions and 10 deletions.
42 changes: 35 additions & 7 deletions morpheus/llm/services/nemo_llm_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@

import asyncio
import logging
import os
import typing
import warnings

from morpheus.llm.services.llm_service import LLMClient
from morpheus.llm.services.llm_service import LLMService
from morpheus.utils.env_config_value import EnvConfigValue

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -190,7 +190,24 @@ class NeMoLLMService(LLMService):
A service for interacting with NeMo LLM models, this class should be used to create a client for a specific model.
"""

def __init__(self, *, api_key: str = None, org_id: str = None, retry_count=5) -> None:
class APIKey(EnvConfigValue):
_ENV_KEY: str = "NGC_API_KEY"
_ALLOW_NONE: bool = True

class OrgId(EnvConfigValue):
_ENV_KEY: str = "NGC_ORG_ID"
_ALLOW_NONE: bool = True

class BaseURI(EnvConfigValue):
_ENV_KEY: str = "NGC_API_BASE"
_ALLOW_NONE: bool = True

def __init__(self,
*,
api_key: APIKey | str = None,
org_id: OrgId | str = None,
base_uri: BaseURI | str = None,
retry_count=5) -> None:
"""
Creates a service for interacting with NeMo LLM models.
Expand All @@ -203,6 +220,10 @@ def __init__(self, *, api_key: str = None, org_id: str = None, retry_count=5) ->
The organization ID for the LLM service, by default None. If `None` the organization ID will be read from
the `NGC_ORG_ID` environment variable. This value is only required if the account associated with the
`api_key` is a member of multiple NGC organizations., by default None
base_uri : str, optional
The base URI for the LLM service, by default None. If `None` the base URI will be read from
the `NGC_API_BASE` environment variable. This value is only required if the account associated with the
`api_key` is a member of multiple NGC organizations., by default None
retry_count : int, optional
The number of times to retry a request before raising an exception, by default 5
Expand All @@ -212,22 +233,29 @@ def __init__(self, *, api_key: str = None, org_id: str = None, retry_count=5) ->
raise ImportError(IMPORT_ERROR_MESSAGE) from IMPORT_EXCEPTION

super().__init__()
api_key = api_key if api_key is not None else os.environ.get("NGC_API_KEY", None)
org_id = org_id if org_id is not None else os.environ.get("NGC_ORG_ID", None)

if not isinstance(api_key, NeMoLLMService.APIKey):
api_key = NeMoLLMService.APIKey(api_key)

if not isinstance(org_id, NeMoLLMService.OrgId):
org_id = NeMoLLMService.OrgId(org_id)

if not isinstance(base_uri, NeMoLLMService.BaseURI):
base_uri = NeMoLLMService.BaseURI(base_uri)

self._retry_count = retry_count

self._conn = nemollm.NemoLLM(
api_host=os.environ.get("NGC_API_BASE", None),
api_host=base_uri.value,
# The client must configure the authentication and authorization parameters
# in accordance with the API server security policy.
# Configure Bearer authorization
api_key=api_key,
api_key=api_key.value,

# If you are in more than one LLM-enabled organization, you must
# specify your org ID in the form of a header. This is optional
# if you are only in one LLM-enabled org.
org_id=org_id,
org_id=org_id.value,
)

def get_client(self, *, model_name: str, **model_kwargs) -> NeMoLLMClient:
Expand Down
37 changes: 35 additions & 2 deletions morpheus/llm/services/openai_chat_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from morpheus.llm.services.llm_service import LLMClient
from morpheus.llm.services.llm_service import LLMService
from morpheus.utils.env_config_value import EnvConfigValue

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -65,6 +66,20 @@ def set_output(self, output: typing.Any) -> None:
self.outputs = output


class OpenAIOrgId(EnvConfigValue):
_ENV_KEY: str = "OPENAI_ORG_ID"
_ALLOW_NONE: bool = True


class OpenAIAPIKey(EnvConfigValue):
_ENV_KEY: str = "OPENAI_API_KEY"


class OpenAIBaseURL(EnvConfigValue):
_ENV_KEY: str = "OPENAI_BASE_URL"
_ALLOW_NONE: bool = True


class OpenAIChatClient(LLMClient):
"""
Client for interacting with a specific OpenAI chat model. This class should be constructed with the
Expand Down Expand Up @@ -94,12 +109,24 @@ def __init__(self,
model_name: str,
set_assistant: bool = False,
max_retries: int = 10,
org_id: str | OpenAIOrgId = None,
api_key: str | OpenAIAPIKey = None,
base_url: str | OpenAIBaseURL = None,
**model_kwargs) -> None:
if IMPORT_EXCEPTION is not None:
raise ImportError(IMPORT_ERROR_MESSAGE) from IMPORT_EXCEPTION

super().__init__()

if not isinstance(org_id, OpenAIOrgId):
org_id = OpenAIOrgId(org_id)

if not isinstance(api_key, OpenAIOrgId):
api_key = OpenAIOrgId(api_key)

if not isinstance(base_url, OpenAIBaseURL):
base_url = OpenAIBaseURL(base_url)

assert parent is not None, "Parent service cannot be None."

self._parent = parent
Expand All @@ -113,8 +140,14 @@ def __init__(self,
self._model_kwargs = copy.deepcopy(model_kwargs)

# Create the client objects for both sync and async
self._client = openai.OpenAI(max_retries=max_retries)
self._client_async = openai.AsyncOpenAI(max_retries=max_retries)
self._client = openai.OpenAI(max_retries=max_retries,
organization=org_id.value,
api_key=api_key.value,
base_url=base_url.value)
self._client_async = openai.AsyncOpenAI(max_retries=max_retries,
organization=org_id.value,
api_key=api_key.value,
base_url=base_url.value)

def get_input_names(self) -> list[str]:
input_names = [self._prompt_key]
Expand Down
94 changes: 94 additions & 0 deletions morpheus/utils/env_config_value.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from abc import ABC
from enum import Enum


class EnvConfigValueSource(Enum):
ENV_DEFAULT = 1
CONSTRUCTOR = 2
ENV_OVERRIDE = 3


class EnvConfigValue(ABC):
"""
A wrapper for a string used as a configuration value which can be loaded from the system environment or injected via
the constructor. This class should be subclassed and the class fields `_ENV_KEY` and `_ENV_KEY_OVERRIDE` can be set
to enable environment-loading functionality. Convienience properties are available to check from where the value was
loaded.
"""

_ENV_KEY: str | None = None
_ENV_KEY_OVERRIDE: str | None = None
_ALLOW_NONE: bool = False

def __init__(self, value: str | None = None, use_env: bool = True):
"""
Parameters
----------
value : str, optional
The value to be contained in the EnvConfigValue. If the value is `None`, an attempt will be made to load it
from the environment using `_ENV_KEY`. if the `_ENV_KEY_OVERRIDE` field is not `None`, an attempt will be
made to load that environment variable in place of the passed-in value.
use_env : bool
If False, all environment-loading logic will be bypassed and the passed-in value will be used as-is.
defaults to True.
"""

self._source = EnvConfigValueSource.CONSTRUCTOR

if use_env:
if value is None and self.__class__._ENV_KEY is not None:
value = os.environ.get(self.__class__._ENV_KEY, None)
self._source = EnvConfigValueSource.ENV_DEFAULT

if self.__class__._ENV_KEY_OVERRIDE is not None and self.__class__._ENV_KEY_OVERRIDE in os.environ:
value = os.environ[self.__class__._ENV_KEY_OVERRIDE]
self._source = EnvConfigValueSource.ENV_OVERRIDE

if not self.__class__._ALLOW_NONE and value is None:

message = ("value must not be None, but provided value was None and no environment-based default or "
"override was found.")

if self.__class__._ENV_KEY is None:
raise ValueError(message)

raise ValueError(
f"{message} Try passing a value to the constructor, or setting the `{self.__class__._ENV_KEY}` "
"environment variable.")

else:
if not self.__class__._ALLOW_NONE and value is None:
raise ValueError("value must not be none")

assert isinstance(value, str) or value is None

self._value = value
self._use_env = use_env

@property
def source(self) -> EnvConfigValueSource:
return self._source

@property
def use_env(self) -> bool:
return self._use_env

@property
def value(self) -> str | None:
return self._value
2 changes: 1 addition & 1 deletion tests/llm/services/test_openai_chat_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def test_constructor(mock_chat_completion: tuple[mock.MagicMock, mock.MagicMock]
assert isinstance(client, LLMClient)

for mock_client in mock_chat_completion:
mock_client.assert_called_once_with(max_retries=max_retries)
mock_client.assert_called_once_with(max_retries=max_retries, organization=None, api_key=None, base_url=None)


@pytest.mark.parametrize("use_async", [True, False])
Expand Down
118 changes: 118 additions & 0 deletions tests/test_env_config_value.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from unittest import mock

import pytest

from morpheus.utils.env_config_value import EnvConfigValue
from morpheus.utils.env_config_value import EnvConfigValueSource


class EnvDrivenValue(EnvConfigValue):
_ENV_KEY = "DEFAULT"
_ENV_KEY_OVERRIDE = "OVERRIDE"


def test_env_driven_value():
with mock.patch.dict(os.environ, clear=True, values={"DEFAULT": "default.api.com"}):

config = EnvDrivenValue()
assert config.value == "default.api.com"
assert config.source == EnvConfigValueSource.ENV_DEFAULT
assert config.use_env

with pytest.raises(ValueError):
config = EnvDrivenValue(use_env=False)

config = EnvDrivenValue("api.com")
assert config.value == "api.com"
assert config.source == EnvConfigValueSource.CONSTRUCTOR
assert config.use_env

with mock.patch.dict(os.environ, clear=True, values={"OVERRIDE": "override.api.com"}):

config = EnvDrivenValue("api.com")
assert config.value == "override.api.com"
assert config.source == EnvConfigValueSource.ENV_OVERRIDE
assert config.use_env

config = EnvDrivenValue("api.com", use_env=False)
assert config.value == "api.com"
assert config.source == EnvConfigValueSource.CONSTRUCTOR
assert not config.use_env


class EnvDriverValueNoOverride(EnvConfigValue):
_ENV_KEY = "DEFAULT"


def test_env_driven_value_no_override():
with mock.patch.dict(os.environ, clear=True, values={"DEFAULT": "default.api.com"}):

config = EnvDriverValueNoOverride()
assert config.value == "default.api.com"
assert config.source == EnvConfigValueSource.ENV_DEFAULT
assert config.use_env

with pytest.raises(ValueError):
config = EnvDriverValueNoOverride(use_env=False)

config = EnvDriverValueNoOverride("api.com")
assert config.value == "api.com"
assert config.source == EnvConfigValueSource.CONSTRUCTOR
assert config.use_env

with mock.patch.dict(os.environ, clear=True, values={"OVERRIDE": "override.api.com"}):

config = EnvDriverValueNoOverride("api.com")
assert config.value == "api.com"
assert config.source == EnvConfigValueSource.CONSTRUCTOR
assert config.use_env


class EnvDrivenValueNoDefault(EnvConfigValue):
_ENV_KEY_OVERRIDE = "OVERRIDE"


def test_env_driven_value_no_default():
with mock.patch.dict(os.environ, clear=True, values={"DEFAULT": "default.api.com"}):

with pytest.raises(ValueError):
config = EnvDrivenValueNoDefault()

config = EnvDrivenValueNoDefault("api.com")
assert config.value == "api.com"
assert config.source == EnvConfigValueSource.CONSTRUCTOR
assert config.use_env

with mock.patch.dict(os.environ, clear=True, values={"OVERRIDE": "override.api.com"}):

config = EnvDrivenValueNoDefault("api.com")
assert config.value == "override.api.com"
assert config.source == EnvConfigValueSource.ENV_OVERRIDE
assert config.use_env


class EnvOptionalValue(EnvConfigValue):
_ALLOW_NONE = True


def test_env_optional_value():
config = EnvOptionalValue()
assert config.value is None
assert config.source == EnvConfigValueSource.CONSTRUCTOR
assert config.use_env

0 comments on commit 6c722c7

Please sign in to comment.