forked from phoenix/litellm-mirror
fix(router.py): initial commit for semaphores on router
This commit is contained in:
parent
74aa230eac
commit
a4e415b23c
3 changed files with 66 additions and 7 deletions
|
@ -5,7 +5,7 @@ model_list:
|
|||
api_key: my-fake-key
|
||||
api_base: https://openai-function-calling-workers.tasslexyz.workers.dev/
|
||||
stream_timeout: 0.001
|
||||
rpm: 10
|
||||
rpm: 1
|
||||
- litellm_params:
|
||||
model: azure/chatgpt-v-2
|
||||
api_base: os.environ/AZURE_API_BASE
|
||||
|
|
|
@ -1836,6 +1836,9 @@ async def _run_background_health_check():
|
|||
await asyncio.sleep(health_check_interval)
|
||||
|
||||
|
||||
semaphore = asyncio.Semaphore(1)
|
||||
|
||||
|
||||
class ProxyConfig:
|
||||
"""
|
||||
Abstraction class on top of config loading/updating logic. Gives us one place to control all config updating logic.
|
||||
|
@ -2425,8 +2428,7 @@ class ProxyConfig:
|
|||
for k, v in router_settings.items():
|
||||
if k in available_args:
|
||||
router_params[k] = v
|
||||
|
||||
router = litellm.Router(**router_params) # type:ignore
|
||||
router = litellm.Router(**router_params, semaphore=semaphore) # type:ignore
|
||||
return router, model_list, general_settings
|
||||
|
||||
async def add_deployment(
|
||||
|
@ -3421,6 +3423,7 @@ async def chat_completion(
|
|||
):
|
||||
global general_settings, user_debug, proxy_logging_obj, llm_model_list
|
||||
try:
|
||||
# async with llm_router.sem
|
||||
data = {}
|
||||
body = await request.body()
|
||||
body_str = body.decode()
|
||||
|
@ -3525,7 +3528,9 @@ async def chat_completion(
|
|||
tasks = []
|
||||
tasks.append(
|
||||
proxy_logging_obj.during_call_hook(
|
||||
data=data, user_api_key_dict=user_api_key_dict, call_type="completion"
|
||||
data=data,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
call_type="completion",
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
@ -78,6 +78,7 @@ class Router:
|
|||
"latency-based-routing",
|
||||
] = "simple-shuffle",
|
||||
routing_strategy_args: dict = {}, # just for latency-based routing
|
||||
semaphore: Optional[asyncio.Semaphore] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the Router class with the given parameters for caching, reliability, and routing strategy.
|
||||
|
@ -143,6 +144,8 @@ class Router:
|
|||
router = Router(model_list=model_list, fallbacks=[{"azure-gpt-3.5-turbo": "openai-gpt-3.5-turbo"}])
|
||||
```
|
||||
"""
|
||||
if semaphore:
|
||||
self.semaphore = semaphore
|
||||
self.set_verbose = set_verbose
|
||||
self.debug_level = debug_level
|
||||
self.enable_pre_call_checks = enable_pre_call_checks
|
||||
|
@ -409,11 +412,18 @@ class Router:
|
|||
raise e
|
||||
|
||||
async def _acompletion(self, model: str, messages: List[Dict[str, str]], **kwargs):
|
||||
"""
|
||||
- Get an available deployment
|
||||
- call it with a semaphore over the call
|
||||
- semaphore specific to it's rpm
|
||||
- in the semaphore, make a check against it's local rpm before running
|
||||
"""
|
||||
model_name = None
|
||||
try:
|
||||
verbose_router_logger.debug(
|
||||
f"Inside _acompletion()- model: {model}; kwargs: {kwargs}"
|
||||
)
|
||||
|
||||
deployment = await self.async_get_available_deployment(
|
||||
model=model,
|
||||
messages=messages,
|
||||
|
@ -443,6 +453,7 @@ class Router:
|
|||
potential_model_client = self._get_client(
|
||||
deployment=deployment, kwargs=kwargs, client_type="async"
|
||||
)
|
||||
|
||||
# check if provided keys == client keys #
|
||||
dynamic_api_key = kwargs.get("api_key", None)
|
||||
if (
|
||||
|
@ -465,7 +476,7 @@ class Router:
|
|||
) # this uses default_litellm_params when nothing is set
|
||||
)
|
||||
|
||||
response = await litellm.acompletion(
|
||||
_response = litellm.acompletion(
|
||||
**{
|
||||
**data,
|
||||
"messages": messages,
|
||||
|
@ -475,6 +486,30 @@ class Router:
|
|||
**kwargs,
|
||||
}
|
||||
)
|
||||
|
||||
rpm_semaphore = self.semaphore
|
||||
|
||||
if rpm_semaphore is not None and isinstance(
|
||||
rpm_semaphore, asyncio.Semaphore
|
||||
):
|
||||
async with rpm_semaphore:
|
||||
"""
|
||||
- Check against in-memory tpm/rpm limits before making the call
|
||||
"""
|
||||
dt = get_utc_datetime()
|
||||
current_minute = dt.strftime("%H-%M")
|
||||
id = kwargs["model_info"]["id"]
|
||||
rpm_key = "{}:rpm:{}".format(id, current_minute)
|
||||
curr_rpm = await self.cache.async_get_cache(key=rpm_key)
|
||||
if (
|
||||
curr_rpm is not None and curr_rpm >= data["rpm"]
|
||||
): # >= b/c the initial count is 0
|
||||
raise Exception("Rate Limit error")
|
||||
await self.cache.async_increment_cache(key=rpm_key, value=1)
|
||||
response = await _response
|
||||
else:
|
||||
response = await _response
|
||||
|
||||
self.success_calls[model_name] += 1
|
||||
verbose_router_logger.info(
|
||||
f"litellm.acompletion(model={model_name})\033[32m 200 OK\033[0m"
|
||||
|
@ -1680,12 +1715,27 @@ class Router:
|
|||
|
||||
def set_client(self, model: dict):
|
||||
"""
|
||||
Initializes Azure/OpenAI clients. Stores them in cache, b/c of this - https://github.com/BerriAI/litellm/issues/1278
|
||||
- Initializes Azure/OpenAI clients. Stores them in cache, b/c of this - https://github.com/BerriAI/litellm/issues/1278
|
||||
- Initializes Semaphore for client w/ rpm. Stores them in cache. b/c of this - https://github.com/BerriAI/litellm/issues/2994
|
||||
"""
|
||||
client_ttl = self.client_ttl
|
||||
litellm_params = model.get("litellm_params", {})
|
||||
model_name = litellm_params.get("model")
|
||||
model_id = model["model_info"]["id"]
|
||||
# ### IF RPM SET - initialize a semaphore ###
|
||||
# rpm = litellm_params.get("rpm", None)
|
||||
# print(f"rpm: {rpm}")
|
||||
# if rpm:
|
||||
# semaphore = asyncio.Semaphore(rpm)
|
||||
# cache_key = f"{model_id}_rpm_client"
|
||||
# self.cache.set_cache(
|
||||
# key=cache_key,
|
||||
# value=semaphore,
|
||||
# local_only=True,
|
||||
# )
|
||||
|
||||
# print("STORES SEMAPHORE IN CACHE")
|
||||
|
||||
#### for OpenAI / Azure we need to initalize the Client for High Traffic ########
|
||||
custom_llm_provider = litellm_params.get("custom_llm_provider")
|
||||
custom_llm_provider = custom_llm_provider or model_name.split("/", 1)[0] or ""
|
||||
|
@ -2275,7 +2325,11 @@ class Router:
|
|||
The appropriate client based on the given client_type and kwargs.
|
||||
"""
|
||||
model_id = deployment["model_info"]["id"]
|
||||
if client_type == "async":
|
||||
if client_type == "rpm_client":
|
||||
cache_key = "{}_rpm_client".format(model_id)
|
||||
client = self.cache.get_cache(key=cache_key, local_only=True)
|
||||
return client
|
||||
elif client_type == "async":
|
||||
if kwargs.get("stream") == True:
|
||||
cache_key = f"{model_id}_stream_async_client"
|
||||
client = self.cache.get_cache(key=cache_key, local_only=True)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue