This commit is contained in:
Xi Yan 2025-03-20 16:10:59 -07:00
parent 1f04ca357b
commit 0b1e71718c
2 changed files with 448 additions and 156 deletions

View file

@ -40,10 +40,10 @@ from llama_stack.apis.agents import (
Turn,
)
from llama_stack.apis.common.content_types import (
URL,
TextContentItem,
ToolCallDelta,
ToolCallParseStatus,
URL,
)
from llama_stack.apis.inference import (
ChatCompletionResponseEventType,
@ -80,7 +80,9 @@ from .safety import SafetyException, ShieldRunnerMixin
def make_random_string(length: int = 8):
return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length))
return "".join(
secrets.choice(string.ascii_letters + string.digits) for _ in range(length)
)
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
@ -179,7 +181,9 @@ class ChatAgent(ShieldRunnerMixin):
messages.extend(self.turn_to_messages(turn))
return messages
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
async def create_and_execute_turn(
self, request: AgentTurnCreateRequest
) -> AsyncGenerator:
await self._initialize_tools(request.toolgroups)
async with tracing.span("create_and_execute_turn") as span:
span.set_attribute("session_id", request.session_id)
@ -220,13 +224,16 @@ class ChatAgent(ShieldRunnerMixin):
messages = await self.get_messages_from_turns(turns)
if is_resume:
tool_response_messages = [
ToolResponseMessage(call_id=x.call_id, content=x.content) for x in request.tool_responses
ToolResponseMessage(call_id=x.call_id, content=x.content)
for x in request.tool_responses
]
messages.extend(tool_response_messages)
last_turn = turns[-1]
last_turn_messages = self.turn_to_messages(last_turn)
last_turn_messages = [
x for x in last_turn_messages if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage)
x
for x in last_turn_messages
if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage)
]
last_turn_messages.extend(tool_response_messages)
@ -236,17 +243,31 @@ class ChatAgent(ShieldRunnerMixin):
# mark tool execution step as complete
# if there's no tool execution in progress step (due to storage, or tool call parsing on client),
# we'll create a new tool execution step with current time
in_progress_tool_call_step = await self.storage.get_in_progress_tool_call_step(
request.session_id, request.turn_id
in_progress_tool_call_step = (
await self.storage.get_in_progress_tool_call_step(
request.session_id, request.turn_id
)
)
now = datetime.now(timezone.utc).isoformat()
tool_execution_step = ToolExecutionStep(
step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())),
step_id=(
in_progress_tool_call_step.step_id
if in_progress_tool_call_step
else str(uuid.uuid4())
),
turn_id=request.turn_id,
tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []),
tool_calls=(
in_progress_tool_call_step.tool_calls
if in_progress_tool_call_step
else []
),
tool_responses=request.tool_responses,
completed_at=now,
started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now),
started_at=(
in_progress_tool_call_step.started_at
if in_progress_tool_call_step
else now
),
)
steps.append(tool_execution_step)
yield AgentTurnResponseStreamChunk(
@ -280,9 +301,14 @@ class ChatAgent(ShieldRunnerMixin):
output_message = chunk
continue
assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}"
assert isinstance(
chunk, AgentTurnResponseStreamChunk
), f"Unexpected type {type(chunk)}"
event = chunk.event
if event.payload.event_type == AgentTurnResponseEventType.step_complete.value:
if (
event.payload.event_type
== AgentTurnResponseEventType.step_complete.value
):
steps.append(event.payload.step_details)
yield chunk
@ -440,6 +466,18 @@ class ChatAgent(ShieldRunnerMixin):
)
span.set_attribute("output", "no violations")
async def get_raw_document_text(self, document: Document) -> str:
if isinstance(document.content, URL):
return await load_data_from_url(document.content)
elif isinstance(document.content, str):
return document.content
elif isinstance(document.content, TextContentItem):
return document.content.text
else:
raise ValueError(
f"Unexpected document content type: {type(document.content)}"
)
async def _run(
self,
session_id: str,
@ -449,8 +487,23 @@ class ChatAgent(ShieldRunnerMixin):
stream: bool = False,
documents: Optional[List[Document]] = None,
) -> AsyncGenerator:
# if documents:
# await self.handle_documents(session_id, documents, input_messages)
# if document is passed in a turn, we parse the raw text of the document
# and sent it as a user message
if documents:
await self.handle_documents(session_id, documents, input_messages)
contexts = []
for document in documents:
raw_document_text = await self.get_raw_document_text(document)
contexts.append(TextContentItem(text=raw_document_text))
# modify the last user message to include the document
input_messages.append(
ToolResponseMessage(
call_id=str(uuid.uuid4()),
content=contexts,
)
)
session_info = await self.storage.get_session_info(session_id)
# if the session has a memory bank id, let the memory tool use it
@ -458,13 +511,19 @@ class ChatAgent(ShieldRunnerMixin):
for tool_name in self.tool_name_to_args.keys():
if tool_name == MEMORY_QUERY_TOOL:
if "vector_db_ids" not in self.tool_name_to_args[tool_name]:
self.tool_name_to_args[tool_name]["vector_db_ids"] = [session_info.vector_db_id]
self.tool_name_to_args[tool_name]["vector_db_ids"] = [
session_info.vector_db_id
]
else:
self.tool_name_to_args[tool_name]["vector_db_ids"].append(session_info.vector_db_id)
self.tool_name_to_args[tool_name]["vector_db_ids"].append(
session_info.vector_db_id
)
output_attachments = []
n_iter = await self.storage.get_num_infer_iters_in_turn(session_id, turn_id) or 0
n_iter = (
await self.storage.get_num_infer_iters_in_turn(session_id, turn_id) or 0
)
# Build a map of custom tools to their definitions for faster lookup
client_tools = {}
@ -487,6 +546,9 @@ class ChatAgent(ShieldRunnerMixin):
stop_reason = None
async with tracing.span("inference") as span:
from rich.pretty import pprint
pprint(input_messages)
async for chunk in await self.inference_api.chat_completion(
self.agent_config.model,
input_messages,
@ -542,12 +604,16 @@ class ChatAgent(ShieldRunnerMixin):
span.set_attribute("stop_reason", stop_reason)
span.set_attribute(
"input",
json.dumps([json.loads(m.model_dump_json()) for m in input_messages]),
json.dumps(
[json.loads(m.model_dump_json()) for m in input_messages]
),
)
output_attr = json.dumps(
{
"content": content,
"tool_calls": [json.loads(t.model_dump_json()) for t in tool_calls],
"tool_calls": [
json.loads(t.model_dump_json()) for t in tool_calls
],
}
)
span.set_attribute("output", output_attr)
@ -611,7 +677,9 @@ class ChatAgent(ShieldRunnerMixin):
message.content = [message.content] + output_attachments
yield message
else:
logger.debug(f"completion message with EOM (iter: {n_iter}): {str(message)}")
logger.debug(
f"completion message with EOM (iter: {n_iter}): {str(message)}"
)
input_messages = input_messages + [message]
else:
input_messages = input_messages + [message]
@ -660,7 +728,9 @@ class ChatAgent(ShieldRunnerMixin):
"input": message.model_dump_json(),
},
) as span:
tool_execution_start_time = datetime.now(timezone.utc).isoformat()
tool_execution_start_time = datetime.now(
timezone.utc
).isoformat()
tool_result = await self.execute_tool_call_maybe(
session_id,
tool_call,
@ -709,7 +779,9 @@ class ChatAgent(ShieldRunnerMixin):
# TODO: add tool-input touchpoint and a "start" event for this step also
# but that needs a lot more refactoring of Tool code potentially
if (type(result_message.content) is str) and (
out_attachment := _interpret_content_as_attachment(result_message.content)
out_attachment := _interpret_content_as_attachment(
result_message.content
)
):
# NOTE: when we push this message back to the model, the model may ignore the
# attached file path etc. since the model is trained to only provide a user message
@ -746,16 +818,24 @@ class ChatAgent(ShieldRunnerMixin):
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
) -> None:
toolgroup_to_args = {}
for toolgroup in (self.agent_config.toolgroups or []) + (toolgroups_for_turn or []):
for toolgroup in (self.agent_config.toolgroups or []) + (
toolgroups_for_turn or []
):
if isinstance(toolgroup, AgentToolGroupWithArgs):
tool_group_name, _ = self._parse_toolgroup_name(toolgroup.name)
toolgroup_to_args[tool_group_name] = toolgroup.args
# Determine which tools to include
tool_groups_to_include = toolgroups_for_turn or self.agent_config.toolgroups or []
tool_groups_to_include = (
toolgroups_for_turn or self.agent_config.toolgroups or []
)
agent_config_toolgroups = []
for toolgroup in tool_groups_to_include:
name = toolgroup.name if isinstance(toolgroup, AgentToolGroupWithArgs) else toolgroup
name = (
toolgroup.name
if isinstance(toolgroup, AgentToolGroupWithArgs)
else toolgroup
)
if name not in agent_config_toolgroups:
agent_config_toolgroups.append(name)
@ -781,20 +861,32 @@ class ChatAgent(ShieldRunnerMixin):
},
)
for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups:
toolgroup_name, input_tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name)
toolgroup_name, input_tool_name = self._parse_toolgroup_name(
toolgroup_name_with_maybe_tool_name
)
tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name)
if not tools.data:
available_tool_groups = ", ".join(
[t.identifier for t in (await self.tool_groups_api.list_tool_groups()).data]
[
t.identifier
for t in (await self.tool_groups_api.list_tool_groups()).data
]
)
raise ValueError(f"Toolgroup {toolgroup_name} not found, available toolgroups: {available_tool_groups}")
if input_tool_name is not None and not any(tool.identifier == input_tool_name for tool in tools.data):
raise ValueError(
f"Toolgroup {toolgroup_name} not found, available toolgroups: {available_tool_groups}"
)
if input_tool_name is not None and not any(
tool.identifier == input_tool_name for tool in tools.data
):
raise ValueError(
f"Tool {input_tool_name} not found in toolgroup {toolgroup_name}. Available tools: {', '.join([tool.identifier for tool in tools.data])}"
)
for tool_def in tools.data:
if toolgroup_name.startswith("builtin") and toolgroup_name != RAG_TOOL_GROUP:
if (
toolgroup_name.startswith("builtin")
and toolgroup_name != RAG_TOOL_GROUP
):
identifier: str | BuiltinTool | None = tool_def.identifier
if identifier == "web_search":
identifier = BuiltinTool.brave_search
@ -823,11 +915,18 @@ class ChatAgent(ShieldRunnerMixin):
for param in tool_def.parameters
},
)
tool_name_to_args[tool_def.identifier] = toolgroup_to_args.get(toolgroup_name, {})
tool_name_to_args[tool_def.identifier] = toolgroup_to_args.get(
toolgroup_name, {}
)
self.tool_defs, self.tool_name_to_args = list(tool_name_to_def.values()), tool_name_to_args
self.tool_defs, self.tool_name_to_args = (
list(tool_name_to_def.values()),
tool_name_to_args,
)
def _parse_toolgroup_name(self, toolgroup_name_with_maybe_tool_name: str) -> tuple[str, Optional[str]]:
def _parse_toolgroup_name(
self, toolgroup_name_with_maybe_tool_name: str
) -> tuple[str, Optional[str]]:
"""Parse a toolgroup name into its components.
Args:
@ -863,7 +962,9 @@ class ChatAgent(ShieldRunnerMixin):
else:
tool_name_str = tool_name
logger.info(f"executing tool call: {tool_name_str} with args: {tool_call.arguments}")
logger.info(
f"executing tool call: {tool_name_str} with args: {tool_call.arguments}"
)
result = await self.tool_runtime_api.invoke_tool(
tool_name=tool_name_str,
kwargs={
@ -876,144 +977,142 @@ class ChatAgent(ShieldRunnerMixin):
logger.debug(f"tool call {tool_name_str} completed with result: {result}")
return result
async def handle_documents(
self,
session_id: str,
documents: List[Document],
input_messages: List[Message],
) -> None:
memory_tool = any(tool_def.tool_name == MEMORY_QUERY_TOOL for tool_def in self.tool_defs)
code_interpreter_tool = any(tool_def.tool_name == BuiltinTool.code_interpreter for tool_def in self.tool_defs)
content_items = []
url_items = []
pattern = re.compile("^(https?://|file://|data:)")
for d in documents:
if isinstance(d.content, URL):
url_items.append(d.content)
elif pattern.match(d.content):
url_items.append(URL(uri=d.content))
else:
content_items.append(d)
# async def handle_documents(
# self,
# session_id: str,
# documents: List[Document],
# input_messages: List[Message],
# ) -> None:
# memory_tool = any(tool_def.tool_name == MEMORY_QUERY_TOOL for tool_def in self.tool_defs)
# code_interpreter_tool = any(tool_def.tool_name == BuiltinTool.code_interpreter for tool_def in self.tool_defs)
# content_items = []
# url_items = []
# pattern = re.compile("^(https?://|file://|data:)")
# for d in documents:
# if isinstance(d.content, URL):
# url_items.append(d.content)
# elif pattern.match(d.content):
# url_items.append(URL(uri=d.content))
# else:
# content_items.append(d)
# Save the contents to a tempdir and use its path as a URL if code interpreter is present
if code_interpreter_tool:
for c in content_items:
temp_file_path = os.path.join(self.tempdir, f"{make_random_string()}.txt")
with open(temp_file_path, "w") as temp_file:
temp_file.write(c.content)
url_items.append(URL(uri=f"file://{temp_file_path}"))
# # Save the contents to a tempdir and use its path as a URL if code interpreter is present
# if code_interpreter_tool:
# for c in content_items:
# temp_file_path = os.path.join(self.tempdir, f"{make_random_string()}.txt")
# with open(temp_file_path, "w") as temp_file:
# temp_file.write(c.content)
# url_items.append(URL(uri=f"file://{temp_file_path}"))
if memory_tool and code_interpreter_tool:
# if both memory and code_interpreter are available, we download the URLs
# and attach the data to the last message.
await attachment_message(self.tempdir, url_items, input_messages[-1])
# Since memory is present, add all the data to the memory bank
await self.add_to_session_vector_db(session_id, documents)
elif code_interpreter_tool:
# if only code_interpreter is available, we download the URLs to a tempdir
# and attach the path to them as a message to inference with the
# assumption that the model invokes the code_interpreter tool with the path
await attachment_message(self.tempdir, url_items, input_messages[-1])
elif memory_tool:
# if only memory is available, we load the data from the URLs and content items to the memory bank
await self.add_to_session_vector_db(session_id, documents)
else:
# if no memory or code_interpreter tool is available,
# we try to load the data from the URLs and content items as a message to inference
# and add it to the last message's context
input_messages[-1].context = "\n".join(
[doc.content for doc in content_items] + await load_data_from_urls(url_items)
)
# if memory_tool and code_interpreter_tool:
# # if both memory and code_interpreter are available, we download the URLs
# # and attach the data to the last message.
# await attachment_message(self.tempdir, url_items, input_messages[-1])
# # Since memory is present, add all the data to the memory bank
# await self.add_to_session_vector_db(session_id, documents)
# elif code_interpreter_tool:
# # if only code_interpreter is available, we download the URLs to a tempdir
# # and attach the path to them as a message to inference with the
# # assumption that the model invokes the code_interpreter tool with the path
# await attachment_message(self.tempdir, url_items, input_messages[-1])
# elif memory_tool:
# # if only memory is available, we load the data from the URLs and content items to the memory bank
# await self.add_to_session_vector_db(session_id, documents)
# else:
# # if no memory or code_interpreter tool is available,
# # we try to load the data from the URLs and content items as a message to inference
# # and add it to the last message's context
# input_messages[-1].context = "\n".join(
# [doc.content for doc in content_items] + await load_data_from_urls(url_items)
# )
async def _ensure_vector_db(self, session_id: str) -> str:
session_info = await self.storage.get_session_info(session_id)
if session_info is None:
raise ValueError(f"Session {session_id} not found")
# async def _ensure_vector_db(self, session_id: str) -> str:
# session_info = await self.storage.get_session_info(session_id)
# if session_info is None:
# raise ValueError(f"Session {session_id} not found")
if session_info.vector_db_id is None:
vector_db_id = f"vector_db_{session_id}"
# if session_info.vector_db_id is None:
# vector_db_id = f"vector_db_{session_id}"
# TODO: the semantic for registration is definitely not "creation"
# so we need to fix it if we expect the agent to create a new vector db
# for each session
await self.vector_io_api.register_vector_db(
vector_db_id=vector_db_id,
embedding_model="all-MiniLM-L6-v2",
)
await self.storage.add_vector_db_to_session(session_id, vector_db_id)
else:
vector_db_id = session_info.vector_db_id
# # TODO: the semantic for registration is definitely not "creation"
# # so we need to fix it if we expect the agent to create a new vector db
# # for each session
# await self.vector_io_api.register_vector_db(
# vector_db_id=vector_db_id,
# embedding_model="all-MiniLM-L6-v2",
# )
# await self.storage.add_vector_db_to_session(session_id, vector_db_id)
# else:
# vector_db_id = session_info.vector_db_id
return vector_db_id
# return vector_db_id
async def add_to_session_vector_db(self, session_id: str, data: List[Document]) -> None:
vector_db_id = await self._ensure_vector_db(session_id)
documents = [
RAGDocument(
document_id=str(uuid.uuid4()),
content=a.content,
mime_type=a.mime_type,
metadata={},
)
for a in data
]
await self.tool_runtime_api.rag_tool.insert(
documents=documents,
vector_db_id=vector_db_id,
chunk_size_in_tokens=512,
)
# async def add_to_session_vector_db(
# self, session_id: str, data: List[Document]
# ) -> None:
# vector_db_id = await self._ensure_vector_db(session_id)
# documents = [
# RAGDocument(
# document_id=str(uuid.uuid4()),
# content=a.content,
# mime_type=a.mime_type,
# metadata={},
# )
# for a in data
# ]
# await self.tool_runtime_api.rag_tool.insert(
# documents=documents,
# vector_db_id=vector_db_id,
# chunk_size_in_tokens=512,
# )
async def load_data_from_urls(urls: List[URL]) -> List[str]:
data = []
for url in urls:
uri = url.uri
if uri.startswith("file://"):
filepath = uri[len("file://") :]
with open(filepath, "r") as f:
data.append(f.read())
elif uri.startswith("http"):
async with httpx.AsyncClient() as client:
r = await client.get(uri)
resp = r.text
data.append(resp)
return data
async def load_data_from_url(url: URL) -> str:
uri = url.uri
if uri.startswith("http"):
async with httpx.AsyncClient() as client:
r = await client.get(uri)
resp = r.text
return resp
return ""
async def attachment_message(tempdir: str, urls: List[URL], message: UserMessage) -> None:
contents = []
# async def attachment_message(
# tempdir: str, urls: List[URL], message: UserMessage
# ) -> None:
# contents = []
for url in urls:
uri = url.uri
if uri.startswith("file://"):
filepath = uri[len("file://") :]
elif uri.startswith("http"):
path = urlparse(uri).path
basename = os.path.basename(path)
filepath = f"{tempdir}/{make_random_string() + basename}"
logger.info(f"Downloading {url} -> {filepath}")
# for url in urls:
# uri = url.uri
# if uri.startswith("file://"):
# filepath = uri[len("file://") :]
# elif uri.startswith("http"):
# path = urlparse(uri).path
# basename = os.path.basename(path)
# filepath = f"{tempdir}/{make_random_string() + basename}"
# logger.info(f"Downloading {url} -> {filepath}")
async with httpx.AsyncClient() as client:
r = await client.get(uri)
resp = r.text
with open(filepath, "w") as fp:
fp.write(resp)
else:
raise ValueError(f"Unsupported URL {url}")
# async with httpx.AsyncClient() as client:
# r = await client.get(uri)
# resp = r.text
# with open(filepath, "w") as fp:
# fp.write(resp)
# else:
# raise ValueError(f"Unsupported URL {url}")
contents.append(
TextContentItem(
text=f'# User provided a file accessible to you at "{filepath}"\nYou can use code_interpreter to load and inspect it.'
)
)
# contents.append(
# TextContentItem(
# text=f'# User provided a file accessible to you at "{filepath}"\nYou can use code_interpreter to load and inspect it.'
# )
# )
if isinstance(message.content, list):
message.content.extend(contents)
else:
if isinstance(message.content, str):
message.content = [TextContentItem(text=message.content)] + contents
else:
message.content = [message.content] + contents
# if isinstance(message.content, list):
# message.content.extend(contents)
# else:
# if isinstance(message.content, str):
# message.content = [TextContentItem(text=message.content)] + contents
# else:
# message.content = [message.content] + contents
def _interpret_content_as_attachment(