fix(router.py): initial commit for semaphores on router

This commit is contained in:
Krrish Dholakia 2024-04-12 17:59:05 -07:00
parent 74aa230eac
commit a4e415b23c
3 changed files with 66 additions and 7 deletions

View file

@ -5,7 +5,7 @@ model_list:
api_key: my-fake-key
api_base: https://openai-function-calling-workers.tasslexyz.workers.dev/
stream_timeout: 0.001
rpm: 10
rpm: 1
- litellm_params:
model: azure/chatgpt-v-2
api_base: os.environ/AZURE_API_BASE

View file

@ -1836,6 +1836,9 @@ async def _run_background_health_check():
await asyncio.sleep(health_check_interval)
semaphore = asyncio.Semaphore(1)
class ProxyConfig:
"""
Abstraction class on top of config loading/updating logic. Gives us one place to control all config updating logic.
@ -2425,8 +2428,7 @@ class ProxyConfig:
for k, v in router_settings.items():
if k in available_args:
router_params[k] = v
router = litellm.Router(**router_params) # type:ignore
router = litellm.Router(**router_params, semaphore=semaphore) # type:ignore
return router, model_list, general_settings
async def add_deployment(
@ -3421,6 +3423,7 @@ async def chat_completion(
):
global general_settings, user_debug, proxy_logging_obj, llm_model_list
try:
# async with llm_router.sem
data = {}
body = await request.body()
body_str = body.decode()
@ -3525,7 +3528,9 @@ async def chat_completion(
tasks = []
tasks.append(
proxy_logging_obj.during_call_hook(
data=data, user_api_key_dict=user_api_key_dict, call_type="completion"
data=data,
user_api_key_dict=user_api_key_dict,
call_type="completion",
)
)

View file

@ -78,6 +78,7 @@ class Router:
"latency-based-routing",
] = "simple-shuffle",
routing_strategy_args: dict = {}, # just for latency-based routing
semaphore: Optional[asyncio.Semaphore] = None,
) -> None:
"""
Initialize the Router class with the given parameters for caching, reliability, and routing strategy.
@ -143,6 +144,8 @@ class Router:
router = Router(model_list=model_list, fallbacks=[{"azure-gpt-3.5-turbo": "openai-gpt-3.5-turbo"}])
```
"""
if semaphore:
self.semaphore = semaphore
self.set_verbose = set_verbose
self.debug_level = debug_level
self.enable_pre_call_checks = enable_pre_call_checks
@ -409,11 +412,18 @@ class Router:
raise e
async def _acompletion(self, model: str, messages: List[Dict[str, str]], **kwargs):
"""
- Get an available deployment
- call it with a semaphore over the call
- semaphore specific to it's rpm
- in the semaphore, make a check against it's local rpm before running
"""
model_name = None
try:
verbose_router_logger.debug(
f"Inside _acompletion()- model: {model}; kwargs: {kwargs}"
)
deployment = await self.async_get_available_deployment(
model=model,
messages=messages,
@ -443,6 +453,7 @@ class Router:
potential_model_client = self._get_client(
deployment=deployment, kwargs=kwargs, client_type="async"
)
# check if provided keys == client keys #
dynamic_api_key = kwargs.get("api_key", None)
if (
@ -465,7 +476,7 @@ class Router:
) # this uses default_litellm_params when nothing is set
)
response = await litellm.acompletion(
_response = litellm.acompletion(
**{
**data,
"messages": messages,
@ -475,6 +486,30 @@ class Router:
**kwargs,
}
)
rpm_semaphore = self.semaphore
if rpm_semaphore is not None and isinstance(
rpm_semaphore, asyncio.Semaphore
):
async with rpm_semaphore:
"""
- Check against in-memory tpm/rpm limits before making the call
"""
dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M")
id = kwargs["model_info"]["id"]
rpm_key = "{}:rpm:{}".format(id, current_minute)
curr_rpm = await self.cache.async_get_cache(key=rpm_key)
if (
curr_rpm is not None and curr_rpm >= data["rpm"]
): # >= b/c the initial count is 0
raise Exception("Rate Limit error")
await self.cache.async_increment_cache(key=rpm_key, value=1)
response = await _response
else:
response = await _response
self.success_calls[model_name] += 1
verbose_router_logger.info(
f"litellm.acompletion(model={model_name})\033[32m 200 OK\033[0m"
@ -1680,12 +1715,27 @@ class Router:
def set_client(self, model: dict):
"""
Initializes Azure/OpenAI clients. Stores them in cache, b/c of this - https://github.com/BerriAI/litellm/issues/1278
- Initializes Azure/OpenAI clients. Stores them in cache, b/c of this - https://github.com/BerriAI/litellm/issues/1278
- Initializes Semaphore for client w/ rpm. Stores them in cache. b/c of this - https://github.com/BerriAI/litellm/issues/2994
"""
client_ttl = self.client_ttl
litellm_params = model.get("litellm_params", {})
model_name = litellm_params.get("model")
model_id = model["model_info"]["id"]
# ### IF RPM SET - initialize a semaphore ###
# rpm = litellm_params.get("rpm", None)
# print(f"rpm: {rpm}")
# if rpm:
# semaphore = asyncio.Semaphore(rpm)
# cache_key = f"{model_id}_rpm_client"
# self.cache.set_cache(
# key=cache_key,
# value=semaphore,
# local_only=True,
# )
# print("STORES SEMAPHORE IN CACHE")
#### for OpenAI / Azure we need to initalize the Client for High Traffic ########
custom_llm_provider = litellm_params.get("custom_llm_provider")
custom_llm_provider = custom_llm_provider or model_name.split("/", 1)[0] or ""
@ -2275,7 +2325,11 @@ class Router:
The appropriate client based on the given client_type and kwargs.
"""
model_id = deployment["model_info"]["id"]
if client_type == "async":
if client_type == "rpm_client":
cache_key = "{}_rpm_client".format(model_id)
client = self.cache.get_cache(key=cache_key, local_only=True)
return client
elif client_type == "async":
if kwargs.get("stream") == True:
cache_key = f"{model_id}_stream_async_client"
client = self.cache.get_cache(key=cache_key, local_only=True)