diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 6f702d304..e5853e3dc 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -691,7 +691,7 @@ class PrismaClient: finally: os.chdir(original_dir) # Now you can import the Prisma Client - from prisma import Prisma + from prisma import Prisma # type: ignore self.db = Prisma() # Client to connect to Prisma db diff --git a/litellm/tests/test_optional_params.py b/litellm/tests/test_optional_params.py index 0741b6fbe..be14e4feb 100644 --- a/litellm/tests/test_optional_params.py +++ b/litellm/tests/test_optional_params.py @@ -1,20 +1,25 @@ #### What this tests #### # This tests if get_optional_params works as expected -import sys, os, time, inspect, asyncio, traceback +import asyncio +import inspect +import os +import sys +import time +import traceback + import pytest sys.path.insert(0, os.path.abspath("../..")) +from unittest.mock import MagicMock, patch + import litellm -from litellm.utils import get_optional_params_embeddings, get_optional_params -from litellm.llms.prompt_templates.factory import ( - map_system_message_pt, -) -from unittest.mock import patch, MagicMock +from litellm.llms.prompt_templates.factory import map_system_message_pt from litellm.types.completion import ( - ChatCompletionUserMessageParam, - ChatCompletionSystemMessageParam, ChatCompletionMessageParam, + ChatCompletionSystemMessageParam, + ChatCompletionUserMessageParam, ) +from litellm.utils import get_optional_params, get_optional_params_embeddings ## get_optional_params_embeddings ### Models: OpenAI, Azure, Bedrock @@ -286,3 +291,45 @@ def test_dynamic_drop_params_e2e(): mock_response.assert_called_once() print(mock_response.call_args.kwargs["data"]) assert "response_format" not in mock_response.call_args.kwargs["data"] + + +@pytest.mark.parametrize("drop_params", [True, False, None]) +def test_dynamic_drop_additional_params(drop_params): + """ + Make a call to cohere, dropping 'response_format' specifically + """ + if drop_params is True: + optional_params = litellm.utils.get_optional_params( + model="command-r", + custom_llm_provider="cohere", + response_format="json", + additional_drop_params=["response_format"], + ) + else: + try: + optional_params = litellm.utils.get_optional_params( + model="command-r", + custom_llm_provider="cohere", + response_format="json", + ) + pytest.fail("Expected to fail") + except Exception as e: + pass + + +def test_dynamic_drop_additional_params_e2e(): + with patch("requests.post", new=MagicMock()) as mock_response: + try: + response = litellm.completion( + model="command-r", + messages=[{"role": "user", "content": "Hey, how's it going?"}], + response_format={"key": "value"}, + additional_drop_params=["response_format"], + ) + except Exception as e: + pass + + mock_response.assert_called_once() + print(mock_response.call_args.kwargs["data"]) + assert "response_format" not in mock_response.call_args.kwargs["data"] + assert "additional_drop_params" not in mock_response.call_args.kwargs["data"] diff --git a/litellm/utils.py b/litellm/utils.py index 852c1d6a8..37de56692 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2293,6 +2293,7 @@ def get_optional_params( extra_headers=None, api_version=None, drop_params=None, + additional_drop_params=None, **kwargs, ): # retrieve all parameters passed to the function @@ -2309,7 +2310,6 @@ def get_optional_params( k.startswith("vertex_") and custom_llm_provider != "vertex_ai" ): # allow dynamically setting vertex ai init logic continue - passed_params[k] = v optional_params: Dict = {} @@ -2365,7 +2365,19 @@ def get_optional_params( "extra_headers": None, "api_version": None, "drop_params": None, + "additional_drop_params": None, } + + def _should_drop_param(k, additional_drop_params) -> bool: + if ( + additional_drop_params is not None + and isinstance(additional_drop_params, list) + and k in additional_drop_params + ): + return True # allow user to drop specific params for a model - e.g. vllm - logit bias + + return False + # filter out those parameters that were passed with non-default values non_default_params = { k: v @@ -2375,8 +2387,11 @@ def get_optional_params( and k != "custom_llm_provider" and k != "api_version" and k != "drop_params" + and k != "additional_drop_params" and k in default_params and v != default_params[k] + and _should_drop_param(k=k, additional_drop_params=additional_drop_params) + is False ) }