From 51d9fd48083de32e9796f00030b57be10a8ac7bd Mon Sep 17 00:00:00 2001 From: Ben Browning Date: Fri, 11 Jul 2025 16:38:27 -0400 Subject: [PATCH] fix: Don't cache clients for passthrough auth providers (#2728) # What does this PR do? Some of our inference providers support passthrough authentication via `x-llamastack-provider-data` header values. This fixes the providers that support passthrough auth to not cache their clients to the backend providers (mostly OpenAI client instances) so that the client connecting to Llama Stack has to provide those auth values on each and every request. ## Test Plan I added some unit tests to ensure we're not caching clients across requests for all the fixed providers in this PR. ``` uv run pytest -sv tests/unit/providers/inference/test_inference_client_caching.py ``` I also ran some of our OpenAI compatible API integration tests for each of the changed providers, just to ensure they still work. Note that these providers don't actually pass all these tests (for unrelated reasons due to quirks of the Groq and Together SaaS services), but enough of the tests passed to confirm the clients are still working as intended. ### Together ``` ENABLE_TOGETHER="together" \ uv run llama stack run llama_stack/templates/starter/run.yaml LLAMA_STACK_CONFIG=http://localhost:8321 \ uv run pytest -sv \ tests/integration/inference/test_openai_completion.py \ --text-model "together/meta-llama/Llama-3.1-8B-Instruct" ``` ### OpenAI ``` ENABLE_OPENAI="openai" \ uv run llama stack run llama_stack/templates/starter/run.yaml LLAMA_STACK_CONFIG=http://localhost:8321 \ uv run pytest -sv \ tests/integration/inference/test_openai_completion.py \ --text-model "openai/gpt-4o-mini" ``` ### Groq ``` ENABLE_GROQ="groq" \ uv run llama stack run llama_stack/templates/starter/run.yaml LLAMA_STACK_CONFIG=http://localhost:8321 \ uv run pytest -sv \ tests/integration/inference/test_openai_completion.py \ --text-model "groq/meta-llama/Llama-3.1-8B-Instruct" ``` --------- Signed-off-by: Ben Browning --- .../providers/remote/inference/groq/groq.py | 14 +-- .../remote/inference/openai/openai.py | 14 +-- .../remote/inference/together/together.py | 47 ++++------ pyproject.toml | 2 + .../test_inference_client_caching.py | 73 +++++++++++++++ uv.lock | 91 +++++++++++++++++++ 6 files changed, 196 insertions(+), 45 deletions(-) create mode 100644 tests/unit/providers/inference/test_inference_client_caching.py diff --git a/llama_stack/providers/remote/inference/groq/groq.py b/llama_stack/providers/remote/inference/groq/groq.py index 4b295e788..91c6b6c17 100644 --- a/llama_stack/providers/remote/inference/groq/groq.py +++ b/llama_stack/providers/remote/inference/groq/groq.py @@ -38,24 +38,18 @@ class GroqInferenceAdapter(LiteLLMOpenAIMixin): provider_data_api_key_field="groq_api_key", ) self.config = config - self._openai_client = None async def initialize(self): await super().initialize() async def shutdown(self): await super().shutdown() - if self._openai_client: - await self._openai_client.close() - self._openai_client = None def _get_openai_client(self) -> AsyncOpenAI: - if not self._openai_client: - self._openai_client = AsyncOpenAI( - base_url=f"{self.config.url}/openai/v1", - api_key=self.config.api_key, - ) - return self._openai_client + return AsyncOpenAI( + base_url=f"{self.config.url}/openai/v1", + api_key=self.get_api_key(), + ) async def openai_chat_completion( self, diff --git a/llama_stack/providers/remote/inference/openai/openai.py b/llama_stack/providers/remote/inference/openai/openai.py index 72428422f..818883919 100644 --- a/llama_stack/providers/remote/inference/openai/openai.py +++ b/llama_stack/providers/remote/inference/openai/openai.py @@ -59,9 +59,6 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin): # if we do not set this, users will be exposed to the # litellm specific model names, an abstraction leak. self.is_openai_compat = True - self._openai_client = AsyncOpenAI( - api_key=self.config.api_key, - ) async def initialize(self) -> None: await super().initialize() @@ -69,6 +66,11 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin): async def shutdown(self) -> None: await super().shutdown() + def _get_openai_client(self) -> AsyncOpenAI: + return AsyncOpenAI( + api_key=self.get_api_key(), + ) + async def openai_completion( self, model: str, @@ -120,7 +122,7 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin): user=user, suffix=suffix, ) - return await self._openai_client.completions.create(**params) + return await self._get_openai_client().completions.create(**params) async def openai_chat_completion( self, @@ -176,7 +178,7 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin): top_p=top_p, user=user, ) - return await self._openai_client.chat.completions.create(**params) + return await self._get_openai_client().chat.completions.create(**params) async def openai_embeddings( self, @@ -204,7 +206,7 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin): params["user"] = user # Call OpenAI embeddings API - response = await self._openai_client.embeddings.create(**params) + response = await self._get_openai_client().embeddings.create(**params) data = [] for i, embedding_data in enumerate(response.data): diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index 9e6877b7c..e1eb934c5 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -68,19 +68,12 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi def __init__(self, config: TogetherImplConfig) -> None: ModelRegistryHelper.__init__(self, MODEL_ENTRIES) self.config = config - self._client = None - self._openai_client = None async def initialize(self) -> None: pass async def shutdown(self) -> None: - if self._client: - # Together client has no close method, so just set to None - self._client = None - if self._openai_client: - await self._openai_client.close() - self._openai_client = None + pass async def completion( self, @@ -108,29 +101,25 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi return await self._nonstream_completion(request) def _get_client(self) -> AsyncTogether: - if not self._client: - together_api_key = None - config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None - if config_api_key: - together_api_key = config_api_key - else: - provider_data = self.get_request_provider_data() - if provider_data is None or not provider_data.together_api_key: - raise ValueError( - 'Pass Together API Key in the header X-LlamaStack-Provider-Data as { "together_api_key": }' - ) - together_api_key = provider_data.together_api_key - self._client = AsyncTogether(api_key=together_api_key) - return self._client + together_api_key = None + config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None + if config_api_key: + together_api_key = config_api_key + else: + provider_data = self.get_request_provider_data() + if provider_data is None or not provider_data.together_api_key: + raise ValueError( + 'Pass Together API Key in the header X-LlamaStack-Provider-Data as { "together_api_key": }' + ) + together_api_key = provider_data.together_api_key + return AsyncTogether(api_key=together_api_key) def _get_openai_client(self) -> AsyncOpenAI: - if not self._openai_client: - together_client = self._get_client().client - self._openai_client = AsyncOpenAI( - base_url=together_client.base_url, - api_key=together_client.api_key, - ) - return self._openai_client + together_client = self._get_client().client + return AsyncOpenAI( + base_url=together_client.base_url, + api_key=together_client.api_key, + ) async def _nonstream_completion(self, request: CompletionRequest) -> ChatCompletionResponse: params = await self._get_params(request) diff --git a/pyproject.toml b/pyproject.toml index f4115d028..9977d7372 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -87,6 +87,8 @@ unit = [ "blobfile", "faiss-cpu", "pymilvus>=2.5.12", + "litellm", + "together", ] # These are the core dependencies required for running integration tests. They are shared across all # providers. If a provider requires additional dependencies, please add them to your environment diff --git a/tests/unit/providers/inference/test_inference_client_caching.py b/tests/unit/providers/inference/test_inference_client_caching.py new file mode 100644 index 000000000..c9a931d47 --- /dev/null +++ b/tests/unit/providers/inference/test_inference_client_caching.py @@ -0,0 +1,73 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +from unittest.mock import MagicMock + +from llama_stack.distribution.request_headers import request_provider_data_context +from llama_stack.providers.remote.inference.groq.config import GroqConfig +from llama_stack.providers.remote.inference.groq.groq import GroqInferenceAdapter +from llama_stack.providers.remote.inference.openai.config import OpenAIConfig +from llama_stack.providers.remote.inference.openai.openai import OpenAIInferenceAdapter +from llama_stack.providers.remote.inference.together.config import TogetherImplConfig +from llama_stack.providers.remote.inference.together.together import TogetherInferenceAdapter + + +def test_groq_provider_openai_client_caching(): + """Ensure the Groq provider does not cache api keys across client requests""" + + config = GroqConfig() + inference_adapter = GroqInferenceAdapter(config) + + inference_adapter.__provider_spec__ = MagicMock() + inference_adapter.__provider_spec__.provider_data_validator = ( + "llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator" + ) + + for api_key in ["test1", "test2"]: + with request_provider_data_context( + {"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})} + ): + openai_client = inference_adapter._get_openai_client() + assert openai_client.api_key == api_key + + +def test_openai_provider_openai_client_caching(): + """Ensure the OpenAI provider does not cache api keys across client requests""" + + config = OpenAIConfig() + inference_adapter = OpenAIInferenceAdapter(config) + + inference_adapter.__provider_spec__ = MagicMock() + inference_adapter.__provider_spec__.provider_data_validator = ( + "llama_stack.providers.remote.inference.openai.config.OpenAIProviderDataValidator" + ) + + for api_key in ["test1", "test2"]: + with request_provider_data_context( + {"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})} + ): + openai_client = inference_adapter._get_openai_client() + assert openai_client.api_key == api_key + + +def test_together_provider_openai_client_caching(): + """Ensure the Together provider does not cache api keys across client requests""" + + config = TogetherImplConfig() + inference_adapter = TogetherInferenceAdapter(config) + + inference_adapter.__provider_spec__ = MagicMock() + inference_adapter.__provider_spec__.provider_data_validator = ( + "llama_stack.providers.remote.inference.together.TogetherProviderDataValidator" + ) + + for api_key in ["test1", "test2"]: + with request_provider_data_context({"x-llamastack-provider-data": json.dumps({"together_api_key": api_key})}): + together_client = inference_adapter._get_client() + assert together_client.client.api_key == api_key + openai_client = inference_adapter._get_openai_client() + assert openai_client.api_key == api_key diff --git a/uv.lock b/uv.lock index fe50f88aa..bca12fc51 100644 --- a/uv.lock +++ b/uv.lock @@ -615,6 +615,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/cb/a3/460c57f094a4a165c84a1341c373b0a4f5ec6ac244b998d5021aade89b77/ecdsa-0.19.1-py2.py3-none-any.whl", hash = "sha256:30638e27cf77b7e15c4c4cc1973720149e1033827cfd00661ca5c8cc0cdb24c3", size = 150607, upload-time = "2025-03-13T11:52:41.757Z" }, ] +[[package]] +name = "eval-type-backport" +version = "0.2.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/30/ea/8b0ac4469d4c347c6a385ff09dc3c048c2d021696664e26c7ee6791631b5/eval_type_backport-0.2.2.tar.gz", hash = "sha256:f0576b4cf01ebb5bd358d02314d31846af5e07678387486e2c798af0e7d849c1", size = 9079, upload-time = "2024-12-21T20:09:46.005Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ce/31/55cd413eaccd39125368be33c46de24a1f639f2e12349b0361b4678f3915/eval_type_backport-0.2.2-py3-none-any.whl", hash = "sha256:cb6ad7c393517f476f96d456d0412ea80f0a8cf96f6892834cd9340149111b0a", size = 5830, upload-time = "2024-12-21T20:09:44.175Z" }, +] + [[package]] name = "executing" version = "2.2.0" @@ -1238,6 +1247,28 @@ version = "1.4" source = { registry = "https://pypi.org/simple" } sdist = { url = "https://files.pythonhosted.org/packages/65/c6/246100fa3967074d9725b3716913bd495823547bde5047050d4c3462f994/linkify-1.4.tar.gz", hash = "sha256:9ba276ba179525f7262820d90f009604e51cd4f1466c1112b882ef7eda243d5e", size = 1749, upload-time = "2009-11-12T21:42:00.934Z" } +[[package]] +name = "litellm" +version = "1.74.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "aiohttp" }, + { name = "click" }, + { name = "httpx" }, + { name = "importlib-metadata" }, + { name = "jinja2" }, + { name = "jsonschema" }, + { name = "openai" }, + { name = "pydantic" }, + { name = "python-dotenv" }, + { name = "tiktoken" }, + { name = "tokenizers" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/38/10/63cdae1b1d581ad1db51153dfd06c4e18394a3ba8de495f73f2d797ece3b/litellm-1.74.2.tar.gz", hash = "sha256:cbacffe93976c60ca674fec0a942c70b900b4ad1c8069395174049a162f255bf", size = 9230641, upload-time = "2025-07-11T03:31:07.925Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2a/f7/67689245f48b9e79bcd2f3a10a3690cb1918fb99fffd5a623ed2496bca66/litellm-1.74.2-py3-none-any.whl", hash = "sha256:29bb555b45128e4cc696e72921a6ec24e97b14e9b69e86eed6f155124ad629b1", size = 8587065, upload-time = "2025-07-11T03:31:05.598Z" }, +] + [[package]] name = "llama-stack" version = "0.2.14" @@ -1341,6 +1372,7 @@ unit = [ { name = "blobfile" }, { name = "chardet" }, { name = "faiss-cpu" }, + { name = "litellm" }, { name = "mcp" }, { name = "openai" }, { name = "pymilvus" }, @@ -1348,6 +1380,7 @@ unit = [ { name = "qdrant-client" }, { name = "sqlalchemy", extra = ["asyncio"] }, { name = "sqlite-vec" }, + { name = "together" }, ] [package.metadata] @@ -1446,6 +1479,7 @@ unit = [ { name = "blobfile" }, { name = "chardet" }, { name = "faiss-cpu" }, + { name = "litellm" }, { name = "mcp" }, { name = "openai" }, { name = "pymilvus", specifier = ">=2.5.12" }, @@ -1454,6 +1488,7 @@ unit = [ { name = "sqlalchemy" }, { name = "sqlalchemy", extras = ["asyncio"], specifier = ">=2.0.41" }, { name = "sqlite-vec" }, + { name = "together" }, ] [[package]] @@ -2952,6 +2987,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/58/29/93c53c098d301132196c3238c312825324740851d77a8500a2462c0fd888/setuptools-80.8.0-py3-none-any.whl", hash = "sha256:95a60484590d24103af13b686121328cc2736bee85de8936383111e421b9edc0", size = 1201470, upload-time = "2025-05-20T14:02:51.348Z" }, ] +[[package]] +name = "shellingham" +version = "1.5.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/58/15/8b3609fd3830ef7b27b655beb4b4e9c62313a4e8da8c676e142cc210d58e/shellingham-1.5.4.tar.gz", hash = "sha256:8dbca0739d487e5bd35ab3ca4b36e11c4078f3a234bfce294b0a0291363404de", size = 10310, upload-time = "2023-10-24T04:13:40.426Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e0/f9/0595336914c5619e5f28a1fb793285925a8cd4b432c9da0a987836c7f822/shellingham-1.5.4-py2.py3-none-any.whl", hash = "sha256:7ecfff8f2fd72616f7481040475a65b2bf8af90a56c89140852d1120324e8686", size = 9755, upload-time = "2023-10-24T04:13:38.866Z" }, +] + [[package]] name = "six" version = "1.17.0" @@ -3384,6 +3428,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b2/fe/81695a1aa331a842b582453b605175f419fe8540355886031328089d840a/sympy-1.13.1-py3-none-any.whl", hash = "sha256:db36cdc64bf61b9b24578b6f7bab1ecdd2452cf008f34faa33776680c26d66f8", size = 6189177, upload-time = "2024-07-19T09:26:48.863Z" }, ] +[[package]] +name = "tabulate" +version = "0.9.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ec/fe/802052aecb21e3797b8f7902564ab6ea0d60ff8ca23952079064155d1ae1/tabulate-0.9.0.tar.gz", hash = "sha256:0095b12bf5966de529c0feb1fa08671671b3368eec77d7ef7ab114be2c068b3c", size = 81090, upload-time = "2022-10-06T17:21:48.54Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/40/44/4a5f08c96eb108af5cb50b41f76142f0afa346dfa99d5296fe7202a11854/tabulate-0.9.0-py3-none-any.whl", hash = "sha256:024ca478df22e9340661486f85298cff5f6dcdba14f3813e8830015b9ed1948f", size = 35252, upload-time = "2022-10-06T17:21:44.262Z" }, +] + [[package]] name = "tenacity" version = "9.1.2" @@ -3426,6 +3479,29 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/de/a8/8f499c179ec900783ffe133e9aab10044481679bb9aad78436d239eee716/tiktoken-0.9.0-cp313-cp313-win_amd64.whl", hash = "sha256:5ea0edb6f83dc56d794723286215918c1cde03712cbbafa0348b33448faf5b95", size = 894669, upload-time = "2025-02-14T06:02:47.341Z" }, ] +[[package]] +name = "together" +version = "1.5.21" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "aiohttp" }, + { name = "click" }, + { name = "eval-type-backport" }, + { name = "filelock" }, + { name = "numpy" }, + { name = "pillow" }, + { name = "pydantic" }, + { name = "requests" }, + { name = "rich" }, + { name = "tabulate" }, + { name = "tqdm" }, + { name = "typer" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ea/53/e33c5e6d53c2e2bbd07f9dcb1979e27ac670fca0e4e238b169aa4c358ee2/together-1.5.21.tar.gz", hash = "sha256:59adb8cf4c5b77eca76b8c66a73c47c45fd828aaf4f059f33f893f8c5f68f85a", size = 69887, upload-time = "2025-07-10T21:04:43.781Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fc/31/6556a303ea39929fa016f4260eef289b620cf366a576c304507cb75b4d12/together-1.5.21-py3-none-any.whl", hash = "sha256:35e6c0072033a2e5f1105de8781e969f41cffc85dae508b6f4dc293360026872", size = 96141, upload-time = "2025-07-10T21:04:42.418Z" }, +] + [[package]] name = "tokenizers" version = "0.21.1" @@ -3644,6 +3720,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/aa/22/733a6fc4a6445d835242f64c490fdd30f4a08d58f2b788613de3f9170692/transformers-4.50.3-py3-none-any.whl", hash = "sha256:6111610a43dec24ef32c3df0632c6b25b07d9711c01d9e1077bdd2ff6b14a38c", size = 10180411, upload-time = "2025-03-28T18:20:59.265Z" }, ] +[[package]] +name = "typer" +version = "0.15.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "click" }, + { name = "rich" }, + { name = "shellingham" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/6c/89/c527e6c848739be8ceb5c44eb8208c52ea3515c6cf6406aa61932887bf58/typer-0.15.4.tar.gz", hash = "sha256:89507b104f9b6a0730354f27c39fae5b63ccd0c95b1ce1f1a6ba0cfd329997c3", size = 101559, upload-time = "2025-05-14T16:34:57.704Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c9/62/d4ba7afe2096d5659ec3db8b15d8665bdcb92a3c6ff0b95e99895b335a9c/typer-0.15.4-py3-none-any.whl", hash = "sha256:eb0651654dcdea706780c466cf06d8f174405a659ffff8f163cfbfee98c0e173", size = 45258, upload-time = "2025-05-14T16:34:55.583Z" }, +] + [[package]] name = "types-requests" version = "2.32.0.20241016"