mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-25 13:49:51 +00:00
fix: search mode validation for rag query (#2857)
# What does this PR do? <!-- Provide a short summary of what this PR does and why. Link to relevant issues if applicable. --> I noticed a few issues with my implementation of the search mode validation for RagQuery. This PR replaces the check for search mode in RagQuery with a Literal. There were issues before with ``` TypeError: Object of type RAGSearchMode is not JSON serializable ``` When using ``` query_config = RAGQueryConfig(max_chunks=6, mode="vector").model_dump() ``` It also fixes the fact that despite user input "vector" was always the used search mode. <!-- If resolving an issue, uncomment and update the line below --> <!-- Closes #[issue-number] --> ## Test Plan <!-- Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.* --> Verify that a chosen search mode works when using Rag Query or use below agent config: ``` agent = Agent( client, model=model_id, instructions="You are a helpful assistant", tools=[ { "name": "builtin::rag/knowledge_search", "args": { "vector_db_ids": [vector_db_id], "query_config": { "mode": "keyword", "max_chunks": 6 } }, } ], ) ``` Running Unit Tests: ``` uv sync --extra dev uv run pytest tests/unit/rag/test_rag_query.py -v ```
This commit is contained in:
parent
2aba2c1236
commit
8353ad4981
2 changed files with 14 additions and 6 deletions
|
@ -65,7 +65,15 @@ class TestRagQuery:
|
|||
RAGQueryConfig(mode="invalid_mode")
|
||||
|
||||
async def test_query_accepts_valid_modes(self):
|
||||
RAGQueryConfig() # Test default (vector)
|
||||
RAGQueryConfig(mode="vector") # Test vector
|
||||
RAGQueryConfig(mode="keyword") # Test keyword
|
||||
RAGQueryConfig(mode="hybrid") # Test hybrid
|
||||
default_config = RAGQueryConfig() # Test default (vector)
|
||||
assert default_config.mode == "vector"
|
||||
vector_config = RAGQueryConfig(mode="vector") # Test vector
|
||||
assert vector_config.mode == "vector"
|
||||
keyword_config = RAGQueryConfig(mode="keyword") # Test keyword
|
||||
assert keyword_config.mode == "keyword"
|
||||
hybrid_config = RAGQueryConfig(mode="hybrid") # Test hybrid
|
||||
assert hybrid_config.mode == "hybrid"
|
||||
|
||||
# Test that invalid mode raises an error
|
||||
with pytest.raises(ValueError):
|
||||
RAGQueryConfig(mode="wrong_mode")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue