mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
feat: Updating Rag Tool to use Files API and Vector Stores API (#3344)
Some checks failed
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 1s
Python Package Build Test / build (3.12) (push) Failing after 1s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 18s
Update ReadTheDocs / update-readthedocs (push) Failing after 15s
Python Package Build Test / build (3.13) (push) Failing after 19s
Test External API and Providers / test-external (venv) (push) Failing after 17s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 23s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 22s
Unit Tests / unit-tests (3.12) (push) Failing after 19s
Unit Tests / unit-tests (3.13) (push) Failing after 19s
Vector IO Integration Tests / test-matrix (push) Failing after 23s
UI Tests / ui-tests (22) (push) Successful in 44s
Pre-commit / pre-commit (push) Successful in 1m32s
Some checks failed
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 1s
Python Package Build Test / build (3.12) (push) Failing after 1s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 18s
Update ReadTheDocs / update-readthedocs (push) Failing after 15s
Python Package Build Test / build (3.13) (push) Failing after 19s
Test External API and Providers / test-external (venv) (push) Failing after 17s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 23s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 22s
Unit Tests / unit-tests (3.12) (push) Failing after 19s
Unit Tests / unit-tests (3.13) (push) Failing after 19s
Vector IO Integration Tests / test-matrix (push) Failing after 23s
UI Tests / ui-tests (22) (push) Successful in 44s
Pre-commit / pre-commit (push) Successful in 1m32s
This commit is contained in:
parent
47b640370e
commit
7cd1c2c238
6 changed files with 93 additions and 39 deletions
|
@ -14,6 +14,6 @@ from .config import RagToolRuntimeConfig
|
|||
async def get_provider_impl(config: RagToolRuntimeConfig, deps: dict[Api, Any]):
|
||||
from .memory import MemoryToolRuntimeImpl
|
||||
|
||||
impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference])
|
||||
impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference], deps[Api.files])
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
@ -5,10 +5,15 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import mimetypes
|
||||
import secrets
|
||||
import string
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from fastapi import UploadFile
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
|
@ -17,6 +22,7 @@ from llama_stack.apis.common.content_types import (
|
|||
InterleavedContentItem,
|
||||
TextContentItem,
|
||||
)
|
||||
from llama_stack.apis.files import Files, OpenAIFilePurpose
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.tools import (
|
||||
ListToolDefsResponse,
|
||||
|
@ -30,13 +36,18 @@ from llama_stack.apis.tools import (
|
|||
ToolParameter,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
|
||||
from llama_stack.apis.vector_io import (
|
||||
QueryChunksResponse,
|
||||
VectorIO,
|
||||
VectorStoreChunkingStrategyStatic,
|
||||
VectorStoreChunkingStrategyStaticConfig,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
content_from_doc,
|
||||
make_overlapped_chunks,
|
||||
parse_data_url,
|
||||
)
|
||||
|
||||
from .config import RagToolRuntimeConfig
|
||||
|
@ -55,10 +66,12 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
|||
config: RagToolRuntimeConfig,
|
||||
vector_io_api: VectorIO,
|
||||
inference_api: Inference,
|
||||
files_api: Files,
|
||||
):
|
||||
self.config = config
|
||||
self.vector_io_api = vector_io_api
|
||||
self.inference_api = inference_api
|
||||
self.files_api = files_api
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
@ -78,27 +91,50 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
|||
vector_db_id: str,
|
||||
chunk_size_in_tokens: int = 512,
|
||||
) -> None:
|
||||
chunks = []
|
||||
if not documents:
|
||||
return
|
||||
|
||||
for doc in documents:
|
||||
content = await content_from_doc(doc)
|
||||
# TODO: we should add enrichment here as URLs won't be added to the metadata by default
|
||||
chunks.extend(
|
||||
make_overlapped_chunks(
|
||||
doc.document_id,
|
||||
content,
|
||||
chunk_size_in_tokens,
|
||||
chunk_size_in_tokens // 4,
|
||||
doc.metadata,
|
||||
if isinstance(doc.content, URL):
|
||||
if doc.content.uri.startswith("data:"):
|
||||
parts = parse_data_url(doc.content.uri)
|
||||
file_data = base64.b64decode(parts["data"]) if parts["is_base64"] else parts["data"].encode()
|
||||
mime_type = parts["mimetype"]
|
||||
else:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(doc.content.uri)
|
||||
file_data = response.content
|
||||
mime_type = doc.mime_type or response.headers.get("content-type", "application/octet-stream")
|
||||
else:
|
||||
content_str = await content_from_doc(doc)
|
||||
file_data = content_str.encode("utf-8")
|
||||
mime_type = doc.mime_type or "text/plain"
|
||||
|
||||
file_extension = mimetypes.guess_extension(mime_type) or ".txt"
|
||||
filename = doc.metadata.get("filename", f"{doc.document_id}{file_extension}")
|
||||
|
||||
file_obj = io.BytesIO(file_data)
|
||||
file_obj.name = filename
|
||||
|
||||
upload_file = UploadFile(file=file_obj, filename=filename)
|
||||
|
||||
created_file = await self.files_api.openai_upload_file(
|
||||
file=upload_file, purpose=OpenAIFilePurpose.ASSISTANTS
|
||||
)
|
||||
|
||||
chunking_strategy = VectorStoreChunkingStrategyStatic(
|
||||
static=VectorStoreChunkingStrategyStaticConfig(
|
||||
max_chunk_size_tokens=chunk_size_in_tokens,
|
||||
chunk_overlap_tokens=chunk_size_in_tokens // 4,
|
||||
)
|
||||
)
|
||||
|
||||
if not chunks:
|
||||
return
|
||||
|
||||
await self.vector_io_api.insert_chunks(
|
||||
chunks=chunks,
|
||||
vector_db_id=vector_db_id,
|
||||
)
|
||||
await self.vector_io_api.openai_attach_file_to_vector_store(
|
||||
vector_store_id=vector_db_id,
|
||||
file_id=created_file.id,
|
||||
attributes=doc.metadata,
|
||||
chunking_strategy=chunking_strategy,
|
||||
)
|
||||
|
||||
async def query(
|
||||
self,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue