forked from phoenix/litellm-mirror
Merge pull request #4313 from BerriAI/litellm_drop_specific_params
fix(utils.py): allow dropping specific openai params
This commit is contained in:
commit
c8a40eca05
3 changed files with 72 additions and 10 deletions
|
@ -691,7 +691,7 @@ class PrismaClient:
|
|||
finally:
|
||||
os.chdir(original_dir)
|
||||
# Now you can import the Prisma Client
|
||||
from prisma import Prisma
|
||||
from prisma import Prisma # type: ignore
|
||||
|
||||
self.db = Prisma() # Client to connect to Prisma db
|
||||
|
||||
|
|
|
@ -1,20 +1,25 @@
|
|||
#### What this tests ####
|
||||
# This tests if get_optional_params works as expected
|
||||
import sys, os, time, inspect, asyncio, traceback
|
||||
import asyncio
|
||||
import inspect
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, os.path.abspath("../.."))
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import litellm
|
||||
from litellm.utils import get_optional_params_embeddings, get_optional_params
|
||||
from litellm.llms.prompt_templates.factory import (
|
||||
map_system_message_pt,
|
||||
)
|
||||
from unittest.mock import patch, MagicMock
|
||||
from litellm.llms.prompt_templates.factory import map_system_message_pt
|
||||
from litellm.types.completion import (
|
||||
ChatCompletionUserMessageParam,
|
||||
ChatCompletionSystemMessageParam,
|
||||
ChatCompletionMessageParam,
|
||||
ChatCompletionSystemMessageParam,
|
||||
ChatCompletionUserMessageParam,
|
||||
)
|
||||
from litellm.utils import get_optional_params, get_optional_params_embeddings
|
||||
|
||||
## get_optional_params_embeddings
|
||||
### Models: OpenAI, Azure, Bedrock
|
||||
|
@ -286,3 +291,45 @@ def test_dynamic_drop_params_e2e():
|
|||
mock_response.assert_called_once()
|
||||
print(mock_response.call_args.kwargs["data"])
|
||||
assert "response_format" not in mock_response.call_args.kwargs["data"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("drop_params", [True, False, None])
|
||||
def test_dynamic_drop_additional_params(drop_params):
|
||||
"""
|
||||
Make a call to cohere, dropping 'response_format' specifically
|
||||
"""
|
||||
if drop_params is True:
|
||||
optional_params = litellm.utils.get_optional_params(
|
||||
model="command-r",
|
||||
custom_llm_provider="cohere",
|
||||
response_format="json",
|
||||
additional_drop_params=["response_format"],
|
||||
)
|
||||
else:
|
||||
try:
|
||||
optional_params = litellm.utils.get_optional_params(
|
||||
model="command-r",
|
||||
custom_llm_provider="cohere",
|
||||
response_format="json",
|
||||
)
|
||||
pytest.fail("Expected to fail")
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
|
||||
def test_dynamic_drop_additional_params_e2e():
|
||||
with patch("requests.post", new=MagicMock()) as mock_response:
|
||||
try:
|
||||
response = litellm.completion(
|
||||
model="command-r",
|
||||
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||
response_format={"key": "value"},
|
||||
additional_drop_params=["response_format"],
|
||||
)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
mock_response.assert_called_once()
|
||||
print(mock_response.call_args.kwargs["data"])
|
||||
assert "response_format" not in mock_response.call_args.kwargs["data"]
|
||||
assert "additional_drop_params" not in mock_response.call_args.kwargs["data"]
|
||||
|
|
|
@ -2293,6 +2293,7 @@ def get_optional_params(
|
|||
extra_headers=None,
|
||||
api_version=None,
|
||||
drop_params=None,
|
||||
additional_drop_params=None,
|
||||
**kwargs,
|
||||
):
|
||||
# retrieve all parameters passed to the function
|
||||
|
@ -2309,7 +2310,6 @@ def get_optional_params(
|
|||
k.startswith("vertex_") and custom_llm_provider != "vertex_ai"
|
||||
): # allow dynamically setting vertex ai init logic
|
||||
continue
|
||||
|
||||
passed_params[k] = v
|
||||
|
||||
optional_params: Dict = {}
|
||||
|
@ -2365,7 +2365,19 @@ def get_optional_params(
|
|||
"extra_headers": None,
|
||||
"api_version": None,
|
||||
"drop_params": None,
|
||||
"additional_drop_params": None,
|
||||
}
|
||||
|
||||
def _should_drop_param(k, additional_drop_params) -> bool:
|
||||
if (
|
||||
additional_drop_params is not None
|
||||
and isinstance(additional_drop_params, list)
|
||||
and k in additional_drop_params
|
||||
):
|
||||
return True # allow user to drop specific params for a model - e.g. vllm - logit bias
|
||||
|
||||
return False
|
||||
|
||||
# filter out those parameters that were passed with non-default values
|
||||
non_default_params = {
|
||||
k: v
|
||||
|
@ -2375,8 +2387,11 @@ def get_optional_params(
|
|||
and k != "custom_llm_provider"
|
||||
and k != "api_version"
|
||||
and k != "drop_params"
|
||||
and k != "additional_drop_params"
|
||||
and k in default_params
|
||||
and v != default_params[k]
|
||||
and _should_drop_param(k=k, additional_drop_params=additional_drop_params)
|
||||
is False
|
||||
)
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue