From 7f43bc6d765c3453ed9bbfeb96397d64b1ecfcaf Mon Sep 17 00:00:00 2001 From: Bobbins228 Date: Tue, 27 May 2025 10:14:41 +0100 Subject: [PATCH 1/2] feat: add input validation for search mode of rag query --- docs/_static/llama-stack-spec.html | 13 ++++++++++++- docs/_static/llama-stack-spec.yaml | 14 +++++++++++++- llama_stack/apis/tools/rag_tool.py | 16 +++++++++++++++- 3 files changed, 40 insertions(+), 3 deletions(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index f9e4bb38e..a5558d718 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -14676,7 +14676,8 @@ "description": "Template for formatting each retrieved chunk in the context. Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict). Default: \"Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n\"" }, "mode": { - "type": "string", + "$ref": "#/components/schemas/RAGSearchMode", + "default": "vector", "description": "Search mode for retrieval—either \"vector\", \"keyword\", or \"hybrid\". Default \"vector\"." }, "ranker": { @@ -14711,6 +14712,16 @@ } } }, + "RAGSearchMode": { + "type": "string", + "enum": [ + "vector", + "keyword", + "hybrid" + ], + "title": "RAGSearchMode", + "description": "Search modes for RAG query retrieval: - VECTOR: Uses vector similarity search for semantic matching - KEYWORD: Uses keyword-based search for exact matching - HYBRID: Combines both vector and keyword search for better results" + }, "RRFRanker": { "type": "object", "properties": { diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 9175c97fc..82b67df2a 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -10260,7 +10260,8 @@ components: content string), {metadata} (chunk metadata dict). Default: "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n" mode: - type: string + $ref: '#/components/schemas/RAGSearchMode' + default: vector description: >- Search mode for retrieval—either "vector", "keyword", or "hybrid". Default "vector". @@ -10287,6 +10288,17 @@ components: mapping: default: '#/components/schemas/DefaultRAGQueryGeneratorConfig' llm: '#/components/schemas/LLMRAGQueryGeneratorConfig' + RAGSearchMode: + type: string + enum: + - vector + - keyword + - hybrid + title: RAGSearchMode + description: >- + Search modes for RAG query retrieval: - VECTOR: Uses vector similarity search + for semantic matching - KEYWORD: Uses keyword-based search for exact matching + - HYBRID: Combines both vector and keyword search for better results RRFRanker: type: object properties: diff --git a/llama_stack/apis/tools/rag_tool.py b/llama_stack/apis/tools/rag_tool.py index d497fe1a7..cfaa49488 100644 --- a/llama_stack/apis/tools/rag_tool.py +++ b/llama_stack/apis/tools/rag_tool.py @@ -87,6 +87,20 @@ class RAGQueryGenerator(Enum): custom = "custom" +@json_schema_type +class RAGSearchMode(Enum): + """ + Search modes for RAG query retrieval: + - VECTOR: Uses vector similarity search for semantic matching + - KEYWORD: Uses keyword-based search for exact matching + - HYBRID: Combines both vector and keyword search for better results + """ + + VECTOR = "vector" + KEYWORD = "keyword" + HYBRID = "hybrid" + + @json_schema_type class DefaultRAGQueryGeneratorConfig(BaseModel): type: Literal["default"] = "default" @@ -128,7 +142,7 @@ class RAGQueryConfig(BaseModel): max_tokens_in_context: int = 4096 max_chunks: int = 5 chunk_template: str = "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n" - mode: str | None = None + mode: RAGSearchMode | None = RAGSearchMode.VECTOR ranker: Ranker | None = Field(default=None) # Only used for hybrid mode @field_validator("chunk_template") From bfef0ee0e9945fa1706cd695086bad3f88e0e4c1 Mon Sep 17 00:00:00 2001 From: Bobbins228 Date: Wed, 28 May 2025 14:34:09 +0100 Subject: [PATCH 2/2] test: add unit tests for vector db search mode validation --- tests/unit/rag/test_rag_query.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/unit/rag/test_rag_query.py b/tests/unit/rag/test_rag_query.py index d2dd1783b..a902609a0 100644 --- a/tests/unit/rag/test_rag_query.py +++ b/tests/unit/rag/test_rag_query.py @@ -8,6 +8,7 @@ from unittest.mock import AsyncMock, MagicMock import pytest +from llama_stack.apis.tools.rag_tool import RAGQueryConfig from llama_stack.apis.vector_io import ( Chunk, ChunkMetadata, @@ -60,3 +61,14 @@ class TestRagQuery: ) assert expected_metadata_string in result.content[1].text assert result.content is not None + + async def test_query_raises_incorrect_mode(self): + with pytest.raises(ValueError): + RAGQueryConfig(mode="invalid_mode") + + @pytest.mark.asyncio + async def test_query_accepts_valid_modes(self): + RAGQueryConfig() # Test default (vector) + RAGQueryConfig(mode="vector") # Test vector + RAGQueryConfig(mode="keyword") # Test keyword + RAGQueryConfig(mode="hybrid") # Test hybrid