forked from phoenix-oss/llama-stack-mirror
# What does this PR do? This PR introduces support for keyword based FTS5 search with BM25 relevance scoring. It makes changes to the existing EmbeddingIndex base class in order to support a search_mode and query_str parameter, that can be used for keyword based search implementations. [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan run ``` pytest llama_stack/providers/tests/vector_io/test_sqlite_vec.py -v -s --tb=short --disable-warnings --asyncio-mode=auto ``` Output: ``` pytest llama_stack/providers/tests/vector_io/test_sqlite_vec.py -v -s --tb=short --disable-warnings --asyncio-mode=auto /Users/vnarsing/miniconda3/envs/stack-client/lib/python3.10/site-packages/pytest_asyncio/plugin.py:207: PytestDeprecationWarning: The configuration option "asyncio_default_fixture_loop_scope" is unset. The event loop scope for asynchronous fixtures will default to the fixture caching scope. Future versions of pytest-asyncio will default the loop scope for asynchronous fixtures to function scope. Set the default fixture loop scope explicitly in order to avoid unexpected behavior in the future. Valid fixture loop scopes are: "function", "class", "module", "package", "session" warnings.warn(PytestDeprecationWarning(_DEFAULT_FIXTURE_LOOP_SCOPE_UNSET)) ====================================================== test session starts ======================================================= platform darwin -- Python 3.10.16, pytest-8.3.4, pluggy-1.5.0 -- /Users/vnarsing/miniconda3/envs/stack-client/bin/python cachedir: .pytest_cache metadata: {'Python': '3.10.16', 'Platform': 'macOS-14.7.4-arm64-arm-64bit', 'Packages': {'pytest': '8.3.4', 'pluggy': '1.5.0'}, 'Plugins': {'html': '4.1.1', 'metadata': '3.1.1', 'asyncio': '0.25.3', 'anyio': '4.8.0'}} rootdir: /Users/vnarsing/go/src/github/meta-llama/llama-stack configfile: pyproject.toml plugins: html-4.1.1, metadata-3.1.1, asyncio-0.25.3, anyio-4.8.0 asyncio: mode=auto, asyncio_default_fixture_loop_scope=None collected 7 items llama_stack/providers/tests/vector_io/test_sqlite_vec.py::test_add_chunks PASSED llama_stack/providers/tests/vector_io/test_sqlite_vec.py::test_query_chunks_vector PASSED llama_stack/providers/tests/vector_io/test_sqlite_vec.py::test_query_chunks_fts PASSED llama_stack/providers/tests/vector_io/test_sqlite_vec.py::test_chunk_id_conflict PASSED llama_stack/providers/tests/vector_io/test_sqlite_vec.py::test_register_vector_db PASSED llama_stack/providers/tests/vector_io/test_sqlite_vec.py::test_unregister_vector_db PASSED llama_stack/providers/tests/vector_io/test_sqlite_vec.py::test_generate_chunk_id PASSED ``` For reference, with the implementation, the fts table looks like below: ``` Chunk ID: 9fbc39ce-c729-64a2-260f-c5ec9bb2a33e, Content: Sentence 0 from document 0 Chunk ID: 94062914-3e23-44cf-1e50-9e25821ba882, Content: Sentence 1 from document 0 Chunk ID: e6cfd559-4641-33ba-6ce1-7038226495eb, Content: Sentence 2 from document 0 Chunk ID: 1383af9b-f1f0-f417-4de5-65fe9456cc20, Content: Sentence 3 from document 0 Chunk ID: 2db19b1a-de14-353b-f4e1-085e8463361c, Content: Sentence 4 from document 0 Chunk ID: 9faf986a-f028-7714-068a-1c795e8f2598, Content: Sentence 5 from document 0 Chunk ID: ef593ead-5a4a-392f-7ad8-471a50f033e8, Content: Sentence 6 from document 0 Chunk ID: e161950f-021f-7300-4d05-3166738b94cf, Content: Sentence 7 from document 0 Chunk ID: 90610fc4-67c1-e740-f043-709c5978867a, Content: Sentence 8 from document 0 Chunk ID: 97712879-6fff-98ad-0558-e9f42e6b81d3, Content: Sentence 9 from document 0 Chunk ID: aea70411-51df-61ba-d2f0-cb2b5972c210, Content: Sentence 0 from document 1 Chunk ID: b678a463-7b84-92b8-abb2-27e9a1977e3c, Content: Sentence 1 from document 1 Chunk ID: 27bd63da-909c-1606-a109-75bdb9479882, Content: Sentence 2 from document 1 Chunk ID: a2ad49ad-f9be-5372-e0c7-7b0221d0b53e, Content: Sentence 3 from document 1 Chunk ID: cac53bcd-1965-082a-c0f4-ceee7323fc70, Content: Sentence 4 from document 1 ``` Query results: Result 1: Sentence 5 from document 0 Result 2: Sentence 5 from document 1 Result 3: Sentence 5 from document 2 [//]: # (## Documentation) --------- Signed-off-by: Varsha Prasad Narsing <varshaprasad96@gmail.com>
221 lines
7.2 KiB
Python
221 lines
7.2 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import asyncio
|
|
import logging
|
|
import secrets
|
|
import string
|
|
from typing import Any
|
|
|
|
from pydantic import TypeAdapter
|
|
|
|
from llama_stack.apis.common.content_types import (
|
|
URL,
|
|
InterleavedContent,
|
|
InterleavedContentItem,
|
|
TextContentItem,
|
|
)
|
|
from llama_stack.apis.inference import Inference
|
|
from llama_stack.apis.tools import (
|
|
ListToolDefsResponse,
|
|
RAGDocument,
|
|
RAGQueryConfig,
|
|
RAGQueryResult,
|
|
RAGToolRuntime,
|
|
Tool,
|
|
ToolDef,
|
|
ToolInvocationResult,
|
|
ToolParameter,
|
|
ToolRuntime,
|
|
)
|
|
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
|
|
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
|
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
|
from llama_stack.providers.utils.memory.vector_store import (
|
|
content_from_doc,
|
|
make_overlapped_chunks,
|
|
)
|
|
|
|
from .config import RagToolRuntimeConfig
|
|
from .context_retriever import generate_rag_query
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
def make_random_string(length: int = 8):
|
|
return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length))
|
|
|
|
|
|
class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
|
def __init__(
|
|
self,
|
|
config: RagToolRuntimeConfig,
|
|
vector_io_api: VectorIO,
|
|
inference_api: Inference,
|
|
):
|
|
self.config = config
|
|
self.vector_io_api = vector_io_api
|
|
self.inference_api = inference_api
|
|
|
|
async def initialize(self):
|
|
pass
|
|
|
|
async def shutdown(self):
|
|
pass
|
|
|
|
async def register_tool(self, tool: Tool) -> None:
|
|
pass
|
|
|
|
async def unregister_tool(self, tool_id: str) -> None:
|
|
return
|
|
|
|
async def insert(
|
|
self,
|
|
documents: list[RAGDocument],
|
|
vector_db_id: str,
|
|
chunk_size_in_tokens: int = 512,
|
|
) -> None:
|
|
chunks = []
|
|
for doc in documents:
|
|
content = await content_from_doc(doc)
|
|
chunks.extend(
|
|
make_overlapped_chunks(
|
|
doc.document_id,
|
|
content,
|
|
chunk_size_in_tokens,
|
|
chunk_size_in_tokens // 4,
|
|
doc.metadata,
|
|
)
|
|
)
|
|
|
|
if not chunks:
|
|
return
|
|
|
|
await self.vector_io_api.insert_chunks(
|
|
chunks=chunks,
|
|
vector_db_id=vector_db_id,
|
|
)
|
|
|
|
async def query(
|
|
self,
|
|
content: InterleavedContent,
|
|
vector_db_ids: list[str],
|
|
query_config: RAGQueryConfig | None = None,
|
|
) -> RAGQueryResult:
|
|
if not vector_db_ids:
|
|
raise ValueError(
|
|
"No vector DBs were provided to the knowledge search tool. Please provide at least one vector DB ID."
|
|
)
|
|
|
|
query_config = query_config or RAGQueryConfig()
|
|
query = await generate_rag_query(
|
|
query_config.query_generator_config,
|
|
content,
|
|
inference_api=self.inference_api,
|
|
)
|
|
tasks = [
|
|
self.vector_io_api.query_chunks(
|
|
vector_db_id=vector_db_id,
|
|
query=query,
|
|
params={
|
|
"max_chunks": query_config.max_chunks,
|
|
"mode": query_config.mode,
|
|
},
|
|
)
|
|
for vector_db_id in vector_db_ids
|
|
]
|
|
results: list[QueryChunksResponse] = await asyncio.gather(*tasks)
|
|
chunks = [c for r in results for c in r.chunks]
|
|
scores = [s for r in results for s in r.scores]
|
|
|
|
if not chunks:
|
|
return RAGQueryResult(content=None)
|
|
|
|
# sort by score
|
|
chunks, scores = zip(*sorted(zip(chunks, scores, strict=False), key=lambda x: x[1], reverse=True), strict=False) # type: ignore
|
|
chunks = chunks[: query_config.max_chunks]
|
|
|
|
tokens = 0
|
|
picked: list[InterleavedContentItem] = [
|
|
TextContentItem(
|
|
text=f"knowledge_search tool found {len(chunks)} chunks:\nBEGIN of knowledge_search tool results.\n"
|
|
)
|
|
]
|
|
for i, chunk in enumerate(chunks):
|
|
metadata = chunk.metadata
|
|
tokens += metadata["token_count"]
|
|
tokens += metadata.get("metadata_token_count", 0)
|
|
|
|
if tokens > query_config.max_tokens_in_context:
|
|
log.error(
|
|
f"Using {len(picked)} chunks; reached max tokens in context: {tokens}",
|
|
)
|
|
break
|
|
|
|
metadata_subset = {k: v for k, v in metadata.items() if k not in ["token_count", "metadata_token_count"]}
|
|
text_content = query_config.chunk_template.format(index=i + 1, chunk=chunk, metadata=metadata_subset)
|
|
picked.append(TextContentItem(text=text_content))
|
|
|
|
picked.append(TextContentItem(text="END of knowledge_search tool results.\n"))
|
|
picked.append(
|
|
TextContentItem(
|
|
text=f'The above results were retrieved to help answer the user\'s query: "{interleaved_content_as_str(content)}". Use them as supporting information only in answering this query.\n',
|
|
)
|
|
)
|
|
|
|
return RAGQueryResult(
|
|
content=picked,
|
|
metadata={
|
|
"document_ids": [c.metadata["document_id"] for c in chunks[: len(picked)]],
|
|
},
|
|
)
|
|
|
|
async def list_runtime_tools(
|
|
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
|
|
) -> ListToolDefsResponse:
|
|
# Parameters are not listed since these methods are not yet invoked automatically
|
|
# by the LLM. The method is only implemented so things like /tools can list without
|
|
# encountering fatals.
|
|
return ListToolDefsResponse(
|
|
data=[
|
|
ToolDef(
|
|
name="insert_into_memory",
|
|
description="Insert documents into memory",
|
|
),
|
|
ToolDef(
|
|
name="knowledge_search",
|
|
description="Search for information in a database.",
|
|
parameters=[
|
|
ToolParameter(
|
|
name="query",
|
|
description="The query to search for. Can be a natural language sentence or keywords.",
|
|
parameter_type="string",
|
|
),
|
|
],
|
|
),
|
|
]
|
|
)
|
|
|
|
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
|
|
vector_db_ids = kwargs.get("vector_db_ids", [])
|
|
query_config = kwargs.get("query_config")
|
|
if query_config:
|
|
query_config = TypeAdapter(RAGQueryConfig).validate_python(query_config)
|
|
else:
|
|
# handle someone passing an empty dict
|
|
query_config = RAGQueryConfig()
|
|
|
|
query = kwargs["query"]
|
|
result = await self.query(
|
|
content=query,
|
|
vector_db_ids=vector_db_ids,
|
|
query_config=query_config,
|
|
)
|
|
|
|
return ToolInvocationResult(
|
|
content=result.content,
|
|
metadata=result.metadata,
|
|
)
|