litellm/litellm_server/main.py

196 lines
No EOL
8 KiB
Python

import os, traceback
from fastapi import FastAPI, Request, HTTPException
from fastapi.routing import APIRouter
from fastapi.responses import StreamingResponse, FileResponse
from fastapi.middleware.cors import CORSMiddleware
import json, sys
from typing import Optional
sys.path.insert(
0, os.path.abspath("../")
) # Adds the parent directory to the system path - for litellm local dev
import litellm
print(f"litellm: {litellm}")
try:
from utils import set_callbacks, load_router_config, print_verbose
except ImportError:
from litellm_server.utils import set_callbacks, load_router_config, print_verbose
import dotenv
dotenv.load_dotenv() # load env variables
app = FastAPI(docs_url="/", title="LiteLLM API")
router = APIRouter()
origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
#### GLOBAL VARIABLES ####
llm_router: Optional[litellm.Router] = None
llm_model_list: Optional[list] = None
server_settings: Optional[dict] = None
set_callbacks() # sets litellm callbacks for logging if they exist in the environment
if "CONFIG_FILE_PATH" in os.environ:
print(f"CONFIG FILE DETECTED")
llm_router, llm_model_list, server_settings = load_router_config(router=llm_router, config_file_path=os.getenv("CONFIG_FILE_PATH"))
else:
llm_router, llm_model_list, server_settings = load_router_config(router=llm_router)
#### API ENDPOINTS ####
@router.get("/v1/models")
@router.get("/models") # if project requires model list
def model_list():
all_models = litellm.utils.get_valid_models()
if llm_model_list:
all_models += llm_model_list
return dict(
data=[
{
"id": model,
"object": "model",
"created": 1677610602,
"owned_by": "openai",
}
for model in all_models
],
object="list",
)
# for streaming
def data_generator(response):
print("inside generator")
for chunk in response:
print(f"returned chunk: {chunk}")
yield f"data: {json.dumps(chunk)}\n\n"
@router.post("/v1/completions")
@router.post("/completions")
async def completion(request: Request):
data = await request.json()
response = litellm.completion(
**data
)
if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses
return StreamingResponse(data_generator(response), media_type='text/event-stream')
return response
@router.post("/v1/embeddings")
@router.post("/embeddings")
async def embedding(request: Request):
try:
data = await request.json()
# default to always using the "ENV" variables, only if AUTH_STRATEGY==DYNAMIC then reads headers
if os.getenv("AUTH_STRATEGY", None) == "DYNAMIC" and "authorization" in request.headers: # if users pass LLM api keys as part of header
api_key = request.headers.get("authorization")
api_key = api_key.replace("Bearer", "").strip()
if len(api_key.strip()) > 0:
api_key = api_key
data["api_key"] = api_key
response = litellm.embedding(
**data
)
return response
except Exception as e:
error_traceback = traceback.format_exc()
error_msg = f"{str(e)}\n\n{error_traceback}"
return {"error": error_msg}
@router.post("/v1/chat/completions")
@router.post("/chat/completions")
@router.post("/openai/deployments/{model:path}/chat/completions") # azure compatible endpoint
async def chat_completion(request: Request, model: Optional[str] = None):
global llm_model_list, server_settings
try:
data = await request.json()
print(f"data: {data}")
server_model = server_settings.get("completion_model", None) if server_settings else None
data["model"] = server_model or model or data["model"]
## CHECK KEYS ##
# default to always using the "ENV" variables, only if AUTH_STRATEGY==DYNAMIC then reads headers
# env_validation = litellm.validate_environment(model=data["model"])
# if (env_validation['keys_in_environment'] is False or os.getenv("AUTH_STRATEGY", None) == "DYNAMIC") and ("authorization" in request.headers or "api-key" in request.headers): # if users pass LLM api keys as part of header
# if "authorization" in request.headers:
# api_key = request.headers.get("authorization")
# elif "api-key" in request.headers:
# api_key = request.headers.get("api-key")
# print(f"api_key in headers: {api_key}")
# if " " in api_key:
# api_key = api_key.split(" ")[1]
# print(f"api_key split: {api_key}")
# if len(api_key) > 0:
# api_key = api_key
# data["api_key"] = api_key
# print(f"api_key in data: {api_key}")
## CHECK CONFIG ##
if llm_model_list and data["model"] in [m["model_name"] for m in llm_model_list]:
for m in llm_model_list:
if data["model"] == m["model_name"]:
for key, value in m["litellm_params"].items():
data[key] = value
break
response = litellm.completion(
**data
)
if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses
return StreamingResponse(data_generator(response), media_type='text/event-stream')
print(f"response: {response}")
return response
except Exception as e:
error_traceback = traceback.format_exc()
print(f"{error_traceback}")
error_msg = f"{str(e)}\n\n{error_traceback}"
# return {"error": error_msg}
raise HTTPException(status_code=500, detail=error_msg)
@router.post("/router/completions")
async def router_completion(request: Request):
global llm_router
try:
data = await request.json()
if "model_list" in data:
llm_router = litellm.Router(model_list=data.pop("model_list"))
if llm_router is None:
raise Exception("Save model list via config.yaml. Eg.: ` docker build -t myapp --build-arg CONFIG_FILE=myconfig.yaml .` or pass it in as model_list=[..] as part of the request body")
# openai.ChatCompletion.create replacement
response = await llm_router.acompletion(model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hey, how's it going?"}])
if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses
return StreamingResponse(data_generator(response), media_type='text/event-stream')
return response
except Exception as e:
error_traceback = traceback.format_exc()
error_msg = f"{str(e)}\n\n{error_traceback}"
return {"error": error_msg}
@router.post("/router/embedding")
async def router_embedding(request: Request):
global llm_router
try:
data = await request.json()
if "model_list" in data:
llm_router = litellm.Router(model_list=data.pop("model_list"))
if llm_router is None:
raise Exception("Save model list via config.yaml. Eg.: ` docker build -t myapp --build-arg CONFIG_FILE=myconfig.yaml .` or pass it in as model_list=[..] as part of the request body")
response = await llm_router.aembedding(model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hey, how's it going?"}])
if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses
return StreamingResponse(data_generator(response), media_type='text/event-stream')
return response
except Exception as e:
error_traceback = traceback.format_exc()
error_msg = f"{str(e)}\n\n{error_traceback}"
return {"error": error_msg}
@router.get("/")
async def home(request: Request):
return "LiteLLM: RUNNING"
app.include_router(router)