fix: OAI compat endpoint for meta reference inference provider (#1962)

Test plan:
python tests/verifications/generate_report.py --providers
fireworks,together,llama_meta_ref,openai

Co-authored-by: Eric Huang <erichuang@fb.com>
This commit is contained in:
ehhuang 2025-04-17 11:16:04 -07:00 committed by GitHub
parent 8bd6665775
commit 2976b5d992
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 1184 additions and 44 deletions

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
import io
import json
import uuid
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
@ -299,6 +300,7 @@ class ChatFormat:
call_id=call_id,
tool_name=tool_name,
arguments=tool_arguments,
arguments_json=json.dumps(tool_arguments),
)
)

View file

@ -515,7 +515,8 @@ class MetaReferenceInferenceImpl(
stop_reason = None
ipython = False
for token_result in self.generator.chat_completion(request):
for token_results in self.generator.chat_completion([request]):
token_result = token_results[0]
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1":
cprint(token_result.text, "cyan", end="")
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2":

View file

@ -8,7 +8,17 @@ import logging
import time
import uuid
import warnings
from typing import Any, AsyncGenerator, AsyncIterator, Awaitable, Dict, Iterable, List, Optional, Union
from typing import (
Any,
AsyncGenerator,
AsyncIterator,
Awaitable,
Dict,
Iterable,
List,
Optional,
Union,
)
from openai import AsyncStream
from openai.types.chat import (
@ -78,6 +88,7 @@ from llama_stack.apis.common.content_types import (
TextDelta,
ToolCallDelta,
ToolCallParseStatus,
_URLOrData,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
@ -93,6 +104,7 @@ from llama_stack.apis.inference import (
SamplingParams,
SystemMessage,
TokenLogProbs,
ToolChoice,
ToolResponseMessage,
TopKSamplingStrategy,
TopPSamplingStrategy,
@ -103,7 +115,6 @@ from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAICompletion,
OpenAICompletionChoice,
OpenAIMessageParam,
OpenAIResponseFormatParam,
ToolConfig,
)
@ -612,13 +623,10 @@ async def convert_message_to_openai_dict_new(
)
for tool in message.tool_calls
]
params = {}
if tool_calls:
params = {"tool_calls": tool_calls}
out = OpenAIChatCompletionAssistantMessage(
role="assistant",
content=await _convert_message_content(message.content),
**params,
tool_calls=tool_calls or None,
)
elif isinstance(message, ToolResponseMessage):
out = OpenAIChatCompletionToolMessage(
@ -695,7 +703,10 @@ def to_openai_param_type(param_type: str) -> dict:
if param_type.startswith("list[") and param_type.endswith("]"):
inner_type = param_type[5:-1]
if inner_type in basic_types:
return {"type": "array", "items": {"type": basic_types.get(inner_type, inner_type)}}
return {
"type": "array",
"items": {"type": basic_types.get(inner_type, inner_type)},
}
return {"type": param_type}
@ -815,6 +826,10 @@ def _convert_openai_finish_reason(finish_reason: str) -> StopReason:
def _convert_openai_request_tool_config(tool_choice: Optional[Union[str, Dict[str, Any]]] = None) -> ToolConfig:
tool_config = ToolConfig()
if tool_choice:
try:
tool_choice = ToolChoice(tool_choice)
except ValueError:
pass
tool_config.tool_choice = tool_choice
return tool_config
@ -849,7 +864,9 @@ def _convert_openai_request_tools(tools: Optional[List[Dict[str, Any]]] = None)
return lls_tools
def _convert_openai_request_response_format(response_format: OpenAIResponseFormatParam = None):
def _convert_openai_request_response_format(
response_format: OpenAIResponseFormatParam = None,
):
if not response_format:
return None
# response_format can be a dict or a pydantic model
@ -957,38 +974,50 @@ def _convert_openai_sampling_params(
return sampling_params
def _convert_openai_request_messages(messages: List[OpenAIMessageParam]):
# Llama Stack messages and OpenAI messages are similar, but not identical.
lls_messages = []
def openai_messages_to_messages(
messages: List[OpenAIChatCompletionMessage],
) -> List[Message]:
"""
Convert a list of OpenAIChatCompletionMessage into a list of Message.
"""
converted_messages = []
for message in messages:
lls_message = dict(message)
if message.role == "system":
converted_message = SystemMessage(content=message.content)
elif message.role == "user":
converted_message = UserMessage(content=openai_content_to_content(message.content))
elif message.role == "assistant":
converted_message = CompletionMessage(
content=message.content,
tool_calls=_convert_openai_tool_calls(message.tool_calls),
stop_reason=StopReason.end_of_turn,
)
elif message.role == "tool":
converted_message = ToolResponseMessage(
role="tool",
call_id=message.tool_call_id,
content=openai_content_to_content(message.content),
)
else:
raise ValueError(f"Unknown role {message.role}")
converted_messages.append(converted_message)
return converted_messages
# Llama Stack expects `call_id` but OpenAI uses `tool_call_id`
tool_call_id = lls_message.pop("tool_call_id", None)
if tool_call_id:
lls_message["call_id"] = tool_call_id
content = lls_message.get("content", None)
if isinstance(content, list):
lls_content = []
for item in content:
# items can either by pydantic models or dicts here...
item = dict(item)
if item.get("type", "") == "image_url":
lls_item = ImageContentItem(
type="image",
image=URL(uri=item.get("image_url", {}).get("url", "")),
)
elif item.get("type", "") == "text":
lls_item = TextContentItem(
type="text",
text=item.get("text", ""),
)
lls_content.append(lls_item)
lls_message["content"] = lls_content
lls_messages.append(lls_message)
return lls_messages
def openai_content_to_content(content: Union[str, Iterable[OpenAIChatCompletionContentPartParam]]):
if isinstance(content, str):
return content
elif isinstance(content, list):
return [openai_content_to_content(c) for c in content]
elif hasattr(content, "type"):
if content.type == "text":
return TextContentItem(type="text", text=content.text)
elif content.type == "image_url":
return ImageContentItem(type="image", image=_URLOrData(url=URL(uri=content.image_url.url)))
else:
raise ValueError(f"Unknown content type: {content.type}")
else:
raise ValueError(f"Unknown content type: {content}")
def convert_openai_chat_completion_choice(
@ -1313,7 +1342,7 @@ class OpenAIChatCompletionToLlamaStackMixin:
top_p: Optional[float] = None,
user: Optional[str] = None,
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
messages = _convert_openai_request_messages(messages)
messages = openai_messages_to_messages(messages)
response_format = _convert_openai_request_response_format(response_format)
sampling_params = _convert_openai_sampling_params(
max_tokens=max_tokens,
@ -1321,7 +1350,10 @@ class OpenAIChatCompletionToLlamaStackMixin:
top_p=top_p,
)
tool_config = _convert_openai_request_tool_config(tool_choice)
tools = _convert_openai_request_tools(tools)
if tool_config.tool_choice == ToolChoice.none:
tools = []
outstanding_responses = []
# "n" is the number of completions to generate per prompt
@ -1346,7 +1378,9 @@ class OpenAIChatCompletionToLlamaStackMixin:
)
async def _process_stream_response(
self, model: str, outstanding_responses: List[Awaitable[AsyncIterator[ChatCompletionResponseStreamChunk]]]
self,
model: str,
outstanding_responses: List[Awaitable[AsyncIterator[ChatCompletionResponseStreamChunk]]],
):
id = f"chatcmpl-{uuid.uuid4()}"
for outstanding_response in outstanding_responses:
@ -1369,11 +1403,31 @@ class OpenAIChatCompletionToLlamaStackMixin:
elif isinstance(event.delta, ToolCallDelta):
if event.delta.parse_status == ToolCallParseStatus.succeeded:
tool_call = event.delta.tool_call
# First chunk includes full structure
openai_tool_call = OpenAIChoiceDeltaToolCall(
index=0,
id=tool_call.call_id,
function=OpenAIChoiceDeltaToolCallFunction(
name=tool_call.tool_name, arguments=tool_call.arguments_json
name=tool_call.tool_name,
arguments="",
),
)
delta = OpenAIChoiceDelta(tool_calls=[openai_tool_call])
yield OpenAIChatCompletionChunk(
id=id,
choices=[
OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta)
],
created=int(time.time()),
model=model,
object="chat.completion.chunk",
)
# arguments
openai_tool_call = OpenAIChoiceDeltaToolCall(
index=0,
function=OpenAIChoiceDeltaToolCallFunction(
arguments=tool_call.arguments_json,
),
)
delta = OpenAIChoiceDelta(tool_calls=[openai_tool_call])