From 8353ad498117a1bcc2a40406faa7734bb18f47ea Mon Sep 17 00:00:00 2001 From: Mark Campbell Date: Wed, 23 Jul 2025 19:25:12 +0100 Subject: [PATCH] fix: search mode validation for rag query (#2857) # What does this PR do? I noticed a few issues with my implementation of the search mode validation for RagQuery. This PR replaces the check for search mode in RagQuery with a Literal. There were issues before with ``` TypeError: Object of type RAGSearchMode is not JSON serializable ``` When using ``` query_config = RAGQueryConfig(max_chunks=6, mode="vector").model_dump() ``` It also fixes the fact that despite user input "vector" was always the used search mode. ## Test Plan Verify that a chosen search mode works when using Rag Query or use below agent config: ``` agent = Agent( client, model=model_id, instructions="You are a helpful assistant", tools=[ { "name": "builtin::rag/knowledge_search", "args": { "vector_db_ids": [vector_db_id], "query_config": { "mode": "keyword", "max_chunks": 6 } }, } ], ) ``` Running Unit Tests: ``` uv sync --extra dev uv run pytest tests/unit/rag/test_rag_query.py -v ``` --- llama_stack/apis/tools/rag_tool.py | 4 ++-- tests/unit/rag/test_rag_query.py | 16 ++++++++++++---- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/llama_stack/apis/tools/rag_tool.py b/llama_stack/apis/tools/rag_tool.py index cfaa49488..1d5e7b6cb 100644 --- a/llama_stack/apis/tools/rag_tool.py +++ b/llama_stack/apis/tools/rag_tool.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from enum import Enum +from enum import Enum, StrEnum from typing import Annotated, Any, Literal, Protocol from pydantic import BaseModel, Field, field_validator @@ -88,7 +88,7 @@ class RAGQueryGenerator(Enum): @json_schema_type -class RAGSearchMode(Enum): +class RAGSearchMode(StrEnum): """ Search modes for RAG query retrieval: - VECTOR: Uses vector similarity search for semantic matching diff --git a/tests/unit/rag/test_rag_query.py b/tests/unit/rag/test_rag_query.py index a9149541a..05ccecb99 100644 --- a/tests/unit/rag/test_rag_query.py +++ b/tests/unit/rag/test_rag_query.py @@ -65,7 +65,15 @@ class TestRagQuery: RAGQueryConfig(mode="invalid_mode") async def test_query_accepts_valid_modes(self): - RAGQueryConfig() # Test default (vector) - RAGQueryConfig(mode="vector") # Test vector - RAGQueryConfig(mode="keyword") # Test keyword - RAGQueryConfig(mode="hybrid") # Test hybrid + default_config = RAGQueryConfig() # Test default (vector) + assert default_config.mode == "vector" + vector_config = RAGQueryConfig(mode="vector") # Test vector + assert vector_config.mode == "vector" + keyword_config = RAGQueryConfig(mode="keyword") # Test keyword + assert keyword_config.mode == "keyword" + hybrid_config = RAGQueryConfig(mode="hybrid") # Test hybrid + assert hybrid_config.mode == "hybrid" + + # Test that invalid mode raises an error + with pytest.raises(ValueError): + RAGQueryConfig(mode="wrong_mode")