fix(acompletion): support client side timeouts + raise exceptions correctly for async calls

This commit is contained in:
Krrish Dholakia 2023-11-17 15:39:39 -08:00
parent c4e53aa77b
commit 0ab6b2451d
8 changed files with 142 additions and 81 deletions

View file

@ -10,7 +10,6 @@
import os, openai, sys, json, inspect, uuid, datetime, threading
from typing import Any
from functools import partial
import dotenv, traceback, random, asyncio, time, contextvars
from copy import deepcopy
import httpx
@ -18,7 +17,6 @@ import litellm
from litellm import ( # type: ignore
client,
exception_type,
timeout,
get_optional_params,
get_litellm_params,
Logging,
@ -176,28 +174,28 @@ async def acompletion(*args, **kwargs):
init_response = completion(*args, **kwargs)
if isinstance(init_response, dict) or isinstance(init_response, ModelResponse): ## CACHING SCENARIO
response = init_response
else:
elif asyncio.iscoroutine(init_response):
response = await init_response
else:
# Call the synchronous function using run_in_executor
response = await loop.run_in_executor(None, func_with_context)
if kwargs.get("stream", False): # return an async generator
# do not change this
# for stream = True, always return an async generator
# See OpenAI acreate https://github.com/openai/openai-python/blob/5d50e9e3b39540af782ca24e65c290343d86e1a9/openai/api_resources/abstract/engine_api_resource.py#L193
# return response
return(
line
async for line in response
)
return _async_streaming(response=response, model=model, custom_llm_provider=custom_llm_provider, args=args)
else:
return response
except Exception as e:
## Map to OpenAI Exception
raise exception_type(
model=model, custom_llm_provider=custom_llm_provider, original_exception=e, completion_kwargs=args,
)
async def _async_streaming(response, model, custom_llm_provider, args):
try:
async for line in response:
yield line
except Exception as e:
raise exception_type(
model=model, custom_llm_provider=custom_llm_provider, original_exception=e, completion_kwargs=args,
)
def mock_completion(model: str, messages: List, stream: Optional[bool] = False, mock_response: str = "This is a mock request", **kwargs):
"""
@ -245,6 +243,7 @@ def completion(
messages: List = [],
functions: List = [],
function_call: str = "", # optional params
timeout: Union[float, int] = 600.0,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
n: Optional[int] = None,
@ -261,7 +260,6 @@ def completion(
tools: Optional[List] = None,
tool_choice: Optional[str] = None,
deployment_id = None,
# set api_base, api_version, api_key
base_url: Optional[str] = None,
api_version: Optional[str] = None,
@ -298,9 +296,8 @@ def completion(
LITELLM Specific Params
mock_response (str, optional): If provided, return a mock completion response for testing or debugging purposes (default is None).
force_timeout (int, optional): The maximum execution time in seconds for the completion request (default is 600).
custom_llm_provider (str, optional): Used for Non-OpenAI LLMs, Example usage for bedrock, set model="amazon.titan-tg1-large" and custom_llm_provider="bedrock"
num_retries (int, optional): The number of retries to attempt (default is 0).
max_retries (int, optional): The number of retries to attempt (default is 0).
Returns:
ModelResponse: A response object containing the generated completion and associated metadata.
@ -314,7 +311,7 @@ def completion(
api_base = kwargs.get('api_base', None)
return_async = kwargs.get('return_async', False)
mock_response = kwargs.get('mock_response', None)
force_timeout= kwargs.get('force_timeout', 600)
force_timeout= kwargs.get('force_timeout', 600) ## deprecated
logger_fn = kwargs.get('logger_fn', None)
verbose = kwargs.get('verbose', False)
custom_llm_provider = kwargs.get('custom_llm_provider', None)
@ -338,8 +335,10 @@ def completion(
litellm_params = ["metadata", "acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "max_retries"]
default_params = openai_params + litellm_params
non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider
if mock_response:
return mock_completion(model, messages, stream=stream, mock_response=mock_response)
timeout = float(timeout)
try:
if base_url:
api_base = base_url
@ -486,7 +485,8 @@ def completion(
litellm_params=litellm_params,
logger_fn=logger_fn,
logging_obj=logging,
acompletion=acompletion
acompletion=acompletion,
timeout=timeout
)
## LOGGING
@ -552,7 +552,8 @@ def completion(
logging_obj=logging,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn
logger_fn=logger_fn,
timeout=timeout
)
except Exception as e:
## LOGGING - log the original exception returned
@ -1399,6 +1400,24 @@ def completion_with_retries(*args, **kwargs):
retryer = tenacity.Retrying(wait=tenacity.wait_exponential(multiplier=1, max=10), stop=tenacity.stop_after_attempt(num_retries), reraise=True)
return retryer(original_function, *args, **kwargs)
async def acompletion_with_retries(*args, **kwargs):
"""
Executes a litellm.completion() with 3 retries
"""
try:
import tenacity
except Exception as e:
raise Exception(f"tenacity import failed please run `pip install tenacity`. Error{e}")
num_retries = kwargs.pop("num_retries", 3)
retry_strategy = kwargs.pop("retry_strategy", "constant_retry")
original_function = kwargs.pop("original_function", completion)
if retry_strategy == "constant_retry":
retryer = tenacity.Retrying(stop=tenacity.stop_after_attempt(num_retries), reraise=True)
elif retry_strategy == "exponential_backoff_retry":
retryer = tenacity.Retrying(wait=tenacity.wait_exponential(multiplier=1, max=10), stop=tenacity.stop_after_attempt(num_retries), reraise=True)
return await retryer(original_function, *args, **kwargs)
def batch_completion(
@ -1639,9 +1658,6 @@ async def aembedding(*args, **kwargs):
return response
@client
@timeout( # type: ignore
60
) ## set timeouts, in case calls hang (e.g. Azure) - default is 60s, override with `force_timeout`
def embedding(
model,
input=[],