forked from phoenix-oss/llama-stack-mirror
chore: enable pyupgrade fixes (#1806)
# What does this PR do? The goal of this PR is code base modernization. Schema reflection code needed a minor adjustment to handle UnionTypes and collections.abc.AsyncIterator. (Both are preferred for latest Python releases.) Note to reviewers: almost all changes here are automatically generated by pyupgrade. Some additional unused imports were cleaned up. The only change worth of note can be found under `docs/openapi_generator` and `llama_stack/strong_typing/schema.py` where reflection code was updated to deal with "newer" types. Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
parent
ffe3d0b2cd
commit
9e6561a1ec
319 changed files with 2843 additions and 3033 deletions
|
@ -18,8 +18,9 @@ import os
|
|||
import tempfile
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Callable, Generator
|
||||
from enum import Enum
|
||||
from typing import Callable, Generator, List, Literal, Optional, Tuple, Union
|
||||
from typing import Annotated, Literal
|
||||
|
||||
import torch
|
||||
import zmq
|
||||
|
@ -30,7 +31,6 @@ from fairscale.nn.model_parallel.initialize import (
|
|||
)
|
||||
from pydantic import BaseModel, Field
|
||||
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.models.llama.datatypes import GenerationResult
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
|
@ -69,15 +69,15 @@ class CancelSentinel(BaseModel):
|
|||
|
||||
class TaskRequest(BaseModel):
|
||||
type: Literal[ProcessingMessageName.task_request] = ProcessingMessageName.task_request
|
||||
task: Tuple[
|
||||
task: tuple[
|
||||
str,
|
||||
List[CompletionRequestWithRawContent] | List[ChatCompletionRequestWithRawContent],
|
||||
list[CompletionRequestWithRawContent] | list[ChatCompletionRequestWithRawContent],
|
||||
]
|
||||
|
||||
|
||||
class TaskResponse(BaseModel):
|
||||
type: Literal[ProcessingMessageName.task_response] = ProcessingMessageName.task_response
|
||||
result: List[GenerationResult]
|
||||
result: list[GenerationResult]
|
||||
|
||||
|
||||
class ExceptionResponse(BaseModel):
|
||||
|
@ -85,15 +85,9 @@ class ExceptionResponse(BaseModel):
|
|||
error: str
|
||||
|
||||
|
||||
ProcessingMessage = Union[
|
||||
ReadyRequest,
|
||||
ReadyResponse,
|
||||
EndSentinel,
|
||||
CancelSentinel,
|
||||
TaskRequest,
|
||||
TaskResponse,
|
||||
ExceptionResponse,
|
||||
]
|
||||
ProcessingMessage = (
|
||||
ReadyRequest | ReadyResponse | EndSentinel | CancelSentinel | TaskRequest | TaskResponse | ExceptionResponse
|
||||
)
|
||||
|
||||
|
||||
class ProcessingMessageWrapper(BaseModel):
|
||||
|
@ -203,7 +197,7 @@ def maybe_get_work(sock: zmq.Socket):
|
|||
return client_id, message
|
||||
|
||||
|
||||
def maybe_parse_message(maybe_json: Optional[str]) -> Optional[ProcessingMessage]:
|
||||
def maybe_parse_message(maybe_json: str | None) -> ProcessingMessage | None:
|
||||
if maybe_json is None:
|
||||
return None
|
||||
try:
|
||||
|
@ -334,9 +328,9 @@ class ModelParallelProcessGroup:
|
|||
|
||||
def run_inference(
|
||||
self,
|
||||
req: Tuple[
|
||||
req: tuple[
|
||||
str,
|
||||
List[CompletionRequestWithRawContent] | List[ChatCompletionRequestWithRawContent],
|
||||
list[CompletionRequestWithRawContent] | list[ChatCompletionRequestWithRawContent],
|
||||
],
|
||||
) -> Generator:
|
||||
assert not self.running, "inference already running"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue