fix(router.py): fix client init

This commit is contained in:
Krrish Dholakia 2024-01-22 22:15:30 -08:00 committed by ishaan-jaff
parent 278cc75603
commit eb46ea8f8b
3 changed files with 33 additions and 7 deletions

View file

@ -1,5 +1,5 @@
from typing import Optional, Union, Any
import types, time, json
import types, time, json, traceback
import httpx
from .base import BaseLLM
from litellm.utils import (
@ -349,7 +349,7 @@ class OpenAIChatCompletion(BaseLLM):
if hasattr(e, "status_code"):
raise OpenAIError(status_code=e.status_code, message=str(e))
else:
raise OpenAIError(status_code=500, message=str(e))
raise OpenAIError(status_code=500, message=traceback.format_exc())
async def acompletion(
self,

View file

@ -1521,13 +1521,13 @@ class Router:
):
stream_timeout_env_name = stream_timeout.replace("os.environ/", "")
stream_timeout = litellm.get_secret(stream_timeout_env_name)
litellm_params["stream_timeout"] = stream_timeout
max_retries = litellm_params.pop("max_retries", 2)
if isinstance(max_retries, str) and max_retries.startswith("os.environ/"):
max_retries_env_name = max_retries.replace("os.environ/", "")
max_retries = litellm.get_secret(max_retries_env_name)
litellm_params["max_retries"] = max_retries
if isinstance(max_retries, str):
if max_retries.startswith("os.environ/"):
max_retries_env_name = max_retries.replace("os.environ/", "")
max_retries = litellm.get_secret(max_retries_env_name)
max_retries = int(max_retries)
if "azure" in model_name:
if api_base is None:

View file

@ -942,3 +942,29 @@ def test_reading_openai_keys_os_environ():
# test_reading_openai_keys_os_environ()
def test_router_timeout():
model_list = [
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo",
"api_key": "os.environ/OPENAI_API_KEY",
"timeout": "os.environ/AZURE_TIMEOUT",
"stream_timeout": "os.environ/AZURE_STREAM_TIMEOUT",
"max_retries": "os.environ/AZURE_MAX_RETRIES",
},
}
]
router = Router(model_list=model_list)
messages = [{"role": "user", "content": "Hey, how's it going?"}]
start_time = time.time()
try:
router.completion(
model="gpt-3.5-turbo", messages=messages, max_tokens=500, timeout=1
)
except litellm.exceptions.Timeout as e:
pass
end_time = time.time()
assert end_time - start_time < 1.1