Merge pull request #2996 from BerriAI/litellm_semaphores

fix(router.py): initial commit for semaphores on router
This commit is contained in:
Krish Dholakia 2024-04-12 23:23:36 -07:00 committed by GitHub
commit fd7760d3db
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 198 additions and 41 deletions

View file

@ -189,6 +189,9 @@ jobs:
-p 4000:4000 \ -p 4000:4000 \
-e DATABASE_URL=$PROXY_DOCKER_DB_URL \ -e DATABASE_URL=$PROXY_DOCKER_DB_URL \
-e AZURE_API_KEY=$AZURE_API_KEY \ -e AZURE_API_KEY=$AZURE_API_KEY \
-e REDIS_HOST=$REDIS_HOST \
-e REDIS_PASSWORD=$REDIS_PASSWORD \
-e REDIS_PORT=$REDIS_PORT \
-e AZURE_FRANCE_API_KEY=$AZURE_FRANCE_API_KEY \ -e AZURE_FRANCE_API_KEY=$AZURE_FRANCE_API_KEY \
-e AZURE_EUROPE_API_KEY=$AZURE_EUROPE_API_KEY \ -e AZURE_EUROPE_API_KEY=$AZURE_EUROPE_API_KEY \
-e AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID \ -e AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID \

View file

@ -98,11 +98,12 @@ class InMemoryCache(BaseCache):
return_val.append(val) return_val.append(val)
return return_val return return_val
async def async_increment(self, key, value: int, **kwargs): async def async_increment(self, key, value: int, **kwargs) -> int:
# get the value # get the value
init_value = await self.async_get_cache(key=key) or 0 init_value = await self.async_get_cache(key=key) or 0
value = init_value + value value = init_value + value
await self.async_set_cache(key, value, **kwargs) await self.async_set_cache(key, value, **kwargs)
return value
def flush_cache(self): def flush_cache(self):
self.cache_dict.clear() self.cache_dict.clear()
@ -266,11 +267,12 @@ class RedisCache(BaseCache):
if len(self.redis_batch_writing_buffer) >= self.redis_flush_size: if len(self.redis_batch_writing_buffer) >= self.redis_flush_size:
await self.flush_cache_buffer() await self.flush_cache_buffer()
async def async_increment(self, key, value: int, **kwargs): async def async_increment(self, key, value: int, **kwargs) -> int:
_redis_client = self.init_async_client() _redis_client = self.init_async_client()
try: try:
async with _redis_client as redis_client: async with _redis_client as redis_client:
await redis_client.incr(name=key, amount=value) result = await redis_client.incr(name=key, amount=value)
return result
except Exception as e: except Exception as e:
verbose_logger.error( verbose_logger.error(
"LiteLLM Redis Caching: async async_increment() - Got exception from REDIS %s, Writing value=%s", "LiteLLM Redis Caching: async async_increment() - Got exception from REDIS %s, Writing value=%s",
@ -278,6 +280,7 @@ class RedisCache(BaseCache):
value, value,
) )
traceback.print_exc() traceback.print_exc()
raise e
async def flush_cache_buffer(self): async def flush_cache_buffer(self):
print_verbose( print_verbose(
@ -1076,21 +1079,29 @@ class DualCache(BaseCache):
async def async_increment_cache( async def async_increment_cache(
self, key, value: int, local_only: bool = False, **kwargs self, key, value: int, local_only: bool = False, **kwargs
): ) -> int:
""" """
Key - the key in cache Key - the key in cache
Value - int - the value you want to increment by Value - int - the value you want to increment by
Returns - int - the incremented value
""" """
try: try:
result: int = value
if self.in_memory_cache is not None: if self.in_memory_cache is not None:
await self.in_memory_cache.async_increment(key, value, **kwargs) result = await self.in_memory_cache.async_increment(
key, value, **kwargs
)
if self.redis_cache is not None and local_only == False: if self.redis_cache is not None and local_only == False:
await self.redis_cache.async_increment(key, value, **kwargs) result = await self.redis_cache.async_increment(key, value, **kwargs)
return result
except Exception as e: except Exception as e:
print_verbose(f"LiteLLM Cache: Excepton async add_cache: {str(e)}") print_verbose(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
traceback.print_exc() traceback.print_exc()
raise e
def flush_cache(self): def flush_cache(self):
if self.in_memory_cache is not None: if self.in_memory_cache is not None:

View file

@ -1836,6 +1836,9 @@ async def _run_background_health_check():
await asyncio.sleep(health_check_interval) await asyncio.sleep(health_check_interval)
semaphore = asyncio.Semaphore(1)
class ProxyConfig: class ProxyConfig:
""" """
Abstraction class on top of config loading/updating logic. Gives us one place to control all config updating logic. Abstraction class on top of config loading/updating logic. Gives us one place to control all config updating logic.
@ -2425,8 +2428,7 @@ class ProxyConfig:
for k, v in router_settings.items(): for k, v in router_settings.items():
if k in available_args: if k in available_args:
router_params[k] = v router_params[k] = v
router = litellm.Router(**router_params, semaphore=semaphore) # type:ignore
router = litellm.Router(**router_params) # type:ignore
return router, model_list, general_settings return router, model_list, general_settings
async def add_deployment( async def add_deployment(
@ -3421,6 +3423,7 @@ async def chat_completion(
): ):
global general_settings, user_debug, proxy_logging_obj, llm_model_list global general_settings, user_debug, proxy_logging_obj, llm_model_list
try: try:
# async with llm_router.sem
data = {} data = {}
body = await request.body() body = await request.body()
body_str = body.decode() body_str = body.decode()
@ -3525,7 +3528,9 @@ async def chat_completion(
tasks = [] tasks = []
tasks.append( tasks.append(
proxy_logging_obj.during_call_hook( proxy_logging_obj.during_call_hook(
data=data, user_api_key_dict=user_api_key_dict, call_type="completion" data=data,
user_api_key_dict=user_api_key_dict,
call_type="completion",
) )
) )

View file

@ -30,7 +30,7 @@ from litellm.utils import ModelResponse, CustomStreamWrapper, get_utc_datetime
import copy import copy
from litellm._logging import verbose_router_logger from litellm._logging import verbose_router_logger
import logging import logging
from litellm.types.router import Deployment, ModelInfo, LiteLLM_Params from litellm.types.router import Deployment, ModelInfo, LiteLLM_Params, RouterErrors
class Router: class Router:
@ -78,6 +78,7 @@ class Router:
"latency-based-routing", "latency-based-routing",
] = "simple-shuffle", ] = "simple-shuffle",
routing_strategy_args: dict = {}, # just for latency-based routing routing_strategy_args: dict = {}, # just for latency-based routing
semaphore: Optional[asyncio.Semaphore] = None,
) -> None: ) -> None:
""" """
Initialize the Router class with the given parameters for caching, reliability, and routing strategy. Initialize the Router class with the given parameters for caching, reliability, and routing strategy.
@ -143,6 +144,8 @@ class Router:
router = Router(model_list=model_list, fallbacks=[{"azure-gpt-3.5-turbo": "openai-gpt-3.5-turbo"}]) router = Router(model_list=model_list, fallbacks=[{"azure-gpt-3.5-turbo": "openai-gpt-3.5-turbo"}])
``` ```
""" """
if semaphore:
self.semaphore = semaphore
self.set_verbose = set_verbose self.set_verbose = set_verbose
self.debug_level = debug_level self.debug_level = debug_level
self.enable_pre_call_checks = enable_pre_call_checks self.enable_pre_call_checks = enable_pre_call_checks
@ -409,11 +412,18 @@ class Router:
raise e raise e
async def _acompletion(self, model: str, messages: List[Dict[str, str]], **kwargs): async def _acompletion(self, model: str, messages: List[Dict[str, str]], **kwargs):
"""
- Get an available deployment
- call it with a semaphore over the call
- semaphore specific to it's rpm
- in the semaphore, make a check against it's local rpm before running
"""
model_name = None model_name = None
try: try:
verbose_router_logger.debug( verbose_router_logger.debug(
f"Inside _acompletion()- model: {model}; kwargs: {kwargs}" f"Inside _acompletion()- model: {model}; kwargs: {kwargs}"
) )
deployment = await self.async_get_available_deployment( deployment = await self.async_get_available_deployment(
model=model, model=model,
messages=messages, messages=messages,
@ -443,6 +453,7 @@ class Router:
potential_model_client = self._get_client( potential_model_client = self._get_client(
deployment=deployment, kwargs=kwargs, client_type="async" deployment=deployment, kwargs=kwargs, client_type="async"
) )
# check if provided keys == client keys # # check if provided keys == client keys #
dynamic_api_key = kwargs.get("api_key", None) dynamic_api_key = kwargs.get("api_key", None)
if ( if (
@ -465,7 +476,7 @@ class Router:
) # this uses default_litellm_params when nothing is set ) # this uses default_litellm_params when nothing is set
) )
response = await litellm.acompletion( _response = litellm.acompletion(
**{ **{
**data, **data,
"messages": messages, "messages": messages,
@ -475,6 +486,25 @@ class Router:
**kwargs, **kwargs,
} }
) )
rpm_semaphore = self._get_client(
deployment=deployment, kwargs=kwargs, client_type="rpm_client"
)
if (
rpm_semaphore is not None
and isinstance(rpm_semaphore, asyncio.Semaphore)
and self.routing_strategy == "usage-based-routing-v2"
):
async with rpm_semaphore:
"""
- Check rpm limits before making the call
"""
await self.lowesttpm_logger_v2.pre_call_rpm_check(deployment)
response = await _response
else:
response = await _response
self.success_calls[model_name] += 1 self.success_calls[model_name] += 1
verbose_router_logger.info( verbose_router_logger.info(
f"litellm.acompletion(model={model_name})\033[32m 200 OK\033[0m" f"litellm.acompletion(model={model_name})\033[32m 200 OK\033[0m"
@ -1265,6 +1295,8 @@ class Router:
min_timeout=self.retry_after, min_timeout=self.retry_after,
) )
await asyncio.sleep(timeout) await asyncio.sleep(timeout)
elif RouterErrors.user_defined_ratelimit_error.value in str(e):
raise e # don't wait to retry if deployment hits user-defined rate-limit
elif hasattr(original_exception, "status_code") and litellm._should_retry( elif hasattr(original_exception, "status_code") and litellm._should_retry(
status_code=original_exception.status_code status_code=original_exception.status_code
): ):
@ -1680,12 +1712,26 @@ class Router:
def set_client(self, model: dict): def set_client(self, model: dict):
""" """
Initializes Azure/OpenAI clients. Stores them in cache, b/c of this - https://github.com/BerriAI/litellm/issues/1278 - Initializes Azure/OpenAI clients. Stores them in cache, b/c of this - https://github.com/BerriAI/litellm/issues/1278
- Initializes Semaphore for client w/ rpm. Stores them in cache. b/c of this - https://github.com/BerriAI/litellm/issues/2994
""" """
client_ttl = self.client_ttl client_ttl = self.client_ttl
litellm_params = model.get("litellm_params", {}) litellm_params = model.get("litellm_params", {})
model_name = litellm_params.get("model") model_name = litellm_params.get("model")
model_id = model["model_info"]["id"] model_id = model["model_info"]["id"]
# ### IF RPM SET - initialize a semaphore ###
rpm = litellm_params.get("rpm", None)
if rpm:
semaphore = asyncio.Semaphore(rpm)
cache_key = f"{model_id}_rpm_client"
self.cache.set_cache(
key=cache_key,
value=semaphore,
local_only=True,
)
# print("STORES SEMAPHORE IN CACHE")
#### for OpenAI / Azure we need to initalize the Client for High Traffic ######## #### for OpenAI / Azure we need to initalize the Client for High Traffic ########
custom_llm_provider = litellm_params.get("custom_llm_provider") custom_llm_provider = litellm_params.get("custom_llm_provider")
custom_llm_provider = custom_llm_provider or model_name.split("/", 1)[0] or "" custom_llm_provider = custom_llm_provider or model_name.split("/", 1)[0] or ""
@ -2275,7 +2321,11 @@ class Router:
The appropriate client based on the given client_type and kwargs. The appropriate client based on the given client_type and kwargs.
""" """
model_id = deployment["model_info"]["id"] model_id = deployment["model_info"]["id"]
if client_type == "async": if client_type == "rpm_client":
cache_key = "{}_rpm_client".format(model_id)
client = self.cache.get_cache(key=cache_key, local_only=True)
return client
elif client_type == "async":
if kwargs.get("stream") == True: if kwargs.get("stream") == True:
cache_key = f"{model_id}_stream_async_client" cache_key = f"{model_id}_stream_async_client"
client = self.cache.get_cache(key=cache_key, local_only=True) client = self.cache.get_cache(key=cache_key, local_only=True)
@ -2328,6 +2378,7 @@ class Router:
Filter out model in model group, if: Filter out model in model group, if:
- model context window < message length - model context window < message length
- filter models above rpm limits
- [TODO] function call and model doesn't support function calling - [TODO] function call and model doesn't support function calling
""" """
verbose_router_logger.debug( verbose_router_logger.debug(
@ -2352,7 +2403,7 @@ class Router:
rpm_key = f"{model}:rpm:{current_minute}" rpm_key = f"{model}:rpm:{current_minute}"
model_group_cache = ( model_group_cache = (
self.cache.get_cache(key=rpm_key, local_only=True) or {} self.cache.get_cache(key=rpm_key, local_only=True) or {}
) # check the redis + in-memory cache used by lowest_latency and usage-based routing. Only check the local cache. ) # check the in-memory cache used by lowest_latency and usage-based routing. Only check the local cache.
for idx, deployment in enumerate(_returned_deployments): for idx, deployment in enumerate(_returned_deployments):
# see if we have the info for this model # see if we have the info for this model
try: try:
@ -2388,6 +2439,7 @@ class Router:
self.cache.get_cache(key=model_id, local_only=True) or 0 self.cache.get_cache(key=model_id, local_only=True) or 0
) )
### get usage based cache ### ### get usage based cache ###
if isinstance(model_group_cache, dict):
model_group_cache[model_id] = model_group_cache.get(model_id, 0) model_group_cache[model_id] = model_group_cache.get(model_id, 0)
current_request = max( current_request = max(

View file

@ -7,12 +7,14 @@ import datetime as datetime_og
from datetime import datetime from datetime import datetime
dotenv.load_dotenv() # Loading env variables using dotenv dotenv.load_dotenv() # Loading env variables using dotenv
import traceback, asyncio import traceback, asyncio, httpx
import litellm
from litellm import token_counter from litellm import token_counter
from litellm.caching import DualCache from litellm.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from litellm._logging import verbose_router_logger from litellm._logging import verbose_router_logger
from litellm.utils import print_verbose, get_utc_datetime from litellm.utils import print_verbose, get_utc_datetime
from litellm.types.router import RouterErrors
class LowestTPMLoggingHandler_v2(CustomLogger): class LowestTPMLoggingHandler_v2(CustomLogger):
@ -37,6 +39,86 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
self.router_cache = router_cache self.router_cache = router_cache
self.model_list = model_list self.model_list = model_list
async def pre_call_rpm_check(self, deployment: dict) -> dict:
"""
Pre-call check + update model rpm
- Used inside semaphore
- raise rate limit error if deployment over limit
Why? solves concurrency issue - https://github.com/BerriAI/litellm/issues/2994
Returns - deployment
Raises - RateLimitError if deployment over defined RPM limit
"""
try:
# ------------
# Setup values
# ------------
dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M")
model_group = deployment.get("model_name", "")
rpm_key = f"{model_group}:rpm:{current_minute}"
local_result = await self.router_cache.async_get_cache(
key=rpm_key, local_only=True
) # check local result first
deployment_rpm = None
if deployment_rpm is None:
deployment_rpm = deployment.get("rpm")
if deployment_rpm is None:
deployment_rpm = deployment.get("litellm_params", {}).get("rpm")
if deployment_rpm is None:
deployment_rpm = deployment.get("model_info", {}).get("rpm")
if deployment_rpm is None:
deployment_rpm = float("inf")
if local_result is not None and local_result >= deployment_rpm:
raise litellm.RateLimitError(
message="Deployment over defined rpm limit={}. current usage={}".format(
deployment_rpm, local_result
),
llm_provider="",
model=deployment.get("litellm_params", {}).get("model"),
response=httpx.Response(
status_code=429,
content="{} rpm limit={}. current usage={}".format(
RouterErrors.user_defined_ratelimit_error.value,
deployment_rpm,
local_result,
),
request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore
),
)
else:
# if local result below limit, check redis ## prevent unnecessary redis checks
result = await self.router_cache.async_increment_cache(
key=rpm_key, value=1
)
if result is not None and result > deployment_rpm:
raise litellm.RateLimitError(
message="Deployment over defined rpm limit={}. current usage={}".format(
deployment_rpm, result
),
llm_provider="",
model=deployment.get("litellm_params", {}).get("model"),
response=httpx.Response(
status_code=429,
content="{} rpm limit={}. current usage={}".format(
RouterErrors.user_defined_ratelimit_error.value,
deployment_rpm,
result,
),
request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore
),
)
return deployment
except Exception as e:
if isinstance(e, litellm.RateLimitError):
raise e
return deployment # don't fail calls if eg. redis fails to connect
def log_success_event(self, kwargs, response_obj, start_time, end_time): def log_success_event(self, kwargs, response_obj, start_time, end_time):
try: try:
""" """
@ -91,7 +173,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
try: try:
""" """
Update TPM/RPM usage on success Update TPM usage on success
""" """
if kwargs["litellm_params"].get("metadata") is None: if kwargs["litellm_params"].get("metadata") is None:
pass pass
@ -117,8 +199,6 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
) # use the same timezone regardless of system clock ) # use the same timezone regardless of system clock
tpm_key = f"{id}:tpm:{current_minute}" tpm_key = f"{id}:tpm:{current_minute}"
rpm_key = f"{id}:rpm:{current_minute}"
# ------------ # ------------
# Update usage # Update usage
# ------------ # ------------
@ -128,8 +208,6 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
await self.router_cache.async_increment_cache( await self.router_cache.async_increment_cache(
key=tpm_key, value=total_tokens key=tpm_key, value=total_tokens
) )
## RPM
await self.router_cache.async_increment_cache(key=rpm_key, value=1)
### TESTING ### ### TESTING ###
if self.test_flag: if self.test_flag:

View file

@ -3,7 +3,7 @@ from typing import List, Optional, Union, Dict, Tuple, Literal
from pydantic import BaseModel, validator from pydantic import BaseModel, validator
from .completion import CompletionRequest from .completion import CompletionRequest
from .embedding import EmbeddingRequest from .embedding import EmbeddingRequest
import uuid import uuid, enum
class ModelConfig(BaseModel): class ModelConfig(BaseModel):
@ -166,3 +166,11 @@ class Deployment(BaseModel):
def __setitem__(self, key, value): def __setitem__(self, key, value):
# Allow dictionary-style assignment of attributes # Allow dictionary-style assignment of attributes
setattr(self, key, value) setattr(self, key, value)
class RouterErrors(enum.Enum):
"""
Enum for router specific errors with common codes
"""
user_defined_ratelimit_error = "Deployment over user-defined ratelimit."

View file

@ -67,12 +67,12 @@ litellm_settings:
telemetry: False telemetry: False
context_window_fallbacks: [{"gpt-3.5-turbo": ["gpt-3.5-turbo-large"]}] context_window_fallbacks: [{"gpt-3.5-turbo": ["gpt-3.5-turbo-large"]}]
# router_settings: router_settings:
# routing_strategy: usage-based-routing-v2 routing_strategy: usage-based-routing-v2
# redis_host: os.environ/REDIS_HOST redis_host: os.environ/REDIS_HOST
# redis_password: os.environ/REDIS_PASSWORD redis_password: os.environ/REDIS_PASSWORD
# redis_port: os.environ/REDIS_PORT redis_port: os.environ/REDIS_PORT
# enable_pre_call_checks: true enable_pre_call_checks: true
general_settings: general_settings:
master_key: sk-1234 # [OPTIONAL] Use to enforce auth on proxy. See - https://docs.litellm.ai/docs/proxy/virtual_keys master_key: sk-1234 # [OPTIONAL] Use to enforce auth on proxy. See - https://docs.litellm.ai/docs/proxy/virtual_keys

View file

@ -194,7 +194,7 @@ async def test_chat_completion():
await chat_completion(session=session, key=key_2) await chat_completion(session=session, key=key_2)
@pytest.mark.skip(reason="Local test. Proxy not concurrency safe yet. WIP.") # @pytest.mark.skip(reason="Local test. Proxy not concurrency safe yet. WIP.")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_chat_completion_ratelimit(): async def test_chat_completion_ratelimit():
""" """