renaming to query_rewrite, consolidating, and cleaning up validation

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
Francisco Javier Arceo 2025-11-21 23:38:13 -05:00
parent d887f1f8bb
commit 31e28b6d17
12 changed files with 138 additions and 180 deletions

View file

@ -1236,9 +1236,9 @@ async def test_query_expansion_functionality(vector_io_adapter):
"""Test query expansion with simplified global configuration approach."""
from unittest.mock import MagicMock
from llama_stack.core.datatypes import QualifiedModel
from llama_stack.core.datatypes import QualifiedModel, RewriteQueryParams
from llama_stack.providers.utils.memory.constants import DEFAULT_QUERY_EXPANSION_PROMPT
from llama_stack.providers.utils.memory.query_expansion_config import set_default_query_expansion_config
from llama_stack.providers.utils.memory.rewrite_query_config import set_default_rewrite_query_config
from llama_stack.providers.utils.memory.vector_store import VectorStoreWithIndex
from llama_stack_api import QueryChunksResponse
@ -1266,13 +1266,12 @@ async def test_query_expansion_functionality(vector_io_adapter):
# Test 1: Query expansion with default prompt (no custom prompt configured)
mock_vector_stores_config = MagicMock()
mock_vector_stores_config.default_query_expansion_model = QualifiedModel(provider_id="test", model_id="llama")
mock_vector_stores_config.query_expansion_prompt = None # Use built-in default prompt
mock_vector_stores_config.query_expansion_max_tokens = 100 # Default value
mock_vector_stores_config.query_expansion_temperature = 0.3 # Default value
mock_vector_stores_config.rewrite_query_params = RewriteQueryParams(
model=QualifiedModel(provider_id="test", model_id="llama"), max_tokens=100, temperature=0.3
)
# Set global config
set_default_query_expansion_config(mock_vector_stores_config)
set_default_rewrite_query_config(mock_vector_stores_config)
# Mock chat completion for query rewriting
mock_inference_api.openai_chat_completion = AsyncMock(
@ -1305,10 +1304,13 @@ async def test_query_expansion_functionality(vector_io_adapter):
mock_inference_api.reset_mock()
mock_index.reset_mock()
mock_vector_stores_config.query_expansion_prompt = "Custom prompt for rewriting: {query}"
mock_vector_stores_config.query_expansion_max_tokens = 150
mock_vector_stores_config.query_expansion_temperature = 0.7
set_default_query_expansion_config(mock_vector_stores_config)
mock_vector_stores_config.rewrite_query_params = RewriteQueryParams(
model=QualifiedModel(provider_id="test", model_id="llama"),
prompt="Custom prompt for rewriting: {query}",
max_tokens=150,
temperature=0.7,
)
set_default_rewrite_query_config(mock_vector_stores_config)
result = await vector_store_with_index.query_chunks("test query", params)
@ -1328,7 +1330,7 @@ async def test_query_expansion_functionality(vector_io_adapter):
mock_index.reset_mock()
# Clear global config
set_default_query_expansion_config(None)
set_default_rewrite_query_config(None)
params = {"rewrite_query": True, "max_chunks": 5}
result2 = await vector_store_with_index.query_chunks("test query", params)