llama-stack/llama_stack/providers/utils/inference/litellm_openai_mixin.py
Ben Browning 7641a5cd0b
fix: 100% OpenAI API verification for together and fireworks (#1946)
# What does this PR do?

TLDR: Changes needed to get 100% passing tests for OpenAI API
verification tests when run against Llama Stack with the `together`,
`fireworks`, and `openai` providers. And `groq` is better than before,
at 88% passing.

This cleans up the OpenAI API support for image message types
(specifically `image_url` types) and handling of the `response_format`
chat completion parameter. Both of these required a few more Pydantic
model definitions in our Inference API, just to move from the
not-quite-right stubs I had in place to something fleshed out to match
the actual OpenAI API specs.

As part of testing this, I also found and fixed a bug in the litellm
implementation of openai_completion and openai_chat_completion, so the
providers based on those should actually be working now.

The method `prepare_openai_completion_params` in
`llama_stack/providers/utils/inference/openai_compat.py` was improved to
actually recursively clean up input parameters, including handling of
lists, dicts, and dumping of Pydantic models to dicts. These changes
were required to get to 100% passing tests on the OpenAI API
verification against the `openai` provider.

With the above, the together.ai provider was passing as well as it is
without Llama Stack. But, since we have Llama Stack in the middle, I
took the opportunity to clean up the together.ai provider so that it now
also passes the OpenAI API spec tests we have at 100%. That means
together.ai is now passing our verification test better when using an
OpenAI client talking to Llama Stack than it is when hitting together.ai
directly, without Llama Stack in the middle.

And, another round of work for Fireworks to improve translation of
incoming OpenAI chat completion requests to Llama Stack chat completion
requests gets the fireworks provider passing at 100%. The server-side
fireworks.ai tool calling support with OpenAI chat completions and Llama
4 models isn't great yet, but by pointing the OpenAI clients at Llama
Stack's API we can clean things up and get everything working as
expected for Llama 4 models.

## Test Plan

### OpenAI API Verification Tests

I ran the OpenAI API verification tests as below and 100% of the tests
passed.

First, start a Llama Stack server that runs the `openai` provider with
the `gpt-4o` and `gpt-4o-mini` models deployed. There's not a template
setup to do this out of the box, so I added a
`tests/verifications/openai-api-verification-run.yaml` to do this.

First, ensure you have the necessary API key environment variables set:

```
export TOGETHER_API_KEY="..."
export FIREWORKS_API_KEY="..."
export OPENAI_API_KEY="..."
```

Then, run a Llama Stack server that serves up all these providers:

```
llama stack run \
      --image-type venv \
      tests/verifications/openai-api-verification-run.yaml
```

Finally, generate a new verification report against all these providers,
both with and without the Llama Stack server in the middle.

```
python tests/verifications/generate_report.py \
      --run-tests \
      --provider \
        together \
        fireworks \
        groq \
        openai \
        together-llama-stack \
        fireworks-llama-stack \
        groq-llama-stack \
        openai-llama-stack
```

You'll see that most of the configurations with Llama Stack in the
middle now pass at 100%, even though some of them do not pass at 100%
when hitting the backend provider's API directly with an OpenAI client.

### OpenAI Completion Integration Tests with vLLM:

I also ran the smaller `test_openai_completion.py` test suite (that's
not yet merged with the verification tests) on multiple of the
providers, since I had to adjust the method signature of
openai_chat_completion a bit and thus had to touch lots of these
providers to match. Here's the tests I ran there, all passing:

```
VLLM_URL="http://localhost:8000/v1" INFERENCE_MODEL="meta-llama/Llama-3.2-3B-Instruct" llama stack build --template remote-vllm --image-type venv --run
```

in another terminal

```
LLAMA_STACK_CONFIG=http://localhost:8321 INFERENCE_MODEL="meta-llama/Llama-3.2-3B-Instruct" python -m pytest -v tests/integration/inference/test_openai_completion.py --text-model "meta-llama/Llama-3.2-3B-Instruct"
```

### OpenAI Completion Integration Tests with ollama

```
INFERENCE_MODEL="llama3.2:3b-instruct-q8_0" llama stack build --template ollama --image-type venv --run
```

in another terminal

```
LLAMA_STACK_CONFIG=http://localhost:8321 INFERENCE_MODEL="llama3.2:3b-instruct-q8_0" python -m pytest -v tests/integration/inference/test_openai_completion.py --text-model "llama3.2:3b-instruct-q8_0"
```

### OpenAI Completion Integration Tests with together.ai

```
INFERENCE_MODEL="meta-llama/Llama-3.2-3B-Instruct-Turbo" llama stack build --template together --image-type venv --run
```

in another terminal

```
LLAMA_STACK_CONFIG=http://localhost:8321 INFERENCE_MODEL="meta-llama/Llama-3.2-3B-Instruct-Turbo" python -m pytest -v tests/integration/inference/test_openai_completion.py --text-model "meta-llama/Llama-3.2-3B-Instruct-Turbo"
```

### OpenAI Completion Integration Tests with fireworks.ai

```
INFERENCE_MODEL="meta-llama/Llama-3.1-8B-Instruct" llama stack build --template fireworks --image-type venv --run
```

in another terminal

```
LLAMA_STACK_CONFIG=http://localhost:8321 INFERENCE_MODEL="meta-llama/Llama-3.1-8B-Instruct" python -m pytest -v tests/integration/inference/test_openai_completion.py --text-model "meta-llama/Llama-3.1-8B-Instruct"

---------

Signed-off-by: Ben Browning <bbrownin@redhat.com>
2025-04-14 08:56:29 -07:00

377 lines
14 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
import litellm
from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseStreamChunk,
EmbeddingsResponse,
EmbeddingTaskType,
Inference,
JsonSchemaResponseFormat,
LogProbConfig,
Message,
ResponseFormat,
SamplingParams,
TextTruncation,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.apis.models.models import Model
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (
convert_message_to_openai_dict_new,
convert_openai_chat_completion_choice,
convert_openai_chat_completion_stream,
convert_tooldef_to_openai_tool,
get_sampling_options,
prepare_openai_completion_params,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
logger = get_logger(name=__name__, category="inference")
class LiteLLMOpenAIMixin(
ModelRegistryHelper,
Inference,
NeedsRequestProviderData,
):
def __init__(
self,
model_entries,
api_key_from_config: Optional[str],
provider_data_api_key_field: str,
openai_compat_api_base: str | None = None,
):
ModelRegistryHelper.__init__(self, model_entries)
self.api_key_from_config = api_key_from_config
self.provider_data_api_key_field = provider_data_api_key_field
self.api_base = openai_compat_api_base
if openai_compat_api_base:
self.is_openai_compat = True
else:
self.is_openai_compat = False
async def initialize(self):
pass
async def shutdown(self):
pass
async def register_model(self, model: Model) -> Model:
model_id = self.get_provider_model_id(model.provider_resource_id)
if model_id is None:
raise ValueError(f"Unsupported model: {model.provider_resource_id}")
return model
async def completion(
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
raise NotImplementedError("LiteLLM does not support completion requests")
async def chat_completion(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
model=model.provider_resource_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
response_format=response_format,
stream=stream,
logprobs=logprobs,
tool_config=tool_config,
)
params = await self._get_params(request)
if self.is_openai_compat:
params["model"] = "openai/" + params["model"]
logger.debug(f"params to litellm (openai compat): {params}")
# unfortunately, we need to use synchronous litellm.completion here because litellm
# caches various httpx.client objects in a non-eventloop aware manner
response = litellm.completion(**params)
if stream:
return self._stream_chat_completion(response)
else:
return convert_openai_chat_completion_choice(response.choices[0])
async def _stream_chat_completion(
self, response: litellm.ModelResponse
) -> AsyncIterator[ChatCompletionResponseStreamChunk]:
async def _stream_generator():
for chunk in response:
yield chunk
async for chunk in convert_openai_chat_completion_stream(
_stream_generator(), enable_incremental_tool_calls=True
):
yield chunk
def _add_additional_properties_recursive(self, schema):
"""
Recursively add additionalProperties: False to all object schemas
"""
if isinstance(schema, dict):
if schema.get("type") == "object":
schema["additionalProperties"] = False
# Add required field with all property keys if properties exist
if "properties" in schema and schema["properties"]:
schema["required"] = list(schema["properties"].keys())
if "properties" in schema:
for prop_schema in schema["properties"].values():
self._add_additional_properties_recursive(prop_schema)
for key in ["anyOf", "allOf", "oneOf"]:
if key in schema:
for sub_schema in schema[key]:
self._add_additional_properties_recursive(sub_schema)
if "not" in schema:
self._add_additional_properties_recursive(schema["not"])
# Handle $defs/$ref
if "$defs" in schema:
for def_schema in schema["$defs"].values():
self._add_additional_properties_recursive(def_schema)
return schema
async def _get_params(self, request: ChatCompletionRequest) -> dict:
input_dict = {}
input_dict["messages"] = [await convert_message_to_openai_dict_new(m) for m in request.messages]
if fmt := request.response_format:
if not isinstance(fmt, JsonSchemaResponseFormat):
raise ValueError(
f"Unsupported response format: {type(fmt)}. Only JsonSchemaResponseFormat is supported."
)
fmt = fmt.json_schema
name = fmt["title"]
del fmt["title"]
fmt["additionalProperties"] = False
# Apply additionalProperties: False recursively to all objects
fmt = self._add_additional_properties_recursive(fmt)
input_dict["response_format"] = {
"type": "json_schema",
"json_schema": {
"name": name,
"schema": fmt,
"strict": True,
},
}
if request.tools:
input_dict["tools"] = [convert_tooldef_to_openai_tool(tool) for tool in request.tools]
if request.tool_config.tool_choice:
input_dict["tool_choice"] = (
request.tool_config.tool_choice.value
if isinstance(request.tool_config.tool_choice, ToolChoice)
else request.tool_config.tool_choice
)
provider_data = self.get_request_provider_data()
key_field = self.provider_data_api_key_field
if provider_data and getattr(provider_data, key_field, None):
api_key = getattr(provider_data, key_field)
else:
api_key = self.api_key_from_config
return {
"model": request.model,
"api_key": api_key,
"api_base": self.api_base,
**input_dict,
"stream": request.stream,
**get_sampling_options(request.sampling_params),
}
async def embeddings(
self,
model_id: str,
contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
response = litellm.embedding(
model=model.provider_resource_id,
input=[interleaved_content_as_str(content) for content in contents],
)
embeddings = [data["embedding"] for data in response["data"]]
return EmbeddingsResponse(embeddings=embeddings)
async def openai_completion(
self,
model: str,
prompt: Union[str, List[str], List[int], List[List[int]]],
best_of: Optional[int] = None,
echo: Optional[bool] = None,
frequency_penalty: Optional[float] = None,
logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
presence_penalty: Optional[float] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
stream_options: Optional[Dict[str, Any]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
guided_choice: Optional[List[str]] = None,
prompt_logprobs: Optional[int] = None,
) -> OpenAICompletion:
model_obj = await self.model_store.get_model(model)
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
prompt=prompt,
best_of=best_of,
echo=echo,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
top_p=top_p,
user=user,
guided_choice=guided_choice,
prompt_logprobs=prompt_logprobs,
)
return await litellm.atext_completion(**params)
async def openai_chat_completion(
self,
model: str,
messages: List[OpenAIMessageParam],
frequency_penalty: Optional[float] = None,
function_call: Optional[Union[str, Dict[str, Any]]] = None,
functions: Optional[List[Dict[str, Any]]] = None,
logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
max_completion_tokens: Optional[int] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
parallel_tool_calls: Optional[bool] = None,
presence_penalty: Optional[float] = None,
response_format: Optional[OpenAIResponseFormatParam] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
stream_options: Optional[Dict[str, Any]] = None,
temperature: Optional[float] = None,
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
tools: Optional[List[Dict[str, Any]]] = None,
top_logprobs: Optional[int] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
model_obj = await self.model_store.get_model(model)
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
return await litellm.acompletion(**params)
async def batch_completion(
self,
model_id: str,
content_batch: List[InterleavedContent],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
):
raise NotImplementedError("Batch completion is not supported for OpenAI Compat")
async def batch_chat_completion(
self,
model_id: str,
messages_batch: List[List[Message]],
sampling_params: Optional[SamplingParams] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_config: Optional[ToolConfig] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
):
raise NotImplementedError("Batch chat completion is not supported for OpenAI Compat")