mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
fix(proxy_server.py): track team spend for cached team object
fixes issue where team budgets for jwt tokens weren't asserted
This commit is contained in:
parent
5ad095ad9d
commit
6558abf845
5 changed files with 316 additions and 156 deletions
|
@ -73,16 +73,12 @@ assistant_settings:
|
||||||
|
|
||||||
router_settings:
|
router_settings:
|
||||||
enable_pre_call_checks: true
|
enable_pre_call_checks: true
|
||||||
|
|
||||||
litellm_settings:
|
|
||||||
callbacks: ["lago"]
|
|
||||||
success_callback: ["langfuse"]
|
|
||||||
failure_callback: ["langfuse"]
|
|
||||||
cache: true
|
|
||||||
json_logs: true
|
|
||||||
|
|
||||||
general_settings:
|
general_settings:
|
||||||
alerting: ["slack"]
|
alerting: ["slack"]
|
||||||
|
enable_jwt_auth: True
|
||||||
|
litellm_jwtauth:
|
||||||
|
team_id_jwt_field: "client_id"
|
||||||
# key_management_system: "aws_kms"
|
# key_management_system: "aws_kms"
|
||||||
# key_management_settings:
|
# key_management_settings:
|
||||||
# hosted_keys: ["LITELLM_MASTER_KEY"]
|
# hosted_keys: ["LITELLM_MASTER_KEY"]
|
||||||
|
|
|
@ -8,21 +8,22 @@ Run checks for:
|
||||||
2. If user is in budget
|
2. If user is in budget
|
||||||
3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
|
3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
|
||||||
"""
|
"""
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import TYPE_CHECKING, Any, Literal, Optional
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm.caching import DualCache
|
||||||
from litellm.proxy._types import (
|
from litellm.proxy._types import (
|
||||||
LiteLLM_UserTable,
|
|
||||||
LiteLLM_EndUserTable,
|
LiteLLM_EndUserTable,
|
||||||
LiteLLM_JWTAuth,
|
LiteLLM_JWTAuth,
|
||||||
LiteLLM_TeamTable,
|
|
||||||
LiteLLMRoutes,
|
|
||||||
LiteLLM_OrganizationTable,
|
LiteLLM_OrganizationTable,
|
||||||
|
LiteLLM_TeamTable,
|
||||||
|
LiteLLM_UserTable,
|
||||||
|
LiteLLMRoutes,
|
||||||
LitellmUserRoles,
|
LitellmUserRoles,
|
||||||
)
|
)
|
||||||
from typing import Optional, Literal, TYPE_CHECKING, Any
|
|
||||||
from litellm.proxy.utils import PrismaClient, ProxyLogging, log_to_opentelemetry
|
from litellm.proxy.utils import PrismaClient, ProxyLogging, log_to_opentelemetry
|
||||||
from litellm.caching import DualCache
|
|
||||||
import litellm
|
|
||||||
from litellm.types.services import ServiceLoggerPayload, ServiceTypes
|
from litellm.types.services import ServiceLoggerPayload, ServiceTypes
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from opentelemetry.trace import Span as _Span
|
from opentelemetry.trace import Span as _Span
|
||||||
|
@ -110,7 +111,7 @@ def common_checks(
|
||||||
# Enterprise ONLY Feature
|
# Enterprise ONLY Feature
|
||||||
# we already validate if user is premium_user when reading the config
|
# we already validate if user is premium_user when reading the config
|
||||||
# Add an extra premium_usercheck here too, just incase
|
# Add an extra premium_usercheck here too, just incase
|
||||||
from litellm.proxy.proxy_server import premium_user, CommonProxyErrors
|
from litellm.proxy.proxy_server import CommonProxyErrors, premium_user
|
||||||
|
|
||||||
if premium_user is not True:
|
if premium_user is not True:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -364,7 +365,8 @@ async def get_team_object(
|
||||||
)
|
)
|
||||||
|
|
||||||
# check if in cache
|
# check if in cache
|
||||||
cached_team_obj = await user_api_key_cache.async_get_cache(key=team_id)
|
key = "team_id:{}".format(team_id)
|
||||||
|
cached_team_obj = await user_api_key_cache.async_get_cache(key=key)
|
||||||
if cached_team_obj is not None:
|
if cached_team_obj is not None:
|
||||||
if isinstance(cached_team_obj, dict):
|
if isinstance(cached_team_obj, dict):
|
||||||
return LiteLLM_TeamTable(**cached_team_obj)
|
return LiteLLM_TeamTable(**cached_team_obj)
|
||||||
|
@ -381,7 +383,7 @@ async def get_team_object(
|
||||||
|
|
||||||
_response = LiteLLM_TeamTable(**response.dict())
|
_response = LiteLLM_TeamTable(**response.dict())
|
||||||
# save the team object to cache
|
# save the team object to cache
|
||||||
await user_api_key_cache.async_set_cache(key=response.team_id, value=_response)
|
await user_api_key_cache.async_set_cache(key=key, value=_response)
|
||||||
|
|
||||||
return _response
|
return _response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
@ -120,7 +120,7 @@ async def user_api_key_auth(
|
||||||
)
|
)
|
||||||
### USER-DEFINED AUTH FUNCTION ###
|
### USER-DEFINED AUTH FUNCTION ###
|
||||||
if user_custom_auth is not None:
|
if user_custom_auth is not None:
|
||||||
response = await user_custom_auth(request=request, api_key=api_key)
|
response = await user_custom_auth(request=request, api_key=api_key) # type: ignore
|
||||||
return UserAPIKeyAuth.model_validate(response)
|
return UserAPIKeyAuth.model_validate(response)
|
||||||
|
|
||||||
### LITELLM-DEFINED AUTH FUNCTION ###
|
### LITELLM-DEFINED AUTH FUNCTION ###
|
||||||
|
@ -140,7 +140,7 @@ async def user_api_key_auth(
|
||||||
# check if public endpoint
|
# check if public endpoint
|
||||||
return UserAPIKeyAuth(user_role=LitellmUserRoles.INTERNAL_USER_VIEW_ONLY)
|
return UserAPIKeyAuth(user_role=LitellmUserRoles.INTERNAL_USER_VIEW_ONLY)
|
||||||
|
|
||||||
if general_settings.get("enable_jwt_auth", False) == True:
|
if general_settings.get("enable_jwt_auth", False) is True:
|
||||||
is_jwt = jwt_handler.is_jwt(token=api_key)
|
is_jwt = jwt_handler.is_jwt(token=api_key)
|
||||||
verbose_proxy_logger.debug("is_jwt: %s", is_jwt)
|
verbose_proxy_logger.debug("is_jwt: %s", is_jwt)
|
||||||
if is_jwt:
|
if is_jwt:
|
||||||
|
@ -177,7 +177,7 @@ async def user_api_key_auth(
|
||||||
token=jwt_valid_token, default_value=None
|
token=jwt_valid_token, default_value=None
|
||||||
)
|
)
|
||||||
|
|
||||||
if team_id is None and jwt_handler.is_required_team_id() == True:
|
if team_id is None and jwt_handler.is_required_team_id() is True:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"No team id passed in. Field checked in jwt token - '{jwt_handler.litellm_jwtauth.team_id_jwt_field}'"
|
f"No team id passed in. Field checked in jwt token - '{jwt_handler.litellm_jwtauth.team_id_jwt_field}'"
|
||||||
)
|
)
|
||||||
|
@ -190,7 +190,7 @@ async def user_api_key_auth(
|
||||||
user_route=route,
|
user_route=route,
|
||||||
litellm_proxy_roles=jwt_handler.litellm_jwtauth,
|
litellm_proxy_roles=jwt_handler.litellm_jwtauth,
|
||||||
)
|
)
|
||||||
if is_allowed == False:
|
if is_allowed is False:
|
||||||
allowed_routes = jwt_handler.litellm_jwtauth.team_allowed_routes # type: ignore
|
allowed_routes = jwt_handler.litellm_jwtauth.team_allowed_routes # type: ignore
|
||||||
actual_routes = get_actual_routes(allowed_routes=allowed_routes)
|
actual_routes = get_actual_routes(allowed_routes=allowed_routes)
|
||||||
raise Exception(
|
raise Exception(
|
||||||
|
|
|
@ -1,12 +1,26 @@
|
||||||
import sys, os, platform, time, copy, re, asyncio, inspect
|
import ast
|
||||||
import threading, ast
|
import asyncio
|
||||||
import shutil, random, traceback, requests
|
import copy
|
||||||
from datetime import datetime, timedelta, timezone
|
import hashlib
|
||||||
from typing import Optional, List, Callable, get_args, Set, Any, TYPE_CHECKING
|
|
||||||
import secrets, subprocess
|
|
||||||
import hashlib, uuid
|
|
||||||
import warnings
|
|
||||||
import importlib
|
import importlib
|
||||||
|
import inspect
|
||||||
|
import os
|
||||||
|
import platform
|
||||||
|
import random
|
||||||
|
import re
|
||||||
|
import secrets
|
||||||
|
import shutil
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
import uuid
|
||||||
|
import warnings
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Set, get_args
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from opentelemetry.trace import Span as _Span
|
from opentelemetry.trace import Span as _Span
|
||||||
|
@ -34,11 +48,12 @@ sys.path.insert(
|
||||||
) # Adds the parent directory to the system path - for litellm local dev
|
) # Adds the parent directory to the system path - for litellm local dev
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import fastapi
|
|
||||||
import backoff
|
|
||||||
import yaml # type: ignore
|
|
||||||
import orjson
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
import backoff
|
||||||
|
import fastapi
|
||||||
|
import orjson
|
||||||
|
import yaml # type: ignore
|
||||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
raise ImportError(f"Missing dependency {e}. Run `pip install 'litellm[proxy]'`")
|
raise ImportError(f"Missing dependency {e}. Run `pip install 'litellm[proxy]'`")
|
||||||
|
@ -85,60 +100,31 @@ def generate_feedback_box():
|
||||||
print() # noqa
|
print() # noqa
|
||||||
|
|
||||||
|
|
||||||
import litellm
|
|
||||||
from litellm.types.llms.openai import (
|
|
||||||
HttpxBinaryResponseContent,
|
|
||||||
)
|
|
||||||
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
|
|
||||||
from litellm.proxy.utils import (
|
|
||||||
PrismaClient,
|
|
||||||
DBClient,
|
|
||||||
get_instance_fn,
|
|
||||||
ProxyLogging,
|
|
||||||
_cache_user_row,
|
|
||||||
send_email,
|
|
||||||
get_logging_payload,
|
|
||||||
reset_budget,
|
|
||||||
hash_token,
|
|
||||||
html_form,
|
|
||||||
missing_keys_html_form,
|
|
||||||
_is_valid_team_configs,
|
|
||||||
_is_projected_spend_over_limit,
|
|
||||||
_get_projected_spend_over_limit,
|
|
||||||
update_spend,
|
|
||||||
encrypt_value,
|
|
||||||
decrypt_value,
|
|
||||||
get_error_message_str,
|
|
||||||
)
|
|
||||||
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
|
|
||||||
|
|
||||||
from litellm import (
|
|
||||||
CreateBatchRequest,
|
|
||||||
RetrieveBatchRequest,
|
|
||||||
ListBatchRequest,
|
|
||||||
CancelBatchRequest,
|
|
||||||
CreateFileRequest,
|
|
||||||
)
|
|
||||||
from litellm.proxy.secret_managers.google_kms import load_google_kms
|
|
||||||
from litellm.proxy.secret_managers.aws_secret_manager import (
|
|
||||||
load_aws_secret_manager,
|
|
||||||
load_aws_kms,
|
|
||||||
)
|
|
||||||
import pydantic
|
import pydantic
|
||||||
from litellm.proxy._types import *
|
|
||||||
from litellm.caching import DualCache, RedisCache
|
import litellm
|
||||||
from litellm.proxy.health_check import perform_health_check
|
from litellm import (
|
||||||
from litellm.router import (
|
CancelBatchRequest,
|
||||||
LiteLLM_Params,
|
CreateBatchRequest,
|
||||||
Deployment,
|
CreateFileRequest,
|
||||||
updateDeployment,
|
ListBatchRequest,
|
||||||
ModelGroupInfo,
|
RetrieveBatchRequest,
|
||||||
AssistantsTypedDict,
|
|
||||||
)
|
)
|
||||||
from litellm.router import ModelInfo as RouterModelInfo
|
from litellm._logging import verbose_proxy_logger, verbose_router_logger
|
||||||
from litellm._logging import (
|
from litellm.caching import DualCache, RedisCache
|
||||||
verbose_router_logger,
|
from litellm.exceptions import RejectedRequestError
|
||||||
verbose_proxy_logger,
|
from litellm.integrations.slack_alerting import SlackAlerting, SlackAlertingArgs
|
||||||
|
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
|
||||||
|
from litellm.proxy._types import *
|
||||||
|
from litellm.proxy.auth.auth_checks import (
|
||||||
|
allowed_routes_check,
|
||||||
|
common_checks,
|
||||||
|
get_actual_routes,
|
||||||
|
get_end_user_object,
|
||||||
|
get_org_object,
|
||||||
|
get_team_object,
|
||||||
|
get_user_object,
|
||||||
|
log_to_opentelemetry,
|
||||||
)
|
)
|
||||||
from litellm.proxy.auth.handle_jwt import JWTHandler
|
from litellm.proxy.auth.handle_jwt import JWTHandler
|
||||||
from litellm.proxy.auth.litellm_license import LicenseCheck
|
from litellm.proxy.auth.litellm_license import LicenseCheck
|
||||||
|
@ -148,78 +134,105 @@ from litellm.proxy.auth.model_checks import (
|
||||||
get_team_models,
|
get_team_models,
|
||||||
)
|
)
|
||||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||||
from litellm.proxy.hooks.prompt_injection_detection import (
|
|
||||||
_OPTIONAL_PromptInjectionDetection,
|
|
||||||
)
|
|
||||||
from litellm.proxy.auth.auth_checks import (
|
|
||||||
common_checks,
|
|
||||||
get_end_user_object,
|
|
||||||
get_org_object,
|
|
||||||
get_team_object,
|
|
||||||
get_user_object,
|
|
||||||
allowed_routes_check,
|
|
||||||
get_actual_routes,
|
|
||||||
log_to_opentelemetry,
|
|
||||||
)
|
|
||||||
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
|
|
||||||
from litellm.exceptions import RejectedRequestError
|
|
||||||
from litellm.integrations.slack_alerting import SlackAlertingArgs, SlackAlerting
|
|
||||||
from litellm.scheduler import Scheduler, FlowItem, DefaultPriorities
|
|
||||||
|
|
||||||
## Import All Misc routes here ##
|
## Import All Misc routes here ##
|
||||||
from litellm.proxy.caching_routes import router as caching_router
|
from litellm.proxy.caching_routes import router as caching_router
|
||||||
from litellm.proxy.management_endpoints.team_endpoints import router as team_router
|
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
|
||||||
from litellm.proxy.spend_reporting_endpoints.spend_management_endpoints import (
|
from litellm.proxy.health_check import perform_health_check
|
||||||
router as spend_management_router,
|
from litellm.proxy.health_endpoints._health_endpoints import router as health_router
|
||||||
|
from litellm.proxy.hooks.prompt_injection_detection import (
|
||||||
|
_OPTIONAL_PromptInjectionDetection,
|
||||||
|
)
|
||||||
|
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
|
||||||
|
from litellm.proxy.management_endpoints.internal_user_endpoints import (
|
||||||
|
router as internal_user_router,
|
||||||
|
)
|
||||||
|
from litellm.proxy.management_endpoints.internal_user_endpoints import user_update
|
||||||
|
from litellm.proxy.management_endpoints.key_management_endpoints import (
|
||||||
|
_duration_in_seconds,
|
||||||
|
delete_verification_token,
|
||||||
|
generate_key_helper_fn,
|
||||||
)
|
)
|
||||||
from litellm.proxy.management_endpoints.key_management_endpoints import (
|
from litellm.proxy.management_endpoints.key_management_endpoints import (
|
||||||
router as key_management_router,
|
router as key_management_router,
|
||||||
_duration_in_seconds,
|
|
||||||
generate_key_helper_fn,
|
|
||||||
delete_verification_token,
|
|
||||||
)
|
)
|
||||||
from litellm.proxy.management_endpoints.internal_user_endpoints import (
|
from litellm.proxy.management_endpoints.team_endpoints import router as team_router
|
||||||
router as internal_user_router,
|
from litellm.proxy.secret_managers.aws_secret_manager import (
|
||||||
user_update,
|
load_aws_kms,
|
||||||
|
load_aws_secret_manager,
|
||||||
)
|
)
|
||||||
from litellm.proxy.health_endpoints._health_endpoints import router as health_router
|
from litellm.proxy.secret_managers.google_kms import load_google_kms
|
||||||
|
from litellm.proxy.spend_reporting_endpoints.spend_management_endpoints import (
|
||||||
|
router as spend_management_router,
|
||||||
|
)
|
||||||
|
from litellm.proxy.utils import (
|
||||||
|
DBClient,
|
||||||
|
PrismaClient,
|
||||||
|
ProxyLogging,
|
||||||
|
_cache_user_row,
|
||||||
|
_get_projected_spend_over_limit,
|
||||||
|
_is_projected_spend_over_limit,
|
||||||
|
_is_valid_team_configs,
|
||||||
|
decrypt_value,
|
||||||
|
encrypt_value,
|
||||||
|
get_error_message_str,
|
||||||
|
get_instance_fn,
|
||||||
|
get_logging_payload,
|
||||||
|
hash_token,
|
||||||
|
html_form,
|
||||||
|
missing_keys_html_form,
|
||||||
|
reset_budget,
|
||||||
|
send_email,
|
||||||
|
update_spend,
|
||||||
|
)
|
||||||
|
from litellm.router import (
|
||||||
|
AssistantsTypedDict,
|
||||||
|
Deployment,
|
||||||
|
LiteLLM_Params,
|
||||||
|
ModelGroupInfo,
|
||||||
|
)
|
||||||
|
from litellm.router import ModelInfo as RouterModelInfo
|
||||||
|
from litellm.router import updateDeployment
|
||||||
|
from litellm.scheduler import DefaultPriorities, FlowItem, Scheduler
|
||||||
|
from litellm.types.llms.openai import HttpxBinaryResponseContent
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from litellm._version import version
|
from litellm._version import version
|
||||||
except:
|
except:
|
||||||
version = "0.0.0"
|
version = "0.0.0"
|
||||||
litellm.suppress_debug_info = True
|
litellm.suppress_debug_info = True
|
||||||
from fastapi import (
|
|
||||||
FastAPI,
|
|
||||||
Request,
|
|
||||||
HTTPException,
|
|
||||||
status,
|
|
||||||
Path,
|
|
||||||
Depends,
|
|
||||||
Header,
|
|
||||||
Response,
|
|
||||||
Form,
|
|
||||||
UploadFile,
|
|
||||||
File,
|
|
||||||
)
|
|
||||||
from fastapi.routing import APIRouter
|
|
||||||
from fastapi.security import OAuth2PasswordBearer
|
|
||||||
from fastapi.encoders import jsonable_encoder
|
|
||||||
from fastapi.responses import (
|
|
||||||
StreamingResponse,
|
|
||||||
FileResponse,
|
|
||||||
ORJSONResponse,
|
|
||||||
JSONResponse,
|
|
||||||
)
|
|
||||||
from fastapi.openapi.utils import get_openapi
|
|
||||||
from fastapi.responses import RedirectResponse
|
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
|
||||||
from fastapi.staticfiles import StaticFiles
|
|
||||||
from fastapi.security.api_key import APIKeyHeader
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
|
from fastapi import (
|
||||||
|
Depends,
|
||||||
|
FastAPI,
|
||||||
|
File,
|
||||||
|
Form,
|
||||||
|
Header,
|
||||||
|
HTTPException,
|
||||||
|
Path,
|
||||||
|
Request,
|
||||||
|
Response,
|
||||||
|
UploadFile,
|
||||||
|
status,
|
||||||
|
)
|
||||||
|
from fastapi.encoders import jsonable_encoder
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from fastapi.openapi.utils import get_openapi
|
||||||
|
from fastapi.responses import (
|
||||||
|
FileResponse,
|
||||||
|
JSONResponse,
|
||||||
|
ORJSONResponse,
|
||||||
|
RedirectResponse,
|
||||||
|
StreamingResponse,
|
||||||
|
)
|
||||||
|
from fastapi.routing import APIRouter
|
||||||
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
|
from fastapi.security.api_key import APIKeyHeader
|
||||||
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
|
||||||
# import enterprise folder
|
# import enterprise folder
|
||||||
try:
|
try:
|
||||||
# when using litellm cli
|
# when using litellm cli
|
||||||
|
@ -488,8 +501,8 @@ def load_from_azure_key_vault(use_azure_key_vault: bool = False):
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from azure.keyvault.secrets import SecretClient
|
|
||||||
from azure.identity import ClientSecretCredential
|
from azure.identity import ClientSecretCredential
|
||||||
|
from azure.keyvault.secrets import SecretClient
|
||||||
|
|
||||||
# Set your Azure Key Vault URI
|
# Set your Azure Key Vault URI
|
||||||
KVUri = os.getenv("AZURE_KEY_VAULT_URI", None)
|
KVUri = os.getenv("AZURE_KEY_VAULT_URI", None)
|
||||||
|
@ -655,6 +668,7 @@ async def _PROXY_track_cost_callback(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
end_user_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
response_cost=response_cost,
|
response_cost=response_cost,
|
||||||
|
team_id=team_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
await proxy_logging_obj.slack_alerting_instance.customer_spend_alert(
|
await proxy_logging_obj.slack_alerting_instance.customer_spend_alert(
|
||||||
|
@ -897,6 +911,7 @@ async def update_cache(
|
||||||
token: Optional[str],
|
token: Optional[str],
|
||||||
user_id: Optional[str],
|
user_id: Optional[str],
|
||||||
end_user_id: Optional[str],
|
end_user_id: Optional[str],
|
||||||
|
team_id: Optional[str],
|
||||||
response_cost: Optional[float],
|
response_cost: Optional[float],
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -993,6 +1008,7 @@ async def update_cache(
|
||||||
existing_spend_obj.spend = new_spend
|
existing_spend_obj.spend = new_spend
|
||||||
user_api_key_cache.set_cache(key=hashed_token, value=existing_spend_obj)
|
user_api_key_cache.set_cache(key=hashed_token, value=existing_spend_obj)
|
||||||
|
|
||||||
|
### UPDATE USER SPEND ###
|
||||||
async def _update_user_cache():
|
async def _update_user_cache():
|
||||||
## UPDATE CACHE FOR USER ID + GLOBAL PROXY
|
## UPDATE CACHE FOR USER ID + GLOBAL PROXY
|
||||||
user_ids = [user_id]
|
user_ids = [user_id]
|
||||||
|
@ -1054,6 +1070,7 @@ async def update_cache(
|
||||||
f"An error occurred updating user cache: {str(e)}\n\n{traceback.format_exc()}"
|
f"An error occurred updating user cache: {str(e)}\n\n{traceback.format_exc()}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
### UPDATE END-USER SPEND ###
|
||||||
async def _update_end_user_cache():
|
async def _update_end_user_cache():
|
||||||
if end_user_id is None or response_cost is None:
|
if end_user_id is None or response_cost is None:
|
||||||
return
|
return
|
||||||
|
@ -1102,14 +1119,59 @@ async def update_cache(
|
||||||
f"An error occurred updating end user cache: {str(e)}\n\n{traceback.format_exc()}"
|
f"An error occurred updating end user cache: {str(e)}\n\n{traceback.format_exc()}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
### UPDATE TEAM SPEND ###
|
||||||
|
async def _update_team_cache():
|
||||||
|
if team_id is None or response_cost is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
_id = "team_id:{}".format(team_id)
|
||||||
|
try:
|
||||||
|
# Fetch the existing cost for the given user
|
||||||
|
existing_spend_obj: Optional[LiteLLM_TeamTable] = (
|
||||||
|
await user_api_key_cache.async_get_cache(key=_id)
|
||||||
|
)
|
||||||
|
if existing_spend_obj is None:
|
||||||
|
return
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
f"_update_team_db: existing spend: {existing_spend_obj}; response_cost: {response_cost}"
|
||||||
|
)
|
||||||
|
if existing_spend_obj is None:
|
||||||
|
existing_spend: Optional[float] = 0.0
|
||||||
|
else:
|
||||||
|
if isinstance(existing_spend_obj, dict):
|
||||||
|
existing_spend = existing_spend_obj["spend"]
|
||||||
|
else:
|
||||||
|
existing_spend = existing_spend_obj.spend
|
||||||
|
|
||||||
|
if existing_spend is None:
|
||||||
|
existing_spend = 0.0
|
||||||
|
# Calculate the new cost by adding the existing cost and response_cost
|
||||||
|
new_spend = existing_spend + response_cost
|
||||||
|
|
||||||
|
# Update the cost column for the given user
|
||||||
|
if isinstance(existing_spend_obj, dict):
|
||||||
|
existing_spend_obj["spend"] = new_spend
|
||||||
|
user_api_key_cache.set_cache(key=_id, value=existing_spend_obj)
|
||||||
|
else:
|
||||||
|
existing_spend_obj.spend = new_spend
|
||||||
|
user_api_key_cache.set_cache(key=_id, value=existing_spend_obj)
|
||||||
|
except Exception as e:
|
||||||
|
verbose_proxy_logger.error(
|
||||||
|
f"An error occurred updating end user cache: {str(e)}\n\n{traceback.format_exc()}"
|
||||||
|
)
|
||||||
|
|
||||||
if token is not None and response_cost is not None:
|
if token is not None and response_cost is not None:
|
||||||
asyncio.create_task(_update_key_cache(token=token, response_cost=response_cost))
|
asyncio.create_task(_update_key_cache(token=token, response_cost=response_cost))
|
||||||
|
|
||||||
asyncio.create_task(_update_user_cache())
|
if user_id is not None:
|
||||||
|
asyncio.create_task(_update_user_cache())
|
||||||
|
|
||||||
if end_user_id is not None:
|
if end_user_id is not None:
|
||||||
asyncio.create_task(_update_end_user_cache())
|
asyncio.create_task(_update_end_user_cache())
|
||||||
|
|
||||||
|
if team_id is not None:
|
||||||
|
asyncio.create_task(_update_team_cache())
|
||||||
|
|
||||||
|
|
||||||
def run_ollama_serve():
|
def run_ollama_serve():
|
||||||
try:
|
try:
|
||||||
|
@ -2297,25 +2359,27 @@ async def initialize(
|
||||||
user_model = model
|
user_model = model
|
||||||
user_debug = debug
|
user_debug = debug
|
||||||
if debug == True: # this needs to be first, so users can see Router init debugg
|
if debug == True: # this needs to be first, so users can see Router init debugg
|
||||||
from litellm._logging import (
|
|
||||||
verbose_router_logger,
|
|
||||||
verbose_proxy_logger,
|
|
||||||
verbose_logger,
|
|
||||||
)
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from litellm._logging import (
|
||||||
|
verbose_logger,
|
||||||
|
verbose_proxy_logger,
|
||||||
|
verbose_router_logger,
|
||||||
|
)
|
||||||
|
|
||||||
# this must ALWAYS remain logging.INFO, DO NOT MODIFY THIS
|
# this must ALWAYS remain logging.INFO, DO NOT MODIFY THIS
|
||||||
verbose_logger.setLevel(level=logging.INFO) # sets package logs to info
|
verbose_logger.setLevel(level=logging.INFO) # sets package logs to info
|
||||||
verbose_router_logger.setLevel(level=logging.INFO) # set router logs to info
|
verbose_router_logger.setLevel(level=logging.INFO) # set router logs to info
|
||||||
verbose_proxy_logger.setLevel(level=logging.INFO) # set proxy logs to info
|
verbose_proxy_logger.setLevel(level=logging.INFO) # set proxy logs to info
|
||||||
if detailed_debug == True:
|
if detailed_debug == True:
|
||||||
from litellm._logging import (
|
|
||||||
verbose_router_logger,
|
|
||||||
verbose_proxy_logger,
|
|
||||||
verbose_logger,
|
|
||||||
)
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from litellm._logging import (
|
||||||
|
verbose_logger,
|
||||||
|
verbose_proxy_logger,
|
||||||
|
verbose_router_logger,
|
||||||
|
)
|
||||||
|
|
||||||
verbose_logger.setLevel(level=logging.DEBUG) # set package log to debug
|
verbose_logger.setLevel(level=logging.DEBUG) # set package log to debug
|
||||||
verbose_router_logger.setLevel(level=logging.DEBUG) # set router logs to debug
|
verbose_router_logger.setLevel(level=logging.DEBUG) # set router logs to debug
|
||||||
verbose_proxy_logger.setLevel(level=logging.DEBUG) # set proxy logs to debug
|
verbose_proxy_logger.setLevel(level=logging.DEBUG) # set proxy logs to debug
|
||||||
|
@ -2324,9 +2388,10 @@ async def initialize(
|
||||||
litellm_log_setting = os.environ.get("LITELLM_LOG", "")
|
litellm_log_setting = os.environ.get("LITELLM_LOG", "")
|
||||||
if litellm_log_setting != None:
|
if litellm_log_setting != None:
|
||||||
if litellm_log_setting.upper() == "INFO":
|
if litellm_log_setting.upper() == "INFO":
|
||||||
from litellm._logging import verbose_router_logger, verbose_proxy_logger
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from litellm._logging import verbose_proxy_logger, verbose_router_logger
|
||||||
|
|
||||||
# this must ALWAYS remain logging.INFO, DO NOT MODIFY THIS
|
# this must ALWAYS remain logging.INFO, DO NOT MODIFY THIS
|
||||||
|
|
||||||
verbose_router_logger.setLevel(
|
verbose_router_logger.setLevel(
|
||||||
|
@ -2336,9 +2401,10 @@ async def initialize(
|
||||||
level=logging.INFO
|
level=logging.INFO
|
||||||
) # set proxy logs to info
|
) # set proxy logs to info
|
||||||
elif litellm_log_setting.upper() == "DEBUG":
|
elif litellm_log_setting.upper() == "DEBUG":
|
||||||
from litellm._logging import verbose_router_logger, verbose_proxy_logger
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from litellm._logging import verbose_proxy_logger, verbose_router_logger
|
||||||
|
|
||||||
verbose_router_logger.setLevel(
|
verbose_router_logger.setLevel(
|
||||||
level=logging.DEBUG
|
level=logging.DEBUG
|
||||||
) # set router logs to info
|
) # set router logs to info
|
||||||
|
@ -7036,7 +7102,7 @@ async def google_login(request: Request):
|
||||||
with microsoft_sso:
|
with microsoft_sso:
|
||||||
return await microsoft_sso.get_login_redirect()
|
return await microsoft_sso.get_login_redirect()
|
||||||
elif generic_client_id is not None:
|
elif generic_client_id is not None:
|
||||||
from fastapi_sso.sso.generic import create_provider, DiscoveryDocument
|
from fastapi_sso.sso.generic import DiscoveryDocument, create_provider
|
||||||
|
|
||||||
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
|
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
|
||||||
generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ")
|
generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ")
|
||||||
|
@ -7611,7 +7677,7 @@ async def auth_callback(request: Request):
|
||||||
result = await microsoft_sso.verify_and_process(request)
|
result = await microsoft_sso.verify_and_process(request)
|
||||||
elif generic_client_id is not None:
|
elif generic_client_id is not None:
|
||||||
# make generic sso provider
|
# make generic sso provider
|
||||||
from fastapi_sso.sso.generic import create_provider, DiscoveryDocument, OpenID
|
from fastapi_sso.sso.generic import DiscoveryDocument, OpenID, create_provider
|
||||||
|
|
||||||
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
|
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
|
||||||
generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ")
|
generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ")
|
||||||
|
|
|
@ -18,6 +18,7 @@ sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system path
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
|
@ -26,6 +27,7 @@ from litellm.caching import DualCache
|
||||||
from litellm.proxy._types import LiteLLM_JWTAuth, LiteLLMRoutes
|
from litellm.proxy._types import LiteLLM_JWTAuth, LiteLLMRoutes
|
||||||
from litellm.proxy.auth.handle_jwt import JWTHandler
|
from litellm.proxy.auth.handle_jwt import JWTHandler
|
||||||
from litellm.proxy.management_endpoints.team_endpoints import new_team
|
from litellm.proxy.management_endpoints.team_endpoints import new_team
|
||||||
|
from litellm.proxy.proxy_server import chat_completion
|
||||||
|
|
||||||
public_key = {
|
public_key = {
|
||||||
"kty": "RSA",
|
"kty": "RSA",
|
||||||
|
@ -220,6 +222,70 @@ def prisma_client():
|
||||||
return prisma_client
|
return prisma_client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def team_token_tuple():
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import jwt
|
||||||
|
from cryptography.hazmat.backends import default_backend
|
||||||
|
from cryptography.hazmat.primitives import serialization
|
||||||
|
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||||
|
from fastapi import Request
|
||||||
|
from starlette.datastructures import URL
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm.proxy._types import NewTeamRequest, UserAPIKeyAuth
|
||||||
|
from litellm.proxy.proxy_server import user_api_key_auth
|
||||||
|
|
||||||
|
# Generate a private / public key pair using RSA algorithm
|
||||||
|
key = rsa.generate_private_key(
|
||||||
|
public_exponent=65537, key_size=2048, backend=default_backend()
|
||||||
|
)
|
||||||
|
# Get private key in PEM format
|
||||||
|
private_key = key.private_bytes(
|
||||||
|
encoding=serialization.Encoding.PEM,
|
||||||
|
format=serialization.PrivateFormat.PKCS8,
|
||||||
|
encryption_algorithm=serialization.NoEncryption(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get public key in PEM format
|
||||||
|
public_key = key.public_key().public_bytes(
|
||||||
|
encoding=serialization.Encoding.PEM,
|
||||||
|
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||||
|
)
|
||||||
|
|
||||||
|
public_key_obj = serialization.load_pem_public_key(
|
||||||
|
public_key, backend=default_backend()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert RSA public key object to JWK (JSON Web Key)
|
||||||
|
public_jwk = json.loads(jwt.algorithms.RSAAlgorithm.to_jwk(public_key_obj))
|
||||||
|
|
||||||
|
# VALID TOKEN
|
||||||
|
## GENERATE A TOKEN
|
||||||
|
# Assuming the current time is in UTC
|
||||||
|
expiration_time = int((datetime.utcnow() + timedelta(minutes=10)).timestamp())
|
||||||
|
|
||||||
|
team_id = f"team123_{uuid.uuid4()}"
|
||||||
|
payload = {
|
||||||
|
"sub": "user123",
|
||||||
|
"exp": expiration_time, # set the token to expire in 10 minutes
|
||||||
|
"scope": "litellm_team",
|
||||||
|
"client_id": team_id,
|
||||||
|
"aud": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Generate the JWT token
|
||||||
|
# But before, you should convert bytes to string
|
||||||
|
private_key_str = private_key.decode("utf-8")
|
||||||
|
|
||||||
|
## team token
|
||||||
|
token = jwt.encode(payload, private_key_str, algorithm="RS256")
|
||||||
|
|
||||||
|
return team_id, token, public_jwk
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("audience", [None, "litellm-proxy"])
|
@pytest.mark.parametrize("audience", [None, "litellm-proxy"])
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_team_token_output(prisma_client, audience):
|
async def test_team_token_output(prisma_client, audience):
|
||||||
|
@ -750,3 +816,33 @@ async def test_allowed_routes_admin(prisma_client, audience):
|
||||||
result = await user_api_key_auth(request=request, api_key=bearer_token)
|
result = await user_api_key_auth(request=request, api_key=bearer_token)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_team_cache_update_called():
|
||||||
|
import litellm
|
||||||
|
from litellm.proxy.proxy_server import user_api_key_cache
|
||||||
|
|
||||||
|
# Use setattr to replace the method on the user_api_key_cache object
|
||||||
|
cache = DualCache()
|
||||||
|
|
||||||
|
setattr(
|
||||||
|
litellm.proxy.proxy_server,
|
||||||
|
"user_api_key_cache",
|
||||||
|
cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(cache, "async_get_cache", new=AsyncMock()) as mock_call_cache:
|
||||||
|
cache.async_get_cache = mock_call_cache
|
||||||
|
# Call the function under test
|
||||||
|
await litellm.proxy.proxy_server.update_cache(
|
||||||
|
token=None, user_id=None, end_user_id=None, team_id="1234", response_cost=20
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
await asyncio.sleep(3)
|
||||||
|
mock_call_cache.assert_awaited_once()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue