forked from phoenix/litellm-mirror
Merge pull request #3585 from BerriAI/litellm_router_batch_comp
[Litellm Proxy + litellm.Router] - Pass the same message/prompt to N models
This commit is contained in:
commit
bf909a89f8
7 changed files with 213 additions and 2 deletions
|
@ -4,6 +4,12 @@ LiteLLM allows you to:
|
|||
* Send 1 completion call to many models: Return Fastest Response
|
||||
* Send 1 completion call to many models: Return All Responses
|
||||
|
||||
:::info
|
||||
|
||||
Trying to do batch completion on LiteLLM Proxy ? Go here: https://docs.litellm.ai/docs/proxy/user_keys#beta-batch-completions---pass-model-as-list
|
||||
|
||||
:::
|
||||
|
||||
## Send multiple completion calls to 1 model
|
||||
|
||||
In the batch_completion method, you provide a list of `messages` where each sub-list of messages is passed to `litellm.completion()`, allowing you to process multiple prompts efficiently in a single API call.
|
||||
|
|
|
@ -365,6 +365,90 @@ curl --location 'http://0.0.0.0:4000/moderations' \
|
|||
|
||||
## Advanced
|
||||
|
||||
### (BETA) Batch Completions - pass `model` as List
|
||||
|
||||
Use this when you want to send 1 request to N Models
|
||||
|
||||
#### Expected Request Format
|
||||
|
||||
This same request will be sent to the following model groups on the [litellm proxy config.yaml](https://docs.litellm.ai/docs/proxy/configs)
|
||||
- `model_name="llama3"`
|
||||
- `model_name="gpt-3.5-turbo"`
|
||||
|
||||
```shell
|
||||
curl --location 'http://localhost:4000/chat/completions' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data '{
|
||||
"model": ["llama3", "gpt-3.5-turbo"],
|
||||
"max_tokens": 10,
|
||||
"user": "litellm2",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "is litellm getting better"
|
||||
}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
|
||||
#### Expected Response Format
|
||||
|
||||
Get a list of responses when `model` is passed as a list
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"id": "chatcmpl-3dbd5dd8-7c82-4ca3-bf1f-7c26f497cf2b",
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "length",
|
||||
"index": 0,
|
||||
"message": {
|
||||
"content": "The Elder Scrolls IV: Oblivion!\n\nReleased",
|
||||
"role": "assistant"
|
||||
}
|
||||
}
|
||||
],
|
||||
"created": 1715459876,
|
||||
"model": "groq/llama3-8b-8192",
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": "fp_179b0f92c9",
|
||||
"usage": {
|
||||
"completion_tokens": 10,
|
||||
"prompt_tokens": 12,
|
||||
"total_tokens": 22
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "chatcmpl-9NnldUfFLmVquFHSX4yAtjCw8PGei",
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "length",
|
||||
"index": 0,
|
||||
"message": {
|
||||
"content": "TES4 could refer to The Elder Scrolls IV:",
|
||||
"role": "assistant"
|
||||
}
|
||||
}
|
||||
],
|
||||
"created": 1715459877,
|
||||
"model": "gpt-3.5-turbo-0125",
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": null,
|
||||
"usage": {
|
||||
"completion_tokens": 10,
|
||||
"prompt_tokens": 9,
|
||||
"total_tokens": 19
|
||||
}
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
|
||||
|
||||
|
||||
### Pass User LLM API Keys, Fallbacks
|
||||
Allow your end-users to pass their model list, api base, OpenAI API key (any LiteLLM supported provider) to make requests
|
||||
|
||||
|
|
|
@ -4,6 +4,12 @@ model_list:
|
|||
model: openai/fake
|
||||
api_key: fake-key
|
||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||
- model_name: llama3
|
||||
litellm_params:
|
||||
model: groq/llama3-8b-8192
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: gpt-3.5-turbo
|
||||
- model_name: "*"
|
||||
litellm_params:
|
||||
model: openai/*
|
||||
|
|
|
@ -3656,7 +3656,7 @@ async def chat_completion(
|
|||
### MODEL ALIAS MAPPING ###
|
||||
# check if model name in model alias map
|
||||
# get the actual model name
|
||||
if data["model"] in litellm.model_alias_map:
|
||||
if isinstance(data["model"], str) and data["model"] in litellm.model_alias_map:
|
||||
data["model"] = litellm.model_alias_map[data["model"]]
|
||||
|
||||
## LOGGING OBJECT ## - initialize logging object for logging success/failure events for call
|
||||
|
@ -3690,6 +3690,9 @@ async def chat_completion(
|
|||
# skip router if user passed their key
|
||||
if "api_key" in data:
|
||||
tasks.append(litellm.acompletion(**data))
|
||||
elif isinstance(data["model"], list) and llm_router is not None:
|
||||
_models = data.pop("model")
|
||||
tasks.append(llm_router.abatch_completion(models=_models, **data))
|
||||
elif "user_config" in data:
|
||||
# initialize a new router instance. make request using this Router
|
||||
router_config = data.pop("user_config")
|
||||
|
|
|
@ -606,6 +606,33 @@ class Router:
|
|||
self.fail_calls[model_name] += 1
|
||||
raise e
|
||||
|
||||
async def abatch_completion(
|
||||
self, models: List[str], messages: List[Dict[str, str]], **kwargs
|
||||
):
|
||||
|
||||
async def _async_completion_no_exceptions(
|
||||
model: str, messages: List[Dict[str, str]], **kwargs
|
||||
):
|
||||
"""
|
||||
Wrapper around self.async_completion that catches exceptions and returns them as a result
|
||||
"""
|
||||
try:
|
||||
return await self.acompletion(model=model, messages=messages, **kwargs)
|
||||
except Exception as e:
|
||||
return e
|
||||
|
||||
_tasks = []
|
||||
for model in models:
|
||||
# add each task but if the task fails
|
||||
_tasks.append(
|
||||
_async_completion_no_exceptions(
|
||||
model=model, messages=messages, **kwargs
|
||||
)
|
||||
)
|
||||
|
||||
response = await asyncio.gather(*_tasks)
|
||||
return response
|
||||
|
||||
def image_generation(self, prompt: str, model: str, **kwargs):
|
||||
try:
|
||||
kwargs["model"] = model
|
||||
|
|
60
litellm/tests/test_router_batch_completion.py
Normal file
60
litellm/tests/test_router_batch_completion.py
Normal file
|
@ -0,0 +1,60 @@
|
|||
#### What this tests ####
|
||||
# This tests litellm router with batch completion
|
||||
|
||||
import sys, os, time, openai
|
||||
import traceback, asyncio
|
||||
import pytest
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import litellm
|
||||
from litellm import Router
|
||||
from litellm.router import Deployment, LiteLLM_Params, ModelInfo
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from collections import defaultdict
|
||||
from dotenv import load_dotenv
|
||||
import os, httpx
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_completion_multiple_models():
|
||||
litellm.set_verbose = True
|
||||
|
||||
router = litellm.Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "gpt-3.5-turbo",
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "groq-llama",
|
||||
"litellm_params": {
|
||||
"model": "groq/llama3-8b-8192",
|
||||
},
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
response = await router.abatch_completion(
|
||||
models=["gpt-3.5-turbo", "groq-llama"],
|
||||
messages=[
|
||||
{"role": "user", "content": "is litellm becoming a better product ?"}
|
||||
],
|
||||
max_tokens=15,
|
||||
)
|
||||
|
||||
print(response)
|
||||
assert len(response) == 2
|
||||
|
||||
models_in_responses = []
|
||||
for individual_response in response:
|
||||
_model = individual_response["model"]
|
||||
models_in_responses.append(_model)
|
||||
|
||||
# assert both models are different
|
||||
assert models_in_responses[0] != models_in_responses[1]
|
|
@ -4,6 +4,7 @@ import pytest
|
|||
import asyncio
|
||||
import aiohttp, openai
|
||||
from openai import OpenAI, AsyncOpenAI
|
||||
from typing import Optional, List, Union
|
||||
|
||||
|
||||
def response_header_check(response):
|
||||
|
@ -71,7 +72,7 @@ async def new_user(session):
|
|||
return await response.json()
|
||||
|
||||
|
||||
async def chat_completion(session, key, model="gpt-4"):
|
||||
async def chat_completion(session, key, model: Union[str, List] = "gpt-4"):
|
||||
url = "http://0.0.0.0:4000/chat/completions"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {key}",
|
||||
|
@ -409,3 +410,27 @@ async def test_openai_wildcard_chat_completion():
|
|||
|
||||
# call chat/completions with a model that the key was not created for + the model is not on the config.yaml
|
||||
await chat_completion(session=session, key=key, model="gpt-3.5-turbo-0125")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_chat_completions():
|
||||
"""
|
||||
- Make chat completion call using
|
||||
|
||||
"""
|
||||
async with aiohttp.ClientSession() as session:
|
||||
|
||||
# call chat/completions with a model that the key was not created for + the model is not on the config.yaml
|
||||
response = await chat_completion(
|
||||
session=session,
|
||||
key="sk-1234",
|
||||
model=[
|
||||
"gpt-3.5-turbo",
|
||||
"fake-openai-endpoint",
|
||||
],
|
||||
)
|
||||
|
||||
print(f"response: {response}")
|
||||
|
||||
assert len(response) == 2
|
||||
assert isinstance(response, list)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue