mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
See https://github.com/meta-llama/llama-stack/issues/827 for the broader design. Third part: - we need to make `tool_runtime.rag_tool.query_context()` and `tool_runtime.rag_tool.insert_documents()` methods work smoothly with complete type safety. To that end, we introduce a sub-resource path `tool-runtime/rag-tool/` and make changes to the resolver to make things work. - the PR updates the agents implementation to directly call these typed APIs for memory accesses rather than going through the complex, untyped "invoke_tool" API. the code looks much nicer and simpler (expectedly.) - there are a number of hacks in the server resolver implementation still, we will live with some and fix some Note that we must make sure the client SDKs are able to handle this subresource complexity also. Stainless has support for subresources, so this should be possible but beware. ## Test Plan Our RAG test is sad (doesn't actually test for actual RAG output) but I verified that the implementation works. I will work on fixing the RAG test afterwards. ```bash pytest -s -v tests/agents/test_agents.py -k "rag and together" --safety-shield=meta-llama/Llama-Guard-3-8B ```
105 lines
3.4 KiB
Python
105 lines
3.4 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import json
|
|
from typing import List
|
|
|
|
import pytest
|
|
import requests
|
|
from pydantic import TypeAdapter
|
|
|
|
from llama_stack.apis.tools import (
|
|
DefaultRAGQueryGeneratorConfig,
|
|
RAGDocument,
|
|
RAGQueryConfig,
|
|
RAGQueryResult,
|
|
)
|
|
from llama_stack.apis.vector_dbs import VectorDB
|
|
from llama_stack.providers.utils.memory.vector_store import interleaved_content_as_str
|
|
|
|
|
|
class TestRAGToolEndpoints:
|
|
@pytest.fixture
|
|
def base_url(self) -> str:
|
|
return "http://localhost:8321/v1" # Adjust port if needed
|
|
|
|
@pytest.fixture
|
|
def sample_documents(self) -> List[RAGDocument]:
|
|
return [
|
|
RAGDocument(
|
|
document_id="doc1",
|
|
content="Python is a high-level programming language.",
|
|
metadata={"category": "programming", "difficulty": "beginner"},
|
|
),
|
|
RAGDocument(
|
|
document_id="doc2",
|
|
content="Machine learning is a subset of artificial intelligence.",
|
|
metadata={"category": "AI", "difficulty": "advanced"},
|
|
),
|
|
RAGDocument(
|
|
document_id="doc3",
|
|
content="Data structures are fundamental to computer science.",
|
|
metadata={"category": "computer science", "difficulty": "intermediate"},
|
|
),
|
|
]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_rag_workflow(
|
|
self, base_url: str, sample_documents: List[RAGDocument]
|
|
):
|
|
vector_db_payload = {
|
|
"vector_db_id": "test_vector_db",
|
|
"embedding_model": "all-MiniLM-L6-v2",
|
|
"embedding_dimension": 384,
|
|
}
|
|
|
|
response = requests.post(f"{base_url}/vector-dbs", json=vector_db_payload)
|
|
assert response.status_code == 200
|
|
vector_db = VectorDB(**response.json())
|
|
|
|
insert_payload = {
|
|
"documents": [
|
|
json.loads(doc.model_dump_json()) for doc in sample_documents
|
|
],
|
|
"vector_db_id": vector_db.identifier,
|
|
"chunk_size_in_tokens": 512,
|
|
}
|
|
|
|
response = requests.post(
|
|
f"{base_url}/tool-runtime/rag-tool/insert-documents",
|
|
json=insert_payload,
|
|
)
|
|
assert response.status_code == 200
|
|
|
|
query = "What is Python?"
|
|
query_config = RAGQueryConfig(
|
|
query_generator_config=DefaultRAGQueryGeneratorConfig(),
|
|
max_tokens_in_context=4096,
|
|
max_chunks=2,
|
|
)
|
|
|
|
query_payload = {
|
|
"content": query,
|
|
"query_config": json.loads(query_config.model_dump_json()),
|
|
"vector_db_ids": [vector_db.identifier],
|
|
}
|
|
|
|
response = requests.post(
|
|
f"{base_url}/tool-runtime/rag-tool/query-context",
|
|
json=query_payload,
|
|
)
|
|
assert response.status_code == 200
|
|
result = response.json()
|
|
result = TypeAdapter(RAGQueryResult).validate_python(result)
|
|
|
|
content_str = interleaved_content_as_str(result.content)
|
|
print(f"content: {content_str}")
|
|
assert len(content_str) > 0
|
|
assert "Python" in content_str
|
|
|
|
# Clean up: Delete the vector DB
|
|
response = requests.delete(f"{base_url}/vector-dbs/{vector_db.identifier}")
|
|
assert response.status_code == 200
|