From eefa66e8f0212d8fb89e87b5c8be7b73f614b4ba Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 18 Nov 2023 17:34:02 -0800 Subject: [PATCH] docs(simple_proxy.md): adding token based auth to docs --- docs/my-website/docs/simple_proxy.md | 54 +++++++++++++++++++++++++++- litellm/proxy/proxy_server.py | 53 +++++++++++++++++---------- litellm/proxy/schema.prisma | 1 + 3 files changed, 88 insertions(+), 20 deletions(-) diff --git a/docs/my-website/docs/simple_proxy.md b/docs/my-website/docs/simple_proxy.md index aea529c1e..80dcdf68d 100644 --- a/docs/my-website/docs/simple_proxy.md +++ b/docs/my-website/docs/simple_proxy.md @@ -588,7 +588,7 @@ general_settings: master_key: sk-1234 # [OPTIONAL] Only use this if you to require all calls to contain this key (Authorization: Bearer sk-1234) ``` -### Multiple Models - Quick Start +### Multiple Models Here's how you can use multiple llms with one proxy `config.yaml`. @@ -638,8 +638,60 @@ curl --location 'http://0.0.0.0:8000/chat/completions' \ ' ``` +### Managing Auth - Virtual Keys +Grant other's temporary access to your proxy, with keys that expire after a set duration. +Requirements: + +- Need to a postgres database (e.g. [Supabase](https://supabase.com/)) + +You can then generate temporary keys by hitting the `/key/generate` endpoint. + +[**See code**](https://github.com/BerriAI/litellm/blob/7a669a36d2689c7f7890bc9c93e04ff3c2641299/litellm/proxy/proxy_server.py#L672) + +**Step 1: Save postgres db url** + +```yaml +model_list: + - model_name: gpt-4 + litellm_params: + model: ollama/llama2 + - model_name: gpt-3.5-turbo + litellm_params: + model: ollama/llama2 + +general_settings: + master_key: sk-1234 # [OPTIONAL] if set all calls to proxy will require either this key or a valid generated token + database_url: "postgresql://:@:/" +``` + +**Step 2: Start litellm** + +```bash +litellm --config /path/to/config.yaml +``` + +**Step 3: Generate temporary keys** + +```curl +curl 'http://0.0.0.0:8000/key/generate' \ +--h 'Authorization: Bearer sk-1234' \ +--d '{"models": ["gpt-3.5-turbo", "gpt-4", "claude-2"], "duration": "20m"}' +``` + +- `models`: *list or null (optional)* - Specify the models a token has access too. If null, then token has access to all models on server. + +- `duration`: *str or null (optional)* Specify the length of time the token is valid for. If null, default is set to 1 hour. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d"). + +Expected response: + +```python +{ + "key": "sk-kdEXbIqZRwEeEiHwdg7sFA", # Bearer token + "expires": "2023-11-19T01:38:25.838000+00:00" # datetime object +} +``` ### Save Model-specific params (API Base, API Keys, Temperature, Headers etc.) You can use the config to save model-specific information like api_base, api_key, temperature, max_tokens, etc. diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index f902ffb63..dbefc5d09 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -159,7 +159,6 @@ async def user_api_key_auth(request: Request): api_key = await oauth2_scheme(request=request) if api_key == master_key: return - print(f"prisma_client: {prisma_client}") if prisma_client: valid_token = await prisma_client.litellm_verificationtoken.find_first( where={ @@ -167,11 +166,17 @@ async def user_api_key_auth(request: Request): "expires": {"gte": datetime.utcnow()} # Check if the token is not expired } ) - print(f"valid_token: {valid_token}") if valid_token: - return + if len(valid_token.models) == 0: # assume an empty model list means all models are allowed to be called + return + else: + data = await request.json() + model = data.get("model", None) + if model and model not in valid_token.models: + raise Exception(f"Token not allowed to access model") + return else: - raise Exception + raise Exception(f"Invalid token") except Exception as e: print(f"An exception occurred - {e}") raise HTTPException( @@ -231,14 +236,17 @@ def save_params_to_config(data: dict): with open(user_config_path, "wb") as f: tomli_w.dump(config, f) -def prisma_setup(): +def prisma_setup(database_url: Optional[str]): global prisma_client - subprocess.run(['pip', 'install', 'prisma']) - subprocess.run(['python3', '-m', 'pip', 'install', 'prisma']) - subprocess.run(['prisma', 'db', 'push']) - # Now you can import the Prisma Client - from prisma import Client - prisma_client = Client() + if database_url: + import os + os.environ["DATABASE_URL"] = database_url + subprocess.run(['pip', 'install', 'prisma']) + subprocess.run(['python3', '-m', 'pip', 'install', 'prisma']) + subprocess.run(['prisma', 'db', 'push']) + # Now you can import the Prisma Client + from prisma import Client + prisma_client = Client() def load_router_config(router: Optional[litellm.Router], config_file_path: str): @@ -261,16 +269,15 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str): if environment_variables: for key, value in environment_variables.items(): os.environ[key] = value - ### CONNECT TO DATABASE ### - if key == "DATABASE_URL": - prisma_setup() - ## GENERAL SERVER SETTINGS (e.g. master key,..) general_settings = config.get("general_settings", None) if general_settings: ### MASTER KEY ### master_key = general_settings.get("master_key", None) + ### CONNECT TO DATABASE ### + database_url = general_settings.get("database_url", None) + prisma_setup(database_url=database_url) ## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..) @@ -290,7 +297,7 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str): return router, model_list, server_settings -async def generate_key_helper_fn(duration_str: str): +async def generate_key_helper_fn(duration_str: str, models: Optional[list]): token = f"sk-{secrets.token_urlsafe(16)}" def _duration_in_seconds(duration: str): match = re.match(r"(\d+)([smhd]?)", duration) @@ -318,7 +325,8 @@ async def generate_key_helper_fn(duration_str: str): # Create a new verification token (you may want to enhance this logic based on your needs) verification_token_data = { "token": token, - "expires": expires + "expires": expires, + "models": models } new_verification_token = await db.litellm_verificationtoken.create( # type: ignore {**verification_token_data} @@ -673,8 +681,15 @@ async def generate_key_fn(request: Request): data = await request.json() duration_str = data.get("duration", "1h") # Default to 1 hour if duration is not provided - response = await generate_key_helper_fn(duration_str=duration_str) - return {"token": response["token"], "expires": response["expires"]} + models = data.get("models", []) # Default to an empty list (meaning allow token to call all models) + if isinstance(models, list): + response = await generate_key_helper_fn(duration_str=duration_str, models=models) + return {"key": response["token"], "expires": response["expires"]} + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={"error": "models param must be a list"}, + ) @router.get("/ollama_logs", dependencies=[Depends(user_api_key_auth)]) diff --git a/litellm/proxy/schema.prisma b/litellm/proxy/schema.prisma index d4b603db1..4d2837a8c 100644 --- a/litellm/proxy/schema.prisma +++ b/litellm/proxy/schema.prisma @@ -11,4 +11,5 @@ generator client { model LiteLLM_VerificationToken { token String @unique expires DateTime + models String[] } \ No newline at end of file