From 2cc7943fd61f69783dff1d810fa4e4fcc03a4b41 Mon Sep 17 00:00:00 2001 From: Francisco Javier Arceo Date: Wed, 19 Nov 2025 22:41:19 -0500 Subject: [PATCH] added quey expnasion model to extra_body Signed-off-by: Francisco Javier Arceo --- .../utils/memory/openai_vector_store_mixin.py | 8 ++ .../providers/utils/memory/vector_store.py | 39 +++++--- src/llama_stack_api/vector_stores.py | 1 + .../test_vector_io_openai_vector_stores.py | 94 +++++++++++++++++++ 4 files changed, 130 insertions(+), 12 deletions(-) diff --git a/src/llama_stack/providers/utils/memory/openai_vector_store_mixin.py b/src/llama_stack/providers/utils/memory/openai_vector_store_mixin.py index d83aa6dc1..4e67cf24b 100644 --- a/src/llama_stack/providers/utils/memory/openai_vector_store_mixin.py +++ b/src/llama_stack/providers/utils/memory/openai_vector_store_mixin.py @@ -379,6 +379,11 @@ class OpenAIVectorStoreMixin(ABC): f"Using embedding config from extra_body: model='{embedding_model}', dimension={embedding_dimension}" ) + # Extract query expansion model from extra_body if provided + query_expansion_model = extra_body.get("query_expansion_model") + if query_expansion_model: + logger.debug(f"Using per-store query expansion model: {query_expansion_model}") + # use provider_id set by router; fallback to provider's own ID when used directly via --stack-config provider_id = extra_body.get("provider_id") or getattr(self, "__provider_id__", None) # Derive the canonical vector_store_id (allow override, else generate) @@ -402,6 +407,7 @@ class OpenAIVectorStoreMixin(ABC): provider_id=provider_id, provider_resource_id=vector_store_id, vector_store_name=params.name, + query_expansion_model=query_expansion_model, ) await self.register_vector_store(vector_store) @@ -607,12 +613,14 @@ class OpenAIVectorStoreMixin(ABC): if ranking_options and ranking_options.score_threshold is not None else 0.0 ) + params = { "max_chunks": max_num_results * CHUNK_MULTIPLIER, "score_threshold": score_threshold, "mode": search_mode, "rewrite_query": rewrite_query, } + # Add vector_stores_config if available (for query rewriting) if hasattr(self, "vector_stores_config"): params["vector_stores_config"] = self.vector_stores_config diff --git a/src/llama_stack/providers/utils/memory/vector_store.py b/src/llama_stack/providers/utils/memory/vector_store.py index 2a7b94292..71d61787a 100644 --- a/src/llama_stack/providers/utils/memory/vector_store.py +++ b/src/llama_stack/providers/utils/memory/vector_store.py @@ -17,7 +17,7 @@ import numpy as np from numpy.typing import NDArray from pydantic import BaseModel -from llama_stack.core.datatypes import VectorStoresConfig +from llama_stack.core.datatypes import QualifiedModel, VectorStoresConfig from llama_stack.log import get_logger from llama_stack.models.llama.llama3.tokenizer import Tokenizer from llama_stack.providers.utils.inference.prompt_adapter import ( @@ -366,18 +366,33 @@ class VectorStoreWithIndex: :param query: The original user query :returns: The rewritten query optimized for vector search """ - # Check if query expansion model is configured - if not self.vector_stores_config: - raise ValueError( - f"No vector_stores_config found! self.vector_stores_config is: {self.vector_stores_config}" - ) - if not self.vector_stores_config.default_query_expansion_model: - raise ValueError( - f"No default_query_expansion_model configured! vector_stores_config: {self.vector_stores_config}, default_query_expansion_model: {self.vector_stores_config.default_query_expansion_model}" - ) + expansion_model = None + + # Check for per-store query expansion model first + if self.vector_store.query_expansion_model: + # Parse the model string into provider_id and model_id + model_parts = self.vector_store.query_expansion_model.split("/", 1) + if len(model_parts) == 2: + expansion_model = QualifiedModel(provider_id=model_parts[0], model_id=model_parts[1]) + log.debug(f"Using per-store query expansion model: {expansion_model}") + else: + log.warning( + f"Invalid query_expansion_model format: {self.vector_store.query_expansion_model}. Expected 'provider_id/model_id'" + ) + + # Fall back to global default if no per-store model + if not expansion_model: + if not self.vector_stores_config: + raise ValueError( + f"No vector_stores_config found and no per-store query_expansion_model! self.vector_stores_config is: {self.vector_stores_config}" + ) + if not self.vector_stores_config.default_query_expansion_model: + raise ValueError( + f"No default_query_expansion_model configured and no per-store query_expansion_model! vector_stores_config: {self.vector_stores_config}, default_query_expansion_model: {self.vector_stores_config.default_query_expansion_model}" + ) + expansion_model = self.vector_stores_config.default_query_expansion_model + log.debug(f"Using global default query expansion model: {expansion_model}") - # Use the configured model - expansion_model = self.vector_stores_config.default_query_expansion_model chat_model = f"{expansion_model.provider_id}/{expansion_model.model_id}" # Validate that the model is available and is an LLM diff --git a/src/llama_stack_api/vector_stores.py b/src/llama_stack_api/vector_stores.py index 0a1e6c53c..4c0d1ced2 100644 --- a/src/llama_stack_api/vector_stores.py +++ b/src/llama_stack_api/vector_stores.py @@ -25,6 +25,7 @@ class VectorStore(Resource): embedding_model: str embedding_dimension: int vector_store_name: str | None = None + query_expansion_model: str | None = None @property def vector_store_id(self) -> str: diff --git a/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py b/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py index 3797abb2c..cfda7aa5e 100644 --- a/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py +++ b/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py @@ -1230,3 +1230,97 @@ async def test_embedding_config_required_model_missing(vector_io_adapter): with pytest.raises(ValueError, match="embedding_model is required"): await vector_io_adapter.openai_create_vector_store(params) + + +async def test_query_expansion_functionality(vector_io_adapter): + """Test query expansion with per-store models, global defaults, and error validation.""" + from unittest.mock import MagicMock + + from llama_stack.core.datatypes import QualifiedModel, VectorStoresConfig + from llama_stack.providers.utils.memory.vector_store import VectorStoreWithIndex + from llama_stack_api.models import Model, ModelType + + vector_io_adapter.register_vector_store = AsyncMock() + vector_io_adapter.__provider_id__ = "test_provider" + + # Test 1: Per-store model usage + params = OpenAICreateVectorStoreRequestWithExtraBody( + name="test_store", + metadata={}, + **{"embedding_model": "test/embedding", "query_expansion_model": "test/llama-model"}, + ) + await vector_io_adapter.openai_create_vector_store(params) + call_args = vector_io_adapter.register_vector_store.call_args[0][0] + assert call_args.query_expansion_model == "test/llama-model" + + # Test 2: Global default fallback + vector_io_adapter.register_vector_store.reset_mock() + params_no_model = OpenAICreateVectorStoreRequestWithExtraBody( + name="test_store2", metadata={}, **{"embedding_model": "test/embedding"} + ) + await vector_io_adapter.openai_create_vector_store(params_no_model) + call_args2 = vector_io_adapter.register_vector_store.call_args[0][0] + assert call_args2.query_expansion_model is None + + # Test query rewriting scenarios + mock_inference_api = MagicMock() + + # Per-store model scenario + mock_vector_store = MagicMock() + mock_vector_store.query_expansion_model = "test/llama-model" + mock_inference_api.routing_table.list_models = AsyncMock( + return_value=MagicMock( + data=[Model(identifier="test/llama-model", provider_id="test", model_type=ModelType.llm)] + ) + ) + mock_inference_api.openai_chat_completion = AsyncMock( + return_value=MagicMock(choices=[MagicMock(message=MagicMock(content="per-store expanded"))]) + ) + + vector_store_with_index = VectorStoreWithIndex( + vector_store=mock_vector_store, + index=MagicMock(), + inference_api=mock_inference_api, + vector_stores_config=VectorStoresConfig( + default_query_expansion_model=QualifiedModel(provider_id="global", model_id="default") + ), + ) + + result = await vector_store_with_index._rewrite_query_for_search("test") + assert mock_inference_api.openai_chat_completion.call_args[0][0].model == "test/llama-model" + assert result == "per-store expanded" + + # Global default fallback scenario + mock_inference_api.reset_mock() + mock_vector_store.query_expansion_model = None + mock_inference_api.routing_table.list_models = AsyncMock( + return_value=MagicMock( + data=[Model(identifier="global/default", provider_id="global", model_type=ModelType.llm)] + ) + ) + mock_inference_api.openai_chat_completion = AsyncMock( + return_value=MagicMock(choices=[MagicMock(message=MagicMock(content="global expanded"))]) + ) + + result = await vector_store_with_index._rewrite_query_for_search("test") + assert mock_inference_api.openai_chat_completion.call_args[0][0].model == "global/default" + assert result == "global expanded" + + # Test 3: Error cases + # Model not found + mock_vector_store.query_expansion_model = "missing/model" + mock_inference_api.routing_table.list_models = AsyncMock(return_value=MagicMock(data=[])) + + with pytest.raises(ValueError, match="Configured query expansion model .* is not available"): + await vector_store_with_index._rewrite_query_for_search("test") + + # Non-LLM model + mock_vector_store.query_expansion_model = "test/embedding-model" + mock_inference_api.routing_table.list_models = AsyncMock( + return_value=MagicMock( + data=[Model(identifier="test/embedding-model", provider_id="test", model_type=ModelType.embedding)] + ) + ) + + with pytest.raises(ValueError, match="is not an LLM model.*Query rewriting requires an LLM model"): + await vector_store_with_index._rewrite_query_for_search("test")