reformatting

This commit is contained in:
Omar Abdelwahab 2025-10-06 14:35:38 -07:00
parent 6adaca3d96
commit f4104756f6

View file

@ -15,9 +15,17 @@ from typing import Any
from openai import AsyncStream from openai import AsyncStream
from openai.types.chat import ( from openai.types.chat import (
ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage, ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage,
)
from openai.types.chat import (
ChatCompletionChunk as OpenAIChatCompletionChunk, ChatCompletionChunk as OpenAIChatCompletionChunk,
)
from openai.types.chat import (
ChatCompletionContentPartImageParam as OpenAIChatCompletionContentPartImageParam, ChatCompletionContentPartImageParam as OpenAIChatCompletionContentPartImageParam,
)
from openai.types.chat import (
ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam, ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam,
)
from openai.types.chat import (
ChatCompletionContentPartTextParam as OpenAIChatCompletionContentPartTextParam, ChatCompletionContentPartTextParam as OpenAIChatCompletionContentPartTextParam,
) )
@ -29,15 +37,56 @@ except ImportError:
from openai.types.chat.chat_completion_message_tool_call import ( from openai.types.chat.chat_completion_message_tool_call import (
ChatCompletionMessageToolCall as OpenAIChatCompletionMessageFunctionToolCall, ChatCompletionMessageToolCall as OpenAIChatCompletionMessageFunctionToolCall,
) )
from openai.types.chat import (
ChatCompletionMessageParam as OpenAIChatCompletionMessage,
)
from openai.types.chat import (
ChatCompletionMessageToolCall,
)
from openai.types.chat import (
ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage,
)
from openai.types.chat import (
ChatCompletionToolMessageParam as OpenAIChatCompletionToolMessage,
)
from openai.types.chat import (
ChatCompletionUserMessageParam as OpenAIChatCompletionUserMessage,
)
from openai.types.chat.chat_completion import (
Choice as OpenAIChoice,
)
from openai.types.chat.chat_completion import (
ChoiceLogprobs as OpenAIChoiceLogprobs, # same as chat_completion_chunk ChoiceLogprobs
)
from openai.types.chat.chat_completion_chunk import (
Choice as OpenAIChatCompletionChunkChoice,
)
from openai.types.chat.chat_completion_chunk import (
ChoiceDelta as OpenAIChoiceDelta,
)
from openai.types.chat.chat_completion_chunk import (
ChoiceDeltaToolCall as OpenAIChoiceDeltaToolCall,
)
from openai.types.chat.chat_completion_chunk import (
ChoiceDeltaToolCallFunction as OpenAIChoiceDeltaToolCallFunction,
)
from openai.types.chat.chat_completion_content_part_image_param import (
ImageURL as OpenAIImageURL,
)
from openai.types.chat.chat_completion_message_tool_call import (
Function as OpenAIFunction,
)
from pydantic import BaseModel
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
_URLOrData, URL,
ImageContentItem, ImageContentItem,
InterleavedContent, InterleavedContent,
TextContentItem, TextContentItem,
TextDelta, TextDelta,
ToolCallDelta, ToolCallDelta,
ToolCallParseStatus, ToolCallParseStatus,
URL, _URLOrData,
) )
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionRequest, ChatCompletionRequest,
@ -50,7 +99,6 @@ from llama_stack.apis.inference import (
JsonSchemaResponseFormat, JsonSchemaResponseFormat,
Message, Message,
OpenAIChatCompletion, OpenAIChatCompletion,
OpenAIChoice as OpenAIChatCompletionChoice,
OpenAIEmbeddingData, OpenAIEmbeddingData,
OpenAIMessageParam, OpenAIMessageParam,
OpenAIResponseFormatParam, OpenAIResponseFormatParam,
@ -64,6 +112,9 @@ from llama_stack.apis.inference import (
TopPSamplingStrategy, TopPSamplingStrategy,
UserMessage, UserMessage,
) )
from llama_stack.apis.inference import (
OpenAIChoice as OpenAIChatCompletionChoice,
)
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import ( from llama_stack.models.llama.datatypes import (
BuiltinTool, BuiltinTool,
@ -75,30 +126,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
convert_image_content_to_url, convert_image_content_to_url,
decode_assistant_message, decode_assistant_message,
) )
from openai.types.chat import (
ChatCompletionMessageParam as OpenAIChatCompletionMessage,
ChatCompletionMessageToolCall,
ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage,
ChatCompletionToolMessageParam as OpenAIChatCompletionToolMessage,
ChatCompletionUserMessageParam as OpenAIChatCompletionUserMessage,
)
from openai.types.chat.chat_completion import (
Choice as OpenAIChoice,
ChoiceLogprobs as OpenAIChoiceLogprobs, # same as chat_completion_chunk ChoiceLogprobs
)
from openai.types.chat.chat_completion_chunk import (
Choice as OpenAIChatCompletionChunkChoice,
ChoiceDelta as OpenAIChoiceDelta,
ChoiceDeltaToolCall as OpenAIChoiceDeltaToolCall,
ChoiceDeltaToolCallFunction as OpenAIChoiceDeltaToolCallFunction,
)
from openai.types.chat.chat_completion_content_part_image_param import (
ImageURL as OpenAIImageURL,
)
from openai.types.chat.chat_completion_message_tool_call import (
Function as OpenAIFunction,
)
from pydantic import BaseModel
logger = get_logger(name=__name__, category="providers::utils") logger = get_logger(name=__name__, category="providers::utils")
@ -197,16 +224,12 @@ def convert_openai_completion_logprobs(
if logprobs.tokens and logprobs.token_logprobs: if logprobs.tokens and logprobs.token_logprobs:
return [ return [
TokenLogProbs(logprobs_by_token={token: token_lp}) TokenLogProbs(logprobs_by_token={token: token_lp})
for token, token_lp in zip( for token, token_lp in zip(logprobs.tokens, logprobs.token_logprobs, strict=False)
logprobs.tokens, logprobs.token_logprobs, strict=False
)
] ]
return None return None
def convert_openai_completion_logprobs_stream( def convert_openai_completion_logprobs_stream(text: str, logprobs: float | OpenAICompatLogprobs | None):
text: str, logprobs: float | OpenAICompatLogprobs | None
):
if logprobs is None: if logprobs is None:
return None return None
if isinstance(logprobs, float): if isinstance(logprobs, float):
@ -226,9 +249,7 @@ def process_chat_completion_response(
if not choice.message or not choice.message.tool_calls: if not choice.message or not choice.message.tool_calls:
raise ValueError("Tool calls are not present in the response") raise ValueError("Tool calls are not present in the response")
tool_calls = [ tool_calls = [convert_tool_call(tool_call) for tool_call in choice.message.tool_calls]
convert_tool_call(tool_call) for tool_call in choice.message.tool_calls
]
if any(isinstance(tool_call, UnparseableToolCall) for tool_call in tool_calls): if any(isinstance(tool_call, UnparseableToolCall) for tool_call in tool_calls):
# If we couldn't parse a tool call, jsonify the tool calls and return them # If we couldn't parse a tool call, jsonify the tool calls and return them
return ChatCompletionResponse( return ChatCompletionResponse(
@ -252,9 +273,7 @@ def process_chat_completion_response(
# TODO: This does not work well with tool calls for vLLM remote provider # TODO: This does not work well with tool calls for vLLM remote provider
# Ref: https://github.com/meta-llama/llama-stack/issues/1058 # Ref: https://github.com/meta-llama/llama-stack/issues/1058
raw_message = decode_assistant_message( raw_message = decode_assistant_message(text_from_choice(choice), get_stop_reason(choice.finish_reason))
text_from_choice(choice), get_stop_reason(choice.finish_reason)
)
# NOTE: If we do not set tools in chat-completion request, we should not # NOTE: If we do not set tools in chat-completion request, we should not
# expect the ToolCall in the response. Instead, we should return the raw # expect the ToolCall in the response. Instead, we should return the raw
@ -455,17 +474,13 @@ async def process_chat_completion_stream_response(
) )
async def convert_message_to_openai_dict( async def convert_message_to_openai_dict(message: Message, download: bool = False) -> dict:
message: Message, download: bool = False
) -> dict:
async def _convert_content(content) -> dict: async def _convert_content(content) -> dict:
if isinstance(content, ImageContentItem): if isinstance(content, ImageContentItem):
return { return {
"type": "image_url", "type": "image_url",
"image_url": { "image_url": {
"url": await convert_image_content_to_url( "url": await convert_image_content_to_url(content, download=download),
content, download=download
),
}, },
} }
else: else:
@ -550,11 +565,7 @@ async def convert_message_to_openai_dict_new(
) -> str | Iterable[OpenAIChatCompletionContentPartParam]: ) -> str | Iterable[OpenAIChatCompletionContentPartParam]:
async def impl( async def impl(
content_: InterleavedContent, content_: InterleavedContent,
) -> ( ) -> str | OpenAIChatCompletionContentPartParam | list[OpenAIChatCompletionContentPartParam]:
str
| OpenAIChatCompletionContentPartParam
| list[OpenAIChatCompletionContentPartParam]
):
# Llama Stack and OpenAI spec match for str and text input # Llama Stack and OpenAI spec match for str and text input
if isinstance(content_, str): if isinstance(content_, str):
return content_ return content_
@ -567,9 +578,7 @@ async def convert_message_to_openai_dict_new(
return OpenAIChatCompletionContentPartImageParam( return OpenAIChatCompletionContentPartImageParam(
type="image_url", type="image_url",
image_url=OpenAIImageURL( image_url=OpenAIImageURL(
url=await convert_image_content_to_url( url=await convert_image_content_to_url(content_, download=download_images)
content_, download=download_images
)
), ),
) )
elif isinstance(content_, list): elif isinstance(content_, list):
@ -596,11 +605,7 @@ async def convert_message_to_openai_dict_new(
OpenAIChatCompletionMessageFunctionToolCall( OpenAIChatCompletionMessageFunctionToolCall(
id=tool.call_id, id=tool.call_id,
function=OpenAIFunction( function=OpenAIFunction(
name=( name=(tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value),
tool.tool_name
if not isinstance(tool.tool_name, BuiltinTool)
else tool.tool_name.value
),
arguments=tool.arguments, # Already a JSON string, don't double-encode arguments=tool.arguments, # Already a JSON string, don't double-encode
), ),
type="function", type="function",
@ -780,9 +785,7 @@ def _convert_openai_finish_reason(finish_reason: str) -> StopReason:
}.get(finish_reason, StopReason.end_of_turn) }.get(finish_reason, StopReason.end_of_turn)
def _convert_openai_request_tool_config( def _convert_openai_request_tool_config(tool_choice: str | dict[str, Any] | None = None) -> ToolConfig:
tool_choice: str | dict[str, Any] | None = None
) -> ToolConfig:
tool_config = ToolConfig() tool_config = ToolConfig()
if tool_choice: if tool_choice:
try: try:
@ -793,9 +796,7 @@ def _convert_openai_request_tool_config(
return tool_config return tool_config
def _convert_openai_request_tools( def _convert_openai_request_tools(tools: list[dict[str, Any]] | None = None) -> list[ToolDefinition]:
tools: list[dict[str, Any]] | None = None
) -> list[ToolDefinition]:
lls_tools = [] lls_tools = []
if not tools: if not tools:
return lls_tools return lls_tools
@ -894,11 +895,7 @@ def _convert_openai_logprobs(
return None return None
return [ return [
TokenLogProbs( TokenLogProbs(logprobs_by_token={logprobs.token: logprobs.logprob for logprobs in content.top_logprobs})
logprobs_by_token={
logprobs.token: logprobs.logprob for logprobs in content.top_logprobs
}
)
for content in logprobs.content for content in logprobs.content
] ]
@ -937,13 +934,9 @@ def openai_messages_to_messages(
converted_messages = [] converted_messages = []
for message in messages: for message in messages:
if message.role == "system": if message.role == "system":
converted_message = SystemMessage( converted_message = SystemMessage(content=openai_content_to_content(message.content))
content=openai_content_to_content(message.content)
)
elif message.role == "user": elif message.role == "user":
converted_message = UserMessage( converted_message = UserMessage(content=openai_content_to_content(message.content))
content=openai_content_to_content(message.content)
)
elif message.role == "assistant": elif message.role == "assistant":
converted_message = CompletionMessage( converted_message = CompletionMessage(
content=openai_content_to_content(message.content), content=openai_content_to_content(message.content),
@ -975,9 +968,7 @@ def openai_content_to_content(
if content.type == "text": if content.type == "text":
return TextContentItem(type="text", text=content.text) return TextContentItem(type="text", text=content.text)
elif content.type == "image_url": elif content.type == "image_url":
return ImageContentItem( return ImageContentItem(type="image", image=_URLOrData(url=URL(uri=content.image_url.url)))
type="image", image=_URLOrData(url=URL(uri=content.image_url.url))
)
else: else:
raise ValueError(f"Unknown content type: {content.type}") raise ValueError(f"Unknown content type: {content.type}")
else: else:
@ -1017,17 +1008,14 @@ def convert_openai_chat_completion_choice(
end_of_message = "end_of_message" end_of_message = "end_of_message"
out_of_tokens = "out_of_tokens" out_of_tokens = "out_of_tokens"
""" """
assert ( assert hasattr(choice, "message") and choice.message, "error in server response: message not found"
hasattr(choice, "message") and choice.message assert hasattr(choice, "finish_reason") and choice.finish_reason, (
), "error in server response: message not found" "error in server response: finish_reason not found"
assert ( )
hasattr(choice, "finish_reason") and choice.finish_reason
), "error in server response: finish_reason not found"
return ChatCompletionResponse( return ChatCompletionResponse(
completion_message=CompletionMessage( completion_message=CompletionMessage(
content=choice.message.content content=choice.message.content or "", # CompletionMessage content is not optional
or "", # CompletionMessage content is not optional
stop_reason=_convert_openai_finish_reason(choice.finish_reason), stop_reason=_convert_openai_finish_reason(choice.finish_reason),
tool_calls=_convert_openai_tool_calls(choice.message.tool_calls), tool_calls=_convert_openai_tool_calls(choice.message.tool_calls),
), ),
@ -1267,9 +1255,7 @@ class OpenAIChatCompletionToLlamaStackMixin:
outstanding_responses.append(response) outstanding_responses.append(response)
if stream: if stream:
return OpenAIChatCompletionToLlamaStackMixin._process_stream_response( return OpenAIChatCompletionToLlamaStackMixin._process_stream_response(self, model, outstanding_responses)
self, model, outstanding_responses
)
return await OpenAIChatCompletionToLlamaStackMixin._process_non_stream_response( return await OpenAIChatCompletionToLlamaStackMixin._process_non_stream_response(
self, model, outstanding_responses self, model, outstanding_responses
@ -1278,29 +1264,21 @@ class OpenAIChatCompletionToLlamaStackMixin:
async def _process_stream_response( async def _process_stream_response(
self, self,
model: str, model: str,
outstanding_responses: list[ outstanding_responses: list[Awaitable[AsyncIterator[ChatCompletionResponseStreamChunk]]],
Awaitable[AsyncIterator[ChatCompletionResponseStreamChunk]]
],
): ):
id = f"chatcmpl-{uuid.uuid4()}" id = f"chatcmpl-{uuid.uuid4()}"
for i, outstanding_response in enumerate(outstanding_responses): for i, outstanding_response in enumerate(outstanding_responses):
response = await outstanding_response response = await outstanding_response
async for chunk in response: async for chunk in response:
event = chunk.event event = chunk.event
finish_reason = _convert_stop_reason_to_openai_finish_reason( finish_reason = _convert_stop_reason_to_openai_finish_reason(event.stop_reason)
event.stop_reason
)
if isinstance(event.delta, TextDelta): if isinstance(event.delta, TextDelta):
text_delta = event.delta.text text_delta = event.delta.text
delta = OpenAIChoiceDelta(content=text_delta) delta = OpenAIChoiceDelta(content=text_delta)
yield OpenAIChatCompletionChunk( yield OpenAIChatCompletionChunk(
id=id, id=id,
choices=[ choices=[OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta)],
OpenAIChatCompletionChunkChoice(
index=i, finish_reason=finish_reason, delta=delta
)
],
created=int(time.time()), created=int(time.time()),
model=model, model=model,
object="chat.completion.chunk", object="chat.completion.chunk",
@ -1322,9 +1300,7 @@ class OpenAIChatCompletionToLlamaStackMixin:
yield OpenAIChatCompletionChunk( yield OpenAIChatCompletionChunk(
id=id, id=id,
choices=[ choices=[
OpenAIChatCompletionChunkChoice( OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta)
index=i, finish_reason=finish_reason, delta=delta
)
], ],
created=int(time.time()), created=int(time.time()),
model=model, model=model,
@ -1341,9 +1317,7 @@ class OpenAIChatCompletionToLlamaStackMixin:
yield OpenAIChatCompletionChunk( yield OpenAIChatCompletionChunk(
id=id, id=id,
choices=[ choices=[
OpenAIChatCompletionChunkChoice( OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta)
index=i, finish_reason=finish_reason, delta=delta
)
], ],
created=int(time.time()), created=int(time.time()),
model=model, model=model,
@ -1358,9 +1332,7 @@ class OpenAIChatCompletionToLlamaStackMixin:
response = await outstanding_response response = await outstanding_response
completion_message = response.completion_message completion_message = response.completion_message
message = await convert_message_to_openai_dict_new(completion_message) message = await convert_message_to_openai_dict_new(completion_message)
finish_reason = _convert_stop_reason_to_openai_finish_reason( finish_reason = _convert_stop_reason_to_openai_finish_reason(completion_message.stop_reason)
completion_message.stop_reason
)
choice = OpenAIChatCompletionChoice( choice = OpenAIChatCompletionChoice(
index=len(choices), index=len(choices),