mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
feat: Qdrant Vector index support
This commit is contained in:
parent
f4f7618120
commit
a0c888c071
4 changed files with 240 additions and 0 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -13,3 +13,4 @@ xcuserdata/
|
||||||
Package.resolved
|
Package.resolved
|
||||||
*.pte
|
*.pte
|
||||||
*.ipynb_checkpoints*
|
*.ipynb_checkpoints*
|
||||||
|
.venv/
|
||||||
|
|
15
llama_stack/providers/adapters/memory/qdrant/__init__.py
Normal file
15
llama_stack/providers/adapters/memory/qdrant/__init__.py
Normal file
|
@ -0,0 +1,15 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .config import QdrantConfig
|
||||||
|
|
||||||
|
|
||||||
|
async def get_adapter_impl(config: QdrantConfig, _deps):
|
||||||
|
from .qdrant import QdrantVectorMemoryAdapter
|
||||||
|
|
||||||
|
impl = QdrantVectorMemoryAdapter(config)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
25
llama_stack/providers/adapters/memory/qdrant/config.py
Normal file
25
llama_stack/providers/adapters/memory/qdrant/config.py
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from llama_models.schema_utils import json_schema_type
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class QdrantConfig(BaseModel):
|
||||||
|
location: Optional[str] = None
|
||||||
|
url: Optional[str] = None
|
||||||
|
port: Optional[int] = 6333
|
||||||
|
grpc_port: int = 6334
|
||||||
|
prefer_grpc: bool = False
|
||||||
|
https: Optional[bool] = None
|
||||||
|
api_key: Optional[str] = None
|
||||||
|
prefix: Optional[str] = None
|
||||||
|
timeout: Optional[int] = None
|
||||||
|
host: Optional[str] = None
|
||||||
|
path: Optional[str] = None
|
199
llama_stack/providers/adapters/memory/qdrant/qdrant.py
Normal file
199
llama_stack/providers/adapters/memory/qdrant/qdrant.py
Normal file
|
@ -0,0 +1,199 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import traceback
|
||||||
|
import uuid
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from numpy.typing import NDArray
|
||||||
|
from qdrant_client import AsyncQdrantClient, models
|
||||||
|
from qdrant_client.models import PointStruct
|
||||||
|
|
||||||
|
from llama_stack.apis.memory import * # noqa: F403
|
||||||
|
from llama_stack.distribution.datatypes import RoutableProvider
|
||||||
|
|
||||||
|
from llama_stack.providers.adapters.memory.qdrant.config import QdrantConfig
|
||||||
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
|
BankWithIndex,
|
||||||
|
EmbeddingIndex,
|
||||||
|
)
|
||||||
|
|
||||||
|
CHUNK_ID_KEY = "_chunk_id"
|
||||||
|
METADATA_COLLECTION_NAME = "metadata_store"
|
||||||
|
|
||||||
|
|
||||||
|
def convert_id(_id: str) -> str:
|
||||||
|
"""
|
||||||
|
Converts any string into a UUID string based on a seed.
|
||||||
|
|
||||||
|
Qdrant accepts UUID strings and unsigned integers as point ID.
|
||||||
|
We use a seed to convert each string into a UUID string deterministically.
|
||||||
|
This allows us to overwrite the same point with the original ID.
|
||||||
|
"""
|
||||||
|
return str(uuid.uuid5(uuid.NAMESPACE_DNS, _id))
|
||||||
|
|
||||||
|
|
||||||
|
class QdrantIndex(EmbeddingIndex):
|
||||||
|
def __init__(self, client: AsyncQdrantClient, bank: MemoryBank):
|
||||||
|
self.client = client
|
||||||
|
self.collection_name = bank.name
|
||||||
|
|
||||||
|
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
|
||||||
|
assert len(chunks) == len(
|
||||||
|
embeddings
|
||||||
|
), f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
|
||||||
|
|
||||||
|
if not await self.client.collection_exists(self.collection_name):
|
||||||
|
await self.client.create_collection(
|
||||||
|
self.collection_name,
|
||||||
|
vectors_config=models.VectorParams(
|
||||||
|
size=len(embeddings[0]), distance=models.Distance.COSINE
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
points = []
|
||||||
|
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
|
||||||
|
chunk_id = f"{chunk.document_id}:chunk-{i}"
|
||||||
|
points.append(
|
||||||
|
PointStruct(
|
||||||
|
id=convert_id(chunk_id),
|
||||||
|
vector=embedding,
|
||||||
|
payload=chunk.model_dump() | {CHUNK_ID_KEY: chunk_id},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.client.upsert(collection_name=self.collection_name, points=points)
|
||||||
|
|
||||||
|
async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse:
|
||||||
|
results = (
|
||||||
|
await self.client.query_points(
|
||||||
|
collection_name=self.collection_name, query=embedding.tolist(), limit=k
|
||||||
|
)
|
||||||
|
).points
|
||||||
|
|
||||||
|
chunks, scores = [], []
|
||||||
|
for point in results:
|
||||||
|
assert isinstance(point, models.ScoredPoint)
|
||||||
|
assert point.payload is not None
|
||||||
|
|
||||||
|
try:
|
||||||
|
point.payload.pop(CHUNK_ID_KEY, None)
|
||||||
|
chunk = Chunk(**point.payload)
|
||||||
|
except Exception:
|
||||||
|
traceback.print_exc()
|
||||||
|
continue
|
||||||
|
|
||||||
|
chunks.append(chunk)
|
||||||
|
scores.append(point.score)
|
||||||
|
|
||||||
|
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
|
|
||||||
|
class QdrantVectorMemoryAdapter(Memory, RoutableProvider):
|
||||||
|
def __init__(self, config: QdrantConfig) -> None:
|
||||||
|
self.config = config
|
||||||
|
self.client = None
|
||||||
|
self.cache = {}
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
try:
|
||||||
|
self.client = AsyncQdrantClient(**self.config.model_dump(exclude_none=True))
|
||||||
|
|
||||||
|
if not await self.client.collection_exists(METADATA_COLLECTION_NAME):
|
||||||
|
await self.client.create_collection(
|
||||||
|
METADATA_COLLECTION_NAME, vectors_config={}
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
raise RuntimeError(f"Could not connect to Qdrant: {e}") from e
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def validate_routing_keys(self, routing_keys: List[str]) -> None:
|
||||||
|
print(f"[qdrant] Registering memory bank routing keys: {routing_keys}")
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def create_memory_bank(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
config: MemoryBankConfig,
|
||||||
|
url: Optional[URL] = None,
|
||||||
|
) -> MemoryBank:
|
||||||
|
bank_id = str(uuid.uuid4())
|
||||||
|
bank = MemoryBank(
|
||||||
|
bank_id=bank_id,
|
||||||
|
name=name,
|
||||||
|
config=config,
|
||||||
|
url=url,
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.client.upsert(
|
||||||
|
METADATA_COLLECTION_NAME,
|
||||||
|
points=[
|
||||||
|
PointStruct(
|
||||||
|
id=convert_id(bank_id), vector={}, payload=bank.model_dump()
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
index = BankWithIndex(
|
||||||
|
bank=bank,
|
||||||
|
index=QdrantIndex(self.client, bank),
|
||||||
|
)
|
||||||
|
self.cache[bank_id] = index
|
||||||
|
return bank
|
||||||
|
|
||||||
|
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
|
||||||
|
bank_index = await self._get_and_cache_bank_index(bank_id)
|
||||||
|
if bank_index is None:
|
||||||
|
return None
|
||||||
|
return bank_index.bank
|
||||||
|
|
||||||
|
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
|
||||||
|
if bank_id in self.cache:
|
||||||
|
return self.cache[bank_id]
|
||||||
|
|
||||||
|
bank_point = await self.client.retrieve(
|
||||||
|
METADATA_COLLECTION_NAME, ids=[convert_id(bank_id)], with_payload=True
|
||||||
|
)
|
||||||
|
|
||||||
|
if not bank_point:
|
||||||
|
return None
|
||||||
|
|
||||||
|
bank = MemoryBank(**bank_point[0].payload)
|
||||||
|
index = BankWithIndex(
|
||||||
|
bank=bank,
|
||||||
|
index=QdrantIndex(self.client, bank),
|
||||||
|
)
|
||||||
|
self.cache[bank_id] = index
|
||||||
|
return index
|
||||||
|
|
||||||
|
async def insert_documents(
|
||||||
|
self,
|
||||||
|
bank_id: str,
|
||||||
|
documents: List[MemoryBankDocument],
|
||||||
|
ttl_seconds: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
index = await self._get_and_cache_bank_index(bank_id)
|
||||||
|
if not index:
|
||||||
|
raise ValueError(f"Bank {bank_id} not found")
|
||||||
|
|
||||||
|
await index.insert_documents(documents)
|
||||||
|
|
||||||
|
async def query_documents(
|
||||||
|
self,
|
||||||
|
bank_id: str,
|
||||||
|
query: InterleavedTextMedia,
|
||||||
|
params: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> QueryDocumentsResponse:
|
||||||
|
index = await self._get_and_cache_bank_index(bank_id)
|
||||||
|
if not index:
|
||||||
|
raise ValueError(f"Bank {bank_id} not found")
|
||||||
|
|
||||||
|
return await index.query_documents(query, params)
|
Loading…
Add table
Add a link
Reference in a new issue