mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
feat(router.py): support fastest response batch completion call
returns fastest response. cancels others.
This commit is contained in:
parent
3558f06de2
commit
ecd182eb6a
2 changed files with 102 additions and 17 deletions
|
@ -356,7 +356,8 @@ class Router:
|
|||
raise ValueError(f"Item '{fallback_dict}' is not a dictionary.")
|
||||
if len(fallback_dict) != 1:
|
||||
raise ValueError(
|
||||
f"Dictionary '{fallback_dict}' must have exactly one key, but has {len(fallback_dict)} keys.")
|
||||
f"Dictionary '{fallback_dict}' must have exactly one key, but has {len(fallback_dict)} keys."
|
||||
)
|
||||
|
||||
def routing_strategy_init(self, routing_strategy: str, routing_strategy_args: dict):
|
||||
if routing_strategy == "least-busy":
|
||||
|
@ -737,6 +738,76 @@ class Router:
|
|||
response = await asyncio.gather(*_tasks)
|
||||
return response
|
||||
|
||||
# fmt: off
|
||||
|
||||
@overload
|
||||
async def abatch_completion_fastest_response(
|
||||
self, models: List[str], messages: List[Dict[str, str]], stream: Literal[True], **kwargs
|
||||
) -> CustomStreamWrapper:
|
||||
...
|
||||
|
||||
|
||||
|
||||
@overload
|
||||
async def abatch_completion_fastest_response(
|
||||
self, models: List[str], messages: List[Dict[str, str]], stream: Literal[False] = False, **kwargs
|
||||
) -> ModelResponse:
|
||||
...
|
||||
|
||||
# fmt: on
|
||||
|
||||
async def abatch_completion_fastest_response(
|
||||
self,
|
||||
models: List[str],
|
||||
messages: List[Dict[str, str]],
|
||||
stream: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""Send 1 completion call to many models: Return Fastest Response."""
|
||||
|
||||
async def _async_completion_no_exceptions(
|
||||
model: str, messages: List[Dict[str, str]], **kwargs
|
||||
):
|
||||
"""
|
||||
Wrapper around self.async_completion that catches exceptions and returns them as a result
|
||||
"""
|
||||
try:
|
||||
return await self.acompletion(model=model, messages=messages, **kwargs)
|
||||
except Exception as e:
|
||||
return e
|
||||
|
||||
_tasks = []
|
||||
pending_tasks = [] # type: ignore
|
||||
|
||||
async def check_response(task):
|
||||
nonlocal pending_tasks
|
||||
result = await task
|
||||
if isinstance(result, (ModelResponse, CustomStreamWrapper)):
|
||||
# If a desired response is received, cancel all other pending tasks
|
||||
for t in pending_tasks:
|
||||
t.cancel()
|
||||
return result
|
||||
else:
|
||||
try:
|
||||
pending_tasks.remove(task)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
for model in models:
|
||||
task = asyncio.create_task(
|
||||
_async_completion_no_exceptions(
|
||||
model=model, messages=messages, **kwargs
|
||||
)
|
||||
)
|
||||
task.add_done_callback(check_response)
|
||||
_tasks.append(task)
|
||||
pending_tasks.append(task)
|
||||
|
||||
responses = await asyncio.gather(*_tasks, return_exceptions=True)
|
||||
if isinstance(responses[0], Exception):
|
||||
raise responses[0]
|
||||
return responses[0] # return first value from list
|
||||
|
||||
def image_generation(self, prompt: str, model: str, **kwargs):
|
||||
try:
|
||||
kwargs["model"] = model
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue