Merge branch 'main' into litellm_aws_kms_fixes

This commit is contained in:
Krish Dholakia 2024-06-10 20:17:34 -07:00 committed by GitHub
commit f3feffc9d4
34 changed files with 1293 additions and 483 deletions

View file

@ -160,6 +160,7 @@ from litellm.proxy.auth.auth_checks import (
get_user_object,
allowed_routes_check,
get_actual_routes,
log_to_opentelemetry,
)
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
from litellm.exceptions import RejectedRequestError
@ -368,6 +369,11 @@ from typing import Dict
api_key_header = APIKeyHeader(
name="Authorization", auto_error=False, description="Bearer token"
)
azure_api_key_header = APIKeyHeader(
name="API-Key",
auto_error=False,
description="Some older versions of the openai Python package will send an API-Key header with just the API key ",
)
user_api_base = None
user_model = None
user_debug = False
@ -508,13 +514,19 @@ async def check_request_disconnection(request: Request, llm_api_call_task):
async def user_api_key_auth(
request: Request, api_key: str = fastapi.Security(api_key_header)
request: Request,
api_key: str = fastapi.Security(api_key_header),
azure_api_key_header: str = fastapi.Security(azure_api_key_header),
) -> UserAPIKeyAuth:
global master_key, prisma_client, llm_model_list, user_custom_auth, custom_db_client, general_settings, proxy_logging_obj
try:
if isinstance(api_key, str):
passed_in_key = api_key
api_key = _get_bearer_token(api_key=api_key)
elif isinstance(azure_api_key_header, str):
api_key = azure_api_key_header
parent_otel_span: Optional[Span] = None
if open_telemetry_logger is not None:
parent_otel_span = open_telemetry_logger.tracer.start_span(
@ -1495,7 +1507,7 @@ async def user_api_key_auth(
)
if valid_token is None:
# No token was found when looking up in the DB
raise Exception("Invalid token passed")
raise Exception("Invalid proxy server token passed")
if valid_token_dict is not None:
if user_id_information is not None and _is_user_proxy_admin(
user_id_information
@ -1528,6 +1540,14 @@ async def user_api_key_auth(
str(e)
)
)
# Log this exception to OTEL
if open_telemetry_logger is not None:
await open_telemetry_logger.async_post_call_failure_hook(
original_exception=e,
user_api_key_dict=UserAPIKeyAuth(parent_otel_span=parent_otel_span),
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, litellm.BudgetExceededError):
raise ProxyException(
@ -7803,6 +7823,10 @@ async def get_global_spend_report(
default=None,
description="Time till which to view spend",
),
group_by: Optional[Literal["team", "customer"]] = fastapi.Query(
default="team",
description="Group spend by internal team or customer",
),
):
"""
Get Daily Spend per Team, based on specific startTime and endTime. Per team, view usage by each key, model
@ -7849,69 +7873,130 @@ async def get_global_spend_report(
f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys"
)
# first get data from spend logs -> SpendByModelApiKey
# then read data from "SpendByModelApiKey" to format the response obj
sql_query = """
if group_by == "team":
# first get data from spend logs -> SpendByModelApiKey
# then read data from "SpendByModelApiKey" to format the response obj
sql_query = """
WITH SpendByModelApiKey AS (
SELECT
date_trunc('day', sl."startTime") AS group_by_day,
COALESCE(tt.team_alias, 'Unassigned Team') AS team_name,
sl.model,
sl.api_key,
SUM(sl.spend) AS model_api_spend,
SUM(sl.total_tokens) AS model_api_tokens
FROM
"LiteLLM_SpendLogs" sl
LEFT JOIN
"LiteLLM_TeamTable" tt
ON
sl.team_id = tt.team_id
WHERE
sl."startTime" BETWEEN $1::date AND $2::date
GROUP BY
date_trunc('day', sl."startTime"),
tt.team_alias,
sl.model,
sl.api_key
)
WITH SpendByModelApiKey AS (
SELECT
date_trunc('day', sl."startTime") AS group_by_day,
COALESCE(tt.team_alias, 'Unassigned Team') AS team_name,
sl.model,
sl.api_key,
SUM(sl.spend) AS model_api_spend,
SUM(sl.total_tokens) AS model_api_tokens
FROM
"LiteLLM_SpendLogs" sl
LEFT JOIN
"LiteLLM_TeamTable" tt
ON
sl.team_id = tt.team_id
WHERE
sl."startTime" BETWEEN $1::date AND $2::date
GROUP BY
date_trunc('day', sl."startTime"),
tt.team_alias,
sl.model,
sl.api_key
)
SELECT
group_by_day,
jsonb_agg(jsonb_build_object(
'team_name', team_name,
'total_spend', total_spend,
'metadata', metadata
)) AS teams
FROM (
SELECT
group_by_day,
team_name,
SUM(model_api_spend) AS total_spend,
jsonb_agg(jsonb_build_object(
'model', model,
'api_key', api_key,
'spend', model_api_spend,
'total_tokens', model_api_tokens
)) AS metadata
FROM
SpendByModelApiKey
GROUP BY
group_by_day,
team_name
) AS aggregated
GROUP BY
group_by_day
ORDER BY
group_by_day;
"""
db_response = await prisma_client.db.query_raw(
sql_query, start_date_obj, end_date_obj
)
if db_response is None:
return []
return db_response
elif group_by == "customer":
sql_query = """
WITH SpendByModelApiKey AS (
SELECT
date_trunc('day', sl."startTime") AS group_by_day,
sl.end_user AS customer,
sl.model,
sl.api_key,
SUM(sl.spend) AS model_api_spend,
SUM(sl.total_tokens) AS model_api_tokens
FROM
"LiteLLM_SpendLogs" sl
WHERE
sl."startTime" BETWEEN $1::date AND $2::date
GROUP BY
date_trunc('day', sl."startTime"),
customer,
sl.model,
sl.api_key
)
SELECT
group_by_day,
jsonb_agg(jsonb_build_object(
'team_name', team_name,
'customer', customer,
'total_spend', total_spend,
'metadata', metadata
)) AS teams
FROM (
SELECT
group_by_day,
team_name,
SUM(model_api_spend) AS total_spend,
jsonb_agg(jsonb_build_object(
'model', model,
'api_key', api_key,
'spend', model_api_spend,
'total_tokens', model_api_tokens
)) AS metadata
FROM
SpendByModelApiKey
GROUP BY
group_by_day,
team_name
) AS aggregated
)) AS customers
FROM
(
SELECT
group_by_day,
customer,
SUM(model_api_spend) AS total_spend,
jsonb_agg(jsonb_build_object(
'model', model,
'api_key', api_key,
'spend', model_api_spend,
'total_tokens', model_api_tokens
)) AS metadata
FROM
SpendByModelApiKey
GROUP BY
group_by_day,
customer
) AS aggregated
GROUP BY
group_by_day
ORDER BY
group_by_day;
"""
"""
db_response = await prisma_client.db.query_raw(
sql_query, start_date_obj, end_date_obj
)
if db_response is None:
return []
db_response = await prisma_client.db.query_raw(
sql_query, start_date_obj, end_date_obj
)
if db_response is None:
return []
return db_response
return db_response
except Exception as e:
raise HTTPException(