Merge branch 'main' into export_agent_dataset

This commit is contained in:
Xi Yan 2025-02-26 14:55:50 -08:00
commit 48174b5422
16 changed files with 210 additions and 213 deletions

View file

@ -803,7 +803,7 @@
}
],
"source": [
"model_id = \"meta-llama/Llama-3.1-70B-Instruct\"\n",
"model_id = \"meta-llama/Llama-3.3-70B-Instruct\"\n",
"\n",
"model_id\n"
]
@ -1688,7 +1688,7 @@
" enable_session_persistence=False,\n",
" toolgroups = [\n",
" {\n",
" \"name\": \"builtin::rag\",\n",
" \"name\": \"builtin::rag/knowledge_search\",\n",
" \"args\" : {\n",
" \"vector_db_ids\": [vector_db_id],\n",
" }\n",

View file

@ -3,6 +3,8 @@
{
"cell_type": "markdown",
"source": [
"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/meta-llama/llama-stack/blob/main/docs/notebooks/Alpha_Llama_Stack_Post_Training.ipynb)\n",
"\n",
"# [Alpha] Llama Stack Post Training\n",
"This notebook will use a real world problem (improve LLM as tax preparer) to walk through the main sets of APIs we offer with Llama stack for post training to improve the LLM performance for agentic apps (We support supervised finetune now, RLHF and knowledge distillation will come soon!).\n",
"\n",
@ -64,7 +66,7 @@
"output_type": "stream",
"name": "stdout",
"text": [
"Collecting git+https://github.com/meta-llama/llama-stack.git@hf_format_checkpointer\n",
"Collecting git+https://github.com/meta-llama/llama-stack.git\n",
" Cloning https://github.com/meta-llama/llama-stack.git (to revision hf_format_checkpointer) to /tmp/pip-req-build-j_1bxqzm\n",
" Running command git clone --filter=blob:none --quiet https://github.com/meta-llama/llama-stack.git /tmp/pip-req-build-j_1bxqzm\n",
" Running command git checkout -b hf_format_checkpointer --track origin/hf_format_checkpointer\n",
@ -76,7 +78,7 @@
}
],
"source": [
"!pip install git+https://github.com/meta-llama/llama-stack.git@hf_format_checkpointer"
"!pip install git+https://github.com/meta-llama/llama-stack.git #TODO: update this after the next pkg release"
]
},
{

View file

@ -7,12 +7,12 @@ Each agent turn follows these key steps:
1. **Initial Safety Check**: The user's input is first screened through configured safety shields
2. **Context Retrieval**:
- If RAG is enabled, the agent queries relevant documents from memory banks
- For new documents, they are first inserted into the memory bank
- Retrieved context is augmented to the user's prompt
- If RAG is enabled, the agent can choose to query relevant documents from memory banks. You can use the `instructions` field to steer the agent.
- For new documents, they are first inserted into the memory bank.
- Retrieved context is provided to the LLM as a tool response in the message history.
3. **Inference Loop**: The agent enters its main execution loop:
- The LLM receives the augmented prompt (with context and/or previous tool outputs)
- The LLM receives a user prompt (with previous tool outputs)
- The LLM generates a response, potentially with tool calls
- If tool calls are present:
- Tool inputs are safety-checked
@ -40,19 +40,16 @@ sequenceDiagram
S->>E: Input Safety Check
deactivate S
E->>M: 2.1 Query Context
M-->>E: 2.2 Retrieved Documents
loop Inference Loop
E->>L: 3.1 Augment with Context
L-->>E: 3.2 Response (with/without tool calls)
E->>L: 2.1 Augment with Context
L-->>E: 2.2 Response (with/without tool calls)
alt Has Tool Calls
E->>S: Check Tool Input
S->>T: 4.1 Execute Tool
T-->>E: 4.2 Tool Response
E->>L: 5.1 Tool Response
L-->>E: 5.2 Synthesized Response
S->>T: 3.1 Execute Tool
T-->>E: 3.2 Tool Response
E->>L: 4.1 Tool Response
L-->>E: 4.2 Synthesized Response
end
opt Stop Conditions
@ -64,7 +61,7 @@ sequenceDiagram
end
E->>S: Output Safety Check
S->>U: 6. Final Response
S->>U: 5. Final Response
```
Each step in this process can be monitored and controlled through configurations. Here's an example that demonstrates monitoring the agent's execution:
@ -77,7 +74,10 @@ agent_config = AgentConfig(
instructions="You are a helpful assistant",
# Enable both RAG and tool usage
toolgroups=[
{"name": "builtin::rag", "args": {"vector_db_ids": ["my_docs"]}},
{
"name": "builtin::rag/knowledge_search",
"args": {"vector_db_ids": ["my_docs"]},
},
"builtin::code_interpreter",
],
# Configure safety

View file

@ -91,7 +91,7 @@ agent_config = AgentConfig(
enable_session_persistence=False,
toolgroups=[
{
"name": "builtin::rag",
"name": "builtin::rag/knowledge_search",
"args": {
"vector_db_ids": [vector_db_id],
},

View file

@ -13,6 +13,13 @@
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
from docutils import nodes
import tomli # Import tomli for TOML parsing
from pathlib import Path
# Read version from pyproject.toml
with Path(__file__).parent.parent.parent.joinpath("pyproject.toml").open("rb") as f:
pyproject = tomli.load(f)
llama_stack_version = pyproject["project"]["version"]
project = "llama-stack"
copyright = "2025, Meta"
@ -66,6 +73,7 @@ myst_enable_extensions = [
myst_substitutions = {
"docker_hub": "https://hub.docker.com/repository/docker/llamastack",
"llama_stack_version": llama_stack_version,
}
suppress_warnings = ['myst.header']

View file

@ -243,7 +243,7 @@ agent_config = AgentConfig(
# Define tools available to the agent
toolgroups=[
{
"name": "builtin::rag",
"name": "builtin::rag/knowledge_search",
"args": {
"vector_db_ids": [vector_db_id],
},

View file

@ -1,8 +1,7 @@
```{admonition} News
:class: tip
Llama Stack 0.1.4 is now available! See the [release notes](https://github.com/meta-llama/llama-stack/releases/tag/v0.1.4) for more details.
Llama Stack {{ llama_stack_version }} is now available! See the [release notes](https://github.com/meta-llama/llama-stack/releases/tag/v{{ llama_stack_version }}) for more details.
```
# Llama Stack

View file

@ -441,7 +441,7 @@ class ToolRuntimeRouter(ToolRuntime):
vector_db_ids: List[str],
query_config: Optional[RAGQueryConfig] = None,
) -> RAGQueryResult:
return await self.routing_table.get_provider_impl("query_from_memory").query(
return await self.routing_table.get_provider_impl("knowledge_search").query(
content, vector_db_ids, query_config
)

View file

@ -33,7 +33,7 @@ class DistributionRegistry(Protocol):
REGISTER_PREFIX = "distributions:registry"
KEY_VERSION = "v7"
KEY_VERSION = "v8"
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"

View file

@ -132,7 +132,7 @@ def rag_chat_page():
},
toolgroups=[
dict(
name="builtin::rag",
name="builtin::rag/knowledge_search",
args={
"vector_db_ids": [vector_db_id for vector_db_id in selected_vector_dbs],
},

View file

@ -17,7 +17,6 @@ from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
from urllib.parse import urlparse
import httpx
from pydantic import TypeAdapter
from llama_stack.apis.agents import (
AgentConfig,
@ -62,7 +61,7 @@ from llama_stack.apis.inference import (
UserMessage,
)
from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import RAGDocument, RAGQueryConfig, ToolGroups, ToolInvocationResult, ToolRuntime
from llama_stack.apis.tools import RAGDocument, ToolGroups, ToolInvocationResult, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.models.llama.datatypes import (
BuiltinTool,
@ -70,7 +69,6 @@ from llama_stack.models.llama.datatypes import (
ToolParamDefinition,
)
from llama_stack.providers.utils.kvstore import KVStore
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
from llama_stack.providers.utils.telemetry import tracing
from .persistence import AgentPersistence
@ -84,7 +82,7 @@ def make_random_string(length: int = 8):
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
MEMORY_QUERY_TOOL = "query_from_memory"
MEMORY_QUERY_TOOL = "knowledge_search"
WEB_SEARCH_TOOL = "web_search"
RAG_TOOL_GROUP = "builtin::rag"
@ -499,111 +497,18 @@ class ChatAgent(ShieldRunnerMixin):
# TODO: simplify all of this code, it can be simpler
toolgroup_args = {}
toolgroups = set()
for toolgroup in self.agent_config.toolgroups:
for toolgroup in self.agent_config.toolgroups + (toolgroups_for_turn or []):
if isinstance(toolgroup, AgentToolGroupWithArgs):
toolgroups.add(toolgroup.name)
toolgroup_args[toolgroup.name] = toolgroup.args
tool_group_name, tool_name = self._parse_toolgroup_name(toolgroup.name)
toolgroups.add(tool_group_name)
toolgroup_args[tool_group_name] = toolgroup.args
else:
toolgroups.add(toolgroup)
if toolgroups_for_turn:
for toolgroup in toolgroups_for_turn:
if isinstance(toolgroup, AgentToolGroupWithArgs):
toolgroups.add(toolgroup.name)
toolgroup_args[toolgroup.name] = toolgroup.args
else:
toolgroups.add(toolgroup)
tool_defs, tool_to_group = await self._get_tool_defs(toolgroups_for_turn)
if documents:
await self.handle_documents(session_id, documents, input_messages, tool_defs)
if RAG_TOOL_GROUP in toolgroups and len(input_messages) > 0:
with tracing.span(MEMORY_QUERY_TOOL) as span:
step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepStartPayload(
step_type=StepType.tool_execution.value,
step_id=step_id,
)
)
)
args = toolgroup_args.get(RAG_TOOL_GROUP, {})
vector_db_ids = args.get("vector_db_ids", [])
query_config = args.get("query_config")
if query_config:
query_config = TypeAdapter(RAGQueryConfig).validate_python(query_config)
else:
# handle someone passing an empty dict
query_config = RAGQueryConfig()
session_info = await self.storage.get_session_info(session_id)
# if the session has a memory bank id, let the memory tool use it
if session_info.vector_db_id:
vector_db_ids.append(session_info.vector_db_id)
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.tool_execution.value,
step_id=step_id,
delta=ToolCallDelta(
parse_status=ToolCallParseStatus.succeeded,
tool_call=ToolCall(
call_id="",
tool_name=MEMORY_QUERY_TOOL,
arguments={},
),
),
)
)
)
result = await self.tool_runtime_api.rag_tool.query(
content=concat_interleaved_content([msg.content for msg in input_messages]),
vector_db_ids=vector_db_ids,
query_config=query_config,
)
retrieved_context = result.content
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.tool_execution.value,
step_id=step_id,
step_details=ToolExecutionStep(
step_id=step_id,
turn_id=turn_id,
tool_calls=[
ToolCall(
call_id="",
tool_name=MEMORY_QUERY_TOOL,
arguments={},
)
],
tool_responses=[
ToolResponse(
call_id="",
tool_name=MEMORY_QUERY_TOOL,
content=retrieved_context or [],
metadata=result.metadata,
)
],
),
)
)
)
span.set_attribute("input", [m.model_dump_json() for m in input_messages])
span.set_attribute("output", retrieved_context)
span.set_attribute("tool_name", MEMORY_QUERY_TOOL)
# append retrieved_context to the last user message
for message in input_messages[::-1]:
if isinstance(message, UserMessage):
message.context = retrieved_context
break
output_attachments = []
n_iter = 0
@ -631,9 +536,7 @@ class ChatAgent(ShieldRunnerMixin):
async for chunk in await self.inference_api.chat_completion(
self.agent_config.model,
input_messages,
tools=[
tool for tool in tool_defs.values() if tool_to_group.get(tool.tool_name, None) != RAG_TOOL_GROUP
],
tools=tool_defs,
tool_prompt_format=self.agent_config.tool_config.tool_prompt_format,
response_format=self.agent_config.response_format,
stream=True,
@ -837,7 +740,7 @@ class ChatAgent(ShieldRunnerMixin):
)
],
started_at=tool_execution_start_time,
completed_at=datetime.now(),
completed_at=datetime.now().astimezone().isoformat(),
),
)
)
@ -845,8 +748,9 @@ class ChatAgent(ShieldRunnerMixin):
# TODO: add tool-input touchpoint and a "start" event for this step also
# but that needs a lot more refactoring of Tool code potentially
if out_attachment := _interpret_content_as_attachment(result_message.content):
if (type(result_message.content) is str) and (
out_attachment := _interpret_content_as_attachment(result_message.content)
):
# NOTE: when we push this message back to the model, the model may ignore the
# attached file path etc. since the model is trained to only provide a user message
# with the summary. We keep all generated attachments and then attach them to final message
@ -858,7 +762,7 @@ class ChatAgent(ShieldRunnerMixin):
async def _get_tool_defs(
self, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None
) -> Tuple[Dict[str, ToolDefinition], Dict[str, str]]:
) -> Tuple[List[ToolDefinition], Dict[str, str]]:
# Determine which tools to include
agent_config_toolgroups = set(
(toolgroup.name if isinstance(toolgroup, AgentToolGroupWithArgs) else toolgroup)
@ -873,13 +777,13 @@ class ChatAgent(ShieldRunnerMixin):
}
)
tool_def_map = {}
tool_name_to_def = {}
tool_to_group = {}
for tool_def in self.agent_config.client_tools:
if tool_def_map.get(tool_def.name, None):
if tool_name_to_def.get(tool_def.name, None):
raise ValueError(f"Tool {tool_def.name} already exists")
tool_def_map[tool_def.name] = ToolDefinition(
tool_name_to_def[tool_def.name] = ToolDefinition(
tool_name=tool_def.name,
description=tool_def.description,
parameters={
@ -893,10 +797,17 @@ class ChatAgent(ShieldRunnerMixin):
},
)
tool_to_group[tool_def.name] = "__client_tools__"
for toolgroup_name in agent_config_toolgroups:
if toolgroup_name not in toolgroups_for_turn_set:
for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups:
if toolgroup_name_with_maybe_tool_name not in toolgroups_for_turn_set:
continue
toolgroup_name, tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name)
tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name)
if tool_name is not None and not any(tool.identifier == tool_name for tool in tools.data):
raise ValueError(
f"Tool {tool_name} not found in toolgroup {toolgroup_name}. Available tools: {', '.join([tool.identifier for tool in tools.data])}"
)
for tool_def in tools.data:
if toolgroup_name.startswith("builtin") and toolgroup_name != RAG_TOOL_GROUP:
tool_name = tool_def.identifier
@ -906,10 +817,10 @@ class ChatAgent(ShieldRunnerMixin):
else:
built_in_type = BuiltinTool(tool_name)
if tool_def_map.get(built_in_type, None):
if tool_name_to_def.get(built_in_type, None):
raise ValueError(f"Tool {built_in_type} already exists")
tool_def_map[built_in_type] = ToolDefinition(
tool_name_to_def[built_in_type] = ToolDefinition(
tool_name=built_in_type,
description=tool_def.description,
parameters={
@ -925,24 +836,42 @@ class ChatAgent(ShieldRunnerMixin):
tool_to_group[built_in_type] = tool_def.toolgroup_id
continue
if tool_def_map.get(tool_def.identifier, None):
if tool_name_to_def.get(tool_def.identifier, None):
raise ValueError(f"Tool {tool_def.identifier} already exists")
tool_def_map[tool_def.identifier] = ToolDefinition(
tool_name=tool_def.identifier,
description=tool_def.description,
parameters={
param.name: ToolParamDefinition(
param_type=param.parameter_type,
description=param.description,
required=param.required,
default=param.default,
)
for param in tool_def.parameters
},
)
tool_to_group[tool_def.identifier] = tool_def.toolgroup_id
if tool_name in (None, tool_def.identifier):
tool_name_to_def[tool_def.identifier] = ToolDefinition(
tool_name=tool_def.identifier,
description=tool_def.description,
parameters={
param.name: ToolParamDefinition(
param_type=param.parameter_type,
description=param.description,
required=param.required,
default=param.default,
)
for param in tool_def.parameters
},
)
tool_to_group[tool_def.identifier] = tool_def.toolgroup_id
return tool_def_map, tool_to_group
return list(tool_name_to_def.values()), tool_to_group
def _parse_toolgroup_name(self, toolgroup_name_with_maybe_tool_name: str) -> tuple[str, Optional[str]]:
"""Parse a toolgroup name into its components.
Args:
toolgroup_name: The toolgroup name to parse (e.g. "builtin::rag/knowledge_search")
Returns:
A tuple of (tool_type, tool_group, tool_name)
"""
split_names = toolgroup_name_with_maybe_tool_name.split("/")
if len(split_names) == 2:
# e.g. "builtin::rag"
tool_group, tool_name = split_names
else:
tool_group, tool_name = split_names[0], None
return tool_group, tool_name
async def handle_documents(
self,
@ -951,8 +880,8 @@ class ChatAgent(ShieldRunnerMixin):
input_messages: List[Message],
tool_defs: Dict[str, ToolDefinition],
) -> None:
memory_tool = tool_defs.get(MEMORY_QUERY_TOOL, None)
code_interpreter_tool = tool_defs.get(BuiltinTool.code_interpreter, None)
memory_tool = any(tool_def.tool_name == MEMORY_QUERY_TOOL for tool_def in tool_defs)
code_interpreter_tool = any(tool_def.tool_name == BuiltinTool.code_interpreter for tool_def in tool_defs)
content_items = []
url_items = []
pattern = re.compile("^(https?://|file://|data:)")
@ -1072,7 +1001,11 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa
else:
raise ValueError(f"Unsupported URL {url}")
content.append(TextContentItem(text=f'# There is a file accessible to you at "{filepath}"\n'))
content.append(
TextContentItem(
text=f'# User provided a file accessible to you at "{filepath}"\nYou can use code_interpreter to load and inspect it.'
)
)
return ToolResponseMessage(
call_id="",

View file

@ -7,6 +7,7 @@
import json
import os
import sqlite3
import threading
from datetime import datetime
from opentelemetry.sdk.trace import SpanProcessor
@ -17,14 +18,18 @@ class SQLiteSpanProcessor(SpanProcessor):
def __init__(self, conn_string):
"""Initialize the SQLite span processor with a connection string."""
self.conn_string = conn_string
self.conn = None
self._local = threading.local() # Thread-local storage for connections
self.setup_database()
def _get_connection(self) -> sqlite3.Connection:
"""Get the database connection."""
if self.conn is None:
self.conn = sqlite3.connect(self.conn_string, check_same_thread=False)
return self.conn
def _get_connection(self):
"""Get a thread-local database connection."""
if not hasattr(self._local, "conn"):
try:
self._local.conn = sqlite3.connect(self.conn_string)
except Exception as e:
print(f"Error connecting to SQLite database: {e}")
raise e
return self._local.conn
def setup_database(self):
"""Create the necessary tables if they don't exist."""
@ -168,9 +173,14 @@ class SQLiteSpanProcessor(SpanProcessor):
def shutdown(self):
"""Cleanup any resources."""
if self.conn:
self.conn.close()
self.conn = None
# We can't access other threads' connections, so we just close our own
if hasattr(self._local, "conn"):
try:
self._local.conn.close()
except Exception as e:
print(f"Error closing SQLite connection: {e}")
finally:
del self._local.conn
def force_flush(self, timeout_millis=30000):
"""Force export of spans."""

View file

@ -10,6 +10,8 @@ import secrets
import string
from typing import Any, Dict, List, Optional
from pydantic import TypeAdapter
from llama_stack.apis.common.content_types import (
URL,
InterleavedContent,
@ -23,6 +25,7 @@ from llama_stack.apis.tools import (
RAGToolRuntime,
ToolDef,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
@ -120,9 +123,14 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
# sort by score
chunks, scores = zip(*sorted(zip(chunks, scores, strict=False), key=lambda x: x[1], reverse=True), strict=False)
chunks = chunks[: query_config.max_chunks]
tokens = 0
picked = []
for c in chunks:
picked = [
TextContentItem(
text=f"knowledge_search tool found {len(chunks)} chunks:\nBEGIN of knowledge_search tool results.\n"
)
]
for i, c in enumerate(chunks):
metadata = c.metadata
tokens += metadata["token_count"]
if tokens > query_config.max_tokens_in_context:
@ -132,20 +140,13 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
break
picked.append(
TextContentItem(
text=f"id:{metadata['document_id']}; content:{c.content}",
text=f"Result {i + 1}:\nDocument_id:{metadata['document_id'][:5]}\nContent: {c.content}\n",
)
)
picked.append(TextContentItem(text="END of knowledge_search tool results.\n"))
return RAGQueryResult(
content=[
TextContentItem(
text="Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n",
),
*picked,
TextContentItem(
text="\n=== END-RETRIEVED-CONTEXT ===\n",
),
],
content=picked,
metadata={
"document_ids": [c.metadata["document_id"] for c in chunks[: len(picked)]],
},
@ -158,17 +159,40 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
# by the LLM. The method is only implemented so things like /tools can list without
# encountering fatals.
return [
ToolDef(
name="query_from_memory",
description="Retrieve context from memory",
),
ToolDef(
name="insert_into_memory",
description="Insert documents into memory",
),
ToolDef(
name="knowledge_search",
description="Search for information in a database.",
parameters=[
ToolParameter(
name="query",
description="The query to search for. Can be a natural language sentence or keywords.",
parameter_type="string",
),
],
),
]
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
raise RuntimeError(
"This toolgroup should not be called generically but only through specific methods of the RAGToolRuntime protocol"
vector_db_ids = kwargs.get("vector_db_ids", [])
query_config = kwargs.get("query_config")
if query_config:
query_config = TypeAdapter(RAGQueryConfig).validate_python(query_config)
else:
# handle someone passing an empty dict
query_config = RAGQueryConfig()
query = kwargs["query"]
result = await self.query(
content=query,
vector_db_ids=vector_db_ids,
query_config=query_config,
)
return ToolInvocationResult(
content=result.content,
metadata=result.metadata,
)

View file

@ -74,6 +74,7 @@ docs = [
"sphinxcontrib.redoc",
"sphinxcontrib.video",
"sphinxcontrib.mermaid",
"tomli",
]
[project.urls]

View file

@ -96,7 +96,7 @@ def agent_config(llama_stack_client, text_model_id):
sampling_params={
"strategy": {
"type": "top_p",
"temperature": 1.0,
"temperature": 0.0001,
"top_p": 0.9,
},
},
@ -441,7 +441,8 @@ def xtest_override_system_message_behavior(llama_stack_client, agent_config):
assert "get_boiling_point" in logs_str
def test_rag_agent(llama_stack_client, agent_config):
@pytest.mark.parametrize("rag_tool_name", ["builtin::rag/knowledge_search", "builtin::rag"])
def test_rag_agent(llama_stack_client, agent_config, rag_tool_name):
urls = ["chat.rst", "llama3.rst", "memory_optimizations.rst", "lora_finetune.rst"]
documents = [
Document(
@ -469,7 +470,7 @@ def test_rag_agent(llama_stack_client, agent_config):
**agent_config,
"toolgroups": [
dict(
name="builtin::rag",
name=rag_tool_name,
args={
"vector_db_ids": [vector_db_id],
},
@ -483,10 +484,6 @@ def test_rag_agent(llama_stack_client, agent_config):
"Instead of the standard multi-head attention, what attention type does Llama3-8B use?",
"grouped",
),
(
"What `tune` command to use for getting access to Llama3-8B-Instruct ?",
"download",
),
]
for prompt, expected_kw in user_prompts:
response = rag_agent.create_turn(
@ -496,23 +493,36 @@ def test_rag_agent(llama_stack_client, agent_config):
)
# rag is called
tool_execution_step = next(step for step in response.steps if step.step_type == "tool_execution")
assert tool_execution_step.tool_calls[0].tool_name == "query_from_memory"
assert tool_execution_step.tool_calls[0].tool_name == "knowledge_search"
# document ids are present in metadata
assert "num-0" in tool_execution_step.tool_responses[0].metadata["document_ids"]
assert expected_kw in response.output_message.content.lower()
assert all(
doc_id.startswith("num-") for doc_id in tool_execution_step.tool_responses[0].metadata["document_ids"]
)
if expected_kw:
assert expected_kw in response.output_message.content.lower()
def test_rag_and_code_agent(llama_stack_client, agent_config):
urls = ["chat.rst"]
documents = [
documents = []
documents.append(
Document(
document_id=f"num-{i}",
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
mime_type="text/plain",
document_id="nba_wiki",
content="The NBA was created on August 3, 1949, with the merger of the Basketball Association of America (BAA) and the National Basketball League (NBL).",
metadata={},
)
for i, url in enumerate(urls)
]
)
documents.append(
Document(
document_id="perplexity_wiki",
content="""Perplexity the company was founded in 2022 by Aravind Srinivas, Andy Konwinski, Denis Yarats and Johnny Ho, engineers with backgrounds in back-end systems, artificial intelligence (AI) and machine learning:
Srinivas, the CEO, worked at OpenAI as an AI researcher.
Konwinski was among the founding team at Databricks.
Yarats, the CTO, was an AI research scientist at Meta.
Ho, the CSO, worked as an engineer at Quora, then as a quantitative trader on Wall Street.[5]""",
metadata={},
)
)
vector_db_id = f"test-vector-db-{uuid4()}"
llama_stack_client.vector_dbs.register(
vector_db_id=vector_db_id,
@ -528,7 +538,7 @@ def test_rag_and_code_agent(llama_stack_client, agent_config):
**agent_config,
"toolgroups": [
dict(
name="builtin::rag",
name="builtin::rag/knowledge_search",
args={"vector_db_ids": [vector_db_id]},
),
"builtin::code_interpreter",
@ -546,24 +556,34 @@ def test_rag_and_code_agent(llama_stack_client, agent_config):
"Here is a csv file, can you describe it?",
[inflation_doc],
"code_interpreter",
"",
),
(
"What are the top 5 topics that were explained? Only list succinct bullet points.",
"when was Perplexity the company founded?",
[],
"query_from_memory",
"knowledge_search",
"2022",
),
(
"when was the nba created?",
[],
"knowledge_search",
"1949",
),
]
for prompt, docs, tool_name in user_prompts:
for prompt, docs, tool_name, expected_kw in user_prompts:
session_id = agent.create_session(f"test-session-{uuid4()}")
response = agent.create_turn(
messages=[{"role": "user", "content": prompt}],
session_id=session_id,
documents=docs,
stream=False,
)
logs = [str(log) for log in EventLogger().log(response) if log is not None]
logs_str = "".join(logs)
assert f"Tool:{tool_name}" in logs_str
tool_execution_step = next(step for step in response.steps if step.step_type == "tool_execution")
assert tool_execution_step.tool_calls[0].tool_name == tool_name
if expected_kw:
assert expected_kw in response.output_message.content.lower()
def test_create_turn_response(llama_stack_client, agent_config):

4
uv.lock generated
View file

@ -1,5 +1,4 @@
version = 1
revision = 1
requires-python = ">=3.10"
resolution-markers = [
"(python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux')",
@ -913,6 +912,7 @@ docs = [
{ name = "sphinxcontrib-mermaid" },
{ name = "sphinxcontrib-redoc" },
{ name = "sphinxcontrib-video" },
{ name = "tomli" },
]
test = [
{ name = "aiosqlite" },
@ -971,13 +971,13 @@ requires-dist = [
{ name = "sphinxcontrib-redoc", marker = "extra == 'docs'" },
{ name = "sphinxcontrib-video", marker = "extra == 'docs'" },
{ name = "termcolor" },
{ name = "tomli", marker = "extra == 'docs'" },
{ name = "torch", marker = "extra == 'test'", specifier = ">=2.6.0", index = "https://download.pytorch.org/whl/cpu" },
{ name = "torchvision", marker = "extra == 'test'", specifier = ">=0.21.0", index = "https://download.pytorch.org/whl/cpu" },
{ name = "types-requests", marker = "extra == 'dev'" },
{ name = "types-setuptools", marker = "extra == 'dev'" },
{ name = "uvicorn", marker = "extra == 'dev'" },
]
provides-extras = ["dev", "test", "docs"]
[[package]]
name = "llama-stack-client"