From 31181c070bb1d10c770bf9fc4bec9395a4a3864f Mon Sep 17 00:00:00 2001 From: Ben Browning Date: Thu, 10 Apr 2025 15:29:32 -0400 Subject: [PATCH] Fireworks provider support for OpenAI API endpoints This wires up the openai_completion and openai_chat_completion API methods for the remote Fireworks inference provider. Signed-off-by: Ben Browning --- .../remote/inference/fireworks/fireworks.py | 109 +++++++++++++++++- .../inference/test_openai_completion.py | 5 +- 2 files changed, 112 insertions(+), 2 deletions(-) diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index 4acbe43f8..b59e9f2cb 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -4,9 +4,10 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import AsyncGenerator, List, Optional, Union +from typing import Any, AsyncGenerator, Dict, List, Optional, Union from fireworks.client import Fireworks +from openai import AsyncOpenAI from llama_stack.apis.common.content_types import ( InterleavedContent, @@ -31,6 +32,7 @@ from llama_stack.apis.inference import ( ToolDefinition, ToolPromptFormat, ) +from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.log import get_logger from llama_stack.providers.utils.inference.model_registry import ( @@ -39,6 +41,7 @@ from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.openai_compat import ( convert_message_to_openai_dict, get_sampling_options, + prepare_openai_completion_params, process_chat_completion_response, process_chat_completion_stream_response, process_completion_response, @@ -81,10 +84,16 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv ) return provider_data.fireworks_api_key + def _get_base_url(self) -> str: + return "https://api.fireworks.ai/inference/v1" + def _get_client(self) -> Fireworks: fireworks_api_key = self._get_api_key() return Fireworks(api_key=fireworks_api_key) + def _get_openai_client(self) -> AsyncOpenAI: + return AsyncOpenAI(base_url=self._get_base_url(), api_key=self._get_api_key()) + async def completion( self, model_id: str, @@ -268,3 +277,101 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv embeddings = [data.embedding for data in response.data] return EmbeddingsResponse(embeddings=embeddings) + + async def openai_completion( + self, + model: str, + prompt: Union[str, List[str], List[int], List[List[int]]], + best_of: Optional[int] = None, + echo: Optional[bool] = None, + frequency_penalty: Optional[float] = None, + logit_bias: Optional[Dict[str, float]] = None, + logprobs: Optional[bool] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + presence_penalty: Optional[float] = None, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: Optional[bool] = None, + stream_options: Optional[Dict[str, Any]] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + user: Optional[str] = None, + guided_choice: Optional[List[str]] = None, + prompt_logprobs: Optional[int] = None, + ) -> OpenAICompletion: + model_obj = await self.model_store.get_model(model) + params = await prepare_openai_completion_params( + model=model_obj.provider_resource_id, + prompt=prompt, + best_of=best_of, + echo=echo, + frequency_penalty=frequency_penalty, + logit_bias=logit_bias, + logprobs=logprobs, + max_tokens=max_tokens, + n=n, + presence_penalty=presence_penalty, + seed=seed, + stop=stop, + stream=stream, + stream_options=stream_options, + temperature=temperature, + top_p=top_p, + user=user, + ) + return await self._get_openai_client().completions.create(**params) + + async def openai_chat_completion( + self, + model: str, + messages: List[OpenAIMessageParam], + frequency_penalty: Optional[float] = None, + function_call: Optional[Union[str, Dict[str, Any]]] = None, + functions: Optional[List[Dict[str, Any]]] = None, + logit_bias: Optional[Dict[str, float]] = None, + logprobs: Optional[bool] = None, + max_completion_tokens: Optional[int] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + parallel_tool_calls: Optional[bool] = None, + presence_penalty: Optional[float] = None, + response_format: Optional[Dict[str, str]] = None, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: Optional[bool] = None, + stream_options: Optional[Dict[str, Any]] = None, + temperature: Optional[float] = None, + tool_choice: Optional[Union[str, Dict[str, Any]]] = None, + tools: Optional[List[Dict[str, Any]]] = None, + top_logprobs: Optional[int] = None, + top_p: Optional[float] = None, + user: Optional[str] = None, + ) -> OpenAIChatCompletion: + model_obj = await self.model_store.get_model(model) + params = await prepare_openai_completion_params( + model=model_obj.provider_resource_id, + messages=messages, + frequency_penalty=frequency_penalty, + function_call=function_call, + functions=functions, + logit_bias=logit_bias, + logprobs=logprobs, + max_completion_tokens=max_completion_tokens, + max_tokens=max_tokens, + n=n, + parallel_tool_calls=parallel_tool_calls, + presence_penalty=presence_penalty, + response_format=response_format, + seed=seed, + stop=stop, + stream=stream, + stream_options=stream_options, + temperature=temperature, + tool_choice=tool_choice, + tools=tools, + top_logprobs=top_logprobs, + top_p=top_p, + user=user, + ) + return await self._get_openai_client().chat.completions.create(**params) diff --git a/tests/integration/inference/test_openai_completion.py b/tests/integration/inference/test_openai_completion.py index e6e584727..0905d5817 100644 --- a/tests/integration/inference/test_openai_completion.py +++ b/tests/integration/inference/test_openai_completion.py @@ -208,6 +208,9 @@ def test_openai_chat_completion_streaming(openai_client, client_with_models, tex stream=True, timeout=120, # Increase timeout to 2 minutes for large conversation history ) - streamed_content = [str(chunk.choices[0].delta.content.lower().strip()) for chunk in response] + streamed_content = [] + for chunk in response: + if chunk.choices[0].delta.content: + streamed_content.append(chunk.choices[0].delta.content.lower().strip()) assert len(streamed_content) > 0 assert expected.lower() in "".join(streamed_content)