diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index e50d4d561..e3f1d0913 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -15,9 +15,17 @@ from typing import Any from openai import AsyncStream from openai.types.chat import ( ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage, +) +from openai.types.chat import ( ChatCompletionChunk as OpenAIChatCompletionChunk, +) +from openai.types.chat import ( ChatCompletionContentPartImageParam as OpenAIChatCompletionContentPartImageParam, +) +from openai.types.chat import ( ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam, +) +from openai.types.chat import ( ChatCompletionContentPartTextParam as OpenAIChatCompletionContentPartTextParam, ) @@ -29,15 +37,56 @@ except ImportError: from openai.types.chat.chat_completion_message_tool_call import ( ChatCompletionMessageToolCall as OpenAIChatCompletionMessageFunctionToolCall, ) +from openai.types.chat import ( + ChatCompletionMessageParam as OpenAIChatCompletionMessage, +) +from openai.types.chat import ( + ChatCompletionMessageToolCall, +) +from openai.types.chat import ( + ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage, +) +from openai.types.chat import ( + ChatCompletionToolMessageParam as OpenAIChatCompletionToolMessage, +) +from openai.types.chat import ( + ChatCompletionUserMessageParam as OpenAIChatCompletionUserMessage, +) +from openai.types.chat.chat_completion import ( + Choice as OpenAIChoice, +) +from openai.types.chat.chat_completion import ( + ChoiceLogprobs as OpenAIChoiceLogprobs, # same as chat_completion_chunk ChoiceLogprobs +) +from openai.types.chat.chat_completion_chunk import ( + Choice as OpenAIChatCompletionChunkChoice, +) +from openai.types.chat.chat_completion_chunk import ( + ChoiceDelta as OpenAIChoiceDelta, +) +from openai.types.chat.chat_completion_chunk import ( + ChoiceDeltaToolCall as OpenAIChoiceDeltaToolCall, +) +from openai.types.chat.chat_completion_chunk import ( + ChoiceDeltaToolCallFunction as OpenAIChoiceDeltaToolCallFunction, +) +from openai.types.chat.chat_completion_content_part_image_param import ( + ImageURL as OpenAIImageURL, +) +from openai.types.chat.chat_completion_message_tool_call import ( + Function as OpenAIFunction, +) +from pydantic import BaseModel + from llama_stack.apis.common.content_types import ( - _URLOrData, + URL, ImageContentItem, InterleavedContent, TextContentItem, TextDelta, ToolCallDelta, ToolCallParseStatus, - URL, + _URLOrData, ) from llama_stack.apis.inference import ( ChatCompletionRequest, @@ -50,7 +99,6 @@ from llama_stack.apis.inference import ( JsonSchemaResponseFormat, Message, OpenAIChatCompletion, - OpenAIChoice as OpenAIChatCompletionChoice, OpenAIEmbeddingData, OpenAIMessageParam, OpenAIResponseFormatParam, @@ -64,6 +112,9 @@ from llama_stack.apis.inference import ( TopPSamplingStrategy, UserMessage, ) +from llama_stack.apis.inference import ( + OpenAIChoice as OpenAIChatCompletionChoice, +) from llama_stack.log import get_logger from llama_stack.models.llama.datatypes import ( BuiltinTool, @@ -75,30 +126,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( convert_image_content_to_url, decode_assistant_message, ) -from openai.types.chat import ( - ChatCompletionMessageParam as OpenAIChatCompletionMessage, - ChatCompletionMessageToolCall, - ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage, - ChatCompletionToolMessageParam as OpenAIChatCompletionToolMessage, - ChatCompletionUserMessageParam as OpenAIChatCompletionUserMessage, -) -from openai.types.chat.chat_completion import ( - Choice as OpenAIChoice, - ChoiceLogprobs as OpenAIChoiceLogprobs, # same as chat_completion_chunk ChoiceLogprobs -) -from openai.types.chat.chat_completion_chunk import ( - Choice as OpenAIChatCompletionChunkChoice, - ChoiceDelta as OpenAIChoiceDelta, - ChoiceDeltaToolCall as OpenAIChoiceDeltaToolCall, - ChoiceDeltaToolCallFunction as OpenAIChoiceDeltaToolCallFunction, -) -from openai.types.chat.chat_completion_content_part_image_param import ( - ImageURL as OpenAIImageURL, -) -from openai.types.chat.chat_completion_message_tool_call import ( - Function as OpenAIFunction, -) -from pydantic import BaseModel logger = get_logger(name=__name__, category="providers::utils") @@ -197,16 +224,12 @@ def convert_openai_completion_logprobs( if logprobs.tokens and logprobs.token_logprobs: return [ TokenLogProbs(logprobs_by_token={token: token_lp}) - for token, token_lp in zip( - logprobs.tokens, logprobs.token_logprobs, strict=False - ) + for token, token_lp in zip(logprobs.tokens, logprobs.token_logprobs, strict=False) ] return None -def convert_openai_completion_logprobs_stream( - text: str, logprobs: float | OpenAICompatLogprobs | None -): +def convert_openai_completion_logprobs_stream(text: str, logprobs: float | OpenAICompatLogprobs | None): if logprobs is None: return None if isinstance(logprobs, float): @@ -226,9 +249,7 @@ def process_chat_completion_response( if not choice.message or not choice.message.tool_calls: raise ValueError("Tool calls are not present in the response") - tool_calls = [ - convert_tool_call(tool_call) for tool_call in choice.message.tool_calls - ] + tool_calls = [convert_tool_call(tool_call) for tool_call in choice.message.tool_calls] if any(isinstance(tool_call, UnparseableToolCall) for tool_call in tool_calls): # If we couldn't parse a tool call, jsonify the tool calls and return them return ChatCompletionResponse( @@ -252,9 +273,7 @@ def process_chat_completion_response( # TODO: This does not work well with tool calls for vLLM remote provider # Ref: https://github.com/meta-llama/llama-stack/issues/1058 - raw_message = decode_assistant_message( - text_from_choice(choice), get_stop_reason(choice.finish_reason) - ) + raw_message = decode_assistant_message(text_from_choice(choice), get_stop_reason(choice.finish_reason)) # NOTE: If we do not set tools in chat-completion request, we should not # expect the ToolCall in the response. Instead, we should return the raw @@ -455,17 +474,13 @@ async def process_chat_completion_stream_response( ) -async def convert_message_to_openai_dict( - message: Message, download: bool = False -) -> dict: +async def convert_message_to_openai_dict(message: Message, download: bool = False) -> dict: async def _convert_content(content) -> dict: if isinstance(content, ImageContentItem): return { "type": "image_url", "image_url": { - "url": await convert_image_content_to_url( - content, download=download - ), + "url": await convert_image_content_to_url(content, download=download), }, } else: @@ -550,11 +565,7 @@ async def convert_message_to_openai_dict_new( ) -> str | Iterable[OpenAIChatCompletionContentPartParam]: async def impl( content_: InterleavedContent, - ) -> ( - str - | OpenAIChatCompletionContentPartParam - | list[OpenAIChatCompletionContentPartParam] - ): + ) -> str | OpenAIChatCompletionContentPartParam | list[OpenAIChatCompletionContentPartParam]: # Llama Stack and OpenAI spec match for str and text input if isinstance(content_, str): return content_ @@ -567,9 +578,7 @@ async def convert_message_to_openai_dict_new( return OpenAIChatCompletionContentPartImageParam( type="image_url", image_url=OpenAIImageURL( - url=await convert_image_content_to_url( - content_, download=download_images - ) + url=await convert_image_content_to_url(content_, download=download_images) ), ) elif isinstance(content_, list): @@ -596,11 +605,7 @@ async def convert_message_to_openai_dict_new( OpenAIChatCompletionMessageFunctionToolCall( id=tool.call_id, function=OpenAIFunction( - name=( - tool.tool_name - if not isinstance(tool.tool_name, BuiltinTool) - else tool.tool_name.value - ), + name=(tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value), arguments=tool.arguments, # Already a JSON string, don't double-encode ), type="function", @@ -780,9 +785,7 @@ def _convert_openai_finish_reason(finish_reason: str) -> StopReason: }.get(finish_reason, StopReason.end_of_turn) -def _convert_openai_request_tool_config( - tool_choice: str | dict[str, Any] | None = None -) -> ToolConfig: +def _convert_openai_request_tool_config(tool_choice: str | dict[str, Any] | None = None) -> ToolConfig: tool_config = ToolConfig() if tool_choice: try: @@ -793,9 +796,7 @@ def _convert_openai_request_tool_config( return tool_config -def _convert_openai_request_tools( - tools: list[dict[str, Any]] | None = None -) -> list[ToolDefinition]: +def _convert_openai_request_tools(tools: list[dict[str, Any]] | None = None) -> list[ToolDefinition]: lls_tools = [] if not tools: return lls_tools @@ -894,11 +895,7 @@ def _convert_openai_logprobs( return None return [ - TokenLogProbs( - logprobs_by_token={ - logprobs.token: logprobs.logprob for logprobs in content.top_logprobs - } - ) + TokenLogProbs(logprobs_by_token={logprobs.token: logprobs.logprob for logprobs in content.top_logprobs}) for content in logprobs.content ] @@ -937,13 +934,9 @@ def openai_messages_to_messages( converted_messages = [] for message in messages: if message.role == "system": - converted_message = SystemMessage( - content=openai_content_to_content(message.content) - ) + converted_message = SystemMessage(content=openai_content_to_content(message.content)) elif message.role == "user": - converted_message = UserMessage( - content=openai_content_to_content(message.content) - ) + converted_message = UserMessage(content=openai_content_to_content(message.content)) elif message.role == "assistant": converted_message = CompletionMessage( content=openai_content_to_content(message.content), @@ -975,9 +968,7 @@ def openai_content_to_content( if content.type == "text": return TextContentItem(type="text", text=content.text) elif content.type == "image_url": - return ImageContentItem( - type="image", image=_URLOrData(url=URL(uri=content.image_url.url)) - ) + return ImageContentItem(type="image", image=_URLOrData(url=URL(uri=content.image_url.url))) else: raise ValueError(f"Unknown content type: {content.type}") else: @@ -1017,17 +1008,14 @@ def convert_openai_chat_completion_choice( end_of_message = "end_of_message" out_of_tokens = "out_of_tokens" """ - assert ( - hasattr(choice, "message") and choice.message - ), "error in server response: message not found" - assert ( - hasattr(choice, "finish_reason") and choice.finish_reason - ), "error in server response: finish_reason not found" + assert hasattr(choice, "message") and choice.message, "error in server response: message not found" + assert hasattr(choice, "finish_reason") and choice.finish_reason, ( + "error in server response: finish_reason not found" + ) return ChatCompletionResponse( completion_message=CompletionMessage( - content=choice.message.content - or "", # CompletionMessage content is not optional + content=choice.message.content or "", # CompletionMessage content is not optional stop_reason=_convert_openai_finish_reason(choice.finish_reason), tool_calls=_convert_openai_tool_calls(choice.message.tool_calls), ), @@ -1267,9 +1255,7 @@ class OpenAIChatCompletionToLlamaStackMixin: outstanding_responses.append(response) if stream: - return OpenAIChatCompletionToLlamaStackMixin._process_stream_response( - self, model, outstanding_responses - ) + return OpenAIChatCompletionToLlamaStackMixin._process_stream_response(self, model, outstanding_responses) return await OpenAIChatCompletionToLlamaStackMixin._process_non_stream_response( self, model, outstanding_responses @@ -1278,29 +1264,21 @@ class OpenAIChatCompletionToLlamaStackMixin: async def _process_stream_response( self, model: str, - outstanding_responses: list[ - Awaitable[AsyncIterator[ChatCompletionResponseStreamChunk]] - ], + outstanding_responses: list[Awaitable[AsyncIterator[ChatCompletionResponseStreamChunk]]], ): id = f"chatcmpl-{uuid.uuid4()}" for i, outstanding_response in enumerate(outstanding_responses): response = await outstanding_response async for chunk in response: event = chunk.event - finish_reason = _convert_stop_reason_to_openai_finish_reason( - event.stop_reason - ) + finish_reason = _convert_stop_reason_to_openai_finish_reason(event.stop_reason) if isinstance(event.delta, TextDelta): text_delta = event.delta.text delta = OpenAIChoiceDelta(content=text_delta) yield OpenAIChatCompletionChunk( id=id, - choices=[ - OpenAIChatCompletionChunkChoice( - index=i, finish_reason=finish_reason, delta=delta - ) - ], + choices=[OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta)], created=int(time.time()), model=model, object="chat.completion.chunk", @@ -1322,9 +1300,7 @@ class OpenAIChatCompletionToLlamaStackMixin: yield OpenAIChatCompletionChunk( id=id, choices=[ - OpenAIChatCompletionChunkChoice( - index=i, finish_reason=finish_reason, delta=delta - ) + OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta) ], created=int(time.time()), model=model, @@ -1341,9 +1317,7 @@ class OpenAIChatCompletionToLlamaStackMixin: yield OpenAIChatCompletionChunk( id=id, choices=[ - OpenAIChatCompletionChunkChoice( - index=i, finish_reason=finish_reason, delta=delta - ) + OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta) ], created=int(time.time()), model=model, @@ -1358,9 +1332,7 @@ class OpenAIChatCompletionToLlamaStackMixin: response = await outstanding_response completion_message = response.completion_message message = await convert_message_to_openai_dict_new(completion_message) - finish_reason = _convert_stop_reason_to_openai_finish_reason( - completion_message.stop_reason - ) + finish_reason = _convert_stop_reason_to_openai_finish_reason(completion_message.stop_reason) choice = OpenAIChatCompletionChoice( index=len(choices),