chore(api): add mypy coverage to meta_reference_agent_instance

Signed-off-by: Mustafa Elbehery <melbeher@redhat.com>
This commit is contained in:
Mustafa Elbehery 2025-07-08 21:11:45 +02:00
parent 81109a0f72
commit 1d52c303d1
3 changed files with 105 additions and 81 deletions

View file

@ -12,6 +12,7 @@ import string
import uuid
from collections.abc import AsyncGenerator
from datetime import UTC, datetime
from typing import Any
import httpx
@ -73,7 +74,7 @@ from .persistence import AgentPersistence
from .safety import SafetyException, ShieldRunnerMixin
def make_random_string(length: int = 8):
def make_random_string(length: int = 8) -> str:
return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length))
@ -117,12 +118,13 @@ class ChatAgent(ShieldRunnerMixin):
)
def turn_to_messages(self, turn: Turn) -> list[Message]:
messages = []
messages: list[Message] = []
# NOTE: if a toolcall response is in a step, we do not add it when processing the input messages
tool_call_ids = set()
for step in turn.steps:
if step.step_type == StepType.tool_execution.value:
if isinstance(step, ToolExecutionStep):
for response in step.tool_responses:
tool_call_ids.add(response.call_id)
@ -142,8 +144,10 @@ class ChatAgent(ShieldRunnerMixin):
for step in turn.steps:
if step.step_type == StepType.inference.value:
if isinstance(step, InferenceStep):
messages.append(step.model_response)
elif step.step_type == StepType.tool_execution.value:
if isinstance(step, ToolExecutionStep):
for response in step.tool_responses:
messages.append(
ToolResponseMessage(
@ -152,11 +156,12 @@ class ChatAgent(ShieldRunnerMixin):
)
)
elif step.step_type == StepType.shield_call.value:
if isinstance(step, ShieldCallStep):
if step.violation:
# CompletionMessage itself in the ShieldResponse
messages.append(
CompletionMessage(
content=step.violation.user_message,
content=step.violation.user_message or "",
stop_reason=StopReason.end_of_turn,
)
)
@ -166,7 +171,7 @@ class ChatAgent(ShieldRunnerMixin):
return await self.storage.create_session(name)
async def get_messages_from_turns(self, turns: list[Turn]) -> list[Message]:
messages = []
messages: list[Message] = []
if self.agent_config.instructions != "":
messages.append(SystemMessage(content=self.agent_config.instructions))
@ -174,7 +179,9 @@ class ChatAgent(ShieldRunnerMixin):
messages.extend(self.turn_to_messages(turn))
return messages
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
async def create_and_execute_turn(
self, request: AgentTurnCreateRequest
) -> AsyncGenerator[AgentTurnResponseStreamChunk, None]:
span = tracing.get_current_span()
if span:
span.set_attribute("session_id", request.session_id)
@ -189,7 +196,7 @@ class ChatAgent(ShieldRunnerMixin):
async for chunk in self._run_turn(request, turn_id):
yield chunk
async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator:
async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator[AgentTurnResponseStreamChunk, None]:
span = tracing.get_current_span()
if span:
span.set_attribute("agent_id", self.agent_id)
@ -207,7 +214,7 @@ class ChatAgent(ShieldRunnerMixin):
self,
request: AgentTurnCreateRequest | AgentTurnResumeRequest,
turn_id: str | None = None,
) -> AsyncGenerator:
) -> AsyncGenerator[AgentTurnResponseStreamChunk, None]:
assert request.stream is True, "Non-streaming not supported"
is_resume = isinstance(request, AgentTurnResumeRequest)
@ -271,11 +278,13 @@ class ChatAgent(ShieldRunnerMixin):
input_messages = request.messages
output_message = None
turn_id_final = turn_id or str(uuid.uuid4())
sampling_params = self.agent_config.sampling_params or SamplingParams()
async for chunk in self.run(
session_id=request.session_id,
turn_id=turn_id,
turn_id=turn_id_final,
input_messages=messages,
sampling_params=self.agent_config.sampling_params,
sampling_params=sampling_params,
stream=request.stream,
documents=request.documents if not is_resume else None,
):
@ -286,6 +295,7 @@ class ChatAgent(ShieldRunnerMixin):
assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}"
event = chunk.event
if event.payload.event_type == AgentTurnResponseEventType.step_complete.value:
if hasattr(event.payload, "step_details"):
steps.append(event.payload.step_details)
yield chunk
@ -293,12 +303,12 @@ class ChatAgent(ShieldRunnerMixin):
assert output_message is not None
turn = Turn(
turn_id=turn_id,
turn_id=turn_id_final,
session_id=request.session_id,
input_messages=input_messages,
output_message=output_message,
started_at=start_time,
completed_at=datetime.now(UTC).isoformat(),
completed_at=datetime.now(UTC),
steps=steps,
)
await self.storage.add_turn_to_session(request.session_id, turn)
@ -329,7 +339,7 @@ class ChatAgent(ShieldRunnerMixin):
sampling_params: SamplingParams,
stream: bool = False,
documents: list[Document] | None = None,
) -> AsyncGenerator:
) -> AsyncGenerator[AgentTurnResponseStreamChunk | CompletionMessage, None]:
# Doing async generators makes downstream code much simpler and everything amenable to
# streaming. However, it also makes things complicated here because AsyncGenerators cannot
# return a "final value" for the `yield from` statement. we simulate that by yielding a
@ -381,7 +391,7 @@ class ChatAgent(ShieldRunnerMixin):
messages: list[Message],
shields: list[str],
touchpoint: str,
) -> AsyncGenerator:
) -> AsyncGenerator[AgentTurnResponseStreamChunk | CompletionMessage | bool, None]:
async with tracing.span("run_shields") as span:
span.set_attribute("input", [m.model_dump_json() for m in messages])
if len(shields) == 0:
@ -389,12 +399,12 @@ class ChatAgent(ShieldRunnerMixin):
return
step_id = str(uuid.uuid4())
shield_call_start_time = datetime.now(UTC).isoformat()
shield_call_start_time = datetime.now(UTC)
try:
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepStartPayload(
step_type=StepType.shield_call.value,
step_type=StepType.shield_call,
step_id=step_id,
metadata=dict(touchpoint=touchpoint),
)
@ -406,14 +416,14 @@ class ChatAgent(ShieldRunnerMixin):
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.shield_call.value,
step_type=StepType.shield_call,
step_id=step_id,
step_details=ShieldCallStep(
step_id=step_id,
turn_id=turn_id,
violation=e.violation,
started_at=shield_call_start_time,
completed_at=datetime.now(UTC).isoformat(),
completed_at=datetime.now(UTC),
),
)
)
@ -429,14 +439,14 @@ class ChatAgent(ShieldRunnerMixin):
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.shield_call.value,
step_type=StepType.shield_call,
step_id=step_id,
step_details=ShieldCallStep(
step_id=step_id,
turn_id=turn_id,
violation=None,
started_at=shield_call_start_time,
completed_at=datetime.now(UTC).isoformat(),
completed_at=datetime.now(UTC),
),
)
)
@ -451,7 +461,7 @@ class ChatAgent(ShieldRunnerMixin):
sampling_params: SamplingParams,
stream: bool = False,
documents: list[Document] | None = None,
) -> AsyncGenerator:
) -> AsyncGenerator[AgentTurnResponseStreamChunk | CompletionMessage | bool, None]:
# if document is passed in a turn, we parse the raw text of the document
# and sent it as a user message
if documents:
@ -481,43 +491,46 @@ class ChatAgent(ShieldRunnerMixin):
else:
self.tool_name_to_args[tool_name]["vector_db_ids"].append(session_info.vector_db_id)
output_attachments = []
output_attachments: list[Attachment] = []
n_iter = await self.storage.get_num_infer_iters_in_turn(session_id, turn_id) or 0
# Build a map of custom tools to their definitions for faster lookup
client_tools = {}
for tool in self.agent_config.client_tools:
client_tools: dict[str, Any] = {}
for tool in self.agent_config.client_tools or []:
client_tools[tool.name] = tool
while True:
step_id = str(uuid.uuid4())
inference_start_time = datetime.now(UTC).isoformat()
inference_start_time = datetime.now(UTC)
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepStartPayload(
step_type=StepType.inference.value,
step_type=StepType.inference,
step_id=step_id,
)
)
)
tool_calls = []
tool_calls: list[ToolCall] = []
content = ""
stop_reason = None
async with tracing.span("inference") as span:
if self.agent_config.name:
span.set_attribute("agent_name", self.agent_config.name)
async for chunk in await self.inference_api.chat_completion(
chat_completion_response = await self.inference_api.chat_completion(
self.agent_config.model,
input_messages,
tools=self.tool_defs,
tool_prompt_format=self.agent_config.tool_config.tool_prompt_format,
tool_prompt_format=self.agent_config.tool_config.tool_prompt_format
if self.agent_config.tool_config
else None,
response_format=self.agent_config.response_format,
stream=True,
sampling_params=sampling_params,
tool_config=self.agent_config.tool_config,
):
)
async for chunk in chat_completion_response:
event = chunk.event
if event.event_type == ChatCompletionResponseEventType.start:
continue
@ -527,16 +540,18 @@ class ChatAgent(ShieldRunnerMixin):
delta = event.delta
if delta.type == "tool_call":
if delta.parse_status == ToolCallParseStatus.succeeded:
if hasattr(delta, "parse_status") and delta.parse_status == ToolCallParseStatus.succeeded:
if hasattr(delta, "tool_call"):
tool_calls.append(delta.tool_call)
elif delta.parse_status == ToolCallParseStatus.failed:
elif hasattr(delta, "parse_status") and delta.parse_status == ToolCallParseStatus.failed:
# If we cannot parse the tools, set the content to the unparsed raw text
content = delta.tool_call
if hasattr(delta, "tool_call"):
content = str(delta.tool_call)
if stream:
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.inference.value,
step_type=StepType.inference,
step_id=step_id,
delta=delta,
)
@ -544,12 +559,13 @@ class ChatAgent(ShieldRunnerMixin):
)
elif delta.type == "text":
if hasattr(delta, "text"):
content += delta.text
if stream and event.stop_reason is None:
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.inference.value,
step_type=StepType.inference,
step_id=step_id,
delta=delta,
)
@ -565,10 +581,16 @@ class ChatAgent(ShieldRunnerMixin):
"input",
json.dumps([json.loads(m.model_dump_json()) for m in input_messages]),
)
tool_calls_json = []
for t in tool_calls:
if hasattr(t, "model_dump_json"):
tool_calls_json.append(json.loads(t.model_dump_json()))
else:
tool_calls_json.append(str(t))
output_attr = json.dumps(
{
"content": content,
"tool_calls": [json.loads(t.model_dump_json()) for t in tool_calls],
"tool_calls": tool_calls_json,
}
)
span.set_attribute("output", output_attr)
@ -593,7 +615,7 @@ class ChatAgent(ShieldRunnerMixin):
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.inference.value,
step_type=StepType.inference,
step_id=step_id,
step_details=InferenceStep(
# somewhere deep, we are re-assigning message or closing over some
@ -603,13 +625,14 @@ class ChatAgent(ShieldRunnerMixin):
turn_id=turn_id,
model_response=copy.deepcopy(message),
started_at=inference_start_time,
completed_at=datetime.now(UTC).isoformat(),
completed_at=datetime.now(UTC),
),
)
)
)
if n_iter >= self.agent_config.max_infer_iters:
max_infer_iters = self.agent_config.max_infer_iters or 10
if n_iter >= max_infer_iters:
logger.info(f"done with MAX iterations ({n_iter}), exiting.")
# NOTE: mark end_of_turn to indicate to client that we are done with the turn
# Do not continue the tool call loop after this point
@ -622,7 +645,8 @@ class ChatAgent(ShieldRunnerMixin):
yield message
break
if len(message.tool_calls) == 0:
tool_calls_to_process = message.tool_calls or []
if len(tool_calls_to_process) == 0:
if stop_reason == StopReason.end_of_turn:
# TODO: UPDATE RETURN TYPE TO SEND A TUPLE OF (MESSAGE, ATTACHMENTS)
if len(output_attachments) > 0:
@ -642,7 +666,7 @@ class ChatAgent(ShieldRunnerMixin):
non_client_tool_calls = []
# Separate client and non-client tool calls
for tool_call in message.tool_calls:
for tool_call in tool_calls_to_process:
if tool_call.tool_name in client_tools:
client_tool_calls.append(tool_call)
else:
@ -654,7 +678,7 @@ class ChatAgent(ShieldRunnerMixin):
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepStartPayload(
step_type=StepType.tool_execution.value,
step_type=StepType.tool_execution,
step_id=step_id,
)
)
@ -663,7 +687,7 @@ class ChatAgent(ShieldRunnerMixin):
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.tool_execution.value,
step_type=StepType.tool_execution,
step_id=step_id,
delta=ToolCallDelta(
parse_status=ToolCallParseStatus.in_progress,
@ -681,7 +705,7 @@ class ChatAgent(ShieldRunnerMixin):
"input": message.model_dump_json(),
},
) as span:
tool_execution_start_time = datetime.now(UTC).isoformat()
tool_execution_start_time = datetime.now(UTC)
tool_result = await self.execute_tool_call_maybe(
session_id,
tool_call,
@ -710,14 +734,14 @@ class ChatAgent(ShieldRunnerMixin):
)
],
started_at=tool_execution_start_time,
completed_at=datetime.now(UTC).isoformat(),
completed_at=datetime.now(UTC),
)
# Yield the step completion event
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.tool_execution.value,
step_type=StepType.tool_execution,
step_id=step_id,
step_details=tool_execution_step,
)
@ -747,7 +771,7 @@ class ChatAgent(ShieldRunnerMixin):
turn_id=turn_id,
tool_calls=client_tool_calls,
tool_responses=[],
started_at=datetime.now(UTC).isoformat(),
started_at=datetime.now(UTC),
),
)
@ -766,7 +790,7 @@ class ChatAgent(ShieldRunnerMixin):
self,
toolgroups_for_turn: list[AgentToolGroup] | None = None,
) -> None:
toolgroup_to_args = {}
toolgroup_to_args: dict[str, Any] = {}
for toolgroup in (self.agent_config.toolgroups or []) + (toolgroups_for_turn or []):
if isinstance(toolgroup, AgentToolGroupWithArgs):
tool_group_name, _ = self._parse_toolgroup_name(toolgroup.name)
@ -782,10 +806,10 @@ class ChatAgent(ShieldRunnerMixin):
toolgroup_to_args = toolgroup_to_args or {}
tool_name_to_def = {}
tool_name_to_args = {}
tool_name_to_def: dict[str, ToolDefinition] = {}
tool_name_to_args: dict[str, Any] = {}
for tool_def in self.agent_config.client_tools:
for tool_def in self.agent_config.client_tools or []:
if tool_name_to_def.get(tool_def.name, None):
raise ValueError(f"Tool {tool_def.name} already exists")
tool_name_to_def[tool_def.name] = ToolDefinition(
@ -798,7 +822,7 @@ class ChatAgent(ShieldRunnerMixin):
required=param.required,
default=param.default,
)
for param in tool_def.parameters
for param in tool_def.parameters or []
},
)
for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups:
@ -828,7 +852,7 @@ class ChatAgent(ShieldRunnerMixin):
else:
identifier = None
if tool_name_to_def.get(identifier, None):
if identifier is not None and tool_name_to_def.get(str(identifier), None):
raise ValueError(f"Tool {identifier} already exists")
if identifier:
tool_name_to_def[tool_def.identifier] = ToolDefinition(
@ -841,7 +865,7 @@ class ChatAgent(ShieldRunnerMixin):
required=param.required,
default=param.default,
)
for param in tool_def.parameters
for param in tool_def.parameters or []
},
)
tool_name_to_args[tool_def.identifier] = toolgroup_to_args.get(toolgroup_name, {})
@ -888,14 +912,15 @@ class ChatAgent(ShieldRunnerMixin):
tool_name_str = tool_name
logger.info(f"executing tool call: {tool_name_str} with args: {tool_call.arguments}")
kwargs_dict = {"session_id": session_id}
if tool_call.arguments and isinstance(tool_call.arguments, dict):
kwargs_dict.update(tool_call.arguments)
tool_args = self.tool_name_to_args.get(tool_name_str, {})
if tool_args:
kwargs_dict.update(tool_args)
result = await self.tool_runtime_api.invoke_tool(
tool_name=tool_name_str,
kwargs={
"session_id": session_id,
# get the arguments generated by the model and augment with toolgroup arg overrides for the agent
**tool_call.arguments,
**self.tool_name_to_args.get(tool_name_str, {}),
},
kwargs=kwargs_dict,
)
logger.debug(f"tool call {tool_name_str} completed with result: {result}")
return result
@ -931,7 +956,7 @@ def _interpret_content_as_attachment(
snippet = match.group(1)
data = json.loads(snippet)
return Attachment(
url=URL(uri="file://" + data["filepath"]),
content=URL(uri="file://" + data["filepath"]),
mime_type=data["mimetype"],
)

View file

@ -28,8 +28,8 @@ class ShieldRunnerMixin:
output_shields: list[str] | None = None,
):
self.safety_api = safety_api
self.input_shields = input_shields
self.output_shields = output_shields
self.input_shields = input_shields or []
self.output_shields = output_shields or []
async def run_multiple_shields(self, messages: list[Message], identifiers: list[str]) -> None:
async def run_shield_with_span(identifier: str):

View file

@ -243,7 +243,6 @@ exclude = [
"^llama_stack/models/llama/llama3/tokenizer\\.py$",
"^llama_stack/models/llama/llama3/tool_utils\\.py$",
"^llama_stack/providers/inline/agents/meta_reference/",
"^llama_stack/providers/inline/agents/meta_reference/agent_instance\\.py$",
"^llama_stack/providers/inline/agents/meta_reference/agents\\.py$",
"^llama_stack/providers/inline/datasetio/localfs/",
"^llama_stack/providers/inline/eval/meta_reference/eval\\.py$",