Added in registry and tests passed

This commit is contained in:
Sarthak Deshpande 2024-10-23 23:45:01 +05:30
parent c2d74188ee
commit 07e9da19b3
5 changed files with 42 additions and 25 deletions

View file

@ -20,10 +20,12 @@ providers:
config:
host: localhost
port: 6333
- provider_id: test-pinecone
provider_type: remote::pinecone
config: {}
# if a provider needs private keys from the client, they use the
# "get_request_provider_data" function (see distribution/request_headers.py)
# this is a place to provide such data.
provider_data:
"test-weaviate":
weaviate_api_key: 0xdeadbeefputrealapikeyhere
weaviate_cluster_url: http://foobarbaz
"test-pinecone":
pinecone_api_key:

View file

@ -69,7 +69,7 @@ def sample_documents():
async def register_memory_bank(banks_impl: MemoryBanks):
bank = VectorMemoryBankDef(
identifier="test_bank",
identifier="test-bank",
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
@ -95,7 +95,7 @@ async def test_banks_register(memory_settings):
# but so far we don't have an unregister API unfortunately, so be careful
banks_impl = memory_settings["memory_banks_impl"]
bank = VectorMemoryBankDef(
identifier="test_bank_no_provider",
identifier="test-bank-no-provider",
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
@ -119,33 +119,33 @@ async def test_query_documents(memory_settings, sample_documents):
banks_impl = memory_settings["memory_banks_impl"]
with pytest.raises(ValueError):
await memory_impl.insert_documents("test_bank", sample_documents)
await memory_impl.insert_documents("test-bank", sample_documents)
await register_memory_bank(banks_impl)
await memory_impl.insert_documents("test_bank", sample_documents)
await memory_impl.insert_documents("test-bank", sample_documents)
query1 = "programming language"
response1 = await memory_impl.query_documents("test_bank", query1)
response1 = await memory_impl.query_documents("test-bank", query1)
assert_valid_response(response1)
assert any("Python" in chunk.content for chunk in response1.chunks)
# Test case 3: Query with semantic similarity
query3 = "AI and brain-inspired computing"
response3 = await memory_impl.query_documents("test_bank", query3)
response3 = await memory_impl.query_documents("test-bank", query3)
assert_valid_response(response3)
assert any("neural networks" in chunk.content.lower() for chunk in response3.chunks)
# Test case 4: Query with limit on number of results
query4 = "computer"
params4 = {"max_chunks": 2}
response4 = await memory_impl.query_documents("test_bank", query4, params4)
response4 = await memory_impl.query_documents("test-bank", query4, params4)
assert_valid_response(response4)
assert len(response4.chunks) <= 2
# Test case 5: Query with threshold on similarity score
query5 = "quantum computing" # Not directly related to any document
params5 = {"score_threshold": 0.2}
response5 = await memory_impl.query_documents("test_bank", query5, params5)
response5 = await memory_impl.query_documents("test-bank", query5, params5)
assert_valid_response(response5)
print("The scores are:", response5.scores)
assert all(score >= 0.2 for score in response5.scores)