diff --git a/llama_stack/providers/remote/inference/nvidia/openai_utils.py b/llama_stack/providers/remote/inference/nvidia/openai_utils.py index 2dddeadf9..b74aa05da 100644 --- a/llama_stack/providers/remote/inference/nvidia/openai_utils.py +++ b/llama_stack/providers/remote/inference/nvidia/openai_utils.py @@ -17,13 +17,23 @@ from llama_models.llama3.api.datatypes import ( ToolDefinition, ) from openai import AsyncStream -from openai.types.chat import ChatCompletionChunk as OpenAIChatCompletionChunk + +from openai.types.chat import ( + ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage, + ChatCompletionChunk as OpenAIChatCompletionChunk, + ChatCompletionMessageParam as OpenAIChatCompletionMessage, + ChatCompletionMessageToolCallParam as OpenAIChatCompletionMessageToolCall, + ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage, + ChatCompletionToolMessageParam as OpenAIChatCompletionToolMessage, + ChatCompletionUserMessageParam as OpenAIChatCompletionUserMessage, +) from openai.types.chat.chat_completion import ( Choice as OpenAIChoice, ChoiceLogprobs as OpenAIChoiceLogprobs, # same as chat_completion_chunk ChoiceLogprobs ) -from openai.types.chat.chat_completion_message_tool_call import ( - ChatCompletionMessageToolCall as OpenAIChatCompletionMessageToolCall, + +from openai.types.chat.chat_completion_message_tool_call_param import ( + Function as OpenAIFunction, ) from llama_stack.apis.inference import ( @@ -34,8 +44,11 @@ from llama_stack.apis.inference import ( ChatCompletionResponseStreamChunk, JsonSchemaResponseFormat, Message, + SystemMessage, ToolCallDelta, ToolCallParseStatus, + ToolResponseMessage, + UserMessage, ) @@ -115,21 +128,63 @@ def _convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict: return out -def _convert_message(message: Message) -> Dict: +def _convert_message(message: Message | Dict) -> OpenAIChatCompletionMessage: """ Convert a Message to an OpenAI API-compatible dictionary. """ - out_dict = message.model_dump() - # Llama Stack uses role="ipython" for tool call messages, OpenAI uses "tool" - if out_dict["role"] == "ipython": - out_dict.update(role="tool") + # users can supply a dict instead of a Message object, we'll + # convert it to a Message object and proceed with some type safety. + if isinstance(message, dict): + if "role" not in message: + raise ValueError("role is required in message") + if message["role"] == "user": + message = UserMessage(**message) + elif message["role"] == "assistant": + message = CompletionMessage(**message) + elif message["role"] == "ipython": + message = ToolResponseMessage(**message) + elif message["role"] == "system": + message = SystemMessage(**message) + else: + raise ValueError(f"Unsupported message role: {message['role']}") - if "stop_reason" in out_dict: - out_dict.update(stop_reason=out_dict["stop_reason"].value) + out: OpenAIChatCompletionMessage = None + if isinstance(message, UserMessage): + out = OpenAIChatCompletionUserMessage( + role="user", + content=message.content, # TODO(mf): handle image content + ) + elif isinstance(message, CompletionMessage): + out = OpenAIChatCompletionAssistantMessage( + role="assistant", + content=message.content, + tool_calls=[ + OpenAIChatCompletionMessageToolCall( + id=tool.call_id, + function=OpenAIFunction( + name=tool.tool_name, + arguments=json.dumps(tool.arguments), + ), + type="function", + ) + for tool in message.tool_calls + ], + ) + elif isinstance(message, ToolResponseMessage): + out = OpenAIChatCompletionToolMessage( + role="tool", + tool_call_id=message.call_id, + content=message.content, + ) + elif isinstance(message, SystemMessage): + out = OpenAIChatCompletionSystemMessage( + role="system", + content=message.content, + ) + else: + raise ValueError(f"Unsupported message type: {type(message)}") - # TODO(mf): tool_calls - - return out_dict + return out def convert_chat_completion_request(