fix: OAI compat endpoint for meta reference inference provider (#1962)

Test plan:
python tests/verifications/generate_report.py --providers
fireworks,together,llama_meta_ref,openai

Co-authored-by: Eric Huang <erichuang@fb.com>
This commit is contained in:
ehhuang 2025-04-17 11:16:04 -07:00 committed by GitHub
parent 8bd6665775
commit 2976b5d992
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 1184 additions and 44 deletions

View file

@ -5,6 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
import io import io
import json
import uuid import uuid
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
@ -299,6 +300,7 @@ class ChatFormat:
call_id=call_id, call_id=call_id,
tool_name=tool_name, tool_name=tool_name,
arguments=tool_arguments, arguments=tool_arguments,
arguments_json=json.dumps(tool_arguments),
) )
) )

View file

@ -515,7 +515,8 @@ class MetaReferenceInferenceImpl(
stop_reason = None stop_reason = None
ipython = False ipython = False
for token_result in self.generator.chat_completion(request): for token_results in self.generator.chat_completion([request]):
token_result = token_results[0]
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1": if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1":
cprint(token_result.text, "cyan", end="") cprint(token_result.text, "cyan", end="")
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2": if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2":

View file

@ -8,7 +8,17 @@ import logging
import time import time
import uuid import uuid
import warnings import warnings
from typing import Any, AsyncGenerator, AsyncIterator, Awaitable, Dict, Iterable, List, Optional, Union from typing import (
Any,
AsyncGenerator,
AsyncIterator,
Awaitable,
Dict,
Iterable,
List,
Optional,
Union,
)
from openai import AsyncStream from openai import AsyncStream
from openai.types.chat import ( from openai.types.chat import (
@ -78,6 +88,7 @@ from llama_stack.apis.common.content_types import (
TextDelta, TextDelta,
ToolCallDelta, ToolCallDelta,
ToolCallParseStatus, ToolCallParseStatus,
_URLOrData,
) )
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionRequest, ChatCompletionRequest,
@ -93,6 +104,7 @@ from llama_stack.apis.inference import (
SamplingParams, SamplingParams,
SystemMessage, SystemMessage,
TokenLogProbs, TokenLogProbs,
ToolChoice,
ToolResponseMessage, ToolResponseMessage,
TopKSamplingStrategy, TopKSamplingStrategy,
TopPSamplingStrategy, TopPSamplingStrategy,
@ -103,7 +115,6 @@ from llama_stack.apis.inference.inference import (
OpenAIChatCompletion, OpenAIChatCompletion,
OpenAICompletion, OpenAICompletion,
OpenAICompletionChoice, OpenAICompletionChoice,
OpenAIMessageParam,
OpenAIResponseFormatParam, OpenAIResponseFormatParam,
ToolConfig, ToolConfig,
) )
@ -612,13 +623,10 @@ async def convert_message_to_openai_dict_new(
) )
for tool in message.tool_calls for tool in message.tool_calls
] ]
params = {}
if tool_calls:
params = {"tool_calls": tool_calls}
out = OpenAIChatCompletionAssistantMessage( out = OpenAIChatCompletionAssistantMessage(
role="assistant", role="assistant",
content=await _convert_message_content(message.content), content=await _convert_message_content(message.content),
**params, tool_calls=tool_calls or None,
) )
elif isinstance(message, ToolResponseMessage): elif isinstance(message, ToolResponseMessage):
out = OpenAIChatCompletionToolMessage( out = OpenAIChatCompletionToolMessage(
@ -695,7 +703,10 @@ def to_openai_param_type(param_type: str) -> dict:
if param_type.startswith("list[") and param_type.endswith("]"): if param_type.startswith("list[") and param_type.endswith("]"):
inner_type = param_type[5:-1] inner_type = param_type[5:-1]
if inner_type in basic_types: if inner_type in basic_types:
return {"type": "array", "items": {"type": basic_types.get(inner_type, inner_type)}} return {
"type": "array",
"items": {"type": basic_types.get(inner_type, inner_type)},
}
return {"type": param_type} return {"type": param_type}
@ -815,6 +826,10 @@ def _convert_openai_finish_reason(finish_reason: str) -> StopReason:
def _convert_openai_request_tool_config(tool_choice: Optional[Union[str, Dict[str, Any]]] = None) -> ToolConfig: def _convert_openai_request_tool_config(tool_choice: Optional[Union[str, Dict[str, Any]]] = None) -> ToolConfig:
tool_config = ToolConfig() tool_config = ToolConfig()
if tool_choice: if tool_choice:
try:
tool_choice = ToolChoice(tool_choice)
except ValueError:
pass
tool_config.tool_choice = tool_choice tool_config.tool_choice = tool_choice
return tool_config return tool_config
@ -849,7 +864,9 @@ def _convert_openai_request_tools(tools: Optional[List[Dict[str, Any]]] = None)
return lls_tools return lls_tools
def _convert_openai_request_response_format(response_format: OpenAIResponseFormatParam = None): def _convert_openai_request_response_format(
response_format: OpenAIResponseFormatParam = None,
):
if not response_format: if not response_format:
return None return None
# response_format can be a dict or a pydantic model # response_format can be a dict or a pydantic model
@ -957,38 +974,50 @@ def _convert_openai_sampling_params(
return sampling_params return sampling_params
def _convert_openai_request_messages(messages: List[OpenAIMessageParam]): def openai_messages_to_messages(
# Llama Stack messages and OpenAI messages are similar, but not identical. messages: List[OpenAIChatCompletionMessage],
lls_messages = [] ) -> List[Message]:
"""
Convert a list of OpenAIChatCompletionMessage into a list of Message.
"""
converted_messages = []
for message in messages: for message in messages:
lls_message = dict(message) if message.role == "system":
converted_message = SystemMessage(content=message.content)
# Llama Stack expects `call_id` but OpenAI uses `tool_call_id` elif message.role == "user":
tool_call_id = lls_message.pop("tool_call_id", None) converted_message = UserMessage(content=openai_content_to_content(message.content))
if tool_call_id: elif message.role == "assistant":
lls_message["call_id"] = tool_call_id converted_message = CompletionMessage(
content=message.content,
content = lls_message.get("content", None) tool_calls=_convert_openai_tool_calls(message.tool_calls),
if isinstance(content, list): stop_reason=StopReason.end_of_turn,
lls_content = []
for item in content:
# items can either by pydantic models or dicts here...
item = dict(item)
if item.get("type", "") == "image_url":
lls_item = ImageContentItem(
type="image",
image=URL(uri=item.get("image_url", {}).get("url", "")),
) )
elif item.get("type", "") == "text": elif message.role == "tool":
lls_item = TextContentItem( converted_message = ToolResponseMessage(
type="text", role="tool",
text=item.get("text", ""), call_id=message.tool_call_id,
content=openai_content_to_content(message.content),
) )
lls_content.append(lls_item) else:
lls_message["content"] = lls_content raise ValueError(f"Unknown role {message.role}")
lls_messages.append(lls_message) converted_messages.append(converted_message)
return converted_messages
return lls_messages
def openai_content_to_content(content: Union[str, Iterable[OpenAIChatCompletionContentPartParam]]):
if isinstance(content, str):
return content
elif isinstance(content, list):
return [openai_content_to_content(c) for c in content]
elif hasattr(content, "type"):
if content.type == "text":
return TextContentItem(type="text", text=content.text)
elif content.type == "image_url":
return ImageContentItem(type="image", image=_URLOrData(url=URL(uri=content.image_url.url)))
else:
raise ValueError(f"Unknown content type: {content.type}")
else:
raise ValueError(f"Unknown content type: {content}")
def convert_openai_chat_completion_choice( def convert_openai_chat_completion_choice(
@ -1313,7 +1342,7 @@ class OpenAIChatCompletionToLlamaStackMixin:
top_p: Optional[float] = None, top_p: Optional[float] = None,
user: Optional[str] = None, user: Optional[str] = None,
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]: ) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
messages = _convert_openai_request_messages(messages) messages = openai_messages_to_messages(messages)
response_format = _convert_openai_request_response_format(response_format) response_format = _convert_openai_request_response_format(response_format)
sampling_params = _convert_openai_sampling_params( sampling_params = _convert_openai_sampling_params(
max_tokens=max_tokens, max_tokens=max_tokens,
@ -1321,7 +1350,10 @@ class OpenAIChatCompletionToLlamaStackMixin:
top_p=top_p, top_p=top_p,
) )
tool_config = _convert_openai_request_tool_config(tool_choice) tool_config = _convert_openai_request_tool_config(tool_choice)
tools = _convert_openai_request_tools(tools) tools = _convert_openai_request_tools(tools)
if tool_config.tool_choice == ToolChoice.none:
tools = []
outstanding_responses = [] outstanding_responses = []
# "n" is the number of completions to generate per prompt # "n" is the number of completions to generate per prompt
@ -1346,7 +1378,9 @@ class OpenAIChatCompletionToLlamaStackMixin:
) )
async def _process_stream_response( async def _process_stream_response(
self, model: str, outstanding_responses: List[Awaitable[AsyncIterator[ChatCompletionResponseStreamChunk]]] self,
model: str,
outstanding_responses: List[Awaitable[AsyncIterator[ChatCompletionResponseStreamChunk]]],
): ):
id = f"chatcmpl-{uuid.uuid4()}" id = f"chatcmpl-{uuid.uuid4()}"
for outstanding_response in outstanding_responses: for outstanding_response in outstanding_responses:
@ -1369,11 +1403,31 @@ class OpenAIChatCompletionToLlamaStackMixin:
elif isinstance(event.delta, ToolCallDelta): elif isinstance(event.delta, ToolCallDelta):
if event.delta.parse_status == ToolCallParseStatus.succeeded: if event.delta.parse_status == ToolCallParseStatus.succeeded:
tool_call = event.delta.tool_call tool_call = event.delta.tool_call
# First chunk includes full structure
openai_tool_call = OpenAIChoiceDeltaToolCall( openai_tool_call = OpenAIChoiceDeltaToolCall(
index=0, index=0,
id=tool_call.call_id, id=tool_call.call_id,
function=OpenAIChoiceDeltaToolCallFunction( function=OpenAIChoiceDeltaToolCallFunction(
name=tool_call.tool_name, arguments=tool_call.arguments_json name=tool_call.tool_name,
arguments="",
),
)
delta = OpenAIChoiceDelta(tool_calls=[openai_tool_call])
yield OpenAIChatCompletionChunk(
id=id,
choices=[
OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta)
],
created=int(time.time()),
model=model,
object="chat.completion.chunk",
)
# arguments
openai_tool_call = OpenAIChoiceDeltaToolCall(
index=0,
function=OpenAIChoiceDeltaToolCallFunction(
arguments=tool_call.arguments_json,
), ),
) )
delta = OpenAIChoiceDelta(tool_calls=[openai_tool_call]) delta = OpenAIChoiceDelta(tool_calls=[openai_tool_call])

View file

@ -1,6 +1,6 @@
# Test Results Report # Test Results Report
*Generated on: 2025-04-16 15:10:57* *Generated on: 2025-04-17 11:08:16*
*This report was generated by running `python tests/verifications/generate_report.py`* *This report was generated by running `python tests/verifications/generate_report.py`*
@ -15,12 +15,62 @@
| Provider | Pass Rate | Tests Passed | Total Tests | | Provider | Pass Rate | Tests Passed | Total Tests |
| --- | --- | --- | --- | | --- | --- | --- | --- |
| Meta_reference | 100.0% | 26 | 26 |
| Together | 51.3% | 39 | 76 | | Together | 51.3% | 39 | 76 |
| Fireworks | 47.4% | 36 | 76 | | Fireworks | 47.4% | 36 | 76 |
| Openai | 100.0% | 52 | 52 | | Openai | 100.0% | 52 | 52 |
## Meta_reference
*Tests run on: 2025-04-15 17:08:59*
```bash
# Run all tests for this provider:
pytest tests/verifications/openai_api/test_chat_completion.py --provider=meta_reference -v
# Example: Run only the 'earth' case of test_chat_non_streaming_basic:
pytest tests/verifications/openai_api/test_chat_completion.py --provider=meta_reference -k "test_chat_non_streaming_basic and earth"
```
**Model Key (Meta_reference)**
| Display Name | Full Model ID |
| --- | --- |
| Llama-4-Scout-Instruct | `meta-llama/Llama-4-Scout-17B-16E-Instruct` |
| Test | Llama-4-Scout-Instruct |
| --- | --- |
| test_chat_non_streaming_basic (earth) | ✅ |
| test_chat_non_streaming_basic (saturn) | ✅ |
| test_chat_non_streaming_image | ✅ |
| test_chat_non_streaming_multi_turn_tool_calling (add_product_tool) | ✅ |
| test_chat_non_streaming_multi_turn_tool_calling (compare_monthly_expense_tool) | ✅ |
| test_chat_non_streaming_multi_turn_tool_calling (get_then_create_event_tool) | ✅ |
| test_chat_non_streaming_multi_turn_tool_calling (text_then_weather_tool) | ✅ |
| test_chat_non_streaming_multi_turn_tool_calling (weather_tool_then_text) | ✅ |
| test_chat_non_streaming_structured_output (calendar) | ✅ |
| test_chat_non_streaming_structured_output (math) | ✅ |
| test_chat_non_streaming_tool_calling | ✅ |
| test_chat_non_streaming_tool_choice_none | ✅ |
| test_chat_non_streaming_tool_choice_required | ✅ |
| test_chat_streaming_basic (earth) | ✅ |
| test_chat_streaming_basic (saturn) | ✅ |
| test_chat_streaming_image | ✅ |
| test_chat_streaming_multi_turn_tool_calling (add_product_tool) | ✅ |
| test_chat_streaming_multi_turn_tool_calling (compare_monthly_expense_tool) | ✅ |
| test_chat_streaming_multi_turn_tool_calling (get_then_create_event_tool) | ✅ |
| test_chat_streaming_multi_turn_tool_calling (text_then_weather_tool) | ✅ |
| test_chat_streaming_multi_turn_tool_calling (weather_tool_then_text) | ✅ |
| test_chat_streaming_structured_output (calendar) | ✅ |
| test_chat_streaming_structured_output (math) | ✅ |
| test_chat_streaming_tool_calling | ✅ |
| test_chat_streaming_tool_choice_none | ✅ |
| test_chat_streaming_tool_choice_required | ✅ |
## Together ## Together
*Tests run on: 2025-04-16 15:03:51* *Tests run on: 2025-04-16 15:03:51*

View file

@ -0,0 +1,8 @@
# LLAMA_STACK_PORT=5002 llama stack run meta-reference-gpu --env INFERENCE_MODEL=meta-llama/Llama-4-Scout-17B-16E-Instruct --env INFERENCE_CHECKPOINT_DIR=<path_to_ckpt>
base_url: http://localhost:5002/v1/openai/v1
api_key_var: foo
models:
- meta-llama/Llama-4-Scout-17B-16E-Instruct
model_display_names:
meta-llama/Llama-4-Scout-17B-16E-Instruct: Llama-4-Scout-Instruct
test_exclusions: {}

View file

@ -60,6 +60,7 @@ RESULTS_DIR.mkdir(exist_ok=True)
MAX_RESULTS_PER_PROVIDER = 1 MAX_RESULTS_PER_PROVIDER = 1
DEFAULT_PROVIDERS = [ DEFAULT_PROVIDERS = [
"meta_reference",
"together", "together",
"fireworks", "fireworks",
"openai", "openai",

View file

@ -12,7 +12,9 @@ from typing import Any
import pytest import pytest
from pydantic import BaseModel from pydantic import BaseModel
from tests.verifications.openai_api.fixtures.fixtures import _load_all_verification_configs from tests.verifications.openai_api.fixtures.fixtures import (
_load_all_verification_configs,
)
from tests.verifications.openai_api.fixtures.load import load_test_cases from tests.verifications.openai_api.fixtures.load import load_test_cases
chat_completion_test_cases = load_test_cases("chat_completion") chat_completion_test_cases = load_test_cases("chat_completion")
@ -272,7 +274,6 @@ def test_chat_non_streaming_tool_choice_required(request, openai_client, model,
tool_choice="required", # Force tool call tool_choice="required", # Force tool call
stream=False, stream=False,
) )
print(response)
assert response.choices[0].message.role == "assistant" assert response.choices[0].message.role == "assistant"
assert len(response.choices[0].message.tool_calls) > 0, "Expected tool call when tool_choice='required'" assert len(response.choices[0].message.tool_calls) > 0, "Expected tool call when tool_choice='required'"

File diff suppressed because it is too large Load diff