diff --git a/llama_stack/providers/remote/inference/groq/groq.py b/llama_stack/providers/remote/inference/groq/groq.py index 4b295e788..91c6b6c17 100644 --- a/llama_stack/providers/remote/inference/groq/groq.py +++ b/llama_stack/providers/remote/inference/groq/groq.py @@ -38,24 +38,18 @@ class GroqInferenceAdapter(LiteLLMOpenAIMixin): provider_data_api_key_field="groq_api_key", ) self.config = config - self._openai_client = None async def initialize(self): await super().initialize() async def shutdown(self): await super().shutdown() - if self._openai_client: - await self._openai_client.close() - self._openai_client = None def _get_openai_client(self) -> AsyncOpenAI: - if not self._openai_client: - self._openai_client = AsyncOpenAI( - base_url=f"{self.config.url}/openai/v1", - api_key=self.config.api_key, - ) - return self._openai_client + return AsyncOpenAI( + base_url=f"{self.config.url}/openai/v1", + api_key=self.get_api_key(), + ) async def openai_chat_completion( self, diff --git a/llama_stack/providers/remote/inference/openai/openai.py b/llama_stack/providers/remote/inference/openai/openai.py index 72428422f..818883919 100644 --- a/llama_stack/providers/remote/inference/openai/openai.py +++ b/llama_stack/providers/remote/inference/openai/openai.py @@ -59,9 +59,6 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin): # if we do not set this, users will be exposed to the # litellm specific model names, an abstraction leak. self.is_openai_compat = True - self._openai_client = AsyncOpenAI( - api_key=self.config.api_key, - ) async def initialize(self) -> None: await super().initialize() @@ -69,6 +66,11 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin): async def shutdown(self) -> None: await super().shutdown() + def _get_openai_client(self) -> AsyncOpenAI: + return AsyncOpenAI( + api_key=self.get_api_key(), + ) + async def openai_completion( self, model: str, @@ -120,7 +122,7 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin): user=user, suffix=suffix, ) - return await self._openai_client.completions.create(**params) + return await self._get_openai_client().completions.create(**params) async def openai_chat_completion( self, @@ -176,7 +178,7 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin): top_p=top_p, user=user, ) - return await self._openai_client.chat.completions.create(**params) + return await self._get_openai_client().chat.completions.create(**params) async def openai_embeddings( self, @@ -204,7 +206,7 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin): params["user"] = user # Call OpenAI embeddings API - response = await self._openai_client.embeddings.create(**params) + response = await self._get_openai_client().embeddings.create(**params) data = [] for i, embedding_data in enumerate(response.data): diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index 9e6877b7c..e1eb934c5 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -68,19 +68,12 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi def __init__(self, config: TogetherImplConfig) -> None: ModelRegistryHelper.__init__(self, MODEL_ENTRIES) self.config = config - self._client = None - self._openai_client = None async def initialize(self) -> None: pass async def shutdown(self) -> None: - if self._client: - # Together client has no close method, so just set to None - self._client = None - if self._openai_client: - await self._openai_client.close() - self._openai_client = None + pass async def completion( self, @@ -108,29 +101,25 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi return await self._nonstream_completion(request) def _get_client(self) -> AsyncTogether: - if not self._client: - together_api_key = None - config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None - if config_api_key: - together_api_key = config_api_key - else: - provider_data = self.get_request_provider_data() - if provider_data is None or not provider_data.together_api_key: - raise ValueError( - 'Pass Together API Key in the header X-LlamaStack-Provider-Data as { "together_api_key": }' - ) - together_api_key = provider_data.together_api_key - self._client = AsyncTogether(api_key=together_api_key) - return self._client + together_api_key = None + config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None + if config_api_key: + together_api_key = config_api_key + else: + provider_data = self.get_request_provider_data() + if provider_data is None or not provider_data.together_api_key: + raise ValueError( + 'Pass Together API Key in the header X-LlamaStack-Provider-Data as { "together_api_key": }' + ) + together_api_key = provider_data.together_api_key + return AsyncTogether(api_key=together_api_key) def _get_openai_client(self) -> AsyncOpenAI: - if not self._openai_client: - together_client = self._get_client().client - self._openai_client = AsyncOpenAI( - base_url=together_client.base_url, - api_key=together_client.api_key, - ) - return self._openai_client + together_client = self._get_client().client + return AsyncOpenAI( + base_url=together_client.base_url, + api_key=together_client.api_key, + ) async def _nonstream_completion(self, request: CompletionRequest) -> ChatCompletionResponse: params = await self._get_params(request) diff --git a/tests/unit/providers/inference/test_inference_client_caching.py b/tests/unit/providers/inference/test_inference_client_caching.py new file mode 100644 index 000000000..c9a931d47 --- /dev/null +++ b/tests/unit/providers/inference/test_inference_client_caching.py @@ -0,0 +1,73 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +from unittest.mock import MagicMock + +from llama_stack.distribution.request_headers import request_provider_data_context +from llama_stack.providers.remote.inference.groq.config import GroqConfig +from llama_stack.providers.remote.inference.groq.groq import GroqInferenceAdapter +from llama_stack.providers.remote.inference.openai.config import OpenAIConfig +from llama_stack.providers.remote.inference.openai.openai import OpenAIInferenceAdapter +from llama_stack.providers.remote.inference.together.config import TogetherImplConfig +from llama_stack.providers.remote.inference.together.together import TogetherInferenceAdapter + + +def test_groq_provider_openai_client_caching(): + """Ensure the Groq provider does not cache api keys across client requests""" + + config = GroqConfig() + inference_adapter = GroqInferenceAdapter(config) + + inference_adapter.__provider_spec__ = MagicMock() + inference_adapter.__provider_spec__.provider_data_validator = ( + "llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator" + ) + + for api_key in ["test1", "test2"]: + with request_provider_data_context( + {"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})} + ): + openai_client = inference_adapter._get_openai_client() + assert openai_client.api_key == api_key + + +def test_openai_provider_openai_client_caching(): + """Ensure the OpenAI provider does not cache api keys across client requests""" + + config = OpenAIConfig() + inference_adapter = OpenAIInferenceAdapter(config) + + inference_adapter.__provider_spec__ = MagicMock() + inference_adapter.__provider_spec__.provider_data_validator = ( + "llama_stack.providers.remote.inference.openai.config.OpenAIProviderDataValidator" + ) + + for api_key in ["test1", "test2"]: + with request_provider_data_context( + {"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})} + ): + openai_client = inference_adapter._get_openai_client() + assert openai_client.api_key == api_key + + +def test_together_provider_openai_client_caching(): + """Ensure the Together provider does not cache api keys across client requests""" + + config = TogetherImplConfig() + inference_adapter = TogetherInferenceAdapter(config) + + inference_adapter.__provider_spec__ = MagicMock() + inference_adapter.__provider_spec__.provider_data_validator = ( + "llama_stack.providers.remote.inference.together.TogetherProviderDataValidator" + ) + + for api_key in ["test1", "test2"]: + with request_provider_data_context({"x-llamastack-provider-data": json.dumps({"together_api_key": api_key})}): + together_client = inference_adapter._get_client() + assert together_client.client.api_key == api_key + openai_client = inference_adapter._get_openai_client() + assert openai_client.api_key == api_key