mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
feat - router add abatch_completion
This commit is contained in:
parent
6bf6059b3e
commit
473ec66b84
2 changed files with 96 additions and 13 deletions
|
@ -356,7 +356,8 @@ class Router:
|
|||
raise ValueError(f"Item '{fallback_dict}' is not a dictionary.")
|
||||
if len(fallback_dict) != 1:
|
||||
raise ValueError(
|
||||
f"Dictionary '{fallback_dict}' must have exactly one key, but has {len(fallback_dict)} keys.")
|
||||
f"Dictionary '{fallback_dict}' must have exactly one key, but has {len(fallback_dict)} keys."
|
||||
)
|
||||
|
||||
def routing_strategy_init(self, routing_strategy: str, routing_strategy_args: dict):
|
||||
if routing_strategy == "least-busy":
|
||||
|
@ -662,12 +663,17 @@ class Router:
|
|||
raise e
|
||||
|
||||
async def abatch_completion(
|
||||
self, models: List[str], messages: List[Dict[str, str]], **kwargs
|
||||
self,
|
||||
models: List[str],
|
||||
messages: Union[List[Dict[str, str]], List[List[Dict[str, str]]]],
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Async Batch Completion - Batch Process 1 request to multiple model_group on litellm.Router
|
||||
Use this for sending the same request to N models
|
||||
Async Batch Completion. Used for 2 scenarios:
|
||||
1. Batch Process 1 request to N models on litellm.Router. Pass messages as List[Dict[str, str]] to use this
|
||||
2. Batch Process N requests to M models on litellm.Router. Pass messages as List[List[Dict[str, str]]] to use this
|
||||
"""
|
||||
############## Helpers for async completion ##################
|
||||
|
||||
async def _async_completion_no_exceptions(
|
||||
model: str, messages: List[Dict[str, str]], **kwargs
|
||||
|
@ -680,17 +686,50 @@ class Router:
|
|||
except Exception as e:
|
||||
return e
|
||||
|
||||
_tasks = []
|
||||
for model in models:
|
||||
# add each task but if the task fails
|
||||
_tasks.append(
|
||||
_async_completion_no_exceptions(
|
||||
model=model, messages=messages, **kwargs
|
||||
async def _async_completion_no_exceptions_return_idx(
|
||||
model: str,
|
||||
messages: List[Dict[str, str]],
|
||||
idx: int, # index of message this response corresponds to
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Wrapper around self.async_completion that catches exceptions and returns them as a result
|
||||
"""
|
||||
try:
|
||||
return (
|
||||
await self.acompletion(model=model, messages=messages, **kwargs),
|
||||
idx,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
return e, idx
|
||||
|
||||
response = await asyncio.gather(*_tasks)
|
||||
return response
|
||||
############## Helpers for async completion ##################
|
||||
|
||||
if isinstance(messages, list) and all(isinstance(m, dict) for m in messages):
|
||||
_tasks = []
|
||||
for model in models:
|
||||
# add each task but if the task fails
|
||||
_tasks.append(_async_completion_no_exceptions(model=model, messages=messages, **kwargs)) # type: ignore
|
||||
response = await asyncio.gather(*_tasks)
|
||||
return response
|
||||
elif isinstance(messages, list) and all(isinstance(m, list) for m in messages):
|
||||
_tasks = []
|
||||
for idx, message in enumerate(messages):
|
||||
for model in models:
|
||||
# Request Number X, Model Number Y
|
||||
_tasks.append(
|
||||
_async_completion_no_exceptions_return_idx(
|
||||
model=model, idx=idx, messages=message, **kwargs # type: ignore
|
||||
)
|
||||
)
|
||||
responses = await asyncio.gather(*_tasks)
|
||||
final_responses: List[List[Any]] = [[] for _ in range(len(messages))]
|
||||
for response in responses:
|
||||
if isinstance(response, tuple):
|
||||
final_responses[response[1]].append(response[0])
|
||||
else:
|
||||
final_responses[0].append(response)
|
||||
return final_responses
|
||||
|
||||
async def abatch_completion_one_model_multiple_requests(
|
||||
self, model: str, messages: List[List[Dict[str, str]]], **kwargs
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue