Merge pull request #3585 from BerriAI/litellm_router_batch_comp

[Litellm Proxy + litellm.Router] - Pass the same message/prompt to N models
This commit is contained in:
Ishaan Jaff 2024-05-11 13:51:45 -07:00 committed by GitHub
commit bf909a89f8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 213 additions and 2 deletions

View file

@ -4,6 +4,12 @@ LiteLLM allows you to:
* Send 1 completion call to many models: Return Fastest Response
* Send 1 completion call to many models: Return All Responses
:::info
Trying to do batch completion on LiteLLM Proxy ? Go here: https://docs.litellm.ai/docs/proxy/user_keys#beta-batch-completions---pass-model-as-list
:::
## Send multiple completion calls to 1 model
In the batch_completion method, you provide a list of `messages` where each sub-list of messages is passed to `litellm.completion()`, allowing you to process multiple prompts efficiently in a single API call.

View file

@ -365,6 +365,90 @@ curl --location 'http://0.0.0.0:4000/moderations' \
## Advanced
### (BETA) Batch Completions - pass `model` as List
Use this when you want to send 1 request to N Models
#### Expected Request Format
This same request will be sent to the following model groups on the [litellm proxy config.yaml](https://docs.litellm.ai/docs/proxy/configs)
- `model_name="llama3"`
- `model_name="gpt-3.5-turbo"`
```shell
curl --location 'http://localhost:4000/chat/completions' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"model": ["llama3", "gpt-3.5-turbo"],
"max_tokens": 10,
"user": "litellm2",
"messages": [
{
"role": "user",
"content": "is litellm getting better"
}
]
}'
```
#### Expected Response Format
Get a list of responses when `model` is passed as a list
```json
[
{
"id": "chatcmpl-3dbd5dd8-7c82-4ca3-bf1f-7c26f497cf2b",
"choices": [
{
"finish_reason": "length",
"index": 0,
"message": {
"content": "The Elder Scrolls IV: Oblivion!\n\nReleased",
"role": "assistant"
}
}
],
"created": 1715459876,
"model": "groq/llama3-8b-8192",
"object": "chat.completion",
"system_fingerprint": "fp_179b0f92c9",
"usage": {
"completion_tokens": 10,
"prompt_tokens": 12,
"total_tokens": 22
}
},
{
"id": "chatcmpl-9NnldUfFLmVquFHSX4yAtjCw8PGei",
"choices": [
{
"finish_reason": "length",
"index": 0,
"message": {
"content": "TES4 could refer to The Elder Scrolls IV:",
"role": "assistant"
}
}
],
"created": 1715459877,
"model": "gpt-3.5-turbo-0125",
"object": "chat.completion",
"system_fingerprint": null,
"usage": {
"completion_tokens": 10,
"prompt_tokens": 9,
"total_tokens": 19
}
}
]
```
### Pass User LLM API Keys, Fallbacks
Allow your end-users to pass their model list, api base, OpenAI API key (any LiteLLM supported provider) to make requests

View file

@ -4,6 +4,12 @@ model_list:
model: openai/fake
api_key: fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app/
- model_name: llama3
litellm_params:
model: groq/llama3-8b-8192
- model_name: gpt-3.5-turbo
litellm_params:
model: gpt-3.5-turbo
- model_name: "*"
litellm_params:
model: openai/*

View file

@ -3656,7 +3656,7 @@ async def chat_completion(
### MODEL ALIAS MAPPING ###
# check if model name in model alias map
# get the actual model name
if data["model"] in litellm.model_alias_map:
if isinstance(data["model"], str) and data["model"] in litellm.model_alias_map:
data["model"] = litellm.model_alias_map[data["model"]]
## LOGGING OBJECT ## - initialize logging object for logging success/failure events for call
@ -3690,6 +3690,9 @@ async def chat_completion(
# skip router if user passed their key
if "api_key" in data:
tasks.append(litellm.acompletion(**data))
elif isinstance(data["model"], list) and llm_router is not None:
_models = data.pop("model")
tasks.append(llm_router.abatch_completion(models=_models, **data))
elif "user_config" in data:
# initialize a new router instance. make request using this Router
router_config = data.pop("user_config")

View file

@ -606,6 +606,33 @@ class Router:
self.fail_calls[model_name] += 1
raise e
async def abatch_completion(
self, models: List[str], messages: List[Dict[str, str]], **kwargs
):
async def _async_completion_no_exceptions(
model: str, messages: List[Dict[str, str]], **kwargs
):
"""
Wrapper around self.async_completion that catches exceptions and returns them as a result
"""
try:
return await self.acompletion(model=model, messages=messages, **kwargs)
except Exception as e:
return e
_tasks = []
for model in models:
# add each task but if the task fails
_tasks.append(
_async_completion_no_exceptions(
model=model, messages=messages, **kwargs
)
)
response = await asyncio.gather(*_tasks)
return response
def image_generation(self, prompt: str, model: str, **kwargs):
try:
kwargs["model"] = model

View file

@ -0,0 +1,60 @@
#### What this tests ####
# This tests litellm router with batch completion
import sys, os, time, openai
import traceback, asyncio
import pytest
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import litellm
from litellm import Router
from litellm.router import Deployment, LiteLLM_Params, ModelInfo
from concurrent.futures import ThreadPoolExecutor
from collections import defaultdict
from dotenv import load_dotenv
import os, httpx
load_dotenv()
@pytest.mark.asyncio
async def test_batch_completion_multiple_models():
litellm.set_verbose = True
router = litellm.Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo",
},
},
{
"model_name": "groq-llama",
"litellm_params": {
"model": "groq/llama3-8b-8192",
},
},
]
)
response = await router.abatch_completion(
models=["gpt-3.5-turbo", "groq-llama"],
messages=[
{"role": "user", "content": "is litellm becoming a better product ?"}
],
max_tokens=15,
)
print(response)
assert len(response) == 2
models_in_responses = []
for individual_response in response:
_model = individual_response["model"]
models_in_responses.append(_model)
# assert both models are different
assert models_in_responses[0] != models_in_responses[1]

View file

@ -4,6 +4,7 @@ import pytest
import asyncio
import aiohttp, openai
from openai import OpenAI, AsyncOpenAI
from typing import Optional, List, Union
def response_header_check(response):
@ -71,7 +72,7 @@ async def new_user(session):
return await response.json()
async def chat_completion(session, key, model="gpt-4"):
async def chat_completion(session, key, model: Union[str, List] = "gpt-4"):
url = "http://0.0.0.0:4000/chat/completions"
headers = {
"Authorization": f"Bearer {key}",
@ -409,3 +410,27 @@ async def test_openai_wildcard_chat_completion():
# call chat/completions with a model that the key was not created for + the model is not on the config.yaml
await chat_completion(session=session, key=key, model="gpt-3.5-turbo-0125")
@pytest.mark.asyncio
async def test_batch_chat_completions():
"""
- Make chat completion call using
"""
async with aiohttp.ClientSession() as session:
# call chat/completions with a model that the key was not created for + the model is not on the config.yaml
response = await chat_completion(
session=session,
key="sk-1234",
model=[
"gpt-3.5-turbo",
"fake-openai-endpoint",
],
)
print(f"response: {response}")
assert len(response) == 2
assert isinstance(response, list)