add a RAG test to client SDK

This commit is contained in:
Dinesh Yeduguru 2024-12-26 09:13:34 -08:00
parent c76f5f418f
commit 97798c8442
3 changed files with 105 additions and 9 deletions

View file

@ -15,6 +15,7 @@ from llama_stack_client.lib.agents.event_logger import EventLogger
from llama_stack_client.types import ToolResponseMessage
from llama_stack_client.types.agent_create_params import AgentConfig
from llama_stack_client.types.custom_tool_def import Parameter
from llama_stack_client.types.memory_insert_params import Document
from llama_stack_client.types.shared.completion_message import CompletionMessage
@ -230,3 +231,68 @@ def test_custom_tool(llama_stack_client, agent_config):
logs_str = "".join(logs)
assert "-100" in logs_str
assert "CustomTool" in logs_str
def test_rag_agent(llama_stack_client, agent_config):
urls = [
"memory_optimizations.rst",
"chat.rst",
"llama3.rst",
"datasets.rst",
"qat_finetune.rst",
"lora_finetune.rst",
]
documents = [
Document(
document_id=f"num-{i}",
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
mime_type="text/plain",
metadata={},
)
for i, url in enumerate(urls)
]
llama_stack_client.memory_banks.register(
memory_bank_id="test_bank",
params={
"memory_bank_type": "vector",
"embedding_model": "all-MiniLM-L6-v2",
"chunk_size_in_tokens": 512,
"overlap_size_in_tokens": 64,
},
provider_id="faiss",
)
# insert some documents
llama_stack_client.memory.insert(
bank_id="test_bank",
documents=documents,
)
agent_config = {
**agent_config,
"preprocessing_tools": ["memory-tool"],
}
agent = Agent(llama_stack_client, agent_config)
session_id = agent.create_session(f"test-session-{uuid4()}")
user_prompts = [
"What are the top 5 topics that were explained in the documentation? Only list succinct bullet points.",
"Was anything related to 'Llama3' discussed, if so what?",
"Tell me how to use LoRA",
"What about Quantization?",
]
for prompt in user_prompts:
response = agent.create_turn(
messages=[
{
"role": "user",
"content": prompt,
}
],
session_id=session_id,
)
logs = [str(log) for log in EventLogger().log(response) if log is not None]
logs_str = "".join(logs)
assert "Tool:memory-tool" in logs_str