Partial cleanup of weaviate

This commit is contained in:
Zain Hasan 2024-10-07 01:21:50 -04:00 committed by Ashwin Bharambe
parent 862f8ddb8d
commit 118c0ef105
4 changed files with 82 additions and 103 deletions

View file

@ -39,14 +39,6 @@ RoutedProtocol = Union[
] ]
class ModelRegistry(Protocol):
def get_model(self, identifier: str) -> ModelDef: ...
class MemoryBankRegistry(Protocol):
def get_memory_bank(self, identifier: str) -> MemoryBankDef: ...
# Example: /inference, /safety # Example: /inference, /safety
class AutoRoutedProviderSpec(ProviderSpec): class AutoRoutedProviderSpec(ProviderSpec):
provider_type: str = "router" provider_type: str = "router"

View file

@ -64,7 +64,6 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
+ [x.value for x in routing_table_apis] + [x.value for x in routing_table_apis]
+ [x.value for x in router_apis] + [x.value for x in router_apis]
) )
print(f"{apis_to_serve=}")
for info in builtin_automatically_routed_apis(): for info in builtin_automatically_routed_apis():
if info.router_api.value not in apis_to_serve: if info.router_api.value not in apis_to_serve:

View file

@ -4,15 +4,13 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from llama_models.schema_utils import json_schema_type from pydantic import BaseModel
from pydantic import BaseModel, Field
class WeaviateRequestProviderData(BaseModel): class WeaviateRequestProviderData(BaseModel):
# if there _is_ provider data, it must specify the API KEY
# if you want it to be optional, use Optional[str]
weaviate_api_key: str weaviate_api_key: str
weaviate_cluster_url: str weaviate_cluster_url: str
@json_schema_type
class WeaviateConfig(BaseModel): class WeaviateConfig(BaseModel):
collection: str = Field(default="MemoryBank") pass

View file

@ -1,14 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json import json
import uuid from typing import Any, Dict, List, Optional
from typing import List, Optional, Dict, Any
from numpy.typing import NDArray
import weaviate import weaviate
import weaviate.classes as wvc import weaviate.classes as wvc
from numpy.typing import NDArray
from weaviate.classes.init import Auth from weaviate.classes.init import Auth
from llama_stack.apis.memory import * from llama_stack.apis.memory import * # noqa: F403
from llama_stack.distribution.request_headers import get_request_provider_data from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (
BankWithIndex, BankWithIndex,
EmbeddingIndex, EmbeddingIndex,
@ -16,23 +21,27 @@ from llama_stack.providers.utils.memory.vector_store import (
from .config import WeaviateConfig, WeaviateRequestProviderData from .config import WeaviateConfig, WeaviateRequestProviderData
class WeaviateIndex(EmbeddingIndex): class WeaviateIndex(EmbeddingIndex):
def __init__(self, client: weaviate.Client, collection: str): def __init__(self, client: weaviate.Client, collection: str):
self.client = client self.client = client
self.collection = collection self.collection = collection
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray): async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
assert len(chunks) == len(embeddings), f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}" assert len(chunks) == len(
embeddings
), f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
data_objects = [] data_objects = []
for i, chunk in enumerate(chunks): for i, chunk in enumerate(chunks):
data_objects.append(
data_objects.append(wvc.data.DataObject( wvc.data.DataObject(
properties={ properties={
"chunk_content": chunk, "chunk_content": chunk,
}, },
vector = embeddings[i].tolist() vector=embeddings[i].tolist(),
)) )
)
# Inserting chunks into a prespecified Weaviate collection # Inserting chunks into a prespecified Weaviate collection
assert self.collection is not None, "Collection name must be specified" assert self.collection is not None, "Collection name must be specified"
@ -40,16 +49,15 @@ class WeaviateIndex(EmbeddingIndex):
await my_collection.data.insert_many(data_objects) await my_collection.data.insert_many(data_objects)
async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse: async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse:
assert self.collection is not None, "Collection name must be specified" assert self.collection is not None, "Collection name must be specified"
my_collection = self.client.collections.get(self.collection) my_collection = self.client.collections.get(self.collection)
results = my_collection.query.near_vector( results = my_collection.query.near_vector(
near_vector = embedding.tolist(), near_vector=embedding.tolist(),
limit = k, limit=k,
return_meta_data = wvc.query.MetadataQuery(distance=True) return_meta_data=wvc.query.MetadataQuery(distance=True),
) )
chunks = [] chunks = []
@ -62,99 +70,81 @@ class WeaviateIndex(EmbeddingIndex):
except Exception as e: except Exception as e:
import traceback import traceback
traceback.print_exc() traceback.print_exc()
print(f"Failed to parse document: {e}") print(f"Failed to parse document: {e}")
return QueryDocumentsResponse(chunks=chunks, scores=scores) return QueryDocumentsResponse(chunks=chunks, scores=scores)
class WeaviateMemoryAdapter(Memory): class WeaviateMemoryAdapter(Memory, NeedsRequestProviderData):
def __init__(self, config: WeaviateConfig) -> None: def __init__(self, config: WeaviateConfig) -> None:
self.config = config self.config = config
self.client = None self.client_cache = {}
self.cache = {} self.cache = {}
def _get_client(self) -> weaviate.Client: def _get_client(self) -> weaviate.Client:
request_provider_data = get_request_provider_data() provider_data = self.get_request_provider_data()
assert provider_data is not None, "Request provider data must be set"
assert isinstance(provider_data, WeaviateRequestProviderData)
if request_provider_data is not None: key = f"{provider_data.weaviate_cluster_url}::{provider_data.weaviate_api_key}"
assert isinstance(request_provider_data, WeaviateRequestProviderData) if key in self.client_cache:
return self.client_cache[key]
# Connect to Weaviate Cloud client = weaviate.connect_to_weaviate_cloud(
return weaviate.connect_to_weaviate_cloud( cluster_url=provider_data.weaviate_cluster_url,
cluster_url = request_provider_data.weaviate_cluster_url, auth_credentials=Auth.api_key(provider_data.weaviate_api_key),
auth_credentials = Auth.api_key(request_provider_data.weaviate_api_key), )
) self.client_cache[key] = client
return client
async def initialize(self) -> None: async def initialize(self) -> None:
try: pass
self.client = self._get_client()
# Create collection if it doesn't exist
if not self.client.collections.exists(self.config.collection):
self.client.collections.create(
name = self.config.collection,
vectorizer_config = wvc.config.Configure.Vectorizer.none(),
properties=[
wvc.config.Property(
name="chunk_content",
data_type=wvc.config.DataType.TEXT,
),
]
)
except Exception as e:
import traceback
traceback.print_exc()
raise RuntimeError("Could not connect to Weaviate server") from e
async def shutdown(self) -> None: async def shutdown(self) -> None:
self.client = self._get_client() for client in self.client_cache.values():
client.close()
if self.client: async def register_memory_bank(
self.client.close()
async def create_memory_bank(
self, self,
name: str, memory_bank: MemoryBankDef,
config: MemoryBankConfig, ) -> None:
url: Optional[URL] = None, assert (
) -> MemoryBank: memory_bank.type == MemoryBankType.vector.value
bank_id = str(uuid.uuid4()) ), f"Only vector banks are supported {memory_bank.type}"
bank = MemoryBank(
bank_id=bank_id,
name=name,
config=config,
url=url,
)
self.client = self._get_client()
# Store the bank as a new collection in Weaviate client = await self._get_client()
self.client.collections.create(
name=bank_id # Create collection if it doesn't exist
) if not client.collections.exists(memory_bank.identifier):
client.collections.create(
name=smemory_bank.identifier,
vectorizer_config=wvc.config.Configure.Vectorizer.none(),
properties=[
wvc.config.Property(
name="chunk_content",
data_type=wvc.config.DataType.TEXT,
),
],
)
index = BankWithIndex( index = BankWithIndex(
bank=bank, bank=memory_bank,
index=WeaviateIndex(cleint = self.client, collection = bank_id), index=WeaviateIndex(client=client, collection=memory_bank.identifier),
) )
self.cache[bank_id] = index self.cache[bank_id] = index
return bank
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
bank_index = await self._get_and_cache_bank_index(bank_id)
if bank_index is None:
return None
return bank_index.bank
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]: async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
self.client = self._get_client()
if bank_id in self.cache: if bank_id in self.cache:
return self.cache[bank_id] return self.cache[bank_id]
collections = await self.client.collections.list_all().keys() bank = await self.memory_bank_store.get_memory_bank(bank_id)
if not bank:
raise ValueError(f"Bank {bank_id} not found")
client = await self._get_client()
collections = await client.collections.list_all().keys()
for collection in collections: for collection in collections:
if collection == bank_id: if collection == bank_id: