working end to end client sdk tests

This commit is contained in:
Dinesh Yeduguru 2024-12-23 11:04:24 -08:00
parent 2ad67529ef
commit 0155700ea6
5 changed files with 947 additions and 1039 deletions

File diff suppressed because it is too large Load diff

View file

@ -17,6 +17,10 @@ components:
AgentConfig:
additionalProperties: false
properties:
available_tools:
items:
type: string
type: array
enable_session_persistence:
type: boolean
input_shields:
@ -34,6 +38,10 @@ components:
items:
type: string
type: array
preprocessing_tools:
items:
type: string
type: array
sampling_params:
$ref: '#/components/schemas/SamplingParams'
tool_choice:
@ -42,16 +50,6 @@ components:
tool_prompt_format:
$ref: '#/components/schemas/ToolPromptFormat'
default: json
tools:
items:
oneOf:
- $ref: '#/components/schemas/SearchToolDefinition'
- $ref: '#/components/schemas/WolframAlphaToolDefinition'
- $ref: '#/components/schemas/PhotogenToolDefinition'
- $ref: '#/components/schemas/CodeInterpreterToolDefinition'
- $ref: '#/components/schemas/FunctionCallToolDefinition'
- $ref: '#/components/schemas/MemoryToolDefinition'
type: array
required:
- max_infer_iters
- model
@ -490,30 +488,6 @@ components:
type: object
Checkpoint:
description: Checkpoint created during training runs
CodeInterpreterToolDefinition:
additionalProperties: false
properties:
enable_inline_code_execution:
default: true
type: boolean
input_shields:
items:
type: string
type: array
output_shields:
items:
type: string
type: array
remote_execution:
$ref: '#/components/schemas/RestAPIExecutionConfig'
type:
const: code_interpreter
default: code_interpreter
type: string
required:
- type
- enable_inline_code_execution
type: object
CompletionMessage:
additionalProperties: false
properties:
@ -729,6 +703,14 @@ components:
- agent_id
- session_id
type: object
DiscoverToolsRequest:
additionalProperties: false
properties:
tool_group:
$ref: '#/components/schemas/ToolGroupDef'
required:
- tool_group
type: object
EfficiencyConfig:
additionalProperties: false
properties:
@ -862,37 +844,6 @@ components:
- scoring_functions
- task_config
type: object
FunctionCallToolDefinition:
additionalProperties: false
properties:
description:
type: string
function_name:
type: string
input_shields:
items:
type: string
type: array
output_shields:
items:
type: string
type: array
parameters:
additionalProperties:
$ref: '#/components/schemas/ToolParamDefinition'
type: object
remote_execution:
$ref: '#/components/schemas/RestAPIExecutionConfig'
type:
const: function_call
default: function_call
type: string
required:
- type
- function_name
- description
- parameters
type: object
GetAgentsSessionRequest:
additionalProperties: false
properties:
@ -1017,6 +968,25 @@ components:
oneOf:
- $ref: '#/components/schemas/ImageContentItem'
- $ref: '#/components/schemas/TextContentItem'
InvokeToolRequest:
additionalProperties: false
properties:
args:
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
type: object
tool_name:
type: string
required:
- tool_name
- args
type: object
Job:
additionalProperties: false
properties:
@ -1190,6 +1160,21 @@ components:
- rank
- alpha
type: object
MCPToolGroupDef:
additionalProperties: false
properties:
endpoint:
$ref: '#/components/schemas/URL'
type:
const: model_context_protocol
default: model_context_protocol
type: string
required:
- type
- endpoint
title: A tool group that is defined by in a model context protocol server. Refer
to https://modelcontextprotocol.io/docs/concepts/tools for more information.
type: object
MemoryBankDocument:
additionalProperties: false
properties:
@ -1250,135 +1235,6 @@ components:
- memory_bank_ids
- inserted_context
type: object
MemoryToolDefinition:
additionalProperties: false
properties:
input_shields:
items:
type: string
type: array
max_chunks:
default: 10
type: integer
max_tokens_in_context:
default: 4096
type: integer
memory_bank_configs:
items:
oneOf:
- additionalProperties: false
properties:
bank_id:
type: string
type:
const: vector
default: vector
type: string
required:
- bank_id
- type
type: object
- additionalProperties: false
properties:
bank_id:
type: string
keys:
items:
type: string
type: array
type:
const: keyvalue
default: keyvalue
type: string
required:
- bank_id
- type
- keys
type: object
- additionalProperties: false
properties:
bank_id:
type: string
type:
const: keyword
default: keyword
type: string
required:
- bank_id
- type
type: object
- additionalProperties: false
properties:
bank_id:
type: string
entities:
items:
type: string
type: array
type:
const: graph
default: graph
type: string
required:
- bank_id
- type
- entities
type: object
type: array
output_shields:
items:
type: string
type: array
query_generator_config:
oneOf:
- additionalProperties: false
properties:
sep:
default: ' '
type: string
type:
const: default
default: default
type: string
required:
- type
- sep
type: object
- additionalProperties: false
properties:
model:
type: string
template:
type: string
type:
const: llm
default: llm
type: string
required:
- type
- model
- template
type: object
- additionalProperties: false
properties:
type:
const: custom
default: custom
type: string
required:
- type
type: object
type:
const: memory
default: memory
type: string
required:
- type
- memory_bank_configs
- query_generator_config
- max_tokens_in_context
- max_chunks
type: object
Message:
oneOf:
- $ref: '#/components/schemas/UserMessage'
@ -1621,26 +1477,6 @@ components:
required:
- type
type: object
PhotogenToolDefinition:
additionalProperties: false
properties:
input_shields:
items:
type: string
type: array
output_shields:
items:
type: string
type: array
remote_execution:
$ref: '#/components/schemas/RestAPIExecutionConfig'
type:
const: photogen
default: photogen
type: string
required:
- type
type: object
PostTrainingJob:
additionalProperties: false
properties:
@ -2039,6 +1875,19 @@ components:
required:
- shield_id
type: object
RegisterToolGroupRequest:
additionalProperties: false
properties:
provider_id:
type: string
tool_group:
$ref: '#/components/schemas/ToolGroupDef'
tool_group_id:
type: string
required:
- tool_group_id
- tool_group
type: object
ResponseFormat:
oneOf:
- additionalProperties: false
@ -2081,54 +1930,6 @@ components:
- type
- bnf
type: object
RestAPIExecutionConfig:
additionalProperties: false
properties:
body:
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
type: object
headers:
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
type: object
method:
$ref: '#/components/schemas/RestAPIMethod'
params:
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
type: object
url:
$ref: '#/components/schemas/URL'
required:
- url
- method
type: object
RestAPIMethod:
enum:
- GET
- POST
- PUT
- DELETE
type: string
RouteInfo:
additionalProperties: false
properties:
@ -2399,37 +2200,6 @@ components:
- score_rows
- aggregated_results
type: object
SearchToolDefinition:
additionalProperties: false
properties:
api_key:
type: string
engine:
default: brave
enum:
- bing
- brave
- tavily
type: string
input_shields:
items:
type: string
type: array
output_shields:
items:
type: string
type: array
remote_execution:
$ref: '#/components/schemas/RestAPIExecutionConfig'
type:
const: brave_search
default: brave_search
type: string
required:
- type
- api_key
- engine
type: object
Session:
additionalProperties: false
properties:
@ -2784,6 +2554,48 @@ components:
required:
- logprobs_by_token
type: object
Tool:
additionalProperties: false
properties:
description:
type: string
identifier:
type: string
metadata:
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
type: object
parameters:
items:
$ref: '#/components/schemas/ToolParameter'
type: array
provider_id:
type: string
provider_resource_id:
type: string
tool_group:
type: string
tool_prompt_format:
$ref: '#/components/schemas/ToolPromptFormat'
default: json
type:
const: tool
default: tool
type: string
required:
- identifier
- provider_resource_id
- type
- tool_group
- description
- parameters
type: object
ToolCall:
additionalProperties: false
properties:
@ -2848,6 +2660,36 @@ components:
- auto
- required
type: string
ToolDef:
additionalProperties: false
properties:
description:
type: string
metadata:
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
type: object
name:
type: string
parameters:
items:
$ref: '#/components/schemas/ToolParameter'
type: array
tool_prompt_format:
$ref: '#/components/schemas/ToolPromptFormat'
default: json
required:
- name
- description
- parameters
- metadata
type: object
ToolDefinition:
additionalProperties: false
properties:
@ -2896,6 +2738,41 @@ components:
- tool_calls
- tool_responses
type: object
ToolGroup:
additionalProperties: false
properties:
identifier:
type: string
provider_id:
type: string
provider_resource_id:
type: string
type:
const: tool_group
default: tool_group
type: string
required:
- identifier
- provider_resource_id
- provider_id
- type
type: object
ToolGroupDef:
oneOf:
- $ref: '#/components/schemas/MCPToolGroupDef'
- $ref: '#/components/schemas/UserDefinedToolGroupDef'
ToolInvocationResult:
additionalProperties: false
properties:
content:
$ref: '#/components/schemas/InterleavedContent'
error_code:
type: integer
error_message:
type: string
required:
- content
type: object
ToolParamDefinition:
additionalProperties: false
properties:
@ -2917,6 +2794,31 @@ components:
required:
- param_type
type: object
ToolParameter:
additionalProperties: false
properties:
default:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
description:
type: string
name:
type: string
parameter_type:
type: string
required:
type: boolean
required:
- name
- parameter_type
- description
- required
type: object
ToolPromptFormat:
description: "`json` --\n Refers to the json format for calling tools.\n\
\ The json format takes the form like\n {\n \"type\": \"function\"\
@ -3091,6 +2993,14 @@ components:
required:
- model_id
type: object
UnregisterToolGroupRequest:
additionalProperties: false
properties:
tool_group_id:
type: string
required:
- tool_group_id
type: object
UnstructuredLogEvent:
additionalProperties: false
properties:
@ -3127,6 +3037,21 @@ components:
- message
- severity
type: object
UserDefinedToolGroupDef:
additionalProperties: false
properties:
tools:
items:
$ref: '#/components/schemas/ToolDef'
type: array
type:
const: user_defined
default: user_defined
type: string
required:
- type
- tools
type: object
UserMessage:
additionalProperties: false
properties:
@ -3209,29 +3134,6 @@ components:
- warn
- error
type: string
WolframAlphaToolDefinition:
additionalProperties: false
properties:
api_key:
type: string
input_shields:
items:
type: string
type: array
output_shields:
items:
type: string
type: array
remote_execution:
$ref: '#/components/schemas/RestAPIExecutionConfig'
type:
const: wolfram_alpha
default: wolfram_alpha
type: string
required:
- type
- api_key
type: object
info:
description: "This is the specification of the Llama Stack that provides\n \
\ a set of endpoints and their corresponding interfaces that are tailored\
@ -4869,9 +4771,6 @@ tags:
<SchemaDefinition schemaRef="#/components/schemas/Checkpoint" />'
name: Checkpoint
- description: <SchemaDefinition schemaRef="#/components/schemas/CodeInterpreterToolDefinition"
/>
name: CodeInterpreterToolDefinition
- description: <SchemaDefinition schemaRef="#/components/schemas/CompletionMessage"
/>
name: CompletionMessage
@ -4913,6 +4812,9 @@ tags:
- description: <SchemaDefinition schemaRef="#/components/schemas/DeleteAgentsSessionRequest"
/>
name: DeleteAgentsSessionRequest
- description: <SchemaDefinition schemaRef="#/components/schemas/DiscoverToolsRequest"
/>
name: DiscoverToolsRequest
- description: <SchemaDefinition schemaRef="#/components/schemas/EfficiencyConfig"
/>
name: EfficiencyConfig
@ -4932,9 +4834,6 @@ tags:
- description: <SchemaDefinition schemaRef="#/components/schemas/EvaluateRowsRequest"
/>
name: EvaluateRowsRequest
- description: <SchemaDefinition schemaRef="#/components/schemas/FunctionCallToolDefinition"
/>
name: FunctionCallToolDefinition
- description: <SchemaDefinition schemaRef="#/components/schemas/GetAgentsSessionRequest"
/>
name: GetAgentsSessionRequest
@ -4965,6 +4864,9 @@ tags:
- description: <SchemaDefinition schemaRef="#/components/schemas/InterleavedContentItem"
/>
name: InterleavedContentItem
- description: <SchemaDefinition schemaRef="#/components/schemas/InvokeToolRequest"
/>
name: InvokeToolRequest
- description: <SchemaDefinition schemaRef="#/components/schemas/Job" />
name: Job
- description: <SchemaDefinition schemaRef="#/components/schemas/JobCancelRequest"
@ -4995,6 +4897,12 @@ tags:
- description: <SchemaDefinition schemaRef="#/components/schemas/LoraFinetuningConfig"
/>
name: LoraFinetuningConfig
- description: 'A tool group that is defined by in a model context protocol server.
Refer to https://modelcontextprotocol.io/docs/concepts/tools for more information.
<SchemaDefinition schemaRef="#/components/schemas/MCPToolGroupDef" />'
name: MCPToolGroupDef
- name: Memory
- description: <SchemaDefinition schemaRef="#/components/schemas/MemoryBankDocument"
/>
@ -5003,9 +4911,6 @@ tags:
- description: <SchemaDefinition schemaRef="#/components/schemas/MemoryRetrievalStep"
/>
name: MemoryRetrievalStep
- description: <SchemaDefinition schemaRef="#/components/schemas/MemoryToolDefinition"
/>
name: MemoryToolDefinition
- description: <SchemaDefinition schemaRef="#/components/schemas/Message" />
name: Message
- description: <SchemaDefinition schemaRef="#/components/schemas/MetricEvent" />
@ -5027,9 +4932,6 @@ tags:
name: PaginatedRowsResult
- description: <SchemaDefinition schemaRef="#/components/schemas/ParamType" />
name: ParamType
- description: <SchemaDefinition schemaRef="#/components/schemas/PhotogenToolDefinition"
/>
name: PhotogenToolDefinition
- name: PostTraining (Coming Soon)
- description: <SchemaDefinition schemaRef="#/components/schemas/PostTrainingJob"
/>
@ -5092,13 +4994,11 @@ tags:
- description: <SchemaDefinition schemaRef="#/components/schemas/RegisterShieldRequest"
/>
name: RegisterShieldRequest
- description: <SchemaDefinition schemaRef="#/components/schemas/RegisterToolGroupRequest"
/>
name: RegisterToolGroupRequest
- description: <SchemaDefinition schemaRef="#/components/schemas/ResponseFormat" />
name: ResponseFormat
- description: <SchemaDefinition schemaRef="#/components/schemas/RestAPIExecutionConfig"
/>
name: RestAPIExecutionConfig
- description: <SchemaDefinition schemaRef="#/components/schemas/RestAPIMethod" />
name: RestAPIMethod
- description: <SchemaDefinition schemaRef="#/components/schemas/RouteInfo" />
name: RouteInfo
- description: <SchemaDefinition schemaRef="#/components/schemas/RunEvalRequest" />
@ -5137,9 +5037,6 @@ tags:
- name: ScoringFunctions
- description: <SchemaDefinition schemaRef="#/components/schemas/ScoringResult" />
name: ScoringResult
- description: <SchemaDefinition schemaRef="#/components/schemas/SearchToolDefinition"
/>
name: SearchToolDefinition
- description: 'A single session of an interaction with an Agentic System.
@ -5191,6 +5088,8 @@ tags:
name: TextContentItem
- description: <SchemaDefinition schemaRef="#/components/schemas/TokenLogProbs" />
name: TokenLogProbs
- description: <SchemaDefinition schemaRef="#/components/schemas/Tool" />
name: Tool
- description: <SchemaDefinition schemaRef="#/components/schemas/ToolCall" />
name: ToolCall
- description: <SchemaDefinition schemaRef="#/components/schemas/ToolCallDelta" />
@ -5200,14 +5099,26 @@ tags:
name: ToolCallParseStatus
- description: <SchemaDefinition schemaRef="#/components/schemas/ToolChoice" />
name: ToolChoice
- description: <SchemaDefinition schemaRef="#/components/schemas/ToolDef" />
name: ToolDef
- description: <SchemaDefinition schemaRef="#/components/schemas/ToolDefinition" />
name: ToolDefinition
- description: <SchemaDefinition schemaRef="#/components/schemas/ToolExecutionStep"
/>
name: ToolExecutionStep
- description: <SchemaDefinition schemaRef="#/components/schemas/ToolGroup" />
name: ToolGroup
- description: <SchemaDefinition schemaRef="#/components/schemas/ToolGroupDef" />
name: ToolGroupDef
- name: ToolGroups
- description: <SchemaDefinition schemaRef="#/components/schemas/ToolInvocationResult"
/>
name: ToolInvocationResult
- description: <SchemaDefinition schemaRef="#/components/schemas/ToolParamDefinition"
/>
name: ToolParamDefinition
- description: <SchemaDefinition schemaRef="#/components/schemas/ToolParameter" />
name: ToolParameter
- description: "This Enum refers to the prompt format for calling custom / zero shot\
\ tools\n\n`json` --\n Refers to the json format for calling tools.\n The\
\ json format takes the form like\n {\n \"type\": \"function\",\n \
@ -5224,6 +5135,7 @@ tags:
- description: <SchemaDefinition schemaRef="#/components/schemas/ToolResponseMessage"
/>
name: ToolResponseMessage
- name: ToolRuntime
- description: <SchemaDefinition schemaRef="#/components/schemas/Trace" />
name: Trace
- description: <SchemaDefinition schemaRef="#/components/schemas/TrainingConfig" />
@ -5244,9 +5156,15 @@ tags:
- description: <SchemaDefinition schemaRef="#/components/schemas/UnregisterModelRequest"
/>
name: UnregisterModelRequest
- description: <SchemaDefinition schemaRef="#/components/schemas/UnregisterToolGroupRequest"
/>
name: UnregisterToolGroupRequest
- description: <SchemaDefinition schemaRef="#/components/schemas/UnstructuredLogEvent"
/>
name: UnstructuredLogEvent
- description: <SchemaDefinition schemaRef="#/components/schemas/UserDefinedToolGroupDef"
/>
name: UserDefinedToolGroupDef
- description: <SchemaDefinition schemaRef="#/components/schemas/UserMessage" />
name: UserMessage
- description: <SchemaDefinition schemaRef="#/components/schemas/VectorMemoryBank"
@ -5259,9 +5177,6 @@ tags:
name: VersionInfo
- description: <SchemaDefinition schemaRef="#/components/schemas/ViolationLevel" />
name: ViolationLevel
- description: <SchemaDefinition schemaRef="#/components/schemas/WolframAlphaToolDefinition"
/>
name: WolframAlphaToolDefinition
x-tagGroups:
- name: Operations
tags:
@ -5283,6 +5198,8 @@ x-tagGroups:
- Shields
- SyntheticDataGeneration (Coming Soon)
- Telemetry
- ToolGroups
- ToolRuntime
- name: Types
tags:
- AgentCandidate
@ -5315,7 +5232,6 @@ x-tagGroups:
- ChatCompletionResponseEventType
- ChatCompletionResponseStreamChunk
- Checkpoint
- CodeInterpreterToolDefinition
- CompletionMessage
- CompletionRequest
- CompletionResponse
@ -5328,13 +5244,13 @@ x-tagGroups:
- Dataset
- DeleteAgentsRequest
- DeleteAgentsSessionRequest
- DiscoverToolsRequest
- EfficiencyConfig
- EmbeddingsRequest
- EmbeddingsResponse
- EvalTask
- EvaluateResponse
- EvaluateRowsRequest
- FunctionCallToolDefinition
- GetAgentsSessionRequest
- GetSpanTreeRequest
- GraphMemoryBank
@ -5345,6 +5261,7 @@ x-tagGroups:
- InsertDocumentsRequest
- InterleavedContent
- InterleavedContentItem
- InvokeToolRequest
- Job
- JobCancelRequest
- JobStatus
@ -5356,9 +5273,9 @@ x-tagGroups:
- LogEventRequest
- LogSeverity
- LoraFinetuningConfig
- MCPToolGroupDef
- MemoryBankDocument
- MemoryRetrievalStep
- MemoryToolDefinition
- Message
- MetricEvent
- Model
@ -5368,7 +5285,6 @@ x-tagGroups:
- OptimizerType
- PaginatedRowsResult
- ParamType
- PhotogenToolDefinition
- PostTrainingJob
- PostTrainingJobArtifactsResponse
- PostTrainingJobStatusResponse
@ -5388,9 +5304,8 @@ x-tagGroups:
- RegisterModelRequest
- RegisterScoringFunctionRequest
- RegisterShieldRequest
- RegisterToolGroupRequest
- ResponseFormat
- RestAPIExecutionConfig
- RestAPIMethod
- RouteInfo
- RunEvalRequest
- RunShieldRequest
@ -5405,7 +5320,6 @@ x-tagGroups:
- ScoreResponse
- ScoringFn
- ScoringResult
- SearchToolDefinition
- Session
- Shield
- ShieldCallStep
@ -5422,13 +5336,19 @@ x-tagGroups:
- SystemMessage
- TextContentItem
- TokenLogProbs
- Tool
- ToolCall
- ToolCallDelta
- ToolCallParseStatus
- ToolChoice
- ToolDef
- ToolDefinition
- ToolExecutionStep
- ToolGroup
- ToolGroupDef
- ToolInvocationResult
- ToolParamDefinition
- ToolParameter
- ToolPromptFormat
- ToolResponse
- ToolResponseMessage
@ -5439,10 +5359,11 @@ x-tagGroups:
- UnregisterDatasetRequest
- UnregisterMemoryBankRequest
- UnregisterModelRequest
- UnregisterToolGroupRequest
- UnstructuredLogEvent
- UserDefinedToolGroupDef
- UserMessage
- VectorMemoryBank
- VectorMemoryBankParams
- VersionInfo
- ViolationLevel
- WolframAlphaToolDefinition

View file

@ -74,12 +74,14 @@ ToolGroupDef = register_schema(
)
@json_schema_type
class ToolGroupInput(BaseModel):
tool_group_id: str
tool_group: ToolGroupDef
provider_id: Optional[str] = None
@json_schema_type
class ToolGroup(Resource):
type: Literal[ResourceType.tool_group.value] = ResourceType.tool_group.value

View file

@ -33,6 +33,7 @@ from llama_stack.apis.scoring_functions import ScoringFunctions
from llama_stack.apis.shields import Shields
from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration
from llama_stack.apis.telemetry import Telemetry
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.distribution.datatypes import StackRunConfig
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
@ -63,6 +64,8 @@ class LlamaStack(
Models,
Shields,
Inspect,
ToolGroups,
ToolRuntime,
):
pass

View file

@ -4,78 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from typing import Dict, List
from uuid import uuid4
import pytest
from llama_stack.providers.tests.env import get_env_or_fail
from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.lib.agents.custom_tool import CustomTool
from llama_stack_client.lib.agents.event_logger import EventLogger
from llama_stack_client.types import CompletionMessage, ToolResponseMessage
from llama_stack_client.types.agent_create_params import AgentConfig
from llama_stack_client.types.tool_param_definition_param import (
ToolParamDefinitionParam,
)
class TestCustomTool(CustomTool):
"""Tool to give boiling point of a liquid
Returns the correct value for water in Celcius and Fahrenheit
and returns -1 for other liquids
"""
def run(self, messages: List[CompletionMessage]) -> List[ToolResponseMessage]:
assert len(messages) == 1, "Expected single message"
message = messages[0]
tool_call = message.tool_calls[0]
try:
response = self.run_impl(**tool_call.arguments)
response_str = json.dumps(response, ensure_ascii=False)
except Exception as e:
response_str = f"Error when running tool: {e}"
message = ToolResponseMessage(
call_id=tool_call.call_id,
tool_name=tool_call.tool_name,
content=response_str,
role="ipython",
)
return [message]
def get_name(self) -> str:
return "get_boiling_point"
def get_description(self) -> str:
return "Get the boiling point of a imaginary liquids (eg. polyjuice)"
def get_params_definition(self) -> Dict[str, ToolParamDefinitionParam]:
return {
"liquid_name": ToolParamDefinitionParam(
param_type="string", description="The name of the liquid", required=True
),
"celcius": ToolParamDefinitionParam(
param_type="boolean",
description="Whether to return the boiling point in Celcius",
required=False,
),
}
def run_impl(self, liquid_name: str, celcius: bool = True) -> int:
if liquid_name.lower() == "polyjuice":
if celcius:
return -100
else:
return -212
else:
return -1
@pytest.fixture(scope="session")
@ -151,12 +85,8 @@ def test_agent_simple(llama_stack_client, agent_config):
def test_builtin_tool_brave_search(llama_stack_client, agent_config):
agent_config = {
**agent_config,
"tools": [
{
"type": "brave_search",
"engine": "brave",
"api_key": get_env_or_fail("BRAVE_SEARCH_API_KEY"),
}
"available_tools": [
"brave_search",
],
}
print(f"Agent Config: {agent_config}")
@ -167,7 +97,7 @@ def test_builtin_tool_brave_search(llama_stack_client, agent_config):
messages=[
{
"role": "user",
"content": "Search the web and tell me who the 44th president of the United States was. Please use tools",
"content": "Search the web and tell me who the current CEO of Meta is.",
}
],
session_id=session_id,
@ -178,92 +108,5 @@ def test_builtin_tool_brave_search(llama_stack_client, agent_config):
assert "tool_execution>" in logs_str
assert "Tool:brave_search Response:" in logs_str
assert "obama" in logs_str.lower()
if len(agent_config["input_shields"]) > 0:
assert "mark zuckerberg" in logs_str.lower()
assert "No Violation" in logs_str
def test_builtin_tool_code_execution(llama_stack_client, agent_config):
agent_config = {
**agent_config,
"tools": [
{
"type": "code_interpreter",
}
],
}
agent = Agent(llama_stack_client, agent_config)
session_id = agent.create_session(f"test-session-{uuid4()}")
response = agent.create_turn(
messages=[
{
"role": "user",
"content": "Write code to answer the question: What is the 100th prime number?",
},
],
session_id=session_id,
)
logs = [str(log) for log in EventLogger().log(response) if log is not None]
logs_str = "".join(logs)
if "Tool:code_interpreter Response" not in logs_str:
assert len(logs_str) > 0
pytest.skip("code_interpreter not called by model")
assert "Tool:code_interpreter Response" in logs_str
if "No such file or directory: 'bwrap'" in logs_str:
assert "prime" in logs_str
pytest.skip("`bwrap` is not available on this platform")
else:
assert "541" in logs_str
def test_custom_tool(llama_stack_client, agent_config):
agent_config = {
**agent_config,
"model": "meta-llama/Llama-3.2-3B-Instruct",
"tools": [
{
"type": "brave_search",
"engine": "brave",
"api_key": get_env_or_fail("BRAVE_SEARCH_API_KEY"),
},
{
"function_name": "get_boiling_point",
"description": "Get the boiling point of a imaginary liquids (eg. polyjuice)",
"parameters": {
"liquid_name": {
"param_type": "str",
"description": "The name of the liquid",
"required": True,
},
"celcius": {
"param_type": "boolean",
"description": "Whether to return the boiling point in Celcius",
"required": False,
},
},
"type": "function_call",
},
],
"tool_prompt_format": "python_list",
}
agent = Agent(llama_stack_client, agent_config, custom_tools=(TestCustomTool(),))
session_id = agent.create_session(f"test-session-{uuid4()}")
response = agent.create_turn(
messages=[
{
"role": "user",
"content": "What is the boiling point of polyjuice?",
},
],
session_id=session_id,
)
logs = [str(log) for log in EventLogger().log(response) if log is not None]
logs_str = "".join(logs)
assert "-100" in logs_str
assert "CustomTool" in logs_str