mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
added quey expnasion model to extra_body
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
ac7cb1ba5a
commit
2cc7943fd6
4 changed files with 130 additions and 12 deletions
|
|
@ -379,6 +379,11 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
f"Using embedding config from extra_body: model='{embedding_model}', dimension={embedding_dimension}"
|
||||
)
|
||||
|
||||
# Extract query expansion model from extra_body if provided
|
||||
query_expansion_model = extra_body.get("query_expansion_model")
|
||||
if query_expansion_model:
|
||||
logger.debug(f"Using per-store query expansion model: {query_expansion_model}")
|
||||
|
||||
# use provider_id set by router; fallback to provider's own ID when used directly via --stack-config
|
||||
provider_id = extra_body.get("provider_id") or getattr(self, "__provider_id__", None)
|
||||
# Derive the canonical vector_store_id (allow override, else generate)
|
||||
|
|
@ -402,6 +407,7 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
provider_id=provider_id,
|
||||
provider_resource_id=vector_store_id,
|
||||
vector_store_name=params.name,
|
||||
query_expansion_model=query_expansion_model,
|
||||
)
|
||||
await self.register_vector_store(vector_store)
|
||||
|
||||
|
|
@ -607,12 +613,14 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
if ranking_options and ranking_options.score_threshold is not None
|
||||
else 0.0
|
||||
)
|
||||
|
||||
params = {
|
||||
"max_chunks": max_num_results * CHUNK_MULTIPLIER,
|
||||
"score_threshold": score_threshold,
|
||||
"mode": search_mode,
|
||||
"rewrite_query": rewrite_query,
|
||||
}
|
||||
|
||||
# Add vector_stores_config if available (for query rewriting)
|
||||
if hasattr(self, "vector_stores_config"):
|
||||
params["vector_stores_config"] = self.vector_stores_config
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ import numpy as np
|
|||
from numpy.typing import NDArray
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.core.datatypes import VectorStoresConfig
|
||||
from llama_stack.core.datatypes import QualifiedModel, VectorStoresConfig
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
|
|
@ -366,18 +366,33 @@ class VectorStoreWithIndex:
|
|||
:param query: The original user query
|
||||
:returns: The rewritten query optimized for vector search
|
||||
"""
|
||||
# Check if query expansion model is configured
|
||||
if not self.vector_stores_config:
|
||||
raise ValueError(
|
||||
f"No vector_stores_config found! self.vector_stores_config is: {self.vector_stores_config}"
|
||||
)
|
||||
if not self.vector_stores_config.default_query_expansion_model:
|
||||
raise ValueError(
|
||||
f"No default_query_expansion_model configured! vector_stores_config: {self.vector_stores_config}, default_query_expansion_model: {self.vector_stores_config.default_query_expansion_model}"
|
||||
)
|
||||
expansion_model = None
|
||||
|
||||
# Check for per-store query expansion model first
|
||||
if self.vector_store.query_expansion_model:
|
||||
# Parse the model string into provider_id and model_id
|
||||
model_parts = self.vector_store.query_expansion_model.split("/", 1)
|
||||
if len(model_parts) == 2:
|
||||
expansion_model = QualifiedModel(provider_id=model_parts[0], model_id=model_parts[1])
|
||||
log.debug(f"Using per-store query expansion model: {expansion_model}")
|
||||
else:
|
||||
log.warning(
|
||||
f"Invalid query_expansion_model format: {self.vector_store.query_expansion_model}. Expected 'provider_id/model_id'"
|
||||
)
|
||||
|
||||
# Fall back to global default if no per-store model
|
||||
if not expansion_model:
|
||||
if not self.vector_stores_config:
|
||||
raise ValueError(
|
||||
f"No vector_stores_config found and no per-store query_expansion_model! self.vector_stores_config is: {self.vector_stores_config}"
|
||||
)
|
||||
if not self.vector_stores_config.default_query_expansion_model:
|
||||
raise ValueError(
|
||||
f"No default_query_expansion_model configured and no per-store query_expansion_model! vector_stores_config: {self.vector_stores_config}, default_query_expansion_model: {self.vector_stores_config.default_query_expansion_model}"
|
||||
)
|
||||
expansion_model = self.vector_stores_config.default_query_expansion_model
|
||||
log.debug(f"Using global default query expansion model: {expansion_model}")
|
||||
|
||||
# Use the configured model
|
||||
expansion_model = self.vector_stores_config.default_query_expansion_model
|
||||
chat_model = f"{expansion_model.provider_id}/{expansion_model.model_id}"
|
||||
|
||||
# Validate that the model is available and is an LLM
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ class VectorStore(Resource):
|
|||
embedding_model: str
|
||||
embedding_dimension: int
|
||||
vector_store_name: str | None = None
|
||||
query_expansion_model: str | None = None
|
||||
|
||||
@property
|
||||
def vector_store_id(self) -> str:
|
||||
|
|
|
|||
|
|
@ -1230,3 +1230,97 @@ async def test_embedding_config_required_model_missing(vector_io_adapter):
|
|||
|
||||
with pytest.raises(ValueError, match="embedding_model is required"):
|
||||
await vector_io_adapter.openai_create_vector_store(params)
|
||||
|
||||
|
||||
async def test_query_expansion_functionality(vector_io_adapter):
|
||||
"""Test query expansion with per-store models, global defaults, and error validation."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from llama_stack.core.datatypes import QualifiedModel, VectorStoresConfig
|
||||
from llama_stack.providers.utils.memory.vector_store import VectorStoreWithIndex
|
||||
from llama_stack_api.models import Model, ModelType
|
||||
|
||||
vector_io_adapter.register_vector_store = AsyncMock()
|
||||
vector_io_adapter.__provider_id__ = "test_provider"
|
||||
|
||||
# Test 1: Per-store model usage
|
||||
params = OpenAICreateVectorStoreRequestWithExtraBody(
|
||||
name="test_store",
|
||||
metadata={},
|
||||
**{"embedding_model": "test/embedding", "query_expansion_model": "test/llama-model"},
|
||||
)
|
||||
await vector_io_adapter.openai_create_vector_store(params)
|
||||
call_args = vector_io_adapter.register_vector_store.call_args[0][0]
|
||||
assert call_args.query_expansion_model == "test/llama-model"
|
||||
|
||||
# Test 2: Global default fallback
|
||||
vector_io_adapter.register_vector_store.reset_mock()
|
||||
params_no_model = OpenAICreateVectorStoreRequestWithExtraBody(
|
||||
name="test_store2", metadata={}, **{"embedding_model": "test/embedding"}
|
||||
)
|
||||
await vector_io_adapter.openai_create_vector_store(params_no_model)
|
||||
call_args2 = vector_io_adapter.register_vector_store.call_args[0][0]
|
||||
assert call_args2.query_expansion_model is None
|
||||
|
||||
# Test query rewriting scenarios
|
||||
mock_inference_api = MagicMock()
|
||||
|
||||
# Per-store model scenario
|
||||
mock_vector_store = MagicMock()
|
||||
mock_vector_store.query_expansion_model = "test/llama-model"
|
||||
mock_inference_api.routing_table.list_models = AsyncMock(
|
||||
return_value=MagicMock(
|
||||
data=[Model(identifier="test/llama-model", provider_id="test", model_type=ModelType.llm)]
|
||||
)
|
||||
)
|
||||
mock_inference_api.openai_chat_completion = AsyncMock(
|
||||
return_value=MagicMock(choices=[MagicMock(message=MagicMock(content="per-store expanded"))])
|
||||
)
|
||||
|
||||
vector_store_with_index = VectorStoreWithIndex(
|
||||
vector_store=mock_vector_store,
|
||||
index=MagicMock(),
|
||||
inference_api=mock_inference_api,
|
||||
vector_stores_config=VectorStoresConfig(
|
||||
default_query_expansion_model=QualifiedModel(provider_id="global", model_id="default")
|
||||
),
|
||||
)
|
||||
|
||||
result = await vector_store_with_index._rewrite_query_for_search("test")
|
||||
assert mock_inference_api.openai_chat_completion.call_args[0][0].model == "test/llama-model"
|
||||
assert result == "per-store expanded"
|
||||
|
||||
# Global default fallback scenario
|
||||
mock_inference_api.reset_mock()
|
||||
mock_vector_store.query_expansion_model = None
|
||||
mock_inference_api.routing_table.list_models = AsyncMock(
|
||||
return_value=MagicMock(
|
||||
data=[Model(identifier="global/default", provider_id="global", model_type=ModelType.llm)]
|
||||
)
|
||||
)
|
||||
mock_inference_api.openai_chat_completion = AsyncMock(
|
||||
return_value=MagicMock(choices=[MagicMock(message=MagicMock(content="global expanded"))])
|
||||
)
|
||||
|
||||
result = await vector_store_with_index._rewrite_query_for_search("test")
|
||||
assert mock_inference_api.openai_chat_completion.call_args[0][0].model == "global/default"
|
||||
assert result == "global expanded"
|
||||
|
||||
# Test 3: Error cases
|
||||
# Model not found
|
||||
mock_vector_store.query_expansion_model = "missing/model"
|
||||
mock_inference_api.routing_table.list_models = AsyncMock(return_value=MagicMock(data=[]))
|
||||
|
||||
with pytest.raises(ValueError, match="Configured query expansion model .* is not available"):
|
||||
await vector_store_with_index._rewrite_query_for_search("test")
|
||||
|
||||
# Non-LLM model
|
||||
mock_vector_store.query_expansion_model = "test/embedding-model"
|
||||
mock_inference_api.routing_table.list_models = AsyncMock(
|
||||
return_value=MagicMock(
|
||||
data=[Model(identifier="test/embedding-model", provider_id="test", model_type=ModelType.embedding)]
|
||||
)
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="is not an LLM model.*Query rewriting requires an LLM model"):
|
||||
await vector_store_with_index._rewrite_query_for_search("test")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue