From fa9e2dd5433d2b0c9a1a2a4e555dfc6528ba1a61 Mon Sep 17 00:00:00 2001 From: Ben Browning Date: Fri, 11 Jul 2025 10:11:31 -0400 Subject: [PATCH] fix: Don't cache clients for passthrough auth providers Some of our inference providers support passthrough authentication via `x-llamastack-provider-data` header values. This fixes the providers that support passthrough auth to not cache their clients to the backend providers (mostly OpenAI client instances) so that the client connecting to Llama Stack has to provide those auth values on each and every request. Signed-off-by: Ben Browning --- .../providers/remote/inference/groq/groq.py | 14 +--- .../remote/inference/openai/openai.py | 14 ++-- .../remote/inference/together/together.py | 47 +++++------- .../test_inference_client_caching.py | 73 +++++++++++++++++++ 4 files changed, 103 insertions(+), 45 deletions(-) create mode 100644 tests/unit/providers/inference/test_inference_client_caching.py diff --git a/llama_stack/providers/remote/inference/groq/groq.py b/llama_stack/providers/remote/inference/groq/groq.py index 4b295e788..91c6b6c17 100644 --- a/llama_stack/providers/remote/inference/groq/groq.py +++ b/llama_stack/providers/remote/inference/groq/groq.py @@ -38,24 +38,18 @@ class GroqInferenceAdapter(LiteLLMOpenAIMixin): provider_data_api_key_field="groq_api_key", ) self.config = config - self._openai_client = None async def initialize(self): await super().initialize() async def shutdown(self): await super().shutdown() - if self._openai_client: - await self._openai_client.close() - self._openai_client = None def _get_openai_client(self) -> AsyncOpenAI: - if not self._openai_client: - self._openai_client = AsyncOpenAI( - base_url=f"{self.config.url}/openai/v1", - api_key=self.config.api_key, - ) - return self._openai_client + return AsyncOpenAI( + base_url=f"{self.config.url}/openai/v1", + api_key=self.get_api_key(), + ) async def openai_chat_completion( self, diff --git a/llama_stack/providers/remote/inference/openai/openai.py b/llama_stack/providers/remote/inference/openai/openai.py index 72428422f..818883919 100644 --- a/llama_stack/providers/remote/inference/openai/openai.py +++ b/llama_stack/providers/remote/inference/openai/openai.py @@ -59,9 +59,6 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin): # if we do not set this, users will be exposed to the # litellm specific model names, an abstraction leak. self.is_openai_compat = True - self._openai_client = AsyncOpenAI( - api_key=self.config.api_key, - ) async def initialize(self) -> None: await super().initialize() @@ -69,6 +66,11 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin): async def shutdown(self) -> None: await super().shutdown() + def _get_openai_client(self) -> AsyncOpenAI: + return AsyncOpenAI( + api_key=self.get_api_key(), + ) + async def openai_completion( self, model: str, @@ -120,7 +122,7 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin): user=user, suffix=suffix, ) - return await self._openai_client.completions.create(**params) + return await self._get_openai_client().completions.create(**params) async def openai_chat_completion( self, @@ -176,7 +178,7 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin): top_p=top_p, user=user, ) - return await self._openai_client.chat.completions.create(**params) + return await self._get_openai_client().chat.completions.create(**params) async def openai_embeddings( self, @@ -204,7 +206,7 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin): params["user"] = user # Call OpenAI embeddings API - response = await self._openai_client.embeddings.create(**params) + response = await self._get_openai_client().embeddings.create(**params) data = [] for i, embedding_data in enumerate(response.data): diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index 9e6877b7c..e1eb934c5 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -68,19 +68,12 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi def __init__(self, config: TogetherImplConfig) -> None: ModelRegistryHelper.__init__(self, MODEL_ENTRIES) self.config = config - self._client = None - self._openai_client = None async def initialize(self) -> None: pass async def shutdown(self) -> None: - if self._client: - # Together client has no close method, so just set to None - self._client = None - if self._openai_client: - await self._openai_client.close() - self._openai_client = None + pass async def completion( self, @@ -108,29 +101,25 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi return await self._nonstream_completion(request) def _get_client(self) -> AsyncTogether: - if not self._client: - together_api_key = None - config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None - if config_api_key: - together_api_key = config_api_key - else: - provider_data = self.get_request_provider_data() - if provider_data is None or not provider_data.together_api_key: - raise ValueError( - 'Pass Together API Key in the header X-LlamaStack-Provider-Data as { "together_api_key": }' - ) - together_api_key = provider_data.together_api_key - self._client = AsyncTogether(api_key=together_api_key) - return self._client + together_api_key = None + config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None + if config_api_key: + together_api_key = config_api_key + else: + provider_data = self.get_request_provider_data() + if provider_data is None or not provider_data.together_api_key: + raise ValueError( + 'Pass Together API Key in the header X-LlamaStack-Provider-Data as { "together_api_key": }' + ) + together_api_key = provider_data.together_api_key + return AsyncTogether(api_key=together_api_key) def _get_openai_client(self) -> AsyncOpenAI: - if not self._openai_client: - together_client = self._get_client().client - self._openai_client = AsyncOpenAI( - base_url=together_client.base_url, - api_key=together_client.api_key, - ) - return self._openai_client + together_client = self._get_client().client + return AsyncOpenAI( + base_url=together_client.base_url, + api_key=together_client.api_key, + ) async def _nonstream_completion(self, request: CompletionRequest) -> ChatCompletionResponse: params = await self._get_params(request) diff --git a/tests/unit/providers/inference/test_inference_client_caching.py b/tests/unit/providers/inference/test_inference_client_caching.py new file mode 100644 index 000000000..c9a931d47 --- /dev/null +++ b/tests/unit/providers/inference/test_inference_client_caching.py @@ -0,0 +1,73 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +from unittest.mock import MagicMock + +from llama_stack.distribution.request_headers import request_provider_data_context +from llama_stack.providers.remote.inference.groq.config import GroqConfig +from llama_stack.providers.remote.inference.groq.groq import GroqInferenceAdapter +from llama_stack.providers.remote.inference.openai.config import OpenAIConfig +from llama_stack.providers.remote.inference.openai.openai import OpenAIInferenceAdapter +from llama_stack.providers.remote.inference.together.config import TogetherImplConfig +from llama_stack.providers.remote.inference.together.together import TogetherInferenceAdapter + + +def test_groq_provider_openai_client_caching(): + """Ensure the Groq provider does not cache api keys across client requests""" + + config = GroqConfig() + inference_adapter = GroqInferenceAdapter(config) + + inference_adapter.__provider_spec__ = MagicMock() + inference_adapter.__provider_spec__.provider_data_validator = ( + "llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator" + ) + + for api_key in ["test1", "test2"]: + with request_provider_data_context( + {"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})} + ): + openai_client = inference_adapter._get_openai_client() + assert openai_client.api_key == api_key + + +def test_openai_provider_openai_client_caching(): + """Ensure the OpenAI provider does not cache api keys across client requests""" + + config = OpenAIConfig() + inference_adapter = OpenAIInferenceAdapter(config) + + inference_adapter.__provider_spec__ = MagicMock() + inference_adapter.__provider_spec__.provider_data_validator = ( + "llama_stack.providers.remote.inference.openai.config.OpenAIProviderDataValidator" + ) + + for api_key in ["test1", "test2"]: + with request_provider_data_context( + {"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})} + ): + openai_client = inference_adapter._get_openai_client() + assert openai_client.api_key == api_key + + +def test_together_provider_openai_client_caching(): + """Ensure the Together provider does not cache api keys across client requests""" + + config = TogetherImplConfig() + inference_adapter = TogetherInferenceAdapter(config) + + inference_adapter.__provider_spec__ = MagicMock() + inference_adapter.__provider_spec__.provider_data_validator = ( + "llama_stack.providers.remote.inference.together.TogetherProviderDataValidator" + ) + + for api_key in ["test1", "test2"]: + with request_provider_data_context({"x-llamastack-provider-data": json.dumps({"together_api_key": api_key})}): + together_client = inference_adapter._get_client() + assert together_client.client.api_key == api_key + openai_client = inference_adapter._get_openai_client() + assert openai_client.api_key == api_key