feat(router.py): support fastest response batch completion call

returns fastest response. cancels others.
This commit is contained in:
Krrish Dholakia 2024-05-28 19:44:41 -07:00
parent 3558f06de2
commit ecd182eb6a
2 changed files with 102 additions and 17 deletions

View file

@ -356,7 +356,8 @@ class Router:
raise ValueError(f"Item '{fallback_dict}' is not a dictionary.")
if len(fallback_dict) != 1:
raise ValueError(
f"Dictionary '{fallback_dict}' must have exactly one key, but has {len(fallback_dict)} keys.")
f"Dictionary '{fallback_dict}' must have exactly one key, but has {len(fallback_dict)} keys."
)
def routing_strategy_init(self, routing_strategy: str, routing_strategy_args: dict):
if routing_strategy == "least-busy":
@ -737,6 +738,76 @@ class Router:
response = await asyncio.gather(*_tasks)
return response
# fmt: off
@overload
async def abatch_completion_fastest_response(
self, models: List[str], messages: List[Dict[str, str]], stream: Literal[True], **kwargs
) -> CustomStreamWrapper:
...
@overload
async def abatch_completion_fastest_response(
self, models: List[str], messages: List[Dict[str, str]], stream: Literal[False] = False, **kwargs
) -> ModelResponse:
...
# fmt: on
async def abatch_completion_fastest_response(
self,
models: List[str],
messages: List[Dict[str, str]],
stream: bool = False,
**kwargs,
):
"""Send 1 completion call to many models: Return Fastest Response."""
async def _async_completion_no_exceptions(
model: str, messages: List[Dict[str, str]], **kwargs
):
"""
Wrapper around self.async_completion that catches exceptions and returns them as a result
"""
try:
return await self.acompletion(model=model, messages=messages, **kwargs)
except Exception as e:
return e
_tasks = []
pending_tasks = [] # type: ignore
async def check_response(task):
nonlocal pending_tasks
result = await task
if isinstance(result, (ModelResponse, CustomStreamWrapper)):
# If a desired response is received, cancel all other pending tasks
for t in pending_tasks:
t.cancel()
return result
else:
try:
pending_tasks.remove(task)
except Exception as e:
pass
for model in models:
task = asyncio.create_task(
_async_completion_no_exceptions(
model=model, messages=messages, **kwargs
)
)
task.add_done_callback(check_response)
_tasks.append(task)
pending_tasks.append(task)
responses = await asyncio.gather(*_tasks, return_exceptions=True)
if isinstance(responses[0], Exception):
raise responses[0]
return responses[0] # return first value from list
def image_generation(self, prompt: str, model: str, **kwargs):
try:
kwargs["model"] = model