This commit is contained in:
Xi Yan 2025-03-22 12:31:55 -07:00
parent 8fa487bdc3
commit 9cf356ccce

View file

@ -466,21 +466,6 @@ class ChatAgent(ShieldRunnerMixin):
)
span.set_attribute("output", "no violations")
async def get_raw_document_text(self, document: Document) -> str:
if isinstance(document.content, URL):
return await load_data_from_url(document.content.uri)
elif isinstance(document.content, str):
if document.content.startswith("http"):
return await load_data_from_url(document.content)
else:
return document.content
elif isinstance(document.content, TextContentItem):
return document.content.text
else:
raise ValueError(
f"Unexpected document content type: {type(document.content)}"
)
async def _run(
self,
session_id: str,
@ -490,25 +475,14 @@ class ChatAgent(ShieldRunnerMixin):
stream: bool = False,
documents: Optional[List[Document]] = None,
) -> AsyncGenerator:
# if documents:
# await self.handle_documents(session_id, documents, input_messages)
# if document is passed in a turn, we parse the raw text of the document
# and sent it as a user message
if documents:
contexts = []
for document in documents:
raw_document_text = await self.get_raw_document_text(document)
raw_document_text = await get_raw_document_text(document)
contexts.append(raw_document_text)
print("documents", contexts)
input_messages[-1].context = "\n".join(contexts)
# modify the last user message to include the document
# input_messages.append(
# ToolResponseMessage(
# call_id=str(uuid.uuid4()),
# content=contexts,
# )
# )
session_info = await self.storage.get_session_info(session_id)
# if the session has a memory bank id, let the memory tool use it
@ -982,95 +956,6 @@ class ChatAgent(ShieldRunnerMixin):
logger.debug(f"tool call {tool_name_str} completed with result: {result}")
return result
# async def handle_documents(
# self,
# session_id: str,
# documents: List[Document],
# input_messages: List[Message],
# ) -> None:
# memory_tool = any(tool_def.tool_name == MEMORY_QUERY_TOOL for tool_def in self.tool_defs)
# code_interpreter_tool = any(tool_def.tool_name == BuiltinTool.code_interpreter for tool_def in self.tool_defs)
# content_items = []
# url_items = []
# pattern = re.compile("^(https?://|file://|data:)")
# for d in documents:
# if isinstance(d.content, URL):
# url_items.append(d.content)
# elif pattern.match(d.content):
# url_items.append(URL(uri=d.content))
# else:
# content_items.append(d)
# # Save the contents to a tempdir and use its path as a URL if code interpreter is present
# if code_interpreter_tool:
# for c in content_items:
# temp_file_path = os.path.join(self.tempdir, f"{make_random_string()}.txt")
# with open(temp_file_path, "w") as temp_file:
# temp_file.write(c.content)
# url_items.append(URL(uri=f"file://{temp_file_path}"))
# if memory_tool and code_interpreter_tool:
# # if both memory and code_interpreter are available, we download the URLs
# # and attach the data to the last message.
# await attachment_message(self.tempdir, url_items, input_messages[-1])
# # Since memory is present, add all the data to the memory bank
# await self.add_to_session_vector_db(session_id, documents)
# elif code_interpreter_tool:
# # if only code_interpreter is available, we download the URLs to a tempdir
# # and attach the path to them as a message to inference with the
# # assumption that the model invokes the code_interpreter tool with the path
# await attachment_message(self.tempdir, url_items, input_messages[-1])
# elif memory_tool:
# # if only memory is available, we load the data from the URLs and content items to the memory bank
# await self.add_to_session_vector_db(session_id, documents)
# else:
# # if no memory or code_interpreter tool is available,
# # we try to load the data from the URLs and content items as a message to inference
# # and add it to the last message's context
# input_messages[-1].context = "\n".join(
# [doc.content for doc in content_items] + await load_data_from_urls(url_items)
# )
# async def _ensure_vector_db(self, session_id: str) -> str:
# session_info = await self.storage.get_session_info(session_id)
# if session_info is None:
# raise ValueError(f"Session {session_id} not found")
# if session_info.vector_db_id is None:
# vector_db_id = f"vector_db_{session_id}"
# # TODO: the semantic for registration is definitely not "creation"
# # so we need to fix it if we expect the agent to create a new vector db
# # for each session
# await self.vector_io_api.register_vector_db(
# vector_db_id=vector_db_id,
# embedding_model="all-MiniLM-L6-v2",
# )
# await self.storage.add_vector_db_to_session(session_id, vector_db_id)
# else:
# vector_db_id = session_info.vector_db_id
# return vector_db_id
# async def add_to_session_vector_db(
# self, session_id: str, data: List[Document]
# ) -> None:
# vector_db_id = await self._ensure_vector_db(session_id)
# documents = [
# RAGDocument(
# document_id=str(uuid.uuid4()),
# content=a.content,
# mime_type=a.mime_type,
# metadata={},
# )
# for a in data
# ]
# await self.tool_runtime_api.rag_tool.insert(
# documents=documents,
# vector_db_id=vector_db_id,
# chunk_size_in_tokens=512,
# )
async def load_data_from_url(url: str) -> str:
if url.startswith("http"):
@ -1080,43 +965,20 @@ async def load_data_from_url(url: str) -> str:
return resp
return ""
# async def attachment_message(
# tempdir: str, urls: List[URL], message: UserMessage
# ) -> None:
# contents = []
# for url in urls:
# uri = url.uri
# if uri.startswith("file://"):
# filepath = uri[len("file://") :]
# elif uri.startswith("http"):
# path = urlparse(uri).path
# basename = os.path.basename(path)
# filepath = f"{tempdir}/{make_random_string() + basename}"
# logger.info(f"Downloading {url} -> {filepath}")
# async with httpx.AsyncClient() as client:
# r = await client.get(uri)
# resp = r.text
# with open(filepath, "w") as fp:
# fp.write(resp)
# else:
# raise ValueError(f"Unsupported URL {url}")
# contents.append(
# TextContentItem(
# text=f'# User provided a file accessible to you at "{filepath}"\nYou can use code_interpreter to load and inspect it.'
# )
# )
# if isinstance(message.content, list):
# message.content.extend(contents)
# else:
# if isinstance(message.content, str):
# message.content = [TextContentItem(text=message.content)] + contents
# else:
# message.content = [message.content] + contents
async def get_raw_document_text(document: Document) -> str:
if isinstance(document.content, URL):
return await load_data_from_url(document.content.uri)
elif isinstance(document.content, str):
if document.content.startswith("http"):
return await load_data_from_url(document.content)
else:
return document.content
elif isinstance(document.content, TextContentItem):
return document.content.text
else:
raise ValueError(
f"Unexpected document content type: {type(document.content)}"
)
def _interpret_content_as_attachment(
@ -1131,4 +993,4 @@ def _interpret_content_as_attachment(
mime_type=data["mimetype"],
)
return None
return None