(feat) proxy - support dynamic timeout per request

This commit is contained in:
ishaan-jaff 2023-12-30 10:55:42 +05:30
parent 459ba5b45e
commit 2f4cd3b569
2 changed files with 22 additions and 10 deletions

View file

@ -547,10 +547,6 @@ def completion(
model_api_key = get_api_key( model_api_key = get_api_key(
llm_provider=custom_llm_provider, dynamic_api_key=api_key llm_provider=custom_llm_provider, dynamic_api_key=api_key
) # get the api key from the environment if required for the model ) # get the api key from the environment if required for the model
if model_api_key and "sk-litellm" in model_api_key:
api_base = "https://proxy.litellm.ai"
custom_llm_provider = "openai"
api_key = model_api_key
if dynamic_api_key is not None: if dynamic_api_key is not None:
api_key = dynamic_api_key api_key = dynamic_api_key
@ -578,6 +574,7 @@ def completion(
max_retries=max_retries, max_retries=max_retries,
logprobs=logprobs, logprobs=logprobs,
top_logprobs=top_logprobs, top_logprobs=top_logprobs,
timeout=timeout,
**non_default_params, **non_default_params,
) )

View file

@ -2898,6 +2898,7 @@ def get_optional_params(
max_retries=None, max_retries=None,
logprobs=None, logprobs=None,
top_logprobs=None, top_logprobs=None,
timeout=None,
**kwargs, **kwargs,
): ):
# retrieve all parameters passed to the function # retrieve all parameters passed to the function
@ -2927,6 +2928,7 @@ def get_optional_params(
"max_retries": None, "max_retries": None,
"logprobs": None, "logprobs": None,
"top_logprobs": None, "top_logprobs": None,
"timeout": None,
} }
# filter out those parameters that were passed with non-default values # filter out those parameters that were passed with non-default values
non_default_params = { non_default_params = {
@ -3576,16 +3578,26 @@ def get_optional_params(
"max_tokens", "max_tokens",
"stop", "stop",
"frequency_penalty", "frequency_penalty",
"presence_penalty" "presence_penalty",
] ]
if model in ["mistralai/Mistral-7B-Instruct-v0.1", "mistralai/Mixtral-8x7B-Instruct-v0.1"]: if model in [
supported_params += ["functions", "function_call", "tools", "tool_choice", "response_format"] "mistralai/Mistral-7B-Instruct-v0.1",
"mistralai/Mixtral-8x7B-Instruct-v0.1",
]:
supported_params += [
"functions",
"function_call",
"tools",
"tool_choice",
"response_format",
]
_check_valid_arg(supported_params=supported_params) _check_valid_arg(supported_params=supported_params)
optional_params = non_default_params optional_params = non_default_params
if temperature is not None: if temperature is not None:
if ( if temperature == 0 and model in [
temperature == 0 and model in ["mistralai/Mistral-7B-Instruct-v0.1", "mistralai/Mixtral-8x7B-Instruct-v0.1"] "mistralai/Mistral-7B-Instruct-v0.1",
): # this model does no support temperature == 0 "mistralai/Mixtral-8x7B-Instruct-v0.1",
]: # this model does no support temperature == 0
temperature = 0.0001 # close to 0 temperature = 0.0001 # close to 0
optional_params["temperature"] = temperature optional_params["temperature"] = temperature
if top_p: if top_p:
@ -3710,6 +3722,7 @@ def get_optional_params(
"max_retries", "max_retries",
"logprobs", "logprobs",
"top_logprobs", "top_logprobs",
"timeout",
] ]
_check_valid_arg(supported_params=supported_params) _check_valid_arg(supported_params=supported_params)
if functions is not None: if functions is not None:
@ -3750,6 +3763,8 @@ def get_optional_params(
optional_params["logprobs"] = logprobs optional_params["logprobs"] = logprobs
if top_logprobs is not None: if top_logprobs is not None:
optional_params["top_logprobs"] = top_logprobs optional_params["top_logprobs"] = top_logprobs
if timeout is not None:
optional_params["timeout"] = timeout
# if user passed in non-default kwargs for specific providers/models, pass them along # if user passed in non-default kwargs for specific providers/models, pass them along
for k in passed_params.keys(): for k in passed_params.keys():
if k not in default_params.keys(): if k not in default_params.keys():