mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
37 lines
1.3 KiB
Python
37 lines
1.3 KiB
Python
import asyncio
|
|
from typing import TYPE_CHECKING, Any
|
|
|
|
from litellm.utils import calculate_max_parallel_requests
|
|
|
|
if TYPE_CHECKING:
|
|
from litellm.router import Router as _Router
|
|
|
|
LitellmRouter = _Router
|
|
else:
|
|
LitellmRouter = Any
|
|
|
|
|
|
class InitalizeCachedClient:
|
|
@staticmethod
|
|
def set_max_parallel_requests_client(
|
|
litellm_router_instance: LitellmRouter, model: dict
|
|
):
|
|
litellm_params = model.get("litellm_params", {})
|
|
model_id = model["model_info"]["id"]
|
|
rpm = litellm_params.get("rpm", None)
|
|
tpm = litellm_params.get("tpm", None)
|
|
max_parallel_requests = litellm_params.get("max_parallel_requests", None)
|
|
calculated_max_parallel_requests = calculate_max_parallel_requests(
|
|
rpm=rpm,
|
|
max_parallel_requests=max_parallel_requests,
|
|
tpm=tpm,
|
|
default_max_parallel_requests=litellm_router_instance.default_max_parallel_requests,
|
|
)
|
|
if calculated_max_parallel_requests:
|
|
semaphore = asyncio.Semaphore(calculated_max_parallel_requests)
|
|
cache_key = f"{model_id}_max_parallel_requests_client"
|
|
litellm_router_instance.cache.set_cache(
|
|
key=cache_key,
|
|
value=semaphore,
|
|
local_only=True,
|
|
)
|