mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-06 02:32:40 +00:00
clean up
This commit is contained in:
parent
8fa487bdc3
commit
9cf356ccce
1 changed files with 16 additions and 154 deletions
|
@ -466,21 +466,6 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
span.set_attribute("output", "no violations")
|
||||
|
||||
async def get_raw_document_text(self, document: Document) -> str:
|
||||
if isinstance(document.content, URL):
|
||||
return await load_data_from_url(document.content.uri)
|
||||
elif isinstance(document.content, str):
|
||||
if document.content.startswith("http"):
|
||||
return await load_data_from_url(document.content)
|
||||
else:
|
||||
return document.content
|
||||
elif isinstance(document.content, TextContentItem):
|
||||
return document.content.text
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected document content type: {type(document.content)}"
|
||||
)
|
||||
|
||||
async def _run(
|
||||
self,
|
||||
session_id: str,
|
||||
|
@ -490,25 +475,14 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
stream: bool = False,
|
||||
documents: Optional[List[Document]] = None,
|
||||
) -> AsyncGenerator:
|
||||
# if documents:
|
||||
# await self.handle_documents(session_id, documents, input_messages)
|
||||
|
||||
# if document is passed in a turn, we parse the raw text of the document
|
||||
# and sent it as a user message
|
||||
if documents:
|
||||
contexts = []
|
||||
for document in documents:
|
||||
raw_document_text = await self.get_raw_document_text(document)
|
||||
raw_document_text = await get_raw_document_text(document)
|
||||
contexts.append(raw_document_text)
|
||||
print("documents", contexts)
|
||||
input_messages[-1].context = "\n".join(contexts)
|
||||
# modify the last user message to include the document
|
||||
# input_messages.append(
|
||||
# ToolResponseMessage(
|
||||
# call_id=str(uuid.uuid4()),
|
||||
# content=contexts,
|
||||
# )
|
||||
# )
|
||||
|
||||
session_info = await self.storage.get_session_info(session_id)
|
||||
# if the session has a memory bank id, let the memory tool use it
|
||||
|
@ -982,95 +956,6 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
logger.debug(f"tool call {tool_name_str} completed with result: {result}")
|
||||
return result
|
||||
|
||||
# async def handle_documents(
|
||||
# self,
|
||||
# session_id: str,
|
||||
# documents: List[Document],
|
||||
# input_messages: List[Message],
|
||||
# ) -> None:
|
||||
# memory_tool = any(tool_def.tool_name == MEMORY_QUERY_TOOL for tool_def in self.tool_defs)
|
||||
# code_interpreter_tool = any(tool_def.tool_name == BuiltinTool.code_interpreter for tool_def in self.tool_defs)
|
||||
# content_items = []
|
||||
# url_items = []
|
||||
# pattern = re.compile("^(https?://|file://|data:)")
|
||||
# for d in documents:
|
||||
# if isinstance(d.content, URL):
|
||||
# url_items.append(d.content)
|
||||
# elif pattern.match(d.content):
|
||||
# url_items.append(URL(uri=d.content))
|
||||
# else:
|
||||
# content_items.append(d)
|
||||
|
||||
# # Save the contents to a tempdir and use its path as a URL if code interpreter is present
|
||||
# if code_interpreter_tool:
|
||||
# for c in content_items:
|
||||
# temp_file_path = os.path.join(self.tempdir, f"{make_random_string()}.txt")
|
||||
# with open(temp_file_path, "w") as temp_file:
|
||||
# temp_file.write(c.content)
|
||||
# url_items.append(URL(uri=f"file://{temp_file_path}"))
|
||||
|
||||
# if memory_tool and code_interpreter_tool:
|
||||
# # if both memory and code_interpreter are available, we download the URLs
|
||||
# # and attach the data to the last message.
|
||||
# await attachment_message(self.tempdir, url_items, input_messages[-1])
|
||||
# # Since memory is present, add all the data to the memory bank
|
||||
# await self.add_to_session_vector_db(session_id, documents)
|
||||
# elif code_interpreter_tool:
|
||||
# # if only code_interpreter is available, we download the URLs to a tempdir
|
||||
# # and attach the path to them as a message to inference with the
|
||||
# # assumption that the model invokes the code_interpreter tool with the path
|
||||
# await attachment_message(self.tempdir, url_items, input_messages[-1])
|
||||
# elif memory_tool:
|
||||
# # if only memory is available, we load the data from the URLs and content items to the memory bank
|
||||
# await self.add_to_session_vector_db(session_id, documents)
|
||||
# else:
|
||||
# # if no memory or code_interpreter tool is available,
|
||||
# # we try to load the data from the URLs and content items as a message to inference
|
||||
# # and add it to the last message's context
|
||||
# input_messages[-1].context = "\n".join(
|
||||
# [doc.content for doc in content_items] + await load_data_from_urls(url_items)
|
||||
# )
|
||||
|
||||
# async def _ensure_vector_db(self, session_id: str) -> str:
|
||||
# session_info = await self.storage.get_session_info(session_id)
|
||||
# if session_info is None:
|
||||
# raise ValueError(f"Session {session_id} not found")
|
||||
|
||||
# if session_info.vector_db_id is None:
|
||||
# vector_db_id = f"vector_db_{session_id}"
|
||||
|
||||
# # TODO: the semantic for registration is definitely not "creation"
|
||||
# # so we need to fix it if we expect the agent to create a new vector db
|
||||
# # for each session
|
||||
# await self.vector_io_api.register_vector_db(
|
||||
# vector_db_id=vector_db_id,
|
||||
# embedding_model="all-MiniLM-L6-v2",
|
||||
# )
|
||||
# await self.storage.add_vector_db_to_session(session_id, vector_db_id)
|
||||
# else:
|
||||
# vector_db_id = session_info.vector_db_id
|
||||
|
||||
# return vector_db_id
|
||||
|
||||
# async def add_to_session_vector_db(
|
||||
# self, session_id: str, data: List[Document]
|
||||
# ) -> None:
|
||||
# vector_db_id = await self._ensure_vector_db(session_id)
|
||||
# documents = [
|
||||
# RAGDocument(
|
||||
# document_id=str(uuid.uuid4()),
|
||||
# content=a.content,
|
||||
# mime_type=a.mime_type,
|
||||
# metadata={},
|
||||
# )
|
||||
# for a in data
|
||||
# ]
|
||||
# await self.tool_runtime_api.rag_tool.insert(
|
||||
# documents=documents,
|
||||
# vector_db_id=vector_db_id,
|
||||
# chunk_size_in_tokens=512,
|
||||
# )
|
||||
|
||||
|
||||
async def load_data_from_url(url: str) -> str:
|
||||
if url.startswith("http"):
|
||||
|
@ -1080,43 +965,20 @@ async def load_data_from_url(url: str) -> str:
|
|||
return resp
|
||||
return ""
|
||||
|
||||
|
||||
# async def attachment_message(
|
||||
# tempdir: str, urls: List[URL], message: UserMessage
|
||||
# ) -> None:
|
||||
# contents = []
|
||||
|
||||
# for url in urls:
|
||||
# uri = url.uri
|
||||
# if uri.startswith("file://"):
|
||||
# filepath = uri[len("file://") :]
|
||||
# elif uri.startswith("http"):
|
||||
# path = urlparse(uri).path
|
||||
# basename = os.path.basename(path)
|
||||
# filepath = f"{tempdir}/{make_random_string() + basename}"
|
||||
# logger.info(f"Downloading {url} -> {filepath}")
|
||||
|
||||
# async with httpx.AsyncClient() as client:
|
||||
# r = await client.get(uri)
|
||||
# resp = r.text
|
||||
# with open(filepath, "w") as fp:
|
||||
# fp.write(resp)
|
||||
# else:
|
||||
# raise ValueError(f"Unsupported URL {url}")
|
||||
|
||||
# contents.append(
|
||||
# TextContentItem(
|
||||
# text=f'# User provided a file accessible to you at "{filepath}"\nYou can use code_interpreter to load and inspect it.'
|
||||
# )
|
||||
# )
|
||||
|
||||
# if isinstance(message.content, list):
|
||||
# message.content.extend(contents)
|
||||
# else:
|
||||
# if isinstance(message.content, str):
|
||||
# message.content = [TextContentItem(text=message.content)] + contents
|
||||
# else:
|
||||
# message.content = [message.content] + contents
|
||||
async def get_raw_document_text(document: Document) -> str:
|
||||
if isinstance(document.content, URL):
|
||||
return await load_data_from_url(document.content.uri)
|
||||
elif isinstance(document.content, str):
|
||||
if document.content.startswith("http"):
|
||||
return await load_data_from_url(document.content)
|
||||
else:
|
||||
return document.content
|
||||
elif isinstance(document.content, TextContentItem):
|
||||
return document.content.text
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected document content type: {type(document.content)}"
|
||||
)
|
||||
|
||||
|
||||
def _interpret_content_as_attachment(
|
||||
|
@ -1131,4 +993,4 @@ def _interpret_content_as_attachment(
|
|||
mime_type=data["mimetype"],
|
||||
)
|
||||
|
||||
return None
|
||||
return None
|
Loading…
Add table
Add a link
Reference in a new issue