mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-28 15:02:37 +00:00
chore(api): add mypy coverage to meta_reference_agent_instance
Signed-off-by: Mustafa Elbehery <melbeher@redhat.com>
This commit is contained in:
parent
81109a0f72
commit
1d52c303d1
3 changed files with 105 additions and 81 deletions
|
@ -12,6 +12,7 @@ import string
|
|||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
|
@ -73,7 +74,7 @@ from .persistence import AgentPersistence
|
|||
from .safety import SafetyException, ShieldRunnerMixin
|
||||
|
||||
|
||||
def make_random_string(length: int = 8):
|
||||
def make_random_string(length: int = 8) -> str:
|
||||
return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length))
|
||||
|
||||
|
||||
|
@ -117,14 +118,15 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
|
||||
def turn_to_messages(self, turn: Turn) -> list[Message]:
|
||||
messages = []
|
||||
messages: list[Message] = []
|
||||
|
||||
# NOTE: if a toolcall response is in a step, we do not add it when processing the input messages
|
||||
tool_call_ids = set()
|
||||
for step in turn.steps:
|
||||
if step.step_type == StepType.tool_execution.value:
|
||||
for response in step.tool_responses:
|
||||
tool_call_ids.add(response.call_id)
|
||||
if isinstance(step, ToolExecutionStep):
|
||||
for response in step.tool_responses:
|
||||
tool_call_ids.add(response.call_id)
|
||||
|
||||
for m in turn.input_messages:
|
||||
msg = m.model_copy()
|
||||
|
@ -142,31 +144,34 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
for step in turn.steps:
|
||||
if step.step_type == StepType.inference.value:
|
||||
messages.append(step.model_response)
|
||||
if isinstance(step, InferenceStep):
|
||||
messages.append(step.model_response)
|
||||
elif step.step_type == StepType.tool_execution.value:
|
||||
for response in step.tool_responses:
|
||||
messages.append(
|
||||
ToolResponseMessage(
|
||||
call_id=response.call_id,
|
||||
content=response.content,
|
||||
if isinstance(step, ToolExecutionStep):
|
||||
for response in step.tool_responses:
|
||||
messages.append(
|
||||
ToolResponseMessage(
|
||||
call_id=response.call_id,
|
||||
content=response.content,
|
||||
)
|
||||
)
|
||||
)
|
||||
elif step.step_type == StepType.shield_call.value:
|
||||
if step.violation:
|
||||
# CompletionMessage itself in the ShieldResponse
|
||||
messages.append(
|
||||
CompletionMessage(
|
||||
content=step.violation.user_message,
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
if isinstance(step, ShieldCallStep):
|
||||
if step.violation:
|
||||
# CompletionMessage itself in the ShieldResponse
|
||||
messages.append(
|
||||
CompletionMessage(
|
||||
content=step.violation.user_message or "",
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
)
|
||||
)
|
||||
)
|
||||
return messages
|
||||
|
||||
async def create_session(self, name: str) -> str:
|
||||
return await self.storage.create_session(name)
|
||||
|
||||
async def get_messages_from_turns(self, turns: list[Turn]) -> list[Message]:
|
||||
messages = []
|
||||
messages: list[Message] = []
|
||||
if self.agent_config.instructions != "":
|
||||
messages.append(SystemMessage(content=self.agent_config.instructions))
|
||||
|
||||
|
@ -174,7 +179,9 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
messages.extend(self.turn_to_messages(turn))
|
||||
return messages
|
||||
|
||||
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
|
||||
async def create_and_execute_turn(
|
||||
self, request: AgentTurnCreateRequest
|
||||
) -> AsyncGenerator[AgentTurnResponseStreamChunk, None]:
|
||||
span = tracing.get_current_span()
|
||||
if span:
|
||||
span.set_attribute("session_id", request.session_id)
|
||||
|
@ -189,7 +196,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
async for chunk in self._run_turn(request, turn_id):
|
||||
yield chunk
|
||||
|
||||
async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator:
|
||||
async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator[AgentTurnResponseStreamChunk, None]:
|
||||
span = tracing.get_current_span()
|
||||
if span:
|
||||
span.set_attribute("agent_id", self.agent_id)
|
||||
|
@ -207,7 +214,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
self,
|
||||
request: AgentTurnCreateRequest | AgentTurnResumeRequest,
|
||||
turn_id: str | None = None,
|
||||
) -> AsyncGenerator:
|
||||
) -> AsyncGenerator[AgentTurnResponseStreamChunk, None]:
|
||||
assert request.stream is True, "Non-streaming not supported"
|
||||
|
||||
is_resume = isinstance(request, AgentTurnResumeRequest)
|
||||
|
@ -271,11 +278,13 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
input_messages = request.messages
|
||||
|
||||
output_message = None
|
||||
turn_id_final = turn_id or str(uuid.uuid4())
|
||||
sampling_params = self.agent_config.sampling_params or SamplingParams()
|
||||
async for chunk in self.run(
|
||||
session_id=request.session_id,
|
||||
turn_id=turn_id,
|
||||
turn_id=turn_id_final,
|
||||
input_messages=messages,
|
||||
sampling_params=self.agent_config.sampling_params,
|
||||
sampling_params=sampling_params,
|
||||
stream=request.stream,
|
||||
documents=request.documents if not is_resume else None,
|
||||
):
|
||||
|
@ -286,19 +295,20 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}"
|
||||
event = chunk.event
|
||||
if event.payload.event_type == AgentTurnResponseEventType.step_complete.value:
|
||||
steps.append(event.payload.step_details)
|
||||
if hasattr(event.payload, "step_details"):
|
||||
steps.append(event.payload.step_details)
|
||||
|
||||
yield chunk
|
||||
|
||||
assert output_message is not None
|
||||
|
||||
turn = Turn(
|
||||
turn_id=turn_id,
|
||||
turn_id=turn_id_final,
|
||||
session_id=request.session_id,
|
||||
input_messages=input_messages,
|
||||
output_message=output_message,
|
||||
started_at=start_time,
|
||||
completed_at=datetime.now(UTC).isoformat(),
|
||||
completed_at=datetime.now(UTC),
|
||||
steps=steps,
|
||||
)
|
||||
await self.storage.add_turn_to_session(request.session_id, turn)
|
||||
|
@ -329,7 +339,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
sampling_params: SamplingParams,
|
||||
stream: bool = False,
|
||||
documents: list[Document] | None = None,
|
||||
) -> AsyncGenerator:
|
||||
) -> AsyncGenerator[AgentTurnResponseStreamChunk | CompletionMessage, None]:
|
||||
# Doing async generators makes downstream code much simpler and everything amenable to
|
||||
# streaming. However, it also makes things complicated here because AsyncGenerators cannot
|
||||
# return a "final value" for the `yield from` statement. we simulate that by yielding a
|
||||
|
@ -381,7 +391,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
messages: list[Message],
|
||||
shields: list[str],
|
||||
touchpoint: str,
|
||||
) -> AsyncGenerator:
|
||||
) -> AsyncGenerator[AgentTurnResponseStreamChunk | CompletionMessage | bool, None]:
|
||||
async with tracing.span("run_shields") as span:
|
||||
span.set_attribute("input", [m.model_dump_json() for m in messages])
|
||||
if len(shields) == 0:
|
||||
|
@ -389,12 +399,12 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
return
|
||||
|
||||
step_id = str(uuid.uuid4())
|
||||
shield_call_start_time = datetime.now(UTC).isoformat()
|
||||
shield_call_start_time = datetime.now(UTC)
|
||||
try:
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepStartPayload(
|
||||
step_type=StepType.shield_call.value,
|
||||
step_type=StepType.shield_call,
|
||||
step_id=step_id,
|
||||
metadata=dict(touchpoint=touchpoint),
|
||||
)
|
||||
|
@ -406,14 +416,14 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepCompletePayload(
|
||||
step_type=StepType.shield_call.value,
|
||||
step_type=StepType.shield_call,
|
||||
step_id=step_id,
|
||||
step_details=ShieldCallStep(
|
||||
step_id=step_id,
|
||||
turn_id=turn_id,
|
||||
violation=e.violation,
|
||||
started_at=shield_call_start_time,
|
||||
completed_at=datetime.now(UTC).isoformat(),
|
||||
completed_at=datetime.now(UTC),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
@ -429,14 +439,14 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepCompletePayload(
|
||||
step_type=StepType.shield_call.value,
|
||||
step_type=StepType.shield_call,
|
||||
step_id=step_id,
|
||||
step_details=ShieldCallStep(
|
||||
step_id=step_id,
|
||||
turn_id=turn_id,
|
||||
violation=None,
|
||||
started_at=shield_call_start_time,
|
||||
completed_at=datetime.now(UTC).isoformat(),
|
||||
completed_at=datetime.now(UTC),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
@ -451,7 +461,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
sampling_params: SamplingParams,
|
||||
stream: bool = False,
|
||||
documents: list[Document] | None = None,
|
||||
) -> AsyncGenerator:
|
||||
) -> AsyncGenerator[AgentTurnResponseStreamChunk | CompletionMessage | bool, None]:
|
||||
# if document is passed in a turn, we parse the raw text of the document
|
||||
# and sent it as a user message
|
||||
if documents:
|
||||
|
@ -481,43 +491,46 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
else:
|
||||
self.tool_name_to_args[tool_name]["vector_db_ids"].append(session_info.vector_db_id)
|
||||
|
||||
output_attachments = []
|
||||
output_attachments: list[Attachment] = []
|
||||
|
||||
n_iter = await self.storage.get_num_infer_iters_in_turn(session_id, turn_id) or 0
|
||||
|
||||
# Build a map of custom tools to their definitions for faster lookup
|
||||
client_tools = {}
|
||||
for tool in self.agent_config.client_tools:
|
||||
client_tools: dict[str, Any] = {}
|
||||
for tool in self.agent_config.client_tools or []:
|
||||
client_tools[tool.name] = tool
|
||||
while True:
|
||||
step_id = str(uuid.uuid4())
|
||||
inference_start_time = datetime.now(UTC).isoformat()
|
||||
inference_start_time = datetime.now(UTC)
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepStartPayload(
|
||||
step_type=StepType.inference.value,
|
||||
step_type=StepType.inference,
|
||||
step_id=step_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
tool_calls = []
|
||||
tool_calls: list[ToolCall] = []
|
||||
content = ""
|
||||
stop_reason = None
|
||||
|
||||
async with tracing.span("inference") as span:
|
||||
if self.agent_config.name:
|
||||
span.set_attribute("agent_name", self.agent_config.name)
|
||||
async for chunk in await self.inference_api.chat_completion(
|
||||
chat_completion_response = await self.inference_api.chat_completion(
|
||||
self.agent_config.model,
|
||||
input_messages,
|
||||
tools=self.tool_defs,
|
||||
tool_prompt_format=self.agent_config.tool_config.tool_prompt_format,
|
||||
tool_prompt_format=self.agent_config.tool_config.tool_prompt_format
|
||||
if self.agent_config.tool_config
|
||||
else None,
|
||||
response_format=self.agent_config.response_format,
|
||||
stream=True,
|
||||
sampling_params=sampling_params,
|
||||
tool_config=self.agent_config.tool_config,
|
||||
):
|
||||
)
|
||||
async for chunk in chat_completion_response:
|
||||
event = chunk.event
|
||||
if event.event_type == ChatCompletionResponseEventType.start:
|
||||
continue
|
||||
|
@ -527,16 +540,18 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
delta = event.delta
|
||||
if delta.type == "tool_call":
|
||||
if delta.parse_status == ToolCallParseStatus.succeeded:
|
||||
tool_calls.append(delta.tool_call)
|
||||
elif delta.parse_status == ToolCallParseStatus.failed:
|
||||
if hasattr(delta, "parse_status") and delta.parse_status == ToolCallParseStatus.succeeded:
|
||||
if hasattr(delta, "tool_call"):
|
||||
tool_calls.append(delta.tool_call)
|
||||
elif hasattr(delta, "parse_status") and delta.parse_status == ToolCallParseStatus.failed:
|
||||
# If we cannot parse the tools, set the content to the unparsed raw text
|
||||
content = delta.tool_call
|
||||
if hasattr(delta, "tool_call"):
|
||||
content = str(delta.tool_call)
|
||||
if stream:
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepProgressPayload(
|
||||
step_type=StepType.inference.value,
|
||||
step_type=StepType.inference,
|
||||
step_id=step_id,
|
||||
delta=delta,
|
||||
)
|
||||
|
@ -544,12 +559,13 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
|
||||
elif delta.type == "text":
|
||||
content += delta.text
|
||||
if hasattr(delta, "text"):
|
||||
content += delta.text
|
||||
if stream and event.stop_reason is None:
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepProgressPayload(
|
||||
step_type=StepType.inference.value,
|
||||
step_type=StepType.inference,
|
||||
step_id=step_id,
|
||||
delta=delta,
|
||||
)
|
||||
|
@ -565,10 +581,16 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
"input",
|
||||
json.dumps([json.loads(m.model_dump_json()) for m in input_messages]),
|
||||
)
|
||||
tool_calls_json = []
|
||||
for t in tool_calls:
|
||||
if hasattr(t, "model_dump_json"):
|
||||
tool_calls_json.append(json.loads(t.model_dump_json()))
|
||||
else:
|
||||
tool_calls_json.append(str(t))
|
||||
output_attr = json.dumps(
|
||||
{
|
||||
"content": content,
|
||||
"tool_calls": [json.loads(t.model_dump_json()) for t in tool_calls],
|
||||
"tool_calls": tool_calls_json,
|
||||
}
|
||||
)
|
||||
span.set_attribute("output", output_attr)
|
||||
|
@ -593,7 +615,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepCompletePayload(
|
||||
step_type=StepType.inference.value,
|
||||
step_type=StepType.inference,
|
||||
step_id=step_id,
|
||||
step_details=InferenceStep(
|
||||
# somewhere deep, we are re-assigning message or closing over some
|
||||
|
@ -603,13 +625,14 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
turn_id=turn_id,
|
||||
model_response=copy.deepcopy(message),
|
||||
started_at=inference_start_time,
|
||||
completed_at=datetime.now(UTC).isoformat(),
|
||||
completed_at=datetime.now(UTC),
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if n_iter >= self.agent_config.max_infer_iters:
|
||||
max_infer_iters = self.agent_config.max_infer_iters or 10
|
||||
if n_iter >= max_infer_iters:
|
||||
logger.info(f"done with MAX iterations ({n_iter}), exiting.")
|
||||
# NOTE: mark end_of_turn to indicate to client that we are done with the turn
|
||||
# Do not continue the tool call loop after this point
|
||||
|
@ -622,7 +645,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
yield message
|
||||
break
|
||||
|
||||
if len(message.tool_calls) == 0:
|
||||
tool_calls_to_process = message.tool_calls or []
|
||||
if len(tool_calls_to_process) == 0:
|
||||
if stop_reason == StopReason.end_of_turn:
|
||||
# TODO: UPDATE RETURN TYPE TO SEND A TUPLE OF (MESSAGE, ATTACHMENTS)
|
||||
if len(output_attachments) > 0:
|
||||
|
@ -642,7 +666,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
non_client_tool_calls = []
|
||||
|
||||
# Separate client and non-client tool calls
|
||||
for tool_call in message.tool_calls:
|
||||
for tool_call in tool_calls_to_process:
|
||||
if tool_call.tool_name in client_tools:
|
||||
client_tool_calls.append(tool_call)
|
||||
else:
|
||||
|
@ -654,7 +678,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepStartPayload(
|
||||
step_type=StepType.tool_execution.value,
|
||||
step_type=StepType.tool_execution,
|
||||
step_id=step_id,
|
||||
)
|
||||
)
|
||||
|
@ -663,7 +687,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepProgressPayload(
|
||||
step_type=StepType.tool_execution.value,
|
||||
step_type=StepType.tool_execution,
|
||||
step_id=step_id,
|
||||
delta=ToolCallDelta(
|
||||
parse_status=ToolCallParseStatus.in_progress,
|
||||
|
@ -681,7 +705,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
"input": message.model_dump_json(),
|
||||
},
|
||||
) as span:
|
||||
tool_execution_start_time = datetime.now(UTC).isoformat()
|
||||
tool_execution_start_time = datetime.now(UTC)
|
||||
tool_result = await self.execute_tool_call_maybe(
|
||||
session_id,
|
||||
tool_call,
|
||||
|
@ -710,14 +734,14 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
],
|
||||
started_at=tool_execution_start_time,
|
||||
completed_at=datetime.now(UTC).isoformat(),
|
||||
completed_at=datetime.now(UTC),
|
||||
)
|
||||
|
||||
# Yield the step completion event
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepCompletePayload(
|
||||
step_type=StepType.tool_execution.value,
|
||||
step_type=StepType.tool_execution,
|
||||
step_id=step_id,
|
||||
step_details=tool_execution_step,
|
||||
)
|
||||
|
@ -747,7 +771,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
turn_id=turn_id,
|
||||
tool_calls=client_tool_calls,
|
||||
tool_responses=[],
|
||||
started_at=datetime.now(UTC).isoformat(),
|
||||
started_at=datetime.now(UTC),
|
||||
),
|
||||
)
|
||||
|
||||
|
@ -766,7 +790,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
self,
|
||||
toolgroups_for_turn: list[AgentToolGroup] | None = None,
|
||||
) -> None:
|
||||
toolgroup_to_args = {}
|
||||
toolgroup_to_args: dict[str, Any] = {}
|
||||
for toolgroup in (self.agent_config.toolgroups or []) + (toolgroups_for_turn or []):
|
||||
if isinstance(toolgroup, AgentToolGroupWithArgs):
|
||||
tool_group_name, _ = self._parse_toolgroup_name(toolgroup.name)
|
||||
|
@ -782,10 +806,10 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
toolgroup_to_args = toolgroup_to_args or {}
|
||||
|
||||
tool_name_to_def = {}
|
||||
tool_name_to_args = {}
|
||||
tool_name_to_def: dict[str, ToolDefinition] = {}
|
||||
tool_name_to_args: dict[str, Any] = {}
|
||||
|
||||
for tool_def in self.agent_config.client_tools:
|
||||
for tool_def in self.agent_config.client_tools or []:
|
||||
if tool_name_to_def.get(tool_def.name, None):
|
||||
raise ValueError(f"Tool {tool_def.name} already exists")
|
||||
tool_name_to_def[tool_def.name] = ToolDefinition(
|
||||
|
@ -798,7 +822,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
required=param.required,
|
||||
default=param.default,
|
||||
)
|
||||
for param in tool_def.parameters
|
||||
for param in tool_def.parameters or []
|
||||
},
|
||||
)
|
||||
for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups:
|
||||
|
@ -828,7 +852,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
else:
|
||||
identifier = None
|
||||
|
||||
if tool_name_to_def.get(identifier, None):
|
||||
if identifier is not None and tool_name_to_def.get(str(identifier), None):
|
||||
raise ValueError(f"Tool {identifier} already exists")
|
||||
if identifier:
|
||||
tool_name_to_def[tool_def.identifier] = ToolDefinition(
|
||||
|
@ -841,7 +865,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
required=param.required,
|
||||
default=param.default,
|
||||
)
|
||||
for param in tool_def.parameters
|
||||
for param in tool_def.parameters or []
|
||||
},
|
||||
)
|
||||
tool_name_to_args[tool_def.identifier] = toolgroup_to_args.get(toolgroup_name, {})
|
||||
|
@ -888,14 +912,15 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
tool_name_str = tool_name
|
||||
|
||||
logger.info(f"executing tool call: {tool_name_str} with args: {tool_call.arguments}")
|
||||
kwargs_dict = {"session_id": session_id}
|
||||
if tool_call.arguments and isinstance(tool_call.arguments, dict):
|
||||
kwargs_dict.update(tool_call.arguments)
|
||||
tool_args = self.tool_name_to_args.get(tool_name_str, {})
|
||||
if tool_args:
|
||||
kwargs_dict.update(tool_args)
|
||||
result = await self.tool_runtime_api.invoke_tool(
|
||||
tool_name=tool_name_str,
|
||||
kwargs={
|
||||
"session_id": session_id,
|
||||
# get the arguments generated by the model and augment with toolgroup arg overrides for the agent
|
||||
**tool_call.arguments,
|
||||
**self.tool_name_to_args.get(tool_name_str, {}),
|
||||
},
|
||||
kwargs=kwargs_dict,
|
||||
)
|
||||
logger.debug(f"tool call {tool_name_str} completed with result: {result}")
|
||||
return result
|
||||
|
@ -931,7 +956,7 @@ def _interpret_content_as_attachment(
|
|||
snippet = match.group(1)
|
||||
data = json.loads(snippet)
|
||||
return Attachment(
|
||||
url=URL(uri="file://" + data["filepath"]),
|
||||
content=URL(uri="file://" + data["filepath"]),
|
||||
mime_type=data["mimetype"],
|
||||
)
|
||||
|
||||
|
|
|
@ -28,8 +28,8 @@ class ShieldRunnerMixin:
|
|||
output_shields: list[str] | None = None,
|
||||
):
|
||||
self.safety_api = safety_api
|
||||
self.input_shields = input_shields
|
||||
self.output_shields = output_shields
|
||||
self.input_shields = input_shields or []
|
||||
self.output_shields = output_shields or []
|
||||
|
||||
async def run_multiple_shields(self, messages: list[Message], identifiers: list[str]) -> None:
|
||||
async def run_shield_with_span(identifier: str):
|
||||
|
|
|
@ -243,7 +243,6 @@ exclude = [
|
|||
"^llama_stack/models/llama/llama3/tokenizer\\.py$",
|
||||
"^llama_stack/models/llama/llama3/tool_utils\\.py$",
|
||||
"^llama_stack/providers/inline/agents/meta_reference/",
|
||||
"^llama_stack/providers/inline/agents/meta_reference/agent_instance\\.py$",
|
||||
"^llama_stack/providers/inline/agents/meta_reference/agents\\.py$",
|
||||
"^llama_stack/providers/inline/datasetio/localfs/",
|
||||
"^llama_stack/providers/inline/eval/meta_reference/eval\\.py$",
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue