mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
chore: enable pyupgrade fixes (#1806)
# What does this PR do? The goal of this PR is code base modernization. Schema reflection code needed a minor adjustment to handle UnionTypes and collections.abc.AsyncIterator. (Both are preferred for latest Python releases.) Note to reviewers: almost all changes here are automatically generated by pyupgrade. Some additional unused imports were cleaned up. The only change worth of note can be found under `docs/openapi_generator` and `llama_stack/strong_typing/schema.py` where reflection code was updated to deal with "newer" types. Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
parent
ffe3d0b2cd
commit
9e6561a1ec
319 changed files with 2843 additions and 3033 deletions
|
@ -4,20 +4,10 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Protocol,
|
||||
Union,
|
||||
runtime_checkable,
|
||||
)
|
||||
from typing import Annotated, Any, Literal, Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
@ -79,8 +69,8 @@ class StepCommon(BaseModel):
|
|||
|
||||
turn_id: str
|
||||
step_id: str
|
||||
started_at: Optional[datetime] = None
|
||||
completed_at: Optional[datetime] = None
|
||||
started_at: datetime | None = None
|
||||
completed_at: datetime | None = None
|
||||
|
||||
|
||||
class StepType(Enum):
|
||||
|
@ -120,8 +110,8 @@ class ToolExecutionStep(StepCommon):
|
|||
"""
|
||||
|
||||
step_type: Literal[StepType.tool_execution.value] = StepType.tool_execution.value
|
||||
tool_calls: List[ToolCall]
|
||||
tool_responses: List[ToolResponse]
|
||||
tool_calls: list[ToolCall]
|
||||
tool_responses: list[ToolResponse]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -132,7 +122,7 @@ class ShieldCallStep(StepCommon):
|
|||
"""
|
||||
|
||||
step_type: Literal[StepType.shield_call.value] = StepType.shield_call.value
|
||||
violation: Optional[SafetyViolation]
|
||||
violation: SafetyViolation | None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -150,12 +140,7 @@ class MemoryRetrievalStep(StepCommon):
|
|||
|
||||
|
||||
Step = Annotated[
|
||||
Union[
|
||||
InferenceStep,
|
||||
ToolExecutionStep,
|
||||
ShieldCallStep,
|
||||
MemoryRetrievalStep,
|
||||
],
|
||||
InferenceStep | ToolExecutionStep | ShieldCallStep | MemoryRetrievalStep,
|
||||
Field(discriminator="step_type"),
|
||||
]
|
||||
|
||||
|
@ -166,18 +151,13 @@ class Turn(BaseModel):
|
|||
|
||||
turn_id: str
|
||||
session_id: str
|
||||
input_messages: List[
|
||||
Union[
|
||||
UserMessage,
|
||||
ToolResponseMessage,
|
||||
]
|
||||
]
|
||||
steps: List[Step]
|
||||
input_messages: list[UserMessage | ToolResponseMessage]
|
||||
steps: list[Step]
|
||||
output_message: CompletionMessage
|
||||
output_attachments: Optional[List[Attachment]] = Field(default_factory=list)
|
||||
output_attachments: list[Attachment] | None = Field(default_factory=list)
|
||||
|
||||
started_at: datetime
|
||||
completed_at: Optional[datetime] = None
|
||||
completed_at: datetime | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -186,34 +166,31 @@ class Session(BaseModel):
|
|||
|
||||
session_id: str
|
||||
session_name: str
|
||||
turns: List[Turn]
|
||||
turns: list[Turn]
|
||||
started_at: datetime
|
||||
|
||||
|
||||
class AgentToolGroupWithArgs(BaseModel):
|
||||
name: str
|
||||
args: Dict[str, Any]
|
||||
args: dict[str, Any]
|
||||
|
||||
|
||||
AgentToolGroup = Union[
|
||||
str,
|
||||
AgentToolGroupWithArgs,
|
||||
]
|
||||
AgentToolGroup = str | AgentToolGroupWithArgs
|
||||
register_schema(AgentToolGroup, name="AgentTool")
|
||||
|
||||
|
||||
class AgentConfigCommon(BaseModel):
|
||||
sampling_params: Optional[SamplingParams] = Field(default_factory=SamplingParams)
|
||||
sampling_params: SamplingParams | None = Field(default_factory=SamplingParams)
|
||||
|
||||
input_shields: Optional[List[str]] = Field(default_factory=list)
|
||||
output_shields: Optional[List[str]] = Field(default_factory=list)
|
||||
toolgroups: Optional[List[AgentToolGroup]] = Field(default_factory=list)
|
||||
client_tools: Optional[List[ToolDef]] = Field(default_factory=list)
|
||||
tool_choice: Optional[ToolChoice] = Field(default=None, deprecated="use tool_config instead")
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None, deprecated="use tool_config instead")
|
||||
tool_config: Optional[ToolConfig] = Field(default=None)
|
||||
input_shields: list[str] | None = Field(default_factory=list)
|
||||
output_shields: list[str] | None = Field(default_factory=list)
|
||||
toolgroups: list[AgentToolGroup] | None = Field(default_factory=list)
|
||||
client_tools: list[ToolDef] | None = Field(default_factory=list)
|
||||
tool_choice: ToolChoice | None = Field(default=None, deprecated="use tool_config instead")
|
||||
tool_prompt_format: ToolPromptFormat | None = Field(default=None, deprecated="use tool_config instead")
|
||||
tool_config: ToolConfig | None = Field(default=None)
|
||||
|
||||
max_infer_iters: Optional[int] = 10
|
||||
max_infer_iters: int | None = 10
|
||||
|
||||
def model_post_init(self, __context):
|
||||
if self.tool_config:
|
||||
|
@ -243,9 +220,9 @@ class AgentConfig(AgentConfigCommon):
|
|||
|
||||
model: str
|
||||
instructions: str
|
||||
name: Optional[str] = None
|
||||
enable_session_persistence: Optional[bool] = False
|
||||
response_format: Optional[ResponseFormat] = None
|
||||
name: str | None = None
|
||||
enable_session_persistence: bool | None = False
|
||||
response_format: ResponseFormat | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -257,16 +234,16 @@ class Agent(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class ListAgentsResponse(BaseModel):
|
||||
data: List[Agent]
|
||||
data: list[Agent]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ListAgentSessionsResponse(BaseModel):
|
||||
data: List[Session]
|
||||
data: list[Session]
|
||||
|
||||
|
||||
class AgentConfigOverridablePerTurn(AgentConfigCommon):
|
||||
instructions: Optional[str] = None
|
||||
instructions: str | None = None
|
||||
|
||||
|
||||
class AgentTurnResponseEventType(Enum):
|
||||
|
@ -284,7 +261,7 @@ class AgentTurnResponseStepStartPayload(BaseModel):
|
|||
event_type: Literal[AgentTurnResponseEventType.step_start.value] = AgentTurnResponseEventType.step_start.value
|
||||
step_type: StepType
|
||||
step_id: str
|
||||
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
||||
metadata: dict[str, Any] | None = Field(default_factory=dict)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -327,14 +304,12 @@ class AgentTurnResponseTurnAwaitingInputPayload(BaseModel):
|
|||
|
||||
|
||||
AgentTurnResponseEventPayload = Annotated[
|
||||
Union[
|
||||
AgentTurnResponseStepStartPayload,
|
||||
AgentTurnResponseStepProgressPayload,
|
||||
AgentTurnResponseStepCompletePayload,
|
||||
AgentTurnResponseTurnStartPayload,
|
||||
AgentTurnResponseTurnCompletePayload,
|
||||
AgentTurnResponseTurnAwaitingInputPayload,
|
||||
],
|
||||
AgentTurnResponseStepStartPayload
|
||||
| AgentTurnResponseStepProgressPayload
|
||||
| AgentTurnResponseStepCompletePayload
|
||||
| AgentTurnResponseTurnStartPayload
|
||||
| AgentTurnResponseTurnCompletePayload
|
||||
| AgentTurnResponseTurnAwaitingInputPayload,
|
||||
Field(discriminator="event_type"),
|
||||
]
|
||||
register_schema(AgentTurnResponseEventPayload, name="AgentTurnResponseEventPayload")
|
||||
|
@ -363,18 +338,13 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
|
|||
# TODO: figure out how we can simplify this and make why
|
||||
# ToolResponseMessage needs to be here (it is function call
|
||||
# execution from outside the system)
|
||||
messages: List[
|
||||
Union[
|
||||
UserMessage,
|
||||
ToolResponseMessage,
|
||||
]
|
||||
]
|
||||
messages: list[UserMessage | ToolResponseMessage]
|
||||
|
||||
documents: Optional[List[Document]] = None
|
||||
toolgroups: Optional[List[AgentToolGroup]] = None
|
||||
documents: list[Document] | None = None
|
||||
toolgroups: list[AgentToolGroup] | None = None
|
||||
|
||||
stream: Optional[bool] = False
|
||||
tool_config: Optional[ToolConfig] = None
|
||||
stream: bool | None = False
|
||||
tool_config: ToolConfig | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -382,8 +352,8 @@ class AgentTurnResumeRequest(BaseModel):
|
|||
agent_id: str
|
||||
session_id: str
|
||||
turn_id: str
|
||||
tool_responses: List[ToolResponse]
|
||||
stream: Optional[bool] = False
|
||||
tool_responses: list[ToolResponse]
|
||||
stream: bool | None = False
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -429,17 +399,12 @@ class Agents(Protocol):
|
|||
self,
|
||||
agent_id: str,
|
||||
session_id: str,
|
||||
messages: List[
|
||||
Union[
|
||||
UserMessage,
|
||||
ToolResponseMessage,
|
||||
]
|
||||
],
|
||||
stream: Optional[bool] = False,
|
||||
documents: Optional[List[Document]] = None,
|
||||
toolgroups: Optional[List[AgentToolGroup]] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]:
|
||||
messages: list[UserMessage | ToolResponseMessage],
|
||||
stream: bool | None = False,
|
||||
documents: list[Document] | None = None,
|
||||
toolgroups: list[AgentToolGroup] | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
) -> Turn | AsyncIterator[AgentTurnResponseStreamChunk]:
|
||||
"""Create a new turn for an agent.
|
||||
|
||||
:param agent_id: The ID of the agent to create the turn for.
|
||||
|
@ -463,9 +428,9 @@ class Agents(Protocol):
|
|||
agent_id: str,
|
||||
session_id: str,
|
||||
turn_id: str,
|
||||
tool_responses: List[ToolResponse],
|
||||
stream: Optional[bool] = False,
|
||||
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]:
|
||||
tool_responses: list[ToolResponse],
|
||||
stream: bool | None = False,
|
||||
) -> Turn | AsyncIterator[AgentTurnResponseStreamChunk]:
|
||||
"""Resume an agent turn with executed tool call responses.
|
||||
|
||||
When a Turn has the status `awaiting_input` due to pending input from client side tool calls, this endpoint can be used to submit the outputs from the tool calls once they are ready.
|
||||
|
@ -538,7 +503,7 @@ class Agents(Protocol):
|
|||
self,
|
||||
session_id: str,
|
||||
agent_id: str,
|
||||
turn_ids: Optional[List[str]] = None,
|
||||
turn_ids: list[str] | None = None,
|
||||
) -> Session:
|
||||
"""Retrieve an agent session by its ID.
|
||||
|
||||
|
@ -623,14 +588,14 @@ class Agents(Protocol):
|
|||
@webmethod(route="/openai/v1/responses", method="POST")
|
||||
async def create_openai_response(
|
||||
self,
|
||||
input: Union[str, List[OpenAIResponseInputMessage]],
|
||||
input: str | list[OpenAIResponseInputMessage],
|
||||
model: str,
|
||||
previous_response_id: Optional[str] = None,
|
||||
store: Optional[bool] = True,
|
||||
stream: Optional[bool] = False,
|
||||
temperature: Optional[float] = None,
|
||||
tools: Optional[List[OpenAIResponseInputTool]] = None,
|
||||
) -> Union[OpenAIResponseObject, AsyncIterator[OpenAIResponseObjectStream]]:
|
||||
previous_response_id: str | None = None,
|
||||
store: bool | None = True,
|
||||
stream: bool | None = False,
|
||||
temperature: float | None = None,
|
||||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
) -> OpenAIResponseObject | AsyncIterator[OpenAIResponseObjectStream]:
|
||||
"""Create a new OpenAI response.
|
||||
|
||||
:param input: Input message(s) to create the response.
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue