diff --git a/litellm/proxy/llm.py b/litellm/proxy/llm.py index 1d51f1666b..7e467c4d2f 100644 --- a/litellm/proxy/llm.py +++ b/litellm/proxy/llm.py @@ -111,6 +111,7 @@ def litellm_completion(data: Dict, user_model: Optional[str], user_temperature: Optional[str], user_max_tokens: Optional[int], + user_request_timeout: Optional[int], user_api_base: Optional[str], user_headers: Optional[dict], user_debug: bool, @@ -123,6 +124,8 @@ def litellm_completion(data: Dict, # override with user settings if user_temperature: data["temperature"] = user_temperature + if user_request_timeout: + data["request_timeout"] = user_request_timeout if user_max_tokens: data["max_tokens"] = user_max_tokens if user_api_base: diff --git a/litellm/proxy/proxy_cli.py b/litellm/proxy/proxy_cli.py index 96e089caaa..2e2359b137 100644 --- a/litellm/proxy/proxy_cli.py +++ b/litellm/proxy/proxy_cli.py @@ -95,6 +95,7 @@ def is_port_in_use(port): @click.option('--debug', default=False, is_flag=True, type=bool, help='To debug the input') @click.option('--temperature', default=None, type=float, help='Set temperature for the model') @click.option('--max_tokens', default=None, type=int, help='Set max tokens for the model') +@click.option('--request_timeout', default=600, type=int, help='Set timeout in seconds for completion calls') @click.option('--drop_params', is_flag=True, help='Drop any unmapped params') @click.option('--create_proxy', is_flag=True, help='Creates a local OpenAI-compatible server template') @click.option('--add_function_to_prompt', is_flag=True, help='If function passed but unsupported, pass it as prompt') @@ -106,7 +107,7 @@ def is_port_in_use(port): @click.option('--test', flag_value=True, help='proxy chat completions url to make a test request to') @click.option('--local', is_flag=True, default=False, help='for local debugging') @click.option('--cost', is_flag=True, default=False, help='for viewing cost logs') -def run_server(host, port, api_base, api_version, model, alias, add_key, headers, save, debug, temperature, max_tokens, drop_params, create_proxy, add_function_to_prompt, config, file, max_budget, telemetry, logs, test, local, cost): +def run_server(host, port, api_base, api_version, model, alias, add_key, headers, save, debug, temperature, max_tokens, request_timeout, drop_params, create_proxy, add_function_to_prompt, config, file, max_budget, telemetry, logs, test, local, cost): global feature_telemetry args = locals() if local: @@ -198,7 +199,7 @@ def run_server(host, port, api_base, api_version, model, alias, add_key, headers else: if headers: headers = json.loads(headers) - initialize(model=model, alias=alias, api_base=api_base, api_version=api_version, debug=debug, temperature=temperature, max_tokens=max_tokens, max_budget=max_budget, telemetry=telemetry, drop_params=drop_params, add_function_to_prompt=add_function_to_prompt, headers=headers, save=save) + initialize(model=model, alias=alias, api_base=api_base, api_version=api_version, debug=debug, temperature=temperature, max_tokens=max_tokens, request_timeout=request_timeout, max_budget=max_budget, telemetry=telemetry, drop_params=drop_params, add_function_to_prompt=add_function_to_prompt, headers=headers, save=save) try: import uvicorn except: diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index f82177418c..854e17bed9 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -114,6 +114,7 @@ user_api_base = None user_model = None user_debug = False user_max_tokens = None +user_request_timeout = None user_temperature = None user_telemetry = True user_config = None @@ -312,6 +313,7 @@ def initialize( debug, temperature, max_tokens, + request_timeout, max_budget, telemetry, drop_params, @@ -319,7 +321,7 @@ def initialize( headers, save, ): - global user_model, user_api_base, user_debug, user_max_tokens, user_temperature, user_telemetry, user_headers + global user_model, user_api_base, user_debug, user_max_tokens, user_request_timeout, user_temperature, user_telemetry, user_headers user_model = model user_debug = debug load_config() @@ -340,6 +342,9 @@ def initialize( if temperature: # model-specific param user_temperature = temperature dynamic_config[user_model]["temperature"] = temperature + if request_timeout: + user_request_timeout = request_timeout + dynamic_config[user_model]["request_timeout"] = request_timeout if alias: # model-specific param dynamic_config[user_model]["alias"] = alias if drop_params == True: # litellm-specific param @@ -504,7 +509,7 @@ async def completion(request: Request): data = await request.json() return litellm_completion(data=data, type="completion", user_model=user_model, user_temperature=user_temperature, user_max_tokens=user_max_tokens, user_api_base=user_api_base, user_headers=user_headers, - user_debug=user_debug, model_router=model_router) + user_debug=user_debug, model_router=model_router, user_request_timeout=user_request_timeout) @router.post("/v1/chat/completions") @@ -514,7 +519,7 @@ async def chat_completion(request: Request): print_verbose(f"data passed in: {data}") return litellm_completion(data, type="chat_completion", user_model=user_model, user_temperature=user_temperature, user_max_tokens=user_max_tokens, - user_api_base=user_api_base, user_headers=user_headers, user_debug=user_debug, model_router=model_router) + user_api_base=user_api_base, user_headers=user_headers, user_debug=user_debug, model_router=model_router, user_request_timeout=user_request_timeout) def print_cost_logs():