mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
agents to use tools api (#673)
# What does this PR do? PR #639 introduced the notion of Tools API and ability to invoke tools through API just as any resource. This PR changes the Agents to start using the Tools API to invoke tools. Major changes include: 1) Ability to specify tool groups with AgentConfig 2) Agent gets the corresponding tool definitions for the specified tools and pass along to the model 3) Attachements are now named as Documents and their behavior is mostly unchanged from user perspective 4) You can specify args that can be injected to a tool call through Agent config. This is especially useful in case of memory tool, where you want the tool to operate on a specific memory bank. 5) You can also register tool groups with args, which lets the agent inject these as well into the tool call. 6) All tests have been migrated to use new tools API and fixtures including client SDK tests 7) Telemetry just works with tools API because of our trace protocol decorator ## Test Plan ``` pytest -s -v -k fireworks llama_stack/providers/tests/agents/test_agents.py \ --safety-shield=meta-llama/Llama-Guard-3-8B \ --inference-model=meta-llama/Llama-3.1-8B-Instruct pytest -s -v -k together llama_stack/providers/tests/tools/test_tools.py \ --safety-shield=meta-llama/Llama-Guard-3-8B \ --inference-model=meta-llama/Llama-3.1-8B-Instruct LLAMA_STACK_CONFIG="/Users/dineshyv/.llama/distributions/llamastack-together/together-run.yaml" pytest -v tests/client-sdk/agents/test_agents.py ``` run.yaml: https://gist.github.com/dineshyv/0365845ad325e1c2cab755788ccc5994 Notebook: https://colab.research.google.com/drive/1ck7hXQxRl6UvT-ijNRZ-gMZxH1G3cN2d?usp=sharing
This commit is contained in:
parent
596afc6497
commit
a5c57cd381
116 changed files with 4959 additions and 2778 deletions
|
@ -23,6 +23,7 @@
|
|||
"psycopg2-binary",
|
||||
"pypdf",
|
||||
"redis",
|
||||
"requests",
|
||||
"scikit-learn",
|
||||
"scipy",
|
||||
"sentencepiece",
|
||||
|
@ -54,6 +55,7 @@
|
|||
"psycopg2-binary",
|
||||
"pypdf",
|
||||
"redis",
|
||||
"requests",
|
||||
"scikit-learn",
|
||||
"scipy",
|
||||
"sentencepiece",
|
||||
|
@ -86,6 +88,7 @@
|
|||
"psycopg2-binary",
|
||||
"pypdf",
|
||||
"redis",
|
||||
"requests",
|
||||
"scikit-learn",
|
||||
"scipy",
|
||||
"sentencepiece",
|
||||
|
@ -116,6 +119,7 @@
|
|||
"psycopg2-binary",
|
||||
"pypdf",
|
||||
"redis",
|
||||
"requests",
|
||||
"scikit-learn",
|
||||
"scipy",
|
||||
"sentencepiece",
|
||||
|
@ -148,6 +152,7 @@
|
|||
"psycopg2-binary",
|
||||
"pypdf",
|
||||
"redis",
|
||||
"requests",
|
||||
"scikit-learn",
|
||||
"scipy",
|
||||
"sentencepiece",
|
||||
|
@ -181,6 +186,7 @@
|
|||
"psycopg2-binary",
|
||||
"pypdf",
|
||||
"redis",
|
||||
"requests",
|
||||
"scikit-learn",
|
||||
"scipy",
|
||||
"sentencepiece",
|
||||
|
@ -213,6 +219,7 @@
|
|||
"psycopg2-binary",
|
||||
"pypdf",
|
||||
"redis",
|
||||
"requests",
|
||||
"scikit-learn",
|
||||
"scipy",
|
||||
"sentencepiece",
|
||||
|
@ -247,6 +254,7 @@
|
|||
"psycopg2-binary",
|
||||
"pypdf",
|
||||
"redis",
|
||||
"requests",
|
||||
"scikit-learn",
|
||||
"scipy",
|
||||
"sentence-transformers",
|
||||
|
@ -286,6 +294,7 @@
|
|||
"psycopg2-binary",
|
||||
"pypdf",
|
||||
"redis",
|
||||
"requests",
|
||||
"scikit-learn",
|
||||
"scipy",
|
||||
"sentence-transformers",
|
||||
|
@ -319,6 +328,7 @@
|
|||
"psycopg2-binary",
|
||||
"pypdf",
|
||||
"redis",
|
||||
"requests",
|
||||
"scikit-learn",
|
||||
"scipy",
|
||||
"sentencepiece",
|
||||
|
@ -352,6 +362,7 @@
|
|||
"psycopg2-binary",
|
||||
"pypdf",
|
||||
"redis",
|
||||
"requests",
|
||||
"scikit-learn",
|
||||
"scipy",
|
||||
"sentencepiece",
|
||||
|
@ -385,6 +396,7 @@
|
|||
"psycopg2-binary",
|
||||
"pypdf",
|
||||
"redis",
|
||||
"requests",
|
||||
"scikit-learn",
|
||||
"scipy",
|
||||
"sentencepiece",
|
||||
|
|
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
|
@ -19,6 +19,7 @@ The `llamastack/distribution-bedrock` distribution consists of the following pro
|
|||
| safety | `remote::bedrock` |
|
||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||
| telemetry | `inline::meta-reference` |
|
||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::memory-runtime` |
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -9,6 +9,7 @@ The `llamastack/distribution-cerebras` distribution consists of the following pr
|
|||
| memory | `inline::meta-reference` |
|
||||
| safety | `inline::llama-guard` |
|
||||
| telemetry | `inline::meta-reference` |
|
||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::memory-runtime` |
|
||||
|
||||
|
||||
### Environment Variables
|
||||
|
|
|
@ -22,6 +22,7 @@ The `llamastack/distribution-fireworks` distribution consists of the following p
|
|||
| safety | `inline::llama-guard` |
|
||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||
| telemetry | `inline::meta-reference` |
|
||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::memory-runtime` |
|
||||
|
||||
|
||||
### Environment Variables
|
||||
|
|
|
@ -22,6 +22,7 @@ The `llamastack/distribution-meta-reference-gpu` distribution consists of the fo
|
|||
| safety | `inline::llama-guard` |
|
||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||
| telemetry | `inline::meta-reference` |
|
||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::memory-runtime` |
|
||||
|
||||
|
||||
Note that you need access to nvidia GPUs to run this distribution. This distribution is not compatible with CPU-only machines or machines with AMD GPUs.
|
||||
|
|
|
@ -22,6 +22,7 @@ The `llamastack/distribution-meta-reference-quantized-gpu` distribution consists
|
|||
| safety | `inline::llama-guard` |
|
||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||
| telemetry | `inline::meta-reference` |
|
||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::memory-runtime` |
|
||||
|
||||
|
||||
The only difference vs. the `meta-reference-gpu` distribution is that it has support for more efficient inference -- with fp8, int4 quantization, etc.
|
||||
|
|
|
@ -22,6 +22,7 @@ The `llamastack/distribution-ollama` distribution consists of the following prov
|
|||
| safety | `inline::llama-guard` |
|
||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||
| telemetry | `inline::meta-reference` |
|
||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::memory-runtime` |
|
||||
|
||||
|
||||
You should use this distribution if you have a regular desktop machine without very powerful GPUs. Of course, if you have powerful GPUs, you can still continue using this distribution since Ollama supports GPU acceleration.### Environment Variables
|
||||
|
|
|
@ -18,6 +18,7 @@ The `llamastack/distribution-remote-vllm` distribution consists of the following
|
|||
| memory | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
||||
| safety | `inline::llama-guard` |
|
||||
| telemetry | `inline::meta-reference` |
|
||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::memory-runtime` |
|
||||
|
||||
|
||||
You can use this distribution if you have GPUs and want to run an independent vLLM server container for running inference.
|
||||
|
|
|
@ -23,6 +23,7 @@ The `llamastack/distribution-tgi` distribution consists of the following provide
|
|||
| safety | `inline::llama-guard` |
|
||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||
| telemetry | `inline::meta-reference` |
|
||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::memory-runtime` |
|
||||
|
||||
|
||||
You can use this distribution if you have GPUs and want to run an independent TGI server container for running inference.
|
||||
|
|
|
@ -22,6 +22,7 @@ The `llamastack/distribution-together` distribution consists of the following pr
|
|||
| safety | `inline::llama-guard` |
|
||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||
| telemetry | `inline::meta-reference` |
|
||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::memory-runtime` |
|
||||
|
||||
|
||||
### Environment Variables
|
||||
|
|
|
@ -18,15 +18,11 @@ from typing import (
|
|||
Union,
|
||||
)
|
||||
|
||||
from llama_models.llama3.api.datatypes import ToolParamDefinition
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent, URL
|
||||
from llama_stack.apis.common.deployment_types import RestAPIExecutionConfig
|
||||
from llama_stack.apis.inference import (
|
||||
CompletionMessage,
|
||||
SamplingParams,
|
||||
|
@ -40,166 +36,18 @@ from llama_stack.apis.inference import (
|
|||
)
|
||||
from llama_stack.apis.memory import MemoryBank
|
||||
from llama_stack.apis.safety import SafetyViolation
|
||||
|
||||
from llama_stack.apis.tools import ToolDef
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Attachment(BaseModel):
|
||||
content: InterleavedContent | URL
|
||||
mime_type: str
|
||||
|
||||
|
||||
class AgentTool(Enum):
|
||||
brave_search = "brave_search"
|
||||
wolfram_alpha = "wolfram_alpha"
|
||||
photogen = "photogen"
|
||||
code_interpreter = "code_interpreter"
|
||||
|
||||
function_call = "function_call"
|
||||
memory = "memory"
|
||||
|
||||
|
||||
class ToolDefinitionCommon(BaseModel):
|
||||
input_shields: Optional[List[str]] = Field(default_factory=list)
|
||||
output_shields: Optional[List[str]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class SearchEngineType(Enum):
|
||||
bing = "bing"
|
||||
brave = "brave"
|
||||
tavily = "tavily"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class SearchToolDefinition(ToolDefinitionCommon):
|
||||
# NOTE: brave_search is just a placeholder since model always uses
|
||||
# brave_search as tool call name
|
||||
type: Literal[AgentTool.brave_search.value] = AgentTool.brave_search.value
|
||||
api_key: str
|
||||
engine: SearchEngineType = SearchEngineType.brave
|
||||
remote_execution: Optional[RestAPIExecutionConfig] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class WolframAlphaToolDefinition(ToolDefinitionCommon):
|
||||
type: Literal[AgentTool.wolfram_alpha.value] = AgentTool.wolfram_alpha.value
|
||||
api_key: str
|
||||
remote_execution: Optional[RestAPIExecutionConfig] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class PhotogenToolDefinition(ToolDefinitionCommon):
|
||||
type: Literal[AgentTool.photogen.value] = AgentTool.photogen.value
|
||||
remote_execution: Optional[RestAPIExecutionConfig] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class CodeInterpreterToolDefinition(ToolDefinitionCommon):
|
||||
type: Literal[AgentTool.code_interpreter.value] = AgentTool.code_interpreter.value
|
||||
enable_inline_code_execution: bool = True
|
||||
remote_execution: Optional[RestAPIExecutionConfig] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class FunctionCallToolDefinition(ToolDefinitionCommon):
|
||||
type: Literal[AgentTool.function_call.value] = AgentTool.function_call.value
|
||||
function_name: str
|
||||
description: str
|
||||
parameters: Dict[str, ToolParamDefinition]
|
||||
remote_execution: Optional[RestAPIExecutionConfig] = None
|
||||
|
||||
|
||||
class _MemoryBankConfigCommon(BaseModel):
|
||||
bank_id: str
|
||||
|
||||
|
||||
class AgentVectorMemoryBankConfig(_MemoryBankConfigCommon):
|
||||
type: Literal["vector"] = "vector"
|
||||
|
||||
|
||||
class AgentKeyValueMemoryBankConfig(_MemoryBankConfigCommon):
|
||||
type: Literal["keyvalue"] = "keyvalue"
|
||||
keys: List[str] # what keys to focus on
|
||||
|
||||
|
||||
class AgentKeywordMemoryBankConfig(_MemoryBankConfigCommon):
|
||||
type: Literal["keyword"] = "keyword"
|
||||
|
||||
|
||||
class AgentGraphMemoryBankConfig(_MemoryBankConfigCommon):
|
||||
type: Literal["graph"] = "graph"
|
||||
entities: List[str] # what entities to focus on
|
||||
|
||||
|
||||
MemoryBankConfig = Annotated[
|
||||
Union[
|
||||
AgentVectorMemoryBankConfig,
|
||||
AgentKeyValueMemoryBankConfig,
|
||||
AgentKeywordMemoryBankConfig,
|
||||
AgentGraphMemoryBankConfig,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
class MemoryQueryGenerator(Enum):
|
||||
default = "default"
|
||||
llm = "llm"
|
||||
custom = "custom"
|
||||
|
||||
|
||||
class DefaultMemoryQueryGeneratorConfig(BaseModel):
|
||||
type: Literal[MemoryQueryGenerator.default.value] = (
|
||||
MemoryQueryGenerator.default.value
|
||||
)
|
||||
sep: str = " "
|
||||
|
||||
|
||||
class LLMMemoryQueryGeneratorConfig(BaseModel):
|
||||
type: Literal[MemoryQueryGenerator.llm.value] = MemoryQueryGenerator.llm.value
|
||||
model: str
|
||||
template: str
|
||||
|
||||
|
||||
class CustomMemoryQueryGeneratorConfig(BaseModel):
|
||||
type: Literal[MemoryQueryGenerator.custom.value] = MemoryQueryGenerator.custom.value
|
||||
|
||||
|
||||
MemoryQueryGeneratorConfig = Annotated[
|
||||
Union[
|
||||
DefaultMemoryQueryGeneratorConfig,
|
||||
LLMMemoryQueryGeneratorConfig,
|
||||
CustomMemoryQueryGeneratorConfig,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MemoryToolDefinition(ToolDefinitionCommon):
|
||||
type: Literal[AgentTool.memory.value] = AgentTool.memory.value
|
||||
memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list)
|
||||
# This config defines how a query is generated using the messages
|
||||
# for memory bank retrieval.
|
||||
query_generator_config: MemoryQueryGeneratorConfig = Field(
|
||||
default=DefaultMemoryQueryGeneratorConfig()
|
||||
)
|
||||
max_tokens_in_context: int = 4096
|
||||
max_chunks: int = 10
|
||||
|
||||
|
||||
AgentToolDefinition = Annotated[
|
||||
Union[
|
||||
SearchToolDefinition,
|
||||
WolframAlphaToolDefinition,
|
||||
PhotogenToolDefinition,
|
||||
CodeInterpreterToolDefinition,
|
||||
FunctionCallToolDefinition,
|
||||
MemoryToolDefinition,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
class Document(BaseModel):
|
||||
content: InterleavedContent | URL
|
||||
mime_type: str
|
||||
|
||||
|
||||
class StepCommon(BaseModel):
|
||||
|
@ -289,13 +137,27 @@ class Session(BaseModel):
|
|||
memory_bank: Optional[MemoryBank] = None
|
||||
|
||||
|
||||
class AgentToolGroupWithArgs(BaseModel):
|
||||
name: str
|
||||
args: Dict[str, Any]
|
||||
|
||||
|
||||
AgentToolGroup = register_schema(
|
||||
Union[
|
||||
str,
|
||||
AgentToolGroupWithArgs,
|
||||
],
|
||||
name="AgentTool",
|
||||
)
|
||||
|
||||
|
||||
class AgentConfigCommon(BaseModel):
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams()
|
||||
|
||||
input_shields: Optional[List[str]] = Field(default_factory=list)
|
||||
output_shields: Optional[List[str]] = Field(default_factory=list)
|
||||
|
||||
tools: Optional[List[AgentToolDefinition]] = Field(default_factory=list)
|
||||
toolgroups: Optional[List[AgentToolGroup]] = Field(default_factory=list)
|
||||
client_tools: Optional[List[ToolDef]] = Field(default_factory=list)
|
||||
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
||||
default=ToolPromptFormat.json
|
||||
|
@ -340,6 +202,7 @@ class AgentTurnResponseStepCompletePayload(BaseModel):
|
|||
AgentTurnResponseEventType.step_complete.value
|
||||
)
|
||||
step_type: StepType
|
||||
step_id: str
|
||||
step_details: Step
|
||||
|
||||
|
||||
|
@ -413,7 +276,9 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
|
|||
ToolResponseMessage,
|
||||
]
|
||||
]
|
||||
attachments: Optional[List[Attachment]] = None
|
||||
|
||||
documents: Optional[List[Document]] = None
|
||||
toolgroups: Optional[List[AgentToolGroup]] = None
|
||||
|
||||
stream: Optional[bool] = False
|
||||
|
||||
|
@ -450,8 +315,9 @@ class Agents(Protocol):
|
|||
ToolResponseMessage,
|
||||
]
|
||||
],
|
||||
attachments: Optional[List[Attachment]] = None,
|
||||
stream: Optional[bool] = False,
|
||||
documents: Optional[List[Document]] = None,
|
||||
toolgroups: Optional[List[AgentToolGroup]] = None,
|
||||
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
|
||||
|
||||
@webmethod(route="/agents/turn/get")
|
||||
|
|
|
@ -4,10 +4,11 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Annotated, Any, Dict, List, Literal, Optional, Union
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from llama_models.llama3.api.datatypes import ToolPromptFormat
|
||||
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Protocol, runtime_checkable
|
||||
|
||||
|
@ -21,15 +22,24 @@ class ToolParameter(BaseModel):
|
|||
name: str
|
||||
parameter_type: str
|
||||
description: str
|
||||
required: bool = Field(default=True)
|
||||
default: Optional[Any] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ToolHost(Enum):
|
||||
distribution = "distribution"
|
||||
client = "client"
|
||||
model_context_protocol = "model_context_protocol"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Tool(Resource):
|
||||
type: Literal[ResourceType.tool.value] = ResourceType.tool.value
|
||||
tool_group: str
|
||||
toolgroup_id: str
|
||||
tool_host: ToolHost
|
||||
description: str
|
||||
parameters: List[ToolParameter]
|
||||
provider_id: Optional[str] = None
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
||||
default=ToolPromptFormat.json
|
||||
|
@ -39,41 +49,27 @@ class Tool(Resource):
|
|||
@json_schema_type
|
||||
class ToolDef(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
parameters: List[ToolParameter]
|
||||
metadata: Dict[str, Any]
|
||||
description: Optional[str] = None
|
||||
parameters: Optional[List[ToolParameter]] = None
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
||||
default=ToolPromptFormat.json
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MCPToolGroupDef(BaseModel):
|
||||
"""
|
||||
A tool group that is defined by in a model context protocol server.
|
||||
Refer to https://modelcontextprotocol.io/docs/concepts/tools for more information.
|
||||
"""
|
||||
|
||||
type: Literal["model_context_protocol"] = "model_context_protocol"
|
||||
endpoint: URL
|
||||
class ToolGroupInput(BaseModel):
|
||||
toolgroup_id: str
|
||||
provider_id: str
|
||||
args: Optional[Dict[str, Any]] = None
|
||||
mcp_endpoint: Optional[URL] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class UserDefinedToolGroupDef(BaseModel):
|
||||
type: Literal["user_defined"] = "user_defined"
|
||||
tools: List[ToolDef]
|
||||
|
||||
|
||||
ToolGroupDef = register_schema(
|
||||
Annotated[
|
||||
Union[MCPToolGroupDef, UserDefinedToolGroupDef], Field(discriminator="type")
|
||||
],
|
||||
name="ToolGroup",
|
||||
)
|
||||
|
||||
|
||||
class ToolGroup(Resource):
|
||||
type: Literal[ResourceType.tool_group.value] = ResourceType.tool_group.value
|
||||
mcp_endpoint: Optional[URL] = None
|
||||
args: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -85,6 +81,7 @@ class ToolInvocationResult(BaseModel):
|
|||
|
||||
class ToolStore(Protocol):
|
||||
def get_tool(self, tool_name: str) -> Tool: ...
|
||||
def get_tool_group(self, tool_group_id: str) -> ToolGroup: ...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
|
@ -93,9 +90,10 @@ class ToolGroups(Protocol):
|
|||
@webmethod(route="/toolgroups/register", method="POST")
|
||||
async def register_tool_group(
|
||||
self,
|
||||
tool_group_id: str,
|
||||
tool_group: ToolGroupDef,
|
||||
provider_id: Optional[str] = None,
|
||||
toolgroup_id: str,
|
||||
provider_id: str,
|
||||
mcp_endpoint: Optional[URL] = None,
|
||||
args: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""Register a tool group"""
|
||||
...
|
||||
|
@ -103,7 +101,7 @@ class ToolGroups(Protocol):
|
|||
@webmethod(route="/toolgroups/get", method="GET")
|
||||
async def get_tool_group(
|
||||
self,
|
||||
tool_group_id: str,
|
||||
toolgroup_id: str,
|
||||
) -> ToolGroup: ...
|
||||
|
||||
@webmethod(route="/toolgroups/list", method="GET")
|
||||
|
@ -130,8 +128,11 @@ class ToolGroups(Protocol):
|
|||
class ToolRuntime(Protocol):
|
||||
tool_store: ToolStore
|
||||
|
||||
@webmethod(route="/tool-runtime/discover", method="POST")
|
||||
async def discover_tools(self, tool_group: ToolGroupDef) -> List[ToolDef]: ...
|
||||
# TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed.
|
||||
@webmethod(route="/tool-runtime/list-tools", method="GET")
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||
) -> List[ToolDef]: ...
|
||||
|
||||
@webmethod(route="/tool-runtime/invoke", method="POST")
|
||||
async def invoke_tool(
|
||||
|
|
|
@ -20,7 +20,7 @@ from llama_stack.apis.safety import Safety
|
|||
from llama_stack.apis.scoring import Scoring
|
||||
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnInput
|
||||
from llama_stack.apis.shields import Shield, ShieldInput
|
||||
from llama_stack.apis.tools import Tool, ToolGroup, ToolRuntime
|
||||
from llama_stack.apis.tools import Tool, ToolGroup, ToolGroupInput, ToolRuntime
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig
|
||||
|
||||
|
@ -161,6 +161,7 @@ a default SQLite store will be used.""",
|
|||
datasets: List[DatasetInput] = Field(default_factory=list)
|
||||
scoring_fns: List[ScoringFnInput] = Field(default_factory=list)
|
||||
eval_tasks: List[EvalTaskInput] = Field(default_factory=list)
|
||||
tool_groups: List[ToolGroupInput] = Field(default_factory=list)
|
||||
|
||||
|
||||
class BuildConfig(BaseModel):
|
||||
|
|
|
@ -267,6 +267,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
self.config, self.custom_provider_registry
|
||||
)
|
||||
except ModuleNotFoundError as _e:
|
||||
cprint(_e.msg, "red")
|
||||
cprint(
|
||||
"Using llama-stack as a library requires installing dependencies depending on the template (providers) you choose.\n",
|
||||
"yellow",
|
||||
|
|
|
@ -5,9 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
import importlib
|
||||
import inspect
|
||||
|
||||
import logging
|
||||
|
||||
from typing import Any, Dict, List, Set
|
||||
|
||||
from llama_stack.apis.agents import Agents
|
||||
|
@ -28,7 +26,6 @@ from llama_stack.apis.shields import Shields
|
|||
from llama_stack.apis.telemetry import Telemetry
|
||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||
from llama_stack.distribution.client import get_client_impl
|
||||
|
||||
from llama_stack.distribution.datatypes import (
|
||||
AutoRoutedProviderSpec,
|
||||
Provider,
|
||||
|
@ -38,7 +35,6 @@ from llama_stack.distribution.datatypes import (
|
|||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||
from llama_stack.distribution.store import DistributionRegistry
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
|
||||
from llama_stack.providers.datatypes import (
|
||||
Api,
|
||||
DatasetsProtocolPrivate,
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
from llama_stack.apis.common.content_types import InterleavedContent, URL
|
||||
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
|
||||
from llama_stack.apis.eval import (
|
||||
AppEvalTaskConfig,
|
||||
|
@ -38,7 +38,7 @@ from llama_stack.apis.scoring import (
|
|||
ScoringFnParams,
|
||||
)
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.apis.tools import Tool, ToolGroupDef, ToolRuntime
|
||||
from llama_stack.apis.tools import ToolDef, ToolRuntime
|
||||
from llama_stack.providers.datatypes import RoutingTable
|
||||
|
||||
|
||||
|
@ -417,7 +417,9 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
args=args,
|
||||
)
|
||||
|
||||
async def discover_tools(self, tool_group: ToolGroupDef) -> List[Tool]:
|
||||
return await self.routing_table.get_provider_impl(
|
||||
tool_group.name
|
||||
).discover_tools(tool_group)
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||
) -> List[ToolDef]:
|
||||
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(
|
||||
tool_group_id, mcp_endpoint
|
||||
)
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import parse_obj_as
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.common.type_system import ParamType
|
||||
|
@ -26,20 +26,12 @@ from llama_stack.apis.scoring_functions import (
|
|||
ScoringFunctions,
|
||||
)
|
||||
from llama_stack.apis.shields import Shield, Shields
|
||||
from llama_stack.apis.tools import (
|
||||
MCPToolGroupDef,
|
||||
Tool,
|
||||
ToolGroup,
|
||||
ToolGroupDef,
|
||||
ToolGroups,
|
||||
UserDefinedToolGroupDef,
|
||||
)
|
||||
from llama_stack.apis.tools import Tool, ToolGroup, ToolGroups, ToolHost
|
||||
from llama_stack.distribution.datatypes import (
|
||||
RoutableObject,
|
||||
RoutableObjectWithProvider,
|
||||
RoutedProtocol,
|
||||
)
|
||||
|
||||
from llama_stack.distribution.store import DistributionRegistry
|
||||
from llama_stack.providers.datatypes import Api, RoutingTable
|
||||
|
||||
|
@ -361,7 +353,7 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
|||
memory_bank_data["embedding_dimension"] = model.metadata[
|
||||
"embedding_dimension"
|
||||
]
|
||||
memory_bank = parse_obj_as(MemoryBank, memory_bank_data)
|
||||
memory_bank = TypeAdapter(MemoryBank).validate_python(memory_bank_data)
|
||||
await self.register_object(memory_bank)
|
||||
return memory_bank
|
||||
|
||||
|
@ -496,54 +488,45 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
async def list_tools(self, tool_group_id: Optional[str] = None) -> List[Tool]:
|
||||
tools = await self.get_all_with_type("tool")
|
||||
if tool_group_id:
|
||||
tools = [tool for tool in tools if tool.tool_group == tool_group_id]
|
||||
tools = [tool for tool in tools if tool.toolgroup_id == tool_group_id]
|
||||
return tools
|
||||
|
||||
async def list_tool_groups(self) -> List[ToolGroup]:
|
||||
return await self.get_all_with_type("tool_group")
|
||||
|
||||
async def get_tool_group(self, tool_group_id: str) -> ToolGroup:
|
||||
return await self.get_object_by_identifier("tool_group", tool_group_id)
|
||||
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup:
|
||||
return await self.get_object_by_identifier("tool_group", toolgroup_id)
|
||||
|
||||
async def get_tool(self, tool_name: str) -> Tool:
|
||||
return await self.get_object_by_identifier("tool", tool_name)
|
||||
|
||||
async def register_tool_group(
|
||||
self,
|
||||
tool_group_id: str,
|
||||
tool_group: ToolGroupDef,
|
||||
provider_id: Optional[str] = None,
|
||||
toolgroup_id: str,
|
||||
provider_id: str,
|
||||
mcp_endpoint: Optional[URL] = None,
|
||||
args: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
tools = []
|
||||
tool_defs = []
|
||||
if provider_id is None:
|
||||
if len(self.impls_by_provider_id.keys()) > 1:
|
||||
raise ValueError(
|
||||
f"No provider_id specified and multiple providers available. Please specify a provider_id. Available providers: {', '.join(self.impls_by_provider_id.keys())}"
|
||||
)
|
||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||
|
||||
if isinstance(tool_group, MCPToolGroupDef):
|
||||
tool_defs = await self.impls_by_provider_id[provider_id].discover_tools(
|
||||
tool_group
|
||||
)
|
||||
|
||||
elif isinstance(tool_group, UserDefinedToolGroupDef):
|
||||
tool_defs = tool_group.tools
|
||||
else:
|
||||
raise ValueError(f"Unknown tool group: {tool_group}")
|
||||
tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(
|
||||
toolgroup_id, mcp_endpoint
|
||||
)
|
||||
tool_host = (
|
||||
ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution
|
||||
)
|
||||
|
||||
for tool_def in tool_defs:
|
||||
tools.append(
|
||||
Tool(
|
||||
identifier=tool_def.name,
|
||||
tool_group=tool_group_id,
|
||||
description=tool_def.description,
|
||||
parameters=tool_def.parameters,
|
||||
toolgroup_id=toolgroup_id,
|
||||
description=tool_def.description or "",
|
||||
parameters=tool_def.parameters or [],
|
||||
provider_id=provider_id,
|
||||
tool_prompt_format=tool_def.tool_prompt_format,
|
||||
provider_resource_id=tool_def.name,
|
||||
metadata=tool_def.metadata,
|
||||
tool_host=tool_host,
|
||||
)
|
||||
)
|
||||
for tool in tools:
|
||||
|
@ -561,9 +544,11 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
|
||||
await self.dist_registry.register(
|
||||
ToolGroup(
|
||||
identifier=tool_group_id,
|
||||
identifier=toolgroup_id,
|
||||
provider_id=provider_id,
|
||||
provider_resource_id=tool_group_id,
|
||||
provider_resource_id=toolgroup_id,
|
||||
mcp_endpoint=mcp_endpoint,
|
||||
args=args,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
@ -12,7 +12,6 @@ from typing import Any, Dict, Optional
|
|||
|
||||
import pkg_resources
|
||||
import yaml
|
||||
|
||||
from termcolor import colored
|
||||
|
||||
from llama_stack.apis.agents import Agents
|
||||
|
@ -33,14 +32,13 @@ from llama_stack.apis.scoring_functions import ScoringFunctions
|
|||
from llama_stack.apis.shields import Shields
|
||||
from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration
|
||||
from llama_stack.apis.telemetry import Telemetry
|
||||
|
||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||
from llama_stack.distribution.datatypes import StackRunConfig
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
|
||||
from llama_stack.distribution.store.registry import create_dist_registry
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
LLAMA_STACK_API_VERSION = "alpha"
|
||||
|
@ -65,6 +63,8 @@ class LlamaStack(
|
|||
Models,
|
||||
Shields,
|
||||
Inspect,
|
||||
ToolGroups,
|
||||
ToolRuntime,
|
||||
):
|
||||
pass
|
||||
|
||||
|
@ -81,6 +81,7 @@ RESOURCES = [
|
|||
"list_scoring_functions",
|
||||
),
|
||||
("eval_tasks", Api.eval_tasks, "register_eval_task", "list_eval_tasks"),
|
||||
("tool_groups", Api.tool_groups, "register_tool_group", "list_tool_groups"),
|
||||
]
|
||||
|
||||
|
||||
|
|
|
@ -12,7 +12,6 @@ import pydantic
|
|||
|
||||
from llama_stack.distribution.datatypes import KVStoreConfig, RoutableObjectWithProvider
|
||||
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
|
||||
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
|
||||
|
@ -36,7 +35,7 @@ class DistributionRegistry(Protocol):
|
|||
|
||||
|
||||
REGISTER_PREFIX = "distributions:registry"
|
||||
KEY_VERSION = "v3"
|
||||
KEY_VERSION = "v4"
|
||||
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"
|
||||
|
||||
|
||||
|
|
|
@ -22,6 +22,8 @@ async def get_provider_impl(
|
|||
deps[Api.memory],
|
||||
deps[Api.safety],
|
||||
deps[Api.memory_banks],
|
||||
deps[Api.tool_runtime],
|
||||
deps[Api.tool_groups],
|
||||
)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
@ -4,8 +4,8 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
|
@ -13,16 +13,16 @@ import secrets
|
|||
import string
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import AsyncGenerator, Dict, List, Optional, Tuple
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
|
||||
from llama_models.llama3.api.datatypes import BuiltinTool
|
||||
from llama_models.llama3.api.datatypes import BuiltinTool, ToolCall, ToolParamDefinition
|
||||
|
||||
from llama_stack.apis.agents import (
|
||||
AgentConfig,
|
||||
AgentTool,
|
||||
AgentToolGroup,
|
||||
AgentToolGroupWithArgs,
|
||||
AgentTurnCreateRequest,
|
||||
AgentTurnResponseEvent,
|
||||
AgentTurnResponseEventType,
|
||||
|
@ -33,25 +33,14 @@ from llama_stack.apis.agents import (
|
|||
AgentTurnResponseTurnCompletePayload,
|
||||
AgentTurnResponseTurnStartPayload,
|
||||
Attachment,
|
||||
CodeInterpreterToolDefinition,
|
||||
FunctionCallToolDefinition,
|
||||
Document,
|
||||
InferenceStep,
|
||||
MemoryRetrievalStep,
|
||||
MemoryToolDefinition,
|
||||
PhotogenToolDefinition,
|
||||
SearchToolDefinition,
|
||||
ShieldCallStep,
|
||||
StepType,
|
||||
ToolExecutionStep,
|
||||
Turn,
|
||||
WolframAlphaToolDefinition,
|
||||
)
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
TextContentItem,
|
||||
URL,
|
||||
)
|
||||
from llama_stack.apis.common.content_types import TextContentItem, URL
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionResponseEventType,
|
||||
CompletionMessage,
|
||||
|
@ -62,32 +51,20 @@ from llama_stack.apis.inference import (
|
|||
SystemMessage,
|
||||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
ToolChoice,
|
||||
ToolDefinition,
|
||||
ToolResponse,
|
||||
ToolResponseMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.memory import Memory, MemoryBankDocument, QueryDocumentsResponse
|
||||
from llama_stack.apis.memory import Memory, MemoryBankDocument
|
||||
from llama_stack.apis.memory_banks import MemoryBanks, VectorMemoryBankParams
|
||||
from llama_stack.apis.safety import Safety
|
||||
|
||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||
from llama_stack.providers.utils.kvstore import KVStore
|
||||
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
|
||||
from llama_stack.providers.utils.telemetry import tracing
|
||||
|
||||
from .persistence import AgentPersistence
|
||||
from .rag.context_retriever import generate_rag_query
|
||||
from .safety import SafetyException, ShieldRunnerMixin
|
||||
from .tools.base import BaseTool
|
||||
from .tools.builtin import (
|
||||
CodeInterpreterTool,
|
||||
interpret_content_as_attachment,
|
||||
PhotogenTool,
|
||||
SearchTool,
|
||||
WolframAlphaTool,
|
||||
)
|
||||
from .tools.safety import SafeTool
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
@ -98,6 +75,12 @@ def make_random_string(length: int = 8):
|
|||
)
|
||||
|
||||
|
||||
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
|
||||
MEMORY_QUERY_TOOL = "query_memory"
|
||||
WEB_SEARCH_TOOL = "web_search"
|
||||
MEMORY_GROUP = "builtin::memory"
|
||||
|
||||
|
||||
class ChatAgent(ShieldRunnerMixin):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -108,6 +91,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
memory_api: Memory,
|
||||
memory_banks_api: MemoryBanks,
|
||||
safety_api: Safety,
|
||||
tool_runtime_api: ToolRuntime,
|
||||
tool_groups_api: ToolGroups,
|
||||
persistence_store: KVStore,
|
||||
):
|
||||
self.agent_id = agent_id
|
||||
|
@ -118,29 +103,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
self.memory_banks_api = memory_banks_api
|
||||
self.safety_api = safety_api
|
||||
self.storage = AgentPersistence(agent_id, persistence_store)
|
||||
|
||||
builtin_tools = []
|
||||
for tool_defn in agent_config.tools:
|
||||
if isinstance(tool_defn, WolframAlphaToolDefinition):
|
||||
tool = WolframAlphaTool(tool_defn.api_key)
|
||||
elif isinstance(tool_defn, SearchToolDefinition):
|
||||
tool = SearchTool(tool_defn.engine, tool_defn.api_key)
|
||||
elif isinstance(tool_defn, CodeInterpreterToolDefinition):
|
||||
tool = CodeInterpreterTool()
|
||||
elif isinstance(tool_defn, PhotogenToolDefinition):
|
||||
tool = PhotogenTool(dump_dir=self.tempdir)
|
||||
else:
|
||||
continue
|
||||
|
||||
builtin_tools.append(
|
||||
SafeTool(
|
||||
tool,
|
||||
safety_api,
|
||||
tool_defn.input_shields,
|
||||
tool_defn.output_shields,
|
||||
)
|
||||
)
|
||||
self.tools_dict = {t.get_name(): t for t in builtin_tools}
|
||||
self.tool_runtime_api = tool_runtime_api
|
||||
self.tool_groups_api = tool_groups_api
|
||||
|
||||
ShieldRunnerMixin.__init__(
|
||||
self,
|
||||
|
@ -228,9 +192,10 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
session_id=request.session_id,
|
||||
turn_id=turn_id,
|
||||
input_messages=messages,
|
||||
attachments=request.attachments or [],
|
||||
sampling_params=self.agent_config.sampling_params,
|
||||
stream=request.stream,
|
||||
documents=request.documents,
|
||||
toolgroups_for_turn=request.toolgroups,
|
||||
):
|
||||
if isinstance(chunk, CompletionMessage):
|
||||
log.info(
|
||||
|
@ -278,9 +243,10 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
session_id: str,
|
||||
turn_id: str,
|
||||
input_messages: List[Message],
|
||||
attachments: List[Attachment],
|
||||
sampling_params: SamplingParams,
|
||||
stream: bool = False,
|
||||
documents: Optional[List[Document]] = None,
|
||||
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
|
||||
) -> AsyncGenerator:
|
||||
# Doing async generators makes downstream code much simpler and everything amenable to
|
||||
# streaming. However, it also makes things complicated here because AsyncGenerators cannot
|
||||
|
@ -297,7 +263,13 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
yield res
|
||||
|
||||
async for res in self._run(
|
||||
session_id, turn_id, input_messages, attachments, sampling_params, stream
|
||||
session_id,
|
||||
turn_id,
|
||||
input_messages,
|
||||
sampling_params,
|
||||
stream,
|
||||
documents,
|
||||
toolgroups_for_turn,
|
||||
):
|
||||
if isinstance(res, bool):
|
||||
return
|
||||
|
@ -353,6 +325,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepCompletePayload(
|
||||
step_type=StepType.shield_call.value,
|
||||
step_id=step_id,
|
||||
step_details=ShieldCallStep(
|
||||
step_id=step_id,
|
||||
turn_id=turn_id,
|
||||
|
@ -373,6 +346,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepCompletePayload(
|
||||
step_type=StepType.shield_call.value,
|
||||
step_id=step_id,
|
||||
step_details=ShieldCallStep(
|
||||
step_id=step_id,
|
||||
turn_id=turn_id,
|
||||
|
@ -388,73 +362,116 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
session_id: str,
|
||||
turn_id: str,
|
||||
input_messages: List[Message],
|
||||
attachments: List[Attachment],
|
||||
sampling_params: SamplingParams,
|
||||
stream: bool = False,
|
||||
documents: Optional[List[Document]] = None,
|
||||
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
|
||||
) -> AsyncGenerator:
|
||||
enabled_tools = set(t.type for t in self.agent_config.tools)
|
||||
need_rag_context = await self._should_retrieve_context(
|
||||
input_messages, attachments
|
||||
)
|
||||
if need_rag_context:
|
||||
step_id = str(uuid.uuid4())
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepStartPayload(
|
||||
step_type=StepType.memory_retrieval.value,
|
||||
step_id=step_id,
|
||||
toolgroup_args = {}
|
||||
for toolgroup in self.agent_config.toolgroups:
|
||||
if isinstance(toolgroup, AgentToolGroupWithArgs):
|
||||
toolgroup_args[toolgroup.name] = toolgroup.args
|
||||
if toolgroups_for_turn:
|
||||
for toolgroup in toolgroups_for_turn:
|
||||
if isinstance(toolgroup, AgentToolGroupWithArgs):
|
||||
toolgroup_args[toolgroup.name] = toolgroup.args
|
||||
|
||||
tool_defs, tool_to_group = await self._get_tool_defs(toolgroups_for_turn)
|
||||
if documents:
|
||||
await self.handle_documents(
|
||||
session_id, documents, input_messages, tool_defs
|
||||
)
|
||||
if MEMORY_QUERY_TOOL in tool_defs and len(input_messages) > 0:
|
||||
memory_tool_group = tool_to_group.get(MEMORY_QUERY_TOOL, None)
|
||||
if memory_tool_group is None:
|
||||
raise ValueError(f"Memory tool group not found for {MEMORY_QUERY_TOOL}")
|
||||
with tracing.span(MEMORY_QUERY_TOOL) as span:
|
||||
step_id = str(uuid.uuid4())
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepStartPayload(
|
||||
step_type=StepType.tool_execution.value,
|
||||
step_id=step_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
query_args = {
|
||||
"messages": [msg.content for msg in input_messages],
|
||||
**toolgroup_args.get(memory_tool_group, {}),
|
||||
}
|
||||
|
||||
# TODO: find older context from the session and either replace it
|
||||
# or append with a sliding window. this is really a very simplistic implementation
|
||||
with tracing.span("retrieve_rag_context") as span:
|
||||
rag_context, bank_ids = await self._retrieve_context(
|
||||
session_id, input_messages, attachments
|
||||
session_info = await self.storage.get_session_info(session_id)
|
||||
# if the session has a memory bank id, let the memory tool use it
|
||||
if session_info.memory_bank_id:
|
||||
if "memory_bank_ids" not in query_args:
|
||||
query_args["memory_bank_ids"] = []
|
||||
query_args["memory_bank_ids"].append(session_info.memory_bank_id)
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepProgressPayload(
|
||||
step_type=StepType.tool_execution.value,
|
||||
step_id=step_id,
|
||||
tool_call_delta=ToolCallDelta(
|
||||
parse_status=ToolCallParseStatus.success,
|
||||
content=ToolCall(
|
||||
call_id="",
|
||||
tool_name=MEMORY_QUERY_TOOL,
|
||||
arguments={},
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
result = await self.tool_runtime_api.invoke_tool(
|
||||
tool_name=MEMORY_QUERY_TOOL,
|
||||
args=query_args,
|
||||
)
|
||||
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepCompletePayload(
|
||||
step_type=StepType.tool_execution.value,
|
||||
step_id=step_id,
|
||||
step_details=ToolExecutionStep(
|
||||
step_id=step_id,
|
||||
turn_id=turn_id,
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
call_id="",
|
||||
tool_name=MEMORY_QUERY_TOOL,
|
||||
arguments={},
|
||||
)
|
||||
],
|
||||
tool_responses=[
|
||||
ToolResponse(
|
||||
call_id="",
|
||||
tool_name=MEMORY_QUERY_TOOL,
|
||||
content=result.content,
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
span.set_attribute(
|
||||
"input", [m.model_dump_json() for m in input_messages]
|
||||
)
|
||||
span.set_attribute("output", rag_context)
|
||||
span.set_attribute("bank_ids", bank_ids)
|
||||
|
||||
step_id = str(uuid.uuid4())
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepCompletePayload(
|
||||
step_type=StepType.memory_retrieval.value,
|
||||
step_id=step_id,
|
||||
step_details=MemoryRetrievalStep(
|
||||
turn_id=turn_id,
|
||||
step_id=step_id,
|
||||
memory_bank_ids=bank_ids,
|
||||
inserted_context=rag_context or "",
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if rag_context:
|
||||
last_message = input_messages[-1]
|
||||
last_message.context = rag_context
|
||||
|
||||
elif attachments and AgentTool.code_interpreter.value in enabled_tools:
|
||||
urls = [a.content for a in attachments if isinstance(a.content, URL)]
|
||||
# TODO: we need to migrate URL away from str type
|
||||
pattern = re.compile("^(https?://|file://|data:)")
|
||||
urls += [
|
||||
URL(uri=a.content) for a in attachments if pattern.match(a.content)
|
||||
]
|
||||
msg = await attachment_message(self.tempdir, urls)
|
||||
input_messages.append(msg)
|
||||
span.set_attribute("output", result.content)
|
||||
span.set_attribute("error_code", result.error_code)
|
||||
span.set_attribute("error_message", result.error_message)
|
||||
span.set_attribute("tool_name", MEMORY_QUERY_TOOL)
|
||||
if result.error_code == 0:
|
||||
last_message = input_messages[-1]
|
||||
last_message.context = result.content
|
||||
|
||||
output_attachments = []
|
||||
|
||||
n_iter = 0
|
||||
# Build a map of custom tools to their definitions for faster lookup
|
||||
client_tools = {}
|
||||
for tool in self.agent_config.client_tools:
|
||||
client_tools[tool.name] = tool
|
||||
while True:
|
||||
msg = input_messages[-1]
|
||||
|
||||
step_id = str(uuid.uuid4())
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
|
@ -473,7 +490,11 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
async for chunk in await self.inference_api.chat_completion(
|
||||
self.agent_config.model,
|
||||
input_messages,
|
||||
tools=self._get_tools(),
|
||||
tools=[
|
||||
tool
|
||||
for tool in tool_defs.values()
|
||||
if tool_to_group.get(tool.tool_name, None) != MEMORY_GROUP
|
||||
],
|
||||
tool_prompt_format=self.agent_config.tool_prompt_format,
|
||||
stream=True,
|
||||
sampling_params=sampling_params,
|
||||
|
@ -572,9 +593,9 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
# TODO: UPDATE RETURN TYPE TO SEND A TUPLE OF (MESSAGE, ATTACHMENTS)
|
||||
if len(output_attachments) > 0:
|
||||
if isinstance(message.content, list):
|
||||
message.content += attachments
|
||||
message.content += output_attachments
|
||||
else:
|
||||
message.content = [message.content] + attachments
|
||||
message.content = [message.content] + output_attachments
|
||||
yield message
|
||||
else:
|
||||
log.info(f"Partial message: {str(message)}")
|
||||
|
@ -582,9 +603,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
else:
|
||||
log.info(f"{str(message)}")
|
||||
tool_call = message.tool_calls[0]
|
||||
|
||||
name = tool_call.tool_name
|
||||
if not isinstance(name, BuiltinTool) or name not in enabled_tools:
|
||||
if tool_call.tool_name in client_tools:
|
||||
yield message
|
||||
return
|
||||
|
||||
|
@ -607,16 +626,22 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
)
|
||||
|
||||
tool_name = tool_call.tool_name
|
||||
if isinstance(tool_name, BuiltinTool):
|
||||
tool_name = tool_name.value
|
||||
with tracing.span(
|
||||
"tool_execution",
|
||||
{
|
||||
"tool_name": tool_call.tool_name,
|
||||
"tool_name": tool_name,
|
||||
"input": message.model_dump_json(),
|
||||
},
|
||||
) as span:
|
||||
result_messages = await execute_tool_call_maybe(
|
||||
self.tools_dict,
|
||||
self.tool_runtime_api,
|
||||
session_id,
|
||||
[message],
|
||||
toolgroup_args,
|
||||
tool_to_group,
|
||||
)
|
||||
assert (
|
||||
len(result_messages) == 1
|
||||
|
@ -628,6 +653,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepCompletePayload(
|
||||
step_type=StepType.tool_execution.value,
|
||||
step_id=step_id,
|
||||
step_details=ToolExecutionStep(
|
||||
step_id=step_id,
|
||||
turn_id=turn_id,
|
||||
|
@ -647,7 +673,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
# TODO: add tool-input touchpoint and a "start" event for this step also
|
||||
# but that needs a lot more refactoring of Tool code potentially
|
||||
|
||||
if out_attachment := interpret_content_as_attachment(
|
||||
if out_attachment := _interpret_content_as_attachment(
|
||||
result_message.content
|
||||
):
|
||||
# NOTE: when we push this message back to the model, the model may ignore the
|
||||
|
@ -659,6 +685,150 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
n_iter += 1
|
||||
|
||||
async def _get_tool_defs(
|
||||
self, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None
|
||||
) -> Tuple[Dict[str, ToolDefinition], Dict[str, str]]:
|
||||
# Determine which tools to include
|
||||
agent_config_toolgroups = set(
|
||||
(
|
||||
toolgroup.name
|
||||
if isinstance(toolgroup, AgentToolGroupWithArgs)
|
||||
else toolgroup
|
||||
)
|
||||
for toolgroup in self.agent_config.toolgroups
|
||||
)
|
||||
toolgroups_for_turn_set = (
|
||||
agent_config_toolgroups
|
||||
if toolgroups_for_turn is None
|
||||
else {
|
||||
(
|
||||
toolgroup.name
|
||||
if isinstance(toolgroup, AgentToolGroupWithArgs)
|
||||
else toolgroup
|
||||
)
|
||||
for toolgroup in toolgroups_for_turn
|
||||
}
|
||||
)
|
||||
|
||||
tool_def_map = {}
|
||||
tool_to_group = {}
|
||||
|
||||
for tool_def in self.agent_config.client_tools:
|
||||
if tool_def_map.get(tool_def.name, None):
|
||||
raise ValueError(f"Tool {tool_def.name} already exists")
|
||||
tool_def_map[tool_def.name] = ToolDefinition(
|
||||
tool_name=tool_def.name,
|
||||
description=tool_def.description,
|
||||
parameters={
|
||||
param.name: ToolParamDefinition(
|
||||
param_type=param.parameter_type,
|
||||
description=param.description,
|
||||
required=param.required,
|
||||
default=param.default,
|
||||
)
|
||||
for param in tool_def.parameters
|
||||
},
|
||||
)
|
||||
tool_to_group[tool_def.name] = "__client_tools__"
|
||||
for toolgroup_name in agent_config_toolgroups:
|
||||
if toolgroup_name not in toolgroups_for_turn_set:
|
||||
continue
|
||||
tools = await self.tool_groups_api.list_tools(tool_group_id=toolgroup_name)
|
||||
for tool_def in tools:
|
||||
if (
|
||||
toolgroup_name.startswith("builtin")
|
||||
and toolgroup_name != MEMORY_GROUP
|
||||
):
|
||||
tool_name = tool_def.identifier
|
||||
built_in_type = BuiltinTool.brave_search
|
||||
if tool_name == "web_search":
|
||||
built_in_type = BuiltinTool.brave_search
|
||||
else:
|
||||
built_in_type = BuiltinTool(tool_name)
|
||||
|
||||
if tool_def_map.get(built_in_type, None):
|
||||
raise ValueError(f"Tool {built_in_type} already exists")
|
||||
|
||||
tool_def_map[built_in_type] = ToolDefinition(
|
||||
tool_name=built_in_type
|
||||
)
|
||||
tool_to_group[built_in_type] = tool_def.toolgroup_id
|
||||
continue
|
||||
|
||||
if tool_def_map.get(tool_def.identifier, None):
|
||||
raise ValueError(f"Tool {tool_def.identifier} already exists")
|
||||
tool_def_map[tool_def.identifier] = ToolDefinition(
|
||||
tool_name=tool_def.identifier,
|
||||
description=tool_def.description,
|
||||
parameters={
|
||||
param.name: ToolParamDefinition(
|
||||
param_type=param.parameter_type,
|
||||
description=param.description,
|
||||
required=param.required,
|
||||
default=param.default,
|
||||
)
|
||||
for param in tool_def.parameters
|
||||
},
|
||||
)
|
||||
tool_to_group[tool_def.identifier] = tool_def.toolgroup_id
|
||||
|
||||
return tool_def_map, tool_to_group
|
||||
|
||||
async def handle_documents(
|
||||
self,
|
||||
session_id: str,
|
||||
documents: List[Document],
|
||||
input_messages: List[Message],
|
||||
tool_defs: Dict[str, ToolDefinition],
|
||||
) -> None:
|
||||
memory_tool = tool_defs.get(MEMORY_QUERY_TOOL, None)
|
||||
code_interpreter_tool = tool_defs.get(BuiltinTool.code_interpreter, None)
|
||||
content_items = []
|
||||
url_items = []
|
||||
pattern = re.compile("^(https?://|file://|data:)")
|
||||
for d in documents:
|
||||
if isinstance(d.content, URL):
|
||||
url_items.append(d.content)
|
||||
elif pattern.match(d.content):
|
||||
url_items.append(URL(uri=d.content))
|
||||
else:
|
||||
content_items.append(d)
|
||||
|
||||
# Save the contents to a tempdir and use its path as a URL if code interpreter is present
|
||||
if code_interpreter_tool:
|
||||
for c in content_items:
|
||||
temp_file_path = os.path.join(
|
||||
self.tempdir, f"{make_random_string()}.txt"
|
||||
)
|
||||
with open(temp_file_path, "w") as temp_file:
|
||||
temp_file.write(c.content)
|
||||
url_items.append(URL(uri=f"file://{temp_file_path}"))
|
||||
|
||||
if memory_tool and code_interpreter_tool:
|
||||
# if both memory and code_interpreter are available, we download the URLs
|
||||
# and attach the data to the last message.
|
||||
msg = await attachment_message(self.tempdir, url_items)
|
||||
input_messages.append(msg)
|
||||
# Since memory is present, add all the data to the memory bank
|
||||
await self.add_to_session_memory_bank(session_id, documents)
|
||||
elif code_interpreter_tool:
|
||||
# if only code_interpreter is available, we download the URLs to a tempdir
|
||||
# and attach the path to them as a message to inference with the
|
||||
# assumption that the model invokes the code_interpreter tool with the path
|
||||
msg = await attachment_message(self.tempdir, url_items)
|
||||
input_messages.append(msg)
|
||||
elif memory_tool:
|
||||
# if only memory is available, we load the data from the URLs and content items to the memory bank
|
||||
await self.add_to_session_memory_bank(session_id, documents)
|
||||
else:
|
||||
# if no memory or code_interpreter tool is available,
|
||||
# we try to load the data from the URLs and content items as a message to inference
|
||||
# and add it to the last message's context
|
||||
input_messages[-1].context = "\n".join(
|
||||
[doc.content for doc in content_items]
|
||||
+ await load_data_from_urls(url_items)
|
||||
)
|
||||
|
||||
async def _ensure_memory_bank(self, session_id: str) -> str:
|
||||
session_info = await self.storage.get_session_info(session_id)
|
||||
if session_info is None:
|
||||
|
@ -679,129 +849,39 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
return bank_id
|
||||
|
||||
async def _should_retrieve_context(
|
||||
self, messages: List[Message], attachments: List[Attachment]
|
||||
) -> bool:
|
||||
enabled_tools = set(t.type for t in self.agent_config.tools)
|
||||
if attachments:
|
||||
if (
|
||||
AgentTool.code_interpreter.value in enabled_tools
|
||||
and self.agent_config.tool_choice == ToolChoice.required
|
||||
):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
return AgentTool.memory.value in enabled_tools
|
||||
|
||||
def _memory_tool_definition(self) -> Optional[MemoryToolDefinition]:
|
||||
for t in self.agent_config.tools:
|
||||
if t.type == AgentTool.memory.value:
|
||||
return t
|
||||
|
||||
return None
|
||||
|
||||
async def _retrieve_context(
|
||||
self, session_id: str, messages: List[Message], attachments: List[Attachment]
|
||||
) -> Tuple[Optional[InterleavedContent], List[int]]: # (rag_context, bank_ids)
|
||||
bank_ids = []
|
||||
|
||||
memory = self._memory_tool_definition()
|
||||
assert memory is not None, "Memory tool not configured"
|
||||
bank_ids.extend(c.bank_id for c in memory.memory_bank_configs)
|
||||
|
||||
if attachments:
|
||||
bank_id = await self._ensure_memory_bank(session_id)
|
||||
bank_ids.append(bank_id)
|
||||
|
||||
documents = [
|
||||
MemoryBankDocument(
|
||||
document_id=str(uuid.uuid4()),
|
||||
content=a.content,
|
||||
mime_type=a.mime_type,
|
||||
metadata={},
|
||||
)
|
||||
for a in attachments
|
||||
]
|
||||
with tracing.span("insert_documents"):
|
||||
await self.memory_api.insert_documents(bank_id, documents)
|
||||
else:
|
||||
session_info = await self.storage.get_session_info(session_id)
|
||||
if session_info.memory_bank_id:
|
||||
bank_ids.append(session_info.memory_bank_id)
|
||||
|
||||
if not bank_ids:
|
||||
# this can happen if the per-session memory bank is not yet populated
|
||||
# (i.e., no prior turns uploaded an Attachment)
|
||||
return None, []
|
||||
|
||||
query = await generate_rag_query(
|
||||
memory.query_generator_config, messages, inference_api=self.inference_api
|
||||
)
|
||||
tasks = [
|
||||
self.memory_api.query_documents(
|
||||
bank_id=bank_id,
|
||||
query=query,
|
||||
params={
|
||||
"max_chunks": 5,
|
||||
},
|
||||
async def add_to_session_memory_bank(
|
||||
self, session_id: str, data: List[Document]
|
||||
) -> None:
|
||||
bank_id = await self._ensure_memory_bank(session_id)
|
||||
documents = [
|
||||
MemoryBankDocument(
|
||||
document_id=str(uuid.uuid4()),
|
||||
content=a.content,
|
||||
mime_type=a.mime_type,
|
||||
metadata={},
|
||||
)
|
||||
for bank_id in bank_ids
|
||||
for a in data
|
||||
]
|
||||
results: List[QueryDocumentsResponse] = await asyncio.gather(*tasks)
|
||||
chunks = [c for r in results for c in r.chunks]
|
||||
scores = [s for r in results for s in r.scores]
|
||||
|
||||
if not chunks:
|
||||
return None, bank_ids
|
||||
|
||||
# sort by score
|
||||
chunks, scores = zip(
|
||||
*sorted(zip(chunks, scores), key=lambda x: x[1], reverse=True)
|
||||
await self.memory_api.insert_documents(
|
||||
bank_id=bank_id,
|
||||
documents=documents,
|
||||
)
|
||||
|
||||
tokens = 0
|
||||
picked = []
|
||||
for c in chunks[: memory.max_chunks]:
|
||||
tokens += c.token_count
|
||||
if tokens > memory.max_tokens_in_context:
|
||||
log.error(
|
||||
f"Using {len(picked)} chunks; reached max tokens in context: {tokens}",
|
||||
)
|
||||
break
|
||||
picked.append(f"id:{c.document_id}; content:{c.content}")
|
||||
|
||||
return (
|
||||
concat_interleaved_content(
|
||||
[
|
||||
"Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n",
|
||||
*picked,
|
||||
"\n=== END-RETRIEVED-CONTEXT ===\n",
|
||||
]
|
||||
),
|
||||
bank_ids,
|
||||
)
|
||||
|
||||
def _get_tools(self) -> List[ToolDefinition]:
|
||||
ret = []
|
||||
for t in self.agent_config.tools:
|
||||
if isinstance(t, SearchToolDefinition):
|
||||
ret.append(ToolDefinition(tool_name=BuiltinTool.brave_search))
|
||||
elif isinstance(t, WolframAlphaToolDefinition):
|
||||
ret.append(ToolDefinition(tool_name=BuiltinTool.wolfram_alpha))
|
||||
elif isinstance(t, PhotogenToolDefinition):
|
||||
ret.append(ToolDefinition(tool_name=BuiltinTool.photogen))
|
||||
elif isinstance(t, CodeInterpreterToolDefinition):
|
||||
ret.append(ToolDefinition(tool_name=BuiltinTool.code_interpreter))
|
||||
elif isinstance(t, FunctionCallToolDefinition):
|
||||
ret.append(
|
||||
ToolDefinition(
|
||||
tool_name=t.function_name,
|
||||
description=t.description,
|
||||
parameters=t.parameters,
|
||||
)
|
||||
)
|
||||
return ret
|
||||
async def load_data_from_urls(urls: List[URL]) -> List[str]:
|
||||
data = []
|
||||
for url in urls:
|
||||
uri = url.uri
|
||||
if uri.startswith("file://"):
|
||||
filepath = uri[len("file://") :]
|
||||
with open(filepath, "r") as f:
|
||||
data.append(f.read())
|
||||
elif uri.startswith("http"):
|
||||
async with httpx.AsyncClient() as client:
|
||||
r = await client.get(uri)
|
||||
resp = r.text
|
||||
data.append(resp)
|
||||
return data
|
||||
|
||||
|
||||
async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessage:
|
||||
|
@ -839,7 +919,11 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa
|
|||
|
||||
|
||||
async def execute_tool_call_maybe(
|
||||
tools_dict: Dict[str, BaseTool], messages: List[CompletionMessage]
|
||||
tool_runtime_api: ToolRuntime,
|
||||
session_id: str,
|
||||
messages: List[CompletionMessage],
|
||||
toolgroup_args: Dict[str, Dict[str, Any]],
|
||||
tool_to_group: Dict[str, str],
|
||||
) -> List[ToolResponseMessage]:
|
||||
# While Tools.run interface takes a list of messages,
|
||||
# All tools currently only run on a single message
|
||||
|
@ -851,11 +935,45 @@ async def execute_tool_call_maybe(
|
|||
|
||||
tool_call = message.tool_calls[0]
|
||||
name = tool_call.tool_name
|
||||
assert isinstance(name, BuiltinTool)
|
||||
group_name = tool_to_group.get(name, None)
|
||||
if group_name is None:
|
||||
raise ValueError(f"Tool {name} not found in any tool group")
|
||||
# get the arguments generated by the model and augment with toolgroup arg overrides for the agent
|
||||
tool_call_args = tool_call.arguments
|
||||
tool_call_args.update(toolgroup_args.get(group_name, {}))
|
||||
if isinstance(name, BuiltinTool):
|
||||
if name == BuiltinTool.brave_search:
|
||||
name = WEB_SEARCH_TOOL
|
||||
else:
|
||||
name = name.value
|
||||
|
||||
name = name.value
|
||||
result = await tool_runtime_api.invoke_tool(
|
||||
tool_name=name,
|
||||
args=dict(
|
||||
session_id=session_id,
|
||||
**tool_call_args,
|
||||
),
|
||||
)
|
||||
|
||||
assert name in tools_dict, f"Tool {name} not found"
|
||||
tool = tools_dict[name]
|
||||
result_messages = await tool.run(messages)
|
||||
return result_messages
|
||||
return [
|
||||
ToolResponseMessage(
|
||||
call_id=tool_call.call_id,
|
||||
tool_name=tool_call.tool_name,
|
||||
content=result.content,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def _interpret_content_as_attachment(
|
||||
content: str,
|
||||
) -> Optional[Attachment]:
|
||||
match = re.search(TOOLS_ATTACHMENT_KEY_REGEX, content)
|
||||
if match:
|
||||
snippet = match.group(1)
|
||||
data = json.loads(snippet)
|
||||
return Attachment(
|
||||
url=URL(uri="file://" + data["filepath"]),
|
||||
mime_type=data["mimetype"],
|
||||
)
|
||||
|
||||
return None
|
||||
|
|
|
@ -19,17 +19,17 @@ from llama_stack.apis.agents import (
|
|||
Agents,
|
||||
AgentSessionCreateResponse,
|
||||
AgentStepResponse,
|
||||
AgentToolGroup,
|
||||
AgentTurnCreateRequest,
|
||||
Attachment,
|
||||
Document,
|
||||
Session,
|
||||
Turn,
|
||||
)
|
||||
|
||||
from llama_stack.apis.inference import Inference, ToolResponseMessage, UserMessage
|
||||
from llama_stack.apis.memory import Memory
|
||||
from llama_stack.apis.memory_banks import MemoryBanks
|
||||
from llama_stack.apis.safety import Safety
|
||||
|
||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
|
||||
|
||||
from .agent_instance import ChatAgent
|
||||
|
@ -47,12 +47,16 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
memory_api: Memory,
|
||||
safety_api: Safety,
|
||||
memory_banks_api: MemoryBanks,
|
||||
tool_runtime_api: ToolRuntime,
|
||||
tool_groups_api: ToolGroups,
|
||||
):
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.memory_api = memory_api
|
||||
self.safety_api = safety_api
|
||||
self.memory_banks_api = memory_banks_api
|
||||
self.tool_runtime_api = tool_runtime_api
|
||||
self.tool_groups_api = tool_groups_api
|
||||
|
||||
self.in_memory_store = InmemoryKVStoreImpl()
|
||||
self.tempdir = tempfile.mkdtemp()
|
||||
|
@ -112,6 +116,8 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
safety_api=self.safety_api,
|
||||
memory_api=self.memory_api,
|
||||
memory_banks_api=self.memory_banks_api,
|
||||
tool_runtime_api=self.tool_runtime_api,
|
||||
tool_groups_api=self.tool_groups_api,
|
||||
persistence_store=(
|
||||
self.persistence_store
|
||||
if agent_config.enable_session_persistence
|
||||
|
@ -141,15 +147,17 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
ToolResponseMessage,
|
||||
]
|
||||
],
|
||||
attachments: Optional[List[Attachment]] = None,
|
||||
toolgroups: Optional[List[AgentToolGroup]] = None,
|
||||
documents: Optional[List[Document]] = None,
|
||||
stream: Optional[bool] = False,
|
||||
) -> AsyncGenerator:
|
||||
request = AgentTurnCreateRequest(
|
||||
agent_id=agent_id,
|
||||
session_id=session_id,
|
||||
messages=messages,
|
||||
attachments=attachments,
|
||||
stream=True,
|
||||
toolgroups=toolgroups,
|
||||
documents=documents,
|
||||
)
|
||||
if stream:
|
||||
return self._create_agent_turn_streaming(request)
|
||||
|
|
|
@ -8,13 +8,11 @@ import json
|
|||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.agents import Turn
|
||||
|
||||
from llama_stack.providers.utils.kvstore import KVStore
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
|
|
@ -1,93 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import unittest
|
||||
|
||||
from llama_models.llama3.api.datatypes import (
|
||||
Attachment,
|
||||
BuiltinTool,
|
||||
CompletionMessage,
|
||||
StopReason,
|
||||
ToolCall,
|
||||
)
|
||||
|
||||
from ..tools.builtin import CodeInterpreterTool
|
||||
|
||||
|
||||
class TestCodeInterpreter(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_matplotlib(self):
|
||||
tool = CodeInterpreterTool()
|
||||
code = """
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
x = np.array([1, 1])
|
||||
y = np.array([0, 10])
|
||||
|
||||
plt.plot(x, y)
|
||||
plt.title('x = 1')
|
||||
plt.xlabel('x')
|
||||
plt.ylabel('y')
|
||||
plt.grid(True)
|
||||
plt.axvline(x=1, color='r')
|
||||
plt.show()
|
||||
"""
|
||||
message = CompletionMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
call_id="call_id",
|
||||
tool_name=BuiltinTool.code_interpreter,
|
||||
arguments={"code": code},
|
||||
)
|
||||
],
|
||||
stop_reason=StopReason.end_of_message,
|
||||
)
|
||||
ret = await tool.run([message])
|
||||
|
||||
self.assertEqual(len(ret), 1)
|
||||
|
||||
output = ret[0].content
|
||||
self.assertIsInstance(output, Attachment)
|
||||
self.assertEqual(output.mime_type, "image/png")
|
||||
|
||||
async def test_path_unlink(self):
|
||||
tool = CodeInterpreterTool()
|
||||
code = """
|
||||
import os
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
|
||||
dpath = Path(os.environ["MPLCONFIGDIR"])
|
||||
with open(dpath / "test", "w") as f:
|
||||
f.write("hello")
|
||||
|
||||
Path(dpath / "test").unlink()
|
||||
print("_OK_")
|
||||
"""
|
||||
message = CompletionMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
call_id="call_id",
|
||||
tool_name=BuiltinTool.code_interpreter,
|
||||
arguments={"code": code},
|
||||
)
|
||||
],
|
||||
stop_reason=StopReason.end_of_message,
|
||||
)
|
||||
ret = await tool.run([message])
|
||||
|
||||
self.assertEqual(len(ret), 1)
|
||||
|
||||
output = ret[0].content
|
||||
self.assertTrue("_OK_" in output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -4,21 +4,26 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import tempfile
|
||||
from typing import AsyncIterator, List, Optional, Union
|
||||
|
||||
import pytest
|
||||
from llama_models.llama3.api.datatypes import BuiltinTool
|
||||
|
||||
from llama_stack.apis.agents import (
|
||||
AgentConfig,
|
||||
AgentToolGroupWithArgs,
|
||||
AgentTurnCreateRequest,
|
||||
AgentTurnResponseTurnCompletePayload,
|
||||
StepType,
|
||||
)
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseEvent,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionMessage,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
|
@ -27,13 +32,24 @@ from llama_stack.apis.inference import (
|
|||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.memory import MemoryBank
|
||||
from llama_stack.apis.memory_banks import BankParams, VectorMemoryBank
|
||||
from llama_stack.apis.safety import RunShieldResponse
|
||||
|
||||
from ..agents import (
|
||||
AGENT_INSTANCES_BY_ID,
|
||||
MetaReferenceAgentsImpl,
|
||||
MetaReferenceInferenceConfig,
|
||||
from llama_stack.apis.tools import (
|
||||
Tool,
|
||||
ToolDef,
|
||||
ToolGroup,
|
||||
ToolHost,
|
||||
ToolInvocationResult,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
|
||||
MEMORY_QUERY_TOOL,
|
||||
)
|
||||
from llama_stack.providers.inline.agents.meta_reference.agents import (
|
||||
MetaReferenceAgentsImpl,
|
||||
MetaReferenceAgentsImplConfig,
|
||||
)
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
|
||||
|
||||
class MockInferenceAPI:
|
||||
|
@ -48,10 +64,10 @@ class MockInferenceAPI:
|
|||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncIterator[
|
||||
Union[ChatCompletionResponseStreamChunk, ChatCompletionResponse]
|
||||
) -> Union[
|
||||
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
|
||||
]:
|
||||
if stream:
|
||||
async def stream_response():
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type="start",
|
||||
|
@ -65,19 +81,7 @@ class MockInferenceAPI:
|
|||
delta="AI is a fascinating field...",
|
||||
)
|
||||
)
|
||||
# yield ChatCompletionResponseStreamChunk(
|
||||
# event=ChatCompletionResponseEvent(
|
||||
# event_type="progress",
|
||||
# delta=ToolCallDelta(
|
||||
# content=ToolCall(
|
||||
# call_id="123",
|
||||
# tool_name=BuiltinTool.brave_search.value,
|
||||
# arguments={"query": "AI history"},
|
||||
# ),
|
||||
# parse_status="success",
|
||||
# ),
|
||||
# )
|
||||
# )
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type="complete",
|
||||
|
@ -85,12 +89,17 @@ class MockInferenceAPI:
|
|||
stop_reason="end_of_turn",
|
||||
)
|
||||
)
|
||||
|
||||
if stream:
|
||||
return stream_response()
|
||||
else:
|
||||
yield ChatCompletionResponse(
|
||||
return ChatCompletionResponse(
|
||||
completion_message=CompletionMessage(
|
||||
role="assistant", content="Mock response", stop_reason="end_of_turn"
|
||||
role="assistant",
|
||||
content="Mock response",
|
||||
stop_reason="end_of_turn",
|
||||
),
|
||||
logprobs=[0.1, 0.2, 0.3] if logprobs else None,
|
||||
logprobs={"token_logprobs": [0.1, 0.2, 0.3]} if logprobs else None,
|
||||
)
|
||||
|
||||
|
||||
|
@ -165,6 +174,98 @@ class MockMemoryAPI:
|
|||
self.documents[bank_id].pop(doc_id, None)
|
||||
|
||||
|
||||
class MockToolGroupsAPI:
|
||||
async def register_tool_group(
|
||||
self, toolgroup_id: str, provider_id: str, mcp_endpoint=None, args=None
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup:
|
||||
return ToolGroup(
|
||||
identifier=toolgroup_id,
|
||||
provider_resource_id=toolgroup_id,
|
||||
)
|
||||
|
||||
async def list_tool_groups(self) -> List[ToolGroup]:
|
||||
return []
|
||||
|
||||
async def list_tools(self, tool_group_id: Optional[str] = None) -> List[Tool]:
|
||||
if tool_group_id == MEMORY_TOOLGROUP:
|
||||
return [
|
||||
Tool(
|
||||
identifier=MEMORY_QUERY_TOOL,
|
||||
provider_resource_id=MEMORY_QUERY_TOOL,
|
||||
toolgroup_id=MEMORY_TOOLGROUP,
|
||||
tool_host=ToolHost.client,
|
||||
description="Mock tool",
|
||||
provider_id="builtin::memory",
|
||||
parameters=[],
|
||||
)
|
||||
]
|
||||
if tool_group_id == CODE_INTERPRETER_TOOLGROUP:
|
||||
return [
|
||||
Tool(
|
||||
identifier="code_interpreter",
|
||||
provider_resource_id="code_interpreter",
|
||||
toolgroup_id=CODE_INTERPRETER_TOOLGROUP,
|
||||
tool_host=ToolHost.client,
|
||||
description="Mock tool",
|
||||
provider_id="builtin::code_interpreter",
|
||||
parameters=[],
|
||||
)
|
||||
]
|
||||
return []
|
||||
|
||||
async def get_tool(self, tool_name: str) -> Tool:
|
||||
return Tool(
|
||||
identifier=tool_name,
|
||||
provider_resource_id=tool_name,
|
||||
toolgroup_id="mock_group",
|
||||
tool_host=ToolHost.client,
|
||||
description="Mock tool",
|
||||
provider_id="mock_provider",
|
||||
parameters=[],
|
||||
)
|
||||
|
||||
async def unregister_tool_group(self, tool_group_id: str) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class MockToolRuntimeAPI:
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||
) -> List[ToolDef]:
|
||||
return []
|
||||
|
||||
async def invoke_tool(self, tool_name: str, args: dict) -> ToolInvocationResult:
|
||||
return ToolInvocationResult(content={"result": "Mock tool result"})
|
||||
|
||||
|
||||
class MockMemoryBanksAPI:
|
||||
async def list_memory_banks(self) -> List[MemoryBank]:
|
||||
return []
|
||||
|
||||
async def get_memory_bank(self, memory_bank_id: str) -> Optional[MemoryBank]:
|
||||
return None
|
||||
|
||||
async def register_memory_bank(
|
||||
self,
|
||||
memory_bank_id: str,
|
||||
params: BankParams,
|
||||
provider_id: Optional[str] = None,
|
||||
provider_memory_bank_id: Optional[str] = None,
|
||||
) -> MemoryBank:
|
||||
return VectorMemoryBank(
|
||||
identifier=memory_bank_id,
|
||||
provider_resource_id=provider_memory_bank_id or memory_bank_id,
|
||||
embedding_model="mock_model",
|
||||
chunk_size_in_tokens=512,
|
||||
)
|
||||
|
||||
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_inference_api():
|
||||
return MockInferenceAPI()
|
||||
|
@ -181,64 +282,107 @@ def mock_memory_api():
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
async def chat_agent(mock_inference_api, mock_safety_api, mock_memory_api):
|
||||
def mock_tool_groups_api():
|
||||
return MockToolGroupsAPI()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tool_runtime_api():
|
||||
return MockToolRuntimeAPI()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_memory_banks_api():
|
||||
return MockMemoryBanksAPI()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def get_agents_impl(
|
||||
mock_inference_api,
|
||||
mock_safety_api,
|
||||
mock_memory_api,
|
||||
mock_memory_banks_api,
|
||||
mock_tool_runtime_api,
|
||||
mock_tool_groups_api,
|
||||
):
|
||||
sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
|
||||
impl = MetaReferenceAgentsImpl(
|
||||
config=MetaReferenceInferenceConfig(),
|
||||
config=MetaReferenceAgentsImplConfig(
|
||||
persistence_store=SqliteKVStoreConfig(
|
||||
db_name=sqlite_file.name,
|
||||
),
|
||||
),
|
||||
inference_api=mock_inference_api,
|
||||
safety_api=mock_safety_api,
|
||||
memory_api=mock_memory_api,
|
||||
memory_banks_api=mock_memory_banks_api,
|
||||
tool_runtime_api=mock_tool_runtime_api,
|
||||
tool_groups_api=mock_tool_groups_api,
|
||||
)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def get_chat_agent(get_agents_impl):
|
||||
impl = await get_agents_impl
|
||||
agent_config = AgentConfig(
|
||||
model="test_model",
|
||||
instructions="You are a helpful assistant.",
|
||||
sampling_params=SamplingParams(),
|
||||
tools=[
|
||||
# SearchToolDefinition(
|
||||
# name="brave_search",
|
||||
# api_key="test_key",
|
||||
# ),
|
||||
],
|
||||
toolgroups=[],
|
||||
tool_choice=ToolChoice.auto,
|
||||
enable_session_persistence=False,
|
||||
input_shields=[],
|
||||
output_shields=[],
|
||||
input_shields=["test_shield"],
|
||||
)
|
||||
response = await impl.create_agent(agent_config)
|
||||
agent = AGENT_INSTANCES_BY_ID[response.agent_id]
|
||||
return agent
|
||||
return await impl.get_agent(response.agent_id)
|
||||
|
||||
|
||||
MEMORY_TOOLGROUP = "builtin::memory"
|
||||
CODE_INTERPRETER_TOOLGROUP = "builtin::code_interpreter"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def get_chat_agent_with_tools(get_agents_impl, request):
|
||||
impl = await get_agents_impl
|
||||
toolgroups = request.param
|
||||
agent_config = AgentConfig(
|
||||
model="test_model",
|
||||
instructions="You are a helpful assistant.",
|
||||
toolgroups=toolgroups,
|
||||
tool_choice=ToolChoice.auto,
|
||||
enable_session_persistence=False,
|
||||
input_shields=["test_shield"],
|
||||
)
|
||||
response = await impl.create_agent(agent_config)
|
||||
return await impl.get_agent(response.agent_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_agent_create_session(chat_agent):
|
||||
session = chat_agent.create_session("Test Session")
|
||||
assert session.session_name == "Test Session"
|
||||
assert session.turns == []
|
||||
assert session.session_id in chat_agent.sessions
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_agent_create_and_execute_turn(chat_agent):
|
||||
session = chat_agent.create_session("Test Session")
|
||||
async def test_chat_agent_create_and_execute_turn(get_chat_agent):
|
||||
chat_agent = await get_chat_agent
|
||||
session_id = await chat_agent.create_session("Test Session")
|
||||
request = AgentTurnCreateRequest(
|
||||
agent_id="random",
|
||||
session_id=session.session_id,
|
||||
agent_id=chat_agent.agent_id,
|
||||
session_id=session_id,
|
||||
messages=[UserMessage(content="Hello")],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
responses = []
|
||||
async for response in chat_agent.create_and_execute_turn(request):
|
||||
responses.append(response)
|
||||
|
||||
print(responses)
|
||||
assert len(responses) > 0
|
||||
assert len(responses) == 4 # TurnStart, StepStart, StepComplete, TurnComplete
|
||||
assert (
|
||||
len(responses) == 7
|
||||
) # TurnStart, ShieldCallStart, ShieldCallComplete, StepStart, StepProgress, StepComplete, TurnComplete
|
||||
assert responses[0].event.payload.turn_id is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_multiple_shields_wrapper(chat_agent):
|
||||
async def test_run_multiple_shields_wrapper(get_chat_agent):
|
||||
chat_agent = await get_chat_agent
|
||||
messages = [UserMessage(content="Test message")]
|
||||
shields = ["test_shield"]
|
||||
|
||||
|
@ -254,69 +398,95 @@ async def test_run_multiple_shields_wrapper(chat_agent):
|
|||
|
||||
assert len(responses) == 2 # StepStart, StepComplete
|
||||
assert responses[0].event.payload.step_type.value == "shield_call"
|
||||
assert not responses[1].event.payload.step_details.response.is_violation
|
||||
assert not responses[1].event.payload.step_details.violation
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skip(reason="Not yet implemented; need to mock out tool execution easily")
|
||||
async def test_chat_agent_complex_turn(chat_agent):
|
||||
# Setup
|
||||
session = chat_agent.create_session("Test Session")
|
||||
async def test_chat_agent_complex_turn(get_chat_agent):
|
||||
chat_agent = await get_chat_agent
|
||||
session_id = await chat_agent.create_session("Test Session")
|
||||
request = AgentTurnCreateRequest(
|
||||
agent_id="random",
|
||||
session_id=session.session_id,
|
||||
agent_id=chat_agent.agent_id,
|
||||
session_id=session_id,
|
||||
messages=[UserMessage(content="Tell me about AI and then use a tool.")],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# Execute the turn
|
||||
responses = []
|
||||
async for response in chat_agent.create_and_execute_turn(request):
|
||||
responses.append(response)
|
||||
|
||||
# Assertions
|
||||
assert len(responses) > 0
|
||||
|
||||
# Check for the presence of different step types
|
||||
step_types = [
|
||||
response.event.payload.step_type
|
||||
for response in responses
|
||||
if hasattr(response.event.payload, "step_type")
|
||||
]
|
||||
|
||||
assert "shield_call" in step_types, "Shield call step is missing"
|
||||
assert "inference" in step_types, "Inference step is missing"
|
||||
assert "tool_execution" in step_types, "Tool execution step is missing"
|
||||
assert StepType.shield_call in step_types, "Shield call step is missing"
|
||||
assert StepType.inference in step_types, "Inference step is missing"
|
||||
|
||||
# Check for the presence of start and complete events
|
||||
event_types = [
|
||||
response.event.payload.event_type
|
||||
for response in responses
|
||||
if hasattr(response.event.payload, "event_type")
|
||||
]
|
||||
assert "start" in event_types, "Start event is missing"
|
||||
assert "complete" in event_types, "Complete event is missing"
|
||||
assert "turn_start" in event_types, "Start event is missing"
|
||||
assert "turn_complete" in event_types, "Complete event is missing"
|
||||
|
||||
# Check for the presence of tool call
|
||||
tool_calls = [
|
||||
response.event.payload.tool_call
|
||||
for response in responses
|
||||
if hasattr(response.event.payload, "tool_call")
|
||||
]
|
||||
assert any(
|
||||
tool_call
|
||||
for tool_call in tool_calls
|
||||
if tool_call and tool_call.content.get("name") == "memory"
|
||||
), "Memory tool call is missing"
|
||||
|
||||
# Check for the final turn complete event
|
||||
assert any(
|
||||
isinstance(response.event.payload, AgentTurnResponseTurnCompletePayload)
|
||||
for response in responses
|
||||
), "Turn complete event is missing"
|
||||
turn_complete_payload = next(
|
||||
response.event.payload
|
||||
for response in responses
|
||||
if isinstance(response.event.payload, AgentTurnResponseTurnCompletePayload)
|
||||
)
|
||||
turn = turn_complete_payload.turn
|
||||
assert turn.input_messages == request.messages, "Input messages do not match"
|
||||
|
||||
# Verify the turn was added to the session
|
||||
assert len(session.turns) == 1, "Turn was not added to the session"
|
||||
assert (
|
||||
session.turns[0].input_messages == request.messages
|
||||
), "Input messages do not match"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"toolgroups, expected_memory, expected_code_interpreter",
|
||||
[
|
||||
([], False, False), # no tools
|
||||
([MEMORY_TOOLGROUP], True, False), # memory only
|
||||
([CODE_INTERPRETER_TOOLGROUP], False, True), # code interpreter only
|
||||
([MEMORY_TOOLGROUP, CODE_INTERPRETER_TOOLGROUP], True, True), # all tools
|
||||
],
|
||||
)
|
||||
async def test_chat_agent_tools(
|
||||
get_agents_impl, toolgroups, expected_memory, expected_code_interpreter
|
||||
):
|
||||
impl = await get_agents_impl
|
||||
agent_config = AgentConfig(
|
||||
model="test_model",
|
||||
instructions="You are a helpful assistant.",
|
||||
toolgroups=toolgroups,
|
||||
tool_choice=ToolChoice.auto,
|
||||
enable_session_persistence=False,
|
||||
input_shields=["test_shield"],
|
||||
)
|
||||
response = await impl.create_agent(agent_config)
|
||||
chat_agent = await impl.get_agent(response.agent_id)
|
||||
|
||||
tool_defs, _ = await chat_agent._get_tool_defs()
|
||||
if expected_memory:
|
||||
assert MEMORY_QUERY_TOOL in tool_defs
|
||||
if expected_code_interpreter:
|
||||
assert BuiltinTool.code_interpreter in tool_defs
|
||||
if expected_memory and expected_code_interpreter:
|
||||
# override the tools for turn
|
||||
new_tool_defs, _ = await chat_agent._get_tool_defs(
|
||||
toolgroups_for_turn=[
|
||||
AgentToolGroupWithArgs(
|
||||
name=MEMORY_TOOLGROUP,
|
||||
args={"memory_banks": ["test_memory_bank"]},
|
||||
)
|
||||
]
|
||||
)
|
||||
assert MEMORY_QUERY_TOOL in new_tool_defs
|
||||
assert BuiltinTool.code_interpreter not in new_tool_defs
|
||||
|
|
|
@ -1,20 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
|
||||
|
||||
class BaseTool(ABC):
|
||||
@abstractmethod
|
||||
def get_name(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def run(self, messages: List[Message]) -> List[Message]:
|
||||
raise NotImplementedError
|
|
@ -1,396 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import tempfile
|
||||
|
||||
from abc import abstractmethod
|
||||
from typing import List, Optional
|
||||
|
||||
import requests
|
||||
|
||||
from .ipython_tool.code_execution import (
|
||||
CodeExecutionContext,
|
||||
CodeExecutionRequest,
|
||||
CodeExecutor,
|
||||
TOOLS_ATTACHMENT_KEY_REGEX,
|
||||
)
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.agents import * # noqa: F403
|
||||
|
||||
from .base import BaseTool
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def interpret_content_as_attachment(content: str) -> Optional[Attachment]:
|
||||
match = re.search(TOOLS_ATTACHMENT_KEY_REGEX, content)
|
||||
if match:
|
||||
snippet = match.group(1)
|
||||
data = json.loads(snippet)
|
||||
return Attachment(
|
||||
url=URL(uri="file://" + data["filepath"]), mime_type=data["mimetype"]
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class SingleMessageBuiltinTool(BaseTool):
|
||||
async def run(self, messages: List[CompletionMessage]) -> List[ToolResponseMessage]:
|
||||
assert len(messages) == 1, f"Expected single message, got {len(messages)}"
|
||||
|
||||
message = messages[0]
|
||||
assert len(message.tool_calls) == 1, "Expected a single tool call"
|
||||
|
||||
tool_call = messages[0].tool_calls[0]
|
||||
|
||||
query = tool_call.arguments["query"]
|
||||
response: str = await self.run_impl(query)
|
||||
|
||||
message = ToolResponseMessage(
|
||||
call_id=tool_call.call_id,
|
||||
tool_name=tool_call.tool_name,
|
||||
content=response,
|
||||
)
|
||||
return [message]
|
||||
|
||||
@abstractmethod
|
||||
async def run_impl(self, query: str) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class PhotogenTool(SingleMessageBuiltinTool):
|
||||
def __init__(self, dump_dir: str) -> None:
|
||||
self.dump_dir = dump_dir
|
||||
|
||||
def get_name(self) -> str:
|
||||
return BuiltinTool.photogen.value
|
||||
|
||||
async def run_impl(self, query: str) -> str:
|
||||
"""
|
||||
Implement this to give the model an ability to generate images.
|
||||
|
||||
Return:
|
||||
info = {
|
||||
"filepath": str(image_filepath),
|
||||
"mimetype": "image/png",
|
||||
}
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class SearchTool(SingleMessageBuiltinTool):
|
||||
def __init__(self, engine: SearchEngineType, api_key: str, **kwargs) -> None:
|
||||
self.api_key = api_key
|
||||
self.engine_type = engine
|
||||
if engine == SearchEngineType.bing:
|
||||
self.engine = BingSearch(api_key, **kwargs)
|
||||
elif engine == SearchEngineType.brave:
|
||||
self.engine = BraveSearch(api_key, **kwargs)
|
||||
elif engine == SearchEngineType.tavily:
|
||||
self.engine = TavilySearch(api_key, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unknown search engine: {engine}")
|
||||
|
||||
def get_name(self) -> str:
|
||||
return BuiltinTool.brave_search.value
|
||||
|
||||
async def run_impl(self, query: str) -> str:
|
||||
return await self.engine.search(query)
|
||||
|
||||
|
||||
class BingSearch:
|
||||
def __init__(self, api_key: str, top_k: int = 3, **kwargs) -> None:
|
||||
self.api_key = api_key
|
||||
self.top_k = top_k
|
||||
|
||||
async def search(self, query: str) -> str:
|
||||
url = "https://api.bing.microsoft.com/v7.0/search"
|
||||
headers = {
|
||||
"Ocp-Apim-Subscription-Key": self.api_key,
|
||||
}
|
||||
params = {
|
||||
"count": self.top_k,
|
||||
"textDecorations": True,
|
||||
"textFormat": "HTML",
|
||||
"q": query,
|
||||
}
|
||||
|
||||
response = requests.get(url=url, params=params, headers=headers)
|
||||
response.raise_for_status()
|
||||
clean = self._clean_response(response.json())
|
||||
return json.dumps(clean)
|
||||
|
||||
def _clean_response(self, search_response):
|
||||
clean_response = []
|
||||
query = search_response["queryContext"]["originalQuery"]
|
||||
if "webPages" in search_response:
|
||||
pages = search_response["webPages"]["value"]
|
||||
for p in pages:
|
||||
selected_keys = {"name", "url", "snippet"}
|
||||
clean_response.append(
|
||||
{k: v for k, v in p.items() if k in selected_keys}
|
||||
)
|
||||
if "news" in search_response:
|
||||
clean_news = []
|
||||
news = search_response["news"]["value"]
|
||||
for n in news:
|
||||
selected_keys = {"name", "url", "description"}
|
||||
clean_news.append({k: v for k, v in n.items() if k in selected_keys})
|
||||
|
||||
clean_response.append(clean_news)
|
||||
|
||||
return {"query": query, "top_k": clean_response}
|
||||
|
||||
|
||||
class BraveSearch:
|
||||
def __init__(self, api_key: str) -> None:
|
||||
self.api_key = api_key
|
||||
|
||||
async def search(self, query: str) -> str:
|
||||
url = "https://api.search.brave.com/res/v1/web/search"
|
||||
headers = {
|
||||
"X-Subscription-Token": self.api_key,
|
||||
"Accept-Encoding": "gzip",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
payload = {"q": query}
|
||||
response = requests.get(url=url, params=payload, headers=headers)
|
||||
return json.dumps(self._clean_brave_response(response.json()))
|
||||
|
||||
def _clean_brave_response(self, search_response, top_k=3):
|
||||
query = None
|
||||
clean_response = []
|
||||
if "query" in search_response:
|
||||
if "original" in search_response["query"]:
|
||||
query = search_response["query"]["original"]
|
||||
if "mixed" in search_response:
|
||||
mixed_results = search_response["mixed"]
|
||||
for m in mixed_results["main"][:top_k]:
|
||||
r_type = m["type"]
|
||||
results = search_response[r_type]["results"]
|
||||
if r_type == "web":
|
||||
# For web data - add a single output from the search
|
||||
idx = m["index"]
|
||||
selected_keys = [
|
||||
"type",
|
||||
"title",
|
||||
"url",
|
||||
"description",
|
||||
"date",
|
||||
"extra_snippets",
|
||||
]
|
||||
cleaned = {
|
||||
k: v for k, v in results[idx].items() if k in selected_keys
|
||||
}
|
||||
elif r_type == "faq":
|
||||
# For faw data - take a list of all the questions & answers
|
||||
selected_keys = ["type", "question", "answer", "title", "url"]
|
||||
cleaned = []
|
||||
for q in results:
|
||||
cleaned.append(
|
||||
{k: v for k, v in q.items() if k in selected_keys}
|
||||
)
|
||||
elif r_type == "infobox":
|
||||
idx = m["index"]
|
||||
selected_keys = [
|
||||
"type",
|
||||
"title",
|
||||
"url",
|
||||
"description",
|
||||
"long_desc",
|
||||
]
|
||||
cleaned = {
|
||||
k: v for k, v in results[idx].items() if k in selected_keys
|
||||
}
|
||||
elif r_type == "videos":
|
||||
selected_keys = [
|
||||
"type",
|
||||
"url",
|
||||
"title",
|
||||
"description",
|
||||
"date",
|
||||
]
|
||||
cleaned = []
|
||||
for q in results:
|
||||
cleaned.append(
|
||||
{k: v for k, v in q.items() if k in selected_keys}
|
||||
)
|
||||
elif r_type == "locations":
|
||||
# For faw data - take a list of all the questions & answers
|
||||
selected_keys = [
|
||||
"type",
|
||||
"title",
|
||||
"url",
|
||||
"description",
|
||||
"coordinates",
|
||||
"postal_address",
|
||||
"contact",
|
||||
"rating",
|
||||
"distance",
|
||||
"zoom_level",
|
||||
]
|
||||
cleaned = []
|
||||
for q in results:
|
||||
cleaned.append(
|
||||
{k: v for k, v in q.items() if k in selected_keys}
|
||||
)
|
||||
elif r_type == "news":
|
||||
# For faw data - take a list of all the questions & answers
|
||||
selected_keys = [
|
||||
"type",
|
||||
"title",
|
||||
"url",
|
||||
"description",
|
||||
]
|
||||
cleaned = []
|
||||
for q in results:
|
||||
cleaned.append(
|
||||
{k: v for k, v in q.items() if k in selected_keys}
|
||||
)
|
||||
else:
|
||||
cleaned = []
|
||||
|
||||
clean_response.append(cleaned)
|
||||
|
||||
return {"query": query, "top_k": clean_response}
|
||||
|
||||
|
||||
class TavilySearch:
|
||||
def __init__(self, api_key: str) -> None:
|
||||
self.api_key = api_key
|
||||
|
||||
async def search(self, query: str) -> str:
|
||||
response = requests.post(
|
||||
"https://api.tavily.com/search",
|
||||
json={"api_key": self.api_key, "query": query},
|
||||
)
|
||||
return json.dumps(self._clean_tavily_response(response.json()))
|
||||
|
||||
def _clean_tavily_response(self, search_response, top_k=3):
|
||||
return {"query": search_response["query"], "top_k": search_response["results"]}
|
||||
|
||||
|
||||
class WolframAlphaTool(SingleMessageBuiltinTool):
|
||||
def __init__(self, api_key: str) -> None:
|
||||
self.api_key = api_key
|
||||
self.url = "https://api.wolframalpha.com/v2/query"
|
||||
|
||||
def get_name(self) -> str:
|
||||
return BuiltinTool.wolfram_alpha.value
|
||||
|
||||
async def run_impl(self, query: str) -> str:
|
||||
params = {
|
||||
"input": query,
|
||||
"appid": self.api_key,
|
||||
"format": "plaintext",
|
||||
"output": "json",
|
||||
}
|
||||
response = requests.get(
|
||||
self.url,
|
||||
params=params,
|
||||
)
|
||||
|
||||
return json.dumps(self._clean_wolfram_alpha_response(response.json()))
|
||||
|
||||
def _clean_wolfram_alpha_response(self, wa_response):
|
||||
remove = {
|
||||
"queryresult": [
|
||||
"datatypes",
|
||||
"error",
|
||||
"timedout",
|
||||
"timedoutpods",
|
||||
"numpods",
|
||||
"timing",
|
||||
"parsetiming",
|
||||
"parsetimedout",
|
||||
"recalculate",
|
||||
"id",
|
||||
"host",
|
||||
"server",
|
||||
"related",
|
||||
"version",
|
||||
{
|
||||
"pods": [
|
||||
"scanner",
|
||||
"id",
|
||||
"error",
|
||||
"expressiontypes",
|
||||
"states",
|
||||
"infos",
|
||||
"position",
|
||||
"numsubpods",
|
||||
]
|
||||
},
|
||||
"assumptions",
|
||||
],
|
||||
}
|
||||
for main_key in remove:
|
||||
for key_to_remove in remove[main_key]:
|
||||
try:
|
||||
if key_to_remove == "assumptions":
|
||||
if "assumptions" in wa_response[main_key]:
|
||||
del wa_response[main_key][key_to_remove]
|
||||
if isinstance(key_to_remove, dict):
|
||||
for sub_key in key_to_remove:
|
||||
if sub_key == "pods":
|
||||
for i in range(len(wa_response[main_key][sub_key])):
|
||||
if (
|
||||
wa_response[main_key][sub_key][i]["title"]
|
||||
== "Result"
|
||||
):
|
||||
del wa_response[main_key][sub_key][i + 1 :]
|
||||
break
|
||||
sub_items = wa_response[main_key][sub_key]
|
||||
for i in range(len(sub_items)):
|
||||
for sub_key_to_remove in key_to_remove[sub_key]:
|
||||
if sub_key_to_remove in sub_items[i]:
|
||||
del sub_items[i][sub_key_to_remove]
|
||||
elif key_to_remove in wa_response[main_key]:
|
||||
del wa_response[main_key][key_to_remove]
|
||||
except KeyError:
|
||||
pass
|
||||
return wa_response
|
||||
|
||||
|
||||
class CodeInterpreterTool(BaseTool):
|
||||
def __init__(self) -> None:
|
||||
ctx = CodeExecutionContext(
|
||||
matplotlib_dump_dir=tempfile.mkdtemp(),
|
||||
)
|
||||
self.code_executor = CodeExecutor(ctx)
|
||||
|
||||
def get_name(self) -> str:
|
||||
return BuiltinTool.code_interpreter.value
|
||||
|
||||
async def run(self, messages: List[CompletionMessage]) -> List[ToolResponseMessage]:
|
||||
message = messages[0]
|
||||
assert len(message.tool_calls) == 1, "Expected a single tool call"
|
||||
|
||||
tool_call = messages[0].tool_calls[0]
|
||||
script = tool_call.arguments["code"]
|
||||
|
||||
req = CodeExecutionRequest(scripts=[script])
|
||||
res = self.code_executor.execute(req)
|
||||
|
||||
pieces = [res["process_status"]]
|
||||
for out_type in ["stdout", "stderr"]:
|
||||
res_out = res[out_type]
|
||||
if res_out != "":
|
||||
pieces.extend([f"[{out_type}]", res_out, f"[/{out_type}]"])
|
||||
if out_type == "stderr":
|
||||
log.error(f"ipython tool error: ↓\n{res_out}")
|
||||
|
||||
message = ToolResponseMessage(
|
||||
call_id=tool_call.call_id,
|
||||
tool_name=tool_call.tool_name,
|
||||
content="\n".join(pieces),
|
||||
)
|
||||
return [message]
|
|
@ -1,42 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import List
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.safety import Safety
|
||||
|
||||
from ..safety import ShieldRunnerMixin
|
||||
from .builtin import BaseTool
|
||||
|
||||
|
||||
class SafeTool(BaseTool, ShieldRunnerMixin):
|
||||
"""A tool that makes other tools safety enabled"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tool: BaseTool,
|
||||
safety_api: Safety,
|
||||
input_shields: List[str] = None,
|
||||
output_shields: List[str] = None,
|
||||
):
|
||||
self._tool = tool
|
||||
ShieldRunnerMixin.__init__(
|
||||
self, safety_api, input_shields=input_shields, output_shields=output_shields
|
||||
)
|
||||
|
||||
def get_name(self) -> str:
|
||||
return self._tool.get_name()
|
||||
|
||||
async def run(self, messages: List[Message]) -> List[Message]:
|
||||
if self.input_shields:
|
||||
await self.run_multiple_shields(messages, self.input_shields)
|
||||
# run the underlying tool
|
||||
res = await self._tool.run(messages)
|
||||
if self.output_shields:
|
||||
await self.run_multiple_shields(messages, self.output_shields)
|
||||
|
||||
return res
|
|
@ -0,0 +1,16 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .code_interpreter import CodeInterpreterToolRuntimeImpl
|
||||
from .config import CodeInterpreterToolConfig
|
||||
|
||||
__all__ = ["CodeInterpreterToolConfig", "CodeInterpreterToolRuntimeImpl"]
|
||||
|
||||
|
||||
async def get_provider_impl(config: CodeInterpreterToolConfig, _deps):
|
||||
impl = CodeInterpreterToolRuntimeImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
|
@ -0,0 +1,75 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
import logging
|
||||
import tempfile
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.tools import (
|
||||
Tool,
|
||||
ToolDef,
|
||||
ToolInvocationResult,
|
||||
ToolParameter,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||
|
||||
from .code_execution import CodeExecutionContext, CodeExecutionRequest, CodeExecutor
|
||||
from .config import CodeInterpreterToolConfig
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CodeInterpreterToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
||||
def __init__(self, config: CodeInterpreterToolConfig):
|
||||
self.config = config
|
||||
ctx = CodeExecutionContext(
|
||||
matplotlib_dump_dir=tempfile.mkdtemp(),
|
||||
)
|
||||
self.code_executor = CodeExecutor(ctx)
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def register_tool(self, tool: Tool):
|
||||
pass
|
||||
|
||||
async def unregister_tool(self, tool_id: str) -> None:
|
||||
return
|
||||
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||
) -> List[ToolDef]:
|
||||
return [
|
||||
ToolDef(
|
||||
name="code_interpreter",
|
||||
description="Execute code",
|
||||
parameters=[
|
||||
ToolParameter(
|
||||
name="code",
|
||||
description="The code to execute",
|
||||
parameter_type="string",
|
||||
),
|
||||
],
|
||||
)
|
||||
]
|
||||
|
||||
async def invoke_tool(
|
||||
self, tool_name: str, args: Dict[str, Any]
|
||||
) -> ToolInvocationResult:
|
||||
script = args["code"]
|
||||
req = CodeExecutionRequest(scripts=[script])
|
||||
res = self.code_executor.execute(req)
|
||||
pieces = [res["process_status"]]
|
||||
for out_type in ["stdout", "stderr"]:
|
||||
res_out = res[out_type]
|
||||
if res_out != "":
|
||||
pieces.extend([f"[{out_type}]", res_out, f"[/{out_type}]"])
|
||||
if out_type == "stderr":
|
||||
log.error(f"ipython tool error: ↓\n{res_out}")
|
||||
return ToolInvocationResult(content="\n".join(pieces))
|
|
@ -3,3 +3,9 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class CodeInterpreterToolConfig(BaseModel):
|
||||
pass
|
20
llama_stack/providers/inline/tool_runtime/memory/__init__.py
Normal file
20
llama_stack/providers/inline/tool_runtime/memory/__init__.py
Normal file
|
@ -0,0 +1,20 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
from .config import MemoryToolRuntimeConfig
|
||||
from .memory import MemoryToolRuntimeImpl
|
||||
|
||||
|
||||
async def get_provider_impl(config: MemoryToolRuntimeConfig, deps: Dict[str, Any]):
|
||||
impl = MemoryToolRuntimeImpl(
|
||||
config, deps[Api.memory], deps[Api.memory_banks], deps[Api.inference]
|
||||
)
|
||||
await impl.initialize()
|
||||
return impl
|
90
llama_stack/providers/inline/tool_runtime/memory/config.py
Normal file
90
llama_stack/providers/inline/tool_runtime/memory/config.py
Normal file
|
@ -0,0 +1,90 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Annotated, List, Literal, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class _MemoryBankConfigCommon(BaseModel):
|
||||
bank_id: str
|
||||
|
||||
|
||||
class VectorMemoryBankConfig(_MemoryBankConfigCommon):
|
||||
type: Literal["vector"] = "vector"
|
||||
|
||||
|
||||
class KeyValueMemoryBankConfig(_MemoryBankConfigCommon):
|
||||
type: Literal["keyvalue"] = "keyvalue"
|
||||
keys: List[str] # what keys to focus on
|
||||
|
||||
|
||||
class KeywordMemoryBankConfig(_MemoryBankConfigCommon):
|
||||
type: Literal["keyword"] = "keyword"
|
||||
|
||||
|
||||
class GraphMemoryBankConfig(_MemoryBankConfigCommon):
|
||||
type: Literal["graph"] = "graph"
|
||||
entities: List[str] # what entities to focus on
|
||||
|
||||
|
||||
MemoryBankConfig = Annotated[
|
||||
Union[
|
||||
VectorMemoryBankConfig,
|
||||
KeyValueMemoryBankConfig,
|
||||
KeywordMemoryBankConfig,
|
||||
GraphMemoryBankConfig,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
class MemoryQueryGenerator(Enum):
|
||||
default = "default"
|
||||
llm = "llm"
|
||||
custom = "custom"
|
||||
|
||||
|
||||
class DefaultMemoryQueryGeneratorConfig(BaseModel):
|
||||
type: Literal[MemoryQueryGenerator.default.value] = (
|
||||
MemoryQueryGenerator.default.value
|
||||
)
|
||||
sep: str = " "
|
||||
|
||||
|
||||
class LLMMemoryQueryGeneratorConfig(BaseModel):
|
||||
type: Literal[MemoryQueryGenerator.llm.value] = MemoryQueryGenerator.llm.value
|
||||
model: str
|
||||
template: str
|
||||
|
||||
|
||||
class CustomMemoryQueryGeneratorConfig(BaseModel):
|
||||
type: Literal[MemoryQueryGenerator.custom.value] = MemoryQueryGenerator.custom.value
|
||||
|
||||
|
||||
MemoryQueryGeneratorConfig = Annotated[
|
||||
Union[
|
||||
DefaultMemoryQueryGeneratorConfig,
|
||||
LLMMemoryQueryGeneratorConfig,
|
||||
CustomMemoryQueryGeneratorConfig,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
class MemoryToolConfig(BaseModel):
|
||||
memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list)
|
||||
|
||||
|
||||
class MemoryToolRuntimeConfig(BaseModel):
|
||||
# This config defines how a query is generated using the messages
|
||||
# for memory bank retrieval.
|
||||
query_generator_config: MemoryQueryGeneratorConfig = Field(
|
||||
default=DefaultMemoryQueryGeneratorConfig()
|
||||
)
|
||||
max_tokens_in_context: int = 4096
|
||||
max_chunks: int = 5
|
|
@ -4,25 +4,29 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from typing import List
|
||||
|
||||
from jinja2 import Template
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.agents import (
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
from llama_stack.apis.inference import UserMessage
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
|
||||
from .config import (
|
||||
DefaultMemoryQueryGeneratorConfig,
|
||||
LLMMemoryQueryGeneratorConfig,
|
||||
MemoryQueryGenerator,
|
||||
MemoryQueryGeneratorConfig,
|
||||
)
|
||||
from llama_stack.apis.inference import Message, UserMessage
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
|
||||
|
||||
async def generate_rag_query(
|
||||
config: MemoryQueryGeneratorConfig,
|
||||
messages: List[Message],
|
||||
messages: List[InterleavedContent],
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
|
@ -40,21 +44,26 @@ async def generate_rag_query(
|
|||
|
||||
async def default_rag_query_generator(
|
||||
config: DefaultMemoryQueryGeneratorConfig,
|
||||
messages: List[Message],
|
||||
messages: List[InterleavedContent],
|
||||
**kwargs,
|
||||
):
|
||||
return config.sep.join(interleaved_content_as_str(m.content) for m in messages)
|
||||
return config.sep.join(interleaved_content_as_str(m) for m in messages)
|
||||
|
||||
|
||||
async def llm_rag_query_generator(
|
||||
config: LLMMemoryQueryGeneratorConfig,
|
||||
messages: List[Message],
|
||||
messages: List[InterleavedContent],
|
||||
**kwargs,
|
||||
):
|
||||
assert "inference_api" in kwargs, "LLMRAGQueryGenerator needs inference_api"
|
||||
inference_api = kwargs["inference_api"]
|
||||
|
||||
m_dict = {"messages": [m.model_dump() for m in messages]}
|
||||
m_dict = {
|
||||
"messages": [
|
||||
message.model_dump() if isinstance(message, BaseModel) else message
|
||||
for message in messages
|
||||
]
|
||||
}
|
||||
|
||||
template = Template(config.template)
|
||||
content = template.render(m_dict)
|
146
llama_stack/providers/inline/tool_runtime/memory/memory.py
Normal file
146
llama_stack/providers/inline/tool_runtime/memory/memory.py
Normal file
|
@ -0,0 +1,146 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import secrets
|
||||
import string
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.inference import Inference, InterleavedContent
|
||||
from llama_stack.apis.memory import Memory, QueryDocumentsResponse
|
||||
from llama_stack.apis.memory_banks import MemoryBanks
|
||||
from llama_stack.apis.tools import (
|
||||
ToolDef,
|
||||
ToolInvocationResult,
|
||||
ToolParameter,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
|
||||
|
||||
from .config import MemoryToolConfig, MemoryToolRuntimeConfig
|
||||
from .context_retriever import generate_rag_query
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def make_random_string(length: int = 8):
|
||||
return "".join(
|
||||
secrets.choice(string.ascii_letters + string.digits) for _ in range(length)
|
||||
)
|
||||
|
||||
|
||||
class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
||||
def __init__(
|
||||
self,
|
||||
config: MemoryToolRuntimeConfig,
|
||||
memory_api: Memory,
|
||||
memory_banks_api: MemoryBanks,
|
||||
inference_api: Inference,
|
||||
):
|
||||
self.config = config
|
||||
self.memory_api = memory_api
|
||||
self.memory_banks_api = memory_banks_api
|
||||
self.inference_api = inference_api
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||
) -> List[ToolDef]:
|
||||
return [
|
||||
ToolDef(
|
||||
name="query_memory",
|
||||
description="Retrieve context from memory",
|
||||
parameters=[
|
||||
ToolParameter(
|
||||
name="messages",
|
||||
description="The input messages to search for",
|
||||
parameter_type="array",
|
||||
),
|
||||
],
|
||||
)
|
||||
]
|
||||
|
||||
async def _retrieve_context(
|
||||
self, input_messages: List[InterleavedContent], bank_ids: List[str]
|
||||
) -> Optional[List[InterleavedContent]]:
|
||||
if not bank_ids:
|
||||
return None
|
||||
query = await generate_rag_query(
|
||||
self.config.query_generator_config,
|
||||
input_messages,
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
tasks = [
|
||||
self.memory_api.query_documents(
|
||||
bank_id=bank_id,
|
||||
query=query,
|
||||
params={
|
||||
"max_chunks": self.config.max_chunks,
|
||||
},
|
||||
)
|
||||
for bank_id in bank_ids
|
||||
]
|
||||
results: List[QueryDocumentsResponse] = await asyncio.gather(*tasks)
|
||||
chunks = [c for r in results for c in r.chunks]
|
||||
scores = [s for r in results for s in r.scores]
|
||||
|
||||
if not chunks:
|
||||
return None
|
||||
|
||||
# sort by score
|
||||
chunks, scores = zip(
|
||||
*sorted(zip(chunks, scores), key=lambda x: x[1], reverse=True)
|
||||
)
|
||||
|
||||
tokens = 0
|
||||
picked = []
|
||||
for c in chunks[: self.config.max_chunks]:
|
||||
tokens += c.token_count
|
||||
if tokens > self.config.max_tokens_in_context:
|
||||
log.error(
|
||||
f"Using {len(picked)} chunks; reached max tokens in context: {tokens}",
|
||||
)
|
||||
break
|
||||
picked.append(f"id:{c.document_id}; content:{c.content}")
|
||||
|
||||
return [
|
||||
"Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n",
|
||||
*picked,
|
||||
"\n=== END-RETRIEVED-CONTEXT ===\n",
|
||||
]
|
||||
|
||||
async def invoke_tool(
|
||||
self, tool_name: str, args: Dict[str, Any]
|
||||
) -> ToolInvocationResult:
|
||||
tool = await self.tool_store.get_tool(tool_name)
|
||||
tool_group = await self.tool_store.get_tool_group(tool.toolgroup_id)
|
||||
final_args = tool_group.args or {}
|
||||
final_args.update(args)
|
||||
config = MemoryToolConfig()
|
||||
if tool.metadata and tool.metadata.get("config") is not None:
|
||||
config = MemoryToolConfig(**tool.metadata["config"])
|
||||
if "memory_bank_ids" in final_args:
|
||||
bank_ids = final_args["memory_bank_ids"]
|
||||
else:
|
||||
bank_ids = [
|
||||
bank_config.bank_id for bank_config in config.memory_bank_configs
|
||||
]
|
||||
if "messages" not in final_args:
|
||||
raise ValueError("messages are required")
|
||||
context = await self._retrieve_context(
|
||||
final_args["messages"],
|
||||
bank_ids,
|
||||
)
|
||||
if context is None:
|
||||
context = []
|
||||
return ToolInvocationResult(
|
||||
content=concat_interleaved_content(context), error_code=0
|
||||
)
|
|
@ -35,6 +35,8 @@ def available_providers() -> List[ProviderSpec]:
|
|||
Api.safety,
|
||||
Api.memory,
|
||||
Api.memory_banks,
|
||||
Api.tool_runtime,
|
||||
Api.tool_groups,
|
||||
],
|
||||
),
|
||||
remote_provider_spec(
|
||||
|
|
|
@ -19,11 +19,58 @@ def available_providers() -> List[ProviderSpec]:
|
|||
return [
|
||||
InlineProviderSpec(
|
||||
api=Api.tool_runtime,
|
||||
provider_type="inline::brave-search",
|
||||
provider_type="inline::memory-runtime",
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.inline.tool_runtime.brave_search",
|
||||
config_class="llama_stack.providers.inline.tool_runtime.brave_search.config.BraveSearchToolConfig",
|
||||
provider_data_validator="llama_stack.providers.inline.tool_runtime.brave_search.BraveSearchToolProviderDataValidator",
|
||||
module="llama_stack.providers.inline.tool_runtime.memory",
|
||||
config_class="llama_stack.providers.inline.tool_runtime.memory.config.MemoryToolRuntimeConfig",
|
||||
api_dependencies=[Api.memory, Api.memory_banks, Api.inference],
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.tool_runtime,
|
||||
provider_type="inline::code-interpreter",
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.inline.tool_runtime.code_interpreter",
|
||||
config_class="llama_stack.providers.inline.tool_runtime.code_interpreter.config.CodeInterpreterToolConfig",
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.tool_runtime,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="brave-search",
|
||||
module="llama_stack.providers.remote.tool_runtime.brave_search",
|
||||
config_class="llama_stack.providers.remote.tool_runtime.brave_search.config.BraveSearchToolConfig",
|
||||
pip_packages=["requests"],
|
||||
provider_data_validator="llama_stack.providers.remote.tool_runtime.brave_search.BraveSearchToolProviderDataValidator",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.tool_runtime,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="bing-search",
|
||||
module="llama_stack.providers.remote.tool_runtime.bing_search",
|
||||
config_class="llama_stack.providers.remote.tool_runtime.bing_search.config.BingSearchToolConfig",
|
||||
pip_packages=["requests"],
|
||||
provider_data_validator="llama_stack.providers.remote.tool_runtime.bing_search.BingSearchToolProviderDataValidator",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.tool_runtime,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="tavily-search",
|
||||
module="llama_stack.providers.remote.tool_runtime.tavily_search",
|
||||
config_class="llama_stack.providers.remote.tool_runtime.tavily_search.config.TavilySearchToolConfig",
|
||||
pip_packages=["requests"],
|
||||
provider_data_validator="llama_stack.providers.remote.tool_runtime.tavily_search.TavilySearchToolProviderDataValidator",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.tool_runtime,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="wolfram-alpha",
|
||||
module="llama_stack.providers.remote.tool_runtime.wolfram_alpha",
|
||||
config_class="llama_stack.providers.remote.tool_runtime.wolfram_alpha.config.WolframAlphaToolConfig",
|
||||
pip_packages=["requests"],
|
||||
provider_data_validator="llama_stack.providers.remote.tool_runtime.wolfram_alpha.WolframAlphaToolProviderDataValidator",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.tool_runtime,
|
||||
|
|
|
@ -7,11 +7,8 @@
|
|||
from typing import AsyncGenerator, List, Optional, Union
|
||||
|
||||
from llama_models.datatypes import CoreModelId
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
from together import Together
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
|
@ -53,7 +50,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
|
||||
from .config import TogetherImplConfig
|
||||
|
||||
|
||||
MODEL_ALIASES = [
|
||||
build_model_alias(
|
||||
"meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .bing_search import BingSearchToolRuntimeImpl
|
||||
from .config import BingSearchToolConfig
|
||||
|
||||
__all__ = ["BingSearchToolConfig", "BingSearchToolRuntimeImpl"]
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class BingSearchToolProviderDataValidator(BaseModel):
|
||||
api_key: str
|
||||
|
||||
|
||||
async def get_adapter_impl(config: BingSearchToolConfig, _deps):
|
||||
impl = BingSearchToolRuntimeImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
|
@ -0,0 +1,114 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.tools import (
|
||||
Tool,
|
||||
ToolDef,
|
||||
ToolInvocationResult,
|
||||
ToolParameter,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||
|
||||
from .config import BingSearchToolConfig
|
||||
|
||||
|
||||
class BingSearchToolRuntimeImpl(
|
||||
ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData
|
||||
):
|
||||
def __init__(self, config: BingSearchToolConfig):
|
||||
self.config = config
|
||||
self.url = "https://api.bing.microsoft.com/v7.0/search"
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def register_tool(self, tool: Tool):
|
||||
pass
|
||||
|
||||
async def unregister_tool(self, tool_id: str) -> None:
|
||||
return
|
||||
|
||||
def _get_api_key(self) -> str:
|
||||
if self.config.api_key:
|
||||
return self.config.api_key
|
||||
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.api_key:
|
||||
raise ValueError(
|
||||
'Pass Bing Search API Key in the header X-LlamaStack-ProviderData as { "api_key": <your api key>}'
|
||||
)
|
||||
return provider_data.api_key
|
||||
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||
) -> List[ToolDef]:
|
||||
return [
|
||||
ToolDef(
|
||||
name="web_search",
|
||||
description="Search the web using Bing Search API",
|
||||
parameters=[
|
||||
ToolParameter(
|
||||
name="query",
|
||||
description="The query to search for",
|
||||
parameter_type="string",
|
||||
)
|
||||
],
|
||||
)
|
||||
]
|
||||
|
||||
async def invoke_tool(
|
||||
self, tool_name: str, args: Dict[str, Any]
|
||||
) -> ToolInvocationResult:
|
||||
api_key = self._get_api_key()
|
||||
headers = {
|
||||
"Ocp-Apim-Subscription-Key": api_key,
|
||||
}
|
||||
params = {
|
||||
"count": self.config.top_k,
|
||||
"textDecorations": True,
|
||||
"textFormat": "HTML",
|
||||
"q": args["query"],
|
||||
}
|
||||
|
||||
response = requests.get(
|
||||
url=self.url,
|
||||
params=params,
|
||||
headers=headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
return ToolInvocationResult(
|
||||
content=json.dumps(self._clean_response(response.json()))
|
||||
)
|
||||
|
||||
def _clean_response(self, search_response):
|
||||
clean_response = []
|
||||
query = search_response["queryContext"]["originalQuery"]
|
||||
if "webPages" in search_response:
|
||||
pages = search_response["webPages"]["value"]
|
||||
for p in pages:
|
||||
selected_keys = {"name", "url", "snippet"}
|
||||
clean_response.append(
|
||||
{k: v for k, v in p.items() if k in selected_keys}
|
||||
)
|
||||
if "news" in search_response:
|
||||
clean_news = []
|
||||
news = search_response["news"]["value"]
|
||||
for n in news:
|
||||
selected_keys = {"name", "url", "description"}
|
||||
clean_news.append({k: v for k, v in n.items() if k in selected_keys})
|
||||
|
||||
clean_response.append(clean_news)
|
||||
|
||||
return {"query": query, "top_k": clean_response}
|
|
@ -0,0 +1,16 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class BingSearchToolConfig(BaseModel):
|
||||
"""Configuration for Bing Search Tool Runtime"""
|
||||
|
||||
api_key: Optional[str] = None
|
||||
top_k: int = 3
|
|
@ -14,7 +14,7 @@ class BraveSearchToolProviderDataValidator(BaseModel):
|
|||
api_key: str
|
||||
|
||||
|
||||
async def get_provider_impl(config: BraveSearchToolConfig, _deps):
|
||||
async def get_adapter_impl(config: BraveSearchToolConfig, _deps):
|
||||
impl = BraveSearchToolRuntimeImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
|
@ -4,11 +4,19 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
from llama_models.llama3.api.datatypes import BuiltinTool
|
||||
|
||||
from llama_stack.apis.tools import Tool, ToolGroupDef, ToolInvocationResult, ToolRuntime
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.tools import (
|
||||
Tool,
|
||||
ToolDef,
|
||||
ToolInvocationResult,
|
||||
ToolParameter,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||
|
||||
|
@ -25,8 +33,7 @@ class BraveSearchToolRuntimeImpl(
|
|||
pass
|
||||
|
||||
async def register_tool(self, tool: Tool):
|
||||
if tool.identifier != "brave_search":
|
||||
raise ValueError(f"Tool identifier {tool.identifier} is not supported")
|
||||
pass
|
||||
|
||||
async def unregister_tool(self, tool_id: str) -> None:
|
||||
return
|
||||
|
@ -42,8 +49,23 @@ class BraveSearchToolRuntimeImpl(
|
|||
)
|
||||
return provider_data.api_key
|
||||
|
||||
async def discover_tools(self, tool_group: ToolGroupDef) -> List[Tool]:
|
||||
raise NotImplementedError("Brave search tool group not supported")
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||
) -> List[ToolDef]:
|
||||
return [
|
||||
ToolDef(
|
||||
name="web_search",
|
||||
description="Search the web for information",
|
||||
parameters=[
|
||||
ToolParameter(
|
||||
name="query",
|
||||
description="The query to search for",
|
||||
parameter_type="string",
|
||||
)
|
||||
],
|
||||
built_in_type=BuiltinTool.brave_search,
|
||||
)
|
||||
]
|
||||
|
||||
async def invoke_tool(
|
||||
self, tool_name: str, args: Dict[str, Any]
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
@ -18,3 +18,10 @@ class BraveSearchToolConfig(BaseModel):
|
|||
default=3,
|
||||
description="The maximum number of results to return",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]:
|
||||
return {
|
||||
"api_key": "${env.BRAVE_SEARCH_API_KEY:}",
|
||||
"max_results": 3,
|
||||
}
|
|
@ -4,22 +4,21 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from mcp import ClientSession
|
||||
from mcp.client.sse import sse_client
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.tools import (
|
||||
MCPToolGroupDef,
|
||||
ToolDef,
|
||||
ToolGroupDef,
|
||||
ToolInvocationResult,
|
||||
ToolParameter,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||
|
||||
from mcp import ClientSession
|
||||
from mcp.client.sse import sse_client
|
||||
|
||||
from .config import ModelContextProtocolConfig
|
||||
|
||||
|
||||
|
@ -30,12 +29,14 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
|||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def discover_tools(self, tool_group: ToolGroupDef) -> List[ToolDef]:
|
||||
if not isinstance(tool_group, MCPToolGroupDef):
|
||||
raise ValueError(f"Unsupported tool group type: {type(tool_group)}")
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||
) -> List[ToolDef]:
|
||||
if mcp_endpoint is None:
|
||||
raise ValueError("mcp_endpoint is required")
|
||||
|
||||
tools = []
|
||||
async with sse_client(tool_group.endpoint.uri) as streams:
|
||||
async with sse_client(mcp_endpoint.uri) as streams:
|
||||
async with ClientSession(*streams) as session:
|
||||
await session.initialize()
|
||||
tools_result = await session.list_tools()
|
||||
|
@ -57,7 +58,7 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
|||
description=tool.description,
|
||||
parameters=parameters,
|
||||
metadata={
|
||||
"endpoint": tool_group.endpoint.uri,
|
||||
"endpoint": mcp_endpoint.uri,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .config import TavilySearchToolConfig
|
||||
from .tavily_search import TavilySearchToolRuntimeImpl
|
||||
|
||||
|
||||
class TavilySearchToolProviderDataValidator(BaseModel):
|
||||
api_key: str
|
||||
|
||||
|
||||
async def get_adapter_impl(config: TavilySearchToolConfig, _deps):
|
||||
impl = TavilySearchToolRuntimeImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class TavilySearchToolConfig(BaseModel):
|
||||
api_key: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The Tavily Search API Key",
|
||||
)
|
||||
max_results: int = Field(
|
||||
default=3,
|
||||
description="The maximum number of results to return",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]:
|
||||
return {
|
||||
"api_key": "${env.TAVILY_SEARCH_API_KEY:}",
|
||||
"max_results": 3,
|
||||
}
|
|
@ -0,0 +1,83 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.tools import (
|
||||
Tool,
|
||||
ToolDef,
|
||||
ToolInvocationResult,
|
||||
ToolParameter,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||
|
||||
from .config import TavilySearchToolConfig
|
||||
|
||||
|
||||
class TavilySearchToolRuntimeImpl(
|
||||
ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData
|
||||
):
|
||||
def __init__(self, config: TavilySearchToolConfig):
|
||||
self.config = config
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def register_tool(self, tool: Tool):
|
||||
pass
|
||||
|
||||
async def unregister_tool(self, tool_id: str) -> None:
|
||||
return
|
||||
|
||||
def _get_api_key(self) -> str:
|
||||
if self.config.api_key:
|
||||
return self.config.api_key
|
||||
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.api_key:
|
||||
raise ValueError(
|
||||
'Pass Search provider\'s API Key in the header X-LlamaStack-ProviderData as { "api_key": <your api key>}'
|
||||
)
|
||||
return provider_data.api_key
|
||||
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||
) -> List[ToolDef]:
|
||||
return [
|
||||
ToolDef(
|
||||
name="web_search",
|
||||
description="Search the web for information",
|
||||
parameters=[
|
||||
ToolParameter(
|
||||
name="query",
|
||||
description="The query to search for",
|
||||
parameter_type="string",
|
||||
)
|
||||
],
|
||||
)
|
||||
]
|
||||
|
||||
async def invoke_tool(
|
||||
self, tool_name: str, args: Dict[str, Any]
|
||||
) -> ToolInvocationResult:
|
||||
api_key = self._get_api_key()
|
||||
response = requests.post(
|
||||
"https://api.tavily.com/search",
|
||||
json={"api_key": api_key, "query": args["query"]},
|
||||
)
|
||||
|
||||
return ToolInvocationResult(
|
||||
content=json.dumps(self._clean_tavily_response(response.json()))
|
||||
)
|
||||
|
||||
def _clean_tavily_response(self, search_response, top_k=3):
|
||||
return {"query": search_response["query"], "top_k": search_response["results"]}
|
|
@ -0,0 +1,22 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .config import WolframAlphaToolConfig
|
||||
from .wolfram_alpha import WolframAlphaToolRuntimeImpl
|
||||
|
||||
__all__ = ["WolframAlphaToolConfig", "WolframAlphaToolRuntimeImpl"]
|
||||
|
||||
|
||||
class WolframAlphaToolProviderDataValidator(BaseModel):
|
||||
api_key: str
|
||||
|
||||
|
||||
async def get_adapter_impl(config: WolframAlphaToolConfig, _deps):
|
||||
impl = WolframAlphaToolRuntimeImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
|
@ -0,0 +1,15 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class WolframAlphaToolConfig(BaseModel):
|
||||
"""Configuration for WolframAlpha Tool Runtime"""
|
||||
|
||||
api_key: Optional[str] = None
|
|
@ -0,0 +1,146 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.tools import (
|
||||
Tool,
|
||||
ToolDef,
|
||||
ToolInvocationResult,
|
||||
ToolParameter,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||
|
||||
from .config import WolframAlphaToolConfig
|
||||
|
||||
|
||||
class WolframAlphaToolRuntimeImpl(
|
||||
ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData
|
||||
):
|
||||
def __init__(self, config: WolframAlphaToolConfig):
|
||||
self.config = config
|
||||
self.url = "https://api.wolframalpha.com/v2/query"
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def register_tool(self, tool: Tool):
|
||||
pass
|
||||
|
||||
async def unregister_tool(self, tool_id: str) -> None:
|
||||
return
|
||||
|
||||
def _get_api_key(self) -> str:
|
||||
if self.config.api_key:
|
||||
return self.config.api_key
|
||||
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.api_key:
|
||||
raise ValueError(
|
||||
'Pass WolframAlpha API Key in the header X-LlamaStack-ProviderData as { "api_key": <your api key>}'
|
||||
)
|
||||
return provider_data.api_key
|
||||
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||
) -> List[ToolDef]:
|
||||
return [
|
||||
ToolDef(
|
||||
name="wolfram_alpha",
|
||||
description="Query WolframAlpha for computational knowledge",
|
||||
parameters=[
|
||||
ToolParameter(
|
||||
name="query",
|
||||
description="The query to compute",
|
||||
parameter_type="string",
|
||||
)
|
||||
],
|
||||
)
|
||||
]
|
||||
|
||||
async def invoke_tool(
|
||||
self, tool_name: str, args: Dict[str, Any]
|
||||
) -> ToolInvocationResult:
|
||||
api_key = self._get_api_key()
|
||||
params = {
|
||||
"input": args["query"],
|
||||
"appid": api_key,
|
||||
"format": "plaintext",
|
||||
"output": "json",
|
||||
}
|
||||
response = requests.get(
|
||||
self.url,
|
||||
params=params,
|
||||
)
|
||||
|
||||
return ToolInvocationResult(
|
||||
content=json.dumps(self._clean_wolfram_alpha_response(response.json()))
|
||||
)
|
||||
|
||||
def _clean_wolfram_alpha_response(self, wa_response):
|
||||
remove = {
|
||||
"queryresult": [
|
||||
"datatypes",
|
||||
"error",
|
||||
"timedout",
|
||||
"timedoutpods",
|
||||
"numpods",
|
||||
"timing",
|
||||
"parsetiming",
|
||||
"parsetimedout",
|
||||
"recalculate",
|
||||
"id",
|
||||
"host",
|
||||
"server",
|
||||
"related",
|
||||
"version",
|
||||
{
|
||||
"pods": [
|
||||
"scanner",
|
||||
"id",
|
||||
"error",
|
||||
"expressiontypes",
|
||||
"states",
|
||||
"infos",
|
||||
"position",
|
||||
"numsubpods",
|
||||
]
|
||||
},
|
||||
"assumptions",
|
||||
],
|
||||
}
|
||||
for main_key in remove:
|
||||
for key_to_remove in remove[main_key]:
|
||||
try:
|
||||
if key_to_remove == "assumptions":
|
||||
if "assumptions" in wa_response[main_key]:
|
||||
del wa_response[main_key][key_to_remove]
|
||||
if isinstance(key_to_remove, dict):
|
||||
for sub_key in key_to_remove:
|
||||
if sub_key == "pods":
|
||||
for i in range(len(wa_response[main_key][sub_key])):
|
||||
if (
|
||||
wa_response[main_key][sub_key][i]["title"]
|
||||
== "Result"
|
||||
):
|
||||
del wa_response[main_key][sub_key][i + 1 :]
|
||||
break
|
||||
sub_items = wa_response[main_key][sub_key]
|
||||
for i in range(len(sub_items)):
|
||||
for sub_key_to_remove in key_to_remove[sub_key]:
|
||||
if sub_key_to_remove in sub_items[i]:
|
||||
del sub_items[i][sub_key_to_remove]
|
||||
elif key_to_remove in wa_response[main_key]:
|
||||
del wa_response[main_key][key_to_remove]
|
||||
except KeyError:
|
||||
pass
|
||||
return wa_response
|
|
@ -7,13 +7,12 @@
|
|||
import pytest
|
||||
|
||||
from ..conftest import get_provider_fixture_overrides
|
||||
|
||||
from ..inference.fixtures import INFERENCE_FIXTURES
|
||||
from ..memory.fixtures import MEMORY_FIXTURES
|
||||
from ..safety.fixtures import SAFETY_FIXTURES, safety_model_from_shield
|
||||
from ..tools.fixtures import TOOL_RUNTIME_FIXTURES
|
||||
from .fixtures import AGENTS_FIXTURES
|
||||
|
||||
|
||||
DEFAULT_PROVIDER_COMBINATIONS = [
|
||||
pytest.param(
|
||||
{
|
||||
|
@ -21,6 +20,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
|||
"safety": "llama_guard",
|
||||
"memory": "faiss",
|
||||
"agents": "meta_reference",
|
||||
"tool_runtime": "memory_and_search",
|
||||
},
|
||||
id="meta_reference",
|
||||
marks=pytest.mark.meta_reference,
|
||||
|
@ -31,6 +31,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
|||
"safety": "llama_guard",
|
||||
"memory": "faiss",
|
||||
"agents": "meta_reference",
|
||||
"tool_runtime": "memory_and_search",
|
||||
},
|
||||
id="ollama",
|
||||
marks=pytest.mark.ollama,
|
||||
|
@ -42,6 +43,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
|||
# make this work with Weaviate which is what the together distro supports
|
||||
"memory": "faiss",
|
||||
"agents": "meta_reference",
|
||||
"tool_runtime": "memory_and_search",
|
||||
},
|
||||
id="together",
|
||||
marks=pytest.mark.together,
|
||||
|
@ -52,6 +54,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
|||
"safety": "llama_guard",
|
||||
"memory": "faiss",
|
||||
"agents": "meta_reference",
|
||||
"tool_runtime": "memory_and_search",
|
||||
},
|
||||
id="fireworks",
|
||||
marks=pytest.mark.fireworks,
|
||||
|
@ -62,6 +65,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
|||
"safety": "remote",
|
||||
"memory": "remote",
|
||||
"agents": "remote",
|
||||
"tool_runtime": "memory_and_search",
|
||||
},
|
||||
id="remote",
|
||||
marks=pytest.mark.remote,
|
||||
|
@ -117,6 +121,7 @@ def pytest_generate_tests(metafunc):
|
|||
"safety": SAFETY_FIXTURES,
|
||||
"memory": MEMORY_FIXTURES,
|
||||
"agents": AGENTS_FIXTURES,
|
||||
"tool_runtime": TOOL_RUNTIME_FIXTURES,
|
||||
}
|
||||
combinations = (
|
||||
get_provider_fixture_overrides(metafunc.config, available_fixtures)
|
||||
|
|
|
@ -11,13 +11,12 @@ import pytest_asyncio
|
|||
|
||||
from llama_stack.apis.models import ModelInput, ModelType
|
||||
from llama_stack.distribution.datatypes import Api, Provider
|
||||
|
||||
from llama_stack.providers.inline.agents.meta_reference import (
|
||||
MetaReferenceAgentsImplConfig,
|
||||
)
|
||||
|
||||
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
|
||||
from ..conftest import ProviderFixture, remote_stack_fixture
|
||||
|
||||
|
||||
|
@ -59,12 +58,18 @@ AGENTS_FIXTURES = ["meta_reference", "remote"]
|
|||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def agents_stack(request, inference_model, safety_shield):
|
||||
async def agents_stack(
|
||||
request,
|
||||
inference_model,
|
||||
safety_shield,
|
||||
tool_group_input_memory,
|
||||
tool_group_input_tavily_search,
|
||||
):
|
||||
fixture_dict = request.param
|
||||
|
||||
providers = {}
|
||||
provider_data = {}
|
||||
for key in ["inference", "safety", "memory", "agents"]:
|
||||
for key in ["inference", "safety", "memory", "agents", "tool_runtime"]:
|
||||
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
|
||||
providers[key] = fixture.providers
|
||||
if key == "inference":
|
||||
|
@ -113,10 +118,11 @@ async def agents_stack(request, inference_model, safety_shield):
|
|||
)
|
||||
|
||||
test_stack = await construct_stack_for_test(
|
||||
[Api.agents, Api.inference, Api.safety, Api.memory],
|
||||
[Api.agents, Api.inference, Api.safety, Api.memory, Api.tool_runtime],
|
||||
providers,
|
||||
provider_data,
|
||||
models=models,
|
||||
shields=[safety_shield] if safety_shield else [],
|
||||
tool_groups=[tool_group_input_memory, tool_group_input_tavily_search],
|
||||
)
|
||||
return test_stack
|
||||
|
|
|
@ -5,22 +5,17 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from typing import Dict, List
|
||||
|
||||
import pytest
|
||||
from llama_models.llama3.api.datatypes import BuiltinTool
|
||||
|
||||
from llama_stack.apis.agents import (
|
||||
AgentConfig,
|
||||
AgentTool,
|
||||
AgentTurnResponseEventType,
|
||||
AgentTurnResponseStepCompletePayload,
|
||||
AgentTurnResponseStreamChunk,
|
||||
AgentTurnResponseTurnCompletePayload,
|
||||
Attachment,
|
||||
MemoryToolDefinition,
|
||||
SearchEngineType,
|
||||
SearchToolDefinition,
|
||||
Document,
|
||||
ShieldCallStep,
|
||||
StepType,
|
||||
ToolChoice,
|
||||
|
@ -35,7 +30,6 @@ from llama_stack.providers.datatypes import Api
|
|||
#
|
||||
# pytest -v -s llama_stack/providers/tests/agents/test_agents.py
|
||||
# -m "meta_reference"
|
||||
|
||||
from .fixtures import pick_inference_model
|
||||
from .utils import create_agent_session
|
||||
|
||||
|
@ -51,7 +45,7 @@ def common_params(inference_model):
|
|||
sampling_params=SamplingParams(temperature=0.7, top_p=0.95),
|
||||
input_shields=[],
|
||||
output_shields=[],
|
||||
tools=[],
|
||||
toolgroups=[],
|
||||
max_infer_iters=5,
|
||||
)
|
||||
|
||||
|
@ -88,73 +82,6 @@ def query_attachment_messages():
|
|||
]
|
||||
|
||||
|
||||
async def create_agent_turn_with_search_tool(
|
||||
agents_stack: Dict[str, object],
|
||||
search_query_messages: List[object],
|
||||
common_params: Dict[str, str],
|
||||
search_tool_definition: SearchToolDefinition,
|
||||
) -> None:
|
||||
"""
|
||||
Create an agent turn with a search tool.
|
||||
|
||||
Args:
|
||||
agents_stack (Dict[str, object]): The agents stack.
|
||||
search_query_messages (List[object]): The search query messages.
|
||||
common_params (Dict[str, str]): The common parameters.
|
||||
search_tool_definition (SearchToolDefinition): The search tool definition.
|
||||
"""
|
||||
|
||||
# Create an agent with the search tool
|
||||
agent_config = AgentConfig(
|
||||
**{
|
||||
**common_params,
|
||||
"tools": [search_tool_definition],
|
||||
}
|
||||
)
|
||||
|
||||
agent_id, session_id = await create_agent_session(
|
||||
agents_stack.impls[Api.agents], agent_config
|
||||
)
|
||||
turn_request = dict(
|
||||
agent_id=agent_id,
|
||||
session_id=session_id,
|
||||
messages=search_query_messages,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
turn_response = [
|
||||
chunk
|
||||
async for chunk in await agents_stack.impls[Api.agents].create_agent_turn(
|
||||
**turn_request
|
||||
)
|
||||
]
|
||||
|
||||
assert len(turn_response) > 0
|
||||
assert all(
|
||||
isinstance(chunk, AgentTurnResponseStreamChunk) for chunk in turn_response
|
||||
)
|
||||
|
||||
check_event_types(turn_response)
|
||||
|
||||
# Check for tool execution events
|
||||
tool_execution_events = [
|
||||
chunk
|
||||
for chunk in turn_response
|
||||
if isinstance(chunk.event.payload, AgentTurnResponseStepCompletePayload)
|
||||
and chunk.event.payload.step_details.step_type == StepType.tool_execution.value
|
||||
]
|
||||
assert len(tool_execution_events) > 0, "No tool execution events found"
|
||||
|
||||
# Check the tool execution details
|
||||
tool_execution = tool_execution_events[0].event.payload.step_details
|
||||
assert isinstance(tool_execution, ToolExecutionStep)
|
||||
assert len(tool_execution.tool_calls) > 0
|
||||
assert tool_execution.tool_calls[0].tool_name == BuiltinTool.brave_search
|
||||
assert len(tool_execution.tool_responses) > 0
|
||||
|
||||
check_turn_complete_event(turn_response, session_id, search_query_messages)
|
||||
|
||||
|
||||
class TestAgents:
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_turns_with_safety(
|
||||
|
@ -227,7 +154,7 @@ class TestAgents:
|
|||
check_turn_complete_event(turn_response, session_id, sample_messages)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rag_agent_as_attachments(
|
||||
async def test_rag_agent(
|
||||
self,
|
||||
agents_stack,
|
||||
attachment_message,
|
||||
|
@ -243,29 +170,17 @@ class TestAgents:
|
|||
"qat_finetune.rst",
|
||||
"lora_finetune.rst",
|
||||
]
|
||||
|
||||
attachments = [
|
||||
Attachment(
|
||||
documents = [
|
||||
Document(
|
||||
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
|
||||
mime_type="text/plain",
|
||||
)
|
||||
for i, url in enumerate(urls)
|
||||
]
|
||||
|
||||
agent_config = AgentConfig(
|
||||
**{
|
||||
**common_params,
|
||||
"tools": [
|
||||
MemoryToolDefinition(
|
||||
memory_bank_configs=[],
|
||||
query_generator_config={
|
||||
"type": "default",
|
||||
"sep": " ",
|
||||
},
|
||||
max_tokens_in_context=4096,
|
||||
max_chunks=10,
|
||||
),
|
||||
],
|
||||
"toolgroups": ["builtin::memory"],
|
||||
"tool_choice": ToolChoice.auto,
|
||||
}
|
||||
)
|
||||
|
@ -275,7 +190,7 @@ class TestAgents:
|
|||
agent_id=agent_id,
|
||||
session_id=session_id,
|
||||
messages=attachment_message,
|
||||
attachments=attachments,
|
||||
documents=documents,
|
||||
stream=True,
|
||||
)
|
||||
turn_response = [
|
||||
|
@ -298,22 +213,6 @@ class TestAgents:
|
|||
|
||||
assert len(turn_response) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_agent_turn_with_brave_search(
|
||||
self, agents_stack, search_query_messages, common_params
|
||||
):
|
||||
if "BRAVE_SEARCH_API_KEY" not in os.environ:
|
||||
pytest.skip("BRAVE_SEARCH_API_KEY not set, skipping test")
|
||||
|
||||
search_tool_definition = SearchToolDefinition(
|
||||
type=AgentTool.brave_search.value,
|
||||
api_key=os.environ["BRAVE_SEARCH_API_KEY"],
|
||||
engine=SearchEngineType.brave,
|
||||
)
|
||||
await create_agent_turn_with_search_tool(
|
||||
agents_stack, search_query_messages, common_params, search_tool_definition
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_agent_turn_with_tavily_search(
|
||||
self, agents_stack, search_query_messages, common_params
|
||||
|
@ -321,14 +220,57 @@ class TestAgents:
|
|||
if "TAVILY_SEARCH_API_KEY" not in os.environ:
|
||||
pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test")
|
||||
|
||||
search_tool_definition = SearchToolDefinition(
|
||||
type=AgentTool.brave_search.value, # place holder only
|
||||
api_key=os.environ["TAVILY_SEARCH_API_KEY"],
|
||||
engine=SearchEngineType.tavily,
|
||||
# Create an agent with the toolgroup
|
||||
agent_config = AgentConfig(
|
||||
**{
|
||||
**common_params,
|
||||
"toolgroups": ["builtin::web_search"],
|
||||
}
|
||||
)
|
||||
await create_agent_turn_with_search_tool(
|
||||
agents_stack, search_query_messages, common_params, search_tool_definition
|
||||
|
||||
agent_id, session_id = await create_agent_session(
|
||||
agents_stack.impls[Api.agents], agent_config
|
||||
)
|
||||
turn_request = dict(
|
||||
agent_id=agent_id,
|
||||
session_id=session_id,
|
||||
messages=search_query_messages,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
turn_response = [
|
||||
chunk
|
||||
async for chunk in await agents_stack.impls[Api.agents].create_agent_turn(
|
||||
**turn_request
|
||||
)
|
||||
]
|
||||
|
||||
assert len(turn_response) > 0
|
||||
assert all(
|
||||
isinstance(chunk, AgentTurnResponseStreamChunk) for chunk in turn_response
|
||||
)
|
||||
|
||||
check_event_types(turn_response)
|
||||
|
||||
# Check for tool execution events
|
||||
tool_execution_events = [
|
||||
chunk
|
||||
for chunk in turn_response
|
||||
if isinstance(chunk.event.payload, AgentTurnResponseStepCompletePayload)
|
||||
and chunk.event.payload.step_details.step_type
|
||||
== StepType.tool_execution.value
|
||||
]
|
||||
assert len(tool_execution_events) > 0, "No tool execution events found"
|
||||
|
||||
# Check the tool execution details
|
||||
tool_execution = tool_execution_events[0].event.payload.step_details
|
||||
assert isinstance(tool_execution, ToolExecutionStep)
|
||||
assert len(tool_execution.tool_calls) > 0
|
||||
actual_tool_name = tool_execution.tool_calls[0].tool_name
|
||||
assert actual_tool_name == BuiltinTool.brave_search
|
||||
assert len(tool_execution.tool_responses) > 0
|
||||
|
||||
check_turn_complete_event(turn_response, session_id, search_query_messages)
|
||||
|
||||
|
||||
def check_event_types(turn_response):
|
||||
|
|
|
@ -157,4 +157,5 @@ pytest_plugins = [
|
|||
"llama_stack.providers.tests.scoring.fixtures",
|
||||
"llama_stack.providers.tests.eval.fixtures",
|
||||
"llama_stack.providers.tests.post_training.fixtures",
|
||||
"llama_stack.providers.tests.tools.fixtures",
|
||||
]
|
||||
|
|
|
@ -19,6 +19,7 @@ from llama_stack.providers.remote.memory.pgvector import PGVectorConfig
|
|||
from llama_stack.providers.remote.memory.weaviate import WeaviateConfig
|
||||
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
|
||||
from ..conftest import ProviderFixture, remote_stack_fixture
|
||||
from ..env import get_env_or_fail
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@ from llama_stack.apis.memory_banks import MemoryBankInput
|
|||
from llama_stack.apis.models import ModelInput
|
||||
from llama_stack.apis.scoring_functions import ScoringFnInput
|
||||
from llama_stack.apis.shields import ShieldInput
|
||||
|
||||
from llama_stack.apis.tools import ToolGroupInput
|
||||
from llama_stack.distribution.build import print_pip_install_help
|
||||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||
from llama_stack.distribution.datatypes import Provider, StackRunConfig
|
||||
|
@ -43,6 +43,7 @@ async def construct_stack_for_test(
|
|||
datasets: Optional[List[DatasetInput]] = None,
|
||||
scoring_fns: Optional[List[ScoringFnInput]] = None,
|
||||
eval_tasks: Optional[List[EvalTaskInput]] = None,
|
||||
tool_groups: Optional[List[ToolGroupInput]] = None,
|
||||
) -> TestStack:
|
||||
sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
|
||||
run_config = dict(
|
||||
|
@ -56,6 +57,7 @@ async def construct_stack_for_test(
|
|||
datasets=datasets or [],
|
||||
scoring_fns=scoring_fns or [],
|
||||
eval_tasks=eval_tasks or [],
|
||||
tool_groups=tool_groups or [],
|
||||
)
|
||||
run_config = parse_and_maybe_upgrade_config(run_config)
|
||||
try:
|
||||
|
|
65
llama_stack/providers/tests/tools/conftest.py
Normal file
65
llama_stack/providers/tests/tools/conftest.py
Normal file
|
@ -0,0 +1,65 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
|
||||
from ..conftest import get_provider_fixture_overrides
|
||||
from ..inference.fixtures import INFERENCE_FIXTURES
|
||||
from ..memory.fixtures import MEMORY_FIXTURES
|
||||
from ..safety.fixtures import SAFETY_FIXTURES
|
||||
from .fixtures import TOOL_RUNTIME_FIXTURES
|
||||
|
||||
DEFAULT_PROVIDER_COMBINATIONS = [
|
||||
pytest.param(
|
||||
{
|
||||
"inference": "together",
|
||||
"safety": "llama_guard",
|
||||
"memory": "faiss",
|
||||
"tool_runtime": "memory_and_search",
|
||||
},
|
||||
id="together",
|
||||
marks=pytest.mark.together,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
for mark in ["together"]:
|
||||
config.addinivalue_line(
|
||||
"markers",
|
||||
f"{mark}: marks tests as {mark} specific",
|
||||
)
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption(
|
||||
"--inference-model",
|
||||
action="store",
|
||||
default="meta-llama/Llama-3.2-3B-Instruct",
|
||||
help="Specify the inference model to use for testing",
|
||||
)
|
||||
parser.addoption(
|
||||
"--safety-shield",
|
||||
action="store",
|
||||
default="meta-llama/Llama-Guard-3-1B",
|
||||
help="Specify the safety shield to use for testing",
|
||||
)
|
||||
|
||||
|
||||
def pytest_generate_tests(metafunc):
|
||||
if "tools_stack" in metafunc.fixturenames:
|
||||
available_fixtures = {
|
||||
"inference": INFERENCE_FIXTURES,
|
||||
"safety": SAFETY_FIXTURES,
|
||||
"memory": MEMORY_FIXTURES,
|
||||
"tool_runtime": TOOL_RUNTIME_FIXTURES,
|
||||
}
|
||||
combinations = (
|
||||
get_provider_fixture_overrides(metafunc.config, available_fixtures)
|
||||
or DEFAULT_PROVIDER_COMBINATIONS
|
||||
)
|
||||
print(combinations)
|
||||
metafunc.parametrize("tools_stack", combinations, indirect=True)
|
130
llama_stack/providers/tests/tools/fixtures.py
Normal file
130
llama_stack/providers/tests/tools/fixtures.py
Normal file
|
@ -0,0 +1,130 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from llama_stack.apis.models import ModelInput, ModelType
|
||||
from llama_stack.apis.tools import ToolGroupInput
|
||||
from llama_stack.distribution.datatypes import Api, Provider
|
||||
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
||||
|
||||
from ..conftest import ProviderFixture
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def tool_runtime_memory_and_search() -> ProviderFixture:
|
||||
return ProviderFixture(
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id="memory-runtime",
|
||||
provider_type="inline::memory-runtime",
|
||||
config={},
|
||||
),
|
||||
Provider(
|
||||
provider_id="tavily-search",
|
||||
provider_type="remote::tavily-search",
|
||||
config={
|
||||
"api_key": os.environ["TAVILY_SEARCH_API_KEY"],
|
||||
},
|
||||
),
|
||||
Provider(
|
||||
provider_id="wolfram-alpha",
|
||||
provider_type="remote::wolfram-alpha",
|
||||
config={
|
||||
"api_key": os.environ["WOLFRAM_ALPHA_API_KEY"],
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def tool_group_input_memory() -> ToolGroupInput:
|
||||
return ToolGroupInput(
|
||||
toolgroup_id="builtin::memory",
|
||||
provider_id="memory-runtime",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def tool_group_input_tavily_search() -> ToolGroupInput:
|
||||
return ToolGroupInput(
|
||||
toolgroup_id="builtin::web_search",
|
||||
provider_id="tavily-search",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def tool_group_input_wolfram_alpha() -> ToolGroupInput:
|
||||
return ToolGroupInput(
|
||||
toolgroup_id="builtin::wolfram_alpha",
|
||||
provider_id="wolfram-alpha",
|
||||
)
|
||||
|
||||
|
||||
TOOL_RUNTIME_FIXTURES = ["memory_and_search"]
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def tools_stack(
|
||||
request,
|
||||
inference_model,
|
||||
tool_group_input_memory,
|
||||
tool_group_input_tavily_search,
|
||||
tool_group_input_wolfram_alpha,
|
||||
):
|
||||
fixture_dict = request.param
|
||||
|
||||
providers = {}
|
||||
provider_data = {}
|
||||
for key in ["inference", "memory", "tool_runtime"]:
|
||||
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
|
||||
providers[key] = fixture.providers
|
||||
if key == "inference":
|
||||
providers[key].append(
|
||||
Provider(
|
||||
provider_id="tools_memory_provider",
|
||||
provider_type="inline::sentence-transformers",
|
||||
config={},
|
||||
)
|
||||
)
|
||||
if fixture.provider_data:
|
||||
provider_data.update(fixture.provider_data)
|
||||
inference_models = (
|
||||
inference_model if isinstance(inference_model, list) else [inference_model]
|
||||
)
|
||||
models = [
|
||||
ModelInput(
|
||||
model_id=model,
|
||||
model_type=ModelType.llm,
|
||||
provider_id=providers["inference"][0].provider_id,
|
||||
)
|
||||
for model in inference_models
|
||||
]
|
||||
models.append(
|
||||
ModelInput(
|
||||
model_id="all-MiniLM-L6-v2",
|
||||
model_type=ModelType.embedding,
|
||||
provider_id="tools_memory_provider",
|
||||
metadata={"embedding_dimension": 384},
|
||||
)
|
||||
)
|
||||
|
||||
test_stack = await construct_stack_for_test(
|
||||
[Api.tool_groups, Api.inference, Api.memory, Api.tool_runtime],
|
||||
providers,
|
||||
provider_data,
|
||||
models=models,
|
||||
tool_groups=[
|
||||
tool_group_input_tavily_search,
|
||||
tool_group_input_wolfram_alpha,
|
||||
tool_group_input_memory,
|
||||
],
|
||||
)
|
||||
return test_stack
|
127
llama_stack/providers/tests/tools/test_tools.py
Normal file
127
llama_stack/providers/tests/tools/test_tools.py
Normal file
|
@ -0,0 +1,127 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.inference import UserMessage
|
||||
from llama_stack.apis.memory import MemoryBankDocument
|
||||
from llama_stack.apis.memory_banks import VectorMemoryBankParams
|
||||
from llama_stack.apis.tools import ToolInvocationResult
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_search_query():
|
||||
return "What are the latest developments in quantum computing?"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_wolfram_alpha_query():
|
||||
return "What is the square root of 16?"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_documents():
|
||||
urls = [
|
||||
"memory_optimizations.rst",
|
||||
"chat.rst",
|
||||
"llama3.rst",
|
||||
"datasets.rst",
|
||||
"qat_finetune.rst",
|
||||
"lora_finetune.rst",
|
||||
]
|
||||
return [
|
||||
MemoryBankDocument(
|
||||
document_id=f"num-{i}",
|
||||
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
|
||||
mime_type="text/plain",
|
||||
metadata={},
|
||||
)
|
||||
for i, url in enumerate(urls)
|
||||
]
|
||||
|
||||
|
||||
class TestTools:
|
||||
@pytest.mark.asyncio
|
||||
async def test_web_search_tool(self, tools_stack, sample_search_query):
|
||||
"""Test the web search tool functionality."""
|
||||
if "TAVILY_SEARCH_API_KEY" not in os.environ:
|
||||
pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test")
|
||||
|
||||
tools_impl = tools_stack.impls[Api.tool_runtime]
|
||||
|
||||
# Execute the tool
|
||||
response = await tools_impl.invoke_tool(
|
||||
tool_name="web_search", args={"query": sample_search_query}
|
||||
)
|
||||
|
||||
# Verify the response
|
||||
assert isinstance(response, ToolInvocationResult)
|
||||
assert response.content is not None
|
||||
assert len(response.content) > 0
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wolfram_alpha_tool(self, tools_stack, sample_wolfram_alpha_query):
|
||||
"""Test the wolfram alpha tool functionality."""
|
||||
if "WOLFRAM_ALPHA_API_KEY" not in os.environ:
|
||||
pytest.skip("WOLFRAM_ALPHA_API_KEY not set, skipping test")
|
||||
|
||||
tools_impl = tools_stack.impls[Api.tool_runtime]
|
||||
|
||||
response = await tools_impl.invoke_tool(
|
||||
tool_name="wolfram_alpha", args={"query": sample_wolfram_alpha_query}
|
||||
)
|
||||
|
||||
# Verify the response
|
||||
assert isinstance(response, ToolInvocationResult)
|
||||
assert response.content is not None
|
||||
assert len(response.content) > 0
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_tool(self, tools_stack, sample_documents):
|
||||
"""Test the memory tool functionality."""
|
||||
memory_banks_impl = tools_stack.impls[Api.memory_banks]
|
||||
memory_impl = tools_stack.impls[Api.memory]
|
||||
tools_impl = tools_stack.impls[Api.tool_runtime]
|
||||
|
||||
# Register memory bank
|
||||
await memory_banks_impl.register_memory_bank(
|
||||
memory_bank_id="test_bank",
|
||||
params=VectorMemoryBankParams(
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=512,
|
||||
overlap_size_in_tokens=64,
|
||||
),
|
||||
provider_id="faiss",
|
||||
)
|
||||
|
||||
# Insert documents into memory
|
||||
await memory_impl.insert_documents(
|
||||
bank_id="test_bank",
|
||||
documents=sample_documents,
|
||||
)
|
||||
|
||||
# Execute the memory tool
|
||||
response = await tools_impl.invoke_tool(
|
||||
tool_name="memory",
|
||||
args={
|
||||
"messages": [
|
||||
UserMessage(
|
||||
content="What are the main topics covered in the documentation?",
|
||||
)
|
||||
],
|
||||
"memory_bank_ids": ["test_bank"],
|
||||
},
|
||||
)
|
||||
|
||||
# Verify the response
|
||||
assert isinstance(response, ToolInvocationResult)
|
||||
assert response.content is not None
|
||||
assert len(response.content) > 0
|
|
@ -14,7 +14,6 @@ from typing import List, Optional, Tuple, Union
|
|||
|
||||
import httpx
|
||||
from llama_models.datatypes import is_multimodal, ModelFamily
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.datatypes import (
|
||||
RawContent,
|
||||
|
@ -41,7 +40,6 @@ from llama_stack.apis.common.content_types import (
|
|||
InterleavedContentItem,
|
||||
TextContentItem,
|
||||
)
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
CompletionRequest,
|
||||
|
@ -52,7 +50,6 @@ from llama_stack.apis.inference import (
|
|||
ToolChoice,
|
||||
UserMessage,
|
||||
)
|
||||
|
||||
from llama_stack.providers.utils.inference import supported_inference_models
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
|
|
@ -9,8 +9,7 @@ from pathlib import Path
|
|||
from llama_models.sku_list import all_registered_models
|
||||
|
||||
from llama_stack.apis.models import ModelInput
|
||||
from llama_stack.distribution.datatypes import Provider
|
||||
|
||||
from llama_stack.distribution.datatypes import Provider, ToolGroupInput
|
||||
from llama_stack.providers.inline.memory.faiss.config import FaissImplConfig
|
||||
from llama_stack.providers.remote.inference.bedrock.bedrock import MODEL_ALIASES
|
||||
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
|
||||
|
@ -26,6 +25,12 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
"eval": ["inline::meta-reference"],
|
||||
"datasetio": ["remote::huggingface", "inline::localfs"],
|
||||
"scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"],
|
||||
"tool_runtime": [
|
||||
"remote::brave-search",
|
||||
"remote::tavily-search",
|
||||
"inline::code-interpreter",
|
||||
"inline::memory-runtime",
|
||||
],
|
||||
}
|
||||
name = "bedrock"
|
||||
memory_provider = Provider(
|
||||
|
@ -46,6 +51,20 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
)
|
||||
for m in MODEL_ALIASES
|
||||
]
|
||||
default_tool_groups = [
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::websearch",
|
||||
provider_id="tavily-search",
|
||||
),
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::memory",
|
||||
provider_id="memory-runtime",
|
||||
),
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::code_interpreter",
|
||||
provider_id="code-interpreter",
|
||||
),
|
||||
]
|
||||
|
||||
return DistributionTemplate(
|
||||
name=name,
|
||||
|
@ -61,6 +80,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
"memory": [memory_provider],
|
||||
},
|
||||
default_models=default_models,
|
||||
default_tool_groups=default_tool_groups,
|
||||
),
|
||||
},
|
||||
run_config_env_vars={
|
||||
|
|
|
@ -2,7 +2,6 @@ version: '2'
|
|||
name: bedrock
|
||||
distribution_spec:
|
||||
description: Use AWS Bedrock for running LLM inference and safety
|
||||
docker_image: null
|
||||
providers:
|
||||
inference:
|
||||
- remote::bedrock
|
||||
|
@ -25,4 +24,9 @@ distribution_spec:
|
|||
- inline::basic
|
||||
- inline::llm-as-judge
|
||||
- inline::braintrust
|
||||
tool_runtime:
|
||||
- remote::brave-search
|
||||
- remote::tavily-search
|
||||
- inline::code-interpreter
|
||||
- inline::memory-runtime
|
||||
image_type: conda
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
version: '2'
|
||||
image_name: bedrock
|
||||
docker_image: null
|
||||
conda_env: bedrock
|
||||
apis:
|
||||
- agents
|
||||
|
@ -11,6 +10,7 @@ apis:
|
|||
- safety
|
||||
- scoring
|
||||
- telemetry
|
||||
- tool_runtime
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: bedrock
|
||||
|
@ -65,8 +65,24 @@ providers:
|
|||
provider_type: inline::braintrust
|
||||
config:
|
||||
openai_api_key: ${env.OPENAI_API_KEY:}
|
||||
tool_runtime:
|
||||
- provider_id: brave-search
|
||||
provider_type: remote::brave-search
|
||||
config:
|
||||
api_key: ${env.BRAVE_SEARCH_API_KEY:}
|
||||
max_results: 3
|
||||
- provider_id: tavily-search
|
||||
provider_type: remote::tavily-search
|
||||
config:
|
||||
api_key: ${env.TAVILY_SEARCH_API_KEY:}
|
||||
max_results: 3
|
||||
- provider_id: code-interpreter
|
||||
provider_type: inline::code-interpreter
|
||||
config: {}
|
||||
- provider_id: memory-runtime
|
||||
provider_type: inline::memory-runtime
|
||||
config: {}
|
||||
metadata_store:
|
||||
namespace: null
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/bedrock}/registry.db
|
||||
models:
|
||||
|
@ -90,3 +106,10 @@ memory_banks: []
|
|||
datasets: []
|
||||
scoring_fns: []
|
||||
eval_tasks: []
|
||||
tool_groups:
|
||||
- toolgroup_id: builtin::websearch
|
||||
provider_id: tavily-search
|
||||
- toolgroup_id: builtin::memory
|
||||
provider_id: memory-runtime
|
||||
- toolgroup_id: builtin::code_interpreter
|
||||
provider_id: code-interpreter
|
||||
|
|
|
@ -2,7 +2,6 @@ version: '2'
|
|||
name: cerebras
|
||||
distribution_spec:
|
||||
description: Use Cerebras for running LLM inference
|
||||
docker_image: null
|
||||
providers:
|
||||
inference:
|
||||
- remote::cerebras
|
||||
|
@ -14,4 +13,9 @@ distribution_spec:
|
|||
- inline::meta-reference
|
||||
telemetry:
|
||||
- inline::meta-reference
|
||||
tool_runtime:
|
||||
- remote::brave-search
|
||||
- remote::tavily-search
|
||||
- inline::code-interpreter
|
||||
- inline::memory-runtime
|
||||
image_type: conda
|
||||
|
|
|
@ -9,8 +9,12 @@ from pathlib import Path
|
|||
from llama_models.sku_list import all_registered_models
|
||||
|
||||
from llama_stack.apis.models.models import ModelType
|
||||
|
||||
from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput
|
||||
from llama_stack.distribution.datatypes import (
|
||||
ModelInput,
|
||||
Provider,
|
||||
ShieldInput,
|
||||
ToolGroupInput,
|
||||
)
|
||||
from llama_stack.providers.inline.inference.sentence_transformers import (
|
||||
SentenceTransformersInferenceConfig,
|
||||
)
|
||||
|
@ -26,6 +30,12 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
"memory": ["inline::meta-reference"],
|
||||
"agents": ["inline::meta-reference"],
|
||||
"telemetry": ["inline::meta-reference"],
|
||||
"tool_runtime": [
|
||||
"remote::brave-search",
|
||||
"remote::tavily-search",
|
||||
"inline::code-interpreter",
|
||||
"inline::memory-runtime",
|
||||
],
|
||||
}
|
||||
|
||||
inference_provider = Provider(
|
||||
|
@ -58,6 +68,20 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
"embedding_dimension": 384,
|
||||
},
|
||||
)
|
||||
default_tool_groups = [
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::websearch",
|
||||
provider_id="tavily-search",
|
||||
),
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::memory",
|
||||
provider_id="memory-runtime",
|
||||
),
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::code_interpreter",
|
||||
provider_id="code-interpreter",
|
||||
),
|
||||
]
|
||||
|
||||
return DistributionTemplate(
|
||||
name="cerebras",
|
||||
|
@ -74,6 +98,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
},
|
||||
default_models=default_models + [embedding_model],
|
||||
default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")],
|
||||
default_tool_groups=default_tool_groups,
|
||||
),
|
||||
},
|
||||
run_config_env_vars={
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
version: '2'
|
||||
image_name: cerebras
|
||||
docker_image: null
|
||||
conda_env: cerebras
|
||||
apis:
|
||||
- agents
|
||||
|
@ -8,6 +7,7 @@ apis:
|
|||
- memory
|
||||
- safety
|
||||
- telemetry
|
||||
- tool_runtime
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: cerebras
|
||||
|
@ -45,8 +45,24 @@ providers:
|
|||
service_name: ${env.OTEL_SERVICE_NAME:llama-stack}
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/cerebras/trace_store.db}
|
||||
tool_runtime:
|
||||
- provider_id: brave-search
|
||||
provider_type: remote::brave-search
|
||||
config:
|
||||
api_key: ${env.BRAVE_SEARCH_API_KEY:}
|
||||
max_results: 3
|
||||
- provider_id: tavily-search
|
||||
provider_type: remote::tavily-search
|
||||
config:
|
||||
api_key: ${env.TAVILY_SEARCH_API_KEY:}
|
||||
max_results: 3
|
||||
- provider_id: code-interpreter
|
||||
provider_type: inline::code-interpreter
|
||||
config: {}
|
||||
- provider_id: memory-runtime
|
||||
provider_type: inline::memory-runtime
|
||||
config: {}
|
||||
metadata_store:
|
||||
namespace: null
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/cerebras}/registry.db
|
||||
models:
|
||||
|
@ -64,14 +80,17 @@ models:
|
|||
embedding_dimension: 384
|
||||
model_id: all-MiniLM-L6-v2
|
||||
provider_id: sentence-transformers
|
||||
provider_model_id: null
|
||||
model_type: embedding
|
||||
shields:
|
||||
- params: null
|
||||
shield_id: meta-llama/Llama-Guard-3-8B
|
||||
provider_id: null
|
||||
provider_shield_id: null
|
||||
- shield_id: meta-llama/Llama-Guard-3-8B
|
||||
memory_banks: []
|
||||
datasets: []
|
||||
scoring_fns: []
|
||||
eval_tasks: []
|
||||
tool_groups:
|
||||
- toolgroup_id: builtin::websearch
|
||||
provider_id: tavily-search
|
||||
- toolgroup_id: builtin::memory
|
||||
provider_id: memory-runtime
|
||||
- toolgroup_id: builtin::code_interpreter
|
||||
provider_id: code-interpreter
|
||||
|
|
|
@ -2,7 +2,6 @@ version: '2'
|
|||
name: fireworks
|
||||
distribution_spec:
|
||||
description: Use Fireworks.AI for running LLM inference
|
||||
docker_image: null
|
||||
providers:
|
||||
inference:
|
||||
- remote::fireworks
|
||||
|
@ -25,4 +24,9 @@ distribution_spec:
|
|||
- inline::basic
|
||||
- inline::llm-as-judge
|
||||
- inline::braintrust
|
||||
tool_runtime:
|
||||
- remote::brave-search
|
||||
- remote::tavily-search
|
||||
- inline::code-interpreter
|
||||
- inline::memory-runtime
|
||||
image_type: conda
|
||||
|
|
|
@ -9,8 +9,12 @@ from pathlib import Path
|
|||
from llama_models.sku_list import all_registered_models
|
||||
|
||||
from llama_stack.apis.models.models import ModelType
|
||||
|
||||
from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput
|
||||
from llama_stack.distribution.datatypes import (
|
||||
ModelInput,
|
||||
Provider,
|
||||
ShieldInput,
|
||||
ToolGroupInput,
|
||||
)
|
||||
from llama_stack.providers.inline.inference.sentence_transformers import (
|
||||
SentenceTransformersInferenceConfig,
|
||||
)
|
||||
|
@ -30,6 +34,12 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
"eval": ["inline::meta-reference"],
|
||||
"datasetio": ["remote::huggingface", "inline::localfs"],
|
||||
"scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"],
|
||||
"tool_runtime": [
|
||||
"remote::brave-search",
|
||||
"remote::tavily-search",
|
||||
"inline::code-interpreter",
|
||||
"inline::memory-runtime",
|
||||
],
|
||||
}
|
||||
|
||||
name = "fireworks"
|
||||
|
@ -69,6 +79,20 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
"embedding_dimension": 384,
|
||||
},
|
||||
)
|
||||
default_tool_groups = [
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::websearch",
|
||||
provider_id="tavily-search",
|
||||
),
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::memory",
|
||||
provider_id="memory-runtime",
|
||||
),
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::code_interpreter",
|
||||
provider_id="code-interpreter",
|
||||
),
|
||||
]
|
||||
|
||||
return DistributionTemplate(
|
||||
name=name,
|
||||
|
@ -86,6 +110,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
},
|
||||
default_models=default_models + [embedding_model],
|
||||
default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")],
|
||||
default_tool_groups=default_tool_groups,
|
||||
),
|
||||
},
|
||||
run_config_env_vars={
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
version: '2'
|
||||
image_name: fireworks
|
||||
docker_image: null
|
||||
conda_env: fireworks
|
||||
apis:
|
||||
- agents
|
||||
|
@ -11,6 +10,7 @@ apis:
|
|||
- safety
|
||||
- scoring
|
||||
- telemetry
|
||||
- tool_runtime
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: fireworks
|
||||
|
@ -70,8 +70,24 @@ providers:
|
|||
provider_type: inline::braintrust
|
||||
config:
|
||||
openai_api_key: ${env.OPENAI_API_KEY:}
|
||||
tool_runtime:
|
||||
- provider_id: brave-search
|
||||
provider_type: remote::brave-search
|
||||
config:
|
||||
api_key: ${env.BRAVE_SEARCH_API_KEY:}
|
||||
max_results: 3
|
||||
- provider_id: tavily-search
|
||||
provider_type: remote::tavily-search
|
||||
config:
|
||||
api_key: ${env.TAVILY_SEARCH_API_KEY:}
|
||||
max_results: 3
|
||||
- provider_id: code-interpreter
|
||||
provider_type: inline::code-interpreter
|
||||
config: {}
|
||||
- provider_id: memory-runtime
|
||||
provider_type: inline::memory-runtime
|
||||
config: {}
|
||||
metadata_store:
|
||||
namespace: null
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/fireworks}/registry.db
|
||||
models:
|
||||
|
@ -129,14 +145,17 @@ models:
|
|||
embedding_dimension: 384
|
||||
model_id: all-MiniLM-L6-v2
|
||||
provider_id: sentence-transformers
|
||||
provider_model_id: null
|
||||
model_type: embedding
|
||||
shields:
|
||||
- params: null
|
||||
shield_id: meta-llama/Llama-Guard-3-8B
|
||||
provider_id: null
|
||||
provider_shield_id: null
|
||||
- shield_id: meta-llama/Llama-Guard-3-8B
|
||||
memory_banks: []
|
||||
datasets: []
|
||||
scoring_fns: []
|
||||
eval_tasks: []
|
||||
tool_groups:
|
||||
- toolgroup_id: builtin::websearch
|
||||
provider_id: tavily-search
|
||||
- toolgroup_id: builtin::memory
|
||||
provider_id: memory-runtime
|
||||
- toolgroup_id: builtin::code_interpreter
|
||||
provider_id: code-interpreter
|
||||
|
|
|
@ -2,7 +2,6 @@ version: '2'
|
|||
name: hf-endpoint
|
||||
distribution_spec:
|
||||
description: Use (an external) Hugging Face Inference Endpoint for running LLM inference
|
||||
docker_image: null
|
||||
providers:
|
||||
inference:
|
||||
- remote::hf::endpoint
|
||||
|
@ -25,4 +24,9 @@ distribution_spec:
|
|||
- inline::basic
|
||||
- inline::llm-as-judge
|
||||
- inline::braintrust
|
||||
tool_runtime:
|
||||
- remote::brave-search
|
||||
- remote::tavily-search
|
||||
- inline::code-interpreter
|
||||
- inline::memory-runtime
|
||||
image_type: conda
|
||||
|
|
|
@ -5,7 +5,12 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.models.models import ModelType
|
||||
from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput
|
||||
from llama_stack.distribution.datatypes import (
|
||||
ModelInput,
|
||||
Provider,
|
||||
ShieldInput,
|
||||
ToolGroupInput,
|
||||
)
|
||||
from llama_stack.providers.inline.inference.sentence_transformers import (
|
||||
SentenceTransformersInferenceConfig,
|
||||
)
|
||||
|
@ -24,6 +29,12 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
"eval": ["inline::meta-reference"],
|
||||
"datasetio": ["remote::huggingface", "inline::localfs"],
|
||||
"scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"],
|
||||
"tool_runtime": [
|
||||
"remote::brave-search",
|
||||
"remote::tavily-search",
|
||||
"inline::code-interpreter",
|
||||
"inline::memory-runtime",
|
||||
],
|
||||
}
|
||||
name = "hf-endpoint"
|
||||
inference_provider = Provider(
|
||||
|
@ -58,6 +69,20 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
"embedding_dimension": 384,
|
||||
},
|
||||
)
|
||||
default_tool_groups = [
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::websearch",
|
||||
provider_id="tavily-search",
|
||||
),
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::memory",
|
||||
provider_id="memory-runtime",
|
||||
),
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::code_interpreter",
|
||||
provider_id="code-interpreter",
|
||||
),
|
||||
]
|
||||
|
||||
return DistributionTemplate(
|
||||
name=name,
|
||||
|
@ -74,6 +99,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
"memory": [memory_provider],
|
||||
},
|
||||
default_models=[inference_model, embedding_model],
|
||||
default_tool_groups=default_tool_groups,
|
||||
),
|
||||
"run-with-safety.yaml": RunConfigSettings(
|
||||
provider_overrides={
|
||||
|
@ -96,6 +122,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
embedding_model,
|
||||
],
|
||||
default_shields=[ShieldInput(shield_id="${env.SAFETY_MODEL}")],
|
||||
default_tool_groups=default_tool_groups,
|
||||
),
|
||||
},
|
||||
run_config_env_vars={
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
version: '2'
|
||||
image_name: hf-endpoint
|
||||
docker_image: null
|
||||
conda_env: hf-endpoint
|
||||
apis:
|
||||
- agents
|
||||
|
@ -11,6 +10,7 @@ apis:
|
|||
- safety
|
||||
- scoring
|
||||
- telemetry
|
||||
- tool_runtime
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: hf-endpoint
|
||||
|
@ -75,33 +75,50 @@ providers:
|
|||
provider_type: inline::braintrust
|
||||
config:
|
||||
openai_api_key: ${env.OPENAI_API_KEY:}
|
||||
tool_runtime:
|
||||
- provider_id: brave-search
|
||||
provider_type: remote::brave-search
|
||||
config:
|
||||
api_key: ${env.BRAVE_SEARCH_API_KEY:}
|
||||
max_results: 3
|
||||
- provider_id: tavily-search
|
||||
provider_type: remote::tavily-search
|
||||
config:
|
||||
api_key: ${env.TAVILY_SEARCH_API_KEY:}
|
||||
max_results: 3
|
||||
- provider_id: code-interpreter
|
||||
provider_type: inline::code-interpreter
|
||||
config: {}
|
||||
- provider_id: memory-runtime
|
||||
provider_type: inline::memory-runtime
|
||||
config: {}
|
||||
metadata_store:
|
||||
namespace: null
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/registry.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: ${env.INFERENCE_MODEL}
|
||||
provider_id: hf-endpoint
|
||||
provider_model_id: null
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.SAFETY_MODEL}
|
||||
provider_id: hf-endpoint-safety
|
||||
provider_model_id: null
|
||||
model_type: llm
|
||||
- metadata:
|
||||
embedding_dimension: 384
|
||||
model_id: all-MiniLM-L6-v2
|
||||
provider_id: sentence-transformers
|
||||
provider_model_id: null
|
||||
model_type: embedding
|
||||
shields:
|
||||
- params: null
|
||||
shield_id: ${env.SAFETY_MODEL}
|
||||
provider_id: null
|
||||
provider_shield_id: null
|
||||
- shield_id: ${env.SAFETY_MODEL}
|
||||
memory_banks: []
|
||||
datasets: []
|
||||
scoring_fns: []
|
||||
eval_tasks: []
|
||||
tool_groups:
|
||||
- toolgroup_id: builtin::websearch
|
||||
provider_id: tavily-search
|
||||
- toolgroup_id: builtin::memory
|
||||
provider_id: memory-runtime
|
||||
- toolgroup_id: builtin::code_interpreter
|
||||
provider_id: code-interpreter
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
version: '2'
|
||||
image_name: hf-endpoint
|
||||
docker_image: null
|
||||
conda_env: hf-endpoint
|
||||
apis:
|
||||
- agents
|
||||
|
@ -11,6 +10,7 @@ apis:
|
|||
- safety
|
||||
- scoring
|
||||
- telemetry
|
||||
- tool_runtime
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: hf-endpoint
|
||||
|
@ -70,24 +70,45 @@ providers:
|
|||
provider_type: inline::braintrust
|
||||
config:
|
||||
openai_api_key: ${env.OPENAI_API_KEY:}
|
||||
tool_runtime:
|
||||
- provider_id: brave-search
|
||||
provider_type: remote::brave-search
|
||||
config:
|
||||
api_key: ${env.BRAVE_SEARCH_API_KEY:}
|
||||
max_results: 3
|
||||
- provider_id: tavily-search
|
||||
provider_type: remote::tavily-search
|
||||
config:
|
||||
api_key: ${env.TAVILY_SEARCH_API_KEY:}
|
||||
max_results: 3
|
||||
- provider_id: code-interpreter
|
||||
provider_type: inline::code-interpreter
|
||||
config: {}
|
||||
- provider_id: memory-runtime
|
||||
provider_type: inline::memory-runtime
|
||||
config: {}
|
||||
metadata_store:
|
||||
namespace: null
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/registry.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: ${env.INFERENCE_MODEL}
|
||||
provider_id: hf-endpoint
|
||||
provider_model_id: null
|
||||
model_type: llm
|
||||
- metadata:
|
||||
embedding_dimension: 384
|
||||
model_id: all-MiniLM-L6-v2
|
||||
provider_id: sentence-transformers
|
||||
provider_model_id: null
|
||||
model_type: embedding
|
||||
shields: []
|
||||
memory_banks: []
|
||||
datasets: []
|
||||
scoring_fns: []
|
||||
eval_tasks: []
|
||||
tool_groups:
|
||||
- toolgroup_id: builtin::websearch
|
||||
provider_id: tavily-search
|
||||
- toolgroup_id: builtin::memory
|
||||
provider_id: memory-runtime
|
||||
- toolgroup_id: builtin::code_interpreter
|
||||
provider_id: code-interpreter
|
||||
|
|
|
@ -2,7 +2,6 @@ version: '2'
|
|||
name: hf-serverless
|
||||
distribution_spec:
|
||||
description: Use (an external) Hugging Face Inference Endpoint for running LLM inference
|
||||
docker_image: null
|
||||
providers:
|
||||
inference:
|
||||
- remote::hf::serverless
|
||||
|
@ -25,4 +24,9 @@ distribution_spec:
|
|||
- inline::basic
|
||||
- inline::llm-as-judge
|
||||
- inline::braintrust
|
||||
tool_runtime:
|
||||
- remote::brave-search
|
||||
- remote::tavily-search
|
||||
- inline::code-interpreter
|
||||
- inline::memory-runtime
|
||||
image_type: conda
|
||||
|
|
|
@ -5,7 +5,12 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.models.models import ModelType
|
||||
from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput
|
||||
from llama_stack.distribution.datatypes import (
|
||||
ModelInput,
|
||||
Provider,
|
||||
ShieldInput,
|
||||
ToolGroupInput,
|
||||
)
|
||||
from llama_stack.providers.inline.inference.sentence_transformers import (
|
||||
SentenceTransformersInferenceConfig,
|
||||
)
|
||||
|
@ -24,6 +29,12 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
"eval": ["inline::meta-reference"],
|
||||
"datasetio": ["remote::huggingface", "inline::localfs"],
|
||||
"scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"],
|
||||
"tool_runtime": [
|
||||
"remote::brave-search",
|
||||
"remote::tavily-search",
|
||||
"inline::code-interpreter",
|
||||
"inline::memory-runtime",
|
||||
],
|
||||
}
|
||||
|
||||
name = "hf-serverless"
|
||||
|
@ -59,6 +70,20 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
"embedding_dimension": 384,
|
||||
},
|
||||
)
|
||||
default_tool_groups = [
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::websearch",
|
||||
provider_id="tavily-search",
|
||||
),
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::memory",
|
||||
provider_id="memory-runtime",
|
||||
),
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::code_interpreter",
|
||||
provider_id="code-interpreter",
|
||||
),
|
||||
]
|
||||
|
||||
return DistributionTemplate(
|
||||
name=name,
|
||||
|
@ -97,6 +122,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
embedding_model,
|
||||
],
|
||||
default_shields=[ShieldInput(shield_id="${env.SAFETY_MODEL}")],
|
||||
default_tool_groups=default_tool_groups,
|
||||
),
|
||||
},
|
||||
run_config_env_vars={
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
version: '2'
|
||||
image_name: hf-serverless
|
||||
docker_image: null
|
||||
conda_env: hf-serverless
|
||||
apis:
|
||||
- agents
|
||||
|
@ -11,6 +10,7 @@ apis:
|
|||
- safety
|
||||
- scoring
|
||||
- telemetry
|
||||
- tool_runtime
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: hf-serverless
|
||||
|
@ -75,33 +75,50 @@ providers:
|
|||
provider_type: inline::braintrust
|
||||
config:
|
||||
openai_api_key: ${env.OPENAI_API_KEY:}
|
||||
tool_runtime:
|
||||
- provider_id: brave-search
|
||||
provider_type: remote::brave-search
|
||||
config:
|
||||
api_key: ${env.BRAVE_SEARCH_API_KEY:}
|
||||
max_results: 3
|
||||
- provider_id: tavily-search
|
||||
provider_type: remote::tavily-search
|
||||
config:
|
||||
api_key: ${env.TAVILY_SEARCH_API_KEY:}
|
||||
max_results: 3
|
||||
- provider_id: code-interpreter
|
||||
provider_type: inline::code-interpreter
|
||||
config: {}
|
||||
- provider_id: memory-runtime
|
||||
provider_type: inline::memory-runtime
|
||||
config: {}
|
||||
metadata_store:
|
||||
namespace: null
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-serverless}/registry.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: ${env.INFERENCE_MODEL}
|
||||
provider_id: hf-serverless
|
||||
provider_model_id: null
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.SAFETY_MODEL}
|
||||
provider_id: hf-serverless-safety
|
||||
provider_model_id: null
|
||||
model_type: llm
|
||||
- metadata:
|
||||
embedding_dimension: 384
|
||||
model_id: all-MiniLM-L6-v2
|
||||
provider_id: sentence-transformers
|
||||
provider_model_id: null
|
||||
model_type: embedding
|
||||
shields:
|
||||
- params: null
|
||||
shield_id: ${env.SAFETY_MODEL}
|
||||
provider_id: null
|
||||
provider_shield_id: null
|
||||
- shield_id: ${env.SAFETY_MODEL}
|
||||
memory_banks: []
|
||||
datasets: []
|
||||
scoring_fns: []
|
||||
eval_tasks: []
|
||||
tool_groups:
|
||||
- toolgroup_id: builtin::websearch
|
||||
provider_id: tavily-search
|
||||
- toolgroup_id: builtin::memory
|
||||
provider_id: memory-runtime
|
||||
- toolgroup_id: builtin::code_interpreter
|
||||
provider_id: code-interpreter
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
version: '2'
|
||||
image_name: hf-serverless
|
||||
docker_image: null
|
||||
conda_env: hf-serverless
|
||||
apis:
|
||||
- agents
|
||||
|
@ -11,6 +10,7 @@ apis:
|
|||
- safety
|
||||
- scoring
|
||||
- telemetry
|
||||
- tool_runtime
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: hf-serverless
|
||||
|
@ -70,24 +70,39 @@ providers:
|
|||
provider_type: inline::braintrust
|
||||
config:
|
||||
openai_api_key: ${env.OPENAI_API_KEY:}
|
||||
tool_runtime:
|
||||
- provider_id: brave-search
|
||||
provider_type: remote::brave-search
|
||||
config:
|
||||
api_key: ${env.BRAVE_SEARCH_API_KEY:}
|
||||
max_results: 3
|
||||
- provider_id: tavily-search
|
||||
provider_type: remote::tavily-search
|
||||
config:
|
||||
api_key: ${env.TAVILY_SEARCH_API_KEY:}
|
||||
max_results: 3
|
||||
- provider_id: code-interpreter
|
||||
provider_type: inline::code-interpreter
|
||||
config: {}
|
||||
- provider_id: memory-runtime
|
||||
provider_type: inline::memory-runtime
|
||||
config: {}
|
||||
metadata_store:
|
||||
namespace: null
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-serverless}/registry.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: ${env.INFERENCE_MODEL}
|
||||
provider_id: hf-serverless
|
||||
provider_model_id: null
|
||||
model_type: llm
|
||||
- metadata:
|
||||
embedding_dimension: 384
|
||||
model_id: all-MiniLM-L6-v2
|
||||
provider_id: sentence-transformers
|
||||
provider_model_id: null
|
||||
model_type: embedding
|
||||
shields: []
|
||||
memory_banks: []
|
||||
datasets: []
|
||||
scoring_fns: []
|
||||
eval_tasks: []
|
||||
tool_groups: []
|
||||
|
|
|
@ -2,7 +2,6 @@ version: '2'
|
|||
name: meta-reference-gpu
|
||||
distribution_spec:
|
||||
description: Use Meta Reference for running LLM inference
|
||||
docker_image: null
|
||||
providers:
|
||||
inference:
|
||||
- inline::meta-reference
|
||||
|
@ -25,4 +24,9 @@ distribution_spec:
|
|||
- inline::basic
|
||||
- inline::llm-as-judge
|
||||
- inline::braintrust
|
||||
tool_runtime:
|
||||
- remote::brave-search
|
||||
- remote::tavily-search
|
||||
- inline::code-interpreter
|
||||
- inline::memory-runtime
|
||||
image_type: conda
|
||||
|
|
|
@ -7,8 +7,12 @@
|
|||
from pathlib import Path
|
||||
|
||||
from llama_stack.apis.models.models import ModelType
|
||||
|
||||
from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput
|
||||
from llama_stack.distribution.datatypes import (
|
||||
ModelInput,
|
||||
Provider,
|
||||
ShieldInput,
|
||||
ToolGroupInput,
|
||||
)
|
||||
from llama_stack.providers.inline.inference.meta_reference import (
|
||||
MetaReferenceInferenceConfig,
|
||||
)
|
||||
|
@ -29,6 +33,12 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
"eval": ["inline::meta-reference"],
|
||||
"datasetio": ["remote::huggingface", "inline::localfs"],
|
||||
"scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"],
|
||||
"tool_runtime": [
|
||||
"remote::brave-search",
|
||||
"remote::tavily-search",
|
||||
"inline::code-interpreter",
|
||||
"inline::memory-runtime",
|
||||
],
|
||||
}
|
||||
name = "meta-reference-gpu"
|
||||
inference_provider = Provider(
|
||||
|
@ -66,6 +76,20 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
model_id="${env.SAFETY_MODEL}",
|
||||
provider_id="meta-reference-safety",
|
||||
)
|
||||
default_tool_groups = [
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::websearch",
|
||||
provider_id="tavily-search",
|
||||
),
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::memory",
|
||||
provider_id="memory-runtime",
|
||||
),
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::code_interpreter",
|
||||
provider_id="code-interpreter",
|
||||
),
|
||||
]
|
||||
|
||||
return DistributionTemplate(
|
||||
name=name,
|
||||
|
@ -104,6 +128,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
embedding_model,
|
||||
],
|
||||
default_shields=[ShieldInput(shield_id="${env.SAFETY_MODEL}")],
|
||||
default_tool_groups=default_tool_groups,
|
||||
),
|
||||
},
|
||||
run_config_env_vars={
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
version: '2'
|
||||
image_name: meta-reference-gpu
|
||||
docker_image: null
|
||||
conda_env: meta-reference-gpu
|
||||
apis:
|
||||
- agents
|
||||
|
@ -11,6 +10,7 @@ apis:
|
|||
- safety
|
||||
- scoring
|
||||
- telemetry
|
||||
- tool_runtime
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: meta-reference-inference
|
||||
|
@ -77,33 +77,50 @@ providers:
|
|||
provider_type: inline::braintrust
|
||||
config:
|
||||
openai_api_key: ${env.OPENAI_API_KEY:}
|
||||
tool_runtime:
|
||||
- provider_id: brave-search
|
||||
provider_type: remote::brave-search
|
||||
config:
|
||||
api_key: ${env.BRAVE_SEARCH_API_KEY:}
|
||||
max_results: 3
|
||||
- provider_id: tavily-search
|
||||
provider_type: remote::tavily-search
|
||||
config:
|
||||
api_key: ${env.TAVILY_SEARCH_API_KEY:}
|
||||
max_results: 3
|
||||
- provider_id: code-interpreter
|
||||
provider_type: inline::code-interpreter
|
||||
config: {}
|
||||
- provider_id: memory-runtime
|
||||
provider_type: inline::memory-runtime
|
||||
config: {}
|
||||
metadata_store:
|
||||
namespace: null
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/registry.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: ${env.INFERENCE_MODEL}
|
||||
provider_id: meta-reference-inference
|
||||
provider_model_id: null
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.SAFETY_MODEL}
|
||||
provider_id: meta-reference-safety
|
||||
provider_model_id: null
|
||||
model_type: llm
|
||||
- metadata:
|
||||
embedding_dimension: 384
|
||||
model_id: all-MiniLM-L6-v2
|
||||
provider_id: sentence-transformers
|
||||
provider_model_id: null
|
||||
model_type: embedding
|
||||
shields:
|
||||
- params: null
|
||||
shield_id: ${env.SAFETY_MODEL}
|
||||
provider_id: null
|
||||
provider_shield_id: null
|
||||
- shield_id: ${env.SAFETY_MODEL}
|
||||
memory_banks: []
|
||||
datasets: []
|
||||
scoring_fns: []
|
||||
eval_tasks: []
|
||||
tool_groups:
|
||||
- toolgroup_id: builtin::websearch
|
||||
provider_id: tavily-search
|
||||
- toolgroup_id: builtin::memory
|
||||
provider_id: memory-runtime
|
||||
- toolgroup_id: builtin::code_interpreter
|
||||
provider_id: code-interpreter
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
version: '2'
|
||||
image_name: meta-reference-gpu
|
||||
docker_image: null
|
||||
conda_env: meta-reference-gpu
|
||||
apis:
|
||||
- agents
|
||||
|
@ -11,6 +10,7 @@ apis:
|
|||
- safety
|
||||
- scoring
|
||||
- telemetry
|
||||
- tool_runtime
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: meta-reference-inference
|
||||
|
@ -71,24 +71,39 @@ providers:
|
|||
provider_type: inline::braintrust
|
||||
config:
|
||||
openai_api_key: ${env.OPENAI_API_KEY:}
|
||||
tool_runtime:
|
||||
- provider_id: brave-search
|
||||
provider_type: remote::brave-search
|
||||
config:
|
||||
api_key: ${env.BRAVE_SEARCH_API_KEY:}
|
||||
max_results: 3
|
||||
- provider_id: tavily-search
|
||||
provider_type: remote::tavily-search
|
||||
config:
|
||||
api_key: ${env.TAVILY_SEARCH_API_KEY:}
|
||||
max_results: 3
|
||||
- provider_id: code-interpreter
|
||||
provider_type: inline::code-interpreter
|
||||
config: {}
|
||||
- provider_id: memory-runtime
|
||||
provider_type: inline::memory-runtime
|
||||
config: {}
|
||||
metadata_store:
|
||||
namespace: null
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/registry.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: ${env.INFERENCE_MODEL}
|
||||
provider_id: meta-reference-inference
|
||||
provider_model_id: null
|
||||
model_type: llm
|
||||
- metadata:
|
||||
embedding_dimension: 384
|
||||
model_id: all-MiniLM-L6-v2
|
||||
provider_id: sentence-transformers
|
||||
provider_model_id: null
|
||||
model_type: embedding
|
||||
shields: []
|
||||
memory_banks: []
|
||||
datasets: []
|
||||
scoring_fns: []
|
||||
eval_tasks: []
|
||||
tool_groups: []
|
||||
|
|
|
@ -2,7 +2,6 @@ version: '2'
|
|||
name: meta-reference-quantized-gpu
|
||||
distribution_spec:
|
||||
description: Use Meta Reference with fp8, int4 quantization for running LLM inference
|
||||
docker_image: null
|
||||
providers:
|
||||
inference:
|
||||
- inline::meta-reference-quantized
|
||||
|
@ -25,4 +24,9 @@ distribution_spec:
|
|||
- inline::basic
|
||||
- inline::llm-as-judge
|
||||
- inline::braintrust
|
||||
tool_runtime:
|
||||
- remote::brave-search
|
||||
- remote::tavily-search
|
||||
- inline::code-interpreter
|
||||
- inline::memory-runtime
|
||||
image_type: conda
|
||||
|
|
|
@ -7,8 +7,7 @@
|
|||
from pathlib import Path
|
||||
|
||||
from llama_stack.apis.models.models import ModelType
|
||||
|
||||
from llama_stack.distribution.datatypes import ModelInput, Provider
|
||||
from llama_stack.distribution.datatypes import ModelInput, Provider, ToolGroupInput
|
||||
from llama_stack.providers.inline.inference.meta_reference import (
|
||||
MetaReferenceQuantizedInferenceConfig,
|
||||
)
|
||||
|
@ -29,7 +28,27 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
"eval": ["inline::meta-reference"],
|
||||
"datasetio": ["remote::huggingface", "inline::localfs"],
|
||||
"scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"],
|
||||
"tool_runtime": [
|
||||
"remote::brave-search",
|
||||
"remote::tavily-search",
|
||||
"inline::code-interpreter",
|
||||
"inline::memory-runtime",
|
||||
],
|
||||
}
|
||||
default_tool_groups = [
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::websearch",
|
||||
provider_id="tavily-search",
|
||||
),
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::memory",
|
||||
provider_id="memory-runtime",
|
||||
),
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::code_interpreter",
|
||||
provider_id="code-interpreter",
|
||||
),
|
||||
]
|
||||
name = "meta-reference-quantized-gpu"
|
||||
inference_provider = Provider(
|
||||
provider_id="meta-reference-inference",
|
||||
|
@ -76,6 +95,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
"memory": [memory_provider],
|
||||
},
|
||||
default_models=[inference_model, embedding_model],
|
||||
default_tool_groups=default_tool_groups,
|
||||
),
|
||||
},
|
||||
run_config_env_vars={
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
version: '2'
|
||||
image_name: meta-reference-quantized-gpu
|
||||
docker_image: null
|
||||
conda_env: meta-reference-quantized-gpu
|
||||
apis:
|
||||
- agents
|
||||
|
@ -11,6 +10,7 @@ apis:
|
|||
- safety
|
||||
- scoring
|
||||
- telemetry
|
||||
- tool_runtime
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: meta-reference-inference
|
||||
|
@ -73,24 +73,45 @@ providers:
|
|||
provider_type: inline::braintrust
|
||||
config:
|
||||
openai_api_key: ${env.OPENAI_API_KEY:}
|
||||
tool_runtime:
|
||||
- provider_id: brave-search
|
||||
provider_type: remote::brave-search
|
||||
config:
|
||||
api_key: ${env.BRAVE_SEARCH_API_KEY:}
|
||||
max_results: 3
|
||||
- provider_id: tavily-search
|
||||
provider_type: remote::tavily-search
|
||||
config:
|
||||
api_key: ${env.TAVILY_SEARCH_API_KEY:}
|
||||
max_results: 3
|
||||
- provider_id: code-interpreter
|
||||
provider_type: inline::code-interpreter
|
||||
config: {}
|
||||
- provider_id: memory-runtime
|
||||
provider_type: inline::memory-runtime
|
||||
config: {}
|
||||
metadata_store:
|
||||
namespace: null
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-quantized-gpu}/registry.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: ${env.INFERENCE_MODEL}
|
||||
provider_id: meta-reference-inference
|
||||
provider_model_id: null
|
||||
model_type: llm
|
||||
- metadata:
|
||||
embedding_dimension: 384
|
||||
model_id: all-MiniLM-L6-v2
|
||||
provider_id: sentence-transformers
|
||||
provider_model_id: null
|
||||
model_type: embedding
|
||||
shields: []
|
||||
memory_banks: []
|
||||
datasets: []
|
||||
scoring_fns: []
|
||||
eval_tasks: []
|
||||
tool_groups:
|
||||
- toolgroup_id: builtin::websearch
|
||||
provider_id: tavily-search
|
||||
- toolgroup_id: builtin::memory
|
||||
provider_id: memory-runtime
|
||||
- toolgroup_id: builtin::code_interpreter
|
||||
provider_id: code-interpreter
|
||||
|
|
|
@ -2,7 +2,6 @@ version: '2'
|
|||
name: ollama
|
||||
distribution_spec:
|
||||
description: Use (an external) Ollama server for running LLM inference
|
||||
docker_image: null
|
||||
providers:
|
||||
inference:
|
||||
- remote::ollama
|
||||
|
@ -25,4 +24,9 @@ distribution_spec:
|
|||
- inline::basic
|
||||
- inline::llm-as-judge
|
||||
- inline::braintrust
|
||||
tool_runtime:
|
||||
- remote::brave-search
|
||||
- remote::tavily-search
|
||||
- inline::code-interpreter
|
||||
- inline::memory-runtime
|
||||
image_type: conda
|
||||
|
|
|
@ -7,8 +7,12 @@
|
|||
from pathlib import Path
|
||||
|
||||
from llama_stack.apis.models.models import ModelType
|
||||
|
||||
from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput
|
||||
from llama_stack.distribution.datatypes import (
|
||||
ModelInput,
|
||||
Provider,
|
||||
ShieldInput,
|
||||
ToolGroupInput,
|
||||
)
|
||||
from llama_stack.providers.inline.inference.sentence_transformers import (
|
||||
SentenceTransformersInferenceConfig,
|
||||
)
|
||||
|
@ -27,6 +31,12 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
"eval": ["inline::meta-reference"],
|
||||
"datasetio": ["remote::huggingface", "inline::localfs"],
|
||||
"scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"],
|
||||
"tool_runtime": [
|
||||
"remote::brave-search",
|
||||
"remote::tavily-search",
|
||||
"inline::code-interpreter",
|
||||
"inline::memory-runtime",
|
||||
],
|
||||
}
|
||||
name = "ollama"
|
||||
inference_provider = Provider(
|
||||
|
@ -61,6 +71,20 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
"embedding_dimension": 384,
|
||||
},
|
||||
)
|
||||
default_tool_groups = [
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::websearch",
|
||||
provider_id="tavily-search",
|
||||
),
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::memory",
|
||||
provider_id="memory-runtime",
|
||||
),
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::code_interpreter",
|
||||
provider_id="code-interpreter",
|
||||
),
|
||||
]
|
||||
|
||||
return DistributionTemplate(
|
||||
name=name,
|
||||
|
@ -92,6 +116,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
embedding_model,
|
||||
],
|
||||
default_shields=[ShieldInput(shield_id="${env.SAFETY_MODEL}")],
|
||||
default_tool_groups=default_tool_groups,
|
||||
),
|
||||
},
|
||||
run_config_env_vars={
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
version: '2'
|
||||
image_name: ollama
|
||||
docker_image: null
|
||||
conda_env: ollama
|
||||
apis:
|
||||
- agents
|
||||
|
@ -11,6 +10,7 @@ apis:
|
|||
- safety
|
||||
- scoring
|
||||
- telemetry
|
||||
- tool_runtime
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: ollama
|
||||
|
@ -69,33 +69,50 @@ providers:
|
|||
provider_type: inline::braintrust
|
||||
config:
|
||||
openai_api_key: ${env.OPENAI_API_KEY:}
|
||||
tool_runtime:
|
||||
- provider_id: brave-search
|
||||
provider_type: remote::brave-search
|
||||
config:
|
||||
api_key: ${env.BRAVE_SEARCH_API_KEY:}
|
||||
max_results: 3
|
||||
- provider_id: tavily-search
|
||||
provider_type: remote::tavily-search
|
||||
config:
|
||||
api_key: ${env.TAVILY_SEARCH_API_KEY:}
|
||||
max_results: 3
|
||||
- provider_id: code-interpreter
|
||||
provider_type: inline::code-interpreter
|
||||
config: {}
|
||||
- provider_id: memory-runtime
|
||||
provider_type: inline::memory-runtime
|
||||
config: {}
|
||||
metadata_store:
|
||||
namespace: null
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/registry.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: ${env.INFERENCE_MODEL}
|
||||
provider_id: ollama
|
||||
provider_model_id: null
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.SAFETY_MODEL}
|
||||
provider_id: ollama
|
||||
provider_model_id: null
|
||||
model_type: llm
|
||||
- metadata:
|
||||
embedding_dimension: 384
|
||||
model_id: all-MiniLM-L6-v2
|
||||
provider_id: sentence-transformers
|
||||
provider_model_id: null
|
||||
model_type: embedding
|
||||
shields:
|
||||
- params: null
|
||||
shield_id: ${env.SAFETY_MODEL}
|
||||
provider_id: null
|
||||
provider_shield_id: null
|
||||
- shield_id: ${env.SAFETY_MODEL}
|
||||
memory_banks: []
|
||||
datasets: []
|
||||
scoring_fns: []
|
||||
eval_tasks: []
|
||||
tool_groups:
|
||||
- toolgroup_id: builtin::websearch
|
||||
provider_id: tavily-search
|
||||
- toolgroup_id: builtin::memory
|
||||
provider_id: memory-runtime
|
||||
- toolgroup_id: builtin::code_interpreter
|
||||
provider_id: code-interpreter
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
version: '2'
|
||||
image_name: ollama
|
||||
docker_image: null
|
||||
conda_env: ollama
|
||||
apis:
|
||||
- agents
|
||||
|
@ -11,6 +10,7 @@ apis:
|
|||
- safety
|
||||
- scoring
|
||||
- telemetry
|
||||
- tool_runtime
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: ollama
|
||||
|
@ -69,24 +69,39 @@ providers:
|
|||
provider_type: inline::braintrust
|
||||
config:
|
||||
openai_api_key: ${env.OPENAI_API_KEY:}
|
||||
tool_runtime:
|
||||
- provider_id: brave-search
|
||||
provider_type: remote::brave-search
|
||||
config:
|
||||
api_key: ${env.BRAVE_SEARCH_API_KEY:}
|
||||
max_results: 3
|
||||
- provider_id: tavily-search
|
||||
provider_type: remote::tavily-search
|
||||
config:
|
||||
api_key: ${env.TAVILY_SEARCH_API_KEY:}
|
||||
max_results: 3
|
||||
- provider_id: code-interpreter
|
||||
provider_type: inline::code-interpreter
|
||||
config: {}
|
||||
- provider_id: memory-runtime
|
||||
provider_type: inline::memory-runtime
|
||||
config: {}
|
||||
metadata_store:
|
||||
namespace: null
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/registry.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: ${env.INFERENCE_MODEL}
|
||||
provider_id: ollama
|
||||
provider_model_id: null
|
||||
model_type: llm
|
||||
- metadata:
|
||||
embedding_dimension: 384
|
||||
model_id: all-MiniLM-L6-v2
|
||||
provider_id: sentence-transformers
|
||||
provider_model_id: null
|
||||
model_type: embedding
|
||||
shields: []
|
||||
memory_banks: []
|
||||
datasets: []
|
||||
scoring_fns: []
|
||||
eval_tasks: []
|
||||
tool_groups: []
|
||||
|
|
|
@ -2,7 +2,6 @@ version: '2'
|
|||
name: remote-vllm
|
||||
distribution_spec:
|
||||
description: Use (an external) vLLM server for running LLM inference
|
||||
docker_image: null
|
||||
providers:
|
||||
inference:
|
||||
- remote::vllm
|
||||
|
@ -16,4 +15,9 @@ distribution_spec:
|
|||
- inline::meta-reference
|
||||
telemetry:
|
||||
- inline::meta-reference
|
||||
tool_runtime:
|
||||
- remote::brave-search
|
||||
- remote::tavily-search
|
||||
- inline::code-interpreter
|
||||
- inline::memory-runtime
|
||||
image_type: conda
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue