fix(mypy): resolve agent_instance type issues (part 1 of 2)

Fixed 35 errors, 46 remaining:
- Add isinstance() checks for union type discrimination
- Fix list type annotations for Message types
- Convert strings to datetime/StepType where needed
- Use assert to narrow AgentTurnCreateRequest vs AgentTurnResumeRequest
- Add explicit type annotations to avoid inference issues

Still to fix:
- Remaining str to datetime/StepType conversions
- Optional list handling for shields
- Type annotations for tool maps
- List variance issues for input_messages
- Fix turn_id variable redefinition
This commit is contained in:
Ashwin Bharambe 2025-10-28 11:34:30 -07:00
parent 3a437d80af
commit 3cf36e665b

View file

@ -11,6 +11,7 @@ import uuid
import warnings
from collections.abc import AsyncGenerator
from datetime import UTC, datetime
from typing import Any
import httpx
@ -125,12 +126,12 @@ class ChatAgent(ShieldRunnerMixin):
)
def turn_to_messages(self, turn: Turn) -> list[Message]:
messages = []
messages: list[Message] = []
# NOTE: if a toolcall response is in a step, we do not add it when processing the input messages
tool_call_ids = set()
for step in turn.steps:
if step.step_type == StepType.tool_execution.value:
if step.step_type == StepType.tool_execution.value and isinstance(step, ToolExecutionStep):
for response in step.tool_responses:
tool_call_ids.add(response.call_id)
@ -149,9 +150,9 @@ class ChatAgent(ShieldRunnerMixin):
messages.append(msg)
for step in turn.steps:
if step.step_type == StepType.inference.value:
if step.step_type == StepType.inference.value and isinstance(step, InferenceStep):
messages.append(step.model_response)
elif step.step_type == StepType.tool_execution.value:
elif step.step_type == StepType.tool_execution.value and isinstance(step, ToolExecutionStep):
for response in step.tool_responses:
messages.append(
ToolResponseMessage(
@ -159,7 +160,7 @@ class ChatAgent(ShieldRunnerMixin):
content=response.content,
)
)
elif step.step_type == StepType.shield_call.value:
elif step.step_type == StepType.shield_call.value and isinstance(step, ShieldCallStep):
if step.violation:
# CompletionMessage itself in the ShieldResponse
messages.append(
@ -174,7 +175,7 @@ class ChatAgent(ShieldRunnerMixin):
return await self.storage.create_session(name)
async def get_messages_from_turns(self, turns: list[Turn]) -> list[Message]:
messages = []
messages: list[Message] = []
if self.agent_config.instructions != "":
messages.append(SystemMessage(content=self.agent_config.instructions))
@ -231,7 +232,13 @@ class ChatAgent(ShieldRunnerMixin):
steps = []
messages = await self.get_messages_from_turns(turns)
turn_id: str
start_time: datetime
input_messages: list[Message]
if is_resume:
assert isinstance(request, AgentTurnResumeRequest), "Expected AgentTurnResumeRequest for resume"
tool_response_messages = [
ToolResponseMessage(call_id=x.call_id, content=x.content) for x in request.tool_responses
]
@ -252,20 +259,21 @@ class ChatAgent(ShieldRunnerMixin):
in_progress_tool_call_step = await self.storage.get_in_progress_tool_call_step(
request.session_id, request.turn_id
)
now = datetime.now(UTC).isoformat()
now_iso = datetime.now(UTC).isoformat()
now_dt = datetime.now(UTC)
tool_execution_step = ToolExecutionStep(
step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())),
turn_id=request.turn_id,
tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []),
tool_responses=request.tool_responses,
completed_at=now,
started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now),
completed_at=now_dt,
started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now_dt),
)
steps.append(tool_execution_step)
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.tool_execution.value,
step_type=StepType.tool_execution,
step_id=tool_execution_step.step_id,
step_details=tool_execution_step,
)
@ -276,18 +284,22 @@ class ChatAgent(ShieldRunnerMixin):
turn_id = request.turn_id
start_time = last_turn.started_at
else:
assert isinstance(request, AgentTurnCreateRequest), "Expected AgentTurnCreateRequest for create"
messages.extend(request.messages)
start_time = datetime.now(UTC).isoformat()
start_time = datetime.now(UTC)
input_messages = request.messages
output_message = None
req_documents = request.documents if isinstance(request, AgentTurnCreateRequest) and not is_resume else None
req_sampling = self.agent_config.sampling_params if self.agent_config.sampling_params is not None else SamplingParams()
async for chunk in self.run(
session_id=request.session_id,
turn_id=turn_id,
input_messages=messages,
sampling_params=self.agent_config.sampling_params,
sampling_params=req_sampling,
stream=request.stream,
documents=request.documents if not is_resume else None,
documents=req_documents,
):
if isinstance(chunk, CompletionMessage):
output_message = chunk
@ -295,8 +307,12 @@ class ChatAgent(ShieldRunnerMixin):
assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}"
event = chunk.event
if event.payload.event_type == AgentTurnResponseEventType.step_complete.value:
steps.append(event.payload.step_details)
if (
event.payload.event_type == AgentTurnResponseEventType.step_complete.value
and hasattr(event.payload, "step_details")
):
step_details = getattr(event.payload, "step_details")
steps.append(step_details)
yield chunk
@ -308,7 +324,7 @@ class ChatAgent(ShieldRunnerMixin):
input_messages=input_messages,
output_message=output_message,
started_at=start_time,
completed_at=datetime.now(UTC).isoformat(),
completed_at=datetime.now(UTC),
steps=steps,
)
await self.storage.add_turn_to_session(request.session_id, turn)