feat(proxy_server.py): tracking spend per api key

This commit is contained in:
Krrish Dholakia 2023-11-24 15:14:06 -08:00
parent 2b52e6995c
commit 4f22e7de18
2 changed files with 132 additions and 22 deletions

View file

@ -149,14 +149,18 @@ api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(api_key_header)):
global master_key, prisma_client, llm_model_list
if master_key is None:
return
return {
"api_key": None
}
try:
route = request.url.path
# note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead
is_master_key_valid = secrets.compare_digest(api_key, master_key) or secrets.compare_digest(api_key, "Bearer " + master_key)
if is_master_key_valid:
return
return {
"api_key": master_key
}
if (route == "/key/generate" or route == "/key/delete") and not is_master_key_valid:
raise Exception(f"If master key is set, only master key can be used to generate new keys")
@ -186,7 +190,9 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
llm_model_list = model_list
print("\n new llm router model list", llm_model_list)
if len(valid_token.models) == 0: # assume an empty model list means all models are allowed to be called
return
return {
"api_key": valid_token.token
}
else:
data = await request.json()
model = data.get("model", None)
@ -194,7 +200,9 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
model = litellm.model_alias_map[model]
if model and model not in valid_token.models:
raise Exception(f"Token not allowed to access model")
return
return {
"api_key": valid_token.token
}
else:
raise Exception(f"Invalid token")
except Exception as e:
@ -231,6 +239,83 @@ def celery_setup(use_queue: bool):
async_result = AsyncResult
celery_app_conn = celery_app
def cost_tracking():
global prisma_client, master_key
if prisma_client is not None and master_key is not None:
if isinstance(litellm.success_callback, list):
print("setting litellm success callback to track cost")
if (track_cost_callback) not in litellm.success_callback: # type: ignore
litellm.success_callback.append(track_cost_callback) # type: ignore
else:
litellm.success_callback = track_cost_callback # type: ignore
def track_cost_callback(
kwargs, # kwargs to completion
completion_response: litellm.ModelResponse, # response from completion
start_time = None,
end_time = None, # start/end time for completion
):
try:
# init logging config
print("in custom callback tracking cost", llm_model_list)
# check if it has collected an entire stream response
if "complete_streaming_response" in kwargs:
# for tracking streaming cost we pass the "messages" and the output_text to litellm.completion_cost
completion_response=kwargs["complete_streaming_response"]
input_text = kwargs["messages"]
output_text = completion_response["choices"][0]["message"]["content"]
response_cost = litellm.completion_cost(
model = kwargs["model"],
messages = input_text,
completion=output_text
)
print("streaming response_cost", response_cost)
# for non streaming responses
else:
# we pass the completion_response obj
if kwargs["stream"] != True:
input_text = kwargs.get("messages", "")
if isinstance(input_text, list):
input_text = "".join(m["content"] for m in input_text)
response_cost = litellm.completion_cost(completion_response=completion_response, completion=input_text)
print("regular response_cost", response_cost)
print(f"metadata in kwargs: {kwargs}")
user_api_key = kwargs["litellm_params"]["metadata"].get("user_api_key", None)
if user_api_key:
asyncio.run(update_prisma_database(token=user_api_key, response_cost=response_cost))
except Exception as e:
print(f"error in tracking cost callback - {str(e)}")
async def update_prisma_database(token, response_cost):
global prisma_client
try:
print(f"Enters prisma db call, token: {token}")
# Fetch the existing cost for the given token
existing_spend = await prisma_client.litellm_verificationtoken.find_unique(
where={
"token": token
}
)
print(f"existing spend: {existing_spend}")
# Calculate the new cost by adding the existing cost and response_cost
new_spend = existing_spend.spend + response_cost
print(f"new cost: {new_spend}")
# Update the cost column for the given token
await prisma_client.litellm_verificationtoken.update(
where={
"token": token
},
data={
"spend": new_spend
}
)
print(f"Prisma database updated for token {token}. New cost: {new_spend}")
except Exception as e:
print(f"Error updating Prisma database: {traceback.format_exc()}")
pass
def run_ollama_serve():
try:
command = ['ollama', 'serve']
@ -272,15 +357,8 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
### CONNECT TO DATABASE ###
database_url = general_settings.get("database_url", None)
prisma_setup(database_url=database_url)
## Cost Tracking for master key + auth setup ##
if master_key is not None:
if isinstance(litellm.success_callback, list):
import utils
print("setting litellm success callback to track cost")
if (utils.track_cost_callback) not in litellm.success_callback: # type: ignore
litellm.success_callback.append(utils.track_cost_callback) # type: ignore
else:
litellm.success_callback = utils.track_cost_callback # type: ignore
## COST TRACKING ##
cost_tracking()
### START REDIS QUEUE ###
use_queue = general_settings.get("use_queue", False)
celery_setup(use_queue=use_queue)
@ -386,12 +464,10 @@ async def delete_verification_token(tokens: List[str]):
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
return deleted_tokens
async def generate_key_cli_task(duration_str):
task = asyncio.create_task(generate_key_helper_fn(duration_str=duration_str))
await task
def save_worker_config(**data):
import json
os.environ["WORKER_CONFIG"] = json.dumps(data)
@ -487,7 +563,6 @@ def data_generator(response):
except:
yield f"data: {json.dumps(chunk)}\n\n"
def litellm_completion(*args, **kwargs):
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
call_type = kwargs.pop("call_type")
@ -572,7 +647,7 @@ def model_list():
@router.post("/v1/completions", dependencies=[Depends(user_api_key_auth)])
@router.post("/completions", dependencies=[Depends(user_api_key_auth)])
@router.post("/engines/{model:path}/completions", dependencies=[Depends(user_api_key_auth)])
async def completion(request: Request, model: Optional[str] = None):
async def completion(request: Request, model: Optional[str] = None, user_api_key_dict: dict = Depends(user_api_key_auth)):
try:
body = await request.body()
body_str = body.decode()
@ -589,6 +664,7 @@ async def completion(request: Request, model: Optional[str] = None):
if user_model:
data["model"] = user_model
data["call_type"] = "text_completion"
data["metadata"] = {"user_api_key": user_api_key_dict["api_key"]}
return litellm_completion(
**data
)
@ -609,7 +685,7 @@ async def completion(request: Request, model: Optional[str] = None):
@router.post("/v1/chat/completions", dependencies=[Depends(user_api_key_auth)])
@router.post("/chat/completions", dependencies=[Depends(user_api_key_auth)])
@router.post("/openai/deployments/{model:path}/chat/completions", dependencies=[Depends(user_api_key_auth)]) # azure compatible endpoint
async def chat_completion(request: Request, model: Optional[str] = None):
async def chat_completion(request: Request, model: Optional[str] = None, user_api_key_dict: dict = Depends(user_api_key_auth)):
global general_settings
try:
body = await request.body()
@ -626,6 +702,7 @@ async def chat_completion(request: Request, model: Optional[str] = None):
or data["model"] # default passed in http request
)
data["call_type"] = "chat_completion"
data["metadata"] = {"user_api_key": user_api_key_dict["api_key"]}
return litellm_completion(
**data
)
@ -644,7 +721,7 @@ async def chat_completion(request: Request, model: Optional[str] = None):
@router.post("/v1/embeddings", dependencies=[Depends(user_api_key_auth)])
@router.post("/embeddings", dependencies=[Depends(user_api_key_auth)])
async def embeddings(request: Request):
async def embeddings(request: Request, user_api_key_dict: dict = Depends(user_api_key_auth)):
try:
data = await request.json()
print(f"data: {data}")
@ -655,7 +732,7 @@ async def embeddings(request: Request):
)
if user_model:
data["model"] = user_model
data["metadata"] = {"user_api_key": user_api_key_dict["api_key"]}
## ROUTE TO CORRECT ENDPOINT ##
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
if llm_router is not None and data["model"] in router_model_names: # model in router model list

View file

@ -32,7 +32,40 @@ def track_cost_callback(
else:
# we pass the completion_response obj
if kwargs["stream"] != True:
response_cost = litellm.completion_cost(completion_response=completion_response)
input_text = kwargs.get("messages", "")
if isinstance(input_text, list):
input_text = "".join(m["content"] for m in input_text)
response_cost = litellm.completion_cost(completion_response=completion_response, completion=input_text)
print("regular response_cost", response_cost)
except:
pass
pass
def update_prisma_database(token, response_cost):
try:
# Import your Prisma client
from your_prisma_module import prisma
# Fetch the existing cost for the given token
existing_cost = prisma.LiteLLM_VerificationToken.find_unique(
where={
"token": token
}
).cost
# Calculate the new cost by adding the existing cost and response_cost
new_cost = existing_cost + response_cost
# Update the cost column for the given token
prisma_liteLLM_VerificationToken = prisma.LiteLLM_VerificationToken.update(
where={
"token": token
},
data={
"cost": new_cost
}
)
print(f"Prisma database updated for token {token}. New cost: {new_cost}")
except Exception as e:
print(f"Error updating Prisma database: {e}")
pass