feat(router.py): support fastest response batch completion call

returns fastest response. cancels others.
This commit is contained in:
Krrish Dholakia 2024-05-28 19:44:41 -07:00
parent 7b565271e2
commit 3676c00235
2 changed files with 102 additions and 17 deletions

View file

@ -356,7 +356,8 @@ class Router:
raise ValueError(f"Item '{fallback_dict}' is not a dictionary.")
if len(fallback_dict) != 1:
raise ValueError(
f"Dictionary '{fallback_dict}' must have exactly one key, but has {len(fallback_dict)} keys.")
f"Dictionary '{fallback_dict}' must have exactly one key, but has {len(fallback_dict)} keys."
)
def routing_strategy_init(self, routing_strategy: str, routing_strategy_args: dict):
if routing_strategy == "least-busy":
@ -737,6 +738,76 @@ class Router:
response = await asyncio.gather(*_tasks)
return response
# fmt: off
@overload
async def abatch_completion_fastest_response(
self, models: List[str], messages: List[Dict[str, str]], stream: Literal[True], **kwargs
) -> CustomStreamWrapper:
...
@overload
async def abatch_completion_fastest_response(
self, models: List[str], messages: List[Dict[str, str]], stream: Literal[False] = False, **kwargs
) -> ModelResponse:
...
# fmt: on
async def abatch_completion_fastest_response(
self,
models: List[str],
messages: List[Dict[str, str]],
stream: bool = False,
**kwargs,
):
"""Send 1 completion call to many models: Return Fastest Response."""
async def _async_completion_no_exceptions(
model: str, messages: List[Dict[str, str]], **kwargs
):
"""
Wrapper around self.async_completion that catches exceptions and returns them as a result
"""
try:
return await self.acompletion(model=model, messages=messages, **kwargs)
except Exception as e:
return e
_tasks = []
pending_tasks = [] # type: ignore
async def check_response(task):
nonlocal pending_tasks
result = await task
if isinstance(result, (ModelResponse, CustomStreamWrapper)):
# If a desired response is received, cancel all other pending tasks
for t in pending_tasks:
t.cancel()
return result
else:
try:
pending_tasks.remove(task)
except Exception as e:
pass
for model in models:
task = asyncio.create_task(
_async_completion_no_exceptions(
model=model, messages=messages, **kwargs
)
)
task.add_done_callback(check_response)
_tasks.append(task)
pending_tasks.append(task)
responses = await asyncio.gather(*_tasks, return_exceptions=True)
if isinstance(responses[0], Exception):
raise responses[0]
return responses[0] # return first value from list
def image_generation(self, prompt: str, model: str, **kwargs):
try:
kwargs["model"] = model

View file

@ -19,8 +19,9 @@ import os, httpx
load_dotenv()
@pytest.mark.parametrize("mode", ["all_responses", "fastest_response"])
@pytest.mark.asyncio
async def test_batch_completion_multiple_models():
async def test_batch_completion_multiple_models(mode):
litellm.set_verbose = True
router = litellm.Router(
@ -40,21 +41,34 @@ async def test_batch_completion_multiple_models():
]
)
response = await router.abatch_completion(
models=["gpt-3.5-turbo", "groq-llama"],
messages=[
{"role": "user", "content": "is litellm becoming a better product ?"}
],
max_tokens=15,
)
if mode == "all_responses":
response = await router.abatch_completion(
models=["gpt-3.5-turbo", "groq-llama"],
messages=[
{"role": "user", "content": "is litellm becoming a better product ?"}
],
max_tokens=15,
)
print(response)
assert len(response) == 2
print(response)
assert len(response) == 2
models_in_responses = []
for individual_response in response:
_model = individual_response["model"]
models_in_responses.append(_model)
models_in_responses = []
for individual_response in response:
_model = individual_response["model"]
models_in_responses.append(_model)
# assert both models are different
assert models_in_responses[0] != models_in_responses[1]
# assert both models are different
assert models_in_responses[0] != models_in_responses[1]
elif mode == "fastest_response":
from openai.types.chat.chat_completion import ChatCompletion
response = await router.abatch_completion_fastest_response(
models=["gpt-3.5-turbo", "groq-llama"],
messages=[
{"role": "user", "content": "is litellm becoming a better product ?"}
],
max_tokens=15,
)
ChatCompletion.model_validate(response.model_dump(), strict=True)