fix(router.py): set initial value of default litellm params to none

This commit is contained in:
Krrish Dholakia 2024-04-27 17:22:50 -07:00
parent d9e0d7ce52
commit ec19c1654b
2 changed files with 12 additions and 6 deletions

View file

@ -72,7 +72,9 @@ class Router:
## RELIABILITY ## ## RELIABILITY ##
num_retries: Optional[int] = None, num_retries: Optional[int] = None,
timeout: Optional[float] = None, timeout: Optional[float] = None,
default_litellm_params={}, # default params for Router.chat.completion.create default_litellm_params: Optional[
dict
] = None, # default params for Router.chat.completion.create
default_max_parallel_requests: Optional[int] = None, default_max_parallel_requests: Optional[int] = None,
set_verbose: bool = False, set_verbose: bool = False,
debug_level: Literal["DEBUG", "INFO"] = "INFO", debug_level: Literal["DEBUG", "INFO"] = "INFO",
@ -158,6 +160,7 @@ class Router:
router = Router(model_list=model_list, fallbacks=[{"azure-gpt-3.5-turbo": "openai-gpt-3.5-turbo"}]) router = Router(model_list=model_list, fallbacks=[{"azure-gpt-3.5-turbo": "openai-gpt-3.5-turbo"}])
``` ```
""" """
if semaphore: if semaphore:
self.semaphore = semaphore self.semaphore = semaphore
self.set_verbose = set_verbose self.set_verbose = set_verbose
@ -260,6 +263,7 @@ class Router:
) # dict to store aliases for router, ex. {"gpt-4": "gpt-3.5-turbo"}, all requests with gpt-4 -> get routed to gpt-3.5-turbo group ) # dict to store aliases for router, ex. {"gpt-4": "gpt-3.5-turbo"}, all requests with gpt-4 -> get routed to gpt-3.5-turbo group
# make Router.chat.completions.create compatible for openai.chat.completions.create # make Router.chat.completions.create compatible for openai.chat.completions.create
default_litellm_params = default_litellm_params or {}
self.chat = litellm.Chat(params=default_litellm_params, router_obj=self) self.chat = litellm.Chat(params=default_litellm_params, router_obj=self)
# default litellm args # default litellm args
@ -475,6 +479,7 @@ class Router:
) )
kwargs["model_info"] = deployment.get("model_info", {}) kwargs["model_info"] = deployment.get("model_info", {})
data = deployment["litellm_params"].copy() data = deployment["litellm_params"].copy()
model_name = data["model"] model_name = data["model"]
for k, v in self.default_litellm_params.items(): for k, v in self.default_litellm_params.items():
if ( if (

View file

@ -89,15 +89,15 @@ def test_router_timeouts():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_router_timeouts_bedrock(): async def test_router_timeouts_bedrock():
import openai import openai, uuid
# Model list for OpenAI and Anthropic models # Model list for OpenAI and Anthropic models
model_list = [ _model_list = [
{ {
"model_name": "bedrock", "model_name": "bedrock",
"litellm_params": { "litellm_params": {
"model": "bedrock/anthropic.claude-instant-v1", "model": "bedrock/anthropic.claude-instant-v1",
"timeout": 0.001, "timeout": 0.00001,
}, },
"tpm": 80000, "tpm": 80000,
}, },
@ -105,17 +105,18 @@ async def test_router_timeouts_bedrock():
# Configure router # Configure router
router = Router( router = Router(
model_list=model_list, model_list=_model_list,
routing_strategy="usage-based-routing", routing_strategy="usage-based-routing",
debug_level="DEBUG", debug_level="DEBUG",
set_verbose=True, set_verbose=True,
num_retries=0,
) )
litellm.set_verbose = True litellm.set_verbose = True
try: try:
response = await router.acompletion( response = await router.acompletion(
model="bedrock", model="bedrock",
messages=[{"role": "user", "content": "hello, who are u"}], messages=[{"role": "user", "content": f"hello, who are u {uuid.uuid4()}"}],
) )
print(response) print(response)
pytest.fail("Did not raise error `openai.APITimeoutError`") pytest.fail("Did not raise error `openai.APITimeoutError`")