mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-29 03:14:19 +00:00
feat: Qdrant Vector index support (#221)
This PR adds support for Qdrant - https://qdrant.tech/ to be used as a vector memory. I've unit-tested the methods to confirm that they work as intended. To run Qdrant ``` docker run -p 6333:6333 qdrant/qdrant ```
This commit is contained in:
parent
668a495aba
commit
4c3d33e6f4
11 changed files with 242 additions and 7 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -13,5 +13,6 @@ xcuserdata/
|
||||||
Package.resolved
|
Package.resolved
|
||||||
*.pte
|
*.pte
|
||||||
*.ipynb_checkpoints*
|
*.ipynb_checkpoints*
|
||||||
|
.venv/
|
||||||
.idea
|
.idea
|
||||||
_build
|
_build
|
||||||
|
|
|
@ -38,7 +38,9 @@ class ChromaIndex(EmbeddingIndex):
|
||||||
ids=[f"{c.document_id}:chunk-{i}" for i, c in enumerate(chunks)],
|
ids=[f"{c.document_id}:chunk-{i}" for i, c in enumerate(chunks)],
|
||||||
)
|
)
|
||||||
|
|
||||||
async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse:
|
async def query(
|
||||||
|
self, embedding: NDArray, k: int, score_threshold: float
|
||||||
|
) -> QueryDocumentsResponse:
|
||||||
results = await self.collection.query(
|
results = await self.collection.query(
|
||||||
query_embeddings=[embedding.tolist()],
|
query_embeddings=[embedding.tolist()],
|
||||||
n_results=k,
|
n_results=k,
|
||||||
|
|
|
@ -91,7 +91,9 @@ class PGVectorIndex(EmbeddingIndex):
|
||||||
)
|
)
|
||||||
execute_values(self.cursor, query, values, template="(%s, %s, %s::vector)")
|
execute_values(self.cursor, query, values, template="(%s, %s, %s::vector)")
|
||||||
|
|
||||||
async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse:
|
async def query(
|
||||||
|
self, embedding: NDArray, k: int, score_threshold: float
|
||||||
|
) -> QueryDocumentsResponse:
|
||||||
self.cursor.execute(
|
self.cursor.execute(
|
||||||
f"""
|
f"""
|
||||||
SELECT document, embedding <-> %s::vector AS distance
|
SELECT document, embedding <-> %s::vector AS distance
|
||||||
|
|
15
llama_stack/providers/adapters/memory/qdrant/__init__.py
Normal file
15
llama_stack/providers/adapters/memory/qdrant/__init__.py
Normal file
|
@ -0,0 +1,15 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .config import QdrantConfig
|
||||||
|
|
||||||
|
|
||||||
|
async def get_adapter_impl(config: QdrantConfig, _deps):
|
||||||
|
from .qdrant import QdrantVectorMemoryAdapter
|
||||||
|
|
||||||
|
impl = QdrantVectorMemoryAdapter(config)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
25
llama_stack/providers/adapters/memory/qdrant/config.py
Normal file
25
llama_stack/providers/adapters/memory/qdrant/config.py
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from llama_models.schema_utils import json_schema_type
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class QdrantConfig(BaseModel):
|
||||||
|
location: Optional[str] = None
|
||||||
|
url: Optional[str] = None
|
||||||
|
port: Optional[int] = 6333
|
||||||
|
grpc_port: int = 6334
|
||||||
|
prefer_grpc: bool = False
|
||||||
|
https: Optional[bool] = None
|
||||||
|
api_key: Optional[str] = None
|
||||||
|
prefix: Optional[str] = None
|
||||||
|
timeout: Optional[int] = None
|
||||||
|
host: Optional[str] = None
|
||||||
|
path: Optional[str] = None
|
170
llama_stack/providers/adapters/memory/qdrant/qdrant.py
Normal file
170
llama_stack/providers/adapters/memory/qdrant/qdrant.py
Normal file
|
@ -0,0 +1,170 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import traceback
|
||||||
|
import uuid
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
from numpy.typing import NDArray
|
||||||
|
from qdrant_client import AsyncQdrantClient, models
|
||||||
|
from qdrant_client.models import PointStruct
|
||||||
|
|
||||||
|
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
|
||||||
|
|
||||||
|
from llama_stack.apis.memory import * # noqa: F403
|
||||||
|
|
||||||
|
from llama_stack.providers.adapters.memory.qdrant.config import QdrantConfig
|
||||||
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
|
BankWithIndex,
|
||||||
|
EmbeddingIndex,
|
||||||
|
)
|
||||||
|
|
||||||
|
CHUNK_ID_KEY = "_chunk_id"
|
||||||
|
|
||||||
|
|
||||||
|
def convert_id(_id: str) -> str:
|
||||||
|
"""
|
||||||
|
Converts any string into a UUID string based on a seed.
|
||||||
|
|
||||||
|
Qdrant accepts UUID strings and unsigned integers as point ID.
|
||||||
|
We use a seed to convert each string into a UUID string deterministically.
|
||||||
|
This allows us to overwrite the same point with the original ID.
|
||||||
|
"""
|
||||||
|
return str(uuid.uuid5(uuid.NAMESPACE_DNS, _id))
|
||||||
|
|
||||||
|
|
||||||
|
class QdrantIndex(EmbeddingIndex):
|
||||||
|
def __init__(self, client: AsyncQdrantClient, collection_name: str):
|
||||||
|
self.client = client
|
||||||
|
self.collection_name = collection_name
|
||||||
|
|
||||||
|
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
|
||||||
|
assert len(chunks) == len(
|
||||||
|
embeddings
|
||||||
|
), f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
|
||||||
|
|
||||||
|
if not await self.client.collection_exists(self.collection_name):
|
||||||
|
await self.client.create_collection(
|
||||||
|
self.collection_name,
|
||||||
|
vectors_config=models.VectorParams(
|
||||||
|
size=len(embeddings[0]), distance=models.Distance.COSINE
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
points = []
|
||||||
|
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
|
||||||
|
chunk_id = f"{chunk.document_id}:chunk-{i}"
|
||||||
|
points.append(
|
||||||
|
PointStruct(
|
||||||
|
id=convert_id(chunk_id),
|
||||||
|
vector=embedding,
|
||||||
|
payload={"chunk_content": chunk.model_dump()}
|
||||||
|
| {CHUNK_ID_KEY: chunk_id},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.client.upsert(collection_name=self.collection_name, points=points)
|
||||||
|
|
||||||
|
async def query(
|
||||||
|
self, embedding: NDArray, k: int, score_threshold: float
|
||||||
|
) -> QueryDocumentsResponse:
|
||||||
|
results = (
|
||||||
|
await self.client.query_points(
|
||||||
|
collection_name=self.collection_name,
|
||||||
|
query=embedding.tolist(),
|
||||||
|
limit=k,
|
||||||
|
with_payload=True,
|
||||||
|
score_threshold=score_threshold,
|
||||||
|
)
|
||||||
|
).points
|
||||||
|
|
||||||
|
chunks, scores = [], []
|
||||||
|
for point in results:
|
||||||
|
assert isinstance(point, models.ScoredPoint)
|
||||||
|
assert point.payload is not None
|
||||||
|
|
||||||
|
try:
|
||||||
|
chunk = Chunk(**point.payload["chunk_content"])
|
||||||
|
except Exception:
|
||||||
|
traceback.print_exc()
|
||||||
|
continue
|
||||||
|
|
||||||
|
chunks.append(chunk)
|
||||||
|
scores.append(point.score)
|
||||||
|
|
||||||
|
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
|
|
||||||
|
class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||||
|
def __init__(self, config: QdrantConfig) -> None:
|
||||||
|
self.config = config
|
||||||
|
self.client = AsyncQdrantClient(**self.config.model_dump(exclude_none=True))
|
||||||
|
self.cache = {}
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
self.client.close()
|
||||||
|
|
||||||
|
async def register_memory_bank(
|
||||||
|
self,
|
||||||
|
memory_bank: MemoryBankDef,
|
||||||
|
) -> None:
|
||||||
|
assert (
|
||||||
|
memory_bank.type == MemoryBankType.vector.value
|
||||||
|
), f"Only vector banks are supported {memory_bank.type}"
|
||||||
|
|
||||||
|
index = BankWithIndex(
|
||||||
|
bank=memory_bank,
|
||||||
|
index=QdrantIndex(self.client, memory_bank.identifier),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.cache[memory_bank.identifier] = index
|
||||||
|
|
||||||
|
async def list_memory_banks(self) -> List[MemoryBankDef]:
|
||||||
|
# Qdrant doesn't have collection level metadata to store the bank properties
|
||||||
|
# So we only return from the cache value
|
||||||
|
return [i.bank for i in self.cache.values()]
|
||||||
|
|
||||||
|
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
|
||||||
|
if bank_id in self.cache:
|
||||||
|
return self.cache[bank_id]
|
||||||
|
|
||||||
|
bank = await self.memory_bank_store.get_memory_bank(bank_id)
|
||||||
|
if not bank:
|
||||||
|
raise ValueError(f"Bank {bank_id} not found")
|
||||||
|
|
||||||
|
index = BankWithIndex(
|
||||||
|
bank=bank,
|
||||||
|
index=QdrantIndex(client=self.client, collection_name=bank_id),
|
||||||
|
)
|
||||||
|
self.cache[bank_id] = index
|
||||||
|
return index
|
||||||
|
|
||||||
|
async def insert_documents(
|
||||||
|
self,
|
||||||
|
bank_id: str,
|
||||||
|
documents: List[MemoryBankDocument],
|
||||||
|
ttl_seconds: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
index = await self._get_and_cache_bank_index(bank_id)
|
||||||
|
if not index:
|
||||||
|
raise ValueError(f"Bank {bank_id} not found")
|
||||||
|
|
||||||
|
await index.insert_documents(documents)
|
||||||
|
|
||||||
|
async def query_documents(
|
||||||
|
self,
|
||||||
|
bank_id: str,
|
||||||
|
query: InterleavedTextMedia,
|
||||||
|
params: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> QueryDocumentsResponse:
|
||||||
|
index = await self._get_and_cache_bank_index(bank_id)
|
||||||
|
if not index:
|
||||||
|
raise ValueError(f"Bank {bank_id} not found")
|
||||||
|
|
||||||
|
return await index.query_documents(query, params)
|
|
@ -50,7 +50,9 @@ class WeaviateIndex(EmbeddingIndex):
|
||||||
# TODO: make this async friendly
|
# TODO: make this async friendly
|
||||||
collection.data.insert_many(data_objects)
|
collection.data.insert_many(data_objects)
|
||||||
|
|
||||||
async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse:
|
async def query(
|
||||||
|
self, embedding: NDArray, k: int, score_threshold: float
|
||||||
|
) -> QueryDocumentsResponse:
|
||||||
collection = self.client.collections.get(self.collection_name)
|
collection = self.client.collections.get(self.collection_name)
|
||||||
|
|
||||||
results = collection.query.near_vector(
|
results = collection.query.near_vector(
|
||||||
|
|
|
@ -75,4 +75,13 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
config_class="llama_stack.providers.adapters.memory.sample.SampleConfig",
|
config_class="llama_stack.providers.adapters.memory.sample.SampleConfig",
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
remote_provider_spec(
|
||||||
|
Api.memory,
|
||||||
|
AdapterSpec(
|
||||||
|
adapter_type="qdrant",
|
||||||
|
pip_packages=EMBEDDING_DEPS + ["qdrant-client"],
|
||||||
|
module="llama_stack.providers.adapters.memory.qdrant",
|
||||||
|
config_class="llama_stack.providers.adapters.memory.qdrant.QdrantConfig",
|
||||||
|
),
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
|
@ -15,6 +15,11 @@ providers:
|
||||||
- provider_id: test-weaviate
|
- provider_id: test-weaviate
|
||||||
provider_type: remote::weaviate
|
provider_type: remote::weaviate
|
||||||
config: {}
|
config: {}
|
||||||
|
- provider_id: test-qdrant
|
||||||
|
provider_type: remote::qdrant
|
||||||
|
config:
|
||||||
|
host: localhost
|
||||||
|
port: 6333
|
||||||
# if a provider needs private keys from the client, they use the
|
# if a provider needs private keys from the client, they use the
|
||||||
# "get_request_provider_data" function (see distribution/request_headers.py)
|
# "get_request_provider_data" function (see distribution/request_headers.py)
|
||||||
# this is a place to provide such data.
|
# this is a place to provide such data.
|
||||||
|
|
|
@ -144,10 +144,11 @@ async def test_query_documents(memory_settings, sample_documents):
|
||||||
|
|
||||||
# Test case 5: Query with threshold on similarity score
|
# Test case 5: Query with threshold on similarity score
|
||||||
query5 = "quantum computing" # Not directly related to any document
|
query5 = "quantum computing" # Not directly related to any document
|
||||||
params5 = {"score_threshold": 0.5}
|
params5 = {"score_threshold": 0.2}
|
||||||
response5 = await memory_impl.query_documents("test_bank", query5, params5)
|
response5 = await memory_impl.query_documents("test_bank", query5, params5)
|
||||||
assert_valid_response(response5)
|
assert_valid_response(response5)
|
||||||
assert all(score >= 0.5 for score in response5.scores)
|
print("The scores are:", response5.scores)
|
||||||
|
assert all(score >= 0.2 for score in response5.scores)
|
||||||
|
|
||||||
|
|
||||||
def assert_valid_response(response: QueryDocumentsResponse):
|
def assert_valid_response(response: QueryDocumentsResponse):
|
||||||
|
|
|
@ -140,7 +140,9 @@ class EmbeddingIndex(ABC):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse:
|
async def query(
|
||||||
|
self, embedding: NDArray, k: int, score_threshold: float
|
||||||
|
) -> QueryDocumentsResponse:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
@ -177,6 +179,7 @@ class BankWithIndex:
|
||||||
if params is None:
|
if params is None:
|
||||||
params = {}
|
params = {}
|
||||||
k = params.get("max_chunks", 3)
|
k = params.get("max_chunks", 3)
|
||||||
|
score_threshold = params.get("score_threshold", 0.0)
|
||||||
|
|
||||||
def _process(c) -> str:
|
def _process(c) -> str:
|
||||||
if isinstance(c, str):
|
if isinstance(c, str):
|
||||||
|
@ -191,4 +194,4 @@ class BankWithIndex:
|
||||||
|
|
||||||
model = get_embedding_model(self.bank.embedding_model)
|
model = get_embedding_model(self.bank.embedding_model)
|
||||||
query_vector = model.encode([query_str])[0].astype(np.float32)
|
query_vector = model.encode([query_str])[0].astype(np.float32)
|
||||||
return await self.index.query(query_vector, k)
|
return await self.index.query(query_vector, k, score_threshold)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue