mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-04 02:03:44 +00:00
renaming to query_rewrite, consolidating, and cleaning up validation
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
d887f1f8bb
commit
31e28b6d17
12 changed files with 138 additions and 180 deletions
|
|
@ -1236,9 +1236,9 @@ async def test_query_expansion_functionality(vector_io_adapter):
|
|||
"""Test query expansion with simplified global configuration approach."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from llama_stack.core.datatypes import QualifiedModel
|
||||
from llama_stack.core.datatypes import QualifiedModel, RewriteQueryParams
|
||||
from llama_stack.providers.utils.memory.constants import DEFAULT_QUERY_EXPANSION_PROMPT
|
||||
from llama_stack.providers.utils.memory.query_expansion_config import set_default_query_expansion_config
|
||||
from llama_stack.providers.utils.memory.rewrite_query_config import set_default_rewrite_query_config
|
||||
from llama_stack.providers.utils.memory.vector_store import VectorStoreWithIndex
|
||||
from llama_stack_api import QueryChunksResponse
|
||||
|
||||
|
|
@ -1266,13 +1266,12 @@ async def test_query_expansion_functionality(vector_io_adapter):
|
|||
|
||||
# Test 1: Query expansion with default prompt (no custom prompt configured)
|
||||
mock_vector_stores_config = MagicMock()
|
||||
mock_vector_stores_config.default_query_expansion_model = QualifiedModel(provider_id="test", model_id="llama")
|
||||
mock_vector_stores_config.query_expansion_prompt = None # Use built-in default prompt
|
||||
mock_vector_stores_config.query_expansion_max_tokens = 100 # Default value
|
||||
mock_vector_stores_config.query_expansion_temperature = 0.3 # Default value
|
||||
mock_vector_stores_config.rewrite_query_params = RewriteQueryParams(
|
||||
model=QualifiedModel(provider_id="test", model_id="llama"), max_tokens=100, temperature=0.3
|
||||
)
|
||||
|
||||
# Set global config
|
||||
set_default_query_expansion_config(mock_vector_stores_config)
|
||||
set_default_rewrite_query_config(mock_vector_stores_config)
|
||||
|
||||
# Mock chat completion for query rewriting
|
||||
mock_inference_api.openai_chat_completion = AsyncMock(
|
||||
|
|
@ -1305,10 +1304,13 @@ async def test_query_expansion_functionality(vector_io_adapter):
|
|||
mock_inference_api.reset_mock()
|
||||
mock_index.reset_mock()
|
||||
|
||||
mock_vector_stores_config.query_expansion_prompt = "Custom prompt for rewriting: {query}"
|
||||
mock_vector_stores_config.query_expansion_max_tokens = 150
|
||||
mock_vector_stores_config.query_expansion_temperature = 0.7
|
||||
set_default_query_expansion_config(mock_vector_stores_config)
|
||||
mock_vector_stores_config.rewrite_query_params = RewriteQueryParams(
|
||||
model=QualifiedModel(provider_id="test", model_id="llama"),
|
||||
prompt="Custom prompt for rewriting: {query}",
|
||||
max_tokens=150,
|
||||
temperature=0.7,
|
||||
)
|
||||
set_default_rewrite_query_config(mock_vector_stores_config)
|
||||
|
||||
result = await vector_store_with_index.query_chunks("test query", params)
|
||||
|
||||
|
|
@ -1328,7 +1330,7 @@ async def test_query_expansion_functionality(vector_io_adapter):
|
|||
mock_index.reset_mock()
|
||||
|
||||
# Clear global config
|
||||
set_default_query_expansion_config(None)
|
||||
set_default_rewrite_query_config(None)
|
||||
|
||||
params = {"rewrite_query": True, "max_chunks": 5}
|
||||
result2 = await vector_store_with_index.query_chunks("test query", params)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue