diff --git a/litellm/llms/anthropic.py b/litellm/llms/anthropic.py index 184bc7153f..ce413be65a 100644 --- a/litellm/llms/anthropic.py +++ b/litellm/llms/anthropic.py @@ -6,7 +6,11 @@ import time from typing import Callable, Optional from litellm.utils import ModelResponse, Usage import litellm -from .prompt_templates.factory import prompt_factory, custom_prompt +from .prompt_templates.factory import ( + prompt_factory, + custom_prompt, + construct_tool_use_system_prompt, +) import httpx @@ -142,6 +146,15 @@ def completion( ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in optional_params[k] = v + ## Handle Tool Calling + if "tools" in optional_params: + tool_calling_system_prompt = construct_tool_use_system_prompt( + tools=optional_params["tools"] + ) + optional_params["system"] = ( + optional_params("system", "\n") + tool_calling_system_prompt + ) # add the anthropic tool calling prompt to the system prompt + data = { "model": model, "messages": messages, diff --git a/litellm/llms/prompt_templates/factory.py b/litellm/llms/prompt_templates/factory.py index 6b0deb2ee7..cc75237e05 100644 --- a/litellm/llms/prompt_templates/factory.py +++ b/litellm/llms/prompt_templates/factory.py @@ -390,7 +390,7 @@ def format_prompt_togetherai(messages, prompt_format, chat_template): return prompt -### +### ANTHROPIC ### def anthropic_pt( @@ -424,6 +424,62 @@ def anthropic_pt( return prompt +def construct_format_parameters_prompt(parameters: dict): + parameter_str = "\n" + for k, v in parameters.items(): + parameter_str += f"<{k}>" + parameter_str += f"{v}" + parameter_str += f"" + parameter_str += "\n" + return parameter_str + + +def construct_format_tool_for_claude_prompt(name, description, parameters): + constructed_prompt = ( + "\n" + f"{name}\n" + "\n" + f"{description}\n" + "\n" + "\n" + f"{construct_format_parameters_prompt(parameters)}\n" + "\n" + "" + ) + return constructed_prompt + + +def construct_tool_use_system_prompt( + tools, +): # from https://github.com/anthropics/anthropic-cookbook/blob/main/function_calling/function_calling.ipynb + tool_str_list = [] + for tool in tools: + tool_str = construct_format_tool_for_claude_prompt( + tool["function"]["name"], + tool["function"].get("description", ""), + tool["function"].get("parameters", {}), + ) + tool_str_list.append(tool_str) + tool_use_system_prompt = ( + "In this environment you have access to a set of tools you can use to answer the user's question.\n" + "\n" + "You may call them like this:\n" + "\n" + "\n" + "$TOOL_NAME\n" + "\n" + "<$PARAMETER_NAME>$PARAMETER_VALUE\n" + "...\n" + "\n" + "\n" + "\n" + "\n" + "Here are the tools available:\n" + "\n" + "\n".join([tool_str for tool_str in tool_str_list]) + "\n" + ) + return tool_use_system_prompt + + def anthropic_messages_pt(messages: list): """ format messages for anthropic @@ -464,6 +520,9 @@ def anthropic_messages_pt(messages: list): return new_messages +### + + def amazon_titan_pt( messages: list, ): # format - https://github.com/BerriAI/litellm/issues/1896 @@ -690,6 +749,8 @@ def prompt_factory( if custom_llm_provider == "ollama": return ollama_pt(model=model, messages=messages) elif custom_llm_provider == "anthropic": + if model == "claude-instant-1" or model == "claude-2.1": + return anthropic_pt(messages=messages) return anthropic_messages_pt(messages=messages) elif custom_llm_provider == "together_ai": prompt_format, chat_template = get_model_info(token=api_key, model=model) diff --git a/litellm/utils.py b/litellm/utils.py index 233fd6bae7..69f324589a 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -4106,6 +4106,7 @@ def get_optional_params( and custom_llm_provider != "anyscale" and custom_llm_provider != "together_ai" and custom_llm_provider != "mistral" + and custom_llm_provider != "anthropic" ): if custom_llm_provider == "ollama" or custom_llm_provider == "ollama_chat": # ollama actually supports json output @@ -4186,7 +4187,15 @@ def get_optional_params( ## raise exception if provider doesn't support passed in param if custom_llm_provider == "anthropic": ## check if unsupported param passed in - supported_params = ["stream", "stop", "temperature", "top_p", "max_tokens"] + supported_params = [ + "stream", + "stop", + "temperature", + "top_p", + "max_tokens", + "tools", + "tool_choice", + ] _check_valid_arg(supported_params=supported_params) # handle anthropic params if stream: @@ -4201,6 +4210,8 @@ def get_optional_params( optional_params["top_p"] = top_p if max_tokens is not None: optional_params["max_tokens"] = max_tokens + if tools is not None: + optional_params["tools"] = tools elif custom_llm_provider == "cohere": ## check if unsupported param passed in supported_params = [ @@ -9704,4 +9715,4 @@ def _get_base_model_from_metadata(model_call_details=None): base_model = model_info.get("base_model", None) if base_model is not None: return base_model - return None \ No newline at end of file + return None