mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
Partial cleanup of weaviate
This commit is contained in:
parent
862f8ddb8d
commit
118c0ef105
4 changed files with 82 additions and 103 deletions
|
@ -39,14 +39,6 @@ RoutedProtocol = Union[
|
|||
]
|
||||
|
||||
|
||||
class ModelRegistry(Protocol):
|
||||
def get_model(self, identifier: str) -> ModelDef: ...
|
||||
|
||||
|
||||
class MemoryBankRegistry(Protocol):
|
||||
def get_memory_bank(self, identifier: str) -> MemoryBankDef: ...
|
||||
|
||||
|
||||
# Example: /inference, /safety
|
||||
class AutoRoutedProviderSpec(ProviderSpec):
|
||||
provider_type: str = "router"
|
||||
|
|
|
@ -64,7 +64,6 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
|
|||
+ [x.value for x in routing_table_apis]
|
||||
+ [x.value for x in router_apis]
|
||||
)
|
||||
print(f"{apis_to_serve=}")
|
||||
|
||||
for info in builtin_automatically_routed_apis():
|
||||
if info.router_api.value not in apis_to_serve:
|
||||
|
|
|
@ -4,15 +4,13 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class WeaviateRequestProviderData(BaseModel):
|
||||
# if there _is_ provider data, it must specify the API KEY
|
||||
# if you want it to be optional, use Optional[str]
|
||||
weaviate_api_key: str
|
||||
weaviate_cluster_url: str
|
||||
|
||||
@json_schema_type
|
||||
|
||||
class WeaviateConfig(BaseModel):
|
||||
collection: str = Field(default="MemoryBank")
|
||||
pass
|
||||
|
|
|
@ -1,14 +1,19 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from typing import List, Optional, Dict, Any
|
||||
from numpy.typing import NDArray
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import weaviate
|
||||
import weaviate.classes as wvc
|
||||
from numpy.typing import NDArray
|
||||
from weaviate.classes.init import Auth
|
||||
|
||||
from llama_stack.apis.memory import *
|
||||
from llama_stack.distribution.request_headers import get_request_provider_data
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
BankWithIndex,
|
||||
EmbeddingIndex,
|
||||
|
@ -16,23 +21,27 @@ from llama_stack.providers.utils.memory.vector_store import (
|
|||
|
||||
from .config import WeaviateConfig, WeaviateRequestProviderData
|
||||
|
||||
|
||||
class WeaviateIndex(EmbeddingIndex):
|
||||
def __init__(self, client: weaviate.Client, collection: str):
|
||||
self.client = client
|
||||
self.collection = collection
|
||||
|
||||
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
|
||||
assert len(chunks) == len(embeddings), f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
|
||||
assert len(chunks) == len(
|
||||
embeddings
|
||||
), f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
|
||||
|
||||
data_objects = []
|
||||
for i, chunk in enumerate(chunks):
|
||||
|
||||
data_objects.append(wvc.data.DataObject(
|
||||
data_objects.append(
|
||||
wvc.data.DataObject(
|
||||
properties={
|
||||
"chunk_content": chunk,
|
||||
},
|
||||
vector = embeddings[i].tolist()
|
||||
))
|
||||
vector=embeddings[i].tolist(),
|
||||
)
|
||||
)
|
||||
|
||||
# Inserting chunks into a prespecified Weaviate collection
|
||||
assert self.collection is not None, "Collection name must be specified"
|
||||
|
@ -40,7 +49,6 @@ class WeaviateIndex(EmbeddingIndex):
|
|||
|
||||
await my_collection.data.insert_many(data_objects)
|
||||
|
||||
|
||||
async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse:
|
||||
assert self.collection is not None, "Collection name must be specified"
|
||||
|
||||
|
@ -49,7 +57,7 @@ class WeaviateIndex(EmbeddingIndex):
|
|||
results = my_collection.query.near_vector(
|
||||
near_vector=embedding.tolist(),
|
||||
limit=k,
|
||||
return_meta_data = wvc.query.MetadataQuery(distance=True)
|
||||
return_meta_data=wvc.query.MetadataQuery(distance=True),
|
||||
)
|
||||
|
||||
chunks = []
|
||||
|
@ -62,99 +70,81 @@ class WeaviateIndex(EmbeddingIndex):
|
|||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
print(f"Failed to parse document: {e}")
|
||||
|
||||
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
||||
|
||||
|
||||
class WeaviateMemoryAdapter(Memory):
|
||||
class WeaviateMemoryAdapter(Memory, NeedsRequestProviderData):
|
||||
def __init__(self, config: WeaviateConfig) -> None:
|
||||
self.config = config
|
||||
self.client = None
|
||||
self.client_cache = {}
|
||||
self.cache = {}
|
||||
|
||||
def _get_client(self) -> weaviate.Client:
|
||||
request_provider_data = get_request_provider_data()
|
||||
provider_data = self.get_request_provider_data()
|
||||
assert provider_data is not None, "Request provider data must be set"
|
||||
assert isinstance(provider_data, WeaviateRequestProviderData)
|
||||
|
||||
if request_provider_data is not None:
|
||||
assert isinstance(request_provider_data, WeaviateRequestProviderData)
|
||||
key = f"{provider_data.weaviate_cluster_url}::{provider_data.weaviate_api_key}"
|
||||
if key in self.client_cache:
|
||||
return self.client_cache[key]
|
||||
|
||||
# Connect to Weaviate Cloud
|
||||
return weaviate.connect_to_weaviate_cloud(
|
||||
cluster_url = request_provider_data.weaviate_cluster_url,
|
||||
auth_credentials = Auth.api_key(request_provider_data.weaviate_api_key),
|
||||
client = weaviate.connect_to_weaviate_cloud(
|
||||
cluster_url=provider_data.weaviate_cluster_url,
|
||||
auth_credentials=Auth.api_key(provider_data.weaviate_api_key),
|
||||
)
|
||||
self.client_cache[key] = client
|
||||
return client
|
||||
|
||||
async def initialize(self) -> None:
|
||||
try:
|
||||
self.client = self._get_client()
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
for client in self.client_cache.values():
|
||||
client.close()
|
||||
|
||||
async def register_memory_bank(
|
||||
self,
|
||||
memory_bank: MemoryBankDef,
|
||||
) -> None:
|
||||
assert (
|
||||
memory_bank.type == MemoryBankType.vector.value
|
||||
), f"Only vector banks are supported {memory_bank.type}"
|
||||
|
||||
client = await self._get_client()
|
||||
|
||||
# Create collection if it doesn't exist
|
||||
if not self.client.collections.exists(self.config.collection):
|
||||
self.client.collections.create(
|
||||
name = self.config.collection,
|
||||
if not client.collections.exists(memory_bank.identifier):
|
||||
client.collections.create(
|
||||
name=smemory_bank.identifier,
|
||||
vectorizer_config=wvc.config.Configure.Vectorizer.none(),
|
||||
properties=[
|
||||
wvc.config.Property(
|
||||
name="chunk_content",
|
||||
data_type=wvc.config.DataType.TEXT,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
raise RuntimeError("Could not connect to Weaviate server") from e
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
self.client = self._get_client()
|
||||
|
||||
if self.client:
|
||||
self.client.close()
|
||||
|
||||
async def create_memory_bank(
|
||||
self,
|
||||
name: str,
|
||||
config: MemoryBankConfig,
|
||||
url: Optional[URL] = None,
|
||||
) -> MemoryBank:
|
||||
bank_id = str(uuid.uuid4())
|
||||
bank = MemoryBank(
|
||||
bank_id=bank_id,
|
||||
name=name,
|
||||
config=config,
|
||||
url=url,
|
||||
)
|
||||
self.client = self._get_client()
|
||||
|
||||
# Store the bank as a new collection in Weaviate
|
||||
self.client.collections.create(
|
||||
name=bank_id
|
||||
],
|
||||
)
|
||||
|
||||
index = BankWithIndex(
|
||||
bank=bank,
|
||||
index=WeaviateIndex(cleint = self.client, collection = bank_id),
|
||||
bank=memory_bank,
|
||||
index=WeaviateIndex(client=client, collection=memory_bank.identifier),
|
||||
)
|
||||
self.cache[bank_id] = index
|
||||
return bank
|
||||
|
||||
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
|
||||
bank_index = await self._get_and_cache_bank_index(bank_id)
|
||||
if bank_index is None:
|
||||
return None
|
||||
return bank_index.bank
|
||||
|
||||
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
|
||||
|
||||
self.client = self._get_client()
|
||||
|
||||
if bank_id in self.cache:
|
||||
return self.cache[bank_id]
|
||||
|
||||
collections = await self.client.collections.list_all().keys()
|
||||
bank = await self.memory_bank_store.get_memory_bank(bank_id)
|
||||
if not bank:
|
||||
raise ValueError(f"Bank {bank_id} not found")
|
||||
|
||||
client = await self._get_client()
|
||||
collections = await client.collections.list_all().keys()
|
||||
|
||||
for collection in collections:
|
||||
if collection == bank_id:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue