chore: enable pyupgrade fixes (#1806)

# What does this PR do?

The goal of this PR is code base modernization.

Schema reflection code needed a minor adjustment to handle UnionTypes
and collections.abc.AsyncIterator. (Both are preferred for latest Python
releases.)

Note to reviewers: almost all changes here are automatically generated
by pyupgrade. Some additional unused imports were cleaned up. The only
change worth of note can be found under `docs/openapi_generator` and
`llama_stack/strong_typing/schema.py` where reflection code was updated
to deal with "newer" types.

Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
Ihar Hrachyshka 2025-05-01 17:23:50 -04:00 committed by GitHub
parent ffe3d0b2cd
commit 9e6561a1ec
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
319 changed files with 2843 additions and 3033 deletions

View file

@ -4,8 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from llama_stack.models.llama.sku_list import all_registered_models
from llama_stack.models.llama.sku_types import * # noqa: F403
@ -22,7 +20,7 @@ def is_supported_safety_model(model: Model) -> bool:
]
def supported_inference_models() -> List[Model]:
def supported_inference_models() -> list[Model]:
return [
m
for m in all_registered_models()

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
import logging
from typing import TYPE_CHECKING, List, Optional
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from sentence_transformers import SentenceTransformer
@ -31,10 +31,10 @@ class SentenceTransformerEmbeddingMixin:
async def embeddings(
self,
model_id: str,
contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
contents: list[str] | list[InterleavedContentItem],
text_truncation: TextTruncation | None = TextTruncation.none,
output_dimension: int | None = None,
task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
embedding_model = self._load_sentence_transformer_model(model.provider_resource_id)

View file

@ -4,7 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
from collections.abc import AsyncGenerator, AsyncIterator
from typing import Any
import litellm
@ -64,7 +65,7 @@ class LiteLLMOpenAIMixin(
def __init__(
self,
model_entries,
api_key_from_config: Optional[str],
api_key_from_config: str | None,
provider_data_api_key_field: str,
openai_compat_api_base: str | None = None,
):
@ -97,26 +98,26 @@ class LiteLLMOpenAIMixin(
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
) -> AsyncGenerator:
raise NotImplementedError("LiteLLM does not support completion requests")
async def chat_completion(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
messages: list[Message],
sampling_params: SamplingParams | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
if sampling_params is None:
sampling_params = SamplingParams()
@ -243,10 +244,10 @@ class LiteLLMOpenAIMixin(
async def embeddings(
self,
model_id: str,
contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
contents: list[str] | list[InterleavedContentItem],
text_truncation: TextTruncation | None = TextTruncation.none,
output_dimension: int | None = None,
task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
@ -261,24 +262,24 @@ class LiteLLMOpenAIMixin(
async def openai_completion(
self,
model: str,
prompt: Union[str, List[str], List[int], List[List[int]]],
best_of: Optional[int] = None,
echo: Optional[bool] = None,
frequency_penalty: Optional[float] = None,
logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
presence_penalty: Optional[float] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
stream_options: Optional[Dict[str, Any]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
guided_choice: Optional[List[str]] = None,
prompt_logprobs: Optional[int] = None,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
) -> OpenAICompletion:
model_obj = await self.model_store.get_model(model)
params = await prepare_openai_completion_params(
@ -309,29 +310,29 @@ class LiteLLMOpenAIMixin(
async def openai_chat_completion(
self,
model: str,
messages: List[OpenAIMessageParam],
frequency_penalty: Optional[float] = None,
function_call: Optional[Union[str, Dict[str, Any]]] = None,
functions: Optional[List[Dict[str, Any]]] = None,
logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
max_completion_tokens: Optional[int] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
parallel_tool_calls: Optional[bool] = None,
presence_penalty: Optional[float] = None,
response_format: Optional[OpenAIResponseFormatParam] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
stream_options: Optional[Dict[str, Any]] = None,
temperature: Optional[float] = None,
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
tools: Optional[List[Dict[str, Any]]] = None,
top_logprobs: Optional[int] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
model_obj = await self.model_store.get_model(model)
params = await prepare_openai_completion_params(
model=self.get_litellm_model_name(model_obj.provider_resource_id),
@ -365,21 +366,21 @@ class LiteLLMOpenAIMixin(
async def batch_completion(
self,
model_id: str,
content_batch: List[InterleavedContent],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
content_batch: list[InterleavedContent],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
):
raise NotImplementedError("Batch completion is not supported for OpenAI Compat")
async def batch_chat_completion(
self,
model_id: str,
messages_batch: List[List[Message]],
sampling_params: Optional[SamplingParams] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_config: Optional[ToolConfig] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
messages_batch: list[list[Message]],
sampling_params: SamplingParams | None = None,
tools: list[ToolDefinition] | None = None,
tool_config: ToolConfig | None = None,
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
):
raise NotImplementedError("Batch chat completion is not supported for OpenAI Compat")

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Optional
from typing import Any
from pydantic import BaseModel, Field
@ -20,13 +20,13 @@ from llama_stack.providers.utils.inference import (
# more closer to the Model class.
class ProviderModelEntry(BaseModel):
provider_model_id: str
aliases: List[str] = Field(default_factory=list)
llama_model: Optional[str] = None
aliases: list[str] = Field(default_factory=list)
llama_model: str | None = None
model_type: ModelType = ModelType.llm
metadata: Dict[str, Any] = Field(default_factory=dict)
metadata: dict[str, Any] = Field(default_factory=dict)
def get_huggingface_repo(model_descriptor: str) -> Optional[str]:
def get_huggingface_repo(model_descriptor: str) -> str | None:
for model in all_registered_models():
if model.descriptor() == model_descriptor:
return model.huggingface_repo
@ -34,7 +34,7 @@ def get_huggingface_repo(model_descriptor: str) -> Optional[str]:
def build_hf_repo_model_entry(
provider_model_id: str, model_descriptor: str, additional_aliases: Optional[List[str]] = None
provider_model_id: str, model_descriptor: str, additional_aliases: list[str] | None = None
) -> ProviderModelEntry:
aliases = [
get_huggingface_repo(model_descriptor),
@ -58,7 +58,7 @@ def build_model_entry(provider_model_id: str, model_descriptor: str) -> Provider
class ModelRegistryHelper(ModelsProtocolPrivate):
def __init__(self, model_entries: List[ProviderModelEntry]):
def __init__(self, model_entries: list[ProviderModelEntry]):
self.alias_to_provider_id_map = {}
self.provider_id_to_llama_model_map = {}
for entry in model_entries:
@ -72,11 +72,11 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
self.alias_to_provider_id_map[entry.llama_model] = entry.provider_model_id
self.provider_id_to_llama_model_map[entry.provider_model_id] = entry.llama_model
def get_provider_model_id(self, identifier: str) -> Optional[str]:
def get_provider_model_id(self, identifier: str) -> str | None:
return self.alias_to_provider_id_map.get(identifier, None)
# TODO: why keep a separate llama model mapping?
def get_llama_model(self, provider_model_id: str) -> Optional[str]:
def get_llama_model(self, provider_model_id: str) -> str | None:
return self.provider_id_to_llama_model_map.get(provider_model_id, None)
async def register_model(self, model: Model) -> Model:

View file

@ -8,16 +8,9 @@ import logging
import time
import uuid
import warnings
from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Iterable
from typing import (
Any,
AsyncGenerator,
AsyncIterator,
Awaitable,
Dict,
Iterable,
List,
Optional,
Union,
)
from openai import AsyncStream
@ -141,24 +134,24 @@ class OpenAICompatCompletionChoiceDelta(BaseModel):
class OpenAICompatLogprobs(BaseModel):
text_offset: Optional[List[int]] = None
text_offset: list[int] | None = None
token_logprobs: Optional[List[float]] = None
token_logprobs: list[float] | None = None
tokens: Optional[List[str]] = None
tokens: list[str] | None = None
top_logprobs: Optional[List[Dict[str, float]]] = None
top_logprobs: list[dict[str, float]] | None = None
class OpenAICompatCompletionChoice(BaseModel):
finish_reason: Optional[str] = None
text: Optional[str] = None
delta: Optional[OpenAICompatCompletionChoiceDelta] = None
logprobs: Optional[OpenAICompatLogprobs] = None
finish_reason: str | None = None
text: str | None = None
delta: OpenAICompatCompletionChoiceDelta | None = None
logprobs: OpenAICompatLogprobs | None = None
class OpenAICompatCompletionResponse(BaseModel):
choices: List[OpenAICompatCompletionChoice]
choices: list[OpenAICompatCompletionChoice]
def get_sampling_strategy_options(params: SamplingParams) -> dict:
@ -217,8 +210,8 @@ def get_stop_reason(finish_reason: str) -> StopReason:
def convert_openai_completion_logprobs(
logprobs: Optional[OpenAICompatLogprobs],
) -> Optional[List[TokenLogProbs]]:
logprobs: OpenAICompatLogprobs | None,
) -> list[TokenLogProbs] | None:
if not logprobs:
return None
if hasattr(logprobs, "top_logprobs"):
@ -235,7 +228,7 @@ def convert_openai_completion_logprobs(
return None
def convert_openai_completion_logprobs_stream(text: str, logprobs: Optional[Union[float, OpenAICompatLogprobs]]):
def convert_openai_completion_logprobs_stream(text: str, logprobs: float | OpenAICompatLogprobs | None):
if logprobs is None:
return None
if isinstance(logprobs, float):
@ -562,7 +555,7 @@ class UnparseableToolCall(BaseModel):
async def convert_message_to_openai_dict_new(
message: Message | Dict,
message: Message | dict,
) -> OpenAIChatCompletionMessage:
"""
Convert a Message to an OpenAI API-compatible dictionary.
@ -591,14 +584,10 @@ async def convert_message_to_openai_dict_new(
# List[...] -> List[...]
async def _convert_message_content(
content: InterleavedContent,
) -> Union[str, Iterable[OpenAIChatCompletionContentPartParam]]:
) -> str | Iterable[OpenAIChatCompletionContentPartParam]:
async def impl(
content_: InterleavedContent,
) -> Union[
str,
OpenAIChatCompletionContentPartParam,
List[OpenAIChatCompletionContentPartParam],
]:
) -> str | OpenAIChatCompletionContentPartParam | list[OpenAIChatCompletionContentPartParam]:
# Llama Stack and OpenAI spec match for str and text input
if isinstance(content_, str):
return content_
@ -670,7 +659,7 @@ async def convert_message_to_openai_dict_new(
def convert_tool_call(
tool_call: ChatCompletionMessageToolCall,
) -> Union[ToolCall, UnparseableToolCall]:
) -> ToolCall | UnparseableToolCall:
"""
Convert a ChatCompletionMessageToolCall tool call to either a
ToolCall or UnparseableToolCall. Returns an UnparseableToolCall
@ -846,7 +835,7 @@ def _convert_openai_finish_reason(finish_reason: str) -> StopReason:
}.get(finish_reason, StopReason.end_of_turn)
def _convert_openai_request_tool_config(tool_choice: Optional[Union[str, Dict[str, Any]]] = None) -> ToolConfig:
def _convert_openai_request_tool_config(tool_choice: str | dict[str, Any] | None = None) -> ToolConfig:
tool_config = ToolConfig()
if tool_choice:
try:
@ -857,7 +846,7 @@ def _convert_openai_request_tool_config(tool_choice: Optional[Union[str, Dict[st
return tool_config
def _convert_openai_request_tools(tools: Optional[List[Dict[str, Any]]] = None) -> List[ToolDefinition]:
def _convert_openai_request_tools(tools: list[dict[str, Any]] | None = None) -> list[ToolDefinition]:
lls_tools = []
if not tools:
return lls_tools
@ -903,8 +892,8 @@ def _convert_openai_request_response_format(
def _convert_openai_tool_calls(
tool_calls: List[OpenAIChatCompletionMessageToolCall],
) -> List[ToolCall]:
tool_calls: list[OpenAIChatCompletionMessageToolCall],
) -> list[ToolCall]:
"""
Convert an OpenAI ChatCompletionMessageToolCall list into a list of ToolCall.
@ -940,7 +929,7 @@ def _convert_openai_tool_calls(
def _convert_openai_logprobs(
logprobs: OpenAIChoiceLogprobs,
) -> Optional[List[TokenLogProbs]]:
) -> list[TokenLogProbs] | None:
"""
Convert an OpenAI ChoiceLogprobs into a list of TokenLogProbs.
@ -973,9 +962,9 @@ def _convert_openai_logprobs(
def _convert_openai_sampling_params(
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
max_tokens: int | None = None,
temperature: float | None = None,
top_p: float | None = None,
) -> SamplingParams:
sampling_params = SamplingParams()
@ -998,8 +987,8 @@ def _convert_openai_sampling_params(
def openai_messages_to_messages(
messages: List[OpenAIChatCompletionMessage],
) -> List[Message]:
messages: list[OpenAIChatCompletionMessage],
) -> list[Message]:
"""
Convert a list of OpenAIChatCompletionMessage into a list of Message.
"""
@ -1027,7 +1016,7 @@ def openai_messages_to_messages(
return converted_messages
def openai_content_to_content(content: Union[str, Iterable[OpenAIChatCompletionContentPartParam]]):
def openai_content_to_content(content: str | Iterable[OpenAIChatCompletionContentPartParam]):
if isinstance(content, str):
return content
elif isinstance(content, list):
@ -1273,24 +1262,24 @@ class OpenAICompletionToLlamaStackMixin:
async def openai_completion(
self,
model: str,
prompt: Union[str, List[str], List[int], List[List[int]]],
best_of: Optional[int] = None,
echo: Optional[bool] = None,
frequency_penalty: Optional[float] = None,
logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
presence_penalty: Optional[float] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
stream_options: Optional[Dict[str, Any]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
guided_choice: Optional[List[str]] = None,
prompt_logprobs: Optional[int] = None,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
) -> OpenAICompletion:
if stream:
raise ValueError(f"{self.__class__.__name__} doesn't support streaming openai completions")
@ -1342,29 +1331,29 @@ class OpenAIChatCompletionToLlamaStackMixin:
async def openai_chat_completion(
self,
model: str,
messages: List[OpenAIChatCompletionMessage],
frequency_penalty: Optional[float] = None,
function_call: Optional[Union[str, Dict[str, Any]]] = None,
functions: Optional[List[Dict[str, Any]]] = None,
logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
max_completion_tokens: Optional[int] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
parallel_tool_calls: Optional[bool] = None,
presence_penalty: Optional[float] = None,
response_format: Optional[OpenAIResponseFormatParam] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
stream_options: Optional[Dict[str, Any]] = None,
temperature: Optional[float] = None,
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
tools: Optional[List[Dict[str, Any]]] = None,
top_logprobs: Optional[int] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
messages: list[OpenAIChatCompletionMessage],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
messages = openai_messages_to_messages(messages)
response_format = _convert_openai_request_response_format(response_format)
sampling_params = _convert_openai_sampling_params(
@ -1403,7 +1392,7 @@ class OpenAIChatCompletionToLlamaStackMixin:
async def _process_stream_response(
self,
model: str,
outstanding_responses: List[Awaitable[AsyncIterator[ChatCompletionResponseStreamChunk]]],
outstanding_responses: list[Awaitable[AsyncIterator[ChatCompletionResponseStreamChunk]]],
):
id = f"chatcmpl-{uuid.uuid4()}"
for outstanding_response in outstanding_responses:
@ -1466,7 +1455,7 @@ class OpenAIChatCompletionToLlamaStackMixin:
i = i + 1
async def _process_non_stream_response(
self, model: str, outstanding_responses: List[Awaitable[ChatCompletionResponse]]
self, model: str, outstanding_responses: list[Awaitable[ChatCompletionResponse]]
) -> OpenAIChatCompletion:
choices = []
for outstanding_response in outstanding_responses:

View file

@ -9,7 +9,6 @@ import base64
import io
import json
import re
from typing import List, Optional, Tuple, Union
import httpx
from PIL import Image as PIL_Image
@ -63,7 +62,7 @@ log = get_logger(name=__name__, category="inference")
class ChatCompletionRequestWithRawContent(ChatCompletionRequest):
messages: List[RawMessage]
messages: list[RawMessage]
class CompletionRequestWithRawContent(CompletionRequest):
@ -93,8 +92,8 @@ def interleaved_content_as_str(content: InterleavedContent, sep: str = " ") -> s
async def convert_request_to_raw(
request: Union[ChatCompletionRequest, CompletionRequest],
) -> Union[ChatCompletionRequestWithRawContent, CompletionRequestWithRawContent]:
request: ChatCompletionRequest | CompletionRequest,
) -> ChatCompletionRequestWithRawContent | CompletionRequestWithRawContent:
if isinstance(request, ChatCompletionRequest):
messages = []
for m in request.messages:
@ -170,18 +169,18 @@ def content_has_media(content: InterleavedContent):
return _has_media_content(content)
def messages_have_media(messages: List[Message]):
def messages_have_media(messages: list[Message]):
return any(content_has_media(m.content) for m in messages)
def request_has_media(request: Union[ChatCompletionRequest, CompletionRequest]):
def request_has_media(request: ChatCompletionRequest | CompletionRequest):
if isinstance(request, ChatCompletionRequest):
return messages_have_media(request.messages)
else:
return content_has_media(request.content)
async def localize_image_content(media: ImageContentItem) -> Tuple[bytes, str]:
async def localize_image_content(media: ImageContentItem) -> tuple[bytes, str]:
image = media.image
if image.url and image.url.uri.startswith("http"):
async with httpx.AsyncClient() as client:
@ -228,7 +227,7 @@ async def completion_request_to_prompt(request: CompletionRequest) -> str:
async def completion_request_to_prompt_model_input_info(
request: CompletionRequest,
) -> Tuple[str, int]:
) -> tuple[str, int]:
content = augment_content_with_response_format_prompt(request.response_format, request.content)
request.content = content
request = await convert_request_to_raw(request)
@ -265,7 +264,7 @@ async def chat_completion_request_to_prompt(request: ChatCompletionRequest, llam
async def chat_completion_request_to_model_input_info(
request: ChatCompletionRequest, llama_model: str
) -> Tuple[str, int]:
) -> tuple[str, int]:
messages = chat_completion_request_to_messages(request, llama_model)
request.messages = messages
request = await convert_request_to_raw(request)
@ -284,7 +283,7 @@ async def chat_completion_request_to_model_input_info(
def chat_completion_request_to_messages(
request: ChatCompletionRequest,
llama_model: str,
) -> List[Message]:
) -> list[Message]:
"""Reads chat completion request and augments the messages to handle tools.
For eg. for llama_3_1, add system message with the appropriate tools or
add user messsage for custom tools, etc.
@ -323,7 +322,7 @@ def chat_completion_request_to_messages(
return messages
def response_format_prompt(fmt: Optional[ResponseFormat]):
def response_format_prompt(fmt: ResponseFormat | None):
if not fmt:
return None
@ -337,7 +336,7 @@ def response_format_prompt(fmt: Optional[ResponseFormat]):
def augment_messages_for_tools_llama_3_1(
request: ChatCompletionRequest,
) -> List[Message]:
) -> list[Message]:
existing_messages = request.messages
existing_system_message = None
if existing_messages[0].role == Role.system.value:
@ -406,7 +405,7 @@ def augment_messages_for_tools_llama_3_1(
def augment_messages_for_tools_llama(
request: ChatCompletionRequest,
custom_tool_prompt_generator,
) -> List[Message]:
) -> list[Message]:
existing_messages = request.messages
existing_system_message = None
if existing_messages[0].role == Role.system.value:
@ -457,7 +456,7 @@ def augment_messages_for_tools_llama(
return messages
def _get_tool_choice_prompt(tool_choice: ToolChoice | str, tools: List[ToolDefinition]) -> str:
def _get_tool_choice_prompt(tool_choice: ToolChoice | str, tools: list[ToolDefinition]) -> str:
if tool_choice == ToolChoice.auto:
return ""
elif tool_choice == ToolChoice.required: