docs(simple_proxy.md): adding token based auth to docs

This commit is contained in:
Krrish Dholakia 2023-11-18 17:34:02 -08:00
parent 4b110b3fa4
commit eefa66e8f0
3 changed files with 88 additions and 20 deletions

View file

@ -159,7 +159,6 @@ async def user_api_key_auth(request: Request):
api_key = await oauth2_scheme(request=request)
if api_key == master_key:
return
print(f"prisma_client: {prisma_client}")
if prisma_client:
valid_token = await prisma_client.litellm_verificationtoken.find_first(
where={
@ -167,11 +166,17 @@ async def user_api_key_auth(request: Request):
"expires": {"gte": datetime.utcnow()} # Check if the token is not expired
}
)
print(f"valid_token: {valid_token}")
if valid_token:
return
if len(valid_token.models) == 0: # assume an empty model list means all models are allowed to be called
return
else:
data = await request.json()
model = data.get("model", None)
if model and model not in valid_token.models:
raise Exception(f"Token not allowed to access model")
return
else:
raise Exception
raise Exception(f"Invalid token")
except Exception as e:
print(f"An exception occurred - {e}")
raise HTTPException(
@ -231,14 +236,17 @@ def save_params_to_config(data: dict):
with open(user_config_path, "wb") as f:
tomli_w.dump(config, f)
def prisma_setup():
def prisma_setup(database_url: Optional[str]):
global prisma_client
subprocess.run(['pip', 'install', 'prisma'])
subprocess.run(['python3', '-m', 'pip', 'install', 'prisma'])
subprocess.run(['prisma', 'db', 'push'])
# Now you can import the Prisma Client
from prisma import Client
prisma_client = Client()
if database_url:
import os
os.environ["DATABASE_URL"] = database_url
subprocess.run(['pip', 'install', 'prisma'])
subprocess.run(['python3', '-m', 'pip', 'install', 'prisma'])
subprocess.run(['prisma', 'db', 'push'])
# Now you can import the Prisma Client
from prisma import Client
prisma_client = Client()
def load_router_config(router: Optional[litellm.Router], config_file_path: str):
@ -261,16 +269,15 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
if environment_variables:
for key, value in environment_variables.items():
os.environ[key] = value
### CONNECT TO DATABASE ###
if key == "DATABASE_URL":
prisma_setup()
## GENERAL SERVER SETTINGS (e.g. master key,..)
general_settings = config.get("general_settings", None)
if general_settings:
### MASTER KEY ###
master_key = general_settings.get("master_key", None)
### CONNECT TO DATABASE ###
database_url = general_settings.get("database_url", None)
prisma_setup(database_url=database_url)
## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..)
@ -290,7 +297,7 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
return router, model_list, server_settings
async def generate_key_helper_fn(duration_str: str):
async def generate_key_helper_fn(duration_str: str, models: Optional[list]):
token = f"sk-{secrets.token_urlsafe(16)}"
def _duration_in_seconds(duration: str):
match = re.match(r"(\d+)([smhd]?)", duration)
@ -318,7 +325,8 @@ async def generate_key_helper_fn(duration_str: str):
# Create a new verification token (you may want to enhance this logic based on your needs)
verification_token_data = {
"token": token,
"expires": expires
"expires": expires,
"models": models
}
new_verification_token = await db.litellm_verificationtoken.create( # type: ignore
{**verification_token_data}
@ -673,8 +681,15 @@ async def generate_key_fn(request: Request):
data = await request.json()
duration_str = data.get("duration", "1h") # Default to 1 hour if duration is not provided
response = await generate_key_helper_fn(duration_str=duration_str)
return {"token": response["token"], "expires": response["expires"]}
models = data.get("models", []) # Default to an empty list (meaning allow token to call all models)
if isinstance(models, list):
response = await generate_key_helper_fn(duration_str=duration_str, models=models)
return {"key": response["token"], "expires": response["expires"]}
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"error": "models param must be a list"},
)
@router.get("/ollama_logs", dependencies=[Depends(user_api_key_auth)])