fix(proxy_server.py): hash keys

This commit is contained in:
Krrish Dholakia 2023-12-02 19:24:58 -08:00
parent 2e24c8275a
commit 6015bff80b
5 changed files with 135 additions and 132 deletions

View file

@ -4,6 +4,7 @@ import shutil, random, traceback, requests
from datetime import datetime, timedelta
from typing import Optional, List
import secrets, subprocess
import hashlib, uuid
import warnings
messages: list = []
sys.path.insert(
@ -89,6 +90,9 @@ def generate_feedback_box():
print()
import litellm
from litellm.proxy.utils import (
PrismaClient
)
from litellm.caching import DualCache
litellm.suppress_debug_info = True
from fastapi import FastAPI, Request, HTTPException, status, Depends, BackgroundTasks
@ -204,10 +208,19 @@ class GenerateKeyRequest(BaseModel):
aliases: dict = {}
config: dict = {}
spend: int = 0
user_id: Optional[str]
class GenerateKeyResponse(BaseModel):
key: str
expires: str
user_id: str
class _DeleteKeyObject(BaseModel):
key: str
class DeleteKeyRequest(BaseModel):
keys: List[_DeleteKeyObject]
user_api_base = None
user_model = None
@ -229,7 +242,7 @@ log_file = "api_log.json"
worker_config = None
master_key = None
otel_logging = False
prisma_client = None
prisma_client: Optional[PrismaClient] = None
user_api_key_cache = DualCache()
### REDIS QUEUE ###
async_result = None
@ -277,13 +290,7 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
if valid_token is None and "Bearer " in api_key:
## check db
cleaned_api_key = api_key[len("Bearer "):]
valid_token = await prisma_client.litellm_verificationtoken.find_first(
where={
"token": cleaned_api_key,
"expires": {"gte": datetime.utcnow()} # Check if the token is not expired
}
)
## save to cache for 60s
valid_token = await prisma_client.get_data(token=cleaned_api_key, expires=datetime.utcnow())
user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=60)
elif valid_token is not None:
print(f"API Key Cache Hit!")
@ -321,14 +328,7 @@ def prisma_setup(database_url: Optional[str]):
global prisma_client
if database_url:
try:
import os
print("LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'")
os.environ["DATABASE_URL"] = database_url
subprocess.run(['prisma', 'generate'])
subprocess.run(['prisma', 'db', 'push', '--accept-data-loss']) # this looks like a weird edge case when prisma just wont start on render. we need to have the --accept-data-loss
# Now you can import the Prisma Client
from prisma import Client
prisma_client = Client()
prisma_client = PrismaClient(database_url=database_url)
except Exception as e:
print("Error when initializing prisma, Ensure you run pip install prisma", e)
@ -388,6 +388,7 @@ def track_cost_callback(
start_time = None,
end_time = None, # start/end time for completion
):
global prisma_client
try:
# check if it has collected an entire stream response
if "complete_streaming_response" in kwargs:
@ -415,46 +416,41 @@ def track_cost_callback(
# Create new event loop for async function execution in the new thread
new_loop = asyncio.new_event_loop()
asyncio.set_event_loop(new_loop)
try:
# Run the async function using the newly created event loop
new_loop.run_until_complete(update_prisma_database(user_api_key, response_cost))
existing_spend_obj = new_loop.run_until_complete(prisma_client.get_data(token=user_api_key))
if existing_spend_obj is None:
existing_spend = 0
else:
existing_spend = existing_spend_obj.spend
# Calculate the new cost by adding the existing cost and response_cost
new_spend = existing_spend + response_cost
print(f"new cost: {new_spend}")
# Update the cost column for the given token
new_loop.run_until_complete(prisma_client.update_data(token=user_api_key, data={"spend": new_spend}))
print(f"Prisma database updated for token {user_api_key}. New cost: {new_spend}")
except Exception as e:
print(f"error in tracking cost callback - {str(e)}")
finally:
# Close the event loop after the task is done
new_loop.close()
# Ensure that there's no event loop set in this thread, which could interfere with future asyncio calls
asyncio.set_event_loop(None)
print(f"error in creating async loop - {str(e)}")
except Exception as e:
print(f"error in tracking cost callback - {str(e)}")
async def update_prisma_database(token, response_cost):
global prisma_client
try:
print(f"Enters prisma db call, token: {token}")
# Fetch the existing cost for the given token
existing_spend = await prisma_client.litellm_verificationtoken.find_unique(
where={
"token": token
}
)
print(f"existing spend: {existing_spend}")
if existing_spend is None:
existing_spend_obj = await prisma_client.get_data(token=token)
print(f"existing spend: {existing_spend_obj}")
if existing_spend_obj is None:
existing_spend = 0
else:
existing_spend = existing_spend_obj.spend
# Calculate the new cost by adding the existing cost and response_cost
new_spend = existing_spend.spend + response_cost
new_spend = existing_spend + response_cost
print(f"new cost: {new_spend}")
# Update the cost column for the given token
await prisma_client.litellm_verificationtoken.update(
where={
"token": token
},
data={
"spend": new_spend
}
)
await prisma_client.update_data(token=token, data={"spend": new_spend})
print(f"Prisma database updated for token {token}. New cost: {new_spend}")
except Exception as e:
@ -569,9 +565,11 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
run_ollama_serve()
return router, model_list, general_settings
async def generate_key_helper_fn(duration_str: Optional[str], models: list, aliases: dict, config: dict, spend: float, token: Optional[str]):
async def generate_key_helper_fn(duration_str: Optional[str], models: list, aliases: dict, config: dict, spend: float, token: Optional[str], user_id: Optional[str]=None):
if token is None:
token = f"sk-{secrets.token_urlsafe(16)}"
def _duration_in_seconds(duration: str):
match = re.match(r"(\d+)([smhd]?)", duration)
if not match:
@ -599,8 +597,8 @@ async def generate_key_helper_fn(duration_str: Optional[str], models: list, alia
aliases_json = json.dumps(aliases)
config_json = json.dumps(config)
user_id = user_id or str(uuid.uuid4())
try:
db = prisma_client
# Create a new verification token (you may want to enhance this logic based on your needs)
verification_token_data = {
"token": token,
@ -608,30 +606,21 @@ async def generate_key_helper_fn(duration_str: Optional[str], models: list, alia
"models": models,
"aliases": aliases_json,
"config": config_json,
"spend": spend
"spend": spend,
"user_id": user_id
}
new_verification_token = await db.litellm_verificationtoken.upsert( # type: ignore
where={
'token': token,
},
data={
"create": {**verification_token_data}, #type: ignore
"update": {} # don't do anything if it already exists
}
)
new_verification_token = await prisma_client.insert_data(data=verification_token_data)
except Exception as e:
traceback.print_exc()
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
return {"token": new_verification_token.token, "expires": new_verification_token.expires}
return {"token": new_verification_token.token, "expires": new_verification_token.expires, "user_id": user_id}
async def delete_verification_token(tokens: List[str]):
async def delete_verification_token(tokens: List):
global prisma_client
try:
if prisma_client:
# Assuming 'db' is your Prisma Client instance
deleted_tokens = await prisma_client.litellm_verificationtoken.delete_many(
where={"token": {"in": tokens}}
)
deleted_tokens = await prisma_client.delete_data(tokens=tokens)
else:
raise Exception
except Exception as e:
@ -980,7 +969,9 @@ async def embeddings(request: Request, user_api_key_dict: dict = Depends(user_ap
@router.post("/key/generate", tags=["key management"], dependencies=[Depends(user_api_key_auth)], response_model=GenerateKeyResponse)
async def generate_key_fn(request: Request, data: GenerateKeyRequest):
"""
Generate an API key based on the provided data.
Generate an API key based on the provided data.
Docs: https://docs.litellm.ai/docs/proxy/virtual_keys
Parameters:
- duration: Optional[str] - Specify the length of time the token is valid for. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d"). **(Default is set to 1 hour.)**
@ -1000,9 +991,10 @@ async def generate_key_fn(request: Request, data: GenerateKeyRequest):
aliases = data.aliases # Default to an empty dict (no alias mappings, on top of anything in the config.yaml model_list)
config = data.config
spend = data.spend
user_id = data.user_id
if isinstance(models, list):
response = await generate_key_helper_fn(duration_str=duration_str, models=models, aliases=aliases, config=config, spend=spend)
return GenerateKeyResponse(key=response["token"], expires=response["expires"])
response = await generate_key_helper_fn(duration_str=duration_str, models=models, aliases=aliases, config=config, spend=spend, user_id=user_id)
return GenerateKeyResponse(key=response["token"], expires=response["expires"], user_id=response["user_id"])
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
@ -1010,7 +1002,7 @@ async def generate_key_fn(request: Request, data: GenerateKeyRequest):
)
@router.post("/key/delete", tags=["key management"], dependencies=[Depends(user_api_key_auth)])
async def delete_key_fn(request: Request):
async def delete_key_fn(request: Request, data: DeleteKeyRequest):
try:
data = await request.json()
@ -1037,12 +1029,7 @@ async def info_key_fn(key: str = fastapi.Query(..., description="Key in the requ
try:
if prisma_client is None:
raise Exception(f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys")
key_info = await prisma_client.litellm_verificationtoken.find_unique(
where={
"token": key
}
)
key_info = await prisma_client.get_data(token=key)
return {"key": key, "info": key_info}
except Exception as e:
raise HTTPException(