From 9aee115abe2f53436a4fed78e95fbe5f150af16e Mon Sep 17 00:00:00 2001 From: Swapna Lekkala Date: Wed, 17 Sep 2025 18:19:15 -0700 Subject: [PATCH] revert client change to non openai methods --- .../remote/inference/fireworks/fireworks.py | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index fc77a7214..8f6338afc 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -7,6 +7,8 @@ from collections.abc import AsyncGenerator, AsyncIterator from typing import Any +from fireworks.client import Fireworks + from llama_stack.apis.common.content_types import ( InterleavedContent, InterleavedContentItem, @@ -91,6 +93,10 @@ class FireworksInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, Nee def get_base_url(self) -> str: return "https://api.fireworks.ai/inference/v1" + def _get_client(self) -> Fireworks: + fireworks_api_key = self.get_api_key() + return Fireworks(api_key=fireworks_api_key) + def _preprocess_prompt_for_fireworks(self, prompt: str) -> str: """Remove BOS token as Fireworks automatically prepends it""" if prompt.startswith("<|begin_of_text|>"): @@ -124,13 +130,13 @@ class FireworksInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, Nee async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse: params = await self._get_params(request) - r = await self.client.completions.create(**params) + r = await self._get_client().completions.create(**params) return process_completion_response(r) async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: params = await self._get_params(request) - stream = await self.client.completions.create(**params) + stream = await self._get_client().completions.create(**params) async for chunk in process_completion_stream_response(stream): yield chunk @@ -199,18 +205,18 @@ class FireworksInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, Nee async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: params = await self._get_params(request) if "messages" in params: - r = await self.client.chat.completions.create(**params) + r = await self._get_client().chat.completions.create(**params) else: - r = await self.client.completions.create(**params) + r = await self._get_client().completions.create(**params) return process_chat_completion_response(r, request) async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: params = await self._get_params(request) if "messages" in params: - stream = await self.client.chat.completions.create(**params) + stream = await self._get_client().chat.completions.create(**params) else: - stream = await self.client.completions.create(**params) + stream = await self._get_client().completions.create(**params) async for chunk in process_chat_completion_stream_response(stream, request): yield chunk @@ -261,7 +267,7 @@ class FireworksInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, Nee assert all(not content_has_media(content) for content in contents), ( "Fireworks does not support media for embeddings" ) - response = await self.client.embeddings.create( + response = await self._get_client().embeddings.create( model=model.provider_resource_id, input=[interleaved_content_as_str(content) for content in contents], **kwargs,