From 2f4cd3b5695ab67d94f17eea46fcc8ddb12806c2 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Sat, 30 Dec 2023 10:55:42 +0530 Subject: [PATCH] (feat) proxy - support dynamic timeout per request --- litellm/main.py | 5 +---- litellm/utils.py | 27 +++++++++++++++++++++------ 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/litellm/main.py b/litellm/main.py index 8f00cbc88..0ee123b64 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -547,10 +547,6 @@ def completion( model_api_key = get_api_key( llm_provider=custom_llm_provider, dynamic_api_key=api_key ) # get the api key from the environment if required for the model - if model_api_key and "sk-litellm" in model_api_key: - api_base = "https://proxy.litellm.ai" - custom_llm_provider = "openai" - api_key = model_api_key if dynamic_api_key is not None: api_key = dynamic_api_key @@ -578,6 +574,7 @@ def completion( max_retries=max_retries, logprobs=logprobs, top_logprobs=top_logprobs, + timeout=timeout, **non_default_params, ) diff --git a/litellm/utils.py b/litellm/utils.py index 57ed9417a..0de060d9c 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2898,6 +2898,7 @@ def get_optional_params( max_retries=None, logprobs=None, top_logprobs=None, + timeout=None, **kwargs, ): # retrieve all parameters passed to the function @@ -2927,6 +2928,7 @@ def get_optional_params( "max_retries": None, "logprobs": None, "top_logprobs": None, + "timeout": None, } # filter out those parameters that were passed with non-default values non_default_params = { @@ -3576,16 +3578,26 @@ def get_optional_params( "max_tokens", "stop", "frequency_penalty", - "presence_penalty" + "presence_penalty", ] - if model in ["mistralai/Mistral-7B-Instruct-v0.1", "mistralai/Mixtral-8x7B-Instruct-v0.1"]: - supported_params += ["functions", "function_call", "tools", "tool_choice", "response_format"] + if model in [ + "mistralai/Mistral-7B-Instruct-v0.1", + "mistralai/Mixtral-8x7B-Instruct-v0.1", + ]: + supported_params += [ + "functions", + "function_call", + "tools", + "tool_choice", + "response_format", + ] _check_valid_arg(supported_params=supported_params) optional_params = non_default_params if temperature is not None: - if ( - temperature == 0 and model in ["mistralai/Mistral-7B-Instruct-v0.1", "mistralai/Mixtral-8x7B-Instruct-v0.1"] - ): # this model does no support temperature == 0 + if temperature == 0 and model in [ + "mistralai/Mistral-7B-Instruct-v0.1", + "mistralai/Mixtral-8x7B-Instruct-v0.1", + ]: # this model does no support temperature == 0 temperature = 0.0001 # close to 0 optional_params["temperature"] = temperature if top_p: @@ -3710,6 +3722,7 @@ def get_optional_params( "max_retries", "logprobs", "top_logprobs", + "timeout", ] _check_valid_arg(supported_params=supported_params) if functions is not None: @@ -3750,6 +3763,8 @@ def get_optional_params( optional_params["logprobs"] = logprobs if top_logprobs is not None: optional_params["top_logprobs"] = top_logprobs + if timeout is not None: + optional_params["timeout"] = timeout # if user passed in non-default kwargs for specific providers/models, pass them along for k in passed_params.keys(): if k not in default_params.keys():