working end to end client sdk tests with custom tools

This commit is contained in:
Dinesh Yeduguru 2024-12-23 18:27:55 -08:00
parent 1a66ddc1b5
commit 4dd2f4c363
5 changed files with 304 additions and 149 deletions

View file

@ -3711,6 +3711,12 @@
"type": "string" "type": "string"
} }
}, },
"custom_tools": {
"type": "array",
"items": {
"$ref": "#/components/schemas/CustomToolDef"
}
},
"preprocessing_tools": { "preprocessing_tools": {
"type": "array", "type": "array",
"items": { "items": {
@ -3747,6 +3753,111 @@
"enable_session_persistence" "enable_session_persistence"
] ]
}, },
"CustomToolDef": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "custom",
"default": "custom"
},
"name": {
"type": "string"
},
"description": {
"type": "string"
},
"parameters": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ToolParameter"
}
},
"metadata": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "null"
},
{
"type": "boolean"
},
{
"type": "number"
},
{
"type": "string"
},
{
"type": "array"
},
{
"type": "object"
}
]
}
},
"tool_prompt_format": {
"$ref": "#/components/schemas/ToolPromptFormat",
"default": "json"
}
},
"additionalProperties": false,
"required": [
"type",
"name",
"description",
"parameters",
"metadata"
]
},
"ToolParameter": {
"type": "object",
"properties": {
"name": {
"type": "string"
},
"parameter_type": {
"type": "string"
},
"description": {
"type": "string"
},
"required": {
"type": "boolean"
},
"default": {
"oneOf": [
{
"type": "null"
},
{
"type": "boolean"
},
{
"type": "number"
},
{
"type": "string"
},
{
"type": "array"
},
{
"type": "object"
}
]
}
},
"additionalProperties": false,
"required": [
"name",
"parameter_type",
"description",
"required"
]
},
"CreateAgentRequest": { "CreateAgentRequest": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -4403,39 +4514,16 @@
"session_id" "session_id"
] ]
}, },
"MCPToolGroupDef": { "BuiltInToolDef": {
"type": "object", "type": "object",
"properties": { "properties": {
"type": { "type": {
"type": "string", "type": "string",
"const": "model_context_protocol", "const": "built_in",
"default": "model_context_protocol" "default": "built_in"
}, },
"endpoint": { "built_in_type": {
"$ref": "#/components/schemas/URL" "$ref": "#/components/schemas/BuiltinTool"
}
},
"additionalProperties": false,
"required": [
"type",
"endpoint"
],
"title": "A tool group that is defined by in a model context protocol server. Refer to https://modelcontextprotocol.io/docs/concepts/tools for more information."
},
"ToolDef": {
"type": "object",
"properties": {
"name": {
"type": "string"
},
"description": {
"type": "string"
},
"parameters": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ToolParameter"
}
}, },
"metadata": { "metadata": {
"type": "object", "type": "object",
@ -4461,18 +4549,41 @@
} }
] ]
} }
},
"tool_prompt_format": {
"$ref": "#/components/schemas/ToolPromptFormat",
"default": "json"
} }
}, },
"additionalProperties": false, "additionalProperties": false,
"required": [ "required": [
"name", "type",
"description", "built_in_type"
"parameters", ]
"metadata" },
"MCPToolGroupDef": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "model_context_protocol",
"default": "model_context_protocol"
},
"endpoint": {
"$ref": "#/components/schemas/URL"
}
},
"additionalProperties": false,
"required": [
"type",
"endpoint"
],
"title": "A tool group that is defined by in a model context protocol server. Refer to https://modelcontextprotocol.io/docs/concepts/tools for more information."
},
"ToolDef": {
"oneOf": [
{
"$ref": "#/components/schemas/CustomToolDef"
},
{
"$ref": "#/components/schemas/BuiltInToolDef"
}
] ]
}, },
"ToolGroupDef": { "ToolGroupDef": {
@ -4485,52 +4596,6 @@
} }
] ]
}, },
"ToolParameter": {
"type": "object",
"properties": {
"name": {
"type": "string"
},
"parameter_type": {
"type": "string"
},
"description": {
"type": "string"
},
"required": {
"type": "boolean"
},
"default": {
"oneOf": [
{
"type": "null"
},
{
"type": "boolean"
},
{
"type": "number"
},
{
"type": "string"
},
{
"type": "array"
},
{
"type": "object"
}
]
}
},
"additionalProperties": false,
"required": [
"name",
"parameter_type",
"description",
"required"
]
},
"UserDefinedToolGroupDef": { "UserDefinedToolGroupDef": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -5797,6 +5862,9 @@
"tool_group": { "tool_group": {
"type": "string" "type": "string"
}, },
"tool_host": {
"$ref": "#/components/schemas/ToolHost"
},
"description": { "description": {
"type": "string" "type": "string"
}, },
@ -5806,6 +5874,9 @@
"$ref": "#/components/schemas/ToolParameter" "$ref": "#/components/schemas/ToolParameter"
} }
}, },
"built_in_type": {
"$ref": "#/components/schemas/BuiltinTool"
},
"metadata": { "metadata": {
"type": "object", "type": "object",
"additionalProperties": { "additionalProperties": {
@ -5840,12 +5911,22 @@
"required": [ "required": [
"identifier", "identifier",
"provider_resource_id", "provider_resource_id",
"provider_id",
"type", "type",
"tool_group", "tool_group",
"tool_host",
"description", "description",
"parameters" "parameters"
] ]
}, },
"ToolHost": {
"type": "string",
"enum": [
"distribution",
"client",
"model_context_protocol"
]
},
"ToolGroup": { "ToolGroup": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -7942,6 +8023,10 @@
"name": "BenchmarkEvalTaskConfig", "name": "BenchmarkEvalTaskConfig",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/BenchmarkEvalTaskConfig\" />" "description": "<SchemaDefinition schemaRef=\"#/components/schemas/BenchmarkEvalTaskConfig\" />"
}, },
{
"name": "BuiltInToolDef",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/BuiltInToolDef\" />"
},
{ {
"name": "BuiltinTool", "name": "BuiltinTool",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/BuiltinTool\" />" "description": "<SchemaDefinition schemaRef=\"#/components/schemas/BuiltinTool\" />"
@ -8002,6 +8087,10 @@
"name": "CreateAgentTurnRequest", "name": "CreateAgentTurnRequest",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/CreateAgentTurnRequest\" />" "description": "<SchemaDefinition schemaRef=\"#/components/schemas/CreateAgentTurnRequest\" />"
}, },
{
"name": "CustomToolDef",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/CustomToolDef\" />"
},
{ {
"name": "DPOAlignmentConfig", "name": "DPOAlignmentConfig",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/DPOAlignmentConfig\" />" "description": "<SchemaDefinition schemaRef=\"#/components/schemas/DPOAlignmentConfig\" />"
@ -8481,6 +8570,10 @@
{ {
"name": "ToolGroups" "name": "ToolGroups"
}, },
{
"name": "ToolHost",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/ToolHost\" />"
},
{ {
"name": "ToolInvocationResult", "name": "ToolInvocationResult",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/ToolInvocationResult\" />" "description": "<SchemaDefinition schemaRef=\"#/components/schemas/ToolInvocationResult\" />"
@ -8624,6 +8717,7 @@
"BatchCompletionRequest", "BatchCompletionRequest",
"BatchCompletionResponse", "BatchCompletionResponse",
"BenchmarkEvalTaskConfig", "BenchmarkEvalTaskConfig",
"BuiltInToolDef",
"BuiltinTool", "BuiltinTool",
"CancelTrainingJobRequest", "CancelTrainingJobRequest",
"ChatCompletionRequest", "ChatCompletionRequest",
@ -8639,6 +8733,7 @@
"CreateAgentRequest", "CreateAgentRequest",
"CreateAgentSessionRequest", "CreateAgentSessionRequest",
"CreateAgentTurnRequest", "CreateAgentTurnRequest",
"CustomToolDef",
"DPOAlignmentConfig", "DPOAlignmentConfig",
"DataConfig", "DataConfig",
"Dataset", "Dataset",
@ -8746,6 +8841,7 @@
"ToolExecutionStep", "ToolExecutionStep",
"ToolGroup", "ToolGroup",
"ToolGroupDef", "ToolGroupDef",
"ToolHost",
"ToolInvocationResult", "ToolInvocationResult",
"ToolParamDefinition", "ToolParamDefinition",
"ToolParameter", "ToolParameter",

View file

@ -21,6 +21,10 @@ components:
items: items:
type: string type: string
type: array type: array
custom_tools:
items:
$ref: '#/components/schemas/CustomToolDef'
type: array
enable_session_persistence: enable_session_persistence:
type: boolean type: boolean
input_shields: input_shields:
@ -389,6 +393,29 @@ components:
- type - type
- eval_candidate - eval_candidate
type: object type: object
BuiltInToolDef:
additionalProperties: false
properties:
built_in_type:
$ref: '#/components/schemas/BuiltinTool'
metadata:
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
type: object
type:
const: built_in
default: built_in
type: string
required:
- type
- built_in_type
type: object
BuiltinTool: BuiltinTool:
enum: enum:
- brave_search - brave_search
@ -607,6 +634,41 @@ components:
- session_id - session_id
- messages - messages
type: object type: object
CustomToolDef:
additionalProperties: false
properties:
description:
type: string
metadata:
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
type: object
name:
type: string
parameters:
items:
$ref: '#/components/schemas/ToolParameter'
type: array
tool_prompt_format:
$ref: '#/components/schemas/ToolPromptFormat'
default: json
type:
const: custom
default: custom
type: string
required:
- type
- name
- description
- parameters
- metadata
type: object
DPOAlignmentConfig: DPOAlignmentConfig:
additionalProperties: false additionalProperties: false
properties: properties:
@ -2557,6 +2619,8 @@ components:
Tool: Tool:
additionalProperties: false additionalProperties: false
properties: properties:
built_in_type:
$ref: '#/components/schemas/BuiltinTool'
description: description:
type: string type: string
identifier: identifier:
@ -2581,6 +2645,8 @@ components:
type: string type: string
tool_group: tool_group:
type: string type: string
tool_host:
$ref: '#/components/schemas/ToolHost'
tool_prompt_format: tool_prompt_format:
$ref: '#/components/schemas/ToolPromptFormat' $ref: '#/components/schemas/ToolPromptFormat'
default: json default: json
@ -2591,8 +2657,10 @@ components:
required: required:
- identifier - identifier
- provider_resource_id - provider_resource_id
- provider_id
- type - type
- tool_group - tool_group
- tool_host
- description - description
- parameters - parameters
type: object type: object
@ -2661,35 +2729,9 @@ components:
- required - required
type: string type: string
ToolDef: ToolDef:
additionalProperties: false oneOf:
properties: - $ref: '#/components/schemas/CustomToolDef'
description: - $ref: '#/components/schemas/BuiltInToolDef'
type: string
metadata:
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
type: object
name:
type: string
parameters:
items:
$ref: '#/components/schemas/ToolParameter'
type: array
tool_prompt_format:
$ref: '#/components/schemas/ToolPromptFormat'
default: json
required:
- name
- description
- parameters
- metadata
type: object
ToolDefinition: ToolDefinition:
additionalProperties: false additionalProperties: false
properties: properties:
@ -2761,6 +2803,12 @@ components:
oneOf: oneOf:
- $ref: '#/components/schemas/MCPToolGroupDef' - $ref: '#/components/schemas/MCPToolGroupDef'
- $ref: '#/components/schemas/UserDefinedToolGroupDef' - $ref: '#/components/schemas/UserDefinedToolGroupDef'
ToolHost:
enum:
- distribution
- client
- model_context_protocol
type: string
ToolInvocationResult: ToolInvocationResult:
additionalProperties: false additionalProperties: false
properties: properties:
@ -4738,6 +4786,8 @@ tags:
- description: <SchemaDefinition schemaRef="#/components/schemas/BenchmarkEvalTaskConfig" - description: <SchemaDefinition schemaRef="#/components/schemas/BenchmarkEvalTaskConfig"
/> />
name: BenchmarkEvalTaskConfig name: BenchmarkEvalTaskConfig
- description: <SchemaDefinition schemaRef="#/components/schemas/BuiltInToolDef" />
name: BuiltInToolDef
- description: <SchemaDefinition schemaRef="#/components/schemas/BuiltinTool" /> - description: <SchemaDefinition schemaRef="#/components/schemas/BuiltinTool" />
name: BuiltinTool name: BuiltinTool
- description: <SchemaDefinition schemaRef="#/components/schemas/CancelTrainingJobRequest" - description: <SchemaDefinition schemaRef="#/components/schemas/CancelTrainingJobRequest"
@ -4797,6 +4847,8 @@ tags:
- description: <SchemaDefinition schemaRef="#/components/schemas/CreateAgentTurnRequest" - description: <SchemaDefinition schemaRef="#/components/schemas/CreateAgentTurnRequest"
/> />
name: CreateAgentTurnRequest name: CreateAgentTurnRequest
- description: <SchemaDefinition schemaRef="#/components/schemas/CustomToolDef" />
name: CustomToolDef
- description: <SchemaDefinition schemaRef="#/components/schemas/DPOAlignmentConfig" - description: <SchemaDefinition schemaRef="#/components/schemas/DPOAlignmentConfig"
/> />
name: DPOAlignmentConfig name: DPOAlignmentConfig
@ -5111,6 +5163,8 @@ tags:
- description: <SchemaDefinition schemaRef="#/components/schemas/ToolGroupDef" /> - description: <SchemaDefinition schemaRef="#/components/schemas/ToolGroupDef" />
name: ToolGroupDef name: ToolGroupDef
- name: ToolGroups - name: ToolGroups
- description: <SchemaDefinition schemaRef="#/components/schemas/ToolHost" />
name: ToolHost
- description: <SchemaDefinition schemaRef="#/components/schemas/ToolInvocationResult" - description: <SchemaDefinition schemaRef="#/components/schemas/ToolInvocationResult"
/> />
name: ToolInvocationResult name: ToolInvocationResult
@ -5224,6 +5278,7 @@ x-tagGroups:
- BatchCompletionRequest - BatchCompletionRequest
- BatchCompletionResponse - BatchCompletionResponse
- BenchmarkEvalTaskConfig - BenchmarkEvalTaskConfig
- BuiltInToolDef
- BuiltinTool - BuiltinTool
- CancelTrainingJobRequest - CancelTrainingJobRequest
- ChatCompletionRequest - ChatCompletionRequest
@ -5239,6 +5294,7 @@ x-tagGroups:
- CreateAgentRequest - CreateAgentRequest
- CreateAgentSessionRequest - CreateAgentSessionRequest
- CreateAgentTurnRequest - CreateAgentTurnRequest
- CustomToolDef
- DPOAlignmentConfig - DPOAlignmentConfig
- DataConfig - DataConfig
- Dataset - Dataset
@ -5346,6 +5402,7 @@ x-tagGroups:
- ToolExecutionStep - ToolExecutionStep
- ToolGroup - ToolGroup
- ToolGroupDef - ToolGroupDef
- ToolHost
- ToolInvocationResult - ToolInvocationResult
- ToolParamDefinition - ToolParamDefinition
- ToolParameter - ToolParameter

View file

@ -18,13 +18,11 @@ from typing import (
runtime_checkable, runtime_checkable,
) )
from llama_models.llama3.api.datatypes import ToolParamDefinition
from llama_models.schema_utils import json_schema_type, webmethod from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import Annotated from typing_extensions import Annotated
from llama_stack.apis.common.content_types import URL, InterleavedContent from llama_stack.apis.common.content_types import URL, InterleavedContent
from llama_stack.apis.common.deployment_types import RestAPIExecutionConfig
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
CompletionMessage, CompletionMessage,
SamplingParams, SamplingParams,
@ -140,6 +138,7 @@ class AgentConfigCommon(BaseModel):
input_shields: Optional[List[str]] = Field(default_factory=list) input_shields: Optional[List[str]] = Field(default_factory=list)
output_shields: Optional[List[str]] = Field(default_factory=list) output_shields: Optional[List[str]] = Field(default_factory=list)
available_tools: Optional[List[str]] = Field(default_factory=list) available_tools: Optional[List[str]] = Field(default_factory=list)
custom_tools: Optional[List[CustomToolDef]] = Field(default_factory=list)
preprocessing_tools: Optional[List[str]] = Field(default_factory=list) preprocessing_tools: Optional[List[str]] = Field(default_factory=list)
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto) tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
tool_prompt_format: Optional[ToolPromptFormat] = Field( tool_prompt_format: Optional[ToolPromptFormat] = Field(

View file

@ -400,6 +400,10 @@ class ChatAgent(ShieldRunnerMixin):
output_attachments = [] output_attachments = []
n_iter = 0 n_iter = 0
# Build a map of custom tools to their definitions for faster lookup
custom_tools = {}
for tool in self.agent_config.custom_tools:
custom_tools[tool.name] = tool
while True: while True:
msg = input_messages[-1] msg = input_messages[-1]
@ -530,6 +534,9 @@ class ChatAgent(ShieldRunnerMixin):
else: else:
log.info(f"{str(message)}") log.info(f"{str(message)}")
tool_call = message.tool_calls[0] tool_call = message.tool_calls[0]
if tool_call.tool_name in custom_tools:
yield message
return
step_id = str(uuid.uuid4()) step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
@ -619,6 +626,22 @@ class ChatAgent(ShieldRunnerMixin):
async def _get_tools(self) -> List[ToolDefinition]: async def _get_tools(self) -> List[ToolDefinition]:
ret = [] ret = []
for tool in self.agent_config.custom_tools:
params = {}
for param in tool.parameters:
params[param.name] = ToolParamDefinition(
param_type=param.parameter_type,
description=param.description,
required=param.required,
default=param.default,
)
ret.append(
ToolDefinition(
tool_name=tool.name,
description=tool.description,
parameters=params,
)
)
for tool_name in self.agent_config.available_tools: for tool_name in self.agent_config.available_tools:
tool = await self.tool_groups_api.get_tool(tool_name) tool = await self.tool_groups_api.get_tool(tool_name)
if tool.built_in_type: if tool.built_in_type:

View file

@ -9,16 +9,13 @@ from typing import Dict, List
from uuid import uuid4 from uuid import uuid4
import pytest import pytest
from llama_stack.providers.tests.env import get_env_or_fail
from llama_stack_client.lib.agents.agent import Agent from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.lib.agents.custom_tool import CustomTool from llama_stack_client.lib.agents.custom_tool import CustomTool
from llama_stack_client.lib.agents.event_logger import EventLogger from llama_stack_client.lib.agents.event_logger import EventLogger
from llama_stack_client.types import CompletionMessage, ToolResponseMessage from llama_stack_client.types import ToolResponseMessage
from llama_stack_client.types.agent_create_params import AgentConfig from llama_stack_client.types.agent_create_params import AgentConfig
from llama_stack_client.types.tool_param_definition_param import ( from llama_stack_client.types.custom_tool_def import Parameter
ToolParamDefinitionParam, from llama_stack_client.types.shared.completion_message import CompletionMessage
)
class TestCustomTool(CustomTool): class TestCustomTool(CustomTool):
@ -54,13 +51,17 @@ class TestCustomTool(CustomTool):
def get_description(self) -> str: def get_description(self) -> str:
return "Get the boiling point of a imaginary liquids (eg. polyjuice)" return "Get the boiling point of a imaginary liquids (eg. polyjuice)"
def get_params_definition(self) -> Dict[str, ToolParamDefinitionParam]: def get_params_definition(self) -> Dict[str, Parameter]:
return { return {
"liquid_name": ToolParamDefinitionParam( "liquid_name": Parameter(
param_type="string", description="The name of the liquid", required=True name="liquid_name",
parameter_type="string",
description="The name of the liquid",
required=True,
), ),
"celcius": ToolParamDefinitionParam( "celcius": Parameter(
param_type="boolean", name="celcius",
parameter_type="boolean",
description="Whether to return the boiling point in Celcius", description="Whether to return the boiling point in Celcius",
required=False, required=False,
), ),
@ -203,37 +204,16 @@ def test_builtin_tool_code_execution(llama_stack_client, agent_config):
def test_custom_tool(llama_stack_client, agent_config): def test_custom_tool(llama_stack_client, agent_config):
custom_tool = TestCustomTool()
agent_config = { agent_config = {
**agent_config, **agent_config,
"model": "meta-llama/Llama-3.2-3B-Instruct", "model": "meta-llama/Llama-3.2-3B-Instruct",
"tools": [ "available_tools": ["brave_search"],
{ "custom_tools": [custom_tool.get_tool_definition()],
"type": "brave_search",
"engine": "brave",
"api_key": get_env_or_fail("BRAVE_SEARCH_API_KEY"),
},
{
"function_name": "get_boiling_point",
"description": "Get the boiling point of a imaginary liquids (eg. polyjuice)",
"parameters": {
"liquid_name": {
"param_type": "str",
"description": "The name of the liquid",
"required": True,
},
"celcius": {
"param_type": "boolean",
"description": "Whether to return the boiling point in Celcius",
"required": False,
},
},
"type": "function_call",
},
],
"tool_prompt_format": "python_list", "tool_prompt_format": "python_list",
} }
agent = Agent(llama_stack_client, agent_config, custom_tools=(TestCustomTool(),)) agent = Agent(llama_stack_client, agent_config, custom_tools=(custom_tool,))
session_id = agent.create_session(f"test-session-{uuid4()}") session_id = agent.create_session(f"test-session-{uuid4()}")
response = agent.create_turn( response = agent.create_turn(