diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html
index 348c5d869..124e6b0fa 100644
--- a/docs/_static/llama-stack-spec.html
+++ b/docs/_static/llama-stack-spec.html
@@ -8852,9 +8852,9 @@
"tool_calls": {
"type": "array",
"items": {
- "$ref": "#/components/schemas/ToolCall"
+ "$ref": "#/components/schemas/OpenAIChatCompletionToolCall"
},
- "description": "List of tool calls. Each tool call is a ToolCall object."
+ "description": "List of tool calls. Each tool call is an OpenAIChatCompletionToolCall object."
}
},
"additionalProperties": false,
@@ -8920,6 +8920,46 @@
],
"title": "OpenAIChatCompletionContentPartTextParam"
},
+ "OpenAIChatCompletionToolCall": {
+ "type": "object",
+ "properties": {
+ "id": {
+ "type": "string"
+ },
+ "type": {
+ "type": "string",
+ "const": "function",
+ "default": "function"
+ },
+ "function": {
+ "$ref": "#/components/schemas/OpenAIChatCompletionToolCallFunction"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "id",
+ "type",
+ "function"
+ ],
+ "title": "OpenAIChatCompletionToolCall"
+ },
+ "OpenAIChatCompletionToolCallFunction": {
+ "type": "object",
+ "properties": {
+ "name": {
+ "type": "string"
+ },
+ "arguments": {
+ "type": "string"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "name",
+ "arguments"
+ ],
+ "title": "OpenAIChatCompletionToolCallFunction"
+ },
"OpenAIDeveloperMessageParam": {
"type": "object",
"properties": {
diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml
index 18e39601d..781fbc618 100644
--- a/docs/_static/llama-stack-spec.yaml
+++ b/docs/_static/llama-stack-spec.yaml
@@ -6074,9 +6074,10 @@ components:
tool_calls:
type: array
items:
- $ref: '#/components/schemas/ToolCall'
+ $ref: '#/components/schemas/OpenAIChatCompletionToolCall'
description: >-
- List of tool calls. Each tool call is a ToolCall object.
+ List of tool calls. Each tool call is an OpenAIChatCompletionToolCall
+ object.
additionalProperties: false
required:
- role
@@ -6123,6 +6124,35 @@ components:
- type
- text
title: OpenAIChatCompletionContentPartTextParam
+ OpenAIChatCompletionToolCall:
+ type: object
+ properties:
+ id:
+ type: string
+ type:
+ type: string
+ const: function
+ default: function
+ function:
+ $ref: '#/components/schemas/OpenAIChatCompletionToolCallFunction'
+ additionalProperties: false
+ required:
+ - id
+ - type
+ - function
+ title: OpenAIChatCompletionToolCall
+ OpenAIChatCompletionToolCallFunction:
+ type: object
+ properties:
+ name:
+ type: string
+ arguments:
+ type: string
+ additionalProperties: false
+ required:
+ - name
+ - arguments
+ title: OpenAIChatCompletionToolCallFunction
OpenAIDeveloperMessageParam:
type: object
properties:
diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py
index f843041b3..1f3e64dd6 100644
--- a/llama_stack/apis/inference/inference.py
+++ b/llama_stack/apis/inference/inference.py
@@ -501,6 +501,19 @@ class OpenAISystemMessageParam(BaseModel):
name: Optional[str] = None
+@json_schema_type
+class OpenAIChatCompletionToolCallFunction(BaseModel):
+ name: str
+ arguments: str
+
+
+@json_schema_type
+class OpenAIChatCompletionToolCall(BaseModel):
+ id: str
+ type: Literal["function"] = "function"
+ function: OpenAIChatCompletionToolCallFunction
+
+
@json_schema_type
class OpenAIAssistantMessageParam(BaseModel):
"""A message containing the model's (assistant) response in an OpenAI-compatible chat completion request.
@@ -508,13 +521,13 @@ class OpenAIAssistantMessageParam(BaseModel):
:param role: Must be "assistant" to identify this as the model's response
:param content: The content of the model's response
:param name: (Optional) The name of the assistant message participant.
- :param tool_calls: List of tool calls. Each tool call is a ToolCall object.
+ :param tool_calls: List of tool calls. Each tool call is an OpenAIChatCompletionToolCall object.
"""
role: Literal["assistant"] = "assistant"
content: OpenAIChatCompletionMessageContent
name: Optional[str] = None
- tool_calls: Optional[List[ToolCall]] = Field(default_factory=list)
+ tool_calls: Optional[List[OpenAIChatCompletionToolCall]] = Field(default_factory=list)
@json_schema_type
diff --git a/llama_stack/models/llama/llama3/tool_utils.py b/llama_stack/models/llama/llama3/tool_utils.py
index ef39ba0a5..91b46ec98 100644
--- a/llama_stack/models/llama/llama3/tool_utils.py
+++ b/llama_stack/models/llama/llama3/tool_utils.py
@@ -204,7 +204,9 @@ class ToolUtils:
return None
elif is_json(message_body):
response = json.loads(message_body)
- if ("type" in response and response["type"] == "function") or ("name" in response):
+ if ("type" in response and response["type"] == "function") or (
+ "name" in response and "parameters" in response
+ ):
function_name = response["name"]
args = response["parameters"]
return function_name, args
diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py
index 0b56ba1f7..2b9a27982 100644
--- a/llama_stack/providers/inline/inference/meta_reference/inference.py
+++ b/llama_stack/providers/inline/inference/meta_reference/inference.py
@@ -59,8 +59,8 @@ from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry,
)
from llama_stack.providers.utils.inference.openai_compat import (
- OpenAIChatCompletionUnsupportedMixin,
- OpenAICompletionUnsupportedMixin,
+ OpenAIChatCompletionToLlamaStackMixin,
+ OpenAICompletionToLlamaStackMixin,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
augment_content_with_response_format_prompt,
@@ -83,8 +83,8 @@ def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_
class MetaReferenceInferenceImpl(
- OpenAICompletionUnsupportedMixin,
- OpenAIChatCompletionUnsupportedMixin,
+ OpenAICompletionToLlamaStackMixin,
+ OpenAIChatCompletionToLlamaStackMixin,
SentenceTransformerEmbeddingMixin,
Inference,
ModelsProtocolPrivate,
diff --git a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py
index 5bc20e3c2..d717d055f 100644
--- a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py
+++ b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py
@@ -25,8 +25,8 @@ from llama_stack.providers.utils.inference.embedding_mixin import (
SentenceTransformerEmbeddingMixin,
)
from llama_stack.providers.utils.inference.openai_compat import (
- OpenAIChatCompletionUnsupportedMixin,
- OpenAICompletionUnsupportedMixin,
+ OpenAIChatCompletionToLlamaStackMixin,
+ OpenAICompletionToLlamaStackMixin,
)
from .config import SentenceTransformersInferenceConfig
@@ -35,8 +35,8 @@ log = logging.getLogger(__name__)
class SentenceTransformersInferenceImpl(
- OpenAIChatCompletionUnsupportedMixin,
- OpenAICompletionUnsupportedMixin,
+ OpenAIChatCompletionToLlamaStackMixin,
+ OpenAICompletionToLlamaStackMixin,
SentenceTransformerEmbeddingMixin,
Inference,
ModelsProtocolPrivate,
diff --git a/llama_stack/providers/inline/inference/vllm/vllm.py b/llama_stack/providers/inline/inference/vllm/vllm.py
index 085c79d6b..9d742c39c 100644
--- a/llama_stack/providers/inline/inference/vllm/vllm.py
+++ b/llama_stack/providers/inline/inference/vllm/vllm.py
@@ -66,10 +66,10 @@ from llama_stack.providers.utils.inference.model_registry import (
ModelsProtocolPrivate,
)
from llama_stack.providers.utils.inference.openai_compat import (
- OpenAIChatCompletionUnsupportedMixin,
+ OpenAIChatCompletionToLlamaStackMixin,
OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse,
- OpenAICompletionUnsupportedMixin,
+ OpenAICompletionToLlamaStackMixin,
get_stop_reason,
process_chat_completion_stream_response,
)
@@ -176,8 +176,8 @@ def _convert_sampling_params(
class VLLMInferenceImpl(
Inference,
- OpenAIChatCompletionUnsupportedMixin,
- OpenAICompletionUnsupportedMixin,
+ OpenAIChatCompletionToLlamaStackMixin,
+ OpenAICompletionToLlamaStackMixin,
ModelsProtocolPrivate,
):
"""
diff --git a/llama_stack/providers/remote/inference/bedrock/bedrock.py b/llama_stack/providers/remote/inference/bedrock/bedrock.py
index 0a485da8f..f8dbcf31a 100644
--- a/llama_stack/providers/remote/inference/bedrock/bedrock.py
+++ b/llama_stack/providers/remote/inference/bedrock/bedrock.py
@@ -36,10 +36,10 @@ from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import (
- OpenAIChatCompletionUnsupportedMixin,
+ OpenAIChatCompletionToLlamaStackMixin,
OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse,
- OpenAICompletionUnsupportedMixin,
+ OpenAICompletionToLlamaStackMixin,
get_sampling_strategy_options,
process_chat_completion_response,
process_chat_completion_stream_response,
@@ -56,8 +56,8 @@ from .models import MODEL_ENTRIES
class BedrockInferenceAdapter(
ModelRegistryHelper,
Inference,
- OpenAIChatCompletionUnsupportedMixin,
- OpenAICompletionUnsupportedMixin,
+ OpenAIChatCompletionToLlamaStackMixin,
+ OpenAICompletionToLlamaStackMixin,
):
def __init__(self, config: BedrockConfig) -> None:
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
diff --git a/llama_stack/providers/remote/inference/cerebras/cerebras.py b/llama_stack/providers/remote/inference/cerebras/cerebras.py
index 5e0a5b484..3156601be 100644
--- a/llama_stack/providers/remote/inference/cerebras/cerebras.py
+++ b/llama_stack/providers/remote/inference/cerebras/cerebras.py
@@ -34,8 +34,8 @@ from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import (
- OpenAIChatCompletionUnsupportedMixin,
- OpenAICompletionUnsupportedMixin,
+ OpenAIChatCompletionToLlamaStackMixin,
+ OpenAICompletionToLlamaStackMixin,
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
@@ -54,8 +54,8 @@ from .models import MODEL_ENTRIES
class CerebrasInferenceAdapter(
ModelRegistryHelper,
Inference,
- OpenAIChatCompletionUnsupportedMixin,
- OpenAICompletionUnsupportedMixin,
+ OpenAIChatCompletionToLlamaStackMixin,
+ OpenAICompletionToLlamaStackMixin,
):
def __init__(self, config: CerebrasImplConfig) -> None:
ModelRegistryHelper.__init__(
diff --git a/llama_stack/providers/remote/inference/databricks/databricks.py b/llama_stack/providers/remote/inference/databricks/databricks.py
index a10878b27..27d96eb7d 100644
--- a/llama_stack/providers/remote/inference/databricks/databricks.py
+++ b/llama_stack/providers/remote/inference/databricks/databricks.py
@@ -34,8 +34,8 @@ from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry,
)
from llama_stack.providers.utils.inference.openai_compat import (
- OpenAIChatCompletionUnsupportedMixin,
- OpenAICompletionUnsupportedMixin,
+ OpenAIChatCompletionToLlamaStackMixin,
+ OpenAICompletionToLlamaStackMixin,
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
@@ -61,8 +61,8 @@ model_entries = [
class DatabricksInferenceAdapter(
ModelRegistryHelper,
Inference,
- OpenAIChatCompletionUnsupportedMixin,
- OpenAICompletionUnsupportedMixin,
+ OpenAIChatCompletionToLlamaStackMixin,
+ OpenAICompletionToLlamaStackMixin,
):
def __init__(self, config: DatabricksImplConfig) -> None:
ModelRegistryHelper.__init__(self, model_entries=model_entries)
diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py
index 69bf7c863..48c163c87 100644
--- a/llama_stack/providers/remote/inference/fireworks/fireworks.py
+++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py
@@ -45,6 +45,7 @@ from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import (
+ OpenAIChatCompletionToLlamaStackMixin,
convert_message_to_openai_dict,
get_sampling_options,
prepare_openai_completion_params,
@@ -307,6 +308,11 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
prompt_logprobs: Optional[int] = None,
) -> OpenAICompletion:
model_obj = await self.model_store.get_model(model)
+
+ # Fireworks always prepends with BOS
+ if isinstance(prompt, str) and prompt.startswith("<|begin_of_text|>"):
+ prompt = prompt[len("<|begin_of_text|>") :]
+
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
prompt=prompt,
@@ -326,6 +332,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
top_p=top_p,
user=user,
)
+
return await self._get_openai_client().completions.create(**params)
async def openai_chat_completion(
@@ -356,7 +363,6 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
model_obj = await self.model_store.get_model(model)
params = await prepare_openai_completion_params(
- model=model_obj.provider_resource_id,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
@@ -380,4 +386,12 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
top_p=top_p,
user=user,
)
- return await self._get_openai_client().chat.completions.create(**params)
+
+ # Divert Llama Models through Llama Stack inference APIs because
+ # Fireworks chat completions OpenAI-compatible API does not support
+ # tool calls properly.
+ llama_model = self.get_llama_model(model_obj.provider_resource_id)
+ if llama_model:
+ return await OpenAIChatCompletionToLlamaStackMixin.openai_chat_completion(self, model=model, **params)
+
+ return await self._get_openai_client().chat.completions.create(model=model_obj.provider_resource_id, **params)
diff --git a/llama_stack/providers/remote/inference/runpod/runpod.py b/llama_stack/providers/remote/inference/runpod/runpod.py
index 878460122..72cbead9b 100644
--- a/llama_stack/providers/remote/inference/runpod/runpod.py
+++ b/llama_stack/providers/remote/inference/runpod/runpod.py
@@ -12,8 +12,8 @@ from llama_stack.apis.inference import * # noqa: F403
# from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (
- OpenAIChatCompletionUnsupportedMixin,
- OpenAICompletionUnsupportedMixin,
+ OpenAIChatCompletionToLlamaStackMixin,
+ OpenAICompletionToLlamaStackMixin,
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
@@ -43,8 +43,8 @@ RUNPOD_SUPPORTED_MODELS = {
class RunpodInferenceAdapter(
ModelRegistryHelper,
Inference,
- OpenAIChatCompletionUnsupportedMixin,
- OpenAICompletionUnsupportedMixin,
+ OpenAIChatCompletionToLlamaStackMixin,
+ OpenAICompletionToLlamaStackMixin,
):
def __init__(self, config: RunpodImplConfig) -> None:
ModelRegistryHelper.__init__(self, stack_to_provider_models_map=RUNPOD_SUPPORTED_MODELS)
diff --git a/llama_stack/providers/remote/inference/sambanova/sambanova.py b/llama_stack/providers/remote/inference/sambanova/sambanova.py
index c503657eb..1665e72b8 100644
--- a/llama_stack/providers/remote/inference/sambanova/sambanova.py
+++ b/llama_stack/providers/remote/inference/sambanova/sambanova.py
@@ -42,8 +42,8 @@ from llama_stack.apis.inference import (
)
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (
- OpenAIChatCompletionUnsupportedMixin,
- OpenAICompletionUnsupportedMixin,
+ OpenAIChatCompletionToLlamaStackMixin,
+ OpenAICompletionToLlamaStackMixin,
process_chat_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
@@ -57,8 +57,8 @@ from .models import MODEL_ENTRIES
class SambaNovaInferenceAdapter(
ModelRegistryHelper,
Inference,
- OpenAIChatCompletionUnsupportedMixin,
- OpenAICompletionUnsupportedMixin,
+ OpenAIChatCompletionToLlamaStackMixin,
+ OpenAICompletionToLlamaStackMixin,
):
def __init__(self, config: SambaNovaImplConfig) -> None:
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
diff --git a/llama_stack/providers/remote/inference/tgi/tgi.py b/llama_stack/providers/remote/inference/tgi/tgi.py
index 8f5b5e3cc..4ee386a15 100644
--- a/llama_stack/providers/remote/inference/tgi/tgi.py
+++ b/llama_stack/providers/remote/inference/tgi/tgi.py
@@ -40,10 +40,10 @@ from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry,
)
from llama_stack.providers.utils.inference.openai_compat import (
- OpenAIChatCompletionUnsupportedMixin,
+ OpenAIChatCompletionToLlamaStackMixin,
OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse,
- OpenAICompletionUnsupportedMixin,
+ OpenAICompletionToLlamaStackMixin,
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
@@ -73,8 +73,8 @@ def build_hf_repo_model_entries():
class _HfAdapter(
Inference,
- OpenAIChatCompletionUnsupportedMixin,
- OpenAICompletionUnsupportedMixin,
+ OpenAIChatCompletionToLlamaStackMixin,
+ OpenAICompletionToLlamaStackMixin,
ModelsProtocolPrivate,
):
client: AsyncInferenceClient
diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py
index 07fd75ea8..2fcfa341e 100644
--- a/llama_stack/providers/utils/inference/openai_compat.py
+++ b/llama_stack/providers/utils/inference/openai_compat.py
@@ -8,7 +8,7 @@ import logging
import time
import uuid
import warnings
-from typing import Any, AsyncGenerator, AsyncIterator, Dict, Iterable, List, Optional, Union
+from typing import Any, AsyncGenerator, AsyncIterator, Awaitable, Dict, Iterable, List, Optional, Union
from openai import AsyncStream
from openai.types.chat import (
@@ -50,6 +50,18 @@ from openai.types.chat.chat_completion import (
from openai.types.chat.chat_completion import (
ChoiceLogprobs as OpenAIChoiceLogprobs, # same as chat_completion_chunk ChoiceLogprobs
)
+from openai.types.chat.chat_completion_chunk import (
+ Choice as OpenAIChatCompletionChunkChoice,
+)
+from openai.types.chat.chat_completion_chunk import (
+ ChoiceDelta as OpenAIChoiceDelta,
+)
+from openai.types.chat.chat_completion_chunk import (
+ ChoiceDeltaToolCall as OpenAIChoiceDeltaToolCall,
+)
+from openai.types.chat.chat_completion_chunk import (
+ ChoiceDeltaToolCallFunction as OpenAIChoiceDeltaToolCallFunction,
+)
from openai.types.chat.chat_completion_content_part_image_param import (
ImageURL as OpenAIImageURL,
)
@@ -59,6 +71,7 @@ from openai.types.chat.chat_completion_message_tool_call_param import (
from pydantic import BaseModel
from llama_stack.apis.common.content_types import (
+ URL,
ImageContentItem,
InterleavedContent,
TextContentItem,
@@ -86,16 +99,23 @@ from llama_stack.apis.inference import (
UserMessage,
)
from llama_stack.apis.inference.inference import (
+ JsonSchemaResponseFormat,
OpenAIChatCompletion,
OpenAICompletion,
OpenAICompletionChoice,
+ OpenAIMessageParam,
OpenAIResponseFormatParam,
+ ToolConfig,
+)
+from llama_stack.apis.inference.inference import (
+ OpenAIChoice as OpenAIChatCompletionChoice,
)
from llama_stack.models.llama.datatypes import (
BuiltinTool,
StopReason,
ToolCall,
ToolDefinition,
+ ToolParamDefinition,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
convert_image_content_to_url,
@@ -756,6 +776,17 @@ def convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict:
return out
+def _convert_stop_reason_to_openai_finish_reason(stop_reason: StopReason) -> str:
+ """
+ Convert a StopReason to an OpenAI chat completion finish_reason.
+ """
+ return {
+ StopReason.end_of_turn: "stop",
+ StopReason.end_of_message: "tool_calls",
+ StopReason.out_of_tokens: "length",
+ }.get(stop_reason, "stop")
+
+
def _convert_openai_finish_reason(finish_reason: str) -> StopReason:
"""
Convert an OpenAI chat completion finish_reason to a StopReason.
@@ -781,6 +812,56 @@ def _convert_openai_finish_reason(finish_reason: str) -> StopReason:
}.get(finish_reason, StopReason.end_of_turn)
+def _convert_openai_request_tool_config(tool_choice: Optional[Union[str, Dict[str, Any]]] = None) -> ToolConfig:
+ tool_config = ToolConfig()
+ if tool_choice:
+ tool_config.tool_choice = tool_choice
+ return tool_config
+
+
+def _convert_openai_request_tools(tools: Optional[List[Dict[str, Any]]] = None) -> List[ToolDefinition]:
+ lls_tools = []
+ if not tools:
+ return lls_tools
+
+ for tool in tools:
+ tool_fn = tool.get("function", {})
+ tool_name = tool_fn.get("name", None)
+ tool_desc = tool_fn.get("description", None)
+
+ tool_params = tool_fn.get("parameters", None)
+ lls_tool_params = {}
+ if tool_params is not None:
+ tool_param_properties = tool_params.get("properties", {})
+ for tool_param_key, tool_param_value in tool_param_properties.items():
+ tool_param_def = ToolParamDefinition(
+ param_type=tool_param_value.get("type", None),
+ description=tool_param_value.get("description", None),
+ )
+ lls_tool_params[tool_param_key] = tool_param_def
+
+ lls_tool = ToolDefinition(
+ tool_name=tool_name,
+ description=tool_desc,
+ parameters=lls_tool_params,
+ )
+ lls_tools.append(lls_tool)
+ return lls_tools
+
+
+def _convert_openai_request_response_format(response_format: OpenAIResponseFormatParam = None):
+ if not response_format:
+ return None
+ # response_format can be a dict or a pydantic model
+ response_format = dict(response_format)
+ if response_format.get("type", "") == "json_schema":
+ return JsonSchemaResponseFormat(
+ type="json_schema",
+ json_schema=response_format.get("json_schema", {}).get("schema", ""),
+ )
+ return None
+
+
def _convert_openai_tool_calls(
tool_calls: List[OpenAIChatCompletionMessageToolCall],
) -> List[ToolCall]:
@@ -876,6 +957,40 @@ def _convert_openai_sampling_params(
return sampling_params
+def _convert_openai_request_messages(messages: List[OpenAIMessageParam]):
+ # Llama Stack messages and OpenAI messages are similar, but not identical.
+ lls_messages = []
+ for message in messages:
+ lls_message = dict(message)
+
+ # Llama Stack expects `call_id` but OpenAI uses `tool_call_id`
+ tool_call_id = lls_message.pop("tool_call_id", None)
+ if tool_call_id:
+ lls_message["call_id"] = tool_call_id
+
+ content = lls_message.get("content", None)
+ if isinstance(content, list):
+ lls_content = []
+ for item in content:
+ # items can either by pydantic models or dicts here...
+ item = dict(item)
+ if item.get("type", "") == "image_url":
+ lls_item = ImageContentItem(
+ type="image",
+ image=URL(uri=item.get("image_url", {}).get("url", "")),
+ )
+ elif item.get("type", "") == "text":
+ lls_item = TextContentItem(
+ type="text",
+ text=item.get("text", ""),
+ )
+ lls_content.append(lls_item)
+ lls_message["content"] = lls_content
+ lls_messages.append(lls_message)
+
+ return lls_messages
+
+
def convert_openai_chat_completion_choice(
choice: OpenAIChoice,
) -> ChatCompletionResponse:
@@ -1102,7 +1217,7 @@ async def prepare_openai_completion_params(**params):
return completion_params
-class OpenAICompletionUnsupportedMixin:
+class OpenAICompletionToLlamaStackMixin:
async def openai_completion(
self,
model: str,
@@ -1140,6 +1255,7 @@ class OpenAICompletionUnsupportedMixin:
choices = []
# "n" is the number of completions to generate per prompt
+ n = n or 1
for _i in range(0, n):
# and we may have multiple prompts, if batching was used
@@ -1152,7 +1268,7 @@ class OpenAICompletionUnsupportedMixin:
index = len(choices)
text = result.content
- finish_reason = _convert_openai_finish_reason(result.stop_reason)
+ finish_reason = _convert_stop_reason_to_openai_finish_reason(result.stop_reason)
choice = OpenAICompletionChoice(
index=index,
@@ -1170,7 +1286,7 @@ class OpenAICompletionUnsupportedMixin:
)
-class OpenAIChatCompletionUnsupportedMixin:
+class OpenAIChatCompletionToLlamaStackMixin:
async def openai_chat_completion(
self,
model: str,
@@ -1197,4 +1313,109 @@ class OpenAIChatCompletionUnsupportedMixin:
top_p: Optional[float] = None,
user: Optional[str] = None,
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
- raise ValueError(f"{self.__class__.__name__} doesn't support openai chat completion")
+ messages = _convert_openai_request_messages(messages)
+ response_format = _convert_openai_request_response_format(response_format)
+ sampling_params = _convert_openai_sampling_params(
+ max_tokens=max_tokens,
+ temperature=temperature,
+ top_p=top_p,
+ )
+ tool_config = _convert_openai_request_tool_config(tool_choice)
+ tools = _convert_openai_request_tools(tools)
+
+ outstanding_responses = []
+ # "n" is the number of completions to generate per prompt
+ n = n or 1
+ for _i in range(0, n):
+ response = self.chat_completion(
+ model_id=model,
+ messages=messages,
+ sampling_params=sampling_params,
+ response_format=response_format,
+ stream=stream,
+ tool_config=tool_config,
+ tools=tools,
+ )
+ outstanding_responses.append(response)
+
+ if stream:
+ return OpenAIChatCompletionToLlamaStackMixin._process_stream_response(self, model, outstanding_responses)
+
+ return await OpenAIChatCompletionToLlamaStackMixin._process_non_stream_response(
+ self, model, outstanding_responses
+ )
+
+ async def _process_stream_response(
+ self, model: str, outstanding_responses: List[Awaitable[AsyncIterator[ChatCompletionResponseStreamChunk]]]
+ ):
+ id = f"chatcmpl-{uuid.uuid4()}"
+ for outstanding_response in outstanding_responses:
+ response = await outstanding_response
+ i = 0
+ async for chunk in response:
+ event = chunk.event
+ if event.stop_reason == StopReason.end_of_turn:
+ finish_reason = "stop"
+ elif event.stop_reason == StopReason.end_of_message:
+ finish_reason = "eos"
+ elif event.stop_reason == StopReason.out_of_tokens:
+ finish_reason = "length"
+ else:
+ finish_reason = None
+
+ if isinstance(event.delta, TextDelta):
+ text_delta = event.delta.text
+ delta = OpenAIChoiceDelta(content=text_delta)
+ yield OpenAIChatCompletionChunk(
+ id=id,
+ choices=[OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta)],
+ created=int(time.time()),
+ model=model,
+ object="chat.completion.chunk",
+ )
+ elif isinstance(event.delta, ToolCallDelta):
+ if event.delta.parse_status == ToolCallParseStatus.succeeded:
+ tool_call = event.delta.tool_call
+ openai_tool_call = OpenAIChoiceDeltaToolCall(
+ index=0,
+ id=tool_call.call_id,
+ function=OpenAIChoiceDeltaToolCallFunction(
+ name=tool_call.tool_name, arguments=tool_call.arguments_json
+ ),
+ )
+ delta = OpenAIChoiceDelta(tool_calls=[openai_tool_call])
+ yield OpenAIChatCompletionChunk(
+ id=id,
+ choices=[
+ OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta)
+ ],
+ created=int(time.time()),
+ model=model,
+ object="chat.completion.chunk",
+ )
+ i = i + 1
+
+ async def _process_non_stream_response(
+ self, model: str, outstanding_responses: List[Awaitable[ChatCompletionResponse]]
+ ) -> OpenAIChatCompletion:
+ choices = []
+ for outstanding_response in outstanding_responses:
+ response = await outstanding_response
+ completion_message = response.completion_message
+ message = await convert_message_to_openai_dict_new(completion_message)
+ finish_reason = _convert_stop_reason_to_openai_finish_reason(completion_message.stop_reason)
+
+ choice = OpenAIChatCompletionChoice(
+ index=len(choices),
+ message=message,
+ finish_reason=finish_reason,
+ )
+ choices.append(choice)
+
+ return OpenAIChatCompletion(
+ id=f"chatcmpl-{uuid.uuid4()}",
+ choices=choices,
+ created=int(time.time()),
+ model=model,
+ object="chat.completion",
+ )
diff --git a/tests/verifications/conf/fireworks-llama-stack.yaml b/tests/verifications/conf/fireworks-llama-stack.yaml
new file mode 100644
index 000000000..d91443dd9
--- /dev/null
+++ b/tests/verifications/conf/fireworks-llama-stack.yaml
@@ -0,0 +1,14 @@
+base_url: http://localhost:8321/v1/openai/v1
+api_key_var: FIREWORKS_API_KEY
+models:
+- fireworks/llama-v3p3-70b-instruct
+- fireworks/llama4-scout-instruct-basic
+- fireworks/llama4-maverick-instruct-basic
+model_display_names:
+ fireworks/llama-v3p3-70b-instruct: Llama-3.3-70B-Instruct
+ fireworks/llama4-scout-instruct-basic: Llama-4-Scout-Instruct
+ fireworks/llama4-maverick-instruct-basic: Llama-4-Maverick-Instruct
+test_exclusions:
+ fireworks/llama-v3p3-70b-instruct:
+ - test_chat_non_streaming_image
+ - test_chat_streaming_image
diff --git a/tests/verifications/conf/openai-llama-stack.yaml b/tests/verifications/conf/openai-llama-stack.yaml
new file mode 100644
index 000000000..ee116dcf0
--- /dev/null
+++ b/tests/verifications/conf/openai-llama-stack.yaml
@@ -0,0 +1,9 @@
+base_url: http://localhost:8321/v1/openai/v1
+api_key_var: OPENAI_API_KEY
+models:
+- gpt-4o
+- gpt-4o-mini
+model_display_names:
+ gpt-4o: gpt-4o
+ gpt-4o-mini: gpt-4o-mini
+test_exclusions: {}
diff --git a/tests/verifications/conf/together-llama-stack.yaml b/tests/verifications/conf/together-llama-stack.yaml
new file mode 100644
index 000000000..e49d82604
--- /dev/null
+++ b/tests/verifications/conf/together-llama-stack.yaml
@@ -0,0 +1,14 @@
+base_url: http://localhost:8321/v1/openai/v1
+api_key_var: TOGETHER_API_KEY
+models:
+- together/meta-llama/Llama-3.3-70B-Instruct-Turbo
+- together/meta-llama/Llama-4-Scout-17B-16E-Instruct
+- together/meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8
+model_display_names:
+ together/meta-llama/Llama-3.3-70B-Instruct-Turbo: Llama-3.3-70B-Instruct
+ together/meta-llama/Llama-4-Scout-17B-16E-Instruct: Llama-4-Scout-Instruct
+ together/meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8: Llama-4-Maverick-Instruct
+test_exclusions:
+ together/meta-llama/Llama-3.3-70B-Instruct-Turbo:
+ - test_chat_non_streaming_image
+ - test_chat_streaming_image
diff --git a/tests/verifications/generate_report.py b/tests/verifications/generate_report.py
index 6a7c39ee2..c1eac8a33 100755
--- a/tests/verifications/generate_report.py
+++ b/tests/verifications/generate_report.py
@@ -67,7 +67,16 @@ RESULTS_DIR.mkdir(exist_ok=True)
# Maximum number of test result files to keep per provider
MAX_RESULTS_PER_PROVIDER = 1
-PROVIDER_ORDER = ["together", "fireworks", "groq", "cerebras", "openai"]
+PROVIDER_ORDER = [
+ "together",
+ "fireworks",
+ "groq",
+ "cerebras",
+ "openai",
+ "together-llama-stack",
+ "fireworks-llama-stack",
+ "openai-llama-stack",
+]
VERIFICATION_CONFIG = _load_all_verification_configs()
diff --git a/tests/verifications/openai-api-verification-run.yaml b/tests/verifications/openai-api-verification-run.yaml
new file mode 100644
index 000000000..0e8b99e4f
--- /dev/null
+++ b/tests/verifications/openai-api-verification-run.yaml
@@ -0,0 +1,126 @@
+version: '2'
+image_name: openai-api-verification
+apis:
+- inference
+- telemetry
+- tool_runtime
+- vector_io
+providers:
+ inference:
+ - provider_id: together
+ provider_type: remote::together
+ config:
+ url: https://api.together.xyz/v1
+ api_key: ${env.TOGETHER_API_KEY:}
+ - provider_id: fireworks
+ provider_type: remote::fireworks
+ config:
+ url: https://api.fireworks.ai/inference/v1
+ api_key: ${env.FIREWORKS_API_KEY}
+ - provider_id: openai
+ provider_type: remote::openai
+ config:
+ url: https://api.openai.com/v1
+ api_key: ${env.OPENAI_API_KEY:}
+ - provider_id: sentence-transformers
+ provider_type: inline::sentence-transformers
+ config: {}
+ vector_io:
+ - provider_id: faiss
+ provider_type: inline::faiss
+ config:
+ kvstore:
+ type: sqlite
+ namespace: null
+ db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/openai}/faiss_store.db
+ telemetry:
+ - provider_id: meta-reference
+ provider_type: inline::meta-reference
+ config:
+ service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
+ sinks: ${env.TELEMETRY_SINKS:console,sqlite}
+ sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/openai/trace_store.db}
+ tool_runtime:
+ - provider_id: brave-search
+ provider_type: remote::brave-search
+ config:
+ api_key: ${env.BRAVE_SEARCH_API_KEY:}
+ max_results: 3
+ - provider_id: tavily-search
+ provider_type: remote::tavily-search
+ config:
+ api_key: ${env.TAVILY_SEARCH_API_KEY:}
+ max_results: 3
+ - provider_id: code-interpreter
+ provider_type: inline::code-interpreter
+ config: {}
+ - provider_id: rag-runtime
+ provider_type: inline::rag-runtime
+ config: {}
+ - provider_id: model-context-protocol
+ provider_type: remote::model-context-protocol
+ config: {}
+ - provider_id: wolfram-alpha
+ provider_type: remote::wolfram-alpha
+ config:
+ api_key: ${env.WOLFRAM_ALPHA_API_KEY:}
+metadata_store:
+ type: sqlite
+ db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/openai}/registry.db
+models:
+- metadata: {}
+ model_id: together/meta-llama/Llama-3.3-70B-Instruct-Turbo
+ provider_id: together
+ provider_model_id: meta-llama/Llama-3.3-70B-Instruct-Turbo
+ model_type: llm
+- metadata: {}
+ model_id: together/meta-llama/Llama-4-Scout-17B-16E-Instruct
+ provider_id: together
+ provider_model_id: meta-llama/Llama-4-Scout-17B-16E-Instruct
+ model_type: llm
+- metadata: {}
+ model_id: together/meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8
+ provider_id: together
+ provider_model_id: meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8
+ model_type: llm
+- metadata: {}
+ model_id: fireworks/llama-v3p3-70b-instruct
+ provider_id: fireworks
+ provider_model_id: accounts/fireworks/models/llama-v3p3-70b-instruct
+ model_type: llm
+- metadata: {}
+ model_id: fireworks/llama4-scout-instruct-basic
+ provider_id: fireworks
+ provider_model_id: accounts/fireworks/models/llama4-scout-instruct-basic
+ model_type: llm
+- metadata: {}
+ model_id: fireworks/llama4-maverick-instruct-basic
+ provider_id: fireworks
+ provider_model_id: accounts/fireworks/models/llama4-maverick-instruct-basic
+ model_type: llm
+- metadata: {}
+ model_id: openai/gpt-4o
+ provider_id: openai
+ provider_model_id: openai/gpt-4o
+ model_type: llm
+- metadata: {}
+ model_id: openai/gpt-4o-mini
+ provider_id: openai
+ provider_model_id: openai/gpt-4o-mini
+ model_type: llm
+shields: []
+vector_dbs: []
+datasets: []
+scoring_fns: []
+benchmarks: []
+tool_groups:
+- toolgroup_id: builtin::websearch
+ provider_id: tavily-search
+- toolgroup_id: builtin::rag
+ provider_id: rag-runtime
+- toolgroup_id: builtin::code_interpreter
+ provider_id: code-interpreter
+- toolgroup_id: builtin::wolfram_alpha
+ provider_id: wolfram-alpha
+server:
+ port: 8321
diff --git a/tests/verifications/openai_api/fixtures/fixtures.py b/tests/verifications/openai_api/fixtures/fixtures.py
index 4f8c2e017..940b99b2a 100644
--- a/tests/verifications/openai_api/fixtures/fixtures.py
+++ b/tests/verifications/openai_api/fixtures/fixtures.py
@@ -99,6 +99,9 @@ def model_mapping(provider, providers_model_mapping):
@pytest.fixture
def openai_client(base_url, api_key):
+ # Simplify running against a local Llama Stack
+ if "localhost" in base_url and not api_key:
+ api_key = "empty"
return OpenAI(
base_url=base_url,
api_key=api_key,