append document context

This commit is contained in:
Xi Yan 2025-03-22 12:30:09 -07:00
parent 0b1e71718c
commit 8fa487bdc3
2 changed files with 146 additions and 38 deletions

File diff suppressed because one or more lines are too long

View file

@ -468,9 +468,12 @@ class ChatAgent(ShieldRunnerMixin):
async def get_raw_document_text(self, document: Document) -> str:
if isinstance(document.content, URL):
return await load_data_from_url(document.content)
return await load_data_from_url(document.content.uri)
elif isinstance(document.content, str):
return document.content
if document.content.startswith("http"):
return await load_data_from_url(document.content)
else:
return document.content
elif isinstance(document.content, TextContentItem):
return document.content.text
else:
@ -496,14 +499,16 @@ class ChatAgent(ShieldRunnerMixin):
contexts = []
for document in documents:
raw_document_text = await self.get_raw_document_text(document)
contexts.append(TextContentItem(text=raw_document_text))
contexts.append(raw_document_text)
print("documents", contexts)
input_messages[-1].context = "\n".join(contexts)
# modify the last user message to include the document
input_messages.append(
ToolResponseMessage(
call_id=str(uuid.uuid4()),
content=contexts,
)
)
# input_messages.append(
# ToolResponseMessage(
# call_id=str(uuid.uuid4()),
# content=contexts,
# )
# )
session_info = await self.storage.get_session_info(session_id)
# if the session has a memory bank id, let the memory tool use it
@ -1067,11 +1072,10 @@ class ChatAgent(ShieldRunnerMixin):
# )
async def load_data_from_url(url: URL) -> str:
uri = url.uri
if uri.startswith("http"):
async def load_data_from_url(url: str) -> str:
if url.startswith("http"):
async with httpx.AsyncClient() as client:
r = await client.get(uri)
r = await client.get(url)
resp = r.text
return resp
return ""