diff --git a/llama_stack/scripts/test_rag_via_curl.py b/llama_stack/scripts/test_rag_via_curl.py new file mode 100644 index 000000000..28d6fb601 --- /dev/null +++ b/llama_stack/scripts/test_rag_via_curl.py @@ -0,0 +1,105 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +from typing import List + +import pytest +import requests +from pydantic import TypeAdapter + +from llama_stack.apis.tools import ( + DefaultRAGQueryGeneratorConfig, + RAGDocument, + RAGQueryConfig, + RAGQueryResult, +) +from llama_stack.apis.vector_dbs import VectorDB +from llama_stack.providers.utils.memory.vector_store import interleaved_content_as_str + + +class TestRAGToolEndpoints: + @pytest.fixture + def base_url(self) -> str: + return "http://localhost:8321/v1" # Adjust port if needed + + @pytest.fixture + def sample_documents(self) -> List[RAGDocument]: + return [ + RAGDocument( + document_id="doc1", + content="Python is a high-level programming language.", + metadata={"category": "programming", "difficulty": "beginner"}, + ), + RAGDocument( + document_id="doc2", + content="Machine learning is a subset of artificial intelligence.", + metadata={"category": "AI", "difficulty": "advanced"}, + ), + RAGDocument( + document_id="doc3", + content="Data structures are fundamental to computer science.", + metadata={"category": "computer science", "difficulty": "intermediate"}, + ), + ] + + @pytest.mark.asyncio + async def test_rag_workflow( + self, base_url: str, sample_documents: List[RAGDocument] + ): + vector_db_payload = { + "vector_db_id": "test_vector_db", + "embedding_model": "all-MiniLM-L6-v2", + "embedding_dimension": 384, + } + + response = requests.post(f"{base_url}/vector-dbs", json=vector_db_payload) + assert response.status_code == 200 + vector_db = VectorDB(**response.json()) + + insert_payload = { + "documents": [ + json.loads(doc.model_dump_json()) for doc in sample_documents + ], + "vector_db_id": vector_db.identifier, + "chunk_size_in_tokens": 512, + } + + response = requests.post( + f"{base_url}/tool-runtime/rag-tool/insert-documents", + json=insert_payload, + ) + assert response.status_code == 200 + + query = "What is Python?" + query_config = RAGQueryConfig( + query_generator_config=DefaultRAGQueryGeneratorConfig(), + max_tokens_in_context=4096, + max_chunks=2, + ) + + query_payload = { + "content": query, + "query_config": json.loads(query_config.model_dump_json()), + "vector_db_ids": [vector_db.identifier], + } + + response = requests.post( + f"{base_url}/tool-runtime/rag-tool/query-context", + json=query_payload, + ) + assert response.status_code == 200 + result = response.json() + result = TypeAdapter(RAGQueryResult).validate_python(result) + + content_str = interleaved_content_as_str(result.content) + print(f"content: {content_str}") + assert len(content_str) > 0 + assert "Python" in content_str + + # Clean up: Delete the vector DB + response = requests.delete(f"{base_url}/vector-dbs/{vector_db.identifier}") + assert response.status_code == 200