fix get optional params

This commit is contained in:
Krrish Dholakia 2023-10-02 12:02:45 -07:00
parent 8f1b88c40b
commit 5a19ee1a71
10 changed files with 93 additions and 75 deletions

View file

@ -977,6 +977,9 @@ def get_optional_params( # use the openai defaults
raise ValueError("LiteLLM.Exception: Function calling is not supported by this provider")
def _check_valid_arg(supported_params):
print(f"checking params for {model}")
print(f"params passed in {passed_params}")
print(f"non-default params passed in {non_default_params}")
unsupported_params = [k for k in non_default_params.keys() if k not in supported_params]
if unsupported_params:
raise ValueError("LiteLLM.Exception: Unsupported parameters passed: {}".format(', '.join(unsupported_params)))
@ -990,15 +993,14 @@ def get_optional_params( # use the openai defaults
# handle anthropic params
if stream:
optional_params["stream"] = stream
if stop != None:
if stop:
optional_params["stop_sequences"] = stop
if temperature != 1:
if temperature:
optional_params["temperature"] = temperature
if top_p != 1:
if top_p:
optional_params["top_p"] = top_p
if max_tokens != float("inf"):
if max_tokens:
optional_params["max_tokens_to_sample"] = max_tokens
return optional_params
elif custom_llm_provider == "cohere":
## check if unsupported param passed in
supported_params = ["stream", "temperature", "max_tokens", "logit_bias"]
@ -1006,13 +1008,12 @@ def get_optional_params( # use the openai defaults
# handle cohere params
if stream:
optional_params["stream"] = stream
if temperature != 1:
if temperature:
optional_params["temperature"] = temperature
if max_tokens != float("inf"):
if max_tokens:
optional_params["max_tokens"] = max_tokens
if logit_bias != {}:
optional_params["logit_bias"] = logit_bias
return optional_params
elif custom_llm_provider == "replicate":
## check if unsupported param passed in
supported_params = ["stream", "temperature", "max_tokens", "top_p", "stop"]
@ -1021,39 +1022,37 @@ def get_optional_params( # use the openai defaults
if stream:
optional_params["stream"] = stream
return optional_params
if max_tokens != float("inf"):
if max_tokens:
if "vicuna" in model or "flan" in model:
optional_params["max_length"] = max_tokens
else:
optional_params["max_new_tokens"] = max_tokens
if temperature != 1:
if temperature:
optional_params["temperature"] = temperature
if top_p != 1:
if top_p:
optional_params["top_p"] = top_p
if stop != None:
if stop:
optional_params["stop_sequences"] = stop
elif custom_llm_provider == "huggingface":
## check if unsupported param passed in
supported_params = ["stream", "temperature", "max_tokens", "top_p", "stop", "return_full_text", "details"]
supported_params = ["stream", "temperature", "max_tokens", "top_p", "stop",]
_check_valid_arg(supported_params=supported_params)
if temperature != 1:
if temperature:
optional_params["temperature"] = temperature
if top_p != 1:
if top_p:
optional_params["top_p"] = top_p
if n != 1:
if n:
optional_params["best_of"] = n
optional_params["do_sample"] = True # need to sample if you want best of for hf inference endpoints
if stream:
optional_params["stream"] = stream
if stop != None:
if stop:
optional_params["stop"] = stop
if max_tokens != float("inf"):
if max_tokens:
optional_params["max_new_tokens"] = max_tokens
if presence_penalty != 0:
if presence_penalty:
optional_params["repetition_penalty"] = presence_penalty
optional_params["return_full_text"] = return_full_text
optional_params["details"] = True
elif custom_llm_provider == "together_ai":
## check if unsupported param passed in
supported_params = ["stream", "temperature", "max_tokens", "top_p", "stop", "frequency_penalty"]
@ -1061,24 +1060,24 @@ def get_optional_params( # use the openai defaults
if stream:
optional_params["stream_tokens"] = stream
if temperature != 1:
if temperature:
optional_params["temperature"] = temperature
if top_p != 1:
if top_p:
optional_params["top_p"] = top_p
if max_tokens != float("inf"):
if max_tokens:
optional_params["max_tokens"] = max_tokens
if frequency_penalty != 0:
if frequency_penalty:
optional_params["frequency_penalty"] = frequency_penalty # TODO: Check if should be repetition penalty
if stop != None:
if stop:
optional_params["stop"] = stop #TG AI expects a list, example ["\n\n\n\n","<|endoftext|>"]
elif custom_llm_provider == "palm":
## check if unsupported param passed in
supported_params = ["temperature", "top_p"]
_check_valid_arg(supported_params=supported_params)
if temperature != 1:
if temperature:
optional_params["temperature"] = temperature
if top_p != 1:
if top_p:
optional_params["top_p"] = top_p
elif (
custom_llm_provider == "vertex_ai"
@ -1087,13 +1086,13 @@ def get_optional_params( # use the openai defaults
supported_params = ["temperature", "top_p", "max_tokens", "stream"]
_check_valid_arg(supported_params=supported_params)
if temperature != 1:
if temperature:
optional_params["temperature"] = temperature
if top_p != 1:
if top_p:
optional_params["top_p"] = top_p
if stream:
optional_params["stream"] = stream
if max_tokens != float("inf"):
if max_tokens:
optional_params["max_output_tokens"] = max_tokens
elif custom_llm_provider == "sagemaker":
if "llama-2" in model:
@ -1108,11 +1107,11 @@ def get_optional_params( # use the openai defaults
supported_params = ["temperature", "max_tokens"]
_check_valid_arg(supported_params=supported_params)
if max_tokens != float("inf"):
if max_tokens:
optional_params["max_new_tokens"] = max_tokens
if temperature != 1:
if temperature:
optional_params["temperature"] = temperature
if top_p != 1:
if top_p:
optional_params["top_p"] = top_p
else:
## check if unsupported param passed in
@ -1124,92 +1123,90 @@ def get_optional_params( # use the openai defaults
_check_valid_arg(supported_params=supported_params)
# params "maxTokens":200,"temperature":0,"topP":250,"stop_sequences":[],
# https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=j2-ultra
if max_tokens != float("inf"):
if max_tokens:
optional_params["maxTokens"] = max_tokens
if temperature != 1:
if temperature:
optional_params["temperature"] = temperature
if stop != None:
if stop:
optional_params["stop_sequences"] = stop
if top_p != 1:
if top_p:
optional_params["topP"] = top_p
elif "anthropic" in model:
supported_params = ["max_tokens", "temperature", "stop", "top_p"]
_check_valid_arg(supported_params=supported_params)
# anthropic params on bedrock
# \"max_tokens_to_sample\":300,\"temperature\":0.5,\"top_p\":1,\"stop_sequences\":[\"\\\\n\\\\nHuman:\"]}"
if max_tokens != float("inf"):
if max_tokens:
optional_params["max_tokens_to_sample"] = max_tokens
else:
optional_params["max_tokens_to_sample"] = 256 # anthropic fails without max_tokens_to_sample
if temperature != 1:
if temperature:
optional_params["temperature"] = temperature
if top_p != 1:
if top_p:
optional_params["top_p"] = top_p
if stop != None:
if stop:
optional_params["stop_sequences"] = stop
elif "amazon" in model: # amazon titan llms
supported_params = ["max_tokens", "temperature", "stop", "top_p"]
_check_valid_arg(supported_params=supported_params)
# see https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-large
if max_tokens != float("inf"):
if max_tokens:
optional_params["maxTokenCount"] = max_tokens
if temperature != 1:
if temperature:
optional_params["temperature"] = temperature
if stop != None:
if stop:
optional_params["stopSequences"] = stop
if top_p != 1:
if top_p:
optional_params["topP"] = top_p
elif model in litellm.aleph_alpha_models:
supported_params = ["max_tokens", "stream", "top_p", "temperature", "presence_penalty", "frequency_penalty", "n", "stop"]
_check_valid_arg(supported_params=supported_params)
if max_tokens != float("inf"):
if max_tokens:
optional_params["maximum_tokens"] = max_tokens
if stream:
optional_params["stream"] = stream
if temperature != 1:
if temperature:
optional_params["temperature"] = temperature
if top_p != 1:
if top_p:
optional_params["top_p"] = top_p
if presence_penalty != 0:
if presence_penalty:
optional_params["presence_penalty"] = presence_penalty
if frequency_penalty != 0:
if frequency_penalty:
optional_params["frequency_penalty"] = frequency_penalty
if n != 1:
if n:
optional_params["n"] = n
if stop != None:
if stop:
optional_params["stop_sequences"] = stop
elif model in litellm.nlp_cloud_models or custom_llm_provider == "nlp_cloud":
supported_params = ["max_tokens", "stream", "temperature", "top_p", "presence_penalty", "frequency_penalty", "n", "stop"]
_check_valid_arg(supported_params=supported_params)
if max_tokens != float("inf"):
if max_tokens:
optional_params["max_length"] = max_tokens
if stream:
optional_params["stream"] = stream
if temperature != 1:
if temperature:
optional_params["temperature"] = temperature
if top_p != 1:
if top_p:
optional_params["top_p"] = top_p
if presence_penalty != 0:
if presence_penalty:
optional_params["presence_penalty"] = presence_penalty
if frequency_penalty != 0:
if frequency_penalty:
optional_params["frequency_penalty"] = frequency_penalty
if n != 1:
if n:
optional_params["num_return_sequences"] = n
if stop != None:
if stop:
optional_params["stop_sequences"] = stop
elif model in litellm.petals_models or custom_llm_provider == "petals":
supported_params = ["max_tokens", "temperature", "top_p"]
_check_valid_arg(supported_params=supported_params)
# max_new_tokens=1,temperature=0.9, top_p=0.6
if max_tokens != float("inf"):
if max_tokens:
optional_params["max_new_tokens"] = max_tokens
else:
optional_params["max_new_tokens"] = 256 # petals always needs max_new_tokens
if temperature != 1:
if temperature:
optional_params["temperature"] = temperature
if top_p != 1:
if top_p:
optional_params["top_p"] = top_p
else: # assume passing in params for openai/azure openai
supported_params = ["functions", "function_call", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "logit_bias", "user", "deployment_id"]
@ -1219,6 +1216,7 @@ def get_optional_params( # use the openai defaults
for k in passed_params.keys():
if k not in default_params.keys():
optional_params[k] = passed_params[k]
print(f"final params going to model: {optional_params}")
return optional_params
def get_llm_provider(model: str, custom_llm_provider: Optional[str] = None):