From 618ccea0908bc03914f71fdb6aca5062c8f01699 Mon Sep 17 00:00:00 2001 From: Mark Campbell Date: Mon, 14 Jul 2025 14:11:34 +0100 Subject: [PATCH] feat: add input validation for search mode of rag query config (#2275) # What does this PR do? Adds input validation for mode in RagQueryConfig This will prevent users from inputting search modes other than `vector` and `keyword` for the time being with `hybrid` to follow when that functionality is implemented. ## Test Plan [Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.*] ``` # Check out this PR and enter the LS directory uv sync --extra dev ``` Run the quickstart [example](https://llama-stack.readthedocs.io/en/latest/getting_started/#step-3-run-the-demo) Alter the Agent to include a query_config ``` agent = Agent( client, model=model_id, instructions="You are a helpful assistant", tools=[ { "name": "builtin::rag/knowledge_search", "args": { "vector_db_ids": [vector_db_id], "query_config": { "mode": "i-am-not-vector", # Test for non valid search mode "max_chunks": 6 } }, } ], ) ``` Ensure you get the following error: ``` 400: {'errors': [{'loc': ['mode'], 'msg': "Value error, mode must be either 'vector' or 'keyword' if supported by the vector_io provider", 'type': 'value_error'}]} ``` ## Running unit tests ``` uv sync --extra dev uv run pytest tests/unit/rag/test_rag_query.py -v ``` [//]: # (## Documentation) --- docs/_static/llama-stack-spec.html | 13 ++++++++++++- docs/_static/llama-stack-spec.yaml | 14 +++++++++++++- llama_stack/apis/tools/rag_tool.py | 16 +++++++++++++++- tests/unit/rag/test_rag_query.py | 12 ++++++++++++ 4 files changed, 52 insertions(+), 3 deletions(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 8021e0e55..6794d1fbb 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -14796,7 +14796,8 @@ "description": "Template for formatting each retrieved chunk in the context. Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict). Default: \"Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n\"" }, "mode": { - "type": "string", + "$ref": "#/components/schemas/RAGSearchMode", + "default": "vector", "description": "Search mode for retrieval—either \"vector\", \"keyword\", or \"hybrid\". Default \"vector\"." }, "ranker": { @@ -14831,6 +14832,16 @@ } } }, + "RAGSearchMode": { + "type": "string", + "enum": [ + "vector", + "keyword", + "hybrid" + ], + "title": "RAGSearchMode", + "description": "Search modes for RAG query retrieval: - VECTOR: Uses vector similarity search for semantic matching - KEYWORD: Uses keyword-based search for exact matching - HYBRID: Combines both vector and keyword search for better results" + }, "RRFRanker": { "type": "object", "properties": { diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index a18474646..548c5a988 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -10346,7 +10346,8 @@ components: content string), {metadata} (chunk metadata dict). Default: "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n" mode: - type: string + $ref: '#/components/schemas/RAGSearchMode' + default: vector description: >- Search mode for retrieval—either "vector", "keyword", or "hybrid". Default "vector". @@ -10373,6 +10374,17 @@ components: mapping: default: '#/components/schemas/DefaultRAGQueryGeneratorConfig' llm: '#/components/schemas/LLMRAGQueryGeneratorConfig' + RAGSearchMode: + type: string + enum: + - vector + - keyword + - hybrid + title: RAGSearchMode + description: >- + Search modes for RAG query retrieval: - VECTOR: Uses vector similarity search + for semantic matching - KEYWORD: Uses keyword-based search for exact matching + - HYBRID: Combines both vector and keyword search for better results RRFRanker: type: object properties: diff --git a/llama_stack/apis/tools/rag_tool.py b/llama_stack/apis/tools/rag_tool.py index d497fe1a7..cfaa49488 100644 --- a/llama_stack/apis/tools/rag_tool.py +++ b/llama_stack/apis/tools/rag_tool.py @@ -87,6 +87,20 @@ class RAGQueryGenerator(Enum): custom = "custom" +@json_schema_type +class RAGSearchMode(Enum): + """ + Search modes for RAG query retrieval: + - VECTOR: Uses vector similarity search for semantic matching + - KEYWORD: Uses keyword-based search for exact matching + - HYBRID: Combines both vector and keyword search for better results + """ + + VECTOR = "vector" + KEYWORD = "keyword" + HYBRID = "hybrid" + + @json_schema_type class DefaultRAGQueryGeneratorConfig(BaseModel): type: Literal["default"] = "default" @@ -128,7 +142,7 @@ class RAGQueryConfig(BaseModel): max_tokens_in_context: int = 4096 max_chunks: int = 5 chunk_template: str = "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n" - mode: str | None = None + mode: RAGSearchMode | None = RAGSearchMode.VECTOR ranker: Ranker | None = Field(default=None) # Only used for hybrid mode @field_validator("chunk_template") diff --git a/tests/unit/rag/test_rag_query.py b/tests/unit/rag/test_rag_query.py index b2baa744a..ad155c205 100644 --- a/tests/unit/rag/test_rag_query.py +++ b/tests/unit/rag/test_rag_query.py @@ -8,6 +8,7 @@ from unittest.mock import AsyncMock, MagicMock import pytest +from llama_stack.apis.tools.rag_tool import RAGQueryConfig from llama_stack.apis.vector_io import ( Chunk, ChunkMetadata, @@ -58,3 +59,14 @@ class TestRagQuery: ) assert expected_metadata_string in result.content[1].text assert result.content is not None + + async def test_query_raises_incorrect_mode(self): + with pytest.raises(ValueError): + RAGQueryConfig(mode="invalid_mode") + + @pytest.mark.asyncio + async def test_query_accepts_valid_modes(self): + RAGQueryConfig() # Test default (vector) + RAGQueryConfig(mode="vector") # Test vector + RAGQueryConfig(mode="keyword") # Test keyword + RAGQueryConfig(mode="hybrid") # Test hybrid