forked from phoenix/litellm-mirror
Merge pull request #2996 from BerriAI/litellm_semaphores
fix(router.py): initial commit for semaphores on router
This commit is contained in:
commit
fd7760d3db
8 changed files with 198 additions and 41 deletions
|
@ -189,6 +189,9 @@ jobs:
|
||||||
-p 4000:4000 \
|
-p 4000:4000 \
|
||||||
-e DATABASE_URL=$PROXY_DOCKER_DB_URL \
|
-e DATABASE_URL=$PROXY_DOCKER_DB_URL \
|
||||||
-e AZURE_API_KEY=$AZURE_API_KEY \
|
-e AZURE_API_KEY=$AZURE_API_KEY \
|
||||||
|
-e REDIS_HOST=$REDIS_HOST \
|
||||||
|
-e REDIS_PASSWORD=$REDIS_PASSWORD \
|
||||||
|
-e REDIS_PORT=$REDIS_PORT \
|
||||||
-e AZURE_FRANCE_API_KEY=$AZURE_FRANCE_API_KEY \
|
-e AZURE_FRANCE_API_KEY=$AZURE_FRANCE_API_KEY \
|
||||||
-e AZURE_EUROPE_API_KEY=$AZURE_EUROPE_API_KEY \
|
-e AZURE_EUROPE_API_KEY=$AZURE_EUROPE_API_KEY \
|
||||||
-e AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID \
|
-e AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID \
|
||||||
|
|
|
@ -98,11 +98,12 @@ class InMemoryCache(BaseCache):
|
||||||
return_val.append(val)
|
return_val.append(val)
|
||||||
return return_val
|
return return_val
|
||||||
|
|
||||||
async def async_increment(self, key, value: int, **kwargs):
|
async def async_increment(self, key, value: int, **kwargs) -> int:
|
||||||
# get the value
|
# get the value
|
||||||
init_value = await self.async_get_cache(key=key) or 0
|
init_value = await self.async_get_cache(key=key) or 0
|
||||||
value = init_value + value
|
value = init_value + value
|
||||||
await self.async_set_cache(key, value, **kwargs)
|
await self.async_set_cache(key, value, **kwargs)
|
||||||
|
return value
|
||||||
|
|
||||||
def flush_cache(self):
|
def flush_cache(self):
|
||||||
self.cache_dict.clear()
|
self.cache_dict.clear()
|
||||||
|
@ -266,11 +267,12 @@ class RedisCache(BaseCache):
|
||||||
if len(self.redis_batch_writing_buffer) >= self.redis_flush_size:
|
if len(self.redis_batch_writing_buffer) >= self.redis_flush_size:
|
||||||
await self.flush_cache_buffer()
|
await self.flush_cache_buffer()
|
||||||
|
|
||||||
async def async_increment(self, key, value: int, **kwargs):
|
async def async_increment(self, key, value: int, **kwargs) -> int:
|
||||||
_redis_client = self.init_async_client()
|
_redis_client = self.init_async_client()
|
||||||
try:
|
try:
|
||||||
async with _redis_client as redis_client:
|
async with _redis_client as redis_client:
|
||||||
await redis_client.incr(name=key, amount=value)
|
result = await redis_client.incr(name=key, amount=value)
|
||||||
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_logger.error(
|
verbose_logger.error(
|
||||||
"LiteLLM Redis Caching: async async_increment() - Got exception from REDIS %s, Writing value=%s",
|
"LiteLLM Redis Caching: async async_increment() - Got exception from REDIS %s, Writing value=%s",
|
||||||
|
@ -278,6 +280,7 @@ class RedisCache(BaseCache):
|
||||||
value,
|
value,
|
||||||
)
|
)
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
raise e
|
||||||
|
|
||||||
async def flush_cache_buffer(self):
|
async def flush_cache_buffer(self):
|
||||||
print_verbose(
|
print_verbose(
|
||||||
|
@ -1076,21 +1079,29 @@ class DualCache(BaseCache):
|
||||||
|
|
||||||
async def async_increment_cache(
|
async def async_increment_cache(
|
||||||
self, key, value: int, local_only: bool = False, **kwargs
|
self, key, value: int, local_only: bool = False, **kwargs
|
||||||
):
|
) -> int:
|
||||||
"""
|
"""
|
||||||
Key - the key in cache
|
Key - the key in cache
|
||||||
|
|
||||||
Value - int - the value you want to increment by
|
Value - int - the value you want to increment by
|
||||||
|
|
||||||
|
Returns - int - the incremented value
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
result: int = value
|
||||||
if self.in_memory_cache is not None:
|
if self.in_memory_cache is not None:
|
||||||
await self.in_memory_cache.async_increment(key, value, **kwargs)
|
result = await self.in_memory_cache.async_increment(
|
||||||
|
key, value, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
if self.redis_cache is not None and local_only == False:
|
if self.redis_cache is not None and local_only == False:
|
||||||
await self.redis_cache.async_increment(key, value, **kwargs)
|
result = await self.redis_cache.async_increment(key, value, **kwargs)
|
||||||
|
|
||||||
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print_verbose(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
|
print_verbose(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
raise e
|
||||||
|
|
||||||
def flush_cache(self):
|
def flush_cache(self):
|
||||||
if self.in_memory_cache is not None:
|
if self.in_memory_cache is not None:
|
||||||
|
|
|
@ -1836,6 +1836,9 @@ async def _run_background_health_check():
|
||||||
await asyncio.sleep(health_check_interval)
|
await asyncio.sleep(health_check_interval)
|
||||||
|
|
||||||
|
|
||||||
|
semaphore = asyncio.Semaphore(1)
|
||||||
|
|
||||||
|
|
||||||
class ProxyConfig:
|
class ProxyConfig:
|
||||||
"""
|
"""
|
||||||
Abstraction class on top of config loading/updating logic. Gives us one place to control all config updating logic.
|
Abstraction class on top of config loading/updating logic. Gives us one place to control all config updating logic.
|
||||||
|
@ -2425,8 +2428,7 @@ class ProxyConfig:
|
||||||
for k, v in router_settings.items():
|
for k, v in router_settings.items():
|
||||||
if k in available_args:
|
if k in available_args:
|
||||||
router_params[k] = v
|
router_params[k] = v
|
||||||
|
router = litellm.Router(**router_params, semaphore=semaphore) # type:ignore
|
||||||
router = litellm.Router(**router_params) # type:ignore
|
|
||||||
return router, model_list, general_settings
|
return router, model_list, general_settings
|
||||||
|
|
||||||
async def add_deployment(
|
async def add_deployment(
|
||||||
|
@ -3421,6 +3423,7 @@ async def chat_completion(
|
||||||
):
|
):
|
||||||
global general_settings, user_debug, proxy_logging_obj, llm_model_list
|
global general_settings, user_debug, proxy_logging_obj, llm_model_list
|
||||||
try:
|
try:
|
||||||
|
# async with llm_router.sem
|
||||||
data = {}
|
data = {}
|
||||||
body = await request.body()
|
body = await request.body()
|
||||||
body_str = body.decode()
|
body_str = body.decode()
|
||||||
|
@ -3525,7 +3528,9 @@ async def chat_completion(
|
||||||
tasks = []
|
tasks = []
|
||||||
tasks.append(
|
tasks.append(
|
||||||
proxy_logging_obj.during_call_hook(
|
proxy_logging_obj.during_call_hook(
|
||||||
data=data, user_api_key_dict=user_api_key_dict, call_type="completion"
|
data=data,
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
call_type="completion",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -30,7 +30,7 @@ from litellm.utils import ModelResponse, CustomStreamWrapper, get_utc_datetime
|
||||||
import copy
|
import copy
|
||||||
from litellm._logging import verbose_router_logger
|
from litellm._logging import verbose_router_logger
|
||||||
import logging
|
import logging
|
||||||
from litellm.types.router import Deployment, ModelInfo, LiteLLM_Params
|
from litellm.types.router import Deployment, ModelInfo, LiteLLM_Params, RouterErrors
|
||||||
|
|
||||||
|
|
||||||
class Router:
|
class Router:
|
||||||
|
@ -78,6 +78,7 @@ class Router:
|
||||||
"latency-based-routing",
|
"latency-based-routing",
|
||||||
] = "simple-shuffle",
|
] = "simple-shuffle",
|
||||||
routing_strategy_args: dict = {}, # just for latency-based routing
|
routing_strategy_args: dict = {}, # just for latency-based routing
|
||||||
|
semaphore: Optional[asyncio.Semaphore] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Initialize the Router class with the given parameters for caching, reliability, and routing strategy.
|
Initialize the Router class with the given parameters for caching, reliability, and routing strategy.
|
||||||
|
@ -143,6 +144,8 @@ class Router:
|
||||||
router = Router(model_list=model_list, fallbacks=[{"azure-gpt-3.5-turbo": "openai-gpt-3.5-turbo"}])
|
router = Router(model_list=model_list, fallbacks=[{"azure-gpt-3.5-turbo": "openai-gpt-3.5-turbo"}])
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
if semaphore:
|
||||||
|
self.semaphore = semaphore
|
||||||
self.set_verbose = set_verbose
|
self.set_verbose = set_verbose
|
||||||
self.debug_level = debug_level
|
self.debug_level = debug_level
|
||||||
self.enable_pre_call_checks = enable_pre_call_checks
|
self.enable_pre_call_checks = enable_pre_call_checks
|
||||||
|
@ -409,11 +412,18 @@ class Router:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
async def _acompletion(self, model: str, messages: List[Dict[str, str]], **kwargs):
|
async def _acompletion(self, model: str, messages: List[Dict[str, str]], **kwargs):
|
||||||
|
"""
|
||||||
|
- Get an available deployment
|
||||||
|
- call it with a semaphore over the call
|
||||||
|
- semaphore specific to it's rpm
|
||||||
|
- in the semaphore, make a check against it's local rpm before running
|
||||||
|
"""
|
||||||
model_name = None
|
model_name = None
|
||||||
try:
|
try:
|
||||||
verbose_router_logger.debug(
|
verbose_router_logger.debug(
|
||||||
f"Inside _acompletion()- model: {model}; kwargs: {kwargs}"
|
f"Inside _acompletion()- model: {model}; kwargs: {kwargs}"
|
||||||
)
|
)
|
||||||
|
|
||||||
deployment = await self.async_get_available_deployment(
|
deployment = await self.async_get_available_deployment(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
@ -443,6 +453,7 @@ class Router:
|
||||||
potential_model_client = self._get_client(
|
potential_model_client = self._get_client(
|
||||||
deployment=deployment, kwargs=kwargs, client_type="async"
|
deployment=deployment, kwargs=kwargs, client_type="async"
|
||||||
)
|
)
|
||||||
|
|
||||||
# check if provided keys == client keys #
|
# check if provided keys == client keys #
|
||||||
dynamic_api_key = kwargs.get("api_key", None)
|
dynamic_api_key = kwargs.get("api_key", None)
|
||||||
if (
|
if (
|
||||||
|
@ -465,7 +476,7 @@ class Router:
|
||||||
) # this uses default_litellm_params when nothing is set
|
) # this uses default_litellm_params when nothing is set
|
||||||
)
|
)
|
||||||
|
|
||||||
response = await litellm.acompletion(
|
_response = litellm.acompletion(
|
||||||
**{
|
**{
|
||||||
**data,
|
**data,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
|
@ -475,6 +486,25 @@ class Router:
|
||||||
**kwargs,
|
**kwargs,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
rpm_semaphore = self._get_client(
|
||||||
|
deployment=deployment, kwargs=kwargs, client_type="rpm_client"
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
rpm_semaphore is not None
|
||||||
|
and isinstance(rpm_semaphore, asyncio.Semaphore)
|
||||||
|
and self.routing_strategy == "usage-based-routing-v2"
|
||||||
|
):
|
||||||
|
async with rpm_semaphore:
|
||||||
|
"""
|
||||||
|
- Check rpm limits before making the call
|
||||||
|
"""
|
||||||
|
await self.lowesttpm_logger_v2.pre_call_rpm_check(deployment)
|
||||||
|
response = await _response
|
||||||
|
else:
|
||||||
|
response = await _response
|
||||||
|
|
||||||
self.success_calls[model_name] += 1
|
self.success_calls[model_name] += 1
|
||||||
verbose_router_logger.info(
|
verbose_router_logger.info(
|
||||||
f"litellm.acompletion(model={model_name})\033[32m 200 OK\033[0m"
|
f"litellm.acompletion(model={model_name})\033[32m 200 OK\033[0m"
|
||||||
|
@ -1265,6 +1295,8 @@ class Router:
|
||||||
min_timeout=self.retry_after,
|
min_timeout=self.retry_after,
|
||||||
)
|
)
|
||||||
await asyncio.sleep(timeout)
|
await asyncio.sleep(timeout)
|
||||||
|
elif RouterErrors.user_defined_ratelimit_error.value in str(e):
|
||||||
|
raise e # don't wait to retry if deployment hits user-defined rate-limit
|
||||||
elif hasattr(original_exception, "status_code") and litellm._should_retry(
|
elif hasattr(original_exception, "status_code") and litellm._should_retry(
|
||||||
status_code=original_exception.status_code
|
status_code=original_exception.status_code
|
||||||
):
|
):
|
||||||
|
@ -1680,12 +1712,26 @@ class Router:
|
||||||
|
|
||||||
def set_client(self, model: dict):
|
def set_client(self, model: dict):
|
||||||
"""
|
"""
|
||||||
Initializes Azure/OpenAI clients. Stores them in cache, b/c of this - https://github.com/BerriAI/litellm/issues/1278
|
- Initializes Azure/OpenAI clients. Stores them in cache, b/c of this - https://github.com/BerriAI/litellm/issues/1278
|
||||||
|
- Initializes Semaphore for client w/ rpm. Stores them in cache. b/c of this - https://github.com/BerriAI/litellm/issues/2994
|
||||||
"""
|
"""
|
||||||
client_ttl = self.client_ttl
|
client_ttl = self.client_ttl
|
||||||
litellm_params = model.get("litellm_params", {})
|
litellm_params = model.get("litellm_params", {})
|
||||||
model_name = litellm_params.get("model")
|
model_name = litellm_params.get("model")
|
||||||
model_id = model["model_info"]["id"]
|
model_id = model["model_info"]["id"]
|
||||||
|
# ### IF RPM SET - initialize a semaphore ###
|
||||||
|
rpm = litellm_params.get("rpm", None)
|
||||||
|
if rpm:
|
||||||
|
semaphore = asyncio.Semaphore(rpm)
|
||||||
|
cache_key = f"{model_id}_rpm_client"
|
||||||
|
self.cache.set_cache(
|
||||||
|
key=cache_key,
|
||||||
|
value=semaphore,
|
||||||
|
local_only=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# print("STORES SEMAPHORE IN CACHE")
|
||||||
|
|
||||||
#### for OpenAI / Azure we need to initalize the Client for High Traffic ########
|
#### for OpenAI / Azure we need to initalize the Client for High Traffic ########
|
||||||
custom_llm_provider = litellm_params.get("custom_llm_provider")
|
custom_llm_provider = litellm_params.get("custom_llm_provider")
|
||||||
custom_llm_provider = custom_llm_provider or model_name.split("/", 1)[0] or ""
|
custom_llm_provider = custom_llm_provider or model_name.split("/", 1)[0] or ""
|
||||||
|
@ -2275,7 +2321,11 @@ class Router:
|
||||||
The appropriate client based on the given client_type and kwargs.
|
The appropriate client based on the given client_type and kwargs.
|
||||||
"""
|
"""
|
||||||
model_id = deployment["model_info"]["id"]
|
model_id = deployment["model_info"]["id"]
|
||||||
if client_type == "async":
|
if client_type == "rpm_client":
|
||||||
|
cache_key = "{}_rpm_client".format(model_id)
|
||||||
|
client = self.cache.get_cache(key=cache_key, local_only=True)
|
||||||
|
return client
|
||||||
|
elif client_type == "async":
|
||||||
if kwargs.get("stream") == True:
|
if kwargs.get("stream") == True:
|
||||||
cache_key = f"{model_id}_stream_async_client"
|
cache_key = f"{model_id}_stream_async_client"
|
||||||
client = self.cache.get_cache(key=cache_key, local_only=True)
|
client = self.cache.get_cache(key=cache_key, local_only=True)
|
||||||
|
@ -2328,6 +2378,7 @@ class Router:
|
||||||
Filter out model in model group, if:
|
Filter out model in model group, if:
|
||||||
|
|
||||||
- model context window < message length
|
- model context window < message length
|
||||||
|
- filter models above rpm limits
|
||||||
- [TODO] function call and model doesn't support function calling
|
- [TODO] function call and model doesn't support function calling
|
||||||
"""
|
"""
|
||||||
verbose_router_logger.debug(
|
verbose_router_logger.debug(
|
||||||
|
@ -2352,7 +2403,7 @@ class Router:
|
||||||
rpm_key = f"{model}:rpm:{current_minute}"
|
rpm_key = f"{model}:rpm:{current_minute}"
|
||||||
model_group_cache = (
|
model_group_cache = (
|
||||||
self.cache.get_cache(key=rpm_key, local_only=True) or {}
|
self.cache.get_cache(key=rpm_key, local_only=True) or {}
|
||||||
) # check the redis + in-memory cache used by lowest_latency and usage-based routing. Only check the local cache.
|
) # check the in-memory cache used by lowest_latency and usage-based routing. Only check the local cache.
|
||||||
for idx, deployment in enumerate(_returned_deployments):
|
for idx, deployment in enumerate(_returned_deployments):
|
||||||
# see if we have the info for this model
|
# see if we have the info for this model
|
||||||
try:
|
try:
|
||||||
|
@ -2388,6 +2439,7 @@ class Router:
|
||||||
self.cache.get_cache(key=model_id, local_only=True) or 0
|
self.cache.get_cache(key=model_id, local_only=True) or 0
|
||||||
)
|
)
|
||||||
### get usage based cache ###
|
### get usage based cache ###
|
||||||
|
if isinstance(model_group_cache, dict):
|
||||||
model_group_cache[model_id] = model_group_cache.get(model_id, 0)
|
model_group_cache[model_id] = model_group_cache.get(model_id, 0)
|
||||||
|
|
||||||
current_request = max(
|
current_request = max(
|
||||||
|
|
|
@ -7,12 +7,14 @@ import datetime as datetime_og
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||||
import traceback, asyncio
|
import traceback, asyncio, httpx
|
||||||
|
import litellm
|
||||||
from litellm import token_counter
|
from litellm import token_counter
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
from litellm._logging import verbose_router_logger
|
from litellm._logging import verbose_router_logger
|
||||||
from litellm.utils import print_verbose, get_utc_datetime
|
from litellm.utils import print_verbose, get_utc_datetime
|
||||||
|
from litellm.types.router import RouterErrors
|
||||||
|
|
||||||
|
|
||||||
class LowestTPMLoggingHandler_v2(CustomLogger):
|
class LowestTPMLoggingHandler_v2(CustomLogger):
|
||||||
|
@ -37,6 +39,86 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
|
||||||
self.router_cache = router_cache
|
self.router_cache = router_cache
|
||||||
self.model_list = model_list
|
self.model_list = model_list
|
||||||
|
|
||||||
|
async def pre_call_rpm_check(self, deployment: dict) -> dict:
|
||||||
|
"""
|
||||||
|
Pre-call check + update model rpm
|
||||||
|
- Used inside semaphore
|
||||||
|
- raise rate limit error if deployment over limit
|
||||||
|
|
||||||
|
Why? solves concurrency issue - https://github.com/BerriAI/litellm/issues/2994
|
||||||
|
|
||||||
|
Returns - deployment
|
||||||
|
|
||||||
|
Raises - RateLimitError if deployment over defined RPM limit
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
|
||||||
|
# ------------
|
||||||
|
# Setup values
|
||||||
|
# ------------
|
||||||
|
dt = get_utc_datetime()
|
||||||
|
current_minute = dt.strftime("%H-%M")
|
||||||
|
model_group = deployment.get("model_name", "")
|
||||||
|
rpm_key = f"{model_group}:rpm:{current_minute}"
|
||||||
|
local_result = await self.router_cache.async_get_cache(
|
||||||
|
key=rpm_key, local_only=True
|
||||||
|
) # check local result first
|
||||||
|
|
||||||
|
deployment_rpm = None
|
||||||
|
if deployment_rpm is None:
|
||||||
|
deployment_rpm = deployment.get("rpm")
|
||||||
|
if deployment_rpm is None:
|
||||||
|
deployment_rpm = deployment.get("litellm_params", {}).get("rpm")
|
||||||
|
if deployment_rpm is None:
|
||||||
|
deployment_rpm = deployment.get("model_info", {}).get("rpm")
|
||||||
|
if deployment_rpm is None:
|
||||||
|
deployment_rpm = float("inf")
|
||||||
|
|
||||||
|
if local_result is not None and local_result >= deployment_rpm:
|
||||||
|
raise litellm.RateLimitError(
|
||||||
|
message="Deployment over defined rpm limit={}. current usage={}".format(
|
||||||
|
deployment_rpm, local_result
|
||||||
|
),
|
||||||
|
llm_provider="",
|
||||||
|
model=deployment.get("litellm_params", {}).get("model"),
|
||||||
|
response=httpx.Response(
|
||||||
|
status_code=429,
|
||||||
|
content="{} rpm limit={}. current usage={}".format(
|
||||||
|
RouterErrors.user_defined_ratelimit_error.value,
|
||||||
|
deployment_rpm,
|
||||||
|
local_result,
|
||||||
|
),
|
||||||
|
request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# if local result below limit, check redis ## prevent unnecessary redis checks
|
||||||
|
result = await self.router_cache.async_increment_cache(
|
||||||
|
key=rpm_key, value=1
|
||||||
|
)
|
||||||
|
if result is not None and result > deployment_rpm:
|
||||||
|
raise litellm.RateLimitError(
|
||||||
|
message="Deployment over defined rpm limit={}. current usage={}".format(
|
||||||
|
deployment_rpm, result
|
||||||
|
),
|
||||||
|
llm_provider="",
|
||||||
|
model=deployment.get("litellm_params", {}).get("model"),
|
||||||
|
response=httpx.Response(
|
||||||
|
status_code=429,
|
||||||
|
content="{} rpm limit={}. current usage={}".format(
|
||||||
|
RouterErrors.user_defined_ratelimit_error.value,
|
||||||
|
deployment_rpm,
|
||||||
|
result,
|
||||||
|
),
|
||||||
|
request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return deployment
|
||||||
|
except Exception as e:
|
||||||
|
if isinstance(e, litellm.RateLimitError):
|
||||||
|
raise e
|
||||||
|
return deployment # don't fail calls if eg. redis fails to connect
|
||||||
|
|
||||||
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
try:
|
try:
|
||||||
"""
|
"""
|
||||||
|
@ -91,7 +173,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
|
||||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
try:
|
try:
|
||||||
"""
|
"""
|
||||||
Update TPM/RPM usage on success
|
Update TPM usage on success
|
||||||
"""
|
"""
|
||||||
if kwargs["litellm_params"].get("metadata") is None:
|
if kwargs["litellm_params"].get("metadata") is None:
|
||||||
pass
|
pass
|
||||||
|
@ -117,8 +199,6 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
|
||||||
) # use the same timezone regardless of system clock
|
) # use the same timezone regardless of system clock
|
||||||
|
|
||||||
tpm_key = f"{id}:tpm:{current_minute}"
|
tpm_key = f"{id}:tpm:{current_minute}"
|
||||||
rpm_key = f"{id}:rpm:{current_minute}"
|
|
||||||
|
|
||||||
# ------------
|
# ------------
|
||||||
# Update usage
|
# Update usage
|
||||||
# ------------
|
# ------------
|
||||||
|
@ -128,8 +208,6 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
|
||||||
await self.router_cache.async_increment_cache(
|
await self.router_cache.async_increment_cache(
|
||||||
key=tpm_key, value=total_tokens
|
key=tpm_key, value=total_tokens
|
||||||
)
|
)
|
||||||
## RPM
|
|
||||||
await self.router_cache.async_increment_cache(key=rpm_key, value=1)
|
|
||||||
|
|
||||||
### TESTING ###
|
### TESTING ###
|
||||||
if self.test_flag:
|
if self.test_flag:
|
||||||
|
|
|
@ -3,7 +3,7 @@ from typing import List, Optional, Union, Dict, Tuple, Literal
|
||||||
from pydantic import BaseModel, validator
|
from pydantic import BaseModel, validator
|
||||||
from .completion import CompletionRequest
|
from .completion import CompletionRequest
|
||||||
from .embedding import EmbeddingRequest
|
from .embedding import EmbeddingRequest
|
||||||
import uuid
|
import uuid, enum
|
||||||
|
|
||||||
|
|
||||||
class ModelConfig(BaseModel):
|
class ModelConfig(BaseModel):
|
||||||
|
@ -166,3 +166,11 @@ class Deployment(BaseModel):
|
||||||
def __setitem__(self, key, value):
|
def __setitem__(self, key, value):
|
||||||
# Allow dictionary-style assignment of attributes
|
# Allow dictionary-style assignment of attributes
|
||||||
setattr(self, key, value)
|
setattr(self, key, value)
|
||||||
|
|
||||||
|
|
||||||
|
class RouterErrors(enum.Enum):
|
||||||
|
"""
|
||||||
|
Enum for router specific errors with common codes
|
||||||
|
"""
|
||||||
|
|
||||||
|
user_defined_ratelimit_error = "Deployment over user-defined ratelimit."
|
||||||
|
|
|
@ -67,12 +67,12 @@ litellm_settings:
|
||||||
telemetry: False
|
telemetry: False
|
||||||
context_window_fallbacks: [{"gpt-3.5-turbo": ["gpt-3.5-turbo-large"]}]
|
context_window_fallbacks: [{"gpt-3.5-turbo": ["gpt-3.5-turbo-large"]}]
|
||||||
|
|
||||||
# router_settings:
|
router_settings:
|
||||||
# routing_strategy: usage-based-routing-v2
|
routing_strategy: usage-based-routing-v2
|
||||||
# redis_host: os.environ/REDIS_HOST
|
redis_host: os.environ/REDIS_HOST
|
||||||
# redis_password: os.environ/REDIS_PASSWORD
|
redis_password: os.environ/REDIS_PASSWORD
|
||||||
# redis_port: os.environ/REDIS_PORT
|
redis_port: os.environ/REDIS_PORT
|
||||||
# enable_pre_call_checks: true
|
enable_pre_call_checks: true
|
||||||
|
|
||||||
general_settings:
|
general_settings:
|
||||||
master_key: sk-1234 # [OPTIONAL] Use to enforce auth on proxy. See - https://docs.litellm.ai/docs/proxy/virtual_keys
|
master_key: sk-1234 # [OPTIONAL] Use to enforce auth on proxy. See - https://docs.litellm.ai/docs/proxy/virtual_keys
|
||||||
|
|
|
@ -194,7 +194,7 @@ async def test_chat_completion():
|
||||||
await chat_completion(session=session, key=key_2)
|
await chat_completion(session=session, key=key_2)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="Local test. Proxy not concurrency safe yet. WIP.")
|
# @pytest.mark.skip(reason="Local test. Proxy not concurrency safe yet. WIP.")
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_chat_completion_ratelimit():
|
async def test_chat_completion_ratelimit():
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue