diff --git a/llama_stack/apis/tools/rag_tool.py b/llama_stack/apis/tools/rag_tool.py index cfaa49488..1d5e7b6cb 100644 --- a/llama_stack/apis/tools/rag_tool.py +++ b/llama_stack/apis/tools/rag_tool.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from enum import Enum +from enum import Enum, StrEnum from typing import Annotated, Any, Literal, Protocol from pydantic import BaseModel, Field, field_validator @@ -88,7 +88,7 @@ class RAGQueryGenerator(Enum): @json_schema_type -class RAGSearchMode(Enum): +class RAGSearchMode(StrEnum): """ Search modes for RAG query retrieval: - VECTOR: Uses vector similarity search for semantic matching diff --git a/tests/unit/rag/test_rag_query.py b/tests/unit/rag/test_rag_query.py index a9149541a..05ccecb99 100644 --- a/tests/unit/rag/test_rag_query.py +++ b/tests/unit/rag/test_rag_query.py @@ -65,7 +65,15 @@ class TestRagQuery: RAGQueryConfig(mode="invalid_mode") async def test_query_accepts_valid_modes(self): - RAGQueryConfig() # Test default (vector) - RAGQueryConfig(mode="vector") # Test vector - RAGQueryConfig(mode="keyword") # Test keyword - RAGQueryConfig(mode="hybrid") # Test hybrid + default_config = RAGQueryConfig() # Test default (vector) + assert default_config.mode == "vector" + vector_config = RAGQueryConfig(mode="vector") # Test vector + assert vector_config.mode == "vector" + keyword_config = RAGQueryConfig(mode="keyword") # Test keyword + assert keyword_config.mode == "keyword" + hybrid_config = RAGQueryConfig(mode="hybrid") # Test hybrid + assert hybrid_config.mode == "hybrid" + + # Test that invalid mode raises an error + with pytest.raises(ValueError): + RAGQueryConfig(mode="wrong_mode")