mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
build(openai_proxy/main.py): adding support for routing between multiple azure deployments
This commit is contained in:
parent
f208a1231b
commit
b9a4bfc054
15 changed files with 159 additions and 1 deletions
2
openai_proxy/__init__.py
Normal file
2
openai_proxy/__init__.py
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
from .main import *
|
||||||
|
from .utils import *
|
0
openai_proxy/config
Normal file
0
openai_proxy/config
Normal file
|
@ -5,7 +5,8 @@ from fastapi.responses import StreamingResponse, FileResponse
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from utils import set_callbacks
|
from typing import Optional
|
||||||
|
from openai_proxy.utils import set_callbacks, load_router_config
|
||||||
import dotenv
|
import dotenv
|
||||||
dotenv.load_dotenv() # load env variables
|
dotenv.load_dotenv() # load env variables
|
||||||
|
|
||||||
|
@ -20,7 +21,11 @@ app.add_middleware(
|
||||||
allow_methods=["*"],
|
allow_methods=["*"],
|
||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
#### GLOBAL VARIABLES ####
|
||||||
|
llm_router: Optional[litellm.Router] = None
|
||||||
|
|
||||||
set_callbacks() # sets litellm callbacks for logging if they exist in the environment
|
set_callbacks() # sets litellm callbacks for logging if they exist in the environment
|
||||||
|
llm_router = load_router_config(router=llm_router)
|
||||||
#### API ENDPOINTS ####
|
#### API ENDPOINTS ####
|
||||||
@router.post("/v1/models")
|
@router.post("/v1/models")
|
||||||
@router.get("/models") # if project requires model list
|
@router.get("/models") # if project requires model list
|
||||||
|
@ -101,6 +106,48 @@ async def chat_completion(request: Request):
|
||||||
return {"error": error_msg}
|
return {"error": error_msg}
|
||||||
# raise HTTPException(status_code=500, detail=error_msg)
|
# raise HTTPException(status_code=500, detail=error_msg)
|
||||||
|
|
||||||
|
@router.post("/router/completions")
|
||||||
|
async def router_completion(request: Request):
|
||||||
|
global llm_router
|
||||||
|
try:
|
||||||
|
data = await request.json()
|
||||||
|
if "model_list" in data:
|
||||||
|
llm_router = litellm.Router(model_list=data["model_list"])
|
||||||
|
if llm_router is None:
|
||||||
|
raise Exception("Save model list via config.yaml. Eg.: ` docker build -t myapp --build-arg CONFIG_FILE=myconfig.yaml .` or pass it in as model_list=[..] as part of the request body")
|
||||||
|
|
||||||
|
# openai.ChatCompletion.create replacement
|
||||||
|
response = await llm_router.acompletion(model="gpt-3.5-turbo",
|
||||||
|
messages=[{"role": "user", "content": "Hey, how's it going?"}])
|
||||||
|
|
||||||
|
if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses
|
||||||
|
return StreamingResponse(data_generator(response), media_type='text/event-stream')
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
error_traceback = traceback.format_exc()
|
||||||
|
error_msg = f"{str(e)}\n\n{error_traceback}"
|
||||||
|
return {"error": error_msg}
|
||||||
|
|
||||||
|
@router.post("/router/embedding")
|
||||||
|
async def router_embedding(request: Request):
|
||||||
|
global llm_router
|
||||||
|
try:
|
||||||
|
if llm_router is None:
|
||||||
|
raise Exception("Save model list via config.yaml. Eg.: ` docker build -t myapp --build-arg CONFIG_FILE=myconfig.yaml .`")
|
||||||
|
|
||||||
|
data = await request.json()
|
||||||
|
# openai.ChatCompletion.create replacement
|
||||||
|
response = await llm_router.aembedding(model="gpt-3.5-turbo",
|
||||||
|
messages=[{"role": "user", "content": "Hey, how's it going?"}])
|
||||||
|
|
||||||
|
if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses
|
||||||
|
return StreamingResponse(data_generator(response), media_type='text/event-stream')
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
error_traceback = traceback.format_exc()
|
||||||
|
error_msg = f"{str(e)}\n\n{error_traceback}"
|
||||||
|
return {"error": error_msg}
|
||||||
|
|
||||||
@router.get("/")
|
@router.get("/")
|
||||||
async def home(request: Request):
|
async def home(request: Request):
|
||||||
return "LiteLLM: RUNNING"
|
return "LiteLLM: RUNNING"
|
59
openai_proxy/tests/test_router.py
Normal file
59
openai_proxy/tests/test_router.py
Normal file
|
@ -0,0 +1,59 @@
|
||||||
|
#### What this tests ####
|
||||||
|
# This tests calling batch_completions by running 100 messages together
|
||||||
|
|
||||||
|
import sys, os
|
||||||
|
import traceback, asyncio
|
||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from fastapi import Request
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
from openai_proxy import app
|
||||||
|
|
||||||
|
|
||||||
|
def test_router_completion():
|
||||||
|
client = TestClient(app)
|
||||||
|
data = {
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"messages": [{"role": "user", "content": "Hey, how's it going?"}],
|
||||||
|
"model_list": [{ # list of model deployments
|
||||||
|
"model_name": "gpt-3.5-turbo", # openai model name
|
||||||
|
"litellm_params": { # params for litellm completion/embedding call
|
||||||
|
"model": "azure/chatgpt-v-2",
|
||||||
|
"api_key": os.getenv("AZURE_API_KEY"),
|
||||||
|
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||||
|
"api_base": os.getenv("AZURE_API_BASE")
|
||||||
|
},
|
||||||
|
"tpm": 240000,
|
||||||
|
"rpm": 1800
|
||||||
|
}, {
|
||||||
|
"model_name": "gpt-3.5-turbo", # openai model name
|
||||||
|
"litellm_params": { # params for litellm completion/embedding call
|
||||||
|
"model": "azure/chatgpt-functioncalling",
|
||||||
|
"api_key": os.getenv("AZURE_API_KEY"),
|
||||||
|
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||||
|
"api_base": os.getenv("AZURE_API_BASE")
|
||||||
|
},
|
||||||
|
"tpm": 240000,
|
||||||
|
"rpm": 1800
|
||||||
|
}, {
|
||||||
|
"model_name": "gpt-3.5-turbo", # openai model name
|
||||||
|
"litellm_params": { # params for litellm completion/embedding call
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||||
|
},
|
||||||
|
"tpm": 1000000,
|
||||||
|
"rpm": 9000
|
||||||
|
}]
|
||||||
|
}
|
||||||
|
|
||||||
|
response = client.post("/router/completions", json=data)
|
||||||
|
print(f"response: {response.text}")
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
response_data = response.json()
|
||||||
|
# Perform assertions on the response data
|
||||||
|
assert isinstance(response_data['choices'][0]['message']['content'], str)
|
||||||
|
|
||||||
|
test_router_completion()
|
|
@ -1,5 +1,7 @@
|
||||||
import os, litellm
|
import os, litellm
|
||||||
|
import yaml
|
||||||
import dotenv
|
import dotenv
|
||||||
|
from typing import Optional
|
||||||
dotenv.load_dotenv() # load env variables
|
dotenv.load_dotenv() # load env variables
|
||||||
|
|
||||||
def set_callbacks():
|
def set_callbacks():
|
||||||
|
@ -21,5 +23,25 @@ def set_callbacks():
|
||||||
litellm.cache = Cache(type="redis", host=os.getenv("REDIS_HOST"), port=os.getenv("REDIS_PORT"), password=os.getenv("REDIS_PASSWORD"))
|
litellm.cache = Cache(type="redis", host=os.getenv("REDIS_HOST"), port=os.getenv("REDIS_PORT"), password=os.getenv("REDIS_PASSWORD"))
|
||||||
|
|
||||||
|
|
||||||
|
def load_router_config(router: Optional[litellm.Router]):
|
||||||
|
config = {}
|
||||||
|
config_file = 'config.yaml'
|
||||||
|
|
||||||
|
if os.path.exists(config_file):
|
||||||
|
with open(config_file, 'r') as file:
|
||||||
|
config = yaml.safe_load(file)
|
||||||
|
else:
|
||||||
|
print(f"Config file '{config_file}' not found.")
|
||||||
|
|
||||||
|
## MODEL LIST
|
||||||
|
model_list = config.get('model_list', None)
|
||||||
|
if model_list:
|
||||||
|
router = litellm.Router(model_list=model_list)
|
||||||
|
|
||||||
|
## ENVIRONMENT VARIABLES
|
||||||
|
environment_variables = config.get('environment_variables', None)
|
||||||
|
if environment_variables:
|
||||||
|
for key, value in environment_variables.items():
|
||||||
|
os.environ[key] = value
|
||||||
|
|
||||||
|
return router
|
28
router_config_template.yaml
Normal file
28
router_config_template.yaml
Normal file
|
@ -0,0 +1,28 @@
|
||||||
|
model_list:
|
||||||
|
- model_name: gpt-3.5-turbo
|
||||||
|
litellm_params:
|
||||||
|
model: azure/chatgpt-v-2
|
||||||
|
api_key: your_azure_api_key
|
||||||
|
api_version: your_azure_api_version
|
||||||
|
api_base: your_azure_api_base
|
||||||
|
tpm: 240000 # REPLACE with your azure deployment tpm
|
||||||
|
rpm: 1800 # REPLACE with your azure deployment rpm
|
||||||
|
- model_name: gpt-3.5-turbo
|
||||||
|
litellm_params:
|
||||||
|
model: azure/chatgpt-functioncalling
|
||||||
|
api_key: your_azure_api_key
|
||||||
|
api_version: your_azure_api_version
|
||||||
|
api_base: your_azure_api_base
|
||||||
|
tpm: 240000
|
||||||
|
rpm: 1800
|
||||||
|
- model_name: gpt-3.5-turbo
|
||||||
|
litellm_params:
|
||||||
|
model: gpt-3.5-turbo
|
||||||
|
api_key: your_openai_api_key
|
||||||
|
tpm: 1000000 # REPLACE with your openai tpm
|
||||||
|
rpm: 9000 # REPLACE with your openai rpm
|
||||||
|
|
||||||
|
environment_variables:
|
||||||
|
REDIS_HOST: your_redis_host
|
||||||
|
REDIS_PASSWORD: your_redis_password
|
||||||
|
REDIS_PORT: your_redis_port
|
Loading…
Add table
Add a link
Reference in a new issue