From cd9b671cfe94e255cf5f35274cf7d539b740b7ba Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Fri, 27 Oct 2023 10:46:25 -0700 Subject: [PATCH] build(litellm_server/main.py): accept config file in /chat/completions --- Dockerfile | 6 ------ litellm_server/main.py | 8 +++++++- litellm_server/utils.py | 2 +- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/Dockerfile b/Dockerfile index 30d78eb185..179629c9a5 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,11 +1,5 @@ FROM python:3.10 -# Define a build argument for the config file path -ARG CONFIG_FILE - -# Copy the custom config file (if provided) into the Docker image -COPY $CONFIG_FILE /app/config.yaml - COPY . /app WORKDIR /app RUN pip install -r requirements.txt diff --git a/litellm_server/main.py b/litellm_server/main.py index 0aee45b478..cc29d96ce8 100644 --- a/litellm_server/main.py +++ b/litellm_server/main.py @@ -26,9 +26,10 @@ app.add_middleware( ) #### GLOBAL VARIABLES #### llm_router: Optional[litellm.Router] = None +llm_model_list: Optional[list] = None set_callbacks() # sets litellm callbacks for logging if they exist in the environment -llm_router = load_router_config(router=llm_router) +llm_router, llm_model_list = load_router_config(router=llm_router) #### API ENDPOINTS #### @router.get("/v1/models") @router.get("/models") # if project requires model list @@ -88,8 +89,10 @@ async def embedding(request: Request): @router.post("/v1/chat/completions") @router.post("/chat/completions") async def chat_completion(request: Request): + global llm_model_list try: data = await request.json() + ## CHECK KEYS ## # default to always using the "ENV" variables, only if AUTH_STRATEGY==DYNAMIC then reads headers env_validation = litellm.validate_environment(model=data["model"]) if (env_validation['keys_in_environment'] is False or os.getenv("AUTH_STRATEGY", None) == "DYNAMIC") and "authorization" in request.headers: # if users pass LLM api keys as part of header @@ -98,6 +101,9 @@ async def chat_completion(request: Request): if len(api_key) > 0: api_key = api_key data["api_key"] = api_key + ## CHECK CONFIG ## + if llm_model_list and data["model"] in [m["model_name"] for m in llm_model_list]: + return await router_completion(request=request) response = litellm.completion( **data ) diff --git a/litellm_server/utils.py b/litellm_server/utils.py index 5cb1bd06a7..8dee3df03f 100644 --- a/litellm_server/utils.py +++ b/litellm_server/utils.py @@ -67,4 +67,4 @@ def load_router_config(router: Optional[litellm.Router]): for key, value in environment_variables.items(): os.environ[key] = value - return router + return router, model_list