mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-04 13:15:24 +00:00
add Weaviate memory adapter (#95)
This commit is contained in:
parent
27587f32bc
commit
f4f7618120
4 changed files with 227 additions and 0 deletions
|
@ -0,0 +1,8 @@
|
||||||
|
from .config import WeaviateConfig
|
||||||
|
|
||||||
|
async def get_adapter_impl(config: WeaviateConfig, _deps):
|
||||||
|
from .weaviate import WeaviateMemoryAdapter
|
||||||
|
|
||||||
|
impl = WeaviateMemoryAdapter(config)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
18
llama_stack/providers/adapters/memory/weaviate/config.py
Normal file
18
llama_stack/providers/adapters/memory/weaviate/config.py
Normal file
|
@ -0,0 +1,18 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_models.schema_utils import json_schema_type
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
class WeaviateRequestProviderData(BaseModel):
|
||||||
|
# if there _is_ provider data, it must specify the API KEY
|
||||||
|
# if you want it to be optional, use Optional[str]
|
||||||
|
weaviate_api_key: str
|
||||||
|
weaviate_cluster_url: str
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class WeaviateConfig(BaseModel):
|
||||||
|
collection: str = Field(default="MemoryBank")
|
192
llama_stack/providers/adapters/memory/weaviate/weaviate.py
Normal file
192
llama_stack/providers/adapters/memory/weaviate/weaviate.py
Normal file
|
@ -0,0 +1,192 @@
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from typing import List, Optional, Dict, Any
|
||||||
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
|
import weaviate
|
||||||
|
import weaviate.classes as wvc
|
||||||
|
from weaviate.classes.init import Auth
|
||||||
|
|
||||||
|
from llama_stack.apis.memory import *
|
||||||
|
from llama_stack.distribution.request_headers import get_request_provider_data
|
||||||
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
|
BankWithIndex,
|
||||||
|
EmbeddingIndex,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .config import WeaviateConfig, WeaviateRequestProviderData
|
||||||
|
|
||||||
|
class WeaviateIndex(EmbeddingIndex):
|
||||||
|
def __init__(self, client: weaviate.Client, collection: str):
|
||||||
|
self.client = client
|
||||||
|
self.collection = collection
|
||||||
|
|
||||||
|
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
|
||||||
|
assert len(chunks) == len(embeddings), f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
|
||||||
|
|
||||||
|
data_objects = []
|
||||||
|
for i, chunk in enumerate(chunks):
|
||||||
|
|
||||||
|
data_objects.append(wvc.data.DataObject(
|
||||||
|
properties={
|
||||||
|
"chunk_content": chunk,
|
||||||
|
},
|
||||||
|
vector = embeddings[i].tolist()
|
||||||
|
))
|
||||||
|
|
||||||
|
# Inserting chunks into a prespecified Weaviate collection
|
||||||
|
assert self.collection is not None, "Collection name must be specified"
|
||||||
|
my_collection = self.client.collections.get(self.collection)
|
||||||
|
|
||||||
|
await my_collection.data.insert_many(data_objects)
|
||||||
|
|
||||||
|
|
||||||
|
async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse:
|
||||||
|
assert self.collection is not None, "Collection name must be specified"
|
||||||
|
|
||||||
|
my_collection = self.client.collections.get(self.collection)
|
||||||
|
|
||||||
|
results = my_collection.query.near_vector(
|
||||||
|
near_vector = embedding.tolist(),
|
||||||
|
limit = k,
|
||||||
|
return_meta_data = wvc.query.MetadataQuery(distance=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
scores = []
|
||||||
|
for doc in results.objects:
|
||||||
|
try:
|
||||||
|
chunk = doc.properties["chunk_content"]
|
||||||
|
chunks.append(chunk)
|
||||||
|
scores.append(1.0 / doc.metadata.distance)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
print(f"Failed to parse document: {e}")
|
||||||
|
|
||||||
|
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
|
|
||||||
|
class WeaviateMemoryAdapter(Memory):
|
||||||
|
def __init__(self, config: WeaviateConfig) -> None:
|
||||||
|
self.config = config
|
||||||
|
self.client = None
|
||||||
|
self.cache = {}
|
||||||
|
|
||||||
|
def _get_client(self) -> weaviate.Client:
|
||||||
|
request_provider_data = get_request_provider_data()
|
||||||
|
|
||||||
|
if request_provider_data is not None:
|
||||||
|
assert isinstance(request_provider_data, WeaviateRequestProviderData)
|
||||||
|
|
||||||
|
# Connect to Weaviate Cloud
|
||||||
|
return weaviate.connect_to_weaviate_cloud(
|
||||||
|
cluster_url = request_provider_data.weaviate_cluster_url,
|
||||||
|
auth_credentials = Auth.api_key(request_provider_data.weaviate_api_key),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
try:
|
||||||
|
self.client = self._get_client()
|
||||||
|
|
||||||
|
# Create collection if it doesn't exist
|
||||||
|
if not self.client.collections.exists(self.config.collection):
|
||||||
|
self.client.collections.create(
|
||||||
|
name = self.config.collection,
|
||||||
|
vectorizer_config = wvc.config.Configure.Vectorizer.none(),
|
||||||
|
properties=[
|
||||||
|
wvc.config.Property(
|
||||||
|
name="chunk_content",
|
||||||
|
data_type=wvc.config.DataType.TEXT,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
raise RuntimeError("Could not connect to Weaviate server") from e
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
self.client = self._get_client()
|
||||||
|
|
||||||
|
if self.client:
|
||||||
|
self.client.close()
|
||||||
|
|
||||||
|
async def create_memory_bank(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
config: MemoryBankConfig,
|
||||||
|
url: Optional[URL] = None,
|
||||||
|
) -> MemoryBank:
|
||||||
|
bank_id = str(uuid.uuid4())
|
||||||
|
bank = MemoryBank(
|
||||||
|
bank_id=bank_id,
|
||||||
|
name=name,
|
||||||
|
config=config,
|
||||||
|
url=url,
|
||||||
|
)
|
||||||
|
self.client = self._get_client()
|
||||||
|
|
||||||
|
# Store the bank as a new collection in Weaviate
|
||||||
|
self.client.collections.create(
|
||||||
|
name=bank_id
|
||||||
|
)
|
||||||
|
|
||||||
|
index = BankWithIndex(
|
||||||
|
bank=bank,
|
||||||
|
index=WeaviateIndex(cleint = self.client, collection = bank_id),
|
||||||
|
)
|
||||||
|
self.cache[bank_id] = index
|
||||||
|
return bank
|
||||||
|
|
||||||
|
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
|
||||||
|
bank_index = await self._get_and_cache_bank_index(bank_id)
|
||||||
|
if bank_index is None:
|
||||||
|
return None
|
||||||
|
return bank_index.bank
|
||||||
|
|
||||||
|
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
|
||||||
|
|
||||||
|
self.client = self._get_client()
|
||||||
|
|
||||||
|
if bank_id in self.cache:
|
||||||
|
return self.cache[bank_id]
|
||||||
|
|
||||||
|
collections = await self.client.collections.list_all().keys()
|
||||||
|
|
||||||
|
for collection in collections:
|
||||||
|
if collection == bank_id:
|
||||||
|
bank = MemoryBank(**json.loads(collection.metadata["bank"]))
|
||||||
|
index = BankWithIndex(
|
||||||
|
bank=bank,
|
||||||
|
index=WeaviateIndex(self.client, collection),
|
||||||
|
)
|
||||||
|
self.cache[bank_id] = index
|
||||||
|
return index
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def insert_documents(
|
||||||
|
self,
|
||||||
|
bank_id: str,
|
||||||
|
documents: List[MemoryBankDocument],
|
||||||
|
) -> None:
|
||||||
|
index = await self._get_and_cache_bank_index(bank_id)
|
||||||
|
if not index:
|
||||||
|
raise ValueError(f"Bank {bank_id} not found")
|
||||||
|
|
||||||
|
await index.insert_documents(documents)
|
||||||
|
|
||||||
|
async def query_documents(
|
||||||
|
self,
|
||||||
|
bank_id: str,
|
||||||
|
query: InterleavedTextMedia,
|
||||||
|
params: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> QueryDocumentsResponse:
|
||||||
|
index = await self._get_and_cache_bank_index(bank_id)
|
||||||
|
if not index:
|
||||||
|
raise ValueError(f"Bank {bank_id} not found")
|
||||||
|
|
||||||
|
return await index.query_documents(query, params)
|
|
@ -56,6 +56,15 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
config_class="llama_stack.providers.adapters.memory.pgvector.PGVectorConfig",
|
config_class="llama_stack.providers.adapters.memory.pgvector.PGVectorConfig",
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
remote_provider_spec(
|
||||||
|
Api.memory,
|
||||||
|
AdapterSpec(
|
||||||
|
adapter_id="weaviate",
|
||||||
|
pip_packages=EMBEDDING_DEPS + ["weaviate-client"],
|
||||||
|
module="llama_stack.providers.adapters.memory.weaviate",
|
||||||
|
provider_data_validator="llama_stack.providers.adapters.memory.weaviate.WeaviateRequestProviderData",
|
||||||
|
),
|
||||||
|
),
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
api=Api.memory,
|
api=Api.memory,
|
||||||
adapter=AdapterSpec(
|
adapter=AdapterSpec(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue