allow users to provide a tool

This commit is contained in:
Matthew Farrellee 2024-11-21 06:46:26 -05:00
parent 8a35dc8b0e
commit 914cea8939

View file

@ -9,10 +9,12 @@ import warnings
from typing import Any, AsyncGenerator, Dict, Generator, List, Optional
from llama_models.llama3.api.datatypes import (
BuiltinTool,
CompletionMessage,
StopReason,
TokenLogProbs,
ToolCall,
ToolDefinition,
)
from openai import AsyncStream
from openai.types.chat import ChatCompletionChunk as OpenAIChatCompletionChunk
@ -36,6 +38,82 @@ from llama_stack.apis.inference import (
)
def _convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict:
"""
Convert a ToolDefinition to an OpenAI API-compatible dictionary.
ToolDefinition:
tool_name: str | BuiltinTool
description: Optional[str]
parameters: Optional[Dict[str, ToolParamDefinition]]
ToolParamDefinition:
param_type: str
description: Optional[str]
required: Optional[bool]
default: Optional[Any]
OpenAI spec -
{
"type": "function",
"function": {
"name": tool_name,
"description": description,
"parameters": {
"type": "object",
"properties": {
param_name: {
"type": param_type,
"description": description,
"default": default,
},
...
},
"required": [param_name, ...],
},
},
}
"""
out = {
"type": "function",
"function": {},
}
function = out["function"]
if isinstance(tool.tool_name, BuiltinTool):
function.update(name=tool.tool_name.value) # TODO(mf): is this sufficient?
else:
function.update(name=tool.tool_name)
if tool.description:
function.update(description=tool.description)
if tool.parameters:
parameters = {
"type": "object",
"properties": {},
}
properties = parameters["properties"]
required = []
for param_name, param in tool.parameters.items():
properties[param_name] = {"type": param.param_type}
if param.description:
properties[param_name].update(description=param.description)
if param.default:
properties[param_name].update(default=param.default)
if param.required:
required.append(param_name)
if required:
parameters.update(required=required)
function.update(parameters=parameters)
return out
def _convert_message(message: Message) -> Dict:
"""
Convert a Message to an OpenAI API-compatible dictionary.
@ -90,7 +168,9 @@ def convert_chat_completion_request(
)
if request.tools:
payload.update(tools=request.tools)
payload.update(
tools=[_convert_tooldef_to_openai_tool(tool) for tool in request.tools]
)
if request.tool_choice:
payload.update(
tool_choice=request.tool_choice.value