diff --git a/tests/unit/rag/test_rag_query.py b/tests/unit/rag/test_rag_query.py index a9149541a..05ccecb99 100644 --- a/tests/unit/rag/test_rag_query.py +++ b/tests/unit/rag/test_rag_query.py @@ -65,7 +65,15 @@ class TestRagQuery: RAGQueryConfig(mode="invalid_mode") async def test_query_accepts_valid_modes(self): - RAGQueryConfig() # Test default (vector) - RAGQueryConfig(mode="vector") # Test vector - RAGQueryConfig(mode="keyword") # Test keyword - RAGQueryConfig(mode="hybrid") # Test hybrid + default_config = RAGQueryConfig() # Test default (vector) + assert default_config.mode == "vector" + vector_config = RAGQueryConfig(mode="vector") # Test vector + assert vector_config.mode == "vector" + keyword_config = RAGQueryConfig(mode="keyword") # Test keyword + assert keyword_config.mode == "keyword" + hybrid_config = RAGQueryConfig(mode="hybrid") # Test hybrid + assert hybrid_config.mode == "hybrid" + + # Test that invalid mode raises an error + with pytest.raises(ValueError): + RAGQueryConfig(mode="wrong_mode")