fixes and linting

This commit is contained in:
Hardik Shah 2025-03-28 18:33:36 -07:00
parent 021dd0d35d
commit 5251d2422d
8 changed files with 149 additions and 345 deletions

View file

@ -6,7 +6,49 @@
import json
import logging
import warnings
from typing import AsyncGenerator, Dict, Generator, Iterable, List, Optional, Union
from typing import AsyncGenerator, Dict, Iterable, List, Optional, Union
from openai import AsyncStream
from openai.types.chat import (
ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage,
)
from openai.types.chat import (
ChatCompletionChunk as OpenAIChatCompletionChunk,
)
from openai.types.chat import (
ChatCompletionContentPartImageParam as OpenAIChatCompletionContentPartImageParam,
)
from openai.types.chat import (
ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam,
)
from openai.types.chat import (
ChatCompletionContentPartTextParam as OpenAIChatCompletionContentPartTextParam,
)
from openai.types.chat import (
ChatCompletionMessageParam as OpenAIChatCompletionMessage,
)
from openai.types.chat import (
ChatCompletionMessageToolCall as OpenAIChatCompletionMessageToolCall,
)
from openai.types.chat import (
ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage,
)
from openai.types.chat import (
ChatCompletionToolMessageParam as OpenAIChatCompletionToolMessage,
)
from openai.types.chat import (
ChatCompletionUserMessageParam as OpenAIChatCompletionUserMessage,
)
from openai.types.chat.chat_completion import (
Choice as OpenAIChoice,
)
from openai.types.chat.chat_completion import (
ChoiceLogprobs as OpenAIChoiceLogprobs, # same as chat_completion_chunk ChoiceLogprobs
)
from openai.types.chat.chat_completion_content_part_image_param import (
ImageURL as OpenAIImageURL,
)
from pydantic import BaseModel
from llama_stack.apis.common.content_types import (
ImageContentItem,
@ -23,7 +65,6 @@ from llama_stack.apis.inference import (
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionMessage,
CompletionRequest,
CompletionResponse,
CompletionResponseStreamChunk,
JsonSchemaResponseFormat,
@ -49,32 +90,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
decode_assistant_message,
)
from openai import AsyncStream, Stream
from openai.types.chat import (
ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage,
ChatCompletionChunk as OpenAIChatCompletionChunk,
ChatCompletionContentPartImageParam as OpenAIChatCompletionContentPartImageParam,
ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam,
ChatCompletionContentPartTextParam as OpenAIChatCompletionContentPartTextParam,
ChatCompletionMessageParam as OpenAIChatCompletionMessage,
ChatCompletionMessageToolCall as OpenAIChatCompletionMessageToolCall,
ChatCompletionMessageToolCallParam as OpenAIChatCompletionMessageToolCallParam,
ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage,
ChatCompletionToolMessageParam as OpenAIChatCompletionToolMessage,
ChatCompletionUserMessageParam as OpenAIChatCompletionUserMessage,
)
from openai.types.chat.chat_completion import (
Choice as OpenAIChoice,
ChoiceLogprobs as OpenAIChoiceLogprobs, # same as chat_completion_chunk ChoiceLogprobs
)
from openai.types.chat.chat_completion_content_part_image_param import (
ImageURL as OpenAIImageURL,
)
from openai.types.chat.chat_completion_message_tool_call_param import (
Function as OpenAIFunction,
)
from pydantic import BaseModel
logger = logging.getLogger(__name__)
@ -169,16 +184,12 @@ def convert_openai_completion_logprobs(
if logprobs.tokens and logprobs.token_logprobs:
return [
TokenLogProbs(logprobs_by_token={token: token_lp})
for token, token_lp in zip(
logprobs.tokens, logprobs.token_logprobs, strict=False
)
for token, token_lp in zip(logprobs.tokens, logprobs.token_logprobs, strict=False)
]
return None
def convert_openai_completion_logprobs_stream(
text: str, logprobs: Optional[Union[float, OpenAICompatLogprobs]]
):
def convert_openai_completion_logprobs_stream(text: str, logprobs: Optional[Union[float, OpenAICompatLogprobs]]):
if logprobs is None:
return None
if isinstance(logprobs, float):
@ -223,9 +234,7 @@ def process_chat_completion_response(
if not choice.message or not choice.message.tool_calls:
raise ValueError("Tool calls are not present in the response")
tool_calls = [
convert_tool_call(tool_call) for tool_call in choice.message.tool_calls
]
tool_calls = [convert_tool_call(tool_call) for tool_call in choice.message.tool_calls]
if any(isinstance(tool_call, UnparseableToolCall) for tool_call in tool_calls):
# If we couldn't parse a tool call, jsonify the tool calls and return them
return ChatCompletionResponse(
@ -249,9 +258,7 @@ def process_chat_completion_response(
# TODO: This does not work well with tool calls for vLLM remote provider
# Ref: https://github.com/meta-llama/llama-stack/issues/1058
raw_message = decode_assistant_message(
text_from_choice(choice), get_stop_reason(choice.finish_reason)
)
raw_message = decode_assistant_message(text_from_choice(choice), get_stop_reason(choice.finish_reason))
# NOTE: If we do not set tools in chat-completion request, we should not
# expect the ToolCall in the response. Instead, we should return the raw
@ -452,17 +459,13 @@ async def process_chat_completion_stream_response(
)
async def convert_message_to_openai_dict(
message: Message, download: bool = False
) -> dict:
async def convert_message_to_openai_dict(message: Message, download: bool = False) -> dict:
async def _convert_content(content) -> dict:
if isinstance(content, ImageContentItem):
return {
"type": "image_url",
"image_url": {
"url": await convert_image_content_to_url(
content, download=download
),
"url": await convert_image_content_to_url(content, download=download),
},
}
else:
@ -541,9 +544,7 @@ async def convert_message_to_openai_dict_new(
elif isinstance(content_, ImageContentItem):
return OpenAIChatCompletionContentPartImageParam(
type="image_url",
image_url=OpenAIImageURL(
url=await convert_image_content_to_url(content_)
),
image_url=OpenAIImageURL(url=await convert_image_content_to_url(content_)),
)
elif isinstance(content_, list):
return [await impl(item) for item in content_]
@ -587,9 +588,7 @@ async def convert_message_to_openai_dict_new(
"id": tool.call_id,
"function": {
"name": (
tool.tool_name
if not isinstance(tool.tool_name, BuiltinTool)
else tool.tool_name.value
tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value
),
"arguments": json.dumps(tool.arguments),
},
@ -709,11 +708,7 @@ def convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict:
properties = parameters["properties"]
required = []
for param_name, param in tool.parameters.items():
properties[param_name] = {
"type": PYTHON_TYPE_TO_LITELLM_TYPE.get(
param.param_type, param.param_type
)
}
properties[param_name] = {"type": PYTHON_TYPE_TO_LITELLM_TYPE.get(param.param_type, param.param_type)}
if param.description:
properties[param_name].update(description=param.description)
if param.default:
@ -834,11 +829,7 @@ def _convert_openai_logprobs(
return None
return [
TokenLogProbs(
logprobs_by_token={
logprobs.token: logprobs.logprob for logprobs in content.top_logprobs
}
)
TokenLogProbs(logprobs_by_token={logprobs.token: logprobs.logprob for logprobs in content.top_logprobs})
for content in logprobs.content
]
@ -876,17 +867,14 @@ def convert_openai_chat_completion_choice(
end_of_message = "end_of_message"
out_of_tokens = "out_of_tokens"
"""
assert (
hasattr(choice, "message") and choice.message
), "error in server response: message not found"
assert (
hasattr(choice, "finish_reason") and choice.finish_reason
), "error in server response: finish_reason not found"
assert hasattr(choice, "message") and choice.message, "error in server response: message not found"
assert hasattr(choice, "finish_reason") and choice.finish_reason, (
"error in server response: finish_reason not found"
)
return ChatCompletionResponse(
completion_message=CompletionMessage(
content=choice.message.content
or "", # CompletionMessage content is not optional
content=choice.message.content or "", # CompletionMessage content is not optional
stop_reason=_convert_openai_finish_reason(choice.finish_reason),
tool_calls=_convert_openai_tool_calls(choice.message.tool_calls),
),
@ -895,9 +883,7 @@ def convert_openai_chat_completion_choice(
async def convert_openai_chat_completion_stream(
stream: Union[
AsyncStream[OpenAIChatCompletionChunk], Stream[OpenAIChatCompletionChunk]
],
stream: AsyncStream[OpenAIChatCompletionChunk],
enable_incremental_tool_calls: bool,
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
"""
@ -905,14 +891,6 @@ async def convert_openai_chat_completion_stream(
of ChatCompletionResponseStreamChunk.
"""
async def yield_from_stream(stream):
if isinstance(stream, AsyncGenerator):
async for chunk in stream:
yield chunk
elif isinstance(stream, Generator):
for chunk in stream:
yield chunk
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
@ -924,7 +902,7 @@ async def convert_openai_chat_completion_stream(
stop_reason = None
tool_call_idx_to_buffer = {}
async for chunk in yield_from_stream(stream):
async for chunk in stream:
choice = chunk.choices[0] # assuming only one choice per chunk
# we assume there's only one finish_reason in the stream
@ -1092,7 +1070,7 @@ async def convert_openai_chat_completion_stream(
stop_reason=stop_reason,
)
)
except json.JSONDecodeError as e:
except json.JSONDecodeError:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
@ -1137,7 +1115,7 @@ async def convert_openai_chat_completion_stream(
stop_reason=stop_reason,
)
)
except (KeyError, json.JSONDecodeError) as e:
except (KeyError, json.JSONDecodeError):
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
@ -1158,26 +1136,6 @@ async def convert_openai_chat_completion_stream(
)
async def convert_completion_request_to_openai_params(
request: CompletionRequest,
) -> dict:
"""
Convert a ChatCompletionRequest to an OpenAI API-compatible dictionary.
"""
input_dict = {}
if request.logprobs:
input_dict["logprobs"] = request.logprobs.top_k
return {
"model": request.model,
"prompt": request.content,
**input_dict,
"stream": request.stream,
**get_sampling_options(request.sampling_params),
"n": 1,
}
async def convert_chat_completion_request_to_openai_params(
request: ChatCompletionRequest,
) -> dict:
@ -1186,14 +1144,10 @@ async def convert_chat_completion_request_to_openai_params(
"""
input_dict = {}
input_dict["messages"] = [
await convert_message_to_openai_dict_new(m) for m in request.messages
]
input_dict["messages"] = [await convert_message_to_openai_dict_new(m) for m in request.messages]
if fmt := request.response_format:
if not isinstance(fmt, JsonSchemaResponseFormat):
raise ValueError(
f"Unsupported response format: {type(fmt)}. Only JsonSchemaResponseFormat is supported."
)
raise ValueError(f"Unsupported response format: {type(fmt)}. Only JsonSchemaResponseFormat is supported.")
fmt = fmt.json_schema
name = fmt["title"]
@ -1217,9 +1171,7 @@ async def convert_chat_completion_request_to_openai_params(
}
if request.tools:
input_dict["tools"] = [
convert_tooldef_to_openai_tool(tool) for tool in request.tools
]
input_dict["tools"] = [convert_tooldef_to_openai_tool(tool) for tool in request.tools]
if request.tool_config.tool_choice:
input_dict["tool_choice"] = (
request.tool_config.tool_choice.value