mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
fix(acompletion): support client side timeouts + raise exceptions correctly for async calls
This commit is contained in:
parent
c4e53aa77b
commit
0ab6b2451d
8 changed files with 142 additions and 81 deletions
|
@ -10,7 +10,6 @@
|
|||
import os, openai, sys, json, inspect, uuid, datetime, threading
|
||||
from typing import Any
|
||||
from functools import partial
|
||||
|
||||
import dotenv, traceback, random, asyncio, time, contextvars
|
||||
from copy import deepcopy
|
||||
import httpx
|
||||
|
@ -18,7 +17,6 @@ import litellm
|
|||
from litellm import ( # type: ignore
|
||||
client,
|
||||
exception_type,
|
||||
timeout,
|
||||
get_optional_params,
|
||||
get_litellm_params,
|
||||
Logging,
|
||||
|
@ -176,28 +174,28 @@ async def acompletion(*args, **kwargs):
|
|||
init_response = completion(*args, **kwargs)
|
||||
if isinstance(init_response, dict) or isinstance(init_response, ModelResponse): ## CACHING SCENARIO
|
||||
response = init_response
|
||||
else:
|
||||
elif asyncio.iscoroutine(init_response):
|
||||
response = await init_response
|
||||
else:
|
||||
# Call the synchronous function using run_in_executor
|
||||
response = await loop.run_in_executor(None, func_with_context)
|
||||
if kwargs.get("stream", False): # return an async generator
|
||||
# do not change this
|
||||
# for stream = True, always return an async generator
|
||||
# See OpenAI acreate https://github.com/openai/openai-python/blob/5d50e9e3b39540af782ca24e65c290343d86e1a9/openai/api_resources/abstract/engine_api_resource.py#L193
|
||||
# return response
|
||||
return(
|
||||
line
|
||||
async for line in response
|
||||
)
|
||||
return _async_streaming(response=response, model=model, custom_llm_provider=custom_llm_provider, args=args)
|
||||
else:
|
||||
return response
|
||||
except Exception as e:
|
||||
## Map to OpenAI Exception
|
||||
raise exception_type(
|
||||
model=model, custom_llm_provider=custom_llm_provider, original_exception=e, completion_kwargs=args,
|
||||
)
|
||||
|
||||
async def _async_streaming(response, model, custom_llm_provider, args):
|
||||
try:
|
||||
async for line in response:
|
||||
yield line
|
||||
except Exception as e:
|
||||
raise exception_type(
|
||||
model=model, custom_llm_provider=custom_llm_provider, original_exception=e, completion_kwargs=args,
|
||||
)
|
||||
|
||||
def mock_completion(model: str, messages: List, stream: Optional[bool] = False, mock_response: str = "This is a mock request", **kwargs):
|
||||
"""
|
||||
|
@ -245,6 +243,7 @@ def completion(
|
|||
messages: List = [],
|
||||
functions: List = [],
|
||||
function_call: str = "", # optional params
|
||||
timeout: Union[float, int] = 600.0,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
n: Optional[int] = None,
|
||||
|
@ -261,7 +260,6 @@ def completion(
|
|||
tools: Optional[List] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
deployment_id = None,
|
||||
|
||||
# set api_base, api_version, api_key
|
||||
base_url: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
|
@ -298,9 +296,8 @@ def completion(
|
|||
|
||||
LITELLM Specific Params
|
||||
mock_response (str, optional): If provided, return a mock completion response for testing or debugging purposes (default is None).
|
||||
force_timeout (int, optional): The maximum execution time in seconds for the completion request (default is 600).
|
||||
custom_llm_provider (str, optional): Used for Non-OpenAI LLMs, Example usage for bedrock, set model="amazon.titan-tg1-large" and custom_llm_provider="bedrock"
|
||||
num_retries (int, optional): The number of retries to attempt (default is 0).
|
||||
max_retries (int, optional): The number of retries to attempt (default is 0).
|
||||
Returns:
|
||||
ModelResponse: A response object containing the generated completion and associated metadata.
|
||||
|
||||
|
@ -314,7 +311,7 @@ def completion(
|
|||
api_base = kwargs.get('api_base', None)
|
||||
return_async = kwargs.get('return_async', False)
|
||||
mock_response = kwargs.get('mock_response', None)
|
||||
force_timeout= kwargs.get('force_timeout', 600)
|
||||
force_timeout= kwargs.get('force_timeout', 600) ## deprecated
|
||||
logger_fn = kwargs.get('logger_fn', None)
|
||||
verbose = kwargs.get('verbose', False)
|
||||
custom_llm_provider = kwargs.get('custom_llm_provider', None)
|
||||
|
@ -338,8 +335,10 @@ def completion(
|
|||
litellm_params = ["metadata", "acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "max_retries"]
|
||||
default_params = openai_params + litellm_params
|
||||
non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider
|
||||
|
||||
if mock_response:
|
||||
return mock_completion(model, messages, stream=stream, mock_response=mock_response)
|
||||
timeout = float(timeout)
|
||||
try:
|
||||
if base_url:
|
||||
api_base = base_url
|
||||
|
@ -486,7 +485,8 @@ def completion(
|
|||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
logging_obj=logging,
|
||||
acompletion=acompletion
|
||||
acompletion=acompletion,
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
|
@ -552,7 +552,8 @@ def completion(
|
|||
logging_obj=logging,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn
|
||||
logger_fn=logger_fn,
|
||||
timeout=timeout
|
||||
)
|
||||
except Exception as e:
|
||||
## LOGGING - log the original exception returned
|
||||
|
@ -1399,6 +1400,24 @@ def completion_with_retries(*args, **kwargs):
|
|||
retryer = tenacity.Retrying(wait=tenacity.wait_exponential(multiplier=1, max=10), stop=tenacity.stop_after_attempt(num_retries), reraise=True)
|
||||
return retryer(original_function, *args, **kwargs)
|
||||
|
||||
async def acompletion_with_retries(*args, **kwargs):
|
||||
"""
|
||||
Executes a litellm.completion() with 3 retries
|
||||
"""
|
||||
try:
|
||||
import tenacity
|
||||
except Exception as e:
|
||||
raise Exception(f"tenacity import failed please run `pip install tenacity`. Error{e}")
|
||||
|
||||
num_retries = kwargs.pop("num_retries", 3)
|
||||
retry_strategy = kwargs.pop("retry_strategy", "constant_retry")
|
||||
original_function = kwargs.pop("original_function", completion)
|
||||
if retry_strategy == "constant_retry":
|
||||
retryer = tenacity.Retrying(stop=tenacity.stop_after_attempt(num_retries), reraise=True)
|
||||
elif retry_strategy == "exponential_backoff_retry":
|
||||
retryer = tenacity.Retrying(wait=tenacity.wait_exponential(multiplier=1, max=10), stop=tenacity.stop_after_attempt(num_retries), reraise=True)
|
||||
return await retryer(original_function, *args, **kwargs)
|
||||
|
||||
|
||||
|
||||
def batch_completion(
|
||||
|
@ -1639,9 +1658,6 @@ async def aembedding(*args, **kwargs):
|
|||
return response
|
||||
|
||||
@client
|
||||
@timeout( # type: ignore
|
||||
60
|
||||
) ## set timeouts, in case calls hang (e.g. Azure) - default is 60s, override with `force_timeout`
|
||||
def embedding(
|
||||
model,
|
||||
input=[],
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue