Merge pull request #2996 from BerriAI/litellm_semaphores

fix(router.py): initial commit for semaphores on router
This commit is contained in:
Krish Dholakia 2024-04-12 23:23:36 -07:00 committed by GitHub
commit fd7760d3db
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 198 additions and 41 deletions

View file

@ -189,6 +189,9 @@ jobs:
-p 4000:4000 \
-e DATABASE_URL=$PROXY_DOCKER_DB_URL \
-e AZURE_API_KEY=$AZURE_API_KEY \
-e REDIS_HOST=$REDIS_HOST \
-e REDIS_PASSWORD=$REDIS_PASSWORD \
-e REDIS_PORT=$REDIS_PORT \
-e AZURE_FRANCE_API_KEY=$AZURE_FRANCE_API_KEY \
-e AZURE_EUROPE_API_KEY=$AZURE_EUROPE_API_KEY \
-e AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID \

View file

@ -98,11 +98,12 @@ class InMemoryCache(BaseCache):
return_val.append(val)
return return_val
async def async_increment(self, key, value: int, **kwargs):
async def async_increment(self, key, value: int, **kwargs) -> int:
# get the value
init_value = await self.async_get_cache(key=key) or 0
value = init_value + value
await self.async_set_cache(key, value, **kwargs)
return value
def flush_cache(self):
self.cache_dict.clear()
@ -266,11 +267,12 @@ class RedisCache(BaseCache):
if len(self.redis_batch_writing_buffer) >= self.redis_flush_size:
await self.flush_cache_buffer()
async def async_increment(self, key, value: int, **kwargs):
async def async_increment(self, key, value: int, **kwargs) -> int:
_redis_client = self.init_async_client()
try:
async with _redis_client as redis_client:
await redis_client.incr(name=key, amount=value)
result = await redis_client.incr(name=key, amount=value)
return result
except Exception as e:
verbose_logger.error(
"LiteLLM Redis Caching: async async_increment() - Got exception from REDIS %s, Writing value=%s",
@ -278,6 +280,7 @@ class RedisCache(BaseCache):
value,
)
traceback.print_exc()
raise e
async def flush_cache_buffer(self):
print_verbose(
@ -1076,21 +1079,29 @@ class DualCache(BaseCache):
async def async_increment_cache(
self, key, value: int, local_only: bool = False, **kwargs
):
) -> int:
"""
Key - the key in cache
Value - int - the value you want to increment by
Returns - int - the incremented value
"""
try:
result: int = value
if self.in_memory_cache is not None:
await self.in_memory_cache.async_increment(key, value, **kwargs)
result = await self.in_memory_cache.async_increment(
key, value, **kwargs
)
if self.redis_cache is not None and local_only == False:
await self.redis_cache.async_increment(key, value, **kwargs)
result = await self.redis_cache.async_increment(key, value, **kwargs)
return result
except Exception as e:
print_verbose(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
traceback.print_exc()
raise e
def flush_cache(self):
if self.in_memory_cache is not None:

View file

@ -1836,6 +1836,9 @@ async def _run_background_health_check():
await asyncio.sleep(health_check_interval)
semaphore = asyncio.Semaphore(1)
class ProxyConfig:
"""
Abstraction class on top of config loading/updating logic. Gives us one place to control all config updating logic.
@ -2425,8 +2428,7 @@ class ProxyConfig:
for k, v in router_settings.items():
if k in available_args:
router_params[k] = v
router = litellm.Router(**router_params) # type:ignore
router = litellm.Router(**router_params, semaphore=semaphore) # type:ignore
return router, model_list, general_settings
async def add_deployment(
@ -3421,6 +3423,7 @@ async def chat_completion(
):
global general_settings, user_debug, proxy_logging_obj, llm_model_list
try:
# async with llm_router.sem
data = {}
body = await request.body()
body_str = body.decode()
@ -3525,7 +3528,9 @@ async def chat_completion(
tasks = []
tasks.append(
proxy_logging_obj.during_call_hook(
data=data, user_api_key_dict=user_api_key_dict, call_type="completion"
data=data,
user_api_key_dict=user_api_key_dict,
call_type="completion",
)
)

View file

@ -30,7 +30,7 @@ from litellm.utils import ModelResponse, CustomStreamWrapper, get_utc_datetime
import copy
from litellm._logging import verbose_router_logger
import logging
from litellm.types.router import Deployment, ModelInfo, LiteLLM_Params
from litellm.types.router import Deployment, ModelInfo, LiteLLM_Params, RouterErrors
class Router:
@ -78,6 +78,7 @@ class Router:
"latency-based-routing",
] = "simple-shuffle",
routing_strategy_args: dict = {}, # just for latency-based routing
semaphore: Optional[asyncio.Semaphore] = None,
) -> None:
"""
Initialize the Router class with the given parameters for caching, reliability, and routing strategy.
@ -143,6 +144,8 @@ class Router:
router = Router(model_list=model_list, fallbacks=[{"azure-gpt-3.5-turbo": "openai-gpt-3.5-turbo"}])
```
"""
if semaphore:
self.semaphore = semaphore
self.set_verbose = set_verbose
self.debug_level = debug_level
self.enable_pre_call_checks = enable_pre_call_checks
@ -409,11 +412,18 @@ class Router:
raise e
async def _acompletion(self, model: str, messages: List[Dict[str, str]], **kwargs):
"""
- Get an available deployment
- call it with a semaphore over the call
- semaphore specific to it's rpm
- in the semaphore, make a check against it's local rpm before running
"""
model_name = None
try:
verbose_router_logger.debug(
f"Inside _acompletion()- model: {model}; kwargs: {kwargs}"
)
deployment = await self.async_get_available_deployment(
model=model,
messages=messages,
@ -443,6 +453,7 @@ class Router:
potential_model_client = self._get_client(
deployment=deployment, kwargs=kwargs, client_type="async"
)
# check if provided keys == client keys #
dynamic_api_key = kwargs.get("api_key", None)
if (
@ -465,7 +476,7 @@ class Router:
) # this uses default_litellm_params when nothing is set
)
response = await litellm.acompletion(
_response = litellm.acompletion(
**{
**data,
"messages": messages,
@ -475,6 +486,25 @@ class Router:
**kwargs,
}
)
rpm_semaphore = self._get_client(
deployment=deployment, kwargs=kwargs, client_type="rpm_client"
)
if (
rpm_semaphore is not None
and isinstance(rpm_semaphore, asyncio.Semaphore)
and self.routing_strategy == "usage-based-routing-v2"
):
async with rpm_semaphore:
"""
- Check rpm limits before making the call
"""
await self.lowesttpm_logger_v2.pre_call_rpm_check(deployment)
response = await _response
else:
response = await _response
self.success_calls[model_name] += 1
verbose_router_logger.info(
f"litellm.acompletion(model={model_name})\033[32m 200 OK\033[0m"
@ -1265,6 +1295,8 @@ class Router:
min_timeout=self.retry_after,
)
await asyncio.sleep(timeout)
elif RouterErrors.user_defined_ratelimit_error.value in str(e):
raise e # don't wait to retry if deployment hits user-defined rate-limit
elif hasattr(original_exception, "status_code") and litellm._should_retry(
status_code=original_exception.status_code
):
@ -1680,12 +1712,26 @@ class Router:
def set_client(self, model: dict):
"""
Initializes Azure/OpenAI clients. Stores them in cache, b/c of this - https://github.com/BerriAI/litellm/issues/1278
- Initializes Azure/OpenAI clients. Stores them in cache, b/c of this - https://github.com/BerriAI/litellm/issues/1278
- Initializes Semaphore for client w/ rpm. Stores them in cache. b/c of this - https://github.com/BerriAI/litellm/issues/2994
"""
client_ttl = self.client_ttl
litellm_params = model.get("litellm_params", {})
model_name = litellm_params.get("model")
model_id = model["model_info"]["id"]
# ### IF RPM SET - initialize a semaphore ###
rpm = litellm_params.get("rpm", None)
if rpm:
semaphore = asyncio.Semaphore(rpm)
cache_key = f"{model_id}_rpm_client"
self.cache.set_cache(
key=cache_key,
value=semaphore,
local_only=True,
)
# print("STORES SEMAPHORE IN CACHE")
#### for OpenAI / Azure we need to initalize the Client for High Traffic ########
custom_llm_provider = litellm_params.get("custom_llm_provider")
custom_llm_provider = custom_llm_provider or model_name.split("/", 1)[0] or ""
@ -2275,7 +2321,11 @@ class Router:
The appropriate client based on the given client_type and kwargs.
"""
model_id = deployment["model_info"]["id"]
if client_type == "async":
if client_type == "rpm_client":
cache_key = "{}_rpm_client".format(model_id)
client = self.cache.get_cache(key=cache_key, local_only=True)
return client
elif client_type == "async":
if kwargs.get("stream") == True:
cache_key = f"{model_id}_stream_async_client"
client = self.cache.get_cache(key=cache_key, local_only=True)
@ -2328,6 +2378,7 @@ class Router:
Filter out model in model group, if:
- model context window < message length
- filter models above rpm limits
- [TODO] function call and model doesn't support function calling
"""
verbose_router_logger.debug(
@ -2352,7 +2403,7 @@ class Router:
rpm_key = f"{model}:rpm:{current_minute}"
model_group_cache = (
self.cache.get_cache(key=rpm_key, local_only=True) or {}
) # check the redis + in-memory cache used by lowest_latency and usage-based routing. Only check the local cache.
) # check the in-memory cache used by lowest_latency and usage-based routing. Only check the local cache.
for idx, deployment in enumerate(_returned_deployments):
# see if we have the info for this model
try:
@ -2388,6 +2439,7 @@ class Router:
self.cache.get_cache(key=model_id, local_only=True) or 0
)
### get usage based cache ###
if isinstance(model_group_cache, dict):
model_group_cache[model_id] = model_group_cache.get(model_id, 0)
current_request = max(

View file

@ -7,12 +7,14 @@ import datetime as datetime_og
from datetime import datetime
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback, asyncio
import traceback, asyncio, httpx
import litellm
from litellm import token_counter
from litellm.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger
from litellm._logging import verbose_router_logger
from litellm.utils import print_verbose, get_utc_datetime
from litellm.types.router import RouterErrors
class LowestTPMLoggingHandler_v2(CustomLogger):
@ -37,6 +39,86 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
self.router_cache = router_cache
self.model_list = model_list
async def pre_call_rpm_check(self, deployment: dict) -> dict:
"""
Pre-call check + update model rpm
- Used inside semaphore
- raise rate limit error if deployment over limit
Why? solves concurrency issue - https://github.com/BerriAI/litellm/issues/2994
Returns - deployment
Raises - RateLimitError if deployment over defined RPM limit
"""
try:
# ------------
# Setup values
# ------------
dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M")
model_group = deployment.get("model_name", "")
rpm_key = f"{model_group}:rpm:{current_minute}"
local_result = await self.router_cache.async_get_cache(
key=rpm_key, local_only=True
) # check local result first
deployment_rpm = None
if deployment_rpm is None:
deployment_rpm = deployment.get("rpm")
if deployment_rpm is None:
deployment_rpm = deployment.get("litellm_params", {}).get("rpm")
if deployment_rpm is None:
deployment_rpm = deployment.get("model_info", {}).get("rpm")
if deployment_rpm is None:
deployment_rpm = float("inf")
if local_result is not None and local_result >= deployment_rpm:
raise litellm.RateLimitError(
message="Deployment over defined rpm limit={}. current usage={}".format(
deployment_rpm, local_result
),
llm_provider="",
model=deployment.get("litellm_params", {}).get("model"),
response=httpx.Response(
status_code=429,
content="{} rpm limit={}. current usage={}".format(
RouterErrors.user_defined_ratelimit_error.value,
deployment_rpm,
local_result,
),
request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore
),
)
else:
# if local result below limit, check redis ## prevent unnecessary redis checks
result = await self.router_cache.async_increment_cache(
key=rpm_key, value=1
)
if result is not None and result > deployment_rpm:
raise litellm.RateLimitError(
message="Deployment over defined rpm limit={}. current usage={}".format(
deployment_rpm, result
),
llm_provider="",
model=deployment.get("litellm_params", {}).get("model"),
response=httpx.Response(
status_code=429,
content="{} rpm limit={}. current usage={}".format(
RouterErrors.user_defined_ratelimit_error.value,
deployment_rpm,
result,
),
request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore
),
)
return deployment
except Exception as e:
if isinstance(e, litellm.RateLimitError):
raise e
return deployment # don't fail calls if eg. redis fails to connect
def log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
"""
@ -91,7 +173,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
"""
Update TPM/RPM usage on success
Update TPM usage on success
"""
if kwargs["litellm_params"].get("metadata") is None:
pass
@ -117,8 +199,6 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
) # use the same timezone regardless of system clock
tpm_key = f"{id}:tpm:{current_minute}"
rpm_key = f"{id}:rpm:{current_minute}"
# ------------
# Update usage
# ------------
@ -128,8 +208,6 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
await self.router_cache.async_increment_cache(
key=tpm_key, value=total_tokens
)
## RPM
await self.router_cache.async_increment_cache(key=rpm_key, value=1)
### TESTING ###
if self.test_flag:

View file

@ -3,7 +3,7 @@ from typing import List, Optional, Union, Dict, Tuple, Literal
from pydantic import BaseModel, validator
from .completion import CompletionRequest
from .embedding import EmbeddingRequest
import uuid
import uuid, enum
class ModelConfig(BaseModel):
@ -166,3 +166,11 @@ class Deployment(BaseModel):
def __setitem__(self, key, value):
# Allow dictionary-style assignment of attributes
setattr(self, key, value)
class RouterErrors(enum.Enum):
"""
Enum for router specific errors with common codes
"""
user_defined_ratelimit_error = "Deployment over user-defined ratelimit."

View file

@ -67,12 +67,12 @@ litellm_settings:
telemetry: False
context_window_fallbacks: [{"gpt-3.5-turbo": ["gpt-3.5-turbo-large"]}]
# router_settings:
# routing_strategy: usage-based-routing-v2
# redis_host: os.environ/REDIS_HOST
# redis_password: os.environ/REDIS_PASSWORD
# redis_port: os.environ/REDIS_PORT
# enable_pre_call_checks: true
router_settings:
routing_strategy: usage-based-routing-v2
redis_host: os.environ/REDIS_HOST
redis_password: os.environ/REDIS_PASSWORD
redis_port: os.environ/REDIS_PORT
enable_pre_call_checks: true
general_settings:
master_key: sk-1234 # [OPTIONAL] Use to enforce auth on proxy. See - https://docs.litellm.ai/docs/proxy/virtual_keys

View file

@ -194,7 +194,7 @@ async def test_chat_completion():
await chat_completion(session=session, key=key_2)
@pytest.mark.skip(reason="Local test. Proxy not concurrency safe yet. WIP.")
# @pytest.mark.skip(reason="Local test. Proxy not concurrency safe yet. WIP.")
@pytest.mark.asyncio
async def test_chat_completion_ratelimit():
"""