mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
feat(router.py): support fastest response batch completion call
returns fastest response. cancels others.
This commit is contained in:
parent
7b565271e2
commit
3676c00235
2 changed files with 102 additions and 17 deletions
|
@ -356,7 +356,8 @@ class Router:
|
||||||
raise ValueError(f"Item '{fallback_dict}' is not a dictionary.")
|
raise ValueError(f"Item '{fallback_dict}' is not a dictionary.")
|
||||||
if len(fallback_dict) != 1:
|
if len(fallback_dict) != 1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Dictionary '{fallback_dict}' must have exactly one key, but has {len(fallback_dict)} keys.")
|
f"Dictionary '{fallback_dict}' must have exactly one key, but has {len(fallback_dict)} keys."
|
||||||
|
)
|
||||||
|
|
||||||
def routing_strategy_init(self, routing_strategy: str, routing_strategy_args: dict):
|
def routing_strategy_init(self, routing_strategy: str, routing_strategy_args: dict):
|
||||||
if routing_strategy == "least-busy":
|
if routing_strategy == "least-busy":
|
||||||
|
@ -737,6 +738,76 @@ class Router:
|
||||||
response = await asyncio.gather(*_tasks)
|
response = await asyncio.gather(*_tasks)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
|
||||||
|
@overload
|
||||||
|
async def abatch_completion_fastest_response(
|
||||||
|
self, models: List[str], messages: List[Dict[str, str]], stream: Literal[True], **kwargs
|
||||||
|
) -> CustomStreamWrapper:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
async def abatch_completion_fastest_response(
|
||||||
|
self, models: List[str], messages: List[Dict[str, str]], stream: Literal[False] = False, **kwargs
|
||||||
|
) -> ModelResponse:
|
||||||
|
...
|
||||||
|
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
async def abatch_completion_fastest_response(
|
||||||
|
self,
|
||||||
|
models: List[str],
|
||||||
|
messages: List[Dict[str, str]],
|
||||||
|
stream: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""Send 1 completion call to many models: Return Fastest Response."""
|
||||||
|
|
||||||
|
async def _async_completion_no_exceptions(
|
||||||
|
model: str, messages: List[Dict[str, str]], **kwargs
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Wrapper around self.async_completion that catches exceptions and returns them as a result
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return await self.acompletion(model=model, messages=messages, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
return e
|
||||||
|
|
||||||
|
_tasks = []
|
||||||
|
pending_tasks = [] # type: ignore
|
||||||
|
|
||||||
|
async def check_response(task):
|
||||||
|
nonlocal pending_tasks
|
||||||
|
result = await task
|
||||||
|
if isinstance(result, (ModelResponse, CustomStreamWrapper)):
|
||||||
|
# If a desired response is received, cancel all other pending tasks
|
||||||
|
for t in pending_tasks:
|
||||||
|
t.cancel()
|
||||||
|
return result
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
pending_tasks.remove(task)
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
|
for model in models:
|
||||||
|
task = asyncio.create_task(
|
||||||
|
_async_completion_no_exceptions(
|
||||||
|
model=model, messages=messages, **kwargs
|
||||||
|
)
|
||||||
|
)
|
||||||
|
task.add_done_callback(check_response)
|
||||||
|
_tasks.append(task)
|
||||||
|
pending_tasks.append(task)
|
||||||
|
|
||||||
|
responses = await asyncio.gather(*_tasks, return_exceptions=True)
|
||||||
|
if isinstance(responses[0], Exception):
|
||||||
|
raise responses[0]
|
||||||
|
return responses[0] # return first value from list
|
||||||
|
|
||||||
def image_generation(self, prompt: str, model: str, **kwargs):
|
def image_generation(self, prompt: str, model: str, **kwargs):
|
||||||
try:
|
try:
|
||||||
kwargs["model"] = model
|
kwargs["model"] = model
|
||||||
|
|
|
@ -19,8 +19,9 @@ import os, httpx
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("mode", ["all_responses", "fastest_response"])
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_batch_completion_multiple_models():
|
async def test_batch_completion_multiple_models(mode):
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
|
|
||||||
router = litellm.Router(
|
router = litellm.Router(
|
||||||
|
@ -40,21 +41,34 @@ async def test_batch_completion_multiple_models():
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
response = await router.abatch_completion(
|
if mode == "all_responses":
|
||||||
models=["gpt-3.5-turbo", "groq-llama"],
|
response = await router.abatch_completion(
|
||||||
messages=[
|
models=["gpt-3.5-turbo", "groq-llama"],
|
||||||
{"role": "user", "content": "is litellm becoming a better product ?"}
|
messages=[
|
||||||
],
|
{"role": "user", "content": "is litellm becoming a better product ?"}
|
||||||
max_tokens=15,
|
],
|
||||||
)
|
max_tokens=15,
|
||||||
|
)
|
||||||
|
|
||||||
print(response)
|
print(response)
|
||||||
assert len(response) == 2
|
assert len(response) == 2
|
||||||
|
|
||||||
models_in_responses = []
|
models_in_responses = []
|
||||||
for individual_response in response:
|
for individual_response in response:
|
||||||
_model = individual_response["model"]
|
_model = individual_response["model"]
|
||||||
models_in_responses.append(_model)
|
models_in_responses.append(_model)
|
||||||
|
|
||||||
# assert both models are different
|
# assert both models are different
|
||||||
assert models_in_responses[0] != models_in_responses[1]
|
assert models_in_responses[0] != models_in_responses[1]
|
||||||
|
elif mode == "fastest_response":
|
||||||
|
from openai.types.chat.chat_completion import ChatCompletion
|
||||||
|
|
||||||
|
response = await router.abatch_completion_fastest_response(
|
||||||
|
models=["gpt-3.5-turbo", "groq-llama"],
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": "is litellm becoming a better product ?"}
|
||||||
|
],
|
||||||
|
max_tokens=15,
|
||||||
|
)
|
||||||
|
|
||||||
|
ChatCompletion.model_validate(response.model_dump(), strict=True)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue