Merge pull request #4313 from BerriAI/litellm_drop_specific_params

fix(utils.py): allow dropping specific openai params
This commit is contained in:
Krish Dholakia 2024-06-20 15:15:19 -07:00 committed by GitHub
commit c8a40eca05
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 72 additions and 10 deletions

View file

@ -691,7 +691,7 @@ class PrismaClient:
finally:
os.chdir(original_dir)
# Now you can import the Prisma Client
from prisma import Prisma
from prisma import Prisma # type: ignore
self.db = Prisma() # Client to connect to Prisma db

View file

@ -1,20 +1,25 @@
#### What this tests ####
# This tests if get_optional_params works as expected
import sys, os, time, inspect, asyncio, traceback
import asyncio
import inspect
import os
import sys
import time
import traceback
import pytest
sys.path.insert(0, os.path.abspath("../.."))
from unittest.mock import MagicMock, patch
import litellm
from litellm.utils import get_optional_params_embeddings, get_optional_params
from litellm.llms.prompt_templates.factory import (
map_system_message_pt,
)
from unittest.mock import patch, MagicMock
from litellm.llms.prompt_templates.factory import map_system_message_pt
from litellm.types.completion import (
ChatCompletionUserMessageParam,
ChatCompletionSystemMessageParam,
ChatCompletionMessageParam,
ChatCompletionSystemMessageParam,
ChatCompletionUserMessageParam,
)
from litellm.utils import get_optional_params, get_optional_params_embeddings
## get_optional_params_embeddings
### Models: OpenAI, Azure, Bedrock
@ -286,3 +291,45 @@ def test_dynamic_drop_params_e2e():
mock_response.assert_called_once()
print(mock_response.call_args.kwargs["data"])
assert "response_format" not in mock_response.call_args.kwargs["data"]
@pytest.mark.parametrize("drop_params", [True, False, None])
def test_dynamic_drop_additional_params(drop_params):
"""
Make a call to cohere, dropping 'response_format' specifically
"""
if drop_params is True:
optional_params = litellm.utils.get_optional_params(
model="command-r",
custom_llm_provider="cohere",
response_format="json",
additional_drop_params=["response_format"],
)
else:
try:
optional_params = litellm.utils.get_optional_params(
model="command-r",
custom_llm_provider="cohere",
response_format="json",
)
pytest.fail("Expected to fail")
except Exception as e:
pass
def test_dynamic_drop_additional_params_e2e():
with patch("requests.post", new=MagicMock()) as mock_response:
try:
response = litellm.completion(
model="command-r",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
response_format={"key": "value"},
additional_drop_params=["response_format"],
)
except Exception as e:
pass
mock_response.assert_called_once()
print(mock_response.call_args.kwargs["data"])
assert "response_format" not in mock_response.call_args.kwargs["data"]
assert "additional_drop_params" not in mock_response.call_args.kwargs["data"]

View file

@ -2293,6 +2293,7 @@ def get_optional_params(
extra_headers=None,
api_version=None,
drop_params=None,
additional_drop_params=None,
**kwargs,
):
# retrieve all parameters passed to the function
@ -2309,7 +2310,6 @@ def get_optional_params(
k.startswith("vertex_") and custom_llm_provider != "vertex_ai"
): # allow dynamically setting vertex ai init logic
continue
passed_params[k] = v
optional_params: Dict = {}
@ -2365,7 +2365,19 @@ def get_optional_params(
"extra_headers": None,
"api_version": None,
"drop_params": None,
"additional_drop_params": None,
}
def _should_drop_param(k, additional_drop_params) -> bool:
if (
additional_drop_params is not None
and isinstance(additional_drop_params, list)
and k in additional_drop_params
):
return True # allow user to drop specific params for a model - e.g. vllm - logit bias
return False
# filter out those parameters that were passed with non-default values
non_default_params = {
k: v
@ -2375,8 +2387,11 @@ def get_optional_params(
and k != "custom_llm_provider"
and k != "api_version"
and k != "drop_params"
and k != "additional_drop_params"
and k in default_params
and v != default_params[k]
and _should_drop_param(k=k, additional_drop_params=additional_drop_params)
is False
)
}