From e56a3f266c28e5cc4856a61bbd5f2a498a00fbbe Mon Sep 17 00:00:00 2001 From: Swapna Lekkala Date: Tue, 16 Sep 2025 09:38:51 -0700 Subject: [PATCH] chore: Refactor to use OpenAIMixin --- .../remote/inference/fireworks/fireworks.py | 76 +++++++------------ 1 file changed, 28 insertions(+), 48 deletions(-) diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index e907e8ec6..2c01d192c 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -7,9 +7,6 @@ from collections.abc import AsyncGenerator, AsyncIterator from typing import Any -from fireworks.client import Fireworks -from openai import AsyncOpenAI - from llama_stack.apis.common.content_types import ( InterleavedContent, InterleavedContentItem, @@ -48,12 +45,12 @@ from llama_stack.providers.utils.inference.openai_compat import ( OpenAIChatCompletionToLlamaStackMixin, convert_message_to_openai_dict, get_sampling_options, - prepare_openai_completion_params, process_chat_completion_response, process_chat_completion_stream_response, process_completion_response, process_completion_stream_response, ) +from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from llama_stack.providers.utils.inference.prompt_adapter import ( chat_completion_request_to_prompt, completion_request_to_prompt, @@ -68,7 +65,7 @@ from .models import MODEL_ENTRIES logger = get_logger(name=__name__, category="inference::fireworks") -class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData): +class FireworksInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, NeedsRequestProviderData): def __init__(self, config: FireworksImplConfig) -> None: ModelRegistryHelper.__init__(self, MODEL_ENTRIES, config.allowed_models) self.config = config @@ -79,7 +76,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv async def shutdown(self) -> None: pass - def _get_api_key(self) -> str: + def get_api_key(self) -> str: config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None if config_api_key: return config_api_key @@ -91,15 +88,14 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv ) return provider_data.fireworks_api_key - def _get_base_url(self) -> str: + def get_base_url(self) -> str: return "https://api.fireworks.ai/inference/v1" - def _get_client(self) -> Fireworks: - fireworks_api_key = self._get_api_key() - return Fireworks(api_key=fireworks_api_key) - - def _get_openai_client(self) -> AsyncOpenAI: - return AsyncOpenAI(base_url=self._get_base_url(), api_key=self._get_api_key()) + def _preprocess_prompt_for_fireworks(self, prompt: str) -> str: + """Remove BOS token as Fireworks automatically prepends it""" + if prompt.startswith("<|begin_of_text|>"): + return prompt[len("<|begin_of_text|>") :] + return prompt async def completion( self, @@ -128,19 +124,13 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse: params = await self._get_params(request) - r = await self._get_client().completion.acreate(**params) + r = await self.client.completions.create(**params) return process_completion_response(r) async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: params = await self._get_params(request) - # Wrapper for async generator similar - async def _to_async_generator(): - stream = self._get_client().completion.create(**params) - for chunk in stream: - yield chunk - - stream = _to_async_generator() + stream = self.client.completions.create(**params) async for chunk in process_completion_stream_response(stream): yield chunk @@ -209,23 +199,18 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: params = await self._get_params(request) if "messages" in params: - r = await self._get_client().chat.completions.acreate(**params) + r = await self.client.chat.completions.create(**params) else: - r = await self._get_client().completion.acreate(**params) + r = await self.client.completions.create(**params) return process_chat_completion_response(r, request) async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: params = await self._get_params(request) - async def _to_async_generator(): - if "messages" in params: - stream = self._get_client().chat.completions.acreate(**params) - else: - stream = self._get_client().completion.acreate(**params) - async for chunk in stream: - yield chunk - - stream = _to_async_generator() + if "messages" in params: + stream = self.client.chat.completions.create(**params) + else: + stream = self.client.completions.create(**params) async for chunk in process_chat_completion_stream_response(stream, request): yield chunk @@ -248,8 +233,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv # Fireworks always prepends with BOS if "prompt" in input_dict: - if input_dict["prompt"].startswith("<|begin_of_text|>"): - input_dict["prompt"] = input_dict["prompt"][len("<|begin_of_text|>") :] + input_dict["prompt"] = self._preprocess_prompt_for_fireworks(input_dict["prompt"]) params = { "model": request.model, @@ -277,7 +261,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv assert all(not content_has_media(content) for content in contents), ( "Fireworks does not support media for embeddings" ) - response = self._get_client().embeddings.create( + response = await self.client.embeddings.create( model=model.provider_resource_id, input=[interleaved_content_as_str(content) for content in contents], **kwargs, @@ -319,14 +303,11 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv prompt_logprobs: int | None = None, suffix: str | None = None, ) -> OpenAICompletion: - model_obj = await self.model_store.get_model(model) + if isinstance(prompt, str): + prompt = self._preprocess_prompt_for_fireworks(prompt) - # Fireworks always prepends with BOS - if isinstance(prompt, str) and prompt.startswith("<|begin_of_text|>"): - prompt = prompt[len("<|begin_of_text|>") :] - - params = await prepare_openai_completion_params( - model=model_obj.provider_resource_id, + return await super().openai_completion( + model=model, prompt=prompt, best_of=best_of, echo=echo, @@ -343,10 +324,11 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv temperature=temperature, top_p=top_p, user=user, + guided_choice=guided_choice, + prompt_logprobs=prompt_logprobs, + suffix=suffix, ) - return await self._get_openai_client().completions.create(**params) - async def openai_chat_completion( self, model: str, @@ -408,7 +390,8 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv user=user, ) - params = await prepare_openai_completion_params( + return await super().openai_chat_completion( + model=model, messages=messages, frequency_penalty=frequency_penalty, function_call=function_call, @@ -432,6 +415,3 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv top_p=top_p, user=user, ) - - logger.debug(f"fireworks params: {params}") - return await self._get_openai_client().chat.completions.create(model=model_obj.provider_resource_id, **params)