mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-21 09:48:41 +00:00
feat: Add support for query rewrite in vector_store.search (#4171)
Some checks failed
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 0s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 0s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Integration Tests (Replay) / generate-matrix (push) Successful in 3s
API Conformance Tests / check-schema-compatibility (push) Successful in 11s
Python Package Build Test / build (3.12) (push) Successful in 15s
Python Package Build Test / build (3.13) (push) Successful in 20s
Test External API and Providers / test-external (venv) (push) Failing after 41s
Vector IO Integration Tests / test-matrix (push) Failing after 49s
UI Tests / ui-tests (22) (push) Successful in 51s
Unit Tests / unit-tests (3.13) (push) Failing after 1m27s
Unit Tests / unit-tests (3.12) (push) Failing after 1m45s
Pre-commit / pre-commit (22) (push) Failing after 2m30s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 4m22s
Some checks failed
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 0s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 0s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Integration Tests (Replay) / generate-matrix (push) Successful in 3s
API Conformance Tests / check-schema-compatibility (push) Successful in 11s
Python Package Build Test / build (3.12) (push) Successful in 15s
Python Package Build Test / build (3.13) (push) Successful in 20s
Test External API and Providers / test-external (venv) (push) Failing after 41s
Vector IO Integration Tests / test-matrix (push) Failing after 49s
UI Tests / ui-tests (22) (push) Successful in 51s
Unit Tests / unit-tests (3.13) (push) Failing after 1m27s
Unit Tests / unit-tests (3.12) (push) Failing after 1m45s
Pre-commit / pre-commit (22) (push) Failing after 2m30s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 4m22s
# What does this PR do?
Actualize query rewrite in search API, add
`default_query_expansion_model` and `query_expansion_prompt` in
`VectorStoresConfig`.
Makes `rewrite_query` parameter functional in vector store search.
- `rewrite_query=false` (default): Use original query
- `rewrite_query=true`: Expand query via LLM, or fail gracefully if no
LLM available
Adds 4 parameters to`VectorStoresConfig`:
- `default_query_expansion_model`: LLM model for query expansion
(optional)
- `query_expansion_prompt`: Custom prompt template (optional, uses
built-in default)
- `query_expansion_max_tokens`: Configurable token limit (default: 100)
- `query_expansion_temperature`: Configurable temperature (default: 0.3)
Enabled `run.yaml`:
```yaml
vector_stores:
rewrite_query_params:
model:
provider_id: "ollama"
model_id: "llama3.2:3b-instruct-fp16"
# prompt defaults to built-in
# max_tokens defaults to 100
# temperature defaults to 0.3
```
Fully customized `run.yaml`:
```yaml
vector_stores:
default_provider_id: faiss
default_embedding_model:
provider_id: sentence-transformers
model_id: nomic-ai/nomic-embed-text-v1.5
rewrite_query_params:
model:
provider_id: ollama
model_id: llama3.2:3b-instruct-fp16
prompt: "Rewrite this search query to improve retrieval results by expanding it with relevant synonyms and related terms: {query}"
max_tokens: 100
temperature: 0.3
```
## Test Plan
Added test and recording
Example script as well:
```python
import asyncio
from llama_stack_client import LlamaStackClient
from io import BytesIO
def gen_file(client, text: str=""):
file_buffer = BytesIO(text.encode('utf-8'))
file_buffer.name = "my_file.txt"
uploaded_file = client.files.create(
file=file_buffer,
purpose="assistants"
)
return uploaded_file
async def test_query_rewriting():
client = LlamaStackClient(base_url="http://0.0.0.0:8321/")
uploaded_file = gen_file(client, "banana banana apple")
uploaded_file2 = gen_file(client, "orange orange kiwi")
vs = client.vector_stores.create()
xf_vs = client.vector_stores.files.create(vector_store_id=vs.id, file_id=uploaded_file.id)
xf_vs1 = client.vector_stores.files.create(vector_store_id=vs.id, file_id=uploaded_file2.id)
response1 = client.vector_stores.search(
vector_store_id=vs.id,
query="apple",
max_num_results=3,
rewrite_query=False
)
response2 = client.vector_stores.search(
vector_store_id=vs.id,
query="kiwi",
max_num_results=3,
rewrite_query=True,
)
print(f"\n🔵 Response 1 (rewrite_query=False):\n\033[94m{response1}\033[0m")
print(f"\n🟢 Response 2 (rewrite_query=True):\n\033[92m{response2}\033[0m")
for f in [uploaded_file.id, uploaded_file2.id]:
client.files.delete(file_id=f)
client.vector_stores.delete(vector_store_id=vs.id)
if __name__ == "__main__":
asyncio.run(test_query_rewriting())
```
And see the screen shot of the server logs showing it worked.
<img width="1111" height="826" alt="Screenshot 2025-11-19 at 1 16 03 PM"
src="https://github.com/user-attachments/assets/2d188b44-1fef-4df5-b465-2d6728ca49ce"
/>
Notice the log:
```bash
Query rewritten:
'kiwi' → 'kiwi, a small brown or green fruit native to New Zealand, or a person having a fuzzy brown outer skin similar in appearance.'
```
So `kiwi` was expanded.
---------
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
Co-authored-by: Matthew Farrellee <matt@cs.wisc.edu>
This commit is contained in:
parent
ff375f1abb
commit
95b2948d11
22 changed files with 7636 additions and 20 deletions
|
|
@ -87,6 +87,7 @@ async def get_auto_router_impl(
|
|||
api_to_dep_impl["store"] = inference_store
|
||||
elif api == Api.vector_io:
|
||||
api_to_dep_impl["vector_stores_config"] = run_config.vector_stores
|
||||
api_to_dep_impl["inference_api"] = deps.get(Api.inference)
|
||||
elif api == Api.safety:
|
||||
api_to_dep_impl["safety_config"] = run_config.safety
|
||||
|
||||
|
|
|
|||
|
|
@ -16,12 +16,15 @@ from llama_stack_api import (
|
|||
Chunk,
|
||||
HealthResponse,
|
||||
HealthStatus,
|
||||
Inference,
|
||||
InterleavedContent,
|
||||
ModelNotFoundError,
|
||||
ModelType,
|
||||
ModelTypeError,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
OpenAICreateVectorStoreFileBatchRequestWithExtraBody,
|
||||
OpenAICreateVectorStoreRequestWithExtraBody,
|
||||
OpenAIUserMessageParam,
|
||||
QueryChunksResponse,
|
||||
RoutingTable,
|
||||
SearchRankingOptions,
|
||||
|
|
@ -51,10 +54,11 @@ class VectorIORouter(VectorIO):
|
|||
self,
|
||||
routing_table: RoutingTable,
|
||||
vector_stores_config: VectorStoresConfig | None = None,
|
||||
inference_api: Inference | None = None,
|
||||
) -> None:
|
||||
logger.debug("Initializing VectorIORouter")
|
||||
self.routing_table = routing_table
|
||||
self.vector_stores_config = vector_stores_config
|
||||
self.inference_api = inference_api
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logger.debug("VectorIORouter.initialize")
|
||||
|
|
@ -64,6 +68,46 @@ class VectorIORouter(VectorIO):
|
|||
logger.debug("VectorIORouter.shutdown")
|
||||
pass
|
||||
|
||||
async def _rewrite_query_for_search(self, query: str) -> str:
|
||||
"""Rewrite a search query using the configured LLM model for better retrieval results."""
|
||||
if (
|
||||
not self.vector_stores_config
|
||||
or not self.vector_stores_config.rewrite_query_params
|
||||
or not self.vector_stores_config.rewrite_query_params.model
|
||||
):
|
||||
logger.warning(
|
||||
"User is trying to use vector_store query rewriting, but it is not configured. Please configure rewrite_query_params.model in vector_stores config."
|
||||
)
|
||||
raise ValueError("Query rewriting is not available")
|
||||
|
||||
if not self.inference_api:
|
||||
logger.warning("Query rewriting requires inference API but it is not available")
|
||||
raise ValueError("Query rewriting is not available")
|
||||
|
||||
model = self.vector_stores_config.rewrite_query_params.model
|
||||
model_id = f"{model.provider_id}/{model.model_id}"
|
||||
|
||||
prompt = self.vector_stores_config.rewrite_query_params.prompt.format(query=query)
|
||||
|
||||
request = OpenAIChatCompletionRequestWithExtraBody(
|
||||
model=model_id,
|
||||
messages=[OpenAIUserMessageParam(role="user", content=prompt)],
|
||||
max_tokens=self.vector_stores_config.rewrite_query_params.max_tokens or 100,
|
||||
temperature=self.vector_stores_config.rewrite_query_params.temperature or 0.3,
|
||||
)
|
||||
|
||||
try:
|
||||
response = await self.inference_api.openai_chat_completion(request)
|
||||
content = response.choices[0].message.content
|
||||
if content is None:
|
||||
logger.error(f"LLM returned None content for query rewriting. Model: {model_id}")
|
||||
raise RuntimeError("Query rewrite failed due to an internal error")
|
||||
rewritten_query: str = content.strip()
|
||||
return rewritten_query
|
||||
except Exception as e:
|
||||
logger.error(f"Query rewrite failed with LLM call error. Model: {model_id}, Error: {e}")
|
||||
raise RuntimeError("Query rewrite failed due to an internal error") from e
|
||||
|
||||
async def _get_embedding_model_dimension(self, embedding_model_id: str) -> int:
|
||||
"""Get the embedding dimension for a specific embedding model."""
|
||||
all_models = await self.routing_table.get_all_with_type("model")
|
||||
|
|
@ -292,14 +336,24 @@ class VectorIORouter(VectorIO):
|
|||
search_mode: str | None = "vector",
|
||||
) -> VectorStoreSearchResponsePage:
|
||||
logger.debug(f"VectorIORouter.openai_search_vector_store: {vector_store_id}")
|
||||
|
||||
# Handle query rewriting at the router level
|
||||
search_query = query
|
||||
if rewrite_query:
|
||||
if isinstance(query, list):
|
||||
original_query = " ".join(query)
|
||||
else:
|
||||
original_query = query
|
||||
search_query = await self._rewrite_query_for_search(original_query)
|
||||
|
||||
provider = await self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_search_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
query=query,
|
||||
query=search_query,
|
||||
filters=filters,
|
||||
max_num_results=max_num_results,
|
||||
ranking_options=ranking_options,
|
||||
rewrite_query=rewrite_query,
|
||||
rewrite_query=False, # Already handled at router level
|
||||
search_mode=search_mode,
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue