append document context

This commit is contained in:
Xi Yan 2025-03-22 12:30:09 -07:00
parent 0b1e71718c
commit 8fa487bdc3
2 changed files with 146 additions and 38 deletions

File diff suppressed because one or more lines are too long

View file

@ -468,9 +468,12 @@ class ChatAgent(ShieldRunnerMixin):
async def get_raw_document_text(self, document: Document) -> str: async def get_raw_document_text(self, document: Document) -> str:
if isinstance(document.content, URL): if isinstance(document.content, URL):
return await load_data_from_url(document.content) return await load_data_from_url(document.content.uri)
elif isinstance(document.content, str): elif isinstance(document.content, str):
return document.content if document.content.startswith("http"):
return await load_data_from_url(document.content)
else:
return document.content
elif isinstance(document.content, TextContentItem): elif isinstance(document.content, TextContentItem):
return document.content.text return document.content.text
else: else:
@ -496,14 +499,16 @@ class ChatAgent(ShieldRunnerMixin):
contexts = [] contexts = []
for document in documents: for document in documents:
raw_document_text = await self.get_raw_document_text(document) raw_document_text = await self.get_raw_document_text(document)
contexts.append(TextContentItem(text=raw_document_text)) contexts.append(raw_document_text)
print("documents", contexts)
input_messages[-1].context = "\n".join(contexts)
# modify the last user message to include the document # modify the last user message to include the document
input_messages.append( # input_messages.append(
ToolResponseMessage( # ToolResponseMessage(
call_id=str(uuid.uuid4()), # call_id=str(uuid.uuid4()),
content=contexts, # content=contexts,
) # )
) # )
session_info = await self.storage.get_session_info(session_id) session_info = await self.storage.get_session_info(session_id)
# if the session has a memory bank id, let the memory tool use it # if the session has a memory bank id, let the memory tool use it
@ -1067,11 +1072,10 @@ class ChatAgent(ShieldRunnerMixin):
# ) # )
async def load_data_from_url(url: URL) -> str: async def load_data_from_url(url: str) -> str:
uri = url.uri if url.startswith("http"):
if uri.startswith("http"):
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
r = await client.get(uri) r = await client.get(url)
resp = r.text resp = r.text
return resp return resp
return "" return ""