mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-10 05:24:39 +00:00
fix(vector-io): unify score calculation to use cosine and normalize to [0,1]
This commit is contained in:
parent
9618adba89
commit
a0e0c7030b
9 changed files with 166 additions and 42 deletions
|
@ -222,3 +222,63 @@ def test_query_returns_valid_object_when_identical_to_embedding_in_vdb(
|
|||
assert len(response.chunks) > 0
|
||||
assert response.chunks[0].metadata["document_id"] == "doc1"
|
||||
assert response.chunks[0].metadata["source"] == "precomputed"
|
||||
|
||||
|
||||
def test_vector_similarity_scores_are_normalized(
|
||||
client_with_empty_registry, embedding_model_id, embedding_dimension, sample_chunks
|
||||
):
|
||||
"""Test that vector similarity scores are properly normalized to [0,1] range for all vector providers."""
|
||||
vector_db_name = "test_score_normalization_db"
|
||||
register_response = client_with_empty_registry.vector_dbs.register(
|
||||
vector_db_id=vector_db_name,
|
||||
embedding_model=embedding_model_id,
|
||||
embedding_dimension=embedding_dimension,
|
||||
)
|
||||
|
||||
actual_vector_db_id = register_response.identifier
|
||||
|
||||
# Insert sample chunks
|
||||
client_with_empty_registry.vector_io.insert(
|
||||
vector_db_id=actual_vector_db_id,
|
||||
chunks=sample_chunks,
|
||||
)
|
||||
|
||||
# Test various queries to ensure score normalization across different similarity levels
|
||||
test_queries = [
|
||||
# High similarity query that should match Python doc chunk
|
||||
"Python programming language with readable code",
|
||||
# Medium similarity query
|
||||
"artificial intelligence and machine learning systems",
|
||||
# Lower similarity query
|
||||
"What is the capital of France?",
|
||||
# High similarity query that should match neural networks chunk
|
||||
"biological neural networks and artificial neurons",
|
||||
# Very low similarity query to test edge case normalization
|
||||
"xyzabc random nonsense gibberish qwerty asdfgh",
|
||||
]
|
||||
|
||||
for query in test_queries:
|
||||
response = client_with_empty_registry.vector_io.query(
|
||||
vector_db_id=actual_vector_db_id,
|
||||
query=query,
|
||||
)
|
||||
|
||||
# Verify response structure
|
||||
assert response is not None, f"Query '{query}' returned None response"
|
||||
assert len(response.chunks) > 0, f"Query '{query}' returned no chunks"
|
||||
assert len(response.scores) > 0, f"Query '{query}' returned no scores"
|
||||
assert len(response.chunks) == len(response.scores), "Mismatch between chunks and scores count"
|
||||
|
||||
# Verify all scores are normalized to [0,1] range
|
||||
for i, score in enumerate(response.scores):
|
||||
assert isinstance(score, (int | float)), f"Score at index {i} is not numeric: {type(score)}"
|
||||
assert 0.0 <= score <= 1.0, (
|
||||
f"Score at index {i} is not normalized: {score} (should be in [0,1] range) for query '{query}'"
|
||||
)
|
||||
|
||||
# Verify scores are in descending order (most similar first)
|
||||
for i in range(1, len(response.scores)):
|
||||
assert response.scores[i - 1] >= response.scores[i], (
|
||||
f"Scores not in descending order at indices {i - 1} and {i}: "
|
||||
f"{response.scores[i - 1]} >= {response.scores[i]} for query '{query}'"
|
||||
)
|
||||
|
|
|
@ -112,14 +112,16 @@ async def faiss_adapter(faiss_config, mock_inference_api, mock_files_api) -> Fai
|
|||
yield adapter
|
||||
|
||||
|
||||
async def test_faiss_query_vector_returns_infinity_when_query_and_embedding_are_identical(
|
||||
async def test_faiss_query_vector_returns_perfect_score_when_query_and_embedding_are_identical(
|
||||
faiss_index, sample_chunks, sample_embeddings, embedding_dimension
|
||||
):
|
||||
await faiss_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
|
||||
|
||||
with patch.object(faiss_index.index, "search") as mock_search:
|
||||
mock_search.return_value = (np.array([[0.0, 0.1]]), np.array([[0, 1]]))
|
||||
# IndexFlatIP with normalized vectors returns cosine similarity scores [-1,1]
|
||||
# These will be normalized to [0,1] using (score + 1.0) / 2.0
|
||||
mock_search.return_value = (np.array([[1.0, 0.6]]), np.array([[0, 1]]))
|
||||
|
||||
response = await faiss_index.query_vector(embedding=query_embedding, k=2, score_threshold=0.0)
|
||||
|
||||
|
@ -127,8 +129,8 @@ async def test_faiss_query_vector_returns_infinity_when_query_and_embedding_are_
|
|||
assert len(response.chunks) == 2
|
||||
assert len(response.scores) == 2
|
||||
|
||||
assert response.scores[0] == float("inf") # infinity (1.0 / 0.0)
|
||||
assert response.scores[1] == 10.0 # (1.0 / 0.1 = 10.0)
|
||||
assert response.scores[0] == 1.0 # (1.0 + 1.0) / 2.0 = 1.0 (perfect similarity)
|
||||
assert response.scores[1] == 0.8 # (0.6 + 1.0) / 2.0 = 0.8 (high similarity)
|
||||
|
||||
assert response.chunks[0] == sample_chunks[0]
|
||||
assert response.chunks[1] == sample_chunks[1]
|
||||
|
@ -141,7 +143,7 @@ async def test_health_success():
|
|||
inference_api = MagicMock()
|
||||
files_api = MagicMock()
|
||||
|
||||
with patch("llama_stack.providers.inline.vector_io.faiss.faiss.faiss.IndexFlatL2") as mock_index_flat:
|
||||
with patch("llama_stack.providers.inline.vector_io.faiss.faiss.faiss.IndexFlatIP") as mock_index_flat:
|
||||
mock_index_flat.return_value = MagicMock()
|
||||
adapter = FaissVectorIOAdapter(config=config, inference_api=inference_api, files_api=files_api)
|
||||
|
||||
|
@ -153,7 +155,7 @@ async def test_health_success():
|
|||
assert response["status"] == HealthStatus.OK
|
||||
assert "message" not in response
|
||||
|
||||
# Verifying that IndexFlatL2 was called with the correct dimension
|
||||
# Verifying that IndexFlatIP was called with the correct dimension
|
||||
mock_index_flat.assert_called_once_with(128) # VECTOR_DIMENSION is 128
|
||||
|
||||
|
||||
|
@ -164,7 +166,7 @@ async def test_health_failure():
|
|||
inference_api = MagicMock()
|
||||
files_api = MagicMock()
|
||||
|
||||
with patch("llama_stack.providers.inline.vector_io.faiss.faiss.faiss.IndexFlatL2") as mock_index_flat:
|
||||
with patch("llama_stack.providers.inline.vector_io.faiss.faiss.faiss.IndexFlatIP") as mock_index_flat:
|
||||
mock_index_flat.side_effect = Exception("Test error")
|
||||
|
||||
adapter = FaissVectorIOAdapter(config=config, inference_api=inference_api, files_api=files_api)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue