forked from phoenix/litellm-mirror
test(test_completion.py): reintegrate testing for huggingface tgi + non-tgi
This commit is contained in:
parent
781d5888c3
commit
c17f221b89
3 changed files with 218 additions and 58 deletions
|
@ -4840,6 +4840,7 @@ def get_optional_params_embeddings(
|
|||
def get_optional_params(
|
||||
# use the openai defaults
|
||||
# https://platform.openai.com/docs/api-reference/chat/create
|
||||
model: str,
|
||||
functions=None,
|
||||
function_call=None,
|
||||
temperature=None,
|
||||
|
@ -4853,7 +4854,6 @@ def get_optional_params(
|
|||
frequency_penalty=None,
|
||||
logit_bias=None,
|
||||
user=None,
|
||||
model=None,
|
||||
custom_llm_provider="",
|
||||
response_format=None,
|
||||
seed=None,
|
||||
|
@ -4882,7 +4882,7 @@ def get_optional_params(
|
|||
|
||||
passed_params[k] = v
|
||||
|
||||
optional_params = {}
|
||||
optional_params: Dict = {}
|
||||
|
||||
common_auth_dict = litellm.common_cloud_provider_auth_params
|
||||
if custom_llm_provider in common_auth_dict["providers"]:
|
||||
|
@ -5156,41 +5156,9 @@ def get_optional_params(
|
|||
model=model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
_check_valid_arg(supported_params=supported_params)
|
||||
# temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None
|
||||
if temperature is not None:
|
||||
if temperature == 0.0 or temperature == 0:
|
||||
# hugging face exception raised when temp==0
|
||||
# Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive
|
||||
temperature = 0.01
|
||||
optional_params["temperature"] = temperature
|
||||
if top_p is not None:
|
||||
optional_params["top_p"] = top_p
|
||||
if n is not None:
|
||||
optional_params["best_of"] = n
|
||||
optional_params["do_sample"] = (
|
||||
True # Need to sample if you want best of for hf inference endpoints
|
||||
)
|
||||
if stream is not None:
|
||||
optional_params["stream"] = stream
|
||||
if stop is not None:
|
||||
optional_params["stop"] = stop
|
||||
if max_tokens is not None:
|
||||
# HF TGI raises the following exception when max_new_tokens==0
|
||||
# Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive
|
||||
if max_tokens == 0:
|
||||
max_tokens = 1
|
||||
optional_params["max_new_tokens"] = max_tokens
|
||||
if n is not None:
|
||||
optional_params["best_of"] = n
|
||||
if presence_penalty is not None:
|
||||
optional_params["repetition_penalty"] = presence_penalty
|
||||
if "echo" in passed_params:
|
||||
# https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation.decoder_input_details
|
||||
# Return the decoder input token logprobs and ids. You must set details=True as well for it to be taken into account. Defaults to False
|
||||
optional_params["decoder_input_details"] = special_params["echo"]
|
||||
passed_params.pop(
|
||||
"echo", None
|
||||
) # since we handle translating echo, we should not send it to TGI request
|
||||
optional_params = litellm.HuggingfaceConfig().map_openai_params(
|
||||
non_default_params=non_default_params, optional_params=optional_params
|
||||
)
|
||||
elif custom_llm_provider == "together_ai":
|
||||
## check if unsupported param passed in
|
||||
supported_params = get_supported_openai_params(
|
||||
|
@ -6150,7 +6118,7 @@ def get_supported_openai_params(model: str, custom_llm_provider: str):
|
|||
"seed",
|
||||
]
|
||||
elif custom_llm_provider == "huggingface":
|
||||
return ["stream", "temperature", "max_tokens", "top_p", "stop", "n"]
|
||||
return litellm.HuggingfaceConfig().get_supported_openai_params()
|
||||
elif custom_llm_provider == "together_ai":
|
||||
return [
|
||||
"stream",
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue