feat(proxy_server.py): tracking spend per api key

This commit is contained in:
Krrish Dholakia 2023-11-24 15:14:06 -08:00
parent 74dcf6c95d
commit 32cdd0a613
2 changed files with 132 additions and 22 deletions

View file

@ -149,14 +149,18 @@ api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(api_key_header)): async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(api_key_header)):
global master_key, prisma_client, llm_model_list global master_key, prisma_client, llm_model_list
if master_key is None: if master_key is None:
return return {
"api_key": None
}
try: try:
route = request.url.path route = request.url.path
# note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead # note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead
is_master_key_valid = secrets.compare_digest(api_key, master_key) or secrets.compare_digest(api_key, "Bearer " + master_key) is_master_key_valid = secrets.compare_digest(api_key, master_key) or secrets.compare_digest(api_key, "Bearer " + master_key)
if is_master_key_valid: if is_master_key_valid:
return return {
"api_key": master_key
}
if (route == "/key/generate" or route == "/key/delete") and not is_master_key_valid: if (route == "/key/generate" or route == "/key/delete") and not is_master_key_valid:
raise Exception(f"If master key is set, only master key can be used to generate new keys") raise Exception(f"If master key is set, only master key can be used to generate new keys")
@ -186,7 +190,9 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
llm_model_list = model_list llm_model_list = model_list
print("\n new llm router model list", llm_model_list) print("\n new llm router model list", llm_model_list)
if len(valid_token.models) == 0: # assume an empty model list means all models are allowed to be called if len(valid_token.models) == 0: # assume an empty model list means all models are allowed to be called
return return {
"api_key": valid_token.token
}
else: else:
data = await request.json() data = await request.json()
model = data.get("model", None) model = data.get("model", None)
@ -194,7 +200,9 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
model = litellm.model_alias_map[model] model = litellm.model_alias_map[model]
if model and model not in valid_token.models: if model and model not in valid_token.models:
raise Exception(f"Token not allowed to access model") raise Exception(f"Token not allowed to access model")
return return {
"api_key": valid_token.token
}
else: else:
raise Exception(f"Invalid token") raise Exception(f"Invalid token")
except Exception as e: except Exception as e:
@ -231,6 +239,83 @@ def celery_setup(use_queue: bool):
async_result = AsyncResult async_result = AsyncResult
celery_app_conn = celery_app celery_app_conn = celery_app
def cost_tracking():
global prisma_client, master_key
if prisma_client is not None and master_key is not None:
if isinstance(litellm.success_callback, list):
print("setting litellm success callback to track cost")
if (track_cost_callback) not in litellm.success_callback: # type: ignore
litellm.success_callback.append(track_cost_callback) # type: ignore
else:
litellm.success_callback = track_cost_callback # type: ignore
def track_cost_callback(
kwargs, # kwargs to completion
completion_response: litellm.ModelResponse, # response from completion
start_time = None,
end_time = None, # start/end time for completion
):
try:
# init logging config
print("in custom callback tracking cost", llm_model_list)
# check if it has collected an entire stream response
if "complete_streaming_response" in kwargs:
# for tracking streaming cost we pass the "messages" and the output_text to litellm.completion_cost
completion_response=kwargs["complete_streaming_response"]
input_text = kwargs["messages"]
output_text = completion_response["choices"][0]["message"]["content"]
response_cost = litellm.completion_cost(
model = kwargs["model"],
messages = input_text,
completion=output_text
)
print("streaming response_cost", response_cost)
# for non streaming responses
else:
# we pass the completion_response obj
if kwargs["stream"] != True:
input_text = kwargs.get("messages", "")
if isinstance(input_text, list):
input_text = "".join(m["content"] for m in input_text)
response_cost = litellm.completion_cost(completion_response=completion_response, completion=input_text)
print("regular response_cost", response_cost)
print(f"metadata in kwargs: {kwargs}")
user_api_key = kwargs["litellm_params"]["metadata"].get("user_api_key", None)
if user_api_key:
asyncio.run(update_prisma_database(token=user_api_key, response_cost=response_cost))
except Exception as e:
print(f"error in tracking cost callback - {str(e)}")
async def update_prisma_database(token, response_cost):
global prisma_client
try:
print(f"Enters prisma db call, token: {token}")
# Fetch the existing cost for the given token
existing_spend = await prisma_client.litellm_verificationtoken.find_unique(
where={
"token": token
}
)
print(f"existing spend: {existing_spend}")
# Calculate the new cost by adding the existing cost and response_cost
new_spend = existing_spend.spend + response_cost
print(f"new cost: {new_spend}")
# Update the cost column for the given token
await prisma_client.litellm_verificationtoken.update(
where={
"token": token
},
data={
"spend": new_spend
}
)
print(f"Prisma database updated for token {token}. New cost: {new_spend}")
except Exception as e:
print(f"Error updating Prisma database: {traceback.format_exc()}")
pass
def run_ollama_serve(): def run_ollama_serve():
try: try:
command = ['ollama', 'serve'] command = ['ollama', 'serve']
@ -272,15 +357,8 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
### CONNECT TO DATABASE ### ### CONNECT TO DATABASE ###
database_url = general_settings.get("database_url", None) database_url = general_settings.get("database_url", None)
prisma_setup(database_url=database_url) prisma_setup(database_url=database_url)
## Cost Tracking for master key + auth setup ## ## COST TRACKING ##
if master_key is not None: cost_tracking()
if isinstance(litellm.success_callback, list):
import utils
print("setting litellm success callback to track cost")
if (utils.track_cost_callback) not in litellm.success_callback: # type: ignore
litellm.success_callback.append(utils.track_cost_callback) # type: ignore
else:
litellm.success_callback = utils.track_cost_callback # type: ignore
### START REDIS QUEUE ### ### START REDIS QUEUE ###
use_queue = general_settings.get("use_queue", False) use_queue = general_settings.get("use_queue", False)
celery_setup(use_queue=use_queue) celery_setup(use_queue=use_queue)
@ -386,12 +464,10 @@ async def delete_verification_token(tokens: List[str]):
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
return deleted_tokens return deleted_tokens
async def generate_key_cli_task(duration_str): async def generate_key_cli_task(duration_str):
task = asyncio.create_task(generate_key_helper_fn(duration_str=duration_str)) task = asyncio.create_task(generate_key_helper_fn(duration_str=duration_str))
await task await task
def save_worker_config(**data): def save_worker_config(**data):
import json import json
os.environ["WORKER_CONFIG"] = json.dumps(data) os.environ["WORKER_CONFIG"] = json.dumps(data)
@ -487,7 +563,6 @@ def data_generator(response):
except: except:
yield f"data: {json.dumps(chunk)}\n\n" yield f"data: {json.dumps(chunk)}\n\n"
def litellm_completion(*args, **kwargs): def litellm_completion(*args, **kwargs):
global user_temperature, user_request_timeout, user_max_tokens, user_api_base global user_temperature, user_request_timeout, user_max_tokens, user_api_base
call_type = kwargs.pop("call_type") call_type = kwargs.pop("call_type")
@ -572,7 +647,7 @@ def model_list():
@router.post("/v1/completions", dependencies=[Depends(user_api_key_auth)]) @router.post("/v1/completions", dependencies=[Depends(user_api_key_auth)])
@router.post("/completions", dependencies=[Depends(user_api_key_auth)]) @router.post("/completions", dependencies=[Depends(user_api_key_auth)])
@router.post("/engines/{model:path}/completions", dependencies=[Depends(user_api_key_auth)]) @router.post("/engines/{model:path}/completions", dependencies=[Depends(user_api_key_auth)])
async def completion(request: Request, model: Optional[str] = None): async def completion(request: Request, model: Optional[str] = None, user_api_key_dict: dict = Depends(user_api_key_auth)):
try: try:
body = await request.body() body = await request.body()
body_str = body.decode() body_str = body.decode()
@ -589,6 +664,7 @@ async def completion(request: Request, model: Optional[str] = None):
if user_model: if user_model:
data["model"] = user_model data["model"] = user_model
data["call_type"] = "text_completion" data["call_type"] = "text_completion"
data["metadata"] = {"user_api_key": user_api_key_dict["api_key"]}
return litellm_completion( return litellm_completion(
**data **data
) )
@ -609,7 +685,7 @@ async def completion(request: Request, model: Optional[str] = None):
@router.post("/v1/chat/completions", dependencies=[Depends(user_api_key_auth)]) @router.post("/v1/chat/completions", dependencies=[Depends(user_api_key_auth)])
@router.post("/chat/completions", dependencies=[Depends(user_api_key_auth)]) @router.post("/chat/completions", dependencies=[Depends(user_api_key_auth)])
@router.post("/openai/deployments/{model:path}/chat/completions", dependencies=[Depends(user_api_key_auth)]) # azure compatible endpoint @router.post("/openai/deployments/{model:path}/chat/completions", dependencies=[Depends(user_api_key_auth)]) # azure compatible endpoint
async def chat_completion(request: Request, model: Optional[str] = None): async def chat_completion(request: Request, model: Optional[str] = None, user_api_key_dict: dict = Depends(user_api_key_auth)):
global general_settings global general_settings
try: try:
body = await request.body() body = await request.body()
@ -626,6 +702,7 @@ async def chat_completion(request: Request, model: Optional[str] = None):
or data["model"] # default passed in http request or data["model"] # default passed in http request
) )
data["call_type"] = "chat_completion" data["call_type"] = "chat_completion"
data["metadata"] = {"user_api_key": user_api_key_dict["api_key"]}
return litellm_completion( return litellm_completion(
**data **data
) )
@ -644,7 +721,7 @@ async def chat_completion(request: Request, model: Optional[str] = None):
@router.post("/v1/embeddings", dependencies=[Depends(user_api_key_auth)]) @router.post("/v1/embeddings", dependencies=[Depends(user_api_key_auth)])
@router.post("/embeddings", dependencies=[Depends(user_api_key_auth)]) @router.post("/embeddings", dependencies=[Depends(user_api_key_auth)])
async def embeddings(request: Request): async def embeddings(request: Request, user_api_key_dict: dict = Depends(user_api_key_auth)):
try: try:
data = await request.json() data = await request.json()
print(f"data: {data}") print(f"data: {data}")
@ -655,7 +732,7 @@ async def embeddings(request: Request):
) )
if user_model: if user_model:
data["model"] = user_model data["model"] = user_model
data["metadata"] = {"user_api_key": user_api_key_dict["api_key"]}
## ROUTE TO CORRECT ENDPOINT ## ## ROUTE TO CORRECT ENDPOINT ##
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else [] router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
if llm_router is not None and data["model"] in router_model_names: # model in router model list if llm_router is not None and data["model"] in router_model_names: # model in router model list

View file

@ -32,7 +32,40 @@ def track_cost_callback(
else: else:
# we pass the completion_response obj # we pass the completion_response obj
if kwargs["stream"] != True: if kwargs["stream"] != True:
response_cost = litellm.completion_cost(completion_response=completion_response) input_text = kwargs.get("messages", "")
if isinstance(input_text, list):
input_text = "".join(m["content"] for m in input_text)
response_cost = litellm.completion_cost(completion_response=completion_response, completion=input_text)
print("regular response_cost", response_cost) print("regular response_cost", response_cost)
except: except:
pass pass
def update_prisma_database(token, response_cost):
try:
# Import your Prisma client
from your_prisma_module import prisma
# Fetch the existing cost for the given token
existing_cost = prisma.LiteLLM_VerificationToken.find_unique(
where={
"token": token
}
).cost
# Calculate the new cost by adding the existing cost and response_cost
new_cost = existing_cost + response_cost
# Update the cost column for the given token
prisma_liteLLM_VerificationToken = prisma.LiteLLM_VerificationToken.update(
where={
"token": token
},
data={
"cost": new_cost
}
)
print(f"Prisma database updated for token {token}. New cost: {new_cost}")
except Exception as e:
print(f"Error updating Prisma database: {e}")
pass