(add request_timeout) as param to proxy_server

This commit is contained in:
ishaan-jaff 2023-10-20 11:55:41 -07:00
parent fca20c0847
commit 63f9f63ba4
3 changed files with 14 additions and 5 deletions

View file

@ -114,6 +114,7 @@ user_api_base = None
user_model = None
user_debug = False
user_max_tokens = None
user_request_timeout = None
user_temperature = None
user_telemetry = True
user_config = None
@ -312,6 +313,7 @@ def initialize(
debug,
temperature,
max_tokens,
request_timeout,
max_budget,
telemetry,
drop_params,
@ -319,7 +321,7 @@ def initialize(
headers,
save,
):
global user_model, user_api_base, user_debug, user_max_tokens, user_temperature, user_telemetry, user_headers
global user_model, user_api_base, user_debug, user_max_tokens, user_request_timeout, user_temperature, user_telemetry, user_headers
user_model = model
user_debug = debug
load_config()
@ -340,6 +342,9 @@ def initialize(
if temperature: # model-specific param
user_temperature = temperature
dynamic_config[user_model]["temperature"] = temperature
if request_timeout:
user_request_timeout = request_timeout
dynamic_config[user_model]["request_timeout"] = request_timeout
if alias: # model-specific param
dynamic_config[user_model]["alias"] = alias
if drop_params == True: # litellm-specific param
@ -504,7 +509,7 @@ async def completion(request: Request):
data = await request.json()
return litellm_completion(data=data, type="completion", user_model=user_model, user_temperature=user_temperature,
user_max_tokens=user_max_tokens, user_api_base=user_api_base, user_headers=user_headers,
user_debug=user_debug, model_router=model_router)
user_debug=user_debug, model_router=model_router, user_request_timeout=user_request_timeout)
@router.post("/v1/chat/completions")
@ -514,7 +519,7 @@ async def chat_completion(request: Request):
print_verbose(f"data passed in: {data}")
return litellm_completion(data, type="chat_completion", user_model=user_model,
user_temperature=user_temperature, user_max_tokens=user_max_tokens,
user_api_base=user_api_base, user_headers=user_headers, user_debug=user_debug, model_router=model_router)
user_api_base=user_api_base, user_headers=user_headers, user_debug=user_debug, model_router=model_router, user_request_timeout=user_request_timeout)
def print_cost_logs():