forked from phoenix/litellm-mirror
test(test_optional_params.py): unit tests for get_optional_params_embeddings()
This commit is contained in:
parent
e1bffe3de6
commit
51d62189f1
3 changed files with 36 additions and 7 deletions
|
@ -318,16 +318,18 @@ async def test_async_custom_handler_embedding_optional_param_bedrock():
|
||||||
|
|
||||||
but makes sure these are not sent to the non-openai/azure endpoint (raises errors).
|
but makes sure these are not sent to the non-openai/azure endpoint (raises errors).
|
||||||
"""
|
"""
|
||||||
|
litellm.drop_params = True
|
||||||
|
litellm.set_verbose = True
|
||||||
customHandler_optional_params = MyCustomHandler()
|
customHandler_optional_params = MyCustomHandler()
|
||||||
litellm.callbacks = [customHandler_optional_params]
|
litellm.callbacks = [customHandler_optional_params]
|
||||||
response = await litellm.aembedding(
|
response = await litellm.aembedding(
|
||||||
model="azure/azure-embedding-model",
|
model="bedrock/cohere.embed-multilingual-v3",
|
||||||
input = ["hello world"],
|
input = ["hello world"],
|
||||||
user = "John"
|
user = "John"
|
||||||
)
|
)
|
||||||
await asyncio.sleep(1) # success callback is async
|
await asyncio.sleep(1) # success callback is async
|
||||||
assert customHandler_optional_params.user == "John"
|
assert customHandler_optional_params.user == "John"
|
||||||
assert customHandler_optional_params.user == customHandler_optional_params.data_sent_to_api["user"]
|
assert "user" not in customHandler_optional_params.data_sent_to_api
|
||||||
|
|
||||||
|
|
||||||
def test_redis_cache_completion_stream():
|
def test_redis_cache_completion_stream():
|
||||||
|
|
27
litellm/tests/test_optional_params.py
Normal file
27
litellm/tests/test_optional_params.py
Normal file
|
@ -0,0 +1,27 @@
|
||||||
|
#### What this tests ####
|
||||||
|
# This tests if get_optional_params works as expected
|
||||||
|
import sys, os, time, inspect, asyncio, traceback
|
||||||
|
import pytest
|
||||||
|
sys.path.insert(0, os.path.abspath('../..'))
|
||||||
|
import litellm
|
||||||
|
from litellm.utils import get_optional_params_embeddings
|
||||||
|
## get_optional_params_embeddings
|
||||||
|
### Models: OpenAI, Azure, Bedrock
|
||||||
|
### Scenarios: w/ optional params + litellm.drop_params = True
|
||||||
|
|
||||||
|
def test_bedrock_optional_params_embeddings():
|
||||||
|
litellm.drop_params = True
|
||||||
|
optional_params = get_optional_params_embeddings(user="John", encoding_format=None, custom_llm_provider="bedrock")
|
||||||
|
assert len(optional_params) == 0
|
||||||
|
|
||||||
|
def test_openai_optional_params_embeddings():
|
||||||
|
litellm.drop_params = True
|
||||||
|
optional_params = get_optional_params_embeddings(user="John", encoding_format=None, custom_llm_provider="openai")
|
||||||
|
assert len(optional_params) == 1
|
||||||
|
assert optional_params["user"] == "John"
|
||||||
|
|
||||||
|
def test_azure_optional_params_embeddings():
|
||||||
|
litellm.drop_params = True
|
||||||
|
optional_params = get_optional_params_embeddings(user="John", encoding_format=None, custom_llm_provider="azure")
|
||||||
|
assert len(optional_params) == 1
|
||||||
|
assert optional_params["user"] == "John"
|
|
@ -2170,14 +2170,14 @@ def get_optional_params_embeddings(
|
||||||
}
|
}
|
||||||
|
|
||||||
non_default_params = {k: v for k, v in passed_params.items() if (k in default_params and v != default_params[k])}
|
non_default_params = {k: v for k, v in passed_params.items() if (k in default_params and v != default_params[k])}
|
||||||
|
|
||||||
## raise exception if non-default value passed for non-openai/azure embedding calls
|
## raise exception if non-default value passed for non-openai/azure embedding calls
|
||||||
if custom_llm_provider != "openai" and custom_llm_provider != "azure":
|
if custom_llm_provider != "openai" and custom_llm_provider != "azure":
|
||||||
if len(non_default_params.keys()) > 0:
|
if len(non_default_params.keys()) > 0:
|
||||||
if litellm.drop_params is True:
|
if litellm.drop_params is True: # drop the unsupported non-default values
|
||||||
for k in non_default_params.keys():
|
keys = list(non_default_params.keys())
|
||||||
passed_params.pop(k, None)
|
for k in keys:
|
||||||
return passed_params
|
non_default_params.pop(k, None)
|
||||||
|
return non_default_params
|
||||||
raise UnsupportedParamsError(status_code=500, message=f"Setting user/encoding format is not supported by {custom_llm_provider}. To drop it from the call, set `litellm.drop_params = True`.")
|
raise UnsupportedParamsError(status_code=500, message=f"Setting user/encoding format is not supported by {custom_llm_provider}. To drop it from the call, set `litellm.drop_params = True`.")
|
||||||
|
|
||||||
final_params = {**non_default_params, **kwargs}
|
final_params = {**non_default_params, **kwargs}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue