build(litellm_server/utils.py): add support for general settings + num retries as a module variable

This commit is contained in:
Krrish Dholakia 2023-11-02 20:56:33 -07:00
parent 3f1b4c0759
commit e3a1c58dd9
7 changed files with 52 additions and 33 deletions

View file

@ -5,10 +5,11 @@ from fastapi.responses import StreamingResponse, FileResponse
from fastapi.middleware.cors import CORSMiddleware
import json, sys
from typing import Optional
# sys.path.insert(
# 0, os.path.abspath("../")
# ) # Adds the parent directory to the system path - for litellm local dev
sys.path.insert(
0, os.path.abspath("../")
) # Adds the parent directory to the system path - for litellm local dev
import litellm
print(f"litellm: {litellm}")
try:
from utils import set_callbacks, load_router_config, print_verbose
except ImportError:
@ -30,14 +31,15 @@ app.add_middleware(
#### GLOBAL VARIABLES ####
llm_router: Optional[litellm.Router] = None
llm_model_list: Optional[list] = None
server_settings: Optional[dict] = None
set_callbacks() # sets litellm callbacks for logging if they exist in the environment
if "CONFIG_FILE_PATH" in os.environ:
print(f"CONFIG FILE DETECTED")
llm_router, llm_model_list = load_router_config(router=llm_router, config_file_path=os.getenv("CONFIG_FILE_PATH"))
llm_router, llm_model_list, server_settings = load_router_config(router=llm_router, config_file_path=os.getenv("CONFIG_FILE_PATH"))
else:
llm_router, llm_model_list = load_router_config(router=llm_router)
llm_router, llm_model_list, server_settings = load_router_config(router=llm_router)
#### API ENDPOINTS ####
@router.get("/v1/models")
@router.get("/models") # if project requires model list
@ -100,27 +102,31 @@ async def embedding(request: Request):
@router.post("/chat/completions")
@router.post("/openai/deployments/{model:path}/chat/completions") # azure compatible endpoint
async def chat_completion(request: Request, model: Optional[str] = None):
global llm_model_list
global llm_model_list, server_settings
try:
data = await request.json()
if model:
data["model"] = model
print(f"data: {data}")
data["model"] = (
server_settings.get("completion_model", None) # server default
or model # model passed in url
or data["model"] # default passed in
)
## CHECK KEYS ##
# default to always using the "ENV" variables, only if AUTH_STRATEGY==DYNAMIC then reads headers
env_validation = litellm.validate_environment(model=data["model"])
if (env_validation['keys_in_environment'] is False or os.getenv("AUTH_STRATEGY", None) == "DYNAMIC") and ("authorization" in request.headers or "api-key" in request.headers): # if users pass LLM api keys as part of header
if "authorization" in request.headers:
api_key = request.headers.get("authorization")
elif "api-key" in request.headers:
api_key = request.headers.get("api-key")
print(f"api_key in headers: {api_key}")
if " " in api_key:
api_key = api_key.split(" ")[1]
print(f"api_key split: {api_key}")
if len(api_key) > 0:
api_key = api_key
data["api_key"] = api_key
print(f"api_key in data: {api_key}")
# env_validation = litellm.validate_environment(model=data["model"])
# if (env_validation['keys_in_environment'] is False or os.getenv("AUTH_STRATEGY", None) == "DYNAMIC") and ("authorization" in request.headers or "api-key" in request.headers): # if users pass LLM api keys as part of header
# if "authorization" in request.headers:
# api_key = request.headers.get("authorization")
# elif "api-key" in request.headers:
# api_key = request.headers.get("api-key")
# print(f"api_key in headers: {api_key}")
# if " " in api_key:
# api_key = api_key.split(" ")[1]
# print(f"api_key split: {api_key}")
# if len(api_key) > 0:
# api_key = api_key
# data["api_key"] = api_key
# print(f"api_key in data: {api_key}")
## CHECK CONFIG ##
if llm_model_list and data["model"] in [m["model_name"] for m in llm_model_list]:
for m in llm_model_list:
@ -133,13 +139,14 @@ async def chat_completion(request: Request, model: Optional[str] = None):
)
if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses
return StreamingResponse(data_generator(response), media_type='text/event-stream')
print(f"response: {response}")
return response
except Exception as e:
error_traceback = traceback.format_exc()
print(f"{error_traceback}")
error_msg = f"{str(e)}\n\n{error_traceback}"
return {"error": error_msg}
# raise HTTPException(status_code=500, detail=error_msg)
# return {"error": error_msg}
raise HTTPException(status_code=500, detail=error_msg)
@router.post("/router/completions")
async def router_completion(request: Request):