fix(acompletion): support client side timeouts + raise exceptions correctly for async calls

This commit is contained in:
Krrish Dholakia 2023-11-17 15:39:39 -08:00
parent 43ee581d46
commit 02ed97d0b2
8 changed files with 142 additions and 81 deletions

View file

@ -1,4 +1,4 @@
from typing import Optional, Union
from typing import Optional, Union, Any
import types, requests
from .base import BaseLLM
from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper, convert_to_model_response_object
@ -98,6 +98,7 @@ class AzureChatCompletion(BaseLLM):
api_type: str,
azure_ad_token: str,
print_verbose: Callable,
timeout,
logging_obj,
optional_params,
litellm_params,
@ -129,13 +130,13 @@ class AzureChatCompletion(BaseLLM):
)
if acompletion is True:
if optional_params.get("stream", False):
return self.async_streaming(logging_obj=logging_obj, api_base=api_base, data=data, model=model, api_key=api_key, api_version=api_version, azure_ad_token=azure_ad_token)
return self.async_streaming(logging_obj=logging_obj, api_base=api_base, data=data, model=model, api_key=api_key, api_version=api_version, azure_ad_token=azure_ad_token, timeout=timeout)
else:
return self.acompletion(api_base=api_base, data=data, model_response=model_response, api_key=api_key, api_version=api_version, model=model, azure_ad_token=azure_ad_token)
return self.acompletion(api_base=api_base, data=data, model_response=model_response, api_key=api_key, api_version=api_version, model=model, azure_ad_token=azure_ad_token, timeout=timeout)
elif "stream" in optional_params and optional_params["stream"] == True:
return self.streaming(logging_obj=logging_obj, api_base=api_base, data=data, model=model, api_key=api_key, api_version=api_version, azure_ad_token=azure_ad_token)
return self.streaming(logging_obj=logging_obj, api_base=api_base, data=data, model=model, api_key=api_key, api_version=api_version, azure_ad_token=azure_ad_token, timeout=timeout)
else:
azure_client = AzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token, http_client=litellm.client_session)
azure_client = AzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token, http_client=litellm.client_session, timeout=timeout)
response = azure_client.chat.completions.create(**data) # type: ignore
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response)
except AzureOpenAIError as e:
@ -150,16 +151,18 @@ class AzureChatCompletion(BaseLLM):
model: str,
api_base: str,
data: dict,
timeout: Any,
model_response: ModelResponse,
azure_ad_token: Optional[str]=None, ):
response = None
try:
azure_client = AsyncAzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token, http_client=litellm.aclient_session)
azure_client = AsyncAzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token, http_client=litellm.aclient_session, timeout=timeout)
response = await azure_client.chat.completions.create(**data)
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response)
except Exception as e:
if isinstance(e,httpx.TimeoutException):
raise AzureOpenAIError(status_code=500, message="Request Timeout Error")
elif response and hasattr(response, "text"):
elif response is not None and hasattr(response, "text"):
raise AzureOpenAIError(status_code=500, message=f"{str(e)}\n\nOriginal Response: {response.text}")
else:
raise AzureOpenAIError(status_code=500, message=f"{str(e)}")
@ -171,9 +174,10 @@ class AzureChatCompletion(BaseLLM):
api_version: str,
data: dict,
model: str,
timeout: Any,
azure_ad_token: Optional[str]=None,
):
azure_client = AzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token, http_client=litellm.client_session)
azure_client = AzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token, http_client=litellm.client_session, timeout=timeout)
response = azure_client.chat.completions.create(**data)
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="azure",logging_obj=logging_obj)
for transformed_chunk in streamwrapper:
@ -186,8 +190,9 @@ class AzureChatCompletion(BaseLLM):
api_version: str,
data: dict,
model: str,
timeout: Any,
azure_ad_token: Optional[str]=None):
azure_client = AsyncAzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token, http_client=litellm.aclient_session)
azure_client = AsyncAzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token, http_client=litellm.aclient_session, timeout=timeout)
response = await azure_client.chat.completions.create(**data)
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="azure",logging_obj=logging_obj)
async for transformed_chunk in streamwrapper:

View file

@ -1,4 +1,4 @@
from typing import Optional, Union
from typing import Optional, Union, Any
import types, time, json
import httpx
from .base import BaseLLM
@ -161,6 +161,7 @@ class OpenAIChatCompletion(BaseLLM):
def completion(self,
model_response: ModelResponse,
timeout: Any,
model: Optional[str]=None,
messages: Optional[list]=None,
print_verbose: Optional[Callable]=None,
@ -180,6 +181,9 @@ class OpenAIChatCompletion(BaseLLM):
if model is None or messages is None:
raise OpenAIError(status_code=422, message=f"Missing model or messages")
if not isinstance(timeout, float):
raise OpenAIError(status_code=422, message=f"Timeout needs to be a float")
for _ in range(2): # if call fails due to alternating messages, retry with reformatted message
data = {
"model": model,
@ -197,13 +201,13 @@ class OpenAIChatCompletion(BaseLLM):
try:
if acompletion is True:
if optional_params.get("stream", False):
return self.async_streaming(logging_obj=logging_obj, data=data, model=model, api_base=api_base, api_key=api_key)
return self.async_streaming(logging_obj=logging_obj, data=data, model=model, api_base=api_base, api_key=api_key, timeout=timeout)
else:
return self.acompletion(data=data, model_response=model_response, api_base=api_base, api_key=api_key)
return self.acompletion(data=data, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout)
elif optional_params.get("stream", False):
return self.streaming(logging_obj=logging_obj, data=data, model=model, api_base=api_base, api_key=api_key)
return self.streaming(logging_obj=logging_obj, data=data, model=model, api_base=api_base, api_key=api_key, timeout=timeout)
else:
openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session)
openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout)
response = openai_client.chat.completions.create(**data) # type: ignore
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response)
except Exception as e:
@ -234,27 +238,32 @@ class OpenAIChatCompletion(BaseLLM):
async def acompletion(self,
data: dict,
model_response: ModelResponse,
timeout: float,
api_key: Optional[str]=None,
api_base: Optional[str]=None):
response = None
try:
openai_aclient = AsyncOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.aclient_session)
openai_aclient = AsyncOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.aclient_session, timeout=timeout)
response = await openai_aclient.chat.completions.create(**data)
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response)
except Exception as e:
if response and hasattr(response, "text"):
raise OpenAIError(status_code=500, message=f"{str(e)}\n\nOriginal Response: {response.text}")
else:
if type(e).__name__ == "ReadTimeout":
raise OpenAIError(status_code=408, message=f"{type(e).__name__}")
else:
raise OpenAIError(status_code=500, message=f"{str(e)}")
def streaming(self,
logging_obj,
timeout: float,
data: dict,
model: str,
api_key: Optional[str]=None,
api_base: Optional[str]=None
):
openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session)
openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout)
response = openai_client.chat.completions.create(**data)
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="openai",logging_obj=logging_obj)
for transformed_chunk in streamwrapper:
@ -262,15 +271,26 @@ class OpenAIChatCompletion(BaseLLM):
async def async_streaming(self,
logging_obj,
timeout: float,
data: dict,
model: str,
api_key: Optional[str]=None,
api_base: Optional[str]=None):
openai_aclient = AsyncOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.aclient_session)
response = None
try:
openai_aclient = AsyncOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.aclient_session, timeout=timeout)
response = await openai_aclient.chat.completions.create(**data)
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="openai",logging_obj=logging_obj)
async for transformed_chunk in streamwrapper:
yield transformed_chunk
except Exception as e: # need to exception handle here. async exceptions don't get caught in sync functions.
if response is not None and hasattr(response, "text"):
raise OpenAIError(status_code=500, message=f"{str(e)}\n\nOriginal Response: {response.text}")
else:
if type(e).__name__ == "ReadTimeout":
raise OpenAIError(status_code=408, message=f"{type(e).__name__}")
else:
raise OpenAIError(status_code=500, message=f"{str(e)}")
def embedding(self,
model: str,

View file

@ -10,7 +10,6 @@
import os, openai, sys, json, inspect, uuid, datetime, threading
from typing import Any
from functools import partial
import dotenv, traceback, random, asyncio, time, contextvars
from copy import deepcopy
import httpx
@ -18,7 +17,6 @@ import litellm
from litellm import ( # type: ignore
client,
exception_type,
timeout,
get_optional_params,
get_litellm_params,
Logging,
@ -176,28 +174,28 @@ async def acompletion(*args, **kwargs):
init_response = completion(*args, **kwargs)
if isinstance(init_response, dict) or isinstance(init_response, ModelResponse): ## CACHING SCENARIO
response = init_response
else:
elif asyncio.iscoroutine(init_response):
response = await init_response
else:
# Call the synchronous function using run_in_executor
response = await loop.run_in_executor(None, func_with_context)
if kwargs.get("stream", False): # return an async generator
# do not change this
# for stream = True, always return an async generator
# See OpenAI acreate https://github.com/openai/openai-python/blob/5d50e9e3b39540af782ca24e65c290343d86e1a9/openai/api_resources/abstract/engine_api_resource.py#L193
# return response
return(
line
async for line in response
)
return _async_streaming(response=response, model=model, custom_llm_provider=custom_llm_provider, args=args)
else:
return response
except Exception as e:
## Map to OpenAI Exception
raise exception_type(
model=model, custom_llm_provider=custom_llm_provider, original_exception=e, completion_kwargs=args,
)
async def _async_streaming(response, model, custom_llm_provider, args):
try:
async for line in response:
yield line
except Exception as e:
raise exception_type(
model=model, custom_llm_provider=custom_llm_provider, original_exception=e, completion_kwargs=args,
)
def mock_completion(model: str, messages: List, stream: Optional[bool] = False, mock_response: str = "This is a mock request", **kwargs):
"""
@ -245,6 +243,7 @@ def completion(
messages: List = [],
functions: List = [],
function_call: str = "", # optional params
timeout: Union[float, int] = 600.0,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
n: Optional[int] = None,
@ -261,7 +260,6 @@ def completion(
tools: Optional[List] = None,
tool_choice: Optional[str] = None,
deployment_id = None,
# set api_base, api_version, api_key
base_url: Optional[str] = None,
api_version: Optional[str] = None,
@ -298,9 +296,8 @@ def completion(
LITELLM Specific Params
mock_response (str, optional): If provided, return a mock completion response for testing or debugging purposes (default is None).
force_timeout (int, optional): The maximum execution time in seconds for the completion request (default is 600).
custom_llm_provider (str, optional): Used for Non-OpenAI LLMs, Example usage for bedrock, set model="amazon.titan-tg1-large" and custom_llm_provider="bedrock"
num_retries (int, optional): The number of retries to attempt (default is 0).
max_retries (int, optional): The number of retries to attempt (default is 0).
Returns:
ModelResponse: A response object containing the generated completion and associated metadata.
@ -314,7 +311,7 @@ def completion(
api_base = kwargs.get('api_base', None)
return_async = kwargs.get('return_async', False)
mock_response = kwargs.get('mock_response', None)
force_timeout= kwargs.get('force_timeout', 600)
force_timeout= kwargs.get('force_timeout', 600) ## deprecated
logger_fn = kwargs.get('logger_fn', None)
verbose = kwargs.get('verbose', False)
custom_llm_provider = kwargs.get('custom_llm_provider', None)
@ -338,8 +335,10 @@ def completion(
litellm_params = ["metadata", "acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "max_retries"]
default_params = openai_params + litellm_params
non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider
if mock_response:
return mock_completion(model, messages, stream=stream, mock_response=mock_response)
timeout = float(timeout)
try:
if base_url:
api_base = base_url
@ -486,7 +485,8 @@ def completion(
litellm_params=litellm_params,
logger_fn=logger_fn,
logging_obj=logging,
acompletion=acompletion
acompletion=acompletion,
timeout=timeout
)
## LOGGING
@ -552,7 +552,8 @@ def completion(
logging_obj=logging,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn
logger_fn=logger_fn,
timeout=timeout
)
except Exception as e:
## LOGGING - log the original exception returned
@ -1399,6 +1400,24 @@ def completion_with_retries(*args, **kwargs):
retryer = tenacity.Retrying(wait=tenacity.wait_exponential(multiplier=1, max=10), stop=tenacity.stop_after_attempt(num_retries), reraise=True)
return retryer(original_function, *args, **kwargs)
async def acompletion_with_retries(*args, **kwargs):
"""
Executes a litellm.completion() with 3 retries
"""
try:
import tenacity
except Exception as e:
raise Exception(f"tenacity import failed please run `pip install tenacity`. Error{e}")
num_retries = kwargs.pop("num_retries", 3)
retry_strategy = kwargs.pop("retry_strategy", "constant_retry")
original_function = kwargs.pop("original_function", completion)
if retry_strategy == "constant_retry":
retryer = tenacity.Retrying(stop=tenacity.stop_after_attempt(num_retries), reraise=True)
elif retry_strategy == "exponential_backoff_retry":
retryer = tenacity.Retrying(wait=tenacity.wait_exponential(multiplier=1, max=10), stop=tenacity.stop_after_attempt(num_retries), reraise=True)
return await retryer(original_function, *args, **kwargs)
def batch_completion(
@ -1639,9 +1658,6 @@ async def aembedding(*args, **kwargs):
return response
@client
@timeout( # type: ignore
60
) ## set timeouts, in case calls hang (e.g. Azure) - default is 60s, override with `force_timeout`
def embedding(
model,
input=[],

View file

@ -26,4 +26,3 @@ model_list:
litellm_settings:
drop_params: True
success_callback: ["langfuse"] # https://docs.litellm.ai/docs/observability/langfuse_integration

View file

@ -33,7 +33,7 @@ class Router:
redis_port: Optional[int] = None,
redis_password: Optional[str] = None,
cache_responses: bool = False,
num_retries: Optional[int] = None,
num_retries: int = 0,
timeout: float = 600,
default_litellm_params = {}, # default params for Router.chat.completion.create
routing_strategy: Literal["simple-shuffle", "least-busy"] = "simple-shuffle") -> None:
@ -42,12 +42,13 @@ class Router:
self.set_model_list(model_list)
self.healthy_deployments: List = self.model_list
if num_retries:
self.num_retries = num_retries
self.chat = litellm.Chat(params=default_litellm_params)
litellm.request_timeout = timeout
self.default_litellm_params = {
"timeout": timeout
}
self.routing_strategy = routing_strategy
### HEALTH CHECK THREAD ###
if self.routing_strategy == "least-busy":
@ -222,6 +223,9 @@ class Router:
# pick the one that is available (lowest TPM/RPM)
deployment = self.get_available_deployment(model=model, messages=messages)
data = deployment["litellm_params"]
for k, v in self.default_litellm_params.items():
if k not in data: # prioritize model-specific params > default router params
data[k] = v
return litellm.completion(**{**data, "messages": messages, "caching": self.cache_responses, **kwargs})
@ -234,16 +238,20 @@ class Router:
try:
deployment = self.get_available_deployment(model=model, messages=messages)
data = deployment["litellm_params"]
for k, v in self.default_litellm_params.items():
if k not in data: # prioritize model-specific params > default router params
data[k] = v
response = await litellm.acompletion(**{**data, "messages": messages, "caching": self.cache_responses, **kwargs})
if inspect.iscoroutinefunction(response): # async errors are often returned as coroutines
response = await response
return response
except Exception as e:
if self.num_retries > 0:
kwargs["model"] = model
kwargs["messages"] = messages
kwargs["original_exception"] = e
kwargs["original_function"] = self.acompletion
return await self.async_function_with_retries(**kwargs)
else:
raise e
def text_completion(self,
model: str,
@ -258,6 +266,9 @@ class Router:
deployment = self.get_available_deployment(model=model, messages=messages)
data = deployment["litellm_params"]
for k, v in self.default_litellm_params.items():
if k not in data: # prioritize model-specific params > default router params
data[k] = v
# call via litellm.completion()
return litellm.text_completion(**{**data, "prompt": prompt, "caching": self.cache_responses, **kwargs}) # type: ignore
@ -270,6 +281,9 @@ class Router:
deployment = self.get_available_deployment(model=model, input=input)
data = deployment["litellm_params"]
for k, v in self.default_litellm_params.items():
if k not in data: # prioritize model-specific params > default router params
data[k] = v
# call via litellm.embedding()
return litellm.embedding(**{**data, "input": input, "caching": self.cache_responses, **kwargs})
@ -282,4 +296,7 @@ class Router:
deployment = self.get_available_deployment(model=model, input=input)
data = deployment["litellm_params"]
for k, v in self.default_litellm_params.items():
if k not in data: # prioritize model-specific params > default router params
data[k] = v
return await litellm.aembedding(**{**data, "input": input, "caching": self.cache_responses, **kwargs})

View file

@ -5,8 +5,6 @@ import sys, os
import pytest
import traceback
import asyncio, logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
sys.path.insert(
0, os.path.abspath("../..")
@ -20,7 +18,8 @@ def test_sync_response():
user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}]
try:
response = completion(model="gpt-3.5-turbo", messages=messages, api_key=os.environ["OPENAI_API_KEY"])
response = completion(model="gpt-3.5-turbo", messages=messages, timeout=5)
print(f"response: {response}")
except Exception as e:
pytest.fail(f"An exception occurred: {e}")
# test_sync_response()
@ -30,7 +29,7 @@ def test_sync_response_anyscale():
user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}]
try:
response = completion(model="anyscale/mistralai/Mistral-7B-Instruct-v0.1", messages=messages)
response = completion(model="anyscale/mistralai/Mistral-7B-Instruct-v0.1", messages=messages, timeout=5)
except Exception as e:
pytest.fail(f"An exception occurred: {e}")
@ -43,10 +42,13 @@ def test_async_response_openai():
user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}]
try:
response = await acompletion(model="gpt-3.5-turbo", messages=messages)
response = await acompletion(model="gpt-3.5-turbo", messages=messages, timeout=5)
print(f"response: {response}")
except litellm.Timeout as e:
pass
except Exception as e:
pytest.fail(f"An exception occurred: {e}")
print(e)
asyncio.run(test_get_response())
@ -56,17 +58,18 @@ def test_async_response_azure():
import asyncio
litellm.set_verbose = True
async def test_get_response():
user_message = "Hello, how are you?"
user_message = "What do you know?"
messages = [{"content": user_message, "role": "user"}]
try:
response = await acompletion(model="azure/chatgpt-v-2", messages=messages)
response = await acompletion(model="azure/chatgpt-v-2", messages=messages, timeout=5)
print(f"response: {response}")
except litellm.Timeout as e:
pass
except Exception as e:
pytest.fail(f"An exception occurred: {e}")
asyncio.run(test_get_response())
def test_async_anyscale_response():
import asyncio
litellm.set_verbose = True
@ -74,9 +77,11 @@ def test_async_anyscale_response():
user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}]
try:
response = await acompletion(model="anyscale/mistralai/Mistral-7B-Instruct-v0.1", messages=messages)
response = await acompletion(model="anyscale/mistralai/Mistral-7B-Instruct-v0.1", messages=messages, timeout=5)
# response = await response
print(f"response: {response}")
except litellm.Timeout as e:
pass
except Exception as e:
pytest.fail(f"An exception occurred: {e}")
@ -91,7 +96,7 @@ def test_get_response_streaming():
messages = [{"content": user_message, "role": "user"}]
try:
litellm.set_verbose = True
response = await acompletion(model="gpt-3.5-turbo", messages=messages, stream=True)
response = await acompletion(model="gpt-3.5-turbo", messages=messages, stream=True, timeout=5)
print(type(response))
import inspect
@ -108,12 +113,12 @@ def test_get_response_streaming():
assert isinstance(output, str), "output needs to be of type str"
assert len(output) > 0, "Length of output needs to be greater than 0."
print(f'output: {output}')
except litellm.Timeout as e:
pass
except Exception as e:
pytest.fail(f"An exception occurred: {e}")
return response
asyncio.run(test_async_call())
# test_get_response_streaming()
def test_get_response_non_openai_streaming():
@ -123,7 +128,7 @@ def test_get_response_non_openai_streaming():
user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}]
try:
response = await acompletion(model="anyscale/mistralai/Mistral-7B-Instruct-v0.1", messages=messages, stream=True)
response = await acompletion(model="anyscale/mistralai/Mistral-7B-Instruct-v0.1", messages=messages, stream=True, timeout=5)
print(type(response))
import inspect
@ -140,10 +145,11 @@ def test_get_response_non_openai_streaming():
assert output is not None, "output cannot be None."
assert isinstance(output, str), "output needs to be of type str"
assert len(output) > 0, "Length of output needs to be greater than 0."
except litellm.Timeout as e:
pass
except Exception as e:
pytest.fail(f"An exception occurred: {e}")
return response
asyncio.run(test_async_call())
# test_get_response_non_openai_streaming()
test_get_response_non_openai_streaming()

View file

@ -242,11 +242,11 @@ def test_acompletion_on_router():
]
messages = [
{"role": "user", "content": "What is the weather like in Boston?"}
{"role": "user", "content": "What is the weather like in SF?"}
]
async def get_response():
router = Router(model_list=model_list, redis_host=os.environ["REDIS_HOST"], redis_password=os.environ["REDIS_PASSWORD"], redis_port=os.environ["REDIS_PORT"], cache_responses=True, timeout=10)
router = Router(model_list=model_list, redis_host=os.environ["REDIS_HOST"], redis_password=os.environ["REDIS_PASSWORD"], redis_port=os.environ["REDIS_PORT"], cache_responses=True, timeout=0.1)
response1 = await router.acompletion(model="gpt-3.5-turbo", messages=messages)
print(f"response1: {response1}")
response2 = await router.acompletion(model="gpt-3.5-turbo", messages=messages)

View file

@ -18,7 +18,7 @@ import tiktoken
import uuid
import aiohttp
import logging
import asyncio, httpx
import asyncio, httpx, inspect
import copy
from tokenizers import Tokenizer
from dataclasses import (
@ -1047,7 +1047,6 @@ def exception_logging(
# make it easy to log if completion/embedding runs succeeded or failed + see what happened | Non-Blocking
def client(original_function):
global liteDebuggerClient, get_all_keys
import inspect
def function_setup(
start_time, *args, **kwargs
): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc.
@ -1333,13 +1332,13 @@ def client(original_function):
kwargs["retry_strategy"] = "exponential_backoff_retry"
elif (isinstance(e, openai.APIError)): # generic api error
kwargs["retry_strategy"] = "constant_retry"
return litellm.completion_with_retries(*args, **kwargs)
return await litellm.acompletion_with_retries(*args, **kwargs)
elif isinstance(e, litellm.exceptions.ContextWindowExceededError) and context_window_fallback_dict and model in context_window_fallback_dict:
if len(args) > 0:
args[0] = context_window_fallback_dict[model]
else:
kwargs["model"] = context_window_fallback_dict[model]
return original_function(*args, **kwargs)
return await original_function(*args, **kwargs)
traceback_exception = traceback.format_exc()
crash_reporting(*args, **kwargs, exception=traceback_exception)
end_time = datetime.datetime.now()
@ -3370,7 +3369,7 @@ def exception_type(
else:
exception_type = ""
if "Request Timeout Error" in error_str:
if "Request Timeout Error" in error_str or "Request timed out" in error_str:
exception_mapping_worked = True
raise Timeout(
message=f"APITimeoutError - Request timed out",
@ -3411,7 +3410,6 @@ def exception_type(
message=f"OpenAIException - {original_exception.message}",
model=model,
llm_provider="openai",
request=original_exception.request
)
if original_exception.status_code == 422:
exception_mapping_worked = True
@ -4229,7 +4227,7 @@ def exception_type(
llm_provider=custom_llm_provider,
response=original_exception.response
)
else: # ensure generic errors always return APIConnectionError
else: # ensure generic errors always return APIConnectionError=
exception_mapping_worked = True
if hasattr(original_exception, "request"):
raise APIConnectionError(