agents to use tools api (#673)

# What does this PR do?

PR #639 introduced the notion of Tools API and ability to invoke tools
through API just as any resource. This PR changes the Agents to start
using the Tools API to invoke tools. Major changes include:
1) Ability to specify tool groups with AgentConfig
2) Agent gets the corresponding tool definitions for the specified tools
and pass along to the model
3) Attachements are now named as Documents and their behavior is mostly
unchanged from user perspective
4) You can specify args that can be injected to a tool call through
Agent config. This is especially useful in case of memory tool, where
you want the tool to operate on a specific memory bank.
5) You can also register tool groups with args, which lets the agent
inject these as well into the tool call.
6) All tests have been migrated to use new tools API and fixtures
including client SDK tests
7) Telemetry just works with tools API because of our trace protocol
decorator


## Test Plan
```
pytest -s -v -k fireworks llama_stack/providers/tests/agents/test_agents.py  \
   --safety-shield=meta-llama/Llama-Guard-3-8B \
   --inference-model=meta-llama/Llama-3.1-8B-Instruct

pytest -s -v -k together  llama_stack/providers/tests/tools/test_tools.py \
   --safety-shield=meta-llama/Llama-Guard-3-8B \
   --inference-model=meta-llama/Llama-3.1-8B-Instruct

LLAMA_STACK_CONFIG="/Users/dineshyv/.llama/distributions/llamastack-together/together-run.yaml" pytest -v tests/client-sdk/agents/test_agents.py
```
run.yaml:
https://gist.github.com/dineshyv/0365845ad325e1c2cab755788ccc5994

Notebook:
https://colab.research.google.com/drive/1ck7hXQxRl6UvT-ijNRZ-gMZxH1G3cN2d?usp=sharing
This commit is contained in:
Dinesh Yeduguru 2025-01-08 19:01:00 -08:00 committed by GitHub
parent 596afc6497
commit a5c57cd381
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
116 changed files with 4959 additions and 2778 deletions

View file

@ -23,6 +23,7 @@
"psycopg2-binary",
"pypdf",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
@ -54,6 +55,7 @@
"psycopg2-binary",
"pypdf",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
@ -86,6 +88,7 @@
"psycopg2-binary",
"pypdf",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
@ -116,6 +119,7 @@
"psycopg2-binary",
"pypdf",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
@ -148,6 +152,7 @@
"psycopg2-binary",
"pypdf",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
@ -181,6 +186,7 @@
"psycopg2-binary",
"pypdf",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
@ -213,6 +219,7 @@
"psycopg2-binary",
"pypdf",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
@ -247,6 +254,7 @@
"psycopg2-binary",
"pypdf",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentence-transformers",
@ -286,6 +294,7 @@
"psycopg2-binary",
"pypdf",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentence-transformers",
@ -319,6 +328,7 @@
"psycopg2-binary",
"pypdf",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
@ -352,6 +362,7 @@
"psycopg2-binary",
"pypdf",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
@ -385,6 +396,7 @@
"psycopg2-binary",
"pypdf",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -19,6 +19,7 @@ The `llamastack/distribution-bedrock` distribution consists of the following pro
| safety | `remote::bedrock` |
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
| telemetry | `inline::meta-reference` |
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::memory-runtime` |

View file

@ -9,6 +9,7 @@ The `llamastack/distribution-cerebras` distribution consists of the following pr
| memory | `inline::meta-reference` |
| safety | `inline::llama-guard` |
| telemetry | `inline::meta-reference` |
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::memory-runtime` |
### Environment Variables

View file

@ -22,6 +22,7 @@ The `llamastack/distribution-fireworks` distribution consists of the following p
| safety | `inline::llama-guard` |
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
| telemetry | `inline::meta-reference` |
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::memory-runtime` |
### Environment Variables

View file

@ -22,6 +22,7 @@ The `llamastack/distribution-meta-reference-gpu` distribution consists of the fo
| safety | `inline::llama-guard` |
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
| telemetry | `inline::meta-reference` |
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::memory-runtime` |
Note that you need access to nvidia GPUs to run this distribution. This distribution is not compatible with CPU-only machines or machines with AMD GPUs.

View file

@ -22,6 +22,7 @@ The `llamastack/distribution-meta-reference-quantized-gpu` distribution consists
| safety | `inline::llama-guard` |
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
| telemetry | `inline::meta-reference` |
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::memory-runtime` |
The only difference vs. the `meta-reference-gpu` distribution is that it has support for more efficient inference -- with fp8, int4 quantization, etc.

View file

@ -22,6 +22,7 @@ The `llamastack/distribution-ollama` distribution consists of the following prov
| safety | `inline::llama-guard` |
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
| telemetry | `inline::meta-reference` |
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::memory-runtime` |
You should use this distribution if you have a regular desktop machine without very powerful GPUs. Of course, if you have powerful GPUs, you can still continue using this distribution since Ollama supports GPU acceleration.### Environment Variables

View file

@ -18,6 +18,7 @@ The `llamastack/distribution-remote-vllm` distribution consists of the following
| memory | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
| safety | `inline::llama-guard` |
| telemetry | `inline::meta-reference` |
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::memory-runtime` |
You can use this distribution if you have GPUs and want to run an independent vLLM server container for running inference.

View file

@ -23,6 +23,7 @@ The `llamastack/distribution-tgi` distribution consists of the following provide
| safety | `inline::llama-guard` |
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
| telemetry | `inline::meta-reference` |
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::memory-runtime` |
You can use this distribution if you have GPUs and want to run an independent TGI server container for running inference.

View file

@ -22,6 +22,7 @@ The `llamastack/distribution-together` distribution consists of the following pr
| safety | `inline::llama-guard` |
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
| telemetry | `inline::meta-reference` |
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::memory-runtime` |
### Environment Variables

View file

@ -18,15 +18,11 @@ from typing import (
Union,
)
from llama_models.llama3.api.datatypes import ToolParamDefinition
from llama_models.schema_utils import json_schema_type, webmethod
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import Annotated
from llama_stack.apis.common.content_types import InterleavedContent, URL
from llama_stack.apis.common.deployment_types import RestAPIExecutionConfig
from llama_stack.apis.inference import (
CompletionMessage,
SamplingParams,
@ -40,166 +36,18 @@ from llama_stack.apis.inference import (
)
from llama_stack.apis.memory import MemoryBank
from llama_stack.apis.safety import SafetyViolation
from llama_stack.apis.tools import ToolDef
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
@json_schema_type
class Attachment(BaseModel):
content: InterleavedContent | URL
mime_type: str
class AgentTool(Enum):
brave_search = "brave_search"
wolfram_alpha = "wolfram_alpha"
photogen = "photogen"
code_interpreter = "code_interpreter"
function_call = "function_call"
memory = "memory"
class ToolDefinitionCommon(BaseModel):
input_shields: Optional[List[str]] = Field(default_factory=list)
output_shields: Optional[List[str]] = Field(default_factory=list)
class SearchEngineType(Enum):
bing = "bing"
brave = "brave"
tavily = "tavily"
@json_schema_type
class SearchToolDefinition(ToolDefinitionCommon):
# NOTE: brave_search is just a placeholder since model always uses
# brave_search as tool call name
type: Literal[AgentTool.brave_search.value] = AgentTool.brave_search.value
api_key: str
engine: SearchEngineType = SearchEngineType.brave
remote_execution: Optional[RestAPIExecutionConfig] = None
@json_schema_type
class WolframAlphaToolDefinition(ToolDefinitionCommon):
type: Literal[AgentTool.wolfram_alpha.value] = AgentTool.wolfram_alpha.value
api_key: str
remote_execution: Optional[RestAPIExecutionConfig] = None
@json_schema_type
class PhotogenToolDefinition(ToolDefinitionCommon):
type: Literal[AgentTool.photogen.value] = AgentTool.photogen.value
remote_execution: Optional[RestAPIExecutionConfig] = None
@json_schema_type
class CodeInterpreterToolDefinition(ToolDefinitionCommon):
type: Literal[AgentTool.code_interpreter.value] = AgentTool.code_interpreter.value
enable_inline_code_execution: bool = True
remote_execution: Optional[RestAPIExecutionConfig] = None
@json_schema_type
class FunctionCallToolDefinition(ToolDefinitionCommon):
type: Literal[AgentTool.function_call.value] = AgentTool.function_call.value
function_name: str
description: str
parameters: Dict[str, ToolParamDefinition]
remote_execution: Optional[RestAPIExecutionConfig] = None
class _MemoryBankConfigCommon(BaseModel):
bank_id: str
class AgentVectorMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal["vector"] = "vector"
class AgentKeyValueMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal["keyvalue"] = "keyvalue"
keys: List[str] # what keys to focus on
class AgentKeywordMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal["keyword"] = "keyword"
class AgentGraphMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal["graph"] = "graph"
entities: List[str] # what entities to focus on
MemoryBankConfig = Annotated[
Union[
AgentVectorMemoryBankConfig,
AgentKeyValueMemoryBankConfig,
AgentKeywordMemoryBankConfig,
AgentGraphMemoryBankConfig,
],
Field(discriminator="type"),
]
class MemoryQueryGenerator(Enum):
default = "default"
llm = "llm"
custom = "custom"
class DefaultMemoryQueryGeneratorConfig(BaseModel):
type: Literal[MemoryQueryGenerator.default.value] = (
MemoryQueryGenerator.default.value
)
sep: str = " "
class LLMMemoryQueryGeneratorConfig(BaseModel):
type: Literal[MemoryQueryGenerator.llm.value] = MemoryQueryGenerator.llm.value
model: str
template: str
class CustomMemoryQueryGeneratorConfig(BaseModel):
type: Literal[MemoryQueryGenerator.custom.value] = MemoryQueryGenerator.custom.value
MemoryQueryGeneratorConfig = Annotated[
Union[
DefaultMemoryQueryGeneratorConfig,
LLMMemoryQueryGeneratorConfig,
CustomMemoryQueryGeneratorConfig,
],
Field(discriminator="type"),
]
@json_schema_type
class MemoryToolDefinition(ToolDefinitionCommon):
type: Literal[AgentTool.memory.value] = AgentTool.memory.value
memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list)
# This config defines how a query is generated using the messages
# for memory bank retrieval.
query_generator_config: MemoryQueryGeneratorConfig = Field(
default=DefaultMemoryQueryGeneratorConfig()
)
max_tokens_in_context: int = 4096
max_chunks: int = 10
AgentToolDefinition = Annotated[
Union[
SearchToolDefinition,
WolframAlphaToolDefinition,
PhotogenToolDefinition,
CodeInterpreterToolDefinition,
FunctionCallToolDefinition,
MemoryToolDefinition,
],
Field(discriminator="type"),
]
class Document(BaseModel):
content: InterleavedContent | URL
mime_type: str
class StepCommon(BaseModel):
@ -289,13 +137,27 @@ class Session(BaseModel):
memory_bank: Optional[MemoryBank] = None
class AgentToolGroupWithArgs(BaseModel):
name: str
args: Dict[str, Any]
AgentToolGroup = register_schema(
Union[
str,
AgentToolGroupWithArgs,
],
name="AgentTool",
)
class AgentConfigCommon(BaseModel):
sampling_params: Optional[SamplingParams] = SamplingParams()
input_shields: Optional[List[str]] = Field(default_factory=list)
output_shields: Optional[List[str]] = Field(default_factory=list)
tools: Optional[List[AgentToolDefinition]] = Field(default_factory=list)
toolgroups: Optional[List[AgentToolGroup]] = Field(default_factory=list)
client_tools: Optional[List[ToolDef]] = Field(default_factory=list)
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
tool_prompt_format: Optional[ToolPromptFormat] = Field(
default=ToolPromptFormat.json
@ -340,6 +202,7 @@ class AgentTurnResponseStepCompletePayload(BaseModel):
AgentTurnResponseEventType.step_complete.value
)
step_type: StepType
step_id: str
step_details: Step
@ -413,7 +276,9 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
ToolResponseMessage,
]
]
attachments: Optional[List[Attachment]] = None
documents: Optional[List[Document]] = None
toolgroups: Optional[List[AgentToolGroup]] = None
stream: Optional[bool] = False
@ -450,8 +315,9 @@ class Agents(Protocol):
ToolResponseMessage,
]
],
attachments: Optional[List[Attachment]] = None,
stream: Optional[bool] = False,
documents: Optional[List[Document]] = None,
toolgroups: Optional[List[AgentToolGroup]] = None,
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
@webmethod(route="/agents/turn/get")

View file

@ -4,10 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Annotated, Any, Dict, List, Literal, Optional, Union
from enum import Enum
from typing import Any, Dict, List, Literal, Optional
from llama_models.llama3.api.datatypes import ToolPromptFormat
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from typing_extensions import Protocol, runtime_checkable
@ -21,15 +22,24 @@ class ToolParameter(BaseModel):
name: str
parameter_type: str
description: str
required: bool = Field(default=True)
default: Optional[Any] = None
@json_schema_type
class ToolHost(Enum):
distribution = "distribution"
client = "client"
model_context_protocol = "model_context_protocol"
@json_schema_type
class Tool(Resource):
type: Literal[ResourceType.tool.value] = ResourceType.tool.value
tool_group: str
toolgroup_id: str
tool_host: ToolHost
description: str
parameters: List[ToolParameter]
provider_id: Optional[str] = None
metadata: Optional[Dict[str, Any]] = None
tool_prompt_format: Optional[ToolPromptFormat] = Field(
default=ToolPromptFormat.json
@ -39,41 +49,27 @@ class Tool(Resource):
@json_schema_type
class ToolDef(BaseModel):
name: str
description: str
parameters: List[ToolParameter]
metadata: Dict[str, Any]
description: Optional[str] = None
parameters: Optional[List[ToolParameter]] = None
metadata: Optional[Dict[str, Any]] = None
tool_prompt_format: Optional[ToolPromptFormat] = Field(
default=ToolPromptFormat.json
)
@json_schema_type
class MCPToolGroupDef(BaseModel):
"""
A tool group that is defined by in a model context protocol server.
Refer to https://modelcontextprotocol.io/docs/concepts/tools for more information.
"""
type: Literal["model_context_protocol"] = "model_context_protocol"
endpoint: URL
class ToolGroupInput(BaseModel):
toolgroup_id: str
provider_id: str
args: Optional[Dict[str, Any]] = None
mcp_endpoint: Optional[URL] = None
@json_schema_type
class UserDefinedToolGroupDef(BaseModel):
type: Literal["user_defined"] = "user_defined"
tools: List[ToolDef]
ToolGroupDef = register_schema(
Annotated[
Union[MCPToolGroupDef, UserDefinedToolGroupDef], Field(discriminator="type")
],
name="ToolGroup",
)
class ToolGroup(Resource):
type: Literal[ResourceType.tool_group.value] = ResourceType.tool_group.value
mcp_endpoint: Optional[URL] = None
args: Optional[Dict[str, Any]] = None
@json_schema_type
@ -85,6 +81,7 @@ class ToolInvocationResult(BaseModel):
class ToolStore(Protocol):
def get_tool(self, tool_name: str) -> Tool: ...
def get_tool_group(self, tool_group_id: str) -> ToolGroup: ...
@runtime_checkable
@ -93,9 +90,10 @@ class ToolGroups(Protocol):
@webmethod(route="/toolgroups/register", method="POST")
async def register_tool_group(
self,
tool_group_id: str,
tool_group: ToolGroupDef,
provider_id: Optional[str] = None,
toolgroup_id: str,
provider_id: str,
mcp_endpoint: Optional[URL] = None,
args: Optional[Dict[str, Any]] = None,
) -> None:
"""Register a tool group"""
...
@ -103,7 +101,7 @@ class ToolGroups(Protocol):
@webmethod(route="/toolgroups/get", method="GET")
async def get_tool_group(
self,
tool_group_id: str,
toolgroup_id: str,
) -> ToolGroup: ...
@webmethod(route="/toolgroups/list", method="GET")
@ -130,8 +128,11 @@ class ToolGroups(Protocol):
class ToolRuntime(Protocol):
tool_store: ToolStore
@webmethod(route="/tool-runtime/discover", method="POST")
async def discover_tools(self, tool_group: ToolGroupDef) -> List[ToolDef]: ...
# TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed.
@webmethod(route="/tool-runtime/list-tools", method="GET")
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]: ...
@webmethod(route="/tool-runtime/invoke", method="POST")
async def invoke_tool(

View file

@ -20,7 +20,7 @@ from llama_stack.apis.safety import Safety
from llama_stack.apis.scoring import Scoring
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnInput
from llama_stack.apis.shields import Shield, ShieldInput
from llama_stack.apis.tools import Tool, ToolGroup, ToolRuntime
from llama_stack.apis.tools import Tool, ToolGroup, ToolGroupInput, ToolRuntime
from llama_stack.providers.datatypes import Api, ProviderSpec
from llama_stack.providers.utils.kvstore.config import KVStoreConfig
@ -161,6 +161,7 @@ a default SQLite store will be used.""",
datasets: List[DatasetInput] = Field(default_factory=list)
scoring_fns: List[ScoringFnInput] = Field(default_factory=list)
eval_tasks: List[EvalTaskInput] = Field(default_factory=list)
tool_groups: List[ToolGroupInput] = Field(default_factory=list)
class BuildConfig(BaseModel):

View file

@ -267,6 +267,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
self.config, self.custom_provider_registry
)
except ModuleNotFoundError as _e:
cprint(_e.msg, "red")
cprint(
"Using llama-stack as a library requires installing dependencies depending on the template (providers) you choose.\n",
"yellow",

View file

@ -5,9 +5,7 @@
# the root directory of this source tree.
import importlib
import inspect
import logging
from typing import Any, Dict, List, Set
from llama_stack.apis.agents import Agents
@ -28,7 +26,6 @@ from llama_stack.apis.shields import Shields
from llama_stack.apis.telemetry import Telemetry
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.distribution.client import get_client_impl
from llama_stack.distribution.datatypes import (
AutoRoutedProviderSpec,
Provider,
@ -38,7 +35,6 @@ from llama_stack.distribution.datatypes import (
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.store import DistributionRegistry
from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.providers.datatypes import (
Api,
DatasetsProtocolPrivate,

View file

@ -6,7 +6,7 @@
from typing import Any, AsyncGenerator, Dict, List, Optional
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.common.content_types import InterleavedContent, URL
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
from llama_stack.apis.eval import (
AppEvalTaskConfig,
@ -38,7 +38,7 @@ from llama_stack.apis.scoring import (
ScoringFnParams,
)
from llama_stack.apis.shields import Shield
from llama_stack.apis.tools import Tool, ToolGroupDef, ToolRuntime
from llama_stack.apis.tools import ToolDef, ToolRuntime
from llama_stack.providers.datatypes import RoutingTable
@ -417,7 +417,9 @@ class ToolRuntimeRouter(ToolRuntime):
args=args,
)
async def discover_tools(self, tool_group: ToolGroupDef) -> List[Tool]:
return await self.routing_table.get_provider_impl(
tool_group.name
).discover_tools(tool_group)
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(
tool_group_id, mcp_endpoint
)

View file

@ -6,7 +6,7 @@
from typing import Any, Dict, List, Optional
from pydantic import parse_obj_as
from pydantic import TypeAdapter
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.type_system import ParamType
@ -26,20 +26,12 @@ from llama_stack.apis.scoring_functions import (
ScoringFunctions,
)
from llama_stack.apis.shields import Shield, Shields
from llama_stack.apis.tools import (
MCPToolGroupDef,
Tool,
ToolGroup,
ToolGroupDef,
ToolGroups,
UserDefinedToolGroupDef,
)
from llama_stack.apis.tools import Tool, ToolGroup, ToolGroups, ToolHost
from llama_stack.distribution.datatypes import (
RoutableObject,
RoutableObjectWithProvider,
RoutedProtocol,
)
from llama_stack.distribution.store import DistributionRegistry
from llama_stack.providers.datatypes import Api, RoutingTable
@ -361,7 +353,7 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
memory_bank_data["embedding_dimension"] = model.metadata[
"embedding_dimension"
]
memory_bank = parse_obj_as(MemoryBank, memory_bank_data)
memory_bank = TypeAdapter(MemoryBank).validate_python(memory_bank_data)
await self.register_object(memory_bank)
return memory_bank
@ -496,54 +488,45 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
async def list_tools(self, tool_group_id: Optional[str] = None) -> List[Tool]:
tools = await self.get_all_with_type("tool")
if tool_group_id:
tools = [tool for tool in tools if tool.tool_group == tool_group_id]
tools = [tool for tool in tools if tool.toolgroup_id == tool_group_id]
return tools
async def list_tool_groups(self) -> List[ToolGroup]:
return await self.get_all_with_type("tool_group")
async def get_tool_group(self, tool_group_id: str) -> ToolGroup:
return await self.get_object_by_identifier("tool_group", tool_group_id)
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup:
return await self.get_object_by_identifier("tool_group", toolgroup_id)
async def get_tool(self, tool_name: str) -> Tool:
return await self.get_object_by_identifier("tool", tool_name)
async def register_tool_group(
self,
tool_group_id: str,
tool_group: ToolGroupDef,
provider_id: Optional[str] = None,
toolgroup_id: str,
provider_id: str,
mcp_endpoint: Optional[URL] = None,
args: Optional[Dict[str, Any]] = None,
) -> None:
tools = []
tool_defs = []
if provider_id is None:
if len(self.impls_by_provider_id.keys()) > 1:
raise ValueError(
f"No provider_id specified and multiple providers available. Please specify a provider_id. Available providers: {', '.join(self.impls_by_provider_id.keys())}"
)
provider_id = list(self.impls_by_provider_id.keys())[0]
if isinstance(tool_group, MCPToolGroupDef):
tool_defs = await self.impls_by_provider_id[provider_id].discover_tools(
tool_group
)
elif isinstance(tool_group, UserDefinedToolGroupDef):
tool_defs = tool_group.tools
else:
raise ValueError(f"Unknown tool group: {tool_group}")
tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(
toolgroup_id, mcp_endpoint
)
tool_host = (
ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution
)
for tool_def in tool_defs:
tools.append(
Tool(
identifier=tool_def.name,
tool_group=tool_group_id,
description=tool_def.description,
parameters=tool_def.parameters,
toolgroup_id=toolgroup_id,
description=tool_def.description or "",
parameters=tool_def.parameters or [],
provider_id=provider_id,
tool_prompt_format=tool_def.tool_prompt_format,
provider_resource_id=tool_def.name,
metadata=tool_def.metadata,
tool_host=tool_host,
)
)
for tool in tools:
@ -561,9 +544,11 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
await self.dist_registry.register(
ToolGroup(
identifier=tool_group_id,
identifier=toolgroup_id,
provider_id=provider_id,
provider_resource_id=tool_group_id,
provider_resource_id=toolgroup_id,
mcp_endpoint=mcp_endpoint,
args=args,
)
)

View file

@ -12,7 +12,6 @@ from typing import Any, Dict, Optional
import pkg_resources
import yaml
from termcolor import colored
from llama_stack.apis.agents import Agents
@ -33,14 +32,13 @@ from llama_stack.apis.scoring_functions import ScoringFunctions
from llama_stack.apis.shields import Shields
from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration
from llama_stack.apis.telemetry import Telemetry
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.distribution.datatypes import StackRunConfig
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
from llama_stack.distribution.store.registry import create_dist_registry
from llama_stack.providers.datatypes import Api
log = logging.getLogger(__name__)
LLAMA_STACK_API_VERSION = "alpha"
@ -65,6 +63,8 @@ class LlamaStack(
Models,
Shields,
Inspect,
ToolGroups,
ToolRuntime,
):
pass
@ -81,6 +81,7 @@ RESOURCES = [
"list_scoring_functions",
),
("eval_tasks", Api.eval_tasks, "register_eval_task", "list_eval_tasks"),
("tool_groups", Api.tool_groups, "register_tool_group", "list_tool_groups"),
]

View file

@ -12,7 +12,6 @@ import pydantic
from llama_stack.distribution.datatypes import KVStoreConfig, RoutableObjectWithProvider
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
@ -36,7 +35,7 @@ class DistributionRegistry(Protocol):
REGISTER_PREFIX = "distributions:registry"
KEY_VERSION = "v3"
KEY_VERSION = "v4"
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"

View file

@ -22,6 +22,8 @@ async def get_provider_impl(
deps[Api.memory],
deps[Api.safety],
deps[Api.memory_banks],
deps[Api.tool_runtime],
deps[Api.tool_groups],
)
await impl.initialize()
return impl

View file

@ -4,8 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import copy
import json
import logging
import os
import re
@ -13,16 +13,16 @@ import secrets
import string
import uuid
from datetime import datetime
from typing import AsyncGenerator, Dict, List, Optional, Tuple
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
from urllib.parse import urlparse
import httpx
from llama_models.llama3.api.datatypes import BuiltinTool
from llama_models.llama3.api.datatypes import BuiltinTool, ToolCall, ToolParamDefinition
from llama_stack.apis.agents import (
AgentConfig,
AgentTool,
AgentToolGroup,
AgentToolGroupWithArgs,
AgentTurnCreateRequest,
AgentTurnResponseEvent,
AgentTurnResponseEventType,
@ -33,25 +33,14 @@ from llama_stack.apis.agents import (
AgentTurnResponseTurnCompletePayload,
AgentTurnResponseTurnStartPayload,
Attachment,
CodeInterpreterToolDefinition,
FunctionCallToolDefinition,
Document,
InferenceStep,
MemoryRetrievalStep,
MemoryToolDefinition,
PhotogenToolDefinition,
SearchToolDefinition,
ShieldCallStep,
StepType,
ToolExecutionStep,
Turn,
WolframAlphaToolDefinition,
)
from llama_stack.apis.common.content_types import (
InterleavedContent,
TextContentItem,
URL,
)
from llama_stack.apis.common.content_types import TextContentItem, URL
from llama_stack.apis.inference import (
ChatCompletionResponseEventType,
CompletionMessage,
@ -62,32 +51,20 @@ from llama_stack.apis.inference import (
SystemMessage,
ToolCallDelta,
ToolCallParseStatus,
ToolChoice,
ToolDefinition,
ToolResponse,
ToolResponseMessage,
UserMessage,
)
from llama_stack.apis.memory import Memory, MemoryBankDocument, QueryDocumentsResponse
from llama_stack.apis.memory import Memory, MemoryBankDocument
from llama_stack.apis.memory_banks import MemoryBanks, VectorMemoryBankParams
from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.providers.utils.kvstore import KVStore
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
from llama_stack.providers.utils.telemetry import tracing
from .persistence import AgentPersistence
from .rag.context_retriever import generate_rag_query
from .safety import SafetyException, ShieldRunnerMixin
from .tools.base import BaseTool
from .tools.builtin import (
CodeInterpreterTool,
interpret_content_as_attachment,
PhotogenTool,
SearchTool,
WolframAlphaTool,
)
from .tools.safety import SafeTool
log = logging.getLogger(__name__)
@ -98,6 +75,12 @@ def make_random_string(length: int = 8):
)
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
MEMORY_QUERY_TOOL = "query_memory"
WEB_SEARCH_TOOL = "web_search"
MEMORY_GROUP = "builtin::memory"
class ChatAgent(ShieldRunnerMixin):
def __init__(
self,
@ -108,6 +91,8 @@ class ChatAgent(ShieldRunnerMixin):
memory_api: Memory,
memory_banks_api: MemoryBanks,
safety_api: Safety,
tool_runtime_api: ToolRuntime,
tool_groups_api: ToolGroups,
persistence_store: KVStore,
):
self.agent_id = agent_id
@ -118,29 +103,8 @@ class ChatAgent(ShieldRunnerMixin):
self.memory_banks_api = memory_banks_api
self.safety_api = safety_api
self.storage = AgentPersistence(agent_id, persistence_store)
builtin_tools = []
for tool_defn in agent_config.tools:
if isinstance(tool_defn, WolframAlphaToolDefinition):
tool = WolframAlphaTool(tool_defn.api_key)
elif isinstance(tool_defn, SearchToolDefinition):
tool = SearchTool(tool_defn.engine, tool_defn.api_key)
elif isinstance(tool_defn, CodeInterpreterToolDefinition):
tool = CodeInterpreterTool()
elif isinstance(tool_defn, PhotogenToolDefinition):
tool = PhotogenTool(dump_dir=self.tempdir)
else:
continue
builtin_tools.append(
SafeTool(
tool,
safety_api,
tool_defn.input_shields,
tool_defn.output_shields,
)
)
self.tools_dict = {t.get_name(): t for t in builtin_tools}
self.tool_runtime_api = tool_runtime_api
self.tool_groups_api = tool_groups_api
ShieldRunnerMixin.__init__(
self,
@ -228,9 +192,10 @@ class ChatAgent(ShieldRunnerMixin):
session_id=request.session_id,
turn_id=turn_id,
input_messages=messages,
attachments=request.attachments or [],
sampling_params=self.agent_config.sampling_params,
stream=request.stream,
documents=request.documents,
toolgroups_for_turn=request.toolgroups,
):
if isinstance(chunk, CompletionMessage):
log.info(
@ -278,9 +243,10 @@ class ChatAgent(ShieldRunnerMixin):
session_id: str,
turn_id: str,
input_messages: List[Message],
attachments: List[Attachment],
sampling_params: SamplingParams,
stream: bool = False,
documents: Optional[List[Document]] = None,
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
) -> AsyncGenerator:
# Doing async generators makes downstream code much simpler and everything amenable to
# streaming. However, it also makes things complicated here because AsyncGenerators cannot
@ -297,7 +263,13 @@ class ChatAgent(ShieldRunnerMixin):
yield res
async for res in self._run(
session_id, turn_id, input_messages, attachments, sampling_params, stream
session_id,
turn_id,
input_messages,
sampling_params,
stream,
documents,
toolgroups_for_turn,
):
if isinstance(res, bool):
return
@ -353,6 +325,7 @@ class ChatAgent(ShieldRunnerMixin):
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.shield_call.value,
step_id=step_id,
step_details=ShieldCallStep(
step_id=step_id,
turn_id=turn_id,
@ -373,6 +346,7 @@ class ChatAgent(ShieldRunnerMixin):
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.shield_call.value,
step_id=step_id,
step_details=ShieldCallStep(
step_id=step_id,
turn_id=turn_id,
@ -388,73 +362,116 @@ class ChatAgent(ShieldRunnerMixin):
session_id: str,
turn_id: str,
input_messages: List[Message],
attachments: List[Attachment],
sampling_params: SamplingParams,
stream: bool = False,
documents: Optional[List[Document]] = None,
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
) -> AsyncGenerator:
enabled_tools = set(t.type for t in self.agent_config.tools)
need_rag_context = await self._should_retrieve_context(
input_messages, attachments
)
if need_rag_context:
step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepStartPayload(
step_type=StepType.memory_retrieval.value,
step_id=step_id,
toolgroup_args = {}
for toolgroup in self.agent_config.toolgroups:
if isinstance(toolgroup, AgentToolGroupWithArgs):
toolgroup_args[toolgroup.name] = toolgroup.args
if toolgroups_for_turn:
for toolgroup in toolgroups_for_turn:
if isinstance(toolgroup, AgentToolGroupWithArgs):
toolgroup_args[toolgroup.name] = toolgroup.args
tool_defs, tool_to_group = await self._get_tool_defs(toolgroups_for_turn)
if documents:
await self.handle_documents(
session_id, documents, input_messages, tool_defs
)
if MEMORY_QUERY_TOOL in tool_defs and len(input_messages) > 0:
memory_tool_group = tool_to_group.get(MEMORY_QUERY_TOOL, None)
if memory_tool_group is None:
raise ValueError(f"Memory tool group not found for {MEMORY_QUERY_TOOL}")
with tracing.span(MEMORY_QUERY_TOOL) as span:
step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepStartPayload(
step_type=StepType.tool_execution.value,
step_id=step_id,
)
)
)
)
query_args = {
"messages": [msg.content for msg in input_messages],
**toolgroup_args.get(memory_tool_group, {}),
}
# TODO: find older context from the session and either replace it
# or append with a sliding window. this is really a very simplistic implementation
with tracing.span("retrieve_rag_context") as span:
rag_context, bank_ids = await self._retrieve_context(
session_id, input_messages, attachments
session_info = await self.storage.get_session_info(session_id)
# if the session has a memory bank id, let the memory tool use it
if session_info.memory_bank_id:
if "memory_bank_ids" not in query_args:
query_args["memory_bank_ids"] = []
query_args["memory_bank_ids"].append(session_info.memory_bank_id)
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.tool_execution.value,
step_id=step_id,
tool_call_delta=ToolCallDelta(
parse_status=ToolCallParseStatus.success,
content=ToolCall(
call_id="",
tool_name=MEMORY_QUERY_TOOL,
arguments={},
),
),
)
)
)
result = await self.tool_runtime_api.invoke_tool(
tool_name=MEMORY_QUERY_TOOL,
args=query_args,
)
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.tool_execution.value,
step_id=step_id,
step_details=ToolExecutionStep(
step_id=step_id,
turn_id=turn_id,
tool_calls=[
ToolCall(
call_id="",
tool_name=MEMORY_QUERY_TOOL,
arguments={},
)
],
tool_responses=[
ToolResponse(
call_id="",
tool_name=MEMORY_QUERY_TOOL,
content=result.content,
)
],
),
)
)
)
span.set_attribute(
"input", [m.model_dump_json() for m in input_messages]
)
span.set_attribute("output", rag_context)
span.set_attribute("bank_ids", bank_ids)
step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.memory_retrieval.value,
step_id=step_id,
step_details=MemoryRetrievalStep(
turn_id=turn_id,
step_id=step_id,
memory_bank_ids=bank_ids,
inserted_context=rag_context or "",
),
)
)
)
if rag_context:
last_message = input_messages[-1]
last_message.context = rag_context
elif attachments and AgentTool.code_interpreter.value in enabled_tools:
urls = [a.content for a in attachments if isinstance(a.content, URL)]
# TODO: we need to migrate URL away from str type
pattern = re.compile("^(https?://|file://|data:)")
urls += [
URL(uri=a.content) for a in attachments if pattern.match(a.content)
]
msg = await attachment_message(self.tempdir, urls)
input_messages.append(msg)
span.set_attribute("output", result.content)
span.set_attribute("error_code", result.error_code)
span.set_attribute("error_message", result.error_message)
span.set_attribute("tool_name", MEMORY_QUERY_TOOL)
if result.error_code == 0:
last_message = input_messages[-1]
last_message.context = result.content
output_attachments = []
n_iter = 0
# Build a map of custom tools to their definitions for faster lookup
client_tools = {}
for tool in self.agent_config.client_tools:
client_tools[tool.name] = tool
while True:
msg = input_messages[-1]
step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
@ -473,7 +490,11 @@ class ChatAgent(ShieldRunnerMixin):
async for chunk in await self.inference_api.chat_completion(
self.agent_config.model,
input_messages,
tools=self._get_tools(),
tools=[
tool
for tool in tool_defs.values()
if tool_to_group.get(tool.tool_name, None) != MEMORY_GROUP
],
tool_prompt_format=self.agent_config.tool_prompt_format,
stream=True,
sampling_params=sampling_params,
@ -572,9 +593,9 @@ class ChatAgent(ShieldRunnerMixin):
# TODO: UPDATE RETURN TYPE TO SEND A TUPLE OF (MESSAGE, ATTACHMENTS)
if len(output_attachments) > 0:
if isinstance(message.content, list):
message.content += attachments
message.content += output_attachments
else:
message.content = [message.content] + attachments
message.content = [message.content] + output_attachments
yield message
else:
log.info(f"Partial message: {str(message)}")
@ -582,9 +603,7 @@ class ChatAgent(ShieldRunnerMixin):
else:
log.info(f"{str(message)}")
tool_call = message.tool_calls[0]
name = tool_call.tool_name
if not isinstance(name, BuiltinTool) or name not in enabled_tools:
if tool_call.tool_name in client_tools:
yield message
return
@ -607,16 +626,22 @@ class ChatAgent(ShieldRunnerMixin):
)
)
tool_name = tool_call.tool_name
if isinstance(tool_name, BuiltinTool):
tool_name = tool_name.value
with tracing.span(
"tool_execution",
{
"tool_name": tool_call.tool_name,
"tool_name": tool_name,
"input": message.model_dump_json(),
},
) as span:
result_messages = await execute_tool_call_maybe(
self.tools_dict,
self.tool_runtime_api,
session_id,
[message],
toolgroup_args,
tool_to_group,
)
assert (
len(result_messages) == 1
@ -628,6 +653,7 @@ class ChatAgent(ShieldRunnerMixin):
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.tool_execution.value,
step_id=step_id,
step_details=ToolExecutionStep(
step_id=step_id,
turn_id=turn_id,
@ -647,7 +673,7 @@ class ChatAgent(ShieldRunnerMixin):
# TODO: add tool-input touchpoint and a "start" event for this step also
# but that needs a lot more refactoring of Tool code potentially
if out_attachment := interpret_content_as_attachment(
if out_attachment := _interpret_content_as_attachment(
result_message.content
):
# NOTE: when we push this message back to the model, the model may ignore the
@ -659,6 +685,150 @@ class ChatAgent(ShieldRunnerMixin):
n_iter += 1
async def _get_tool_defs(
self, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None
) -> Tuple[Dict[str, ToolDefinition], Dict[str, str]]:
# Determine which tools to include
agent_config_toolgroups = set(
(
toolgroup.name
if isinstance(toolgroup, AgentToolGroupWithArgs)
else toolgroup
)
for toolgroup in self.agent_config.toolgroups
)
toolgroups_for_turn_set = (
agent_config_toolgroups
if toolgroups_for_turn is None
else {
(
toolgroup.name
if isinstance(toolgroup, AgentToolGroupWithArgs)
else toolgroup
)
for toolgroup in toolgroups_for_turn
}
)
tool_def_map = {}
tool_to_group = {}
for tool_def in self.agent_config.client_tools:
if tool_def_map.get(tool_def.name, None):
raise ValueError(f"Tool {tool_def.name} already exists")
tool_def_map[tool_def.name] = ToolDefinition(
tool_name=tool_def.name,
description=tool_def.description,
parameters={
param.name: ToolParamDefinition(
param_type=param.parameter_type,
description=param.description,
required=param.required,
default=param.default,
)
for param in tool_def.parameters
},
)
tool_to_group[tool_def.name] = "__client_tools__"
for toolgroup_name in agent_config_toolgroups:
if toolgroup_name not in toolgroups_for_turn_set:
continue
tools = await self.tool_groups_api.list_tools(tool_group_id=toolgroup_name)
for tool_def in tools:
if (
toolgroup_name.startswith("builtin")
and toolgroup_name != MEMORY_GROUP
):
tool_name = tool_def.identifier
built_in_type = BuiltinTool.brave_search
if tool_name == "web_search":
built_in_type = BuiltinTool.brave_search
else:
built_in_type = BuiltinTool(tool_name)
if tool_def_map.get(built_in_type, None):
raise ValueError(f"Tool {built_in_type} already exists")
tool_def_map[built_in_type] = ToolDefinition(
tool_name=built_in_type
)
tool_to_group[built_in_type] = tool_def.toolgroup_id
continue
if tool_def_map.get(tool_def.identifier, None):
raise ValueError(f"Tool {tool_def.identifier} already exists")
tool_def_map[tool_def.identifier] = ToolDefinition(
tool_name=tool_def.identifier,
description=tool_def.description,
parameters={
param.name: ToolParamDefinition(
param_type=param.parameter_type,
description=param.description,
required=param.required,
default=param.default,
)
for param in tool_def.parameters
},
)
tool_to_group[tool_def.identifier] = tool_def.toolgroup_id
return tool_def_map, tool_to_group
async def handle_documents(
self,
session_id: str,
documents: List[Document],
input_messages: List[Message],
tool_defs: Dict[str, ToolDefinition],
) -> None:
memory_tool = tool_defs.get(MEMORY_QUERY_TOOL, None)
code_interpreter_tool = tool_defs.get(BuiltinTool.code_interpreter, None)
content_items = []
url_items = []
pattern = re.compile("^(https?://|file://|data:)")
for d in documents:
if isinstance(d.content, URL):
url_items.append(d.content)
elif pattern.match(d.content):
url_items.append(URL(uri=d.content))
else:
content_items.append(d)
# Save the contents to a tempdir and use its path as a URL if code interpreter is present
if code_interpreter_tool:
for c in content_items:
temp_file_path = os.path.join(
self.tempdir, f"{make_random_string()}.txt"
)
with open(temp_file_path, "w") as temp_file:
temp_file.write(c.content)
url_items.append(URL(uri=f"file://{temp_file_path}"))
if memory_tool and code_interpreter_tool:
# if both memory and code_interpreter are available, we download the URLs
# and attach the data to the last message.
msg = await attachment_message(self.tempdir, url_items)
input_messages.append(msg)
# Since memory is present, add all the data to the memory bank
await self.add_to_session_memory_bank(session_id, documents)
elif code_interpreter_tool:
# if only code_interpreter is available, we download the URLs to a tempdir
# and attach the path to them as a message to inference with the
# assumption that the model invokes the code_interpreter tool with the path
msg = await attachment_message(self.tempdir, url_items)
input_messages.append(msg)
elif memory_tool:
# if only memory is available, we load the data from the URLs and content items to the memory bank
await self.add_to_session_memory_bank(session_id, documents)
else:
# if no memory or code_interpreter tool is available,
# we try to load the data from the URLs and content items as a message to inference
# and add it to the last message's context
input_messages[-1].context = "\n".join(
[doc.content for doc in content_items]
+ await load_data_from_urls(url_items)
)
async def _ensure_memory_bank(self, session_id: str) -> str:
session_info = await self.storage.get_session_info(session_id)
if session_info is None:
@ -679,129 +849,39 @@ class ChatAgent(ShieldRunnerMixin):
return bank_id
async def _should_retrieve_context(
self, messages: List[Message], attachments: List[Attachment]
) -> bool:
enabled_tools = set(t.type for t in self.agent_config.tools)
if attachments:
if (
AgentTool.code_interpreter.value in enabled_tools
and self.agent_config.tool_choice == ToolChoice.required
):
return False
else:
return True
return AgentTool.memory.value in enabled_tools
def _memory_tool_definition(self) -> Optional[MemoryToolDefinition]:
for t in self.agent_config.tools:
if t.type == AgentTool.memory.value:
return t
return None
async def _retrieve_context(
self, session_id: str, messages: List[Message], attachments: List[Attachment]
) -> Tuple[Optional[InterleavedContent], List[int]]: # (rag_context, bank_ids)
bank_ids = []
memory = self._memory_tool_definition()
assert memory is not None, "Memory tool not configured"
bank_ids.extend(c.bank_id for c in memory.memory_bank_configs)
if attachments:
bank_id = await self._ensure_memory_bank(session_id)
bank_ids.append(bank_id)
documents = [
MemoryBankDocument(
document_id=str(uuid.uuid4()),
content=a.content,
mime_type=a.mime_type,
metadata={},
)
for a in attachments
]
with tracing.span("insert_documents"):
await self.memory_api.insert_documents(bank_id, documents)
else:
session_info = await self.storage.get_session_info(session_id)
if session_info.memory_bank_id:
bank_ids.append(session_info.memory_bank_id)
if not bank_ids:
# this can happen if the per-session memory bank is not yet populated
# (i.e., no prior turns uploaded an Attachment)
return None, []
query = await generate_rag_query(
memory.query_generator_config, messages, inference_api=self.inference_api
)
tasks = [
self.memory_api.query_documents(
bank_id=bank_id,
query=query,
params={
"max_chunks": 5,
},
async def add_to_session_memory_bank(
self, session_id: str, data: List[Document]
) -> None:
bank_id = await self._ensure_memory_bank(session_id)
documents = [
MemoryBankDocument(
document_id=str(uuid.uuid4()),
content=a.content,
mime_type=a.mime_type,
metadata={},
)
for bank_id in bank_ids
for a in data
]
results: List[QueryDocumentsResponse] = await asyncio.gather(*tasks)
chunks = [c for r in results for c in r.chunks]
scores = [s for r in results for s in r.scores]
if not chunks:
return None, bank_ids
# sort by score
chunks, scores = zip(
*sorted(zip(chunks, scores), key=lambda x: x[1], reverse=True)
await self.memory_api.insert_documents(
bank_id=bank_id,
documents=documents,
)
tokens = 0
picked = []
for c in chunks[: memory.max_chunks]:
tokens += c.token_count
if tokens > memory.max_tokens_in_context:
log.error(
f"Using {len(picked)} chunks; reached max tokens in context: {tokens}",
)
break
picked.append(f"id:{c.document_id}; content:{c.content}")
return (
concat_interleaved_content(
[
"Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n",
*picked,
"\n=== END-RETRIEVED-CONTEXT ===\n",
]
),
bank_ids,
)
def _get_tools(self) -> List[ToolDefinition]:
ret = []
for t in self.agent_config.tools:
if isinstance(t, SearchToolDefinition):
ret.append(ToolDefinition(tool_name=BuiltinTool.brave_search))
elif isinstance(t, WolframAlphaToolDefinition):
ret.append(ToolDefinition(tool_name=BuiltinTool.wolfram_alpha))
elif isinstance(t, PhotogenToolDefinition):
ret.append(ToolDefinition(tool_name=BuiltinTool.photogen))
elif isinstance(t, CodeInterpreterToolDefinition):
ret.append(ToolDefinition(tool_name=BuiltinTool.code_interpreter))
elif isinstance(t, FunctionCallToolDefinition):
ret.append(
ToolDefinition(
tool_name=t.function_name,
description=t.description,
parameters=t.parameters,
)
)
return ret
async def load_data_from_urls(urls: List[URL]) -> List[str]:
data = []
for url in urls:
uri = url.uri
if uri.startswith("file://"):
filepath = uri[len("file://") :]
with open(filepath, "r") as f:
data.append(f.read())
elif uri.startswith("http"):
async with httpx.AsyncClient() as client:
r = await client.get(uri)
resp = r.text
data.append(resp)
return data
async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessage:
@ -839,7 +919,11 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa
async def execute_tool_call_maybe(
tools_dict: Dict[str, BaseTool], messages: List[CompletionMessage]
tool_runtime_api: ToolRuntime,
session_id: str,
messages: List[CompletionMessage],
toolgroup_args: Dict[str, Dict[str, Any]],
tool_to_group: Dict[str, str],
) -> List[ToolResponseMessage]:
# While Tools.run interface takes a list of messages,
# All tools currently only run on a single message
@ -851,11 +935,45 @@ async def execute_tool_call_maybe(
tool_call = message.tool_calls[0]
name = tool_call.tool_name
assert isinstance(name, BuiltinTool)
group_name = tool_to_group.get(name, None)
if group_name is None:
raise ValueError(f"Tool {name} not found in any tool group")
# get the arguments generated by the model and augment with toolgroup arg overrides for the agent
tool_call_args = tool_call.arguments
tool_call_args.update(toolgroup_args.get(group_name, {}))
if isinstance(name, BuiltinTool):
if name == BuiltinTool.brave_search:
name = WEB_SEARCH_TOOL
else:
name = name.value
name = name.value
result = await tool_runtime_api.invoke_tool(
tool_name=name,
args=dict(
session_id=session_id,
**tool_call_args,
),
)
assert name in tools_dict, f"Tool {name} not found"
tool = tools_dict[name]
result_messages = await tool.run(messages)
return result_messages
return [
ToolResponseMessage(
call_id=tool_call.call_id,
tool_name=tool_call.tool_name,
content=result.content,
)
]
def _interpret_content_as_attachment(
content: str,
) -> Optional[Attachment]:
match = re.search(TOOLS_ATTACHMENT_KEY_REGEX, content)
if match:
snippet = match.group(1)
data = json.loads(snippet)
return Attachment(
url=URL(uri="file://" + data["filepath"]),
mime_type=data["mimetype"],
)
return None

View file

@ -19,17 +19,17 @@ from llama_stack.apis.agents import (
Agents,
AgentSessionCreateResponse,
AgentStepResponse,
AgentToolGroup,
AgentTurnCreateRequest,
Attachment,
Document,
Session,
Turn,
)
from llama_stack.apis.inference import Inference, ToolResponseMessage, UserMessage
from llama_stack.apis.memory import Memory
from llama_stack.apis.memory_banks import MemoryBanks
from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
from .agent_instance import ChatAgent
@ -47,12 +47,16 @@ class MetaReferenceAgentsImpl(Agents):
memory_api: Memory,
safety_api: Safety,
memory_banks_api: MemoryBanks,
tool_runtime_api: ToolRuntime,
tool_groups_api: ToolGroups,
):
self.config = config
self.inference_api = inference_api
self.memory_api = memory_api
self.safety_api = safety_api
self.memory_banks_api = memory_banks_api
self.tool_runtime_api = tool_runtime_api
self.tool_groups_api = tool_groups_api
self.in_memory_store = InmemoryKVStoreImpl()
self.tempdir = tempfile.mkdtemp()
@ -112,6 +116,8 @@ class MetaReferenceAgentsImpl(Agents):
safety_api=self.safety_api,
memory_api=self.memory_api,
memory_banks_api=self.memory_banks_api,
tool_runtime_api=self.tool_runtime_api,
tool_groups_api=self.tool_groups_api,
persistence_store=(
self.persistence_store
if agent_config.enable_session_persistence
@ -141,15 +147,17 @@ class MetaReferenceAgentsImpl(Agents):
ToolResponseMessage,
]
],
attachments: Optional[List[Attachment]] = None,
toolgroups: Optional[List[AgentToolGroup]] = None,
documents: Optional[List[Document]] = None,
stream: Optional[bool] = False,
) -> AsyncGenerator:
request = AgentTurnCreateRequest(
agent_id=agent_id,
session_id=session_id,
messages=messages,
attachments=attachments,
stream=True,
toolgroups=toolgroups,
documents=documents,
)
if stream:
return self._create_agent_turn_streaming(request)

View file

@ -8,13 +8,11 @@ import json
import logging
import uuid
from datetime import datetime
from typing import List, Optional
from pydantic import BaseModel
from llama_stack.apis.agents import Turn
from llama_stack.providers.utils.kvstore import KVStore
log = logging.getLogger(__name__)

View file

@ -1,93 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import unittest
from llama_models.llama3.api.datatypes import (
Attachment,
BuiltinTool,
CompletionMessage,
StopReason,
ToolCall,
)
from ..tools.builtin import CodeInterpreterTool
class TestCodeInterpreter(unittest.IsolatedAsyncioTestCase):
async def test_matplotlib(self):
tool = CodeInterpreterTool()
code = """
import matplotlib.pyplot as plt
import numpy as np
x = np.array([1, 1])
y = np.array([0, 10])
plt.plot(x, y)
plt.title('x = 1')
plt.xlabel('x')
plt.ylabel('y')
plt.grid(True)
plt.axvline(x=1, color='r')
plt.show()
"""
message = CompletionMessage(
role="assistant",
content="",
tool_calls=[
ToolCall(
call_id="call_id",
tool_name=BuiltinTool.code_interpreter,
arguments={"code": code},
)
],
stop_reason=StopReason.end_of_message,
)
ret = await tool.run([message])
self.assertEqual(len(ret), 1)
output = ret[0].content
self.assertIsInstance(output, Attachment)
self.assertEqual(output.mime_type, "image/png")
async def test_path_unlink(self):
tool = CodeInterpreterTool()
code = """
import os
from pathlib import Path
import tempfile
dpath = Path(os.environ["MPLCONFIGDIR"])
with open(dpath / "test", "w") as f:
f.write("hello")
Path(dpath / "test").unlink()
print("_OK_")
"""
message = CompletionMessage(
role="assistant",
content="",
tool_calls=[
ToolCall(
call_id="call_id",
tool_name=BuiltinTool.code_interpreter,
arguments={"code": code},
)
],
stop_reason=StopReason.end_of_message,
)
ret = await tool.run([message])
self.assertEqual(len(ret), 1)
output = ret[0].content
self.assertTrue("_OK_" in output)
if __name__ == "__main__":
unittest.main()

View file

@ -4,21 +4,26 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import tempfile
from typing import AsyncIterator, List, Optional, Union
import pytest
from llama_models.llama3.api.datatypes import BuiltinTool
from llama_stack.apis.agents import (
AgentConfig,
AgentToolGroupWithArgs,
AgentTurnCreateRequest,
AgentTurnResponseTurnCompletePayload,
StepType,
)
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.inference import (
ChatCompletionResponse,
ChatCompletionResponseEvent,
ChatCompletionResponseStreamChunk,
CompletionMessage,
LogProbConfig,
Message,
ResponseFormat,
SamplingParams,
@ -27,13 +32,24 @@ from llama_stack.apis.inference import (
UserMessage,
)
from llama_stack.apis.memory import MemoryBank
from llama_stack.apis.memory_banks import BankParams, VectorMemoryBank
from llama_stack.apis.safety import RunShieldResponse
from ..agents import (
AGENT_INSTANCES_BY_ID,
MetaReferenceAgentsImpl,
MetaReferenceInferenceConfig,
from llama_stack.apis.tools import (
Tool,
ToolDef,
ToolGroup,
ToolHost,
ToolInvocationResult,
ToolPromptFormat,
)
from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
MEMORY_QUERY_TOOL,
)
from llama_stack.providers.inline.agents.meta_reference.agents import (
MetaReferenceAgentsImpl,
MetaReferenceAgentsImplConfig,
)
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
class MockInferenceAPI:
@ -48,10 +64,10 @@ class MockInferenceAPI:
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncIterator[
Union[ChatCompletionResponseStreamChunk, ChatCompletionResponse]
) -> Union[
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
]:
if stream:
async def stream_response():
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type="start",
@ -65,19 +81,7 @@ class MockInferenceAPI:
delta="AI is a fascinating field...",
)
)
# yield ChatCompletionResponseStreamChunk(
# event=ChatCompletionResponseEvent(
# event_type="progress",
# delta=ToolCallDelta(
# content=ToolCall(
# call_id="123",
# tool_name=BuiltinTool.brave_search.value,
# arguments={"query": "AI history"},
# ),
# parse_status="success",
# ),
# )
# )
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type="complete",
@ -85,12 +89,17 @@ class MockInferenceAPI:
stop_reason="end_of_turn",
)
)
if stream:
return stream_response()
else:
yield ChatCompletionResponse(
return ChatCompletionResponse(
completion_message=CompletionMessage(
role="assistant", content="Mock response", stop_reason="end_of_turn"
role="assistant",
content="Mock response",
stop_reason="end_of_turn",
),
logprobs=[0.1, 0.2, 0.3] if logprobs else None,
logprobs={"token_logprobs": [0.1, 0.2, 0.3]} if logprobs else None,
)
@ -165,6 +174,98 @@ class MockMemoryAPI:
self.documents[bank_id].pop(doc_id, None)
class MockToolGroupsAPI:
async def register_tool_group(
self, toolgroup_id: str, provider_id: str, mcp_endpoint=None, args=None
) -> None:
pass
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup:
return ToolGroup(
identifier=toolgroup_id,
provider_resource_id=toolgroup_id,
)
async def list_tool_groups(self) -> List[ToolGroup]:
return []
async def list_tools(self, tool_group_id: Optional[str] = None) -> List[Tool]:
if tool_group_id == MEMORY_TOOLGROUP:
return [
Tool(
identifier=MEMORY_QUERY_TOOL,
provider_resource_id=MEMORY_QUERY_TOOL,
toolgroup_id=MEMORY_TOOLGROUP,
tool_host=ToolHost.client,
description="Mock tool",
provider_id="builtin::memory",
parameters=[],
)
]
if tool_group_id == CODE_INTERPRETER_TOOLGROUP:
return [
Tool(
identifier="code_interpreter",
provider_resource_id="code_interpreter",
toolgroup_id=CODE_INTERPRETER_TOOLGROUP,
tool_host=ToolHost.client,
description="Mock tool",
provider_id="builtin::code_interpreter",
parameters=[],
)
]
return []
async def get_tool(self, tool_name: str) -> Tool:
return Tool(
identifier=tool_name,
provider_resource_id=tool_name,
toolgroup_id="mock_group",
tool_host=ToolHost.client,
description="Mock tool",
provider_id="mock_provider",
parameters=[],
)
async def unregister_tool_group(self, tool_group_id: str) -> None:
pass
class MockToolRuntimeAPI:
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
return []
async def invoke_tool(self, tool_name: str, args: dict) -> ToolInvocationResult:
return ToolInvocationResult(content={"result": "Mock tool result"})
class MockMemoryBanksAPI:
async def list_memory_banks(self) -> List[MemoryBank]:
return []
async def get_memory_bank(self, memory_bank_id: str) -> Optional[MemoryBank]:
return None
async def register_memory_bank(
self,
memory_bank_id: str,
params: BankParams,
provider_id: Optional[str] = None,
provider_memory_bank_id: Optional[str] = None,
) -> MemoryBank:
return VectorMemoryBank(
identifier=memory_bank_id,
provider_resource_id=provider_memory_bank_id or memory_bank_id,
embedding_model="mock_model",
chunk_size_in_tokens=512,
)
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
pass
@pytest.fixture
def mock_inference_api():
return MockInferenceAPI()
@ -181,64 +282,107 @@ def mock_memory_api():
@pytest.fixture
async def chat_agent(mock_inference_api, mock_safety_api, mock_memory_api):
def mock_tool_groups_api():
return MockToolGroupsAPI()
@pytest.fixture
def mock_tool_runtime_api():
return MockToolRuntimeAPI()
@pytest.fixture
def mock_memory_banks_api():
return MockMemoryBanksAPI()
@pytest.fixture
async def get_agents_impl(
mock_inference_api,
mock_safety_api,
mock_memory_api,
mock_memory_banks_api,
mock_tool_runtime_api,
mock_tool_groups_api,
):
sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
impl = MetaReferenceAgentsImpl(
config=MetaReferenceInferenceConfig(),
config=MetaReferenceAgentsImplConfig(
persistence_store=SqliteKVStoreConfig(
db_name=sqlite_file.name,
),
),
inference_api=mock_inference_api,
safety_api=mock_safety_api,
memory_api=mock_memory_api,
memory_banks_api=mock_memory_banks_api,
tool_runtime_api=mock_tool_runtime_api,
tool_groups_api=mock_tool_groups_api,
)
await impl.initialize()
return impl
@pytest.fixture
async def get_chat_agent(get_agents_impl):
impl = await get_agents_impl
agent_config = AgentConfig(
model="test_model",
instructions="You are a helpful assistant.",
sampling_params=SamplingParams(),
tools=[
# SearchToolDefinition(
# name="brave_search",
# api_key="test_key",
# ),
],
toolgroups=[],
tool_choice=ToolChoice.auto,
enable_session_persistence=False,
input_shields=[],
output_shields=[],
input_shields=["test_shield"],
)
response = await impl.create_agent(agent_config)
agent = AGENT_INSTANCES_BY_ID[response.agent_id]
return agent
return await impl.get_agent(response.agent_id)
MEMORY_TOOLGROUP = "builtin::memory"
CODE_INTERPRETER_TOOLGROUP = "builtin::code_interpreter"
@pytest.fixture
async def get_chat_agent_with_tools(get_agents_impl, request):
impl = await get_agents_impl
toolgroups = request.param
agent_config = AgentConfig(
model="test_model",
instructions="You are a helpful assistant.",
toolgroups=toolgroups,
tool_choice=ToolChoice.auto,
enable_session_persistence=False,
input_shields=["test_shield"],
)
response = await impl.create_agent(agent_config)
return await impl.get_agent(response.agent_id)
@pytest.mark.asyncio
async def test_chat_agent_create_session(chat_agent):
session = chat_agent.create_session("Test Session")
assert session.session_name == "Test Session"
assert session.turns == []
assert session.session_id in chat_agent.sessions
@pytest.mark.asyncio
async def test_chat_agent_create_and_execute_turn(chat_agent):
session = chat_agent.create_session("Test Session")
async def test_chat_agent_create_and_execute_turn(get_chat_agent):
chat_agent = await get_chat_agent
session_id = await chat_agent.create_session("Test Session")
request = AgentTurnCreateRequest(
agent_id="random",
session_id=session.session_id,
agent_id=chat_agent.agent_id,
session_id=session_id,
messages=[UserMessage(content="Hello")],
stream=True,
)
responses = []
async for response in chat_agent.create_and_execute_turn(request):
responses.append(response)
print(responses)
assert len(responses) > 0
assert len(responses) == 4 # TurnStart, StepStart, StepComplete, TurnComplete
assert (
len(responses) == 7
) # TurnStart, ShieldCallStart, ShieldCallComplete, StepStart, StepProgress, StepComplete, TurnComplete
assert responses[0].event.payload.turn_id is not None
@pytest.mark.asyncio
async def test_run_multiple_shields_wrapper(chat_agent):
async def test_run_multiple_shields_wrapper(get_chat_agent):
chat_agent = await get_chat_agent
messages = [UserMessage(content="Test message")]
shields = ["test_shield"]
@ -254,69 +398,95 @@ async def test_run_multiple_shields_wrapper(chat_agent):
assert len(responses) == 2 # StepStart, StepComplete
assert responses[0].event.payload.step_type.value == "shield_call"
assert not responses[1].event.payload.step_details.response.is_violation
assert not responses[1].event.payload.step_details.violation
@pytest.mark.asyncio
@pytest.mark.skip(reason="Not yet implemented; need to mock out tool execution easily")
async def test_chat_agent_complex_turn(chat_agent):
# Setup
session = chat_agent.create_session("Test Session")
async def test_chat_agent_complex_turn(get_chat_agent):
chat_agent = await get_chat_agent
session_id = await chat_agent.create_session("Test Session")
request = AgentTurnCreateRequest(
agent_id="random",
session_id=session.session_id,
agent_id=chat_agent.agent_id,
session_id=session_id,
messages=[UserMessage(content="Tell me about AI and then use a tool.")],
stream=True,
)
# Execute the turn
responses = []
async for response in chat_agent.create_and_execute_turn(request):
responses.append(response)
# Assertions
assert len(responses) > 0
# Check for the presence of different step types
step_types = [
response.event.payload.step_type
for response in responses
if hasattr(response.event.payload, "step_type")
]
assert "shield_call" in step_types, "Shield call step is missing"
assert "inference" in step_types, "Inference step is missing"
assert "tool_execution" in step_types, "Tool execution step is missing"
assert StepType.shield_call in step_types, "Shield call step is missing"
assert StepType.inference in step_types, "Inference step is missing"
# Check for the presence of start and complete events
event_types = [
response.event.payload.event_type
for response in responses
if hasattr(response.event.payload, "event_type")
]
assert "start" in event_types, "Start event is missing"
assert "complete" in event_types, "Complete event is missing"
assert "turn_start" in event_types, "Start event is missing"
assert "turn_complete" in event_types, "Complete event is missing"
# Check for the presence of tool call
tool_calls = [
response.event.payload.tool_call
for response in responses
if hasattr(response.event.payload, "tool_call")
]
assert any(
tool_call
for tool_call in tool_calls
if tool_call and tool_call.content.get("name") == "memory"
), "Memory tool call is missing"
# Check for the final turn complete event
assert any(
isinstance(response.event.payload, AgentTurnResponseTurnCompletePayload)
for response in responses
), "Turn complete event is missing"
turn_complete_payload = next(
response.event.payload
for response in responses
if isinstance(response.event.payload, AgentTurnResponseTurnCompletePayload)
)
turn = turn_complete_payload.turn
assert turn.input_messages == request.messages, "Input messages do not match"
# Verify the turn was added to the session
assert len(session.turns) == 1, "Turn was not added to the session"
assert (
session.turns[0].input_messages == request.messages
), "Input messages do not match"
@pytest.mark.asyncio
@pytest.mark.parametrize(
"toolgroups, expected_memory, expected_code_interpreter",
[
([], False, False), # no tools
([MEMORY_TOOLGROUP], True, False), # memory only
([CODE_INTERPRETER_TOOLGROUP], False, True), # code interpreter only
([MEMORY_TOOLGROUP, CODE_INTERPRETER_TOOLGROUP], True, True), # all tools
],
)
async def test_chat_agent_tools(
get_agents_impl, toolgroups, expected_memory, expected_code_interpreter
):
impl = await get_agents_impl
agent_config = AgentConfig(
model="test_model",
instructions="You are a helpful assistant.",
toolgroups=toolgroups,
tool_choice=ToolChoice.auto,
enable_session_persistence=False,
input_shields=["test_shield"],
)
response = await impl.create_agent(agent_config)
chat_agent = await impl.get_agent(response.agent_id)
tool_defs, _ = await chat_agent._get_tool_defs()
if expected_memory:
assert MEMORY_QUERY_TOOL in tool_defs
if expected_code_interpreter:
assert BuiltinTool.code_interpreter in tool_defs
if expected_memory and expected_code_interpreter:
# override the tools for turn
new_tool_defs, _ = await chat_agent._get_tool_defs(
toolgroups_for_turn=[
AgentToolGroupWithArgs(
name=MEMORY_TOOLGROUP,
args={"memory_banks": ["test_memory_bank"]},
)
]
)
assert MEMORY_QUERY_TOOL in new_tool_defs
assert BuiltinTool.code_interpreter not in new_tool_defs

View file

@ -1,20 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from abc import ABC, abstractmethod
from typing import List
from llama_stack.apis.inference import Message
class BaseTool(ABC):
@abstractmethod
def get_name(self) -> str:
raise NotImplementedError
@abstractmethod
async def run(self, messages: List[Message]) -> List[Message]:
raise NotImplementedError

View file

@ -1,396 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import logging
import re
import tempfile
from abc import abstractmethod
from typing import List, Optional
import requests
from .ipython_tool.code_execution import (
CodeExecutionContext,
CodeExecutionRequest,
CodeExecutor,
TOOLS_ATTACHMENT_KEY_REGEX,
)
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.agents import * # noqa: F403
from .base import BaseTool
log = logging.getLogger(__name__)
def interpret_content_as_attachment(content: str) -> Optional[Attachment]:
match = re.search(TOOLS_ATTACHMENT_KEY_REGEX, content)
if match:
snippet = match.group(1)
data = json.loads(snippet)
return Attachment(
url=URL(uri="file://" + data["filepath"]), mime_type=data["mimetype"]
)
return None
class SingleMessageBuiltinTool(BaseTool):
async def run(self, messages: List[CompletionMessage]) -> List[ToolResponseMessage]:
assert len(messages) == 1, f"Expected single message, got {len(messages)}"
message = messages[0]
assert len(message.tool_calls) == 1, "Expected a single tool call"
tool_call = messages[0].tool_calls[0]
query = tool_call.arguments["query"]
response: str = await self.run_impl(query)
message = ToolResponseMessage(
call_id=tool_call.call_id,
tool_name=tool_call.tool_name,
content=response,
)
return [message]
@abstractmethod
async def run_impl(self, query: str) -> str:
raise NotImplementedError()
class PhotogenTool(SingleMessageBuiltinTool):
def __init__(self, dump_dir: str) -> None:
self.dump_dir = dump_dir
def get_name(self) -> str:
return BuiltinTool.photogen.value
async def run_impl(self, query: str) -> str:
"""
Implement this to give the model an ability to generate images.
Return:
info = {
"filepath": str(image_filepath),
"mimetype": "image/png",
}
"""
raise NotImplementedError()
class SearchTool(SingleMessageBuiltinTool):
def __init__(self, engine: SearchEngineType, api_key: str, **kwargs) -> None:
self.api_key = api_key
self.engine_type = engine
if engine == SearchEngineType.bing:
self.engine = BingSearch(api_key, **kwargs)
elif engine == SearchEngineType.brave:
self.engine = BraveSearch(api_key, **kwargs)
elif engine == SearchEngineType.tavily:
self.engine = TavilySearch(api_key, **kwargs)
else:
raise ValueError(f"Unknown search engine: {engine}")
def get_name(self) -> str:
return BuiltinTool.brave_search.value
async def run_impl(self, query: str) -> str:
return await self.engine.search(query)
class BingSearch:
def __init__(self, api_key: str, top_k: int = 3, **kwargs) -> None:
self.api_key = api_key
self.top_k = top_k
async def search(self, query: str) -> str:
url = "https://api.bing.microsoft.com/v7.0/search"
headers = {
"Ocp-Apim-Subscription-Key": self.api_key,
}
params = {
"count": self.top_k,
"textDecorations": True,
"textFormat": "HTML",
"q": query,
}
response = requests.get(url=url, params=params, headers=headers)
response.raise_for_status()
clean = self._clean_response(response.json())
return json.dumps(clean)
def _clean_response(self, search_response):
clean_response = []
query = search_response["queryContext"]["originalQuery"]
if "webPages" in search_response:
pages = search_response["webPages"]["value"]
for p in pages:
selected_keys = {"name", "url", "snippet"}
clean_response.append(
{k: v for k, v in p.items() if k in selected_keys}
)
if "news" in search_response:
clean_news = []
news = search_response["news"]["value"]
for n in news:
selected_keys = {"name", "url", "description"}
clean_news.append({k: v for k, v in n.items() if k in selected_keys})
clean_response.append(clean_news)
return {"query": query, "top_k": clean_response}
class BraveSearch:
def __init__(self, api_key: str) -> None:
self.api_key = api_key
async def search(self, query: str) -> str:
url = "https://api.search.brave.com/res/v1/web/search"
headers = {
"X-Subscription-Token": self.api_key,
"Accept-Encoding": "gzip",
"Accept": "application/json",
}
payload = {"q": query}
response = requests.get(url=url, params=payload, headers=headers)
return json.dumps(self._clean_brave_response(response.json()))
def _clean_brave_response(self, search_response, top_k=3):
query = None
clean_response = []
if "query" in search_response:
if "original" in search_response["query"]:
query = search_response["query"]["original"]
if "mixed" in search_response:
mixed_results = search_response["mixed"]
for m in mixed_results["main"][:top_k]:
r_type = m["type"]
results = search_response[r_type]["results"]
if r_type == "web":
# For web data - add a single output from the search
idx = m["index"]
selected_keys = [
"type",
"title",
"url",
"description",
"date",
"extra_snippets",
]
cleaned = {
k: v for k, v in results[idx].items() if k in selected_keys
}
elif r_type == "faq":
# For faw data - take a list of all the questions & answers
selected_keys = ["type", "question", "answer", "title", "url"]
cleaned = []
for q in results:
cleaned.append(
{k: v for k, v in q.items() if k in selected_keys}
)
elif r_type == "infobox":
idx = m["index"]
selected_keys = [
"type",
"title",
"url",
"description",
"long_desc",
]
cleaned = {
k: v for k, v in results[idx].items() if k in selected_keys
}
elif r_type == "videos":
selected_keys = [
"type",
"url",
"title",
"description",
"date",
]
cleaned = []
for q in results:
cleaned.append(
{k: v for k, v in q.items() if k in selected_keys}
)
elif r_type == "locations":
# For faw data - take a list of all the questions & answers
selected_keys = [
"type",
"title",
"url",
"description",
"coordinates",
"postal_address",
"contact",
"rating",
"distance",
"zoom_level",
]
cleaned = []
for q in results:
cleaned.append(
{k: v for k, v in q.items() if k in selected_keys}
)
elif r_type == "news":
# For faw data - take a list of all the questions & answers
selected_keys = [
"type",
"title",
"url",
"description",
]
cleaned = []
for q in results:
cleaned.append(
{k: v for k, v in q.items() if k in selected_keys}
)
else:
cleaned = []
clean_response.append(cleaned)
return {"query": query, "top_k": clean_response}
class TavilySearch:
def __init__(self, api_key: str) -> None:
self.api_key = api_key
async def search(self, query: str) -> str:
response = requests.post(
"https://api.tavily.com/search",
json={"api_key": self.api_key, "query": query},
)
return json.dumps(self._clean_tavily_response(response.json()))
def _clean_tavily_response(self, search_response, top_k=3):
return {"query": search_response["query"], "top_k": search_response["results"]}
class WolframAlphaTool(SingleMessageBuiltinTool):
def __init__(self, api_key: str) -> None:
self.api_key = api_key
self.url = "https://api.wolframalpha.com/v2/query"
def get_name(self) -> str:
return BuiltinTool.wolfram_alpha.value
async def run_impl(self, query: str) -> str:
params = {
"input": query,
"appid": self.api_key,
"format": "plaintext",
"output": "json",
}
response = requests.get(
self.url,
params=params,
)
return json.dumps(self._clean_wolfram_alpha_response(response.json()))
def _clean_wolfram_alpha_response(self, wa_response):
remove = {
"queryresult": [
"datatypes",
"error",
"timedout",
"timedoutpods",
"numpods",
"timing",
"parsetiming",
"parsetimedout",
"recalculate",
"id",
"host",
"server",
"related",
"version",
{
"pods": [
"scanner",
"id",
"error",
"expressiontypes",
"states",
"infos",
"position",
"numsubpods",
]
},
"assumptions",
],
}
for main_key in remove:
for key_to_remove in remove[main_key]:
try:
if key_to_remove == "assumptions":
if "assumptions" in wa_response[main_key]:
del wa_response[main_key][key_to_remove]
if isinstance(key_to_remove, dict):
for sub_key in key_to_remove:
if sub_key == "pods":
for i in range(len(wa_response[main_key][sub_key])):
if (
wa_response[main_key][sub_key][i]["title"]
== "Result"
):
del wa_response[main_key][sub_key][i + 1 :]
break
sub_items = wa_response[main_key][sub_key]
for i in range(len(sub_items)):
for sub_key_to_remove in key_to_remove[sub_key]:
if sub_key_to_remove in sub_items[i]:
del sub_items[i][sub_key_to_remove]
elif key_to_remove in wa_response[main_key]:
del wa_response[main_key][key_to_remove]
except KeyError:
pass
return wa_response
class CodeInterpreterTool(BaseTool):
def __init__(self) -> None:
ctx = CodeExecutionContext(
matplotlib_dump_dir=tempfile.mkdtemp(),
)
self.code_executor = CodeExecutor(ctx)
def get_name(self) -> str:
return BuiltinTool.code_interpreter.value
async def run(self, messages: List[CompletionMessage]) -> List[ToolResponseMessage]:
message = messages[0]
assert len(message.tool_calls) == 1, "Expected a single tool call"
tool_call = messages[0].tool_calls[0]
script = tool_call.arguments["code"]
req = CodeExecutionRequest(scripts=[script])
res = self.code_executor.execute(req)
pieces = [res["process_status"]]
for out_type in ["stdout", "stderr"]:
res_out = res[out_type]
if res_out != "":
pieces.extend([f"[{out_type}]", res_out, f"[/{out_type}]"])
if out_type == "stderr":
log.error(f"ipython tool error: ↓\n{res_out}")
message = ToolResponseMessage(
call_id=tool_call.call_id,
tool_name=tool_call.tool_name,
content="\n".join(pieces),
)
return [message]

View file

@ -1,42 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from llama_stack.apis.inference import Message
from llama_stack.apis.safety import Safety
from ..safety import ShieldRunnerMixin
from .builtin import BaseTool
class SafeTool(BaseTool, ShieldRunnerMixin):
"""A tool that makes other tools safety enabled"""
def __init__(
self,
tool: BaseTool,
safety_api: Safety,
input_shields: List[str] = None,
output_shields: List[str] = None,
):
self._tool = tool
ShieldRunnerMixin.__init__(
self, safety_api, input_shields=input_shields, output_shields=output_shields
)
def get_name(self) -> str:
return self._tool.get_name()
async def run(self, messages: List[Message]) -> List[Message]:
if self.input_shields:
await self.run_multiple_shields(messages, self.input_shields)
# run the underlying tool
res = await self._tool.run(messages)
if self.output_shields:
await self.run_multiple_shields(messages, self.output_shields)
return res

View file

@ -0,0 +1,16 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .code_interpreter import CodeInterpreterToolRuntimeImpl
from .config import CodeInterpreterToolConfig
__all__ = ["CodeInterpreterToolConfig", "CodeInterpreterToolRuntimeImpl"]
async def get_provider_impl(config: CodeInterpreterToolConfig, _deps):
impl = CodeInterpreterToolRuntimeImpl(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,75 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import tempfile
from typing import Any, Dict, List, Optional
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import (
Tool,
ToolDef,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from llama_stack.providers.datatypes import ToolsProtocolPrivate
from .code_execution import CodeExecutionContext, CodeExecutionRequest, CodeExecutor
from .config import CodeInterpreterToolConfig
log = logging.getLogger(__name__)
class CodeInterpreterToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
def __init__(self, config: CodeInterpreterToolConfig):
self.config = config
ctx = CodeExecutionContext(
matplotlib_dump_dir=tempfile.mkdtemp(),
)
self.code_executor = CodeExecutor(ctx)
async def initialize(self):
pass
async def register_tool(self, tool: Tool):
pass
async def unregister_tool(self, tool_id: str) -> None:
return
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
return [
ToolDef(
name="code_interpreter",
description="Execute code",
parameters=[
ToolParameter(
name="code",
description="The code to execute",
parameter_type="string",
),
],
)
]
async def invoke_tool(
self, tool_name: str, args: Dict[str, Any]
) -> ToolInvocationResult:
script = args["code"]
req = CodeExecutionRequest(scripts=[script])
res = self.code_executor.execute(req)
pieces = [res["process_status"]]
for out_type in ["stdout", "stderr"]:
res_out = res[out_type]
if res_out != "":
pieces.extend([f"[{out_type}]", res_out, f"[/{out_type}]"])
if out_type == "stderr":
log.error(f"ipython tool error: ↓\n{res_out}")
return ToolInvocationResult(content="\n".join(pieces))

View file

@ -3,3 +3,9 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
class CodeInterpreterToolConfig(BaseModel):
pass

View file

@ -0,0 +1,20 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from llama_stack.providers.datatypes import Api
from .config import MemoryToolRuntimeConfig
from .memory import MemoryToolRuntimeImpl
async def get_provider_impl(config: MemoryToolRuntimeConfig, deps: Dict[str, Any]):
impl = MemoryToolRuntimeImpl(
config, deps[Api.memory], deps[Api.memory_banks], deps[Api.inference]
)
await impl.initialize()
return impl

View file

@ -0,0 +1,90 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Annotated, List, Literal, Union
from pydantic import BaseModel, Field
class _MemoryBankConfigCommon(BaseModel):
bank_id: str
class VectorMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal["vector"] = "vector"
class KeyValueMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal["keyvalue"] = "keyvalue"
keys: List[str] # what keys to focus on
class KeywordMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal["keyword"] = "keyword"
class GraphMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal["graph"] = "graph"
entities: List[str] # what entities to focus on
MemoryBankConfig = Annotated[
Union[
VectorMemoryBankConfig,
KeyValueMemoryBankConfig,
KeywordMemoryBankConfig,
GraphMemoryBankConfig,
],
Field(discriminator="type"),
]
class MemoryQueryGenerator(Enum):
default = "default"
llm = "llm"
custom = "custom"
class DefaultMemoryQueryGeneratorConfig(BaseModel):
type: Literal[MemoryQueryGenerator.default.value] = (
MemoryQueryGenerator.default.value
)
sep: str = " "
class LLMMemoryQueryGeneratorConfig(BaseModel):
type: Literal[MemoryQueryGenerator.llm.value] = MemoryQueryGenerator.llm.value
model: str
template: str
class CustomMemoryQueryGeneratorConfig(BaseModel):
type: Literal[MemoryQueryGenerator.custom.value] = MemoryQueryGenerator.custom.value
MemoryQueryGeneratorConfig = Annotated[
Union[
DefaultMemoryQueryGeneratorConfig,
LLMMemoryQueryGeneratorConfig,
CustomMemoryQueryGeneratorConfig,
],
Field(discriminator="type"),
]
class MemoryToolConfig(BaseModel):
memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list)
class MemoryToolRuntimeConfig(BaseModel):
# This config defines how a query is generated using the messages
# for memory bank retrieval.
query_generator_config: MemoryQueryGeneratorConfig = Field(
default=DefaultMemoryQueryGeneratorConfig()
)
max_tokens_in_context: int = 4096
max_chunks: int = 5

View file

@ -4,25 +4,29 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from jinja2 import Template
from pydantic import BaseModel
from llama_stack.apis.agents import (
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import UserMessage
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
from .config import (
DefaultMemoryQueryGeneratorConfig,
LLMMemoryQueryGeneratorConfig,
MemoryQueryGenerator,
MemoryQueryGeneratorConfig,
)
from llama_stack.apis.inference import Message, UserMessage
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
async def generate_rag_query(
config: MemoryQueryGeneratorConfig,
messages: List[Message],
messages: List[InterleavedContent],
**kwargs,
):
"""
@ -40,21 +44,26 @@ async def generate_rag_query(
async def default_rag_query_generator(
config: DefaultMemoryQueryGeneratorConfig,
messages: List[Message],
messages: List[InterleavedContent],
**kwargs,
):
return config.sep.join(interleaved_content_as_str(m.content) for m in messages)
return config.sep.join(interleaved_content_as_str(m) for m in messages)
async def llm_rag_query_generator(
config: LLMMemoryQueryGeneratorConfig,
messages: List[Message],
messages: List[InterleavedContent],
**kwargs,
):
assert "inference_api" in kwargs, "LLMRAGQueryGenerator needs inference_api"
inference_api = kwargs["inference_api"]
m_dict = {"messages": [m.model_dump() for m in messages]}
m_dict = {
"messages": [
message.model_dump() if isinstance(message, BaseModel) else message
for message in messages
]
}
template = Template(config.template)
content = template.render(m_dict)

View file

@ -0,0 +1,146 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import logging
import secrets
import string
from typing import Any, Dict, List, Optional
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.inference import Inference, InterleavedContent
from llama_stack.apis.memory import Memory, QueryDocumentsResponse
from llama_stack.apis.memory_banks import MemoryBanks
from llama_stack.apis.tools import (
ToolDef,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from llama_stack.providers.datatypes import ToolsProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
from .config import MemoryToolConfig, MemoryToolRuntimeConfig
from .context_retriever import generate_rag_query
log = logging.getLogger(__name__)
def make_random_string(length: int = 8):
return "".join(
secrets.choice(string.ascii_letters + string.digits) for _ in range(length)
)
class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
def __init__(
self,
config: MemoryToolRuntimeConfig,
memory_api: Memory,
memory_banks_api: MemoryBanks,
inference_api: Inference,
):
self.config = config
self.memory_api = memory_api
self.memory_banks_api = memory_banks_api
self.inference_api = inference_api
async def initialize(self):
pass
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
return [
ToolDef(
name="query_memory",
description="Retrieve context from memory",
parameters=[
ToolParameter(
name="messages",
description="The input messages to search for",
parameter_type="array",
),
],
)
]
async def _retrieve_context(
self, input_messages: List[InterleavedContent], bank_ids: List[str]
) -> Optional[List[InterleavedContent]]:
if not bank_ids:
return None
query = await generate_rag_query(
self.config.query_generator_config,
input_messages,
inference_api=self.inference_api,
)
tasks = [
self.memory_api.query_documents(
bank_id=bank_id,
query=query,
params={
"max_chunks": self.config.max_chunks,
},
)
for bank_id in bank_ids
]
results: List[QueryDocumentsResponse] = await asyncio.gather(*tasks)
chunks = [c for r in results for c in r.chunks]
scores = [s for r in results for s in r.scores]
if not chunks:
return None
# sort by score
chunks, scores = zip(
*sorted(zip(chunks, scores), key=lambda x: x[1], reverse=True)
)
tokens = 0
picked = []
for c in chunks[: self.config.max_chunks]:
tokens += c.token_count
if tokens > self.config.max_tokens_in_context:
log.error(
f"Using {len(picked)} chunks; reached max tokens in context: {tokens}",
)
break
picked.append(f"id:{c.document_id}; content:{c.content}")
return [
"Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n",
*picked,
"\n=== END-RETRIEVED-CONTEXT ===\n",
]
async def invoke_tool(
self, tool_name: str, args: Dict[str, Any]
) -> ToolInvocationResult:
tool = await self.tool_store.get_tool(tool_name)
tool_group = await self.tool_store.get_tool_group(tool.toolgroup_id)
final_args = tool_group.args or {}
final_args.update(args)
config = MemoryToolConfig()
if tool.metadata and tool.metadata.get("config") is not None:
config = MemoryToolConfig(**tool.metadata["config"])
if "memory_bank_ids" in final_args:
bank_ids = final_args["memory_bank_ids"]
else:
bank_ids = [
bank_config.bank_id for bank_config in config.memory_bank_configs
]
if "messages" not in final_args:
raise ValueError("messages are required")
context = await self._retrieve_context(
final_args["messages"],
bank_ids,
)
if context is None:
context = []
return ToolInvocationResult(
content=concat_interleaved_content(context), error_code=0
)

View file

@ -35,6 +35,8 @@ def available_providers() -> List[ProviderSpec]:
Api.safety,
Api.memory,
Api.memory_banks,
Api.tool_runtime,
Api.tool_groups,
],
),
remote_provider_spec(

View file

@ -19,11 +19,58 @@ def available_providers() -> List[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.tool_runtime,
provider_type="inline::brave-search",
provider_type="inline::memory-runtime",
pip_packages=[],
module="llama_stack.providers.inline.tool_runtime.brave_search",
config_class="llama_stack.providers.inline.tool_runtime.brave_search.config.BraveSearchToolConfig",
provider_data_validator="llama_stack.providers.inline.tool_runtime.brave_search.BraveSearchToolProviderDataValidator",
module="llama_stack.providers.inline.tool_runtime.memory",
config_class="llama_stack.providers.inline.tool_runtime.memory.config.MemoryToolRuntimeConfig",
api_dependencies=[Api.memory, Api.memory_banks, Api.inference],
),
InlineProviderSpec(
api=Api.tool_runtime,
provider_type="inline::code-interpreter",
pip_packages=[],
module="llama_stack.providers.inline.tool_runtime.code_interpreter",
config_class="llama_stack.providers.inline.tool_runtime.code_interpreter.config.CodeInterpreterToolConfig",
),
remote_provider_spec(
api=Api.tool_runtime,
adapter=AdapterSpec(
adapter_type="brave-search",
module="llama_stack.providers.remote.tool_runtime.brave_search",
config_class="llama_stack.providers.remote.tool_runtime.brave_search.config.BraveSearchToolConfig",
pip_packages=["requests"],
provider_data_validator="llama_stack.providers.remote.tool_runtime.brave_search.BraveSearchToolProviderDataValidator",
),
),
remote_provider_spec(
api=Api.tool_runtime,
adapter=AdapterSpec(
adapter_type="bing-search",
module="llama_stack.providers.remote.tool_runtime.bing_search",
config_class="llama_stack.providers.remote.tool_runtime.bing_search.config.BingSearchToolConfig",
pip_packages=["requests"],
provider_data_validator="llama_stack.providers.remote.tool_runtime.bing_search.BingSearchToolProviderDataValidator",
),
),
remote_provider_spec(
api=Api.tool_runtime,
adapter=AdapterSpec(
adapter_type="tavily-search",
module="llama_stack.providers.remote.tool_runtime.tavily_search",
config_class="llama_stack.providers.remote.tool_runtime.tavily_search.config.TavilySearchToolConfig",
pip_packages=["requests"],
provider_data_validator="llama_stack.providers.remote.tool_runtime.tavily_search.TavilySearchToolProviderDataValidator",
),
),
remote_provider_spec(
api=Api.tool_runtime,
adapter=AdapterSpec(
adapter_type="wolfram-alpha",
module="llama_stack.providers.remote.tool_runtime.wolfram_alpha",
config_class="llama_stack.providers.remote.tool_runtime.wolfram_alpha.config.WolframAlphaToolConfig",
pip_packages=["requests"],
provider_data_validator="llama_stack.providers.remote.tool_runtime.wolfram_alpha.WolframAlphaToolProviderDataValidator",
),
),
remote_provider_spec(
api=Api.tool_runtime,

View file

@ -7,11 +7,8 @@
from typing import AsyncGenerator, List, Optional, Union
from llama_models.datatypes import CoreModelId
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from together import Together
from llama_stack.apis.common.content_types import InterleavedContent
@ -53,7 +50,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import TogetherImplConfig
MODEL_ALIASES = [
build_model_alias(
"meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",

View file

@ -0,0 +1,21 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .bing_search import BingSearchToolRuntimeImpl
from .config import BingSearchToolConfig
__all__ = ["BingSearchToolConfig", "BingSearchToolRuntimeImpl"]
from pydantic import BaseModel
class BingSearchToolProviderDataValidator(BaseModel):
api_key: str
async def get_adapter_impl(config: BingSearchToolConfig, _deps):
impl = BingSearchToolRuntimeImpl(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,114 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from typing import Any, Dict, List, Optional
import requests
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import (
Tool,
ToolDef,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import ToolsProtocolPrivate
from .config import BingSearchToolConfig
class BingSearchToolRuntimeImpl(
ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData
):
def __init__(self, config: BingSearchToolConfig):
self.config = config
self.url = "https://api.bing.microsoft.com/v7.0/search"
async def initialize(self):
pass
async def register_tool(self, tool: Tool):
pass
async def unregister_tool(self, tool_id: str) -> None:
return
def _get_api_key(self) -> str:
if self.config.api_key:
return self.config.api_key
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.api_key:
raise ValueError(
'Pass Bing Search API Key in the header X-LlamaStack-ProviderData as { "api_key": <your api key>}'
)
return provider_data.api_key
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
return [
ToolDef(
name="web_search",
description="Search the web using Bing Search API",
parameters=[
ToolParameter(
name="query",
description="The query to search for",
parameter_type="string",
)
],
)
]
async def invoke_tool(
self, tool_name: str, args: Dict[str, Any]
) -> ToolInvocationResult:
api_key = self._get_api_key()
headers = {
"Ocp-Apim-Subscription-Key": api_key,
}
params = {
"count": self.config.top_k,
"textDecorations": True,
"textFormat": "HTML",
"q": args["query"],
}
response = requests.get(
url=self.url,
params=params,
headers=headers,
)
response.raise_for_status()
return ToolInvocationResult(
content=json.dumps(self._clean_response(response.json()))
)
def _clean_response(self, search_response):
clean_response = []
query = search_response["queryContext"]["originalQuery"]
if "webPages" in search_response:
pages = search_response["webPages"]["value"]
for p in pages:
selected_keys = {"name", "url", "snippet"}
clean_response.append(
{k: v for k, v in p.items() if k in selected_keys}
)
if "news" in search_response:
clean_news = []
news = search_response["news"]["value"]
for n in news:
selected_keys = {"name", "url", "description"}
clean_news.append({k: v for k, v in n.items() if k in selected_keys})
clean_response.append(clean_news)
return {"query": query, "top_k": clean_response}

View file

@ -0,0 +1,16 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
from pydantic import BaseModel
class BingSearchToolConfig(BaseModel):
"""Configuration for Bing Search Tool Runtime"""
api_key: Optional[str] = None
top_k: int = 3

View file

@ -14,7 +14,7 @@ class BraveSearchToolProviderDataValidator(BaseModel):
api_key: str
async def get_provider_impl(config: BraveSearchToolConfig, _deps):
async def get_adapter_impl(config: BraveSearchToolConfig, _deps):
impl = BraveSearchToolRuntimeImpl(config)
await impl.initialize()
return impl

View file

@ -4,11 +4,19 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional
import requests
from llama_models.llama3.api.datatypes import BuiltinTool
from llama_stack.apis.tools import Tool, ToolGroupDef, ToolInvocationResult, ToolRuntime
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import (
Tool,
ToolDef,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import ToolsProtocolPrivate
@ -25,8 +33,7 @@ class BraveSearchToolRuntimeImpl(
pass
async def register_tool(self, tool: Tool):
if tool.identifier != "brave_search":
raise ValueError(f"Tool identifier {tool.identifier} is not supported")
pass
async def unregister_tool(self, tool_id: str) -> None:
return
@ -42,8 +49,23 @@ class BraveSearchToolRuntimeImpl(
)
return provider_data.api_key
async def discover_tools(self, tool_group: ToolGroupDef) -> List[Tool]:
raise NotImplementedError("Brave search tool group not supported")
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
return [
ToolDef(
name="web_search",
description="Search the web for information",
parameters=[
ToolParameter(
name="query",
description="The query to search for",
parameter_type="string",
)
],
built_in_type=BuiltinTool.brave_search,
)
]
async def invoke_tool(
self, tool_name: str, args: Dict[str, Any]

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field
@ -18,3 +18,10 @@ class BraveSearchToolConfig(BaseModel):
default=3,
description="The maximum number of results to return",
)
@classmethod
def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]:
return {
"api_key": "${env.BRAVE_SEARCH_API_KEY:}",
"max_results": 3,
}

View file

@ -4,22 +4,21 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional
from urllib.parse import urlparse
from mcp import ClientSession
from mcp.client.sse import sse_client
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import (
MCPToolGroupDef,
ToolDef,
ToolGroupDef,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from llama_stack.providers.datatypes import ToolsProtocolPrivate
from mcp import ClientSession
from mcp.client.sse import sse_client
from .config import ModelContextProtocolConfig
@ -30,12 +29,14 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
async def initialize(self):
pass
async def discover_tools(self, tool_group: ToolGroupDef) -> List[ToolDef]:
if not isinstance(tool_group, MCPToolGroupDef):
raise ValueError(f"Unsupported tool group type: {type(tool_group)}")
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
if mcp_endpoint is None:
raise ValueError("mcp_endpoint is required")
tools = []
async with sse_client(tool_group.endpoint.uri) as streams:
async with sse_client(mcp_endpoint.uri) as streams:
async with ClientSession(*streams) as session:
await session.initialize()
tools_result = await session.list_tools()
@ -57,7 +58,7 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
description=tool.description,
parameters=parameters,
metadata={
"endpoint": tool_group.endpoint.uri,
"endpoint": mcp_endpoint.uri,
},
)
)

View file

@ -0,0 +1,20 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
from .config import TavilySearchToolConfig
from .tavily_search import TavilySearchToolRuntimeImpl
class TavilySearchToolProviderDataValidator(BaseModel):
api_key: str
async def get_adapter_impl(config: TavilySearchToolConfig, _deps):
impl = TavilySearchToolRuntimeImpl(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,27 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field
class TavilySearchToolConfig(BaseModel):
api_key: Optional[str] = Field(
default=None,
description="The Tavily Search API Key",
)
max_results: int = Field(
default=3,
description="The maximum number of results to return",
)
@classmethod
def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]:
return {
"api_key": "${env.TAVILY_SEARCH_API_KEY:}",
"max_results": 3,
}

View file

@ -0,0 +1,83 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from typing import Any, Dict, List, Optional
import requests
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import (
Tool,
ToolDef,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import ToolsProtocolPrivate
from .config import TavilySearchToolConfig
class TavilySearchToolRuntimeImpl(
ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData
):
def __init__(self, config: TavilySearchToolConfig):
self.config = config
async def initialize(self):
pass
async def register_tool(self, tool: Tool):
pass
async def unregister_tool(self, tool_id: str) -> None:
return
def _get_api_key(self) -> str:
if self.config.api_key:
return self.config.api_key
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.api_key:
raise ValueError(
'Pass Search provider\'s API Key in the header X-LlamaStack-ProviderData as { "api_key": <your api key>}'
)
return provider_data.api_key
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
return [
ToolDef(
name="web_search",
description="Search the web for information",
parameters=[
ToolParameter(
name="query",
description="The query to search for",
parameter_type="string",
)
],
)
]
async def invoke_tool(
self, tool_name: str, args: Dict[str, Any]
) -> ToolInvocationResult:
api_key = self._get_api_key()
response = requests.post(
"https://api.tavily.com/search",
json={"api_key": api_key, "query": args["query"]},
)
return ToolInvocationResult(
content=json.dumps(self._clean_tavily_response(response.json()))
)
def _clean_tavily_response(self, search_response, top_k=3):
return {"query": search_response["query"], "top_k": search_response["results"]}

View file

@ -0,0 +1,22 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
from .config import WolframAlphaToolConfig
from .wolfram_alpha import WolframAlphaToolRuntimeImpl
__all__ = ["WolframAlphaToolConfig", "WolframAlphaToolRuntimeImpl"]
class WolframAlphaToolProviderDataValidator(BaseModel):
api_key: str
async def get_adapter_impl(config: WolframAlphaToolConfig, _deps):
impl = WolframAlphaToolRuntimeImpl(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
from pydantic import BaseModel
class WolframAlphaToolConfig(BaseModel):
"""Configuration for WolframAlpha Tool Runtime"""
api_key: Optional[str] = None

View file

@ -0,0 +1,146 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from typing import Any, Dict, List, Optional
import requests
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import (
Tool,
ToolDef,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import ToolsProtocolPrivate
from .config import WolframAlphaToolConfig
class WolframAlphaToolRuntimeImpl(
ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData
):
def __init__(self, config: WolframAlphaToolConfig):
self.config = config
self.url = "https://api.wolframalpha.com/v2/query"
async def initialize(self):
pass
async def register_tool(self, tool: Tool):
pass
async def unregister_tool(self, tool_id: str) -> None:
return
def _get_api_key(self) -> str:
if self.config.api_key:
return self.config.api_key
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.api_key:
raise ValueError(
'Pass WolframAlpha API Key in the header X-LlamaStack-ProviderData as { "api_key": <your api key>}'
)
return provider_data.api_key
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
return [
ToolDef(
name="wolfram_alpha",
description="Query WolframAlpha for computational knowledge",
parameters=[
ToolParameter(
name="query",
description="The query to compute",
parameter_type="string",
)
],
)
]
async def invoke_tool(
self, tool_name: str, args: Dict[str, Any]
) -> ToolInvocationResult:
api_key = self._get_api_key()
params = {
"input": args["query"],
"appid": api_key,
"format": "plaintext",
"output": "json",
}
response = requests.get(
self.url,
params=params,
)
return ToolInvocationResult(
content=json.dumps(self._clean_wolfram_alpha_response(response.json()))
)
def _clean_wolfram_alpha_response(self, wa_response):
remove = {
"queryresult": [
"datatypes",
"error",
"timedout",
"timedoutpods",
"numpods",
"timing",
"parsetiming",
"parsetimedout",
"recalculate",
"id",
"host",
"server",
"related",
"version",
{
"pods": [
"scanner",
"id",
"error",
"expressiontypes",
"states",
"infos",
"position",
"numsubpods",
]
},
"assumptions",
],
}
for main_key in remove:
for key_to_remove in remove[main_key]:
try:
if key_to_remove == "assumptions":
if "assumptions" in wa_response[main_key]:
del wa_response[main_key][key_to_remove]
if isinstance(key_to_remove, dict):
for sub_key in key_to_remove:
if sub_key == "pods":
for i in range(len(wa_response[main_key][sub_key])):
if (
wa_response[main_key][sub_key][i]["title"]
== "Result"
):
del wa_response[main_key][sub_key][i + 1 :]
break
sub_items = wa_response[main_key][sub_key]
for i in range(len(sub_items)):
for sub_key_to_remove in key_to_remove[sub_key]:
if sub_key_to_remove in sub_items[i]:
del sub_items[i][sub_key_to_remove]
elif key_to_remove in wa_response[main_key]:
del wa_response[main_key][key_to_remove]
except KeyError:
pass
return wa_response

View file

@ -7,13 +7,12 @@
import pytest
from ..conftest import get_provider_fixture_overrides
from ..inference.fixtures import INFERENCE_FIXTURES
from ..memory.fixtures import MEMORY_FIXTURES
from ..safety.fixtures import SAFETY_FIXTURES, safety_model_from_shield
from ..tools.fixtures import TOOL_RUNTIME_FIXTURES
from .fixtures import AGENTS_FIXTURES
DEFAULT_PROVIDER_COMBINATIONS = [
pytest.param(
{
@ -21,6 +20,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
"safety": "llama_guard",
"memory": "faiss",
"agents": "meta_reference",
"tool_runtime": "memory_and_search",
},
id="meta_reference",
marks=pytest.mark.meta_reference,
@ -31,6 +31,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
"safety": "llama_guard",
"memory": "faiss",
"agents": "meta_reference",
"tool_runtime": "memory_and_search",
},
id="ollama",
marks=pytest.mark.ollama,
@ -42,6 +43,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
# make this work with Weaviate which is what the together distro supports
"memory": "faiss",
"agents": "meta_reference",
"tool_runtime": "memory_and_search",
},
id="together",
marks=pytest.mark.together,
@ -52,6 +54,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
"safety": "llama_guard",
"memory": "faiss",
"agents": "meta_reference",
"tool_runtime": "memory_and_search",
},
id="fireworks",
marks=pytest.mark.fireworks,
@ -62,6 +65,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
"safety": "remote",
"memory": "remote",
"agents": "remote",
"tool_runtime": "memory_and_search",
},
id="remote",
marks=pytest.mark.remote,
@ -117,6 +121,7 @@ def pytest_generate_tests(metafunc):
"safety": SAFETY_FIXTURES,
"memory": MEMORY_FIXTURES,
"agents": AGENTS_FIXTURES,
"tool_runtime": TOOL_RUNTIME_FIXTURES,
}
combinations = (
get_provider_fixture_overrides(metafunc.config, available_fixtures)

View file

@ -11,13 +11,12 @@ import pytest_asyncio
from llama_stack.apis.models import ModelInput, ModelType
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.agents.meta_reference import (
MetaReferenceAgentsImplConfig,
)
from llama_stack.providers.tests.resolver import construct_stack_for_test
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from ..conftest import ProviderFixture, remote_stack_fixture
@ -59,12 +58,18 @@ AGENTS_FIXTURES = ["meta_reference", "remote"]
@pytest_asyncio.fixture(scope="session")
async def agents_stack(request, inference_model, safety_shield):
async def agents_stack(
request,
inference_model,
safety_shield,
tool_group_input_memory,
tool_group_input_tavily_search,
):
fixture_dict = request.param
providers = {}
provider_data = {}
for key in ["inference", "safety", "memory", "agents"]:
for key in ["inference", "safety", "memory", "agents", "tool_runtime"]:
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
providers[key] = fixture.providers
if key == "inference":
@ -113,10 +118,11 @@ async def agents_stack(request, inference_model, safety_shield):
)
test_stack = await construct_stack_for_test(
[Api.agents, Api.inference, Api.safety, Api.memory],
[Api.agents, Api.inference, Api.safety, Api.memory, Api.tool_runtime],
providers,
provider_data,
models=models,
shields=[safety_shield] if safety_shield else [],
tool_groups=[tool_group_input_memory, tool_group_input_tavily_search],
)
return test_stack

View file

@ -5,22 +5,17 @@
# the root directory of this source tree.
import os
from typing import Dict, List
import pytest
from llama_models.llama3.api.datatypes import BuiltinTool
from llama_stack.apis.agents import (
AgentConfig,
AgentTool,
AgentTurnResponseEventType,
AgentTurnResponseStepCompletePayload,
AgentTurnResponseStreamChunk,
AgentTurnResponseTurnCompletePayload,
Attachment,
MemoryToolDefinition,
SearchEngineType,
SearchToolDefinition,
Document,
ShieldCallStep,
StepType,
ToolChoice,
@ -35,7 +30,6 @@ from llama_stack.providers.datatypes import Api
#
# pytest -v -s llama_stack/providers/tests/agents/test_agents.py
# -m "meta_reference"
from .fixtures import pick_inference_model
from .utils import create_agent_session
@ -51,7 +45,7 @@ def common_params(inference_model):
sampling_params=SamplingParams(temperature=0.7, top_p=0.95),
input_shields=[],
output_shields=[],
tools=[],
toolgroups=[],
max_infer_iters=5,
)
@ -88,73 +82,6 @@ def query_attachment_messages():
]
async def create_agent_turn_with_search_tool(
agents_stack: Dict[str, object],
search_query_messages: List[object],
common_params: Dict[str, str],
search_tool_definition: SearchToolDefinition,
) -> None:
"""
Create an agent turn with a search tool.
Args:
agents_stack (Dict[str, object]): The agents stack.
search_query_messages (List[object]): The search query messages.
common_params (Dict[str, str]): The common parameters.
search_tool_definition (SearchToolDefinition): The search tool definition.
"""
# Create an agent with the search tool
agent_config = AgentConfig(
**{
**common_params,
"tools": [search_tool_definition],
}
)
agent_id, session_id = await create_agent_session(
agents_stack.impls[Api.agents], agent_config
)
turn_request = dict(
agent_id=agent_id,
session_id=session_id,
messages=search_query_messages,
stream=True,
)
turn_response = [
chunk
async for chunk in await agents_stack.impls[Api.agents].create_agent_turn(
**turn_request
)
]
assert len(turn_response) > 0
assert all(
isinstance(chunk, AgentTurnResponseStreamChunk) for chunk in turn_response
)
check_event_types(turn_response)
# Check for tool execution events
tool_execution_events = [
chunk
for chunk in turn_response
if isinstance(chunk.event.payload, AgentTurnResponseStepCompletePayload)
and chunk.event.payload.step_details.step_type == StepType.tool_execution.value
]
assert len(tool_execution_events) > 0, "No tool execution events found"
# Check the tool execution details
tool_execution = tool_execution_events[0].event.payload.step_details
assert isinstance(tool_execution, ToolExecutionStep)
assert len(tool_execution.tool_calls) > 0
assert tool_execution.tool_calls[0].tool_name == BuiltinTool.brave_search
assert len(tool_execution.tool_responses) > 0
check_turn_complete_event(turn_response, session_id, search_query_messages)
class TestAgents:
@pytest.mark.asyncio
async def test_agent_turns_with_safety(
@ -227,7 +154,7 @@ class TestAgents:
check_turn_complete_event(turn_response, session_id, sample_messages)
@pytest.mark.asyncio
async def test_rag_agent_as_attachments(
async def test_rag_agent(
self,
agents_stack,
attachment_message,
@ -243,29 +170,17 @@ class TestAgents:
"qat_finetune.rst",
"lora_finetune.rst",
]
attachments = [
Attachment(
documents = [
Document(
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
mime_type="text/plain",
)
for i, url in enumerate(urls)
]
agent_config = AgentConfig(
**{
**common_params,
"tools": [
MemoryToolDefinition(
memory_bank_configs=[],
query_generator_config={
"type": "default",
"sep": " ",
},
max_tokens_in_context=4096,
max_chunks=10,
),
],
"toolgroups": ["builtin::memory"],
"tool_choice": ToolChoice.auto,
}
)
@ -275,7 +190,7 @@ class TestAgents:
agent_id=agent_id,
session_id=session_id,
messages=attachment_message,
attachments=attachments,
documents=documents,
stream=True,
)
turn_response = [
@ -298,22 +213,6 @@ class TestAgents:
assert len(turn_response) > 0
@pytest.mark.asyncio
async def test_create_agent_turn_with_brave_search(
self, agents_stack, search_query_messages, common_params
):
if "BRAVE_SEARCH_API_KEY" not in os.environ:
pytest.skip("BRAVE_SEARCH_API_KEY not set, skipping test")
search_tool_definition = SearchToolDefinition(
type=AgentTool.brave_search.value,
api_key=os.environ["BRAVE_SEARCH_API_KEY"],
engine=SearchEngineType.brave,
)
await create_agent_turn_with_search_tool(
agents_stack, search_query_messages, common_params, search_tool_definition
)
@pytest.mark.asyncio
async def test_create_agent_turn_with_tavily_search(
self, agents_stack, search_query_messages, common_params
@ -321,14 +220,57 @@ class TestAgents:
if "TAVILY_SEARCH_API_KEY" not in os.environ:
pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test")
search_tool_definition = SearchToolDefinition(
type=AgentTool.brave_search.value, # place holder only
api_key=os.environ["TAVILY_SEARCH_API_KEY"],
engine=SearchEngineType.tavily,
# Create an agent with the toolgroup
agent_config = AgentConfig(
**{
**common_params,
"toolgroups": ["builtin::web_search"],
}
)
await create_agent_turn_with_search_tool(
agents_stack, search_query_messages, common_params, search_tool_definition
agent_id, session_id = await create_agent_session(
agents_stack.impls[Api.agents], agent_config
)
turn_request = dict(
agent_id=agent_id,
session_id=session_id,
messages=search_query_messages,
stream=True,
)
turn_response = [
chunk
async for chunk in await agents_stack.impls[Api.agents].create_agent_turn(
**turn_request
)
]
assert len(turn_response) > 0
assert all(
isinstance(chunk, AgentTurnResponseStreamChunk) for chunk in turn_response
)
check_event_types(turn_response)
# Check for tool execution events
tool_execution_events = [
chunk
for chunk in turn_response
if isinstance(chunk.event.payload, AgentTurnResponseStepCompletePayload)
and chunk.event.payload.step_details.step_type
== StepType.tool_execution.value
]
assert len(tool_execution_events) > 0, "No tool execution events found"
# Check the tool execution details
tool_execution = tool_execution_events[0].event.payload.step_details
assert isinstance(tool_execution, ToolExecutionStep)
assert len(tool_execution.tool_calls) > 0
actual_tool_name = tool_execution.tool_calls[0].tool_name
assert actual_tool_name == BuiltinTool.brave_search
assert len(tool_execution.tool_responses) > 0
check_turn_complete_event(turn_response, session_id, search_query_messages)
def check_event_types(turn_response):

View file

@ -157,4 +157,5 @@ pytest_plugins = [
"llama_stack.providers.tests.scoring.fixtures",
"llama_stack.providers.tests.eval.fixtures",
"llama_stack.providers.tests.post_training.fixtures",
"llama_stack.providers.tests.tools.fixtures",
]

View file

@ -19,6 +19,7 @@ from llama_stack.providers.remote.memory.pgvector import PGVectorConfig
from llama_stack.providers.remote.memory.weaviate import WeaviateConfig
from llama_stack.providers.tests.resolver import construct_stack_for_test
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from ..conftest import ProviderFixture, remote_stack_fixture
from ..env import get_env_or_fail

View file

@ -16,7 +16,7 @@ from llama_stack.apis.memory_banks import MemoryBankInput
from llama_stack.apis.models import ModelInput
from llama_stack.apis.scoring_functions import ScoringFnInput
from llama_stack.apis.shields import ShieldInput
from llama_stack.apis.tools import ToolGroupInput
from llama_stack.distribution.build import print_pip_install_help
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.datatypes import Provider, StackRunConfig
@ -43,6 +43,7 @@ async def construct_stack_for_test(
datasets: Optional[List[DatasetInput]] = None,
scoring_fns: Optional[List[ScoringFnInput]] = None,
eval_tasks: Optional[List[EvalTaskInput]] = None,
tool_groups: Optional[List[ToolGroupInput]] = None,
) -> TestStack:
sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
run_config = dict(
@ -56,6 +57,7 @@ async def construct_stack_for_test(
datasets=datasets or [],
scoring_fns=scoring_fns or [],
eval_tasks=eval_tasks or [],
tool_groups=tool_groups or [],
)
run_config = parse_and_maybe_upgrade_config(run_config)
try:

View file

@ -0,0 +1,65 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from ..conftest import get_provider_fixture_overrides
from ..inference.fixtures import INFERENCE_FIXTURES
from ..memory.fixtures import MEMORY_FIXTURES
from ..safety.fixtures import SAFETY_FIXTURES
from .fixtures import TOOL_RUNTIME_FIXTURES
DEFAULT_PROVIDER_COMBINATIONS = [
pytest.param(
{
"inference": "together",
"safety": "llama_guard",
"memory": "faiss",
"tool_runtime": "memory_and_search",
},
id="together",
marks=pytest.mark.together,
),
]
def pytest_configure(config):
for mark in ["together"]:
config.addinivalue_line(
"markers",
f"{mark}: marks tests as {mark} specific",
)
def pytest_addoption(parser):
parser.addoption(
"--inference-model",
action="store",
default="meta-llama/Llama-3.2-3B-Instruct",
help="Specify the inference model to use for testing",
)
parser.addoption(
"--safety-shield",
action="store",
default="meta-llama/Llama-Guard-3-1B",
help="Specify the safety shield to use for testing",
)
def pytest_generate_tests(metafunc):
if "tools_stack" in metafunc.fixturenames:
available_fixtures = {
"inference": INFERENCE_FIXTURES,
"safety": SAFETY_FIXTURES,
"memory": MEMORY_FIXTURES,
"tool_runtime": TOOL_RUNTIME_FIXTURES,
}
combinations = (
get_provider_fixture_overrides(metafunc.config, available_fixtures)
or DEFAULT_PROVIDER_COMBINATIONS
)
print(combinations)
metafunc.parametrize("tools_stack", combinations, indirect=True)

View file

@ -0,0 +1,130 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import pytest
import pytest_asyncio
from llama_stack.apis.models import ModelInput, ModelType
from llama_stack.apis.tools import ToolGroupInput
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.tests.resolver import construct_stack_for_test
from ..conftest import ProviderFixture
@pytest.fixture(scope="session")
def tool_runtime_memory_and_search() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="memory-runtime",
provider_type="inline::memory-runtime",
config={},
),
Provider(
provider_id="tavily-search",
provider_type="remote::tavily-search",
config={
"api_key": os.environ["TAVILY_SEARCH_API_KEY"],
},
),
Provider(
provider_id="wolfram-alpha",
provider_type="remote::wolfram-alpha",
config={
"api_key": os.environ["WOLFRAM_ALPHA_API_KEY"],
},
),
],
)
@pytest.fixture(scope="session")
def tool_group_input_memory() -> ToolGroupInput:
return ToolGroupInput(
toolgroup_id="builtin::memory",
provider_id="memory-runtime",
)
@pytest.fixture(scope="session")
def tool_group_input_tavily_search() -> ToolGroupInput:
return ToolGroupInput(
toolgroup_id="builtin::web_search",
provider_id="tavily-search",
)
@pytest.fixture(scope="session")
def tool_group_input_wolfram_alpha() -> ToolGroupInput:
return ToolGroupInput(
toolgroup_id="builtin::wolfram_alpha",
provider_id="wolfram-alpha",
)
TOOL_RUNTIME_FIXTURES = ["memory_and_search"]
@pytest_asyncio.fixture(scope="session")
async def tools_stack(
request,
inference_model,
tool_group_input_memory,
tool_group_input_tavily_search,
tool_group_input_wolfram_alpha,
):
fixture_dict = request.param
providers = {}
provider_data = {}
for key in ["inference", "memory", "tool_runtime"]:
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
providers[key] = fixture.providers
if key == "inference":
providers[key].append(
Provider(
provider_id="tools_memory_provider",
provider_type="inline::sentence-transformers",
config={},
)
)
if fixture.provider_data:
provider_data.update(fixture.provider_data)
inference_models = (
inference_model if isinstance(inference_model, list) else [inference_model]
)
models = [
ModelInput(
model_id=model,
model_type=ModelType.llm,
provider_id=providers["inference"][0].provider_id,
)
for model in inference_models
]
models.append(
ModelInput(
model_id="all-MiniLM-L6-v2",
model_type=ModelType.embedding,
provider_id="tools_memory_provider",
metadata={"embedding_dimension": 384},
)
)
test_stack = await construct_stack_for_test(
[Api.tool_groups, Api.inference, Api.memory, Api.tool_runtime],
providers,
provider_data,
models=models,
tool_groups=[
tool_group_input_tavily_search,
tool_group_input_wolfram_alpha,
tool_group_input_memory,
],
)
return test_stack

View file

@ -0,0 +1,127 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import pytest
from llama_stack.apis.inference import UserMessage
from llama_stack.apis.memory import MemoryBankDocument
from llama_stack.apis.memory_banks import VectorMemoryBankParams
from llama_stack.apis.tools import ToolInvocationResult
from llama_stack.providers.datatypes import Api
@pytest.fixture
def sample_search_query():
return "What are the latest developments in quantum computing?"
@pytest.fixture
def sample_wolfram_alpha_query():
return "What is the square root of 16?"
@pytest.fixture
def sample_documents():
urls = [
"memory_optimizations.rst",
"chat.rst",
"llama3.rst",
"datasets.rst",
"qat_finetune.rst",
"lora_finetune.rst",
]
return [
MemoryBankDocument(
document_id=f"num-{i}",
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
mime_type="text/plain",
metadata={},
)
for i, url in enumerate(urls)
]
class TestTools:
@pytest.mark.asyncio
async def test_web_search_tool(self, tools_stack, sample_search_query):
"""Test the web search tool functionality."""
if "TAVILY_SEARCH_API_KEY" not in os.environ:
pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test")
tools_impl = tools_stack.impls[Api.tool_runtime]
# Execute the tool
response = await tools_impl.invoke_tool(
tool_name="web_search", args={"query": sample_search_query}
)
# Verify the response
assert isinstance(response, ToolInvocationResult)
assert response.content is not None
assert len(response.content) > 0
assert isinstance(response.content, str)
@pytest.mark.asyncio
async def test_wolfram_alpha_tool(self, tools_stack, sample_wolfram_alpha_query):
"""Test the wolfram alpha tool functionality."""
if "WOLFRAM_ALPHA_API_KEY" not in os.environ:
pytest.skip("WOLFRAM_ALPHA_API_KEY not set, skipping test")
tools_impl = tools_stack.impls[Api.tool_runtime]
response = await tools_impl.invoke_tool(
tool_name="wolfram_alpha", args={"query": sample_wolfram_alpha_query}
)
# Verify the response
assert isinstance(response, ToolInvocationResult)
assert response.content is not None
assert len(response.content) > 0
assert isinstance(response.content, str)
@pytest.mark.asyncio
async def test_memory_tool(self, tools_stack, sample_documents):
"""Test the memory tool functionality."""
memory_banks_impl = tools_stack.impls[Api.memory_banks]
memory_impl = tools_stack.impls[Api.memory]
tools_impl = tools_stack.impls[Api.tool_runtime]
# Register memory bank
await memory_banks_impl.register_memory_bank(
memory_bank_id="test_bank",
params=VectorMemoryBankParams(
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
),
provider_id="faiss",
)
# Insert documents into memory
await memory_impl.insert_documents(
bank_id="test_bank",
documents=sample_documents,
)
# Execute the memory tool
response = await tools_impl.invoke_tool(
tool_name="memory",
args={
"messages": [
UserMessage(
content="What are the main topics covered in the documentation?",
)
],
"memory_bank_ids": ["test_bank"],
},
)
# Verify the response
assert isinstance(response, ToolInvocationResult)
assert response.content is not None
assert len(response.content) > 0

View file

@ -14,7 +14,6 @@ from typing import List, Optional, Tuple, Union
import httpx
from llama_models.datatypes import is_multimodal, ModelFamily
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import (
RawContent,
@ -41,7 +40,6 @@ from llama_stack.apis.common.content_types import (
InterleavedContentItem,
TextContentItem,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
CompletionRequest,
@ -52,7 +50,6 @@ from llama_stack.apis.inference import (
ToolChoice,
UserMessage,
)
from llama_stack.providers.utils.inference import supported_inference_models
log = logging.getLogger(__name__)

View file

@ -9,8 +9,7 @@ from pathlib import Path
from llama_models.sku_list import all_registered_models
from llama_stack.apis.models import ModelInput
from llama_stack.distribution.datatypes import Provider
from llama_stack.distribution.datatypes import Provider, ToolGroupInput
from llama_stack.providers.inline.memory.faiss.config import FaissImplConfig
from llama_stack.providers.remote.inference.bedrock.bedrock import MODEL_ALIASES
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
@ -26,6 +25,12 @@ def get_distribution_template() -> DistributionTemplate:
"eval": ["inline::meta-reference"],
"datasetio": ["remote::huggingface", "inline::localfs"],
"scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"],
"tool_runtime": [
"remote::brave-search",
"remote::tavily-search",
"inline::code-interpreter",
"inline::memory-runtime",
],
}
name = "bedrock"
memory_provider = Provider(
@ -46,6 +51,20 @@ def get_distribution_template() -> DistributionTemplate:
)
for m in MODEL_ALIASES
]
default_tool_groups = [
ToolGroupInput(
toolgroup_id="builtin::websearch",
provider_id="tavily-search",
),
ToolGroupInput(
toolgroup_id="builtin::memory",
provider_id="memory-runtime",
),
ToolGroupInput(
toolgroup_id="builtin::code_interpreter",
provider_id="code-interpreter",
),
]
return DistributionTemplate(
name=name,
@ -61,6 +80,7 @@ def get_distribution_template() -> DistributionTemplate:
"memory": [memory_provider],
},
default_models=default_models,
default_tool_groups=default_tool_groups,
),
},
run_config_env_vars={

View file

@ -2,7 +2,6 @@ version: '2'
name: bedrock
distribution_spec:
description: Use AWS Bedrock for running LLM inference and safety
docker_image: null
providers:
inference:
- remote::bedrock
@ -25,4 +24,9 @@ distribution_spec:
- inline::basic
- inline::llm-as-judge
- inline::braintrust
tool_runtime:
- remote::brave-search
- remote::tavily-search
- inline::code-interpreter
- inline::memory-runtime
image_type: conda

View file

@ -1,6 +1,5 @@
version: '2'
image_name: bedrock
docker_image: null
conda_env: bedrock
apis:
- agents
@ -11,6 +10,7 @@ apis:
- safety
- scoring
- telemetry
- tool_runtime
providers:
inference:
- provider_id: bedrock
@ -65,8 +65,24 @@ providers:
provider_type: inline::braintrust
config:
openai_api_key: ${env.OPENAI_API_KEY:}
tool_runtime:
- provider_id: brave-search
provider_type: remote::brave-search
config:
api_key: ${env.BRAVE_SEARCH_API_KEY:}
max_results: 3
- provider_id: tavily-search
provider_type: remote::tavily-search
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:}
max_results: 3
- provider_id: code-interpreter
provider_type: inline::code-interpreter
config: {}
- provider_id: memory-runtime
provider_type: inline::memory-runtime
config: {}
metadata_store:
namespace: null
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/bedrock}/registry.db
models:
@ -90,3 +106,10 @@ memory_banks: []
datasets: []
scoring_fns: []
eval_tasks: []
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::memory
provider_id: memory-runtime
- toolgroup_id: builtin::code_interpreter
provider_id: code-interpreter

View file

@ -2,7 +2,6 @@ version: '2'
name: cerebras
distribution_spec:
description: Use Cerebras for running LLM inference
docker_image: null
providers:
inference:
- remote::cerebras
@ -14,4 +13,9 @@ distribution_spec:
- inline::meta-reference
telemetry:
- inline::meta-reference
tool_runtime:
- remote::brave-search
- remote::tavily-search
- inline::code-interpreter
- inline::memory-runtime
image_type: conda

View file

@ -9,8 +9,12 @@ from pathlib import Path
from llama_models.sku_list import all_registered_models
from llama_stack.apis.models.models import ModelType
from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput
from llama_stack.distribution.datatypes import (
ModelInput,
Provider,
ShieldInput,
ToolGroupInput,
)
from llama_stack.providers.inline.inference.sentence_transformers import (
SentenceTransformersInferenceConfig,
)
@ -26,6 +30,12 @@ def get_distribution_template() -> DistributionTemplate:
"memory": ["inline::meta-reference"],
"agents": ["inline::meta-reference"],
"telemetry": ["inline::meta-reference"],
"tool_runtime": [
"remote::brave-search",
"remote::tavily-search",
"inline::code-interpreter",
"inline::memory-runtime",
],
}
inference_provider = Provider(
@ -58,6 +68,20 @@ def get_distribution_template() -> DistributionTemplate:
"embedding_dimension": 384,
},
)
default_tool_groups = [
ToolGroupInput(
toolgroup_id="builtin::websearch",
provider_id="tavily-search",
),
ToolGroupInput(
toolgroup_id="builtin::memory",
provider_id="memory-runtime",
),
ToolGroupInput(
toolgroup_id="builtin::code_interpreter",
provider_id="code-interpreter",
),
]
return DistributionTemplate(
name="cerebras",
@ -74,6 +98,7 @@ def get_distribution_template() -> DistributionTemplate:
},
default_models=default_models + [embedding_model],
default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")],
default_tool_groups=default_tool_groups,
),
},
run_config_env_vars={

View file

@ -1,6 +1,5 @@
version: '2'
image_name: cerebras
docker_image: null
conda_env: cerebras
apis:
- agents
@ -8,6 +7,7 @@ apis:
- memory
- safety
- telemetry
- tool_runtime
providers:
inference:
- provider_id: cerebras
@ -45,8 +45,24 @@ providers:
service_name: ${env.OTEL_SERVICE_NAME:llama-stack}
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/cerebras/trace_store.db}
tool_runtime:
- provider_id: brave-search
provider_type: remote::brave-search
config:
api_key: ${env.BRAVE_SEARCH_API_KEY:}
max_results: 3
- provider_id: tavily-search
provider_type: remote::tavily-search
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:}
max_results: 3
- provider_id: code-interpreter
provider_type: inline::code-interpreter
config: {}
- provider_id: memory-runtime
provider_type: inline::memory-runtime
config: {}
metadata_store:
namespace: null
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/cerebras}/registry.db
models:
@ -64,14 +80,17 @@ models:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2
provider_id: sentence-transformers
provider_model_id: null
model_type: embedding
shields:
- params: null
shield_id: meta-llama/Llama-Guard-3-8B
provider_id: null
provider_shield_id: null
- shield_id: meta-llama/Llama-Guard-3-8B
memory_banks: []
datasets: []
scoring_fns: []
eval_tasks: []
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::memory
provider_id: memory-runtime
- toolgroup_id: builtin::code_interpreter
provider_id: code-interpreter

View file

@ -2,7 +2,6 @@ version: '2'
name: fireworks
distribution_spec:
description: Use Fireworks.AI for running LLM inference
docker_image: null
providers:
inference:
- remote::fireworks
@ -25,4 +24,9 @@ distribution_spec:
- inline::basic
- inline::llm-as-judge
- inline::braintrust
tool_runtime:
- remote::brave-search
- remote::tavily-search
- inline::code-interpreter
- inline::memory-runtime
image_type: conda

View file

@ -9,8 +9,12 @@ from pathlib import Path
from llama_models.sku_list import all_registered_models
from llama_stack.apis.models.models import ModelType
from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput
from llama_stack.distribution.datatypes import (
ModelInput,
Provider,
ShieldInput,
ToolGroupInput,
)
from llama_stack.providers.inline.inference.sentence_transformers import (
SentenceTransformersInferenceConfig,
)
@ -30,6 +34,12 @@ def get_distribution_template() -> DistributionTemplate:
"eval": ["inline::meta-reference"],
"datasetio": ["remote::huggingface", "inline::localfs"],
"scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"],
"tool_runtime": [
"remote::brave-search",
"remote::tavily-search",
"inline::code-interpreter",
"inline::memory-runtime",
],
}
name = "fireworks"
@ -69,6 +79,20 @@ def get_distribution_template() -> DistributionTemplate:
"embedding_dimension": 384,
},
)
default_tool_groups = [
ToolGroupInput(
toolgroup_id="builtin::websearch",
provider_id="tavily-search",
),
ToolGroupInput(
toolgroup_id="builtin::memory",
provider_id="memory-runtime",
),
ToolGroupInput(
toolgroup_id="builtin::code_interpreter",
provider_id="code-interpreter",
),
]
return DistributionTemplate(
name=name,
@ -86,6 +110,7 @@ def get_distribution_template() -> DistributionTemplate:
},
default_models=default_models + [embedding_model],
default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")],
default_tool_groups=default_tool_groups,
),
},
run_config_env_vars={

View file

@ -1,6 +1,5 @@
version: '2'
image_name: fireworks
docker_image: null
conda_env: fireworks
apis:
- agents
@ -11,6 +10,7 @@ apis:
- safety
- scoring
- telemetry
- tool_runtime
providers:
inference:
- provider_id: fireworks
@ -70,8 +70,24 @@ providers:
provider_type: inline::braintrust
config:
openai_api_key: ${env.OPENAI_API_KEY:}
tool_runtime:
- provider_id: brave-search
provider_type: remote::brave-search
config:
api_key: ${env.BRAVE_SEARCH_API_KEY:}
max_results: 3
- provider_id: tavily-search
provider_type: remote::tavily-search
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:}
max_results: 3
- provider_id: code-interpreter
provider_type: inline::code-interpreter
config: {}
- provider_id: memory-runtime
provider_type: inline::memory-runtime
config: {}
metadata_store:
namespace: null
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/fireworks}/registry.db
models:
@ -129,14 +145,17 @@ models:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2
provider_id: sentence-transformers
provider_model_id: null
model_type: embedding
shields:
- params: null
shield_id: meta-llama/Llama-Guard-3-8B
provider_id: null
provider_shield_id: null
- shield_id: meta-llama/Llama-Guard-3-8B
memory_banks: []
datasets: []
scoring_fns: []
eval_tasks: []
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::memory
provider_id: memory-runtime
- toolgroup_id: builtin::code_interpreter
provider_id: code-interpreter

View file

@ -2,7 +2,6 @@ version: '2'
name: hf-endpoint
distribution_spec:
description: Use (an external) Hugging Face Inference Endpoint for running LLM inference
docker_image: null
providers:
inference:
- remote::hf::endpoint
@ -25,4 +24,9 @@ distribution_spec:
- inline::basic
- inline::llm-as-judge
- inline::braintrust
tool_runtime:
- remote::brave-search
- remote::tavily-search
- inline::code-interpreter
- inline::memory-runtime
image_type: conda

View file

@ -5,7 +5,12 @@
# the root directory of this source tree.
from llama_stack.apis.models.models import ModelType
from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput
from llama_stack.distribution.datatypes import (
ModelInput,
Provider,
ShieldInput,
ToolGroupInput,
)
from llama_stack.providers.inline.inference.sentence_transformers import (
SentenceTransformersInferenceConfig,
)
@ -24,6 +29,12 @@ def get_distribution_template() -> DistributionTemplate:
"eval": ["inline::meta-reference"],
"datasetio": ["remote::huggingface", "inline::localfs"],
"scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"],
"tool_runtime": [
"remote::brave-search",
"remote::tavily-search",
"inline::code-interpreter",
"inline::memory-runtime",
],
}
name = "hf-endpoint"
inference_provider = Provider(
@ -58,6 +69,20 @@ def get_distribution_template() -> DistributionTemplate:
"embedding_dimension": 384,
},
)
default_tool_groups = [
ToolGroupInput(
toolgroup_id="builtin::websearch",
provider_id="tavily-search",
),
ToolGroupInput(
toolgroup_id="builtin::memory",
provider_id="memory-runtime",
),
ToolGroupInput(
toolgroup_id="builtin::code_interpreter",
provider_id="code-interpreter",
),
]
return DistributionTemplate(
name=name,
@ -74,6 +99,7 @@ def get_distribution_template() -> DistributionTemplate:
"memory": [memory_provider],
},
default_models=[inference_model, embedding_model],
default_tool_groups=default_tool_groups,
),
"run-with-safety.yaml": RunConfigSettings(
provider_overrides={
@ -96,6 +122,7 @@ def get_distribution_template() -> DistributionTemplate:
embedding_model,
],
default_shields=[ShieldInput(shield_id="${env.SAFETY_MODEL}")],
default_tool_groups=default_tool_groups,
),
},
run_config_env_vars={

View file

@ -1,6 +1,5 @@
version: '2'
image_name: hf-endpoint
docker_image: null
conda_env: hf-endpoint
apis:
- agents
@ -11,6 +10,7 @@ apis:
- safety
- scoring
- telemetry
- tool_runtime
providers:
inference:
- provider_id: hf-endpoint
@ -75,33 +75,50 @@ providers:
provider_type: inline::braintrust
config:
openai_api_key: ${env.OPENAI_API_KEY:}
tool_runtime:
- provider_id: brave-search
provider_type: remote::brave-search
config:
api_key: ${env.BRAVE_SEARCH_API_KEY:}
max_results: 3
- provider_id: tavily-search
provider_type: remote::tavily-search
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:}
max_results: 3
- provider_id: code-interpreter
provider_type: inline::code-interpreter
config: {}
- provider_id: memory-runtime
provider_type: inline::memory-runtime
config: {}
metadata_store:
namespace: null
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/registry.db
models:
- metadata: {}
model_id: ${env.INFERENCE_MODEL}
provider_id: hf-endpoint
provider_model_id: null
model_type: llm
- metadata: {}
model_id: ${env.SAFETY_MODEL}
provider_id: hf-endpoint-safety
provider_model_id: null
model_type: llm
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2
provider_id: sentence-transformers
provider_model_id: null
model_type: embedding
shields:
- params: null
shield_id: ${env.SAFETY_MODEL}
provider_id: null
provider_shield_id: null
- shield_id: ${env.SAFETY_MODEL}
memory_banks: []
datasets: []
scoring_fns: []
eval_tasks: []
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::memory
provider_id: memory-runtime
- toolgroup_id: builtin::code_interpreter
provider_id: code-interpreter

View file

@ -1,6 +1,5 @@
version: '2'
image_name: hf-endpoint
docker_image: null
conda_env: hf-endpoint
apis:
- agents
@ -11,6 +10,7 @@ apis:
- safety
- scoring
- telemetry
- tool_runtime
providers:
inference:
- provider_id: hf-endpoint
@ -70,24 +70,45 @@ providers:
provider_type: inline::braintrust
config:
openai_api_key: ${env.OPENAI_API_KEY:}
tool_runtime:
- provider_id: brave-search
provider_type: remote::brave-search
config:
api_key: ${env.BRAVE_SEARCH_API_KEY:}
max_results: 3
- provider_id: tavily-search
provider_type: remote::tavily-search
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:}
max_results: 3
- provider_id: code-interpreter
provider_type: inline::code-interpreter
config: {}
- provider_id: memory-runtime
provider_type: inline::memory-runtime
config: {}
metadata_store:
namespace: null
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/registry.db
models:
- metadata: {}
model_id: ${env.INFERENCE_MODEL}
provider_id: hf-endpoint
provider_model_id: null
model_type: llm
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2
provider_id: sentence-transformers
provider_model_id: null
model_type: embedding
shields: []
memory_banks: []
datasets: []
scoring_fns: []
eval_tasks: []
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::memory
provider_id: memory-runtime
- toolgroup_id: builtin::code_interpreter
provider_id: code-interpreter

View file

@ -2,7 +2,6 @@ version: '2'
name: hf-serverless
distribution_spec:
description: Use (an external) Hugging Face Inference Endpoint for running LLM inference
docker_image: null
providers:
inference:
- remote::hf::serverless
@ -25,4 +24,9 @@ distribution_spec:
- inline::basic
- inline::llm-as-judge
- inline::braintrust
tool_runtime:
- remote::brave-search
- remote::tavily-search
- inline::code-interpreter
- inline::memory-runtime
image_type: conda

View file

@ -5,7 +5,12 @@
# the root directory of this source tree.
from llama_stack.apis.models.models import ModelType
from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput
from llama_stack.distribution.datatypes import (
ModelInput,
Provider,
ShieldInput,
ToolGroupInput,
)
from llama_stack.providers.inline.inference.sentence_transformers import (
SentenceTransformersInferenceConfig,
)
@ -24,6 +29,12 @@ def get_distribution_template() -> DistributionTemplate:
"eval": ["inline::meta-reference"],
"datasetio": ["remote::huggingface", "inline::localfs"],
"scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"],
"tool_runtime": [
"remote::brave-search",
"remote::tavily-search",
"inline::code-interpreter",
"inline::memory-runtime",
],
}
name = "hf-serverless"
@ -59,6 +70,20 @@ def get_distribution_template() -> DistributionTemplate:
"embedding_dimension": 384,
},
)
default_tool_groups = [
ToolGroupInput(
toolgroup_id="builtin::websearch",
provider_id="tavily-search",
),
ToolGroupInput(
toolgroup_id="builtin::memory",
provider_id="memory-runtime",
),
ToolGroupInput(
toolgroup_id="builtin::code_interpreter",
provider_id="code-interpreter",
),
]
return DistributionTemplate(
name=name,
@ -97,6 +122,7 @@ def get_distribution_template() -> DistributionTemplate:
embedding_model,
],
default_shields=[ShieldInput(shield_id="${env.SAFETY_MODEL}")],
default_tool_groups=default_tool_groups,
),
},
run_config_env_vars={

View file

@ -1,6 +1,5 @@
version: '2'
image_name: hf-serverless
docker_image: null
conda_env: hf-serverless
apis:
- agents
@ -11,6 +10,7 @@ apis:
- safety
- scoring
- telemetry
- tool_runtime
providers:
inference:
- provider_id: hf-serverless
@ -75,33 +75,50 @@ providers:
provider_type: inline::braintrust
config:
openai_api_key: ${env.OPENAI_API_KEY:}
tool_runtime:
- provider_id: brave-search
provider_type: remote::brave-search
config:
api_key: ${env.BRAVE_SEARCH_API_KEY:}
max_results: 3
- provider_id: tavily-search
provider_type: remote::tavily-search
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:}
max_results: 3
- provider_id: code-interpreter
provider_type: inline::code-interpreter
config: {}
- provider_id: memory-runtime
provider_type: inline::memory-runtime
config: {}
metadata_store:
namespace: null
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-serverless}/registry.db
models:
- metadata: {}
model_id: ${env.INFERENCE_MODEL}
provider_id: hf-serverless
provider_model_id: null
model_type: llm
- metadata: {}
model_id: ${env.SAFETY_MODEL}
provider_id: hf-serverless-safety
provider_model_id: null
model_type: llm
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2
provider_id: sentence-transformers
provider_model_id: null
model_type: embedding
shields:
- params: null
shield_id: ${env.SAFETY_MODEL}
provider_id: null
provider_shield_id: null
- shield_id: ${env.SAFETY_MODEL}
memory_banks: []
datasets: []
scoring_fns: []
eval_tasks: []
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::memory
provider_id: memory-runtime
- toolgroup_id: builtin::code_interpreter
provider_id: code-interpreter

View file

@ -1,6 +1,5 @@
version: '2'
image_name: hf-serverless
docker_image: null
conda_env: hf-serverless
apis:
- agents
@ -11,6 +10,7 @@ apis:
- safety
- scoring
- telemetry
- tool_runtime
providers:
inference:
- provider_id: hf-serverless
@ -70,24 +70,39 @@ providers:
provider_type: inline::braintrust
config:
openai_api_key: ${env.OPENAI_API_KEY:}
tool_runtime:
- provider_id: brave-search
provider_type: remote::brave-search
config:
api_key: ${env.BRAVE_SEARCH_API_KEY:}
max_results: 3
- provider_id: tavily-search
provider_type: remote::tavily-search
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:}
max_results: 3
- provider_id: code-interpreter
provider_type: inline::code-interpreter
config: {}
- provider_id: memory-runtime
provider_type: inline::memory-runtime
config: {}
metadata_store:
namespace: null
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-serverless}/registry.db
models:
- metadata: {}
model_id: ${env.INFERENCE_MODEL}
provider_id: hf-serverless
provider_model_id: null
model_type: llm
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2
provider_id: sentence-transformers
provider_model_id: null
model_type: embedding
shields: []
memory_banks: []
datasets: []
scoring_fns: []
eval_tasks: []
tool_groups: []

View file

@ -2,7 +2,6 @@ version: '2'
name: meta-reference-gpu
distribution_spec:
description: Use Meta Reference for running LLM inference
docker_image: null
providers:
inference:
- inline::meta-reference
@ -25,4 +24,9 @@ distribution_spec:
- inline::basic
- inline::llm-as-judge
- inline::braintrust
tool_runtime:
- remote::brave-search
- remote::tavily-search
- inline::code-interpreter
- inline::memory-runtime
image_type: conda

View file

@ -7,8 +7,12 @@
from pathlib import Path
from llama_stack.apis.models.models import ModelType
from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput
from llama_stack.distribution.datatypes import (
ModelInput,
Provider,
ShieldInput,
ToolGroupInput,
)
from llama_stack.providers.inline.inference.meta_reference import (
MetaReferenceInferenceConfig,
)
@ -29,6 +33,12 @@ def get_distribution_template() -> DistributionTemplate:
"eval": ["inline::meta-reference"],
"datasetio": ["remote::huggingface", "inline::localfs"],
"scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"],
"tool_runtime": [
"remote::brave-search",
"remote::tavily-search",
"inline::code-interpreter",
"inline::memory-runtime",
],
}
name = "meta-reference-gpu"
inference_provider = Provider(
@ -66,6 +76,20 @@ def get_distribution_template() -> DistributionTemplate:
model_id="${env.SAFETY_MODEL}",
provider_id="meta-reference-safety",
)
default_tool_groups = [
ToolGroupInput(
toolgroup_id="builtin::websearch",
provider_id="tavily-search",
),
ToolGroupInput(
toolgroup_id="builtin::memory",
provider_id="memory-runtime",
),
ToolGroupInput(
toolgroup_id="builtin::code_interpreter",
provider_id="code-interpreter",
),
]
return DistributionTemplate(
name=name,
@ -104,6 +128,7 @@ def get_distribution_template() -> DistributionTemplate:
embedding_model,
],
default_shields=[ShieldInput(shield_id="${env.SAFETY_MODEL}")],
default_tool_groups=default_tool_groups,
),
},
run_config_env_vars={

View file

@ -1,6 +1,5 @@
version: '2'
image_name: meta-reference-gpu
docker_image: null
conda_env: meta-reference-gpu
apis:
- agents
@ -11,6 +10,7 @@ apis:
- safety
- scoring
- telemetry
- tool_runtime
providers:
inference:
- provider_id: meta-reference-inference
@ -77,33 +77,50 @@ providers:
provider_type: inline::braintrust
config:
openai_api_key: ${env.OPENAI_API_KEY:}
tool_runtime:
- provider_id: brave-search
provider_type: remote::brave-search
config:
api_key: ${env.BRAVE_SEARCH_API_KEY:}
max_results: 3
- provider_id: tavily-search
provider_type: remote::tavily-search
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:}
max_results: 3
- provider_id: code-interpreter
provider_type: inline::code-interpreter
config: {}
- provider_id: memory-runtime
provider_type: inline::memory-runtime
config: {}
metadata_store:
namespace: null
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/registry.db
models:
- metadata: {}
model_id: ${env.INFERENCE_MODEL}
provider_id: meta-reference-inference
provider_model_id: null
model_type: llm
- metadata: {}
model_id: ${env.SAFETY_MODEL}
provider_id: meta-reference-safety
provider_model_id: null
model_type: llm
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2
provider_id: sentence-transformers
provider_model_id: null
model_type: embedding
shields:
- params: null
shield_id: ${env.SAFETY_MODEL}
provider_id: null
provider_shield_id: null
- shield_id: ${env.SAFETY_MODEL}
memory_banks: []
datasets: []
scoring_fns: []
eval_tasks: []
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::memory
provider_id: memory-runtime
- toolgroup_id: builtin::code_interpreter
provider_id: code-interpreter

View file

@ -1,6 +1,5 @@
version: '2'
image_name: meta-reference-gpu
docker_image: null
conda_env: meta-reference-gpu
apis:
- agents
@ -11,6 +10,7 @@ apis:
- safety
- scoring
- telemetry
- tool_runtime
providers:
inference:
- provider_id: meta-reference-inference
@ -71,24 +71,39 @@ providers:
provider_type: inline::braintrust
config:
openai_api_key: ${env.OPENAI_API_KEY:}
tool_runtime:
- provider_id: brave-search
provider_type: remote::brave-search
config:
api_key: ${env.BRAVE_SEARCH_API_KEY:}
max_results: 3
- provider_id: tavily-search
provider_type: remote::tavily-search
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:}
max_results: 3
- provider_id: code-interpreter
provider_type: inline::code-interpreter
config: {}
- provider_id: memory-runtime
provider_type: inline::memory-runtime
config: {}
metadata_store:
namespace: null
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/registry.db
models:
- metadata: {}
model_id: ${env.INFERENCE_MODEL}
provider_id: meta-reference-inference
provider_model_id: null
model_type: llm
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2
provider_id: sentence-transformers
provider_model_id: null
model_type: embedding
shields: []
memory_banks: []
datasets: []
scoring_fns: []
eval_tasks: []
tool_groups: []

View file

@ -2,7 +2,6 @@ version: '2'
name: meta-reference-quantized-gpu
distribution_spec:
description: Use Meta Reference with fp8, int4 quantization for running LLM inference
docker_image: null
providers:
inference:
- inline::meta-reference-quantized
@ -25,4 +24,9 @@ distribution_spec:
- inline::basic
- inline::llm-as-judge
- inline::braintrust
tool_runtime:
- remote::brave-search
- remote::tavily-search
- inline::code-interpreter
- inline::memory-runtime
image_type: conda

View file

@ -7,8 +7,7 @@
from pathlib import Path
from llama_stack.apis.models.models import ModelType
from llama_stack.distribution.datatypes import ModelInput, Provider
from llama_stack.distribution.datatypes import ModelInput, Provider, ToolGroupInput
from llama_stack.providers.inline.inference.meta_reference import (
MetaReferenceQuantizedInferenceConfig,
)
@ -29,7 +28,27 @@ def get_distribution_template() -> DistributionTemplate:
"eval": ["inline::meta-reference"],
"datasetio": ["remote::huggingface", "inline::localfs"],
"scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"],
"tool_runtime": [
"remote::brave-search",
"remote::tavily-search",
"inline::code-interpreter",
"inline::memory-runtime",
],
}
default_tool_groups = [
ToolGroupInput(
toolgroup_id="builtin::websearch",
provider_id="tavily-search",
),
ToolGroupInput(
toolgroup_id="builtin::memory",
provider_id="memory-runtime",
),
ToolGroupInput(
toolgroup_id="builtin::code_interpreter",
provider_id="code-interpreter",
),
]
name = "meta-reference-quantized-gpu"
inference_provider = Provider(
provider_id="meta-reference-inference",
@ -76,6 +95,7 @@ def get_distribution_template() -> DistributionTemplate:
"memory": [memory_provider],
},
default_models=[inference_model, embedding_model],
default_tool_groups=default_tool_groups,
),
},
run_config_env_vars={

View file

@ -1,6 +1,5 @@
version: '2'
image_name: meta-reference-quantized-gpu
docker_image: null
conda_env: meta-reference-quantized-gpu
apis:
- agents
@ -11,6 +10,7 @@ apis:
- safety
- scoring
- telemetry
- tool_runtime
providers:
inference:
- provider_id: meta-reference-inference
@ -73,24 +73,45 @@ providers:
provider_type: inline::braintrust
config:
openai_api_key: ${env.OPENAI_API_KEY:}
tool_runtime:
- provider_id: brave-search
provider_type: remote::brave-search
config:
api_key: ${env.BRAVE_SEARCH_API_KEY:}
max_results: 3
- provider_id: tavily-search
provider_type: remote::tavily-search
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:}
max_results: 3
- provider_id: code-interpreter
provider_type: inline::code-interpreter
config: {}
- provider_id: memory-runtime
provider_type: inline::memory-runtime
config: {}
metadata_store:
namespace: null
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-quantized-gpu}/registry.db
models:
- metadata: {}
model_id: ${env.INFERENCE_MODEL}
provider_id: meta-reference-inference
provider_model_id: null
model_type: llm
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2
provider_id: sentence-transformers
provider_model_id: null
model_type: embedding
shields: []
memory_banks: []
datasets: []
scoring_fns: []
eval_tasks: []
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::memory
provider_id: memory-runtime
- toolgroup_id: builtin::code_interpreter
provider_id: code-interpreter

View file

@ -2,7 +2,6 @@ version: '2'
name: ollama
distribution_spec:
description: Use (an external) Ollama server for running LLM inference
docker_image: null
providers:
inference:
- remote::ollama
@ -25,4 +24,9 @@ distribution_spec:
- inline::basic
- inline::llm-as-judge
- inline::braintrust
tool_runtime:
- remote::brave-search
- remote::tavily-search
- inline::code-interpreter
- inline::memory-runtime
image_type: conda

View file

@ -7,8 +7,12 @@
from pathlib import Path
from llama_stack.apis.models.models import ModelType
from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput
from llama_stack.distribution.datatypes import (
ModelInput,
Provider,
ShieldInput,
ToolGroupInput,
)
from llama_stack.providers.inline.inference.sentence_transformers import (
SentenceTransformersInferenceConfig,
)
@ -27,6 +31,12 @@ def get_distribution_template() -> DistributionTemplate:
"eval": ["inline::meta-reference"],
"datasetio": ["remote::huggingface", "inline::localfs"],
"scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"],
"tool_runtime": [
"remote::brave-search",
"remote::tavily-search",
"inline::code-interpreter",
"inline::memory-runtime",
],
}
name = "ollama"
inference_provider = Provider(
@ -61,6 +71,20 @@ def get_distribution_template() -> DistributionTemplate:
"embedding_dimension": 384,
},
)
default_tool_groups = [
ToolGroupInput(
toolgroup_id="builtin::websearch",
provider_id="tavily-search",
),
ToolGroupInput(
toolgroup_id="builtin::memory",
provider_id="memory-runtime",
),
ToolGroupInput(
toolgroup_id="builtin::code_interpreter",
provider_id="code-interpreter",
),
]
return DistributionTemplate(
name=name,
@ -92,6 +116,7 @@ def get_distribution_template() -> DistributionTemplate:
embedding_model,
],
default_shields=[ShieldInput(shield_id="${env.SAFETY_MODEL}")],
default_tool_groups=default_tool_groups,
),
},
run_config_env_vars={

View file

@ -1,6 +1,5 @@
version: '2'
image_name: ollama
docker_image: null
conda_env: ollama
apis:
- agents
@ -11,6 +10,7 @@ apis:
- safety
- scoring
- telemetry
- tool_runtime
providers:
inference:
- provider_id: ollama
@ -69,33 +69,50 @@ providers:
provider_type: inline::braintrust
config:
openai_api_key: ${env.OPENAI_API_KEY:}
tool_runtime:
- provider_id: brave-search
provider_type: remote::brave-search
config:
api_key: ${env.BRAVE_SEARCH_API_KEY:}
max_results: 3
- provider_id: tavily-search
provider_type: remote::tavily-search
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:}
max_results: 3
- provider_id: code-interpreter
provider_type: inline::code-interpreter
config: {}
- provider_id: memory-runtime
provider_type: inline::memory-runtime
config: {}
metadata_store:
namespace: null
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/registry.db
models:
- metadata: {}
model_id: ${env.INFERENCE_MODEL}
provider_id: ollama
provider_model_id: null
model_type: llm
- metadata: {}
model_id: ${env.SAFETY_MODEL}
provider_id: ollama
provider_model_id: null
model_type: llm
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2
provider_id: sentence-transformers
provider_model_id: null
model_type: embedding
shields:
- params: null
shield_id: ${env.SAFETY_MODEL}
provider_id: null
provider_shield_id: null
- shield_id: ${env.SAFETY_MODEL}
memory_banks: []
datasets: []
scoring_fns: []
eval_tasks: []
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::memory
provider_id: memory-runtime
- toolgroup_id: builtin::code_interpreter
provider_id: code-interpreter

View file

@ -1,6 +1,5 @@
version: '2'
image_name: ollama
docker_image: null
conda_env: ollama
apis:
- agents
@ -11,6 +10,7 @@ apis:
- safety
- scoring
- telemetry
- tool_runtime
providers:
inference:
- provider_id: ollama
@ -69,24 +69,39 @@ providers:
provider_type: inline::braintrust
config:
openai_api_key: ${env.OPENAI_API_KEY:}
tool_runtime:
- provider_id: brave-search
provider_type: remote::brave-search
config:
api_key: ${env.BRAVE_SEARCH_API_KEY:}
max_results: 3
- provider_id: tavily-search
provider_type: remote::tavily-search
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:}
max_results: 3
- provider_id: code-interpreter
provider_type: inline::code-interpreter
config: {}
- provider_id: memory-runtime
provider_type: inline::memory-runtime
config: {}
metadata_store:
namespace: null
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/registry.db
models:
- metadata: {}
model_id: ${env.INFERENCE_MODEL}
provider_id: ollama
provider_model_id: null
model_type: llm
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2
provider_id: sentence-transformers
provider_model_id: null
model_type: embedding
shields: []
memory_banks: []
datasets: []
scoring_fns: []
eval_tasks: []
tool_groups: []

View file

@ -2,7 +2,6 @@ version: '2'
name: remote-vllm
distribution_spec:
description: Use (an external) vLLM server for running LLM inference
docker_image: null
providers:
inference:
- remote::vllm
@ -16,4 +15,9 @@ distribution_spec:
- inline::meta-reference
telemetry:
- inline::meta-reference
tool_runtime:
- remote::brave-search
- remote::tavily-search
- inline::code-interpreter
- inline::memory-runtime
image_type: conda

Some files were not shown because too many files have changed in this diff Show more