fix: Don't cache clients for passthrough auth providers

Some of our inference providers support passthrough authentication via
`x-llamastack-provider-data` header values. This fixes the providers
that support passthrough auth to not cache their clients to the
backend providers (mostly OpenAI client instances) so that the client
connecting to Llama Stack has to provide those auth values on each and
every request.

Signed-off-by: Ben Browning <bbrownin@redhat.com>
This commit is contained in:
Ben Browning 2025-07-11 10:11:31 -04:00
parent d880c2df0e
commit fa9e2dd543
4 changed files with 103 additions and 45 deletions

View file

@ -59,9 +59,6 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
# if we do not set this, users will be exposed to the
# litellm specific model names, an abstraction leak.
self.is_openai_compat = True
self._openai_client = AsyncOpenAI(
api_key=self.config.api_key,
)
async def initialize(self) -> None:
await super().initialize()
@ -69,6 +66,11 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
async def shutdown(self) -> None:
await super().shutdown()
def _get_openai_client(self) -> AsyncOpenAI:
return AsyncOpenAI(
api_key=self.get_api_key(),
)
async def openai_completion(
self,
model: str,
@ -120,7 +122,7 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
user=user,
suffix=suffix,
)
return await self._openai_client.completions.create(**params)
return await self._get_openai_client().completions.create(**params)
async def openai_chat_completion(
self,
@ -176,7 +178,7 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
top_p=top_p,
user=user,
)
return await self._openai_client.chat.completions.create(**params)
return await self._get_openai_client().chat.completions.create(**params)
async def openai_embeddings(
self,
@ -204,7 +206,7 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
params["user"] = user
# Call OpenAI embeddings API
response = await self._openai_client.embeddings.create(**params)
response = await self._get_openai_client().embeddings.create(**params)
data = []
for i, embedding_data in enumerate(response.data):