llama-stack/llama_stack/providers/remote/inference/fireworks/fireworks.py
Ben Browning 602e949a46
fix: OpenAI Completions API and Fireworks (#1997)
# What does this PR do?

We were passing a dict into the compat mixin for OpenAI Completions when
using Llama models with Fireworks, and that was breaking some strong
typing code that was added in openai_compat.py. We shouldn't have been
converting these params to a dict in that case anyway, so this adjusts
things to pass the params in as their actual original types when calling
the OpenAIChatCompletionToLlamaStackMixin.

## Test Plan

All of the fireworks provider verification tests were failing due to
some OpenAI compatibility cleanup in #1962. The changes in that PR were
good to make, and this just cleans up the fireworks provider code to
stop passing in untyped dicts to some of those `openai_compat.py`
methods since we have the original strongly-typed parameters we can pass
in.

```
llama stack run --image-type venv tests/verifications/openai-api-verification-run.yaml
```

```
python -m pytest -s -v tests/verifications/openai_api/test_chat_completion.py  --provider=fireworks-llama-stack
```

Before this PR, all of the fireworks OpenAI verification tests were
failing. Now, most of them are passing.

Signed-off-by: Ben Browning <bbrownin@redhat.com>
2025-04-21 11:49:12 -07:00

423 lines
16 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
from fireworks.client import Fireworks
from openai import AsyncOpenAI
from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
CompletionRequest,
CompletionResponse,
EmbeddingsResponse,
EmbeddingTaskType,
Inference,
LogProbConfig,
Message,
ResponseFormat,
ResponseFormatType,
SamplingParams,
TextTruncation,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionToLlamaStackMixin,
convert_message_to_openai_dict,
get_sampling_options,
prepare_openai_completion_params,
process_chat_completion_response,
process_chat_completion_stream_response,
process_completion_response,
process_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
completion_request_to_prompt,
content_has_media,
interleaved_content_as_str,
request_has_media,
)
from .config import FireworksImplConfig
from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference")
class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
def __init__(self, config: FireworksImplConfig) -> None:
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
self.config = config
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
def _get_api_key(self) -> str:
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
if config_api_key:
return config_api_key
else:
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.fireworks_api_key:
raise ValueError(
'Pass Fireworks API Key in the header X-LlamaStack-Provider-Data as { "fireworks_api_key": <your api key>}'
)
return provider_data.fireworks_api_key
def _get_base_url(self) -> str:
return "https://api.fireworks.ai/inference/v1"
def _get_client(self) -> Fireworks:
fireworks_api_key = self._get_api_key()
return Fireworks(api_key=fireworks_api_key)
def _get_openai_client(self) -> AsyncOpenAI:
return AsyncOpenAI(base_url=self._get_base_url(), api_key=self._get_api_key())
async def completion(
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
request = CompletionRequest(
model=model.provider_resource_id,
content=content,
sampling_params=sampling_params,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)
if stream:
return self._stream_completion(request)
else:
return await self._nonstream_completion(request)
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
params = await self._get_params(request)
r = await self._get_client().completion.acreate(**params)
return process_completion_response(r)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
# Wrapper for async generator similar
async def _to_async_generator():
stream = self._get_client().completion.create(**params)
for chunk in stream:
yield chunk
stream = _to_async_generator()
async for chunk in process_completion_stream_response(stream):
yield chunk
def _build_options(
self,
sampling_params: Optional[SamplingParams],
fmt: ResponseFormat,
logprobs: Optional[LogProbConfig],
) -> dict:
options = get_sampling_options(sampling_params)
options.setdefault("max_tokens", 512)
if fmt:
if fmt.type == ResponseFormatType.json_schema.value:
options["response_format"] = {
"type": "json_object",
"schema": fmt.json_schema,
}
elif fmt.type == ResponseFormatType.grammar.value:
options["response_format"] = {
"type": "grammar",
"grammar": fmt.bnf,
}
else:
raise ValueError(f"Unknown response format {fmt.type}")
if logprobs and logprobs.top_k:
options["logprobs"] = logprobs.top_k
if options["logprobs"] <= 0 or options["logprobs"] >= 5:
raise ValueError("Required range: 0 < top_k < 5")
return options
async def chat_completion(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
model=model.provider_resource_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
response_format=response_format,
stream=stream,
logprobs=logprobs,
tool_config=tool_config,
)
if stream:
return self._stream_chat_completion(request)
else:
return await self._nonstream_chat_completion(request)
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request)
if "messages" in params:
r = await self._get_client().chat.completions.acreate(**params)
else:
r = await self._get_client().completion.acreate(**params)
return process_chat_completion_response(r, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
async def _to_async_generator():
if "messages" in params:
stream = self._get_client().chat.completions.acreate(**params)
else:
stream = self._get_client().completion.acreate(**params)
async for chunk in stream:
yield chunk
stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
input_dict = {}
media_present = request_has_media(request)
llama_model = self.get_llama_model(request.model)
if isinstance(request, ChatCompletionRequest):
if media_present or not llama_model:
input_dict["messages"] = [
await convert_message_to_openai_dict(m, download=True) for m in request.messages
]
else:
input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model)
else:
assert not media_present, "Fireworks does not support media for Completion requests"
input_dict["prompt"] = await completion_request_to_prompt(request)
# Fireworks always prepends with BOS
if "prompt" in input_dict:
if input_dict["prompt"].startswith("<|begin_of_text|>"):
input_dict["prompt"] = input_dict["prompt"][len("<|begin_of_text|>") :]
params = {
"model": request.model,
**input_dict,
"stream": request.stream,
**self._build_options(request.sampling_params, request.response_format, request.logprobs),
}
logger.debug(f"params to fireworks: {params}")
return params
async def embeddings(
self,
model_id: str,
contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
kwargs = {}
if model.metadata.get("embedding_dimension"):
kwargs["dimensions"] = model.metadata.get("embedding_dimension")
assert all(not content_has_media(content) for content in contents), (
"Fireworks does not support media for embeddings"
)
response = self._get_client().embeddings.create(
model=model.provider_resource_id,
input=[interleaved_content_as_str(content) for content in contents],
**kwargs,
)
embeddings = [data.embedding for data in response.data]
return EmbeddingsResponse(embeddings=embeddings)
async def openai_completion(
self,
model: str,
prompt: Union[str, List[str], List[int], List[List[int]]],
best_of: Optional[int] = None,
echo: Optional[bool] = None,
frequency_penalty: Optional[float] = None,
logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
presence_penalty: Optional[float] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
stream_options: Optional[Dict[str, Any]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
guided_choice: Optional[List[str]] = None,
prompt_logprobs: Optional[int] = None,
) -> OpenAICompletion:
model_obj = await self.model_store.get_model(model)
# Fireworks always prepends with BOS
if isinstance(prompt, str) and prompt.startswith("<|begin_of_text|>"):
prompt = prompt[len("<|begin_of_text|>") :]
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
prompt=prompt,
best_of=best_of,
echo=echo,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
top_p=top_p,
user=user,
)
return await self._get_openai_client().completions.create(**params)
async def openai_chat_completion(
self,
model: str,
messages: List[OpenAIMessageParam],
frequency_penalty: Optional[float] = None,
function_call: Optional[Union[str, Dict[str, Any]]] = None,
functions: Optional[List[Dict[str, Any]]] = None,
logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
max_completion_tokens: Optional[int] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
parallel_tool_calls: Optional[bool] = None,
presence_penalty: Optional[float] = None,
response_format: Optional[OpenAIResponseFormatParam] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
stream_options: Optional[Dict[str, Any]] = None,
temperature: Optional[float] = None,
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
tools: Optional[List[Dict[str, Any]]] = None,
top_logprobs: Optional[int] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
model_obj = await self.model_store.get_model(model)
# Divert Llama Models through Llama Stack inference APIs because
# Fireworks chat completions OpenAI-compatible API does not support
# tool calls properly.
llama_model = self.get_llama_model(model_obj.provider_resource_id)
if llama_model:
return await OpenAIChatCompletionToLlamaStackMixin.openai_chat_completion(
self,
model=model,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
params = await prepare_openai_completion_params(
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
return await self._get_openai_client().chat.completions.create(model=model_obj.provider_resource_id, **params)