mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
103 lines
3.1 KiB
Python
103 lines
3.1 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import json
|
|
|
|
from abc import abstractmethod
|
|
from typing import Dict, List
|
|
|
|
from llama_models.llama3_1.api.datatypes import * # noqa: F403
|
|
from llama_toolchain.agentic_system.api import * # noqa: F403
|
|
|
|
from .builtin import interpret_content_as_attachment
|
|
|
|
|
|
class CustomTool:
|
|
"""
|
|
Developers can define their custom tools that models can use
|
|
by extending this class.
|
|
|
|
Developers need to provide
|
|
- name
|
|
- description
|
|
- params_definition
|
|
- implement tool's behavior in `run_impl` method
|
|
|
|
NOTE: The return of the `run` method needs to be json serializable
|
|
"""
|
|
|
|
@abstractmethod
|
|
def get_name(self) -> str:
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def get_description(self) -> str:
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def get_params_definition(self) -> Dict[str, ToolParamDefinition]:
|
|
raise NotImplementedError
|
|
|
|
def get_instruction_string(self) -> str:
|
|
return f"Use the function '{self.get_name()}' to: {self.get_description()}"
|
|
|
|
def parameters_for_system_prompt(self) -> str:
|
|
return json.dumps(
|
|
{
|
|
"name": self.get_name(),
|
|
"description": self.get_description(),
|
|
"parameters": {
|
|
name: definition.__dict__
|
|
for name, definition in self.get_params_definition().items()
|
|
},
|
|
}
|
|
)
|
|
|
|
def get_tool_definition(self) -> AgenticSystemToolDefinition:
|
|
return AgenticSystemToolDefinition(
|
|
tool_name=self.get_name(),
|
|
description=self.get_description(),
|
|
parameters=self.get_params_definition(),
|
|
)
|
|
|
|
@abstractmethod
|
|
async def run(self, messages: List[Message]) -> List[Message]:
|
|
raise NotImplementedError
|
|
|
|
|
|
class SingleMessageCustomTool(CustomTool):
|
|
"""
|
|
Helper class to handle custom tools that take a single message
|
|
Extending this class and implementing the `run_impl` method will
|
|
allow for the tool be called by the model and the necessary plumbing.
|
|
"""
|
|
|
|
async def run(self, messages: List[CompletionMessage]) -> List[ToolResponseMessage]:
|
|
assert len(messages) == 1, "Expected single message"
|
|
|
|
message = messages[0]
|
|
|
|
tool_call = message.tool_calls[0]
|
|
|
|
try:
|
|
response = await self.run_impl(**tool_call.arguments)
|
|
response_str = json.dumps(response, ensure_ascii=False)
|
|
except Exception as e:
|
|
response_str = f"Error when running tool: {e}"
|
|
|
|
message = ToolResponseMessage(
|
|
call_id=tool_call.call_id,
|
|
tool_name=tool_call.tool_name,
|
|
content=response_str,
|
|
)
|
|
if attachment := interpret_content_as_attachment(response_str):
|
|
message.content = attachment
|
|
|
|
return [message]
|
|
|
|
@abstractmethod
|
|
async def run_impl(self, *args, **kwargs):
|
|
raise NotImplementedError()
|