diff --git a/litellm/__init__.py b/litellm/__init__.py index 16416b821..8b456987c 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -19,6 +19,7 @@ telemetry = True max_tokens = 256 # OpenAI Defaults drop_params = False retry = True +request_timeout: float = 600 api_key: Optional[str] = None openai_key: Optional[str] = None azure_key: Optional[str] = None @@ -46,6 +47,7 @@ _current_cost = 0 # private variable, used if max budget is set error_logs: Dict = {} add_function_to_prompt: bool = False # if function calling not supported by api, append function call details to system prompt client_session: Optional[httpx.Client] = None +aclient_session: Optional[httpx.AsyncClient] = None model_fallbacks: Optional[List] = None model_cost_map_url: str = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json" num_retries: Optional[int] = None diff --git a/litellm/exceptions.py b/litellm/exceptions.py index 1ea15f299..941d79bd2 100644 --- a/litellm/exceptions.py +++ b/litellm/exceptions.py @@ -105,7 +105,7 @@ class APIError(APIError): # type: ignore super().__init__( self.message, request=request, # type: ignore - body=None, + body=None ) # raised if an invalid request (not get, delete, put, post) is made diff --git a/litellm/llms/base.py b/litellm/llms/base.py index 5ed89fafa..d93b5a3f6 100644 --- a/litellm/llms/base.py +++ b/litellm/llms/base.py @@ -9,13 +9,25 @@ class BaseLLM: if litellm.client_session: _client_session = litellm.client_session else: - _client_session = httpx.Client(timeout=600) + _client_session = httpx.Client(timeout=litellm.request_timeout) return _client_session + def create_aclient_session(self): + if litellm.aclient_session: + _aclient_session = litellm.aclient_session + else: + _aclient_session = httpx.AsyncClient(timeout=litellm.request_timeout) + + return _aclient_session + def __exit__(self): if hasattr(self, '_client_session'): self._client_session.close() + + async def __aexit__(self, exc_type, exc_val, exc_tb): + if hasattr(self, '_aclient_session'): + await self._aclient_session.aclose() def validate_environment(self): # set up the environment required to run the model pass diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index dda5df2f5..2cfe0efeb 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -154,10 +154,12 @@ class OpenAITextCompletionConfig(): class OpenAIChatCompletion(BaseLLM): _client_session: httpx.Client + _aclient_session: httpx.AsyncClient def __init__(self) -> None: super().__init__() self._client_session = self.create_client_session() + self._aclient_session = self.create_aclient_session() def validate_environment(self, api_key): headers = { @@ -251,15 +253,15 @@ class OpenAIChatCompletion(BaseLLM): api_base: str, data: dict, headers: dict, model_response: ModelResponse): - async with httpx.AsyncClient(timeout=600) as client: - response = await client.post(api_base, json=data, headers=headers) - response_json = response.json() - if response.status_code != 200: - raise OpenAIError(status_code=response.status_code, message=response.text, request=response.request, response=response) - + client = self._aclient_session - ## RESPONSE OBJECT - return convert_to_model_response_object(response_object=response_json, model_response_object=model_response) + response = await client.post(api_base, json=data, headers=headers) + response_json = response.json() + if response.status_code != 200: + raise OpenAIError(status_code=response.status_code, message=response.text, request=response.request, response=response) + + ## RESPONSE OBJECT + return convert_to_model_response_object(response_object=response_json, model_response_object=model_response) def streaming(self, logging_obj, @@ -290,8 +292,7 @@ class OpenAIChatCompletion(BaseLLM): headers: dict, model_response: ModelResponse, model: str): - client = httpx.AsyncClient() - async with client.stream( + async with self._aclient_session.stream( url=f"{api_base}", json=data, headers=headers, diff --git a/litellm/llms/vertex_ai.py b/litellm/llms/vertex_ai.py index 5d15cdeff..df34a2549 100644 --- a/litellm/llms/vertex_ai.py +++ b/litellm/llms/vertex_ai.py @@ -6,11 +6,14 @@ import time from typing import Callable, Optional from litellm.utils import ModelResponse, Usage import litellm +import httpx class VertexAIError(Exception): def __init__(self, status_code, message): self.status_code = status_code self.message = message + self.request = httpx.Request(method="POST", url="https://api.ai21.com/studio/v1/") + self.response = httpx.Response(status_code=status_code, request=self.request) super().__init__( self.message ) # Call the base class constructor with the parameters it needs diff --git a/litellm/main.py b/litellm/main.py index 41f9bd35a..2299a4570 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -108,6 +108,7 @@ class Completions(): response = completion(model=model, messages=messages, **self.params) return response +@client async def acompletion(*args, **kwargs): """ Asynchronously executes a litellm.completion() call for any of litellm supported llms (example gpt-4, gpt-3.5-turbo, claude-2, command-nightly) @@ -149,63 +150,51 @@ async def acompletion(*args, **kwargs): """ loop = asyncio.get_event_loop() model = args[0] if len(args) > 0 else kwargs["model"] - messages = args[1] if len(args) > 1 else kwargs["messages"] - ### INITIALIZE LOGGING OBJECT ### - kwargs["litellm_call_id"] = str(uuid.uuid4()) - start_time = datetime.datetime.now() - logging_obj = Logging(model=model, messages=messages, stream=kwargs.get("stream", False), litellm_call_id=kwargs["litellm_call_id"], function_id=kwargs.get("id", None), call_type="completion", start_time=start_time) - ### PASS ARGS TO COMPLETION ### - kwargs["litellm_logging_obj"] = logging_obj kwargs["acompletion"] = True - kwargs["model"] = model - kwargs["messages"] = messages - # Use a partial function to pass your keyword arguments - func = partial(completion, *args, **kwargs) + try: + # Use a partial function to pass your keyword arguments + func = partial(completion, *args, **kwargs) - # Add the context to the function - ctx = contextvars.copy_context() - func_with_context = partial(ctx.run, func) + # Add the context to the function + ctx = contextvars.copy_context() + func_with_context = partial(ctx.run, func) - _, custom_llm_provider, _, _ = get_llm_provider(model=model, api_base=kwargs.get("api_base", None)) + _, custom_llm_provider, _, _ = get_llm_provider(model=model, api_base=kwargs.get("api_base", None)) - if (custom_llm_provider == "openai" - or custom_llm_provider == "azure" - or custom_llm_provider == "custom_openai" - or custom_llm_provider == "text-completion-openai"): # currently implemented aiohttp calls for just azure and openai, soon all. - if kwargs.get("stream", False): - response = completion(*args, **kwargs) - else: - # Await normally - init_response = completion(*args, **kwargs) - if isinstance(init_response, dict) or isinstance(init_response, ModelResponse): - response = init_response + if (custom_llm_provider == "openai" + or custom_llm_provider == "azure" + or custom_llm_provider == "custom_openai" + or custom_llm_provider == "text-completion-openai"): # currently implemented aiohttp calls for just azure and openai, soon all. + if kwargs.get("stream", False): + response = completion(*args, **kwargs) else: - response = await init_response - else: - # Call the synchronous function using run_in_executor - response = await loop.run_in_executor(None, func_with_context) - if kwargs.get("stream", False): # return an async generator - # do not change this - # for stream = True, always return an async generator - # See OpenAI acreate https://github.com/openai/openai-python/blob/5d50e9e3b39540af782ca24e65c290343d86e1a9/openai/api_resources/abstract/engine_api_resource.py#L193 - # return response - return( - line - async for line in response - ) - else: - end_time = datetime.datetime.now() - # [OPTIONAL] ADD TO CACHE - if litellm.caching or litellm.caching_with_models or litellm.cache != None: # user init a cache object - litellm.cache.add_cache(response, *args, **kwargs) + # Await normally + init_response = completion(*args, **kwargs) + if isinstance(init_response, dict) or isinstance(init_response, ModelResponse): ## CACHING SCENARIO + response = init_response + else: + response = await init_response + else: + # Call the synchronous function using run_in_executor + response = await loop.run_in_executor(None, func_with_context) + if kwargs.get("stream", False): # return an async generator + # do not change this + # for stream = True, always return an async generator + # See OpenAI acreate https://github.com/openai/openai-python/blob/5d50e9e3b39540af782ca24e65c290343d86e1a9/openai/api_resources/abstract/engine_api_resource.py#L193 + # return response + return( + line + async for line in response + ) + else: + return response + except Exception as e: + ## Map to OpenAI Exception + raise exception_type( + model=model, custom_llm_provider=custom_llm_provider, original_exception=e, completion_kwargs=args, + ) - # LOG SUCCESS - logging_obj.success_handler(response, start_time, end_time) - # RETURN RESULT - response._response_ms = (end_time - start_time).total_seconds() * 1000 # return response latency in ms like openai - - return response def mock_completion(model: str, messages: List, stream: Optional[bool] = False, mock_response: str = "This is a mock request", **kwargs): """ @@ -1420,8 +1409,14 @@ def completion_with_retries(*args, **kwargs): raise Exception(f"tenacity import failed please run `pip install tenacity`. Error{e}") num_retries = kwargs.pop("num_retries", 3) - retryer = tenacity.Retrying(stop=tenacity.stop_after_attempt(num_retries), reraise=True) - return retryer(completion, *args, **kwargs) + retry_strategy = kwargs.pop("retry_strategy", "constant_retry") + original_function = kwargs.pop("original_function", completion) + if retry_strategy == "constant_retry": + retryer = tenacity.Retrying(stop=tenacity.stop_after_attempt(num_retries), reraise=True) + elif retry_strategy == "exponential_backoff_retry": + retryer = tenacity.Retrying(wait=tenacity.wait_exponential(multiplier=1, max=10), stop=tenacity.stop_after_attempt(num_retries), reraise=True) + return retryer(original_function, *args, **kwargs) + def batch_completion( diff --git a/litellm/tests/test_loadtest_router.py b/litellm/tests/test_loadtest_router.py new file mode 100644 index 000000000..0d983172b --- /dev/null +++ b/litellm/tests/test_loadtest_router.py @@ -0,0 +1,47 @@ +import sys, os +import traceback +from dotenv import load_dotenv +import copy + +load_dotenv() +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import asyncio +from litellm import Router + +async def call_acompletion(semaphore, router: Router, input_data): + async with semaphore: + # Replace 'input_data' with appropriate parameters for acompletion + response = await router.acompletion(**input_data) + # Handle the response as needed + return response + +async def main(): + # Initialize the Router + model_list= [{ + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "gpt-3.5-turbo", + "api_key": os.getenv("OPENAI_API_KEY"), + }, + }] + router = Router(model_list=model_list, num_retries=3) + + # Create a semaphore with a capacity of 100 + semaphore = asyncio.Semaphore(100) + + # List to hold all task references + tasks = [] + + # Launch 1000 tasks + for _ in range(1000): + task = asyncio.create_task(call_acompletion(semaphore, router, {"model": "gpt-3.5-turbo", "messages": [{"role":"user", "content": "Hey, how's it going?"}]})) + tasks.append(task) + + # Wait for all tasks to complete + responses = await asyncio.gather(*tasks) + # Process responses as needed + +# Run the main function +asyncio.run(main()) diff --git a/litellm/tests/test_router.py b/litellm/tests/test_router.py index 8b0d06d27..4a90d9399 100644 --- a/litellm/tests/test_router.py +++ b/litellm/tests/test_router.py @@ -317,7 +317,7 @@ def test_function_calling_on_router(): except Exception as e: print(f"An exception occurred: {e}") -test_function_calling_on_router() +# test_function_calling_on_router() def test_aembedding_on_router(): try: diff --git a/litellm/utils.py b/litellm/utils.py index b0f6ee49c..47741317f 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -456,6 +456,7 @@ from enum import Enum class CallTypes(Enum): embedding = 'embedding' completion = 'completion' + acompletion = 'acompletion' # Logging function -> log the exact model details + what's being sent | Non-Blocking class Logging: @@ -984,7 +985,7 @@ def exception_logging( # make it easy to log if completion/embedding runs succeeded or failed + see what happened | Non-Blocking def client(original_function): global liteDebuggerClient, get_all_keys - + import inspect def function_setup( start_time, *args, **kwargs ): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc. @@ -1036,7 +1037,7 @@ def client(original_function): # INIT LOGGER - for user-specified integrations model = args[0] if len(args) > 0 else kwargs["model"] call_type = original_function.__name__ - if call_type == CallTypes.completion.value: + if call_type == CallTypes.completion.value or call_type == CallTypes.acompletion.value: if len(args) > 1: messages = args[1] elif kwargs.get("messages", None): @@ -1183,7 +1184,107 @@ def client(original_function): ): # make it easy to get to the debugger logs if you've initialized it e.message += f"\n Check the log in your dashboard - {liteDebuggerClient.dashboard_url}" raise e - return wrapper + + async def wrapper_async(*args, **kwargs): + start_time = datetime.datetime.now() + result = None + logging_obj = kwargs.get("litellm_logging_obj", None) + + # only set litellm_call_id if its not in kwargs + if "litellm_call_id" not in kwargs: + kwargs["litellm_call_id"] = str(uuid.uuid4()) + try: + model = args[0] if len(args) > 0 else kwargs["model"] + except: + raise ValueError("model param not passed in.") + + try: + if logging_obj is None: + logging_obj = function_setup(start_time, *args, **kwargs) + kwargs["litellm_logging_obj"] = logging_obj + + # [OPTIONAL] CHECK BUDGET + if litellm.max_budget: + if litellm._current_cost > litellm.max_budget: + raise BudgetExceededError(current_cost=litellm._current_cost, max_budget=litellm.max_budget) + + # [OPTIONAL] CHECK CACHE + print_verbose(f"litellm.cache: {litellm.cache}") + print_verbose(f"kwargs[caching]: {kwargs.get('caching', False)}; litellm.cache: {litellm.cache}") + # if caching is false, don't run this + if (kwargs.get("caching", None) is None and litellm.cache is not None) or kwargs.get("caching", False) == True: # allow users to control returning cached responses from the completion function + # checking cache + if (litellm.cache != None): + print_verbose(f"Checking Cache") + cached_result = litellm.cache.get_cache(*args, **kwargs) + if cached_result != None: + print_verbose(f"Cache Hit!") + call_type = original_function.__name__ + if call_type == CallTypes.acompletion.value and isinstance(cached_result, dict): + return convert_to_model_response_object(response_object=cached_result, model_response_object=ModelResponse()) + else: + return cached_result + # MODEL CALL + result = original_function(*args, **kwargs) + end_time = datetime.datetime.now() + if "stream" in kwargs and kwargs["stream"] == True: + if "complete_response" in kwargs and kwargs["complete_response"] == True: + chunks = [] + for idx, chunk in enumerate(result): + chunks.append(chunk) + return litellm.stream_chunk_builder(chunks) + else: + return result + result = await result + # [OPTIONAL] ADD TO CACHE + if litellm.caching or litellm.caching_with_models or litellm.cache != None: # user init a cache object + litellm.cache.add_cache(result, *args, **kwargs) + + # LOG SUCCESS - handle streaming success logging in the _next_ object, remove `handle_success` once it's deprecated + logging_obj.success_handler(result, start_time, end_time) + # RETURN RESULT + return result + except Exception as e: + call_type = original_function.__name__ + if call_type == CallTypes.acompletion.value: + num_retries = ( + kwargs.get("num_retries", None) + or litellm.num_retries + or None + ) + litellm.num_retries = None # set retries to None to prevent infinite loops + context_window_fallback_dict = kwargs.get("context_window_fallback_dict", {}) + + if num_retries: + kwargs["num_retries"] = num_retries + kwargs["original_function"] = original_function + if (isinstance(e, openai.RateLimitError)): # rate limiting specific error + kwargs["retry_strategy"] = "exponential_backoff_retry" + elif (isinstance(e, openai.APIError)): # generic api error + kwargs["retry_strategy"] = "constant_retry" + return litellm.completion_with_retries(*args, **kwargs) + elif isinstance(e, litellm.exceptions.ContextWindowExceededError) and context_window_fallback_dict and model in context_window_fallback_dict: + if len(args) > 0: + args[0] = context_window_fallback_dict[model] + else: + kwargs["model"] = context_window_fallback_dict[model] + return original_function(*args, **kwargs) + traceback_exception = traceback.format_exc() + crash_reporting(*args, **kwargs, exception=traceback_exception) + end_time = datetime.datetime.now() + # LOG FAILURE - handle streaming failure logging in the _next_ object, remove `handle_failure` once it's deprecated + if logging_obj: + threading.Thread(target=logging_obj.failure_handler, args=(e, traceback_exception, start_time, end_time)).start() + raise e + + # Use httpx to determine if the original function is a coroutine + is_coroutine = inspect.iscoroutinefunction(original_function) + + # Return the appropriate wrapper based on the original function type + if is_coroutine: + return wrapper_async + else: + return wrapper ####### USAGE CALCULATOR ################ @@ -3116,31 +3217,13 @@ def exception_type( print("LiteLLM.Info: If you need to debug this error, use `litellm.set_verbose=True'.") # noqa print() # noqa try: - if isinstance(original_exception, OriginalError): - # Handle the OpenAIError - exception_mapping_worked = True - if custom_llm_provider == "openrouter": - if original_exception.http_status == 413: - raise BadRequestError( - message=str(original_exception), - model=model, - llm_provider="openrouter" - ) - original_exception.llm_provider = "openrouter" - if "This model's maximum context length is" in original_exception._message: - raise ContextWindowExceededError( - message=str(original_exception), - model=model, - llm_provider=original_exception.llm_provider - ) - raise original_exception - elif model: + if model: error_str = str(original_exception) if isinstance(original_exception, BaseException): exception_type = type(original_exception).__name__ else: exception_type = "" - if custom_llm_provider == "openai" or custom_llm_provider == "text-completion-openai": + if custom_llm_provider == "openai" or custom_llm_provider == "text-completion-openai" or custom_llm_provider == "custom_openai": if "This model's maximum context length is" in error_str or "Request too large" in error_str: exception_mapping_worked = True raise ContextWindowExceededError( @@ -3191,6 +3274,14 @@ def exception_type( llm_provider="openai", response=original_exception.response ) + elif original_exception.status_code == 503: + exception_mapping_worked = True + raise ServiceUnavailableError( + message=f"OpenAIException - {original_exception.message}", + model=model, + llm_provider="openai", + response=original_exception.response + ) elif original_exception.status_code == 504: # gateway timeout error exception_mapping_worked = True raise Timeout( @@ -3968,49 +4059,6 @@ def exception_type( model=model, request=original_exception.request ) - elif custom_llm_provider == "custom_openai" or custom_llm_provider == "maritalk": - if hasattr(original_exception, "status_code"): - exception_mapping_worked = True - if original_exception.status_code == 401: - exception_mapping_worked = True - raise AuthenticationError( - message=f"CustomOpenAIException - {original_exception.message}", - llm_provider="custom_openai", - model=model - ) - elif original_exception.status_code == 408: - exception_mapping_worked = True - raise Timeout( - message=f"CustomOpenAIException - {original_exception.message}", - model=model, - llm_provider="custom_openai", - request=original_exception.request - ) - if original_exception.status_code == 422: - exception_mapping_worked = True - raise BadRequestError( - message=f"CustomOpenAIException - {original_exception.message}", - model=model, - llm_provider="custom_openai", - response=original_exception.response - ) - elif original_exception.status_code == 429: - exception_mapping_worked = True - raise RateLimitError( - message=f"CustomOpenAIException - {original_exception.message}", - model=model, - llm_provider="custom_openai", - response=original_exception.response - ) - else: - exception_mapping_worked = True - raise APIError( - status_code=original_exception.status_code, - message=f"CustomOpenAIException - {original_exception.message}", - llm_provider="custom_openai", - model=model, - request=original_exception.request - ) if "BadRequestError.__init__() missing 1 required positional argument: 'param'" in str(original_exception): # deal with edge-case invalid request error bug in openai-python sdk exception_mapping_worked = True raise BadRequestError(