mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 17:29:01 +00:00
rename UserDefinedToolDef to ToolDef
This commit is contained in:
parent
db0b2a60c1
commit
e3775eb6f6
8 changed files with 180 additions and 322 deletions
|
@ -3714,7 +3714,7 @@
|
||||||
"client_tools": {
|
"client_tools": {
|
||||||
"type": "array",
|
"type": "array",
|
||||||
"items": {
|
"items": {
|
||||||
"$ref": "#/components/schemas/UserDefinedToolDef"
|
"$ref": "#/components/schemas/ToolDef"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"tool_choice": {
|
"tool_choice": {
|
||||||
|
@ -3792,60 +3792,9 @@
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"ToolParameter": {
|
"ToolDef": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"name": {
|
|
||||||
"type": "string"
|
|
||||||
},
|
|
||||||
"parameter_type": {
|
|
||||||
"type": "string"
|
|
||||||
},
|
|
||||||
"description": {
|
|
||||||
"type": "string"
|
|
||||||
},
|
|
||||||
"required": {
|
|
||||||
"type": "boolean"
|
|
||||||
},
|
|
||||||
"default": {
|
|
||||||
"oneOf": [
|
|
||||||
{
|
|
||||||
"type": "null"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "boolean"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "number"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "string"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "array"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "object"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"additionalProperties": false,
|
|
||||||
"required": [
|
|
||||||
"name",
|
|
||||||
"parameter_type",
|
|
||||||
"description",
|
|
||||||
"required"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"UserDefinedToolDef": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"type": {
|
|
||||||
"type": "string",
|
|
||||||
"const": "user_defined",
|
|
||||||
"default": "user_defined"
|
|
||||||
},
|
|
||||||
"name": {
|
"name": {
|
||||||
"type": "string"
|
"type": "string"
|
||||||
},
|
},
|
||||||
|
@ -3890,11 +3839,53 @@
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
"required": [
|
"required": [
|
||||||
"type",
|
"name"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"ToolParameter": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"name": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"parameter_type": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"description": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"required": {
|
||||||
|
"type": "boolean"
|
||||||
|
},
|
||||||
|
"default": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"type": "null"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "boolean"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "number"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "array"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "object"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
"name",
|
"name",
|
||||||
|
"parameter_type",
|
||||||
"description",
|
"description",
|
||||||
"parameters",
|
"required"
|
||||||
"metadata"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"CreateAgentRequest": {
|
"CreateAgentRequest": {
|
||||||
|
@ -4589,49 +4580,6 @@
|
||||||
"session_id"
|
"session_id"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"BuiltInToolDef": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"type": {
|
|
||||||
"type": "string",
|
|
||||||
"const": "built_in",
|
|
||||||
"default": "built_in"
|
|
||||||
},
|
|
||||||
"built_in_type": {
|
|
||||||
"$ref": "#/components/schemas/BuiltinTool"
|
|
||||||
},
|
|
||||||
"metadata": {
|
|
||||||
"type": "object",
|
|
||||||
"additionalProperties": {
|
|
||||||
"oneOf": [
|
|
||||||
{
|
|
||||||
"type": "null"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "boolean"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "number"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "string"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "array"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "object"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"additionalProperties": false,
|
|
||||||
"required": [
|
|
||||||
"type",
|
|
||||||
"built_in_type"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"MCPToolGroupDef": {
|
"MCPToolGroupDef": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
@ -4651,16 +4599,6 @@
|
||||||
],
|
],
|
||||||
"title": "A tool group that is defined by in a model context protocol server. Refer to https://modelcontextprotocol.io/docs/concepts/tools for more information."
|
"title": "A tool group that is defined by in a model context protocol server. Refer to https://modelcontextprotocol.io/docs/concepts/tools for more information."
|
||||||
},
|
},
|
||||||
"ToolDef": {
|
|
||||||
"oneOf": [
|
|
||||||
{
|
|
||||||
"$ref": "#/components/schemas/UserDefinedToolDef"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"$ref": "#/components/schemas/BuiltInToolDef"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"ToolGroupDef": {
|
"ToolGroupDef": {
|
||||||
"oneOf": [
|
"oneOf": [
|
||||||
{
|
{
|
||||||
|
@ -7436,7 +7374,7 @@
|
||||||
"tool_group_id": {
|
"tool_group_id": {
|
||||||
"type": "string"
|
"type": "string"
|
||||||
},
|
},
|
||||||
"tool_group": {
|
"tool_group_def": {
|
||||||
"$ref": "#/components/schemas/ToolGroupDef"
|
"$ref": "#/components/schemas/ToolGroupDef"
|
||||||
},
|
},
|
||||||
"provider_id": {
|
"provider_id": {
|
||||||
|
@ -7446,7 +7384,7 @@
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
"required": [
|
"required": [
|
||||||
"tool_group_id",
|
"tool_group_id",
|
||||||
"tool_group"
|
"tool_group_def"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"RunEvalRequest": {
|
"RunEvalRequest": {
|
||||||
|
@ -8098,10 +8036,6 @@
|
||||||
"name": "BenchmarkEvalTaskConfig",
|
"name": "BenchmarkEvalTaskConfig",
|
||||||
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/BenchmarkEvalTaskConfig\" />"
|
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/BenchmarkEvalTaskConfig\" />"
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"name": "BuiltInToolDef",
|
|
||||||
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/BuiltInToolDef\" />"
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"name": "BuiltinTool",
|
"name": "BuiltinTool",
|
||||||
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/BuiltinTool\" />"
|
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/BuiltinTool\" />"
|
||||||
|
@ -8708,10 +8642,6 @@
|
||||||
"name": "UnstructuredLogEvent",
|
"name": "UnstructuredLogEvent",
|
||||||
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/UnstructuredLogEvent\" />"
|
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/UnstructuredLogEvent\" />"
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"name": "UserDefinedToolDef",
|
|
||||||
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/UserDefinedToolDef\" />"
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"name": "UserDefinedToolGroupDef",
|
"name": "UserDefinedToolGroupDef",
|
||||||
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/UserDefinedToolGroupDef\" />"
|
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/UserDefinedToolGroupDef\" />"
|
||||||
|
@ -8792,7 +8722,6 @@
|
||||||
"BatchCompletionRequest",
|
"BatchCompletionRequest",
|
||||||
"BatchCompletionResponse",
|
"BatchCompletionResponse",
|
||||||
"BenchmarkEvalTaskConfig",
|
"BenchmarkEvalTaskConfig",
|
||||||
"BuiltInToolDef",
|
|
||||||
"BuiltinTool",
|
"BuiltinTool",
|
||||||
"CancelTrainingJobRequest",
|
"CancelTrainingJobRequest",
|
||||||
"ChatCompletionRequest",
|
"ChatCompletionRequest",
|
||||||
|
@ -8931,7 +8860,6 @@
|
||||||
"UnregisterModelRequest",
|
"UnregisterModelRequest",
|
||||||
"UnregisterToolGroupRequest",
|
"UnregisterToolGroupRequest",
|
||||||
"UnstructuredLogEvent",
|
"UnstructuredLogEvent",
|
||||||
"UserDefinedToolDef",
|
|
||||||
"UserDefinedToolGroupDef",
|
"UserDefinedToolGroupDef",
|
||||||
"UserMessage",
|
"UserMessage",
|
||||||
"VectorMemoryBank",
|
"VectorMemoryBank",
|
||||||
|
|
|
@ -19,7 +19,7 @@ components:
|
||||||
properties:
|
properties:
|
||||||
client_tools:
|
client_tools:
|
||||||
items:
|
items:
|
||||||
$ref: '#/components/schemas/UserDefinedToolDef'
|
$ref: '#/components/schemas/ToolDef'
|
||||||
type: array
|
type: array
|
||||||
enable_session_persistence:
|
enable_session_persistence:
|
||||||
type: boolean
|
type: boolean
|
||||||
|
@ -396,29 +396,6 @@ components:
|
||||||
- type
|
- type
|
||||||
- eval_candidate
|
- eval_candidate
|
||||||
type: object
|
type: object
|
||||||
BuiltInToolDef:
|
|
||||||
additionalProperties: false
|
|
||||||
properties:
|
|
||||||
built_in_type:
|
|
||||||
$ref: '#/components/schemas/BuiltinTool'
|
|
||||||
metadata:
|
|
||||||
additionalProperties:
|
|
||||||
oneOf:
|
|
||||||
- type: 'null'
|
|
||||||
- type: boolean
|
|
||||||
- type: number
|
|
||||||
- type: string
|
|
||||||
- type: array
|
|
||||||
- type: object
|
|
||||||
type: object
|
|
||||||
type:
|
|
||||||
const: built_in
|
|
||||||
default: built_in
|
|
||||||
type: string
|
|
||||||
required:
|
|
||||||
- type
|
|
||||||
- built_in_type
|
|
||||||
type: object
|
|
||||||
BuiltinTool:
|
BuiltinTool:
|
||||||
enum:
|
enum:
|
||||||
- brave_search
|
- brave_search
|
||||||
|
@ -1929,13 +1906,13 @@ components:
|
||||||
properties:
|
properties:
|
||||||
provider_id:
|
provider_id:
|
||||||
type: string
|
type: string
|
||||||
tool_group:
|
tool_group_def:
|
||||||
$ref: '#/components/schemas/ToolGroupDef'
|
$ref: '#/components/schemas/ToolGroupDef'
|
||||||
tool_group_id:
|
tool_group_id:
|
||||||
type: string
|
type: string
|
||||||
required:
|
required:
|
||||||
- tool_group_id
|
- tool_group_id
|
||||||
- tool_group
|
- tool_group_def
|
||||||
type: object
|
type: object
|
||||||
ResponseFormat:
|
ResponseFormat:
|
||||||
oneOf:
|
oneOf:
|
||||||
|
@ -2716,9 +2693,32 @@ components:
|
||||||
- required
|
- required
|
||||||
type: string
|
type: string
|
||||||
ToolDef:
|
ToolDef:
|
||||||
oneOf:
|
additionalProperties: false
|
||||||
- $ref: '#/components/schemas/UserDefinedToolDef'
|
properties:
|
||||||
- $ref: '#/components/schemas/BuiltInToolDef'
|
description:
|
||||||
|
type: string
|
||||||
|
metadata:
|
||||||
|
additionalProperties:
|
||||||
|
oneOf:
|
||||||
|
- type: 'null'
|
||||||
|
- type: boolean
|
||||||
|
- type: number
|
||||||
|
- type: string
|
||||||
|
- type: array
|
||||||
|
- type: object
|
||||||
|
type: object
|
||||||
|
name:
|
||||||
|
type: string
|
||||||
|
parameters:
|
||||||
|
items:
|
||||||
|
$ref: '#/components/schemas/ToolParameter'
|
||||||
|
type: array
|
||||||
|
tool_prompt_format:
|
||||||
|
$ref: '#/components/schemas/ToolPromptFormat'
|
||||||
|
default: json
|
||||||
|
required:
|
||||||
|
- name
|
||||||
|
type: object
|
||||||
ToolDefinition:
|
ToolDefinition:
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
properties:
|
properties:
|
||||||
|
@ -3087,41 +3087,6 @@ components:
|
||||||
- message
|
- message
|
||||||
- severity
|
- severity
|
||||||
type: object
|
type: object
|
||||||
UserDefinedToolDef:
|
|
||||||
additionalProperties: false
|
|
||||||
properties:
|
|
||||||
description:
|
|
||||||
type: string
|
|
||||||
metadata:
|
|
||||||
additionalProperties:
|
|
||||||
oneOf:
|
|
||||||
- type: 'null'
|
|
||||||
- type: boolean
|
|
||||||
- type: number
|
|
||||||
- type: string
|
|
||||||
- type: array
|
|
||||||
- type: object
|
|
||||||
type: object
|
|
||||||
name:
|
|
||||||
type: string
|
|
||||||
parameters:
|
|
||||||
items:
|
|
||||||
$ref: '#/components/schemas/ToolParameter'
|
|
||||||
type: array
|
|
||||||
tool_prompt_format:
|
|
||||||
$ref: '#/components/schemas/ToolPromptFormat'
|
|
||||||
default: json
|
|
||||||
type:
|
|
||||||
const: user_defined
|
|
||||||
default: user_defined
|
|
||||||
type: string
|
|
||||||
required:
|
|
||||||
- type
|
|
||||||
- name
|
|
||||||
- description
|
|
||||||
- parameters
|
|
||||||
- metadata
|
|
||||||
type: object
|
|
||||||
UserDefinedToolGroupDef:
|
UserDefinedToolGroupDef:
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
properties:
|
properties:
|
||||||
|
@ -4823,8 +4788,6 @@ tags:
|
||||||
- description: <SchemaDefinition schemaRef="#/components/schemas/BenchmarkEvalTaskConfig"
|
- description: <SchemaDefinition schemaRef="#/components/schemas/BenchmarkEvalTaskConfig"
|
||||||
/>
|
/>
|
||||||
name: BenchmarkEvalTaskConfig
|
name: BenchmarkEvalTaskConfig
|
||||||
- description: <SchemaDefinition schemaRef="#/components/schemas/BuiltInToolDef" />
|
|
||||||
name: BuiltInToolDef
|
|
||||||
- description: <SchemaDefinition schemaRef="#/components/schemas/BuiltinTool" />
|
- description: <SchemaDefinition schemaRef="#/components/schemas/BuiltinTool" />
|
||||||
name: BuiltinTool
|
name: BuiltinTool
|
||||||
- description: <SchemaDefinition schemaRef="#/components/schemas/CancelTrainingJobRequest"
|
- description: <SchemaDefinition schemaRef="#/components/schemas/CancelTrainingJobRequest"
|
||||||
|
@ -5251,9 +5214,6 @@ tags:
|
||||||
- description: <SchemaDefinition schemaRef="#/components/schemas/UnstructuredLogEvent"
|
- description: <SchemaDefinition schemaRef="#/components/schemas/UnstructuredLogEvent"
|
||||||
/>
|
/>
|
||||||
name: UnstructuredLogEvent
|
name: UnstructuredLogEvent
|
||||||
- description: <SchemaDefinition schemaRef="#/components/schemas/UserDefinedToolDef"
|
|
||||||
/>
|
|
||||||
name: UserDefinedToolDef
|
|
||||||
- description: <SchemaDefinition schemaRef="#/components/schemas/UserDefinedToolGroupDef"
|
- description: <SchemaDefinition schemaRef="#/components/schemas/UserDefinedToolGroupDef"
|
||||||
/>
|
/>
|
||||||
name: UserDefinedToolGroupDef
|
name: UserDefinedToolGroupDef
|
||||||
|
@ -5316,7 +5276,6 @@ x-tagGroups:
|
||||||
- BatchCompletionRequest
|
- BatchCompletionRequest
|
||||||
- BatchCompletionResponse
|
- BatchCompletionResponse
|
||||||
- BenchmarkEvalTaskConfig
|
- BenchmarkEvalTaskConfig
|
||||||
- BuiltInToolDef
|
|
||||||
- BuiltinTool
|
- BuiltinTool
|
||||||
- CancelTrainingJobRequest
|
- CancelTrainingJobRequest
|
||||||
- ChatCompletionRequest
|
- ChatCompletionRequest
|
||||||
|
@ -5455,7 +5414,6 @@ x-tagGroups:
|
||||||
- UnregisterModelRequest
|
- UnregisterModelRequest
|
||||||
- UnregisterToolGroupRequest
|
- UnregisterToolGroupRequest
|
||||||
- UnstructuredLogEvent
|
- UnstructuredLogEvent
|
||||||
- UserDefinedToolDef
|
|
||||||
- UserDefinedToolGroupDef
|
- UserDefinedToolGroupDef
|
||||||
- UserMessage
|
- UserMessage
|
||||||
- VectorMemoryBank
|
- VectorMemoryBank
|
||||||
|
|
|
@ -36,7 +36,7 @@ from llama_stack.apis.inference import (
|
||||||
)
|
)
|
||||||
from llama_stack.apis.memory import MemoryBank
|
from llama_stack.apis.memory import MemoryBank
|
||||||
from llama_stack.apis.safety import SafetyViolation
|
from llama_stack.apis.safety import SafetyViolation
|
||||||
from llama_stack.apis.tools import UserDefinedToolDef
|
from llama_stack.apis.tools import ToolDef
|
||||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||||
|
|
||||||
|
|
||||||
|
@ -157,7 +157,7 @@ class AgentConfigCommon(BaseModel):
|
||||||
input_shields: Optional[List[str]] = Field(default_factory=list)
|
input_shields: Optional[List[str]] = Field(default_factory=list)
|
||||||
output_shields: Optional[List[str]] = Field(default_factory=list)
|
output_shields: Optional[List[str]] = Field(default_factory=list)
|
||||||
tools: Optional[List[AgentTool]] = Field(default_factory=list)
|
tools: Optional[List[AgentTool]] = Field(default_factory=list)
|
||||||
client_tools: Optional[List[UserDefinedToolDef]] = Field(default_factory=list)
|
client_tools: Optional[List[ToolDef]] = Field(default_factory=list)
|
||||||
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
|
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
||||||
default=ToolPromptFormat.json
|
default=ToolPromptFormat.json
|
||||||
|
|
|
@ -48,30 +48,16 @@ class Tool(Resource):
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class UserDefinedToolDef(BaseModel):
|
class ToolDef(BaseModel):
|
||||||
type: Literal["user_defined"] = "user_defined"
|
|
||||||
name: str
|
name: str
|
||||||
description: str
|
description: Optional[str] = None
|
||||||
parameters: List[ToolParameter]
|
parameters: Optional[List[ToolParameter]] = None
|
||||||
metadata: Dict[str, Any]
|
metadata: Optional[Dict[str, Any]] = None
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
||||||
default=ToolPromptFormat.json
|
default=ToolPromptFormat.json
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class BuiltInToolDef(BaseModel):
|
|
||||||
type: Literal["built_in"] = "built_in"
|
|
||||||
built_in_type: BuiltinTool
|
|
||||||
metadata: Optional[Dict[str, Any]] = None
|
|
||||||
|
|
||||||
|
|
||||||
ToolDef = register_schema(
|
|
||||||
Annotated[Union[UserDefinedToolDef, BuiltInToolDef], Field(discriminator="type")],
|
|
||||||
name="ToolDef",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class MCPToolGroupDef(BaseModel):
|
class MCPToolGroupDef(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
@ -100,7 +86,7 @@ ToolGroupDef = register_schema(
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ToolGroupInput(BaseModel):
|
class ToolGroupInput(BaseModel):
|
||||||
tool_group_id: str
|
tool_group_id: str
|
||||||
tool_group: ToolGroupDef
|
tool_group_def: ToolGroupDef
|
||||||
provider_id: Optional[str] = None
|
provider_id: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
@ -127,7 +113,7 @@ class ToolGroups(Protocol):
|
||||||
async def register_tool_group(
|
async def register_tool_group(
|
||||||
self,
|
self,
|
||||||
tool_group_id: str,
|
tool_group_id: str,
|
||||||
tool_group: ToolGroupDef,
|
tool_group_def: ToolGroupDef,
|
||||||
provider_id: Optional[str] = None,
|
provider_id: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Register a tool group"""
|
"""Register a tool group"""
|
||||||
|
|
|
@ -27,15 +27,12 @@ from llama_stack.apis.scoring_functions import (
|
||||||
)
|
)
|
||||||
from llama_stack.apis.shields import Shield, Shields
|
from llama_stack.apis.shields import Shield, Shields
|
||||||
from llama_stack.apis.tools import (
|
from llama_stack.apis.tools import (
|
||||||
BuiltInToolDef,
|
|
||||||
MCPToolGroupDef,
|
MCPToolGroupDef,
|
||||||
Tool,
|
Tool,
|
||||||
ToolGroup,
|
ToolGroup,
|
||||||
ToolGroupDef,
|
ToolGroupDef,
|
||||||
ToolGroups,
|
ToolGroups,
|
||||||
ToolHost,
|
ToolHost,
|
||||||
ToolPromptFormat,
|
|
||||||
UserDefinedToolDef,
|
|
||||||
UserDefinedToolGroupDef,
|
UserDefinedToolGroupDef,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.datatypes import (
|
from llama_stack.distribution.datatypes import (
|
||||||
|
@ -514,7 +511,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||||
async def register_tool_group(
|
async def register_tool_group(
|
||||||
self,
|
self,
|
||||||
tool_group_id: str,
|
tool_group_id: str,
|
||||||
tool_group: ToolGroupDef,
|
tool_group_def: ToolGroupDef,
|
||||||
provider_id: Optional[str] = None,
|
provider_id: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
tools = []
|
tools = []
|
||||||
|
@ -528,47 +525,31 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||||
|
|
||||||
# parse tool group to the type if dict
|
# parse tool group to the type if dict
|
||||||
tool_group = TypeAdapter(ToolGroupDef).validate_python(tool_group)
|
tool_group_def = TypeAdapter(ToolGroupDef).validate_python(tool_group_def)
|
||||||
if isinstance(tool_group, MCPToolGroupDef):
|
if isinstance(tool_group_def, MCPToolGroupDef):
|
||||||
tool_defs = await self.impls_by_provider_id[provider_id].discover_tools(
|
tool_defs = await self.impls_by_provider_id[provider_id].discover_tools(
|
||||||
tool_group
|
tool_group_def
|
||||||
)
|
)
|
||||||
tool_host = ToolHost.model_context_protocol
|
tool_host = ToolHost.model_context_protocol
|
||||||
elif isinstance(tool_group, UserDefinedToolGroupDef):
|
elif isinstance(tool_group_def, UserDefinedToolGroupDef):
|
||||||
tool_defs = tool_group.tools
|
tool_defs = tool_group_def.tools
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown tool group: {tool_group}")
|
raise ValueError(f"Unknown tool group: {tool_group_def}")
|
||||||
|
|
||||||
for tool_def in tool_defs:
|
for tool_def in tool_defs:
|
||||||
if isinstance(tool_def, UserDefinedToolDef):
|
tools.append(
|
||||||
tools.append(
|
Tool(
|
||||||
Tool(
|
identifier=tool_def.name,
|
||||||
identifier=tool_def.name,
|
tool_group=tool_group_id,
|
||||||
tool_group=tool_group_id,
|
description=tool_def.description or "",
|
||||||
description=tool_def.description,
|
parameters=tool_def.parameters or [],
|
||||||
parameters=tool_def.parameters,
|
provider_id=provider_id,
|
||||||
provider_id=provider_id,
|
tool_prompt_format=tool_def.tool_prompt_format,
|
||||||
tool_prompt_format=tool_def.tool_prompt_format,
|
provider_resource_id=tool_def.name,
|
||||||
provider_resource_id=tool_def.name,
|
metadata=tool_def.metadata,
|
||||||
metadata=tool_def.metadata,
|
tool_host=tool_host,
|
||||||
tool_host=tool_host,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
elif isinstance(tool_def, BuiltInToolDef):
|
|
||||||
tools.append(
|
|
||||||
Tool(
|
|
||||||
identifier=tool_def.built_in_type.value,
|
|
||||||
tool_group=tool_group_id,
|
|
||||||
built_in_type=tool_def.built_in_type,
|
|
||||||
description="",
|
|
||||||
parameters=[],
|
|
||||||
provider_id=provider_id,
|
|
||||||
tool_prompt_format=ToolPromptFormat.json,
|
|
||||||
provider_resource_id=tool_def.built_in_type.value,
|
|
||||||
metadata=tool_def.metadata,
|
|
||||||
tool_host=tool_host,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
)
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
existing_tool = await self.get_tool(tool.identifier)
|
existing_tool = await self.get_tool(tool.identifier)
|
||||||
# Compare existing and new object if one exists
|
# Compare existing and new object if one exists
|
||||||
|
|
|
@ -387,7 +387,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
extra_args = tool_args.get("memory", {})
|
extra_args = tool_args.get("memory", {})
|
||||||
args = {
|
tool_args = {
|
||||||
# Query memory with the last message's content
|
# Query memory with the last message's content
|
||||||
"query": input_messages[-1],
|
"query": input_messages[-1],
|
||||||
**extra_args,
|
**extra_args,
|
||||||
|
@ -396,8 +396,8 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
session_info = await self.storage.get_session_info(session_id)
|
session_info = await self.storage.get_session_info(session_id)
|
||||||
# if the session has a memory bank id, let the memory tool use it
|
# if the session has a memory bank id, let the memory tool use it
|
||||||
if session_info.memory_bank_id:
|
if session_info.memory_bank_id:
|
||||||
args["memory_bank_id"] = session_info.memory_bank_id
|
tool_args["memory_bank_id"] = session_info.memory_bank_id
|
||||||
serialized_args = tracing.serialize_value(args)
|
serialized_args = tracing.serialize_value(tool_args)
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgentTurnResponseStepProgressPayload(
|
payload=AgentTurnResponseStepProgressPayload(
|
||||||
|
@ -416,7 +416,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
result = await self.tool_runtime_api.invoke_tool(
|
result = await self.tool_runtime_api.invoke_tool(
|
||||||
tool_name="memory",
|
tool_name="memory",
|
||||||
args=args,
|
args=tool_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
|
@ -482,11 +482,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
async for chunk in await self.inference_api.chat_completion(
|
async for chunk in await self.inference_api.chat_completion(
|
||||||
self.agent_config.model,
|
self.agent_config.model,
|
||||||
input_messages,
|
input_messages,
|
||||||
tools=[
|
tools=[tool for tool in tool_defs.values()],
|
||||||
tool
|
|
||||||
for tool in tool_defs.values()
|
|
||||||
if tool.tool_name != "memory"
|
|
||||||
],
|
|
||||||
tool_prompt_format=self.agent_config.tool_prompt_format,
|
tool_prompt_format=self.agent_config.tool_prompt_format,
|
||||||
stream=True,
|
stream=True,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
|
@ -728,10 +724,17 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
tool_def = await self.tool_groups_api.get_tool(tool_name)
|
tool_def = await self.tool_groups_api.get_tool(tool_name)
|
||||||
|
if tool_def is None:
|
||||||
|
raise ValueError(f"Tool {tool_name} not found")
|
||||||
|
|
||||||
if tool_def.built_in_type:
|
if tool_def.identifier.startswith("builtin::"):
|
||||||
ret[tool_def.built_in_type] = ToolDefinition(
|
built_in_type = tool_def.identifier[len("builtin::") :]
|
||||||
tool_name=tool_def.built_in_type
|
if built_in_type == "web_search":
|
||||||
|
built_in_type = "brave_search"
|
||||||
|
if built_in_type not in BuiltinTool.__members__:
|
||||||
|
raise ValueError(f"Unknown built-in tool: {built_in_type}")
|
||||||
|
ret[built_in_type] = ToolDefinition(
|
||||||
|
tool_name=BuiltinTool(built_in_type)
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -759,52 +762,52 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
tool_defs: Dict[str, ToolDefinition],
|
tool_defs: Dict[str, ToolDefinition],
|
||||||
) -> None:
|
) -> None:
|
||||||
memory_tool = tool_defs.get("memory", None)
|
memory_tool = tool_defs.get("memory", None)
|
||||||
code_interpreter_tool = tool_defs.get(BuiltinTool.code_interpreter, None)
|
code_interpreter_tool = tool_defs.get("code_interpreter", None)
|
||||||
if documents:
|
content_items = []
|
||||||
content_items = []
|
url_items = []
|
||||||
url_items = []
|
pattern = re.compile("^(https?://|file://|data:)")
|
||||||
pattern = re.compile("^(https?://|file://|data:)")
|
for d in documents:
|
||||||
for d in documents:
|
if isinstance(d.content, URL):
|
||||||
if isinstance(d.content, URL):
|
url_items.append(d.content)
|
||||||
url_items.append(d.content)
|
elif pattern.match(d.content):
|
||||||
elif pattern.match(d.content):
|
url_items.append(URL(uri=d.content))
|
||||||
url_items.append(URL(uri=d.content))
|
|
||||||
else:
|
|
||||||
content_items.append(d)
|
|
||||||
|
|
||||||
# Save the contents to a tempdir and use its path as a URL if code interpreter is present
|
|
||||||
if code_interpreter_tool:
|
|
||||||
for c in content_items:
|
|
||||||
temp_file_path = os.path.join(
|
|
||||||
self.tempdir, f"{make_random_string()}.txt"
|
|
||||||
)
|
|
||||||
with open(temp_file_path, "w") as temp_file:
|
|
||||||
temp_file.write(c.content)
|
|
||||||
url_items.append(URL(uri=f"file://{temp_file_path}"))
|
|
||||||
|
|
||||||
if memory_tool and code_interpreter_tool:
|
|
||||||
# if both memory and code_interpreter are available, we download the URLs
|
|
||||||
# and attach the data to the last message.
|
|
||||||
msg = await attachment_message(self.tempdir, url_items)
|
|
||||||
input_messages.append(msg)
|
|
||||||
# Since memory is present, add all the data to the memory bank
|
|
||||||
await self.add_to_session_memory_bank(session_id, documents)
|
|
||||||
elif code_interpreter_tool:
|
|
||||||
# if only code_interpreter is available, we download the URLs to a tempdir
|
|
||||||
# and attach the path to them as a message to inference with the
|
|
||||||
# assumption that the model invokes the code_interpreter tool with the path
|
|
||||||
msg = await attachment_message(self.tempdir, url_items)
|
|
||||||
input_messages.append(msg)
|
|
||||||
elif memory_tool:
|
|
||||||
# if only memory is available, we load the data from the URLs and content items to the memory bank
|
|
||||||
await self.add_to_session_memory_bank(session_id, documents)
|
|
||||||
else:
|
else:
|
||||||
# if no memory or code_interpreter tool is available,
|
content_items.append(d)
|
||||||
# we try to load the data from the URLs and content items as a message to inference
|
|
||||||
# and add it to the last message's context
|
# Save the contents to a tempdir and use its path as a URL if code interpreter is present
|
||||||
input_messages[-1].context = content_items + await load_data_from_urls(
|
if code_interpreter_tool:
|
||||||
url_items
|
for c in content_items:
|
||||||
|
temp_file_path = os.path.join(
|
||||||
|
self.tempdir, f"{make_random_string()}.txt"
|
||||||
)
|
)
|
||||||
|
with open(temp_file_path, "w") as temp_file:
|
||||||
|
temp_file.write(c.content)
|
||||||
|
url_items.append(URL(uri=f"file://{temp_file_path}"))
|
||||||
|
|
||||||
|
if memory_tool and code_interpreter_tool:
|
||||||
|
# if both memory and code_interpreter are available, we download the URLs
|
||||||
|
# and attach the data to the last message.
|
||||||
|
msg = await attachment_message(self.tempdir, url_items)
|
||||||
|
input_messages.append(msg)
|
||||||
|
# Since memory is present, add all the data to the memory bank
|
||||||
|
await self.add_to_session_memory_bank(session_id, documents)
|
||||||
|
elif code_interpreter_tool:
|
||||||
|
# if only code_interpreter is available, we download the URLs to a tempdir
|
||||||
|
# and attach the path to them as a message to inference with the
|
||||||
|
# assumption that the model invokes the code_interpreter tool with the path
|
||||||
|
msg = await attachment_message(self.tempdir, url_items)
|
||||||
|
input_messages.append(msg)
|
||||||
|
elif memory_tool:
|
||||||
|
# if only memory is available, we load the data from the URLs and content items to the memory bank
|
||||||
|
await self.add_to_session_memory_bank(session_id, documents)
|
||||||
|
else:
|
||||||
|
# if no memory or code_interpreter tool is available,
|
||||||
|
# we try to load the data from the URLs and content items as a message to inference
|
||||||
|
# and add it to the last message's context
|
||||||
|
input_messages[-1].context = "\n".join(
|
||||||
|
[doc.content for doc in content_items]
|
||||||
|
+ await load_data_from_urls(url_items)
|
||||||
|
)
|
||||||
|
|
||||||
async def _ensure_memory_bank(self, session_id: str) -> str:
|
async def _ensure_memory_bank(self, session_id: str) -> str:
|
||||||
session_info = await self.storage.get_session_info(session_id)
|
session_info = await self.storage.get_session_info(session_id)
|
||||||
|
@ -909,7 +912,10 @@ async def execute_tool_call_maybe(
|
||||||
tool_call = message.tool_calls[0]
|
tool_call = message.tool_calls[0]
|
||||||
name = tool_call.tool_name
|
name = tool_call.tool_name
|
||||||
if isinstance(name, BuiltinTool):
|
if isinstance(name, BuiltinTool):
|
||||||
name = name.value
|
if name == BuiltinTool.brave_search:
|
||||||
|
name = "builtin::web_search"
|
||||||
|
else:
|
||||||
|
name = "builtin::" + name.value
|
||||||
result = await tool_runtime_api.invoke_tool(
|
result = await tool_runtime_api.invoke_tool(
|
||||||
tool_name=name,
|
tool_name=name,
|
||||||
args=dict(
|
args=dict(
|
||||||
|
|
|
@ -30,8 +30,7 @@ class CodeInterpreterToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def register_tool(self, tool: Tool):
|
async def register_tool(self, tool: Tool):
|
||||||
if tool.identifier != "code_interpreter":
|
pass
|
||||||
raise ValueError(f"Tool identifier {tool.identifier} is not supported")
|
|
||||||
|
|
||||||
async def unregister_tool(self, tool_id: str) -> None:
|
async def unregister_tool(self, tool_id: str) -> None:
|
||||||
return
|
return
|
||||||
|
|
|
@ -17,7 +17,7 @@ from llama_stack_client.types.agent_create_params import AgentConfig
|
||||||
from llama_stack_client.types.agents.turn_create_params import Document as AgentDocument
|
from llama_stack_client.types.agents.turn_create_params import Document as AgentDocument
|
||||||
from llama_stack_client.types.memory_insert_params import Document
|
from llama_stack_client.types.memory_insert_params import Document
|
||||||
from llama_stack_client.types.shared.completion_message import CompletionMessage
|
from llama_stack_client.types.shared.completion_message import CompletionMessage
|
||||||
from llama_stack_client.types.tool_def_param import UserDefinedToolDefParameter
|
from llama_stack_client.types.tool_def_param import Parameter
|
||||||
|
|
||||||
|
|
||||||
class TestClientTool(ClientTool):
|
class TestClientTool(ClientTool):
|
||||||
|
@ -53,15 +53,15 @@ class TestClientTool(ClientTool):
|
||||||
def get_description(self) -> str:
|
def get_description(self) -> str:
|
||||||
return "Get the boiling point of imaginary liquids (eg. polyjuice)"
|
return "Get the boiling point of imaginary liquids (eg. polyjuice)"
|
||||||
|
|
||||||
def get_params_definition(self) -> Dict[str, UserDefinedToolDefParameter]:
|
def get_params_definition(self) -> Dict[str, Parameter]:
|
||||||
return {
|
return {
|
||||||
"liquid_name": UserDefinedToolDefParameter(
|
"liquid_name": Parameter(
|
||||||
name="liquid_name",
|
name="liquid_name",
|
||||||
parameter_type="string",
|
parameter_type="string",
|
||||||
description="The name of the liquid",
|
description="The name of the liquid",
|
||||||
required=True,
|
required=True,
|
||||||
),
|
),
|
||||||
"celcius": UserDefinedToolDefParameter(
|
"celcius": Parameter(
|
||||||
name="celcius",
|
name="celcius",
|
||||||
parameter_type="boolean",
|
parameter_type="boolean",
|
||||||
description="Whether to return the boiling point in Celcius",
|
description="Whether to return the boiling point in Celcius",
|
||||||
|
@ -149,11 +149,11 @@ def test_agent_simple(llama_stack_client, agent_config):
|
||||||
assert "I can't" in logs_str
|
assert "I can't" in logs_str
|
||||||
|
|
||||||
|
|
||||||
def test_builtin_tool_brave_search(llama_stack_client, agent_config):
|
def test_builtin_tool_web_search(llama_stack_client, agent_config):
|
||||||
agent_config = {
|
agent_config = {
|
||||||
**agent_config,
|
**agent_config,
|
||||||
"tools": [
|
"tools": [
|
||||||
"brave_search",
|
"builtin::web_search",
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
agent = Agent(llama_stack_client, agent_config)
|
agent = Agent(llama_stack_client, agent_config)
|
||||||
|
@ -182,7 +182,7 @@ def test_builtin_tool_code_execution(llama_stack_client, agent_config):
|
||||||
agent_config = {
|
agent_config = {
|
||||||
**agent_config,
|
**agent_config,
|
||||||
"tools": [
|
"tools": [
|
||||||
"code_interpreter",
|
"builtin::code_interpreter",
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
agent = Agent(llama_stack_client, agent_config)
|
agent = Agent(llama_stack_client, agent_config)
|
||||||
|
@ -209,9 +209,9 @@ def test_code_execution(llama_stack_client):
|
||||||
model="meta-llama/Llama-3.1-70B-Instruct",
|
model="meta-llama/Llama-3.1-70B-Instruct",
|
||||||
instructions="You are a helpful assistant",
|
instructions="You are a helpful assistant",
|
||||||
tools=[
|
tools=[
|
||||||
"code_interpreter",
|
"builtin::code_interpreter",
|
||||||
],
|
],
|
||||||
tool_choice="auto",
|
tool_choice="required",
|
||||||
input_shields=[],
|
input_shields=[],
|
||||||
output_shields=[],
|
output_shields=[],
|
||||||
enable_session_persistence=False,
|
enable_session_persistence=False,
|
||||||
|
@ -242,7 +242,7 @@ def test_code_execution(llama_stack_client):
|
||||||
)
|
)
|
||||||
logs = [str(log) for log in EventLogger().log(response) if log is not None]
|
logs = [str(log) for log in EventLogger().log(response) if log is not None]
|
||||||
logs_str = "".join(logs)
|
logs_str = "".join(logs)
|
||||||
print(logs_str)
|
assert "Tool:code_interpreter" in logs_str
|
||||||
|
|
||||||
|
|
||||||
def test_custom_tool(llama_stack_client, agent_config):
|
def test_custom_tool(llama_stack_client, agent_config):
|
||||||
|
@ -250,7 +250,7 @@ def test_custom_tool(llama_stack_client, agent_config):
|
||||||
agent_config = {
|
agent_config = {
|
||||||
**agent_config,
|
**agent_config,
|
||||||
"model": "meta-llama/Llama-3.2-3B-Instruct",
|
"model": "meta-llama/Llama-3.2-3B-Instruct",
|
||||||
"tools": ["brave_search"],
|
"tools": ["builtin::web_search"],
|
||||||
"client_tools": [client_tool.get_tool_definition()],
|
"client_tools": [client_tool.get_tool_definition()],
|
||||||
"tool_prompt_format": "python_list",
|
"tool_prompt_format": "python_list",
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue