forked from phoenix/litellm-mirror
Merge pull request #3889 from BerriAI/litellm_router_batch_n_models_m_messages
feat - router add abatch_completion - N Models, M Messages
This commit is contained in:
commit
1b53a1e98d
2 changed files with 119 additions and 13 deletions
|
@ -356,7 +356,8 @@ class Router:
|
|||
raise ValueError(f"Item '{fallback_dict}' is not a dictionary.")
|
||||
if len(fallback_dict) != 1:
|
||||
raise ValueError(
|
||||
f"Dictionary '{fallback_dict}' must have exactly one key, but has {len(fallback_dict)} keys.")
|
||||
f"Dictionary '{fallback_dict}' must have exactly one key, but has {len(fallback_dict)} keys."
|
||||
)
|
||||
|
||||
def routing_strategy_init(self, routing_strategy: str, routing_strategy_args: dict):
|
||||
if routing_strategy == "least-busy":
|
||||
|
@ -662,12 +663,40 @@ class Router:
|
|||
raise e
|
||||
|
||||
async def abatch_completion(
|
||||
self, models: List[str], messages: List[Dict[str, str]], **kwargs
|
||||
self,
|
||||
models: List[str],
|
||||
messages: Union[List[Dict[str, str]], List[List[Dict[str, str]]]],
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Async Batch Completion - Batch Process 1 request to multiple model_group on litellm.Router
|
||||
Use this for sending the same request to N models
|
||||
Async Batch Completion. Used for 2 scenarios:
|
||||
1. Batch Process 1 request to N models on litellm.Router. Pass messages as List[Dict[str, str]] to use this
|
||||
2. Batch Process N requests to M models on litellm.Router. Pass messages as List[List[Dict[str, str]]] to use this
|
||||
|
||||
Example Request for 1 request to N models:
|
||||
```
|
||||
response = await router.abatch_completion(
|
||||
models=["gpt-3.5-turbo", "groq-llama"],
|
||||
messages=[
|
||||
{"role": "user", "content": "is litellm becoming a better product ?"}
|
||||
],
|
||||
max_tokens=15,
|
||||
)
|
||||
```
|
||||
|
||||
|
||||
Example Request for N requests to M models:
|
||||
```
|
||||
response = await router.abatch_completion(
|
||||
models=["gpt-3.5-turbo", "groq-llama"],
|
||||
messages=[
|
||||
[{"role": "user", "content": "is litellm becoming a better product ?"}],
|
||||
[{"role": "user", "content": "who is this"}],
|
||||
],
|
||||
)
|
||||
```
|
||||
"""
|
||||
############## Helpers for async completion ##################
|
||||
|
||||
async def _async_completion_no_exceptions(
|
||||
model: str, messages: List[Dict[str, str]], **kwargs
|
||||
|
@ -680,17 +709,50 @@ class Router:
|
|||
except Exception as e:
|
||||
return e
|
||||
|
||||
async def _async_completion_no_exceptions_return_idx(
|
||||
model: str,
|
||||
messages: List[Dict[str, str]],
|
||||
idx: int, # index of message this response corresponds to
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Wrapper around self.async_completion that catches exceptions and returns them as a result
|
||||
"""
|
||||
try:
|
||||
return (
|
||||
await self.acompletion(model=model, messages=messages, **kwargs),
|
||||
idx,
|
||||
)
|
||||
except Exception as e:
|
||||
return e, idx
|
||||
|
||||
############## Helpers for async completion ##################
|
||||
|
||||
if isinstance(messages, list) and all(isinstance(m, dict) for m in messages):
|
||||
_tasks = []
|
||||
for model in models:
|
||||
# add each task but if the task fails
|
||||
_tasks.append(
|
||||
_async_completion_no_exceptions(
|
||||
model=model, messages=messages, **kwargs
|
||||
)
|
||||
)
|
||||
|
||||
_tasks.append(_async_completion_no_exceptions(model=model, messages=messages, **kwargs)) # type: ignore
|
||||
response = await asyncio.gather(*_tasks)
|
||||
return response
|
||||
elif isinstance(messages, list) and all(isinstance(m, list) for m in messages):
|
||||
_tasks = []
|
||||
for idx, message in enumerate(messages):
|
||||
for model in models:
|
||||
# Request Number X, Model Number Y
|
||||
_tasks.append(
|
||||
_async_completion_no_exceptions_return_idx(
|
||||
model=model, idx=idx, messages=message, **kwargs # type: ignore
|
||||
)
|
||||
)
|
||||
responses = await asyncio.gather(*_tasks)
|
||||
final_responses: List[List[Any]] = [[] for _ in range(len(messages))]
|
||||
for response in responses:
|
||||
if isinstance(response, tuple):
|
||||
final_responses[response[1]].append(response[0])
|
||||
else:
|
||||
final_responses[0].append(response)
|
||||
return final_responses
|
||||
|
||||
async def abatch_completion_one_model_multiple_requests(
|
||||
self, model: str, messages: List[List[Dict[str, str]]], **kwargs
|
||||
|
|
|
@ -58,3 +58,47 @@ async def test_batch_completion_multiple_models():
|
|||
|
||||
# assert both models are different
|
||||
assert models_in_responses[0] != models_in_responses[1]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_completion_multiple_models_multiple_messages():
|
||||
litellm.set_verbose = True
|
||||
|
||||
router = litellm.Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "gpt-3.5-turbo",
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "groq-llama",
|
||||
"litellm_params": {
|
||||
"model": "groq/llama3-8b-8192",
|
||||
},
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
response = await router.abatch_completion(
|
||||
models=["gpt-3.5-turbo", "groq-llama"],
|
||||
messages=[
|
||||
[{"role": "user", "content": "is litellm becoming a better product ?"}],
|
||||
[{"role": "user", "content": "who is this"}],
|
||||
],
|
||||
max_tokens=15,
|
||||
)
|
||||
|
||||
print("response from batches =", response)
|
||||
assert len(response) == 2
|
||||
assert len(response[0]) == 2
|
||||
assert isinstance(response[0][0], litellm.ModelResponse)
|
||||
|
||||
# models_in_responses = []
|
||||
# for individual_response in response:
|
||||
# _model = individual_response["model"]
|
||||
# models_in_responses.append(_model)
|
||||
|
||||
# # assert both models are different
|
||||
# assert models_in_responses[0] != models_in_responses[1]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue