mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-01 16:24:44 +00:00
improve Llama Stack -> OpenAI message conversion
This commit is contained in:
parent
e6b82a44eb
commit
0e496d1557
1 changed files with 68 additions and 13 deletions
|
@ -17,13 +17,23 @@ from llama_models.llama3.api.datatypes import (
|
|||
ToolDefinition,
|
||||
)
|
||||
from openai import AsyncStream
|
||||
from openai.types.chat import ChatCompletionChunk as OpenAIChatCompletionChunk
|
||||
|
||||
from openai.types.chat import (
|
||||
ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage,
|
||||
ChatCompletionChunk as OpenAIChatCompletionChunk,
|
||||
ChatCompletionMessageParam as OpenAIChatCompletionMessage,
|
||||
ChatCompletionMessageToolCallParam as OpenAIChatCompletionMessageToolCall,
|
||||
ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage,
|
||||
ChatCompletionToolMessageParam as OpenAIChatCompletionToolMessage,
|
||||
ChatCompletionUserMessageParam as OpenAIChatCompletionUserMessage,
|
||||
)
|
||||
from openai.types.chat.chat_completion import (
|
||||
Choice as OpenAIChoice,
|
||||
ChoiceLogprobs as OpenAIChoiceLogprobs, # same as chat_completion_chunk ChoiceLogprobs
|
||||
)
|
||||
from openai.types.chat.chat_completion_message_tool_call import (
|
||||
ChatCompletionMessageToolCall as OpenAIChatCompletionMessageToolCall,
|
||||
|
||||
from openai.types.chat.chat_completion_message_tool_call_param import (
|
||||
Function as OpenAIFunction,
|
||||
)
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
|
@ -34,8 +44,11 @@ from llama_stack.apis.inference import (
|
|||
ChatCompletionResponseStreamChunk,
|
||||
JsonSchemaResponseFormat,
|
||||
Message,
|
||||
SystemMessage,
|
||||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
ToolResponseMessage,
|
||||
UserMessage,
|
||||
)
|
||||
|
||||
|
||||
|
@ -115,21 +128,63 @@ def _convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict:
|
|||
return out
|
||||
|
||||
|
||||
def _convert_message(message: Message) -> Dict:
|
||||
def _convert_message(message: Message | Dict) -> OpenAIChatCompletionMessage:
|
||||
"""
|
||||
Convert a Message to an OpenAI API-compatible dictionary.
|
||||
"""
|
||||
out_dict = message.model_dump()
|
||||
# Llama Stack uses role="ipython" for tool call messages, OpenAI uses "tool"
|
||||
if out_dict["role"] == "ipython":
|
||||
out_dict.update(role="tool")
|
||||
# users can supply a dict instead of a Message object, we'll
|
||||
# convert it to a Message object and proceed with some type safety.
|
||||
if isinstance(message, dict):
|
||||
if "role" not in message:
|
||||
raise ValueError("role is required in message")
|
||||
if message["role"] == "user":
|
||||
message = UserMessage(**message)
|
||||
elif message["role"] == "assistant":
|
||||
message = CompletionMessage(**message)
|
||||
elif message["role"] == "ipython":
|
||||
message = ToolResponseMessage(**message)
|
||||
elif message["role"] == "system":
|
||||
message = SystemMessage(**message)
|
||||
else:
|
||||
raise ValueError(f"Unsupported message role: {message['role']}")
|
||||
|
||||
if "stop_reason" in out_dict:
|
||||
out_dict.update(stop_reason=out_dict["stop_reason"].value)
|
||||
out: OpenAIChatCompletionMessage = None
|
||||
if isinstance(message, UserMessage):
|
||||
out = OpenAIChatCompletionUserMessage(
|
||||
role="user",
|
||||
content=message.content, # TODO(mf): handle image content
|
||||
)
|
||||
elif isinstance(message, CompletionMessage):
|
||||
out = OpenAIChatCompletionAssistantMessage(
|
||||
role="assistant",
|
||||
content=message.content,
|
||||
tool_calls=[
|
||||
OpenAIChatCompletionMessageToolCall(
|
||||
id=tool.call_id,
|
||||
function=OpenAIFunction(
|
||||
name=tool.tool_name,
|
||||
arguments=json.dumps(tool.arguments),
|
||||
),
|
||||
type="function",
|
||||
)
|
||||
for tool in message.tool_calls
|
||||
],
|
||||
)
|
||||
elif isinstance(message, ToolResponseMessage):
|
||||
out = OpenAIChatCompletionToolMessage(
|
||||
role="tool",
|
||||
tool_call_id=message.call_id,
|
||||
content=message.content,
|
||||
)
|
||||
elif isinstance(message, SystemMessage):
|
||||
out = OpenAIChatCompletionSystemMessage(
|
||||
role="system",
|
||||
content=message.content,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported message type: {type(message)}")
|
||||
|
||||
# TODO(mf): tool_calls
|
||||
|
||||
return out_dict
|
||||
return out
|
||||
|
||||
|
||||
def convert_chat_completion_request(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue