forked from phoenix/litellm-mirror
fix(proxy_server.py): hash keys
This commit is contained in:
parent
a4c9e18eb5
commit
6b1b1b82cf
5 changed files with 135 additions and 132 deletions
|
@ -87,7 +87,7 @@ const sidebars = {
|
|||
},
|
||||
{
|
||||
type: "category",
|
||||
label: "💥 OpenAI Proxy",
|
||||
label: "💥 OpenAI Proxy Server",
|
||||
link: {
|
||||
type: 'generated-index',
|
||||
title: '💥 OpenAI Proxy Server',
|
||||
|
|
|
@ -4,6 +4,7 @@ import shutil, random, traceback, requests
|
|||
from datetime import datetime, timedelta
|
||||
from typing import Optional, List
|
||||
import secrets, subprocess
|
||||
import hashlib, uuid
|
||||
import warnings
|
||||
messages: list = []
|
||||
sys.path.insert(
|
||||
|
@ -89,6 +90,9 @@ def generate_feedback_box():
|
|||
print()
|
||||
|
||||
import litellm
|
||||
from litellm.proxy.utils import (
|
||||
PrismaClient
|
||||
)
|
||||
from litellm.caching import DualCache
|
||||
litellm.suppress_debug_info = True
|
||||
from fastapi import FastAPI, Request, HTTPException, status, Depends, BackgroundTasks
|
||||
|
@ -204,10 +208,19 @@ class GenerateKeyRequest(BaseModel):
|
|||
aliases: dict = {}
|
||||
config: dict = {}
|
||||
spend: int = 0
|
||||
user_id: Optional[str]
|
||||
|
||||
class GenerateKeyResponse(BaseModel):
|
||||
key: str
|
||||
expires: str
|
||||
user_id: str
|
||||
|
||||
class _DeleteKeyObject(BaseModel):
|
||||
key: str
|
||||
|
||||
class DeleteKeyRequest(BaseModel):
|
||||
keys: List[_DeleteKeyObject]
|
||||
|
||||
|
||||
user_api_base = None
|
||||
user_model = None
|
||||
|
@ -229,7 +242,7 @@ log_file = "api_log.json"
|
|||
worker_config = None
|
||||
master_key = None
|
||||
otel_logging = False
|
||||
prisma_client = None
|
||||
prisma_client: Optional[PrismaClient] = None
|
||||
user_api_key_cache = DualCache()
|
||||
### REDIS QUEUE ###
|
||||
async_result = None
|
||||
|
@ -277,13 +290,7 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
|
|||
if valid_token is None and "Bearer " in api_key:
|
||||
## check db
|
||||
cleaned_api_key = api_key[len("Bearer "):]
|
||||
valid_token = await prisma_client.litellm_verificationtoken.find_first(
|
||||
where={
|
||||
"token": cleaned_api_key,
|
||||
"expires": {"gte": datetime.utcnow()} # Check if the token is not expired
|
||||
}
|
||||
)
|
||||
## save to cache for 60s
|
||||
valid_token = await prisma_client.get_data(token=cleaned_api_key, expires=datetime.utcnow())
|
||||
user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=60)
|
||||
elif valid_token is not None:
|
||||
print(f"API Key Cache Hit!")
|
||||
|
@ -321,14 +328,7 @@ def prisma_setup(database_url: Optional[str]):
|
|||
global prisma_client
|
||||
if database_url:
|
||||
try:
|
||||
import os
|
||||
print("LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'")
|
||||
os.environ["DATABASE_URL"] = database_url
|
||||
subprocess.run(['prisma', 'generate'])
|
||||
subprocess.run(['prisma', 'db', 'push', '--accept-data-loss']) # this looks like a weird edge case when prisma just wont start on render. we need to have the --accept-data-loss
|
||||
# Now you can import the Prisma Client
|
||||
from prisma import Client
|
||||
prisma_client = Client()
|
||||
prisma_client = PrismaClient(database_url=database_url)
|
||||
except Exception as e:
|
||||
print("Error when initializing prisma, Ensure you run pip install prisma", e)
|
||||
|
||||
|
@ -388,6 +388,7 @@ def track_cost_callback(
|
|||
start_time = None,
|
||||
end_time = None, # start/end time for completion
|
||||
):
|
||||
global prisma_client
|
||||
try:
|
||||
# check if it has collected an entire stream response
|
||||
if "complete_streaming_response" in kwargs:
|
||||
|
@ -415,46 +416,41 @@ def track_cost_callback(
|
|||
# Create new event loop for async function execution in the new thread
|
||||
new_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(new_loop)
|
||||
|
||||
try:
|
||||
# Run the async function using the newly created event loop
|
||||
new_loop.run_until_complete(update_prisma_database(user_api_key, response_cost))
|
||||
existing_spend_obj = new_loop.run_until_complete(prisma_client.get_data(token=user_api_key))
|
||||
if existing_spend_obj is None:
|
||||
existing_spend = 0
|
||||
else:
|
||||
existing_spend = existing_spend_obj.spend
|
||||
# Calculate the new cost by adding the existing cost and response_cost
|
||||
new_spend = existing_spend + response_cost
|
||||
print(f"new cost: {new_spend}")
|
||||
# Update the cost column for the given token
|
||||
new_loop.run_until_complete(prisma_client.update_data(token=user_api_key, data={"spend": new_spend}))
|
||||
print(f"Prisma database updated for token {user_api_key}. New cost: {new_spend}")
|
||||
except Exception as e:
|
||||
print(f"error in tracking cost callback - {str(e)}")
|
||||
finally:
|
||||
# Close the event loop after the task is done
|
||||
new_loop.close()
|
||||
# Ensure that there's no event loop set in this thread, which could interfere with future asyncio calls
|
||||
asyncio.set_event_loop(None)
|
||||
print(f"error in creating async loop - {str(e)}")
|
||||
except Exception as e:
|
||||
print(f"error in tracking cost callback - {str(e)}")
|
||||
|
||||
async def update_prisma_database(token, response_cost):
|
||||
global prisma_client
|
||||
|
||||
try:
|
||||
print(f"Enters prisma db call, token: {token}")
|
||||
# Fetch the existing cost for the given token
|
||||
existing_spend = await prisma_client.litellm_verificationtoken.find_unique(
|
||||
where={
|
||||
"token": token
|
||||
}
|
||||
)
|
||||
print(f"existing spend: {existing_spend}")
|
||||
if existing_spend is None:
|
||||
existing_spend_obj = await prisma_client.get_data(token=token)
|
||||
print(f"existing spend: {existing_spend_obj}")
|
||||
if existing_spend_obj is None:
|
||||
existing_spend = 0
|
||||
else:
|
||||
existing_spend = existing_spend_obj.spend
|
||||
# Calculate the new cost by adding the existing cost and response_cost
|
||||
new_spend = existing_spend.spend + response_cost
|
||||
new_spend = existing_spend + response_cost
|
||||
|
||||
print(f"new cost: {new_spend}")
|
||||
# Update the cost column for the given token
|
||||
await prisma_client.litellm_verificationtoken.update(
|
||||
where={
|
||||
"token": token
|
||||
},
|
||||
data={
|
||||
"spend": new_spend
|
||||
}
|
||||
)
|
||||
await prisma_client.update_data(token=token, data={"spend": new_spend})
|
||||
print(f"Prisma database updated for token {token}. New cost: {new_spend}")
|
||||
|
||||
except Exception as e:
|
||||
|
@ -569,9 +565,11 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
|||
run_ollama_serve()
|
||||
return router, model_list, general_settings
|
||||
|
||||
async def generate_key_helper_fn(duration_str: Optional[str], models: list, aliases: dict, config: dict, spend: float, token: Optional[str]):
|
||||
async def generate_key_helper_fn(duration_str: Optional[str], models: list, aliases: dict, config: dict, spend: float, token: Optional[str], user_id: Optional[str]=None):
|
||||
if token is None:
|
||||
token = f"sk-{secrets.token_urlsafe(16)}"
|
||||
|
||||
|
||||
def _duration_in_seconds(duration: str):
|
||||
match = re.match(r"(\d+)([smhd]?)", duration)
|
||||
if not match:
|
||||
|
@ -599,8 +597,8 @@ async def generate_key_helper_fn(duration_str: Optional[str], models: list, alia
|
|||
|
||||
aliases_json = json.dumps(aliases)
|
||||
config_json = json.dumps(config)
|
||||
user_id = user_id or str(uuid.uuid4())
|
||||
try:
|
||||
db = prisma_client
|
||||
# Create a new verification token (you may want to enhance this logic based on your needs)
|
||||
verification_token_data = {
|
||||
"token": token,
|
||||
|
@ -608,30 +606,21 @@ async def generate_key_helper_fn(duration_str: Optional[str], models: list, alia
|
|||
"models": models,
|
||||
"aliases": aliases_json,
|
||||
"config": config_json,
|
||||
"spend": spend
|
||||
"spend": spend,
|
||||
"user_id": user_id
|
||||
}
|
||||
new_verification_token = await db.litellm_verificationtoken.upsert( # type: ignore
|
||||
where={
|
||||
'token': token,
|
||||
},
|
||||
data={
|
||||
"create": {**verification_token_data}, #type: ignore
|
||||
"update": {} # don't do anything if it already exists
|
||||
}
|
||||
)
|
||||
new_verification_token = await prisma_client.insert_data(data=verification_token_data)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||||
return {"token": new_verification_token.token, "expires": new_verification_token.expires}
|
||||
return {"token": new_verification_token.token, "expires": new_verification_token.expires, "user_id": user_id}
|
||||
|
||||
async def delete_verification_token(tokens: List[str]):
|
||||
async def delete_verification_token(tokens: List):
|
||||
global prisma_client
|
||||
try:
|
||||
if prisma_client:
|
||||
# Assuming 'db' is your Prisma Client instance
|
||||
deleted_tokens = await prisma_client.litellm_verificationtoken.delete_many(
|
||||
where={"token": {"in": tokens}}
|
||||
)
|
||||
deleted_tokens = await prisma_client.delete_data(tokens=tokens)
|
||||
else:
|
||||
raise Exception
|
||||
except Exception as e:
|
||||
|
@ -980,7 +969,9 @@ async def embeddings(request: Request, user_api_key_dict: dict = Depends(user_ap
|
|||
@router.post("/key/generate", tags=["key management"], dependencies=[Depends(user_api_key_auth)], response_model=GenerateKeyResponse)
|
||||
async def generate_key_fn(request: Request, data: GenerateKeyRequest):
|
||||
"""
|
||||
Generate an API key based on the provided data.
|
||||
Generate an API key based on the provided data.
|
||||
|
||||
Docs: https://docs.litellm.ai/docs/proxy/virtual_keys
|
||||
|
||||
Parameters:
|
||||
- duration: Optional[str] - Specify the length of time the token is valid for. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d"). **(Default is set to 1 hour.)**
|
||||
|
@ -1000,9 +991,10 @@ async def generate_key_fn(request: Request, data: GenerateKeyRequest):
|
|||
aliases = data.aliases # Default to an empty dict (no alias mappings, on top of anything in the config.yaml model_list)
|
||||
config = data.config
|
||||
spend = data.spend
|
||||
user_id = data.user_id
|
||||
if isinstance(models, list):
|
||||
response = await generate_key_helper_fn(duration_str=duration_str, models=models, aliases=aliases, config=config, spend=spend)
|
||||
return GenerateKeyResponse(key=response["token"], expires=response["expires"])
|
||||
response = await generate_key_helper_fn(duration_str=duration_str, models=models, aliases=aliases, config=config, spend=spend, user_id=user_id)
|
||||
return GenerateKeyResponse(key=response["token"], expires=response["expires"], user_id=response["user_id"])
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
|
@ -1010,7 +1002,7 @@ async def generate_key_fn(request: Request, data: GenerateKeyRequest):
|
|||
)
|
||||
|
||||
@router.post("/key/delete", tags=["key management"], dependencies=[Depends(user_api_key_auth)])
|
||||
async def delete_key_fn(request: Request):
|
||||
async def delete_key_fn(request: Request, data: DeleteKeyRequest):
|
||||
try:
|
||||
data = await request.json()
|
||||
|
||||
|
@ -1037,12 +1029,7 @@ async def info_key_fn(key: str = fastapi.Query(..., description="Key in the requ
|
|||
try:
|
||||
if prisma_client is None:
|
||||
raise Exception(f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys")
|
||||
key_info = await prisma_client.litellm_verificationtoken.find_unique(
|
||||
where={
|
||||
"token": key
|
||||
}
|
||||
)
|
||||
|
||||
key_info = await prisma_client.get_data(token=key)
|
||||
return {"key": key, "info": key_info}
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
|
|
|
@ -15,4 +15,5 @@ model LiteLLM_VerificationToken {
|
|||
models String[]
|
||||
aliases Json @default("{}")
|
||||
config Json @default("{}")
|
||||
user_id String?
|
||||
}
|
|
@ -5,7 +5,7 @@ import traceback
|
|||
|
||||
|
||||
litellm_client = AsyncOpenAI(
|
||||
api_key="test",
|
||||
api_key="sk-1234",
|
||||
base_url="http://0.0.0.0:8000"
|
||||
)
|
||||
|
||||
|
@ -17,7 +17,6 @@ async def litellm_completion():
|
|||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": f"This is a test: {uuid.uuid4()}"}],
|
||||
)
|
||||
print(response)
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
|
|
|
@ -1,71 +1,87 @@
|
|||
import litellm
|
||||
from litellm import ModelResponse
|
||||
from proxy_server import llm_model_list
|
||||
from typing import Optional
|
||||
from typing import Optional, List, Any
|
||||
import os, subprocess, hashlib
|
||||
|
||||
def track_cost_callback(
|
||||
kwargs, # kwargs to completion
|
||||
completion_response: ModelResponse, # response from completion
|
||||
start_time = None,
|
||||
end_time = None, # start/end time for completion
|
||||
):
|
||||
try:
|
||||
# init logging config
|
||||
print("in custom callback tracking cost", llm_model_list)
|
||||
if "azure" in kwargs["model"]:
|
||||
# for azure cost tracking, we check the provided model list in the config.yaml
|
||||
# we need to map azure/chatgpt-deployment to -> azure/gpt-3.5-turbo
|
||||
pass
|
||||
# check if it has collected an entire stream response
|
||||
if "complete_streaming_response" in kwargs:
|
||||
# for tracking streaming cost we pass the "messages" and the output_text to litellm.completion_cost
|
||||
completion_response=kwargs["complete_streaming_response"]
|
||||
input_text = kwargs["messages"]
|
||||
output_text = completion_response["choices"][0]["message"]["content"]
|
||||
response_cost = litellm.completion_cost(
|
||||
model = kwargs["model"],
|
||||
messages = input_text,
|
||||
completion=output_text
|
||||
class PrismaClient:
|
||||
def __init__(self, database_url: str):
|
||||
print("LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'")
|
||||
os.environ["DATABASE_URL"] = database_url
|
||||
subprocess.run(['prisma', 'generate'])
|
||||
subprocess.run(['prisma', 'db', 'push', '--accept-data-loss']) # this looks like a weird edge case when prisma just wont start on render. we need to have the --accept-data-loss
|
||||
# Now you can import the Prisma Client
|
||||
from prisma import Client
|
||||
self.db = Client() #Client to connect to Prisma db
|
||||
|
||||
def hash_token(self, token: str):
|
||||
# Hash the string using SHA-256
|
||||
hashed_token = hashlib.sha256(token.encode()).hexdigest()
|
||||
|
||||
return hashed_token
|
||||
|
||||
async def get_data(self, token: str, expires: Optional[Any]=None):
|
||||
hashed_token = self.hash_token(token=token)
|
||||
if expires:
|
||||
response = await self.db.litellm_verificationtoken.find_first(
|
||||
where={
|
||||
"token": hashed_token,
|
||||
"expires": {"gte": expires} # Check if the token is not expired
|
||||
}
|
||||
)
|
||||
else:
|
||||
response = await self.db.litellm_verificationtoken.find_unique(
|
||||
where={
|
||||
"token": hashed_token
|
||||
}
|
||||
)
|
||||
print("streaming response_cost", response_cost)
|
||||
# for non streaming responses
|
||||
else:
|
||||
# we pass the completion_response obj
|
||||
if kwargs["stream"] != True:
|
||||
input_text = kwargs.get("messages", "")
|
||||
if isinstance(input_text, list):
|
||||
input_text = "".join(m["content"] for m in input_text)
|
||||
response_cost = litellm.completion_cost(completion_response=completion_response, completion=input_text)
|
||||
print("regular response_cost", response_cost)
|
||||
except:
|
||||
pass
|
||||
return response
|
||||
|
||||
def update_prisma_database(token, response_cost):
|
||||
try:
|
||||
# Import your Prisma client
|
||||
from your_prisma_module import prisma
|
||||
async def insert_data(self, data: dict):
|
||||
"""
|
||||
Add a key to the database. If it already exists, do nothing.
|
||||
"""
|
||||
token = data["token"]
|
||||
hashed_token = self.hash_token(token=token)
|
||||
data["token"] = hashed_token
|
||||
print(f"passed in data: {data}; hashed_token: {hashed_token}")
|
||||
|
||||
# Fetch the existing cost for the given token
|
||||
existing_cost = prisma.LiteLLM_VerificationToken.find_unique(
|
||||
new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore
|
||||
where={
|
||||
"token": token
|
||||
}
|
||||
).cost
|
||||
|
||||
# Calculate the new cost by adding the existing cost and response_cost
|
||||
new_cost = existing_cost + response_cost
|
||||
|
||||
# Update the cost column for the given token
|
||||
prisma_liteLLM_VerificationToken = prisma.LiteLLM_VerificationToken.update(
|
||||
where={
|
||||
"token": token
|
||||
'token': hashed_token,
|
||||
},
|
||||
data={
|
||||
"cost": new_cost
|
||||
"create": {**data}, #type: ignore
|
||||
"update": {} # don't do anything if it already exists
|
||||
}
|
||||
)
|
||||
print(f"Prisma database updated for token {token}. New cost: {new_cost}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error updating Prisma database: {e}")
|
||||
pass
|
||||
return new_verification_token
|
||||
|
||||
async def update_data(self, token: str, data: dict):
|
||||
"""
|
||||
Update existing data
|
||||
"""
|
||||
hashed_token = self.hash_token(token=token)
|
||||
data["token"] = hashed_token
|
||||
await self.db.litellm_verificationtoken.update(
|
||||
where={
|
||||
"token": hashed_token
|
||||
},
|
||||
data={**data}
|
||||
)
|
||||
return {"token": token, "data": data}
|
||||
|
||||
async def delete_data(self, tokens: List):
|
||||
"""
|
||||
Allow user to delete a key(s)
|
||||
"""
|
||||
hashed_tokens = [self.hash_token(token=token) for token in tokens]
|
||||
await self.db.litellm_verificationtoken.delete_many(
|
||||
where={"token": {"in": hashed_tokens}}
|
||||
)
|
||||
return {"deleted_keys": tokens}
|
||||
|
||||
async def connect(self):
|
||||
await self.db.connect()
|
||||
|
||||
async def disconnect(self):
|
||||
await self.db.disconnect()
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue