chore: more mypy fixes (#2029)

# What does this PR do?

Mainly tried to cover the entire llama_stack/apis directory, we only
have one left. Some excludes were just noop.

Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
Sébastien Han 2025-05-06 18:52:31 +02:00 committed by GitHub
parent feb9eb8b0d
commit 1a529705da
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
27 changed files with 581 additions and 166 deletions

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import sys
from collections.abc import AsyncIterator
from datetime import datetime
from enum import Enum
@ -35,6 +36,14 @@ from .openai_responses import (
OpenAIResponseObjectStream,
)
# TODO: use enum.StrEnum when we drop support for python 3.10
if sys.version_info >= (3, 11):
from enum import StrEnum
else:
class StrEnum(str, Enum):
"""Backport of StrEnum for Python 3.10 and below."""
class Attachment(BaseModel):
"""An attachment to an agent turn.
@ -73,7 +82,7 @@ class StepCommon(BaseModel):
completed_at: datetime | None = None
class StepType(Enum):
class StepType(StrEnum):
"""Type of the step in an agent turn.
:cvar inference: The step is an inference step that calls an LLM.
@ -97,7 +106,7 @@ class InferenceStep(StepCommon):
model_config = ConfigDict(protected_namespaces=())
step_type: Literal[StepType.inference.value] = StepType.inference.value
step_type: Literal[StepType.inference] = StepType.inference
model_response: CompletionMessage
@ -109,7 +118,7 @@ class ToolExecutionStep(StepCommon):
:param tool_responses: The tool responses from the tool calls.
"""
step_type: Literal[StepType.tool_execution.value] = StepType.tool_execution.value
step_type: Literal[StepType.tool_execution] = StepType.tool_execution
tool_calls: list[ToolCall]
tool_responses: list[ToolResponse]
@ -121,7 +130,7 @@ class ShieldCallStep(StepCommon):
:param violation: The violation from the shield call.
"""
step_type: Literal[StepType.shield_call.value] = StepType.shield_call.value
step_type: Literal[StepType.shield_call] = StepType.shield_call
violation: SafetyViolation | None
@ -133,7 +142,7 @@ class MemoryRetrievalStep(StepCommon):
:param inserted_context: The context retrieved from the vector databases.
"""
step_type: Literal[StepType.memory_retrieval.value] = StepType.memory_retrieval.value
step_type: Literal[StepType.memory_retrieval] = StepType.memory_retrieval
# TODO: should this be List[str]?
vector_db_ids: str
inserted_context: InterleavedContent
@ -154,7 +163,7 @@ class Turn(BaseModel):
input_messages: list[UserMessage | ToolResponseMessage]
steps: list[Step]
output_message: CompletionMessage
output_attachments: list[Attachment] | None = Field(default_factory=list)
output_attachments: list[Attachment] | None = Field(default_factory=lambda: [])
started_at: datetime
completed_at: datetime | None = None
@ -182,10 +191,10 @@ register_schema(AgentToolGroup, name="AgentTool")
class AgentConfigCommon(BaseModel):
sampling_params: SamplingParams | None = Field(default_factory=SamplingParams)
input_shields: list[str] | None = Field(default_factory=list)
output_shields: list[str] | None = Field(default_factory=list)
toolgroups: list[AgentToolGroup] | None = Field(default_factory=list)
client_tools: list[ToolDef] | None = Field(default_factory=list)
input_shields: list[str] | None = Field(default_factory=lambda: [])
output_shields: list[str] | None = Field(default_factory=lambda: [])
toolgroups: list[AgentToolGroup] | None = Field(default_factory=lambda: [])
client_tools: list[ToolDef] | None = Field(default_factory=lambda: [])
tool_choice: ToolChoice | None = Field(default=None, deprecated="use tool_config instead")
tool_prompt_format: ToolPromptFormat | None = Field(default=None, deprecated="use tool_config instead")
tool_config: ToolConfig | None = Field(default=None)
@ -246,7 +255,7 @@ class AgentConfigOverridablePerTurn(AgentConfigCommon):
instructions: str | None = None
class AgentTurnResponseEventType(Enum):
class AgentTurnResponseEventType(StrEnum):
step_start = "step_start"
step_complete = "step_complete"
step_progress = "step_progress"
@ -258,15 +267,15 @@ class AgentTurnResponseEventType(Enum):
@json_schema_type
class AgentTurnResponseStepStartPayload(BaseModel):
event_type: Literal[AgentTurnResponseEventType.step_start.value] = AgentTurnResponseEventType.step_start.value
event_type: Literal[AgentTurnResponseEventType.step_start] = AgentTurnResponseEventType.step_start
step_type: StepType
step_id: str
metadata: dict[str, Any] | None = Field(default_factory=dict)
metadata: dict[str, Any] | None = Field(default_factory=lambda: {})
@json_schema_type
class AgentTurnResponseStepCompletePayload(BaseModel):
event_type: Literal[AgentTurnResponseEventType.step_complete.value] = AgentTurnResponseEventType.step_complete.value
event_type: Literal[AgentTurnResponseEventType.step_complete] = AgentTurnResponseEventType.step_complete
step_type: StepType
step_id: str
step_details: Step
@ -276,7 +285,7 @@ class AgentTurnResponseStepCompletePayload(BaseModel):
class AgentTurnResponseStepProgressPayload(BaseModel):
model_config = ConfigDict(protected_namespaces=())
event_type: Literal[AgentTurnResponseEventType.step_progress.value] = AgentTurnResponseEventType.step_progress.value
event_type: Literal[AgentTurnResponseEventType.step_progress] = AgentTurnResponseEventType.step_progress
step_type: StepType
step_id: str
@ -285,21 +294,19 @@ class AgentTurnResponseStepProgressPayload(BaseModel):
@json_schema_type
class AgentTurnResponseTurnStartPayload(BaseModel):
event_type: Literal[AgentTurnResponseEventType.turn_start.value] = AgentTurnResponseEventType.turn_start.value
event_type: Literal[AgentTurnResponseEventType.turn_start] = AgentTurnResponseEventType.turn_start
turn_id: str
@json_schema_type
class AgentTurnResponseTurnCompletePayload(BaseModel):
event_type: Literal[AgentTurnResponseEventType.turn_complete.value] = AgentTurnResponseEventType.turn_complete.value
event_type: Literal[AgentTurnResponseEventType.turn_complete] = AgentTurnResponseEventType.turn_complete
turn: Turn
@json_schema_type
class AgentTurnResponseTurnAwaitingInputPayload(BaseModel):
event_type: Literal[AgentTurnResponseEventType.turn_awaiting_input.value] = (
AgentTurnResponseEventType.turn_awaiting_input.value
)
event_type: Literal[AgentTurnResponseEventType.turn_awaiting_input] = AgentTurnResponseEventType.turn_awaiting_input
turn: Turn
@ -341,7 +348,7 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
messages: list[UserMessage | ToolResponseMessage]
documents: list[Document] | None = None
toolgroups: list[AgentToolGroup] | None = None
toolgroups: list[AgentToolGroup] | None = Field(default_factory=lambda: [])
stream: bool | None = False
tool_config: ToolConfig | None = None