This commit is contained in:
Francisco Javier Arceo 2025-12-03 01:04:06 +00:00 committed by GitHub
commit 5bd80a693c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 7531 additions and 14 deletions

View file

@ -1230,3 +1230,121 @@ async def test_embedding_config_required_model_missing(vector_io_adapter):
with pytest.raises(ValueError, match="embedding_model is required"):
await vector_io_adapter.openai_create_vector_store(params)
async def test_query_expansion_functionality(vector_io_adapter):
"""Test query expansion with simplified global configuration approach."""
from unittest.mock import MagicMock
from llama_stack.core.datatypes import QualifiedModel, RewriteQueryParams
from llama_stack.providers.utils.memory.constants import DEFAULT_QUERY_REWRITE_PROMPT
from llama_stack.providers.utils.memory.rewrite_query_config import set_default_rewrite_query_config
from llama_stack.providers.utils.memory.vector_store import VectorStoreWithIndex
from llama_stack_api import QueryChunksResponse
# Mock a simple vector store and index
mock_vector_store = MagicMock()
mock_vector_store.embedding_model = "test/embedding"
mock_inference_api = MagicMock()
mock_index = MagicMock()
# Create VectorStoreWithIndex with simplified constructor
vector_store_with_index = VectorStoreWithIndex(
vector_store=mock_vector_store,
index=mock_index,
inference_api=mock_inference_api,
)
# Mock the query_vector method to return a simple response
mock_response = QueryChunksResponse(chunks=[], scores=[])
mock_index.query_vector = AsyncMock(return_value=mock_response)
# Mock embeddings generation
mock_inference_api.openai_embeddings = AsyncMock(
return_value=MagicMock(data=[MagicMock(embedding=[0.1, 0.2, 0.3])])
)
# Test 1: Query expansion with default prompt (no custom prompt configured)
mock_vector_stores_config = MagicMock()
mock_vector_stores_config.rewrite_query_params = RewriteQueryParams(
model=QualifiedModel(provider_id="test", model_id="llama"), max_tokens=100, temperature=0.3
)
# Set global config
set_default_rewrite_query_config(mock_vector_stores_config)
# Mock chat completion for query rewriting
mock_inference_api.openai_chat_completion = AsyncMock(
return_value=MagicMock(choices=[MagicMock(message=MagicMock(content="expanded test query"))])
)
params = {"rewrite_query": True, "max_chunks": 5}
result = await vector_store_with_index.query_chunks("test query", params)
# Verify chat completion was called for query rewriting
mock_inference_api.openai_chat_completion.assert_called_once()
chat_call_args = mock_inference_api.openai_chat_completion.call_args[0][0]
assert chat_call_args.model == "test/llama"
# Verify default prompt is used (contains our built-in prompt text)
prompt_text = chat_call_args.messages[0].content
expected_prompt = DEFAULT_QUERY_REWRITE_PROMPT.format(query="test query")
assert prompt_text == expected_prompt
# Verify default inference parameters are used
assert chat_call_args.max_tokens == 100 # Default value
assert chat_call_args.temperature == 0.3 # Default value
# Verify the rest of the flow proceeded normally
mock_inference_api.openai_embeddings.assert_called_once()
mock_index.query_vector.assert_called_once()
assert result == mock_response
# Test 1b: Query expansion with custom prompt override and inference parameters
mock_inference_api.reset_mock()
mock_index.reset_mock()
mock_vector_stores_config.rewrite_query_params = RewriteQueryParams(
model=QualifiedModel(provider_id="test", model_id="llama"),
prompt="Custom prompt for rewriting: {query}",
max_tokens=150,
temperature=0.7,
)
set_default_rewrite_query_config(mock_vector_stores_config)
result = await vector_store_with_index.query_chunks("test query", params)
# Verify custom prompt and parameters are used
mock_inference_api.openai_chat_completion.assert_called_once()
chat_call_args = mock_inference_api.openai_chat_completion.call_args[0][0]
prompt_text = chat_call_args.messages[0].content
assert prompt_text == "Custom prompt for rewriting: test query"
assert "Expand this query with relevant synonyms" not in prompt_text # Default not used
# Verify custom inference parameters
assert chat_call_args.max_tokens == 150
assert chat_call_args.temperature == 0.7
# Test 2: Error when query rewriting is requested but no global model is configured
mock_inference_api.reset_mock()
mock_index.reset_mock()
# Clear global config
set_default_rewrite_query_config(None)
params = {"rewrite_query": True, "max_chunks": 5}
with pytest.raises(ValueError, match="Query rewriting requested but not configured"):
await vector_store_with_index.query_chunks("test query", params)
# Test 3: Normal behavior without rewrite_query parameter
mock_inference_api.reset_mock()
mock_index.reset_mock()
params_no_rewrite = {"max_chunks": 5}
result3 = await vector_store_with_index.query_chunks("test query", params_no_rewrite)
# Neither chat completion nor query rewriting should be called
mock_inference_api.openai_chat_completion.assert_not_called()
mock_inference_api.openai_embeddings.assert_called_once()
mock_index.query_vector.assert_called_once()
assert result3 == mock_response