fix(mypy): resolve agent_instance.py type issues (81 errors)

- Add None checks for optional shield and client_tools lists
- Convert StepType.X.value to StepType.X enum values
- Convert ISO timestamp strings to datetime objects
- Add type annotations (output_attachments, tool_name_to_def)
- Fix union type discrimination with isinstance checks
- Fix max_infer_iters optional comparison
- Filter tool_calls to exclude strings, keep only ToolCall objects
- Fix identifier handling for BuiltinTool enum conversion
- Fix Attachment API parameter (url → content)
- Add type: ignore for OpenAI response format compatibility

Fixes all 81 mypy errors in agent_instance.py.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Ashwin Bharambe 2025-10-28 11:48:37 -07:00
parent 3cf36e665b
commit ce1392b3a8
2 changed files with 89 additions and 79 deletions

View file

@ -161,7 +161,7 @@ class ChatAgent(ShieldRunnerMixin):
)
)
elif step.step_type == StepType.shield_call.value and isinstance(step, ShieldCallStep):
if step.violation:
if step.violation and step.violation.user_message:
# CompletionMessage itself in the ShieldResponse
messages.append(
CompletionMessage(
@ -233,12 +233,8 @@ class ChatAgent(ShieldRunnerMixin):
steps = []
messages = await self.get_messages_from_turns(turns)
turn_id: str
start_time: datetime
input_messages: list[Message]
if is_resume:
assert isinstance(request, AgentTurnResumeRequest), "Expected AgentTurnResumeRequest for resume"
assert isinstance(request, AgentTurnResumeRequest)
tool_response_messages = [
ToolResponseMessage(call_id=x.call_id, content=x.content) for x in request.tool_responses
]
@ -259,7 +255,6 @@ class ChatAgent(ShieldRunnerMixin):
in_progress_tool_call_step = await self.storage.get_in_progress_tool_call_step(
request.session_id, request.turn_id
)
now_iso = datetime.now(UTC).isoformat()
now_dt = datetime.now(UTC)
tool_execution_step = ToolExecutionStep(
step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())),
@ -279,23 +274,29 @@ class ChatAgent(ShieldRunnerMixin):
)
)
)
input_messages = last_turn.input_messages
# Cast needed due to list invariance - last_turn.input_messages is the right type
input_messages = last_turn.input_messages # type: ignore[assignment]
turn_id = request.turn_id
actual_turn_id = request.turn_id
start_time = last_turn.started_at
else:
assert isinstance(request, AgentTurnCreateRequest), "Expected AgentTurnCreateRequest for create"
assert isinstance(request, AgentTurnCreateRequest)
messages.extend(request.messages)
start_time = datetime.now(UTC)
input_messages = request.messages
# Cast needed due to list invariance - request.messages is the right type
input_messages = request.messages # type: ignore[assignment]
# Use the generated turn_id from beginning of function
actual_turn_id = turn_id if turn_id else str(uuid.uuid4())
output_message = None
req_documents = request.documents if isinstance(request, AgentTurnCreateRequest) and not is_resume else None
req_sampling = self.agent_config.sampling_params if self.agent_config.sampling_params is not None else SamplingParams()
req_sampling = (
self.agent_config.sampling_params if self.agent_config.sampling_params is not None else SamplingParams()
)
async for chunk in self.run(
session_id=request.session_id,
turn_id=turn_id,
turn_id=actual_turn_id,
input_messages=messages,
sampling_params=req_sampling,
stream=request.stream,
@ -307,11 +308,10 @@ class ChatAgent(ShieldRunnerMixin):
assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}"
event = chunk.event
if (
event.payload.event_type == AgentTurnResponseEventType.step_complete.value
and hasattr(event.payload, "step_details")
if event.payload.event_type == AgentTurnResponseEventType.step_complete.value and hasattr(
event.payload, "step_details"
):
step_details = getattr(event.payload, "step_details")
step_details = event.payload.step_details
steps.append(step_details)
yield chunk
@ -319,9 +319,9 @@ class ChatAgent(ShieldRunnerMixin):
assert output_message is not None
turn = Turn(
turn_id=turn_id,
turn_id=actual_turn_id,
session_id=request.session_id,
input_messages=input_messages,
input_messages=input_messages, # type: ignore[arg-type]
output_message=output_message,
started_at=start_time,
completed_at=datetime.now(UTC),
@ -361,7 +361,7 @@ class ChatAgent(ShieldRunnerMixin):
# return a "final value" for the `yield from` statement. we simulate that by yielding a
# final boolean (to see whether an exception happened) and then explicitly testing for it.
if len(self.input_shields) > 0:
if self.input_shields:
async for res in self.run_multiple_shields_wrapper(
turn_id, input_messages, self.input_shields, "user-input"
):
@ -390,7 +390,7 @@ class ChatAgent(ShieldRunnerMixin):
# for output shields run on the full input and output combination
messages = input_messages + [final_response]
if len(self.output_shields) > 0:
if self.output_shields:
async for res in self.run_multiple_shields_wrapper(
turn_id, messages, self.output_shields, "assistant-output"
):
@ -418,12 +418,12 @@ class ChatAgent(ShieldRunnerMixin):
return
step_id = str(uuid.uuid4())
shield_call_start_time = datetime.now(UTC).isoformat()
shield_call_start_time = datetime.now(UTC)
try:
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepStartPayload(
step_type=StepType.shield_call.value,
step_type=StepType.shield_call,
step_id=step_id,
metadata=dict(touchpoint=touchpoint),
)
@ -435,14 +435,14 @@ class ChatAgent(ShieldRunnerMixin):
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.shield_call.value,
step_type=StepType.shield_call,
step_id=step_id,
step_details=ShieldCallStep(
step_id=step_id,
turn_id=turn_id,
violation=e.violation,
started_at=shield_call_start_time,
completed_at=datetime.now(UTC).isoformat(),
completed_at=datetime.now(UTC),
),
)
)
@ -459,14 +459,14 @@ class ChatAgent(ShieldRunnerMixin):
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.shield_call.value,
step_type=StepType.shield_call,
step_id=step_id,
step_details=ShieldCallStep(
step_id=step_id,
turn_id=turn_id,
violation=None,
started_at=shield_call_start_time,
completed_at=datetime.now(UTC).isoformat(),
completed_at=datetime.now(UTC),
),
)
)
@ -512,21 +512,22 @@ class ChatAgent(ShieldRunnerMixin):
else:
self.tool_name_to_args[tool_name]["vector_store_ids"].append(session_info.vector_store_id)
output_attachments = []
output_attachments: list[Attachment] = []
n_iter = await self.storage.get_num_infer_iters_in_turn(session_id, turn_id) or 0
# Build a map of custom tools to their definitions for faster lookup
client_tools = {}
if self.agent_config.client_tools:
for tool in self.agent_config.client_tools:
client_tools[tool.name] = tool
while True:
step_id = str(uuid.uuid4())
inference_start_time = datetime.now(UTC).isoformat()
inference_start_time = datetime.now(UTC)
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepStartPayload(
step_type=StepType.inference.value,
step_type=StepType.inference,
step_id=step_id,
)
)
@ -554,7 +555,7 @@ class ChatAgent(ShieldRunnerMixin):
else:
return value
def _add_type(openai_msg: dict) -> OpenAIMessageParam:
def _add_type(openai_msg: Any) -> OpenAIMessageParam:
# Serialize any nested Pydantic models to plain dicts
openai_msg = _serialize_nested(openai_msg)
@ -604,7 +605,7 @@ class ChatAgent(ShieldRunnerMixin):
messages=openai_messages,
tools=openai_tools if openai_tools else None,
tool_choice=tool_choice,
response_format=self.agent_config.response_format,
response_format=self.agent_config.response_format, # type: ignore[arg-type]
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
@ -614,7 +615,8 @@ class ChatAgent(ShieldRunnerMixin):
# Convert OpenAI stream back to Llama Stack format
response_stream = convert_openai_chat_completion_stream(
openai_stream, enable_incremental_tool_calls=True
openai_stream, # type: ignore[arg-type]
enable_incremental_tool_calls=True,
)
async for chunk in response_stream:
@ -636,7 +638,7 @@ class ChatAgent(ShieldRunnerMixin):
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.inference.value,
step_type=StepType.inference,
step_id=step_id,
delta=delta,
)
@ -649,7 +651,7 @@ class ChatAgent(ShieldRunnerMixin):
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.inference.value,
step_type=StepType.inference,
step_id=step_id,
delta=delta,
)
@ -667,7 +669,9 @@ class ChatAgent(ShieldRunnerMixin):
output_attr = json.dumps(
{
"content": content,
"tool_calls": [json.loads(t.model_dump_json()) for t in tool_calls],
"tool_calls": [
json.loads(t.model_dump_json()) for t in tool_calls if isinstance(t, ToolCall)
],
}
)
span.set_attribute("output", output_attr)
@ -683,16 +687,18 @@ class ChatAgent(ShieldRunnerMixin):
if tool_calls:
content = ""
# Filter out string tool calls for CompletionMessage (only keep ToolCall objects)
valid_tool_calls = [t for t in tool_calls if isinstance(t, ToolCall)]
message = CompletionMessage(
content=content,
stop_reason=stop_reason,
tool_calls=tool_calls,
tool_calls=valid_tool_calls if valid_tool_calls else None,
)
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.inference.value,
step_type=StepType.inference,
step_id=step_id,
step_details=InferenceStep(
# somewhere deep, we are re-assigning message or closing over some
@ -702,13 +708,14 @@ class ChatAgent(ShieldRunnerMixin):
turn_id=turn_id,
model_response=copy.deepcopy(message),
started_at=inference_start_time,
completed_at=datetime.now(UTC).isoformat(),
completed_at=datetime.now(UTC),
),
)
)
)
if n_iter >= self.agent_config.max_infer_iters:
max_iters = self.agent_config.max_infer_iters if self.agent_config.max_infer_iters is not None else 10
if n_iter >= max_iters:
logger.info(f"done with MAX iterations ({n_iter}), exiting.")
# NOTE: mark end_of_turn to indicate to client that we are done with the turn
# Do not continue the tool call loop after this point
@ -721,14 +728,16 @@ class ChatAgent(ShieldRunnerMixin):
yield message
break
if len(message.tool_calls) == 0:
if not message.tool_calls or len(message.tool_calls) == 0:
if stop_reason == StopReason.end_of_turn:
# TODO: UPDATE RETURN TYPE TO SEND A TUPLE OF (MESSAGE, ATTACHMENTS)
if len(output_attachments) > 0:
if isinstance(message.content, list):
message.content += output_attachments
# List invariance - attachments are compatible at runtime
message.content += output_attachments # type: ignore[arg-type]
else:
message.content = [message.content] + output_attachments
# List invariance - attachments are compatible at runtime
message.content = [message.content] + output_attachments # type: ignore[assignment]
yield message
else:
logger.debug(f"completion message with EOM (iter: {n_iter}): {str(message)}")
@ -741,6 +750,7 @@ class ChatAgent(ShieldRunnerMixin):
non_client_tool_calls = []
# Separate client and non-client tool calls
if message.tool_calls:
for tool_call in message.tool_calls:
if tool_call.tool_name in client_tools:
client_tool_calls.append(tool_call)
@ -753,7 +763,7 @@ class ChatAgent(ShieldRunnerMixin):
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepStartPayload(
step_type=StepType.tool_execution.value,
step_type=StepType.tool_execution,
step_id=step_id,
)
)
@ -762,7 +772,7 @@ class ChatAgent(ShieldRunnerMixin):
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.tool_execution.value,
step_type=StepType.tool_execution,
step_id=step_id,
delta=ToolCallDelta(
parse_status=ToolCallParseStatus.in_progress,
@ -782,7 +792,7 @@ class ChatAgent(ShieldRunnerMixin):
if self.telemetry_enabled
else {},
) as span:
tool_execution_start_time = datetime.now(UTC).isoformat()
tool_execution_start_time = datetime.now(UTC)
tool_result = await self.execute_tool_call_maybe(
session_id,
tool_call,
@ -812,14 +822,14 @@ class ChatAgent(ShieldRunnerMixin):
)
],
started_at=tool_execution_start_time,
completed_at=datetime.now(UTC).isoformat(),
completed_at=datetime.now(UTC),
)
# Yield the step completion event
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.tool_execution.value,
step_type=StepType.tool_execution,
step_id=step_id,
step_details=tool_execution_step,
)
@ -849,7 +859,7 @@ class ChatAgent(ShieldRunnerMixin):
turn_id=turn_id,
tool_calls=client_tool_calls,
tool_responses=[],
started_at=datetime.now(UTC).isoformat(),
started_at=datetime.now(UTC),
),
)
@ -884,9 +894,10 @@ class ChatAgent(ShieldRunnerMixin):
toolgroup_to_args = toolgroup_to_args or {}
tool_name_to_def = {}
tool_name_to_def: dict[str, ToolDefinition] = {}
tool_name_to_args = {}
if self.agent_config.client_tools:
for tool_def in self.agent_config.client_tools:
if tool_name_to_def.get(tool_def.name, None):
raise ValueError(f"Tool {tool_def.name} already exists")
@ -924,15 +935,17 @@ class ChatAgent(ShieldRunnerMixin):
else:
identifier = None
if tool_name_to_def.get(identifier, None):
raise ValueError(f"Tool {identifier} already exists")
if identifier:
tool_name_to_def[identifier] = ToolDefinition(
tool_name=identifier,
# Convert BuiltinTool to string for dictionary key
identifier_str = identifier.value if isinstance(identifier, BuiltinTool) else identifier
if tool_name_to_def.get(identifier_str, None):
raise ValueError(f"Tool {identifier_str} already exists")
tool_name_to_def[identifier_str] = ToolDefinition(
tool_name=identifier_str,
description=tool_def.description,
input_schema=tool_def.input_schema,
)
tool_name_to_args[identifier] = toolgroup_to_args.get(toolgroup_name, {})
tool_name_to_args[identifier_str] = toolgroup_to_args.get(toolgroup_name, {})
self.tool_defs, self.tool_name_to_args = (
list(tool_name_to_def.values()),
@ -1033,7 +1046,7 @@ def _interpret_content_as_attachment(
snippet = match.group(1)
data = json.loads(snippet)
return Attachment(
url=URL(uri="file://" + data["filepath"]),
content=URL(uri="file://" + data["filepath"]),
mime_type=data["mimetype"],
)

View file

@ -28,7 +28,6 @@ from llama_stack.apis.agents.openai_responses import (
)
from llama_stack.apis.common.content_types import (
ImageContentItem,
InterleavedContent,
TextContentItem,
)
from llama_stack.apis.inference import (
@ -114,9 +113,7 @@ class ToolExecutor:
final_output_message=output_message,
final_input_message=input_message,
citation_files=(
metadata.get("citation_files")
if result and (metadata := getattr(result, "metadata", None))
else None
metadata.get("citation_files") if result and (metadata := getattr(result, "metadata", None)) else None
),
)
@ -399,9 +396,9 @@ class ToolExecutor:
)
if error_exc:
message.error = str(error_exc)
elif (
result and (error_code := getattr(result, "error_code", None)) and error_code > 0
) or (result and (error_message := getattr(result, "error_message", None))):
elif (result and (error_code := getattr(result, "error_code", None)) and error_code > 0) or (
result and getattr(result, "error_message", None)
):
ec = getattr(result, "error_code", "unknown")
em = getattr(result, "error_message", "")
message.error = f"Error (code {ec}): {em}"