mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-13 05:17:26 +00:00
Merge branch 'main' into export_agent_dataset
This commit is contained in:
commit
48174b5422
16 changed files with 210 additions and 213 deletions
|
@ -803,7 +803,7 @@
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"model_id = \"meta-llama/Llama-3.1-70B-Instruct\"\n",
|
"model_id = \"meta-llama/Llama-3.3-70B-Instruct\"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"model_id\n"
|
"model_id\n"
|
||||||
]
|
]
|
||||||
|
@ -1688,7 +1688,7 @@
|
||||||
" enable_session_persistence=False,\n",
|
" enable_session_persistence=False,\n",
|
||||||
" toolgroups = [\n",
|
" toolgroups = [\n",
|
||||||
" {\n",
|
" {\n",
|
||||||
" \"name\": \"builtin::rag\",\n",
|
" \"name\": \"builtin::rag/knowledge_search\",\n",
|
||||||
" \"args\" : {\n",
|
" \"args\" : {\n",
|
||||||
" \"vector_db_ids\": [vector_db_id],\n",
|
" \"vector_db_ids\": [vector_db_id],\n",
|
||||||
" }\n",
|
" }\n",
|
||||||
|
|
|
@ -3,6 +3,8 @@
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"source": [
|
"source": [
|
||||||
|
"[](https://colab.research.google.com/github/meta-llama/llama-stack/blob/main/docs/notebooks/Alpha_Llama_Stack_Post_Training.ipynb)\n",
|
||||||
|
"\n",
|
||||||
"# [Alpha] Llama Stack Post Training\n",
|
"# [Alpha] Llama Stack Post Training\n",
|
||||||
"This notebook will use a real world problem (improve LLM as tax preparer) to walk through the main sets of APIs we offer with Llama stack for post training to improve the LLM performance for agentic apps (We support supervised finetune now, RLHF and knowledge distillation will come soon!).\n",
|
"This notebook will use a real world problem (improve LLM as tax preparer) to walk through the main sets of APIs we offer with Llama stack for post training to improve the LLM performance for agentic apps (We support supervised finetune now, RLHF and knowledge distillation will come soon!).\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
@ -64,7 +66,7 @@
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"text": [
|
"text": [
|
||||||
"Collecting git+https://github.com/meta-llama/llama-stack.git@hf_format_checkpointer\n",
|
"Collecting git+https://github.com/meta-llama/llama-stack.git\n",
|
||||||
" Cloning https://github.com/meta-llama/llama-stack.git (to revision hf_format_checkpointer) to /tmp/pip-req-build-j_1bxqzm\n",
|
" Cloning https://github.com/meta-llama/llama-stack.git (to revision hf_format_checkpointer) to /tmp/pip-req-build-j_1bxqzm\n",
|
||||||
" Running command git clone --filter=blob:none --quiet https://github.com/meta-llama/llama-stack.git /tmp/pip-req-build-j_1bxqzm\n",
|
" Running command git clone --filter=blob:none --quiet https://github.com/meta-llama/llama-stack.git /tmp/pip-req-build-j_1bxqzm\n",
|
||||||
" Running command git checkout -b hf_format_checkpointer --track origin/hf_format_checkpointer\n",
|
" Running command git checkout -b hf_format_checkpointer --track origin/hf_format_checkpointer\n",
|
||||||
|
@ -76,7 +78,7 @@
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"!pip install git+https://github.com/meta-llama/llama-stack.git@hf_format_checkpointer"
|
"!pip install git+https://github.com/meta-llama/llama-stack.git #TODO: update this after the next pkg release"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|
|
@ -7,12 +7,12 @@ Each agent turn follows these key steps:
|
||||||
1. **Initial Safety Check**: The user's input is first screened through configured safety shields
|
1. **Initial Safety Check**: The user's input is first screened through configured safety shields
|
||||||
|
|
||||||
2. **Context Retrieval**:
|
2. **Context Retrieval**:
|
||||||
- If RAG is enabled, the agent queries relevant documents from memory banks
|
- If RAG is enabled, the agent can choose to query relevant documents from memory banks. You can use the `instructions` field to steer the agent.
|
||||||
- For new documents, they are first inserted into the memory bank
|
- For new documents, they are first inserted into the memory bank.
|
||||||
- Retrieved context is augmented to the user's prompt
|
- Retrieved context is provided to the LLM as a tool response in the message history.
|
||||||
|
|
||||||
3. **Inference Loop**: The agent enters its main execution loop:
|
3. **Inference Loop**: The agent enters its main execution loop:
|
||||||
- The LLM receives the augmented prompt (with context and/or previous tool outputs)
|
- The LLM receives a user prompt (with previous tool outputs)
|
||||||
- The LLM generates a response, potentially with tool calls
|
- The LLM generates a response, potentially with tool calls
|
||||||
- If tool calls are present:
|
- If tool calls are present:
|
||||||
- Tool inputs are safety-checked
|
- Tool inputs are safety-checked
|
||||||
|
@ -40,19 +40,16 @@ sequenceDiagram
|
||||||
S->>E: Input Safety Check
|
S->>E: Input Safety Check
|
||||||
deactivate S
|
deactivate S
|
||||||
|
|
||||||
E->>M: 2.1 Query Context
|
|
||||||
M-->>E: 2.2 Retrieved Documents
|
|
||||||
|
|
||||||
loop Inference Loop
|
loop Inference Loop
|
||||||
E->>L: 3.1 Augment with Context
|
E->>L: 2.1 Augment with Context
|
||||||
L-->>E: 3.2 Response (with/without tool calls)
|
L-->>E: 2.2 Response (with/without tool calls)
|
||||||
|
|
||||||
alt Has Tool Calls
|
alt Has Tool Calls
|
||||||
E->>S: Check Tool Input
|
E->>S: Check Tool Input
|
||||||
S->>T: 4.1 Execute Tool
|
S->>T: 3.1 Execute Tool
|
||||||
T-->>E: 4.2 Tool Response
|
T-->>E: 3.2 Tool Response
|
||||||
E->>L: 5.1 Tool Response
|
E->>L: 4.1 Tool Response
|
||||||
L-->>E: 5.2 Synthesized Response
|
L-->>E: 4.2 Synthesized Response
|
||||||
end
|
end
|
||||||
|
|
||||||
opt Stop Conditions
|
opt Stop Conditions
|
||||||
|
@ -64,7 +61,7 @@ sequenceDiagram
|
||||||
end
|
end
|
||||||
|
|
||||||
E->>S: Output Safety Check
|
E->>S: Output Safety Check
|
||||||
S->>U: 6. Final Response
|
S->>U: 5. Final Response
|
||||||
```
|
```
|
||||||
|
|
||||||
Each step in this process can be monitored and controlled through configurations. Here's an example that demonstrates monitoring the agent's execution:
|
Each step in this process can be monitored and controlled through configurations. Here's an example that demonstrates monitoring the agent's execution:
|
||||||
|
@ -77,7 +74,10 @@ agent_config = AgentConfig(
|
||||||
instructions="You are a helpful assistant",
|
instructions="You are a helpful assistant",
|
||||||
# Enable both RAG and tool usage
|
# Enable both RAG and tool usage
|
||||||
toolgroups=[
|
toolgroups=[
|
||||||
{"name": "builtin::rag", "args": {"vector_db_ids": ["my_docs"]}},
|
{
|
||||||
|
"name": "builtin::rag/knowledge_search",
|
||||||
|
"args": {"vector_db_ids": ["my_docs"]},
|
||||||
|
},
|
||||||
"builtin::code_interpreter",
|
"builtin::code_interpreter",
|
||||||
],
|
],
|
||||||
# Configure safety
|
# Configure safety
|
||||||
|
|
|
@ -91,7 +91,7 @@ agent_config = AgentConfig(
|
||||||
enable_session_persistence=False,
|
enable_session_persistence=False,
|
||||||
toolgroups=[
|
toolgroups=[
|
||||||
{
|
{
|
||||||
"name": "builtin::rag",
|
"name": "builtin::rag/knowledge_search",
|
||||||
"args": {
|
"args": {
|
||||||
"vector_db_ids": [vector_db_id],
|
"vector_db_ids": [vector_db_id],
|
||||||
},
|
},
|
||||||
|
|
|
@ -13,6 +13,13 @@
|
||||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
|
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
|
||||||
|
|
||||||
from docutils import nodes
|
from docutils import nodes
|
||||||
|
import tomli # Import tomli for TOML parsing
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Read version from pyproject.toml
|
||||||
|
with Path(__file__).parent.parent.parent.joinpath("pyproject.toml").open("rb") as f:
|
||||||
|
pyproject = tomli.load(f)
|
||||||
|
llama_stack_version = pyproject["project"]["version"]
|
||||||
|
|
||||||
project = "llama-stack"
|
project = "llama-stack"
|
||||||
copyright = "2025, Meta"
|
copyright = "2025, Meta"
|
||||||
|
@ -66,6 +73,7 @@ myst_enable_extensions = [
|
||||||
|
|
||||||
myst_substitutions = {
|
myst_substitutions = {
|
||||||
"docker_hub": "https://hub.docker.com/repository/docker/llamastack",
|
"docker_hub": "https://hub.docker.com/repository/docker/llamastack",
|
||||||
|
"llama_stack_version": llama_stack_version,
|
||||||
}
|
}
|
||||||
|
|
||||||
suppress_warnings = ['myst.header']
|
suppress_warnings = ['myst.header']
|
||||||
|
|
|
@ -243,7 +243,7 @@ agent_config = AgentConfig(
|
||||||
# Define tools available to the agent
|
# Define tools available to the agent
|
||||||
toolgroups=[
|
toolgroups=[
|
||||||
{
|
{
|
||||||
"name": "builtin::rag",
|
"name": "builtin::rag/knowledge_search",
|
||||||
"args": {
|
"args": {
|
||||||
"vector_db_ids": [vector_db_id],
|
"vector_db_ids": [vector_db_id],
|
||||||
},
|
},
|
||||||
|
|
|
@ -1,8 +1,7 @@
|
||||||
|
|
||||||
```{admonition} News
|
```{admonition} News
|
||||||
:class: tip
|
:class: tip
|
||||||
|
|
||||||
Llama Stack 0.1.4 is now available! See the [release notes](https://github.com/meta-llama/llama-stack/releases/tag/v0.1.4) for more details.
|
Llama Stack {{ llama_stack_version }} is now available! See the [release notes](https://github.com/meta-llama/llama-stack/releases/tag/v{{ llama_stack_version }}) for more details.
|
||||||
```
|
```
|
||||||
|
|
||||||
# Llama Stack
|
# Llama Stack
|
||||||
|
|
|
@ -441,7 +441,7 @@ class ToolRuntimeRouter(ToolRuntime):
|
||||||
vector_db_ids: List[str],
|
vector_db_ids: List[str],
|
||||||
query_config: Optional[RAGQueryConfig] = None,
|
query_config: Optional[RAGQueryConfig] = None,
|
||||||
) -> RAGQueryResult:
|
) -> RAGQueryResult:
|
||||||
return await self.routing_table.get_provider_impl("query_from_memory").query(
|
return await self.routing_table.get_provider_impl("knowledge_search").query(
|
||||||
content, vector_db_ids, query_config
|
content, vector_db_ids, query_config
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -33,7 +33,7 @@ class DistributionRegistry(Protocol):
|
||||||
|
|
||||||
|
|
||||||
REGISTER_PREFIX = "distributions:registry"
|
REGISTER_PREFIX = "distributions:registry"
|
||||||
KEY_VERSION = "v7"
|
KEY_VERSION = "v8"
|
||||||
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"
|
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -132,7 +132,7 @@ def rag_chat_page():
|
||||||
},
|
},
|
||||||
toolgroups=[
|
toolgroups=[
|
||||||
dict(
|
dict(
|
||||||
name="builtin::rag",
|
name="builtin::rag/knowledge_search",
|
||||||
args={
|
args={
|
||||||
"vector_db_ids": [vector_db_id for vector_db_id in selected_vector_dbs],
|
"vector_db_ids": [vector_db_id for vector_db_id in selected_vector_dbs],
|
||||||
},
|
},
|
||||||
|
|
|
@ -17,7 +17,6 @@ from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from pydantic import TypeAdapter
|
|
||||||
|
|
||||||
from llama_stack.apis.agents import (
|
from llama_stack.apis.agents import (
|
||||||
AgentConfig,
|
AgentConfig,
|
||||||
|
@ -62,7 +61,7 @@ from llama_stack.apis.inference import (
|
||||||
UserMessage,
|
UserMessage,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.safety import Safety
|
from llama_stack.apis.safety import Safety
|
||||||
from llama_stack.apis.tools import RAGDocument, RAGQueryConfig, ToolGroups, ToolInvocationResult, ToolRuntime
|
from llama_stack.apis.tools import RAGDocument, ToolGroups, ToolInvocationResult, ToolRuntime
|
||||||
from llama_stack.apis.vector_io import VectorIO
|
from llama_stack.apis.vector_io import VectorIO
|
||||||
from llama_stack.models.llama.datatypes import (
|
from llama_stack.models.llama.datatypes import (
|
||||||
BuiltinTool,
|
BuiltinTool,
|
||||||
|
@ -70,7 +69,6 @@ from llama_stack.models.llama.datatypes import (
|
||||||
ToolParamDefinition,
|
ToolParamDefinition,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.utils.kvstore import KVStore
|
from llama_stack.providers.utils.kvstore import KVStore
|
||||||
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
|
|
||||||
from llama_stack.providers.utils.telemetry import tracing
|
from llama_stack.providers.utils.telemetry import tracing
|
||||||
|
|
||||||
from .persistence import AgentPersistence
|
from .persistence import AgentPersistence
|
||||||
|
@ -84,7 +82,7 @@ def make_random_string(length: int = 8):
|
||||||
|
|
||||||
|
|
||||||
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
|
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
|
||||||
MEMORY_QUERY_TOOL = "query_from_memory"
|
MEMORY_QUERY_TOOL = "knowledge_search"
|
||||||
WEB_SEARCH_TOOL = "web_search"
|
WEB_SEARCH_TOOL = "web_search"
|
||||||
RAG_TOOL_GROUP = "builtin::rag"
|
RAG_TOOL_GROUP = "builtin::rag"
|
||||||
|
|
||||||
|
@ -499,17 +497,11 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
# TODO: simplify all of this code, it can be simpler
|
# TODO: simplify all of this code, it can be simpler
|
||||||
toolgroup_args = {}
|
toolgroup_args = {}
|
||||||
toolgroups = set()
|
toolgroups = set()
|
||||||
for toolgroup in self.agent_config.toolgroups:
|
for toolgroup in self.agent_config.toolgroups + (toolgroups_for_turn or []):
|
||||||
if isinstance(toolgroup, AgentToolGroupWithArgs):
|
if isinstance(toolgroup, AgentToolGroupWithArgs):
|
||||||
toolgroups.add(toolgroup.name)
|
tool_group_name, tool_name = self._parse_toolgroup_name(toolgroup.name)
|
||||||
toolgroup_args[toolgroup.name] = toolgroup.args
|
toolgroups.add(tool_group_name)
|
||||||
else:
|
toolgroup_args[tool_group_name] = toolgroup.args
|
||||||
toolgroups.add(toolgroup)
|
|
||||||
if toolgroups_for_turn:
|
|
||||||
for toolgroup in toolgroups_for_turn:
|
|
||||||
if isinstance(toolgroup, AgentToolGroupWithArgs):
|
|
||||||
toolgroups.add(toolgroup.name)
|
|
||||||
toolgroup_args[toolgroup.name] = toolgroup.args
|
|
||||||
else:
|
else:
|
||||||
toolgroups.add(toolgroup)
|
toolgroups.add(toolgroup)
|
||||||
|
|
||||||
|
@ -517,93 +509,6 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
if documents:
|
if documents:
|
||||||
await self.handle_documents(session_id, documents, input_messages, tool_defs)
|
await self.handle_documents(session_id, documents, input_messages, tool_defs)
|
||||||
|
|
||||||
if RAG_TOOL_GROUP in toolgroups and len(input_messages) > 0:
|
|
||||||
with tracing.span(MEMORY_QUERY_TOOL) as span:
|
|
||||||
step_id = str(uuid.uuid4())
|
|
||||||
yield AgentTurnResponseStreamChunk(
|
|
||||||
event=AgentTurnResponseEvent(
|
|
||||||
payload=AgentTurnResponseStepStartPayload(
|
|
||||||
step_type=StepType.tool_execution.value,
|
|
||||||
step_id=step_id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
args = toolgroup_args.get(RAG_TOOL_GROUP, {})
|
|
||||||
vector_db_ids = args.get("vector_db_ids", [])
|
|
||||||
query_config = args.get("query_config")
|
|
||||||
if query_config:
|
|
||||||
query_config = TypeAdapter(RAGQueryConfig).validate_python(query_config)
|
|
||||||
else:
|
|
||||||
# handle someone passing an empty dict
|
|
||||||
query_config = RAGQueryConfig()
|
|
||||||
|
|
||||||
session_info = await self.storage.get_session_info(session_id)
|
|
||||||
|
|
||||||
# if the session has a memory bank id, let the memory tool use it
|
|
||||||
if session_info.vector_db_id:
|
|
||||||
vector_db_ids.append(session_info.vector_db_id)
|
|
||||||
|
|
||||||
yield AgentTurnResponseStreamChunk(
|
|
||||||
event=AgentTurnResponseEvent(
|
|
||||||
payload=AgentTurnResponseStepProgressPayload(
|
|
||||||
step_type=StepType.tool_execution.value,
|
|
||||||
step_id=step_id,
|
|
||||||
delta=ToolCallDelta(
|
|
||||||
parse_status=ToolCallParseStatus.succeeded,
|
|
||||||
tool_call=ToolCall(
|
|
||||||
call_id="",
|
|
||||||
tool_name=MEMORY_QUERY_TOOL,
|
|
||||||
arguments={},
|
|
||||||
),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
result = await self.tool_runtime_api.rag_tool.query(
|
|
||||||
content=concat_interleaved_content([msg.content for msg in input_messages]),
|
|
||||||
vector_db_ids=vector_db_ids,
|
|
||||||
query_config=query_config,
|
|
||||||
)
|
|
||||||
retrieved_context = result.content
|
|
||||||
|
|
||||||
yield AgentTurnResponseStreamChunk(
|
|
||||||
event=AgentTurnResponseEvent(
|
|
||||||
payload=AgentTurnResponseStepCompletePayload(
|
|
||||||
step_type=StepType.tool_execution.value,
|
|
||||||
step_id=step_id,
|
|
||||||
step_details=ToolExecutionStep(
|
|
||||||
step_id=step_id,
|
|
||||||
turn_id=turn_id,
|
|
||||||
tool_calls=[
|
|
||||||
ToolCall(
|
|
||||||
call_id="",
|
|
||||||
tool_name=MEMORY_QUERY_TOOL,
|
|
||||||
arguments={},
|
|
||||||
)
|
|
||||||
],
|
|
||||||
tool_responses=[
|
|
||||||
ToolResponse(
|
|
||||||
call_id="",
|
|
||||||
tool_name=MEMORY_QUERY_TOOL,
|
|
||||||
content=retrieved_context or [],
|
|
||||||
metadata=result.metadata,
|
|
||||||
)
|
|
||||||
],
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
span.set_attribute("input", [m.model_dump_json() for m in input_messages])
|
|
||||||
span.set_attribute("output", retrieved_context)
|
|
||||||
span.set_attribute("tool_name", MEMORY_QUERY_TOOL)
|
|
||||||
|
|
||||||
# append retrieved_context to the last user message
|
|
||||||
for message in input_messages[::-1]:
|
|
||||||
if isinstance(message, UserMessage):
|
|
||||||
message.context = retrieved_context
|
|
||||||
break
|
|
||||||
|
|
||||||
output_attachments = []
|
output_attachments = []
|
||||||
|
|
||||||
n_iter = 0
|
n_iter = 0
|
||||||
|
@ -631,9 +536,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
async for chunk in await self.inference_api.chat_completion(
|
async for chunk in await self.inference_api.chat_completion(
|
||||||
self.agent_config.model,
|
self.agent_config.model,
|
||||||
input_messages,
|
input_messages,
|
||||||
tools=[
|
tools=tool_defs,
|
||||||
tool for tool in tool_defs.values() if tool_to_group.get(tool.tool_name, None) != RAG_TOOL_GROUP
|
|
||||||
],
|
|
||||||
tool_prompt_format=self.agent_config.tool_config.tool_prompt_format,
|
tool_prompt_format=self.agent_config.tool_config.tool_prompt_format,
|
||||||
response_format=self.agent_config.response_format,
|
response_format=self.agent_config.response_format,
|
||||||
stream=True,
|
stream=True,
|
||||||
|
@ -837,7 +740,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
started_at=tool_execution_start_time,
|
started_at=tool_execution_start_time,
|
||||||
completed_at=datetime.now(),
|
completed_at=datetime.now().astimezone().isoformat(),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -845,8 +748,9 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
# TODO: add tool-input touchpoint and a "start" event for this step also
|
# TODO: add tool-input touchpoint and a "start" event for this step also
|
||||||
# but that needs a lot more refactoring of Tool code potentially
|
# but that needs a lot more refactoring of Tool code potentially
|
||||||
|
if (type(result_message.content) is str) and (
|
||||||
if out_attachment := _interpret_content_as_attachment(result_message.content):
|
out_attachment := _interpret_content_as_attachment(result_message.content)
|
||||||
|
):
|
||||||
# NOTE: when we push this message back to the model, the model may ignore the
|
# NOTE: when we push this message back to the model, the model may ignore the
|
||||||
# attached file path etc. since the model is trained to only provide a user message
|
# attached file path etc. since the model is trained to only provide a user message
|
||||||
# with the summary. We keep all generated attachments and then attach them to final message
|
# with the summary. We keep all generated attachments and then attach them to final message
|
||||||
|
@ -858,7 +762,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
async def _get_tool_defs(
|
async def _get_tool_defs(
|
||||||
self, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None
|
self, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None
|
||||||
) -> Tuple[Dict[str, ToolDefinition], Dict[str, str]]:
|
) -> Tuple[List[ToolDefinition], Dict[str, str]]:
|
||||||
# Determine which tools to include
|
# Determine which tools to include
|
||||||
agent_config_toolgroups = set(
|
agent_config_toolgroups = set(
|
||||||
(toolgroup.name if isinstance(toolgroup, AgentToolGroupWithArgs) else toolgroup)
|
(toolgroup.name if isinstance(toolgroup, AgentToolGroupWithArgs) else toolgroup)
|
||||||
|
@ -873,13 +777,13 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
tool_def_map = {}
|
tool_name_to_def = {}
|
||||||
tool_to_group = {}
|
tool_to_group = {}
|
||||||
|
|
||||||
for tool_def in self.agent_config.client_tools:
|
for tool_def in self.agent_config.client_tools:
|
||||||
if tool_def_map.get(tool_def.name, None):
|
if tool_name_to_def.get(tool_def.name, None):
|
||||||
raise ValueError(f"Tool {tool_def.name} already exists")
|
raise ValueError(f"Tool {tool_def.name} already exists")
|
||||||
tool_def_map[tool_def.name] = ToolDefinition(
|
tool_name_to_def[tool_def.name] = ToolDefinition(
|
||||||
tool_name=tool_def.name,
|
tool_name=tool_def.name,
|
||||||
description=tool_def.description,
|
description=tool_def.description,
|
||||||
parameters={
|
parameters={
|
||||||
|
@ -893,10 +797,17 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
tool_to_group[tool_def.name] = "__client_tools__"
|
tool_to_group[tool_def.name] = "__client_tools__"
|
||||||
for toolgroup_name in agent_config_toolgroups:
|
for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups:
|
||||||
if toolgroup_name not in toolgroups_for_turn_set:
|
if toolgroup_name_with_maybe_tool_name not in toolgroups_for_turn_set:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
toolgroup_name, tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name)
|
||||||
tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name)
|
tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name)
|
||||||
|
if tool_name is not None and not any(tool.identifier == tool_name for tool in tools.data):
|
||||||
|
raise ValueError(
|
||||||
|
f"Tool {tool_name} not found in toolgroup {toolgroup_name}. Available tools: {', '.join([tool.identifier for tool in tools.data])}"
|
||||||
|
)
|
||||||
|
|
||||||
for tool_def in tools.data:
|
for tool_def in tools.data:
|
||||||
if toolgroup_name.startswith("builtin") and toolgroup_name != RAG_TOOL_GROUP:
|
if toolgroup_name.startswith("builtin") and toolgroup_name != RAG_TOOL_GROUP:
|
||||||
tool_name = tool_def.identifier
|
tool_name = tool_def.identifier
|
||||||
|
@ -906,10 +817,10 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
else:
|
else:
|
||||||
built_in_type = BuiltinTool(tool_name)
|
built_in_type = BuiltinTool(tool_name)
|
||||||
|
|
||||||
if tool_def_map.get(built_in_type, None):
|
if tool_name_to_def.get(built_in_type, None):
|
||||||
raise ValueError(f"Tool {built_in_type} already exists")
|
raise ValueError(f"Tool {built_in_type} already exists")
|
||||||
|
|
||||||
tool_def_map[built_in_type] = ToolDefinition(
|
tool_name_to_def[built_in_type] = ToolDefinition(
|
||||||
tool_name=built_in_type,
|
tool_name=built_in_type,
|
||||||
description=tool_def.description,
|
description=tool_def.description,
|
||||||
parameters={
|
parameters={
|
||||||
|
@ -925,9 +836,10 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
tool_to_group[built_in_type] = tool_def.toolgroup_id
|
tool_to_group[built_in_type] = tool_def.toolgroup_id
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if tool_def_map.get(tool_def.identifier, None):
|
if tool_name_to_def.get(tool_def.identifier, None):
|
||||||
raise ValueError(f"Tool {tool_def.identifier} already exists")
|
raise ValueError(f"Tool {tool_def.identifier} already exists")
|
||||||
tool_def_map[tool_def.identifier] = ToolDefinition(
|
if tool_name in (None, tool_def.identifier):
|
||||||
|
tool_name_to_def[tool_def.identifier] = ToolDefinition(
|
||||||
tool_name=tool_def.identifier,
|
tool_name=tool_def.identifier,
|
||||||
description=tool_def.description,
|
description=tool_def.description,
|
||||||
parameters={
|
parameters={
|
||||||
|
@ -942,7 +854,24 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
tool_to_group[tool_def.identifier] = tool_def.toolgroup_id
|
tool_to_group[tool_def.identifier] = tool_def.toolgroup_id
|
||||||
|
|
||||||
return tool_def_map, tool_to_group
|
return list(tool_name_to_def.values()), tool_to_group
|
||||||
|
|
||||||
|
def _parse_toolgroup_name(self, toolgroup_name_with_maybe_tool_name: str) -> tuple[str, Optional[str]]:
|
||||||
|
"""Parse a toolgroup name into its components.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
toolgroup_name: The toolgroup name to parse (e.g. "builtin::rag/knowledge_search")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of (tool_type, tool_group, tool_name)
|
||||||
|
"""
|
||||||
|
split_names = toolgroup_name_with_maybe_tool_name.split("/")
|
||||||
|
if len(split_names) == 2:
|
||||||
|
# e.g. "builtin::rag"
|
||||||
|
tool_group, tool_name = split_names
|
||||||
|
else:
|
||||||
|
tool_group, tool_name = split_names[0], None
|
||||||
|
return tool_group, tool_name
|
||||||
|
|
||||||
async def handle_documents(
|
async def handle_documents(
|
||||||
self,
|
self,
|
||||||
|
@ -951,8 +880,8 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
input_messages: List[Message],
|
input_messages: List[Message],
|
||||||
tool_defs: Dict[str, ToolDefinition],
|
tool_defs: Dict[str, ToolDefinition],
|
||||||
) -> None:
|
) -> None:
|
||||||
memory_tool = tool_defs.get(MEMORY_QUERY_TOOL, None)
|
memory_tool = any(tool_def.tool_name == MEMORY_QUERY_TOOL for tool_def in tool_defs)
|
||||||
code_interpreter_tool = tool_defs.get(BuiltinTool.code_interpreter, None)
|
code_interpreter_tool = any(tool_def.tool_name == BuiltinTool.code_interpreter for tool_def in tool_defs)
|
||||||
content_items = []
|
content_items = []
|
||||||
url_items = []
|
url_items = []
|
||||||
pattern = re.compile("^(https?://|file://|data:)")
|
pattern = re.compile("^(https?://|file://|data:)")
|
||||||
|
@ -1072,7 +1001,11 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported URL {url}")
|
raise ValueError(f"Unsupported URL {url}")
|
||||||
|
|
||||||
content.append(TextContentItem(text=f'# There is a file accessible to you at "{filepath}"\n'))
|
content.append(
|
||||||
|
TextContentItem(
|
||||||
|
text=f'# User provided a file accessible to you at "{filepath}"\nYou can use code_interpreter to load and inspect it.'
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return ToolResponseMessage(
|
return ToolResponseMessage(
|
||||||
call_id="",
|
call_id="",
|
||||||
|
|
|
@ -7,6 +7,7 @@
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import sqlite3
|
import sqlite3
|
||||||
|
import threading
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from opentelemetry.sdk.trace import SpanProcessor
|
from opentelemetry.sdk.trace import SpanProcessor
|
||||||
|
@ -17,14 +18,18 @@ class SQLiteSpanProcessor(SpanProcessor):
|
||||||
def __init__(self, conn_string):
|
def __init__(self, conn_string):
|
||||||
"""Initialize the SQLite span processor with a connection string."""
|
"""Initialize the SQLite span processor with a connection string."""
|
||||||
self.conn_string = conn_string
|
self.conn_string = conn_string
|
||||||
self.conn = None
|
self._local = threading.local() # Thread-local storage for connections
|
||||||
self.setup_database()
|
self.setup_database()
|
||||||
|
|
||||||
def _get_connection(self) -> sqlite3.Connection:
|
def _get_connection(self):
|
||||||
"""Get the database connection."""
|
"""Get a thread-local database connection."""
|
||||||
if self.conn is None:
|
if not hasattr(self._local, "conn"):
|
||||||
self.conn = sqlite3.connect(self.conn_string, check_same_thread=False)
|
try:
|
||||||
return self.conn
|
self._local.conn = sqlite3.connect(self.conn_string)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error connecting to SQLite database: {e}")
|
||||||
|
raise e
|
||||||
|
return self._local.conn
|
||||||
|
|
||||||
def setup_database(self):
|
def setup_database(self):
|
||||||
"""Create the necessary tables if they don't exist."""
|
"""Create the necessary tables if they don't exist."""
|
||||||
|
@ -168,9 +173,14 @@ class SQLiteSpanProcessor(SpanProcessor):
|
||||||
|
|
||||||
def shutdown(self):
|
def shutdown(self):
|
||||||
"""Cleanup any resources."""
|
"""Cleanup any resources."""
|
||||||
if self.conn:
|
# We can't access other threads' connections, so we just close our own
|
||||||
self.conn.close()
|
if hasattr(self._local, "conn"):
|
||||||
self.conn = None
|
try:
|
||||||
|
self._local.conn.close()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error closing SQLite connection: {e}")
|
||||||
|
finally:
|
||||||
|
del self._local.conn
|
||||||
|
|
||||||
def force_flush(self, timeout_millis=30000):
|
def force_flush(self, timeout_millis=30000):
|
||||||
"""Force export of spans."""
|
"""Force export of spans."""
|
||||||
|
|
|
@ -10,6 +10,8 @@ import secrets
|
||||||
import string
|
import string
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from pydantic import TypeAdapter
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
URL,
|
URL,
|
||||||
InterleavedContent,
|
InterleavedContent,
|
||||||
|
@ -23,6 +25,7 @@ from llama_stack.apis.tools import (
|
||||||
RAGToolRuntime,
|
RAGToolRuntime,
|
||||||
ToolDef,
|
ToolDef,
|
||||||
ToolInvocationResult,
|
ToolInvocationResult,
|
||||||
|
ToolParameter,
|
||||||
ToolRuntime,
|
ToolRuntime,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
|
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
|
||||||
|
@ -120,9 +123,14 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
||||||
# sort by score
|
# sort by score
|
||||||
chunks, scores = zip(*sorted(zip(chunks, scores, strict=False), key=lambda x: x[1], reverse=True), strict=False)
|
chunks, scores = zip(*sorted(zip(chunks, scores, strict=False), key=lambda x: x[1], reverse=True), strict=False)
|
||||||
chunks = chunks[: query_config.max_chunks]
|
chunks = chunks[: query_config.max_chunks]
|
||||||
|
|
||||||
tokens = 0
|
tokens = 0
|
||||||
picked = []
|
picked = [
|
||||||
for c in chunks:
|
TextContentItem(
|
||||||
|
text=f"knowledge_search tool found {len(chunks)} chunks:\nBEGIN of knowledge_search tool results.\n"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
for i, c in enumerate(chunks):
|
||||||
metadata = c.metadata
|
metadata = c.metadata
|
||||||
tokens += metadata["token_count"]
|
tokens += metadata["token_count"]
|
||||||
if tokens > query_config.max_tokens_in_context:
|
if tokens > query_config.max_tokens_in_context:
|
||||||
|
@ -132,20 +140,13 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
||||||
break
|
break
|
||||||
picked.append(
|
picked.append(
|
||||||
TextContentItem(
|
TextContentItem(
|
||||||
text=f"id:{metadata['document_id']}; content:{c.content}",
|
text=f"Result {i + 1}:\nDocument_id:{metadata['document_id'][:5]}\nContent: {c.content}\n",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
picked.append(TextContentItem(text="END of knowledge_search tool results.\n"))
|
||||||
|
|
||||||
return RAGQueryResult(
|
return RAGQueryResult(
|
||||||
content=[
|
content=picked,
|
||||||
TextContentItem(
|
|
||||||
text="Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n",
|
|
||||||
),
|
|
||||||
*picked,
|
|
||||||
TextContentItem(
|
|
||||||
text="\n=== END-RETRIEVED-CONTEXT ===\n",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
metadata={
|
metadata={
|
||||||
"document_ids": [c.metadata["document_id"] for c in chunks[: len(picked)]],
|
"document_ids": [c.metadata["document_id"] for c in chunks[: len(picked)]],
|
||||||
},
|
},
|
||||||
|
@ -158,17 +159,40 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
||||||
# by the LLM. The method is only implemented so things like /tools can list without
|
# by the LLM. The method is only implemented so things like /tools can list without
|
||||||
# encountering fatals.
|
# encountering fatals.
|
||||||
return [
|
return [
|
||||||
ToolDef(
|
|
||||||
name="query_from_memory",
|
|
||||||
description="Retrieve context from memory",
|
|
||||||
),
|
|
||||||
ToolDef(
|
ToolDef(
|
||||||
name="insert_into_memory",
|
name="insert_into_memory",
|
||||||
description="Insert documents into memory",
|
description="Insert documents into memory",
|
||||||
),
|
),
|
||||||
|
ToolDef(
|
||||||
|
name="knowledge_search",
|
||||||
|
description="Search for information in a database.",
|
||||||
|
parameters=[
|
||||||
|
ToolParameter(
|
||||||
|
name="query",
|
||||||
|
description="The query to search for. Can be a natural language sentence or keywords.",
|
||||||
|
parameter_type="string",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
|
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
|
||||||
raise RuntimeError(
|
vector_db_ids = kwargs.get("vector_db_ids", [])
|
||||||
"This toolgroup should not be called generically but only through specific methods of the RAGToolRuntime protocol"
|
query_config = kwargs.get("query_config")
|
||||||
|
if query_config:
|
||||||
|
query_config = TypeAdapter(RAGQueryConfig).validate_python(query_config)
|
||||||
|
else:
|
||||||
|
# handle someone passing an empty dict
|
||||||
|
query_config = RAGQueryConfig()
|
||||||
|
|
||||||
|
query = kwargs["query"]
|
||||||
|
result = await self.query(
|
||||||
|
content=query,
|
||||||
|
vector_db_ids=vector_db_ids,
|
||||||
|
query_config=query_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ToolInvocationResult(
|
||||||
|
content=result.content,
|
||||||
|
metadata=result.metadata,
|
||||||
)
|
)
|
||||||
|
|
|
@ -74,6 +74,7 @@ docs = [
|
||||||
"sphinxcontrib.redoc",
|
"sphinxcontrib.redoc",
|
||||||
"sphinxcontrib.video",
|
"sphinxcontrib.video",
|
||||||
"sphinxcontrib.mermaid",
|
"sphinxcontrib.mermaid",
|
||||||
|
"tomli",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.urls]
|
[project.urls]
|
||||||
|
|
|
@ -96,7 +96,7 @@ def agent_config(llama_stack_client, text_model_id):
|
||||||
sampling_params={
|
sampling_params={
|
||||||
"strategy": {
|
"strategy": {
|
||||||
"type": "top_p",
|
"type": "top_p",
|
||||||
"temperature": 1.0,
|
"temperature": 0.0001,
|
||||||
"top_p": 0.9,
|
"top_p": 0.9,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -441,7 +441,8 @@ def xtest_override_system_message_behavior(llama_stack_client, agent_config):
|
||||||
assert "get_boiling_point" in logs_str
|
assert "get_boiling_point" in logs_str
|
||||||
|
|
||||||
|
|
||||||
def test_rag_agent(llama_stack_client, agent_config):
|
@pytest.mark.parametrize("rag_tool_name", ["builtin::rag/knowledge_search", "builtin::rag"])
|
||||||
|
def test_rag_agent(llama_stack_client, agent_config, rag_tool_name):
|
||||||
urls = ["chat.rst", "llama3.rst", "memory_optimizations.rst", "lora_finetune.rst"]
|
urls = ["chat.rst", "llama3.rst", "memory_optimizations.rst", "lora_finetune.rst"]
|
||||||
documents = [
|
documents = [
|
||||||
Document(
|
Document(
|
||||||
|
@ -469,7 +470,7 @@ def test_rag_agent(llama_stack_client, agent_config):
|
||||||
**agent_config,
|
**agent_config,
|
||||||
"toolgroups": [
|
"toolgroups": [
|
||||||
dict(
|
dict(
|
||||||
name="builtin::rag",
|
name=rag_tool_name,
|
||||||
args={
|
args={
|
||||||
"vector_db_ids": [vector_db_id],
|
"vector_db_ids": [vector_db_id],
|
||||||
},
|
},
|
||||||
|
@ -483,10 +484,6 @@ def test_rag_agent(llama_stack_client, agent_config):
|
||||||
"Instead of the standard multi-head attention, what attention type does Llama3-8B use?",
|
"Instead of the standard multi-head attention, what attention type does Llama3-8B use?",
|
||||||
"grouped",
|
"grouped",
|
||||||
),
|
),
|
||||||
(
|
|
||||||
"What `tune` command to use for getting access to Llama3-8B-Instruct ?",
|
|
||||||
"download",
|
|
||||||
),
|
|
||||||
]
|
]
|
||||||
for prompt, expected_kw in user_prompts:
|
for prompt, expected_kw in user_prompts:
|
||||||
response = rag_agent.create_turn(
|
response = rag_agent.create_turn(
|
||||||
|
@ -496,23 +493,36 @@ def test_rag_agent(llama_stack_client, agent_config):
|
||||||
)
|
)
|
||||||
# rag is called
|
# rag is called
|
||||||
tool_execution_step = next(step for step in response.steps if step.step_type == "tool_execution")
|
tool_execution_step = next(step for step in response.steps if step.step_type == "tool_execution")
|
||||||
assert tool_execution_step.tool_calls[0].tool_name == "query_from_memory"
|
assert tool_execution_step.tool_calls[0].tool_name == "knowledge_search"
|
||||||
# document ids are present in metadata
|
# document ids are present in metadata
|
||||||
assert "num-0" in tool_execution_step.tool_responses[0].metadata["document_ids"]
|
assert all(
|
||||||
|
doc_id.startswith("num-") for doc_id in tool_execution_step.tool_responses[0].metadata["document_ids"]
|
||||||
|
)
|
||||||
|
if expected_kw:
|
||||||
assert expected_kw in response.output_message.content.lower()
|
assert expected_kw in response.output_message.content.lower()
|
||||||
|
|
||||||
|
|
||||||
def test_rag_and_code_agent(llama_stack_client, agent_config):
|
def test_rag_and_code_agent(llama_stack_client, agent_config):
|
||||||
urls = ["chat.rst"]
|
documents = []
|
||||||
documents = [
|
documents.append(
|
||||||
Document(
|
Document(
|
||||||
document_id=f"num-{i}",
|
document_id="nba_wiki",
|
||||||
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
|
content="The NBA was created on August 3, 1949, with the merger of the Basketball Association of America (BAA) and the National Basketball League (NBL).",
|
||||||
mime_type="text/plain",
|
|
||||||
metadata={},
|
metadata={},
|
||||||
)
|
)
|
||||||
for i, url in enumerate(urls)
|
)
|
||||||
]
|
documents.append(
|
||||||
|
Document(
|
||||||
|
document_id="perplexity_wiki",
|
||||||
|
content="""Perplexity the company was founded in 2022 by Aravind Srinivas, Andy Konwinski, Denis Yarats and Johnny Ho, engineers with backgrounds in back-end systems, artificial intelligence (AI) and machine learning:
|
||||||
|
|
||||||
|
Srinivas, the CEO, worked at OpenAI as an AI researcher.
|
||||||
|
Konwinski was among the founding team at Databricks.
|
||||||
|
Yarats, the CTO, was an AI research scientist at Meta.
|
||||||
|
Ho, the CSO, worked as an engineer at Quora, then as a quantitative trader on Wall Street.[5]""",
|
||||||
|
metadata={},
|
||||||
|
)
|
||||||
|
)
|
||||||
vector_db_id = f"test-vector-db-{uuid4()}"
|
vector_db_id = f"test-vector-db-{uuid4()}"
|
||||||
llama_stack_client.vector_dbs.register(
|
llama_stack_client.vector_dbs.register(
|
||||||
vector_db_id=vector_db_id,
|
vector_db_id=vector_db_id,
|
||||||
|
@ -528,7 +538,7 @@ def test_rag_and_code_agent(llama_stack_client, agent_config):
|
||||||
**agent_config,
|
**agent_config,
|
||||||
"toolgroups": [
|
"toolgroups": [
|
||||||
dict(
|
dict(
|
||||||
name="builtin::rag",
|
name="builtin::rag/knowledge_search",
|
||||||
args={"vector_db_ids": [vector_db_id]},
|
args={"vector_db_ids": [vector_db_id]},
|
||||||
),
|
),
|
||||||
"builtin::code_interpreter",
|
"builtin::code_interpreter",
|
||||||
|
@ -546,24 +556,34 @@ def test_rag_and_code_agent(llama_stack_client, agent_config):
|
||||||
"Here is a csv file, can you describe it?",
|
"Here is a csv file, can you describe it?",
|
||||||
[inflation_doc],
|
[inflation_doc],
|
||||||
"code_interpreter",
|
"code_interpreter",
|
||||||
|
"",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
"What are the top 5 topics that were explained? Only list succinct bullet points.",
|
"when was Perplexity the company founded?",
|
||||||
[],
|
[],
|
||||||
"query_from_memory",
|
"knowledge_search",
|
||||||
|
"2022",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"when was the nba created?",
|
||||||
|
[],
|
||||||
|
"knowledge_search",
|
||||||
|
"1949",
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
for prompt, docs, tool_name in user_prompts:
|
for prompt, docs, tool_name, expected_kw in user_prompts:
|
||||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||||
response = agent.create_turn(
|
response = agent.create_turn(
|
||||||
messages=[{"role": "user", "content": prompt}],
|
messages=[{"role": "user", "content": prompt}],
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
documents=docs,
|
documents=docs,
|
||||||
|
stream=False,
|
||||||
)
|
)
|
||||||
logs = [str(log) for log in EventLogger().log(response) if log is not None]
|
tool_execution_step = next(step for step in response.steps if step.step_type == "tool_execution")
|
||||||
logs_str = "".join(logs)
|
assert tool_execution_step.tool_calls[0].tool_name == tool_name
|
||||||
assert f"Tool:{tool_name}" in logs_str
|
if expected_kw:
|
||||||
|
assert expected_kw in response.output_message.content.lower()
|
||||||
|
|
||||||
|
|
||||||
def test_create_turn_response(llama_stack_client, agent_config):
|
def test_create_turn_response(llama_stack_client, agent_config):
|
||||||
|
|
4
uv.lock
generated
4
uv.lock
generated
|
@ -1,5 +1,4 @@
|
||||||
version = 1
|
version = 1
|
||||||
revision = 1
|
|
||||||
requires-python = ">=3.10"
|
requires-python = ">=3.10"
|
||||||
resolution-markers = [
|
resolution-markers = [
|
||||||
"(python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux')",
|
"(python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux')",
|
||||||
|
@ -913,6 +912,7 @@ docs = [
|
||||||
{ name = "sphinxcontrib-mermaid" },
|
{ name = "sphinxcontrib-mermaid" },
|
||||||
{ name = "sphinxcontrib-redoc" },
|
{ name = "sphinxcontrib-redoc" },
|
||||||
{ name = "sphinxcontrib-video" },
|
{ name = "sphinxcontrib-video" },
|
||||||
|
{ name = "tomli" },
|
||||||
]
|
]
|
||||||
test = [
|
test = [
|
||||||
{ name = "aiosqlite" },
|
{ name = "aiosqlite" },
|
||||||
|
@ -971,13 +971,13 @@ requires-dist = [
|
||||||
{ name = "sphinxcontrib-redoc", marker = "extra == 'docs'" },
|
{ name = "sphinxcontrib-redoc", marker = "extra == 'docs'" },
|
||||||
{ name = "sphinxcontrib-video", marker = "extra == 'docs'" },
|
{ name = "sphinxcontrib-video", marker = "extra == 'docs'" },
|
||||||
{ name = "termcolor" },
|
{ name = "termcolor" },
|
||||||
|
{ name = "tomli", marker = "extra == 'docs'" },
|
||||||
{ name = "torch", marker = "extra == 'test'", specifier = ">=2.6.0", index = "https://download.pytorch.org/whl/cpu" },
|
{ name = "torch", marker = "extra == 'test'", specifier = ">=2.6.0", index = "https://download.pytorch.org/whl/cpu" },
|
||||||
{ name = "torchvision", marker = "extra == 'test'", specifier = ">=0.21.0", index = "https://download.pytorch.org/whl/cpu" },
|
{ name = "torchvision", marker = "extra == 'test'", specifier = ">=0.21.0", index = "https://download.pytorch.org/whl/cpu" },
|
||||||
{ name = "types-requests", marker = "extra == 'dev'" },
|
{ name = "types-requests", marker = "extra == 'dev'" },
|
||||||
{ name = "types-setuptools", marker = "extra == 'dev'" },
|
{ name = "types-setuptools", marker = "extra == 'dev'" },
|
||||||
{ name = "uvicorn", marker = "extra == 'dev'" },
|
{ name = "uvicorn", marker = "extra == 'dev'" },
|
||||||
]
|
]
|
||||||
provides-extras = ["dev", "test", "docs"]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "llama-stack-client"
|
name = "llama-stack-client"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue