feat(anthropic.py): support anthropic 'tool_choice' param

Closes https://github.com/BerriAI/litellm/issues/3752
This commit is contained in:
Krrish Dholakia 2024-05-21 17:50:44 -07:00
parent 9f4c04dce3
commit 4795c56f84
4 changed files with 23 additions and 3 deletions

View file

@ -10,6 +10,7 @@ from .prompt_templates.factory import prompt_factory, custom_prompt
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from .base import BaseLLM
import httpx # type: ignore
from litellm.types.llms.anthropic import AnthropicMessagesToolChoice
class AnthropicConstants(Enum):
@ -102,6 +103,17 @@ class AnthropicConfig:
optional_params["max_tokens"] = value
if param == "tools":
optional_params["tools"] = value
if param == "tool_choice":
_tool_choice: Optional[AnthropicMessagesToolChoice] = None
if value == "auto":
_tool_choice = {"type": "auto"}
elif value == "required":
_tool_choice = {"type": "any"}
elif isinstance(value, dict):
_tool_choice = {"type": "tool", "name": value["function"]["name"]}
if _tool_choice is not None:
optional_params["tool_choice"] = _tool_choice
if param == "stream" and value == True:
optional_params["stream"] = value
if param == "stop":

View file

@ -486,7 +486,7 @@ def completion(
response_format: Optional[dict] = None,
seed: Optional[int] = None,
tools: Optional[List] = None,
tool_choice: Optional[str] = None,
tool_choice: Optional[Union[str, dict]] = None,
logprobs: Optional[bool] = None,
top_logprobs: Optional[int] = None,
deployment_id=None,

View file

@ -278,9 +278,12 @@ def test_completion_claude_3_function_call():
model="anthropic/claude-3-opus-20240229",
messages=messages,
tools=tools,
tool_choice={"type": "tool", "name": "get_weather"},
extra_headers={"anthropic-beta": "tools-2024-05-16"},
tool_choice={
"type": "function",
"function": {"name": "get_current_weather"},
},
)
# Add any assertions, here to check response args
print(response)
assert isinstance(response.choices[0].message.tool_calls[0].function.name, str)

View file

@ -5,6 +5,11 @@ from pydantic import BaseModel, validator
from typing_extensions import Literal, Required, TypedDict
class AnthropicMessagesToolChoice(TypedDict, total=False):
type: Required[Literal["auto", "any", "tool"]]
name: str
class AnthopicMessagesAssistantMessageTextContentParam(TypedDict, total=False):
type: Required[Literal["text"]]