mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
chore: Updating documentation, adding exception handling for Vector Stores in RAG Tool, more tests on migration, and migrate off of inference_api for context_retriever for RAG (#3367)
# What does this PR do? - Updating documentation on migration from RAG Tool to Vector Stores and Files APIs - Adding exception handling for Vector Stores in RAG Tool - Add more tests on migration from RAG Tool to Vector Stores - Migrate off of inference_api for context_retriever for RAG <!-- If resolving an issue, uncomment and update the line below --> <!-- Closes #[issue-number] --> ## Test Plan Integration and unit tests added Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
f31bcc11bc
commit
d15368a302
5 changed files with 355 additions and 45 deletions
|
@ -8,7 +8,7 @@
|
|||
from jinja2 import Template
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
from llama_stack.apis.inference import UserMessage
|
||||
from llama_stack.apis.inference import OpenAIUserMessageParam
|
||||
from llama_stack.apis.tools.rag_tool import (
|
||||
DefaultRAGQueryGeneratorConfig,
|
||||
LLMRAGQueryGeneratorConfig,
|
||||
|
@ -61,16 +61,16 @@ async def llm_rag_query_generator(
|
|||
messages = [interleaved_content_as_str(content)]
|
||||
|
||||
template = Template(config.template)
|
||||
content = template.render({"messages": messages})
|
||||
rendered_content: str = template.render({"messages": messages})
|
||||
|
||||
model = config.model
|
||||
message = UserMessage(content=content)
|
||||
response = await inference_api.chat_completion(
|
||||
model_id=model,
|
||||
message = OpenAIUserMessageParam(content=rendered_content)
|
||||
response = await inference_api.openai_chat_completion(
|
||||
model=model,
|
||||
messages=[message],
|
||||
stream=False,
|
||||
)
|
||||
|
||||
query = response.completion_message.content
|
||||
query = response.choices[0].message.content
|
||||
|
||||
return query
|
||||
|
|
|
@ -45,10 +45,7 @@ from llama_stack.apis.vector_io import (
|
|||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
content_from_doc,
|
||||
parse_data_url,
|
||||
)
|
||||
from llama_stack.providers.utils.memory.vector_store import parse_data_url
|
||||
|
||||
from .config import RagToolRuntimeConfig
|
||||
from .context_retriever import generate_rag_query
|
||||
|
@ -60,6 +57,47 @@ def make_random_string(length: int = 8):
|
|||
return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length))
|
||||
|
||||
|
||||
async def raw_data_from_doc(doc: RAGDocument) -> tuple[bytes, str]:
|
||||
"""Get raw binary data and mime type from a RAGDocument for file upload."""
|
||||
if isinstance(doc.content, URL):
|
||||
if doc.content.uri.startswith("data:"):
|
||||
parts = parse_data_url(doc.content.uri)
|
||||
mime_type = parts["mimetype"]
|
||||
data = parts["data"]
|
||||
|
||||
if parts["is_base64"]:
|
||||
file_data = base64.b64decode(data)
|
||||
else:
|
||||
file_data = data.encode("utf-8")
|
||||
|
||||
return file_data, mime_type
|
||||
else:
|
||||
async with httpx.AsyncClient() as client:
|
||||
r = await client.get(doc.content.uri)
|
||||
r.raise_for_status()
|
||||
mime_type = r.headers.get("content-type", "application/octet-stream")
|
||||
return r.content, mime_type
|
||||
else:
|
||||
if isinstance(doc.content, str):
|
||||
content_str = doc.content
|
||||
else:
|
||||
content_str = interleaved_content_as_str(doc.content)
|
||||
|
||||
if content_str.startswith("data:"):
|
||||
parts = parse_data_url(content_str)
|
||||
mime_type = parts["mimetype"]
|
||||
data = parts["data"]
|
||||
|
||||
if parts["is_base64"]:
|
||||
file_data = base64.b64decode(data)
|
||||
else:
|
||||
file_data = data.encode("utf-8")
|
||||
|
||||
return file_data, mime_type
|
||||
else:
|
||||
return content_str.encode("utf-8"), "text/plain"
|
||||
|
||||
|
||||
class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -95,46 +133,52 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
|||
return
|
||||
|
||||
for doc in documents:
|
||||
if isinstance(doc.content, URL):
|
||||
if doc.content.uri.startswith("data:"):
|
||||
parts = parse_data_url(doc.content.uri)
|
||||
file_data = base64.b64decode(parts["data"]) if parts["is_base64"] else parts["data"].encode()
|
||||
mime_type = parts["mimetype"]
|
||||
else:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(doc.content.uri)
|
||||
file_data = response.content
|
||||
mime_type = doc.mime_type or response.headers.get("content-type", "application/octet-stream")
|
||||
else:
|
||||
content_str = await content_from_doc(doc)
|
||||
file_data = content_str.encode("utf-8")
|
||||
mime_type = doc.mime_type or "text/plain"
|
||||
try:
|
||||
try:
|
||||
file_data, mime_type = await raw_data_from_doc(doc)
|
||||
except Exception as e:
|
||||
log.error(f"Failed to extract content from document {doc.document_id}: {e}")
|
||||
continue
|
||||
|
||||
file_extension = mimetypes.guess_extension(mime_type) or ".txt"
|
||||
filename = doc.metadata.get("filename", f"{doc.document_id}{file_extension}")
|
||||
file_extension = mimetypes.guess_extension(mime_type) or ".txt"
|
||||
filename = doc.metadata.get("filename", f"{doc.document_id}{file_extension}")
|
||||
|
||||
file_obj = io.BytesIO(file_data)
|
||||
file_obj.name = filename
|
||||
file_obj = io.BytesIO(file_data)
|
||||
file_obj.name = filename
|
||||
|
||||
upload_file = UploadFile(file=file_obj, filename=filename)
|
||||
upload_file = UploadFile(file=file_obj, filename=filename)
|
||||
|
||||
created_file = await self.files_api.openai_upload_file(
|
||||
file=upload_file, purpose=OpenAIFilePurpose.ASSISTANTS
|
||||
)
|
||||
try:
|
||||
created_file = await self.files_api.openai_upload_file(
|
||||
file=upload_file, purpose=OpenAIFilePurpose.ASSISTANTS
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Failed to upload file for document {doc.document_id}: {e}")
|
||||
continue
|
||||
|
||||
chunking_strategy = VectorStoreChunkingStrategyStatic(
|
||||
static=VectorStoreChunkingStrategyStaticConfig(
|
||||
max_chunk_size_tokens=chunk_size_in_tokens,
|
||||
chunk_overlap_tokens=chunk_size_in_tokens // 4,
|
||||
chunking_strategy = VectorStoreChunkingStrategyStatic(
|
||||
static=VectorStoreChunkingStrategyStaticConfig(
|
||||
max_chunk_size_tokens=chunk_size_in_tokens,
|
||||
chunk_overlap_tokens=chunk_size_in_tokens // 4,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
await self.vector_io_api.openai_attach_file_to_vector_store(
|
||||
vector_store_id=vector_db_id,
|
||||
file_id=created_file.id,
|
||||
attributes=doc.metadata,
|
||||
chunking_strategy=chunking_strategy,
|
||||
)
|
||||
try:
|
||||
await self.vector_io_api.openai_attach_file_to_vector_store(
|
||||
vector_store_id=vector_db_id,
|
||||
file_id=created_file.id,
|
||||
attributes=doc.metadata,
|
||||
chunking_strategy=chunking_strategy,
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(
|
||||
f"Failed to attach file {created_file.id} to vector store {vector_db_id} for document {doc.document_id}: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Unexpected error processing document {doc.document_id}: {e}")
|
||||
continue
|
||||
|
||||
async def query(
|
||||
self,
|
||||
|
@ -274,7 +318,6 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
|||
if query_config:
|
||||
query_config = TypeAdapter(RAGQueryConfig).validate_python(query_config)
|
||||
else:
|
||||
# handle someone passing an empty dict
|
||||
query_config = RAGQueryConfig()
|
||||
|
||||
query = kwargs["query"]
|
||||
|
@ -285,6 +328,6 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
|||
)
|
||||
|
||||
return ToolInvocationResult(
|
||||
content=result.content,
|
||||
content=result.content or [],
|
||||
metadata=result.metadata,
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue