fix(router.py): allow passing httpx.timeout to timeout param in router

Closes https://github.com/BerriAI/litellm/issues/3162
This commit is contained in:
Krrish Dholakia 2024-04-26 14:56:58 -07:00
parent e1c643ef69
commit 08e36547d6
2 changed files with 36 additions and 2 deletions

View file

@ -14,10 +14,41 @@ from litellm.router import Deployment, LiteLLM_Params, ModelInfo
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from collections import defaultdict from collections import defaultdict
from dotenv import load_dotenv from dotenv import load_dotenv
import os, httpx
load_dotenv() load_dotenv()
@pytest.mark.parametrize(
"timeout", [10, 1.0, httpx.Timeout(timeout=300.0, connect=20.0)]
)
def test_router_timeout_init(timeout):
"""
Allow user to pass httpx.Timeout
related issue - https://github.com/BerriAI/litellm/issues/3162
"""
router = Router(
model_list=[
{
"model_name": "test-model",
"litellm_params": {
"model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"),
"api_base": os.getenv("AZURE_API_BASE"),
"api_version": os.getenv("AZURE_API_VERSION"),
"timeout": timeout,
},
}
]
)
router.completion(
model="test-model", messages=[{"role": "user", "content": "Hey!"}]
)
def test_exception_raising(): def test_exception_raising():
# this tests if the router raises an exception when invalid params are set # this tests if the router raises an exception when invalid params are set
# in this test both deployments have bad keys - Keep this test. It validates if the router raises the most recent exception # in this test both deployments have bad keys - Keep this test. It validates if the router raises the most recent exception

View file

@ -1,5 +1,5 @@
from typing import List, Optional, Union, Dict, Tuple, Literal from typing import List, Optional, Union, Dict, Tuple, Literal
import httpx
from pydantic import BaseModel, validator from pydantic import BaseModel, validator
from .completion import CompletionRequest from .completion import CompletionRequest
from .embedding import EmbeddingRequest from .embedding import EmbeddingRequest
@ -104,7 +104,9 @@ class LiteLLM_Params(BaseModel):
api_key: Optional[str] = None api_key: Optional[str] = None
api_base: Optional[str] = None api_base: Optional[str] = None
api_version: Optional[str] = None api_version: Optional[str] = None
timeout: Optional[Union[float, str]] = None # if str, pass in as os.environ/ timeout: Optional[Union[float, str, httpx.Timeout]] = (
None # if str, pass in as os.environ/
)
stream_timeout: Optional[Union[float, str]] = ( stream_timeout: Optional[Union[float, str]] = (
None # timeout when making stream=True calls, if str, pass in as os.environ/ None # timeout when making stream=True calls, if str, pass in as os.environ/
) )
@ -154,6 +156,7 @@ class LiteLLM_Params(BaseModel):
class Config: class Config:
extra = "allow" extra = "allow"
arbitrary_types_allowed = True
def __contains__(self, key): def __contains__(self, key):
# Define custom behavior for the 'in' operator # Define custom behavior for the 'in' operator