mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
fix(router.py): initial commit for semaphores on router
This commit is contained in:
parent
74aa230eac
commit
a4e415b23c
3 changed files with 66 additions and 7 deletions
|
@ -5,7 +5,7 @@ model_list:
|
||||||
api_key: my-fake-key
|
api_key: my-fake-key
|
||||||
api_base: https://openai-function-calling-workers.tasslexyz.workers.dev/
|
api_base: https://openai-function-calling-workers.tasslexyz.workers.dev/
|
||||||
stream_timeout: 0.001
|
stream_timeout: 0.001
|
||||||
rpm: 10
|
rpm: 1
|
||||||
- litellm_params:
|
- litellm_params:
|
||||||
model: azure/chatgpt-v-2
|
model: azure/chatgpt-v-2
|
||||||
api_base: os.environ/AZURE_API_BASE
|
api_base: os.environ/AZURE_API_BASE
|
||||||
|
|
|
@ -1836,6 +1836,9 @@ async def _run_background_health_check():
|
||||||
await asyncio.sleep(health_check_interval)
|
await asyncio.sleep(health_check_interval)
|
||||||
|
|
||||||
|
|
||||||
|
semaphore = asyncio.Semaphore(1)
|
||||||
|
|
||||||
|
|
||||||
class ProxyConfig:
|
class ProxyConfig:
|
||||||
"""
|
"""
|
||||||
Abstraction class on top of config loading/updating logic. Gives us one place to control all config updating logic.
|
Abstraction class on top of config loading/updating logic. Gives us one place to control all config updating logic.
|
||||||
|
@ -2425,8 +2428,7 @@ class ProxyConfig:
|
||||||
for k, v in router_settings.items():
|
for k, v in router_settings.items():
|
||||||
if k in available_args:
|
if k in available_args:
|
||||||
router_params[k] = v
|
router_params[k] = v
|
||||||
|
router = litellm.Router(**router_params, semaphore=semaphore) # type:ignore
|
||||||
router = litellm.Router(**router_params) # type:ignore
|
|
||||||
return router, model_list, general_settings
|
return router, model_list, general_settings
|
||||||
|
|
||||||
async def add_deployment(
|
async def add_deployment(
|
||||||
|
@ -3421,6 +3423,7 @@ async def chat_completion(
|
||||||
):
|
):
|
||||||
global general_settings, user_debug, proxy_logging_obj, llm_model_list
|
global general_settings, user_debug, proxy_logging_obj, llm_model_list
|
||||||
try:
|
try:
|
||||||
|
# async with llm_router.sem
|
||||||
data = {}
|
data = {}
|
||||||
body = await request.body()
|
body = await request.body()
|
||||||
body_str = body.decode()
|
body_str = body.decode()
|
||||||
|
@ -3525,7 +3528,9 @@ async def chat_completion(
|
||||||
tasks = []
|
tasks = []
|
||||||
tasks.append(
|
tasks.append(
|
||||||
proxy_logging_obj.during_call_hook(
|
proxy_logging_obj.during_call_hook(
|
||||||
data=data, user_api_key_dict=user_api_key_dict, call_type="completion"
|
data=data,
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
call_type="completion",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -78,6 +78,7 @@ class Router:
|
||||||
"latency-based-routing",
|
"latency-based-routing",
|
||||||
] = "simple-shuffle",
|
] = "simple-shuffle",
|
||||||
routing_strategy_args: dict = {}, # just for latency-based routing
|
routing_strategy_args: dict = {}, # just for latency-based routing
|
||||||
|
semaphore: Optional[asyncio.Semaphore] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Initialize the Router class with the given parameters for caching, reliability, and routing strategy.
|
Initialize the Router class with the given parameters for caching, reliability, and routing strategy.
|
||||||
|
@ -143,6 +144,8 @@ class Router:
|
||||||
router = Router(model_list=model_list, fallbacks=[{"azure-gpt-3.5-turbo": "openai-gpt-3.5-turbo"}])
|
router = Router(model_list=model_list, fallbacks=[{"azure-gpt-3.5-turbo": "openai-gpt-3.5-turbo"}])
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
if semaphore:
|
||||||
|
self.semaphore = semaphore
|
||||||
self.set_verbose = set_verbose
|
self.set_verbose = set_verbose
|
||||||
self.debug_level = debug_level
|
self.debug_level = debug_level
|
||||||
self.enable_pre_call_checks = enable_pre_call_checks
|
self.enable_pre_call_checks = enable_pre_call_checks
|
||||||
|
@ -409,11 +412,18 @@ class Router:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
async def _acompletion(self, model: str, messages: List[Dict[str, str]], **kwargs):
|
async def _acompletion(self, model: str, messages: List[Dict[str, str]], **kwargs):
|
||||||
|
"""
|
||||||
|
- Get an available deployment
|
||||||
|
- call it with a semaphore over the call
|
||||||
|
- semaphore specific to it's rpm
|
||||||
|
- in the semaphore, make a check against it's local rpm before running
|
||||||
|
"""
|
||||||
model_name = None
|
model_name = None
|
||||||
try:
|
try:
|
||||||
verbose_router_logger.debug(
|
verbose_router_logger.debug(
|
||||||
f"Inside _acompletion()- model: {model}; kwargs: {kwargs}"
|
f"Inside _acompletion()- model: {model}; kwargs: {kwargs}"
|
||||||
)
|
)
|
||||||
|
|
||||||
deployment = await self.async_get_available_deployment(
|
deployment = await self.async_get_available_deployment(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
@ -443,6 +453,7 @@ class Router:
|
||||||
potential_model_client = self._get_client(
|
potential_model_client = self._get_client(
|
||||||
deployment=deployment, kwargs=kwargs, client_type="async"
|
deployment=deployment, kwargs=kwargs, client_type="async"
|
||||||
)
|
)
|
||||||
|
|
||||||
# check if provided keys == client keys #
|
# check if provided keys == client keys #
|
||||||
dynamic_api_key = kwargs.get("api_key", None)
|
dynamic_api_key = kwargs.get("api_key", None)
|
||||||
if (
|
if (
|
||||||
|
@ -465,7 +476,7 @@ class Router:
|
||||||
) # this uses default_litellm_params when nothing is set
|
) # this uses default_litellm_params when nothing is set
|
||||||
)
|
)
|
||||||
|
|
||||||
response = await litellm.acompletion(
|
_response = litellm.acompletion(
|
||||||
**{
|
**{
|
||||||
**data,
|
**data,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
|
@ -475,6 +486,30 @@ class Router:
|
||||||
**kwargs,
|
**kwargs,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
rpm_semaphore = self.semaphore
|
||||||
|
|
||||||
|
if rpm_semaphore is not None and isinstance(
|
||||||
|
rpm_semaphore, asyncio.Semaphore
|
||||||
|
):
|
||||||
|
async with rpm_semaphore:
|
||||||
|
"""
|
||||||
|
- Check against in-memory tpm/rpm limits before making the call
|
||||||
|
"""
|
||||||
|
dt = get_utc_datetime()
|
||||||
|
current_minute = dt.strftime("%H-%M")
|
||||||
|
id = kwargs["model_info"]["id"]
|
||||||
|
rpm_key = "{}:rpm:{}".format(id, current_minute)
|
||||||
|
curr_rpm = await self.cache.async_get_cache(key=rpm_key)
|
||||||
|
if (
|
||||||
|
curr_rpm is not None and curr_rpm >= data["rpm"]
|
||||||
|
): # >= b/c the initial count is 0
|
||||||
|
raise Exception("Rate Limit error")
|
||||||
|
await self.cache.async_increment_cache(key=rpm_key, value=1)
|
||||||
|
response = await _response
|
||||||
|
else:
|
||||||
|
response = await _response
|
||||||
|
|
||||||
self.success_calls[model_name] += 1
|
self.success_calls[model_name] += 1
|
||||||
verbose_router_logger.info(
|
verbose_router_logger.info(
|
||||||
f"litellm.acompletion(model={model_name})\033[32m 200 OK\033[0m"
|
f"litellm.acompletion(model={model_name})\033[32m 200 OK\033[0m"
|
||||||
|
@ -1680,12 +1715,27 @@ class Router:
|
||||||
|
|
||||||
def set_client(self, model: dict):
|
def set_client(self, model: dict):
|
||||||
"""
|
"""
|
||||||
Initializes Azure/OpenAI clients. Stores them in cache, b/c of this - https://github.com/BerriAI/litellm/issues/1278
|
- Initializes Azure/OpenAI clients. Stores them in cache, b/c of this - https://github.com/BerriAI/litellm/issues/1278
|
||||||
|
- Initializes Semaphore for client w/ rpm. Stores them in cache. b/c of this - https://github.com/BerriAI/litellm/issues/2994
|
||||||
"""
|
"""
|
||||||
client_ttl = self.client_ttl
|
client_ttl = self.client_ttl
|
||||||
litellm_params = model.get("litellm_params", {})
|
litellm_params = model.get("litellm_params", {})
|
||||||
model_name = litellm_params.get("model")
|
model_name = litellm_params.get("model")
|
||||||
model_id = model["model_info"]["id"]
|
model_id = model["model_info"]["id"]
|
||||||
|
# ### IF RPM SET - initialize a semaphore ###
|
||||||
|
# rpm = litellm_params.get("rpm", None)
|
||||||
|
# print(f"rpm: {rpm}")
|
||||||
|
# if rpm:
|
||||||
|
# semaphore = asyncio.Semaphore(rpm)
|
||||||
|
# cache_key = f"{model_id}_rpm_client"
|
||||||
|
# self.cache.set_cache(
|
||||||
|
# key=cache_key,
|
||||||
|
# value=semaphore,
|
||||||
|
# local_only=True,
|
||||||
|
# )
|
||||||
|
|
||||||
|
# print("STORES SEMAPHORE IN CACHE")
|
||||||
|
|
||||||
#### for OpenAI / Azure we need to initalize the Client for High Traffic ########
|
#### for OpenAI / Azure we need to initalize the Client for High Traffic ########
|
||||||
custom_llm_provider = litellm_params.get("custom_llm_provider")
|
custom_llm_provider = litellm_params.get("custom_llm_provider")
|
||||||
custom_llm_provider = custom_llm_provider or model_name.split("/", 1)[0] or ""
|
custom_llm_provider = custom_llm_provider or model_name.split("/", 1)[0] or ""
|
||||||
|
@ -2275,7 +2325,11 @@ class Router:
|
||||||
The appropriate client based on the given client_type and kwargs.
|
The appropriate client based on the given client_type and kwargs.
|
||||||
"""
|
"""
|
||||||
model_id = deployment["model_info"]["id"]
|
model_id = deployment["model_info"]["id"]
|
||||||
if client_type == "async":
|
if client_type == "rpm_client":
|
||||||
|
cache_key = "{}_rpm_client".format(model_id)
|
||||||
|
client = self.cache.get_cache(key=cache_key, local_only=True)
|
||||||
|
return client
|
||||||
|
elif client_type == "async":
|
||||||
if kwargs.get("stream") == True:
|
if kwargs.get("stream") == True:
|
||||||
cache_key = f"{model_id}_stream_async_client"
|
cache_key = f"{model_id}_stream_async_client"
|
||||||
client = self.cache.get_cache(key=cache_key, local_only=True)
|
client = self.cache.get_cache(key=cache_key, local_only=True)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue