fix(mypy): resolve agent_instance type issues (part 1 of 2)

Fixed 35 errors, 46 remaining:
- Add isinstance() checks for union type discrimination
- Fix list type annotations for Message types
- Convert strings to datetime/StepType where needed
- Use assert to narrow AgentTurnCreateRequest vs AgentTurnResumeRequest
- Add explicit type annotations to avoid inference issues

Still to fix:
- Remaining str to datetime/StepType conversions
- Optional list handling for shields
- Type annotations for tool maps
- List variance issues for input_messages
- Fix turn_id variable redefinition
This commit is contained in:
Ashwin Bharambe 2025-10-28 11:34:30 -07:00
parent 3a437d80af
commit 3cf36e665b

View file

@ -11,6 +11,7 @@ import uuid
import warnings import warnings
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from datetime import UTC, datetime from datetime import UTC, datetime
from typing import Any
import httpx import httpx
@ -125,12 +126,12 @@ class ChatAgent(ShieldRunnerMixin):
) )
def turn_to_messages(self, turn: Turn) -> list[Message]: def turn_to_messages(self, turn: Turn) -> list[Message]:
messages = [] messages: list[Message] = []
# NOTE: if a toolcall response is in a step, we do not add it when processing the input messages # NOTE: if a toolcall response is in a step, we do not add it when processing the input messages
tool_call_ids = set() tool_call_ids = set()
for step in turn.steps: for step in turn.steps:
if step.step_type == StepType.tool_execution.value: if step.step_type == StepType.tool_execution.value and isinstance(step, ToolExecutionStep):
for response in step.tool_responses: for response in step.tool_responses:
tool_call_ids.add(response.call_id) tool_call_ids.add(response.call_id)
@ -149,9 +150,9 @@ class ChatAgent(ShieldRunnerMixin):
messages.append(msg) messages.append(msg)
for step in turn.steps: for step in turn.steps:
if step.step_type == StepType.inference.value: if step.step_type == StepType.inference.value and isinstance(step, InferenceStep):
messages.append(step.model_response) messages.append(step.model_response)
elif step.step_type == StepType.tool_execution.value: elif step.step_type == StepType.tool_execution.value and isinstance(step, ToolExecutionStep):
for response in step.tool_responses: for response in step.tool_responses:
messages.append( messages.append(
ToolResponseMessage( ToolResponseMessage(
@ -159,7 +160,7 @@ class ChatAgent(ShieldRunnerMixin):
content=response.content, content=response.content,
) )
) )
elif step.step_type == StepType.shield_call.value: elif step.step_type == StepType.shield_call.value and isinstance(step, ShieldCallStep):
if step.violation: if step.violation:
# CompletionMessage itself in the ShieldResponse # CompletionMessage itself in the ShieldResponse
messages.append( messages.append(
@ -174,7 +175,7 @@ class ChatAgent(ShieldRunnerMixin):
return await self.storage.create_session(name) return await self.storage.create_session(name)
async def get_messages_from_turns(self, turns: list[Turn]) -> list[Message]: async def get_messages_from_turns(self, turns: list[Turn]) -> list[Message]:
messages = [] messages: list[Message] = []
if self.agent_config.instructions != "": if self.agent_config.instructions != "":
messages.append(SystemMessage(content=self.agent_config.instructions)) messages.append(SystemMessage(content=self.agent_config.instructions))
@ -231,7 +232,13 @@ class ChatAgent(ShieldRunnerMixin):
steps = [] steps = []
messages = await self.get_messages_from_turns(turns) messages = await self.get_messages_from_turns(turns)
turn_id: str
start_time: datetime
input_messages: list[Message]
if is_resume: if is_resume:
assert isinstance(request, AgentTurnResumeRequest), "Expected AgentTurnResumeRequest for resume"
tool_response_messages = [ tool_response_messages = [
ToolResponseMessage(call_id=x.call_id, content=x.content) for x in request.tool_responses ToolResponseMessage(call_id=x.call_id, content=x.content) for x in request.tool_responses
] ]
@ -252,20 +259,21 @@ class ChatAgent(ShieldRunnerMixin):
in_progress_tool_call_step = await self.storage.get_in_progress_tool_call_step( in_progress_tool_call_step = await self.storage.get_in_progress_tool_call_step(
request.session_id, request.turn_id request.session_id, request.turn_id
) )
now = datetime.now(UTC).isoformat() now_iso = datetime.now(UTC).isoformat()
now_dt = datetime.now(UTC)
tool_execution_step = ToolExecutionStep( tool_execution_step = ToolExecutionStep(
step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())), step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())),
turn_id=request.turn_id, turn_id=request.turn_id,
tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []), tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []),
tool_responses=request.tool_responses, tool_responses=request.tool_responses,
completed_at=now, completed_at=now_dt,
started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now), started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now_dt),
) )
steps.append(tool_execution_step) steps.append(tool_execution_step)
yield AgentTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent( event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload( payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.tool_execution.value, step_type=StepType.tool_execution,
step_id=tool_execution_step.step_id, step_id=tool_execution_step.step_id,
step_details=tool_execution_step, step_details=tool_execution_step,
) )
@ -276,18 +284,22 @@ class ChatAgent(ShieldRunnerMixin):
turn_id = request.turn_id turn_id = request.turn_id
start_time = last_turn.started_at start_time = last_turn.started_at
else: else:
assert isinstance(request, AgentTurnCreateRequest), "Expected AgentTurnCreateRequest for create"
messages.extend(request.messages) messages.extend(request.messages)
start_time = datetime.now(UTC).isoformat() start_time = datetime.now(UTC)
input_messages = request.messages input_messages = request.messages
output_message = None output_message = None
req_documents = request.documents if isinstance(request, AgentTurnCreateRequest) and not is_resume else None
req_sampling = self.agent_config.sampling_params if self.agent_config.sampling_params is not None else SamplingParams()
async for chunk in self.run( async for chunk in self.run(
session_id=request.session_id, session_id=request.session_id,
turn_id=turn_id, turn_id=turn_id,
input_messages=messages, input_messages=messages,
sampling_params=self.agent_config.sampling_params, sampling_params=req_sampling,
stream=request.stream, stream=request.stream,
documents=request.documents if not is_resume else None, documents=req_documents,
): ):
if isinstance(chunk, CompletionMessage): if isinstance(chunk, CompletionMessage):
output_message = chunk output_message = chunk
@ -295,8 +307,12 @@ class ChatAgent(ShieldRunnerMixin):
assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}" assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}"
event = chunk.event event = chunk.event
if event.payload.event_type == AgentTurnResponseEventType.step_complete.value: if (
steps.append(event.payload.step_details) event.payload.event_type == AgentTurnResponseEventType.step_complete.value
and hasattr(event.payload, "step_details")
):
step_details = getattr(event.payload, "step_details")
steps.append(step_details)
yield chunk yield chunk
@ -308,7 +324,7 @@ class ChatAgent(ShieldRunnerMixin):
input_messages=input_messages, input_messages=input_messages,
output_message=output_message, output_message=output_message,
started_at=start_time, started_at=start_time,
completed_at=datetime.now(UTC).isoformat(), completed_at=datetime.now(UTC),
steps=steps, steps=steps,
) )
await self.storage.add_turn_to_session(request.session_id, turn) await self.storage.add_turn_to_session(request.session_id, turn)