mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 13:00:39 +00:00
Merge branch 'main' into export_agent_dataset
This commit is contained in:
commit
48174b5422
16 changed files with 210 additions and 213 deletions
|
@ -803,7 +803,7 @@
|
|||
}
|
||||
],
|
||||
"source": [
|
||||
"model_id = \"meta-llama/Llama-3.1-70B-Instruct\"\n",
|
||||
"model_id = \"meta-llama/Llama-3.3-70B-Instruct\"\n",
|
||||
"\n",
|
||||
"model_id\n"
|
||||
]
|
||||
|
@ -1688,7 +1688,7 @@
|
|||
" enable_session_persistence=False,\n",
|
||||
" toolgroups = [\n",
|
||||
" {\n",
|
||||
" \"name\": \"builtin::rag\",\n",
|
||||
" \"name\": \"builtin::rag/knowledge_search\",\n",
|
||||
" \"args\" : {\n",
|
||||
" \"vector_db_ids\": [vector_db_id],\n",
|
||||
" }\n",
|
||||
|
|
|
@ -3,6 +3,8 @@
|
|||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"[](https://colab.research.google.com/github/meta-llama/llama-stack/blob/main/docs/notebooks/Alpha_Llama_Stack_Post_Training.ipynb)\n",
|
||||
"\n",
|
||||
"# [Alpha] Llama Stack Post Training\n",
|
||||
"This notebook will use a real world problem (improve LLM as tax preparer) to walk through the main sets of APIs we offer with Llama stack for post training to improve the LLM performance for agentic apps (We support supervised finetune now, RLHF and knowledge distillation will come soon!).\n",
|
||||
"\n",
|
||||
|
@ -64,7 +66,7 @@
|
|||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": [
|
||||
"Collecting git+https://github.com/meta-llama/llama-stack.git@hf_format_checkpointer\n",
|
||||
"Collecting git+https://github.com/meta-llama/llama-stack.git\n",
|
||||
" Cloning https://github.com/meta-llama/llama-stack.git (to revision hf_format_checkpointer) to /tmp/pip-req-build-j_1bxqzm\n",
|
||||
" Running command git clone --filter=blob:none --quiet https://github.com/meta-llama/llama-stack.git /tmp/pip-req-build-j_1bxqzm\n",
|
||||
" Running command git checkout -b hf_format_checkpointer --track origin/hf_format_checkpointer\n",
|
||||
|
@ -76,7 +78,7 @@
|
|||
}
|
||||
],
|
||||
"source": [
|
||||
"!pip install git+https://github.com/meta-llama/llama-stack.git@hf_format_checkpointer"
|
||||
"!pip install git+https://github.com/meta-llama/llama-stack.git #TODO: update this after the next pkg release"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
|
@ -7,12 +7,12 @@ Each agent turn follows these key steps:
|
|||
1. **Initial Safety Check**: The user's input is first screened through configured safety shields
|
||||
|
||||
2. **Context Retrieval**:
|
||||
- If RAG is enabled, the agent queries relevant documents from memory banks
|
||||
- For new documents, they are first inserted into the memory bank
|
||||
- Retrieved context is augmented to the user's prompt
|
||||
- If RAG is enabled, the agent can choose to query relevant documents from memory banks. You can use the `instructions` field to steer the agent.
|
||||
- For new documents, they are first inserted into the memory bank.
|
||||
- Retrieved context is provided to the LLM as a tool response in the message history.
|
||||
|
||||
3. **Inference Loop**: The agent enters its main execution loop:
|
||||
- The LLM receives the augmented prompt (with context and/or previous tool outputs)
|
||||
- The LLM receives a user prompt (with previous tool outputs)
|
||||
- The LLM generates a response, potentially with tool calls
|
||||
- If tool calls are present:
|
||||
- Tool inputs are safety-checked
|
||||
|
@ -40,19 +40,16 @@ sequenceDiagram
|
|||
S->>E: Input Safety Check
|
||||
deactivate S
|
||||
|
||||
E->>M: 2.1 Query Context
|
||||
M-->>E: 2.2 Retrieved Documents
|
||||
|
||||
loop Inference Loop
|
||||
E->>L: 3.1 Augment with Context
|
||||
L-->>E: 3.2 Response (with/without tool calls)
|
||||
E->>L: 2.1 Augment with Context
|
||||
L-->>E: 2.2 Response (with/without tool calls)
|
||||
|
||||
alt Has Tool Calls
|
||||
E->>S: Check Tool Input
|
||||
S->>T: 4.1 Execute Tool
|
||||
T-->>E: 4.2 Tool Response
|
||||
E->>L: 5.1 Tool Response
|
||||
L-->>E: 5.2 Synthesized Response
|
||||
S->>T: 3.1 Execute Tool
|
||||
T-->>E: 3.2 Tool Response
|
||||
E->>L: 4.1 Tool Response
|
||||
L-->>E: 4.2 Synthesized Response
|
||||
end
|
||||
|
||||
opt Stop Conditions
|
||||
|
@ -64,7 +61,7 @@ sequenceDiagram
|
|||
end
|
||||
|
||||
E->>S: Output Safety Check
|
||||
S->>U: 6. Final Response
|
||||
S->>U: 5. Final Response
|
||||
```
|
||||
|
||||
Each step in this process can be monitored and controlled through configurations. Here's an example that demonstrates monitoring the agent's execution:
|
||||
|
@ -77,7 +74,10 @@ agent_config = AgentConfig(
|
|||
instructions="You are a helpful assistant",
|
||||
# Enable both RAG and tool usage
|
||||
toolgroups=[
|
||||
{"name": "builtin::rag", "args": {"vector_db_ids": ["my_docs"]}},
|
||||
{
|
||||
"name": "builtin::rag/knowledge_search",
|
||||
"args": {"vector_db_ids": ["my_docs"]},
|
||||
},
|
||||
"builtin::code_interpreter",
|
||||
],
|
||||
# Configure safety
|
||||
|
|
|
@ -91,7 +91,7 @@ agent_config = AgentConfig(
|
|||
enable_session_persistence=False,
|
||||
toolgroups=[
|
||||
{
|
||||
"name": "builtin::rag",
|
||||
"name": "builtin::rag/knowledge_search",
|
||||
"args": {
|
||||
"vector_db_ids": [vector_db_id],
|
||||
},
|
||||
|
|
|
@ -13,6 +13,13 @@
|
|||
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
|
||||
|
||||
from docutils import nodes
|
||||
import tomli # Import tomli for TOML parsing
|
||||
from pathlib import Path
|
||||
|
||||
# Read version from pyproject.toml
|
||||
with Path(__file__).parent.parent.parent.joinpath("pyproject.toml").open("rb") as f:
|
||||
pyproject = tomli.load(f)
|
||||
llama_stack_version = pyproject["project"]["version"]
|
||||
|
||||
project = "llama-stack"
|
||||
copyright = "2025, Meta"
|
||||
|
@ -66,6 +73,7 @@ myst_enable_extensions = [
|
|||
|
||||
myst_substitutions = {
|
||||
"docker_hub": "https://hub.docker.com/repository/docker/llamastack",
|
||||
"llama_stack_version": llama_stack_version,
|
||||
}
|
||||
|
||||
suppress_warnings = ['myst.header']
|
||||
|
|
|
@ -243,7 +243,7 @@ agent_config = AgentConfig(
|
|||
# Define tools available to the agent
|
||||
toolgroups=[
|
||||
{
|
||||
"name": "builtin::rag",
|
||||
"name": "builtin::rag/knowledge_search",
|
||||
"args": {
|
||||
"vector_db_ids": [vector_db_id],
|
||||
},
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
|
||||
```{admonition} News
|
||||
:class: tip
|
||||
|
||||
Llama Stack 0.1.4 is now available! See the [release notes](https://github.com/meta-llama/llama-stack/releases/tag/v0.1.4) for more details.
|
||||
Llama Stack {{ llama_stack_version }} is now available! See the [release notes](https://github.com/meta-llama/llama-stack/releases/tag/v{{ llama_stack_version }}) for more details.
|
||||
```
|
||||
|
||||
# Llama Stack
|
||||
|
|
|
@ -441,7 +441,7 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
vector_db_ids: List[str],
|
||||
query_config: Optional[RAGQueryConfig] = None,
|
||||
) -> RAGQueryResult:
|
||||
return await self.routing_table.get_provider_impl("query_from_memory").query(
|
||||
return await self.routing_table.get_provider_impl("knowledge_search").query(
|
||||
content, vector_db_ids, query_config
|
||||
)
|
||||
|
||||
|
|
|
@ -33,7 +33,7 @@ class DistributionRegistry(Protocol):
|
|||
|
||||
|
||||
REGISTER_PREFIX = "distributions:registry"
|
||||
KEY_VERSION = "v7"
|
||||
KEY_VERSION = "v8"
|
||||
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"
|
||||
|
||||
|
||||
|
|
|
@ -132,7 +132,7 @@ def rag_chat_page():
|
|||
},
|
||||
toolgroups=[
|
||||
dict(
|
||||
name="builtin::rag",
|
||||
name="builtin::rag/knowledge_search",
|
||||
args={
|
||||
"vector_db_ids": [vector_db_id for vector_db_id in selected_vector_dbs],
|
||||
},
|
||||
|
|
|
@ -17,7 +17,6 @@ from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
|
|||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from llama_stack.apis.agents import (
|
||||
AgentConfig,
|
||||
|
@ -62,7 +61,7 @@ from llama_stack.apis.inference import (
|
|||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.tools import RAGDocument, RAGQueryConfig, ToolGroups, ToolInvocationResult, ToolRuntime
|
||||
from llama_stack.apis.tools import RAGDocument, ToolGroups, ToolInvocationResult, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
BuiltinTool,
|
||||
|
@ -70,7 +69,6 @@ from llama_stack.models.llama.datatypes import (
|
|||
ToolParamDefinition,
|
||||
)
|
||||
from llama_stack.providers.utils.kvstore import KVStore
|
||||
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
|
||||
from llama_stack.providers.utils.telemetry import tracing
|
||||
|
||||
from .persistence import AgentPersistence
|
||||
|
@ -84,7 +82,7 @@ def make_random_string(length: int = 8):
|
|||
|
||||
|
||||
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
|
||||
MEMORY_QUERY_TOOL = "query_from_memory"
|
||||
MEMORY_QUERY_TOOL = "knowledge_search"
|
||||
WEB_SEARCH_TOOL = "web_search"
|
||||
RAG_TOOL_GROUP = "builtin::rag"
|
||||
|
||||
|
@ -499,111 +497,18 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
# TODO: simplify all of this code, it can be simpler
|
||||
toolgroup_args = {}
|
||||
toolgroups = set()
|
||||
for toolgroup in self.agent_config.toolgroups:
|
||||
for toolgroup in self.agent_config.toolgroups + (toolgroups_for_turn or []):
|
||||
if isinstance(toolgroup, AgentToolGroupWithArgs):
|
||||
toolgroups.add(toolgroup.name)
|
||||
toolgroup_args[toolgroup.name] = toolgroup.args
|
||||
tool_group_name, tool_name = self._parse_toolgroup_name(toolgroup.name)
|
||||
toolgroups.add(tool_group_name)
|
||||
toolgroup_args[tool_group_name] = toolgroup.args
|
||||
else:
|
||||
toolgroups.add(toolgroup)
|
||||
if toolgroups_for_turn:
|
||||
for toolgroup in toolgroups_for_turn:
|
||||
if isinstance(toolgroup, AgentToolGroupWithArgs):
|
||||
toolgroups.add(toolgroup.name)
|
||||
toolgroup_args[toolgroup.name] = toolgroup.args
|
||||
else:
|
||||
toolgroups.add(toolgroup)
|
||||
|
||||
tool_defs, tool_to_group = await self._get_tool_defs(toolgroups_for_turn)
|
||||
if documents:
|
||||
await self.handle_documents(session_id, documents, input_messages, tool_defs)
|
||||
|
||||
if RAG_TOOL_GROUP in toolgroups and len(input_messages) > 0:
|
||||
with tracing.span(MEMORY_QUERY_TOOL) as span:
|
||||
step_id = str(uuid.uuid4())
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepStartPayload(
|
||||
step_type=StepType.tool_execution.value,
|
||||
step_id=step_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
args = toolgroup_args.get(RAG_TOOL_GROUP, {})
|
||||
vector_db_ids = args.get("vector_db_ids", [])
|
||||
query_config = args.get("query_config")
|
||||
if query_config:
|
||||
query_config = TypeAdapter(RAGQueryConfig).validate_python(query_config)
|
||||
else:
|
||||
# handle someone passing an empty dict
|
||||
query_config = RAGQueryConfig()
|
||||
|
||||
session_info = await self.storage.get_session_info(session_id)
|
||||
|
||||
# if the session has a memory bank id, let the memory tool use it
|
||||
if session_info.vector_db_id:
|
||||
vector_db_ids.append(session_info.vector_db_id)
|
||||
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepProgressPayload(
|
||||
step_type=StepType.tool_execution.value,
|
||||
step_id=step_id,
|
||||
delta=ToolCallDelta(
|
||||
parse_status=ToolCallParseStatus.succeeded,
|
||||
tool_call=ToolCall(
|
||||
call_id="",
|
||||
tool_name=MEMORY_QUERY_TOOL,
|
||||
arguments={},
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
result = await self.tool_runtime_api.rag_tool.query(
|
||||
content=concat_interleaved_content([msg.content for msg in input_messages]),
|
||||
vector_db_ids=vector_db_ids,
|
||||
query_config=query_config,
|
||||
)
|
||||
retrieved_context = result.content
|
||||
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepCompletePayload(
|
||||
step_type=StepType.tool_execution.value,
|
||||
step_id=step_id,
|
||||
step_details=ToolExecutionStep(
|
||||
step_id=step_id,
|
||||
turn_id=turn_id,
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
call_id="",
|
||||
tool_name=MEMORY_QUERY_TOOL,
|
||||
arguments={},
|
||||
)
|
||||
],
|
||||
tool_responses=[
|
||||
ToolResponse(
|
||||
call_id="",
|
||||
tool_name=MEMORY_QUERY_TOOL,
|
||||
content=retrieved_context or [],
|
||||
metadata=result.metadata,
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
span.set_attribute("input", [m.model_dump_json() for m in input_messages])
|
||||
span.set_attribute("output", retrieved_context)
|
||||
span.set_attribute("tool_name", MEMORY_QUERY_TOOL)
|
||||
|
||||
# append retrieved_context to the last user message
|
||||
for message in input_messages[::-1]:
|
||||
if isinstance(message, UserMessage):
|
||||
message.context = retrieved_context
|
||||
break
|
||||
|
||||
output_attachments = []
|
||||
|
||||
n_iter = 0
|
||||
|
@ -631,9 +536,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
async for chunk in await self.inference_api.chat_completion(
|
||||
self.agent_config.model,
|
||||
input_messages,
|
||||
tools=[
|
||||
tool for tool in tool_defs.values() if tool_to_group.get(tool.tool_name, None) != RAG_TOOL_GROUP
|
||||
],
|
||||
tools=tool_defs,
|
||||
tool_prompt_format=self.agent_config.tool_config.tool_prompt_format,
|
||||
response_format=self.agent_config.response_format,
|
||||
stream=True,
|
||||
|
@ -837,7 +740,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
],
|
||||
started_at=tool_execution_start_time,
|
||||
completed_at=datetime.now(),
|
||||
completed_at=datetime.now().astimezone().isoformat(),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
@ -845,8 +748,9 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
# TODO: add tool-input touchpoint and a "start" event for this step also
|
||||
# but that needs a lot more refactoring of Tool code potentially
|
||||
|
||||
if out_attachment := _interpret_content_as_attachment(result_message.content):
|
||||
if (type(result_message.content) is str) and (
|
||||
out_attachment := _interpret_content_as_attachment(result_message.content)
|
||||
):
|
||||
# NOTE: when we push this message back to the model, the model may ignore the
|
||||
# attached file path etc. since the model is trained to only provide a user message
|
||||
# with the summary. We keep all generated attachments and then attach them to final message
|
||||
|
@ -858,7 +762,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
async def _get_tool_defs(
|
||||
self, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None
|
||||
) -> Tuple[Dict[str, ToolDefinition], Dict[str, str]]:
|
||||
) -> Tuple[List[ToolDefinition], Dict[str, str]]:
|
||||
# Determine which tools to include
|
||||
agent_config_toolgroups = set(
|
||||
(toolgroup.name if isinstance(toolgroup, AgentToolGroupWithArgs) else toolgroup)
|
||||
|
@ -873,13 +777,13 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
}
|
||||
)
|
||||
|
||||
tool_def_map = {}
|
||||
tool_name_to_def = {}
|
||||
tool_to_group = {}
|
||||
|
||||
for tool_def in self.agent_config.client_tools:
|
||||
if tool_def_map.get(tool_def.name, None):
|
||||
if tool_name_to_def.get(tool_def.name, None):
|
||||
raise ValueError(f"Tool {tool_def.name} already exists")
|
||||
tool_def_map[tool_def.name] = ToolDefinition(
|
||||
tool_name_to_def[tool_def.name] = ToolDefinition(
|
||||
tool_name=tool_def.name,
|
||||
description=tool_def.description,
|
||||
parameters={
|
||||
|
@ -893,10 +797,17 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
},
|
||||
)
|
||||
tool_to_group[tool_def.name] = "__client_tools__"
|
||||
for toolgroup_name in agent_config_toolgroups:
|
||||
if toolgroup_name not in toolgroups_for_turn_set:
|
||||
for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups:
|
||||
if toolgroup_name_with_maybe_tool_name not in toolgroups_for_turn_set:
|
||||
continue
|
||||
|
||||
toolgroup_name, tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name)
|
||||
tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name)
|
||||
if tool_name is not None and not any(tool.identifier == tool_name for tool in tools.data):
|
||||
raise ValueError(
|
||||
f"Tool {tool_name} not found in toolgroup {toolgroup_name}. Available tools: {', '.join([tool.identifier for tool in tools.data])}"
|
||||
)
|
||||
|
||||
for tool_def in tools.data:
|
||||
if toolgroup_name.startswith("builtin") and toolgroup_name != RAG_TOOL_GROUP:
|
||||
tool_name = tool_def.identifier
|
||||
|
@ -906,10 +817,10 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
else:
|
||||
built_in_type = BuiltinTool(tool_name)
|
||||
|
||||
if tool_def_map.get(built_in_type, None):
|
||||
if tool_name_to_def.get(built_in_type, None):
|
||||
raise ValueError(f"Tool {built_in_type} already exists")
|
||||
|
||||
tool_def_map[built_in_type] = ToolDefinition(
|
||||
tool_name_to_def[built_in_type] = ToolDefinition(
|
||||
tool_name=built_in_type,
|
||||
description=tool_def.description,
|
||||
parameters={
|
||||
|
@ -925,24 +836,42 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
tool_to_group[built_in_type] = tool_def.toolgroup_id
|
||||
continue
|
||||
|
||||
if tool_def_map.get(tool_def.identifier, None):
|
||||
if tool_name_to_def.get(tool_def.identifier, None):
|
||||
raise ValueError(f"Tool {tool_def.identifier} already exists")
|
||||
tool_def_map[tool_def.identifier] = ToolDefinition(
|
||||
tool_name=tool_def.identifier,
|
||||
description=tool_def.description,
|
||||
parameters={
|
||||
param.name: ToolParamDefinition(
|
||||
param_type=param.parameter_type,
|
||||
description=param.description,
|
||||
required=param.required,
|
||||
default=param.default,
|
||||
)
|
||||
for param in tool_def.parameters
|
||||
},
|
||||
)
|
||||
tool_to_group[tool_def.identifier] = tool_def.toolgroup_id
|
||||
if tool_name in (None, tool_def.identifier):
|
||||
tool_name_to_def[tool_def.identifier] = ToolDefinition(
|
||||
tool_name=tool_def.identifier,
|
||||
description=tool_def.description,
|
||||
parameters={
|
||||
param.name: ToolParamDefinition(
|
||||
param_type=param.parameter_type,
|
||||
description=param.description,
|
||||
required=param.required,
|
||||
default=param.default,
|
||||
)
|
||||
for param in tool_def.parameters
|
||||
},
|
||||
)
|
||||
tool_to_group[tool_def.identifier] = tool_def.toolgroup_id
|
||||
|
||||
return tool_def_map, tool_to_group
|
||||
return list(tool_name_to_def.values()), tool_to_group
|
||||
|
||||
def _parse_toolgroup_name(self, toolgroup_name_with_maybe_tool_name: str) -> tuple[str, Optional[str]]:
|
||||
"""Parse a toolgroup name into its components.
|
||||
|
||||
Args:
|
||||
toolgroup_name: The toolgroup name to parse (e.g. "builtin::rag/knowledge_search")
|
||||
|
||||
Returns:
|
||||
A tuple of (tool_type, tool_group, tool_name)
|
||||
"""
|
||||
split_names = toolgroup_name_with_maybe_tool_name.split("/")
|
||||
if len(split_names) == 2:
|
||||
# e.g. "builtin::rag"
|
||||
tool_group, tool_name = split_names
|
||||
else:
|
||||
tool_group, tool_name = split_names[0], None
|
||||
return tool_group, tool_name
|
||||
|
||||
async def handle_documents(
|
||||
self,
|
||||
|
@ -951,8 +880,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
input_messages: List[Message],
|
||||
tool_defs: Dict[str, ToolDefinition],
|
||||
) -> None:
|
||||
memory_tool = tool_defs.get(MEMORY_QUERY_TOOL, None)
|
||||
code_interpreter_tool = tool_defs.get(BuiltinTool.code_interpreter, None)
|
||||
memory_tool = any(tool_def.tool_name == MEMORY_QUERY_TOOL for tool_def in tool_defs)
|
||||
code_interpreter_tool = any(tool_def.tool_name == BuiltinTool.code_interpreter for tool_def in tool_defs)
|
||||
content_items = []
|
||||
url_items = []
|
||||
pattern = re.compile("^(https?://|file://|data:)")
|
||||
|
@ -1072,7 +1001,11 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa
|
|||
else:
|
||||
raise ValueError(f"Unsupported URL {url}")
|
||||
|
||||
content.append(TextContentItem(text=f'# There is a file accessible to you at "{filepath}"\n'))
|
||||
content.append(
|
||||
TextContentItem(
|
||||
text=f'# User provided a file accessible to you at "{filepath}"\nYou can use code_interpreter to load and inspect it.'
|
||||
)
|
||||
)
|
||||
|
||||
return ToolResponseMessage(
|
||||
call_id="",
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
import json
|
||||
import os
|
||||
import sqlite3
|
||||
import threading
|
||||
from datetime import datetime
|
||||
|
||||
from opentelemetry.sdk.trace import SpanProcessor
|
||||
|
@ -17,14 +18,18 @@ class SQLiteSpanProcessor(SpanProcessor):
|
|||
def __init__(self, conn_string):
|
||||
"""Initialize the SQLite span processor with a connection string."""
|
||||
self.conn_string = conn_string
|
||||
self.conn = None
|
||||
self._local = threading.local() # Thread-local storage for connections
|
||||
self.setup_database()
|
||||
|
||||
def _get_connection(self) -> sqlite3.Connection:
|
||||
"""Get the database connection."""
|
||||
if self.conn is None:
|
||||
self.conn = sqlite3.connect(self.conn_string, check_same_thread=False)
|
||||
return self.conn
|
||||
def _get_connection(self):
|
||||
"""Get a thread-local database connection."""
|
||||
if not hasattr(self._local, "conn"):
|
||||
try:
|
||||
self._local.conn = sqlite3.connect(self.conn_string)
|
||||
except Exception as e:
|
||||
print(f"Error connecting to SQLite database: {e}")
|
||||
raise e
|
||||
return self._local.conn
|
||||
|
||||
def setup_database(self):
|
||||
"""Create the necessary tables if they don't exist."""
|
||||
|
@ -168,9 +173,14 @@ class SQLiteSpanProcessor(SpanProcessor):
|
|||
|
||||
def shutdown(self):
|
||||
"""Cleanup any resources."""
|
||||
if self.conn:
|
||||
self.conn.close()
|
||||
self.conn = None
|
||||
# We can't access other threads' connections, so we just close our own
|
||||
if hasattr(self._local, "conn"):
|
||||
try:
|
||||
self._local.conn.close()
|
||||
except Exception as e:
|
||||
print(f"Error closing SQLite connection: {e}")
|
||||
finally:
|
||||
del self._local.conn
|
||||
|
||||
def force_flush(self, timeout_millis=30000):
|
||||
"""Force export of spans."""
|
||||
|
|
|
@ -10,6 +10,8 @@ import secrets
|
|||
import string
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
URL,
|
||||
InterleavedContent,
|
||||
|
@ -23,6 +25,7 @@ from llama_stack.apis.tools import (
|
|||
RAGToolRuntime,
|
||||
ToolDef,
|
||||
ToolInvocationResult,
|
||||
ToolParameter,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
|
||||
|
@ -120,9 +123,14 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
|||
# sort by score
|
||||
chunks, scores = zip(*sorted(zip(chunks, scores, strict=False), key=lambda x: x[1], reverse=True), strict=False)
|
||||
chunks = chunks[: query_config.max_chunks]
|
||||
|
||||
tokens = 0
|
||||
picked = []
|
||||
for c in chunks:
|
||||
picked = [
|
||||
TextContentItem(
|
||||
text=f"knowledge_search tool found {len(chunks)} chunks:\nBEGIN of knowledge_search tool results.\n"
|
||||
)
|
||||
]
|
||||
for i, c in enumerate(chunks):
|
||||
metadata = c.metadata
|
||||
tokens += metadata["token_count"]
|
||||
if tokens > query_config.max_tokens_in_context:
|
||||
|
@ -132,20 +140,13 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
|||
break
|
||||
picked.append(
|
||||
TextContentItem(
|
||||
text=f"id:{metadata['document_id']}; content:{c.content}",
|
||||
text=f"Result {i + 1}:\nDocument_id:{metadata['document_id'][:5]}\nContent: {c.content}\n",
|
||||
)
|
||||
)
|
||||
picked.append(TextContentItem(text="END of knowledge_search tool results.\n"))
|
||||
|
||||
return RAGQueryResult(
|
||||
content=[
|
||||
TextContentItem(
|
||||
text="Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n",
|
||||
),
|
||||
*picked,
|
||||
TextContentItem(
|
||||
text="\n=== END-RETRIEVED-CONTEXT ===\n",
|
||||
),
|
||||
],
|
||||
content=picked,
|
||||
metadata={
|
||||
"document_ids": [c.metadata["document_id"] for c in chunks[: len(picked)]],
|
||||
},
|
||||
|
@ -158,17 +159,40 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
|||
# by the LLM. The method is only implemented so things like /tools can list without
|
||||
# encountering fatals.
|
||||
return [
|
||||
ToolDef(
|
||||
name="query_from_memory",
|
||||
description="Retrieve context from memory",
|
||||
),
|
||||
ToolDef(
|
||||
name="insert_into_memory",
|
||||
description="Insert documents into memory",
|
||||
),
|
||||
ToolDef(
|
||||
name="knowledge_search",
|
||||
description="Search for information in a database.",
|
||||
parameters=[
|
||||
ToolParameter(
|
||||
name="query",
|
||||
description="The query to search for. Can be a natural language sentence or keywords.",
|
||||
parameter_type="string",
|
||||
),
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
|
||||
raise RuntimeError(
|
||||
"This toolgroup should not be called generically but only through specific methods of the RAGToolRuntime protocol"
|
||||
vector_db_ids = kwargs.get("vector_db_ids", [])
|
||||
query_config = kwargs.get("query_config")
|
||||
if query_config:
|
||||
query_config = TypeAdapter(RAGQueryConfig).validate_python(query_config)
|
||||
else:
|
||||
# handle someone passing an empty dict
|
||||
query_config = RAGQueryConfig()
|
||||
|
||||
query = kwargs["query"]
|
||||
result = await self.query(
|
||||
content=query,
|
||||
vector_db_ids=vector_db_ids,
|
||||
query_config=query_config,
|
||||
)
|
||||
|
||||
return ToolInvocationResult(
|
||||
content=result.content,
|
||||
metadata=result.metadata,
|
||||
)
|
||||
|
|
|
@ -74,6 +74,7 @@ docs = [
|
|||
"sphinxcontrib.redoc",
|
||||
"sphinxcontrib.video",
|
||||
"sphinxcontrib.mermaid",
|
||||
"tomli",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
|
|
|
@ -96,7 +96,7 @@ def agent_config(llama_stack_client, text_model_id):
|
|||
sampling_params={
|
||||
"strategy": {
|
||||
"type": "top_p",
|
||||
"temperature": 1.0,
|
||||
"temperature": 0.0001,
|
||||
"top_p": 0.9,
|
||||
},
|
||||
},
|
||||
|
@ -441,7 +441,8 @@ def xtest_override_system_message_behavior(llama_stack_client, agent_config):
|
|||
assert "get_boiling_point" in logs_str
|
||||
|
||||
|
||||
def test_rag_agent(llama_stack_client, agent_config):
|
||||
@pytest.mark.parametrize("rag_tool_name", ["builtin::rag/knowledge_search", "builtin::rag"])
|
||||
def test_rag_agent(llama_stack_client, agent_config, rag_tool_name):
|
||||
urls = ["chat.rst", "llama3.rst", "memory_optimizations.rst", "lora_finetune.rst"]
|
||||
documents = [
|
||||
Document(
|
||||
|
@ -469,7 +470,7 @@ def test_rag_agent(llama_stack_client, agent_config):
|
|||
**agent_config,
|
||||
"toolgroups": [
|
||||
dict(
|
||||
name="builtin::rag",
|
||||
name=rag_tool_name,
|
||||
args={
|
||||
"vector_db_ids": [vector_db_id],
|
||||
},
|
||||
|
@ -483,10 +484,6 @@ def test_rag_agent(llama_stack_client, agent_config):
|
|||
"Instead of the standard multi-head attention, what attention type does Llama3-8B use?",
|
||||
"grouped",
|
||||
),
|
||||
(
|
||||
"What `tune` command to use for getting access to Llama3-8B-Instruct ?",
|
||||
"download",
|
||||
),
|
||||
]
|
||||
for prompt, expected_kw in user_prompts:
|
||||
response = rag_agent.create_turn(
|
||||
|
@ -496,23 +493,36 @@ def test_rag_agent(llama_stack_client, agent_config):
|
|||
)
|
||||
# rag is called
|
||||
tool_execution_step = next(step for step in response.steps if step.step_type == "tool_execution")
|
||||
assert tool_execution_step.tool_calls[0].tool_name == "query_from_memory"
|
||||
assert tool_execution_step.tool_calls[0].tool_name == "knowledge_search"
|
||||
# document ids are present in metadata
|
||||
assert "num-0" in tool_execution_step.tool_responses[0].metadata["document_ids"]
|
||||
assert expected_kw in response.output_message.content.lower()
|
||||
assert all(
|
||||
doc_id.startswith("num-") for doc_id in tool_execution_step.tool_responses[0].metadata["document_ids"]
|
||||
)
|
||||
if expected_kw:
|
||||
assert expected_kw in response.output_message.content.lower()
|
||||
|
||||
|
||||
def test_rag_and_code_agent(llama_stack_client, agent_config):
|
||||
urls = ["chat.rst"]
|
||||
documents = [
|
||||
documents = []
|
||||
documents.append(
|
||||
Document(
|
||||
document_id=f"num-{i}",
|
||||
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
|
||||
mime_type="text/plain",
|
||||
document_id="nba_wiki",
|
||||
content="The NBA was created on August 3, 1949, with the merger of the Basketball Association of America (BAA) and the National Basketball League (NBL).",
|
||||
metadata={},
|
||||
)
|
||||
for i, url in enumerate(urls)
|
||||
]
|
||||
)
|
||||
documents.append(
|
||||
Document(
|
||||
document_id="perplexity_wiki",
|
||||
content="""Perplexity the company was founded in 2022 by Aravind Srinivas, Andy Konwinski, Denis Yarats and Johnny Ho, engineers with backgrounds in back-end systems, artificial intelligence (AI) and machine learning:
|
||||
|
||||
Srinivas, the CEO, worked at OpenAI as an AI researcher.
|
||||
Konwinski was among the founding team at Databricks.
|
||||
Yarats, the CTO, was an AI research scientist at Meta.
|
||||
Ho, the CSO, worked as an engineer at Quora, then as a quantitative trader on Wall Street.[5]""",
|
||||
metadata={},
|
||||
)
|
||||
)
|
||||
vector_db_id = f"test-vector-db-{uuid4()}"
|
||||
llama_stack_client.vector_dbs.register(
|
||||
vector_db_id=vector_db_id,
|
||||
|
@ -528,7 +538,7 @@ def test_rag_and_code_agent(llama_stack_client, agent_config):
|
|||
**agent_config,
|
||||
"toolgroups": [
|
||||
dict(
|
||||
name="builtin::rag",
|
||||
name="builtin::rag/knowledge_search",
|
||||
args={"vector_db_ids": [vector_db_id]},
|
||||
),
|
||||
"builtin::code_interpreter",
|
||||
|
@ -546,24 +556,34 @@ def test_rag_and_code_agent(llama_stack_client, agent_config):
|
|||
"Here is a csv file, can you describe it?",
|
||||
[inflation_doc],
|
||||
"code_interpreter",
|
||||
"",
|
||||
),
|
||||
(
|
||||
"What are the top 5 topics that were explained? Only list succinct bullet points.",
|
||||
"when was Perplexity the company founded?",
|
||||
[],
|
||||
"query_from_memory",
|
||||
"knowledge_search",
|
||||
"2022",
|
||||
),
|
||||
(
|
||||
"when was the nba created?",
|
||||
[],
|
||||
"knowledge_search",
|
||||
"1949",
|
||||
),
|
||||
]
|
||||
|
||||
for prompt, docs, tool_name in user_prompts:
|
||||
for prompt, docs, tool_name, expected_kw in user_prompts:
|
||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||
response = agent.create_turn(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
session_id=session_id,
|
||||
documents=docs,
|
||||
stream=False,
|
||||
)
|
||||
logs = [str(log) for log in EventLogger().log(response) if log is not None]
|
||||
logs_str = "".join(logs)
|
||||
assert f"Tool:{tool_name}" in logs_str
|
||||
tool_execution_step = next(step for step in response.steps if step.step_type == "tool_execution")
|
||||
assert tool_execution_step.tool_calls[0].tool_name == tool_name
|
||||
if expected_kw:
|
||||
assert expected_kw in response.output_message.content.lower()
|
||||
|
||||
|
||||
def test_create_turn_response(llama_stack_client, agent_config):
|
||||
|
|
4
uv.lock
generated
4
uv.lock
generated
|
@ -1,5 +1,4 @@
|
|||
version = 1
|
||||
revision = 1
|
||||
requires-python = ">=3.10"
|
||||
resolution-markers = [
|
||||
"(python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux')",
|
||||
|
@ -913,6 +912,7 @@ docs = [
|
|||
{ name = "sphinxcontrib-mermaid" },
|
||||
{ name = "sphinxcontrib-redoc" },
|
||||
{ name = "sphinxcontrib-video" },
|
||||
{ name = "tomli" },
|
||||
]
|
||||
test = [
|
||||
{ name = "aiosqlite" },
|
||||
|
@ -971,13 +971,13 @@ requires-dist = [
|
|||
{ name = "sphinxcontrib-redoc", marker = "extra == 'docs'" },
|
||||
{ name = "sphinxcontrib-video", marker = "extra == 'docs'" },
|
||||
{ name = "termcolor" },
|
||||
{ name = "tomli", marker = "extra == 'docs'" },
|
||||
{ name = "torch", marker = "extra == 'test'", specifier = ">=2.6.0", index = "https://download.pytorch.org/whl/cpu" },
|
||||
{ name = "torchvision", marker = "extra == 'test'", specifier = ">=0.21.0", index = "https://download.pytorch.org/whl/cpu" },
|
||||
{ name = "types-requests", marker = "extra == 'dev'" },
|
||||
{ name = "types-setuptools", marker = "extra == 'dev'" },
|
||||
{ name = "uvicorn", marker = "extra == 'dev'" },
|
||||
]
|
||||
provides-extras = ["dev", "test", "docs"]
|
||||
|
||||
[[package]]
|
||||
name = "llama-stack-client"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue