feat - router add abatch_completion

This commit is contained in:
Ishaan Jaff 2024-05-28 22:19:33 -07:00
parent 6bf6059b3e
commit 473ec66b84
2 changed files with 96 additions and 13 deletions

View file

@ -356,7 +356,8 @@ class Router:
raise ValueError(f"Item '{fallback_dict}' is not a dictionary.")
if len(fallback_dict) != 1:
raise ValueError(
f"Dictionary '{fallback_dict}' must have exactly one key, but has {len(fallback_dict)} keys.")
f"Dictionary '{fallback_dict}' must have exactly one key, but has {len(fallback_dict)} keys."
)
def routing_strategy_init(self, routing_strategy: str, routing_strategy_args: dict):
if routing_strategy == "least-busy":
@ -662,12 +663,17 @@ class Router:
raise e
async def abatch_completion(
self, models: List[str], messages: List[Dict[str, str]], **kwargs
self,
models: List[str],
messages: Union[List[Dict[str, str]], List[List[Dict[str, str]]]],
**kwargs,
):
"""
Async Batch Completion - Batch Process 1 request to multiple model_group on litellm.Router
Use this for sending the same request to N models
Async Batch Completion. Used for 2 scenarios:
1. Batch Process 1 request to N models on litellm.Router. Pass messages as List[Dict[str, str]] to use this
2. Batch Process N requests to M models on litellm.Router. Pass messages as List[List[Dict[str, str]]] to use this
"""
############## Helpers for async completion ##################
async def _async_completion_no_exceptions(
model: str, messages: List[Dict[str, str]], **kwargs
@ -680,17 +686,50 @@ class Router:
except Exception as e:
return e
_tasks = []
for model in models:
# add each task but if the task fails
_tasks.append(
_async_completion_no_exceptions(
model=model, messages=messages, **kwargs
async def _async_completion_no_exceptions_return_idx(
model: str,
messages: List[Dict[str, str]],
idx: int, # index of message this response corresponds to
**kwargs,
):
"""
Wrapper around self.async_completion that catches exceptions and returns them as a result
"""
try:
return (
await self.acompletion(model=model, messages=messages, **kwargs),
idx,
)
)
except Exception as e:
return e, idx
response = await asyncio.gather(*_tasks)
return response
############## Helpers for async completion ##################
if isinstance(messages, list) and all(isinstance(m, dict) for m in messages):
_tasks = []
for model in models:
# add each task but if the task fails
_tasks.append(_async_completion_no_exceptions(model=model, messages=messages, **kwargs)) # type: ignore
response = await asyncio.gather(*_tasks)
return response
elif isinstance(messages, list) and all(isinstance(m, list) for m in messages):
_tasks = []
for idx, message in enumerate(messages):
for model in models:
# Request Number X, Model Number Y
_tasks.append(
_async_completion_no_exceptions_return_idx(
model=model, idx=idx, messages=message, **kwargs # type: ignore
)
)
responses = await asyncio.gather(*_tasks)
final_responses: List[List[Any]] = [[] for _ in range(len(messages))]
for response in responses:
if isinstance(response, tuple):
final_responses[response[1]].append(response[0])
else:
final_responses[0].append(response)
return final_responses
async def abatch_completion_one_model_multiple_requests(
self, model: str, messages: List[List[Dict[str, str]]], **kwargs