add Weaviate memory adapter

This commit is contained in:
Zain Hasan 2024-09-23 20:54:06 -04:00
parent 70fb70a71c
commit 49763c4452
5 changed files with 240 additions and 4 deletions

View file

@ -42,7 +42,10 @@ class MemoryClient(Memory):
params={
"bank_id": bank_id,
},
headers={"Content-Type": "application/json"},
headers={
"Content-Type": "application/json",
"X-LlamaStack-ProviderData": json.dumps({"weaviate_api_key": "1234"}),
},
timeout=20,
)
r.raise_for_status()
@ -65,7 +68,10 @@ class MemoryClient(Memory):
"config": config.dict(),
"url": url,
},
headers={"Content-Type": "application/json"},
headers={
"Content-Type": "application/json",
"X-LlamaStack-ProviderData": json.dumps({"weaviate_api_key": "1234"}),
},
timeout=20,
)
r.raise_for_status()
@ -86,7 +92,10 @@ class MemoryClient(Memory):
"bank_id": bank_id,
"documents": [d.dict() for d in documents],
},
headers={"Content-Type": "application/json"},
headers={
"Content-Type": "application/json",
"X-LlamaStack-ProviderData": json.dumps({"weaviate_api_key": "1234"}),
},
timeout=20,
)
r.raise_for_status()
@ -105,7 +114,10 @@ class MemoryClient(Memory):
"query": query,
"params": params,
},
headers={"Content-Type": "application/json"},
headers={
"Content-Type": "application/json",
"X-LlamaStack-ProviderData": json.dumps({"weaviate_api_key": "1234"}),
},
timeout=20,
)
r.raise_for_status()

View file

@ -0,0 +1,8 @@
from llama_stack.distribution.datatypes import RemoteProviderConfig
async def get_adapter_impl(config: RemoteProviderConfig, _deps):
from .weaviate import WeaviateMemoryAdapter
impl = WeaviateMemoryAdapter(config.url, config.username, config.password)
await impl.initialize()
return impl

View file

@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
class WeaviateRequestProviderData(BaseModel):
# if there _is_ provider data, it must specify the API KEY
# if you want it to be optional, use Optional[str]
weaviate_api_key: str
@json_schema_type
class WeaviateConfig(BaseModel):
url: str = Field(default="http://localhost:8080")
api_key: str = Field(default="")
collection: str = Field(default="MemoryBank")

View file

@ -0,0 +1,188 @@
import json
import uuid
from typing import List, Optional, Dict, Any
from urllib.parse import urlparse
from numpy.typing import NDArray
import weaviate
import weaviate.classes as wvc
from weaviate.classes.init import Auth
from llama_stack.apis.memory import *
from llama_stack.distribution.request_headers import get_request_provider_data
from llama_stack.providers.utils.memory.vector_store import (
BankWithIndex,
EmbeddingIndex,
)
from .config import WeaviateConfig, WeaviateRequestProviderData
class WeaviateIndex(EmbeddingIndex):
def __init__(self, client: weaviate.Client, collection: str):
self.client = client
self.collection = collection
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
assert len(chunks) == len(embeddings), f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
data_objects = []
for i, chunk in enumerate(chunks):
print(f"Adding chunk #{i} tokens={chunk.token_count}")
data_objects.append(wvc.data.DataObject(
properties={
"chunk_content": chunk,
},
vector = embeddings[i].tolist()
))
# Inserting chunks into a prespecified Weaviate collection
assert self.collection is not None, "Collection name must be specified"
my_collection = self.client.collections.get(self.collection)
await my_collection.data.insert_many(data_objects)
async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse:
assert self.collection is not None, "Collection name must be specified"
my_collection = self.client.collections.get(self.collection)
results = my_collection.query.near_vector(
near_vector = embedding.tolist(),
limit = k,
return_meta_data = wvc.query.MetadataQuery(distance=True)
)
chunks = []
scores = []
for doc in results.objects:
try:
chunk = doc.properties["chunk_content"]
chunks.append(chunk)
scores.append(1.0 / doc.metadata.distance)
except Exception as e:
import traceback
traceback.print_exc()
print(f"Failed to parse document: {e}")
return QueryDocumentsResponse(chunks=chunks, scores=scores)
class WeaviateMemoryAdapter(Memory):
def __init__(self, config: WeaviateConfig) -> None:
print(f"Initializing WeaviateMemoryAdapter with URL: {config.url}")
self.config = config
self.client = None
self.cache = {}
async def initialize(self) -> None:
try:
request_provider_data = get_request_provider_data()
if request_provider_data is not None:
assert isinstance(request_provider_data, WeaviateRequestProviderData)
print(f"WEAVIATE API KEY: {request_provider_data.weaviate_api_key}")
# Connect to Weaviate Cloud
self.client = weaviate.connect_to_weaviate_cloud(
cluster_url = self.config.url,
auth_credentials = Auth.api_key(request_provider_data.weaviate_api_key),
)
# Create collection if it doesn't exist
if not self.client.collections.exists(self.config.collection):
self.client.collections.create(
name = self.config.collection,
vectorizer_config = wvc.config.Configure.Vectorizer.none(),
properties=[
wvc.config.Property(
name="chunk_content",
data_type=wvc.config.DataType.TEXT,
),
]
)
except Exception as e:
import traceback
traceback.print_exc()
raise RuntimeError("Could not connect to Weaviate server") from e
async def shutdown(self) -> None:
if self.client:
self.client.close()
async def create_memory_bank(
self,
name: str,
config: MemoryBankConfig,
url: Optional[URL] = None,
) -> MemoryBank:
bank_id = str(uuid.uuid4())
bank = MemoryBank(
bank_id=bank_id,
name=name,
config=config,
url=url,
)
# Store the bank as a new collection in Weaviate
self.client.collections.create(
name=bank_id
)
index = BankWithIndex(
bank=bank,
index=WeaviateIndex(cleint = self.client, collection = bank_id),
)
self.cache[bank_id] = index
return bank
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
bank_index = await self._get_and_cache_bank_index(bank_id)
if bank_index is None:
return None
return bank_index.bank
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
if bank_id in self.cache:
return self.cache[bank_id]
collections = await self.client.collections.list_all().keys()
for collection in collections:
if collection == bank_id:
bank = MemoryBank(**json.loads(collection.metadata["bank"]))
index = BankWithIndex(
bank=bank,
index=WeaviateIndex(self.client, collection),
)
self.cache[bank_id] = index
return index
return None
async def insert_documents(
self,
bank_id: str,
documents: List[MemoryBankDocument],
ttl_seconds: Optional[int] = None,
) -> None:
index = await self._get_and_cache_bank_index(bank_id)
if not index:
raise ValueError(f"Bank {bank_id} not found")
await index.insert_documents(documents)
async def query_documents(
self,
bank_id: str,
query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
index = await self._get_and_cache_bank_index(bank_id)
if not index:
raise ValueError(f"Bank {bank_id} not found")
return await index.query_documents(query, params)

View file

@ -42,6 +42,15 @@ def available_providers() -> List[ProviderSpec]:
config_class="llama_stack.providers.adapters.memory.pgvector.PGVectorConfig",
),
),
remote_provider_spec(
Api.memory,
AdapterSpec(
adapter_id="weaviate",
pip_packages=EMBEDDING_DEPS + ["weaviate-client"],
module="llama_stack.providers.adapters.memory.weaviate",
provider_data_validator="llama_stack.providers.adapters.memory.weaviate.WeaviateRequestProviderData",
),
),
remote_provider_spec(
api=Api.memory,
adapter=AdapterSpec(