added quey expnasion model to extra_body

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
Francisco Javier Arceo 2025-11-19 22:41:19 -05:00
parent ac7cb1ba5a
commit 2cc7943fd6
4 changed files with 130 additions and 12 deletions

View file

@ -1230,3 +1230,97 @@ async def test_embedding_config_required_model_missing(vector_io_adapter):
with pytest.raises(ValueError, match="embedding_model is required"):
await vector_io_adapter.openai_create_vector_store(params)
async def test_query_expansion_functionality(vector_io_adapter):
"""Test query expansion with per-store models, global defaults, and error validation."""
from unittest.mock import MagicMock
from llama_stack.core.datatypes import QualifiedModel, VectorStoresConfig
from llama_stack.providers.utils.memory.vector_store import VectorStoreWithIndex
from llama_stack_api.models import Model, ModelType
vector_io_adapter.register_vector_store = AsyncMock()
vector_io_adapter.__provider_id__ = "test_provider"
# Test 1: Per-store model usage
params = OpenAICreateVectorStoreRequestWithExtraBody(
name="test_store",
metadata={},
**{"embedding_model": "test/embedding", "query_expansion_model": "test/llama-model"},
)
await vector_io_adapter.openai_create_vector_store(params)
call_args = vector_io_adapter.register_vector_store.call_args[0][0]
assert call_args.query_expansion_model == "test/llama-model"
# Test 2: Global default fallback
vector_io_adapter.register_vector_store.reset_mock()
params_no_model = OpenAICreateVectorStoreRequestWithExtraBody(
name="test_store2", metadata={}, **{"embedding_model": "test/embedding"}
)
await vector_io_adapter.openai_create_vector_store(params_no_model)
call_args2 = vector_io_adapter.register_vector_store.call_args[0][0]
assert call_args2.query_expansion_model is None
# Test query rewriting scenarios
mock_inference_api = MagicMock()
# Per-store model scenario
mock_vector_store = MagicMock()
mock_vector_store.query_expansion_model = "test/llama-model"
mock_inference_api.routing_table.list_models = AsyncMock(
return_value=MagicMock(
data=[Model(identifier="test/llama-model", provider_id="test", model_type=ModelType.llm)]
)
)
mock_inference_api.openai_chat_completion = AsyncMock(
return_value=MagicMock(choices=[MagicMock(message=MagicMock(content="per-store expanded"))])
)
vector_store_with_index = VectorStoreWithIndex(
vector_store=mock_vector_store,
index=MagicMock(),
inference_api=mock_inference_api,
vector_stores_config=VectorStoresConfig(
default_query_expansion_model=QualifiedModel(provider_id="global", model_id="default")
),
)
result = await vector_store_with_index._rewrite_query_for_search("test")
assert mock_inference_api.openai_chat_completion.call_args[0][0].model == "test/llama-model"
assert result == "per-store expanded"
# Global default fallback scenario
mock_inference_api.reset_mock()
mock_vector_store.query_expansion_model = None
mock_inference_api.routing_table.list_models = AsyncMock(
return_value=MagicMock(
data=[Model(identifier="global/default", provider_id="global", model_type=ModelType.llm)]
)
)
mock_inference_api.openai_chat_completion = AsyncMock(
return_value=MagicMock(choices=[MagicMock(message=MagicMock(content="global expanded"))])
)
result = await vector_store_with_index._rewrite_query_for_search("test")
assert mock_inference_api.openai_chat_completion.call_args[0][0].model == "global/default"
assert result == "global expanded"
# Test 3: Error cases
# Model not found
mock_vector_store.query_expansion_model = "missing/model"
mock_inference_api.routing_table.list_models = AsyncMock(return_value=MagicMock(data=[]))
with pytest.raises(ValueError, match="Configured query expansion model .* is not available"):
await vector_store_with_index._rewrite_query_for_search("test")
# Non-LLM model
mock_vector_store.query_expansion_model = "test/embedding-model"
mock_inference_api.routing_table.list_models = AsyncMock(
return_value=MagicMock(
data=[Model(identifier="test/embedding-model", provider_id="test", model_type=ModelType.embedding)]
)
)
with pytest.raises(ValueError, match="is not an LLM model.*Query rewriting requires an LLM model"):
await vector_store_with_index._rewrite_query_for_search("test")