llama-stack/llama_stack/providers/tests/memory/test_memory.py
Ashwin Bharambe d9d271a684
Allow specifying resources in StackRunConfig (#425)
# What does this PR do? 

This PR brings back the facility to not force registration of resources
onto the user. This is not just annoying but actually not feasible
sometimes. For example, you may have a Stack which boots up with private
providers for inference for models A and B. There is no way for the user
to actually know which model is being served by these providers now (to
be able to register it.)

How will this avoid the users needing to do registration? In a follow-up
diff, I will make sure I update the sample run.yaml files so they list
the models served by the distributions explicitly. So when users do
`llama stack build --template <...>` and run it, their distributions
come up with the right set of models they expect.

For self-hosted distributions, it also allows us to have a place to
explicit list the models that need to be served to make the "complete"
stack (including safety, e.g.)

## Test Plan

Started ollama locally with two lightweight models: Llama3.2-3B-Instruct
and Llama-Guard-3-1B.

Updated all the tests including agents. Here's the tests I ran so far:

```bash
pytest -s -v -m "fireworks and llama_3b" test_text_inference.py::TestInference \
  --env FIREWORKS_API_KEY=...

pytest -s -v -m "ollama and llama_3b" test_text_inference.py::TestInference 

pytest -s -v -m ollama test_safety.py

pytest -s -v -m faiss test_memory.py

pytest -s -v -m ollama  test_agents.py \
  --inference-model=Llama3.2-3B-Instruct --safety-model=Llama-Guard-3-1B
```

Found a few bugs here and there pre-existing that these test runs fixed.
2024-11-12 10:58:49 -08:00

144 lines
5.4 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.apis.memory_banks.memory_banks import VectorMemoryBankParams
# How to run this test:
#
# pytest llama_stack/providers/tests/memory/test_memory.py
# -m "meta_reference"
# -v -s --tb=short --disable-warnings
@pytest.fixture
def sample_documents():
return [
MemoryBankDocument(
document_id="doc1",
content="Python is a high-level programming language.",
metadata={"category": "programming", "difficulty": "beginner"},
),
MemoryBankDocument(
document_id="doc2",
content="Machine learning is a subset of artificial intelligence.",
metadata={"category": "AI", "difficulty": "advanced"},
),
MemoryBankDocument(
document_id="doc3",
content="Data structures are fundamental to computer science.",
metadata={"category": "computer science", "difficulty": "intermediate"},
),
MemoryBankDocument(
document_id="doc4",
content="Neural networks are inspired by biological neural networks.",
metadata={"category": "AI", "difficulty": "advanced"},
),
]
async def register_memory_bank(banks_impl: MemoryBanks):
return await banks_impl.register_memory_bank(
memory_bank_id="test_bank",
params=VectorMemoryBankParams(
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
),
)
class TestMemory:
@pytest.mark.asyncio
async def test_banks_list(self, memory_stack):
# NOTE: this needs you to ensure that you are starting from a clean state
# but so far we don't have an unregister API unfortunately, so be careful
_, banks_impl = memory_stack
response = await banks_impl.list_memory_banks()
assert isinstance(response, list)
assert len(response) == 0
@pytest.mark.asyncio
async def test_banks_register(self, memory_stack):
# NOTE: this needs you to ensure that you are starting from a clean state
# but so far we don't have an unregister API unfortunately, so be careful
_, banks_impl = memory_stack
await banks_impl.register_memory_bank(
memory_bank_id="test_bank_no_provider",
params=VectorMemoryBankParams(
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
),
)
response = await banks_impl.list_memory_banks()
assert isinstance(response, list)
assert len(response) == 1
# register same memory bank with same id again will fail
await banks_impl.register_memory_bank(
memory_bank_id="test_bank_no_provider",
params=VectorMemoryBankParams(
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
),
)
response = await banks_impl.list_memory_banks()
assert isinstance(response, list)
assert len(response) == 1
@pytest.mark.asyncio
async def test_query_documents(self, memory_stack, sample_documents):
memory_impl, banks_impl = memory_stack
with pytest.raises(ValueError):
await memory_impl.insert_documents("test_bank", sample_documents)
await register_memory_bank(banks_impl)
await memory_impl.insert_documents("test_bank", sample_documents)
query1 = "programming language"
response1 = await memory_impl.query_documents("test_bank", query1)
assert_valid_response(response1)
assert any("Python" in chunk.content for chunk in response1.chunks)
# Test case 3: Query with semantic similarity
query3 = "AI and brain-inspired computing"
response3 = await memory_impl.query_documents("test_bank", query3)
assert_valid_response(response3)
assert any(
"neural networks" in chunk.content.lower() for chunk in response3.chunks
)
# Test case 4: Query with limit on number of results
query4 = "computer"
params4 = {"max_chunks": 2}
response4 = await memory_impl.query_documents("test_bank", query4, params4)
assert_valid_response(response4)
assert len(response4.chunks) <= 2
# Test case 5: Query with threshold on similarity score
query5 = "quantum computing" # Not directly related to any document
params5 = {"score_threshold": 0.2}
response5 = await memory_impl.query_documents("test_bank", query5, params5)
assert_valid_response(response5)
print("The scores are:", response5.scores)
assert all(score >= 0.2 for score in response5.scores)
def assert_valid_response(response: QueryDocumentsResponse):
assert isinstance(response, QueryDocumentsResponse)
assert len(response.chunks) > 0
assert len(response.scores) > 0
assert len(response.chunks) == len(response.scores)
for chunk in response.chunks:
assert isinstance(chunk.content, str)
assert chunk.document_id is not None