mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-23 00:12:24 +00:00
working end to end client sdk tests with custom tools
This commit is contained in:
parent
1a66ddc1b5
commit
4dd2f4c363
5 changed files with 304 additions and 149 deletions
|
|
@ -18,13 +18,11 @@ from typing import (
|
|||
runtime_checkable,
|
||||
)
|
||||
|
||||
from llama_models.llama3.api.datatypes import ToolParamDefinition
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.apis.common.content_types import URL, InterleavedContent
|
||||
from llama_stack.apis.common.deployment_types import RestAPIExecutionConfig
|
||||
from llama_stack.apis.inference import (
|
||||
CompletionMessage,
|
||||
SamplingParams,
|
||||
|
|
@ -140,6 +138,7 @@ class AgentConfigCommon(BaseModel):
|
|||
input_shields: Optional[List[str]] = Field(default_factory=list)
|
||||
output_shields: Optional[List[str]] = Field(default_factory=list)
|
||||
available_tools: Optional[List[str]] = Field(default_factory=list)
|
||||
custom_tools: Optional[List[CustomToolDef]] = Field(default_factory=list)
|
||||
preprocessing_tools: Optional[List[str]] = Field(default_factory=list)
|
||||
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
||||
|
|
|
|||
|
|
@ -400,6 +400,10 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
output_attachments = []
|
||||
|
||||
n_iter = 0
|
||||
# Build a map of custom tools to their definitions for faster lookup
|
||||
custom_tools = {}
|
||||
for tool in self.agent_config.custom_tools:
|
||||
custom_tools[tool.name] = tool
|
||||
while True:
|
||||
msg = input_messages[-1]
|
||||
|
||||
|
|
@ -530,6 +534,9 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
else:
|
||||
log.info(f"{str(message)}")
|
||||
tool_call = message.tool_calls[0]
|
||||
if tool_call.tool_name in custom_tools:
|
||||
yield message
|
||||
return
|
||||
|
||||
step_id = str(uuid.uuid4())
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
|
|
@ -619,6 +626,22 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
async def _get_tools(self) -> List[ToolDefinition]:
|
||||
ret = []
|
||||
for tool in self.agent_config.custom_tools:
|
||||
params = {}
|
||||
for param in tool.parameters:
|
||||
params[param.name] = ToolParamDefinition(
|
||||
param_type=param.parameter_type,
|
||||
description=param.description,
|
||||
required=param.required,
|
||||
default=param.default,
|
||||
)
|
||||
ret.append(
|
||||
ToolDefinition(
|
||||
tool_name=tool.name,
|
||||
description=tool.description,
|
||||
parameters=params,
|
||||
)
|
||||
)
|
||||
for tool_name in self.agent_config.available_tools:
|
||||
tool = await self.tool_groups_api.get_tool(tool_name)
|
||||
if tool.built_in_type:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue