diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index f08eec462..c987d4c87 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -39,14 +39,6 @@ RoutedProtocol = Union[ ] -class ModelRegistry(Protocol): - def get_model(self, identifier: str) -> ModelDef: ... - - -class MemoryBankRegistry(Protocol): - def get_memory_bank(self, identifier: str) -> MemoryBankDef: ... - - # Example: /inference, /safety class AutoRoutedProviderSpec(ProviderSpec): provider_type: str = "router" diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index 0fc9bd72e..2d3679177 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -64,7 +64,6 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An + [x.value for x in routing_table_apis] + [x.value for x in router_apis] ) - print(f"{apis_to_serve=}") for info in builtin_automatically_routed_apis(): if info.router_api.value not in apis_to_serve: diff --git a/llama_stack/providers/adapters/memory/weaviate/config.py b/llama_stack/providers/adapters/memory/weaviate/config.py index db73604d2..d0811acb4 100644 --- a/llama_stack/providers/adapters/memory/weaviate/config.py +++ b/llama_stack/providers/adapters/memory/weaviate/config.py @@ -4,15 +4,13 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_models.schema_utils import json_schema_type -from pydantic import BaseModel, Field +from pydantic import BaseModel + class WeaviateRequestProviderData(BaseModel): - # if there _is_ provider data, it must specify the API KEY - # if you want it to be optional, use Optional[str] weaviate_api_key: str weaviate_cluster_url: str -@json_schema_type + class WeaviateConfig(BaseModel): - collection: str = Field(default="MemoryBank") + pass diff --git a/llama_stack/providers/adapters/memory/weaviate/weaviate.py b/llama_stack/providers/adapters/memory/weaviate/weaviate.py index abfe27150..9f8e93434 100644 --- a/llama_stack/providers/adapters/memory/weaviate/weaviate.py +++ b/llama_stack/providers/adapters/memory/weaviate/weaviate.py @@ -1,14 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + import json -import uuid -from typing import List, Optional, Dict, Any -from numpy.typing import NDArray +from typing import Any, Dict, List, Optional import weaviate import weaviate.classes as wvc +from numpy.typing import NDArray from weaviate.classes.init import Auth -from llama_stack.apis.memory import * -from llama_stack.distribution.request_headers import get_request_provider_data +from llama_stack.apis.memory import * # noqa: F403 +from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.providers.utils.memory.vector_store import ( BankWithIndex, EmbeddingIndex, @@ -16,40 +21,43 @@ from llama_stack.providers.utils.memory.vector_store import ( from .config import WeaviateConfig, WeaviateRequestProviderData + class WeaviateIndex(EmbeddingIndex): def __init__(self, client: weaviate.Client, collection: str): self.client = client self.collection = collection async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray): - assert len(chunks) == len(embeddings), f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}" + assert len(chunks) == len( + embeddings + ), f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}" data_objects = [] for i, chunk in enumerate(chunks): - - data_objects.append(wvc.data.DataObject( - properties={ - "chunk_content": chunk, - }, - vector = embeddings[i].tolist() - )) + data_objects.append( + wvc.data.DataObject( + properties={ + "chunk_content": chunk, + }, + vector=embeddings[i].tolist(), + ) + ) # Inserting chunks into a prespecified Weaviate collection assert self.collection is not None, "Collection name must be specified" my_collection = self.client.collections.get(self.collection) - - await my_collection.data.insert_many(data_objects) + await my_collection.data.insert_many(data_objects) async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse: assert self.collection is not None, "Collection name must be specified" my_collection = self.client.collections.get(self.collection) - + results = my_collection.query.near_vector( - near_vector = embedding.tolist(), - limit = k, - return_meta_data = wvc.query.MetadataQuery(distance=True) + near_vector=embedding.tolist(), + limit=k, + return_meta_data=wvc.query.MetadataQuery(distance=True), ) chunks = [] @@ -59,102 +67,84 @@ class WeaviateIndex(EmbeddingIndex): chunk = doc.properties["chunk_content"] chunks.append(chunk) scores.append(1.0 / doc.metadata.distance) - + except Exception as e: import traceback + traceback.print_exc() print(f"Failed to parse document: {e}") return QueryDocumentsResponse(chunks=chunks, scores=scores) -class WeaviateMemoryAdapter(Memory): +class WeaviateMemoryAdapter(Memory, NeedsRequestProviderData): def __init__(self, config: WeaviateConfig) -> None: self.config = config - self.client = None + self.client_cache = {} self.cache = {} def _get_client(self) -> weaviate.Client: - request_provider_data = get_request_provider_data() - - if request_provider_data is not None: - assert isinstance(request_provider_data, WeaviateRequestProviderData) - - # Connect to Weaviate Cloud - return weaviate.connect_to_weaviate_cloud( - cluster_url = request_provider_data.weaviate_cluster_url, - auth_credentials = Auth.api_key(request_provider_data.weaviate_api_key), - ) + provider_data = self.get_request_provider_data() + assert provider_data is not None, "Request provider data must be set" + assert isinstance(provider_data, WeaviateRequestProviderData) + + key = f"{provider_data.weaviate_cluster_url}::{provider_data.weaviate_api_key}" + if key in self.client_cache: + return self.client_cache[key] + + client = weaviate.connect_to_weaviate_cloud( + cluster_url=provider_data.weaviate_cluster_url, + auth_credentials=Auth.api_key(provider_data.weaviate_api_key), + ) + self.client_cache[key] = client + return client async def initialize(self) -> None: - try: - self.client = self._get_client() - - # Create collection if it doesn't exist - if not self.client.collections.exists(self.config.collection): - self.client.collections.create( - name = self.config.collection, - vectorizer_config = wvc.config.Configure.Vectorizer.none(), - properties=[ - wvc.config.Property( - name="chunk_content", - data_type=wvc.config.DataType.TEXT, - ), - ] - ) - - except Exception as e: - import traceback - traceback.print_exc() - raise RuntimeError("Could not connect to Weaviate server") from e + pass async def shutdown(self) -> None: - self.client = self._get_client() + for client in self.client_cache.values(): + client.close() - if self.client: - self.client.close() - - async def create_memory_bank( + async def register_memory_bank( self, - name: str, - config: MemoryBankConfig, - url: Optional[URL] = None, - ) -> MemoryBank: - bank_id = str(uuid.uuid4()) - bank = MemoryBank( - bank_id=bank_id, - name=name, - config=config, - url=url, - ) - self.client = self._get_client() - - # Store the bank as a new collection in Weaviate - self.client.collections.create( - name=bank_id - ) + memory_bank: MemoryBankDef, + ) -> None: + assert ( + memory_bank.type == MemoryBankType.vector.value + ), f"Only vector banks are supported {memory_bank.type}" + + client = await self._get_client() + + # Create collection if it doesn't exist + if not client.collections.exists(memory_bank.identifier): + client.collections.create( + name=smemory_bank.identifier, + vectorizer_config=wvc.config.Configure.Vectorizer.none(), + properties=[ + wvc.config.Property( + name="chunk_content", + data_type=wvc.config.DataType.TEXT, + ), + ], + ) index = BankWithIndex( - bank=bank, - index=WeaviateIndex(cleint = self.client, collection = bank_id), + bank=memory_bank, + index=WeaviateIndex(client=client, collection=memory_bank.identifier), ) self.cache[bank_id] = index - return bank - - async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]: - bank_index = await self._get_and_cache_bank_index(bank_id) - if bank_index is None: - return None - return bank_index.bank async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]: - - self.client = self._get_client() - if bank_id in self.cache: return self.cache[bank_id] - collections = await self.client.collections.list_all().keys() + bank = await self.memory_bank_store.get_memory_bank(bank_id) + if not bank: + raise ValueError(f"Bank {bank_id} not found") + + client = await self._get_client() + collections = await client.collections.list_all().keys() for collection in collections: if collection == bank_id: @@ -189,4 +179,4 @@ class WeaviateMemoryAdapter(Memory): if not index: raise ValueError(f"Bank {bank_id} not found") - return await index.query_documents(query, params) \ No newline at end of file + return await index.query_documents(query, params)