fix: search mode validation for rag query (#2857)

# What does this PR do?
<!-- Provide a short summary of what this PR does and why. Link to
relevant issues if applicable. -->
I noticed a few issues with my implementation of the search mode
validation for RagQuery.
This PR replaces the check for search mode in RagQuery with a Literal. 
There were issues before with
```
TypeError: Object of type RAGSearchMode is not JSON serializable
```
When using 
```
query_config = RAGQueryConfig(max_chunks=6, mode="vector").model_dump()
```

It also fixes the fact that despite user input "vector" was always the
used search mode.
<!-- If resolving an issue, uncomment and update the line below -->
<!-- Closes #[issue-number] -->

## Test Plan
<!-- Describe the tests you ran to verify your changes with result
summaries. *Provide clear instructions so the plan can be easily
re-executed.* -->

Verify that a chosen search mode works when using Rag Query or use below
agent config:
```
agent = Agent(
    client,
    model=model_id,
    instructions="You are a helpful assistant",
    tools=[
        {
            "name": "builtin::rag/knowledge_search",
            "args": {
                "vector_db_ids": [vector_db_id],
                "query_config": {
                    "mode": "keyword",
                    "max_chunks": 6
                }
            },
        }
    ],
)
```

Running Unit Tests:
```
uv sync --extra dev
uv run pytest tests/unit/rag/test_rag_query.py -v
```
This commit is contained in:
Mark Campbell 2025-07-23 19:25:12 +01:00 committed by GitHub
parent 2aba2c1236
commit 8353ad4981
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 14 additions and 6 deletions

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from enum import Enum from enum import Enum, StrEnum
from typing import Annotated, Any, Literal, Protocol from typing import Annotated, Any, Literal, Protocol
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
@ -88,7 +88,7 @@ class RAGQueryGenerator(Enum):
@json_schema_type @json_schema_type
class RAGSearchMode(Enum): class RAGSearchMode(StrEnum):
""" """
Search modes for RAG query retrieval: Search modes for RAG query retrieval:
- VECTOR: Uses vector similarity search for semantic matching - VECTOR: Uses vector similarity search for semantic matching

View file

@ -65,7 +65,15 @@ class TestRagQuery:
RAGQueryConfig(mode="invalid_mode") RAGQueryConfig(mode="invalid_mode")
async def test_query_accepts_valid_modes(self): async def test_query_accepts_valid_modes(self):
RAGQueryConfig() # Test default (vector) default_config = RAGQueryConfig() # Test default (vector)
RAGQueryConfig(mode="vector") # Test vector assert default_config.mode == "vector"
RAGQueryConfig(mode="keyword") # Test keyword vector_config = RAGQueryConfig(mode="vector") # Test vector
RAGQueryConfig(mode="hybrid") # Test hybrid assert vector_config.mode == "vector"
keyword_config = RAGQueryConfig(mode="keyword") # Test keyword
assert keyword_config.mode == "keyword"
hybrid_config = RAGQueryConfig(mode="hybrid") # Test hybrid
assert hybrid_config.mode == "hybrid"
# Test that invalid mode raises an error
with pytest.raises(ValueError):
RAGQueryConfig(mode="wrong_mode")