Merge pull request #3571 from BerriAI/litellm_hf_classifier_support

Huggingface classifier support
This commit is contained in:
Krish Dholakia 2024-05-10 17:54:27 -07:00 committed by GitHub
commit 1aa567f3b5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 415 additions and 64 deletions

View file

@ -4871,6 +4871,7 @@ def get_optional_params_embeddings(
def get_optional_params(
# use the openai defaults
# https://platform.openai.com/docs/api-reference/chat/create
model: str,
functions=None,
function_call=None,
temperature=None,
@ -4884,7 +4885,6 @@ def get_optional_params(
frequency_penalty=None,
logit_bias=None,
user=None,
model=None,
custom_llm_provider="",
response_format=None,
seed=None,
@ -4913,7 +4913,7 @@ def get_optional_params(
passed_params[k] = v
optional_params = {}
optional_params: Dict = {}
common_auth_dict = litellm.common_cloud_provider_auth_params
if custom_llm_provider in common_auth_dict["providers"]:
@ -5187,41 +5187,9 @@ def get_optional_params(
model=model, custom_llm_provider=custom_llm_provider
)
_check_valid_arg(supported_params=supported_params)
# temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None
if temperature is not None:
if temperature == 0.0 or temperature == 0:
# hugging face exception raised when temp==0
# Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive
temperature = 0.01
optional_params["temperature"] = temperature
if top_p is not None:
optional_params["top_p"] = top_p
if n is not None:
optional_params["best_of"] = n
optional_params["do_sample"] = (
True # Need to sample if you want best of for hf inference endpoints
)
if stream is not None:
optional_params["stream"] = stream
if stop is not None:
optional_params["stop"] = stop
if max_tokens is not None:
# HF TGI raises the following exception when max_new_tokens==0
# Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive
if max_tokens == 0:
max_tokens = 1
optional_params["max_new_tokens"] = max_tokens
if n is not None:
optional_params["best_of"] = n
if presence_penalty is not None:
optional_params["repetition_penalty"] = presence_penalty
if "echo" in passed_params:
# https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation.decoder_input_details
# Return the decoder input token logprobs and ids. You must set details=True as well for it to be taken into account. Defaults to False
optional_params["decoder_input_details"] = special_params["echo"]
passed_params.pop(
"echo", None
) # since we handle translating echo, we should not send it to TGI request
optional_params = litellm.HuggingfaceConfig().map_openai_params(
non_default_params=non_default_params, optional_params=optional_params
)
elif custom_llm_provider == "together_ai":
## check if unsupported param passed in
supported_params = get_supported_openai_params(
@ -6181,7 +6149,7 @@ def get_supported_openai_params(model: str, custom_llm_provider: str):
"seed",
]
elif custom_llm_provider == "huggingface":
return ["stream", "temperature", "max_tokens", "top_p", "stop", "n"]
return litellm.HuggingfaceConfig().get_supported_openai_params()
elif custom_llm_provider == "together_ai":
return [
"stream",