mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-23 19:13:59 +00:00
fix: Don't cache clients for passthrough auth providers
Some of our inference providers support passthrough authentication via `x-llamastack-provider-data` header values. This fixes the providers that support passthrough auth to not cache their clients to the backend providers (mostly OpenAI client instances) so that the client connecting to Llama Stack has to provide those auth values on each and every request. Signed-off-by: Ben Browning <bbrownin@redhat.com>
This commit is contained in:
parent
d880c2df0e
commit
fa9e2dd543
4 changed files with 103 additions and 45 deletions
|
|
@ -59,9 +59,6 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
# if we do not set this, users will be exposed to the
|
||||
# litellm specific model names, an abstraction leak.
|
||||
self.is_openai_compat = True
|
||||
self._openai_client = AsyncOpenAI(
|
||||
api_key=self.config.api_key,
|
||||
)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
await super().initialize()
|
||||
|
|
@ -69,6 +66,11 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
async def shutdown(self) -> None:
|
||||
await super().shutdown()
|
||||
|
||||
def _get_openai_client(self) -> AsyncOpenAI:
|
||||
return AsyncOpenAI(
|
||||
api_key=self.get_api_key(),
|
||||
)
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
|
|
@ -120,7 +122,7 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
user=user,
|
||||
suffix=suffix,
|
||||
)
|
||||
return await self._openai_client.completions.create(**params)
|
||||
return await self._get_openai_client().completions.create(**params)
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
|
|
@ -176,7 +178,7 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
return await self._openai_client.chat.completions.create(**params)
|
||||
return await self._get_openai_client().chat.completions.create(**params)
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
|
|
@ -204,7 +206,7 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
params["user"] = user
|
||||
|
||||
# Call OpenAI embeddings API
|
||||
response = await self._openai_client.embeddings.create(**params)
|
||||
response = await self._get_openai_client().embeddings.create(**params)
|
||||
|
||||
data = []
|
||||
for i, embedding_data in enumerate(response.data):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue