mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
fix: OAI compat endpoint for meta reference inference provider (#1962)
Test plan: python tests/verifications/generate_report.py --providers fireworks,together,llama_meta_ref,openai Co-authored-by: Eric Huang <erichuang@fb.com>
This commit is contained in:
parent
8bd6665775
commit
2976b5d992
8 changed files with 1184 additions and 44 deletions
|
@ -5,6 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import io
|
import io
|
||||||
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
@ -299,6 +300,7 @@ class ChatFormat:
|
||||||
call_id=call_id,
|
call_id=call_id,
|
||||||
tool_name=tool_name,
|
tool_name=tool_name,
|
||||||
arguments=tool_arguments,
|
arguments=tool_arguments,
|
||||||
|
arguments_json=json.dumps(tool_arguments),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -515,7 +515,8 @@ class MetaReferenceInferenceImpl(
|
||||||
stop_reason = None
|
stop_reason = None
|
||||||
ipython = False
|
ipython = False
|
||||||
|
|
||||||
for token_result in self.generator.chat_completion(request):
|
for token_results in self.generator.chat_completion([request]):
|
||||||
|
token_result = token_results[0]
|
||||||
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1":
|
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1":
|
||||||
cprint(token_result.text, "cyan", end="")
|
cprint(token_result.text, "cyan", end="")
|
||||||
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2":
|
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2":
|
||||||
|
|
|
@ -8,7 +8,17 @@ import logging
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Any, AsyncGenerator, AsyncIterator, Awaitable, Dict, Iterable, List, Optional, Union
|
from typing import (
|
||||||
|
Any,
|
||||||
|
AsyncGenerator,
|
||||||
|
AsyncIterator,
|
||||||
|
Awaitable,
|
||||||
|
Dict,
|
||||||
|
Iterable,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
from openai import AsyncStream
|
from openai import AsyncStream
|
||||||
from openai.types.chat import (
|
from openai.types.chat import (
|
||||||
|
@ -78,6 +88,7 @@ from llama_stack.apis.common.content_types import (
|
||||||
TextDelta,
|
TextDelta,
|
||||||
ToolCallDelta,
|
ToolCallDelta,
|
||||||
ToolCallParseStatus,
|
ToolCallParseStatus,
|
||||||
|
_URLOrData,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
|
@ -93,6 +104,7 @@ from llama_stack.apis.inference import (
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
TokenLogProbs,
|
TokenLogProbs,
|
||||||
|
ToolChoice,
|
||||||
ToolResponseMessage,
|
ToolResponseMessage,
|
||||||
TopKSamplingStrategy,
|
TopKSamplingStrategy,
|
||||||
TopPSamplingStrategy,
|
TopPSamplingStrategy,
|
||||||
|
@ -103,7 +115,6 @@ from llama_stack.apis.inference.inference import (
|
||||||
OpenAIChatCompletion,
|
OpenAIChatCompletion,
|
||||||
OpenAICompletion,
|
OpenAICompletion,
|
||||||
OpenAICompletionChoice,
|
OpenAICompletionChoice,
|
||||||
OpenAIMessageParam,
|
|
||||||
OpenAIResponseFormatParam,
|
OpenAIResponseFormatParam,
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
)
|
)
|
||||||
|
@ -612,13 +623,10 @@ async def convert_message_to_openai_dict_new(
|
||||||
)
|
)
|
||||||
for tool in message.tool_calls
|
for tool in message.tool_calls
|
||||||
]
|
]
|
||||||
params = {}
|
|
||||||
if tool_calls:
|
|
||||||
params = {"tool_calls": tool_calls}
|
|
||||||
out = OpenAIChatCompletionAssistantMessage(
|
out = OpenAIChatCompletionAssistantMessage(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
content=await _convert_message_content(message.content),
|
content=await _convert_message_content(message.content),
|
||||||
**params,
|
tool_calls=tool_calls or None,
|
||||||
)
|
)
|
||||||
elif isinstance(message, ToolResponseMessage):
|
elif isinstance(message, ToolResponseMessage):
|
||||||
out = OpenAIChatCompletionToolMessage(
|
out = OpenAIChatCompletionToolMessage(
|
||||||
|
@ -695,7 +703,10 @@ def to_openai_param_type(param_type: str) -> dict:
|
||||||
if param_type.startswith("list[") and param_type.endswith("]"):
|
if param_type.startswith("list[") and param_type.endswith("]"):
|
||||||
inner_type = param_type[5:-1]
|
inner_type = param_type[5:-1]
|
||||||
if inner_type in basic_types:
|
if inner_type in basic_types:
|
||||||
return {"type": "array", "items": {"type": basic_types.get(inner_type, inner_type)}}
|
return {
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": basic_types.get(inner_type, inner_type)},
|
||||||
|
}
|
||||||
|
|
||||||
return {"type": param_type}
|
return {"type": param_type}
|
||||||
|
|
||||||
|
@ -815,6 +826,10 @@ def _convert_openai_finish_reason(finish_reason: str) -> StopReason:
|
||||||
def _convert_openai_request_tool_config(tool_choice: Optional[Union[str, Dict[str, Any]]] = None) -> ToolConfig:
|
def _convert_openai_request_tool_config(tool_choice: Optional[Union[str, Dict[str, Any]]] = None) -> ToolConfig:
|
||||||
tool_config = ToolConfig()
|
tool_config = ToolConfig()
|
||||||
if tool_choice:
|
if tool_choice:
|
||||||
|
try:
|
||||||
|
tool_choice = ToolChoice(tool_choice)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
tool_config.tool_choice = tool_choice
|
tool_config.tool_choice = tool_choice
|
||||||
return tool_config
|
return tool_config
|
||||||
|
|
||||||
|
@ -849,7 +864,9 @@ def _convert_openai_request_tools(tools: Optional[List[Dict[str, Any]]] = None)
|
||||||
return lls_tools
|
return lls_tools
|
||||||
|
|
||||||
|
|
||||||
def _convert_openai_request_response_format(response_format: OpenAIResponseFormatParam = None):
|
def _convert_openai_request_response_format(
|
||||||
|
response_format: OpenAIResponseFormatParam = None,
|
||||||
|
):
|
||||||
if not response_format:
|
if not response_format:
|
||||||
return None
|
return None
|
||||||
# response_format can be a dict or a pydantic model
|
# response_format can be a dict or a pydantic model
|
||||||
|
@ -957,38 +974,50 @@ def _convert_openai_sampling_params(
|
||||||
return sampling_params
|
return sampling_params
|
||||||
|
|
||||||
|
|
||||||
def _convert_openai_request_messages(messages: List[OpenAIMessageParam]):
|
def openai_messages_to_messages(
|
||||||
# Llama Stack messages and OpenAI messages are similar, but not identical.
|
messages: List[OpenAIChatCompletionMessage],
|
||||||
lls_messages = []
|
) -> List[Message]:
|
||||||
|
"""
|
||||||
|
Convert a list of OpenAIChatCompletionMessage into a list of Message.
|
||||||
|
"""
|
||||||
|
converted_messages = []
|
||||||
for message in messages:
|
for message in messages:
|
||||||
lls_message = dict(message)
|
if message.role == "system":
|
||||||
|
converted_message = SystemMessage(content=message.content)
|
||||||
|
elif message.role == "user":
|
||||||
|
converted_message = UserMessage(content=openai_content_to_content(message.content))
|
||||||
|
elif message.role == "assistant":
|
||||||
|
converted_message = CompletionMessage(
|
||||||
|
content=message.content,
|
||||||
|
tool_calls=_convert_openai_tool_calls(message.tool_calls),
|
||||||
|
stop_reason=StopReason.end_of_turn,
|
||||||
|
)
|
||||||
|
elif message.role == "tool":
|
||||||
|
converted_message = ToolResponseMessage(
|
||||||
|
role="tool",
|
||||||
|
call_id=message.tool_call_id,
|
||||||
|
content=openai_content_to_content(message.content),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown role {message.role}")
|
||||||
|
converted_messages.append(converted_message)
|
||||||
|
return converted_messages
|
||||||
|
|
||||||
# Llama Stack expects `call_id` but OpenAI uses `tool_call_id`
|
|
||||||
tool_call_id = lls_message.pop("tool_call_id", None)
|
|
||||||
if tool_call_id:
|
|
||||||
lls_message["call_id"] = tool_call_id
|
|
||||||
|
|
||||||
content = lls_message.get("content", None)
|
def openai_content_to_content(content: Union[str, Iterable[OpenAIChatCompletionContentPartParam]]):
|
||||||
if isinstance(content, list):
|
if isinstance(content, str):
|
||||||
lls_content = []
|
return content
|
||||||
for item in content:
|
elif isinstance(content, list):
|
||||||
# items can either by pydantic models or dicts here...
|
return [openai_content_to_content(c) for c in content]
|
||||||
item = dict(item)
|
elif hasattr(content, "type"):
|
||||||
if item.get("type", "") == "image_url":
|
if content.type == "text":
|
||||||
lls_item = ImageContentItem(
|
return TextContentItem(type="text", text=content.text)
|
||||||
type="image",
|
elif content.type == "image_url":
|
||||||
image=URL(uri=item.get("image_url", {}).get("url", "")),
|
return ImageContentItem(type="image", image=_URLOrData(url=URL(uri=content.image_url.url)))
|
||||||
)
|
else:
|
||||||
elif item.get("type", "") == "text":
|
raise ValueError(f"Unknown content type: {content.type}")
|
||||||
lls_item = TextContentItem(
|
else:
|
||||||
type="text",
|
raise ValueError(f"Unknown content type: {content}")
|
||||||
text=item.get("text", ""),
|
|
||||||
)
|
|
||||||
lls_content.append(lls_item)
|
|
||||||
lls_message["content"] = lls_content
|
|
||||||
lls_messages.append(lls_message)
|
|
||||||
|
|
||||||
return lls_messages
|
|
||||||
|
|
||||||
|
|
||||||
def convert_openai_chat_completion_choice(
|
def convert_openai_chat_completion_choice(
|
||||||
|
@ -1313,7 +1342,7 @@ class OpenAIChatCompletionToLlamaStackMixin:
|
||||||
top_p: Optional[float] = None,
|
top_p: Optional[float] = None,
|
||||||
user: Optional[str] = None,
|
user: Optional[str] = None,
|
||||||
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
||||||
messages = _convert_openai_request_messages(messages)
|
messages = openai_messages_to_messages(messages)
|
||||||
response_format = _convert_openai_request_response_format(response_format)
|
response_format = _convert_openai_request_response_format(response_format)
|
||||||
sampling_params = _convert_openai_sampling_params(
|
sampling_params = _convert_openai_sampling_params(
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
|
@ -1321,7 +1350,10 @@ class OpenAIChatCompletionToLlamaStackMixin:
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
)
|
)
|
||||||
tool_config = _convert_openai_request_tool_config(tool_choice)
|
tool_config = _convert_openai_request_tool_config(tool_choice)
|
||||||
|
|
||||||
tools = _convert_openai_request_tools(tools)
|
tools = _convert_openai_request_tools(tools)
|
||||||
|
if tool_config.tool_choice == ToolChoice.none:
|
||||||
|
tools = []
|
||||||
|
|
||||||
outstanding_responses = []
|
outstanding_responses = []
|
||||||
# "n" is the number of completions to generate per prompt
|
# "n" is the number of completions to generate per prompt
|
||||||
|
@ -1346,7 +1378,9 @@ class OpenAIChatCompletionToLlamaStackMixin:
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _process_stream_response(
|
async def _process_stream_response(
|
||||||
self, model: str, outstanding_responses: List[Awaitable[AsyncIterator[ChatCompletionResponseStreamChunk]]]
|
self,
|
||||||
|
model: str,
|
||||||
|
outstanding_responses: List[Awaitable[AsyncIterator[ChatCompletionResponseStreamChunk]]],
|
||||||
):
|
):
|
||||||
id = f"chatcmpl-{uuid.uuid4()}"
|
id = f"chatcmpl-{uuid.uuid4()}"
|
||||||
for outstanding_response in outstanding_responses:
|
for outstanding_response in outstanding_responses:
|
||||||
|
@ -1369,11 +1403,31 @@ class OpenAIChatCompletionToLlamaStackMixin:
|
||||||
elif isinstance(event.delta, ToolCallDelta):
|
elif isinstance(event.delta, ToolCallDelta):
|
||||||
if event.delta.parse_status == ToolCallParseStatus.succeeded:
|
if event.delta.parse_status == ToolCallParseStatus.succeeded:
|
||||||
tool_call = event.delta.tool_call
|
tool_call = event.delta.tool_call
|
||||||
|
|
||||||
|
# First chunk includes full structure
|
||||||
openai_tool_call = OpenAIChoiceDeltaToolCall(
|
openai_tool_call = OpenAIChoiceDeltaToolCall(
|
||||||
index=0,
|
index=0,
|
||||||
id=tool_call.call_id,
|
id=tool_call.call_id,
|
||||||
function=OpenAIChoiceDeltaToolCallFunction(
|
function=OpenAIChoiceDeltaToolCallFunction(
|
||||||
name=tool_call.tool_name, arguments=tool_call.arguments_json
|
name=tool_call.tool_name,
|
||||||
|
arguments="",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
delta = OpenAIChoiceDelta(tool_calls=[openai_tool_call])
|
||||||
|
yield OpenAIChatCompletionChunk(
|
||||||
|
id=id,
|
||||||
|
choices=[
|
||||||
|
OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta)
|
||||||
|
],
|
||||||
|
created=int(time.time()),
|
||||||
|
model=model,
|
||||||
|
object="chat.completion.chunk",
|
||||||
|
)
|
||||||
|
# arguments
|
||||||
|
openai_tool_call = OpenAIChoiceDeltaToolCall(
|
||||||
|
index=0,
|
||||||
|
function=OpenAIChoiceDeltaToolCallFunction(
|
||||||
|
arguments=tool_call.arguments_json,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
delta = OpenAIChoiceDelta(tool_calls=[openai_tool_call])
|
delta = OpenAIChoiceDelta(tool_calls=[openai_tool_call])
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
# Test Results Report
|
# Test Results Report
|
||||||
|
|
||||||
*Generated on: 2025-04-16 15:10:57*
|
*Generated on: 2025-04-17 11:08:16*
|
||||||
|
|
||||||
*This report was generated by running `python tests/verifications/generate_report.py`*
|
*This report was generated by running `python tests/verifications/generate_report.py`*
|
||||||
|
|
||||||
|
@ -15,12 +15,62 @@
|
||||||
|
|
||||||
| Provider | Pass Rate | Tests Passed | Total Tests |
|
| Provider | Pass Rate | Tests Passed | Total Tests |
|
||||||
| --- | --- | --- | --- |
|
| --- | --- | --- | --- |
|
||||||
|
| Meta_reference | 100.0% | 26 | 26 |
|
||||||
| Together | 51.3% | 39 | 76 |
|
| Together | 51.3% | 39 | 76 |
|
||||||
| Fireworks | 47.4% | 36 | 76 |
|
| Fireworks | 47.4% | 36 | 76 |
|
||||||
| Openai | 100.0% | 52 | 52 |
|
| Openai | 100.0% | 52 | 52 |
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## Meta_reference
|
||||||
|
|
||||||
|
*Tests run on: 2025-04-15 17:08:59*
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Run all tests for this provider:
|
||||||
|
pytest tests/verifications/openai_api/test_chat_completion.py --provider=meta_reference -v
|
||||||
|
|
||||||
|
# Example: Run only the 'earth' case of test_chat_non_streaming_basic:
|
||||||
|
pytest tests/verifications/openai_api/test_chat_completion.py --provider=meta_reference -k "test_chat_non_streaming_basic and earth"
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
**Model Key (Meta_reference)**
|
||||||
|
|
||||||
|
| Display Name | Full Model ID |
|
||||||
|
| --- | --- |
|
||||||
|
| Llama-4-Scout-Instruct | `meta-llama/Llama-4-Scout-17B-16E-Instruct` |
|
||||||
|
|
||||||
|
|
||||||
|
| Test | Llama-4-Scout-Instruct |
|
||||||
|
| --- | --- |
|
||||||
|
| test_chat_non_streaming_basic (earth) | ✅ |
|
||||||
|
| test_chat_non_streaming_basic (saturn) | ✅ |
|
||||||
|
| test_chat_non_streaming_image | ✅ |
|
||||||
|
| test_chat_non_streaming_multi_turn_tool_calling (add_product_tool) | ✅ |
|
||||||
|
| test_chat_non_streaming_multi_turn_tool_calling (compare_monthly_expense_tool) | ✅ |
|
||||||
|
| test_chat_non_streaming_multi_turn_tool_calling (get_then_create_event_tool) | ✅ |
|
||||||
|
| test_chat_non_streaming_multi_turn_tool_calling (text_then_weather_tool) | ✅ |
|
||||||
|
| test_chat_non_streaming_multi_turn_tool_calling (weather_tool_then_text) | ✅ |
|
||||||
|
| test_chat_non_streaming_structured_output (calendar) | ✅ |
|
||||||
|
| test_chat_non_streaming_structured_output (math) | ✅ |
|
||||||
|
| test_chat_non_streaming_tool_calling | ✅ |
|
||||||
|
| test_chat_non_streaming_tool_choice_none | ✅ |
|
||||||
|
| test_chat_non_streaming_tool_choice_required | ✅ |
|
||||||
|
| test_chat_streaming_basic (earth) | ✅ |
|
||||||
|
| test_chat_streaming_basic (saturn) | ✅ |
|
||||||
|
| test_chat_streaming_image | ✅ |
|
||||||
|
| test_chat_streaming_multi_turn_tool_calling (add_product_tool) | ✅ |
|
||||||
|
| test_chat_streaming_multi_turn_tool_calling (compare_monthly_expense_tool) | ✅ |
|
||||||
|
| test_chat_streaming_multi_turn_tool_calling (get_then_create_event_tool) | ✅ |
|
||||||
|
| test_chat_streaming_multi_turn_tool_calling (text_then_weather_tool) | ✅ |
|
||||||
|
| test_chat_streaming_multi_turn_tool_calling (weather_tool_then_text) | ✅ |
|
||||||
|
| test_chat_streaming_structured_output (calendar) | ✅ |
|
||||||
|
| test_chat_streaming_structured_output (math) | ✅ |
|
||||||
|
| test_chat_streaming_tool_calling | ✅ |
|
||||||
|
| test_chat_streaming_tool_choice_none | ✅ |
|
||||||
|
| test_chat_streaming_tool_choice_required | ✅ |
|
||||||
|
|
||||||
## Together
|
## Together
|
||||||
|
|
||||||
*Tests run on: 2025-04-16 15:03:51*
|
*Tests run on: 2025-04-16 15:03:51*
|
||||||
|
|
8
tests/verifications/conf/meta_reference.yaml
Normal file
8
tests/verifications/conf/meta_reference.yaml
Normal file
|
@ -0,0 +1,8 @@
|
||||||
|
# LLAMA_STACK_PORT=5002 llama stack run meta-reference-gpu --env INFERENCE_MODEL=meta-llama/Llama-4-Scout-17B-16E-Instruct --env INFERENCE_CHECKPOINT_DIR=<path_to_ckpt>
|
||||||
|
base_url: http://localhost:5002/v1/openai/v1
|
||||||
|
api_key_var: foo
|
||||||
|
models:
|
||||||
|
- meta-llama/Llama-4-Scout-17B-16E-Instruct
|
||||||
|
model_display_names:
|
||||||
|
meta-llama/Llama-4-Scout-17B-16E-Instruct: Llama-4-Scout-Instruct
|
||||||
|
test_exclusions: {}
|
|
@ -60,6 +60,7 @@ RESULTS_DIR.mkdir(exist_ok=True)
|
||||||
MAX_RESULTS_PER_PROVIDER = 1
|
MAX_RESULTS_PER_PROVIDER = 1
|
||||||
|
|
||||||
DEFAULT_PROVIDERS = [
|
DEFAULT_PROVIDERS = [
|
||||||
|
"meta_reference",
|
||||||
"together",
|
"together",
|
||||||
"fireworks",
|
"fireworks",
|
||||||
"openai",
|
"openai",
|
||||||
|
|
|
@ -12,7 +12,9 @@ from typing import Any
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from tests.verifications.openai_api.fixtures.fixtures import _load_all_verification_configs
|
from tests.verifications.openai_api.fixtures.fixtures import (
|
||||||
|
_load_all_verification_configs,
|
||||||
|
)
|
||||||
from tests.verifications.openai_api.fixtures.load import load_test_cases
|
from tests.verifications.openai_api.fixtures.load import load_test_cases
|
||||||
|
|
||||||
chat_completion_test_cases = load_test_cases("chat_completion")
|
chat_completion_test_cases = load_test_cases("chat_completion")
|
||||||
|
@ -272,7 +274,6 @@ def test_chat_non_streaming_tool_choice_required(request, openai_client, model,
|
||||||
tool_choice="required", # Force tool call
|
tool_choice="required", # Force tool call
|
||||||
stream=False,
|
stream=False,
|
||||||
)
|
)
|
||||||
print(response)
|
|
||||||
|
|
||||||
assert response.choices[0].message.role == "assistant"
|
assert response.choices[0].message.role == "assistant"
|
||||||
assert len(response.choices[0].message.tool_calls) > 0, "Expected tool call when tool_choice='required'"
|
assert len(response.choices[0].message.tool_calls) > 0, "Expected tool call when tool_choice='required'"
|
||||||
|
|
1023
tests/verifications/test_results/meta_reference.json
Normal file
1023
tests/verifications/test_results/meta_reference.json
Normal file
File diff suppressed because it is too large
Load diff
Loading…
Add table
Add a link
Reference in a new issue