mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
fix(router.py): simplify scheduler
move the scheduler poll queuing logic into the router class, making it easier to use
This commit is contained in:
parent
27087f60c2
commit
7715267989
5 changed files with 177 additions and 131 deletions
|
@ -62,6 +62,7 @@ from litellm.types.llms.openai import (
|
|||
Run,
|
||||
AssistantToolParam,
|
||||
)
|
||||
from litellm.scheduler import Scheduler, FlowItem
|
||||
from typing import Iterable
|
||||
|
||||
|
||||
|
@ -87,6 +88,8 @@ class Router:
|
|||
List[tuple]
|
||||
] = None, # if you want to cache across model groups
|
||||
client_ttl: int = 3600, # ttl for cached clients - will re-initialize after this time in seconds
|
||||
## SCHEDULER ##
|
||||
polling_interval: Optional[float] = None,
|
||||
## RELIABILITY ##
|
||||
num_retries: Optional[int] = None,
|
||||
timeout: Optional[float] = None,
|
||||
|
@ -141,7 +144,8 @@ class Router:
|
|||
cache_kwargs (dict): Additional kwargs to pass to RedisCache. Defaults to {}.
|
||||
caching_groups (Optional[List[tuple]]): List of model groups for caching across model groups. Defaults to None.
|
||||
client_ttl (int): Time-to-live for cached clients in seconds. Defaults to 3600.
|
||||
num_retries (int): Number of retries for failed requests. Defaults to 0.
|
||||
polling_interval: (Optional[float]): frequency of polling queue. Only for '.scheduler_acompletion()'. Default is 3ms.
|
||||
num_retries (Optional[int]): Number of retries for failed requests. Defaults to 2.
|
||||
timeout (Optional[float]): Timeout for requests. Defaults to None.
|
||||
default_litellm_params (dict): Default parameters for Router.chat.completion.create. Defaults to {}.
|
||||
set_verbose (bool): Flag to set verbose mode. Defaults to False.
|
||||
|
@ -208,6 +212,8 @@ class Router:
|
|||
[]
|
||||
) # names of models under litellm_params. ex. azure/chatgpt-v-2
|
||||
self.deployment_latency_map = {}
|
||||
### SCHEDULER ###
|
||||
self.scheduler = Scheduler(polling_interval=polling_interval)
|
||||
### CACHING ###
|
||||
cache_type: Literal["local", "redis"] = "local" # default to an in-memory cache
|
||||
redis_cache = None
|
||||
|
@ -533,11 +539,17 @@ class Router:
|
|||
) -> ModelResponse:
|
||||
...
|
||||
|
||||
@overload
|
||||
async def acompletion(
|
||||
self, model: str, messages: List[Dict[str, str]], stream: Union[Literal[True], Literal[False]] = False, **kwargs
|
||||
) -> Union[CustomStreamWrapper, ModelResponse]:
|
||||
...
|
||||
|
||||
# fmt: on
|
||||
|
||||
# The actual implementation of the function
|
||||
async def acompletion(
|
||||
self, model: str, messages: List[Dict[str, str]], stream=False, **kwargs
|
||||
self, model: str, messages: List[Dict[str, str]], stream: bool = False, **kwargs
|
||||
):
|
||||
try:
|
||||
kwargs["model"] = model
|
||||
|
@ -905,6 +917,81 @@ class Router:
|
|||
# If we exit the loop without returning, all tasks failed
|
||||
raise Exception("All tasks failed")
|
||||
|
||||
### SCHEDULER ###
|
||||
|
||||
# fmt: off
|
||||
|
||||
@overload
|
||||
async def schedule_acompletion(
|
||||
self, model: str, messages: List[Dict[str, str]], priority: int, stream: Literal[False] = False, **kwargs
|
||||
) -> ModelResponse:
|
||||
...
|
||||
|
||||
@overload
|
||||
async def schedule_acompletion(
|
||||
self, model: str, messages: List[Dict[str, str]], priority: int, stream: Literal[True], **kwargs
|
||||
) -> CustomStreamWrapper:
|
||||
...
|
||||
|
||||
# fmt: on
|
||||
|
||||
async def schedule_acompletion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Dict[str, str]],
|
||||
priority: int,
|
||||
stream=False,
|
||||
**kwargs,
|
||||
):
|
||||
### FLOW ITEM ###
|
||||
_request_id = str(uuid.uuid4())
|
||||
item = FlowItem(
|
||||
priority=priority, # 👈 SET PRIORITY FOR REQUEST
|
||||
request_id=_request_id, # 👈 SET REQUEST ID
|
||||
model_name="gpt-3.5-turbo", # 👈 SAME as 'Router'
|
||||
)
|
||||
### [fin] ###
|
||||
|
||||
## ADDS REQUEST TO QUEUE ##
|
||||
await self.scheduler.add_request(request=item)
|
||||
|
||||
## POLL QUEUE
|
||||
end_time = time.time() + self.timeout
|
||||
curr_time = time.time()
|
||||
poll_interval = self.scheduler.polling_interval # poll every 3ms
|
||||
make_request = False
|
||||
|
||||
while curr_time < end_time:
|
||||
_healthy_deployments = await self._async_get_healthy_deployments(
|
||||
model=model
|
||||
)
|
||||
make_request = await self.scheduler.poll( ## POLL QUEUE ## - returns 'True' if there's healthy deployments OR if request is at top of queue
|
||||
id=item.request_id,
|
||||
model_name=item.model_name,
|
||||
health_deployments=_healthy_deployments,
|
||||
)
|
||||
if make_request: ## IF TRUE -> MAKE REQUEST
|
||||
break
|
||||
else: ## ELSE -> loop till default_timeout
|
||||
await asyncio.sleep(poll_interval)
|
||||
curr_time = time.time()
|
||||
|
||||
if make_request:
|
||||
try:
|
||||
_response = await self.acompletion(
|
||||
model=model, messages=messages, stream=stream, **kwargs
|
||||
)
|
||||
return _response
|
||||
except Exception as e:
|
||||
setattr(e, "priority", priority)
|
||||
raise e
|
||||
else:
|
||||
raise litellm.Timeout(
|
||||
message="Request timed out while polling queue",
|
||||
model=model,
|
||||
llm_provider="openai",
|
||||
)
|
||||
|
||||
def image_generation(self, prompt: str, model: str, **kwargs):
|
||||
try:
|
||||
kwargs["model"] = model
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue