mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-06 02:32:40 +00:00
append document context
This commit is contained in:
parent
0b1e71718c
commit
8fa487bdc3
2 changed files with 146 additions and 38 deletions
File diff suppressed because one or more lines are too long
|
@ -468,9 +468,12 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
async def get_raw_document_text(self, document: Document) -> str:
|
||||
if isinstance(document.content, URL):
|
||||
return await load_data_from_url(document.content)
|
||||
return await load_data_from_url(document.content.uri)
|
||||
elif isinstance(document.content, str):
|
||||
return document.content
|
||||
if document.content.startswith("http"):
|
||||
return await load_data_from_url(document.content)
|
||||
else:
|
||||
return document.content
|
||||
elif isinstance(document.content, TextContentItem):
|
||||
return document.content.text
|
||||
else:
|
||||
|
@ -496,14 +499,16 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
contexts = []
|
||||
for document in documents:
|
||||
raw_document_text = await self.get_raw_document_text(document)
|
||||
contexts.append(TextContentItem(text=raw_document_text))
|
||||
contexts.append(raw_document_text)
|
||||
print("documents", contexts)
|
||||
input_messages[-1].context = "\n".join(contexts)
|
||||
# modify the last user message to include the document
|
||||
input_messages.append(
|
||||
ToolResponseMessage(
|
||||
call_id=str(uuid.uuid4()),
|
||||
content=contexts,
|
||||
)
|
||||
)
|
||||
# input_messages.append(
|
||||
# ToolResponseMessage(
|
||||
# call_id=str(uuid.uuid4()),
|
||||
# content=contexts,
|
||||
# )
|
||||
# )
|
||||
|
||||
session_info = await self.storage.get_session_info(session_id)
|
||||
# if the session has a memory bank id, let the memory tool use it
|
||||
|
@ -1067,11 +1072,10 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
# )
|
||||
|
||||
|
||||
async def load_data_from_url(url: URL) -> str:
|
||||
uri = url.uri
|
||||
if uri.startswith("http"):
|
||||
async def load_data_from_url(url: str) -> str:
|
||||
if url.startswith("http"):
|
||||
async with httpx.AsyncClient() as client:
|
||||
r = await client.get(uri)
|
||||
r = await client.get(url)
|
||||
resp = r.text
|
||||
return resp
|
||||
return ""
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue