build(litellm_server/main.py): accept config file in /chat/completions

This commit is contained in:
Krrish Dholakia 2023-10-27 10:46:25 -07:00
parent dca28667fa
commit cd9b671cfe
3 changed files with 8 additions and 8 deletions

View file

@ -1,11 +1,5 @@
FROM python:3.10
# Define a build argument for the config file path
ARG CONFIG_FILE
# Copy the custom config file (if provided) into the Docker image
COPY $CONFIG_FILE /app/config.yaml
COPY . /app
WORKDIR /app
RUN pip install -r requirements.txt

View file

@ -26,9 +26,10 @@ app.add_middleware(
)
#### GLOBAL VARIABLES ####
llm_router: Optional[litellm.Router] = None
llm_model_list: Optional[list] = None
set_callbacks() # sets litellm callbacks for logging if they exist in the environment
llm_router = load_router_config(router=llm_router)
llm_router, llm_model_list = load_router_config(router=llm_router)
#### API ENDPOINTS ####
@router.get("/v1/models")
@router.get("/models") # if project requires model list
@ -88,8 +89,10 @@ async def embedding(request: Request):
@router.post("/v1/chat/completions")
@router.post("/chat/completions")
async def chat_completion(request: Request):
global llm_model_list
try:
data = await request.json()
## CHECK KEYS ##
# default to always using the "ENV" variables, only if AUTH_STRATEGY==DYNAMIC then reads headers
env_validation = litellm.validate_environment(model=data["model"])
if (env_validation['keys_in_environment'] is False or os.getenv("AUTH_STRATEGY", None) == "DYNAMIC") and "authorization" in request.headers: # if users pass LLM api keys as part of header
@ -98,6 +101,9 @@ async def chat_completion(request: Request):
if len(api_key) > 0:
api_key = api_key
data["api_key"] = api_key
## CHECK CONFIG ##
if llm_model_list and data["model"] in [m["model_name"] for m in llm_model_list]:
return await router_completion(request=request)
response = litellm.completion(
**data
)

View file

@ -67,4 +67,4 @@ def load_router_config(router: Optional[litellm.Router]):
for key, value in environment_variables.items():
os.environ[key] = value
return router
return router, model_list