diff --git a/llama_stack/providers/adapters/memory/chroma/chroma.py b/llama_stack/providers/adapters/memory/chroma/chroma.py index f720159a5..f8af9ac5c 100644 --- a/llama_stack/providers/adapters/memory/chroma/chroma.py +++ b/llama_stack/providers/adapters/memory/chroma/chroma.py @@ -99,7 +99,7 @@ class ChromaMemoryAdapter(Memory): memory_bank.type == MemoryBankType.vector.value ), f"Only vector banks are supported {memory_bank.type}" - collection = await self.client.create_collection( + collection = await self.client.get_or_create_collection( name=memory_bank.identifier, ) bank_index = BankWithIndex( diff --git a/llama_stack/providers/tests/memory/test_memory.py b/llama_stack/providers/tests/memory/test_memory.py index 4f6dadb14..1e9db2161 100644 --- a/llama_stack/providers/tests/memory/test_memory.py +++ b/llama_stack/providers/tests/memory/test_memory.py @@ -22,13 +22,29 @@ async def memory_impl(): @pytest.fixture -def sample_document(): - return MemoryBankDocument( - document_id="doc1", - content="This is a sample document for testing.", - mime_type="text/plain", - metadata={"author": "Test Author"}, - ) +def sample_documents(): + return [ + MemoryBankDocument( + document_id="doc1", + content="Python is a high-level programming language.", + metadata={"category": "programming", "difficulty": "beginner"}, + ), + MemoryBankDocument( + document_id="doc2", + content="Machine learning is a subset of artificial intelligence.", + metadata={"category": "AI", "difficulty": "advanced"}, + ), + MemoryBankDocument( + document_id="doc3", + content="Data structures are fundamental to computer science.", + metadata={"category": "computer science", "difficulty": "intermediate"}, + ), + MemoryBankDocument( + document_id="doc4", + content="Neural networks are inspired by biological neural networks.", + metadata={"category": "AI", "difficulty": "advanced"}, + ), + ] async def register_memory_bank(memory_impl: Memory): @@ -44,17 +60,44 @@ async def register_memory_bank(memory_impl: Memory): @pytest.mark.asyncio -async def test_query_documents(memory_impl, sample_document): +async def test_query_documents(memory_impl, sample_documents): with pytest.raises(ValueError): - await memory_impl.insert_documents("test_bank", [sample_document]) + await memory_impl.insert_documents("test_bank", sample_documents) await register_memory_bank(memory_impl) - await memory_impl.insert_documents("test_bank", [sample_document]) + await memory_impl.insert_documents("test_bank", sample_documents) - query = ["sample ", "document"] - response = await memory_impl.query_documents("test_bank", query) + query1 = "programming language" + response1 = await memory_impl.query_documents("test_bank", query1) + assert_valid_response(response1) + assert any("Python" in chunk.content for chunk in response1.chunks) + # Test case 3: Query with semantic similarity + query3 = "AI and brain-inspired computing" + response3 = await memory_impl.query_documents("test_bank", query3) + assert_valid_response(response3) + assert any("neural networks" in chunk.content.lower() for chunk in response3.chunks) + + # Test case 4: Query with limit on number of results + query4 = "computer" + params4 = {"max_chunks": 2} + response4 = await memory_impl.query_documents("test_bank", query4, params4) + assert_valid_response(response4) + assert len(response4.chunks) <= 2 + + # Test case 5: Query with threshold on similarity score + query5 = "quantum computing" # Not directly related to any document + params5 = {"score_threshold": 0.5} + response5 = await memory_impl.query_documents("test_bank", query5, params5) + assert_valid_response(response5) + assert all(score >= 0.5 for score in response5.scores) + + +def assert_valid_response(response: QueryDocumentsResponse): assert isinstance(response, QueryDocumentsResponse) assert len(response.chunks) > 0 assert len(response.scores) > 0 assert len(response.chunks) == len(response.scores) + for chunk in response.chunks: + assert isinstance(chunk.content, str) + assert chunk.document_id is not None