mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-11 19:56:03 +00:00
80 lines
No EOL
2.2 KiB
Python
80 lines
No EOL
2.2 KiB
Python
from llama_stack_client import LlamaStackClient
|
|
from llama_stack_client import Agent, AgentEventLogger
|
|
from llama_stack_client.types import Document
|
|
import uuid
|
|
|
|
client = LlamaStackClient(base_url="http://localhost:8321",
|
|
default_headers={
|
|
"X-Telemetry-Service": "llama-stack-rag-agent",
|
|
"X-Telemetry-Version": "1.0.0",
|
|
})
|
|
|
|
# Create a vector database instance
|
|
embed_lm = next(m for m in client.models.list() if m.model_type == "embedding")
|
|
embedding_model = embed_lm.identifier
|
|
vector_db_name = f"v{uuid.uuid4().hex}"
|
|
# The VectorDB API is deprecated; the server now returns its own authoritative ID.
|
|
# We capture the correct ID from the response's .identifier attribute.
|
|
vector_store = client.vector_stores.create(
|
|
name=vector_db_name,
|
|
embedding_model=embedding_model,
|
|
)
|
|
vector_db_id = vector_store.id
|
|
|
|
# Create Documents
|
|
urls = [
|
|
"memory_optimizations.rst",
|
|
"chat.rst",
|
|
"llama3.rst",
|
|
"qat_finetune.rst",
|
|
"lora_finetune.rst",
|
|
]
|
|
documents = [
|
|
Document(
|
|
document_id=f"num-{i}",
|
|
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
|
|
mime_type="text/plain",
|
|
metadata={},
|
|
)
|
|
for i, url in enumerate(urls)
|
|
]
|
|
|
|
# Insert documents
|
|
client.tool_runtime.rag_tool.insert(
|
|
documents=documents,
|
|
vector_db_id=vector_db_id,
|
|
chunk_size_in_tokens=512,
|
|
)
|
|
|
|
# Get the model being served
|
|
llm = next(
|
|
m
|
|
for m in client.models.list()
|
|
if m.model_type == "llm" and m.provider_id == "vllm"
|
|
)
|
|
model = llm.identifier
|
|
|
|
# Create the RAG agent
|
|
rag_agent = Agent(
|
|
client,
|
|
model=model,
|
|
instructions="You are a helpful assistant. Use the RAG tool to answer questions as needed.",
|
|
tools=[
|
|
{
|
|
"name": "builtin::rag/knowledge_search",
|
|
"args": {"vector_db_ids": [vector_db_id]},
|
|
}
|
|
],
|
|
)
|
|
|
|
session_id = rag_agent.create_session(session_name=f"s{uuid.uuid4().hex}")
|
|
|
|
turns = ["what is torchtune", "tell me about lora"]
|
|
|
|
for t in turns:
|
|
print("user>", t)
|
|
stream = rag_agent.create_turn(
|
|
messages=[{"role": "user", "content": t}], session_id=session_id, stream=True
|
|
)
|
|
for event in AgentEventLogger().log(stream):
|
|
event.print() |