chore: Refactor to use OpenAIMixin

This commit is contained in:
Swapna Lekkala 2025-09-16 09:38:51 -07:00
parent e3fd70c321
commit e56a3f266c

View file

@ -7,9 +7,6 @@
from collections.abc import AsyncGenerator, AsyncIterator from collections.abc import AsyncGenerator, AsyncIterator
from typing import Any from typing import Any
from fireworks.client import Fireworks
from openai import AsyncOpenAI
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
InterleavedContent, InterleavedContent,
InterleavedContentItem, InterleavedContentItem,
@ -48,12 +45,12 @@ from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionToLlamaStackMixin, OpenAIChatCompletionToLlamaStackMixin,
convert_message_to_openai_dict, convert_message_to_openai_dict,
get_sampling_options, get_sampling_options,
prepare_openai_completion_params,
process_chat_completion_response, process_chat_completion_response,
process_chat_completion_stream_response, process_chat_completion_stream_response,
process_completion_response, process_completion_response,
process_completion_stream_response, process_completion_stream_response,
) )
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt, chat_completion_request_to_prompt,
completion_request_to_prompt, completion_request_to_prompt,
@ -68,7 +65,7 @@ from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference::fireworks") logger = get_logger(name=__name__, category="inference::fireworks")
class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData): class FireworksInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, NeedsRequestProviderData):
def __init__(self, config: FireworksImplConfig) -> None: def __init__(self, config: FireworksImplConfig) -> None:
ModelRegistryHelper.__init__(self, MODEL_ENTRIES, config.allowed_models) ModelRegistryHelper.__init__(self, MODEL_ENTRIES, config.allowed_models)
self.config = config self.config = config
@ -79,7 +76,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
def _get_api_key(self) -> str: def get_api_key(self) -> str:
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
if config_api_key: if config_api_key:
return config_api_key return config_api_key
@ -91,15 +88,14 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
) )
return provider_data.fireworks_api_key return provider_data.fireworks_api_key
def _get_base_url(self) -> str: def get_base_url(self) -> str:
return "https://api.fireworks.ai/inference/v1" return "https://api.fireworks.ai/inference/v1"
def _get_client(self) -> Fireworks: def _preprocess_prompt_for_fireworks(self, prompt: str) -> str:
fireworks_api_key = self._get_api_key() """Remove BOS token as Fireworks automatically prepends it"""
return Fireworks(api_key=fireworks_api_key) if prompt.startswith("<|begin_of_text|>"):
return prompt[len("<|begin_of_text|>") :]
def _get_openai_client(self) -> AsyncOpenAI: return prompt
return AsyncOpenAI(base_url=self._get_base_url(), api_key=self._get_api_key())
async def completion( async def completion(
self, self,
@ -128,19 +124,13 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse: async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
params = await self._get_params(request) params = await self._get_params(request)
r = await self._get_client().completion.acreate(**params) r = await self.client.completions.create(**params)
return process_completion_response(r) return process_completion_response(r)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params(request) params = await self._get_params(request)
# Wrapper for async generator similar stream = self.client.completions.create(**params)
async def _to_async_generator():
stream = self._get_client().completion.create(**params)
for chunk in stream:
yield chunk
stream = _to_async_generator()
async for chunk in process_completion_stream_response(stream): async for chunk in process_completion_stream_response(stream):
yield chunk yield chunk
@ -209,23 +199,18 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request) params = await self._get_params(request)
if "messages" in params: if "messages" in params:
r = await self._get_client().chat.completions.acreate(**params) r = await self.client.chat.completions.create(**params)
else: else:
r = await self._get_client().completion.acreate(**params) r = await self.client.completions.create(**params)
return process_chat_completion_response(r, request) return process_chat_completion_response(r, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
params = await self._get_params(request) params = await self._get_params(request)
async def _to_async_generator():
if "messages" in params: if "messages" in params:
stream = self._get_client().chat.completions.acreate(**params) stream = self.client.chat.completions.create(**params)
else: else:
stream = self._get_client().completion.acreate(**params) stream = self.client.completions.create(**params)
async for chunk in stream:
yield chunk
stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(stream, request): async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk yield chunk
@ -248,8 +233,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
# Fireworks always prepends with BOS # Fireworks always prepends with BOS
if "prompt" in input_dict: if "prompt" in input_dict:
if input_dict["prompt"].startswith("<|begin_of_text|>"): input_dict["prompt"] = self._preprocess_prompt_for_fireworks(input_dict["prompt"])
input_dict["prompt"] = input_dict["prompt"][len("<|begin_of_text|>") :]
params = { params = {
"model": request.model, "model": request.model,
@ -277,7 +261,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
assert all(not content_has_media(content) for content in contents), ( assert all(not content_has_media(content) for content in contents), (
"Fireworks does not support media for embeddings" "Fireworks does not support media for embeddings"
) )
response = self._get_client().embeddings.create( response = await self.client.embeddings.create(
model=model.provider_resource_id, model=model.provider_resource_id,
input=[interleaved_content_as_str(content) for content in contents], input=[interleaved_content_as_str(content) for content in contents],
**kwargs, **kwargs,
@ -319,14 +303,11 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
prompt_logprobs: int | None = None, prompt_logprobs: int | None = None,
suffix: str | None = None, suffix: str | None = None,
) -> OpenAICompletion: ) -> OpenAICompletion:
model_obj = await self.model_store.get_model(model) if isinstance(prompt, str):
prompt = self._preprocess_prompt_for_fireworks(prompt)
# Fireworks always prepends with BOS return await super().openai_completion(
if isinstance(prompt, str) and prompt.startswith("<|begin_of_text|>"): model=model,
prompt = prompt[len("<|begin_of_text|>") :]
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
prompt=prompt, prompt=prompt,
best_of=best_of, best_of=best_of,
echo=echo, echo=echo,
@ -343,10 +324,11 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
user=user, user=user,
guided_choice=guided_choice,
prompt_logprobs=prompt_logprobs,
suffix=suffix,
) )
return await self._get_openai_client().completions.create(**params)
async def openai_chat_completion( async def openai_chat_completion(
self, self,
model: str, model: str,
@ -408,7 +390,8 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
user=user, user=user,
) )
params = await prepare_openai_completion_params( return await super().openai_chat_completion(
model=model,
messages=messages, messages=messages,
frequency_penalty=frequency_penalty, frequency_penalty=frequency_penalty,
function_call=function_call, function_call=function_call,
@ -432,6 +415,3 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
top_p=top_p, top_p=top_p,
user=user, user=user,
) )
logger.debug(f"fireworks params: {params}")
return await self._get_openai_client().chat.completions.create(model=model_obj.provider_resource_id, **params)