Merge pull request #3889 from BerriAI/litellm_router_batch_n_models_m_messages

feat - router add abatch_completion - N Models, M Messages
This commit is contained in:
Ishaan Jaff 2024-05-28 22:29:34 -07:00 committed by GitHub
commit 1b53a1e98d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 119 additions and 13 deletions

View file

@ -356,7 +356,8 @@ class Router:
raise ValueError(f"Item '{fallback_dict}' is not a dictionary.") raise ValueError(f"Item '{fallback_dict}' is not a dictionary.")
if len(fallback_dict) != 1: if len(fallback_dict) != 1:
raise ValueError( raise ValueError(
f"Dictionary '{fallback_dict}' must have exactly one key, but has {len(fallback_dict)} keys.") f"Dictionary '{fallback_dict}' must have exactly one key, but has {len(fallback_dict)} keys."
)
def routing_strategy_init(self, routing_strategy: str, routing_strategy_args: dict): def routing_strategy_init(self, routing_strategy: str, routing_strategy_args: dict):
if routing_strategy == "least-busy": if routing_strategy == "least-busy":
@ -662,12 +663,40 @@ class Router:
raise e raise e
async def abatch_completion( async def abatch_completion(
self, models: List[str], messages: List[Dict[str, str]], **kwargs self,
models: List[str],
messages: Union[List[Dict[str, str]], List[List[Dict[str, str]]]],
**kwargs,
): ):
""" """
Async Batch Completion - Batch Process 1 request to multiple model_group on litellm.Router Async Batch Completion. Used for 2 scenarios:
Use this for sending the same request to N models 1. Batch Process 1 request to N models on litellm.Router. Pass messages as List[Dict[str, str]] to use this
2. Batch Process N requests to M models on litellm.Router. Pass messages as List[List[Dict[str, str]]] to use this
Example Request for 1 request to N models:
```
response = await router.abatch_completion(
models=["gpt-3.5-turbo", "groq-llama"],
messages=[
{"role": "user", "content": "is litellm becoming a better product ?"}
],
max_tokens=15,
)
```
Example Request for N requests to M models:
```
response = await router.abatch_completion(
models=["gpt-3.5-turbo", "groq-llama"],
messages=[
[{"role": "user", "content": "is litellm becoming a better product ?"}],
[{"role": "user", "content": "who is this"}],
],
)
```
""" """
############## Helpers for async completion ##################
async def _async_completion_no_exceptions( async def _async_completion_no_exceptions(
model: str, messages: List[Dict[str, str]], **kwargs model: str, messages: List[Dict[str, str]], **kwargs
@ -680,17 +709,50 @@ class Router:
except Exception as e: except Exception as e:
return e return e
_tasks = [] async def _async_completion_no_exceptions_return_idx(
for model in models: model: str,
# add each task but if the task fails messages: List[Dict[str, str]],
_tasks.append( idx: int, # index of message this response corresponds to
_async_completion_no_exceptions( **kwargs,
model=model, messages=messages, **kwargs ):
"""
Wrapper around self.async_completion that catches exceptions and returns them as a result
"""
try:
return (
await self.acompletion(model=model, messages=messages, **kwargs),
idx,
) )
) except Exception as e:
return e, idx
response = await asyncio.gather(*_tasks) ############## Helpers for async completion ##################
return response
if isinstance(messages, list) and all(isinstance(m, dict) for m in messages):
_tasks = []
for model in models:
# add each task but if the task fails
_tasks.append(_async_completion_no_exceptions(model=model, messages=messages, **kwargs)) # type: ignore
response = await asyncio.gather(*_tasks)
return response
elif isinstance(messages, list) and all(isinstance(m, list) for m in messages):
_tasks = []
for idx, message in enumerate(messages):
for model in models:
# Request Number X, Model Number Y
_tasks.append(
_async_completion_no_exceptions_return_idx(
model=model, idx=idx, messages=message, **kwargs # type: ignore
)
)
responses = await asyncio.gather(*_tasks)
final_responses: List[List[Any]] = [[] for _ in range(len(messages))]
for response in responses:
if isinstance(response, tuple):
final_responses[response[1]].append(response[0])
else:
final_responses[0].append(response)
return final_responses
async def abatch_completion_one_model_multiple_requests( async def abatch_completion_one_model_multiple_requests(
self, model: str, messages: List[List[Dict[str, str]]], **kwargs self, model: str, messages: List[List[Dict[str, str]]], **kwargs

View file

@ -58,3 +58,47 @@ async def test_batch_completion_multiple_models():
# assert both models are different # assert both models are different
assert models_in_responses[0] != models_in_responses[1] assert models_in_responses[0] != models_in_responses[1]
@pytest.mark.asyncio
async def test_batch_completion_multiple_models_multiple_messages():
litellm.set_verbose = True
router = litellm.Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo",
},
},
{
"model_name": "groq-llama",
"litellm_params": {
"model": "groq/llama3-8b-8192",
},
},
]
)
response = await router.abatch_completion(
models=["gpt-3.5-turbo", "groq-llama"],
messages=[
[{"role": "user", "content": "is litellm becoming a better product ?"}],
[{"role": "user", "content": "who is this"}],
],
max_tokens=15,
)
print("response from batches =", response)
assert len(response) == 2
assert len(response[0]) == 2
assert isinstance(response[0][0], litellm.ModelResponse)
# models_in_responses = []
# for individual_response in response:
# _model = individual_response["model"]
# models_in_responses.append(_model)
# # assert both models are different
# assert models_in_responses[0] != models_in_responses[1]