From b0310af1776f175d01a2d31b9a930ce4c585f062 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Tue, 25 Feb 2025 22:02:11 -0800 Subject: [PATCH] refactor: move OpenAI compat utilities from nvidia to openai_compat (#1258) # What does this PR do? This PR: - refactors code which converts between Llama Stack <> OpenAI compat servers which was used by the nvidia implementation to be used more broadly. Next PRs in the stack will show usage. - adds incremental tool call parsing (when tool calls are streamed incrementally, not just whole-sale) ## Test Plan Run ```bash pytest -s -v -k nvidia llama_stack/providers/tests/inference/ --env NVIDIA_API_KEY=.... ``` Text model tests pass (albeit without completions tests) ``` test_text_inference.py::TestInference::test_model_list[-nvidia] PASSED test_text_inference.py::TestInference::test_text_completion_non_streaming[-nvidia-inference:completion:non_streaming] FAILED test_text_inference.py::TestInference::test_text_completion_streaming[-nvidia-inference:completion:streaming] FAILED test_text_inference.py::TestInference::test_text_completion_logprobs_non_streaming[-nvidia-inference:completion:logprobs_non_streaming] FAILED test_text_inference.py::TestInference::test_text_completion_logprobs_streaming[-nvidia-inference:completion:logprobs_streaming] FAILED test_text_inference.py::TestInference::test_text_completion_structured_output[-nvidia-inference:completion:structured_output] FAILED test_text_inference.py::TestInference::test_text_chat_completion_non_streaming[-nvidia-inference:chat_completion:sample_messages] PASSED test_text_inference.py::TestInference::test_text_chat_completion_structured_output[-nvidia-inference:chat_completion:structured_output] PASSED test_text_inference.py::TestInference::test_text_chat_completion_streaming[-nvidia-inference:chat_completion:sample_messages] PASSED test_text_inference.py::TestInference::test_text_chat_completion_with_tool_calling[-nvidia-inference:chat_completion:sample_messages_tool_calling] PASSED test_text_inference.py::TestInference::test_text_chat_completion_with_tool_calling_streaming[-nvidia-inference:chat_completion:sample_messages_tool_calling] PASSED ``` Vision model tests don't: ``` FAILED test_vision_inference.py::TestVisionModelInference::test_vision_chat_completion_non_streaming[-nvidia-image0-expected_strings0] - openai.BadRequestError: Error code: 400 - {'type': 'about:blank', 'status': 400, 'title': 'Bad Request', 'detail': 'Inference error'} FAILED test_vision_inference.py::TestVisionModelInference::test_vision_chat_completion_non_streaming[-nvidia-image1-expected_strings1] - openai.BadRequestError: Error code: 400 - {'type': 'about:blank', 'status': 400, 'title': 'Bad Request', 'detail': 'Inference error'} FAILED test_vision_inference.py::TestVisionModelInference::test_vision_chat_completion_streaming[-nvidia] - openai.BadRequestError: Error code: 400 - {'object': 'error', 'message': "[{'type': 'string_type', 'loc': ('body', 'messages', 1, 'content'), 'msg': 'Input should be a valid string', 'input': [{'image_url': {'url': 'https://raw.githubusercontent.com/meta-llama/llam... ``` --- .../remote/inference/nvidia/nvidia.py | 8 +- .../remote/inference/nvidia/openai_utils.py | 460 +-------------- .../providers/tests/inference/conftest.py | 3 +- .../providers/tests/inference/fixtures.py | 2 +- .../utils/inference/openai_compat.py | 524 +++++++++++++++++- 5 files changed, 538 insertions(+), 459 deletions(-) diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py index ecd53e91c..cc3bd85bb 100644 --- a/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -40,6 +40,10 @@ from llama_stack.models.llama.datatypes import ( from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, ) +from llama_stack.providers.utils.inference.openai_compat import ( + convert_openai_chat_completion_choice, + convert_openai_chat_completion_stream, +) from llama_stack.providers.utils.inference.prompt_adapter import content_has_media from . import NVIDIAConfig @@ -47,8 +51,6 @@ from .models import _MODEL_ENTRIES from .openai_utils import ( convert_chat_completion_request, convert_completion_request, - convert_openai_chat_completion_choice, - convert_openai_chat_completion_stream, convert_openai_completion_choice, convert_openai_completion_stream, ) @@ -201,7 +203,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e if stream: - return convert_openai_chat_completion_stream(response) + return convert_openai_chat_completion_stream(response, enable_incremental_tool_calls=False) else: # we pass n=1 to get only one completion return convert_openai_chat_completion_choice(response.choices[0]) diff --git a/llama_stack/providers/remote/inference/nvidia/openai_utils.py b/llama_stack/providers/remote/inference/nvidia/openai_utils.py index 9799eedcc..1849fda6d 100644 --- a/llama_stack/providers/remote/inference/nvidia/openai_utils.py +++ b/llama_stack/providers/remote/inference/nvidia/openai_utils.py @@ -4,249 +4,36 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import json import warnings -from typing import Any, AsyncGenerator, Dict, Generator, Iterable, List, Optional, Union +from typing import Any, AsyncGenerator, Dict, List, Optional from openai import AsyncStream -from openai.types.chat import ( - ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage, -) -from openai.types.chat import ( - ChatCompletionChunk as OpenAIChatCompletionChunk, -) -from openai.types.chat import ( - ChatCompletionContentPartImageParam as OpenAIChatCompletionContentPartImageParam, -) -from openai.types.chat import ( - ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam, -) -from openai.types.chat import ( - ChatCompletionMessageParam as OpenAIChatCompletionMessage, -) -from openai.types.chat import ( - ChatCompletionMessageToolCallParam as OpenAIChatCompletionMessageToolCall, -) -from openai.types.chat import ( - ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage, -) -from openai.types.chat import ( - ChatCompletionToolMessageParam as OpenAIChatCompletionToolMessage, -) -from openai.types.chat import ( - ChatCompletionUserMessageParam as OpenAIChatCompletionUserMessage, -) from openai.types.chat.chat_completion import ( Choice as OpenAIChoice, ) -from openai.types.chat.chat_completion import ( - ChoiceLogprobs as OpenAIChoiceLogprobs, # same as chat_completion_chunk ChoiceLogprobs -) -from openai.types.chat.chat_completion_content_part_image_param import ( - ImageURL as OpenAIImageURL, -) -from openai.types.chat.chat_completion_message_tool_call_param import ( - Function as OpenAIFunction, -) from openai.types.completion import Completion as OpenAICompletion from openai.types.completion_choice import Logprobs as OpenAICompletionLogprobs -from llama_stack.apis.common.content_types import ( - ImageContentItem, - InterleavedContent, - TextContentItem, - TextDelta, - ToolCallDelta, - ToolCallParseStatus, -) from llama_stack.apis.inference import ( ChatCompletionRequest, - ChatCompletionResponse, - ChatCompletionResponseEvent, - ChatCompletionResponseEventType, - ChatCompletionResponseStreamChunk, - CompletionMessage, CompletionRequest, CompletionResponse, CompletionResponseStreamChunk, JsonSchemaResponseFormat, - Message, - SystemMessage, TokenLogProbs, - ToolResponseMessage, - UserMessage, ) from llama_stack.models.llama.datatypes import ( - BuiltinTool, GreedySamplingStrategy, - StopReason, - ToolCall, - ToolDefinition, TopKSamplingStrategy, TopPSamplingStrategy, ) -from llama_stack.providers.utils.inference.prompt_adapter import ( - convert_image_content_to_url, +from llama_stack.providers.utils.inference.openai_compat import ( + _convert_openai_finish_reason, + convert_message_to_openai_dict_new, + convert_tooldef_to_openai_tool, ) -def _convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict: - """ - Convert a ToolDefinition to an OpenAI API-compatible dictionary. - - ToolDefinition: - tool_name: str | BuiltinTool - description: Optional[str] - parameters: Optional[Dict[str, ToolParamDefinition]] - - ToolParamDefinition: - param_type: str - description: Optional[str] - required: Optional[bool] - default: Optional[Any] - - - OpenAI spec - - - { - "type": "function", - "function": { - "name": tool_name, - "description": description, - "parameters": { - "type": "object", - "properties": { - param_name: { - "type": param_type, - "description": description, - "default": default, - }, - ... - }, - "required": [param_name, ...], - }, - }, - } - """ - out = { - "type": "function", - "function": {}, - } - function = out["function"] - - if isinstance(tool.tool_name, BuiltinTool): - function.update(name=tool.tool_name.value) # TODO(mf): is this sufficient? - else: - function.update(name=tool.tool_name) - - if tool.description: - function.update(description=tool.description) - - if tool.parameters: - parameters = { - "type": "object", - "properties": {}, - } - properties = parameters["properties"] - required = [] - for param_name, param in tool.parameters.items(): - properties[param_name] = {"type": param.param_type} - if param.description: - properties[param_name].update(description=param.description) - if param.default: - properties[param_name].update(default=param.default) - if param.required: - required.append(param_name) - - if required: - parameters.update(required=required) - - function.update(parameters=parameters) - - return out - - -async def _convert_message(message: Message | Dict) -> OpenAIChatCompletionMessage: - """ - Convert a Message to an OpenAI API-compatible dictionary. - """ - # users can supply a dict instead of a Message object, we'll - # convert it to a Message object and proceed with some type safety. - if isinstance(message, dict): - if "role" not in message: - raise ValueError("role is required in message") - if message["role"] == "user": - message = UserMessage(**message) - elif message["role"] == "assistant": - message = CompletionMessage(**message) - elif message["role"] == "tool": - message = ToolResponseMessage(**message) - elif message["role"] == "system": - message = SystemMessage(**message) - else: - raise ValueError(f"Unsupported message role: {message['role']}") - - # Map Llama Stack spec to OpenAI spec - - # str -> str - # {"type": "text", "text": ...} -> {"type": "text", "text": ...} - # {"type": "image", "image": {"url": {"uri": ...}}} -> {"type": "image_url", "image_url": {"url": ...}} - # {"type": "image", "image": {"data": ...}} -> {"type": "image_url", "image_url": {"url": "data:image/?;base64,..."}} - # List[...] -> List[...] - async def _convert_user_message_content( - content: InterleavedContent, - ) -> Union[str, Iterable[OpenAIChatCompletionContentPartParam]]: - # Llama Stack and OpenAI spec match for str and text input - if isinstance(content, str) or isinstance(content, TextContentItem): - return content - elif isinstance(content, ImageContentItem): - return OpenAIChatCompletionContentPartImageParam( - image_url=OpenAIImageURL(url=await convert_image_content_to_url(content)), - type="image_url", - ) - elif isinstance(content, List): - return [await _convert_user_message_content(item) for item in content] - else: - raise ValueError(f"Unsupported content type: {type(content)}") - - out: OpenAIChatCompletionMessage = None - if isinstance(message, UserMessage): - out = OpenAIChatCompletionUserMessage( - role="user", - content=await _convert_user_message_content(message.content), - ) - elif isinstance(message, CompletionMessage): - out = OpenAIChatCompletionAssistantMessage( - role="assistant", - content=message.content, - tool_calls=[ - OpenAIChatCompletionMessageToolCall( - id=tool.call_id, - function=OpenAIFunction( - name=tool.tool_name, - arguments=json.dumps(tool.arguments), - ), - type="function", - ) - for tool in message.tool_calls - ], - ) - elif isinstance(message, ToolResponseMessage): - out = OpenAIChatCompletionToolMessage( - role="tool", - tool_call_id=message.call_id, - content=message.content, - ) - elif isinstance(message, SystemMessage): - out = OpenAIChatCompletionSystemMessage( - role="system", - content=message.content, - ) - else: - raise ValueError(f"Unsupported message type: {type(message)}") - - return out - - async def convert_chat_completion_request( request: ChatCompletionRequest, n: int = 1, @@ -281,7 +68,7 @@ async def convert_chat_completion_request( nvext = {} payload: Dict[str, Any] = dict( model=request.model, - messages=[await _convert_message(message) for message in request.messages], + messages=[await convert_message_to_openai_dict_new(message) for message in request.messages], stream=request.stream, n=n, extra_body=dict(nvext=nvext), @@ -296,7 +83,7 @@ async def convert_chat_completion_request( nvext.update(guided_json=request.response_format.json_schema) if request.tools: - payload.update(tools=[_convert_tooldef_to_openai_tool(tool) for tool in request.tools]) + payload.update(tools=[convert_tooldef_to_openai_tool(tool) for tool in request.tools]) if request.tool_config.tool_choice: payload.update( tool_choice=request.tool_config.tool_choice.value @@ -329,239 +116,6 @@ async def convert_chat_completion_request( return payload -def _convert_openai_finish_reason(finish_reason: str) -> StopReason: - """ - Convert an OpenAI chat completion finish_reason to a StopReason. - - finish_reason: Literal["stop", "length", "tool_calls", ...] - - stop: model hit a natural stop point or a provided stop sequence - - length: maximum number of tokens specified in the request was reached - - tool_calls: model called a tool - - -> - - class StopReason(Enum): - end_of_turn = "end_of_turn" - end_of_message = "end_of_message" - out_of_tokens = "out_of_tokens" - """ - - # TODO(mf): are end_of_turn and end_of_message semantics correct? - return { - "stop": StopReason.end_of_turn, - "length": StopReason.out_of_tokens, - "tool_calls": StopReason.end_of_message, - }.get(finish_reason, StopReason.end_of_turn) - - -def _convert_openai_tool_calls( - tool_calls: List[OpenAIChatCompletionMessageToolCall], -) -> List[ToolCall]: - """ - Convert an OpenAI ChatCompletionMessageToolCall list into a list of ToolCall. - - OpenAI ChatCompletionMessageToolCall: - id: str - function: Function - type: Literal["function"] - - OpenAI Function: - arguments: str - name: str - - -> - - ToolCall: - call_id: str - tool_name: str - arguments: Dict[str, ...] - """ - if not tool_calls: - return [] # CompletionMessage tool_calls is not optional - - return [ - ToolCall( - call_id=call.id, - tool_name=call.function.name, - arguments=json.loads(call.function.arguments), - ) - for call in tool_calls - ] - - -def _convert_openai_logprobs( - logprobs: OpenAIChoiceLogprobs, -) -> Optional[List[TokenLogProbs]]: - """ - Convert an OpenAI ChoiceLogprobs into a list of TokenLogProbs. - - OpenAI ChoiceLogprobs: - content: Optional[List[ChatCompletionTokenLogprob]] - - OpenAI ChatCompletionTokenLogprob: - token: str - logprob: float - top_logprobs: List[TopLogprob] - - OpenAI TopLogprob: - token: str - logprob: float - - -> - - TokenLogProbs: - logprobs_by_token: Dict[str, float] - - token, logprob - - """ - if not logprobs: - return None - - return [ - TokenLogProbs(logprobs_by_token={logprobs.token: logprobs.logprob for logprobs in content.top_logprobs}) - for content in logprobs.content - ] - - -def convert_openai_chat_completion_choice( - choice: OpenAIChoice, -) -> ChatCompletionResponse: - """ - Convert an OpenAI Choice into a ChatCompletionResponse. - - OpenAI Choice: - message: ChatCompletionMessage - finish_reason: str - logprobs: Optional[ChoiceLogprobs] - - OpenAI ChatCompletionMessage: - role: Literal["assistant"] - content: Optional[str] - tool_calls: Optional[List[ChatCompletionMessageToolCall]] - - -> - - ChatCompletionResponse: - completion_message: CompletionMessage - logprobs: Optional[List[TokenLogProbs]] - - CompletionMessage: - role: Literal["assistant"] - content: str | ImageMedia | List[str | ImageMedia] - stop_reason: StopReason - tool_calls: List[ToolCall] - - class StopReason(Enum): - end_of_turn = "end_of_turn" - end_of_message = "end_of_message" - out_of_tokens = "out_of_tokens" - """ - assert hasattr(choice, "message") and choice.message, "error in server response: message not found" - assert hasattr(choice, "finish_reason") and choice.finish_reason, ( - "error in server response: finish_reason not found" - ) - - return ChatCompletionResponse( - completion_message=CompletionMessage( - content=choice.message.content or "", # CompletionMessage content is not optional - stop_reason=_convert_openai_finish_reason(choice.finish_reason), - tool_calls=_convert_openai_tool_calls(choice.message.tool_calls), - ), - logprobs=_convert_openai_logprobs(choice.logprobs), - ) - - -async def convert_openai_chat_completion_stream( - stream: AsyncStream[OpenAIChatCompletionChunk], -) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]: - """ - Convert a stream of OpenAI chat completion chunks into a stream - of ChatCompletionResponseStreamChunk. - """ - - # generate a stream of ChatCompletionResponseEventType: start -> progress -> progress -> ... - def _event_type_generator() -> Generator[ChatCompletionResponseEventType, None, None]: - yield ChatCompletionResponseEventType.start - while True: - yield ChatCompletionResponseEventType.progress - - event_type = _event_type_generator() - - # we implement NIM specific semantics, the main difference from OpenAI - # is that tool_calls are always produced as a complete call. there is no - # intermediate / partial tool call streamed. because of this, we can - # simplify the logic and not concern outselves with parse_status of - # started/in_progress/failed. we can always assume success. - # - # a stream of ChatCompletionResponseStreamChunk consists of - # 0. a start event - # 1. zero or more progress events - # - each progress event has a delta - # - each progress event may have a stop_reason - # - each progress event may have logprobs - # - each progress event may have tool_calls - # if a progress event has tool_calls, - # it is fully formed and - # can be emitted with a parse_status of success - # 2. a complete event - - stop_reason = None - - async for chunk in stream: - choice = chunk.choices[0] # assuming only one choice per chunk - - # we assume there's only one finish_reason in the stream - stop_reason = _convert_openai_finish_reason(choice.finish_reason) or stop_reason - - # if there's a tool call, emit an event for each tool in the list - # if tool call and content, emit both separately - - if choice.delta.tool_calls: - # the call may have content and a tool call. ChatCompletionResponseEvent - # does not support both, so we emit the content first - if choice.delta.content: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=next(event_type), - delta=TextDelta(text=choice.delta.content), - logprobs=_convert_openai_logprobs(choice.logprobs), - ) - ) - - # it is possible to have parallel tool calls in stream, but - # ChatCompletionResponseEvent only supports one per stream - if len(choice.delta.tool_calls) > 1: - warnings.warn("multiple tool calls found in a single delta, using the first, ignoring the rest") - - # NIM only produces fully formed tool calls, so we can assume success - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=next(event_type), - delta=ToolCallDelta( - tool_call=_convert_openai_tool_calls(choice.delta.tool_calls)[0], - parse_status=ToolCallParseStatus.succeeded, - ), - logprobs=_convert_openai_logprobs(choice.logprobs), - ) - ) - else: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=next(event_type), - delta=TextDelta(text=choice.delta.content or ""), - logprobs=_convert_openai_logprobs(choice.logprobs), - ) - ) - - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.complete, - delta=TextDelta(text=""), - stop_reason=stop_reason, - ) - ) - - def convert_completion_request( request: CompletionRequest, n: int = 1, diff --git a/llama_stack/providers/tests/inference/conftest.py b/llama_stack/providers/tests/inference/conftest.py index 2e9b5bcff..0075ff80d 100644 --- a/llama_stack/providers/tests/inference/conftest.py +++ b/llama_stack/providers/tests/inference/conftest.py @@ -46,9 +46,10 @@ def pytest_generate_tests(metafunc): if ("Vision" in cls_name and "Vision" in model) or ("Vision" not in cls_name and "Vision" not in model): params.append(pytest.param(model, id=model)) + print(f"params: {params}") if not params: model = metafunc.config.getoption("--inference-model") - params = [pytest.param(model, id="")] + params = [pytest.param(model, id=model)] metafunc.parametrize( "inference_model", diff --git a/llama_stack/providers/tests/inference/fixtures.py b/llama_stack/providers/tests/inference/fixtures.py index 5291bffb3..80ee68ba8 100644 --- a/llama_stack/providers/tests/inference/fixtures.py +++ b/llama_stack/providers/tests/inference/fixtures.py @@ -197,7 +197,7 @@ def inference_nvidia() -> ProviderFixture: Provider( provider_id="nvidia", provider_type="remote::nvidia", - config=NVIDIAConfig().model_dump(), + config=NVIDIAConfig(api_key=get_env_or_fail("NVIDIA_API_KEY")).model_dump(), ) ], ) diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index 946d27763..1e684f4a3 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -5,13 +5,58 @@ # the root directory of this source tree. import json import logging -from typing import AsyncGenerator, Dict, List, Optional, Union +import warnings +from typing import AsyncGenerator, Dict, Generator, Iterable, List, Optional, Union +from openai import AsyncStream +from openai.types.chat import ( + ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage, +) +from openai.types.chat import ( + ChatCompletionChunk as OpenAIChatCompletionChunk, +) +from openai.types.chat import ( + ChatCompletionContentPartImageParam as OpenAIChatCompletionContentPartImageParam, +) +from openai.types.chat import ( + ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam, +) +from openai.types.chat import ( + ChatCompletionContentPartTextParam as OpenAIChatCompletionContentPartTextParam, +) +from openai.types.chat import ( + ChatCompletionMessageParam as OpenAIChatCompletionMessage, +) from openai.types.chat import ChatCompletionMessageToolCall +from openai.types.chat import ( + ChatCompletionMessageToolCallParam as OpenAIChatCompletionMessageToolCall, +) +from openai.types.chat import ( + ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage, +) +from openai.types.chat import ( + ChatCompletionToolMessageParam as OpenAIChatCompletionToolMessage, +) +from openai.types.chat import ( + ChatCompletionUserMessageParam as OpenAIChatCompletionUserMessage, +) +from openai.types.chat.chat_completion import ( + Choice as OpenAIChoice, +) +from openai.types.chat.chat_completion import ( + ChoiceLogprobs as OpenAIChoiceLogprobs, # same as chat_completion_chunk ChoiceLogprobs +) +from openai.types.chat.chat_completion_content_part_image_param import ( + ImageURL as OpenAIImageURL, +) +from openai.types.chat.chat_completion_message_tool_call_param import ( + Function as OpenAIFunction, +) from pydantic import BaseModel from llama_stack.apis.common.content_types import ( ImageContentItem, + InterleavedContent, TextContentItem, TextDelta, ToolCallDelta, @@ -27,13 +72,18 @@ from llama_stack.apis.inference import ( CompletionResponse, CompletionResponseStreamChunk, Message, + SystemMessage, TokenLogProbs, + ToolResponseMessage, + UserMessage, ) from llama_stack.models.llama.datatypes import ( + BuiltinTool, GreedySamplingStrategy, SamplingParams, StopReason, ToolCall, + ToolDefinition, TopKSamplingStrategy, TopPSamplingStrategy, ) @@ -177,6 +227,31 @@ def process_chat_completion_response( request: ChatCompletionRequest, ) -> ChatCompletionResponse: choice = response.choices[0] + if choice.finish_reason == "tool_calls": + if not choice.message or not choice.message.tool_calls: + raise ValueError("Tool calls are not present in the response") + + tool_calls = [convert_tool_call(tool_call) for tool_call in choice.message.tool_calls] + if any(isinstance(tool_call, UnparseableToolCall) for tool_call in tool_calls): + # If we couldn't parse a tool call, jsonify the tool calls and return them + return ChatCompletionResponse( + completion_message=CompletionMessage( + stop_reason=StopReason.end_of_turn, + content=json.dumps(tool_calls, default=lambda x: x.model_dump()), + ), + logprobs=None, + ) + else: + # Otherwise, return tool calls as normal + return ChatCompletionResponse( + completion_message=CompletionMessage( + tool_calls=tool_calls, + stop_reason=StopReason.end_of_turn, + # Content is not optional + content="", + ), + logprobs=None, + ) # TODO: This does not work well with tool calls for vLLM remote provider # Ref: https://github.com/meta-llama/llama-stack/issues/1058 @@ -417,6 +492,91 @@ class UnparseableToolCall(BaseModel): arguments: str = "" +async def convert_message_to_openai_dict_new(message: Message | Dict) -> OpenAIChatCompletionMessage: + """ + Convert a Message to an OpenAI API-compatible dictionary. + """ + # users can supply a dict instead of a Message object, we'll + # convert it to a Message object and proceed with some type safety. + if isinstance(message, dict): + if "role" not in message: + raise ValueError("role is required in message") + if message["role"] == "user": + message = UserMessage(**message) + elif message["role"] == "assistant": + message = CompletionMessage(**message) + elif message["role"] == "tool": + message = ToolResponseMessage(**message) + elif message["role"] == "system": + message = SystemMessage(**message) + else: + raise ValueError(f"Unsupported message role: {message['role']}") + + # Map Llama Stack spec to OpenAI spec - + # str -> str + # {"type": "text", "text": ...} -> {"type": "text", "text": ...} + # {"type": "image", "image": {"url": {"uri": ...}}} -> {"type": "image_url", "image_url": {"url": ...}} + # {"type": "image", "image": {"data": ...}} -> {"type": "image_url", "image_url": {"url": "data:image/?;base64,..."}} + # List[...] -> List[...] + async def _convert_user_message_content( + content: InterleavedContent, + ) -> Union[str, Iterable[OpenAIChatCompletionContentPartParam]]: + # Llama Stack and OpenAI spec match for str and text input + if isinstance(content, str): + return content + elif isinstance(content, TextContentItem): + return OpenAIChatCompletionContentPartTextParam( + text=content.text, + ) + elif isinstance(content, ImageContentItem): + return OpenAIChatCompletionContentPartImageParam( + image_url=OpenAIImageURL(url=await convert_image_content_to_url(content)), + type="image_url", + ) + elif isinstance(content, List): + return [await _convert_user_message_content(item) for item in content] + else: + raise ValueError(f"Unsupported content type: {type(content)}") + + out: OpenAIChatCompletionMessage = None + if isinstance(message, UserMessage): + out = OpenAIChatCompletionUserMessage( + role="user", + content=await _convert_user_message_content(message.content), + ) + elif isinstance(message, CompletionMessage): + out = OpenAIChatCompletionAssistantMessage( + role="assistant", + content=message.content, + tool_calls=[ + OpenAIChatCompletionMessageToolCall( + id=tool.call_id, + function=OpenAIFunction( + name=tool.tool_name, + arguments=json.dumps(tool.arguments), + ), + type="function", + ) + for tool in message.tool_calls + ], + ) + elif isinstance(message, ToolResponseMessage): + out = OpenAIChatCompletionToolMessage( + role="tool", + tool_call_id=message.call_id, + content=message.content, + ) + elif isinstance(message, SystemMessage): + out = OpenAIChatCompletionSystemMessage( + role="system", + content=message.content, + ) + else: + raise ValueError(f"Unsupported message type: {type(message)}") + + return out + + def convert_tool_call( tool_call: ChatCompletionMessageToolCall, ) -> Union[ToolCall, UnparseableToolCall]: @@ -439,3 +599,365 @@ def convert_tool_call( ) return valid_tool_call + + +def convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict: + """ + Convert a ToolDefinition to an OpenAI API-compatible dictionary. + + ToolDefinition: + tool_name: str | BuiltinTool + description: Optional[str] + parameters: Optional[Dict[str, ToolParamDefinition]] + + ToolParamDefinition: + param_type: str + description: Optional[str] + required: Optional[bool] + default: Optional[Any] + + + OpenAI spec - + + { + "type": "function", + "function": { + "name": tool_name, + "description": description, + "parameters": { + "type": "object", + "properties": { + param_name: { + "type": param_type, + "description": description, + "default": default, + }, + ... + }, + "required": [param_name, ...], + }, + }, + } + """ + out = { + "type": "function", + "function": {}, + } + function = out["function"] + + if isinstance(tool.tool_name, BuiltinTool): + function.update(name=tool.tool_name.value) # TODO(mf): is this sufficient? + else: + function.update(name=tool.tool_name) + + if tool.description: + function.update(description=tool.description) + + if tool.parameters: + parameters = { + "type": "object", + "properties": {}, + } + properties = parameters["properties"] + required = [] + for param_name, param in tool.parameters.items(): + properties[param_name] = {"type": param.param_type} + if param.description: + properties[param_name].update(description=param.description) + if param.default: + properties[param_name].update(default=param.default) + if param.required: + required.append(param_name) + + if required: + parameters.update(required=required) + + function.update(parameters=parameters) + + return out + + +def _convert_openai_finish_reason(finish_reason: str) -> StopReason: + """ + Convert an OpenAI chat completion finish_reason to a StopReason. + + finish_reason: Literal["stop", "length", "tool_calls", ...] + - stop: model hit a natural stop point or a provided stop sequence + - length: maximum number of tokens specified in the request was reached + - tool_calls: model called a tool + + -> + + class StopReason(Enum): + end_of_turn = "end_of_turn" + end_of_message = "end_of_message" + out_of_tokens = "out_of_tokens" + """ + + # TODO(mf): are end_of_turn and end_of_message semantics correct? + return { + "stop": StopReason.end_of_turn, + "length": StopReason.out_of_tokens, + "tool_calls": StopReason.end_of_message, + }.get(finish_reason, StopReason.end_of_turn) + + +def _convert_openai_tool_calls( + tool_calls: List[OpenAIChatCompletionMessageToolCall], +) -> List[ToolCall]: + """ + Convert an OpenAI ChatCompletionMessageToolCall list into a list of ToolCall. + + OpenAI ChatCompletionMessageToolCall: + id: str + function: Function + type: Literal["function"] + + OpenAI Function: + arguments: str + name: str + + -> + + ToolCall: + call_id: str + tool_name: str + arguments: Dict[str, ...] + """ + if not tool_calls: + return [] # CompletionMessage tool_calls is not optional + + return [ + ToolCall( + call_id=call.id, + tool_name=call.function.name, + arguments=json.loads(call.function.arguments), + ) + for call in tool_calls + ] + + +def _convert_openai_logprobs( + logprobs: OpenAIChoiceLogprobs, +) -> Optional[List[TokenLogProbs]]: + """ + Convert an OpenAI ChoiceLogprobs into a list of TokenLogProbs. + + OpenAI ChoiceLogprobs: + content: Optional[List[ChatCompletionTokenLogprob]] + + OpenAI ChatCompletionTokenLogprob: + token: str + logprob: float + top_logprobs: List[TopLogprob] + + OpenAI TopLogprob: + token: str + logprob: float + + -> + + TokenLogProbs: + logprobs_by_token: Dict[str, float] + - token, logprob + + """ + if not logprobs: + return None + + return [ + TokenLogProbs(logprobs_by_token={logprobs.token: logprobs.logprob for logprobs in content.top_logprobs}) + for content in logprobs.content + ] + + +def convert_openai_chat_completion_choice( + choice: OpenAIChoice, +) -> ChatCompletionResponse: + """ + Convert an OpenAI Choice into a ChatCompletionResponse. + + OpenAI Choice: + message: ChatCompletionMessage + finish_reason: str + logprobs: Optional[ChoiceLogprobs] + + OpenAI ChatCompletionMessage: + role: Literal["assistant"] + content: Optional[str] + tool_calls: Optional[List[ChatCompletionMessageToolCall]] + + -> + + ChatCompletionResponse: + completion_message: CompletionMessage + logprobs: Optional[List[TokenLogProbs]] + + CompletionMessage: + role: Literal["assistant"] + content: str | ImageMedia | List[str | ImageMedia] + stop_reason: StopReason + tool_calls: List[ToolCall] + + class StopReason(Enum): + end_of_turn = "end_of_turn" + end_of_message = "end_of_message" + out_of_tokens = "out_of_tokens" + """ + assert hasattr(choice, "message") and choice.message, "error in server response: message not found" + assert hasattr(choice, "finish_reason") and choice.finish_reason, ( + "error in server response: finish_reason not found" + ) + + return ChatCompletionResponse( + completion_message=CompletionMessage( + content=choice.message.content or "", # CompletionMessage content is not optional + stop_reason=_convert_openai_finish_reason(choice.finish_reason), + tool_calls=_convert_openai_tool_calls(choice.message.tool_calls), + ), + logprobs=_convert_openai_logprobs(getattr(choice, "logprobs", None)), + ) + + +async def convert_openai_chat_completion_stream( + stream: AsyncStream[OpenAIChatCompletionChunk], + enable_incremental_tool_calls: bool, +) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]: + """ + Convert a stream of OpenAI chat completion chunks into a stream + of ChatCompletionResponseStreamChunk. + """ + + # generate a stream of ChatCompletionResponseEventType: start -> progress -> progress -> ... + def _event_type_generator() -> Generator[ChatCompletionResponseEventType, None, None]: + yield ChatCompletionResponseEventType.start + while True: + yield ChatCompletionResponseEventType.progress + + event_type = _event_type_generator() + + stop_reason = None + toolcall_buffer = {} + async for chunk in stream: + choice = chunk.choices[0] # assuming only one choice per chunk + + # we assume there's only one finish_reason in the stream + stop_reason = _convert_openai_finish_reason(choice.finish_reason) or stop_reason + logprobs = getattr(choice, "logprobs", None) + + # if there's a tool call, emit an event for each tool in the list + # if tool call and content, emit both separately + + if choice.delta.tool_calls: + # the call may have content and a tool call. ChatCompletionResponseEvent + # does not support both, so we emit the content first + if choice.delta.content: + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=next(event_type), + delta=TextDelta(text=choice.delta.content), + logprobs=_convert_openai_logprobs(logprobs), + ) + ) + + # it is possible to have parallel tool calls in stream, but + # ChatCompletionResponseEvent only supports one per stream + if len(choice.delta.tool_calls) > 1: + warnings.warn("multiple tool calls found in a single delta, using the first, ignoring the rest") + + if not enable_incremental_tool_calls: + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=next(event_type), + delta=ToolCallDelta( + tool_call=_convert_openai_tool_calls(choice.delta.tool_calls)[0], + parse_status=ToolCallParseStatus.succeeded, + ), + logprobs=_convert_openai_logprobs(logprobs), + ) + ) + else: + tool_call = choice.delta.tool_calls[0] + if "name" not in toolcall_buffer: + toolcall_buffer["call_id"] = tool_call.id + toolcall_buffer["name"] = None + toolcall_buffer["content"] = "" + if "arguments" not in toolcall_buffer: + toolcall_buffer["arguments"] = "" + + if tool_call.function.name: + toolcall_buffer["name"] = tool_call.function.name + delta = f"{toolcall_buffer['name']}(" + if tool_call.function.arguments: + toolcall_buffer["arguments"] += tool_call.function.arguments + delta = toolcall_buffer["arguments"] + + toolcall_buffer["content"] += delta + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=next(event_type), + delta=ToolCallDelta( + tool_call=delta, + parse_status=ToolCallParseStatus.in_progress, + ), + logprobs=_convert_openai_logprobs(logprobs), + ) + ) + else: + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=next(event_type), + delta=TextDelta(text=choice.delta.content or ""), + logprobs=_convert_openai_logprobs(logprobs), + ) + ) + + if toolcall_buffer: + delta = ")" + toolcall_buffer["content"] += delta + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=next(event_type), + delta=ToolCallDelta( + tool_call=delta, + parse_status=ToolCallParseStatus.in_progress, + ), + logprobs=_convert_openai_logprobs(logprobs), + ) + ) + try: + arguments = json.loads(toolcall_buffer["arguments"]) + tool_call = ToolCall( + call_id=toolcall_buffer["call_id"], + tool_name=toolcall_buffer["name"], + arguments=arguments, + ) + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.complete, + delta=ToolCallDelta( + tool_call=tool_call, + parse_status=ToolCallParseStatus.succeeded, + ), + stop_reason=stop_reason, + ) + ) + except json.JSONDecodeError: + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.complete, + delta=ToolCallDelta( + tool_call=toolcall_buffer["content"], + parse_status=ToolCallParseStatus.failed, + ), + stop_reason=stop_reason, + ) + ) + + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.complete, + delta=TextDelta(text=""), + stop_reason=stop_reason, + ) + )