fix(acompletion): support client side timeouts + raise exceptions correctly for async calls

This commit is contained in:
Krrish Dholakia 2023-11-17 15:39:39 -08:00
parent 43ee581d46
commit 02ed97d0b2
8 changed files with 142 additions and 81 deletions

View file

@ -1,4 +1,4 @@
from typing import Optional, Union from typing import Optional, Union, Any
import types, requests import types, requests
from .base import BaseLLM from .base import BaseLLM
from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper, convert_to_model_response_object from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper, convert_to_model_response_object
@ -98,6 +98,7 @@ class AzureChatCompletion(BaseLLM):
api_type: str, api_type: str,
azure_ad_token: str, azure_ad_token: str,
print_verbose: Callable, print_verbose: Callable,
timeout,
logging_obj, logging_obj,
optional_params, optional_params,
litellm_params, litellm_params,
@ -129,13 +130,13 @@ class AzureChatCompletion(BaseLLM):
) )
if acompletion is True: if acompletion is True:
if optional_params.get("stream", False): if optional_params.get("stream", False):
return self.async_streaming(logging_obj=logging_obj, api_base=api_base, data=data, model=model, api_key=api_key, api_version=api_version, azure_ad_token=azure_ad_token) return self.async_streaming(logging_obj=logging_obj, api_base=api_base, data=data, model=model, api_key=api_key, api_version=api_version, azure_ad_token=azure_ad_token, timeout=timeout)
else: else:
return self.acompletion(api_base=api_base, data=data, model_response=model_response, api_key=api_key, api_version=api_version, model=model, azure_ad_token=azure_ad_token) return self.acompletion(api_base=api_base, data=data, model_response=model_response, api_key=api_key, api_version=api_version, model=model, azure_ad_token=azure_ad_token, timeout=timeout)
elif "stream" in optional_params and optional_params["stream"] == True: elif "stream" in optional_params and optional_params["stream"] == True:
return self.streaming(logging_obj=logging_obj, api_base=api_base, data=data, model=model, api_key=api_key, api_version=api_version, azure_ad_token=azure_ad_token) return self.streaming(logging_obj=logging_obj, api_base=api_base, data=data, model=model, api_key=api_key, api_version=api_version, azure_ad_token=azure_ad_token, timeout=timeout)
else: else:
azure_client = AzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token, http_client=litellm.client_session) azure_client = AzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token, http_client=litellm.client_session, timeout=timeout)
response = azure_client.chat.completions.create(**data) # type: ignore response = azure_client.chat.completions.create(**data) # type: ignore
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response) return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response)
except AzureOpenAIError as e: except AzureOpenAIError as e:
@ -150,16 +151,18 @@ class AzureChatCompletion(BaseLLM):
model: str, model: str,
api_base: str, api_base: str,
data: dict, data: dict,
timeout: Any,
model_response: ModelResponse, model_response: ModelResponse,
azure_ad_token: Optional[str]=None, ): azure_ad_token: Optional[str]=None, ):
response = None
try: try:
azure_client = AsyncAzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token, http_client=litellm.aclient_session) azure_client = AsyncAzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token, http_client=litellm.aclient_session, timeout=timeout)
response = await azure_client.chat.completions.create(**data) response = await azure_client.chat.completions.create(**data)
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response) return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response)
except Exception as e: except Exception as e:
if isinstance(e,httpx.TimeoutException): if isinstance(e,httpx.TimeoutException):
raise AzureOpenAIError(status_code=500, message="Request Timeout Error") raise AzureOpenAIError(status_code=500, message="Request Timeout Error")
elif response and hasattr(response, "text"): elif response is not None and hasattr(response, "text"):
raise AzureOpenAIError(status_code=500, message=f"{str(e)}\n\nOriginal Response: {response.text}") raise AzureOpenAIError(status_code=500, message=f"{str(e)}\n\nOriginal Response: {response.text}")
else: else:
raise AzureOpenAIError(status_code=500, message=f"{str(e)}") raise AzureOpenAIError(status_code=500, message=f"{str(e)}")
@ -171,9 +174,10 @@ class AzureChatCompletion(BaseLLM):
api_version: str, api_version: str,
data: dict, data: dict,
model: str, model: str,
timeout: Any,
azure_ad_token: Optional[str]=None, azure_ad_token: Optional[str]=None,
): ):
azure_client = AzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token, http_client=litellm.client_session) azure_client = AzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token, http_client=litellm.client_session, timeout=timeout)
response = azure_client.chat.completions.create(**data) response = azure_client.chat.completions.create(**data)
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="azure",logging_obj=logging_obj) streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="azure",logging_obj=logging_obj)
for transformed_chunk in streamwrapper: for transformed_chunk in streamwrapper:
@ -186,8 +190,9 @@ class AzureChatCompletion(BaseLLM):
api_version: str, api_version: str,
data: dict, data: dict,
model: str, model: str,
timeout: Any,
azure_ad_token: Optional[str]=None): azure_ad_token: Optional[str]=None):
azure_client = AsyncAzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token, http_client=litellm.aclient_session) azure_client = AsyncAzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token, http_client=litellm.aclient_session, timeout=timeout)
response = await azure_client.chat.completions.create(**data) response = await azure_client.chat.completions.create(**data)
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="azure",logging_obj=logging_obj) streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="azure",logging_obj=logging_obj)
async for transformed_chunk in streamwrapper: async for transformed_chunk in streamwrapper:

View file

@ -1,4 +1,4 @@
from typing import Optional, Union from typing import Optional, Union, Any
import types, time, json import types, time, json
import httpx import httpx
from .base import BaseLLM from .base import BaseLLM
@ -161,6 +161,7 @@ class OpenAIChatCompletion(BaseLLM):
def completion(self, def completion(self,
model_response: ModelResponse, model_response: ModelResponse,
timeout: Any,
model: Optional[str]=None, model: Optional[str]=None,
messages: Optional[list]=None, messages: Optional[list]=None,
print_verbose: Optional[Callable]=None, print_verbose: Optional[Callable]=None,
@ -180,6 +181,9 @@ class OpenAIChatCompletion(BaseLLM):
if model is None or messages is None: if model is None or messages is None:
raise OpenAIError(status_code=422, message=f"Missing model or messages") raise OpenAIError(status_code=422, message=f"Missing model or messages")
if not isinstance(timeout, float):
raise OpenAIError(status_code=422, message=f"Timeout needs to be a float")
for _ in range(2): # if call fails due to alternating messages, retry with reformatted message for _ in range(2): # if call fails due to alternating messages, retry with reformatted message
data = { data = {
"model": model, "model": model,
@ -197,13 +201,13 @@ class OpenAIChatCompletion(BaseLLM):
try: try:
if acompletion is True: if acompletion is True:
if optional_params.get("stream", False): if optional_params.get("stream", False):
return self.async_streaming(logging_obj=logging_obj, data=data, model=model, api_base=api_base, api_key=api_key) return self.async_streaming(logging_obj=logging_obj, data=data, model=model, api_base=api_base, api_key=api_key, timeout=timeout)
else: else:
return self.acompletion(data=data, model_response=model_response, api_base=api_base, api_key=api_key) return self.acompletion(data=data, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout)
elif optional_params.get("stream", False): elif optional_params.get("stream", False):
return self.streaming(logging_obj=logging_obj, data=data, model=model, api_base=api_base, api_key=api_key) return self.streaming(logging_obj=logging_obj, data=data, model=model, api_base=api_base, api_key=api_key, timeout=timeout)
else: else:
openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session) openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout)
response = openai_client.chat.completions.create(**data) # type: ignore response = openai_client.chat.completions.create(**data) # type: ignore
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response) return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response)
except Exception as e: except Exception as e:
@ -234,27 +238,32 @@ class OpenAIChatCompletion(BaseLLM):
async def acompletion(self, async def acompletion(self,
data: dict, data: dict,
model_response: ModelResponse, model_response: ModelResponse,
timeout: float,
api_key: Optional[str]=None, api_key: Optional[str]=None,
api_base: Optional[str]=None): api_base: Optional[str]=None):
response = None response = None
try: try:
openai_aclient = AsyncOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.aclient_session) openai_aclient = AsyncOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.aclient_session, timeout=timeout)
response = await openai_aclient.chat.completions.create(**data) response = await openai_aclient.chat.completions.create(**data)
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response) return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response)
except Exception as e: except Exception as e:
if response and hasattr(response, "text"): if response and hasattr(response, "text"):
raise OpenAIError(status_code=500, message=f"{str(e)}\n\nOriginal Response: {response.text}") raise OpenAIError(status_code=500, message=f"{str(e)}\n\nOriginal Response: {response.text}")
else:
if type(e).__name__ == "ReadTimeout":
raise OpenAIError(status_code=408, message=f"{type(e).__name__}")
else: else:
raise OpenAIError(status_code=500, message=f"{str(e)}") raise OpenAIError(status_code=500, message=f"{str(e)}")
def streaming(self, def streaming(self,
logging_obj, logging_obj,
timeout: float,
data: dict, data: dict,
model: str, model: str,
api_key: Optional[str]=None, api_key: Optional[str]=None,
api_base: Optional[str]=None api_base: Optional[str]=None
): ):
openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session) openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout)
response = openai_client.chat.completions.create(**data) response = openai_client.chat.completions.create(**data)
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="openai",logging_obj=logging_obj) streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="openai",logging_obj=logging_obj)
for transformed_chunk in streamwrapper: for transformed_chunk in streamwrapper:
@ -262,15 +271,26 @@ class OpenAIChatCompletion(BaseLLM):
async def async_streaming(self, async def async_streaming(self,
logging_obj, logging_obj,
timeout: float,
data: dict, data: dict,
model: str, model: str,
api_key: Optional[str]=None, api_key: Optional[str]=None,
api_base: Optional[str]=None): api_base: Optional[str]=None):
openai_aclient = AsyncOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.aclient_session) response = None
try:
openai_aclient = AsyncOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.aclient_session, timeout=timeout)
response = await openai_aclient.chat.completions.create(**data) response = await openai_aclient.chat.completions.create(**data)
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="openai",logging_obj=logging_obj) streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="openai",logging_obj=logging_obj)
async for transformed_chunk in streamwrapper: async for transformed_chunk in streamwrapper:
yield transformed_chunk yield transformed_chunk
except Exception as e: # need to exception handle here. async exceptions don't get caught in sync functions.
if response is not None and hasattr(response, "text"):
raise OpenAIError(status_code=500, message=f"{str(e)}\n\nOriginal Response: {response.text}")
else:
if type(e).__name__ == "ReadTimeout":
raise OpenAIError(status_code=408, message=f"{type(e).__name__}")
else:
raise OpenAIError(status_code=500, message=f"{str(e)}")
def embedding(self, def embedding(self,
model: str, model: str,

View file

@ -10,7 +10,6 @@
import os, openai, sys, json, inspect, uuid, datetime, threading import os, openai, sys, json, inspect, uuid, datetime, threading
from typing import Any from typing import Any
from functools import partial from functools import partial
import dotenv, traceback, random, asyncio, time, contextvars import dotenv, traceback, random, asyncio, time, contextvars
from copy import deepcopy from copy import deepcopy
import httpx import httpx
@ -18,7 +17,6 @@ import litellm
from litellm import ( # type: ignore from litellm import ( # type: ignore
client, client,
exception_type, exception_type,
timeout,
get_optional_params, get_optional_params,
get_litellm_params, get_litellm_params,
Logging, Logging,
@ -176,28 +174,28 @@ async def acompletion(*args, **kwargs):
init_response = completion(*args, **kwargs) init_response = completion(*args, **kwargs)
if isinstance(init_response, dict) or isinstance(init_response, ModelResponse): ## CACHING SCENARIO if isinstance(init_response, dict) or isinstance(init_response, ModelResponse): ## CACHING SCENARIO
response = init_response response = init_response
else: elif asyncio.iscoroutine(init_response):
response = await init_response response = await init_response
else: else:
# Call the synchronous function using run_in_executor # Call the synchronous function using run_in_executor
response = await loop.run_in_executor(None, func_with_context) response = await loop.run_in_executor(None, func_with_context)
if kwargs.get("stream", False): # return an async generator if kwargs.get("stream", False): # return an async generator
# do not change this return _async_streaming(response=response, model=model, custom_llm_provider=custom_llm_provider, args=args)
# for stream = True, always return an async generator
# See OpenAI acreate https://github.com/openai/openai-python/blob/5d50e9e3b39540af782ca24e65c290343d86e1a9/openai/api_resources/abstract/engine_api_resource.py#L193
# return response
return(
line
async for line in response
)
else: else:
return response return response
except Exception as e: except Exception as e:
## Map to OpenAI Exception
raise exception_type( raise exception_type(
model=model, custom_llm_provider=custom_llm_provider, original_exception=e, completion_kwargs=args, model=model, custom_llm_provider=custom_llm_provider, original_exception=e, completion_kwargs=args,
) )
async def _async_streaming(response, model, custom_llm_provider, args):
try:
async for line in response:
yield line
except Exception as e:
raise exception_type(
model=model, custom_llm_provider=custom_llm_provider, original_exception=e, completion_kwargs=args,
)
def mock_completion(model: str, messages: List, stream: Optional[bool] = False, mock_response: str = "This is a mock request", **kwargs): def mock_completion(model: str, messages: List, stream: Optional[bool] = False, mock_response: str = "This is a mock request", **kwargs):
""" """
@ -245,6 +243,7 @@ def completion(
messages: List = [], messages: List = [],
functions: List = [], functions: List = [],
function_call: str = "", # optional params function_call: str = "", # optional params
timeout: Union[float, int] = 600.0,
temperature: Optional[float] = None, temperature: Optional[float] = None,
top_p: Optional[float] = None, top_p: Optional[float] = None,
n: Optional[int] = None, n: Optional[int] = None,
@ -261,7 +260,6 @@ def completion(
tools: Optional[List] = None, tools: Optional[List] = None,
tool_choice: Optional[str] = None, tool_choice: Optional[str] = None,
deployment_id = None, deployment_id = None,
# set api_base, api_version, api_key # set api_base, api_version, api_key
base_url: Optional[str] = None, base_url: Optional[str] = None,
api_version: Optional[str] = None, api_version: Optional[str] = None,
@ -298,9 +296,8 @@ def completion(
LITELLM Specific Params LITELLM Specific Params
mock_response (str, optional): If provided, return a mock completion response for testing or debugging purposes (default is None). mock_response (str, optional): If provided, return a mock completion response for testing or debugging purposes (default is None).
force_timeout (int, optional): The maximum execution time in seconds for the completion request (default is 600).
custom_llm_provider (str, optional): Used for Non-OpenAI LLMs, Example usage for bedrock, set model="amazon.titan-tg1-large" and custom_llm_provider="bedrock" custom_llm_provider (str, optional): Used for Non-OpenAI LLMs, Example usage for bedrock, set model="amazon.titan-tg1-large" and custom_llm_provider="bedrock"
num_retries (int, optional): The number of retries to attempt (default is 0). max_retries (int, optional): The number of retries to attempt (default is 0).
Returns: Returns:
ModelResponse: A response object containing the generated completion and associated metadata. ModelResponse: A response object containing the generated completion and associated metadata.
@ -314,7 +311,7 @@ def completion(
api_base = kwargs.get('api_base', None) api_base = kwargs.get('api_base', None)
return_async = kwargs.get('return_async', False) return_async = kwargs.get('return_async', False)
mock_response = kwargs.get('mock_response', None) mock_response = kwargs.get('mock_response', None)
force_timeout= kwargs.get('force_timeout', 600) force_timeout= kwargs.get('force_timeout', 600) ## deprecated
logger_fn = kwargs.get('logger_fn', None) logger_fn = kwargs.get('logger_fn', None)
verbose = kwargs.get('verbose', False) verbose = kwargs.get('verbose', False)
custom_llm_provider = kwargs.get('custom_llm_provider', None) custom_llm_provider = kwargs.get('custom_llm_provider', None)
@ -338,8 +335,10 @@ def completion(
litellm_params = ["metadata", "acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "max_retries"] litellm_params = ["metadata", "acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "max_retries"]
default_params = openai_params + litellm_params default_params = openai_params + litellm_params
non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider
if mock_response: if mock_response:
return mock_completion(model, messages, stream=stream, mock_response=mock_response) return mock_completion(model, messages, stream=stream, mock_response=mock_response)
timeout = float(timeout)
try: try:
if base_url: if base_url:
api_base = base_url api_base = base_url
@ -486,7 +485,8 @@ def completion(
litellm_params=litellm_params, litellm_params=litellm_params,
logger_fn=logger_fn, logger_fn=logger_fn,
logging_obj=logging, logging_obj=logging,
acompletion=acompletion acompletion=acompletion,
timeout=timeout
) )
## LOGGING ## LOGGING
@ -552,7 +552,8 @@ def completion(
logging_obj=logging, logging_obj=logging,
optional_params=optional_params, optional_params=optional_params,
litellm_params=litellm_params, litellm_params=litellm_params,
logger_fn=logger_fn logger_fn=logger_fn,
timeout=timeout
) )
except Exception as e: except Exception as e:
## LOGGING - log the original exception returned ## LOGGING - log the original exception returned
@ -1399,6 +1400,24 @@ def completion_with_retries(*args, **kwargs):
retryer = tenacity.Retrying(wait=tenacity.wait_exponential(multiplier=1, max=10), stop=tenacity.stop_after_attempt(num_retries), reraise=True) retryer = tenacity.Retrying(wait=tenacity.wait_exponential(multiplier=1, max=10), stop=tenacity.stop_after_attempt(num_retries), reraise=True)
return retryer(original_function, *args, **kwargs) return retryer(original_function, *args, **kwargs)
async def acompletion_with_retries(*args, **kwargs):
"""
Executes a litellm.completion() with 3 retries
"""
try:
import tenacity
except Exception as e:
raise Exception(f"tenacity import failed please run `pip install tenacity`. Error{e}")
num_retries = kwargs.pop("num_retries", 3)
retry_strategy = kwargs.pop("retry_strategy", "constant_retry")
original_function = kwargs.pop("original_function", completion)
if retry_strategy == "constant_retry":
retryer = tenacity.Retrying(stop=tenacity.stop_after_attempt(num_retries), reraise=True)
elif retry_strategy == "exponential_backoff_retry":
retryer = tenacity.Retrying(wait=tenacity.wait_exponential(multiplier=1, max=10), stop=tenacity.stop_after_attempt(num_retries), reraise=True)
return await retryer(original_function, *args, **kwargs)
def batch_completion( def batch_completion(
@ -1639,9 +1658,6 @@ async def aembedding(*args, **kwargs):
return response return response
@client @client
@timeout( # type: ignore
60
) ## set timeouts, in case calls hang (e.g. Azure) - default is 60s, override with `force_timeout`
def embedding( def embedding(
model, model,
input=[], input=[],

View file

@ -26,4 +26,3 @@ model_list:
litellm_settings: litellm_settings:
drop_params: True drop_params: True
success_callback: ["langfuse"] # https://docs.litellm.ai/docs/observability/langfuse_integration success_callback: ["langfuse"] # https://docs.litellm.ai/docs/observability/langfuse_integration

View file

@ -33,7 +33,7 @@ class Router:
redis_port: Optional[int] = None, redis_port: Optional[int] = None,
redis_password: Optional[str] = None, redis_password: Optional[str] = None,
cache_responses: bool = False, cache_responses: bool = False,
num_retries: Optional[int] = None, num_retries: int = 0,
timeout: float = 600, timeout: float = 600,
default_litellm_params = {}, # default params for Router.chat.completion.create default_litellm_params = {}, # default params for Router.chat.completion.create
routing_strategy: Literal["simple-shuffle", "least-busy"] = "simple-shuffle") -> None: routing_strategy: Literal["simple-shuffle", "least-busy"] = "simple-shuffle") -> None:
@ -42,12 +42,13 @@ class Router:
self.set_model_list(model_list) self.set_model_list(model_list)
self.healthy_deployments: List = self.model_list self.healthy_deployments: List = self.model_list
if num_retries:
self.num_retries = num_retries self.num_retries = num_retries
self.chat = litellm.Chat(params=default_litellm_params) self.chat = litellm.Chat(params=default_litellm_params)
litellm.request_timeout = timeout self.default_litellm_params = {
"timeout": timeout
}
self.routing_strategy = routing_strategy self.routing_strategy = routing_strategy
### HEALTH CHECK THREAD ### ### HEALTH CHECK THREAD ###
if self.routing_strategy == "least-busy": if self.routing_strategy == "least-busy":
@ -222,6 +223,9 @@ class Router:
# pick the one that is available (lowest TPM/RPM) # pick the one that is available (lowest TPM/RPM)
deployment = self.get_available_deployment(model=model, messages=messages) deployment = self.get_available_deployment(model=model, messages=messages)
data = deployment["litellm_params"] data = deployment["litellm_params"]
for k, v in self.default_litellm_params.items():
if k not in data: # prioritize model-specific params > default router params
data[k] = v
return litellm.completion(**{**data, "messages": messages, "caching": self.cache_responses, **kwargs}) return litellm.completion(**{**data, "messages": messages, "caching": self.cache_responses, **kwargs})
@ -234,16 +238,20 @@ class Router:
try: try:
deployment = self.get_available_deployment(model=model, messages=messages) deployment = self.get_available_deployment(model=model, messages=messages)
data = deployment["litellm_params"] data = deployment["litellm_params"]
for k, v in self.default_litellm_params.items():
if k not in data: # prioritize model-specific params > default router params
data[k] = v
response = await litellm.acompletion(**{**data, "messages": messages, "caching": self.cache_responses, **kwargs}) response = await litellm.acompletion(**{**data, "messages": messages, "caching": self.cache_responses, **kwargs})
if inspect.iscoroutinefunction(response): # async errors are often returned as coroutines
response = await response
return response return response
except Exception as e: except Exception as e:
if self.num_retries > 0:
kwargs["model"] = model kwargs["model"] = model
kwargs["messages"] = messages kwargs["messages"] = messages
kwargs["original_exception"] = e kwargs["original_exception"] = e
kwargs["original_function"] = self.acompletion kwargs["original_function"] = self.acompletion
return await self.async_function_with_retries(**kwargs) return await self.async_function_with_retries(**kwargs)
else:
raise e
def text_completion(self, def text_completion(self,
model: str, model: str,
@ -258,6 +266,9 @@ class Router:
deployment = self.get_available_deployment(model=model, messages=messages) deployment = self.get_available_deployment(model=model, messages=messages)
data = deployment["litellm_params"] data = deployment["litellm_params"]
for k, v in self.default_litellm_params.items():
if k not in data: # prioritize model-specific params > default router params
data[k] = v
# call via litellm.completion() # call via litellm.completion()
return litellm.text_completion(**{**data, "prompt": prompt, "caching": self.cache_responses, **kwargs}) # type: ignore return litellm.text_completion(**{**data, "prompt": prompt, "caching": self.cache_responses, **kwargs}) # type: ignore
@ -270,6 +281,9 @@ class Router:
deployment = self.get_available_deployment(model=model, input=input) deployment = self.get_available_deployment(model=model, input=input)
data = deployment["litellm_params"] data = deployment["litellm_params"]
for k, v in self.default_litellm_params.items():
if k not in data: # prioritize model-specific params > default router params
data[k] = v
# call via litellm.embedding() # call via litellm.embedding()
return litellm.embedding(**{**data, "input": input, "caching": self.cache_responses, **kwargs}) return litellm.embedding(**{**data, "input": input, "caching": self.cache_responses, **kwargs})
@ -282,4 +296,7 @@ class Router:
deployment = self.get_available_deployment(model=model, input=input) deployment = self.get_available_deployment(model=model, input=input)
data = deployment["litellm_params"] data = deployment["litellm_params"]
for k, v in self.default_litellm_params.items():
if k not in data: # prioritize model-specific params > default router params
data[k] = v
return await litellm.aembedding(**{**data, "input": input, "caching": self.cache_responses, **kwargs}) return await litellm.aembedding(**{**data, "input": input, "caching": self.cache_responses, **kwargs})

View file

@ -5,8 +5,6 @@ import sys, os
import pytest import pytest
import traceback import traceback
import asyncio, logging import asyncio, logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
@ -20,7 +18,8 @@ def test_sync_response():
user_message = "Hello, how are you?" user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}] messages = [{"content": user_message, "role": "user"}]
try: try:
response = completion(model="gpt-3.5-turbo", messages=messages, api_key=os.environ["OPENAI_API_KEY"]) response = completion(model="gpt-3.5-turbo", messages=messages, timeout=5)
print(f"response: {response}")
except Exception as e: except Exception as e:
pytest.fail(f"An exception occurred: {e}") pytest.fail(f"An exception occurred: {e}")
# test_sync_response() # test_sync_response()
@ -30,7 +29,7 @@ def test_sync_response_anyscale():
user_message = "Hello, how are you?" user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}] messages = [{"content": user_message, "role": "user"}]
try: try:
response = completion(model="anyscale/mistralai/Mistral-7B-Instruct-v0.1", messages=messages) response = completion(model="anyscale/mistralai/Mistral-7B-Instruct-v0.1", messages=messages, timeout=5)
except Exception as e: except Exception as e:
pytest.fail(f"An exception occurred: {e}") pytest.fail(f"An exception occurred: {e}")
@ -43,10 +42,13 @@ def test_async_response_openai():
user_message = "Hello, how are you?" user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}] messages = [{"content": user_message, "role": "user"}]
try: try:
response = await acompletion(model="gpt-3.5-turbo", messages=messages) response = await acompletion(model="gpt-3.5-turbo", messages=messages, timeout=5)
print(f"response: {response}") print(f"response: {response}")
except litellm.Timeout as e:
pass
except Exception as e: except Exception as e:
pytest.fail(f"An exception occurred: {e}") pytest.fail(f"An exception occurred: {e}")
print(e)
asyncio.run(test_get_response()) asyncio.run(test_get_response())
@ -56,17 +58,18 @@ def test_async_response_azure():
import asyncio import asyncio
litellm.set_verbose = True litellm.set_verbose = True
async def test_get_response(): async def test_get_response():
user_message = "Hello, how are you?" user_message = "What do you know?"
messages = [{"content": user_message, "role": "user"}] messages = [{"content": user_message, "role": "user"}]
try: try:
response = await acompletion(model="azure/chatgpt-v-2", messages=messages) response = await acompletion(model="azure/chatgpt-v-2", messages=messages, timeout=5)
print(f"response: {response}") print(f"response: {response}")
except litellm.Timeout as e:
pass
except Exception as e: except Exception as e:
pytest.fail(f"An exception occurred: {e}") pytest.fail(f"An exception occurred: {e}")
asyncio.run(test_get_response()) asyncio.run(test_get_response())
def test_async_anyscale_response(): def test_async_anyscale_response():
import asyncio import asyncio
litellm.set_verbose = True litellm.set_verbose = True
@ -74,9 +77,11 @@ def test_async_anyscale_response():
user_message = "Hello, how are you?" user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}] messages = [{"content": user_message, "role": "user"}]
try: try:
response = await acompletion(model="anyscale/mistralai/Mistral-7B-Instruct-v0.1", messages=messages) response = await acompletion(model="anyscale/mistralai/Mistral-7B-Instruct-v0.1", messages=messages, timeout=5)
# response = await response # response = await response
print(f"response: {response}") print(f"response: {response}")
except litellm.Timeout as e:
pass
except Exception as e: except Exception as e:
pytest.fail(f"An exception occurred: {e}") pytest.fail(f"An exception occurred: {e}")
@ -91,7 +96,7 @@ def test_get_response_streaming():
messages = [{"content": user_message, "role": "user"}] messages = [{"content": user_message, "role": "user"}]
try: try:
litellm.set_verbose = True litellm.set_verbose = True
response = await acompletion(model="gpt-3.5-turbo", messages=messages, stream=True) response = await acompletion(model="gpt-3.5-turbo", messages=messages, stream=True, timeout=5)
print(type(response)) print(type(response))
import inspect import inspect
@ -108,12 +113,12 @@ def test_get_response_streaming():
assert isinstance(output, str), "output needs to be of type str" assert isinstance(output, str), "output needs to be of type str"
assert len(output) > 0, "Length of output needs to be greater than 0." assert len(output) > 0, "Length of output needs to be greater than 0."
print(f'output: {output}') print(f'output: {output}')
except litellm.Timeout as e:
pass
except Exception as e: except Exception as e:
pytest.fail(f"An exception occurred: {e}") pytest.fail(f"An exception occurred: {e}")
return response
asyncio.run(test_async_call()) asyncio.run(test_async_call())
# test_get_response_streaming() # test_get_response_streaming()
def test_get_response_non_openai_streaming(): def test_get_response_non_openai_streaming():
@ -123,7 +128,7 @@ def test_get_response_non_openai_streaming():
user_message = "Hello, how are you?" user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}] messages = [{"content": user_message, "role": "user"}]
try: try:
response = await acompletion(model="anyscale/mistralai/Mistral-7B-Instruct-v0.1", messages=messages, stream=True) response = await acompletion(model="anyscale/mistralai/Mistral-7B-Instruct-v0.1", messages=messages, stream=True, timeout=5)
print(type(response)) print(type(response))
import inspect import inspect
@ -140,10 +145,11 @@ def test_get_response_non_openai_streaming():
assert output is not None, "output cannot be None." assert output is not None, "output cannot be None."
assert isinstance(output, str), "output needs to be of type str" assert isinstance(output, str), "output needs to be of type str"
assert len(output) > 0, "Length of output needs to be greater than 0." assert len(output) > 0, "Length of output needs to be greater than 0."
except litellm.Timeout as e:
pass
except Exception as e: except Exception as e:
pytest.fail(f"An exception occurred: {e}") pytest.fail(f"An exception occurred: {e}")
return response return response
asyncio.run(test_async_call()) asyncio.run(test_async_call())
# test_get_response_non_openai_streaming() test_get_response_non_openai_streaming()

View file

@ -242,11 +242,11 @@ def test_acompletion_on_router():
] ]
messages = [ messages = [
{"role": "user", "content": "What is the weather like in Boston?"} {"role": "user", "content": "What is the weather like in SF?"}
] ]
async def get_response(): async def get_response():
router = Router(model_list=model_list, redis_host=os.environ["REDIS_HOST"], redis_password=os.environ["REDIS_PASSWORD"], redis_port=os.environ["REDIS_PORT"], cache_responses=True, timeout=10) router = Router(model_list=model_list, redis_host=os.environ["REDIS_HOST"], redis_password=os.environ["REDIS_PASSWORD"], redis_port=os.environ["REDIS_PORT"], cache_responses=True, timeout=0.1)
response1 = await router.acompletion(model="gpt-3.5-turbo", messages=messages) response1 = await router.acompletion(model="gpt-3.5-turbo", messages=messages)
print(f"response1: {response1}") print(f"response1: {response1}")
response2 = await router.acompletion(model="gpt-3.5-turbo", messages=messages) response2 = await router.acompletion(model="gpt-3.5-turbo", messages=messages)

View file

@ -18,7 +18,7 @@ import tiktoken
import uuid import uuid
import aiohttp import aiohttp
import logging import logging
import asyncio, httpx import asyncio, httpx, inspect
import copy import copy
from tokenizers import Tokenizer from tokenizers import Tokenizer
from dataclasses import ( from dataclasses import (
@ -1047,7 +1047,6 @@ def exception_logging(
# make it easy to log if completion/embedding runs succeeded or failed + see what happened | Non-Blocking # make it easy to log if completion/embedding runs succeeded or failed + see what happened | Non-Blocking
def client(original_function): def client(original_function):
global liteDebuggerClient, get_all_keys global liteDebuggerClient, get_all_keys
import inspect
def function_setup( def function_setup(
start_time, *args, **kwargs start_time, *args, **kwargs
): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc. ): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc.
@ -1333,13 +1332,13 @@ def client(original_function):
kwargs["retry_strategy"] = "exponential_backoff_retry" kwargs["retry_strategy"] = "exponential_backoff_retry"
elif (isinstance(e, openai.APIError)): # generic api error elif (isinstance(e, openai.APIError)): # generic api error
kwargs["retry_strategy"] = "constant_retry" kwargs["retry_strategy"] = "constant_retry"
return litellm.completion_with_retries(*args, **kwargs) return await litellm.acompletion_with_retries(*args, **kwargs)
elif isinstance(e, litellm.exceptions.ContextWindowExceededError) and context_window_fallback_dict and model in context_window_fallback_dict: elif isinstance(e, litellm.exceptions.ContextWindowExceededError) and context_window_fallback_dict and model in context_window_fallback_dict:
if len(args) > 0: if len(args) > 0:
args[0] = context_window_fallback_dict[model] args[0] = context_window_fallback_dict[model]
else: else:
kwargs["model"] = context_window_fallback_dict[model] kwargs["model"] = context_window_fallback_dict[model]
return original_function(*args, **kwargs) return await original_function(*args, **kwargs)
traceback_exception = traceback.format_exc() traceback_exception = traceback.format_exc()
crash_reporting(*args, **kwargs, exception=traceback_exception) crash_reporting(*args, **kwargs, exception=traceback_exception)
end_time = datetime.datetime.now() end_time = datetime.datetime.now()
@ -3370,7 +3369,7 @@ def exception_type(
else: else:
exception_type = "" exception_type = ""
if "Request Timeout Error" in error_str: if "Request Timeout Error" in error_str or "Request timed out" in error_str:
exception_mapping_worked = True exception_mapping_worked = True
raise Timeout( raise Timeout(
message=f"APITimeoutError - Request timed out", message=f"APITimeoutError - Request timed out",
@ -3411,7 +3410,6 @@ def exception_type(
message=f"OpenAIException - {original_exception.message}", message=f"OpenAIException - {original_exception.message}",
model=model, model=model,
llm_provider="openai", llm_provider="openai",
request=original_exception.request
) )
if original_exception.status_code == 422: if original_exception.status_code == 422:
exception_mapping_worked = True exception_mapping_worked = True
@ -4229,7 +4227,7 @@ def exception_type(
llm_provider=custom_llm_provider, llm_provider=custom_llm_provider,
response=original_exception.response response=original_exception.response
) )
else: # ensure generic errors always return APIConnectionError else: # ensure generic errors always return APIConnectionError=
exception_mapping_worked = True exception_mapping_worked = True
if hasattr(original_exception, "request"): if hasattr(original_exception, "request"):
raise APIConnectionError( raise APIConnectionError(