mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
chore: enable pyupgrade fixes (#1806)
# What does this PR do? The goal of this PR is code base modernization. Schema reflection code needed a minor adjustment to handle UnionTypes and collections.abc.AsyncIterator. (Both are preferred for latest Python releases.) Note to reviewers: almost all changes here are automatically generated by pyupgrade. Some additional unused imports were cleaned up. The only change worth of note can be found under `docs/openapi_generator` and `llama_stack/strong_typing/schema.py` where reflection code was updated to deal with "newer" types. Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
parent
ffe3d0b2cd
commit
9e6561a1ec
319 changed files with 2843 additions and 3033 deletions
|
@ -9,7 +9,6 @@ import base64
|
|||
import io
|
||||
import json
|
||||
import re
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import httpx
|
||||
from PIL import Image as PIL_Image
|
||||
|
@ -63,7 +62,7 @@ log = get_logger(name=__name__, category="inference")
|
|||
|
||||
|
||||
class ChatCompletionRequestWithRawContent(ChatCompletionRequest):
|
||||
messages: List[RawMessage]
|
||||
messages: list[RawMessage]
|
||||
|
||||
|
||||
class CompletionRequestWithRawContent(CompletionRequest):
|
||||
|
@ -93,8 +92,8 @@ def interleaved_content_as_str(content: InterleavedContent, sep: str = " ") -> s
|
|||
|
||||
|
||||
async def convert_request_to_raw(
|
||||
request: Union[ChatCompletionRequest, CompletionRequest],
|
||||
) -> Union[ChatCompletionRequestWithRawContent, CompletionRequestWithRawContent]:
|
||||
request: ChatCompletionRequest | CompletionRequest,
|
||||
) -> ChatCompletionRequestWithRawContent | CompletionRequestWithRawContent:
|
||||
if isinstance(request, ChatCompletionRequest):
|
||||
messages = []
|
||||
for m in request.messages:
|
||||
|
@ -170,18 +169,18 @@ def content_has_media(content: InterleavedContent):
|
|||
return _has_media_content(content)
|
||||
|
||||
|
||||
def messages_have_media(messages: List[Message]):
|
||||
def messages_have_media(messages: list[Message]):
|
||||
return any(content_has_media(m.content) for m in messages)
|
||||
|
||||
|
||||
def request_has_media(request: Union[ChatCompletionRequest, CompletionRequest]):
|
||||
def request_has_media(request: ChatCompletionRequest | CompletionRequest):
|
||||
if isinstance(request, ChatCompletionRequest):
|
||||
return messages_have_media(request.messages)
|
||||
else:
|
||||
return content_has_media(request.content)
|
||||
|
||||
|
||||
async def localize_image_content(media: ImageContentItem) -> Tuple[bytes, str]:
|
||||
async def localize_image_content(media: ImageContentItem) -> tuple[bytes, str]:
|
||||
image = media.image
|
||||
if image.url and image.url.uri.startswith("http"):
|
||||
async with httpx.AsyncClient() as client:
|
||||
|
@ -228,7 +227,7 @@ async def completion_request_to_prompt(request: CompletionRequest) -> str:
|
|||
|
||||
async def completion_request_to_prompt_model_input_info(
|
||||
request: CompletionRequest,
|
||||
) -> Tuple[str, int]:
|
||||
) -> tuple[str, int]:
|
||||
content = augment_content_with_response_format_prompt(request.response_format, request.content)
|
||||
request.content = content
|
||||
request = await convert_request_to_raw(request)
|
||||
|
@ -265,7 +264,7 @@ async def chat_completion_request_to_prompt(request: ChatCompletionRequest, llam
|
|||
|
||||
async def chat_completion_request_to_model_input_info(
|
||||
request: ChatCompletionRequest, llama_model: str
|
||||
) -> Tuple[str, int]:
|
||||
) -> tuple[str, int]:
|
||||
messages = chat_completion_request_to_messages(request, llama_model)
|
||||
request.messages = messages
|
||||
request = await convert_request_to_raw(request)
|
||||
|
@ -284,7 +283,7 @@ async def chat_completion_request_to_model_input_info(
|
|||
def chat_completion_request_to_messages(
|
||||
request: ChatCompletionRequest,
|
||||
llama_model: str,
|
||||
) -> List[Message]:
|
||||
) -> list[Message]:
|
||||
"""Reads chat completion request and augments the messages to handle tools.
|
||||
For eg. for llama_3_1, add system message with the appropriate tools or
|
||||
add user messsage for custom tools, etc.
|
||||
|
@ -323,7 +322,7 @@ def chat_completion_request_to_messages(
|
|||
return messages
|
||||
|
||||
|
||||
def response_format_prompt(fmt: Optional[ResponseFormat]):
|
||||
def response_format_prompt(fmt: ResponseFormat | None):
|
||||
if not fmt:
|
||||
return None
|
||||
|
||||
|
@ -337,7 +336,7 @@ def response_format_prompt(fmt: Optional[ResponseFormat]):
|
|||
|
||||
def augment_messages_for_tools_llama_3_1(
|
||||
request: ChatCompletionRequest,
|
||||
) -> List[Message]:
|
||||
) -> list[Message]:
|
||||
existing_messages = request.messages
|
||||
existing_system_message = None
|
||||
if existing_messages[0].role == Role.system.value:
|
||||
|
@ -406,7 +405,7 @@ def augment_messages_for_tools_llama_3_1(
|
|||
def augment_messages_for_tools_llama(
|
||||
request: ChatCompletionRequest,
|
||||
custom_tool_prompt_generator,
|
||||
) -> List[Message]:
|
||||
) -> list[Message]:
|
||||
existing_messages = request.messages
|
||||
existing_system_message = None
|
||||
if existing_messages[0].role == Role.system.value:
|
||||
|
@ -457,7 +456,7 @@ def augment_messages_for_tools_llama(
|
|||
return messages
|
||||
|
||||
|
||||
def _get_tool_choice_prompt(tool_choice: ToolChoice | str, tools: List[ToolDefinition]) -> str:
|
||||
def _get_tool_choice_prompt(tool_choice: ToolChoice | str, tools: list[ToolDefinition]) -> str:
|
||||
if tool_choice == ToolChoice.auto:
|
||||
return ""
|
||||
elif tool_choice == ToolChoice.required:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue