fix(proxy_server.py): hash keys

This commit is contained in:
Krrish Dholakia 2023-12-02 19:24:58 -08:00
parent a4c9e18eb5
commit 6b1b1b82cf
5 changed files with 135 additions and 132 deletions

View file

@ -87,7 +87,7 @@ const sidebars = {
},
{
type: "category",
label: "💥 OpenAI Proxy",
label: "💥 OpenAI Proxy Server",
link: {
type: 'generated-index',
title: '💥 OpenAI Proxy Server',

View file

@ -4,6 +4,7 @@ import shutil, random, traceback, requests
from datetime import datetime, timedelta
from typing import Optional, List
import secrets, subprocess
import hashlib, uuid
import warnings
messages: list = []
sys.path.insert(
@ -89,6 +90,9 @@ def generate_feedback_box():
print()
import litellm
from litellm.proxy.utils import (
PrismaClient
)
from litellm.caching import DualCache
litellm.suppress_debug_info = True
from fastapi import FastAPI, Request, HTTPException, status, Depends, BackgroundTasks
@ -204,10 +208,19 @@ class GenerateKeyRequest(BaseModel):
aliases: dict = {}
config: dict = {}
spend: int = 0
user_id: Optional[str]
class GenerateKeyResponse(BaseModel):
key: str
expires: str
user_id: str
class _DeleteKeyObject(BaseModel):
key: str
class DeleteKeyRequest(BaseModel):
keys: List[_DeleteKeyObject]
user_api_base = None
user_model = None
@ -229,7 +242,7 @@ log_file = "api_log.json"
worker_config = None
master_key = None
otel_logging = False
prisma_client = None
prisma_client: Optional[PrismaClient] = None
user_api_key_cache = DualCache()
### REDIS QUEUE ###
async_result = None
@ -277,13 +290,7 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
if valid_token is None and "Bearer " in api_key:
## check db
cleaned_api_key = api_key[len("Bearer "):]
valid_token = await prisma_client.litellm_verificationtoken.find_first(
where={
"token": cleaned_api_key,
"expires": {"gte": datetime.utcnow()} # Check if the token is not expired
}
)
## save to cache for 60s
valid_token = await prisma_client.get_data(token=cleaned_api_key, expires=datetime.utcnow())
user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=60)
elif valid_token is not None:
print(f"API Key Cache Hit!")
@ -321,14 +328,7 @@ def prisma_setup(database_url: Optional[str]):
global prisma_client
if database_url:
try:
import os
print("LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'")
os.environ["DATABASE_URL"] = database_url
subprocess.run(['prisma', 'generate'])
subprocess.run(['prisma', 'db', 'push', '--accept-data-loss']) # this looks like a weird edge case when prisma just wont start on render. we need to have the --accept-data-loss
# Now you can import the Prisma Client
from prisma import Client
prisma_client = Client()
prisma_client = PrismaClient(database_url=database_url)
except Exception as e:
print("Error when initializing prisma, Ensure you run pip install prisma", e)
@ -388,6 +388,7 @@ def track_cost_callback(
start_time = None,
end_time = None, # start/end time for completion
):
global prisma_client
try:
# check if it has collected an entire stream response
if "complete_streaming_response" in kwargs:
@ -415,46 +416,41 @@ def track_cost_callback(
# Create new event loop for async function execution in the new thread
new_loop = asyncio.new_event_loop()
asyncio.set_event_loop(new_loop)
try:
# Run the async function using the newly created event loop
new_loop.run_until_complete(update_prisma_database(user_api_key, response_cost))
existing_spend_obj = new_loop.run_until_complete(prisma_client.get_data(token=user_api_key))
if existing_spend_obj is None:
existing_spend = 0
else:
existing_spend = existing_spend_obj.spend
# Calculate the new cost by adding the existing cost and response_cost
new_spend = existing_spend + response_cost
print(f"new cost: {new_spend}")
# Update the cost column for the given token
new_loop.run_until_complete(prisma_client.update_data(token=user_api_key, data={"spend": new_spend}))
print(f"Prisma database updated for token {user_api_key}. New cost: {new_spend}")
except Exception as e:
print(f"error in tracking cost callback - {str(e)}")
finally:
# Close the event loop after the task is done
new_loop.close()
# Ensure that there's no event loop set in this thread, which could interfere with future asyncio calls
asyncio.set_event_loop(None)
print(f"error in creating async loop - {str(e)}")
except Exception as e:
print(f"error in tracking cost callback - {str(e)}")
async def update_prisma_database(token, response_cost):
global prisma_client
try:
print(f"Enters prisma db call, token: {token}")
# Fetch the existing cost for the given token
existing_spend = await prisma_client.litellm_verificationtoken.find_unique(
where={
"token": token
}
)
print(f"existing spend: {existing_spend}")
if existing_spend is None:
existing_spend_obj = await prisma_client.get_data(token=token)
print(f"existing spend: {existing_spend_obj}")
if existing_spend_obj is None:
existing_spend = 0
else:
existing_spend = existing_spend_obj.spend
# Calculate the new cost by adding the existing cost and response_cost
new_spend = existing_spend.spend + response_cost
new_spend = existing_spend + response_cost
print(f"new cost: {new_spend}")
# Update the cost column for the given token
await prisma_client.litellm_verificationtoken.update(
where={
"token": token
},
data={
"spend": new_spend
}
)
await prisma_client.update_data(token=token, data={"spend": new_spend})
print(f"Prisma database updated for token {token}. New cost: {new_spend}")
except Exception as e:
@ -569,9 +565,11 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
run_ollama_serve()
return router, model_list, general_settings
async def generate_key_helper_fn(duration_str: Optional[str], models: list, aliases: dict, config: dict, spend: float, token: Optional[str]):
async def generate_key_helper_fn(duration_str: Optional[str], models: list, aliases: dict, config: dict, spend: float, token: Optional[str], user_id: Optional[str]=None):
if token is None:
token = f"sk-{secrets.token_urlsafe(16)}"
def _duration_in_seconds(duration: str):
match = re.match(r"(\d+)([smhd]?)", duration)
if not match:
@ -599,8 +597,8 @@ async def generate_key_helper_fn(duration_str: Optional[str], models: list, alia
aliases_json = json.dumps(aliases)
config_json = json.dumps(config)
user_id = user_id or str(uuid.uuid4())
try:
db = prisma_client
# Create a new verification token (you may want to enhance this logic based on your needs)
verification_token_data = {
"token": token,
@ -608,30 +606,21 @@ async def generate_key_helper_fn(duration_str: Optional[str], models: list, alia
"models": models,
"aliases": aliases_json,
"config": config_json,
"spend": spend
"spend": spend,
"user_id": user_id
}
new_verification_token = await db.litellm_verificationtoken.upsert( # type: ignore
where={
'token': token,
},
data={
"create": {**verification_token_data}, #type: ignore
"update": {} # don't do anything if it already exists
}
)
new_verification_token = await prisma_client.insert_data(data=verification_token_data)
except Exception as e:
traceback.print_exc()
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
return {"token": new_verification_token.token, "expires": new_verification_token.expires}
return {"token": new_verification_token.token, "expires": new_verification_token.expires, "user_id": user_id}
async def delete_verification_token(tokens: List[str]):
async def delete_verification_token(tokens: List):
global prisma_client
try:
if prisma_client:
# Assuming 'db' is your Prisma Client instance
deleted_tokens = await prisma_client.litellm_verificationtoken.delete_many(
where={"token": {"in": tokens}}
)
deleted_tokens = await prisma_client.delete_data(tokens=tokens)
else:
raise Exception
except Exception as e:
@ -980,7 +969,9 @@ async def embeddings(request: Request, user_api_key_dict: dict = Depends(user_ap
@router.post("/key/generate", tags=["key management"], dependencies=[Depends(user_api_key_auth)], response_model=GenerateKeyResponse)
async def generate_key_fn(request: Request, data: GenerateKeyRequest):
"""
Generate an API key based on the provided data.
Generate an API key based on the provided data.
Docs: https://docs.litellm.ai/docs/proxy/virtual_keys
Parameters:
- duration: Optional[str] - Specify the length of time the token is valid for. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d"). **(Default is set to 1 hour.)**
@ -1000,9 +991,10 @@ async def generate_key_fn(request: Request, data: GenerateKeyRequest):
aliases = data.aliases # Default to an empty dict (no alias mappings, on top of anything in the config.yaml model_list)
config = data.config
spend = data.spend
user_id = data.user_id
if isinstance(models, list):
response = await generate_key_helper_fn(duration_str=duration_str, models=models, aliases=aliases, config=config, spend=spend)
return GenerateKeyResponse(key=response["token"], expires=response["expires"])
response = await generate_key_helper_fn(duration_str=duration_str, models=models, aliases=aliases, config=config, spend=spend, user_id=user_id)
return GenerateKeyResponse(key=response["token"], expires=response["expires"], user_id=response["user_id"])
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
@ -1010,7 +1002,7 @@ async def generate_key_fn(request: Request, data: GenerateKeyRequest):
)
@router.post("/key/delete", tags=["key management"], dependencies=[Depends(user_api_key_auth)])
async def delete_key_fn(request: Request):
async def delete_key_fn(request: Request, data: DeleteKeyRequest):
try:
data = await request.json()
@ -1037,12 +1029,7 @@ async def info_key_fn(key: str = fastapi.Query(..., description="Key in the requ
try:
if prisma_client is None:
raise Exception(f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys")
key_info = await prisma_client.litellm_verificationtoken.find_unique(
where={
"token": key
}
)
key_info = await prisma_client.get_data(token=key)
return {"key": key, "info": key_info}
except Exception as e:
raise HTTPException(

View file

@ -15,4 +15,5 @@ model LiteLLM_VerificationToken {
models String[]
aliases Json @default("{}")
config Json @default("{}")
user_id String?
}

View file

@ -5,7 +5,7 @@ import traceback
litellm_client = AsyncOpenAI(
api_key="test",
api_key="sk-1234",
base_url="http://0.0.0.0:8000"
)
@ -17,7 +17,6 @@ async def litellm_completion():
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": f"This is a test: {uuid.uuid4()}"}],
)
print(response)
return response
except Exception as e:

View file

@ -1,71 +1,87 @@
import litellm
from litellm import ModelResponse
from proxy_server import llm_model_list
from typing import Optional
from typing import Optional, List, Any
import os, subprocess, hashlib
def track_cost_callback(
kwargs, # kwargs to completion
completion_response: ModelResponse, # response from completion
start_time = None,
end_time = None, # start/end time for completion
):
try:
# init logging config
print("in custom callback tracking cost", llm_model_list)
if "azure" in kwargs["model"]:
# for azure cost tracking, we check the provided model list in the config.yaml
# we need to map azure/chatgpt-deployment to -> azure/gpt-3.5-turbo
pass
# check if it has collected an entire stream response
if "complete_streaming_response" in kwargs:
# for tracking streaming cost we pass the "messages" and the output_text to litellm.completion_cost
completion_response=kwargs["complete_streaming_response"]
input_text = kwargs["messages"]
output_text = completion_response["choices"][0]["message"]["content"]
response_cost = litellm.completion_cost(
model = kwargs["model"],
messages = input_text,
completion=output_text
class PrismaClient:
def __init__(self, database_url: str):
print("LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'")
os.environ["DATABASE_URL"] = database_url
subprocess.run(['prisma', 'generate'])
subprocess.run(['prisma', 'db', 'push', '--accept-data-loss']) # this looks like a weird edge case when prisma just wont start on render. we need to have the --accept-data-loss
# Now you can import the Prisma Client
from prisma import Client
self.db = Client() #Client to connect to Prisma db
def hash_token(self, token: str):
# Hash the string using SHA-256
hashed_token = hashlib.sha256(token.encode()).hexdigest()
return hashed_token
async def get_data(self, token: str, expires: Optional[Any]=None):
hashed_token = self.hash_token(token=token)
if expires:
response = await self.db.litellm_verificationtoken.find_first(
where={
"token": hashed_token,
"expires": {"gte": expires} # Check if the token is not expired
}
)
else:
response = await self.db.litellm_verificationtoken.find_unique(
where={
"token": hashed_token
}
)
print("streaming response_cost", response_cost)
# for non streaming responses
else:
# we pass the completion_response obj
if kwargs["stream"] != True:
input_text = kwargs.get("messages", "")
if isinstance(input_text, list):
input_text = "".join(m["content"] for m in input_text)
response_cost = litellm.completion_cost(completion_response=completion_response, completion=input_text)
print("regular response_cost", response_cost)
except:
pass
return response
def update_prisma_database(token, response_cost):
try:
# Import your Prisma client
from your_prisma_module import prisma
async def insert_data(self, data: dict):
"""
Add a key to the database. If it already exists, do nothing.
"""
token = data["token"]
hashed_token = self.hash_token(token=token)
data["token"] = hashed_token
print(f"passed in data: {data}; hashed_token: {hashed_token}")
# Fetch the existing cost for the given token
existing_cost = prisma.LiteLLM_VerificationToken.find_unique(
new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore
where={
"token": token
}
).cost
# Calculate the new cost by adding the existing cost and response_cost
new_cost = existing_cost + response_cost
# Update the cost column for the given token
prisma_liteLLM_VerificationToken = prisma.LiteLLM_VerificationToken.update(
where={
"token": token
'token': hashed_token,
},
data={
"cost": new_cost
"create": {**data}, #type: ignore
"update": {} # don't do anything if it already exists
}
)
print(f"Prisma database updated for token {token}. New cost: {new_cost}")
except Exception as e:
print(f"Error updating Prisma database: {e}")
pass
return new_verification_token
async def update_data(self, token: str, data: dict):
"""
Update existing data
"""
hashed_token = self.hash_token(token=token)
data["token"] = hashed_token
await self.db.litellm_verificationtoken.update(
where={
"token": hashed_token
},
data={**data}
)
return {"token": token, "data": data}
async def delete_data(self, tokens: List):
"""
Allow user to delete a key(s)
"""
hashed_tokens = [self.hash_token(token=token) for token in tokens]
await self.db.litellm_verificationtoken.delete_many(
where={"token": {"in": hashed_tokens}}
)
return {"deleted_keys": tokens}
async def connect(self):
await self.db.connect()
async def disconnect(self):
await self.db.disconnect()