improve Llama Stack -> OpenAI message conversion

This commit is contained in:
Matthew Farrellee 2024-11-22 10:34:31 -05:00
parent e6b82a44eb
commit 0e496d1557

View file

@ -17,13 +17,23 @@ from llama_models.llama3.api.datatypes import (
ToolDefinition,
)
from openai import AsyncStream
from openai.types.chat import ChatCompletionChunk as OpenAIChatCompletionChunk
from openai.types.chat import (
ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage,
ChatCompletionChunk as OpenAIChatCompletionChunk,
ChatCompletionMessageParam as OpenAIChatCompletionMessage,
ChatCompletionMessageToolCallParam as OpenAIChatCompletionMessageToolCall,
ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage,
ChatCompletionToolMessageParam as OpenAIChatCompletionToolMessage,
ChatCompletionUserMessageParam as OpenAIChatCompletionUserMessage,
)
from openai.types.chat.chat_completion import (
Choice as OpenAIChoice,
ChoiceLogprobs as OpenAIChoiceLogprobs, # same as chat_completion_chunk ChoiceLogprobs
)
from openai.types.chat.chat_completion_message_tool_call import (
ChatCompletionMessageToolCall as OpenAIChatCompletionMessageToolCall,
from openai.types.chat.chat_completion_message_tool_call_param import (
Function as OpenAIFunction,
)
from llama_stack.apis.inference import (
@ -34,8 +44,11 @@ from llama_stack.apis.inference import (
ChatCompletionResponseStreamChunk,
JsonSchemaResponseFormat,
Message,
SystemMessage,
ToolCallDelta,
ToolCallParseStatus,
ToolResponseMessage,
UserMessage,
)
@ -115,21 +128,63 @@ def _convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict:
return out
def _convert_message(message: Message) -> Dict:
def _convert_message(message: Message | Dict) -> OpenAIChatCompletionMessage:
"""
Convert a Message to an OpenAI API-compatible dictionary.
"""
out_dict = message.model_dump()
# Llama Stack uses role="ipython" for tool call messages, OpenAI uses "tool"
if out_dict["role"] == "ipython":
out_dict.update(role="tool")
# users can supply a dict instead of a Message object, we'll
# convert it to a Message object and proceed with some type safety.
if isinstance(message, dict):
if "role" not in message:
raise ValueError("role is required in message")
if message["role"] == "user":
message = UserMessage(**message)
elif message["role"] == "assistant":
message = CompletionMessage(**message)
elif message["role"] == "ipython":
message = ToolResponseMessage(**message)
elif message["role"] == "system":
message = SystemMessage(**message)
else:
raise ValueError(f"Unsupported message role: {message['role']}")
if "stop_reason" in out_dict:
out_dict.update(stop_reason=out_dict["stop_reason"].value)
out: OpenAIChatCompletionMessage = None
if isinstance(message, UserMessage):
out = OpenAIChatCompletionUserMessage(
role="user",
content=message.content, # TODO(mf): handle image content
)
elif isinstance(message, CompletionMessage):
out = OpenAIChatCompletionAssistantMessage(
role="assistant",
content=message.content,
tool_calls=[
OpenAIChatCompletionMessageToolCall(
id=tool.call_id,
function=OpenAIFunction(
name=tool.tool_name,
arguments=json.dumps(tool.arguments),
),
type="function",
)
for tool in message.tool_calls
],
)
elif isinstance(message, ToolResponseMessage):
out = OpenAIChatCompletionToolMessage(
role="tool",
tool_call_id=message.call_id,
content=message.content,
)
elif isinstance(message, SystemMessage):
out = OpenAIChatCompletionSystemMessage(
role="system",
content=message.content,
)
else:
raise ValueError(f"Unsupported message type: {type(message)}")
# TODO(mf): tool_calls
return out_dict
return out
def convert_chat_completion_request(