Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use EnvConfigValue for passing env-configured arguments to services #1704

Merged
merged 15 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 35 additions & 7 deletions morpheus/llm/services/nemo_llm_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@

import asyncio
import logging
import os
import typing
import warnings

from morpheus.llm.services.llm_service import LLMClient
from morpheus.llm.services.llm_service import LLMService
from morpheus.utils.env_config_value import EnvConfigValue

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -190,7 +190,24 @@ class NeMoLLMService(LLMService):
A service for interacting with NeMo LLM models, this class should be used to create a client for a specific model.
"""

def __init__(self, *, api_key: str = None, org_id: str = None, retry_count=5) -> None:
class APIKey(EnvConfigValue):
_ENV_KEY: str = "NGC_API_KEY"
_ALLOW_NONE: bool = True

class OrgId(EnvConfigValue):
_ENV_KEY: str = "NGC_ORG_ID"
_ALLOW_NONE: bool = True

class BaseURI(EnvConfigValue):
_ENV_KEY: str = "NGC_API_BASE"
_ALLOW_NONE: bool = True

def __init__(self,
*,
api_key: APIKey | str = None,
org_id: OrgId | str = None,
base_uri: BaseURI | str = None,
retry_count=5) -> None:
"""
Creates a service for interacting with NeMo LLM models.

Expand All @@ -203,6 +220,10 @@ def __init__(self, *, api_key: str = None, org_id: str = None, retry_count=5) ->
The organization ID for the LLM service, by default None. If `None` the organization ID will be read from
the `NGC_ORG_ID` environment variable. This value is only required if the account associated with the
`api_key` is a member of multiple NGC organizations., by default None
base_uri : str, optional
The base URI for the LLM service, by default None. If `None` the base URI will be read from
the `NGC_API_BASE` environment variable. This value is only required if the account associated with the
`api_key` is a member of multiple NGC organizations., by default None
retry_count : int, optional
The number of times to retry a request before raising an exception, by default 5

Expand All @@ -212,22 +233,29 @@ def __init__(self, *, api_key: str = None, org_id: str = None, retry_count=5) ->
raise ImportError(IMPORT_ERROR_MESSAGE) from IMPORT_EXCEPTION

super().__init__()
api_key = api_key if api_key is not None else os.environ.get("NGC_API_KEY", None)
org_id = org_id if org_id is not None else os.environ.get("NGC_ORG_ID", None)

if not isinstance(api_key, NeMoLLMService.APIKey):
api_key = NeMoLLMService.APIKey(api_key)

if not isinstance(org_id, NeMoLLMService.OrgId):
org_id = NeMoLLMService.OrgId(org_id)

if not isinstance(base_uri, NeMoLLMService.BaseURI):
base_uri = NeMoLLMService.BaseURI(base_uri)

self._retry_count = retry_count

self._conn = nemollm.NemoLLM(
api_host=os.environ.get("NGC_API_BASE", None),
api_host=base_uri.value,
# The client must configure the authentication and authorization parameters
# in accordance with the API server security policy.
# Configure Bearer authorization
api_key=api_key,
api_key=api_key.value,

# If you are in more than one LLM-enabled organization, you must
# specify your org ID in the form of a header. This is optional
# if you are only in one LLM-enabled org.
org_id=org_id,
org_id=org_id.value,
)

def get_client(self, *, model_name: str, **model_kwargs) -> NeMoLLMClient:
Expand Down
37 changes: 35 additions & 2 deletions morpheus/llm/services/openai_chat_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from morpheus.llm.services.llm_service import LLMClient
from morpheus.llm.services.llm_service import LLMService
from morpheus.utils.env_config_value import EnvConfigValue

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -65,6 +66,20 @@ def set_output(self, output: typing.Any) -> None:
self.outputs = output


class OpenAIOrgId(EnvConfigValue):
_ENV_KEY: str = "OPENAI_ORG_ID"
_ALLOW_NONE: bool = True


class OpenAIAPIKey(EnvConfigValue):
_ENV_KEY: str = "OPENAI_API_KEY"


class OpenAIBaseURL(EnvConfigValue):
_ENV_KEY: str = "OPENAI_BASE_URL"
_ALLOW_NONE: bool = True


class OpenAIChatClient(LLMClient):
"""
Client for interacting with a specific OpenAI chat model. This class should be constructed with the
Expand Down Expand Up @@ -94,12 +109,24 @@ def __init__(self,
model_name: str,
set_assistant: bool = False,
max_retries: int = 10,
org_id: str | OpenAIOrgId = None,
api_key: str | OpenAIAPIKey = None,
base_url: str | OpenAIBaseURL = None,
**model_kwargs) -> None:
if IMPORT_EXCEPTION is not None:
raise ImportError(IMPORT_ERROR_MESSAGE) from IMPORT_EXCEPTION

super().__init__()

if not isinstance(org_id, OpenAIOrgId):
org_id = OpenAIOrgId(org_id)

if not isinstance(api_key, OpenAIOrgId):
api_key = OpenAIOrgId(api_key)

if not isinstance(base_url, OpenAIBaseURL):
base_url = OpenAIBaseURL(base_url)

assert parent is not None, "Parent service cannot be None."

self._parent = parent
Expand All @@ -113,8 +140,14 @@ def __init__(self,
self._model_kwargs = copy.deepcopy(model_kwargs)

# Create the client objects for both sync and async
self._client = openai.OpenAI(max_retries=max_retries)
self._client_async = openai.AsyncOpenAI(max_retries=max_retries)
self._client = openai.OpenAI(max_retries=max_retries,
organization=org_id.value,
api_key=api_key.value,
base_url=base_url.value)
self._client_async = openai.AsyncOpenAI(max_retries=max_retries,
organization=org_id.value,
api_key=api_key.value,
base_url=base_url.value)

def get_input_names(self) -> list[str]:
input_names = [self._prompt_key]
Expand Down
94 changes: 94 additions & 0 deletions morpheus/utils/env_config_value.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from abc import ABC
from enum import Enum


class EnvConfigValueSource(Enum):
ENV_DEFAULT = 1
CONSTRUCTOR = 2
ENV_OVERRIDE = 3


class EnvConfigValue(ABC):
"""
A wrapper for a string used as a configuration value which can be loaded from the system environment or injected via
the constructor. This class should be subclassed and the class fields `_ENV_KEY` and `_ENV_KEY_OVERRIDE` can be set
to enable environment-loading functionality. Convienience properties are available to check from where the value was
loaded.
"""

_ENV_KEY: str | None = None
_ENV_KEY_OVERRIDE: str | None = None
_ALLOW_NONE: bool = False
cwharris marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, value: str | None = None, use_env: bool = True):
"""
Parameters
----------
value : str, optional
The value to be contained in the EnvConfigValue. If the value is `None`, an attempt will be made to load it
from the environment using `_ENV_KEY`. if the `_ENV_KEY_OVERRIDE` field is not `None`, an attempt will be
made to load that environment variable in place of the passed-in value.
use_env : bool
If False, all environment-loading logic will be bypassed and the passed-in value will be used as-is.
defaults to True.
"""

self._source = EnvConfigValueSource.CONSTRUCTOR

if use_env:
if value is None and self.__class__._ENV_KEY is not None:
value = os.environ.get(self.__class__._ENV_KEY, None)
self._source = EnvConfigValueSource.ENV_DEFAULT

if self.__class__._ENV_KEY_OVERRIDE is not None and self.__class__._ENV_KEY_OVERRIDE in os.environ:
value = os.environ[self.__class__._ENV_KEY_OVERRIDE]
self._source = EnvConfigValueSource.ENV_OVERRIDE

if not self.__class__._ALLOW_NONE and value is None:

message = ("value must not be None, but provided value was None and no environment-based default or "
"override was found.")

if self.__class__._ENV_KEY is None:
raise ValueError(message)

raise ValueError(
f"{message} Try passing a value to the constructor, or setting the `{self.__class__._ENV_KEY}` "
"environment variable.")

else:
if not self.__class__._ALLOW_NONE and value is None:
raise ValueError("value must not be none")

assert isinstance(value, str) or value is None

self._value = value
self._use_env = use_env

@property
def source(self) -> EnvConfigValueSource:
return self._source

@property
def use_env(self) -> bool:
return self._use_env

@property
def value(self) -> str | None:
return self._value
2 changes: 1 addition & 1 deletion tests/llm/services/test_openai_chat_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def test_constructor(mock_chat_completion: tuple[mock.MagicMock, mock.MagicMock]
assert isinstance(client, LLMClient)

for mock_client in mock_chat_completion:
mock_client.assert_called_once_with(max_retries=max_retries)
mock_client.assert_called_once_with(max_retries=max_retries, organization=None, api_key=None, base_url=None)


@pytest.mark.parametrize("use_async", [True, False])
Expand Down
118 changes: 118 additions & 0 deletions tests/test_env_config_value.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from unittest import mock

import pytest

from morpheus.utils.env_config_value import EnvConfigValue
from morpheus.utils.env_config_value import EnvConfigValueSource


class EnvDrivenValue(EnvConfigValue):
_ENV_KEY = "DEFAULT"
_ENV_KEY_OVERRIDE = "OVERRIDE"


def test_env_driven_value():
with mock.patch.dict(os.environ, clear=True, values={"DEFAULT": "default.api.com"}):

config = EnvDrivenValue()
assert config.value == "default.api.com"
assert config.source == EnvConfigValueSource.ENV_DEFAULT
assert config.use_env

with pytest.raises(ValueError):
config = EnvDrivenValue(use_env=False)

config = EnvDrivenValue("api.com")
assert config.value == "api.com"
assert config.source == EnvConfigValueSource.CONSTRUCTOR
assert config.use_env

with mock.patch.dict(os.environ, clear=True, values={"OVERRIDE": "override.api.com"}):

config = EnvDrivenValue("api.com")
assert config.value == "override.api.com"
assert config.source == EnvConfigValueSource.ENV_OVERRIDE
assert config.use_env

config = EnvDrivenValue("api.com", use_env=False)
assert config.value == "api.com"
assert config.source == EnvConfigValueSource.CONSTRUCTOR
assert not config.use_env


class EnvDriverValueNoOverride(EnvConfigValue):
_ENV_KEY = "DEFAULT"


def test_env_driven_value_no_override():
with mock.patch.dict(os.environ, clear=True, values={"DEFAULT": "default.api.com"}):

config = EnvDriverValueNoOverride()
assert config.value == "default.api.com"
assert config.source == EnvConfigValueSource.ENV_DEFAULT
assert config.use_env

with pytest.raises(ValueError):
config = EnvDriverValueNoOverride(use_env=False)

config = EnvDriverValueNoOverride("api.com")
assert config.value == "api.com"
assert config.source == EnvConfigValueSource.CONSTRUCTOR
assert config.use_env

with mock.patch.dict(os.environ, clear=True, values={"OVERRIDE": "override.api.com"}):

config = EnvDriverValueNoOverride("api.com")
assert config.value == "api.com"
assert config.source == EnvConfigValueSource.CONSTRUCTOR
assert config.use_env


class EnvDrivenValueNoDefault(EnvConfigValue):
_ENV_KEY_OVERRIDE = "OVERRIDE"


def test_env_driven_value_no_default():
with mock.patch.dict(os.environ, clear=True, values={"DEFAULT": "default.api.com"}):

with pytest.raises(ValueError):
config = EnvDrivenValueNoDefault()

config = EnvDrivenValueNoDefault("api.com")
assert config.value == "api.com"
assert config.source == EnvConfigValueSource.CONSTRUCTOR
assert config.use_env

with mock.patch.dict(os.environ, clear=True, values={"OVERRIDE": "override.api.com"}):

config = EnvDrivenValueNoDefault("api.com")
assert config.value == "override.api.com"
assert config.source == EnvConfigValueSource.ENV_OVERRIDE
assert config.use_env


class EnvOptionalValue(EnvConfigValue):
_ALLOW_NONE = True


def test_env_optional_value():
config = EnvOptionalValue()
assert config.value is None
assert config.source == EnvConfigValueSource.CONSTRUCTOR
assert config.use_env
Loading