forked from phoenix/litellm-mirror
Merge pull request #4313 from BerriAI/litellm_drop_specific_params
fix(utils.py): allow dropping specific openai params
This commit is contained in:
commit
c8a40eca05
3 changed files with 72 additions and 10 deletions
|
@ -691,7 +691,7 @@ class PrismaClient:
|
||||||
finally:
|
finally:
|
||||||
os.chdir(original_dir)
|
os.chdir(original_dir)
|
||||||
# Now you can import the Prisma Client
|
# Now you can import the Prisma Client
|
||||||
from prisma import Prisma
|
from prisma import Prisma # type: ignore
|
||||||
|
|
||||||
self.db = Prisma() # Client to connect to Prisma db
|
self.db = Prisma() # Client to connect to Prisma db
|
||||||
|
|
||||||
|
|
|
@ -1,20 +1,25 @@
|
||||||
#### What this tests ####
|
#### What this tests ####
|
||||||
# This tests if get_optional_params works as expected
|
# This tests if get_optional_params works as expected
|
||||||
import sys, os, time, inspect, asyncio, traceback
|
import asyncio
|
||||||
|
import inspect
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
sys.path.insert(0, os.path.abspath("../.."))
|
sys.path.insert(0, os.path.abspath("../.."))
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.utils import get_optional_params_embeddings, get_optional_params
|
from litellm.llms.prompt_templates.factory import map_system_message_pt
|
||||||
from litellm.llms.prompt_templates.factory import (
|
|
||||||
map_system_message_pt,
|
|
||||||
)
|
|
||||||
from unittest.mock import patch, MagicMock
|
|
||||||
from litellm.types.completion import (
|
from litellm.types.completion import (
|
||||||
ChatCompletionUserMessageParam,
|
|
||||||
ChatCompletionSystemMessageParam,
|
|
||||||
ChatCompletionMessageParam,
|
ChatCompletionMessageParam,
|
||||||
|
ChatCompletionSystemMessageParam,
|
||||||
|
ChatCompletionUserMessageParam,
|
||||||
)
|
)
|
||||||
|
from litellm.utils import get_optional_params, get_optional_params_embeddings
|
||||||
|
|
||||||
## get_optional_params_embeddings
|
## get_optional_params_embeddings
|
||||||
### Models: OpenAI, Azure, Bedrock
|
### Models: OpenAI, Azure, Bedrock
|
||||||
|
@ -286,3 +291,45 @@ def test_dynamic_drop_params_e2e():
|
||||||
mock_response.assert_called_once()
|
mock_response.assert_called_once()
|
||||||
print(mock_response.call_args.kwargs["data"])
|
print(mock_response.call_args.kwargs["data"])
|
||||||
assert "response_format" not in mock_response.call_args.kwargs["data"]
|
assert "response_format" not in mock_response.call_args.kwargs["data"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("drop_params", [True, False, None])
|
||||||
|
def test_dynamic_drop_additional_params(drop_params):
|
||||||
|
"""
|
||||||
|
Make a call to cohere, dropping 'response_format' specifically
|
||||||
|
"""
|
||||||
|
if drop_params is True:
|
||||||
|
optional_params = litellm.utils.get_optional_params(
|
||||||
|
model="command-r",
|
||||||
|
custom_llm_provider="cohere",
|
||||||
|
response_format="json",
|
||||||
|
additional_drop_params=["response_format"],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
optional_params = litellm.utils.get_optional_params(
|
||||||
|
model="command-r",
|
||||||
|
custom_llm_provider="cohere",
|
||||||
|
response_format="json",
|
||||||
|
)
|
||||||
|
pytest.fail("Expected to fail")
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_dynamic_drop_additional_params_e2e():
|
||||||
|
with patch("requests.post", new=MagicMock()) as mock_response:
|
||||||
|
try:
|
||||||
|
response = litellm.completion(
|
||||||
|
model="command-r",
|
||||||
|
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||||
|
response_format={"key": "value"},
|
||||||
|
additional_drop_params=["response_format"],
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
|
mock_response.assert_called_once()
|
||||||
|
print(mock_response.call_args.kwargs["data"])
|
||||||
|
assert "response_format" not in mock_response.call_args.kwargs["data"]
|
||||||
|
assert "additional_drop_params" not in mock_response.call_args.kwargs["data"]
|
||||||
|
|
|
@ -2293,6 +2293,7 @@ def get_optional_params(
|
||||||
extra_headers=None,
|
extra_headers=None,
|
||||||
api_version=None,
|
api_version=None,
|
||||||
drop_params=None,
|
drop_params=None,
|
||||||
|
additional_drop_params=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
# retrieve all parameters passed to the function
|
# retrieve all parameters passed to the function
|
||||||
|
@ -2309,7 +2310,6 @@ def get_optional_params(
|
||||||
k.startswith("vertex_") and custom_llm_provider != "vertex_ai"
|
k.startswith("vertex_") and custom_llm_provider != "vertex_ai"
|
||||||
): # allow dynamically setting vertex ai init logic
|
): # allow dynamically setting vertex ai init logic
|
||||||
continue
|
continue
|
||||||
|
|
||||||
passed_params[k] = v
|
passed_params[k] = v
|
||||||
|
|
||||||
optional_params: Dict = {}
|
optional_params: Dict = {}
|
||||||
|
@ -2365,7 +2365,19 @@ def get_optional_params(
|
||||||
"extra_headers": None,
|
"extra_headers": None,
|
||||||
"api_version": None,
|
"api_version": None,
|
||||||
"drop_params": None,
|
"drop_params": None,
|
||||||
|
"additional_drop_params": None,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def _should_drop_param(k, additional_drop_params) -> bool:
|
||||||
|
if (
|
||||||
|
additional_drop_params is not None
|
||||||
|
and isinstance(additional_drop_params, list)
|
||||||
|
and k in additional_drop_params
|
||||||
|
):
|
||||||
|
return True # allow user to drop specific params for a model - e.g. vllm - logit bias
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
# filter out those parameters that were passed with non-default values
|
# filter out those parameters that were passed with non-default values
|
||||||
non_default_params = {
|
non_default_params = {
|
||||||
k: v
|
k: v
|
||||||
|
@ -2375,8 +2387,11 @@ def get_optional_params(
|
||||||
and k != "custom_llm_provider"
|
and k != "custom_llm_provider"
|
||||||
and k != "api_version"
|
and k != "api_version"
|
||||||
and k != "drop_params"
|
and k != "drop_params"
|
||||||
|
and k != "additional_drop_params"
|
||||||
and k in default_params
|
and k in default_params
|
||||||
and v != default_params[k]
|
and v != default_params[k]
|
||||||
|
and _should_drop_param(k=k, additional_drop_params=additional_drop_params)
|
||||||
|
is False
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue