mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-22 22:22:32 +00:00
add a RAG test to client SDK
This commit is contained in:
parent
c76f5f418f
commit
97798c8442
3 changed files with 105 additions and 9 deletions
|
|
@ -15,6 +15,7 @@ from llama_stack_client.lib.agents.event_logger import EventLogger
|
|||
from llama_stack_client.types import ToolResponseMessage
|
||||
from llama_stack_client.types.agent_create_params import AgentConfig
|
||||
from llama_stack_client.types.custom_tool_def import Parameter
|
||||
from llama_stack_client.types.memory_insert_params import Document
|
||||
from llama_stack_client.types.shared.completion_message import CompletionMessage
|
||||
|
||||
|
||||
|
|
@ -230,3 +231,68 @@ def test_custom_tool(llama_stack_client, agent_config):
|
|||
logs_str = "".join(logs)
|
||||
assert "-100" in logs_str
|
||||
assert "CustomTool" in logs_str
|
||||
|
||||
|
||||
def test_rag_agent(llama_stack_client, agent_config):
|
||||
urls = [
|
||||
"memory_optimizations.rst",
|
||||
"chat.rst",
|
||||
"llama3.rst",
|
||||
"datasets.rst",
|
||||
"qat_finetune.rst",
|
||||
"lora_finetune.rst",
|
||||
]
|
||||
documents = [
|
||||
Document(
|
||||
document_id=f"num-{i}",
|
||||
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
|
||||
mime_type="text/plain",
|
||||
metadata={},
|
||||
)
|
||||
for i, url in enumerate(urls)
|
||||
]
|
||||
llama_stack_client.memory_banks.register(
|
||||
memory_bank_id="test_bank",
|
||||
params={
|
||||
"memory_bank_type": "vector",
|
||||
"embedding_model": "all-MiniLM-L6-v2",
|
||||
"chunk_size_in_tokens": 512,
|
||||
"overlap_size_in_tokens": 64,
|
||||
},
|
||||
provider_id="faiss",
|
||||
)
|
||||
|
||||
# insert some documents
|
||||
llama_stack_client.memory.insert(
|
||||
bank_id="test_bank",
|
||||
documents=documents,
|
||||
)
|
||||
|
||||
agent_config = {
|
||||
**agent_config,
|
||||
"preprocessing_tools": ["memory-tool"],
|
||||
}
|
||||
agent = Agent(llama_stack_client, agent_config)
|
||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||
|
||||
user_prompts = [
|
||||
"What are the top 5 topics that were explained in the documentation? Only list succinct bullet points.",
|
||||
"Was anything related to 'Llama3' discussed, if so what?",
|
||||
"Tell me how to use LoRA",
|
||||
"What about Quantization?",
|
||||
]
|
||||
|
||||
for prompt in user_prompts:
|
||||
response = agent.create_turn(
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
}
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
logs = [str(log) for log in EventLogger().log(response) if log is not None]
|
||||
logs_str = "".join(logs)
|
||||
assert "Tool:memory-tool" in logs_str
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue