mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 10:14:26 +00:00
feat(proxy_server.py): tracking spend per api key
This commit is contained in:
parent
2b52e6995c
commit
4f22e7de18
2 changed files with 132 additions and 22 deletions
|
@ -149,14 +149,18 @@ api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
|
|||
async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(api_key_header)):
|
||||
global master_key, prisma_client, llm_model_list
|
||||
if master_key is None:
|
||||
return
|
||||
return {
|
||||
"api_key": None
|
||||
}
|
||||
try:
|
||||
route = request.url.path
|
||||
|
||||
# note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead
|
||||
is_master_key_valid = secrets.compare_digest(api_key, master_key) or secrets.compare_digest(api_key, "Bearer " + master_key)
|
||||
if is_master_key_valid:
|
||||
return
|
||||
return {
|
||||
"api_key": master_key
|
||||
}
|
||||
|
||||
if (route == "/key/generate" or route == "/key/delete") and not is_master_key_valid:
|
||||
raise Exception(f"If master key is set, only master key can be used to generate new keys")
|
||||
|
@ -186,7 +190,9 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
|
|||
llm_model_list = model_list
|
||||
print("\n new llm router model list", llm_model_list)
|
||||
if len(valid_token.models) == 0: # assume an empty model list means all models are allowed to be called
|
||||
return
|
||||
return {
|
||||
"api_key": valid_token.token
|
||||
}
|
||||
else:
|
||||
data = await request.json()
|
||||
model = data.get("model", None)
|
||||
|
@ -194,7 +200,9 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
|
|||
model = litellm.model_alias_map[model]
|
||||
if model and model not in valid_token.models:
|
||||
raise Exception(f"Token not allowed to access model")
|
||||
return
|
||||
return {
|
||||
"api_key": valid_token.token
|
||||
}
|
||||
else:
|
||||
raise Exception(f"Invalid token")
|
||||
except Exception as e:
|
||||
|
@ -231,6 +239,83 @@ def celery_setup(use_queue: bool):
|
|||
async_result = AsyncResult
|
||||
celery_app_conn = celery_app
|
||||
|
||||
def cost_tracking():
|
||||
global prisma_client, master_key
|
||||
if prisma_client is not None and master_key is not None:
|
||||
if isinstance(litellm.success_callback, list):
|
||||
print("setting litellm success callback to track cost")
|
||||
if (track_cost_callback) not in litellm.success_callback: # type: ignore
|
||||
litellm.success_callback.append(track_cost_callback) # type: ignore
|
||||
else:
|
||||
litellm.success_callback = track_cost_callback # type: ignore
|
||||
|
||||
def track_cost_callback(
|
||||
kwargs, # kwargs to completion
|
||||
completion_response: litellm.ModelResponse, # response from completion
|
||||
start_time = None,
|
||||
end_time = None, # start/end time for completion
|
||||
):
|
||||
try:
|
||||
# init logging config
|
||||
print("in custom callback tracking cost", llm_model_list)
|
||||
# check if it has collected an entire stream response
|
||||
if "complete_streaming_response" in kwargs:
|
||||
# for tracking streaming cost we pass the "messages" and the output_text to litellm.completion_cost
|
||||
completion_response=kwargs["complete_streaming_response"]
|
||||
input_text = kwargs["messages"]
|
||||
output_text = completion_response["choices"][0]["message"]["content"]
|
||||
response_cost = litellm.completion_cost(
|
||||
model = kwargs["model"],
|
||||
messages = input_text,
|
||||
completion=output_text
|
||||
)
|
||||
print("streaming response_cost", response_cost)
|
||||
# for non streaming responses
|
||||
else:
|
||||
# we pass the completion_response obj
|
||||
if kwargs["stream"] != True:
|
||||
input_text = kwargs.get("messages", "")
|
||||
if isinstance(input_text, list):
|
||||
input_text = "".join(m["content"] for m in input_text)
|
||||
response_cost = litellm.completion_cost(completion_response=completion_response, completion=input_text)
|
||||
print("regular response_cost", response_cost)
|
||||
print(f"metadata in kwargs: {kwargs}")
|
||||
user_api_key = kwargs["litellm_params"]["metadata"].get("user_api_key", None)
|
||||
if user_api_key:
|
||||
asyncio.run(update_prisma_database(token=user_api_key, response_cost=response_cost))
|
||||
except Exception as e:
|
||||
print(f"error in tracking cost callback - {str(e)}")
|
||||
|
||||
async def update_prisma_database(token, response_cost):
|
||||
global prisma_client
|
||||
try:
|
||||
print(f"Enters prisma db call, token: {token}")
|
||||
# Fetch the existing cost for the given token
|
||||
existing_spend = await prisma_client.litellm_verificationtoken.find_unique(
|
||||
where={
|
||||
"token": token
|
||||
}
|
||||
)
|
||||
print(f"existing spend: {existing_spend}")
|
||||
# Calculate the new cost by adding the existing cost and response_cost
|
||||
new_spend = existing_spend.spend + response_cost
|
||||
|
||||
print(f"new cost: {new_spend}")
|
||||
# Update the cost column for the given token
|
||||
await prisma_client.litellm_verificationtoken.update(
|
||||
where={
|
||||
"token": token
|
||||
},
|
||||
data={
|
||||
"spend": new_spend
|
||||
}
|
||||
)
|
||||
print(f"Prisma database updated for token {token}. New cost: {new_spend}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error updating Prisma database: {traceback.format_exc()}")
|
||||
pass
|
||||
|
||||
def run_ollama_serve():
|
||||
try:
|
||||
command = ['ollama', 'serve']
|
||||
|
@ -272,15 +357,8 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
|||
### CONNECT TO DATABASE ###
|
||||
database_url = general_settings.get("database_url", None)
|
||||
prisma_setup(database_url=database_url)
|
||||
## Cost Tracking for master key + auth setup ##
|
||||
if master_key is not None:
|
||||
if isinstance(litellm.success_callback, list):
|
||||
import utils
|
||||
print("setting litellm success callback to track cost")
|
||||
if (utils.track_cost_callback) not in litellm.success_callback: # type: ignore
|
||||
litellm.success_callback.append(utils.track_cost_callback) # type: ignore
|
||||
else:
|
||||
litellm.success_callback = utils.track_cost_callback # type: ignore
|
||||
## COST TRACKING ##
|
||||
cost_tracking()
|
||||
### START REDIS QUEUE ###
|
||||
use_queue = general_settings.get("use_queue", False)
|
||||
celery_setup(use_queue=use_queue)
|
||||
|
@ -386,12 +464,10 @@ async def delete_verification_token(tokens: List[str]):
|
|||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||||
return deleted_tokens
|
||||
|
||||
|
||||
async def generate_key_cli_task(duration_str):
|
||||
task = asyncio.create_task(generate_key_helper_fn(duration_str=duration_str))
|
||||
await task
|
||||
|
||||
|
||||
def save_worker_config(**data):
|
||||
import json
|
||||
os.environ["WORKER_CONFIG"] = json.dumps(data)
|
||||
|
@ -487,7 +563,6 @@ def data_generator(response):
|
|||
except:
|
||||
yield f"data: {json.dumps(chunk)}\n\n"
|
||||
|
||||
|
||||
def litellm_completion(*args, **kwargs):
|
||||
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
|
||||
call_type = kwargs.pop("call_type")
|
||||
|
@ -572,7 +647,7 @@ def model_list():
|
|||
@router.post("/v1/completions", dependencies=[Depends(user_api_key_auth)])
|
||||
@router.post("/completions", dependencies=[Depends(user_api_key_auth)])
|
||||
@router.post("/engines/{model:path}/completions", dependencies=[Depends(user_api_key_auth)])
|
||||
async def completion(request: Request, model: Optional[str] = None):
|
||||
async def completion(request: Request, model: Optional[str] = None, user_api_key_dict: dict = Depends(user_api_key_auth)):
|
||||
try:
|
||||
body = await request.body()
|
||||
body_str = body.decode()
|
||||
|
@ -589,6 +664,7 @@ async def completion(request: Request, model: Optional[str] = None):
|
|||
if user_model:
|
||||
data["model"] = user_model
|
||||
data["call_type"] = "text_completion"
|
||||
data["metadata"] = {"user_api_key": user_api_key_dict["api_key"]}
|
||||
return litellm_completion(
|
||||
**data
|
||||
)
|
||||
|
@ -609,7 +685,7 @@ async def completion(request: Request, model: Optional[str] = None):
|
|||
@router.post("/v1/chat/completions", dependencies=[Depends(user_api_key_auth)])
|
||||
@router.post("/chat/completions", dependencies=[Depends(user_api_key_auth)])
|
||||
@router.post("/openai/deployments/{model:path}/chat/completions", dependencies=[Depends(user_api_key_auth)]) # azure compatible endpoint
|
||||
async def chat_completion(request: Request, model: Optional[str] = None):
|
||||
async def chat_completion(request: Request, model: Optional[str] = None, user_api_key_dict: dict = Depends(user_api_key_auth)):
|
||||
global general_settings
|
||||
try:
|
||||
body = await request.body()
|
||||
|
@ -626,6 +702,7 @@ async def chat_completion(request: Request, model: Optional[str] = None):
|
|||
or data["model"] # default passed in http request
|
||||
)
|
||||
data["call_type"] = "chat_completion"
|
||||
data["metadata"] = {"user_api_key": user_api_key_dict["api_key"]}
|
||||
return litellm_completion(
|
||||
**data
|
||||
)
|
||||
|
@ -644,7 +721,7 @@ async def chat_completion(request: Request, model: Optional[str] = None):
|
|||
|
||||
@router.post("/v1/embeddings", dependencies=[Depends(user_api_key_auth)])
|
||||
@router.post("/embeddings", dependencies=[Depends(user_api_key_auth)])
|
||||
async def embeddings(request: Request):
|
||||
async def embeddings(request: Request, user_api_key_dict: dict = Depends(user_api_key_auth)):
|
||||
try:
|
||||
data = await request.json()
|
||||
print(f"data: {data}")
|
||||
|
@ -655,7 +732,7 @@ async def embeddings(request: Request):
|
|||
)
|
||||
if user_model:
|
||||
data["model"] = user_model
|
||||
|
||||
data["metadata"] = {"user_api_key": user_api_key_dict["api_key"]}
|
||||
## ROUTE TO CORRECT ENDPOINT ##
|
||||
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
|
||||
if llm_router is not None and data["model"] in router_model_names: # model in router model list
|
||||
|
|
|
@ -32,7 +32,40 @@ def track_cost_callback(
|
|||
else:
|
||||
# we pass the completion_response obj
|
||||
if kwargs["stream"] != True:
|
||||
response_cost = litellm.completion_cost(completion_response=completion_response)
|
||||
input_text = kwargs.get("messages", "")
|
||||
if isinstance(input_text, list):
|
||||
input_text = "".join(m["content"] for m in input_text)
|
||||
response_cost = litellm.completion_cost(completion_response=completion_response, completion=input_text)
|
||||
print("regular response_cost", response_cost)
|
||||
except:
|
||||
pass
|
||||
pass
|
||||
|
||||
def update_prisma_database(token, response_cost):
|
||||
try:
|
||||
# Import your Prisma client
|
||||
from your_prisma_module import prisma
|
||||
|
||||
# Fetch the existing cost for the given token
|
||||
existing_cost = prisma.LiteLLM_VerificationToken.find_unique(
|
||||
where={
|
||||
"token": token
|
||||
}
|
||||
).cost
|
||||
|
||||
# Calculate the new cost by adding the existing cost and response_cost
|
||||
new_cost = existing_cost + response_cost
|
||||
|
||||
# Update the cost column for the given token
|
||||
prisma_liteLLM_VerificationToken = prisma.LiteLLM_VerificationToken.update(
|
||||
where={
|
||||
"token": token
|
||||
},
|
||||
data={
|
||||
"cost": new_cost
|
||||
}
|
||||
)
|
||||
print(f"Prisma database updated for token {token}. New cost: {new_cost}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error updating Prisma database: {e}")
|
||||
pass
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue