forked from phoenix/litellm-mirror
Merge pull request #3585 from BerriAI/litellm_router_batch_comp
[Litellm Proxy + litellm.Router] - Pass the same message/prompt to N models
This commit is contained in:
commit
bf909a89f8
7 changed files with 213 additions and 2 deletions
|
@ -4,6 +4,12 @@ LiteLLM allows you to:
|
||||||
* Send 1 completion call to many models: Return Fastest Response
|
* Send 1 completion call to many models: Return Fastest Response
|
||||||
* Send 1 completion call to many models: Return All Responses
|
* Send 1 completion call to many models: Return All Responses
|
||||||
|
|
||||||
|
:::info
|
||||||
|
|
||||||
|
Trying to do batch completion on LiteLLM Proxy ? Go here: https://docs.litellm.ai/docs/proxy/user_keys#beta-batch-completions---pass-model-as-list
|
||||||
|
|
||||||
|
:::
|
||||||
|
|
||||||
## Send multiple completion calls to 1 model
|
## Send multiple completion calls to 1 model
|
||||||
|
|
||||||
In the batch_completion method, you provide a list of `messages` where each sub-list of messages is passed to `litellm.completion()`, allowing you to process multiple prompts efficiently in a single API call.
|
In the batch_completion method, you provide a list of `messages` where each sub-list of messages is passed to `litellm.completion()`, allowing you to process multiple prompts efficiently in a single API call.
|
||||||
|
|
|
@ -365,6 +365,90 @@ curl --location 'http://0.0.0.0:4000/moderations' \
|
||||||
|
|
||||||
## Advanced
|
## Advanced
|
||||||
|
|
||||||
|
### (BETA) Batch Completions - pass `model` as List
|
||||||
|
|
||||||
|
Use this when you want to send 1 request to N Models
|
||||||
|
|
||||||
|
#### Expected Request Format
|
||||||
|
|
||||||
|
This same request will be sent to the following model groups on the [litellm proxy config.yaml](https://docs.litellm.ai/docs/proxy/configs)
|
||||||
|
- `model_name="llama3"`
|
||||||
|
- `model_name="gpt-3.5-turbo"`
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl --location 'http://localhost:4000/chat/completions' \
|
||||||
|
--header 'Authorization: Bearer sk-1234' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--data '{
|
||||||
|
"model": ["llama3", "gpt-3.5-turbo"],
|
||||||
|
"max_tokens": 10,
|
||||||
|
"user": "litellm2",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "is litellm getting better"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
#### Expected Response Format
|
||||||
|
|
||||||
|
Get a list of responses when `model` is passed as a list
|
||||||
|
|
||||||
|
```json
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"id": "chatcmpl-3dbd5dd8-7c82-4ca3-bf1f-7c26f497cf2b",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "length",
|
||||||
|
"index": 0,
|
||||||
|
"message": {
|
||||||
|
"content": "The Elder Scrolls IV: Oblivion!\n\nReleased",
|
||||||
|
"role": "assistant"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1715459876,
|
||||||
|
"model": "groq/llama3-8b-8192",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"system_fingerprint": "fp_179b0f92c9",
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 10,
|
||||||
|
"prompt_tokens": 12,
|
||||||
|
"total_tokens": 22
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "chatcmpl-9NnldUfFLmVquFHSX4yAtjCw8PGei",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "length",
|
||||||
|
"index": 0,
|
||||||
|
"message": {
|
||||||
|
"content": "TES4 could refer to The Elder Scrolls IV:",
|
||||||
|
"role": "assistant"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1715459877,
|
||||||
|
"model": "gpt-3.5-turbo-0125",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"system_fingerprint": null,
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 10,
|
||||||
|
"prompt_tokens": 9,
|
||||||
|
"total_tokens": 19
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
### Pass User LLM API Keys, Fallbacks
|
### Pass User LLM API Keys, Fallbacks
|
||||||
Allow your end-users to pass their model list, api base, OpenAI API key (any LiteLLM supported provider) to make requests
|
Allow your end-users to pass their model list, api base, OpenAI API key (any LiteLLM supported provider) to make requests
|
||||||
|
|
||||||
|
|
|
@ -4,6 +4,12 @@ model_list:
|
||||||
model: openai/fake
|
model: openai/fake
|
||||||
api_key: fake-key
|
api_key: fake-key
|
||||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||||
|
- model_name: llama3
|
||||||
|
litellm_params:
|
||||||
|
model: groq/llama3-8b-8192
|
||||||
|
- model_name: gpt-3.5-turbo
|
||||||
|
litellm_params:
|
||||||
|
model: gpt-3.5-turbo
|
||||||
- model_name: "*"
|
- model_name: "*"
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: openai/*
|
model: openai/*
|
||||||
|
|
|
@ -3656,7 +3656,7 @@ async def chat_completion(
|
||||||
### MODEL ALIAS MAPPING ###
|
### MODEL ALIAS MAPPING ###
|
||||||
# check if model name in model alias map
|
# check if model name in model alias map
|
||||||
# get the actual model name
|
# get the actual model name
|
||||||
if data["model"] in litellm.model_alias_map:
|
if isinstance(data["model"], str) and data["model"] in litellm.model_alias_map:
|
||||||
data["model"] = litellm.model_alias_map[data["model"]]
|
data["model"] = litellm.model_alias_map[data["model"]]
|
||||||
|
|
||||||
## LOGGING OBJECT ## - initialize logging object for logging success/failure events for call
|
## LOGGING OBJECT ## - initialize logging object for logging success/failure events for call
|
||||||
|
@ -3690,6 +3690,9 @@ async def chat_completion(
|
||||||
# skip router if user passed their key
|
# skip router if user passed their key
|
||||||
if "api_key" in data:
|
if "api_key" in data:
|
||||||
tasks.append(litellm.acompletion(**data))
|
tasks.append(litellm.acompletion(**data))
|
||||||
|
elif isinstance(data["model"], list) and llm_router is not None:
|
||||||
|
_models = data.pop("model")
|
||||||
|
tasks.append(llm_router.abatch_completion(models=_models, **data))
|
||||||
elif "user_config" in data:
|
elif "user_config" in data:
|
||||||
# initialize a new router instance. make request using this Router
|
# initialize a new router instance. make request using this Router
|
||||||
router_config = data.pop("user_config")
|
router_config = data.pop("user_config")
|
||||||
|
|
|
@ -606,6 +606,33 @@ class Router:
|
||||||
self.fail_calls[model_name] += 1
|
self.fail_calls[model_name] += 1
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
async def abatch_completion(
|
||||||
|
self, models: List[str], messages: List[Dict[str, str]], **kwargs
|
||||||
|
):
|
||||||
|
|
||||||
|
async def _async_completion_no_exceptions(
|
||||||
|
model: str, messages: List[Dict[str, str]], **kwargs
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Wrapper around self.async_completion that catches exceptions and returns them as a result
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return await self.acompletion(model=model, messages=messages, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
return e
|
||||||
|
|
||||||
|
_tasks = []
|
||||||
|
for model in models:
|
||||||
|
# add each task but if the task fails
|
||||||
|
_tasks.append(
|
||||||
|
_async_completion_no_exceptions(
|
||||||
|
model=model, messages=messages, **kwargs
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await asyncio.gather(*_tasks)
|
||||||
|
return response
|
||||||
|
|
||||||
def image_generation(self, prompt: str, model: str, **kwargs):
|
def image_generation(self, prompt: str, model: str, **kwargs):
|
||||||
try:
|
try:
|
||||||
kwargs["model"] = model
|
kwargs["model"] = model
|
||||||
|
|
60
litellm/tests/test_router_batch_completion.py
Normal file
60
litellm/tests/test_router_batch_completion.py
Normal file
|
@ -0,0 +1,60 @@
|
||||||
|
#### What this tests ####
|
||||||
|
# This tests litellm router with batch completion
|
||||||
|
|
||||||
|
import sys, os, time, openai
|
||||||
|
import traceback, asyncio
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
import litellm
|
||||||
|
from litellm import Router
|
||||||
|
from litellm.router import Deployment, LiteLLM_Params, ModelInfo
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from collections import defaultdict
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
import os, httpx
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_batch_completion_multiple_models():
|
||||||
|
litellm.set_verbose = True
|
||||||
|
|
||||||
|
router = litellm.Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "groq-llama",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "groq/llama3-8b-8192",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await router.abatch_completion(
|
||||||
|
models=["gpt-3.5-turbo", "groq-llama"],
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": "is litellm becoming a better product ?"}
|
||||||
|
],
|
||||||
|
max_tokens=15,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
assert len(response) == 2
|
||||||
|
|
||||||
|
models_in_responses = []
|
||||||
|
for individual_response in response:
|
||||||
|
_model = individual_response["model"]
|
||||||
|
models_in_responses.append(_model)
|
||||||
|
|
||||||
|
# assert both models are different
|
||||||
|
assert models_in_responses[0] != models_in_responses[1]
|
|
@ -4,6 +4,7 @@ import pytest
|
||||||
import asyncio
|
import asyncio
|
||||||
import aiohttp, openai
|
import aiohttp, openai
|
||||||
from openai import OpenAI, AsyncOpenAI
|
from openai import OpenAI, AsyncOpenAI
|
||||||
|
from typing import Optional, List, Union
|
||||||
|
|
||||||
|
|
||||||
def response_header_check(response):
|
def response_header_check(response):
|
||||||
|
@ -71,7 +72,7 @@ async def new_user(session):
|
||||||
return await response.json()
|
return await response.json()
|
||||||
|
|
||||||
|
|
||||||
async def chat_completion(session, key, model="gpt-4"):
|
async def chat_completion(session, key, model: Union[str, List] = "gpt-4"):
|
||||||
url = "http://0.0.0.0:4000/chat/completions"
|
url = "http://0.0.0.0:4000/chat/completions"
|
||||||
headers = {
|
headers = {
|
||||||
"Authorization": f"Bearer {key}",
|
"Authorization": f"Bearer {key}",
|
||||||
|
@ -409,3 +410,27 @@ async def test_openai_wildcard_chat_completion():
|
||||||
|
|
||||||
# call chat/completions with a model that the key was not created for + the model is not on the config.yaml
|
# call chat/completions with a model that the key was not created for + the model is not on the config.yaml
|
||||||
await chat_completion(session=session, key=key, model="gpt-3.5-turbo-0125")
|
await chat_completion(session=session, key=key, model="gpt-3.5-turbo-0125")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_batch_chat_completions():
|
||||||
|
"""
|
||||||
|
- Make chat completion call using
|
||||||
|
|
||||||
|
"""
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
|
||||||
|
# call chat/completions with a model that the key was not created for + the model is not on the config.yaml
|
||||||
|
response = await chat_completion(
|
||||||
|
session=session,
|
||||||
|
key="sk-1234",
|
||||||
|
model=[
|
||||||
|
"gpt-3.5-turbo",
|
||||||
|
"fake-openai-endpoint",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"response: {response}")
|
||||||
|
|
||||||
|
assert len(response) == 2
|
||||||
|
assert isinstance(response, list)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue