From 030bd220785765638b57d929a8e05b4be5303e45 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 4 Dec 2023 18:32:47 -0800 Subject: [PATCH] feat(proxy_server.py): allow user to override api key auth --- litellm/proxy/custom_auth.py | 14 ++ litellm/proxy/proxy_server.py | 158 +++++------------- litellm/proxy/types.py | 70 ++++++++ litellm/proxy/utils.py | 60 ++++++- litellm/tests/test_configs/custom_auth.py | 14 ++ .../tests/{ => test_configs}/test_config.yaml | 0 .../test_configs/test_config_custom_auth.yaml | 11 ++ litellm/tests/test_proxy_custom_auth.py | 63 +++++++ litellm/tests/test_proxy_server_keys.py | 2 +- 9 files changed, 274 insertions(+), 118 deletions(-) create mode 100644 litellm/proxy/custom_auth.py create mode 100644 litellm/proxy/types.py create mode 100644 litellm/tests/test_configs/custom_auth.py rename litellm/tests/{ => test_configs}/test_config.yaml (100%) create mode 100644 litellm/tests/test_configs/test_config_custom_auth.yaml create mode 100644 litellm/tests/test_proxy_custom_auth.py diff --git a/litellm/proxy/custom_auth.py b/litellm/proxy/custom_auth.py new file mode 100644 index 000000000..0cce561ca --- /dev/null +++ b/litellm/proxy/custom_auth.py @@ -0,0 +1,14 @@ +from litellm.proxy.types import UserAPIKeyAuth +from fastapi import Request +from dotenv import load_dotenv +import os + +load_dotenv() +async def user_api_key_auth(request: Request, api_key: str) -> UserAPIKeyAuth: + try: + modified_master_key = f"{os.getenv('PROXY_MASTER_KEY')}-1234" + if api_key == modified_master_key: + return UserAPIKeyAuth(api_key=api_key) + raise Exception + except: + raise Exception \ No newline at end of file diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 8e9ddc9fa..6f8e0f6ab 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -92,12 +92,16 @@ def generate_feedback_box(): import litellm from litellm.proxy.utils import ( - PrismaClient + PrismaClient, + get_instance_fn ) +import pydantic +from litellm.proxy.types import * from litellm.caching import DualCache litellm.suppress_debug_info = True from fastapi import FastAPI, Request, HTTPException, status, Depends, BackgroundTasks from fastapi.routing import APIRouter +from fastapi.security import OAuth2PasswordBearer from fastapi.encoders import jsonable_encoder from fastapi.responses import StreamingResponse, FileResponse, ORJSONResponse from fastapi.middleware.cors import CORSMiddleware @@ -163,70 +167,8 @@ def log_input_output(request, response, custom_logger=None): return True from typing import Dict -from pydantic import BaseModel -######### Request Class Definition ###### -class ProxyChatCompletionRequest(BaseModel): - model: str - messages: List[Dict[str, str]] - temperature: Optional[float] = None - top_p: Optional[float] = None - n: Optional[int] = None - stream: Optional[bool] = None - stop: Optional[List[str]] = None - max_tokens: Optional[int] = None - presence_penalty: Optional[float] = None - frequency_penalty: Optional[float] = None - logit_bias: Optional[Dict[str, float]] = None - user: Optional[str] = None - response_format: Optional[Dict[str, str]] = None - seed: Optional[int] = None - tools: Optional[List[str]] = None - tool_choice: Optional[str] = None - functions: Optional[List[str]] = None # soon to be deprecated - function_call: Optional[str] = None # soon to be deprecated - - # Optional LiteLLM params - caching: Optional[bool] = None - api_base: Optional[str] = None - api_version: Optional[str] = None - api_key: Optional[str] = None - num_retries: Optional[int] = None - context_window_fallback_dict: Optional[Dict[str, str]] = None - fallbacks: Optional[List[str]] = None - metadata: Optional[Dict[str, str]] = {} - deployment_id: Optional[str] = None - request_timeout: Optional[int] = None - - class Config: - extra='allow' # allow params not defined here, these fall in litellm.completion(**kwargs) - -class ModelParams(BaseModel): - model_name: str - litellm_params: dict - model_info: Optional[dict] - class Config: - protected_namespaces = () - -class GenerateKeyRequest(BaseModel): - duration: str = "1h" - models: list = [] - aliases: dict = {} - config: dict = {} - spend: int = 0 - user_id: Optional[str] = None - -class GenerateKeyResponse(BaseModel): - key: str - expires: datetime - user_id: str - -class _DeleteKeyObject(BaseModel): - key: str - -class DeleteKeyRequest(BaseModel): - keys: List[_DeleteKeyObject] - +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") user_api_base = None user_model = None user_debug = False @@ -249,6 +191,7 @@ master_key = None otel_logging = False prisma_client: Optional[PrismaClient] = None user_api_key_cache = DualCache() +user_custom_auth = None ### REDIS QUEUE ### async_result = None celery_app_conn = None @@ -268,21 +211,21 @@ def usage_telemetry( target=litellm.utils.litellm_telemetry, args=(data,), daemon=True ).start() -api_key_header = APIKeyHeader(name="Authorization", auto_error=False) -async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(api_key_header)): - global master_key, prisma_client, llm_model_list - print(f"master_key - {master_key}; api_key - {api_key}") - if master_key is None: - if isinstance(api_key, str): - return { - "api_key": api_key.replace("Bearer ", "") - } - else: - return { - "api_key": api_key - } + +async def user_api_key_auth(request: Request, api_key: str = Depends(oauth2_scheme)) -> UserAPIKeyAuth: + global master_key, prisma_client, llm_model_list, user_custom_auth try: + ### USER-DEFINED AUTH FUNCTION ### + if user_custom_auth: + response = await user_custom_auth(request=request, api_key=api_key) + return UserAPIKeyAuth.model_validate(response) + + if master_key is None: + if isinstance(api_key, str): + return UserAPIKeyAuth(api_key=api_key.replace("Bearer ", "")) + else: + return UserAPIKeyAuth() if api_key is None: raise Exception("No api key passed in.") route = request.url.path @@ -290,9 +233,7 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap # note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead is_master_key_valid = secrets.compare_digest(api_key, master_key) or secrets.compare_digest(api_key, "Bearer " + master_key) if is_master_key_valid: - return { - "api_key": master_key - } + return UserAPIKeyAuth(api_key=master_key) if (route == "/key/generate" or route == "/key/delete" or route == "/key/info") and not is_master_key_valid: raise Exception(f"If master key is set, only master key can be used to generate, delete or get info for new keys") @@ -318,7 +259,7 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap return_dict = {"api_key": valid_token.token} if valid_token.user_id: return_dict["user_id"] = valid_token.user_id - return return_dict + return UserAPIKeyAuth(**return_dict) else: data = await request.json() model = data.get("model", None) @@ -329,14 +270,14 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap return_dict = {"api_key": valid_token.token} if valid_token.user_id: return_dict["user_id"] = valid_token.user_id - return return_dict + return UserAPIKeyAuth(**return_dict) else: raise Exception(f"Invalid token") except Exception as e: print(f"An exception occurred - {traceback.format_exc()}") raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, - detail={"error": "invalid user key"}, + detail="invalid user key", ) def prisma_setup(database_url: Optional[str]): @@ -464,7 +405,7 @@ def run_ollama_serve(): """) def load_router_config(router: Optional[litellm.Router], config_file_path: str): - global master_key, user_config_file_path, otel_logging + global master_key, user_config_file_path, otel_logging, user_custom_auth config = {} try: if os.path.exists(config_file_path): @@ -499,7 +440,6 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str): ### LOAD FROM AZURE KEY VAULT ### use_azure_key_vault = general_settings.get("use_azure_key_vault", False) load_from_azure_key_vault(use_azure_key_vault=use_azure_key_vault) - ### CONNECT TO DATABASE ### database_url = general_settings.get("database_url", None) if database_url and database_url.startswith("os.environ/"): @@ -514,12 +454,14 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str): master_key = general_settings.get("master_key", None) if master_key and master_key.startswith("os.environ/"): master_key = litellm.get_secret(master_key) - #### OpenTelemetry Logging (OTEL) ######## otel_logging = general_settings.get("otel", False) if otel_logging == True: print("\nOpenTelemetry Logging Activated") - + ### CUSTOM API KEY AUTH ### + custom_auth = general_settings.get("custom_auth", None) + if custom_auth: + user_custom_auth = get_instance_fn(value=custom_auth, config_file_path=config_file_path) ## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..) litellm_settings = config.get('litellm_settings', None) if litellm_settings: @@ -549,23 +491,7 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str): password=cache_password ) elif key == "callbacks": - print(f"{blue_color_code}\nSetting custom callbacks on Proxy") - passed_module, instance_name = value.split(".") - - # Dynamically import the module - module = importlib.import_module(passed_module) - # Get the instance from the module - instance = getattr(module, instance_name) - - methods = [method for method in dir(instance) if callable(getattr(instance, method))] - # Print the methods - print("Methods in the custom callbacks instance:") - for method in methods: - print(method) - - litellm.callbacks = [instance] - print() - + litellm.callbacks = [get_instance_fn(value=value)] else: setattr(litellm, key, value) @@ -844,7 +770,7 @@ def model_list(): @router.post("/v1/completions", dependencies=[Depends(user_api_key_auth)]) @router.post("/completions", dependencies=[Depends(user_api_key_auth)]) @router.post("/engines/{model:path}/completions", dependencies=[Depends(user_api_key_auth)]) -async def completion(request: Request, model: Optional[str] = None, user_api_key_dict: dict = Depends(user_api_key_auth)): +async def completion(request: Request, model: Optional[str] = None, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth)): try: body = await request.body() body_str = body.decode() @@ -853,7 +779,7 @@ async def completion(request: Request, model: Optional[str] = None, user_api_key except: data = json.loads(body_str) - data["user"] = user_api_key_dict.get("user_id", None) + data["user"] = user_api_key_dict.user_id data["model"] = ( general_settings.get("completion_model", None) # server default or user_model # model name passed via cli args @@ -864,9 +790,9 @@ async def completion(request: Request, model: Optional[str] = None, user_api_key data["model"] = user_model data["call_type"] = "text_completion" if "metadata" in data: - data["metadata"]["user_api_key"] = user_api_key_dict["api_key"] + data["metadata"]["user_api_key"] = user_api_key_dict.api_key else: - data["metadata"] = {"user_api_key": user_api_key_dict["api_key"]} + data["metadata"] = {"user_api_key": user_api_key_dict.api_key} return litellm_completion( **data @@ -888,7 +814,7 @@ async def completion(request: Request, model: Optional[str] = None, user_api_key @router.post("/v1/chat/completions", dependencies=[Depends(user_api_key_auth)], tags=["chat/completions"]) @router.post("/chat/completions", dependencies=[Depends(user_api_key_auth)], tags=["chat/completions"]) @router.post("/openai/deployments/{model:path}/chat/completions", dependencies=[Depends(user_api_key_auth)], tags=["chat/completions"]) # azure compatible endpoint -async def chat_completion(request: Request, model: Optional[str] = None, user_api_key_dict: dict = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()): +async def chat_completion(request: Request, model: Optional[str] = None, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()): global general_settings, user_debug try: data = {} @@ -905,13 +831,13 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap # users can pass in 'user' param to /chat/completions. Don't override it if data.get("user", None) is None: # if users are using user_api_key_auth, set `user` in `data` - data["user"] = user_api_key_dict.get("user_id", None) + data["user"] = user_api_key_dict.user_id if "metadata" in data: - data["metadata"]["user_api_key"] = user_api_key_dict["api_key"] + data["metadata"]["user_api_key"] = user_api_key_dict.api_key data["metadata"]["headers"] = request.headers else: - data["metadata"] = {"user_api_key": user_api_key_dict["api_key"]} + data["metadata"] = {"user_api_key": user_api_key_dict.api_key} data["metadata"]["headers"] = request.headers global user_temperature, user_request_timeout, user_max_tokens, user_api_base # override with user settings, these are params passed via cli @@ -962,14 +888,14 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap @router.post("/v1/embeddings", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse) @router.post("/embeddings", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse) -async def embeddings(request: Request, user_api_key_dict: dict = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()): +async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()): try: # Use orjson to parse JSON data, orjson speeds up requests significantly body = await request.body() data = orjson.loads(body) - data["user"] = user_api_key_dict.get("user_id", None) + data["user"] = user_api_key_dict.user_id data["model"] = ( general_settings.get("embedding_model", None) # server default or user_model # model name passed via cli args @@ -978,9 +904,9 @@ async def embeddings(request: Request, user_api_key_dict: dict = Depends(user_ap if user_model: data["model"] = user_model if "metadata" in data: - data["metadata"]["user_api_key"] = user_api_key_dict["api_key"] + data["metadata"]["user_api_key"] = user_api_key_dict.api_key else: - data["metadata"] = {"user_api_key": user_api_key_dict["api_key"]} + data["metadata"] = {"user_api_key": user_api_key_dict.api_key} ## ROUTE TO CORRECT ENDPOINT ## router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else [] diff --git a/litellm/proxy/types.py b/litellm/proxy/types.py new file mode 100644 index 000000000..fbee0732b --- /dev/null +++ b/litellm/proxy/types.py @@ -0,0 +1,70 @@ +from pydantic import BaseModel +from typing import Optional, List, Union, Dict +from datetime import datetime + +######### Request Class Definition ###### +class ProxyChatCompletionRequest(BaseModel): + model: str + messages: List[Dict[str, str]] + temperature: Optional[float] = None + top_p: Optional[float] = None + n: Optional[int] = None + stream: Optional[bool] = None + stop: Optional[List[str]] = None + max_tokens: Optional[int] = None + presence_penalty: Optional[float] = None + frequency_penalty: Optional[float] = None + logit_bias: Optional[Dict[str, float]] = None + user: Optional[str] = None + response_format: Optional[Dict[str, str]] = None + seed: Optional[int] = None + tools: Optional[List[str]] = None + tool_choice: Optional[str] = None + functions: Optional[List[str]] = None # soon to be deprecated + function_call: Optional[str] = None # soon to be deprecated + + # Optional LiteLLM params + caching: Optional[bool] = None + api_base: Optional[str] = None + api_version: Optional[str] = None + api_key: Optional[str] = None + num_retries: Optional[int] = None + context_window_fallback_dict: Optional[Dict[str, str]] = None + fallbacks: Optional[List[str]] = None + metadata: Optional[Dict[str, str]] = {} + deployment_id: Optional[str] = None + request_timeout: Optional[int] = None + + class Config: + extra='allow' # allow params not defined here, these fall in litellm.completion(**kwargs) + +class ModelParams(BaseModel): + model_name: str + litellm_params: dict + model_info: Optional[dict] + class Config: + protected_namespaces = () + +class GenerateKeyRequest(BaseModel): + duration: str = "1h" + models: list = [] + aliases: dict = {} + config: dict = {} + spend: int = 0 + user_id: Optional[str] = None + +class GenerateKeyResponse(BaseModel): + key: str + expires: datetime + user_id: str + +class _DeleteKeyObject(BaseModel): + key: str + +class DeleteKeyRequest(BaseModel): + keys: List[_DeleteKeyObject] + + +class UserAPIKeyAuth(BaseModel): # the expected response object for user api key auth + api_key: Optional[str] = None + user_id: Optional[str] = None \ No newline at end of file diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 1ea7f47a0..5b2039543 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -1,6 +1,7 @@ from typing import Optional, List, Any -import os, subprocess, hashlib +import os, subprocess, hashlib, importlib +### DB CONNECTOR ### class PrismaClient: def __init__(self, database_url: str): print("LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'") @@ -95,3 +96,60 @@ class PrismaClient: async def disconnect(self): await self.db.disconnect() +# ### CUSTOM FILE ### +# def get_instance_fn(value: str, config_file_path: Optional[str]=None): +# try: +# # Split the path by dots to separate module from instance +# parts = value.split(".") +# # The module path is all but the last part, and the instance is the last part +# module_path = ".".join(parts[:-1]) +# instance_name = parts[-1] + +# if config_file_path is not None: +# directory = os.path.dirname(config_file_path) +# module_path = os.path.join(directory, module_path) +# # Dynamically import the module +# module = importlib.import_module(module_path) + +# # Get the instance from the module +# instance = getattr(module, instance_name) + +# return instance +# except ImportError as e: +# print(e) +# raise ImportError(f"Could not import file at {value}") + +def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any: + try: + print(f"value: {value}") + # Split the path by dots to separate module from instance + parts = value.split(".") + + # The module path is all but the last part, and the instance_name is the last part + module_name = ".".join(parts[:-1]) + instance_name = parts[-1] + + # If config_file_path is provided, use it to determine the module spec and load the module + if config_file_path is not None: + directory = os.path.dirname(config_file_path) + module_file_path = os.path.join(directory, *module_name.split('.')) + module_file_path += '.py' + + spec = importlib.util.spec_from_file_location(module_name, module_file_path) + if spec is None: + raise ImportError(f"Could not find a module specification for {module_file_path}") + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + else: + # Dynamically import the module + module = importlib.import_module(module_name) + + # Get the instance from the module + instance = getattr(module, instance_name) + + return instance + except ImportError as e: + # Print the error message for easier debugging + print(e) + # Re-raise the exception with a user-friendly message + raise ImportError(f"Could not import {instance_name} from {module_name}") from e \ No newline at end of file diff --git a/litellm/tests/test_configs/custom_auth.py b/litellm/tests/test_configs/custom_auth.py new file mode 100644 index 000000000..f9de3a97a --- /dev/null +++ b/litellm/tests/test_configs/custom_auth.py @@ -0,0 +1,14 @@ +from litellm.proxy.types import UserAPIKeyAuth +from fastapi import Request +from dotenv import load_dotenv +import os + +load_dotenv() +async def user_api_key_auth(request: Request, api_key: str) -> UserAPIKeyAuth: + try: + print(f"api_key: {api_key}") + if api_key == f"{os.getenv('PROXY_MASTER_KEY')}-1234": + return UserAPIKeyAuth(api_key=api_key) + raise Exception + except: + raise Exception \ No newline at end of file diff --git a/litellm/tests/test_config.yaml b/litellm/tests/test_configs/test_config.yaml similarity index 100% rename from litellm/tests/test_config.yaml rename to litellm/tests/test_configs/test_config.yaml diff --git a/litellm/tests/test_configs/test_config_custom_auth.yaml b/litellm/tests/test_configs/test_config_custom_auth.yaml new file mode 100644 index 000000000..33088bd1c --- /dev/null +++ b/litellm/tests/test_configs/test_config_custom_auth.yaml @@ -0,0 +1,11 @@ +model_list: + - model_name: "openai-model" + litellm_params: + model: "gpt-3.5-turbo" + +litellm_settings: + drop_params: True + set_verbose: True + +general_settings: + custom_auth: custom_auth.user_api_key_auth \ No newline at end of file diff --git a/litellm/tests/test_proxy_custom_auth.py b/litellm/tests/test_proxy_custom_auth.py new file mode 100644 index 000000000..fa1b5f6dd --- /dev/null +++ b/litellm/tests/test_proxy_custom_auth.py @@ -0,0 +1,63 @@ +import sys, os +import traceback +from dotenv import load_dotenv + +load_dotenv() +import os, io + +# this file is to test litellm/proxy + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import pytest +import litellm +from litellm import embedding, completion, completion_cost, Timeout +from litellm import RateLimitError + +# test /chat/completion request to the proxy +from fastapi.testclient import TestClient +from fastapi import FastAPI +from litellm.proxy.proxy_server import router, save_worker_config, startup_event # Replace with the actual module where your FastAPI router is defined +filepath = os.path.dirname(os.path.abspath(__file__)) +config_fp = f"{filepath}/test_configs/test_config_custom_auth.yaml" +save_worker_config(config=config_fp, model=None, alias=None, api_base=None, api_version=None, debug=False, temperature=None, max_tokens=None, request_timeout=600, max_budget=None, telemetry=False, drop_params=True, add_function_to_prompt=False, headers=None, save=False, use_queue=False) +app = FastAPI() +app.include_router(router) # Include your router in the test app +@app.on_event("startup") +async def wrapper_startup_event(): + await startup_event() + +# Here you create a fixture that will be used by your tests +# Make sure the fixture returns TestClient(app) +@pytest.fixture(autouse=True) +def client(): + with TestClient(app) as client: + yield client + +def test_custom_auth(client): + try: + # Your test data + test_data = { + "model": "openai-model", + "messages": [ + { + "role": "user", + "content": "hi" + }, + ], + "max_tokens": 10, + } + # Your bearer token + token = os.getenv("PROXY_MASTER_KEY") + + headers = { + "Authorization": f"Bearer {token}" + } + response = client.post("/chat/completions", json=test_data, headers=headers) + print(f"response: {response.text}") + assert response.status_code == 401 + result = response.json() + print(f"Received response: {result}") + except Exception as e: + pytest.fail("LiteLLM Proxy test failed. Exception", e) \ No newline at end of file diff --git a/litellm/tests/test_proxy_server_keys.py b/litellm/tests/test_proxy_server_keys.py index a2dd396c0..fb0ec2f3c 100644 --- a/litellm/tests/test_proxy_server_keys.py +++ b/litellm/tests/test_proxy_server_keys.py @@ -25,7 +25,7 @@ from fastapi.testclient import TestClient from fastapi import FastAPI from litellm.proxy.proxy_server import router, save_worker_config, startup_event # Replace with the actual module where your FastAPI router is defined filepath = os.path.dirname(os.path.abspath(__file__)) -config_fp = f"{filepath}/test_config.yaml" +config_fp = f"{filepath}/test_configs/test_config.yaml" save_worker_config(config=config_fp, model=None, alias=None, api_base=None, api_version=None, debug=False, temperature=None, max_tokens=None, request_timeout=600, max_budget=None, telemetry=False, drop_params=True, add_function_to_prompt=False, headers=None, save=False, use_queue=False) app = FastAPI() app.include_router(router) # Include your router in the test app