fix(proxy_server.py): hash keys

This commit is contained in:
Krrish Dholakia 2023-12-02 19:24:58 -08:00
parent 2e24c8275a
commit 6015bff80b
5 changed files with 135 additions and 132 deletions

View file

@ -87,7 +87,7 @@ const sidebars = {
}, },
{ {
type: "category", type: "category",
label: "💥 OpenAI Proxy", label: "💥 OpenAI Proxy Server",
link: { link: {
type: 'generated-index', type: 'generated-index',
title: '💥 OpenAI Proxy Server', title: '💥 OpenAI Proxy Server',

View file

@ -4,6 +4,7 @@ import shutil, random, traceback, requests
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Optional, List from typing import Optional, List
import secrets, subprocess import secrets, subprocess
import hashlib, uuid
import warnings import warnings
messages: list = [] messages: list = []
sys.path.insert( sys.path.insert(
@ -89,6 +90,9 @@ def generate_feedback_box():
print() print()
import litellm import litellm
from litellm.proxy.utils import (
PrismaClient
)
from litellm.caching import DualCache from litellm.caching import DualCache
litellm.suppress_debug_info = True litellm.suppress_debug_info = True
from fastapi import FastAPI, Request, HTTPException, status, Depends, BackgroundTasks from fastapi import FastAPI, Request, HTTPException, status, Depends, BackgroundTasks
@ -204,10 +208,19 @@ class GenerateKeyRequest(BaseModel):
aliases: dict = {} aliases: dict = {}
config: dict = {} config: dict = {}
spend: int = 0 spend: int = 0
user_id: Optional[str]
class GenerateKeyResponse(BaseModel): class GenerateKeyResponse(BaseModel):
key: str key: str
expires: str expires: str
user_id: str
class _DeleteKeyObject(BaseModel):
key: str
class DeleteKeyRequest(BaseModel):
keys: List[_DeleteKeyObject]
user_api_base = None user_api_base = None
user_model = None user_model = None
@ -229,7 +242,7 @@ log_file = "api_log.json"
worker_config = None worker_config = None
master_key = None master_key = None
otel_logging = False otel_logging = False
prisma_client = None prisma_client: Optional[PrismaClient] = None
user_api_key_cache = DualCache() user_api_key_cache = DualCache()
### REDIS QUEUE ### ### REDIS QUEUE ###
async_result = None async_result = None
@ -277,13 +290,7 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
if valid_token is None and "Bearer " in api_key: if valid_token is None and "Bearer " in api_key:
## check db ## check db
cleaned_api_key = api_key[len("Bearer "):] cleaned_api_key = api_key[len("Bearer "):]
valid_token = await prisma_client.litellm_verificationtoken.find_first( valid_token = await prisma_client.get_data(token=cleaned_api_key, expires=datetime.utcnow())
where={
"token": cleaned_api_key,
"expires": {"gte": datetime.utcnow()} # Check if the token is not expired
}
)
## save to cache for 60s
user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=60) user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=60)
elif valid_token is not None: elif valid_token is not None:
print(f"API Key Cache Hit!") print(f"API Key Cache Hit!")
@ -321,14 +328,7 @@ def prisma_setup(database_url: Optional[str]):
global prisma_client global prisma_client
if database_url: if database_url:
try: try:
import os prisma_client = PrismaClient(database_url=database_url)
print("LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'")
os.environ["DATABASE_URL"] = database_url
subprocess.run(['prisma', 'generate'])
subprocess.run(['prisma', 'db', 'push', '--accept-data-loss']) # this looks like a weird edge case when prisma just wont start on render. we need to have the --accept-data-loss
# Now you can import the Prisma Client
from prisma import Client
prisma_client = Client()
except Exception as e: except Exception as e:
print("Error when initializing prisma, Ensure you run pip install prisma", e) print("Error when initializing prisma, Ensure you run pip install prisma", e)
@ -388,6 +388,7 @@ def track_cost_callback(
start_time = None, start_time = None,
end_time = None, # start/end time for completion end_time = None, # start/end time for completion
): ):
global prisma_client
try: try:
# check if it has collected an entire stream response # check if it has collected an entire stream response
if "complete_streaming_response" in kwargs: if "complete_streaming_response" in kwargs:
@ -415,46 +416,41 @@ def track_cost_callback(
# Create new event loop for async function execution in the new thread # Create new event loop for async function execution in the new thread
new_loop = asyncio.new_event_loop() new_loop = asyncio.new_event_loop()
asyncio.set_event_loop(new_loop) asyncio.set_event_loop(new_loop)
try: try:
# Run the async function using the newly created event loop # Run the async function using the newly created event loop
new_loop.run_until_complete(update_prisma_database(user_api_key, response_cost)) existing_spend_obj = new_loop.run_until_complete(prisma_client.get_data(token=user_api_key))
if existing_spend_obj is None:
existing_spend = 0
else:
existing_spend = existing_spend_obj.spend
# Calculate the new cost by adding the existing cost and response_cost
new_spend = existing_spend + response_cost
print(f"new cost: {new_spend}")
# Update the cost column for the given token
new_loop.run_until_complete(prisma_client.update_data(token=user_api_key, data={"spend": new_spend}))
print(f"Prisma database updated for token {user_api_key}. New cost: {new_spend}")
except Exception as e: except Exception as e:
print(f"error in tracking cost callback - {str(e)}") print(f"error in creating async loop - {str(e)}")
finally:
# Close the event loop after the task is done
new_loop.close()
# Ensure that there's no event loop set in this thread, which could interfere with future asyncio calls
asyncio.set_event_loop(None)
except Exception as e: except Exception as e:
print(f"error in tracking cost callback - {str(e)}") print(f"error in tracking cost callback - {str(e)}")
async def update_prisma_database(token, response_cost): async def update_prisma_database(token, response_cost):
global prisma_client
try: try:
print(f"Enters prisma db call, token: {token}") print(f"Enters prisma db call, token: {token}")
# Fetch the existing cost for the given token # Fetch the existing cost for the given token
existing_spend = await prisma_client.litellm_verificationtoken.find_unique( existing_spend_obj = await prisma_client.get_data(token=token)
where={ print(f"existing spend: {existing_spend_obj}")
"token": token if existing_spend_obj is None:
}
)
print(f"existing spend: {existing_spend}")
if existing_spend is None:
existing_spend = 0 existing_spend = 0
else:
existing_spend = existing_spend_obj.spend
# Calculate the new cost by adding the existing cost and response_cost # Calculate the new cost by adding the existing cost and response_cost
new_spend = existing_spend.spend + response_cost new_spend = existing_spend + response_cost
print(f"new cost: {new_spend}") print(f"new cost: {new_spend}")
# Update the cost column for the given token # Update the cost column for the given token
await prisma_client.litellm_verificationtoken.update( await prisma_client.update_data(token=token, data={"spend": new_spend})
where={
"token": token
},
data={
"spend": new_spend
}
)
print(f"Prisma database updated for token {token}. New cost: {new_spend}") print(f"Prisma database updated for token {token}. New cost: {new_spend}")
except Exception as e: except Exception as e:
@ -569,9 +565,11 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
run_ollama_serve() run_ollama_serve()
return router, model_list, general_settings return router, model_list, general_settings
async def generate_key_helper_fn(duration_str: Optional[str], models: list, aliases: dict, config: dict, spend: float, token: Optional[str]): async def generate_key_helper_fn(duration_str: Optional[str], models: list, aliases: dict, config: dict, spend: float, token: Optional[str], user_id: Optional[str]=None):
if token is None: if token is None:
token = f"sk-{secrets.token_urlsafe(16)}" token = f"sk-{secrets.token_urlsafe(16)}"
def _duration_in_seconds(duration: str): def _duration_in_seconds(duration: str):
match = re.match(r"(\d+)([smhd]?)", duration) match = re.match(r"(\d+)([smhd]?)", duration)
if not match: if not match:
@ -599,8 +597,8 @@ async def generate_key_helper_fn(duration_str: Optional[str], models: list, alia
aliases_json = json.dumps(aliases) aliases_json = json.dumps(aliases)
config_json = json.dumps(config) config_json = json.dumps(config)
user_id = user_id or str(uuid.uuid4())
try: try:
db = prisma_client
# Create a new verification token (you may want to enhance this logic based on your needs) # Create a new verification token (you may want to enhance this logic based on your needs)
verification_token_data = { verification_token_data = {
"token": token, "token": token,
@ -608,30 +606,21 @@ async def generate_key_helper_fn(duration_str: Optional[str], models: list, alia
"models": models, "models": models,
"aliases": aliases_json, "aliases": aliases_json,
"config": config_json, "config": config_json,
"spend": spend "spend": spend,
"user_id": user_id
} }
new_verification_token = await db.litellm_verificationtoken.upsert( # type: ignore new_verification_token = await prisma_client.insert_data(data=verification_token_data)
where={
'token': token,
},
data={
"create": {**verification_token_data}, #type: ignore
"update": {} # don't do anything if it already exists
}
)
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
return {"token": new_verification_token.token, "expires": new_verification_token.expires} return {"token": new_verification_token.token, "expires": new_verification_token.expires, "user_id": user_id}
async def delete_verification_token(tokens: List[str]): async def delete_verification_token(tokens: List):
global prisma_client global prisma_client
try: try:
if prisma_client: if prisma_client:
# Assuming 'db' is your Prisma Client instance # Assuming 'db' is your Prisma Client instance
deleted_tokens = await prisma_client.litellm_verificationtoken.delete_many( deleted_tokens = await prisma_client.delete_data(tokens=tokens)
where={"token": {"in": tokens}}
)
else: else:
raise Exception raise Exception
except Exception as e: except Exception as e:
@ -982,6 +971,8 @@ async def generate_key_fn(request: Request, data: GenerateKeyRequest):
""" """
Generate an API key based on the provided data. Generate an API key based on the provided data.
Docs: https://docs.litellm.ai/docs/proxy/virtual_keys
Parameters: Parameters:
- duration: Optional[str] - Specify the length of time the token is valid for. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d"). **(Default is set to 1 hour.)** - duration: Optional[str] - Specify the length of time the token is valid for. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d"). **(Default is set to 1 hour.)**
- models: Optional[list] - Model_name's a user is allowed to call. (if empty, key is allowed to call all models) - models: Optional[list] - Model_name's a user is allowed to call. (if empty, key is allowed to call all models)
@ -1000,9 +991,10 @@ async def generate_key_fn(request: Request, data: GenerateKeyRequest):
aliases = data.aliases # Default to an empty dict (no alias mappings, on top of anything in the config.yaml model_list) aliases = data.aliases # Default to an empty dict (no alias mappings, on top of anything in the config.yaml model_list)
config = data.config config = data.config
spend = data.spend spend = data.spend
user_id = data.user_id
if isinstance(models, list): if isinstance(models, list):
response = await generate_key_helper_fn(duration_str=duration_str, models=models, aliases=aliases, config=config, spend=spend) response = await generate_key_helper_fn(duration_str=duration_str, models=models, aliases=aliases, config=config, spend=spend, user_id=user_id)
return GenerateKeyResponse(key=response["token"], expires=response["expires"]) return GenerateKeyResponse(key=response["token"], expires=response["expires"], user_id=response["user_id"])
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
@ -1010,7 +1002,7 @@ async def generate_key_fn(request: Request, data: GenerateKeyRequest):
) )
@router.post("/key/delete", tags=["key management"], dependencies=[Depends(user_api_key_auth)]) @router.post("/key/delete", tags=["key management"], dependencies=[Depends(user_api_key_auth)])
async def delete_key_fn(request: Request): async def delete_key_fn(request: Request, data: DeleteKeyRequest):
try: try:
data = await request.json() data = await request.json()
@ -1037,12 +1029,7 @@ async def info_key_fn(key: str = fastapi.Query(..., description="Key in the requ
try: try:
if prisma_client is None: if prisma_client is None:
raise Exception(f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys") raise Exception(f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys")
key_info = await prisma_client.litellm_verificationtoken.find_unique( key_info = await prisma_client.get_data(token=key)
where={
"token": key
}
)
return {"key": key, "info": key_info} return {"key": key, "info": key_info}
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(

View file

@ -15,4 +15,5 @@ model LiteLLM_VerificationToken {
models String[] models String[]
aliases Json @default("{}") aliases Json @default("{}")
config Json @default("{}") config Json @default("{}")
user_id String?
} }

View file

@ -5,7 +5,7 @@ import traceback
litellm_client = AsyncOpenAI( litellm_client = AsyncOpenAI(
api_key="test", api_key="sk-1234",
base_url="http://0.0.0.0:8000" base_url="http://0.0.0.0:8000"
) )
@ -17,7 +17,6 @@ async def litellm_completion():
model="gpt-3.5-turbo", model="gpt-3.5-turbo",
messages=[{"role": "user", "content": f"This is a test: {uuid.uuid4()}"}], messages=[{"role": "user", "content": f"This is a test: {uuid.uuid4()}"}],
) )
print(response)
return response return response
except Exception as e: except Exception as e:

View file

@ -1,71 +1,87 @@
import litellm from typing import Optional, List, Any
from litellm import ModelResponse import os, subprocess, hashlib
from proxy_server import llm_model_list
from typing import Optional
def track_cost_callback( class PrismaClient:
kwargs, # kwargs to completion def __init__(self, database_url: str):
completion_response: ModelResponse, # response from completion print("LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'")
start_time = None, os.environ["DATABASE_URL"] = database_url
end_time = None, # start/end time for completion subprocess.run(['prisma', 'generate'])
): subprocess.run(['prisma', 'db', 'push', '--accept-data-loss']) # this looks like a weird edge case when prisma just wont start on render. we need to have the --accept-data-loss
try: # Now you can import the Prisma Client
# init logging config from prisma import Client
print("in custom callback tracking cost", llm_model_list) self.db = Client() #Client to connect to Prisma db
if "azure" in kwargs["model"]:
# for azure cost tracking, we check the provided model list in the config.yaml
# we need to map azure/chatgpt-deployment to -> azure/gpt-3.5-turbo
pass
# check if it has collected an entire stream response
if "complete_streaming_response" in kwargs:
# for tracking streaming cost we pass the "messages" and the output_text to litellm.completion_cost
completion_response=kwargs["complete_streaming_response"]
input_text = kwargs["messages"]
output_text = completion_response["choices"][0]["message"]["content"]
response_cost = litellm.completion_cost(
model = kwargs["model"],
messages = input_text,
completion=output_text
)
print("streaming response_cost", response_cost)
# for non streaming responses
else:
# we pass the completion_response obj
if kwargs["stream"] != True:
input_text = kwargs.get("messages", "")
if isinstance(input_text, list):
input_text = "".join(m["content"] for m in input_text)
response_cost = litellm.completion_cost(completion_response=completion_response, completion=input_text)
print("regular response_cost", response_cost)
except:
pass
def update_prisma_database(token, response_cost): def hash_token(self, token: str):
try: # Hash the string using SHA-256
# Import your Prisma client hashed_token = hashlib.sha256(token.encode()).hexdigest()
from your_prisma_module import prisma
# Fetch the existing cost for the given token return hashed_token
existing_cost = prisma.LiteLLM_VerificationToken.find_unique(
async def get_data(self, token: str, expires: Optional[Any]=None):
hashed_token = self.hash_token(token=token)
if expires:
response = await self.db.litellm_verificationtoken.find_first(
where={ where={
"token": token "token": hashed_token,
"expires": {"gte": expires} # Check if the token is not expired
} }
).cost )
else:
# Calculate the new cost by adding the existing cost and response_cost response = await self.db.litellm_verificationtoken.find_unique(
new_cost = existing_cost + response_cost
# Update the cost column for the given token
prisma_liteLLM_VerificationToken = prisma.LiteLLM_VerificationToken.update(
where={ where={
"token": token "token": hashed_token
}
)
return response
async def insert_data(self, data: dict):
"""
Add a key to the database. If it already exists, do nothing.
"""
token = data["token"]
hashed_token = self.hash_token(token=token)
data["token"] = hashed_token
print(f"passed in data: {data}; hashed_token: {hashed_token}")
new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore
where={
'token': hashed_token,
}, },
data={ data={
"cost": new_cost "create": {**data}, #type: ignore
"update": {} # don't do anything if it already exists
} }
) )
print(f"Prisma database updated for token {token}. New cost: {new_cost}")
except Exception as e: return new_verification_token
print(f"Error updating Prisma database: {e}")
pass async def update_data(self, token: str, data: dict):
"""
Update existing data
"""
hashed_token = self.hash_token(token=token)
data["token"] = hashed_token
await self.db.litellm_verificationtoken.update(
where={
"token": hashed_token
},
data={**data}
)
return {"token": token, "data": data}
async def delete_data(self, tokens: List):
"""
Allow user to delete a key(s)
"""
hashed_tokens = [self.hash_token(token=token) for token in tokens]
await self.db.litellm_verificationtoken.delete_many(
where={"token": {"in": hashed_tokens}}
)
return {"deleted_keys": tokens}
async def connect(self):
await self.db.connect()
async def disconnect(self):
await self.db.disconnect()