diff --git a/llama_stack/providers/adapters/memory/weaviate/weaviate.py b/llama_stack/providers/adapters/memory/weaviate/weaviate.py index 573802c84..8f5cafdc5 100644 --- a/llama_stack/providers/adapters/memory/weaviate/weaviate.py +++ b/llama_stack/providers/adapters/memory/weaviate/weaviate.py @@ -3,6 +3,7 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import json from typing import Any, Dict, List, Optional @@ -36,7 +37,7 @@ class WeaviateIndex(EmbeddingIndex): data_objects.append( wvc.data.DataObject( properties={ - "chunk_content": chunk, + "chunk_content": chunk.json(), }, vector=embeddings[i].tolist(), ) @@ -44,7 +45,9 @@ class WeaviateIndex(EmbeddingIndex): # Inserting chunks into a prespecified Weaviate collection collection = self.client.collections.get(self.collection_name) - await collection.data.insert_many(data_objects) + + # TODO: make this async friendly + collection.data.insert_many(data_objects) async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse: collection = self.client.collections.get(self.collection_name) @@ -52,13 +55,23 @@ class WeaviateIndex(EmbeddingIndex): results = collection.query.near_vector( near_vector=embedding.tolist(), limit=k, - return_meta_data=wvc.query.MetadataQuery(distance=True), + return_metadata=wvc.query.MetadataQuery(distance=True), ) chunks = [] scores = [] for doc in results.objects: - chunk = doc.properties["chunk_content"] + chunk_json = doc.properties["chunk_content"] + try: + chunk_dict = json.loads(chunk_json) + chunk = Chunk(**chunk_dict) + except Exception: + import traceback + + traceback.print_exc() + print(f"Failed to parse document: {chunk_json}") + continue + chunks.append(chunk) scores.append(1.0 / doc.metadata.distance) @@ -102,12 +115,12 @@ class WeaviateMemoryAdapter(Memory, NeedsRequestProviderData): memory_bank.type == MemoryBankType.vector.value ), f"Only vector banks are supported {memory_bank.type}" - client = await self._get_client() + client = self._get_client() # Create collection if it doesn't exist if not client.collections.exists(memory_bank.identifier): client.collections.create( - name=smemory_bank.identifier, + name=memory_bank.identifier, vectorizer_config=wvc.config.Configure.Vectorizer.none(), properties=[ wvc.config.Property( @@ -121,7 +134,7 @@ class WeaviateMemoryAdapter(Memory, NeedsRequestProviderData): bank=memory_bank, index=WeaviateIndex(client=client, collection_name=memory_bank.identifier), ) - self.cache[bank_id] = index + self.cache[memory_bank.identifier] = index async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]: if bank_id in self.cache: @@ -131,7 +144,7 @@ class WeaviateMemoryAdapter(Memory, NeedsRequestProviderData): if not bank: raise ValueError(f"Bank {bank_id} not found") - client = await self._get_client() + client = self._get_client() if not client.collections.exists(bank_id): raise ValueError(f"Collection with name `{bank_id}` not found") @@ -146,6 +159,7 @@ class WeaviateMemoryAdapter(Memory, NeedsRequestProviderData): self, bank_id: str, documents: List[MemoryBankDocument], + ttl_seconds: Optional[int] = None, ) -> None: index = await self._get_and_cache_bank_index(bank_id) if not index: diff --git a/llama_stack/providers/tests/inference/test_inference.py b/llama_stack/providers/tests/inference/test_inference.py index de8241b20..38b9ff860 100644 --- a/llama_stack/providers/tests/inference/test_inference.py +++ b/llama_stack/providers/tests/inference/test_inference.py @@ -15,6 +15,23 @@ from llama_stack.apis.inference import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.providers.tests.resolver import resolve_impls_for_test +# How to run this test: +# +# 1. Ensure you have a conda with the right dependencies installed. This is a bit tricky +# since it depends on the provider you are testing. On top of that you need +# `pytest` and `pytest-asyncio` installed. +# +# 2. Copy and modify the provider_config_example.yaml depending on the provider you are testing. +# +# 3. Run: +# +# ```bash +# PROVIDER_ID= \ +# PROVIDER_CONFIG=provider_config.yaml \ +# pytest -s llama_stack/providers/tests/memory/test_inference.py \ +# --tb=short --disable-warnings +# ``` + def group_chunks(response): return { diff --git a/llama_stack/providers/tests/memory/test_memory.py b/llama_stack/providers/tests/memory/test_memory.py index 1e9db2161..4351ae699 100644 --- a/llama_stack/providers/tests/memory/test_memory.py +++ b/llama_stack/providers/tests/memory/test_memory.py @@ -11,6 +11,23 @@ from llama_stack.apis.memory import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.providers.tests.resolver import resolve_impls_for_test +# How to run this test: +# +# 1. Ensure you have a conda with the right dependencies installed. This is a bit tricky +# since it depends on the provider you are testing. On top of that you need +# `pytest` and `pytest-asyncio` installed. +# +# 2. Copy and modify the provider_config_example.yaml depending on the provider you are testing. +# +# 3. Run: +# +# ```bash +# PROVIDER_ID= \ +# PROVIDER_CONFIG=provider_config.yaml \ +# pytest -s llama_stack/providers/tests/memory/test_memory.py \ +# --tb=short --disable-warnings +# ``` + @pytest_asyncio.fixture(scope="session") async def memory_impl():