llama-stack-mirror/examples/rag_agent.py
Antony Sallas 17e74251e2 updated with vllm based values
rh-pre-commit.version: 2.3.2
rh-pre-commit.check-secrets: ENABLED
2025-10-22 18:20:32 +08:00

80 lines
No EOL
2.2 KiB
Python

from llama_stack_client import LlamaStackClient
from llama_stack_client import Agent, AgentEventLogger
from llama_stack_client.types import Document
import uuid
client = LlamaStackClient(base_url="http://localhost:8321",
default_headers={
"X-Telemetry-Service": "llama-stack-rag-agent",
"X-Telemetry-Version": "1.0.0",
})
# Create a vector database instance
embed_lm = next(m for m in client.models.list() if m.model_type == "embedding")
embedding_model = embed_lm.identifier
vector_db_name = f"v{uuid.uuid4().hex}"
# The VectorDB API is deprecated; the server now returns its own authoritative ID.
# We capture the correct ID from the response's .identifier attribute.
vector_store = client.vector_stores.create(
name=vector_db_name,
embedding_model=embedding_model,
)
vector_db_id = vector_store.id
# Create Documents
urls = [
"memory_optimizations.rst",
"chat.rst",
"llama3.rst",
"qat_finetune.rst",
"lora_finetune.rst",
]
documents = [
Document(
document_id=f"num-{i}",
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
mime_type="text/plain",
metadata={},
)
for i, url in enumerate(urls)
]
# Insert documents
client.tool_runtime.rag_tool.insert(
documents=documents,
vector_db_id=vector_db_id,
chunk_size_in_tokens=512,
)
# Get the model being served
llm = next(
m
for m in client.models.list()
if m.model_type == "llm" and m.provider_id == "vllm"
)
model = llm.identifier
# Create the RAG agent
rag_agent = Agent(
client,
model=model,
instructions="You are a helpful assistant. Use the RAG tool to answer questions as needed.",
tools=[
{
"name": "builtin::rag/knowledge_search",
"args": {"vector_db_ids": [vector_db_id]},
}
],
)
session_id = rag_agent.create_session(session_name=f"s{uuid.uuid4().hex}")
turns = ["what is torchtune", "tell me about lora"]
for t in turns:
print("user>", t)
stream = rag_agent.create_turn(
messages=[{"role": "user", "content": t}], session_id=session_id, stream=True
)
for event in AgentEventLogger().log(stream):
event.print()