diff --git a/tests/unit/rag/test_rag_query.py b/tests/unit/rag/test_rag_query.py index d2dd1783b..a902609a0 100644 --- a/tests/unit/rag/test_rag_query.py +++ b/tests/unit/rag/test_rag_query.py @@ -8,6 +8,7 @@ from unittest.mock import AsyncMock, MagicMock import pytest +from llama_stack.apis.tools.rag_tool import RAGQueryConfig from llama_stack.apis.vector_io import ( Chunk, ChunkMetadata, @@ -60,3 +61,14 @@ class TestRagQuery: ) assert expected_metadata_string in result.content[1].text assert result.content is not None + + async def test_query_raises_incorrect_mode(self): + with pytest.raises(ValueError): + RAGQueryConfig(mode="invalid_mode") + + @pytest.mark.asyncio + async def test_query_accepts_valid_modes(self): + RAGQueryConfig() # Test default (vector) + RAGQueryConfig(mode="vector") # Test vector + RAGQueryConfig(mode="keyword") # Test keyword + RAGQueryConfig(mode="hybrid") # Test hybrid