fix(acompletion): support client side timeouts + raise exceptions correctly for async calls

This commit is contained in:
Krrish Dholakia 2023-11-17 15:39:39 -08:00
parent c4e53aa77b
commit 0ab6b2451d
8 changed files with 142 additions and 81 deletions

View file

@ -33,7 +33,7 @@ class Router:
redis_port: Optional[int] = None,
redis_password: Optional[str] = None,
cache_responses: bool = False,
num_retries: Optional[int] = None,
num_retries: int = 0,
timeout: float = 600,
default_litellm_params = {}, # default params for Router.chat.completion.create
routing_strategy: Literal["simple-shuffle", "least-busy"] = "simple-shuffle") -> None:
@ -42,12 +42,13 @@ class Router:
self.set_model_list(model_list)
self.healthy_deployments: List = self.model_list
if num_retries:
self.num_retries = num_retries
self.num_retries = num_retries
self.chat = litellm.Chat(params=default_litellm_params)
litellm.request_timeout = timeout
self.default_litellm_params = {
"timeout": timeout
}
self.routing_strategy = routing_strategy
### HEALTH CHECK THREAD ###
if self.routing_strategy == "least-busy":
@ -222,6 +223,9 @@ class Router:
# pick the one that is available (lowest TPM/RPM)
deployment = self.get_available_deployment(model=model, messages=messages)
data = deployment["litellm_params"]
for k, v in self.default_litellm_params.items():
if k not in data: # prioritize model-specific params > default router params
data[k] = v
return litellm.completion(**{**data, "messages": messages, "caching": self.cache_responses, **kwargs})
@ -234,16 +238,20 @@ class Router:
try:
deployment = self.get_available_deployment(model=model, messages=messages)
data = deployment["litellm_params"]
for k, v in self.default_litellm_params.items():
if k not in data: # prioritize model-specific params > default router params
data[k] = v
response = await litellm.acompletion(**{**data, "messages": messages, "caching": self.cache_responses, **kwargs})
if inspect.iscoroutinefunction(response): # async errors are often returned as coroutines
response = await response
return response
except Exception as e:
kwargs["model"] = model
kwargs["messages"] = messages
kwargs["original_exception"] = e
kwargs["original_function"] = self.acompletion
return await self.async_function_with_retries(**kwargs)
if self.num_retries > 0:
kwargs["model"] = model
kwargs["messages"] = messages
kwargs["original_exception"] = e
kwargs["original_function"] = self.acompletion
return await self.async_function_with_retries(**kwargs)
else:
raise e
def text_completion(self,
model: str,
@ -258,6 +266,9 @@ class Router:
deployment = self.get_available_deployment(model=model, messages=messages)
data = deployment["litellm_params"]
for k, v in self.default_litellm_params.items():
if k not in data: # prioritize model-specific params > default router params
data[k] = v
# call via litellm.completion()
return litellm.text_completion(**{**data, "prompt": prompt, "caching": self.cache_responses, **kwargs}) # type: ignore
@ -270,6 +281,9 @@ class Router:
deployment = self.get_available_deployment(model=model, input=input)
data = deployment["litellm_params"]
for k, v in self.default_litellm_params.items():
if k not in data: # prioritize model-specific params > default router params
data[k] = v
# call via litellm.embedding()
return litellm.embedding(**{**data, "input": input, "caching": self.cache_responses, **kwargs})
@ -282,4 +296,7 @@ class Router:
deployment = self.get_available_deployment(model=model, input=input)
data = deployment["litellm_params"]
for k, v in self.default_litellm_params.items():
if k not in data: # prioritize model-specific params > default router params
data[k] = v
return await litellm.aembedding(**{**data, "input": input, "caching": self.cache_responses, **kwargs})