diff --git a/litellm/proxy/_super_secret_config.yaml b/litellm/proxy/_super_secret_config.yaml index b5946afd6d..c19076f299 100644 --- a/litellm/proxy/_super_secret_config.yaml +++ b/litellm/proxy/_super_secret_config.yaml @@ -73,16 +73,12 @@ assistant_settings: router_settings: enable_pre_call_checks: true - -litellm_settings: - callbacks: ["lago"] - success_callback: ["langfuse"] - failure_callback: ["langfuse"] - cache: true - json_logs: true - + general_settings: alerting: ["slack"] + enable_jwt_auth: True + litellm_jwtauth: + team_id_jwt_field: "client_id" # key_management_system: "aws_kms" # key_management_settings: # hosted_keys: ["LITELLM_MASTER_KEY"] diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index e32d56706a..9c3a79f583 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -8,21 +8,22 @@ Run checks for: 2. If user is in budget 3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget """ +from datetime import datetime +from typing import TYPE_CHECKING, Any, Literal, Optional + +import litellm +from litellm.caching import DualCache from litellm.proxy._types import ( - LiteLLM_UserTable, LiteLLM_EndUserTable, LiteLLM_JWTAuth, - LiteLLM_TeamTable, - LiteLLMRoutes, LiteLLM_OrganizationTable, + LiteLLM_TeamTable, + LiteLLM_UserTable, + LiteLLMRoutes, LitellmUserRoles, ) -from typing import Optional, Literal, TYPE_CHECKING, Any from litellm.proxy.utils import PrismaClient, ProxyLogging, log_to_opentelemetry -from litellm.caching import DualCache -import litellm from litellm.types.services import ServiceLoggerPayload, ServiceTypes -from datetime import datetime if TYPE_CHECKING: from opentelemetry.trace import Span as _Span @@ -110,7 +111,7 @@ def common_checks( # Enterprise ONLY Feature # we already validate if user is premium_user when reading the config # Add an extra premium_usercheck here too, just incase - from litellm.proxy.proxy_server import premium_user, CommonProxyErrors + from litellm.proxy.proxy_server import CommonProxyErrors, premium_user if premium_user is not True: raise ValueError( @@ -364,7 +365,8 @@ async def get_team_object( ) # check if in cache - cached_team_obj = await user_api_key_cache.async_get_cache(key=team_id) + key = "team_id:{}".format(team_id) + cached_team_obj = await user_api_key_cache.async_get_cache(key=key) if cached_team_obj is not None: if isinstance(cached_team_obj, dict): return LiteLLM_TeamTable(**cached_team_obj) @@ -381,7 +383,7 @@ async def get_team_object( _response = LiteLLM_TeamTable(**response.dict()) # save the team object to cache - await user_api_key_cache.async_set_cache(key=response.team_id, value=_response) + await user_api_key_cache.async_set_cache(key=key, value=_response) return _response except Exception as e: diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index be3d801e72..9ab76b8d84 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -120,7 +120,7 @@ async def user_api_key_auth( ) ### USER-DEFINED AUTH FUNCTION ### if user_custom_auth is not None: - response = await user_custom_auth(request=request, api_key=api_key) + response = await user_custom_auth(request=request, api_key=api_key) # type: ignore return UserAPIKeyAuth.model_validate(response) ### LITELLM-DEFINED AUTH FUNCTION ### @@ -140,7 +140,7 @@ async def user_api_key_auth( # check if public endpoint return UserAPIKeyAuth(user_role=LitellmUserRoles.INTERNAL_USER_VIEW_ONLY) - if general_settings.get("enable_jwt_auth", False) == True: + if general_settings.get("enable_jwt_auth", False) is True: is_jwt = jwt_handler.is_jwt(token=api_key) verbose_proxy_logger.debug("is_jwt: %s", is_jwt) if is_jwt: @@ -177,7 +177,7 @@ async def user_api_key_auth( token=jwt_valid_token, default_value=None ) - if team_id is None and jwt_handler.is_required_team_id() == True: + if team_id is None and jwt_handler.is_required_team_id() is True: raise Exception( f"No team id passed in. Field checked in jwt token - '{jwt_handler.litellm_jwtauth.team_id_jwt_field}'" ) @@ -190,7 +190,7 @@ async def user_api_key_auth( user_route=route, litellm_proxy_roles=jwt_handler.litellm_jwtauth, ) - if is_allowed == False: + if is_allowed is False: allowed_routes = jwt_handler.litellm_jwtauth.team_allowed_routes # type: ignore actual_routes = get_actual_routes(allowed_routes=allowed_routes) raise Exception( diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index d36d2e157b..f50c138c25 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -1,12 +1,26 @@ -import sys, os, platform, time, copy, re, asyncio, inspect -import threading, ast -import shutil, random, traceback, requests -from datetime import datetime, timedelta, timezone -from typing import Optional, List, Callable, get_args, Set, Any, TYPE_CHECKING -import secrets, subprocess -import hashlib, uuid -import warnings +import ast +import asyncio +import copy +import hashlib import importlib +import inspect +import os +import platform +import random +import re +import secrets +import shutil +import subprocess +import sys +import threading +import time +import traceback +import uuid +import warnings +from datetime import datetime, timedelta, timezone +from typing import TYPE_CHECKING, Any, Callable, List, Optional, Set, get_args + +import requests if TYPE_CHECKING: from opentelemetry.trace import Span as _Span @@ -34,11 +48,12 @@ sys.path.insert( ) # Adds the parent directory to the system path - for litellm local dev try: - import fastapi - import backoff - import yaml # type: ignore - import orjson import logging + + import backoff + import fastapi + import orjson + import yaml # type: ignore from apscheduler.schedulers.asyncio import AsyncIOScheduler except ImportError as e: raise ImportError(f"Missing dependency {e}. Run `pip install 'litellm[proxy]'`") @@ -85,60 +100,31 @@ def generate_feedback_box(): print() # noqa -import litellm -from litellm.types.llms.openai import ( - HttpxBinaryResponseContent, -) -from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request -from litellm.proxy.utils import ( - PrismaClient, - DBClient, - get_instance_fn, - ProxyLogging, - _cache_user_row, - send_email, - get_logging_payload, - reset_budget, - hash_token, - html_form, - missing_keys_html_form, - _is_valid_team_configs, - _is_projected_spend_over_limit, - _get_projected_spend_over_limit, - update_spend, - encrypt_value, - decrypt_value, - get_error_message_str, -) -from litellm.proxy.common_utils.http_parsing_utils import _read_request_body - -from litellm import ( - CreateBatchRequest, - RetrieveBatchRequest, - ListBatchRequest, - CancelBatchRequest, - CreateFileRequest, -) -from litellm.proxy.secret_managers.google_kms import load_google_kms -from litellm.proxy.secret_managers.aws_secret_manager import ( - load_aws_secret_manager, - load_aws_kms, -) import pydantic -from litellm.proxy._types import * -from litellm.caching import DualCache, RedisCache -from litellm.proxy.health_check import perform_health_check -from litellm.router import ( - LiteLLM_Params, - Deployment, - updateDeployment, - ModelGroupInfo, - AssistantsTypedDict, + +import litellm +from litellm import ( + CancelBatchRequest, + CreateBatchRequest, + CreateFileRequest, + ListBatchRequest, + RetrieveBatchRequest, ) -from litellm.router import ModelInfo as RouterModelInfo -from litellm._logging import ( - verbose_router_logger, - verbose_proxy_logger, +from litellm._logging import verbose_proxy_logger, verbose_router_logger +from litellm.caching import DualCache, RedisCache +from litellm.exceptions import RejectedRequestError +from litellm.integrations.slack_alerting import SlackAlerting, SlackAlertingArgs +from litellm.llms.custom_httpx.httpx_handler import HTTPHandler +from litellm.proxy._types import * +from litellm.proxy.auth.auth_checks import ( + allowed_routes_check, + common_checks, + get_actual_routes, + get_end_user_object, + get_org_object, + get_team_object, + get_user_object, + log_to_opentelemetry, ) from litellm.proxy.auth.handle_jwt import JWTHandler from litellm.proxy.auth.litellm_license import LicenseCheck @@ -148,78 +134,105 @@ from litellm.proxy.auth.model_checks import ( get_team_models, ) from litellm.proxy.auth.user_api_key_auth import user_api_key_auth -from litellm.proxy.hooks.prompt_injection_detection import ( - _OPTIONAL_PromptInjectionDetection, -) -from litellm.proxy.auth.auth_checks import ( - common_checks, - get_end_user_object, - get_org_object, - get_team_object, - get_user_object, - allowed_routes_check, - get_actual_routes, - log_to_opentelemetry, -) -from litellm.llms.custom_httpx.httpx_handler import HTTPHandler -from litellm.exceptions import RejectedRequestError -from litellm.integrations.slack_alerting import SlackAlertingArgs, SlackAlerting -from litellm.scheduler import Scheduler, FlowItem, DefaultPriorities ## Import All Misc routes here ## from litellm.proxy.caching_routes import router as caching_router -from litellm.proxy.management_endpoints.team_endpoints import router as team_router -from litellm.proxy.spend_reporting_endpoints.spend_management_endpoints import ( - router as spend_management_router, +from litellm.proxy.common_utils.http_parsing_utils import _read_request_body +from litellm.proxy.health_check import perform_health_check +from litellm.proxy.health_endpoints._health_endpoints import router as health_router +from litellm.proxy.hooks.prompt_injection_detection import ( + _OPTIONAL_PromptInjectionDetection, +) +from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request +from litellm.proxy.management_endpoints.internal_user_endpoints import ( + router as internal_user_router, +) +from litellm.proxy.management_endpoints.internal_user_endpoints import user_update +from litellm.proxy.management_endpoints.key_management_endpoints import ( + _duration_in_seconds, + delete_verification_token, + generate_key_helper_fn, ) from litellm.proxy.management_endpoints.key_management_endpoints import ( router as key_management_router, - _duration_in_seconds, - generate_key_helper_fn, - delete_verification_token, ) -from litellm.proxy.management_endpoints.internal_user_endpoints import ( - router as internal_user_router, - user_update, +from litellm.proxy.management_endpoints.team_endpoints import router as team_router +from litellm.proxy.secret_managers.aws_secret_manager import ( + load_aws_kms, + load_aws_secret_manager, ) -from litellm.proxy.health_endpoints._health_endpoints import router as health_router +from litellm.proxy.secret_managers.google_kms import load_google_kms +from litellm.proxy.spend_reporting_endpoints.spend_management_endpoints import ( + router as spend_management_router, +) +from litellm.proxy.utils import ( + DBClient, + PrismaClient, + ProxyLogging, + _cache_user_row, + _get_projected_spend_over_limit, + _is_projected_spend_over_limit, + _is_valid_team_configs, + decrypt_value, + encrypt_value, + get_error_message_str, + get_instance_fn, + get_logging_payload, + hash_token, + html_form, + missing_keys_html_form, + reset_budget, + send_email, + update_spend, +) +from litellm.router import ( + AssistantsTypedDict, + Deployment, + LiteLLM_Params, + ModelGroupInfo, +) +from litellm.router import ModelInfo as RouterModelInfo +from litellm.router import updateDeployment +from litellm.scheduler import DefaultPriorities, FlowItem, Scheduler +from litellm.types.llms.openai import HttpxBinaryResponseContent try: from litellm._version import version except: version = "0.0.0" litellm.suppress_debug_info = True -from fastapi import ( - FastAPI, - Request, - HTTPException, - status, - Path, - Depends, - Header, - Response, - Form, - UploadFile, - File, -) -from fastapi.routing import APIRouter -from fastapi.security import OAuth2PasswordBearer -from fastapi.encoders import jsonable_encoder -from fastapi.responses import ( - StreamingResponse, - FileResponse, - ORJSONResponse, - JSONResponse, -) -from fastapi.openapi.utils import get_openapi -from fastapi.responses import RedirectResponse -from fastapi.middleware.cors import CORSMiddleware -from fastapi.staticfiles import StaticFiles -from fastapi.security.api_key import APIKeyHeader import json import logging from typing import Union +from fastapi import ( + Depends, + FastAPI, + File, + Form, + Header, + HTTPException, + Path, + Request, + Response, + UploadFile, + status, +) +from fastapi.encoders import jsonable_encoder +from fastapi.middleware.cors import CORSMiddleware +from fastapi.openapi.utils import get_openapi +from fastapi.responses import ( + FileResponse, + JSONResponse, + ORJSONResponse, + RedirectResponse, + StreamingResponse, +) +from fastapi.routing import APIRouter +from fastapi.security import OAuth2PasswordBearer +from fastapi.security.api_key import APIKeyHeader +from fastapi.staticfiles import StaticFiles + # import enterprise folder try: # when using litellm cli @@ -488,8 +501,8 @@ def load_from_azure_key_vault(use_azure_key_vault: bool = False): return try: - from azure.keyvault.secrets import SecretClient from azure.identity import ClientSecretCredential + from azure.keyvault.secrets import SecretClient # Set your Azure Key Vault URI KVUri = os.getenv("AZURE_KEY_VAULT_URI", None) @@ -655,6 +668,7 @@ async def _PROXY_track_cost_callback( user_id=user_id, end_user_id=end_user_id, response_cost=response_cost, + team_id=team_id, ) await proxy_logging_obj.slack_alerting_instance.customer_spend_alert( @@ -897,6 +911,7 @@ async def update_cache( token: Optional[str], user_id: Optional[str], end_user_id: Optional[str], + team_id: Optional[str], response_cost: Optional[float], ): """ @@ -993,6 +1008,7 @@ async def update_cache( existing_spend_obj.spend = new_spend user_api_key_cache.set_cache(key=hashed_token, value=existing_spend_obj) + ### UPDATE USER SPEND ### async def _update_user_cache(): ## UPDATE CACHE FOR USER ID + GLOBAL PROXY user_ids = [user_id] @@ -1054,6 +1070,7 @@ async def update_cache( f"An error occurred updating user cache: {str(e)}\n\n{traceback.format_exc()}" ) + ### UPDATE END-USER SPEND ### async def _update_end_user_cache(): if end_user_id is None or response_cost is None: return @@ -1102,14 +1119,59 @@ async def update_cache( f"An error occurred updating end user cache: {str(e)}\n\n{traceback.format_exc()}" ) + ### UPDATE TEAM SPEND ### + async def _update_team_cache(): + if team_id is None or response_cost is None: + return + + _id = "team_id:{}".format(team_id) + try: + # Fetch the existing cost for the given user + existing_spend_obj: Optional[LiteLLM_TeamTable] = ( + await user_api_key_cache.async_get_cache(key=_id) + ) + if existing_spend_obj is None: + return + verbose_proxy_logger.debug( + f"_update_team_db: existing spend: {existing_spend_obj}; response_cost: {response_cost}" + ) + if existing_spend_obj is None: + existing_spend: Optional[float] = 0.0 + else: + if isinstance(existing_spend_obj, dict): + existing_spend = existing_spend_obj["spend"] + else: + existing_spend = existing_spend_obj.spend + + if existing_spend is None: + existing_spend = 0.0 + # Calculate the new cost by adding the existing cost and response_cost + new_spend = existing_spend + response_cost + + # Update the cost column for the given user + if isinstance(existing_spend_obj, dict): + existing_spend_obj["spend"] = new_spend + user_api_key_cache.set_cache(key=_id, value=existing_spend_obj) + else: + existing_spend_obj.spend = new_spend + user_api_key_cache.set_cache(key=_id, value=existing_spend_obj) + except Exception as e: + verbose_proxy_logger.error( + f"An error occurred updating end user cache: {str(e)}\n\n{traceback.format_exc()}" + ) + if token is not None and response_cost is not None: asyncio.create_task(_update_key_cache(token=token, response_cost=response_cost)) - asyncio.create_task(_update_user_cache()) + if user_id is not None: + asyncio.create_task(_update_user_cache()) if end_user_id is not None: asyncio.create_task(_update_end_user_cache()) + if team_id is not None: + asyncio.create_task(_update_team_cache()) + def run_ollama_serve(): try: @@ -2297,25 +2359,27 @@ async def initialize( user_model = model user_debug = debug if debug == True: # this needs to be first, so users can see Router init debugg - from litellm._logging import ( - verbose_router_logger, - verbose_proxy_logger, - verbose_logger, - ) import logging + from litellm._logging import ( + verbose_logger, + verbose_proxy_logger, + verbose_router_logger, + ) + # this must ALWAYS remain logging.INFO, DO NOT MODIFY THIS verbose_logger.setLevel(level=logging.INFO) # sets package logs to info verbose_router_logger.setLevel(level=logging.INFO) # set router logs to info verbose_proxy_logger.setLevel(level=logging.INFO) # set proxy logs to info if detailed_debug == True: - from litellm._logging import ( - verbose_router_logger, - verbose_proxy_logger, - verbose_logger, - ) import logging + from litellm._logging import ( + verbose_logger, + verbose_proxy_logger, + verbose_router_logger, + ) + verbose_logger.setLevel(level=logging.DEBUG) # set package log to debug verbose_router_logger.setLevel(level=logging.DEBUG) # set router logs to debug verbose_proxy_logger.setLevel(level=logging.DEBUG) # set proxy logs to debug @@ -2324,9 +2388,10 @@ async def initialize( litellm_log_setting = os.environ.get("LITELLM_LOG", "") if litellm_log_setting != None: if litellm_log_setting.upper() == "INFO": - from litellm._logging import verbose_router_logger, verbose_proxy_logger import logging + from litellm._logging import verbose_proxy_logger, verbose_router_logger + # this must ALWAYS remain logging.INFO, DO NOT MODIFY THIS verbose_router_logger.setLevel( @@ -2336,9 +2401,10 @@ async def initialize( level=logging.INFO ) # set proxy logs to info elif litellm_log_setting.upper() == "DEBUG": - from litellm._logging import verbose_router_logger, verbose_proxy_logger import logging + from litellm._logging import verbose_proxy_logger, verbose_router_logger + verbose_router_logger.setLevel( level=logging.DEBUG ) # set router logs to info @@ -7036,7 +7102,7 @@ async def google_login(request: Request): with microsoft_sso: return await microsoft_sso.get_login_redirect() elif generic_client_id is not None: - from fastapi_sso.sso.generic import create_provider, DiscoveryDocument + from fastapi_sso.sso.generic import DiscoveryDocument, create_provider generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None) generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ") @@ -7611,7 +7677,7 @@ async def auth_callback(request: Request): result = await microsoft_sso.verify_and_process(request) elif generic_client_id is not None: # make generic sso provider - from fastapi_sso.sso.generic import create_provider, DiscoveryDocument, OpenID + from fastapi_sso.sso.generic import DiscoveryDocument, OpenID, create_provider generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None) generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ") diff --git a/litellm/tests/test_jwt.py b/litellm/tests/test_jwt.py index 960b85af59..72d4d7b1bf 100644 --- a/litellm/tests/test_jwt.py +++ b/litellm/tests/test_jwt.py @@ -18,6 +18,7 @@ sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path from datetime import datetime, timedelta +from unittest.mock import AsyncMock, MagicMock, patch import pytest from fastapi import Request @@ -26,6 +27,7 @@ from litellm.caching import DualCache from litellm.proxy._types import LiteLLM_JWTAuth, LiteLLMRoutes from litellm.proxy.auth.handle_jwt import JWTHandler from litellm.proxy.management_endpoints.team_endpoints import new_team +from litellm.proxy.proxy_server import chat_completion public_key = { "kty": "RSA", @@ -220,6 +222,70 @@ def prisma_client(): return prisma_client +@pytest.fixture +def team_token_tuple(): + import json + import uuid + + import jwt + from cryptography.hazmat.backends import default_backend + from cryptography.hazmat.primitives import serialization + from cryptography.hazmat.primitives.asymmetric import rsa + from fastapi import Request + from starlette.datastructures import URL + + import litellm + from litellm.proxy._types import NewTeamRequest, UserAPIKeyAuth + from litellm.proxy.proxy_server import user_api_key_auth + + # Generate a private / public key pair using RSA algorithm + key = rsa.generate_private_key( + public_exponent=65537, key_size=2048, backend=default_backend() + ) + # Get private key in PEM format + private_key = key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + + # Get public key in PEM format + public_key = key.public_key().public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + + public_key_obj = serialization.load_pem_public_key( + public_key, backend=default_backend() + ) + + # Convert RSA public key object to JWK (JSON Web Key) + public_jwk = json.loads(jwt.algorithms.RSAAlgorithm.to_jwk(public_key_obj)) + + # VALID TOKEN + ## GENERATE A TOKEN + # Assuming the current time is in UTC + expiration_time = int((datetime.utcnow() + timedelta(minutes=10)).timestamp()) + + team_id = f"team123_{uuid.uuid4()}" + payload = { + "sub": "user123", + "exp": expiration_time, # set the token to expire in 10 minutes + "scope": "litellm_team", + "client_id": team_id, + "aud": None, + } + + # Generate the JWT token + # But before, you should convert bytes to string + private_key_str = private_key.decode("utf-8") + + ## team token + token = jwt.encode(payload, private_key_str, algorithm="RS256") + + return team_id, token, public_jwk + + @pytest.mark.parametrize("audience", [None, "litellm-proxy"]) @pytest.mark.asyncio async def test_team_token_output(prisma_client, audience): @@ -750,3 +816,33 @@ async def test_allowed_routes_admin(prisma_client, audience): result = await user_api_key_auth(request=request, api_key=bearer_token) except Exception as e: raise e + + +from unittest.mock import AsyncMock + +import pytest + + +@pytest.mark.asyncio +async def test_team_cache_update_called(): + import litellm + from litellm.proxy.proxy_server import user_api_key_cache + + # Use setattr to replace the method on the user_api_key_cache object + cache = DualCache() + + setattr( + litellm.proxy.proxy_server, + "user_api_key_cache", + cache, + ) + + with patch.object(cache, "async_get_cache", new=AsyncMock()) as mock_call_cache: + cache.async_get_cache = mock_call_cache + # Call the function under test + await litellm.proxy.proxy_server.update_cache( + token=None, user_id=None, end_user_id=None, team_id="1234", response_cost=20 + ) # type: ignore + + await asyncio.sleep(3) + mock_call_cache.assert_awaited_once()