diff --git a/litellm/main.py b/litellm/main.py index a76ef64a13..307659c8a2 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -428,7 +428,7 @@ def mock_completion( model: str, messages: List, stream: Optional[bool] = False, - mock_response: Union[str, Exception] = "This is a mock request", + mock_response: Union[str, Exception, dict] = "This is a mock request", mock_tool_calls: Optional[List] = None, logging=None, custom_llm_provider=None, @@ -477,6 +477,9 @@ def mock_completion( if time_delay is not None: time.sleep(time_delay) + if isinstance(mock_response, dict): + return ModelResponse(**mock_response) + model_response = ModelResponse(stream=stream) if stream is True: # don't try to access stream object, diff --git a/litellm/proxy/hooks/dynamic_rate_limiter.py b/litellm/proxy/hooks/dynamic_rate_limiter.py index dc29597f6e..dfe9632155 100644 --- a/litellm/proxy/hooks/dynamic_rate_limiter.py +++ b/litellm/proxy/hooks/dynamic_rate_limiter.py @@ -80,25 +80,35 @@ class _PROXY_DynamicRateLimitHandler(CustomLogger): Returns - Tuple[available_tpm, model_tpm, active_projects] - available_tpm: int or null - - model_tpm: int or null. If available tpm is int, then this will be too. + - remaining_model_tpm: int or null. If available tpm is int, then this will be too. - active_projects: int or null """ active_projects = await self.internal_usage_cache.async_get_cache(model=model) + current_model_tpm: Optional[int] = await self.llm_router.get_model_group_usage( + model_group=model + ) model_group_info: Optional[ModelGroupInfo] = ( self.llm_router.get_model_group_info(model_group=model) ) + total_model_tpm: Optional[int] = None + if model_group_info is not None and model_group_info.tpm is not None: + total_model_tpm = model_group_info.tpm + + remaining_model_tpm: Optional[int] = None + if total_model_tpm is not None and current_model_tpm is not None: + remaining_model_tpm = total_model_tpm - current_model_tpm + elif total_model_tpm is not None: + remaining_model_tpm = total_model_tpm available_tpm: Optional[int] = None - model_tpm: Optional[int] = None - if model_group_info is not None and model_group_info.tpm is not None: - model_tpm = model_group_info.tpm + if remaining_model_tpm is not None: if active_projects is not None: - available_tpm = int(model_group_info.tpm / active_projects) + available_tpm = int(remaining_model_tpm / active_projects) else: - available_tpm = model_group_info.tpm + available_tpm = remaining_model_tpm - return available_tpm, model_tpm, active_projects + return available_tpm, remaining_model_tpm, active_projects async def async_pre_call_hook( self, @@ -121,6 +131,7 @@ class _PROXY_DynamicRateLimitHandler(CustomLogger): - Check if tpm available - Raise RateLimitError if no tpm available """ + if "model" in data: available_tpm, model_tpm, active_projects = await self.check_available_tpm( model=data["model"] diff --git a/litellm/router.py b/litellm/router.py index 87890ebffc..9f0884ab80 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -11,6 +11,7 @@ import asyncio import concurrent import copy import datetime as datetime_og +import enum import hashlib import inspect import json @@ -90,6 +91,10 @@ from litellm.utils import ( ) +class RoutingArgs(enum.Enum): + ttl = 60 # 1min (RPM/TPM expire key) + + class Router: model_names: List = [] cache_responses: Optional[bool] = False @@ -387,6 +392,11 @@ class Router: routing_strategy=routing_strategy, routing_strategy_args=routing_strategy_args, ) + ## USAGE TRACKING ## + if isinstance(litellm._async_success_callback, list): + litellm._async_success_callback.append(self.deployment_callback_on_success) + else: + litellm._async_success_callback.append(self.deployment_callback_on_success) ## COOLDOWNS ## if isinstance(litellm.failure_callback, list): litellm.failure_callback.append(self.deployment_callback_on_failure) @@ -2636,13 +2646,69 @@ class Router: time.sleep(_timeout) if type(original_exception) in litellm.LITELLM_EXCEPTION_TYPES: - original_exception.max_retries = num_retries - original_exception.num_retries = current_attempt + setattr(original_exception, "max_retries", num_retries) + setattr(original_exception, "num_retries", current_attempt) raise original_exception ### HELPER FUNCTIONS + async def deployment_callback_on_success( + self, + kwargs, # kwargs to completion + completion_response, # response from completion + start_time, + end_time, # start/end time + ): + """ + Track remaining tpm/rpm quota for model in model_list + """ + try: + """ + Update TPM usage on success + """ + if kwargs["litellm_params"].get("metadata") is None: + pass + else: + model_group = kwargs["litellm_params"]["metadata"].get( + "model_group", None + ) + + id = kwargs["litellm_params"].get("model_info", {}).get("id", None) + if model_group is None or id is None: + return + elif isinstance(id, int): + id = str(id) + + total_tokens = completion_response["usage"]["total_tokens"] + + # ------------ + # Setup values + # ------------ + dt = get_utc_datetime() + current_minute = dt.strftime( + "%H-%M" + ) # use the same timezone regardless of system clock + + tpm_key = f"global_router:{id}:tpm:{current_minute}" + # ------------ + # Update usage + # ------------ + # update cache + + ## TPM + await self.cache.async_increment_cache( + key=tpm_key, value=total_tokens, ttl=RoutingArgs.ttl.value + ) + + except Exception as e: + verbose_router_logger.error( + "litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}\n{}".format( + str(e), traceback.format_exc() + ) + ) + pass + def deployment_callback_on_failure( self, kwargs, # kwargs to completion @@ -3963,6 +4029,35 @@ class Router: return model_group_info + async def get_model_group_usage(self, model_group: str) -> Optional[int]: + """ + Returns remaining tpm quota for model group + """ + dt = get_utc_datetime() + current_minute = dt.strftime( + "%H-%M" + ) # use the same timezone regardless of system clock + tpm_keys: List[str] = [] + for model in self.model_list: + if "model_name" in model and model["model_name"] == model_group: + tpm_keys.append( + f"global_router:{model['model_info']['id']}:tpm:{current_minute}" + ) + + ## TPM + tpm_usage_list: Optional[List] = await self.cache.async_batch_get_cache( + keys=tpm_keys + ) + tpm_usage: Optional[int] = None + if tpm_usage_list is not None: + for t in tpm_usage_list: + if isinstance(t, int): + if tpm_usage is None: + tpm_usage = 0 + tpm_usage += t + + return tpm_usage + def get_model_ids(self) -> List[str]: """ Returns list of model id's. @@ -4890,7 +4985,7 @@ class Router: def reset(self): ## clean up on close litellm.success_callback = [] - litellm.__async_success_callback = [] + litellm._async_success_callback = [] litellm.failure_callback = [] litellm._async_failure_callback = [] self.retry_policy = None diff --git a/litellm/tests/test_dynamic_rate_limit_handler.py b/litellm/tests/test_dynamic_rate_limit_handler.py index 71e3ac5359..df9e258810 100644 --- a/litellm/tests/test_dynamic_rate_limit_handler.py +++ b/litellm/tests/test_dynamic_rate_limit_handler.py @@ -8,7 +8,7 @@ import time import traceback import uuid from datetime import datetime -from typing import Tuple +from typing import Optional, Tuple from dotenv import load_dotenv @@ -40,6 +40,40 @@ def dynamic_rate_limit_handler() -> DynamicRateLimitHandler: return DynamicRateLimitHandler(internal_usage_cache=internal_cache) +@pytest.fixture +def mock_response() -> litellm.ModelResponse: + return litellm.ModelResponse( + **{ + "id": "chatcmpl-abc123", + "object": "chat.completion", + "created": 1699896916, + "model": "gpt-3.5-turbo-0125", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_abc123", + "type": "function", + "function": { + "name": "get_current_weather", + "arguments": '{\n"location": "Boston, MA"\n}', + }, + } + ], + }, + "logprobs": None, + "finish_reason": "tool_calls", + } + ], + "usage": {"prompt_tokens": 5, "completion_tokens": 5, "total_tokens": 10}, + } + ) + + @pytest.mark.parametrize("num_projects", [1, 2, 100]) @pytest.mark.asyncio async def test_available_tpm(num_projects, dynamic_rate_limit_handler): @@ -76,3 +110,65 @@ async def test_available_tpm(num_projects, dynamic_rate_limit_handler): expected_availability = int(model_tpm / num_projects) assert availability == expected_availability + + +@pytest.mark.asyncio +async def test_base_case(dynamic_rate_limit_handler, mock_response): + """ + If just 1 active project + + it should get all the quota + + = allow request to go through + - update token usage + - exhaust all tpm with just 1 project + """ + model = "my-fake-model" + model_tpm = 50 + setattr( + mock_response, + "usage", + litellm.Usage(prompt_tokens=5, completion_tokens=5, total_tokens=10), + ) + + llm_router = Router( + model_list=[ + { + "model_name": model, + "litellm_params": { + "model": "gpt-3.5-turbo", + "api_key": "my-key", + "api_base": "my-base", + "tpm": model_tpm, + "mock_response": mock_response, + }, + } + ] + ) + dynamic_rate_limit_handler.update_variables(llm_router=llm_router) + + prev_availability: Optional[int] = None + for _ in range(5): + # check availability + availability, _, _ = await dynamic_rate_limit_handler.check_available_tpm( + model=model + ) + + ## assert availability updated + if prev_availability is not None and availability is not None: + assert availability == prev_availability - 10 + + print( + "prev_availability={}, availability={}".format( + prev_availability, availability + ) + ) + + prev_availability = availability + + # make call + await llm_router.acompletion( + model=model, messages=[{"role": "user", "content": "hey!"}] + ) + + await asyncio.sleep(3) diff --git a/litellm/tests/test_router.py b/litellm/tests/test_router.py index d2037dc59e..55cf250e74 100644 --- a/litellm/tests/test_router.py +++ b/litellm/tests/test_router.py @@ -1730,3 +1730,96 @@ async def test_router_text_completion_client(): print(responses) except Exception as e: pytest.fail(f"Error occurred: {e}") + + +@pytest.fixture +def mock_response() -> litellm.ModelResponse: + return litellm.ModelResponse( + **{ + "id": "chatcmpl-abc123", + "object": "chat.completion", + "created": 1699896916, + "model": "gpt-3.5-turbo-0125", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_abc123", + "type": "function", + "function": { + "name": "get_current_weather", + "arguments": '{\n"location": "Boston, MA"\n}', + }, + } + ], + }, + "logprobs": None, + "finish_reason": "tool_calls", + } + ], + "usage": {"prompt_tokens": 5, "completion_tokens": 5, "total_tokens": 10}, + } + ) + + +@pytest.mark.asyncio +async def test_router_model_usage(mock_response): + model = "my-fake-model" + model_tpm = 100 + setattr( + mock_response, + "usage", + litellm.Usage(prompt_tokens=5, completion_tokens=5, total_tokens=10), + ) + + print(f"mock_response: {mock_response}") + model_tpm = 100 + llm_router = Router( + model_list=[ + { + "model_name": model, + "litellm_params": { + "model": "gpt-3.5-turbo", + "api_key": "my-key", + "api_base": "my-base", + "tpm": model_tpm, + "mock_response": mock_response, + }, + } + ] + ) + + allowed_fails = 1 # allow for changing b/w minutes + + for _ in range(2): + try: + _ = await llm_router.acompletion( + model=model, messages=[{"role": "user", "content": "Hey!"}] + ) + await asyncio.sleep(3) + + initial_usage = await llm_router.get_model_group_usage(model_group=model) + + # completion call - 10 tokens + _ = await llm_router.acompletion( + model=model, messages=[{"role": "user", "content": "Hey!"}] + ) + + await asyncio.sleep(3) + updated_usage = await llm_router.get_model_group_usage(model_group=model) + + assert updated_usage == initial_usage + 10 # type: ignore + break + except Exception as e: + if allowed_fails > 0: + print( + f"Decrementing allowed_fails: {allowed_fails}.\nReceived error - {str(e)}" + ) + allowed_fails -= 1 + else: + print(f"allowed_fails: {allowed_fails}") + raise e