forked from phoenix/litellm-mirror
* ci(config.yml): add a 'check_code_quality' step Addresses https://github.com/BerriAI/litellm/issues/5991 * ci(config.yml): check why circle ci doesn't pick up this test * ci(config.yml): fix to run 'check_code_quality' tests * fix(__init__.py): fix unprotected import * fix(__init__.py): don't remove unused imports * build(ruff.toml): update ruff.toml to ignore unused imports * fix: fix: ruff + pyright - fix linting + type-checking errors * fix: fix linting errors * fix(lago.py): fix module init error * fix: fix linting errors * ci(config.yml): cd into correct dir for checks * fix(proxy_server.py): fix linting error * fix(utils.py): fix bare except causes ruff linting errors * fix: ruff - fix remaining linting errors * fix(clickhouse.py): use standard logging object * fix(__init__.py): fix unprotected import * fix: ruff - fix linting errors * fix: fix linting errors * ci(config.yml): cleanup code qa step (formatting handled in local_testing) * fix(_health_endpoints.py): fix ruff linting errors * ci(config.yml): just use ruff in check_code_quality pipeline for now * build(custom_guardrail.py): include missing file * style(embedding_handler.py): fix ruff check
220 lines
8 KiB
Python
220 lines
8 KiB
Python
# +-----------------------------------------------+
|
|
# | |
|
|
# | NOT PROXY BUDGET MANAGER |
|
|
# | proxy budget manager is in proxy_server.py |
|
|
# | |
|
|
# +-----------------------------------------------+
|
|
#
|
|
# Thank you users! We ❤️ you! - Krrish & Ishaan
|
|
|
|
import json
|
|
import os
|
|
import threading
|
|
import time
|
|
from typing import Literal, Optional, Union
|
|
|
|
import requests # type: ignore
|
|
|
|
import litellm
|
|
from litellm.utils import ModelResponse
|
|
|
|
|
|
class BudgetManager:
|
|
def __init__(
|
|
self,
|
|
project_name: str,
|
|
client_type: str = "local",
|
|
api_base: Optional[str] = None,
|
|
headers: Optional[dict] = None,
|
|
):
|
|
self.client_type = client_type
|
|
self.project_name = project_name
|
|
self.api_base = api_base or "https://api.litellm.ai"
|
|
self.headers = headers or {"Content-Type": "application/json"}
|
|
## load the data or init the initial dictionaries
|
|
self.load_data()
|
|
|
|
def print_verbose(self, print_statement):
|
|
try:
|
|
if litellm.set_verbose:
|
|
import logging
|
|
|
|
logging.info(print_statement)
|
|
except Exception:
|
|
pass
|
|
|
|
def load_data(self):
|
|
if self.client_type == "local":
|
|
# Check if user dict file exists
|
|
if os.path.isfile("user_cost.json"):
|
|
# Load the user dict
|
|
with open("user_cost.json", "r") as json_file:
|
|
self.user_dict = json.load(json_file)
|
|
else:
|
|
self.print_verbose("User Dictionary not found!")
|
|
self.user_dict = {}
|
|
self.print_verbose(f"user dict from local: {self.user_dict}")
|
|
elif self.client_type == "hosted":
|
|
# Load the user_dict from hosted db
|
|
url = self.api_base + "/get_budget"
|
|
data = {"project_name": self.project_name}
|
|
response = requests.post(url, headers=self.headers, json=data)
|
|
response = response.json()
|
|
if response["status"] == "error":
|
|
self.user_dict = (
|
|
{}
|
|
) # assume this means the user dict hasn't been stored yet
|
|
else:
|
|
self.user_dict = response["data"]
|
|
|
|
def create_budget(
|
|
self,
|
|
total_budget: float,
|
|
user: str,
|
|
duration: Optional[Literal["daily", "weekly", "monthly", "yearly"]] = None,
|
|
created_at: float = time.time(),
|
|
):
|
|
self.user_dict[user] = {"total_budget": total_budget}
|
|
if duration is None:
|
|
return self.user_dict[user]
|
|
|
|
if duration == "daily":
|
|
duration_in_days = 1
|
|
elif duration == "weekly":
|
|
duration_in_days = 7
|
|
elif duration == "monthly":
|
|
duration_in_days = 28
|
|
elif duration == "yearly":
|
|
duration_in_days = 365
|
|
else:
|
|
raise ValueError(
|
|
"""duration needs to be one of ["daily", "weekly", "monthly", "yearly"]"""
|
|
)
|
|
self.user_dict[user] = {
|
|
"total_budget": total_budget,
|
|
"duration": duration_in_days,
|
|
"created_at": created_at,
|
|
"last_updated_at": created_at,
|
|
}
|
|
self._save_data_thread() # [Non-Blocking] Update persistent storage without blocking execution
|
|
return self.user_dict[user]
|
|
|
|
def projected_cost(self, model: str, messages: list, user: str):
|
|
text = "".join(message["content"] for message in messages)
|
|
prompt_tokens = litellm.token_counter(model=model, text=text)
|
|
prompt_cost, _ = litellm.cost_per_token(
|
|
model=model, prompt_tokens=prompt_tokens, completion_tokens=0
|
|
)
|
|
current_cost = self.user_dict[user].get("current_cost", 0)
|
|
projected_cost = prompt_cost + current_cost
|
|
return projected_cost
|
|
|
|
def get_total_budget(self, user: str):
|
|
return self.user_dict[user]["total_budget"]
|
|
|
|
def update_cost(
|
|
self,
|
|
user: str,
|
|
completion_obj: Optional[ModelResponse] = None,
|
|
model: Optional[str] = None,
|
|
input_text: Optional[str] = None,
|
|
output_text: Optional[str] = None,
|
|
):
|
|
if model and input_text and output_text:
|
|
prompt_tokens = litellm.token_counter(
|
|
model=model, messages=[{"role": "user", "content": input_text}]
|
|
)
|
|
completion_tokens = litellm.token_counter(
|
|
model=model, messages=[{"role": "user", "content": output_text}]
|
|
)
|
|
(
|
|
prompt_tokens_cost_usd_dollar,
|
|
completion_tokens_cost_usd_dollar,
|
|
) = litellm.cost_per_token(
|
|
model=model,
|
|
prompt_tokens=prompt_tokens,
|
|
completion_tokens=completion_tokens,
|
|
)
|
|
cost = prompt_tokens_cost_usd_dollar + completion_tokens_cost_usd_dollar
|
|
elif completion_obj:
|
|
cost = litellm.completion_cost(completion_response=completion_obj)
|
|
model = completion_obj[
|
|
"model"
|
|
] # if this throws an error try, model = completion_obj['model']
|
|
else:
|
|
raise ValueError(
|
|
"Either a chat completion object or the text response needs to be passed in. Learn more - https://docs.litellm.ai/docs/budget_manager"
|
|
)
|
|
|
|
self.user_dict[user]["current_cost"] = cost + self.user_dict[user].get(
|
|
"current_cost", 0
|
|
)
|
|
if "model_cost" in self.user_dict[user]:
|
|
self.user_dict[user]["model_cost"][model] = cost + self.user_dict[user][
|
|
"model_cost"
|
|
].get(model, 0)
|
|
else:
|
|
self.user_dict[user]["model_cost"] = {model: cost}
|
|
|
|
self._save_data_thread() # [Non-Blocking] Update persistent storage without blocking execution
|
|
return {"user": self.user_dict[user]}
|
|
|
|
def get_current_cost(self, user):
|
|
return self.user_dict[user].get("current_cost", 0)
|
|
|
|
def get_model_cost(self, user):
|
|
return self.user_dict[user].get("model_cost", 0)
|
|
|
|
def is_valid_user(self, user: str) -> bool:
|
|
return user in self.user_dict
|
|
|
|
def get_users(self):
|
|
return list(self.user_dict.keys())
|
|
|
|
def reset_cost(self, user):
|
|
self.user_dict[user]["current_cost"] = 0
|
|
self.user_dict[user]["model_cost"] = {}
|
|
return {"user": self.user_dict[user]}
|
|
|
|
def reset_on_duration(self, user: str):
|
|
# Get current and creation time
|
|
last_updated_at = self.user_dict[user]["last_updated_at"]
|
|
current_time = time.time()
|
|
|
|
# Convert duration from days to seconds
|
|
duration_in_seconds = self.user_dict[user]["duration"] * 24 * 60 * 60
|
|
|
|
# Check if duration has elapsed
|
|
if current_time - last_updated_at >= duration_in_seconds:
|
|
# Reset cost if duration has elapsed and update the creation time
|
|
self.reset_cost(user)
|
|
self.user_dict[user]["last_updated_at"] = current_time
|
|
self._save_data_thread() # Save the data
|
|
|
|
def update_budget_all_users(self):
|
|
for user in self.get_users():
|
|
if "duration" in self.user_dict[user]:
|
|
self.reset_on_duration(user)
|
|
|
|
def _save_data_thread(self):
|
|
thread = threading.Thread(
|
|
target=self.save_data
|
|
) # [Non-Blocking]: saves data without blocking execution
|
|
thread.start()
|
|
|
|
def save_data(self):
|
|
if self.client_type == "local":
|
|
import json
|
|
|
|
# save the user dict
|
|
with open("user_cost.json", "w") as json_file:
|
|
json.dump(
|
|
self.user_dict, json_file, indent=4
|
|
) # Indent for pretty formatting
|
|
return {"status": "success"}
|
|
elif self.client_type == "hosted":
|
|
url = self.api_base + "/set_budget"
|
|
data = {"project_name": self.project_name, "user_dict": self.user_dict}
|
|
response = requests.post(url, headers=self.headers, json=data)
|
|
response = response.json()
|
|
return response
|