diff --git a/docs/my-website/docs/proxy/team_budgets.md b/docs/my-website/docs/proxy/team_budgets.md index 9bfcb35d4c..9ab0c07866 100644 --- a/docs/my-website/docs/proxy/team_budgets.md +++ b/docs/my-website/docs/proxy/team_budgets.md @@ -152,3 +152,104 @@ litellm_remaining_team_budget_metric{team_alias="QA Prod Bot",team_id="de35b29e- ``` +### Dynamic TPM Allocation + +Prevent projects from gobbling too much quota. + +Dynamically allocate TPM quota to api keys, based on active keys in that minute. + +1. Setup config.yaml + +```yaml +model_list: + - model_name: my-fake-model + litellm_params: + model: gpt-3.5-turbo + api_key: my-fake-key + mock_response: hello-world + tpm: 60 + +litellm_settings: + callbacks: ["dynamic_rate_limiter"] + +general_settings: + master_key: sk-1234 # OR set `LITELLM_MASTER_KEY=".."` in your .env + database_url: postgres://.. # OR set `DATABASE_URL=".."` in your .env +``` + +2. Start proxy + +```bash +litellm --config /path/to/config.yaml +``` + +3. Test it! + +```python +""" +- Run 2 concurrent teams calling same model +- model has 60 TPM +- Mock response returns 30 total tokens / request +- Each team will only be able to make 1 request per minute +""" +""" +- Run 2 concurrent teams calling same model +- model has 60 TPM +- Mock response returns 30 total tokens / request +- Each team will only be able to make 1 request per minute +""" +import requests +from openai import OpenAI, RateLimitError + +def create_key(api_key: str, base_url: str): + response = requests.post( + url="{}/key/generate".format(base_url), + json={}, + headers={ + "Authorization": "Bearer {}".format(api_key) + } + ) + + _response = response.json() + + return _response["key"] + +key_1 = create_key(api_key="sk-1234", base_url="http://0.0.0.0:4000") +key_2 = create_key(api_key="sk-1234", base_url="http://0.0.0.0:4000") + +# call proxy with key 1 - works +openai_client_1 = OpenAI(api_key=key_1, base_url="http://0.0.0.0:4000") + +response = openai_client_1.chat.completions.with_raw_response.create( + model="my-fake-model", messages=[{"role": "user", "content": "Hello world!"}], +) + +print("Headers for call 1 - {}".format(response.headers)) +_response = response.parse() +print("Total tokens for call - {}".format(_response.usage.total_tokens)) + + +# call proxy with key 2 - works +openai_client_2 = OpenAI(api_key=key_2, base_url="http://0.0.0.0:4000") + +response = openai_client_2.chat.completions.with_raw_response.create( + model="my-fake-model", messages=[{"role": "user", "content": "Hello world!"}], +) + +print("Headers for call 2 - {}".format(response.headers)) +_response = response.parse() +print("Total tokens for call - {}".format(_response.usage.total_tokens)) +# call proxy with key 2 - fails +try: + openai_client_2.chat.completions.with_raw_response.create(model="my-fake-model", messages=[{"role": "user", "content": "Hey, how's it going?"}]) + raise Exception("This should have failed!") +except RateLimitError as e: + print("This was rate limited b/c - {}".format(str(e))) + +``` + +**Expected Response** + +``` +This was rate limited b/c - Error code: 429 - {'error': {'message': {'error': 'Key= over available TPM=0. Model TPM=0, Active keys=2'}, 'type': 'None', 'param': 'None', 'code': 429}} +``` \ No newline at end of file diff --git a/litellm/__init__.py b/litellm/__init__.py index a191d46bfd..f07ce88092 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -37,7 +37,9 @@ input_callback: List[Union[str, Callable]] = [] success_callback: List[Union[str, Callable]] = [] failure_callback: List[Union[str, Callable]] = [] service_callback: List[Union[str, Callable]] = [] -_custom_logger_compatible_callbacks_literal = Literal["lago", "openmeter", "logfire"] +_custom_logger_compatible_callbacks_literal = Literal[ + "lago", "openmeter", "logfire", "dynamic_rate_limiter" +] callbacks: List[Union[Callable, _custom_logger_compatible_callbacks_literal]] = [] _langfuse_default_tags: Optional[ List[ @@ -735,6 +737,7 @@ from .utils import ( client, exception_type, get_optional_params, + get_response_string, modify_integration, token_counter, create_pretrained_tokenizer, diff --git a/litellm/_service_logger.py b/litellm/_service_logger.py index f2edb403e9..0aa0312b80 100644 --- a/litellm/_service_logger.py +++ b/litellm/_service_logger.py @@ -1,11 +1,12 @@ -from datetime import datetime +from datetime import datetime, timedelta +from typing import TYPE_CHECKING, Any, Optional, Union + import litellm from litellm.proxy._types import UserAPIKeyAuth -from .types.services import ServiceTypes, ServiceLoggerPayload -from .integrations.prometheus_services import PrometheusServicesLogger + from .integrations.custom_logger import CustomLogger -from datetime import timedelta -from typing import Union, Optional, TYPE_CHECKING, Any +from .integrations.prometheus_services import PrometheusServicesLogger +from .types.services import ServiceLoggerPayload, ServiceTypes if TYPE_CHECKING: from opentelemetry.trace import Span as _Span @@ -53,8 +54,8 @@ class ServiceLogging(CustomLogger): call_type: str, duration: float, parent_otel_span: Optional[Span] = None, - start_time: Optional[datetime] = None, - end_time: Optional[datetime] = None, + start_time: Optional[Union[datetime, float]] = None, + end_time: Optional[Union[datetime, float]] = None, ): """ - For counting if the redis, postgres call is successful @@ -92,8 +93,8 @@ class ServiceLogging(CustomLogger): error: Union[str, Exception], call_type: str, parent_otel_span: Optional[Span] = None, - start_time: Optional[datetime] = None, - end_time: Optional[datetime] = None, + start_time: Optional[Union[datetime, float]] = None, + end_time: Optional[Union[float, datetime]] = None, ): """ - For counting if the redis, postgres call is unsuccessful diff --git a/litellm/caching.py b/litellm/caching.py index 6ac439e0f3..4fe9ace07f 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -14,6 +14,7 @@ import json import logging import time import traceback +from datetime import timedelta from typing import Any, BinaryIO, List, Literal, Optional, Union from openai._models import BaseModel as OpenAIObject @@ -92,9 +93,22 @@ class InMemoryCache(BaseCache): else: self.set_cache(key=cache_key, value=cache_value) + if time.time() - self.last_cleaned > self.default_ttl: asyncio.create_task(self.clean_up_in_memory_cache()) + async def async_set_cache_sadd(self, key, value: List, ttl: Optional[float]): + """ + Add value to set + """ + # get the value + init_value = self.get_cache(key=key) or set() + for val in value: + init_value.add(val) + self.set_cache(key, init_value, ttl=ttl) + return value + + def get_cache(self, key, **kwargs): if key in self.cache_dict: if key in self.ttl_dict: @@ -363,6 +377,7 @@ class RedisCache(BaseCache): start_time=start_time, end_time=end_time, parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs), + call_type="async_set_cache", ) ) # NON blocking - notify users Redis is throwing an exception @@ -482,6 +497,80 @@ class RedisCache(BaseCache): cache_value, ) + async def async_set_cache_sadd( + self, key, value: List, ttl: Optional[float], **kwargs + ): + start_time = time.time() + try: + _redis_client = self.init_async_client() + except Exception as e: + end_time = time.time() + _duration = end_time - start_time + asyncio.create_task( + self.service_logger_obj.async_service_failure_hook( + service=ServiceTypes.REDIS, + duration=_duration, + error=e, + start_time=start_time, + end_time=end_time, + parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs), + call_type="async_set_cache_sadd", + ) + ) + # NON blocking - notify users Redis is throwing an exception + verbose_logger.error( + "LiteLLM Redis Caching: async set() - Got exception from REDIS %s, Writing value=%s", + str(e), + value, + ) + raise e + + key = self.check_and_fix_namespace(key=key) + async with _redis_client as redis_client: + print_verbose( + f"Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}" + ) + try: + await redis_client.sadd(key, *value) + if ttl is not None: + _td = timedelta(seconds=ttl) + await redis_client.expire(key, _td) + print_verbose( + f"Successfully Set ASYNC Redis Cache SADD: key: {key}\nValue {value}\nttl={ttl}" + ) + end_time = time.time() + _duration = end_time - start_time + asyncio.create_task( + self.service_logger_obj.async_service_success_hook( + service=ServiceTypes.REDIS, + duration=_duration, + call_type="async_set_cache_sadd", + start_time=start_time, + end_time=end_time, + parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs), + ) + ) + except Exception as e: + end_time = time.time() + _duration = end_time - start_time + asyncio.create_task( + self.service_logger_obj.async_service_failure_hook( + service=ServiceTypes.REDIS, + duration=_duration, + error=e, + call_type="async_set_cache_sadd", + start_time=start_time, + end_time=end_time, + parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs), + ) + ) + # NON blocking - notify users Redis is throwing an exception + verbose_logger.error( + "LiteLLM Redis Caching: async set_cache_sadd() - Got exception from REDIS %s, Writing value=%s", + str(e), + value, + ) + async def batch_cache_write(self, key, value, **kwargs): print_verbose( f"in batch cache writing for redis buffer size={len(self.redis_batch_writing_buffer)}", @@ -1506,7 +1595,7 @@ class DualCache(BaseCache): key, value, **kwargs ) - if self.redis_cache is not None and local_only == False: + if self.redis_cache is not None and local_only is False: result = await self.redis_cache.async_increment(key, value, **kwargs) return result @@ -1515,6 +1604,38 @@ class DualCache(BaseCache): verbose_logger.debug(traceback.format_exc()) raise e + async def async_set_cache_sadd( + self, key, value: List, local_only: bool = False, **kwargs + ) -> None: + """ + Add value to a set + + Key - the key in cache + + Value - str - the value you want to add to the set + + Returns - None + """ + try: + if self.in_memory_cache is not None: + _ = await self.in_memory_cache.async_set_cache_sadd( + key, value, ttl=kwargs.get("ttl", None) + ) + + if self.redis_cache is not None and local_only is False: + _ = await self.redis_cache.async_set_cache_sadd( + key, value, ttl=kwargs.get("ttl", None) ** kwargs + ) + + return None + except Exception as e: + verbose_logger.error( + "LiteLLM Cache: Excepton async set_cache_sadd: {}\n{}".format( + str(e), traceback.format_exc() + ) + ) + raise e + def flush_cache(self): if self.in_memory_cache is not None: self.in_memory_cache.flush_cache() diff --git a/litellm/integrations/opentelemetry.py b/litellm/integrations/opentelemetry.py index e4daff1218..fa7be1d574 100644 --- a/litellm/integrations/opentelemetry.py +++ b/litellm/integrations/opentelemetry.py @@ -105,8 +105,8 @@ class OpenTelemetry(CustomLogger): self, payload: ServiceLoggerPayload, parent_otel_span: Optional[Span] = None, - start_time: Optional[datetime] = None, - end_time: Optional[datetime] = None, + start_time: Optional[Union[datetime, float]] = None, + end_time: Optional[Union[datetime, float]] = None, ): from datetime import datetime @@ -144,8 +144,8 @@ class OpenTelemetry(CustomLogger): self, payload: ServiceLoggerPayload, parent_otel_span: Optional[Span] = None, - start_time: Optional[datetime] = None, - end_time: Optional[datetime] = None, + start_time: Optional[Union[datetime, float]] = None, + end_time: Optional[Union[float, datetime]] = None, ): from datetime import datetime diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index 1e36a40b7b..aa22b51534 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -19,7 +19,8 @@ from litellm import ( turn_off_message_logging, verbose_logger, ) -from litellm.caching import InMemoryCache, S3Cache + +from litellm.caching import InMemoryCache, S3Cache, DualCache from litellm.integrations.custom_logger import CustomLogger from litellm.litellm_core_utils.redact_messages import ( redact_message_input_output_from_logging, @@ -1899,7 +1900,11 @@ def set_callbacks(callback_list, function_id=None): def _init_custom_logger_compatible_class( logging_integration: litellm._custom_logger_compatible_callbacks_literal, -) -> Callable: + internal_usage_cache: Optional[DualCache], + llm_router: Optional[ + Any + ], # expect litellm.Router, but typing errors due to circular import +) -> CustomLogger: if logging_integration == "lago": for callback in _in_memory_loggers: if isinstance(callback, LagoLogger): @@ -1935,3 +1940,58 @@ def _init_custom_logger_compatible_class( _otel_logger = OpenTelemetry(config=otel_config) _in_memory_loggers.append(_otel_logger) return _otel_logger # type: ignore + elif logging_integration == "dynamic_rate_limiter": + from litellm.proxy.hooks.dynamic_rate_limiter import ( + _PROXY_DynamicRateLimitHandler, + ) + + for callback in _in_memory_loggers: + if isinstance(callback, _PROXY_DynamicRateLimitHandler): + return callback # type: ignore + + if internal_usage_cache is None: + raise Exception( + "Internal Error: Cache cannot be empty - internal_usage_cache={}".format( + internal_usage_cache + ) + ) + + dynamic_rate_limiter_obj = _PROXY_DynamicRateLimitHandler( + internal_usage_cache=internal_usage_cache + ) + + if llm_router is not None and isinstance(llm_router, litellm.Router): + dynamic_rate_limiter_obj.update_variables(llm_router=llm_router) + _in_memory_loggers.append(dynamic_rate_limiter_obj) + return dynamic_rate_limiter_obj # type: ignore + + +def get_custom_logger_compatible_class( + logging_integration: litellm._custom_logger_compatible_callbacks_literal, +) -> Optional[CustomLogger]: + if logging_integration == "lago": + for callback in _in_memory_loggers: + if isinstance(callback, LagoLogger): + return callback + elif logging_integration == "openmeter": + for callback in _in_memory_loggers: + if isinstance(callback, OpenMeterLogger): + return callback + elif logging_integration == "logfire": + if "LOGFIRE_TOKEN" not in os.environ: + raise ValueError("LOGFIRE_TOKEN not found in environment variables") + from litellm.integrations.opentelemetry import OpenTelemetry + + for callback in _in_memory_loggers: + if isinstance(callback, OpenTelemetry): + return callback # type: ignore + + elif logging_integration == "dynamic_rate_limiter": + from litellm.proxy.hooks.dynamic_rate_limiter import ( + _PROXY_DynamicRateLimitHandler, + ) + + for callback in _in_memory_loggers: + if isinstance(callback, _PROXY_DynamicRateLimitHandler): + return callback # type: ignore + return None diff --git a/litellm/main.py b/litellm/main.py index a76ef64a13..307659c8a2 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -428,7 +428,7 @@ def mock_completion( model: str, messages: List, stream: Optional[bool] = False, - mock_response: Union[str, Exception] = "This is a mock request", + mock_response: Union[str, Exception, dict] = "This is a mock request", mock_tool_calls: Optional[List] = None, logging=None, custom_llm_provider=None, @@ -477,6 +477,9 @@ def mock_completion( if time_delay is not None: time.sleep(time_delay) + if isinstance(mock_response, dict): + return ModelResponse(**mock_response) + model_response = ModelResponse(stream=stream) if stream is True: # don't try to access stream object, diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 866ca0ab0a..01f09ca02b 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -1,61 +1,10 @@ -environment_variables: - LANGFUSE_PUBLIC_KEY: Q6K8MQN6L7sPYSJiFKM9eNrETOx6V/FxVPup4FqdKsZK1hyR4gyanlQ2KHLg5D5afng99uIt0JCEQ2jiKF9UxFvtnb4BbJ4qpeceH+iK8v/bdg== - LANGFUSE_SECRET_KEY: 5xQ7KMa6YMLsm+H/Pf1VmlqWq1NON5IoCxABhkUBeSck7ftsj2CmpkL2ZwrxwrktgiTUBH+3gJYBX+XBk7lqOOUpvmiLjol/E5lCqq0M1CqLWA== - SLACK_WEBHOOK_URL: RJjhS0Hhz0/s07sCIf1OTXmTGodpK9L2K9p953Z+fOX0l2SkPFT6mB9+yIrLufmlwEaku5NNEBKy//+AG01yOd+7wV1GhK65vfj3B/gTN8t5cuVnR4vFxKY5Rx4eSGLtzyAs+aIBTp4GoNXDIjroCqfCjPkItEZWCg== -general_settings: - alerting: - - slack - alerting_threshold: 300 - database_connection_pool_limit: 100 - database_connection_timeout: 60 - disable_master_key_return: true - health_check_interval: 300 - proxy_batch_write_at: 60 - ui_access_mode: all - # master_key: sk-1234 -litellm_settings: - allowed_fails: 3 - failure_callback: - - prometheus - num_retries: 3 - service_callback: - - prometheus_system - success_callback: - - langfuse - - prometheus - - langsmith -model_list: -- litellm_params: - model: gpt-3.5-turbo - model_name: gpt-3.5-turbo -- litellm_params: - api_base: https://openai-function-calling-workers.tasslexyz.workers.dev/ - api_key: my-fake-key - model: openai/my-fake-model - stream_timeout: 0.001 - model_name: fake-openai-endpoint -- litellm_params: - api_base: https://openai-function-calling-workers.tasslexyz.workers.dev/ - api_key: my-fake-key - model: openai/my-fake-model-2 - stream_timeout: 0.001 - model_name: fake-openai-endpoint -- litellm_params: - api_base: os.environ/AZURE_API_BASE - api_key: os.environ/AZURE_API_KEY - api_version: 2023-07-01-preview - model: azure/chatgpt-v-2 - stream_timeout: 0.001 - model_name: azure-gpt-3.5 -- litellm_params: - api_key: os.environ/OPENAI_API_KEY - model: text-embedding-ada-002 - model_name: text-embedding-ada-002 -- litellm_params: - model: text-completion-openai/gpt-3.5-turbo-instruct - model_name: gpt-instruct -router_settings: - enable_pre_call_checks: true - redis_host: os.environ/REDIS_HOST - redis_password: os.environ/REDIS_PASSWORD - redis_port: os.environ/REDIS_PORT +model_list: + - model_name: my-fake-model + litellm_params: + model: gpt-3.5-turbo + api_key: my-fake-key + mock_response: hello-world + tpm: 60 + +litellm_settings: + callbacks: ["dynamic_rate_limiter"] \ No newline at end of file diff --git a/litellm/proxy/_super_secret_config.yaml b/litellm/proxy/_super_secret_config.yaml index ac0aaca5c7..04a4806c12 100644 --- a/litellm/proxy/_super_secret_config.yaml +++ b/litellm/proxy/_super_secret_config.yaml @@ -30,6 +30,7 @@ model_list: api_key: os.environ/AZURE_API_KEY api_version: 2024-02-15-preview model: azure/chatgpt-v-2 + tpm: 100 model_name: gpt-3.5-turbo - litellm_params: model: anthropic.claude-3-sonnet-20240229-v1:0 @@ -40,6 +41,7 @@ model_list: api_version: 2024-02-15-preview model: azure/chatgpt-v-2 drop_params: True + tpm: 100 model_name: gpt-3.5-turbo - model_name: tts litellm_params: @@ -67,8 +69,7 @@ model_list: max_input_tokens: 80920 litellm_settings: - success_callback: ["langfuse"] - failure_callback: ["langfuse"] + callbacks: ["dynamic_rate_limiter"] # default_team_settings: # - team_id: proj1 # success_callback: ["langfuse"] diff --git a/litellm/proxy/hooks/dynamic_rate_limiter.py b/litellm/proxy/hooks/dynamic_rate_limiter.py new file mode 100644 index 0000000000..95f0ccc13e --- /dev/null +++ b/litellm/proxy/hooks/dynamic_rate_limiter.py @@ -0,0 +1,205 @@ +# What is this? +## Allocates dynamic tpm/rpm quota for a project based on current traffic +## Tracks num active projects per minute + +import asyncio +import sys +import traceback +from datetime import datetime +from typing import List, Literal, Optional, Tuple, Union + +from fastapi import HTTPException + +import litellm +from litellm import ModelResponse, Router +from litellm._logging import verbose_proxy_logger +from litellm.caching import DualCache +from litellm.integrations.custom_logger import CustomLogger +from litellm.proxy._types import UserAPIKeyAuth +from litellm.types.router import ModelGroupInfo +from litellm.utils import get_utc_datetime + + +class DynamicRateLimiterCache: + """ + Thin wrapper on DualCache for this file. + + Track number of active projects calling a model. + """ + + def __init__(self, cache: DualCache) -> None: + self.cache = cache + self.ttl = 60 # 1 min ttl + + async def async_get_cache(self, model: str) -> Optional[int]: + dt = get_utc_datetime() + current_minute = dt.strftime("%H-%M") + key_name = "{}:{}".format(current_minute, model) + _response = await self.cache.async_get_cache(key=key_name) + response: Optional[int] = None + if _response is not None: + response = len(_response) + return response + + async def async_set_cache_sadd(self, model: str, value: List): + """ + Add value to set. + + Parameters: + - model: str, the name of the model group + - value: str, the team id + + Returns: + - None + + Raises: + - Exception, if unable to connect to cache client (if redis caching enabled) + """ + try: + dt = get_utc_datetime() + current_minute = dt.strftime("%H-%M") + + key_name = "{}:{}".format(current_minute, model) + await self.cache.async_set_cache_sadd( + key=key_name, value=value, ttl=self.ttl + ) + except Exception as e: + verbose_proxy_logger.error( + "litellm.proxy.hooks.dynamic_rate_limiter.py::async_set_cache_sadd(): Exception occured - {}\n{}".format( + str(e), traceback.format_exc() + ) + ) + raise e + + +class _PROXY_DynamicRateLimitHandler(CustomLogger): + + # Class variables or attributes + def __init__(self, internal_usage_cache: DualCache): + self.internal_usage_cache = DynamicRateLimiterCache(cache=internal_usage_cache) + + def update_variables(self, llm_router: Router): + self.llm_router = llm_router + + async def check_available_tpm( + self, model: str + ) -> Tuple[Optional[int], Optional[int], Optional[int]]: + """ + For a given model, get its available tpm + + Returns + - Tuple[available_tpm, model_tpm, active_projects] + - available_tpm: int or null - always 0 or positive. + - remaining_model_tpm: int or null. If available tpm is int, then this will be too. + - active_projects: int or null + """ + active_projects = await self.internal_usage_cache.async_get_cache(model=model) + current_model_tpm: Optional[int] = await self.llm_router.get_model_group_usage( + model_group=model + ) + model_group_info: Optional[ModelGroupInfo] = ( + self.llm_router.get_model_group_info(model_group=model) + ) + total_model_tpm: Optional[int] = None + if model_group_info is not None and model_group_info.tpm is not None: + total_model_tpm = model_group_info.tpm + + remaining_model_tpm: Optional[int] = None + if total_model_tpm is not None and current_model_tpm is not None: + remaining_model_tpm = total_model_tpm - current_model_tpm + elif total_model_tpm is not None: + remaining_model_tpm = total_model_tpm + + available_tpm: Optional[int] = None + + if remaining_model_tpm is not None: + if active_projects is not None: + available_tpm = int(remaining_model_tpm / active_projects) + else: + available_tpm = remaining_model_tpm + + if available_tpm is not None and available_tpm < 0: + available_tpm = 0 + return available_tpm, remaining_model_tpm, active_projects + + async def async_pre_call_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + cache: DualCache, + data: dict, + call_type: Literal[ + "completion", + "text_completion", + "embeddings", + "image_generation", + "moderation", + "audio_transcription", + ], + ) -> Optional[ + Union[Exception, str, dict] + ]: # raise exception if invalid, return a str for the user to receive - if rejected, or return a modified dictionary for passing into litellm + """ + - For a model group + - Check if tpm available + - Raise RateLimitError if no tpm available + """ + if "model" in data: + available_tpm, model_tpm, active_projects = await self.check_available_tpm( + model=data["model"] + ) + if available_tpm is not None and available_tpm == 0: + raise HTTPException( + status_code=429, + detail={ + "error": "Key={} over available TPM={}. Model TPM={}, Active keys={}".format( + user_api_key_dict.api_key, + available_tpm, + model_tpm, + active_projects, + ) + }, + ) + elif available_tpm is not None: + ## UPDATE CACHE WITH ACTIVE PROJECT + asyncio.create_task( + self.internal_usage_cache.async_set_cache_sadd( # this is a set + model=data["model"], # type: ignore + value=[user_api_key_dict.token or "default_key"], + ) + ) + return None + + async def async_post_call_success_hook( + self, user_api_key_dict: UserAPIKeyAuth, response + ): + try: + if isinstance(response, ModelResponse): + model_info = self.llm_router.get_model_info( + id=response._hidden_params["model_id"] + ) + assert ( + model_info is not None + ), "Model info for model with id={} is None".format( + response._hidden_params["model_id"] + ) + available_tpm, remaining_model_tpm, active_projects = ( + await self.check_available_tpm(model=model_info["model_name"]) + ) + response._hidden_params["additional_headers"] = { + "x-litellm-model_group": model_info["model_name"], + "x-ratelimit-remaining-litellm-project-tokens": available_tpm, + "x-ratelimit-remaining-model-tokens": remaining_model_tpm, + "x-ratelimit-current-active-projects": active_projects, + } + + return response + return await super().async_post_call_success_hook( + user_api_key_dict, response + ) + except Exception as e: + verbose_proxy_logger.error( + "litellm.proxy.hooks.dynamic_rate_limiter.py::async_post_call_success_hook(): Exception occured - {}\n{}".format( + str(e), traceback.format_exc() + ) + ) + return response diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 630aa3f3e1..4cac93b24f 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -433,6 +433,7 @@ def get_custom_headers( version: Optional[str] = None, model_region: Optional[str] = None, fastest_response_batch_completion: Optional[bool] = None, + **kwargs, ) -> dict: exclude_values = {"", None} headers = { @@ -448,6 +449,7 @@ def get_custom_headers( if fastest_response_batch_completion is not None else None ), + **{k: str(v) for k, v in kwargs.items()}, } try: return { @@ -2644,7 +2646,9 @@ async def startup_event(): redis_cache=redis_usage_cache ) # used by parallel request limiter for rate limiting keys across instances - proxy_logging_obj._init_litellm_callbacks() # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made + proxy_logging_obj._init_litellm_callbacks( + llm_router=llm_router + ) # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made if "daily_reports" in proxy_logging_obj.slack_alerting_instance.alert_types: asyncio.create_task( @@ -3061,6 +3065,14 @@ async def chat_completion( headers=custom_headers, ) + ### CALL HOOKS ### - modify outgoing data + response = await proxy_logging_obj.post_call_success_hook( + user_api_key_dict=user_api_key_dict, response=response + ) + + hidden_params = getattr(response, "_hidden_params", {}) or {} + additional_headers: dict = hidden_params.get("additional_headers", {}) or {} + fastapi_response.headers.update( get_custom_headers( user_api_key_dict=user_api_key_dict, @@ -3070,14 +3082,10 @@ async def chat_completion( version=version, model_region=getattr(user_api_key_dict, "allowed_model_region", ""), fastest_response_batch_completion=fastest_response_batch_completion, + **additional_headers, ) ) - ### CALL HOOKS ### - modify outgoing data - response = await proxy_logging_obj.post_call_success_hook( - user_api_key_dict=user_api_key_dict, response=response - ) - return response except RejectedRequestError as e: _data = e.request_data @@ -3116,11 +3124,10 @@ async def chat_completion( except Exception as e: data["litellm_status"] = "fail" # used for alerting verbose_proxy_logger.error( - "litellm.proxy.proxy_server.chat_completion(): Exception occured - {}".format( - get_error_message_str(e=e) + "litellm.proxy.proxy_server.chat_completion(): Exception occured - {}\n{}".format( + get_error_message_str(e=e), traceback.format_exc() ) ) - verbose_proxy_logger.debug(traceback.format_exc()) await proxy_logging_obj.post_call_failure_hook( user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data ) diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 8afe679bd1..96aeb4a816 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -229,31 +229,32 @@ class ProxyLogging: if redis_cache is not None: self.internal_usage_cache.redis_cache = redis_cache - def _init_litellm_callbacks(self): - print_verbose("INITIALIZING LITELLM CALLBACKS!") + def _init_litellm_callbacks(self, llm_router: Optional[litellm.Router] = None): self.service_logging_obj = ServiceLogging() - litellm.callbacks.append(self.max_parallel_request_limiter) - litellm.callbacks.append(self.max_budget_limiter) - litellm.callbacks.append(self.cache_control_check) - litellm.callbacks.append(self.service_logging_obj) + litellm.callbacks.append(self.max_parallel_request_limiter) # type: ignore + litellm.callbacks.append(self.max_budget_limiter) # type: ignore + litellm.callbacks.append(self.cache_control_check) # type: ignore + litellm.callbacks.append(self.service_logging_obj) # type: ignore litellm.success_callback.append( self.slack_alerting_instance.response_taking_too_long_callback ) for callback in litellm.callbacks: if isinstance(callback, str): - callback = litellm.litellm_core_utils.litellm_logging._init_custom_logger_compatible_class( - callback + callback = litellm.litellm_core_utils.litellm_logging._init_custom_logger_compatible_class( # type: ignore + callback, + internal_usage_cache=self.internal_usage_cache, + llm_router=llm_router, ) if callback not in litellm.input_callback: - litellm.input_callback.append(callback) + litellm.input_callback.append(callback) # type: ignore if callback not in litellm.success_callback: - litellm.success_callback.append(callback) + litellm.success_callback.append(callback) # type: ignore if callback not in litellm.failure_callback: - litellm.failure_callback.append(callback) + litellm.failure_callback.append(callback) # type: ignore if callback not in litellm._async_success_callback: - litellm._async_success_callback.append(callback) + litellm._async_success_callback.append(callback) # type: ignore if callback not in litellm._async_failure_callback: - litellm._async_failure_callback.append(callback) + litellm._async_failure_callback.append(callback) # type: ignore if ( len(litellm.input_callback) > 0 @@ -301,10 +302,19 @@ class ProxyLogging: try: for callback in litellm.callbacks: - if isinstance(callback, CustomLogger) and "async_pre_call_hook" in vars( - callback.__class__ + _callback: Optional[CustomLogger] = None + if isinstance(callback, str): + _callback = litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class( + callback + ) + else: + _callback = callback # type: ignore + if ( + _callback is not None + and isinstance(_callback, CustomLogger) + and "async_pre_call_hook" in vars(_callback.__class__) ): - response = await callback.async_pre_call_hook( + response = await _callback.async_pre_call_hook( user_api_key_dict=user_api_key_dict, cache=self.call_details["user_api_key_cache"], data=data, @@ -574,8 +584,15 @@ class ProxyLogging: for callback in litellm.callbacks: try: - if isinstance(callback, CustomLogger): - await callback.async_post_call_failure_hook( + _callback: Optional[CustomLogger] = None + if isinstance(callback, str): + _callback = litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class( + callback + ) + else: + _callback = callback # type: ignore + if _callback is not None and isinstance(_callback, CustomLogger): + await _callback.async_post_call_failure_hook( user_api_key_dict=user_api_key_dict, original_exception=original_exception, ) @@ -596,8 +613,15 @@ class ProxyLogging: """ for callback in litellm.callbacks: try: - if isinstance(callback, CustomLogger): - await callback.async_post_call_success_hook( + _callback: Optional[CustomLogger] = None + if isinstance(callback, str): + _callback = litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class( + callback + ) + else: + _callback = callback # type: ignore + if _callback is not None and isinstance(_callback, CustomLogger): + await _callback.async_post_call_success_hook( user_api_key_dict=user_api_key_dict, response=response ) except Exception as e: @@ -615,14 +639,25 @@ class ProxyLogging: Covers: 1. /chat/completions """ - for callback in litellm.callbacks: - try: - if isinstance(callback, CustomLogger): - await callback.async_post_call_streaming_hook( - user_api_key_dict=user_api_key_dict, response=response - ) - except Exception as e: - raise e + response_str: Optional[str] = None + if isinstance(response, ModelResponse): + response_str = litellm.get_response_string(response_obj=response) + if response_str is not None: + for callback in litellm.callbacks: + try: + _callback: Optional[CustomLogger] = None + if isinstance(callback, str): + _callback = litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class( + callback + ) + else: + _callback = callback # type: ignore + if _callback is not None and isinstance(_callback, CustomLogger): + await _callback.async_post_call_streaming_hook( + user_api_key_dict=user_api_key_dict, response=response_str + ) + except Exception as e: + raise e return response async def post_call_streaming_hook( diff --git a/litellm/router.py b/litellm/router.py index 69000d6048..df783eab82 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -11,6 +11,7 @@ import asyncio import concurrent import copy import datetime as datetime_og +import enum import hashlib import inspect import json @@ -90,6 +91,10 @@ from litellm.utils import ( ) +class RoutingArgs(enum.Enum): + ttl = 60 # 1min (RPM/TPM expire key) + + class Router: model_names: List = [] cache_responses: Optional[bool] = False @@ -387,6 +392,11 @@ class Router: routing_strategy=routing_strategy, routing_strategy_args=routing_strategy_args, ) + ## USAGE TRACKING ## + if isinstance(litellm._async_success_callback, list): + litellm._async_success_callback.append(self.deployment_callback_on_success) + else: + litellm._async_success_callback.append(self.deployment_callback_on_success) ## COOLDOWNS ## if isinstance(litellm.failure_callback, list): litellm.failure_callback.append(self.deployment_callback_on_failure) @@ -2640,13 +2650,69 @@ class Router: time.sleep(_timeout) if type(original_exception) in litellm.LITELLM_EXCEPTION_TYPES: - original_exception.max_retries = num_retries - original_exception.num_retries = current_attempt + setattr(original_exception, "max_retries", num_retries) + setattr(original_exception, "num_retries", current_attempt) raise original_exception ### HELPER FUNCTIONS + async def deployment_callback_on_success( + self, + kwargs, # kwargs to completion + completion_response, # response from completion + start_time, + end_time, # start/end time + ): + """ + Track remaining tpm/rpm quota for model in model_list + """ + try: + """ + Update TPM usage on success + """ + if kwargs["litellm_params"].get("metadata") is None: + pass + else: + model_group = kwargs["litellm_params"]["metadata"].get( + "model_group", None + ) + + id = kwargs["litellm_params"].get("model_info", {}).get("id", None) + if model_group is None or id is None: + return + elif isinstance(id, int): + id = str(id) + + total_tokens = completion_response["usage"]["total_tokens"] + + # ------------ + # Setup values + # ------------ + dt = get_utc_datetime() + current_minute = dt.strftime( + "%H-%M" + ) # use the same timezone regardless of system clock + + tpm_key = f"global_router:{id}:tpm:{current_minute}" + # ------------ + # Update usage + # ------------ + # update cache + + ## TPM + await self.cache.async_increment_cache( + key=tpm_key, value=total_tokens, ttl=RoutingArgs.ttl.value + ) + + except Exception as e: + verbose_router_logger.error( + "litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}\n{}".format( + str(e), traceback.format_exc() + ) + ) + pass + def deployment_callback_on_failure( self, kwargs, # kwargs to completion @@ -3812,10 +3878,39 @@ class Router: model_group_info: Optional[ModelGroupInfo] = None + total_tpm: Optional[int] = None + total_rpm: Optional[int] = None + for model in self.model_list: if "model_name" in model and model["model_name"] == model_group: # model in model group found # litellm_params = LiteLLM_Params(**model["litellm_params"]) + # get model tpm + _deployment_tpm: Optional[int] = None + if _deployment_tpm is None: + _deployment_tpm = model.get("tpm", None) + if _deployment_tpm is None: + _deployment_tpm = model.get("litellm_params", {}).get("tpm", None) + if _deployment_tpm is None: + _deployment_tpm = model.get("model_info", {}).get("tpm", None) + + if _deployment_tpm is not None: + if total_tpm is None: + total_tpm = 0 + total_tpm += _deployment_tpm # type: ignore + # get model rpm + _deployment_rpm: Optional[int] = None + if _deployment_rpm is None: + _deployment_rpm = model.get("rpm", None) + if _deployment_rpm is None: + _deployment_rpm = model.get("litellm_params", {}).get("rpm", None) + if _deployment_rpm is None: + _deployment_rpm = model.get("model_info", {}).get("rpm", None) + + if _deployment_rpm is not None: + if total_rpm is None: + total_rpm = 0 + total_rpm += _deployment_rpm # type: ignore # get model info try: model_info = litellm.get_model_info(model=litellm_params.model) @@ -3929,8 +4024,44 @@ class Router: "supported_openai_params" ] + ## UPDATE WITH TOTAL TPM/RPM FOR MODEL GROUP + if total_tpm is not None and model_group_info is not None: + model_group_info.tpm = total_tpm + + if total_rpm is not None and model_group_info is not None: + model_group_info.rpm = total_rpm + return model_group_info + async def get_model_group_usage(self, model_group: str) -> Optional[int]: + """ + Returns remaining tpm quota for model group + """ + dt = get_utc_datetime() + current_minute = dt.strftime( + "%H-%M" + ) # use the same timezone regardless of system clock + tpm_keys: List[str] = [] + for model in self.model_list: + if "model_name" in model and model["model_name"] == model_group: + tpm_keys.append( + f"global_router:{model['model_info']['id']}:tpm:{current_minute}" + ) + + ## TPM + tpm_usage_list: Optional[List] = await self.cache.async_batch_get_cache( + keys=tpm_keys + ) + tpm_usage: Optional[int] = None + if tpm_usage_list is not None: + for t in tpm_usage_list: + if isinstance(t, int): + if tpm_usage is None: + tpm_usage = 0 + tpm_usage += t + + return tpm_usage + def get_model_ids(self) -> List[str]: """ Returns list of model id's. @@ -4858,7 +4989,7 @@ class Router: def reset(self): ## clean up on close litellm.success_callback = [] - litellm.__async_success_callback = [] + litellm._async_success_callback = [] litellm.failure_callback = [] litellm._async_failure_callback = [] self.retry_policy = None diff --git a/litellm/tests/test_dynamic_rate_limit_handler.py b/litellm/tests/test_dynamic_rate_limit_handler.py new file mode 100644 index 0000000000..c3fcca6a6b --- /dev/null +++ b/litellm/tests/test_dynamic_rate_limit_handler.py @@ -0,0 +1,486 @@ +# What is this? +## Unit tests for 'dynamic_rate_limiter.py` +import asyncio +import os +import random +import sys +import time +import traceback +import uuid +from datetime import datetime +from typing import Optional, Tuple + +from dotenv import load_dotenv + +load_dotenv() +import os + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import pytest + +import litellm +from litellm import DualCache, Router +from litellm.proxy._types import UserAPIKeyAuth +from litellm.proxy.hooks.dynamic_rate_limiter import ( + _PROXY_DynamicRateLimitHandler as DynamicRateLimitHandler, +) + +""" +Basic test cases: + +- If 1 'active' project => give all tpm +- If 2 'active' projects => divide tpm in 2 +""" + + +@pytest.fixture +def dynamic_rate_limit_handler() -> DynamicRateLimitHandler: + internal_cache = DualCache() + return DynamicRateLimitHandler(internal_usage_cache=internal_cache) + + +@pytest.fixture +def mock_response() -> litellm.ModelResponse: + return litellm.ModelResponse( + **{ + "id": "chatcmpl-abc123", + "object": "chat.completion", + "created": 1699896916, + "model": "gpt-3.5-turbo-0125", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_abc123", + "type": "function", + "function": { + "name": "get_current_weather", + "arguments": '{\n"location": "Boston, MA"\n}', + }, + } + ], + }, + "logprobs": None, + "finish_reason": "tool_calls", + } + ], + "usage": {"prompt_tokens": 5, "completion_tokens": 5, "total_tokens": 10}, + } + ) + + +@pytest.fixture +def user_api_key_auth() -> UserAPIKeyAuth: + return UserAPIKeyAuth() + + +@pytest.mark.parametrize("num_projects", [1, 2, 100]) +@pytest.mark.asyncio +async def test_available_tpm(num_projects, dynamic_rate_limit_handler): + model = "my-fake-model" + ## SET CACHE W/ ACTIVE PROJECTS + projects = [str(uuid.uuid4()) for _ in range(num_projects)] + + await dynamic_rate_limit_handler.internal_usage_cache.async_set_cache_sadd( + model=model, value=projects + ) + + model_tpm = 100 + llm_router = Router( + model_list=[ + { + "model_name": model, + "litellm_params": { + "model": "gpt-3.5-turbo", + "api_key": "my-key", + "api_base": "my-base", + "tpm": model_tpm, + }, + } + ] + ) + dynamic_rate_limit_handler.update_variables(llm_router=llm_router) + + ## CHECK AVAILABLE TPM PER PROJECT + + availability, _, _ = await dynamic_rate_limit_handler.check_available_tpm( + model=model + ) + + expected_availability = int(model_tpm / num_projects) + + assert availability == expected_availability + + +@pytest.mark.asyncio +async def test_rate_limit_raised(dynamic_rate_limit_handler, user_api_key_auth): + """ + Unit test. Tests if rate limit error raised when quota exhausted. + """ + from fastapi import HTTPException + + model = "my-fake-model" + ## SET CACHE W/ ACTIVE PROJECTS + projects = [str(uuid.uuid4())] + + await dynamic_rate_limit_handler.internal_usage_cache.async_set_cache_sadd( + model=model, value=projects + ) + + model_tpm = 0 + llm_router = Router( + model_list=[ + { + "model_name": model, + "litellm_params": { + "model": "gpt-3.5-turbo", + "api_key": "my-key", + "api_base": "my-base", + "tpm": model_tpm, + }, + } + ] + ) + dynamic_rate_limit_handler.update_variables(llm_router=llm_router) + + ## CHECK AVAILABLE TPM PER PROJECT + + availability, _, _ = await dynamic_rate_limit_handler.check_available_tpm( + model=model + ) + + expected_availability = int(model_tpm / 1) + + assert availability == expected_availability + + ## CHECK if exception raised + + try: + await dynamic_rate_limit_handler.async_pre_call_hook( + user_api_key_dict=user_api_key_auth, + cache=DualCache(), + data={"model": model}, + call_type="completion", + ) + pytest.fail("Expected this to raise HTTPexception") + except HTTPException as e: + assert e.status_code == 429 # check if rate limit error raised + pass + + +@pytest.mark.asyncio +async def test_base_case(dynamic_rate_limit_handler, mock_response): + """ + If just 1 active project + + it should get all the quota + + = allow request to go through + - update token usage + - exhaust all tpm with just 1 project + - assert ratelimiterror raised at 100%+1 tpm + """ + model = "my-fake-model" + ## model tpm - 50 + model_tpm = 50 + ## tpm per request - 10 + setattr( + mock_response, + "usage", + litellm.Usage(prompt_tokens=5, completion_tokens=5, total_tokens=10), + ) + + llm_router = Router( + model_list=[ + { + "model_name": model, + "litellm_params": { + "model": "gpt-3.5-turbo", + "api_key": "my-key", + "api_base": "my-base", + "tpm": model_tpm, + "mock_response": mock_response, + }, + } + ] + ) + dynamic_rate_limit_handler.update_variables(llm_router=llm_router) + + prev_availability: Optional[int] = None + allowed_fails = 1 + for _ in range(5): + try: + # check availability + availability, _, _ = await dynamic_rate_limit_handler.check_available_tpm( + model=model + ) + + ## assert availability updated + if prev_availability is not None and availability is not None: + assert availability == prev_availability - 10 + + print( + "prev_availability={}, availability={}".format( + prev_availability, availability + ) + ) + + prev_availability = availability + + # make call + await llm_router.acompletion( + model=model, messages=[{"role": "user", "content": "hey!"}] + ) + + await asyncio.sleep(3) + except Exception: + if allowed_fails > 0: + allowed_fails -= 1 + else: + raise + + +@pytest.mark.asyncio +async def test_update_cache( + dynamic_rate_limit_handler, mock_response, user_api_key_auth +): + """ + Check if active project correctly updated + """ + model = "my-fake-model" + model_tpm = 50 + + llm_router = Router( + model_list=[ + { + "model_name": model, + "litellm_params": { + "model": "gpt-3.5-turbo", + "api_key": "my-key", + "api_base": "my-base", + "tpm": model_tpm, + "mock_response": mock_response, + }, + } + ] + ) + dynamic_rate_limit_handler.update_variables(llm_router=llm_router) + + ## INITIAL ACTIVE PROJECTS - ASSERT NONE + _, _, active_projects = await dynamic_rate_limit_handler.check_available_tpm( + model=model + ) + + assert active_projects is None + + ## MAKE CALL + await dynamic_rate_limit_handler.async_pre_call_hook( + user_api_key_dict=user_api_key_auth, + cache=DualCache(), + data={"model": model}, + call_type="completion", + ) + + await asyncio.sleep(2) + ## INITIAL ACTIVE PROJECTS - ASSERT 1 + _, _, active_projects = await dynamic_rate_limit_handler.check_available_tpm( + model=model + ) + + assert active_projects == 1 + + +@pytest.mark.parametrize("num_projects", [2]) +@pytest.mark.asyncio +async def test_multiple_projects( + dynamic_rate_limit_handler, mock_response, num_projects +): + """ + If 2 active project + + it should split 50% each + + - assert available tpm is 0 after 50%+1 tpm calls + """ + model = "my-fake-model" + model_tpm = 50 + total_tokens_per_call = 10 + step_tokens_per_call_per_project = total_tokens_per_call / num_projects + + available_tpm_per_project = int(model_tpm / num_projects) + + ## SET CACHE W/ ACTIVE PROJECTS + projects = [str(uuid.uuid4()) for _ in range(num_projects)] + await dynamic_rate_limit_handler.internal_usage_cache.async_set_cache_sadd( + model=model, value=projects + ) + + expected_runs = int(available_tpm_per_project / step_tokens_per_call_per_project) + + setattr( + mock_response, + "usage", + litellm.Usage( + prompt_tokens=5, completion_tokens=5, total_tokens=total_tokens_per_call + ), + ) + + llm_router = Router( + model_list=[ + { + "model_name": model, + "litellm_params": { + "model": "gpt-3.5-turbo", + "api_key": "my-key", + "api_base": "my-base", + "tpm": model_tpm, + "mock_response": mock_response, + }, + } + ] + ) + dynamic_rate_limit_handler.update_variables(llm_router=llm_router) + + prev_availability: Optional[int] = None + + print("expected_runs: {}".format(expected_runs)) + for i in range(expected_runs + 1): + # check availability + availability, _, _ = await dynamic_rate_limit_handler.check_available_tpm( + model=model + ) + + ## assert availability updated + if prev_availability is not None and availability is not None: + assert ( + availability == prev_availability - step_tokens_per_call_per_project + ), "Current Availability: Got={}, Expected={}, Step={}, Tokens per step={}, Initial model tpm={}".format( + availability, + prev_availability - 10, + i, + step_tokens_per_call_per_project, + model_tpm, + ) + + print( + "prev_availability={}, availability={}".format( + prev_availability, availability + ) + ) + + prev_availability = availability + + # make call + await llm_router.acompletion( + model=model, messages=[{"role": "user", "content": "hey!"}] + ) + + await asyncio.sleep(3) + + # check availability + availability, _, _ = await dynamic_rate_limit_handler.check_available_tpm( + model=model + ) + assert availability == 0 + + +@pytest.mark.parametrize("num_projects", [2]) +@pytest.mark.asyncio +async def test_multiple_projects_e2e( + dynamic_rate_limit_handler, mock_response, num_projects +): + """ + 2 parallel calls with different keys, same model + + If 2 active project + + it should split 50% each + + - assert available tpm is 0 after 50%+1 tpm calls + """ + model = "my-fake-model" + model_tpm = 50 + total_tokens_per_call = 10 + step_tokens_per_call_per_project = total_tokens_per_call / num_projects + + available_tpm_per_project = int(model_tpm / num_projects) + + ## SET CACHE W/ ACTIVE PROJECTS + projects = [str(uuid.uuid4()) for _ in range(num_projects)] + await dynamic_rate_limit_handler.internal_usage_cache.async_set_cache_sadd( + model=model, value=projects + ) + + expected_runs = int(available_tpm_per_project / step_tokens_per_call_per_project) + + setattr( + mock_response, + "usage", + litellm.Usage( + prompt_tokens=5, completion_tokens=5, total_tokens=total_tokens_per_call + ), + ) + + llm_router = Router( + model_list=[ + { + "model_name": model, + "litellm_params": { + "model": "gpt-3.5-turbo", + "api_key": "my-key", + "api_base": "my-base", + "tpm": model_tpm, + "mock_response": mock_response, + }, + } + ] + ) + dynamic_rate_limit_handler.update_variables(llm_router=llm_router) + + prev_availability: Optional[int] = None + + print("expected_runs: {}".format(expected_runs)) + for i in range(expected_runs + 1): + # check availability + availability, _, _ = await dynamic_rate_limit_handler.check_available_tpm( + model=model + ) + + ## assert availability updated + if prev_availability is not None and availability is not None: + assert ( + availability == prev_availability - step_tokens_per_call_per_project + ), "Current Availability: Got={}, Expected={}, Step={}, Tokens per step={}, Initial model tpm={}".format( + availability, + prev_availability - 10, + i, + step_tokens_per_call_per_project, + model_tpm, + ) + + print( + "prev_availability={}, availability={}".format( + prev_availability, availability + ) + ) + + prev_availability = availability + + # make call + await llm_router.acompletion( + model=model, messages=[{"role": "user", "content": "hey!"}] + ) + + await asyncio.sleep(3) + + # check availability + availability, _, _ = await dynamic_rate_limit_handler.check_available_tpm( + model=model + ) + assert availability == 0 diff --git a/litellm/tests/test_router.py b/litellm/tests/test_router.py index d2037dc59e..2e88143273 100644 --- a/litellm/tests/test_router.py +++ b/litellm/tests/test_router.py @@ -1730,3 +1730,99 @@ async def test_router_text_completion_client(): print(responses) except Exception as e: pytest.fail(f"Error occurred: {e}") + + +@pytest.fixture +def mock_response() -> litellm.ModelResponse: + return litellm.ModelResponse( + **{ + "id": "chatcmpl-abc123", + "object": "chat.completion", + "created": 1699896916, + "model": "gpt-3.5-turbo-0125", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_abc123", + "type": "function", + "function": { + "name": "get_current_weather", + "arguments": '{\n"location": "Boston, MA"\n}', + }, + } + ], + }, + "logprobs": None, + "finish_reason": "tool_calls", + } + ], + "usage": {"prompt_tokens": 5, "completion_tokens": 5, "total_tokens": 10}, + } + ) + + +@pytest.mark.asyncio +async def test_router_model_usage(mock_response): + """ + Test if tracking used model tpm works as expected + """ + model = "my-fake-model" + model_tpm = 100 + setattr( + mock_response, + "usage", + litellm.Usage(prompt_tokens=5, completion_tokens=5, total_tokens=10), + ) + + print(f"mock_response: {mock_response}") + model_tpm = 100 + llm_router = Router( + model_list=[ + { + "model_name": model, + "litellm_params": { + "model": "gpt-3.5-turbo", + "api_key": "my-key", + "api_base": "my-base", + "tpm": model_tpm, + "mock_response": mock_response, + }, + } + ] + ) + + allowed_fails = 1 # allow for changing b/w minutes + + for _ in range(2): + try: + _ = await llm_router.acompletion( + model=model, messages=[{"role": "user", "content": "Hey!"}] + ) + await asyncio.sleep(3) + + initial_usage = await llm_router.get_model_group_usage(model_group=model) + + # completion call - 10 tokens + _ = await llm_router.acompletion( + model=model, messages=[{"role": "user", "content": "Hey!"}] + ) + + await asyncio.sleep(3) + updated_usage = await llm_router.get_model_group_usage(model_group=model) + + assert updated_usage == initial_usage + 10 # type: ignore + break + except Exception as e: + if allowed_fails > 0: + print( + f"Decrementing allowed_fails: {allowed_fails}.\nReceived error - {str(e)}" + ) + allowed_fails -= 1 + else: + print(f"allowed_fails: {allowed_fails}") + raise e diff --git a/litellm/types/router.py b/litellm/types/router.py index 206216ef0c..7f043e4042 100644 --- a/litellm/types/router.py +++ b/litellm/types/router.py @@ -442,6 +442,8 @@ class ModelGroupInfo(BaseModel): "chat", "embedding", "completion", "image_generation", "audio_transcription" ] ] = Field(default="chat") + tpm: Optional[int] = None + rpm: Optional[int] = None supports_parallel_function_calling: bool = Field(default=False) supports_vision: bool = Field(default=False) supports_function_calling: bool = Field(default=False) diff --git a/litellm/utils.py b/litellm/utils.py index 831ae433cf..19d99ff59b 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -340,14 +340,15 @@ def function_setup( ) try: global callback_list, add_breadcrumb, user_logger_fn, Logging + function_id = kwargs["id"] if "id" in kwargs else None if len(litellm.callbacks) > 0: for callback in litellm.callbacks: # check if callback is a string - e.g. "lago", "openmeter" if isinstance(callback, str): - callback = litellm.litellm_core_utils.litellm_logging._init_custom_logger_compatible_class( - callback + callback = litellm.litellm_core_utils.litellm_logging._init_custom_logger_compatible_class( # type: ignore + callback, internal_usage_cache=None, llm_router=None ) if any( isinstance(cb, type(callback)) @@ -3895,12 +3896,16 @@ def get_formatted_prompt( def get_response_string(response_obj: ModelResponse) -> str: - _choices: List[Choices] = response_obj.choices # type: ignore + _choices: List[Union[Choices, StreamingChoices]] = response_obj.choices response_str = "" for choice in _choices: - if choice.message.content is not None: - response_str += choice.message.content + if isinstance(choice, Choices): + if choice.message.content is not None: + response_str += choice.message.content + elif isinstance(choice, StreamingChoices): + if choice.delta.content is not None: + response_str += choice.delta.content return response_str