mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-17 02:18:13 +00:00
feat: add input validation for search mode of rag query config (#2275)
# What does this PR do? Adds input validation for mode in RagQueryConfig This will prevent users from inputting search modes other than `vector` and `keyword` for the time being with `hybrid` to follow when that functionality is implemented. ## Test Plan [Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.*] ``` # Check out this PR and enter the LS directory uv sync --extra dev ``` Run the quickstart [example](https://llama-stack.readthedocs.io/en/latest/getting_started/#step-3-run-the-demo) Alter the Agent to include a query_config ``` agent = Agent( client, model=model_id, instructions="You are a helpful assistant", tools=[ { "name": "builtin::rag/knowledge_search", "args": { "vector_db_ids": [vector_db_id], "query_config": { "mode": "i-am-not-vector", # Test for non valid search mode "max_chunks": 6 } }, } ], ) ``` Ensure you get the following error: ``` 400: {'errors': [{'loc': ['mode'], 'msg': "Value error, mode must be either 'vector' or 'keyword' if supported by the vector_io provider", 'type': 'value_error'}]} ``` ## Running unit tests ``` uv sync --extra dev uv run pytest tests/unit/rag/test_rag_query.py -v ``` [//]: # (## Documentation)
This commit is contained in:
parent
958fc92b1b
commit
618ccea090
4 changed files with 52 additions and 3 deletions
|
@ -8,6 +8,7 @@ from unittest.mock import AsyncMock, MagicMock
|
|||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.tools.rag_tool import RAGQueryConfig
|
||||
from llama_stack.apis.vector_io import (
|
||||
Chunk,
|
||||
ChunkMetadata,
|
||||
|
@ -58,3 +59,14 @@ class TestRagQuery:
|
|||
)
|
||||
assert expected_metadata_string in result.content[1].text
|
||||
assert result.content is not None
|
||||
|
||||
async def test_query_raises_incorrect_mode(self):
|
||||
with pytest.raises(ValueError):
|
||||
RAGQueryConfig(mode="invalid_mode")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_accepts_valid_modes(self):
|
||||
RAGQueryConfig() # Test default (vector)
|
||||
RAGQueryConfig(mode="vector") # Test vector
|
||||
RAGQueryConfig(mode="keyword") # Test keyword
|
||||
RAGQueryConfig(mode="hybrid") # Test hybrid
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue