mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-06 18:50:44 +00:00
append document context
This commit is contained in:
parent
0b1e71718c
commit
8fa487bdc3
2 changed files with 146 additions and 38 deletions
File diff suppressed because one or more lines are too long
|
@ -468,9 +468,12 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
async def get_raw_document_text(self, document: Document) -> str:
|
async def get_raw_document_text(self, document: Document) -> str:
|
||||||
if isinstance(document.content, URL):
|
if isinstance(document.content, URL):
|
||||||
return await load_data_from_url(document.content)
|
return await load_data_from_url(document.content.uri)
|
||||||
elif isinstance(document.content, str):
|
elif isinstance(document.content, str):
|
||||||
return document.content
|
if document.content.startswith("http"):
|
||||||
|
return await load_data_from_url(document.content)
|
||||||
|
else:
|
||||||
|
return document.content
|
||||||
elif isinstance(document.content, TextContentItem):
|
elif isinstance(document.content, TextContentItem):
|
||||||
return document.content.text
|
return document.content.text
|
||||||
else:
|
else:
|
||||||
|
@ -496,14 +499,16 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
contexts = []
|
contexts = []
|
||||||
for document in documents:
|
for document in documents:
|
||||||
raw_document_text = await self.get_raw_document_text(document)
|
raw_document_text = await self.get_raw_document_text(document)
|
||||||
contexts.append(TextContentItem(text=raw_document_text))
|
contexts.append(raw_document_text)
|
||||||
|
print("documents", contexts)
|
||||||
|
input_messages[-1].context = "\n".join(contexts)
|
||||||
# modify the last user message to include the document
|
# modify the last user message to include the document
|
||||||
input_messages.append(
|
# input_messages.append(
|
||||||
ToolResponseMessage(
|
# ToolResponseMessage(
|
||||||
call_id=str(uuid.uuid4()),
|
# call_id=str(uuid.uuid4()),
|
||||||
content=contexts,
|
# content=contexts,
|
||||||
)
|
# )
|
||||||
)
|
# )
|
||||||
|
|
||||||
session_info = await self.storage.get_session_info(session_id)
|
session_info = await self.storage.get_session_info(session_id)
|
||||||
# if the session has a memory bank id, let the memory tool use it
|
# if the session has a memory bank id, let the memory tool use it
|
||||||
|
@ -1067,11 +1072,10 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
# )
|
# )
|
||||||
|
|
||||||
|
|
||||||
async def load_data_from_url(url: URL) -> str:
|
async def load_data_from_url(url: str) -> str:
|
||||||
uri = url.uri
|
if url.startswith("http"):
|
||||||
if uri.startswith("http"):
|
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
r = await client.get(uri)
|
r = await client.get(url)
|
||||||
resp = r.text
|
resp = r.text
|
||||||
return resp
|
return resp
|
||||||
return ""
|
return ""
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue