feat(anthropic.py): adds tool calling support

This commit is contained in:
Krrish Dholakia 2024-03-04 10:42:28 -08:00
parent 1c40282627
commit ae82b3f31a
3 changed files with 89 additions and 4 deletions

View file

@ -6,7 +6,11 @@ import time
from typing import Callable, Optional from typing import Callable, Optional
from litellm.utils import ModelResponse, Usage from litellm.utils import ModelResponse, Usage
import litellm import litellm
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import (
prompt_factory,
custom_prompt,
construct_tool_use_system_prompt,
)
import httpx import httpx
@ -142,6 +146,15 @@ def completion(
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v optional_params[k] = v
## Handle Tool Calling
if "tools" in optional_params:
tool_calling_system_prompt = construct_tool_use_system_prompt(
tools=optional_params["tools"]
)
optional_params["system"] = (
optional_params("system", "\n") + tool_calling_system_prompt
) # add the anthropic tool calling prompt to the system prompt
data = { data = {
"model": model, "model": model,
"messages": messages, "messages": messages,

View file

@ -390,7 +390,7 @@ def format_prompt_togetherai(messages, prompt_format, chat_template):
return prompt return prompt
### ### ANTHROPIC ###
def anthropic_pt( def anthropic_pt(
@ -424,6 +424,62 @@ def anthropic_pt(
return prompt return prompt
def construct_format_parameters_prompt(parameters: dict):
parameter_str = "<parameter>\n"
for k, v in parameters.items():
parameter_str += f"<{k}>"
parameter_str += f"{v}"
parameter_str += f"</{k}>"
parameter_str += "\n</parameter>"
return parameter_str
def construct_format_tool_for_claude_prompt(name, description, parameters):
constructed_prompt = (
"<tool_description>\n"
f"<tool_name>{name}</tool_name>\n"
"<description>\n"
f"{description}\n"
"</description>\n"
"<parameters>\n"
f"{construct_format_parameters_prompt(parameters)}\n"
"</parameters>\n"
"</tool_description>"
)
return constructed_prompt
def construct_tool_use_system_prompt(
tools,
): # from https://github.com/anthropics/anthropic-cookbook/blob/main/function_calling/function_calling.ipynb
tool_str_list = []
for tool in tools:
tool_str = construct_format_tool_for_claude_prompt(
tool["function"]["name"],
tool["function"].get("description", ""),
tool["function"].get("parameters", {}),
)
tool_str_list.append(tool_str)
tool_use_system_prompt = (
"In this environment you have access to a set of tools you can use to answer the user's question.\n"
"\n"
"You may call them like this:\n"
"<function_calls>\n"
"<invoke>\n"
"<tool_name>$TOOL_NAME</tool_name>\n"
"<parameters>\n"
"<$PARAMETER_NAME>$PARAMETER_VALUE</$PARAMETER_NAME>\n"
"...\n"
"</parameters>\n"
"</invoke>\n"
"</function_calls>\n"
"\n"
"Here are the tools available:\n"
"<tools>\n" + "\n".join([tool_str for tool_str in tool_str_list]) + "\n</tools>"
)
return tool_use_system_prompt
def anthropic_messages_pt(messages: list): def anthropic_messages_pt(messages: list):
""" """
format messages for anthropic format messages for anthropic
@ -464,6 +520,9 @@ def anthropic_messages_pt(messages: list):
return new_messages return new_messages
###
def amazon_titan_pt( def amazon_titan_pt(
messages: list, messages: list,
): # format - https://github.com/BerriAI/litellm/issues/1896 ): # format - https://github.com/BerriAI/litellm/issues/1896
@ -690,6 +749,8 @@ def prompt_factory(
if custom_llm_provider == "ollama": if custom_llm_provider == "ollama":
return ollama_pt(model=model, messages=messages) return ollama_pt(model=model, messages=messages)
elif custom_llm_provider == "anthropic": elif custom_llm_provider == "anthropic":
if model == "claude-instant-1" or model == "claude-2.1":
return anthropic_pt(messages=messages)
return anthropic_messages_pt(messages=messages) return anthropic_messages_pt(messages=messages)
elif custom_llm_provider == "together_ai": elif custom_llm_provider == "together_ai":
prompt_format, chat_template = get_model_info(token=api_key, model=model) prompt_format, chat_template = get_model_info(token=api_key, model=model)

View file

@ -4106,6 +4106,7 @@ def get_optional_params(
and custom_llm_provider != "anyscale" and custom_llm_provider != "anyscale"
and custom_llm_provider != "together_ai" and custom_llm_provider != "together_ai"
and custom_llm_provider != "mistral" and custom_llm_provider != "mistral"
and custom_llm_provider != "anthropic"
): ):
if custom_llm_provider == "ollama" or custom_llm_provider == "ollama_chat": if custom_llm_provider == "ollama" or custom_llm_provider == "ollama_chat":
# ollama actually supports json output # ollama actually supports json output
@ -4186,7 +4187,15 @@ def get_optional_params(
## raise exception if provider doesn't support passed in param ## raise exception if provider doesn't support passed in param
if custom_llm_provider == "anthropic": if custom_llm_provider == "anthropic":
## check if unsupported param passed in ## check if unsupported param passed in
supported_params = ["stream", "stop", "temperature", "top_p", "max_tokens"] supported_params = [
"stream",
"stop",
"temperature",
"top_p",
"max_tokens",
"tools",
"tool_choice",
]
_check_valid_arg(supported_params=supported_params) _check_valid_arg(supported_params=supported_params)
# handle anthropic params # handle anthropic params
if stream: if stream:
@ -4201,6 +4210,8 @@ def get_optional_params(
optional_params["top_p"] = top_p optional_params["top_p"] = top_p
if max_tokens is not None: if max_tokens is not None:
optional_params["max_tokens"] = max_tokens optional_params["max_tokens"] = max_tokens
if tools is not None:
optional_params["tools"] = tools
elif custom_llm_provider == "cohere": elif custom_llm_provider == "cohere":
## check if unsupported param passed in ## check if unsupported param passed in
supported_params = [ supported_params = [
@ -9704,4 +9715,4 @@ def _get_base_model_from_metadata(model_call_details=None):
base_model = model_info.get("base_model", None) base_model = model_info.get("base_model", None)
if base_model is not None: if base_model is not None:
return base_model return base_model
return None return None