mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
fixing optional param mapping
This commit is contained in:
parent
7cec308a2c
commit
8b60d797e1
6 changed files with 32 additions and 21 deletions
|
@ -1072,13 +1072,15 @@ def get_optional_params( # use the openai defaults
|
|||
optional_params["stop"] = stop #TG AI expects a list, example ["\n\n\n\n","<|endoftext|>"]
|
||||
elif custom_llm_provider == "palm":
|
||||
## check if unsupported param passed in
|
||||
supported_params = ["temperature", "top_p"]
|
||||
supported_params = ["temperature", "top_p", "stream"]
|
||||
_check_valid_arg(supported_params=supported_params)
|
||||
|
||||
if temperature:
|
||||
optional_params["temperature"] = temperature
|
||||
if top_p:
|
||||
optional_params["top_p"] = top_p
|
||||
if stream:
|
||||
optional_params["stream"] = stream
|
||||
elif (
|
||||
custom_llm_provider == "vertex_ai"
|
||||
):
|
||||
|
@ -1104,7 +1106,7 @@ def get_optional_params( # use the openai defaults
|
|||
return_full_text: If True, input text will be part of the output generated text. If specified, it must be boolean. The default value for it is False.
|
||||
"""
|
||||
## check if unsupported param passed in
|
||||
supported_params = ["temperature", "max_tokens"]
|
||||
supported_params = ["temperature", "max_tokens", "stream"]
|
||||
_check_valid_arg(supported_params=supported_params)
|
||||
|
||||
if max_tokens:
|
||||
|
@ -1113,13 +1115,15 @@ def get_optional_params( # use the openai defaults
|
|||
optional_params["temperature"] = temperature
|
||||
if top_p:
|
||||
optional_params["top_p"] = top_p
|
||||
if stream:
|
||||
optional_params["stream"] = stream
|
||||
else:
|
||||
## check if unsupported param passed in
|
||||
supported_params = []
|
||||
_check_valid_arg(supported_params=supported_params)
|
||||
elif custom_llm_provider == "bedrock":
|
||||
if "ai21" in model:
|
||||
supported_params = ["max_tokens", "temperature", "stop", "top_p"]
|
||||
supported_params = ["max_tokens", "temperature", "stop", "top_p", "stream"]
|
||||
_check_valid_arg(supported_params=supported_params)
|
||||
# params "maxTokens":200,"temperature":0,"topP":250,"stop_sequences":[],
|
||||
# https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=j2-ultra
|
||||
|
@ -1131,8 +1135,10 @@ def get_optional_params( # use the openai defaults
|
|||
optional_params["stop_sequences"] = stop
|
||||
if top_p:
|
||||
optional_params["topP"] = top_p
|
||||
if stream:
|
||||
optional_params["stream"] = stream
|
||||
elif "anthropic" in model:
|
||||
supported_params = ["max_tokens", "temperature", "stop", "top_p"]
|
||||
supported_params = ["max_tokens", "temperature", "stop", "top_p", "stream"]
|
||||
_check_valid_arg(supported_params=supported_params)
|
||||
# anthropic params on bedrock
|
||||
# \"max_tokens_to_sample\":300,\"temperature\":0.5,\"top_p\":1,\"stop_sequences\":[\"\\\\n\\\\nHuman:\"]}"
|
||||
|
@ -1146,8 +1152,10 @@ def get_optional_params( # use the openai defaults
|
|||
optional_params["top_p"] = top_p
|
||||
if stop:
|
||||
optional_params["stop_sequences"] = stop
|
||||
if stream:
|
||||
optional_params["stream"] = stream
|
||||
elif "amazon" in model: # amazon titan llms
|
||||
supported_params = ["max_tokens", "temperature", "stop", "top_p"]
|
||||
supported_params = ["max_tokens", "temperature", "stop", "top_p", "stream"]
|
||||
_check_valid_arg(supported_params=supported_params)
|
||||
# see https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-large
|
||||
if max_tokens:
|
||||
|
@ -1158,7 +1166,8 @@ def get_optional_params( # use the openai defaults
|
|||
optional_params["stopSequences"] = stop
|
||||
if top_p:
|
||||
optional_params["topP"] = top_p
|
||||
|
||||
if stream:
|
||||
optional_params["stream"] = stream
|
||||
elif model in litellm.aleph_alpha_models:
|
||||
supported_params = ["max_tokens", "stream", "top_p", "temperature", "presence_penalty", "frequency_penalty", "n", "stop"]
|
||||
_check_valid_arg(supported_params=supported_params)
|
||||
|
@ -3431,13 +3440,15 @@ def completion_with_split_tests(models={}, messages=[], use_client=False, overri
|
|||
return litellm.completion(model=selected_llm, messages=messages, use_client=use_client, **kwargs)
|
||||
|
||||
def completion_with_fallbacks(**kwargs):
|
||||
print(f"kwargs inside completion_with_fallbacks: {kwargs}")
|
||||
nested_kwargs = kwargs.pop("kwargs")
|
||||
response = None
|
||||
rate_limited_models = set()
|
||||
model_expiration_times = {}
|
||||
start_time = time.time()
|
||||
original_model = kwargs["model"]
|
||||
fallbacks = [kwargs["model"]] + kwargs["fallbacks"]
|
||||
del kwargs["fallbacks"] # remove fallbacks so it's not recursive
|
||||
fallbacks = [kwargs["model"]] + nested_kwargs["fallbacks"]
|
||||
del nested_kwargs["fallbacks"] # remove fallbacks so it's not recursive
|
||||
|
||||
while response == None and time.time() - start_time < 45:
|
||||
for model in fallbacks:
|
||||
|
@ -3466,8 +3477,10 @@ def completion_with_fallbacks(**kwargs):
|
|||
if kwargs.get("model"):
|
||||
del kwargs["model"]
|
||||
|
||||
print(f"trying to make completion call with model: {model}")
|
||||
kwargs = {**kwargs, **nested_kwargs} # combine the openai + litellm params at the same level
|
||||
response = litellm.completion(**kwargs, model=model)
|
||||
|
||||
print(f"response: {response}")
|
||||
if response != None:
|
||||
return response
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue