chore: enable pyupgrade fixes (#1806)

# What does this PR do?

The goal of this PR is code base modernization.

Schema reflection code needed a minor adjustment to handle UnionTypes
and collections.abc.AsyncIterator. (Both are preferred for latest Python
releases.)

Note to reviewers: almost all changes here are automatically generated
by pyupgrade. Some additional unused imports were cleaned up. The only
change worth of note can be found under `docs/openapi_generator` and
`llama_stack/strong_typing/schema.py` where reflection code was updated
to deal with "newer" types.

Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
Ihar Hrachyshka 2025-05-01 17:23:50 -04:00 committed by GitHub
parent ffe3d0b2cd
commit 9e6561a1ec
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
319 changed files with 2843 additions and 3033 deletions

View file

@ -4,14 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from typing import Any
from .config import MetaReferenceInferenceConfig
async def get_provider_impl(
config: MetaReferenceInferenceConfig,
_deps: Dict[str, Any],
_deps: dict[str, Any],
):
from .inference import MetaReferenceInferenceImpl

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, Optional
from typing import Any
from pydantic import BaseModel, field_validator
@ -17,11 +17,11 @@ class MetaReferenceInferenceConfig(BaseModel):
# the actual inference model id is dtermined by the moddel id in the request
# Note: you need to register the model before using it for inference
# models in the resouce list in the run.yaml config will be registered automatically
model: Optional[str] = None
torch_seed: Optional[int] = None
model: str | None = None
torch_seed: int | None = None
max_seq_len: int = 4096
max_batch_size: int = 1
model_parallel_size: Optional[int] = None
model_parallel_size: int | None = None
# when this is False, we assume that the distributed process group is setup by someone
# outside of this code (e.g., when run inside `torchrun`). that is useful for clients
@ -30,9 +30,9 @@ class MetaReferenceInferenceConfig(BaseModel):
# By default, the implementation will look at ~/.llama/checkpoints/<model> but you
# can override by specifying the directory explicitly
checkpoint_dir: Optional[str] = None
checkpoint_dir: str | None = None
quantization: Optional[QuantizationConfig] = None
quantization: QuantizationConfig | None = None
@field_validator("model")
@classmethod
@ -55,7 +55,7 @@ class MetaReferenceInferenceConfig(BaseModel):
max_batch_size: str = "${env.MAX_BATCH_SIZE:1}",
max_seq_len: str = "${env.MAX_SEQ_LEN:4096}",
**kwargs,
) -> Dict[str, Any]:
) -> dict[str, Any]:
return {
"model": model,
"checkpoint_dir": checkpoint_dir,

View file

@ -5,7 +5,8 @@
# the root directory of this source tree.
import math
from typing import Generator, List, Optional, Tuple
from collections.abc import Generator
from typing import Optional
import torch
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
@ -39,7 +40,7 @@ Tokenizer = Llama4Tokenizer | Llama3Tokenizer
class LogitsProcessor:
def __init__(self, token_enforcer: TokenEnforcer):
self.token_enforcer = token_enforcer
self.mask: Optional[torch.Tensor] = None
self.mask: torch.Tensor | None = None
def __call__(self, tokens: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
token_sequence = tokens[0, :].tolist()
@ -58,7 +59,7 @@ class LogitsProcessor:
def get_logits_processor(
tokenizer: Tokenizer,
vocab_size: int,
response_format: Optional[ResponseFormat],
response_format: ResponseFormat | None,
) -> Optional["LogitsProcessor"]:
if response_format is None:
return None
@ -76,7 +77,7 @@ def get_logits_processor(
return LogitsProcessor(token_enforcer)
def _build_regular_tokens_list(tokenizer: Tokenizer, vocab_size: int) -> List[Tuple[int, str, bool]]:
def _build_regular_tokens_list(tokenizer: Tokenizer, vocab_size: int) -> list[tuple[int, str, bool]]:
token_0 = tokenizer.encode("0", bos=False, eos=False)[-1]
regular_tokens = []
@ -158,7 +159,7 @@ class LlamaGenerator:
def completion(
self,
request_batch: List[CompletionRequestWithRawContent],
request_batch: list[CompletionRequestWithRawContent],
) -> Generator:
first_request = request_batch[0]
sampling_params = first_request.sampling_params or SamplingParams()
@ -167,7 +168,7 @@ class LlamaGenerator:
max_gen_len = self.args.max_seq_len - 1
temperature, top_p = _infer_sampling_params(sampling_params)
for result in self.inner_generator.generate(
yield from self.inner_generator.generate(
llm_inputs=[self.formatter.encode_content(request.content) for request in request_batch],
max_gen_len=max_gen_len,
temperature=temperature,
@ -179,12 +180,11 @@ class LlamaGenerator:
self.args.vocab_size,
first_request.response_format,
),
):
yield result
)
def chat_completion(
self,
request_batch: List[ChatCompletionRequestWithRawContent],
request_batch: list[ChatCompletionRequestWithRawContent],
) -> Generator:
first_request = request_batch[0]
sampling_params = first_request.sampling_params or SamplingParams()
@ -193,7 +193,7 @@ class LlamaGenerator:
max_gen_len = self.args.max_seq_len - 1
temperature, top_p = _infer_sampling_params(sampling_params)
for result in self.inner_generator.generate(
yield from self.inner_generator.generate(
llm_inputs=[
self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request))
for request in request_batch
@ -208,5 +208,4 @@ class LlamaGenerator:
self.args.vocab_size,
first_request.response_format,
),
):
yield result
)

View file

@ -6,7 +6,7 @@
import asyncio
import os
from typing import AsyncGenerator, List, Optional, Union
from collections.abc import AsyncGenerator
from pydantic import BaseModel
from termcolor import cprint
@ -184,11 +184,11 @@ class MetaReferenceInferenceImpl(
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
) -> CompletionResponse | CompletionResponseStreamChunk:
if sampling_params is None:
sampling_params = SamplingParams()
if logprobs:
@ -215,11 +215,11 @@ class MetaReferenceInferenceImpl(
async def batch_completion(
self,
model_id: str,
content_batch: List[InterleavedContent],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
content_batch: list[InterleavedContent],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
) -> BatchCompletionResponse:
if sampling_params is None:
sampling_params = SamplingParams()
@ -291,14 +291,14 @@ class MetaReferenceInferenceImpl(
for x in impl():
yield x
async def _nonstream_completion(self, request_batch: List[CompletionRequest]) -> List[CompletionResponse]:
async def _nonstream_completion(self, request_batch: list[CompletionRequest]) -> list[CompletionResponse]:
tokenizer = self.generator.formatter.tokenizer
first_request = request_batch[0]
class ItemState(BaseModel):
tokens: List[int] = []
logprobs: List[TokenLogProbs] = []
tokens: list[int] = []
logprobs: list[TokenLogProbs] = []
stop_reason: StopReason | None = None
finished: bool = False
@ -349,15 +349,15 @@ class MetaReferenceInferenceImpl(
async def chat_completion(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
messages: list[Message],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
@ -395,13 +395,13 @@ class MetaReferenceInferenceImpl(
async def batch_chat_completion(
self,
model_id: str,
messages_batch: List[List[Message]],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
messages_batch: list[list[Message]],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
tools: list[ToolDefinition] | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> BatchChatCompletionResponse:
if sampling_params is None:
sampling_params = SamplingParams()
@ -436,15 +436,15 @@ class MetaReferenceInferenceImpl(
return BatchChatCompletionResponse(batch=results)
async def _nonstream_chat_completion(
self, request_batch: List[ChatCompletionRequest]
) -> List[ChatCompletionResponse]:
self, request_batch: list[ChatCompletionRequest]
) -> list[ChatCompletionResponse]:
tokenizer = self.generator.formatter.tokenizer
first_request = request_batch[0]
class ItemState(BaseModel):
tokens: List[int] = []
logprobs: List[TokenLogProbs] = []
tokens: list[int] = []
logprobs: list[TokenLogProbs] = []
stop_reason: StopReason | None = None
finished: bool = False

View file

@ -4,9 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import Callable, Generator
from copy import deepcopy
from functools import partial
from typing import Any, Callable, Generator, List
from typing import Any
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
@ -82,7 +83,7 @@ class LlamaModelParallelGenerator:
def completion(
self,
request_batch: List[CompletionRequestWithRawContent],
request_batch: list[CompletionRequestWithRawContent],
) -> Generator:
req_obj = deepcopy(request_batch)
gen = self.group.run_inference(("completion", req_obj))
@ -90,7 +91,7 @@ class LlamaModelParallelGenerator:
def chat_completion(
self,
request_batch: List[ChatCompletionRequestWithRawContent],
request_batch: list[ChatCompletionRequestWithRawContent],
) -> Generator:
req_obj = deepcopy(request_batch)
gen = self.group.run_inference(("chat_completion", req_obj))

View file

@ -18,8 +18,9 @@ import os
import tempfile
import time
import uuid
from collections.abc import Callable, Generator
from enum import Enum
from typing import Callable, Generator, List, Literal, Optional, Tuple, Union
from typing import Annotated, Literal
import torch
import zmq
@ -30,7 +31,6 @@ from fairscale.nn.model_parallel.initialize import (
)
from pydantic import BaseModel, Field
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
from typing_extensions import Annotated
from llama_stack.models.llama.datatypes import GenerationResult
from llama_stack.providers.utils.inference.prompt_adapter import (
@ -69,15 +69,15 @@ class CancelSentinel(BaseModel):
class TaskRequest(BaseModel):
type: Literal[ProcessingMessageName.task_request] = ProcessingMessageName.task_request
task: Tuple[
task: tuple[
str,
List[CompletionRequestWithRawContent] | List[ChatCompletionRequestWithRawContent],
list[CompletionRequestWithRawContent] | list[ChatCompletionRequestWithRawContent],
]
class TaskResponse(BaseModel):
type: Literal[ProcessingMessageName.task_response] = ProcessingMessageName.task_response
result: List[GenerationResult]
result: list[GenerationResult]
class ExceptionResponse(BaseModel):
@ -85,15 +85,9 @@ class ExceptionResponse(BaseModel):
error: str
ProcessingMessage = Union[
ReadyRequest,
ReadyResponse,
EndSentinel,
CancelSentinel,
TaskRequest,
TaskResponse,
ExceptionResponse,
]
ProcessingMessage = (
ReadyRequest | ReadyResponse | EndSentinel | CancelSentinel | TaskRequest | TaskResponse | ExceptionResponse
)
class ProcessingMessageWrapper(BaseModel):
@ -203,7 +197,7 @@ def maybe_get_work(sock: zmq.Socket):
return client_id, message
def maybe_parse_message(maybe_json: Optional[str]) -> Optional[ProcessingMessage]:
def maybe_parse_message(maybe_json: str | None) -> ProcessingMessage | None:
if maybe_json is None:
return None
try:
@ -334,9 +328,9 @@ class ModelParallelProcessGroup:
def run_inference(
self,
req: Tuple[
req: tuple[
str,
List[CompletionRequestWithRawContent] | List[ChatCompletionRequestWithRawContent],
list[CompletionRequestWithRawContent] | list[ChatCompletionRequestWithRawContent],
],
) -> Generator:
assert not self.running, "inference already running"

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from typing import Any
from llama_stack.providers.inline.inference.sentence_transformers.config import (
SentenceTransformersInferenceConfig,
@ -13,7 +13,7 @@ from llama_stack.providers.inline.inference.sentence_transformers.config import
async def get_provider_impl(
config: SentenceTransformersInferenceConfig,
_deps: Dict[str, Any],
_deps: dict[str, Any],
):
from .sentence_transformers import SentenceTransformersInferenceImpl

View file

@ -4,12 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from typing import Any
from pydantic import BaseModel
class SentenceTransformersInferenceConfig(BaseModel):
@classmethod
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
return {}

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
import logging
from typing import AsyncGenerator, List, Optional, Union
from collections.abc import AsyncGenerator
from llama_stack.apis.inference import (
CompletionResponse,
@ -60,46 +60,46 @@ class SentenceTransformersInferenceImpl(
self,
model_id: str,
content: str,
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, AsyncGenerator]:
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
) -> CompletionResponse | AsyncGenerator:
raise ValueError("Sentence transformers don't support completion")
async def chat_completion(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
messages: list[Message],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> AsyncGenerator:
raise ValueError("Sentence transformers don't support chat completion")
async def batch_completion(
self,
model_id: str,
content_batch: List[InterleavedContent],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
content_batch: list[InterleavedContent],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
):
raise NotImplementedError("Batch completion is not supported for Sentence Transformers")
async def batch_chat_completion(
self,
model_id: str,
messages_batch: List[List[Message]],
sampling_params: Optional[SamplingParams] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_config: Optional[ToolConfig] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
messages_batch: list[list[Message]],
sampling_params: SamplingParams | None = None,
tools: list[ToolDefinition] | None = None,
tool_config: ToolConfig | None = None,
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
):
raise NotImplementedError("Batch chat completion is not supported for Sentence Transformers")

View file

@ -4,12 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from typing import Any
from .config import VLLMConfig
async def get_provider_impl(config: VLLMConfig, _deps: Dict[str, Any]):
async def get_provider_impl(config: VLLMConfig, _deps: dict[str, Any]):
from .vllm import VLLMInferenceImpl
impl = VLLMInferenceImpl(config)

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from typing import Any
from pydantic import BaseModel, Field
@ -42,7 +42,7 @@ class VLLMConfig(BaseModel):
)
@classmethod
def sample_run_config(cls, **kwargs: Any) -> Dict[str, Any]:
def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]:
return {
"tensor_parallel_size": "${env.TENSOR_PARALLEL_SIZE:1}",
"max_tokens": "${env.MAX_TOKENS:4096}",

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Optional
import vllm
@ -55,8 +54,8 @@ def _merge_context_into_content(message: Message) -> Message: # type: ignore
def _llama_stack_tools_to_openai_tools(
tools: Optional[List[ToolDefinition]] = None,
) -> List[vllm.entrypoints.openai.protocol.ChatCompletionToolsParam]:
tools: list[ToolDefinition] | None = None,
) -> list[vllm.entrypoints.openai.protocol.ChatCompletionToolsParam]:
"""
Convert the list of available tools from Llama Stack's format to vLLM's
version of OpenAI's format.

View file

@ -7,7 +7,7 @@
import json
import re
import uuid
from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
from collections.abc import AsyncGenerator, AsyncIterator
# These vLLM modules contain names that overlap with Llama Stack names, so we import
# fully-qualified names
@ -100,7 +100,7 @@ def _random_uuid_str() -> str:
def _response_format_to_guided_decoding_params(
response_format: Optional[ResponseFormat], # type: ignore
response_format: ResponseFormat | None, # type: ignore
) -> vllm.sampling_params.GuidedDecodingParams:
"""
Translate constrained decoding parameters from Llama Stack's format to vLLM's format.
@ -131,9 +131,9 @@ def _response_format_to_guided_decoding_params(
def _convert_sampling_params(
sampling_params: Optional[SamplingParams],
response_format: Optional[ResponseFormat], # type: ignore
log_prob_config: Optional[LogProbConfig],
sampling_params: SamplingParams | None,
response_format: ResponseFormat | None, # type: ignore
log_prob_config: LogProbConfig | None,
) -> vllm.SamplingParams:
"""Convert sampling and constrained decoding configuration from Llama Stack's format to vLLM's
format."""
@ -370,11 +370,11 @@ class VLLMInferenceImpl(
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
) -> CompletionResponse | AsyncIterator[CompletionResponseStreamChunk]:
if model_id not in self.model_ids:
raise ValueError(
f"This adapter is not registered to model id '{model_id}'. Registered IDs are: {self.model_ids}"
@ -403,25 +403,25 @@ class VLLMInferenceImpl(
async def embeddings(
self,
model_id: str,
contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
contents: list[str] | list[InterleavedContentItem],
text_truncation: TextTruncation | None = TextTruncation.none,
output_dimension: int | None = None,
task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse:
raise NotImplementedError()
async def chat_completion(
self,
model_id: str,
messages: List[Message], # type: ignore
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None, # type: ignore
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
messages: list[Message], # type: ignore
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None, # type: ignore
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> ChatCompletionResponse | ChatCompletionResponseStreamChunk:
sampling_params = sampling_params or SamplingParams()
if model_id not in self.model_ids:
@ -605,7 +605,7 @@ class VLLMInferenceImpl(
async def _chat_completion_for_meta_llama(
self, request: ChatCompletionRequest
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
"""
Subroutine that routes chat completions for Meta Llama models through Llama Stack's
chat template instead of using vLLM's version of that template. The Llama Stack version
@ -701,7 +701,7 @@ class VLLMInferenceImpl(
# Tool calls come in pieces, but Llama Stack expects them in bigger chunks. We build up
# those chunks and output them at the end.
# This data structure holds the current set of partial tool calls.
index_to_tool_call: Dict[int, Dict] = dict()
index_to_tool_call: dict[int, dict] = dict()
# The Llama Stack event stream must always start with a start event. Use an empty one to
# simplify logic below