mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
chore(api): add mypy coverage to meta_reference_agent_instance
Signed-off-by: Mustafa Elbehery <melbeher@redhat.com>
This commit is contained in:
parent
81109a0f72
commit
1d52c303d1
3 changed files with 105 additions and 81 deletions
|
@ -12,6 +12,7 @@ import string
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
@ -73,7 +74,7 @@ from .persistence import AgentPersistence
|
||||||
from .safety import SafetyException, ShieldRunnerMixin
|
from .safety import SafetyException, ShieldRunnerMixin
|
||||||
|
|
||||||
|
|
||||||
def make_random_string(length: int = 8):
|
def make_random_string(length: int = 8) -> str:
|
||||||
return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length))
|
return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length))
|
||||||
|
|
||||||
|
|
||||||
|
@ -117,14 +118,15 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
|
|
||||||
def turn_to_messages(self, turn: Turn) -> list[Message]:
|
def turn_to_messages(self, turn: Turn) -> list[Message]:
|
||||||
messages = []
|
messages: list[Message] = []
|
||||||
|
|
||||||
# NOTE: if a toolcall response is in a step, we do not add it when processing the input messages
|
# NOTE: if a toolcall response is in a step, we do not add it when processing the input messages
|
||||||
tool_call_ids = set()
|
tool_call_ids = set()
|
||||||
for step in turn.steps:
|
for step in turn.steps:
|
||||||
if step.step_type == StepType.tool_execution.value:
|
if step.step_type == StepType.tool_execution.value:
|
||||||
for response in step.tool_responses:
|
if isinstance(step, ToolExecutionStep):
|
||||||
tool_call_ids.add(response.call_id)
|
for response in step.tool_responses:
|
||||||
|
tool_call_ids.add(response.call_id)
|
||||||
|
|
||||||
for m in turn.input_messages:
|
for m in turn.input_messages:
|
||||||
msg = m.model_copy()
|
msg = m.model_copy()
|
||||||
|
@ -142,31 +144,34 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
for step in turn.steps:
|
for step in turn.steps:
|
||||||
if step.step_type == StepType.inference.value:
|
if step.step_type == StepType.inference.value:
|
||||||
messages.append(step.model_response)
|
if isinstance(step, InferenceStep):
|
||||||
|
messages.append(step.model_response)
|
||||||
elif step.step_type == StepType.tool_execution.value:
|
elif step.step_type == StepType.tool_execution.value:
|
||||||
for response in step.tool_responses:
|
if isinstance(step, ToolExecutionStep):
|
||||||
messages.append(
|
for response in step.tool_responses:
|
||||||
ToolResponseMessage(
|
messages.append(
|
||||||
call_id=response.call_id,
|
ToolResponseMessage(
|
||||||
content=response.content,
|
call_id=response.call_id,
|
||||||
|
content=response.content,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
elif step.step_type == StepType.shield_call.value:
|
elif step.step_type == StepType.shield_call.value:
|
||||||
if step.violation:
|
if isinstance(step, ShieldCallStep):
|
||||||
# CompletionMessage itself in the ShieldResponse
|
if step.violation:
|
||||||
messages.append(
|
# CompletionMessage itself in the ShieldResponse
|
||||||
CompletionMessage(
|
messages.append(
|
||||||
content=step.violation.user_message,
|
CompletionMessage(
|
||||||
stop_reason=StopReason.end_of_turn,
|
content=step.violation.user_message or "",
|
||||||
|
stop_reason=StopReason.end_of_turn,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
async def create_session(self, name: str) -> str:
|
async def create_session(self, name: str) -> str:
|
||||||
return await self.storage.create_session(name)
|
return await self.storage.create_session(name)
|
||||||
|
|
||||||
async def get_messages_from_turns(self, turns: list[Turn]) -> list[Message]:
|
async def get_messages_from_turns(self, turns: list[Turn]) -> list[Message]:
|
||||||
messages = []
|
messages: list[Message] = []
|
||||||
if self.agent_config.instructions != "":
|
if self.agent_config.instructions != "":
|
||||||
messages.append(SystemMessage(content=self.agent_config.instructions))
|
messages.append(SystemMessage(content=self.agent_config.instructions))
|
||||||
|
|
||||||
|
@ -174,7 +179,9 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
messages.extend(self.turn_to_messages(turn))
|
messages.extend(self.turn_to_messages(turn))
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
|
async def create_and_execute_turn(
|
||||||
|
self, request: AgentTurnCreateRequest
|
||||||
|
) -> AsyncGenerator[AgentTurnResponseStreamChunk, None]:
|
||||||
span = tracing.get_current_span()
|
span = tracing.get_current_span()
|
||||||
if span:
|
if span:
|
||||||
span.set_attribute("session_id", request.session_id)
|
span.set_attribute("session_id", request.session_id)
|
||||||
|
@ -189,7 +196,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
async for chunk in self._run_turn(request, turn_id):
|
async for chunk in self._run_turn(request, turn_id):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator:
|
async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator[AgentTurnResponseStreamChunk, None]:
|
||||||
span = tracing.get_current_span()
|
span = tracing.get_current_span()
|
||||||
if span:
|
if span:
|
||||||
span.set_attribute("agent_id", self.agent_id)
|
span.set_attribute("agent_id", self.agent_id)
|
||||||
|
@ -207,7 +214,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
self,
|
self,
|
||||||
request: AgentTurnCreateRequest | AgentTurnResumeRequest,
|
request: AgentTurnCreateRequest | AgentTurnResumeRequest,
|
||||||
turn_id: str | None = None,
|
turn_id: str | None = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator[AgentTurnResponseStreamChunk, None]:
|
||||||
assert request.stream is True, "Non-streaming not supported"
|
assert request.stream is True, "Non-streaming not supported"
|
||||||
|
|
||||||
is_resume = isinstance(request, AgentTurnResumeRequest)
|
is_resume = isinstance(request, AgentTurnResumeRequest)
|
||||||
|
@ -271,11 +278,13 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
input_messages = request.messages
|
input_messages = request.messages
|
||||||
|
|
||||||
output_message = None
|
output_message = None
|
||||||
|
turn_id_final = turn_id or str(uuid.uuid4())
|
||||||
|
sampling_params = self.agent_config.sampling_params or SamplingParams()
|
||||||
async for chunk in self.run(
|
async for chunk in self.run(
|
||||||
session_id=request.session_id,
|
session_id=request.session_id,
|
||||||
turn_id=turn_id,
|
turn_id=turn_id_final,
|
||||||
input_messages=messages,
|
input_messages=messages,
|
||||||
sampling_params=self.agent_config.sampling_params,
|
sampling_params=sampling_params,
|
||||||
stream=request.stream,
|
stream=request.stream,
|
||||||
documents=request.documents if not is_resume else None,
|
documents=request.documents if not is_resume else None,
|
||||||
):
|
):
|
||||||
|
@ -286,19 +295,20 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}"
|
assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}"
|
||||||
event = chunk.event
|
event = chunk.event
|
||||||
if event.payload.event_type == AgentTurnResponseEventType.step_complete.value:
|
if event.payload.event_type == AgentTurnResponseEventType.step_complete.value:
|
||||||
steps.append(event.payload.step_details)
|
if hasattr(event.payload, "step_details"):
|
||||||
|
steps.append(event.payload.step_details)
|
||||||
|
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
assert output_message is not None
|
assert output_message is not None
|
||||||
|
|
||||||
turn = Turn(
|
turn = Turn(
|
||||||
turn_id=turn_id,
|
turn_id=turn_id_final,
|
||||||
session_id=request.session_id,
|
session_id=request.session_id,
|
||||||
input_messages=input_messages,
|
input_messages=input_messages,
|
||||||
output_message=output_message,
|
output_message=output_message,
|
||||||
started_at=start_time,
|
started_at=start_time,
|
||||||
completed_at=datetime.now(UTC).isoformat(),
|
completed_at=datetime.now(UTC),
|
||||||
steps=steps,
|
steps=steps,
|
||||||
)
|
)
|
||||||
await self.storage.add_turn_to_session(request.session_id, turn)
|
await self.storage.add_turn_to_session(request.session_id, turn)
|
||||||
|
@ -329,7 +339,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
sampling_params: SamplingParams,
|
sampling_params: SamplingParams,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
documents: list[Document] | None = None,
|
documents: list[Document] | None = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator[AgentTurnResponseStreamChunk | CompletionMessage, None]:
|
||||||
# Doing async generators makes downstream code much simpler and everything amenable to
|
# Doing async generators makes downstream code much simpler and everything amenable to
|
||||||
# streaming. However, it also makes things complicated here because AsyncGenerators cannot
|
# streaming. However, it also makes things complicated here because AsyncGenerators cannot
|
||||||
# return a "final value" for the `yield from` statement. we simulate that by yielding a
|
# return a "final value" for the `yield from` statement. we simulate that by yielding a
|
||||||
|
@ -381,7 +391,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
messages: list[Message],
|
messages: list[Message],
|
||||||
shields: list[str],
|
shields: list[str],
|
||||||
touchpoint: str,
|
touchpoint: str,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator[AgentTurnResponseStreamChunk | CompletionMessage | bool, None]:
|
||||||
async with tracing.span("run_shields") as span:
|
async with tracing.span("run_shields") as span:
|
||||||
span.set_attribute("input", [m.model_dump_json() for m in messages])
|
span.set_attribute("input", [m.model_dump_json() for m in messages])
|
||||||
if len(shields) == 0:
|
if len(shields) == 0:
|
||||||
|
@ -389,12 +399,12 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
return
|
return
|
||||||
|
|
||||||
step_id = str(uuid.uuid4())
|
step_id = str(uuid.uuid4())
|
||||||
shield_call_start_time = datetime.now(UTC).isoformat()
|
shield_call_start_time = datetime.now(UTC)
|
||||||
try:
|
try:
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgentTurnResponseStepStartPayload(
|
payload=AgentTurnResponseStepStartPayload(
|
||||||
step_type=StepType.shield_call.value,
|
step_type=StepType.shield_call,
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
metadata=dict(touchpoint=touchpoint),
|
metadata=dict(touchpoint=touchpoint),
|
||||||
)
|
)
|
||||||
|
@ -406,14 +416,14 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgentTurnResponseStepCompletePayload(
|
payload=AgentTurnResponseStepCompletePayload(
|
||||||
step_type=StepType.shield_call.value,
|
step_type=StepType.shield_call,
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
step_details=ShieldCallStep(
|
step_details=ShieldCallStep(
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
turn_id=turn_id,
|
turn_id=turn_id,
|
||||||
violation=e.violation,
|
violation=e.violation,
|
||||||
started_at=shield_call_start_time,
|
started_at=shield_call_start_time,
|
||||||
completed_at=datetime.now(UTC).isoformat(),
|
completed_at=datetime.now(UTC),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -429,14 +439,14 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgentTurnResponseStepCompletePayload(
|
payload=AgentTurnResponseStepCompletePayload(
|
||||||
step_type=StepType.shield_call.value,
|
step_type=StepType.shield_call,
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
step_details=ShieldCallStep(
|
step_details=ShieldCallStep(
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
turn_id=turn_id,
|
turn_id=turn_id,
|
||||||
violation=None,
|
violation=None,
|
||||||
started_at=shield_call_start_time,
|
started_at=shield_call_start_time,
|
||||||
completed_at=datetime.now(UTC).isoformat(),
|
completed_at=datetime.now(UTC),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -451,7 +461,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
sampling_params: SamplingParams,
|
sampling_params: SamplingParams,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
documents: list[Document] | None = None,
|
documents: list[Document] | None = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator[AgentTurnResponseStreamChunk | CompletionMessage | bool, None]:
|
||||||
# if document is passed in a turn, we parse the raw text of the document
|
# if document is passed in a turn, we parse the raw text of the document
|
||||||
# and sent it as a user message
|
# and sent it as a user message
|
||||||
if documents:
|
if documents:
|
||||||
|
@ -481,43 +491,46 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
else:
|
else:
|
||||||
self.tool_name_to_args[tool_name]["vector_db_ids"].append(session_info.vector_db_id)
|
self.tool_name_to_args[tool_name]["vector_db_ids"].append(session_info.vector_db_id)
|
||||||
|
|
||||||
output_attachments = []
|
output_attachments: list[Attachment] = []
|
||||||
|
|
||||||
n_iter = await self.storage.get_num_infer_iters_in_turn(session_id, turn_id) or 0
|
n_iter = await self.storage.get_num_infer_iters_in_turn(session_id, turn_id) or 0
|
||||||
|
|
||||||
# Build a map of custom tools to their definitions for faster lookup
|
# Build a map of custom tools to their definitions for faster lookup
|
||||||
client_tools = {}
|
client_tools: dict[str, Any] = {}
|
||||||
for tool in self.agent_config.client_tools:
|
for tool in self.agent_config.client_tools or []:
|
||||||
client_tools[tool.name] = tool
|
client_tools[tool.name] = tool
|
||||||
while True:
|
while True:
|
||||||
step_id = str(uuid.uuid4())
|
step_id = str(uuid.uuid4())
|
||||||
inference_start_time = datetime.now(UTC).isoformat()
|
inference_start_time = datetime.now(UTC)
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgentTurnResponseStepStartPayload(
|
payload=AgentTurnResponseStepStartPayload(
|
||||||
step_type=StepType.inference.value,
|
step_type=StepType.inference,
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
tool_calls = []
|
tool_calls: list[ToolCall] = []
|
||||||
content = ""
|
content = ""
|
||||||
stop_reason = None
|
stop_reason = None
|
||||||
|
|
||||||
async with tracing.span("inference") as span:
|
async with tracing.span("inference") as span:
|
||||||
if self.agent_config.name:
|
if self.agent_config.name:
|
||||||
span.set_attribute("agent_name", self.agent_config.name)
|
span.set_attribute("agent_name", self.agent_config.name)
|
||||||
async for chunk in await self.inference_api.chat_completion(
|
chat_completion_response = await self.inference_api.chat_completion(
|
||||||
self.agent_config.model,
|
self.agent_config.model,
|
||||||
input_messages,
|
input_messages,
|
||||||
tools=self.tool_defs,
|
tools=self.tool_defs,
|
||||||
tool_prompt_format=self.agent_config.tool_config.tool_prompt_format,
|
tool_prompt_format=self.agent_config.tool_config.tool_prompt_format
|
||||||
|
if self.agent_config.tool_config
|
||||||
|
else None,
|
||||||
response_format=self.agent_config.response_format,
|
response_format=self.agent_config.response_format,
|
||||||
stream=True,
|
stream=True,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
tool_config=self.agent_config.tool_config,
|
tool_config=self.agent_config.tool_config,
|
||||||
):
|
)
|
||||||
|
async for chunk in chat_completion_response:
|
||||||
event = chunk.event
|
event = chunk.event
|
||||||
if event.event_type == ChatCompletionResponseEventType.start:
|
if event.event_type == ChatCompletionResponseEventType.start:
|
||||||
continue
|
continue
|
||||||
|
@ -527,16 +540,18 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
delta = event.delta
|
delta = event.delta
|
||||||
if delta.type == "tool_call":
|
if delta.type == "tool_call":
|
||||||
if delta.parse_status == ToolCallParseStatus.succeeded:
|
if hasattr(delta, "parse_status") and delta.parse_status == ToolCallParseStatus.succeeded:
|
||||||
tool_calls.append(delta.tool_call)
|
if hasattr(delta, "tool_call"):
|
||||||
elif delta.parse_status == ToolCallParseStatus.failed:
|
tool_calls.append(delta.tool_call)
|
||||||
|
elif hasattr(delta, "parse_status") and delta.parse_status == ToolCallParseStatus.failed:
|
||||||
# If we cannot parse the tools, set the content to the unparsed raw text
|
# If we cannot parse the tools, set the content to the unparsed raw text
|
||||||
content = delta.tool_call
|
if hasattr(delta, "tool_call"):
|
||||||
|
content = str(delta.tool_call)
|
||||||
if stream:
|
if stream:
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgentTurnResponseStepProgressPayload(
|
payload=AgentTurnResponseStepProgressPayload(
|
||||||
step_type=StepType.inference.value,
|
step_type=StepType.inference,
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
delta=delta,
|
delta=delta,
|
||||||
)
|
)
|
||||||
|
@ -544,12 +559,13 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
|
|
||||||
elif delta.type == "text":
|
elif delta.type == "text":
|
||||||
content += delta.text
|
if hasattr(delta, "text"):
|
||||||
|
content += delta.text
|
||||||
if stream and event.stop_reason is None:
|
if stream and event.stop_reason is None:
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgentTurnResponseStepProgressPayload(
|
payload=AgentTurnResponseStepProgressPayload(
|
||||||
step_type=StepType.inference.value,
|
step_type=StepType.inference,
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
delta=delta,
|
delta=delta,
|
||||||
)
|
)
|
||||||
|
@ -565,10 +581,16 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
"input",
|
"input",
|
||||||
json.dumps([json.loads(m.model_dump_json()) for m in input_messages]),
|
json.dumps([json.loads(m.model_dump_json()) for m in input_messages]),
|
||||||
)
|
)
|
||||||
|
tool_calls_json = []
|
||||||
|
for t in tool_calls:
|
||||||
|
if hasattr(t, "model_dump_json"):
|
||||||
|
tool_calls_json.append(json.loads(t.model_dump_json()))
|
||||||
|
else:
|
||||||
|
tool_calls_json.append(str(t))
|
||||||
output_attr = json.dumps(
|
output_attr = json.dumps(
|
||||||
{
|
{
|
||||||
"content": content,
|
"content": content,
|
||||||
"tool_calls": [json.loads(t.model_dump_json()) for t in tool_calls],
|
"tool_calls": tool_calls_json,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
span.set_attribute("output", output_attr)
|
span.set_attribute("output", output_attr)
|
||||||
|
@ -593,7 +615,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgentTurnResponseStepCompletePayload(
|
payload=AgentTurnResponseStepCompletePayload(
|
||||||
step_type=StepType.inference.value,
|
step_type=StepType.inference,
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
step_details=InferenceStep(
|
step_details=InferenceStep(
|
||||||
# somewhere deep, we are re-assigning message or closing over some
|
# somewhere deep, we are re-assigning message or closing over some
|
||||||
|
@ -603,13 +625,14 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
turn_id=turn_id,
|
turn_id=turn_id,
|
||||||
model_response=copy.deepcopy(message),
|
model_response=copy.deepcopy(message),
|
||||||
started_at=inference_start_time,
|
started_at=inference_start_time,
|
||||||
completed_at=datetime.now(UTC).isoformat(),
|
completed_at=datetime.now(UTC),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if n_iter >= self.agent_config.max_infer_iters:
|
max_infer_iters = self.agent_config.max_infer_iters or 10
|
||||||
|
if n_iter >= max_infer_iters:
|
||||||
logger.info(f"done with MAX iterations ({n_iter}), exiting.")
|
logger.info(f"done with MAX iterations ({n_iter}), exiting.")
|
||||||
# NOTE: mark end_of_turn to indicate to client that we are done with the turn
|
# NOTE: mark end_of_turn to indicate to client that we are done with the turn
|
||||||
# Do not continue the tool call loop after this point
|
# Do not continue the tool call loop after this point
|
||||||
|
@ -622,7 +645,8 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
yield message
|
yield message
|
||||||
break
|
break
|
||||||
|
|
||||||
if len(message.tool_calls) == 0:
|
tool_calls_to_process = message.tool_calls or []
|
||||||
|
if len(tool_calls_to_process) == 0:
|
||||||
if stop_reason == StopReason.end_of_turn:
|
if stop_reason == StopReason.end_of_turn:
|
||||||
# TODO: UPDATE RETURN TYPE TO SEND A TUPLE OF (MESSAGE, ATTACHMENTS)
|
# TODO: UPDATE RETURN TYPE TO SEND A TUPLE OF (MESSAGE, ATTACHMENTS)
|
||||||
if len(output_attachments) > 0:
|
if len(output_attachments) > 0:
|
||||||
|
@ -642,7 +666,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
non_client_tool_calls = []
|
non_client_tool_calls = []
|
||||||
|
|
||||||
# Separate client and non-client tool calls
|
# Separate client and non-client tool calls
|
||||||
for tool_call in message.tool_calls:
|
for tool_call in tool_calls_to_process:
|
||||||
if tool_call.tool_name in client_tools:
|
if tool_call.tool_name in client_tools:
|
||||||
client_tool_calls.append(tool_call)
|
client_tool_calls.append(tool_call)
|
||||||
else:
|
else:
|
||||||
|
@ -654,7 +678,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgentTurnResponseStepStartPayload(
|
payload=AgentTurnResponseStepStartPayload(
|
||||||
step_type=StepType.tool_execution.value,
|
step_type=StepType.tool_execution,
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -663,7 +687,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgentTurnResponseStepProgressPayload(
|
payload=AgentTurnResponseStepProgressPayload(
|
||||||
step_type=StepType.tool_execution.value,
|
step_type=StepType.tool_execution,
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
delta=ToolCallDelta(
|
delta=ToolCallDelta(
|
||||||
parse_status=ToolCallParseStatus.in_progress,
|
parse_status=ToolCallParseStatus.in_progress,
|
||||||
|
@ -681,7 +705,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
"input": message.model_dump_json(),
|
"input": message.model_dump_json(),
|
||||||
},
|
},
|
||||||
) as span:
|
) as span:
|
||||||
tool_execution_start_time = datetime.now(UTC).isoformat()
|
tool_execution_start_time = datetime.now(UTC)
|
||||||
tool_result = await self.execute_tool_call_maybe(
|
tool_result = await self.execute_tool_call_maybe(
|
||||||
session_id,
|
session_id,
|
||||||
tool_call,
|
tool_call,
|
||||||
|
@ -710,14 +734,14 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
started_at=tool_execution_start_time,
|
started_at=tool_execution_start_time,
|
||||||
completed_at=datetime.now(UTC).isoformat(),
|
completed_at=datetime.now(UTC),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Yield the step completion event
|
# Yield the step completion event
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgentTurnResponseStepCompletePayload(
|
payload=AgentTurnResponseStepCompletePayload(
|
||||||
step_type=StepType.tool_execution.value,
|
step_type=StepType.tool_execution,
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
step_details=tool_execution_step,
|
step_details=tool_execution_step,
|
||||||
)
|
)
|
||||||
|
@ -747,7 +771,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
turn_id=turn_id,
|
turn_id=turn_id,
|
||||||
tool_calls=client_tool_calls,
|
tool_calls=client_tool_calls,
|
||||||
tool_responses=[],
|
tool_responses=[],
|
||||||
started_at=datetime.now(UTC).isoformat(),
|
started_at=datetime.now(UTC),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -766,7 +790,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
self,
|
self,
|
||||||
toolgroups_for_turn: list[AgentToolGroup] | None = None,
|
toolgroups_for_turn: list[AgentToolGroup] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
toolgroup_to_args = {}
|
toolgroup_to_args: dict[str, Any] = {}
|
||||||
for toolgroup in (self.agent_config.toolgroups or []) + (toolgroups_for_turn or []):
|
for toolgroup in (self.agent_config.toolgroups or []) + (toolgroups_for_turn or []):
|
||||||
if isinstance(toolgroup, AgentToolGroupWithArgs):
|
if isinstance(toolgroup, AgentToolGroupWithArgs):
|
||||||
tool_group_name, _ = self._parse_toolgroup_name(toolgroup.name)
|
tool_group_name, _ = self._parse_toolgroup_name(toolgroup.name)
|
||||||
|
@ -782,10 +806,10 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
toolgroup_to_args = toolgroup_to_args or {}
|
toolgroup_to_args = toolgroup_to_args or {}
|
||||||
|
|
||||||
tool_name_to_def = {}
|
tool_name_to_def: dict[str, ToolDefinition] = {}
|
||||||
tool_name_to_args = {}
|
tool_name_to_args: dict[str, Any] = {}
|
||||||
|
|
||||||
for tool_def in self.agent_config.client_tools:
|
for tool_def in self.agent_config.client_tools or []:
|
||||||
if tool_name_to_def.get(tool_def.name, None):
|
if tool_name_to_def.get(tool_def.name, None):
|
||||||
raise ValueError(f"Tool {tool_def.name} already exists")
|
raise ValueError(f"Tool {tool_def.name} already exists")
|
||||||
tool_name_to_def[tool_def.name] = ToolDefinition(
|
tool_name_to_def[tool_def.name] = ToolDefinition(
|
||||||
|
@ -798,7 +822,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
required=param.required,
|
required=param.required,
|
||||||
default=param.default,
|
default=param.default,
|
||||||
)
|
)
|
||||||
for param in tool_def.parameters
|
for param in tool_def.parameters or []
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups:
|
for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups:
|
||||||
|
@ -828,7 +852,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
else:
|
else:
|
||||||
identifier = None
|
identifier = None
|
||||||
|
|
||||||
if tool_name_to_def.get(identifier, None):
|
if identifier is not None and tool_name_to_def.get(str(identifier), None):
|
||||||
raise ValueError(f"Tool {identifier} already exists")
|
raise ValueError(f"Tool {identifier} already exists")
|
||||||
if identifier:
|
if identifier:
|
||||||
tool_name_to_def[tool_def.identifier] = ToolDefinition(
|
tool_name_to_def[tool_def.identifier] = ToolDefinition(
|
||||||
|
@ -841,7 +865,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
required=param.required,
|
required=param.required,
|
||||||
default=param.default,
|
default=param.default,
|
||||||
)
|
)
|
||||||
for param in tool_def.parameters
|
for param in tool_def.parameters or []
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
tool_name_to_args[tool_def.identifier] = toolgroup_to_args.get(toolgroup_name, {})
|
tool_name_to_args[tool_def.identifier] = toolgroup_to_args.get(toolgroup_name, {})
|
||||||
|
@ -888,14 +912,15 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
tool_name_str = tool_name
|
tool_name_str = tool_name
|
||||||
|
|
||||||
logger.info(f"executing tool call: {tool_name_str} with args: {tool_call.arguments}")
|
logger.info(f"executing tool call: {tool_name_str} with args: {tool_call.arguments}")
|
||||||
|
kwargs_dict = {"session_id": session_id}
|
||||||
|
if tool_call.arguments and isinstance(tool_call.arguments, dict):
|
||||||
|
kwargs_dict.update(tool_call.arguments)
|
||||||
|
tool_args = self.tool_name_to_args.get(tool_name_str, {})
|
||||||
|
if tool_args:
|
||||||
|
kwargs_dict.update(tool_args)
|
||||||
result = await self.tool_runtime_api.invoke_tool(
|
result = await self.tool_runtime_api.invoke_tool(
|
||||||
tool_name=tool_name_str,
|
tool_name=tool_name_str,
|
||||||
kwargs={
|
kwargs=kwargs_dict,
|
||||||
"session_id": session_id,
|
|
||||||
# get the arguments generated by the model and augment with toolgroup arg overrides for the agent
|
|
||||||
**tool_call.arguments,
|
|
||||||
**self.tool_name_to_args.get(tool_name_str, {}),
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
logger.debug(f"tool call {tool_name_str} completed with result: {result}")
|
logger.debug(f"tool call {tool_name_str} completed with result: {result}")
|
||||||
return result
|
return result
|
||||||
|
@ -931,7 +956,7 @@ def _interpret_content_as_attachment(
|
||||||
snippet = match.group(1)
|
snippet = match.group(1)
|
||||||
data = json.loads(snippet)
|
data = json.loads(snippet)
|
||||||
return Attachment(
|
return Attachment(
|
||||||
url=URL(uri="file://" + data["filepath"]),
|
content=URL(uri="file://" + data["filepath"]),
|
||||||
mime_type=data["mimetype"],
|
mime_type=data["mimetype"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -28,8 +28,8 @@ class ShieldRunnerMixin:
|
||||||
output_shields: list[str] | None = None,
|
output_shields: list[str] | None = None,
|
||||||
):
|
):
|
||||||
self.safety_api = safety_api
|
self.safety_api = safety_api
|
||||||
self.input_shields = input_shields
|
self.input_shields = input_shields or []
|
||||||
self.output_shields = output_shields
|
self.output_shields = output_shields or []
|
||||||
|
|
||||||
async def run_multiple_shields(self, messages: list[Message], identifiers: list[str]) -> None:
|
async def run_multiple_shields(self, messages: list[Message], identifiers: list[str]) -> None:
|
||||||
async def run_shield_with_span(identifier: str):
|
async def run_shield_with_span(identifier: str):
|
||||||
|
|
|
@ -243,7 +243,6 @@ exclude = [
|
||||||
"^llama_stack/models/llama/llama3/tokenizer\\.py$",
|
"^llama_stack/models/llama/llama3/tokenizer\\.py$",
|
||||||
"^llama_stack/models/llama/llama3/tool_utils\\.py$",
|
"^llama_stack/models/llama/llama3/tool_utils\\.py$",
|
||||||
"^llama_stack/providers/inline/agents/meta_reference/",
|
"^llama_stack/providers/inline/agents/meta_reference/",
|
||||||
"^llama_stack/providers/inline/agents/meta_reference/agent_instance\\.py$",
|
|
||||||
"^llama_stack/providers/inline/agents/meta_reference/agents\\.py$",
|
"^llama_stack/providers/inline/agents/meta_reference/agents\\.py$",
|
||||||
"^llama_stack/providers/inline/datasetio/localfs/",
|
"^llama_stack/providers/inline/datasetio/localfs/",
|
||||||
"^llama_stack/providers/inline/eval/meta_reference/eval\\.py$",
|
"^llama_stack/providers/inline/eval/meta_reference/eval\\.py$",
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue