added quey expnasion model to extra_body

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
Francisco Javier Arceo 2025-11-19 22:41:19 -05:00
parent ac7cb1ba5a
commit 2cc7943fd6
4 changed files with 130 additions and 12 deletions

View file

@ -379,6 +379,11 @@ class OpenAIVectorStoreMixin(ABC):
f"Using embedding config from extra_body: model='{embedding_model}', dimension={embedding_dimension}"
)
# Extract query expansion model from extra_body if provided
query_expansion_model = extra_body.get("query_expansion_model")
if query_expansion_model:
logger.debug(f"Using per-store query expansion model: {query_expansion_model}")
# use provider_id set by router; fallback to provider's own ID when used directly via --stack-config
provider_id = extra_body.get("provider_id") or getattr(self, "__provider_id__", None)
# Derive the canonical vector_store_id (allow override, else generate)
@ -402,6 +407,7 @@ class OpenAIVectorStoreMixin(ABC):
provider_id=provider_id,
provider_resource_id=vector_store_id,
vector_store_name=params.name,
query_expansion_model=query_expansion_model,
)
await self.register_vector_store(vector_store)
@ -607,12 +613,14 @@ class OpenAIVectorStoreMixin(ABC):
if ranking_options and ranking_options.score_threshold is not None
else 0.0
)
params = {
"max_chunks": max_num_results * CHUNK_MULTIPLIER,
"score_threshold": score_threshold,
"mode": search_mode,
"rewrite_query": rewrite_query,
}
# Add vector_stores_config if available (for query rewriting)
if hasattr(self, "vector_stores_config"):
params["vector_stores_config"] = self.vector_stores_config

View file

@ -17,7 +17,7 @@ import numpy as np
from numpy.typing import NDArray
from pydantic import BaseModel
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack.core.datatypes import QualifiedModel, VectorStoresConfig
from llama_stack.log import get_logger
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.providers.utils.inference.prompt_adapter import (
@ -366,18 +366,33 @@ class VectorStoreWithIndex:
:param query: The original user query
:returns: The rewritten query optimized for vector search
"""
# Check if query expansion model is configured
if not self.vector_stores_config:
raise ValueError(
f"No vector_stores_config found! self.vector_stores_config is: {self.vector_stores_config}"
)
if not self.vector_stores_config.default_query_expansion_model:
raise ValueError(
f"No default_query_expansion_model configured! vector_stores_config: {self.vector_stores_config}, default_query_expansion_model: {self.vector_stores_config.default_query_expansion_model}"
)
expansion_model = None
# Check for per-store query expansion model first
if self.vector_store.query_expansion_model:
# Parse the model string into provider_id and model_id
model_parts = self.vector_store.query_expansion_model.split("/", 1)
if len(model_parts) == 2:
expansion_model = QualifiedModel(provider_id=model_parts[0], model_id=model_parts[1])
log.debug(f"Using per-store query expansion model: {expansion_model}")
else:
log.warning(
f"Invalid query_expansion_model format: {self.vector_store.query_expansion_model}. Expected 'provider_id/model_id'"
)
# Fall back to global default if no per-store model
if not expansion_model:
if not self.vector_stores_config:
raise ValueError(
f"No vector_stores_config found and no per-store query_expansion_model! self.vector_stores_config is: {self.vector_stores_config}"
)
if not self.vector_stores_config.default_query_expansion_model:
raise ValueError(
f"No default_query_expansion_model configured and no per-store query_expansion_model! vector_stores_config: {self.vector_stores_config}, default_query_expansion_model: {self.vector_stores_config.default_query_expansion_model}"
)
expansion_model = self.vector_stores_config.default_query_expansion_model
log.debug(f"Using global default query expansion model: {expansion_model}")
# Use the configured model
expansion_model = self.vector_stores_config.default_query_expansion_model
chat_model = f"{expansion_model.provider_id}/{expansion_model.model_id}"
# Validate that the model is available and is an LLM

View file

@ -25,6 +25,7 @@ class VectorStore(Resource):
embedding_model: str
embedding_dimension: int
vector_store_name: str | None = None
query_expansion_model: str | None = None
@property
def vector_store_id(self) -> str: