From 788b06a33c070ca65d0d76e9d05db8e2a398981f Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 7 Aug 2024 11:14:05 -0700 Subject: [PATCH] fix(utils.py): support deepseek tool calling Fixes https://github.com/BerriAI/litellm/issues/5081 --- litellm/tests/test_completion.py | 23 +++++++++++++++++++++-- litellm/utils.py | 24 ++++++++---------------- 2 files changed, 29 insertions(+), 18 deletions(-) diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index eec163f26a..aee2068ddf 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -4085,9 +4085,28 @@ async def test_acompletion_gemini(): def test_completion_deepseek(): litellm.set_verbose = True model_name = "deepseek/deepseek-chat" - messages = [{"role": "user", "content": "Hey, how's it going?"}] + tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather of an location, the user shoud supply a location first", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + } + }, + "required": ["location"], + }, + }, + }, + ] + messages = [{"role": "user", "content": "How's the weather in Hangzhou?"}] try: - response = completion(model=model_name, messages=messages) + response = completion(model=model_name, messages=messages, tools=tools) # Add any assertions here to check the response print(response) except litellm.APIError as e: diff --git a/litellm/utils.py b/litellm/utils.py index 20beb47dc2..e1a686eaf7 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -3536,22 +3536,11 @@ def get_optional_params( ) _check_valid_arg(supported_params=supported_params) - if frequency_penalty is not None: - optional_params["frequency_penalty"] = frequency_penalty - if max_tokens is not None: - optional_params["max_tokens"] = max_tokens - if presence_penalty is not None: - optional_params["presence_penalty"] = presence_penalty - if stop is not None: - optional_params["stop"] = stop - if stream is not None: - optional_params["stream"] = stream - if temperature is not None: - optional_params["temperature"] = temperature - if logprobs is not None: - optional_params["logprobs"] = logprobs - if top_logprobs is not None: - optional_params["top_logprobs"] = top_logprobs + optional_params = litellm.OpenAIConfig().map_openai_params( + non_default_params=non_default_params, + optional_params=optional_params, + model=model, + ) elif custom_llm_provider == "openrouter": supported_params = get_supported_openai_params( model=model, custom_llm_provider=custom_llm_provider @@ -4141,12 +4130,15 @@ def get_supported_openai_params( "frequency_penalty", "max_tokens", "presence_penalty", + "response_format", "stop", "stream", "temperature", "top_p", "logprobs", "top_logprobs", + "tools", + "tool_choice", ] elif custom_llm_provider == "cohere": return [