Merge pull request #3887 from BerriAI/litellm_batch_completions

feat(router.py): support fastest response batch completion call
This commit is contained in:
Krish Dholakia 2024-05-28 22:38:12 -07:00 committed by GitHub
commit 0114207b2e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 316 additions and 64 deletions

View file

@ -61,6 +61,7 @@ jobs:
pip install prometheus-client==0.20.0 pip install prometheus-client==0.20.0
pip install "pydantic==2.7.1" pip install "pydantic==2.7.1"
pip install "diskcache==5.6.1" pip install "diskcache==5.6.1"
pip install "Pillow==10.3.0"
- save_cache: - save_cache:
paths: paths:
- ./venv - ./venv

View file

@ -1,3 +1,6 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# Batching Completion() # Batching Completion()
LiteLLM allows you to: LiteLLM allows you to:
* Send many completion calls to 1 model * Send many completion calls to 1 model
@ -51,6 +54,9 @@ This makes parallel calls to the specified `models` and returns the first respon
Use this to reduce latency Use this to reduce latency
<Tabs>
<TabItem value="sdk" label="SDK">
### Example Code ### Example Code
```python ```python
import litellm import litellm
@ -68,8 +74,93 @@ response = batch_completion_models(
print(result) print(result)
``` ```
</TabItem>
<TabItem value="proxy" label="PROXY">
[how to setup proxy config](#example-setup)
Just pass a comma-separated string of model names and the flag `fastest_response=True`.
<Tabs>
<TabItem value="curl" label="curl">
```bash
curl -X POST 'http://localhost:4000/chat/completions' \
-H 'Content-Type: application/json' \
-H 'Authorization: Bearer sk-1234' \
-D '{
"model": "gpt-4o, groq-llama", # 👈 Comma-separated models
"messages": [
{
"role": "user",
"content": "What's the weather like in Boston today?"
}
],
"stream": true,
"fastest_response": true # 👈 FLAG
}
'
```
</TabItem>
<TabItem value="openai" label="OpenAI SDK">
```python
import openai
client = openai.OpenAI(
api_key="anything",
base_url="http://0.0.0.0:4000"
)
# request sent to model set on litellm proxy, `litellm --model`
response = client.chat.completions.create(
model="gpt-4o, groq-llama", # 👈 Comma-separated models
messages = [
{
"role": "user",
"content": "this is a test request, write a short poem"
}
],
extra_body={"fastest_response": true} # 👈 FLAG
)
print(response)
```
</TabItem>
</Tabs>
---
### Example Setup:
```yaml
model_list:
- model_name: groq-llama
litellm_params:
model: groq/llama3-8b-8192
api_key: os.environ/GROQ_API_KEY
- model_name: gpt-4o
litellm_params:
model: gpt-4o
api_key: os.environ/OPENAI_API_KEY
```
```bash
litellm --config /path/to/config.yaml
# RUNNING on http://0.0.0.0:4000
```
</TabItem>
</Tabs>
### Output ### Output
Returns the first response Returns the first response in OpenAI format. Cancels other LLM API calls.
```json ```json
{ {
"object": "chat.completion", "object": "chat.completion",
@ -95,6 +186,7 @@ Returns the first response
} }
``` ```
## Send 1 completion call to many models: Return All Responses ## Send 1 completion call to many models: Return All Responses
This makes parallel calls to the specified models and returns all responses This makes parallel calls to the specified models and returns all responses

View file

@ -14,7 +14,6 @@ from functools import partial
import dotenv, traceback, random, asyncio, time, contextvars import dotenv, traceback, random, asyncio, time, contextvars
from copy import deepcopy from copy import deepcopy
import httpx import httpx
import litellm import litellm
from ._logging import verbose_logger from ._logging import verbose_logger
from litellm import ( # type: ignore from litellm import ( # type: ignore
@ -680,6 +679,7 @@ def completion(
"region_name", "region_name",
"allowed_model_region", "allowed_model_region",
"model_config", "model_config",
"fastest_response",
] ]
default_params = openai_params + litellm_params default_params = openai_params + litellm_params

View file

@ -36,7 +36,7 @@ model_list:
api_base: https://my-endpoint-europe-berri-992.openai.azure.com/ api_base: https://my-endpoint-europe-berri-992.openai.azure.com/
api_key: os.environ/AZURE_EUROPE_API_KEY api_key: os.environ/AZURE_EUROPE_API_KEY
model: azure/gpt-35-turbo model: azure/gpt-35-turbo
model_name: gpt-3.5-turbo model_name: gpt-3.5-turbo-fake-model
- litellm_params: - litellm_params:
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/ api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
api_key: os.environ/AZURE_API_KEY api_key: os.environ/AZURE_API_KEY

View file

@ -423,6 +423,7 @@ def get_custom_headers(
api_base: Optional[str] = None, api_base: Optional[str] = None,
version: Optional[str] = None, version: Optional[str] = None,
model_region: Optional[str] = None, model_region: Optional[str] = None,
fastest_response_batch_completion: Optional[bool] = None,
) -> dict: ) -> dict:
exclude_values = {"", None} exclude_values = {"", None}
headers = { headers = {
@ -433,6 +434,11 @@ def get_custom_headers(
"x-litellm-model-region": model_region, "x-litellm-model-region": model_region,
"x-litellm-key-tpm-limit": str(user_api_key_dict.tpm_limit), "x-litellm-key-tpm-limit": str(user_api_key_dict.tpm_limit),
"x-litellm-key-rpm-limit": str(user_api_key_dict.rpm_limit), "x-litellm-key-rpm-limit": str(user_api_key_dict.rpm_limit),
"x-litellm-fastest_response_batch_completion": (
str(fastest_response_batch_completion)
if fastest_response_batch_completion is not None
else None
),
} }
try: try:
return { return {
@ -4043,9 +4049,15 @@ async def chat_completion(
if "api_key" in data: if "api_key" in data:
tasks.append(litellm.acompletion(**data)) tasks.append(litellm.acompletion(**data))
elif "," in data["model"] and llm_router is not None: elif "," in data["model"] and llm_router is not None:
_models_csv_string = data.pop("model") if (
_models = _models_csv_string.split(",") data.get("fastest_response", None) is not None
tasks.append(llm_router.abatch_completion(models=_models, **data)) and data["fastest_response"] == True
):
tasks.append(llm_router.abatch_completion_fastest_response(**data))
else:
_models_csv_string = data.pop("model")
_models = [model.strip() for model in _models_csv_string.split(",")]
tasks.append(llm_router.abatch_completion(models=_models, **data))
elif "user_config" in data: elif "user_config" in data:
# initialize a new router instance. make request using this Router # initialize a new router instance. make request using this Router
router_config = data.pop("user_config") router_config = data.pop("user_config")
@ -4095,6 +4107,9 @@ async def chat_completion(
model_id = hidden_params.get("model_id", None) or "" model_id = hidden_params.get("model_id", None) or ""
cache_key = hidden_params.get("cache_key", None) or "" cache_key = hidden_params.get("cache_key", None) or ""
api_base = hidden_params.get("api_base", None) or "" api_base = hidden_params.get("api_base", None) or ""
fastest_response_batch_completion = hidden_params.get(
"fastest_response_batch_completion", None
)
# Post Call Processing # Post Call Processing
if llm_router is not None: if llm_router is not None:
@ -4111,6 +4126,7 @@ async def chat_completion(
api_base=api_base, api_base=api_base,
version=version, version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""), model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
fastest_response_batch_completion=fastest_response_batch_completion,
) )
selected_data_generator = select_data_generator( selected_data_generator = select_data_generator(
response=response, response=response,
@ -4131,6 +4147,7 @@ async def chat_completion(
api_base=api_base, api_base=api_base,
version=version, version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""), model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
fastest_response_batch_completion=fastest_response_batch_completion,
) )
) )

View file

@ -799,6 +799,101 @@ class Router:
response = await asyncio.gather(*_tasks) response = await asyncio.gather(*_tasks)
return response return response
# fmt: off
@overload
async def abatch_completion_fastest_response(
self, model: str, messages: List[Dict[str, str]], stream: Literal[True], **kwargs
) -> CustomStreamWrapper:
...
@overload
async def abatch_completion_fastest_response(
self, model: str, messages: List[Dict[str, str]], stream: Literal[False] = False, **kwargs
) -> ModelResponse:
...
# fmt: on
async def abatch_completion_fastest_response(
self,
model: str,
messages: List[Dict[str, str]],
stream: bool = False,
**kwargs,
):
"""
model - List of comma-separated model names. E.g. model="gpt-4, gpt-3.5-turbo"
Returns fastest response from list of model names. OpenAI-compatible endpoint.
"""
models = [m.strip() for m in model.split(",")]
async def _async_completion_no_exceptions(
model: str, messages: List[Dict[str, str]], stream: bool, **kwargs: Any
) -> Union[ModelResponse, CustomStreamWrapper, Exception]:
"""
Wrapper around self.acompletion that catches exceptions and returns them as a result
"""
try:
return await self.acompletion(model=model, messages=messages, stream=stream, **kwargs) # type: ignore
except asyncio.CancelledError:
verbose_router_logger.debug(
"Received 'task.cancel'. Cancelling call w/ model={}.".format(model)
)
raise
except Exception as e:
return e
pending_tasks = [] # type: ignore
async def check_response(task: asyncio.Task):
nonlocal pending_tasks
try:
result = await task
if isinstance(result, (ModelResponse, CustomStreamWrapper)):
verbose_router_logger.debug(
"Received successful response. Cancelling other LLM API calls."
)
# If a desired response is received, cancel all other pending tasks
for t in pending_tasks:
t.cancel()
return result
except Exception:
# Ignore exceptions, let the loop handle them
pass
finally:
# Remove the task from pending tasks if it finishes
try:
pending_tasks.remove(task)
except KeyError:
pass
for model in models:
task = asyncio.create_task(
_async_completion_no_exceptions(
model=model, messages=messages, stream=stream, **kwargs
)
)
pending_tasks.append(task)
# Await the first task to complete successfully
while pending_tasks:
done, pending_tasks = await asyncio.wait( # type: ignore
pending_tasks, return_when=asyncio.FIRST_COMPLETED
)
for completed_task in done:
result = await check_response(completed_task)
if result is not None:
# Return the first successful result
result._hidden_params["fastest_response_batch_completion"] = True
return result
# If we exit the loop without returning, all tasks failed
raise Exception("All tasks failed")
def image_generation(self, prompt: str, model: str, **kwargs): def image_generation(self, prompt: str, model: str, **kwargs):
try: try:
kwargs["model"] = model kwargs["model"] = model
@ -3608,7 +3703,6 @@ class Router:
## get healthy deployments ## get healthy deployments
### get all deployments ### get all deployments
healthy_deployments = [m for m in self.model_list if m["model_name"] == model] healthy_deployments = [m for m in self.model_list if m["model_name"] == model]
if len(healthy_deployments) == 0: if len(healthy_deployments) == 0:
# check if the user sent in a deployment name instead # check if the user sent in a deployment name instead
healthy_deployments = [ healthy_deployments = [

View file

@ -19,8 +19,9 @@ import os, httpx
load_dotenv() load_dotenv()
@pytest.mark.parametrize("mode", ["all_responses", "fastest_response"])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_batch_completion_multiple_models(): async def test_batch_completion_multiple_models(mode):
litellm.set_verbose = True litellm.set_verbose = True
router = litellm.Router( router = litellm.Router(
@ -40,65 +41,112 @@ async def test_batch_completion_multiple_models():
] ]
) )
response = await router.abatch_completion( if mode == "all_responses":
models=["gpt-3.5-turbo", "groq-llama"], response = await router.abatch_completion(
models=["gpt-3.5-turbo", "groq-llama"],
messages=[
{"role": "user", "content": "is litellm becoming a better product ?"}
],
max_tokens=15,
)
print(response)
assert len(response) == 2
models_in_responses = []
for individual_response in response:
_model = individual_response["model"]
models_in_responses.append(_model)
# assert both models are different
assert models_in_responses[0] != models_in_responses[1]
elif mode == "fastest_response":
from openai.types.chat.chat_completion import ChatCompletion
response = await router.abatch_completion_fastest_response(
model="gpt-3.5-turbo, groq-llama",
messages=[
{"role": "user", "content": "is litellm becoming a better product ?"}
],
max_tokens=15,
)
ChatCompletion.model_validate(response.model_dump(), strict=True)
@pytest.mark.asyncio
async def test_batch_completion_fastest_response_unit_test():
"""
Unit test to confirm fastest response will always return the response which arrives earliest.
2 models -> 1 is cached, the other is a real llm api call => assert cached response always returned
"""
litellm.set_verbose = True
router = litellm.Router(
model_list=[
{
"model_name": "gpt-4",
"litellm_params": {
"model": "gpt-4",
},
"model_info": {"id": "1"},
},
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo",
"mock_response": "This is a fake response",
},
"model_info": {"id": "2"},
},
]
)
response = await router.abatch_completion_fastest_response(
model="gpt-4, gpt-3.5-turbo",
messages=[
{"role": "user", "content": "is litellm becoming a better product ?"}
],
max_tokens=500,
)
assert response._hidden_params["model_id"] == "2"
assert response.choices[0].message.content == "This is a fake response"
print(f"response: {response}")
@pytest.mark.asyncio
async def test_batch_completion_fastest_response_streaming():
litellm.set_verbose = True
router = litellm.Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo",
},
},
{
"model_name": "groq-llama",
"litellm_params": {
"model": "groq/llama3-8b-8192",
},
},
]
)
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
response = await router.abatch_completion_fastest_response(
model="gpt-3.5-turbo, groq-llama",
messages=[ messages=[
{"role": "user", "content": "is litellm becoming a better product ?"} {"role": "user", "content": "is litellm becoming a better product ?"}
], ],
max_tokens=15, max_tokens=15,
stream=True,
) )
print(response) async for chunk in response:
assert len(response) == 2 ChatCompletionChunk.model_validate(chunk.model_dump(), strict=True)
models_in_responses = []
for individual_response in response:
_model = individual_response["model"]
models_in_responses.append(_model)
# assert both models are different
assert models_in_responses[0] != models_in_responses[1]
@pytest.mark.asyncio
async def test_batch_completion_multiple_models_multiple_messages():
litellm.set_verbose = True
router = litellm.Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo",
},
},
{
"model_name": "groq-llama",
"litellm_params": {
"model": "groq/llama3-8b-8192",
},
},
]
)
response = await router.abatch_completion(
models=["gpt-3.5-turbo", "groq-llama"],
messages=[
[{"role": "user", "content": "is litellm becoming a better product ?"}],
[{"role": "user", "content": "who is this"}],
],
max_tokens=15,
)
print("response from batches =", response)
assert len(response) == 2
assert len(response[0]) == 2
assert isinstance(response[0][0], litellm.ModelResponse)
# models_in_responses = []
# for individual_response in response:
# _model = individual_response["model"]
# models_in_responses.append(_model)
# # assert both models are different
# assert models_in_responses[0] != models_in_responses[1]