diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index e907e8ec6..2fcf1be2e 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -4,11 +4,9 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from collections.abc import AsyncGenerator, AsyncIterator -from typing import Any +from collections.abc import AsyncGenerator from fireworks.client import Fireworks -from openai import AsyncOpenAI from llama_stack.apis.common.content_types import ( InterleavedContent, @@ -24,12 +22,6 @@ from llama_stack.apis.inference import ( Inference, LogProbConfig, Message, - OpenAIChatCompletion, - OpenAIChatCompletionChunk, - OpenAICompletion, - OpenAIEmbeddingsResponse, - OpenAIMessageParam, - OpenAIResponseFormatParam, ResponseFormat, ResponseFormatType, SamplingParams, @@ -45,15 +37,14 @@ from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, ) from llama_stack.providers.utils.inference.openai_compat import ( - OpenAIChatCompletionToLlamaStackMixin, convert_message_to_openai_dict, get_sampling_options, - prepare_openai_completion_params, process_chat_completion_response, process_chat_completion_stream_response, process_completion_response, process_completion_stream_response, ) +from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from llama_stack.providers.utils.inference.prompt_adapter import ( chat_completion_request_to_prompt, completion_request_to_prompt, @@ -68,7 +59,7 @@ from .models import MODEL_ENTRIES logger = get_logger(name=__name__, category="inference::fireworks") -class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData): +class FireworksInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, NeedsRequestProviderData): def __init__(self, config: FireworksImplConfig) -> None: ModelRegistryHelper.__init__(self, MODEL_ENTRIES, config.allowed_models) self.config = config @@ -79,7 +70,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv async def shutdown(self) -> None: pass - def _get_api_key(self) -> str: + def get_api_key(self) -> str: config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None if config_api_key: return config_api_key @@ -91,15 +82,18 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv ) return provider_data.fireworks_api_key - def _get_base_url(self) -> str: + def get_base_url(self) -> str: return "https://api.fireworks.ai/inference/v1" def _get_client(self) -> Fireworks: - fireworks_api_key = self._get_api_key() + fireworks_api_key = self.get_api_key() return Fireworks(api_key=fireworks_api_key) - def _get_openai_client(self) -> AsyncOpenAI: - return AsyncOpenAI(base_url=self._get_base_url(), api_key=self._get_api_key()) + def _preprocess_prompt_for_fireworks(self, prompt: str) -> str: + """Remove BOS token as Fireworks automatically prepends it""" + if prompt.startswith("<|begin_of_text|>"): + return prompt[len("<|begin_of_text|>") :] + return prompt async def completion( self, @@ -285,153 +279,3 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv embeddings = [data.embedding for data in response.data] return EmbeddingsResponse(embeddings=embeddings) - - async def openai_embeddings( - self, - model: str, - input: str | list[str], - encoding_format: str | None = "float", - dimensions: int | None = None, - user: str | None = None, - ) -> OpenAIEmbeddingsResponse: - raise NotImplementedError() - - async def openai_completion( - self, - model: str, - prompt: str | list[str] | list[int] | list[list[int]], - best_of: int | None = None, - echo: bool | None = None, - frequency_penalty: float | None = None, - logit_bias: dict[str, float] | None = None, - logprobs: bool | None = None, - max_tokens: int | None = None, - n: int | None = None, - presence_penalty: float | None = None, - seed: int | None = None, - stop: str | list[str] | None = None, - stream: bool | None = None, - stream_options: dict[str, Any] | None = None, - temperature: float | None = None, - top_p: float | None = None, - user: str | None = None, - guided_choice: list[str] | None = None, - prompt_logprobs: int | None = None, - suffix: str | None = None, - ) -> OpenAICompletion: - model_obj = await self.model_store.get_model(model) - - # Fireworks always prepends with BOS - if isinstance(prompt, str) and prompt.startswith("<|begin_of_text|>"): - prompt = prompt[len("<|begin_of_text|>") :] - - params = await prepare_openai_completion_params( - model=model_obj.provider_resource_id, - prompt=prompt, - best_of=best_of, - echo=echo, - frequency_penalty=frequency_penalty, - logit_bias=logit_bias, - logprobs=logprobs, - max_tokens=max_tokens, - n=n, - presence_penalty=presence_penalty, - seed=seed, - stop=stop, - stream=stream, - stream_options=stream_options, - temperature=temperature, - top_p=top_p, - user=user, - ) - - return await self._get_openai_client().completions.create(**params) - - async def openai_chat_completion( - self, - model: str, - messages: list[OpenAIMessageParam], - frequency_penalty: float | None = None, - function_call: str | dict[str, Any] | None = None, - functions: list[dict[str, Any]] | None = None, - logit_bias: dict[str, float] | None = None, - logprobs: bool | None = None, - max_completion_tokens: int | None = None, - max_tokens: int | None = None, - n: int | None = None, - parallel_tool_calls: bool | None = None, - presence_penalty: float | None = None, - response_format: OpenAIResponseFormatParam | None = None, - seed: int | None = None, - stop: str | list[str] | None = None, - stream: bool | None = None, - stream_options: dict[str, Any] | None = None, - temperature: float | None = None, - tool_choice: str | dict[str, Any] | None = None, - tools: list[dict[str, Any]] | None = None, - top_logprobs: int | None = None, - top_p: float | None = None, - user: str | None = None, - ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: - model_obj = await self.model_store.get_model(model) - - # Divert Llama Models through Llama Stack inference APIs because - # Fireworks chat completions OpenAI-compatible API does not support - # tool calls properly. - llama_model = self.get_llama_model(model_obj.provider_resource_id) - - if llama_model: - return await OpenAIChatCompletionToLlamaStackMixin.openai_chat_completion( - self, - model=model, - messages=messages, - frequency_penalty=frequency_penalty, - function_call=function_call, - functions=functions, - logit_bias=logit_bias, - logprobs=logprobs, - max_completion_tokens=max_completion_tokens, - max_tokens=max_tokens, - n=n, - parallel_tool_calls=parallel_tool_calls, - presence_penalty=presence_penalty, - response_format=response_format, - seed=seed, - stop=stop, - stream=stream, - stream_options=stream_options, - temperature=temperature, - tool_choice=tool_choice, - tools=tools, - top_logprobs=top_logprobs, - top_p=top_p, - user=user, - ) - - params = await prepare_openai_completion_params( - messages=messages, - frequency_penalty=frequency_penalty, - function_call=function_call, - functions=functions, - logit_bias=logit_bias, - logprobs=logprobs, - max_completion_tokens=max_completion_tokens, - max_tokens=max_tokens, - n=n, - parallel_tool_calls=parallel_tool_calls, - presence_penalty=presence_penalty, - response_format=response_format, - seed=seed, - stop=stop, - stream=stream, - stream_options=stream_options, - temperature=temperature, - tool_choice=tool_choice, - tools=tools, - top_logprobs=top_logprobs, - top_p=top_p, - user=user, - ) - - logger.debug(f"fireworks params: {params}") - return await self._get_openai_client().chat.completions.create(model=model_obj.provider_resource_id, **params) diff --git a/tests/integration/inference/test_openai_embeddings.py b/tests/integration/inference/test_openai_embeddings.py index 622b97287..ce3d2a8ea 100644 --- a/tests/integration/inference/test_openai_embeddings.py +++ b/tests/integration/inference/test_openai_embeddings.py @@ -33,6 +33,7 @@ def skip_if_model_doesnt_support_user_param(client, model_id): provider = provider_from_model(client, model_id) if provider.provider_type in ( "remote::together", # service returns 400 + "remote::fireworks", # service returns 400 malformed input ): pytest.skip(f"Model {model_id} hosted by {provider.provider_type} does not support user param.") @@ -41,6 +42,7 @@ def skip_if_model_doesnt_support_encoding_format_base64(client, model_id): provider = provider_from_model(client, model_id) if provider.provider_type in ( "remote::together", # param silently ignored, always returns floats + "remote::fireworks", # param silently ignored, always returns list of floats ): pytest.skip(f"Model {model_id} hosted by {provider.provider_type} does not support encoding_format='base64'.") @@ -287,7 +289,6 @@ def test_openai_embeddings_base64_batch_processing(compat_client, client_with_mo input=input_texts, encoding_format="base64", ) - # Validate response structure assert response.object == "list" assert response.model == embedding_model_id diff --git a/tests/integration/suites.py b/tests/integration/suites.py index 231480447..fb2c44308 100644 --- a/tests/integration/suites.py +++ b/tests/integration/suites.py @@ -108,6 +108,15 @@ SETUP_DEFINITIONS: dict[str, Setup] = { "embedding_model": "together/togethercomputer/m2-bert-80M-32k-retrieval", }, ), + "fireworks": Setup( + name="fireworks", + description="Fireworks provider with a text model", + defaults={ + "text_model": "accounts/fireworks/models/llama-v3p1-8b-instruct", + "vision_model": "accounts/fireworks/models/llama-v3p2-90b-vision-instruct", + "embedding_model": "nomic-ai/nomic-embed-text-v1.5", + }, + ), }