forked from phoenix/litellm-mirror
Merge branch 'main' into litellm_auto_create_user_fix
This commit is contained in:
commit
7d3244c012
60 changed files with 1237 additions and 215 deletions
|
@ -97,7 +97,6 @@ from litellm.proxy.utils import (
|
|||
_is_projected_spend_over_limit,
|
||||
_get_projected_spend_over_limit,
|
||||
update_spend,
|
||||
monitor_spend_list,
|
||||
)
|
||||
from litellm.proxy.secret_managers.google_kms import load_google_kms
|
||||
from litellm.proxy.secret_managers.aws_secret_manager import load_aws_secret_manager
|
||||
|
@ -118,6 +117,7 @@ from litellm.proxy.auth.auth_checks import (
|
|||
allowed_routes_check,
|
||||
get_actual_routes,
|
||||
)
|
||||
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
|
||||
|
||||
try:
|
||||
from litellm._version import version
|
||||
|
@ -130,7 +130,6 @@ from fastapi import (
|
|||
HTTPException,
|
||||
status,
|
||||
Depends,
|
||||
BackgroundTasks,
|
||||
Header,
|
||||
Response,
|
||||
Form,
|
||||
|
@ -305,6 +304,8 @@ proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache)
|
|||
async_result = None
|
||||
celery_app_conn = None
|
||||
celery_fn = None # Redis Queue for handling requests
|
||||
### DB WRITER ###
|
||||
db_writer_client: Optional[HTTPHandler] = None
|
||||
### logger ###
|
||||
|
||||
|
||||
|
@ -440,6 +441,8 @@ async def user_api_key_auth(
|
|||
request_body=request_data,
|
||||
team_object=team_object,
|
||||
end_user_object=end_user_object,
|
||||
general_settings=general_settings,
|
||||
route=route,
|
||||
)
|
||||
# save user object in cache
|
||||
await user_api_key_cache.async_set_cache(
|
||||
|
@ -867,6 +870,23 @@ async def user_api_key_auth(
|
|||
f"ExceededTokenBudget: Current Team Spend: {valid_token.team_spend}; Max Budget for Team: {valid_token.team_max_budget}"
|
||||
)
|
||||
|
||||
# Check 8: Additional Common Checks across jwt + key auth
|
||||
_team_obj = LiteLLM_TeamTable(
|
||||
team_id=valid_token.team_id,
|
||||
max_budget=valid_token.team_max_budget,
|
||||
spend=valid_token.team_spend,
|
||||
tpm_limit=valid_token.team_tpm_limit,
|
||||
rpm_limit=valid_token.team_rpm_limit,
|
||||
blocked=valid_token.team_blocked,
|
||||
models=valid_token.team_models,
|
||||
)
|
||||
_ = common_checks(
|
||||
request_body=request_data,
|
||||
team_object=_team_obj,
|
||||
end_user_object=None,
|
||||
general_settings=general_settings,
|
||||
route=route,
|
||||
)
|
||||
# Token passed all checks
|
||||
api_key = valid_token.token
|
||||
|
||||
|
@ -1233,10 +1253,11 @@ async def update_database(
|
|||
user_ids.append(litellm_proxy_budget_name)
|
||||
### KEY CHANGE ###
|
||||
for _id in user_ids:
|
||||
prisma_client.user_list_transactons[_id] = (
|
||||
response_cost
|
||||
+ prisma_client.user_list_transactons.get(_id, 0)
|
||||
)
|
||||
if _id is not None:
|
||||
prisma_client.user_list_transactons[_id] = (
|
||||
response_cost
|
||||
+ prisma_client.user_list_transactons.get(_id, 0)
|
||||
)
|
||||
if end_user_id is not None:
|
||||
prisma_client.end_user_list_transactons[end_user_id] = (
|
||||
response_cost
|
||||
|
@ -1364,7 +1385,16 @@ async def update_database(
|
|||
)
|
||||
|
||||
payload["spend"] = response_cost
|
||||
if prisma_client is not None:
|
||||
if (
|
||||
os.getenv("SPEND_LOGS_URL", None) is not None
|
||||
and prisma_client is not None
|
||||
):
|
||||
if isinstance(payload["startTime"], datetime):
|
||||
payload["startTime"] = payload["startTime"].isoformat()
|
||||
if isinstance(payload["endTime"], datetime):
|
||||
payload["endTime"] = payload["endTime"].isoformat()
|
||||
prisma_client.spend_log_transactons.append(payload)
|
||||
elif prisma_client is not None:
|
||||
await prisma_client.insert_data(data=payload, table_name="spend")
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(
|
||||
|
@ -2615,11 +2645,7 @@ async def async_data_generator(response, user_api_key_dict):
|
|||
verbose_proxy_logger.debug(
|
||||
f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`"
|
||||
)
|
||||
router_model_names = (
|
||||
[m["model_name"] for m in llm_model_list]
|
||||
if llm_model_list is not None
|
||||
else []
|
||||
)
|
||||
router_model_names = llm_router.model_names if llm_router is not None else []
|
||||
if user_debug:
|
||||
traceback.print_exc()
|
||||
|
||||
|
@ -2678,7 +2704,7 @@ def on_backoff(details):
|
|||
|
||||
@router.on_event("startup")
|
||||
async def startup_event():
|
||||
global prisma_client, master_key, use_background_health_checks, llm_router, llm_model_list, general_settings, proxy_budget_rescheduler_min_time, proxy_budget_rescheduler_max_time, litellm_proxy_admin_name
|
||||
global prisma_client, master_key, use_background_health_checks, llm_router, llm_model_list, general_settings, proxy_budget_rescheduler_min_time, proxy_budget_rescheduler_max_time, litellm_proxy_admin_name, db_writer_client
|
||||
import json
|
||||
|
||||
### LOAD MASTER KEY ###
|
||||
|
@ -2711,6 +2737,8 @@ async def startup_event():
|
|||
## COST TRACKING ##
|
||||
cost_tracking()
|
||||
|
||||
db_writer_client = HTTPHandler()
|
||||
|
||||
proxy_logging_obj._init_litellm_callbacks() # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made
|
||||
|
||||
## JWT AUTH ##
|
||||
|
@ -2821,7 +2849,7 @@ async def startup_event():
|
|||
update_spend,
|
||||
"interval",
|
||||
seconds=batch_writing_interval,
|
||||
args=[prisma_client],
|
||||
args=[prisma_client, db_writer_client],
|
||||
)
|
||||
scheduler.start()
|
||||
|
||||
|
@ -2881,7 +2909,6 @@ async def completion(
|
|||
fastapi_response: Response,
|
||||
model: Optional[str] = None,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
background_tasks: BackgroundTasks = BackgroundTasks(),
|
||||
):
|
||||
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
|
||||
try:
|
||||
|
@ -2943,11 +2970,7 @@ async def completion(
|
|||
start_time = time.time()
|
||||
|
||||
### ROUTE THE REQUESTs ###
|
||||
router_model_names = (
|
||||
[m["model_name"] for m in llm_model_list]
|
||||
if llm_model_list is not None
|
||||
else []
|
||||
)
|
||||
router_model_names = llm_router.model_names if llm_router is not None else []
|
||||
# skip router if user passed their key
|
||||
if "api_key" in data:
|
||||
response = await litellm.atext_completion(**data)
|
||||
|
@ -3047,7 +3070,6 @@ async def chat_completion(
|
|||
fastapi_response: Response,
|
||||
model: Optional[str] = None,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
background_tasks: BackgroundTasks = BackgroundTasks(),
|
||||
):
|
||||
global general_settings, user_debug, proxy_logging_obj, llm_model_list
|
||||
try:
|
||||
|
@ -3161,11 +3183,8 @@ async def chat_completion(
|
|||
start_time = time.time()
|
||||
|
||||
### ROUTE THE REQUEST ###
|
||||
router_model_names = (
|
||||
[m["model_name"] for m in llm_model_list]
|
||||
if llm_model_list is not None
|
||||
else []
|
||||
)
|
||||
# Do not change this - it should be a constant time fetch - ALWAYS
|
||||
router_model_names = llm_router.model_names if llm_router is not None else []
|
||||
# skip router if user passed their key
|
||||
if "api_key" in data:
|
||||
tasks.append(litellm.acompletion(**data))
|
||||
|
@ -3238,11 +3257,7 @@ async def chat_completion(
|
|||
verbose_proxy_logger.debug(
|
||||
f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`"
|
||||
)
|
||||
router_model_names = (
|
||||
[m["model_name"] for m in llm_model_list]
|
||||
if llm_model_list is not None
|
||||
else []
|
||||
)
|
||||
router_model_names = llm_router.model_names if llm_router is not None else []
|
||||
if user_debug:
|
||||
traceback.print_exc()
|
||||
|
||||
|
@ -3284,7 +3299,6 @@ async def embeddings(
|
|||
request: Request,
|
||||
model: Optional[str] = None,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
background_tasks: BackgroundTasks = BackgroundTasks(),
|
||||
):
|
||||
global proxy_logging_obj
|
||||
try:
|
||||
|
@ -3350,11 +3364,7 @@ async def embeddings(
|
|||
if data["model"] in litellm.model_alias_map:
|
||||
data["model"] = litellm.model_alias_map[data["model"]]
|
||||
|
||||
router_model_names = (
|
||||
[m["model_name"] for m in llm_model_list]
|
||||
if llm_model_list is not None
|
||||
else []
|
||||
)
|
||||
router_model_names = llm_router.model_names if llm_router is not None else []
|
||||
if (
|
||||
"input" in data
|
||||
and isinstance(data["input"], list)
|
||||
|
@ -3460,7 +3470,6 @@ async def embeddings(
|
|||
async def image_generation(
|
||||
request: Request,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
background_tasks: BackgroundTasks = BackgroundTasks(),
|
||||
):
|
||||
global proxy_logging_obj
|
||||
try:
|
||||
|
@ -3526,11 +3535,7 @@ async def image_generation(
|
|||
if data["model"] in litellm.model_alias_map:
|
||||
data["model"] = litellm.model_alias_map[data["model"]]
|
||||
|
||||
router_model_names = (
|
||||
[m["model_name"] for m in llm_model_list]
|
||||
if llm_model_list is not None
|
||||
else []
|
||||
)
|
||||
router_model_names = llm_router.model_names if llm_router is not None else []
|
||||
|
||||
### CALL HOOKS ### - modify incoming data / reject request before calling the model
|
||||
data = await proxy_logging_obj.pre_call_hook(
|
||||
|
@ -3674,11 +3679,7 @@ async def audio_transcriptions(
|
|||
**data,
|
||||
} # add the team-specific configs to the completion call
|
||||
|
||||
router_model_names = (
|
||||
[m["model_name"] for m in llm_model_list]
|
||||
if llm_model_list is not None
|
||||
else []
|
||||
)
|
||||
router_model_names = llm_router.model_names if llm_router is not None else []
|
||||
|
||||
assert (
|
||||
file.filename is not None
|
||||
|
@ -3843,11 +3844,7 @@ async def moderations(
|
|||
**data,
|
||||
} # add the team-specific configs to the completion call
|
||||
|
||||
router_model_names = (
|
||||
[m["model_name"] for m in llm_model_list]
|
||||
if llm_model_list is not None
|
||||
else []
|
||||
)
|
||||
router_model_names = llm_router.model_names if llm_router is not None else []
|
||||
|
||||
### CALL HOOKS ### - modify incoming data / reject request before calling the model
|
||||
data = await proxy_logging_obj.pre_call_hook(
|
||||
|
@ -4353,7 +4350,7 @@ async def info_key_fn(
|
|||
|
||||
@router.get(
|
||||
"/spend/keys",
|
||||
tags=["budget & spend Tracking"],
|
||||
tags=["Budget & Spend Tracking"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def spend_key_fn():
|
||||
|
@ -4385,7 +4382,7 @@ async def spend_key_fn():
|
|||
|
||||
@router.get(
|
||||
"/spend/users",
|
||||
tags=["budget & spend Tracking"],
|
||||
tags=["Budget & Spend Tracking"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def spend_user_fn(
|
||||
|
@ -4437,7 +4434,7 @@ async def spend_user_fn(
|
|||
|
||||
@router.get(
|
||||
"/spend/tags",
|
||||
tags=["budget & spend Tracking"],
|
||||
tags=["Budget & Spend Tracking"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
responses={
|
||||
200: {"model": List[LiteLLM_SpendLogs]},
|
||||
|
@ -4510,6 +4507,77 @@ async def view_spend_tags(
|
|||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/spend/calculate",
|
||||
tags=["Budget & Spend Tracking"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
responses={
|
||||
200: {
|
||||
"cost": {
|
||||
"description": "The calculated cost",
|
||||
"example": 0.0,
|
||||
"type": "float",
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
async def calculate_spend(request: Request):
|
||||
"""
|
||||
Accepts all the params of completion_cost.
|
||||
|
||||
Calculate spend **before** making call:
|
||||
|
||||
```
|
||||
curl --location 'http://localhost:4000/spend/calculate'
|
||||
--header 'Authorization: Bearer sk-1234'
|
||||
--header 'Content-Type: application/json'
|
||||
--data '{
|
||||
"model": "anthropic.claude-v2",
|
||||
"messages": [{"role": "user", "content": "Hey, how'''s it going?"}]
|
||||
}'
|
||||
```
|
||||
|
||||
Calculate spend **after** making call:
|
||||
|
||||
```
|
||||
curl --location 'http://localhost:4000/spend/calculate'
|
||||
--header 'Authorization: Bearer sk-1234'
|
||||
--header 'Content-Type: application/json'
|
||||
--data '{
|
||||
"completion_response": {
|
||||
"id": "chatcmpl-123",
|
||||
"object": "chat.completion",
|
||||
"created": 1677652288,
|
||||
"model": "gpt-3.5-turbo-0125",
|
||||
"system_fingerprint": "fp_44709d6fcb",
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "Hello there, how may I assist you today?"
|
||||
},
|
||||
"logprobs": null,
|
||||
"finish_reason": "stop"
|
||||
}]
|
||||
"usage": {
|
||||
"prompt_tokens": 9,
|
||||
"completion_tokens": 12,
|
||||
"total_tokens": 21
|
||||
}
|
||||
}
|
||||
}'
|
||||
```
|
||||
"""
|
||||
from litellm import completion_cost
|
||||
|
||||
data = await request.json()
|
||||
if "completion_response" in data:
|
||||
data["completion_response"] = litellm.ModelResponse(
|
||||
**data["completion_response"]
|
||||
)
|
||||
return {"cost": completion_cost(**data)}
|
||||
|
||||
|
||||
@router.get(
|
||||
"/spend/logs",
|
||||
tags=["Budget & Spend Tracking"],
|
||||
|
@ -5240,6 +5308,7 @@ async def user_info(
|
|||
user_info = {"spend": spend}
|
||||
|
||||
## REMOVE HASHED TOKEN INFO before returning ##
|
||||
returned_keys = []
|
||||
for key in keys:
|
||||
try:
|
||||
key = key.model_dump() # noqa
|
||||
|
@ -5248,10 +5317,24 @@ async def user_info(
|
|||
key = key.dict()
|
||||
key.pop("token", None)
|
||||
|
||||
if (
|
||||
"team_id" in key
|
||||
and key["team_id"] is not None
|
||||
and key["team_id"] != "litellm-dashboard"
|
||||
):
|
||||
team_info = await prisma_client.get_data(
|
||||
team_id=key["team_id"], table_name="team"
|
||||
)
|
||||
team_alias = getattr(team_info, "team_alias", None)
|
||||
key["team_alias"] = team_alias
|
||||
else:
|
||||
key["team_alias"] = "None"
|
||||
returned_keys.append(key)
|
||||
|
||||
response_data = {
|
||||
"user_id": user_id,
|
||||
"user_info": user_info,
|
||||
"keys": keys,
|
||||
"keys": returned_keys,
|
||||
"teams": team_list,
|
||||
}
|
||||
return response_data
|
||||
|
@ -5639,7 +5722,7 @@ async def new_team(
|
|||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": f"tpm limit higher than user max. User tpm limit={user_api_key_dict.tpm_limit}"
|
||||
"error": f"tpm limit higher than user max. User tpm limit={user_api_key_dict.tpm_limit}. User role={user_api_key_dict.user_role}"
|
||||
},
|
||||
)
|
||||
|
||||
|
@ -5651,7 +5734,7 @@ async def new_team(
|
|||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": f"rpm limit higher than user max. User rpm limit={user_api_key_dict.rpm_limit}"
|
||||
"error": f"rpm limit higher than user max. User rpm limit={user_api_key_dict.rpm_limit}. User role={user_api_key_dict.user_role}"
|
||||
},
|
||||
)
|
||||
|
||||
|
@ -5663,7 +5746,7 @@ async def new_team(
|
|||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": f"max budget higher than user max. User max budget={user_api_key_dict.max_budget}"
|
||||
"error": f"max budget higher than user max. User max budget={user_api_key_dict.max_budget}. User role={user_api_key_dict.user_role}"
|
||||
},
|
||||
)
|
||||
|
||||
|
@ -5673,7 +5756,7 @@ async def new_team(
|
|||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": f"Model not in allowed user models. User allowed models={user_api_key_dict.models}"
|
||||
"error": f"Model not in allowed user models. User allowed models={user_api_key_dict.models}. User id={user_api_key_dict.user_id}"
|
||||
},
|
||||
)
|
||||
|
||||
|
@ -6170,7 +6253,7 @@ async def block_team(
|
|||
raise Exception("No DB Connected.")
|
||||
|
||||
record = await prisma_client.db.litellm_teamtable.update(
|
||||
where={"team_id": data.team_id}, data={"blocked": True}
|
||||
where={"team_id": data.team_id}, data={"blocked": True} # type: ignore
|
||||
)
|
||||
|
||||
return record
|
||||
|
@ -6192,7 +6275,7 @@ async def unblock_team(
|
|||
raise Exception("No DB Connected.")
|
||||
|
||||
record = await prisma_client.db.litellm_teamtable.update(
|
||||
where={"team_id": data.team_id}, data={"blocked": False}
|
||||
where={"team_id": data.team_id}, data={"blocked": False} # type: ignore
|
||||
)
|
||||
|
||||
return record
|
||||
|
@ -6795,7 +6878,6 @@ async def async_queue_request(
|
|||
request: Request,
|
||||
model: Optional[str] = None,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
background_tasks: BackgroundTasks = BackgroundTasks(),
|
||||
):
|
||||
global general_settings, user_debug, proxy_logging_obj
|
||||
"""
|
||||
|
@ -7058,6 +7140,13 @@ async def login(request: Request):
|
|||
except ImportError:
|
||||
subprocess.run(["pip", "install", "python-multipart"])
|
||||
global master_key
|
||||
if master_key is None:
|
||||
raise ProxyException(
|
||||
message="Master Key not set for Proxy. Please set Master Key to use Admin UI. Set `LITELLM_MASTER_KEY` in .env or set general_settings:master_key in config.yaml. https://docs.litellm.ai/docs/proxy/virtual_keys. If set, use `--detailed_debug` to debug issue.",
|
||||
type="auth_error",
|
||||
param="master_key",
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
form = await request.form()
|
||||
username = str(form.get("username"))
|
||||
password = str(form.get("password"))
|
||||
|
@ -7997,6 +8086,8 @@ async def shutdown_event():
|
|||
|
||||
await jwt_handler.close()
|
||||
|
||||
if db_writer_client is not None:
|
||||
await db_writer_client.close()
|
||||
## RESET CUSTOM VARIABLES ##
|
||||
cleanup_router_config_variables()
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue