chore(api): add mypy coverage to meta_reference_agent_instance

Signed-off-by: Mustafa Elbehery <melbeher@redhat.com>
This commit is contained in:
Mustafa Elbehery 2025-07-08 21:11:45 +02:00
parent 81109a0f72
commit 1d52c303d1
3 changed files with 105 additions and 81 deletions

View file

@ -12,6 +12,7 @@ import string
import uuid import uuid
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from datetime import UTC, datetime from datetime import UTC, datetime
from typing import Any
import httpx import httpx
@ -73,7 +74,7 @@ from .persistence import AgentPersistence
from .safety import SafetyException, ShieldRunnerMixin from .safety import SafetyException, ShieldRunnerMixin
def make_random_string(length: int = 8): def make_random_string(length: int = 8) -> str:
return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length)) return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length))
@ -117,14 +118,15 @@ class ChatAgent(ShieldRunnerMixin):
) )
def turn_to_messages(self, turn: Turn) -> list[Message]: def turn_to_messages(self, turn: Turn) -> list[Message]:
messages = [] messages: list[Message] = []
# NOTE: if a toolcall response is in a step, we do not add it when processing the input messages # NOTE: if a toolcall response is in a step, we do not add it when processing the input messages
tool_call_ids = set() tool_call_ids = set()
for step in turn.steps: for step in turn.steps:
if step.step_type == StepType.tool_execution.value: if step.step_type == StepType.tool_execution.value:
for response in step.tool_responses: if isinstance(step, ToolExecutionStep):
tool_call_ids.add(response.call_id) for response in step.tool_responses:
tool_call_ids.add(response.call_id)
for m in turn.input_messages: for m in turn.input_messages:
msg = m.model_copy() msg = m.model_copy()
@ -142,31 +144,34 @@ class ChatAgent(ShieldRunnerMixin):
for step in turn.steps: for step in turn.steps:
if step.step_type == StepType.inference.value: if step.step_type == StepType.inference.value:
messages.append(step.model_response) if isinstance(step, InferenceStep):
messages.append(step.model_response)
elif step.step_type == StepType.tool_execution.value: elif step.step_type == StepType.tool_execution.value:
for response in step.tool_responses: if isinstance(step, ToolExecutionStep):
messages.append( for response in step.tool_responses:
ToolResponseMessage( messages.append(
call_id=response.call_id, ToolResponseMessage(
content=response.content, call_id=response.call_id,
content=response.content,
)
) )
)
elif step.step_type == StepType.shield_call.value: elif step.step_type == StepType.shield_call.value:
if step.violation: if isinstance(step, ShieldCallStep):
# CompletionMessage itself in the ShieldResponse if step.violation:
messages.append( # CompletionMessage itself in the ShieldResponse
CompletionMessage( messages.append(
content=step.violation.user_message, CompletionMessage(
stop_reason=StopReason.end_of_turn, content=step.violation.user_message or "",
stop_reason=StopReason.end_of_turn,
)
) )
)
return messages return messages
async def create_session(self, name: str) -> str: async def create_session(self, name: str) -> str:
return await self.storage.create_session(name) return await self.storage.create_session(name)
async def get_messages_from_turns(self, turns: list[Turn]) -> list[Message]: async def get_messages_from_turns(self, turns: list[Turn]) -> list[Message]:
messages = [] messages: list[Message] = []
if self.agent_config.instructions != "": if self.agent_config.instructions != "":
messages.append(SystemMessage(content=self.agent_config.instructions)) messages.append(SystemMessage(content=self.agent_config.instructions))
@ -174,7 +179,9 @@ class ChatAgent(ShieldRunnerMixin):
messages.extend(self.turn_to_messages(turn)) messages.extend(self.turn_to_messages(turn))
return messages return messages
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator: async def create_and_execute_turn(
self, request: AgentTurnCreateRequest
) -> AsyncGenerator[AgentTurnResponseStreamChunk, None]:
span = tracing.get_current_span() span = tracing.get_current_span()
if span: if span:
span.set_attribute("session_id", request.session_id) span.set_attribute("session_id", request.session_id)
@ -189,7 +196,7 @@ class ChatAgent(ShieldRunnerMixin):
async for chunk in self._run_turn(request, turn_id): async for chunk in self._run_turn(request, turn_id):
yield chunk yield chunk
async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator: async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator[AgentTurnResponseStreamChunk, None]:
span = tracing.get_current_span() span = tracing.get_current_span()
if span: if span:
span.set_attribute("agent_id", self.agent_id) span.set_attribute("agent_id", self.agent_id)
@ -207,7 +214,7 @@ class ChatAgent(ShieldRunnerMixin):
self, self,
request: AgentTurnCreateRequest | AgentTurnResumeRequest, request: AgentTurnCreateRequest | AgentTurnResumeRequest,
turn_id: str | None = None, turn_id: str | None = None,
) -> AsyncGenerator: ) -> AsyncGenerator[AgentTurnResponseStreamChunk, None]:
assert request.stream is True, "Non-streaming not supported" assert request.stream is True, "Non-streaming not supported"
is_resume = isinstance(request, AgentTurnResumeRequest) is_resume = isinstance(request, AgentTurnResumeRequest)
@ -271,11 +278,13 @@ class ChatAgent(ShieldRunnerMixin):
input_messages = request.messages input_messages = request.messages
output_message = None output_message = None
turn_id_final = turn_id or str(uuid.uuid4())
sampling_params = self.agent_config.sampling_params or SamplingParams()
async for chunk in self.run( async for chunk in self.run(
session_id=request.session_id, session_id=request.session_id,
turn_id=turn_id, turn_id=turn_id_final,
input_messages=messages, input_messages=messages,
sampling_params=self.agent_config.sampling_params, sampling_params=sampling_params,
stream=request.stream, stream=request.stream,
documents=request.documents if not is_resume else None, documents=request.documents if not is_resume else None,
): ):
@ -286,19 +295,20 @@ class ChatAgent(ShieldRunnerMixin):
assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}" assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}"
event = chunk.event event = chunk.event
if event.payload.event_type == AgentTurnResponseEventType.step_complete.value: if event.payload.event_type == AgentTurnResponseEventType.step_complete.value:
steps.append(event.payload.step_details) if hasattr(event.payload, "step_details"):
steps.append(event.payload.step_details)
yield chunk yield chunk
assert output_message is not None assert output_message is not None
turn = Turn( turn = Turn(
turn_id=turn_id, turn_id=turn_id_final,
session_id=request.session_id, session_id=request.session_id,
input_messages=input_messages, input_messages=input_messages,
output_message=output_message, output_message=output_message,
started_at=start_time, started_at=start_time,
completed_at=datetime.now(UTC).isoformat(), completed_at=datetime.now(UTC),
steps=steps, steps=steps,
) )
await self.storage.add_turn_to_session(request.session_id, turn) await self.storage.add_turn_to_session(request.session_id, turn)
@ -329,7 +339,7 @@ class ChatAgent(ShieldRunnerMixin):
sampling_params: SamplingParams, sampling_params: SamplingParams,
stream: bool = False, stream: bool = False,
documents: list[Document] | None = None, documents: list[Document] | None = None,
) -> AsyncGenerator: ) -> AsyncGenerator[AgentTurnResponseStreamChunk | CompletionMessage, None]:
# Doing async generators makes downstream code much simpler and everything amenable to # Doing async generators makes downstream code much simpler and everything amenable to
# streaming. However, it also makes things complicated here because AsyncGenerators cannot # streaming. However, it also makes things complicated here because AsyncGenerators cannot
# return a "final value" for the `yield from` statement. we simulate that by yielding a # return a "final value" for the `yield from` statement. we simulate that by yielding a
@ -381,7 +391,7 @@ class ChatAgent(ShieldRunnerMixin):
messages: list[Message], messages: list[Message],
shields: list[str], shields: list[str],
touchpoint: str, touchpoint: str,
) -> AsyncGenerator: ) -> AsyncGenerator[AgentTurnResponseStreamChunk | CompletionMessage | bool, None]:
async with tracing.span("run_shields") as span: async with tracing.span("run_shields") as span:
span.set_attribute("input", [m.model_dump_json() for m in messages]) span.set_attribute("input", [m.model_dump_json() for m in messages])
if len(shields) == 0: if len(shields) == 0:
@ -389,12 +399,12 @@ class ChatAgent(ShieldRunnerMixin):
return return
step_id = str(uuid.uuid4()) step_id = str(uuid.uuid4())
shield_call_start_time = datetime.now(UTC).isoformat() shield_call_start_time = datetime.now(UTC)
try: try:
yield AgentTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent( event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepStartPayload( payload=AgentTurnResponseStepStartPayload(
step_type=StepType.shield_call.value, step_type=StepType.shield_call,
step_id=step_id, step_id=step_id,
metadata=dict(touchpoint=touchpoint), metadata=dict(touchpoint=touchpoint),
) )
@ -406,14 +416,14 @@ class ChatAgent(ShieldRunnerMixin):
yield AgentTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent( event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload( payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.shield_call.value, step_type=StepType.shield_call,
step_id=step_id, step_id=step_id,
step_details=ShieldCallStep( step_details=ShieldCallStep(
step_id=step_id, step_id=step_id,
turn_id=turn_id, turn_id=turn_id,
violation=e.violation, violation=e.violation,
started_at=shield_call_start_time, started_at=shield_call_start_time,
completed_at=datetime.now(UTC).isoformat(), completed_at=datetime.now(UTC),
), ),
) )
) )
@ -429,14 +439,14 @@ class ChatAgent(ShieldRunnerMixin):
yield AgentTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent( event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload( payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.shield_call.value, step_type=StepType.shield_call,
step_id=step_id, step_id=step_id,
step_details=ShieldCallStep( step_details=ShieldCallStep(
step_id=step_id, step_id=step_id,
turn_id=turn_id, turn_id=turn_id,
violation=None, violation=None,
started_at=shield_call_start_time, started_at=shield_call_start_time,
completed_at=datetime.now(UTC).isoformat(), completed_at=datetime.now(UTC),
), ),
) )
) )
@ -451,7 +461,7 @@ class ChatAgent(ShieldRunnerMixin):
sampling_params: SamplingParams, sampling_params: SamplingParams,
stream: bool = False, stream: bool = False,
documents: list[Document] | None = None, documents: list[Document] | None = None,
) -> AsyncGenerator: ) -> AsyncGenerator[AgentTurnResponseStreamChunk | CompletionMessage | bool, None]:
# if document is passed in a turn, we parse the raw text of the document # if document is passed in a turn, we parse the raw text of the document
# and sent it as a user message # and sent it as a user message
if documents: if documents:
@ -481,43 +491,46 @@ class ChatAgent(ShieldRunnerMixin):
else: else:
self.tool_name_to_args[tool_name]["vector_db_ids"].append(session_info.vector_db_id) self.tool_name_to_args[tool_name]["vector_db_ids"].append(session_info.vector_db_id)
output_attachments = [] output_attachments: list[Attachment] = []
n_iter = await self.storage.get_num_infer_iters_in_turn(session_id, turn_id) or 0 n_iter = await self.storage.get_num_infer_iters_in_turn(session_id, turn_id) or 0
# Build a map of custom tools to their definitions for faster lookup # Build a map of custom tools to their definitions for faster lookup
client_tools = {} client_tools: dict[str, Any] = {}
for tool in self.agent_config.client_tools: for tool in self.agent_config.client_tools or []:
client_tools[tool.name] = tool client_tools[tool.name] = tool
while True: while True:
step_id = str(uuid.uuid4()) step_id = str(uuid.uuid4())
inference_start_time = datetime.now(UTC).isoformat() inference_start_time = datetime.now(UTC)
yield AgentTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent( event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepStartPayload( payload=AgentTurnResponseStepStartPayload(
step_type=StepType.inference.value, step_type=StepType.inference,
step_id=step_id, step_id=step_id,
) )
) )
) )
tool_calls = [] tool_calls: list[ToolCall] = []
content = "" content = ""
stop_reason = None stop_reason = None
async with tracing.span("inference") as span: async with tracing.span("inference") as span:
if self.agent_config.name: if self.agent_config.name:
span.set_attribute("agent_name", self.agent_config.name) span.set_attribute("agent_name", self.agent_config.name)
async for chunk in await self.inference_api.chat_completion( chat_completion_response = await self.inference_api.chat_completion(
self.agent_config.model, self.agent_config.model,
input_messages, input_messages,
tools=self.tool_defs, tools=self.tool_defs,
tool_prompt_format=self.agent_config.tool_config.tool_prompt_format, tool_prompt_format=self.agent_config.tool_config.tool_prompt_format
if self.agent_config.tool_config
else None,
response_format=self.agent_config.response_format, response_format=self.agent_config.response_format,
stream=True, stream=True,
sampling_params=sampling_params, sampling_params=sampling_params,
tool_config=self.agent_config.tool_config, tool_config=self.agent_config.tool_config,
): )
async for chunk in chat_completion_response:
event = chunk.event event = chunk.event
if event.event_type == ChatCompletionResponseEventType.start: if event.event_type == ChatCompletionResponseEventType.start:
continue continue
@ -527,16 +540,18 @@ class ChatAgent(ShieldRunnerMixin):
delta = event.delta delta = event.delta
if delta.type == "tool_call": if delta.type == "tool_call":
if delta.parse_status == ToolCallParseStatus.succeeded: if hasattr(delta, "parse_status") and delta.parse_status == ToolCallParseStatus.succeeded:
tool_calls.append(delta.tool_call) if hasattr(delta, "tool_call"):
elif delta.parse_status == ToolCallParseStatus.failed: tool_calls.append(delta.tool_call)
elif hasattr(delta, "parse_status") and delta.parse_status == ToolCallParseStatus.failed:
# If we cannot parse the tools, set the content to the unparsed raw text # If we cannot parse the tools, set the content to the unparsed raw text
content = delta.tool_call if hasattr(delta, "tool_call"):
content = str(delta.tool_call)
if stream: if stream:
yield AgentTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent( event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload( payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.inference.value, step_type=StepType.inference,
step_id=step_id, step_id=step_id,
delta=delta, delta=delta,
) )
@ -544,12 +559,13 @@ class ChatAgent(ShieldRunnerMixin):
) )
elif delta.type == "text": elif delta.type == "text":
content += delta.text if hasattr(delta, "text"):
content += delta.text
if stream and event.stop_reason is None: if stream and event.stop_reason is None:
yield AgentTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent( event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload( payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.inference.value, step_type=StepType.inference,
step_id=step_id, step_id=step_id,
delta=delta, delta=delta,
) )
@ -565,10 +581,16 @@ class ChatAgent(ShieldRunnerMixin):
"input", "input",
json.dumps([json.loads(m.model_dump_json()) for m in input_messages]), json.dumps([json.loads(m.model_dump_json()) for m in input_messages]),
) )
tool_calls_json = []
for t in tool_calls:
if hasattr(t, "model_dump_json"):
tool_calls_json.append(json.loads(t.model_dump_json()))
else:
tool_calls_json.append(str(t))
output_attr = json.dumps( output_attr = json.dumps(
{ {
"content": content, "content": content,
"tool_calls": [json.loads(t.model_dump_json()) for t in tool_calls], "tool_calls": tool_calls_json,
} }
) )
span.set_attribute("output", output_attr) span.set_attribute("output", output_attr)
@ -593,7 +615,7 @@ class ChatAgent(ShieldRunnerMixin):
yield AgentTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent( event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload( payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.inference.value, step_type=StepType.inference,
step_id=step_id, step_id=step_id,
step_details=InferenceStep( step_details=InferenceStep(
# somewhere deep, we are re-assigning message or closing over some # somewhere deep, we are re-assigning message or closing over some
@ -603,13 +625,14 @@ class ChatAgent(ShieldRunnerMixin):
turn_id=turn_id, turn_id=turn_id,
model_response=copy.deepcopy(message), model_response=copy.deepcopy(message),
started_at=inference_start_time, started_at=inference_start_time,
completed_at=datetime.now(UTC).isoformat(), completed_at=datetime.now(UTC),
), ),
) )
) )
) )
if n_iter >= self.agent_config.max_infer_iters: max_infer_iters = self.agent_config.max_infer_iters or 10
if n_iter >= max_infer_iters:
logger.info(f"done with MAX iterations ({n_iter}), exiting.") logger.info(f"done with MAX iterations ({n_iter}), exiting.")
# NOTE: mark end_of_turn to indicate to client that we are done with the turn # NOTE: mark end_of_turn to indicate to client that we are done with the turn
# Do not continue the tool call loop after this point # Do not continue the tool call loop after this point
@ -622,7 +645,8 @@ class ChatAgent(ShieldRunnerMixin):
yield message yield message
break break
if len(message.tool_calls) == 0: tool_calls_to_process = message.tool_calls or []
if len(tool_calls_to_process) == 0:
if stop_reason == StopReason.end_of_turn: if stop_reason == StopReason.end_of_turn:
# TODO: UPDATE RETURN TYPE TO SEND A TUPLE OF (MESSAGE, ATTACHMENTS) # TODO: UPDATE RETURN TYPE TO SEND A TUPLE OF (MESSAGE, ATTACHMENTS)
if len(output_attachments) > 0: if len(output_attachments) > 0:
@ -642,7 +666,7 @@ class ChatAgent(ShieldRunnerMixin):
non_client_tool_calls = [] non_client_tool_calls = []
# Separate client and non-client tool calls # Separate client and non-client tool calls
for tool_call in message.tool_calls: for tool_call in tool_calls_to_process:
if tool_call.tool_name in client_tools: if tool_call.tool_name in client_tools:
client_tool_calls.append(tool_call) client_tool_calls.append(tool_call)
else: else:
@ -654,7 +678,7 @@ class ChatAgent(ShieldRunnerMixin):
yield AgentTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent( event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepStartPayload( payload=AgentTurnResponseStepStartPayload(
step_type=StepType.tool_execution.value, step_type=StepType.tool_execution,
step_id=step_id, step_id=step_id,
) )
) )
@ -663,7 +687,7 @@ class ChatAgent(ShieldRunnerMixin):
yield AgentTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent( event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload( payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.tool_execution.value, step_type=StepType.tool_execution,
step_id=step_id, step_id=step_id,
delta=ToolCallDelta( delta=ToolCallDelta(
parse_status=ToolCallParseStatus.in_progress, parse_status=ToolCallParseStatus.in_progress,
@ -681,7 +705,7 @@ class ChatAgent(ShieldRunnerMixin):
"input": message.model_dump_json(), "input": message.model_dump_json(),
}, },
) as span: ) as span:
tool_execution_start_time = datetime.now(UTC).isoformat() tool_execution_start_time = datetime.now(UTC)
tool_result = await self.execute_tool_call_maybe( tool_result = await self.execute_tool_call_maybe(
session_id, session_id,
tool_call, tool_call,
@ -710,14 +734,14 @@ class ChatAgent(ShieldRunnerMixin):
) )
], ],
started_at=tool_execution_start_time, started_at=tool_execution_start_time,
completed_at=datetime.now(UTC).isoformat(), completed_at=datetime.now(UTC),
) )
# Yield the step completion event # Yield the step completion event
yield AgentTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent( event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload( payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.tool_execution.value, step_type=StepType.tool_execution,
step_id=step_id, step_id=step_id,
step_details=tool_execution_step, step_details=tool_execution_step,
) )
@ -747,7 +771,7 @@ class ChatAgent(ShieldRunnerMixin):
turn_id=turn_id, turn_id=turn_id,
tool_calls=client_tool_calls, tool_calls=client_tool_calls,
tool_responses=[], tool_responses=[],
started_at=datetime.now(UTC).isoformat(), started_at=datetime.now(UTC),
), ),
) )
@ -766,7 +790,7 @@ class ChatAgent(ShieldRunnerMixin):
self, self,
toolgroups_for_turn: list[AgentToolGroup] | None = None, toolgroups_for_turn: list[AgentToolGroup] | None = None,
) -> None: ) -> None:
toolgroup_to_args = {} toolgroup_to_args: dict[str, Any] = {}
for toolgroup in (self.agent_config.toolgroups or []) + (toolgroups_for_turn or []): for toolgroup in (self.agent_config.toolgroups or []) + (toolgroups_for_turn or []):
if isinstance(toolgroup, AgentToolGroupWithArgs): if isinstance(toolgroup, AgentToolGroupWithArgs):
tool_group_name, _ = self._parse_toolgroup_name(toolgroup.name) tool_group_name, _ = self._parse_toolgroup_name(toolgroup.name)
@ -782,10 +806,10 @@ class ChatAgent(ShieldRunnerMixin):
toolgroup_to_args = toolgroup_to_args or {} toolgroup_to_args = toolgroup_to_args or {}
tool_name_to_def = {} tool_name_to_def: dict[str, ToolDefinition] = {}
tool_name_to_args = {} tool_name_to_args: dict[str, Any] = {}
for tool_def in self.agent_config.client_tools: for tool_def in self.agent_config.client_tools or []:
if tool_name_to_def.get(tool_def.name, None): if tool_name_to_def.get(tool_def.name, None):
raise ValueError(f"Tool {tool_def.name} already exists") raise ValueError(f"Tool {tool_def.name} already exists")
tool_name_to_def[tool_def.name] = ToolDefinition( tool_name_to_def[tool_def.name] = ToolDefinition(
@ -798,7 +822,7 @@ class ChatAgent(ShieldRunnerMixin):
required=param.required, required=param.required,
default=param.default, default=param.default,
) )
for param in tool_def.parameters for param in tool_def.parameters or []
}, },
) )
for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups: for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups:
@ -828,7 +852,7 @@ class ChatAgent(ShieldRunnerMixin):
else: else:
identifier = None identifier = None
if tool_name_to_def.get(identifier, None): if identifier is not None and tool_name_to_def.get(str(identifier), None):
raise ValueError(f"Tool {identifier} already exists") raise ValueError(f"Tool {identifier} already exists")
if identifier: if identifier:
tool_name_to_def[tool_def.identifier] = ToolDefinition( tool_name_to_def[tool_def.identifier] = ToolDefinition(
@ -841,7 +865,7 @@ class ChatAgent(ShieldRunnerMixin):
required=param.required, required=param.required,
default=param.default, default=param.default,
) )
for param in tool_def.parameters for param in tool_def.parameters or []
}, },
) )
tool_name_to_args[tool_def.identifier] = toolgroup_to_args.get(toolgroup_name, {}) tool_name_to_args[tool_def.identifier] = toolgroup_to_args.get(toolgroup_name, {})
@ -888,14 +912,15 @@ class ChatAgent(ShieldRunnerMixin):
tool_name_str = tool_name tool_name_str = tool_name
logger.info(f"executing tool call: {tool_name_str} with args: {tool_call.arguments}") logger.info(f"executing tool call: {tool_name_str} with args: {tool_call.arguments}")
kwargs_dict = {"session_id": session_id}
if tool_call.arguments and isinstance(tool_call.arguments, dict):
kwargs_dict.update(tool_call.arguments)
tool_args = self.tool_name_to_args.get(tool_name_str, {})
if tool_args:
kwargs_dict.update(tool_args)
result = await self.tool_runtime_api.invoke_tool( result = await self.tool_runtime_api.invoke_tool(
tool_name=tool_name_str, tool_name=tool_name_str,
kwargs={ kwargs=kwargs_dict,
"session_id": session_id,
# get the arguments generated by the model and augment with toolgroup arg overrides for the agent
**tool_call.arguments,
**self.tool_name_to_args.get(tool_name_str, {}),
},
) )
logger.debug(f"tool call {tool_name_str} completed with result: {result}") logger.debug(f"tool call {tool_name_str} completed with result: {result}")
return result return result
@ -931,7 +956,7 @@ def _interpret_content_as_attachment(
snippet = match.group(1) snippet = match.group(1)
data = json.loads(snippet) data = json.loads(snippet)
return Attachment( return Attachment(
url=URL(uri="file://" + data["filepath"]), content=URL(uri="file://" + data["filepath"]),
mime_type=data["mimetype"], mime_type=data["mimetype"],
) )

View file

@ -28,8 +28,8 @@ class ShieldRunnerMixin:
output_shields: list[str] | None = None, output_shields: list[str] | None = None,
): ):
self.safety_api = safety_api self.safety_api = safety_api
self.input_shields = input_shields self.input_shields = input_shields or []
self.output_shields = output_shields self.output_shields = output_shields or []
async def run_multiple_shields(self, messages: list[Message], identifiers: list[str]) -> None: async def run_multiple_shields(self, messages: list[Message], identifiers: list[str]) -> None:
async def run_shield_with_span(identifier: str): async def run_shield_with_span(identifier: str):

View file

@ -243,7 +243,6 @@ exclude = [
"^llama_stack/models/llama/llama3/tokenizer\\.py$", "^llama_stack/models/llama/llama3/tokenizer\\.py$",
"^llama_stack/models/llama/llama3/tool_utils\\.py$", "^llama_stack/models/llama/llama3/tool_utils\\.py$",
"^llama_stack/providers/inline/agents/meta_reference/", "^llama_stack/providers/inline/agents/meta_reference/",
"^llama_stack/providers/inline/agents/meta_reference/agent_instance\\.py$",
"^llama_stack/providers/inline/agents/meta_reference/agents\\.py$", "^llama_stack/providers/inline/agents/meta_reference/agents\\.py$",
"^llama_stack/providers/inline/datasetio/localfs/", "^llama_stack/providers/inline/datasetio/localfs/",
"^llama_stack/providers/inline/eval/meta_reference/eval\\.py$", "^llama_stack/providers/inline/eval/meta_reference/eval\\.py$",