From 2ac180a3adc5cb6b3fc3083e0e268034f1d42669 Mon Sep 17 00:00:00 2001 From: Philippe Modard Date: Wed, 6 Dec 2023 12:18:04 +0100 Subject: [PATCH] Use a new KAGGLE_GRPC_DATA_PROXY_URL env variable for gRPC proxying (#1337) http://b/308644984 --------- Co-authored-by: Prathamesh Bang --- patches/sitecustomize.py | 14 +++++++++----- tests/test_google_generativeai_patch.py | 1 + tests/test_google_generativeai_patch_disabled.py | 11 +++++++---- 3 files changed, 17 insertions(+), 9 deletions(-) diff --git a/patches/sitecustomize.py b/patches/sitecustomize.py index 6ac7400e..ea47698b 100644 --- a/patches/sitecustomize.py +++ b/patches/sitecustomize.py @@ -81,8 +81,9 @@ def post_import_logic(module): if os.getenv('KAGGLE_DISABLE_GOOGLE_GENERATIVE_AI_INTEGRATION') != None: return if (os.getenv('KAGGLE_DATA_PROXY_TOKEN') == None or - os.getenv('KAGGLE_USER_SECRETS_TOKEN') == None or - os.getenv('KAGGLE_DATA_PROXY_URL') == None): + os.getenv('KAGGLE_USER_SECRETS_TOKEN') == None or + (os.getenv('KAGGLE_DATA_PROXY_URL') == None and + os.getenv('KAGGLE_GRPC_DATA_PROXY_URL') == None)): return old_configure = module.configure @@ -101,12 +102,15 @@ def new_configure(*args, **kwargs): client_options = kwargs['client_options'] else: client_options = {} - client_options['api_endpoint'] = os.environ['KAGGLE_DATA_PROXY_URL'] + if os.getenv('KAGGLE_GOOGLE_GENERATIVE_AI_USE_REST_ONLY') != None: - client_options['api_endpoint'] += '/palmapi' kwargs['transport'] = 'rest' - elif 'transport' in kwargs and kwargs['transport'] == 'rest': + + if 'transport' in kwargs and kwargs['transport'] == 'rest': + client_options['api_endpoint'] = os.environ['KAGGLE_DATA_PROXY_URL'] client_options['api_endpoint'] += '/palmapi' + else: + client_options['api_endpoint'] = os.environ['KAGGLE_GRPC_DATA_PROXY_URL'] kwargs['client_options'] = client_options old_configure(*args, **kwargs) diff --git a/tests/test_google_generativeai_patch.py b/tests/test_google_generativeai_patch.py index 68e766c2..87ac4ecd 100644 --- a/tests/test_google_generativeai_patch.py +++ b/tests/test_google_generativeai_patch.py @@ -33,6 +33,7 @@ def test_proxy_enabled(self): env.set("KAGGLE_USER_SECRETS_TOKEN", secrets_token) env.set("KAGGLE_DATA_PROXY_TOKEN", proxy_token) env.set("KAGGLE_DATA_PROXY_URL", self.endpoint) + env.set("KAGGLE_GRPC_DATA_PROXY_URL", "http://127.0.0.1:50001") env.set("KAGGLE_GOOGLE_GENERATIVE_AI_USE_REST_ONLY", "True") server_address = urlparse(self.endpoint) with env: diff --git a/tests/test_google_generativeai_patch_disabled.py b/tests/test_google_generativeai_patch_disabled.py index 65e02845..2f34af21 100644 --- a/tests/test_google_generativeai_patch_disabled.py +++ b/tests/test_google_generativeai_patch_disabled.py @@ -14,21 +14,24 @@ def do_HEAD(self): self.send_response(200) def do_GET(self): + print('YO MOD', self.path) HTTPHandler.called = True self.send_response(200) self.send_header("Content-type", "application/json") self.end_headers() class TestGoogleGenerativeAiPatchDisabled(unittest.TestCase): - endpoint = "http://127.0.0.1:80" + http_endpoint = "http://127.0.0.1:80" + grpc_endpoint = "http://127.0.0.1:50001" def test_disabled(self): env = EnvironmentVarGuard() env.set("KAGGLE_USER_SECRETS_TOKEN", "foobar") env.set("KAGGLE_DATA_PROXY_TOKEN", "foobar") - env.set("KAGGLE_DATA_PROXY_URL", self.endpoint) + env.set("KAGGLE_DATA_PROXY_URL", self.http_endpoint) + env.set("KAGGLE_GRPC_DATA_PROXY_URL", self.grpc_endpoint) env.set("KAGGLE_DISABLE_GOOGLE_GENERATIVE_AI_INTEGRATION", "True") - server_address = urlparse(self.endpoint) + server_address = urlparse(self.http_endpoint) with env: with HTTPServer((server_address.hostname, server_address.port), HTTPHandler) as httpd: threading.Thread(target=httpd.serve_forever).start() @@ -40,4 +43,4 @@ def test_disabled(self): except: pass httpd.shutdown() - self.assertFalse(HTTPHandler.called) \ No newline at end of file + self.assertFalse(HTTPHandler.called)