mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
feat(proxy_server.py): tracking spend per api key
This commit is contained in:
parent
74dcf6c95d
commit
32cdd0a613
2 changed files with 132 additions and 22 deletions
|
@ -149,14 +149,18 @@ api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
|
||||||
async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(api_key_header)):
|
async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(api_key_header)):
|
||||||
global master_key, prisma_client, llm_model_list
|
global master_key, prisma_client, llm_model_list
|
||||||
if master_key is None:
|
if master_key is None:
|
||||||
return
|
return {
|
||||||
|
"api_key": None
|
||||||
|
}
|
||||||
try:
|
try:
|
||||||
route = request.url.path
|
route = request.url.path
|
||||||
|
|
||||||
# note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead
|
# note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead
|
||||||
is_master_key_valid = secrets.compare_digest(api_key, master_key) or secrets.compare_digest(api_key, "Bearer " + master_key)
|
is_master_key_valid = secrets.compare_digest(api_key, master_key) or secrets.compare_digest(api_key, "Bearer " + master_key)
|
||||||
if is_master_key_valid:
|
if is_master_key_valid:
|
||||||
return
|
return {
|
||||||
|
"api_key": master_key
|
||||||
|
}
|
||||||
|
|
||||||
if (route == "/key/generate" or route == "/key/delete") and not is_master_key_valid:
|
if (route == "/key/generate" or route == "/key/delete") and not is_master_key_valid:
|
||||||
raise Exception(f"If master key is set, only master key can be used to generate new keys")
|
raise Exception(f"If master key is set, only master key can be used to generate new keys")
|
||||||
|
@ -186,7 +190,9 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
|
||||||
llm_model_list = model_list
|
llm_model_list = model_list
|
||||||
print("\n new llm router model list", llm_model_list)
|
print("\n new llm router model list", llm_model_list)
|
||||||
if len(valid_token.models) == 0: # assume an empty model list means all models are allowed to be called
|
if len(valid_token.models) == 0: # assume an empty model list means all models are allowed to be called
|
||||||
return
|
return {
|
||||||
|
"api_key": valid_token.token
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
data = await request.json()
|
data = await request.json()
|
||||||
model = data.get("model", None)
|
model = data.get("model", None)
|
||||||
|
@ -194,7 +200,9 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
|
||||||
model = litellm.model_alias_map[model]
|
model = litellm.model_alias_map[model]
|
||||||
if model and model not in valid_token.models:
|
if model and model not in valid_token.models:
|
||||||
raise Exception(f"Token not allowed to access model")
|
raise Exception(f"Token not allowed to access model")
|
||||||
return
|
return {
|
||||||
|
"api_key": valid_token.token
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
raise Exception(f"Invalid token")
|
raise Exception(f"Invalid token")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -231,6 +239,83 @@ def celery_setup(use_queue: bool):
|
||||||
async_result = AsyncResult
|
async_result = AsyncResult
|
||||||
celery_app_conn = celery_app
|
celery_app_conn = celery_app
|
||||||
|
|
||||||
|
def cost_tracking():
|
||||||
|
global prisma_client, master_key
|
||||||
|
if prisma_client is not None and master_key is not None:
|
||||||
|
if isinstance(litellm.success_callback, list):
|
||||||
|
print("setting litellm success callback to track cost")
|
||||||
|
if (track_cost_callback) not in litellm.success_callback: # type: ignore
|
||||||
|
litellm.success_callback.append(track_cost_callback) # type: ignore
|
||||||
|
else:
|
||||||
|
litellm.success_callback = track_cost_callback # type: ignore
|
||||||
|
|
||||||
|
def track_cost_callback(
|
||||||
|
kwargs, # kwargs to completion
|
||||||
|
completion_response: litellm.ModelResponse, # response from completion
|
||||||
|
start_time = None,
|
||||||
|
end_time = None, # start/end time for completion
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
# init logging config
|
||||||
|
print("in custom callback tracking cost", llm_model_list)
|
||||||
|
# check if it has collected an entire stream response
|
||||||
|
if "complete_streaming_response" in kwargs:
|
||||||
|
# for tracking streaming cost we pass the "messages" and the output_text to litellm.completion_cost
|
||||||
|
completion_response=kwargs["complete_streaming_response"]
|
||||||
|
input_text = kwargs["messages"]
|
||||||
|
output_text = completion_response["choices"][0]["message"]["content"]
|
||||||
|
response_cost = litellm.completion_cost(
|
||||||
|
model = kwargs["model"],
|
||||||
|
messages = input_text,
|
||||||
|
completion=output_text
|
||||||
|
)
|
||||||
|
print("streaming response_cost", response_cost)
|
||||||
|
# for non streaming responses
|
||||||
|
else:
|
||||||
|
# we pass the completion_response obj
|
||||||
|
if kwargs["stream"] != True:
|
||||||
|
input_text = kwargs.get("messages", "")
|
||||||
|
if isinstance(input_text, list):
|
||||||
|
input_text = "".join(m["content"] for m in input_text)
|
||||||
|
response_cost = litellm.completion_cost(completion_response=completion_response, completion=input_text)
|
||||||
|
print("regular response_cost", response_cost)
|
||||||
|
print(f"metadata in kwargs: {kwargs}")
|
||||||
|
user_api_key = kwargs["litellm_params"]["metadata"].get("user_api_key", None)
|
||||||
|
if user_api_key:
|
||||||
|
asyncio.run(update_prisma_database(token=user_api_key, response_cost=response_cost))
|
||||||
|
except Exception as e:
|
||||||
|
print(f"error in tracking cost callback - {str(e)}")
|
||||||
|
|
||||||
|
async def update_prisma_database(token, response_cost):
|
||||||
|
global prisma_client
|
||||||
|
try:
|
||||||
|
print(f"Enters prisma db call, token: {token}")
|
||||||
|
# Fetch the existing cost for the given token
|
||||||
|
existing_spend = await prisma_client.litellm_verificationtoken.find_unique(
|
||||||
|
where={
|
||||||
|
"token": token
|
||||||
|
}
|
||||||
|
)
|
||||||
|
print(f"existing spend: {existing_spend}")
|
||||||
|
# Calculate the new cost by adding the existing cost and response_cost
|
||||||
|
new_spend = existing_spend.spend + response_cost
|
||||||
|
|
||||||
|
print(f"new cost: {new_spend}")
|
||||||
|
# Update the cost column for the given token
|
||||||
|
await prisma_client.litellm_verificationtoken.update(
|
||||||
|
where={
|
||||||
|
"token": token
|
||||||
|
},
|
||||||
|
data={
|
||||||
|
"spend": new_spend
|
||||||
|
}
|
||||||
|
)
|
||||||
|
print(f"Prisma database updated for token {token}. New cost: {new_spend}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error updating Prisma database: {traceback.format_exc()}")
|
||||||
|
pass
|
||||||
|
|
||||||
def run_ollama_serve():
|
def run_ollama_serve():
|
||||||
try:
|
try:
|
||||||
command = ['ollama', 'serve']
|
command = ['ollama', 'serve']
|
||||||
|
@ -272,15 +357,8 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
||||||
### CONNECT TO DATABASE ###
|
### CONNECT TO DATABASE ###
|
||||||
database_url = general_settings.get("database_url", None)
|
database_url = general_settings.get("database_url", None)
|
||||||
prisma_setup(database_url=database_url)
|
prisma_setup(database_url=database_url)
|
||||||
## Cost Tracking for master key + auth setup ##
|
## COST TRACKING ##
|
||||||
if master_key is not None:
|
cost_tracking()
|
||||||
if isinstance(litellm.success_callback, list):
|
|
||||||
import utils
|
|
||||||
print("setting litellm success callback to track cost")
|
|
||||||
if (utils.track_cost_callback) not in litellm.success_callback: # type: ignore
|
|
||||||
litellm.success_callback.append(utils.track_cost_callback) # type: ignore
|
|
||||||
else:
|
|
||||||
litellm.success_callback = utils.track_cost_callback # type: ignore
|
|
||||||
### START REDIS QUEUE ###
|
### START REDIS QUEUE ###
|
||||||
use_queue = general_settings.get("use_queue", False)
|
use_queue = general_settings.get("use_queue", False)
|
||||||
celery_setup(use_queue=use_queue)
|
celery_setup(use_queue=use_queue)
|
||||||
|
@ -386,12 +464,10 @@ async def delete_verification_token(tokens: List[str]):
|
||||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||||||
return deleted_tokens
|
return deleted_tokens
|
||||||
|
|
||||||
|
|
||||||
async def generate_key_cli_task(duration_str):
|
async def generate_key_cli_task(duration_str):
|
||||||
task = asyncio.create_task(generate_key_helper_fn(duration_str=duration_str))
|
task = asyncio.create_task(generate_key_helper_fn(duration_str=duration_str))
|
||||||
await task
|
await task
|
||||||
|
|
||||||
|
|
||||||
def save_worker_config(**data):
|
def save_worker_config(**data):
|
||||||
import json
|
import json
|
||||||
os.environ["WORKER_CONFIG"] = json.dumps(data)
|
os.environ["WORKER_CONFIG"] = json.dumps(data)
|
||||||
|
@ -487,7 +563,6 @@ def data_generator(response):
|
||||||
except:
|
except:
|
||||||
yield f"data: {json.dumps(chunk)}\n\n"
|
yield f"data: {json.dumps(chunk)}\n\n"
|
||||||
|
|
||||||
|
|
||||||
def litellm_completion(*args, **kwargs):
|
def litellm_completion(*args, **kwargs):
|
||||||
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
|
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
|
||||||
call_type = kwargs.pop("call_type")
|
call_type = kwargs.pop("call_type")
|
||||||
|
@ -572,7 +647,7 @@ def model_list():
|
||||||
@router.post("/v1/completions", dependencies=[Depends(user_api_key_auth)])
|
@router.post("/v1/completions", dependencies=[Depends(user_api_key_auth)])
|
||||||
@router.post("/completions", dependencies=[Depends(user_api_key_auth)])
|
@router.post("/completions", dependencies=[Depends(user_api_key_auth)])
|
||||||
@router.post("/engines/{model:path}/completions", dependencies=[Depends(user_api_key_auth)])
|
@router.post("/engines/{model:path}/completions", dependencies=[Depends(user_api_key_auth)])
|
||||||
async def completion(request: Request, model: Optional[str] = None):
|
async def completion(request: Request, model: Optional[str] = None, user_api_key_dict: dict = Depends(user_api_key_auth)):
|
||||||
try:
|
try:
|
||||||
body = await request.body()
|
body = await request.body()
|
||||||
body_str = body.decode()
|
body_str = body.decode()
|
||||||
|
@ -589,6 +664,7 @@ async def completion(request: Request, model: Optional[str] = None):
|
||||||
if user_model:
|
if user_model:
|
||||||
data["model"] = user_model
|
data["model"] = user_model
|
||||||
data["call_type"] = "text_completion"
|
data["call_type"] = "text_completion"
|
||||||
|
data["metadata"] = {"user_api_key": user_api_key_dict["api_key"]}
|
||||||
return litellm_completion(
|
return litellm_completion(
|
||||||
**data
|
**data
|
||||||
)
|
)
|
||||||
|
@ -609,7 +685,7 @@ async def completion(request: Request, model: Optional[str] = None):
|
||||||
@router.post("/v1/chat/completions", dependencies=[Depends(user_api_key_auth)])
|
@router.post("/v1/chat/completions", dependencies=[Depends(user_api_key_auth)])
|
||||||
@router.post("/chat/completions", dependencies=[Depends(user_api_key_auth)])
|
@router.post("/chat/completions", dependencies=[Depends(user_api_key_auth)])
|
||||||
@router.post("/openai/deployments/{model:path}/chat/completions", dependencies=[Depends(user_api_key_auth)]) # azure compatible endpoint
|
@router.post("/openai/deployments/{model:path}/chat/completions", dependencies=[Depends(user_api_key_auth)]) # azure compatible endpoint
|
||||||
async def chat_completion(request: Request, model: Optional[str] = None):
|
async def chat_completion(request: Request, model: Optional[str] = None, user_api_key_dict: dict = Depends(user_api_key_auth)):
|
||||||
global general_settings
|
global general_settings
|
||||||
try:
|
try:
|
||||||
body = await request.body()
|
body = await request.body()
|
||||||
|
@ -626,6 +702,7 @@ async def chat_completion(request: Request, model: Optional[str] = None):
|
||||||
or data["model"] # default passed in http request
|
or data["model"] # default passed in http request
|
||||||
)
|
)
|
||||||
data["call_type"] = "chat_completion"
|
data["call_type"] = "chat_completion"
|
||||||
|
data["metadata"] = {"user_api_key": user_api_key_dict["api_key"]}
|
||||||
return litellm_completion(
|
return litellm_completion(
|
||||||
**data
|
**data
|
||||||
)
|
)
|
||||||
|
@ -644,7 +721,7 @@ async def chat_completion(request: Request, model: Optional[str] = None):
|
||||||
|
|
||||||
@router.post("/v1/embeddings", dependencies=[Depends(user_api_key_auth)])
|
@router.post("/v1/embeddings", dependencies=[Depends(user_api_key_auth)])
|
||||||
@router.post("/embeddings", dependencies=[Depends(user_api_key_auth)])
|
@router.post("/embeddings", dependencies=[Depends(user_api_key_auth)])
|
||||||
async def embeddings(request: Request):
|
async def embeddings(request: Request, user_api_key_dict: dict = Depends(user_api_key_auth)):
|
||||||
try:
|
try:
|
||||||
data = await request.json()
|
data = await request.json()
|
||||||
print(f"data: {data}")
|
print(f"data: {data}")
|
||||||
|
@ -655,7 +732,7 @@ async def embeddings(request: Request):
|
||||||
)
|
)
|
||||||
if user_model:
|
if user_model:
|
||||||
data["model"] = user_model
|
data["model"] = user_model
|
||||||
|
data["metadata"] = {"user_api_key": user_api_key_dict["api_key"]}
|
||||||
## ROUTE TO CORRECT ENDPOINT ##
|
## ROUTE TO CORRECT ENDPOINT ##
|
||||||
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
|
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
|
||||||
if llm_router is not None and data["model"] in router_model_names: # model in router model list
|
if llm_router is not None and data["model"] in router_model_names: # model in router model list
|
||||||
|
|
|
@ -32,7 +32,40 @@ def track_cost_callback(
|
||||||
else:
|
else:
|
||||||
# we pass the completion_response obj
|
# we pass the completion_response obj
|
||||||
if kwargs["stream"] != True:
|
if kwargs["stream"] != True:
|
||||||
response_cost = litellm.completion_cost(completion_response=completion_response)
|
input_text = kwargs.get("messages", "")
|
||||||
|
if isinstance(input_text, list):
|
||||||
|
input_text = "".join(m["content"] for m in input_text)
|
||||||
|
response_cost = litellm.completion_cost(completion_response=completion_response, completion=input_text)
|
||||||
print("regular response_cost", response_cost)
|
print("regular response_cost", response_cost)
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def update_prisma_database(token, response_cost):
|
||||||
|
try:
|
||||||
|
# Import your Prisma client
|
||||||
|
from your_prisma_module import prisma
|
||||||
|
|
||||||
|
# Fetch the existing cost for the given token
|
||||||
|
existing_cost = prisma.LiteLLM_VerificationToken.find_unique(
|
||||||
|
where={
|
||||||
|
"token": token
|
||||||
|
}
|
||||||
|
).cost
|
||||||
|
|
||||||
|
# Calculate the new cost by adding the existing cost and response_cost
|
||||||
|
new_cost = existing_cost + response_cost
|
||||||
|
|
||||||
|
# Update the cost column for the given token
|
||||||
|
prisma_liteLLM_VerificationToken = prisma.LiteLLM_VerificationToken.update(
|
||||||
|
where={
|
||||||
|
"token": token
|
||||||
|
},
|
||||||
|
data={
|
||||||
|
"cost": new_cost
|
||||||
|
}
|
||||||
|
)
|
||||||
|
print(f"Prisma database updated for token {token}. New cost: {new_cost}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error updating Prisma database: {e}")
|
||||||
|
pass
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue