diff --git a/litellm/integrations/custom_guardrail.py b/litellm/integrations/custom_guardrail.py index 3053a4ad1f..816b024c72 100644 --- a/litellm/integrations/custom_guardrail.py +++ b/litellm/integrations/custom_guardrail.py @@ -29,6 +29,7 @@ class CustomGuardrail(CustomLogger): def should_run_guardrail(self, data, event_type: GuardrailEventHooks) -> bool: metadata = data.get("metadata") or {} requested_guardrails = metadata.get("guardrails") or [] + verbose_logger.debug( "inside should_run_guardrail for guardrail=%s event_type= %s guardrail_supported_event_hooks= %s requested_guardrails= %s", self.guardrail_name, diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index fed1cc2863..e2fad98387 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -969,10 +969,10 @@ class Logging: ), result=result, ) - ## LOGGING HOOK ## for callback in callbacks: if isinstance(callback, CustomLogger): + self.model_call_details, result = callback.logging_hook( kwargs=self.model_call_details, result=result, diff --git a/litellm/proxy/_experimental/out/404.html b/litellm/proxy/_experimental/out/404.html deleted file mode 100644 index 6656e5452d..0000000000 --- a/litellm/proxy/_experimental/out/404.html +++ /dev/null @@ -1 +0,0 @@ -404: This page could not be found.LiteLLM Dashboard

404

This page could not be found.

\ No newline at end of file diff --git a/litellm/proxy/_experimental/out/model_hub.html b/litellm/proxy/_experimental/out/model_hub.html deleted file mode 100644 index ffebd4c6e2..0000000000 --- a/litellm/proxy/_experimental/out/model_hub.html +++ /dev/null @@ -1 +0,0 @@ -LiteLLM Dashboard \ No newline at end of file diff --git a/litellm/proxy/_experimental/out/onboarding.html b/litellm/proxy/_experimental/out/onboarding.html deleted file mode 100644 index f59bdc9d5d..0000000000 --- a/litellm/proxy/_experimental/out/onboarding.html +++ /dev/null @@ -1 +0,0 @@ -LiteLLM Dashboard \ No newline at end of file diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 599a1bf230..84075f53e0 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -1,48 +1,30 @@ model_list: - # GPT-4 Turbo Models - - model_name: gpt-4 - litellm_params: - model: gpt-4 - rpm: 1 - - model_name: gpt-4 + - model_name: gpt-3.5-turbo litellm_params: model: azure/chatgpt-v-2 api_key: os.environ/AZURE_API_KEY api_base: os.environ/AZURE_API_BASE - - model_name: rerank-model + temperature: 0.2 + +guardrails: + - guardrail_name: "presidio-log-guard" litellm_params: - model: jina_ai/jina-reranker-v2-base-multilingual - - model_name: anthropic-vertex - litellm_params: - model: vertex_ai/claude-3-5-sonnet-v2 - vertex_ai_project: "adroit-crow-413218" - vertex_ai_location: "us-east5" - - model_name: openai-gpt-4o-realtime-audio - litellm_params: - model: openai/gpt-4o-realtime-preview-2024-10-01 - api_key: os.environ/OPENAI_API_KEY - - model_name: openai/* - litellm_params: - model: openai/* - api_key: os.environ/OPENAI_API_KEY - - model_name: openai/* - litellm_params: - model: openai/* - api_key: os.environ/OPENAI_API_KEY - model_info: - access_groups: ["public-openai-models"] - - model_name: openai/gpt-4o - litellm_params: - model: openai/gpt-4o - api_key: os.environ/OPENAI_API_KEY - model_info: - access_groups: ["private-openai-models"] - -router_settings: - # routing_strategy: usage-based-routing-v2 - #redis_url: "os.environ/REDIS_URL" - redis_host: "os.environ/REDIS_HOST" - redis_port: "os.environ/REDIS_PORT" + guardrail: presidio + mode: "logging_only" + mock_redacted_text: + text: "hello world, my name is . My number is: " + items: + - start: 48 + end: 62 + entity_type: PHONE_NUMBER + text: "" + operator: replace + - start: 24 + end: 32 + entity_type: PERSON + text: "" + operator: replace litellm_settings: - callbacks: ["datadog"] \ No newline at end of file + set_verbose: true + success_callback: ["langfuse"] \ No newline at end of file diff --git a/litellm/proxy/auth/handle_jwt.py b/litellm/proxy/auth/handle_jwt.py index 49d6c3520e..bfcde8fb1e 100644 --- a/litellm/proxy/auth/handle_jwt.py +++ b/litellm/proxy/auth/handle_jwt.py @@ -36,16 +36,19 @@ class JWTHandler: self, ) -> None: self.http_handler = HTTPHandler() + self.leeway = 0 def update_environment( self, prisma_client: Optional[PrismaClient], user_api_key_cache: DualCache, litellm_jwtauth: LiteLLM_JWTAuth, + leeway: int = 0, ) -> None: self.prisma_client = prisma_client self.user_api_key_cache = user_api_key_cache self.litellm_jwtauth = litellm_jwtauth + self.leeway = leeway def is_jwt(self, token: str): parts = token.split(".") @@ -271,6 +274,7 @@ class JWTHandler: algorithms=algorithms, options=decode_options, audience=audience, + leeway=self.leeway, # allow testing of expired tokens ) return payload diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index e9f2fc6d0c..ae6d4774be 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -562,6 +562,7 @@ async def user_api_key_auth( # noqa: PLR0915 user_id=user_id, org_id=org_id, parent_otel_span=parent_otel_span, + end_user_id=end_user_id, ) #### ELSE #### ## CHECK PASS-THROUGH ENDPOINTS ## diff --git a/litellm/proxy/common_utils/callback_utils.py b/litellm/proxy/common_utils/callback_utils.py index 40f66e90b6..fa7208d3c0 100644 --- a/litellm/proxy/common_utils/callback_utils.py +++ b/litellm/proxy/common_utils/callback_utils.py @@ -48,7 +48,7 @@ def initialize_callbacks_on_proxy( # noqa: PLR0915 imported_list.append(open_telemetry_logger) setattr(proxy_server, "open_telemetry_logger", open_telemetry_logger) elif isinstance(callback, str) and callback == "presidio": - from litellm.proxy.hooks.presidio_pii_masking import ( + from litellm.proxy.guardrails.guardrail_hooks.presidio import ( _OPTIONAL_PresidioPIIMasking, ) diff --git a/litellm/proxy/guardrails/guardrail_hooks/presidio.py b/litellm/proxy/guardrails/guardrail_hooks/presidio.py index da53e4a8ae..384b2cb999 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/presidio.py +++ b/litellm/proxy/guardrails/guardrail_hooks/presidio.py @@ -24,6 +24,7 @@ from litellm._logging import verbose_proxy_logger from litellm.caching.caching import DualCache from litellm.integrations.custom_guardrail import CustomGuardrail from litellm.proxy._types import UserAPIKeyAuth +from litellm.types.guardrails import GuardrailEventHooks from litellm.utils import ( EmbeddingResponse, ImageResponse, @@ -54,8 +55,13 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail): presidio_anonymizer_api_base: Optional[str] = None, output_parse_pii: Optional[bool] = False, presidio_ad_hoc_recognizers: Optional[str] = None, + logging_only: Optional[bool] = None, **kwargs, ): + if logging_only is True: + self.logging_only = True + kwargs["event_hook"] = GuardrailEventHooks.logging_only + super().__init__(**kwargs) self.pii_tokens: dict = ( {} ) # mapping of PII token to original text - only used with Presidio `replace` operation @@ -84,8 +90,6 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail): presidio_anonymizer_api_base=presidio_anonymizer_api_base, ) - super().__init__(**kwargs) - def validate_environment( self, presidio_analyzer_api_base: Optional[str] = None, @@ -245,10 +249,44 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail): verbose_proxy_logger.info( f"Presidio PII Masking: Redacted pii message: {data['messages']}" ) + data["messages"] = messages return data except Exception as e: raise e + def logging_hook( + self, kwargs: dict, result: Any, call_type: str + ) -> Tuple[dict, Any]: + import threading + from concurrent.futures import ThreadPoolExecutor + + def run_in_new_loop(): + """Run the coroutine in a new event loop within this thread.""" + new_loop = asyncio.new_event_loop() + try: + asyncio.set_event_loop(new_loop) + return new_loop.run_until_complete( + self.async_logging_hook( + kwargs=kwargs, result=result, call_type=call_type + ) + ) + finally: + new_loop.close() + asyncio.set_event_loop(None) + + try: + # First, try to get the current event loop + _ = asyncio.get_running_loop() + # If we're already in an event loop, run in a separate thread + # to avoid nested event loop issues + with ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(run_in_new_loop) + return future.result() + + except RuntimeError: + # No running event loop, we can safely run in this thread + return run_in_new_loop() + async def async_logging_hook( self, kwargs: dict, result: Any, call_type: str ) -> Tuple[dict, Any]: @@ -304,7 +342,8 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail): verbose_proxy_logger.debug( f"PII Masking Args: self.output_parse_pii={self.output_parse_pii}; type of response={type(response)}" ) - if self.output_parse_pii is False: + + if self.output_parse_pii is False and litellm.output_parse_pii is False: return response if isinstance(response, ModelResponse) and not isinstance( diff --git a/litellm/proxy/hooks/presidio_pii_masking.py b/litellm/proxy/hooks/presidio_pii_masking.py deleted file mode 100644 index 603e075620..0000000000 --- a/litellm/proxy/hooks/presidio_pii_masking.py +++ /dev/null @@ -1,349 +0,0 @@ -# +-----------------------------------------------+ -# | | -# | PII Masking | -# | with Microsoft Presidio | -# | https://github.com/BerriAI/litellm/issues/ | -# +-----------------------------------------------+ -# -# Tell us how we can improve! - Krrish & Ishaan - - -import asyncio -import json -import traceback -import uuid -from typing import Any, List, Optional, Tuple, Union - -import aiohttp -from fastapi import HTTPException - -import litellm # noqa: E401 -from litellm._logging import verbose_proxy_logger -from litellm.caching.caching import DualCache -from litellm.integrations.custom_logger import CustomLogger -from litellm.proxy._types import UserAPIKeyAuth -from litellm.utils import ( - EmbeddingResponse, - ImageResponse, - ModelResponse, - StreamingChoices, - get_formatted_prompt, -) - - -class _OPTIONAL_PresidioPIIMasking(CustomLogger): - user_api_key_cache = None - ad_hoc_recognizers = None - - # Class variables or attributes - def __init__( - self, - logging_only: Optional[bool] = None, - mock_testing: bool = False, - mock_redacted_text: Optional[dict] = None, - ): - self.pii_tokens: dict = ( - {} - ) # mapping of PII token to original text - only used with Presidio `replace` operation - - self.mock_redacted_text = mock_redacted_text - self.logging_only = logging_only - if mock_testing is True: # for testing purposes only - return - - ad_hoc_recognizers = litellm.presidio_ad_hoc_recognizers - if ad_hoc_recognizers is not None: - try: - with open(ad_hoc_recognizers, "r") as file: - self.ad_hoc_recognizers = json.load(file) - except FileNotFoundError: - raise Exception(f"File not found. file_path={ad_hoc_recognizers}") - except json.JSONDecodeError as e: - raise Exception( - f"Error decoding JSON file: {str(e)}, file_path={ad_hoc_recognizers}" - ) - except Exception as e: - raise Exception( - f"An error occurred: {str(e)}, file_path={ad_hoc_recognizers}" - ) - - self.validate_environment() - - def validate_environment(self): - self.presidio_analyzer_api_base: Optional[str] = litellm.get_secret( - "PRESIDIO_ANALYZER_API_BASE", None - ) # type: ignore - self.presidio_anonymizer_api_base: Optional[str] = litellm.get_secret( - "PRESIDIO_ANONYMIZER_API_BASE", None - ) # type: ignore - - if self.presidio_analyzer_api_base is None: - raise Exception("Missing `PRESIDIO_ANALYZER_API_BASE` from environment") - if not self.presidio_analyzer_api_base.endswith("/"): - self.presidio_analyzer_api_base += "/" - if not ( - self.presidio_analyzer_api_base.startswith("http://") - or self.presidio_analyzer_api_base.startswith("https://") - ): - # add http:// if unset, assume communicating over private network - e.g. render - self.presidio_analyzer_api_base = ( - "http://" + self.presidio_analyzer_api_base - ) - - if self.presidio_anonymizer_api_base is None: - raise Exception("Missing `PRESIDIO_ANONYMIZER_API_BASE` from environment") - if not self.presidio_anonymizer_api_base.endswith("/"): - self.presidio_anonymizer_api_base += "/" - if not ( - self.presidio_anonymizer_api_base.startswith("http://") - or self.presidio_anonymizer_api_base.startswith("https://") - ): - # add http:// if unset, assume communicating over private network - e.g. render - self.presidio_anonymizer_api_base = ( - "http://" + self.presidio_anonymizer_api_base - ) - - def print_verbose(self, print_statement): - try: - verbose_proxy_logger.debug(print_statement) - if litellm.set_verbose: - print(print_statement) # noqa - except Exception: - pass - - async def check_pii(self, text: str, output_parse_pii: bool) -> str: - """ - [TODO] make this more performant for high-throughput scenario - """ - try: - async with aiohttp.ClientSession() as session: - if self.mock_redacted_text is not None: - redacted_text = self.mock_redacted_text - else: - # Make the first request to /analyze - analyze_url = f"{self.presidio_analyzer_api_base}analyze" - verbose_proxy_logger.debug("Making request to: %s", analyze_url) - analyze_payload = {"text": text, "language": "en"} - if self.ad_hoc_recognizers is not None: - analyze_payload["ad_hoc_recognizers"] = self.ad_hoc_recognizers - redacted_text = None - async with session.post( - analyze_url, json=analyze_payload - ) as response: - analyze_results = await response.json() - - # Make the second request to /anonymize - anonymize_url = f"{self.presidio_anonymizer_api_base}anonymize" - verbose_proxy_logger.debug("Making request to: %s", anonymize_url) - anonymize_payload = { - "text": text, - "analyzer_results": analyze_results, - } - - async with session.post( - anonymize_url, json=anonymize_payload - ) as response: - redacted_text = await response.json() - - new_text = text - if redacted_text is not None: - verbose_proxy_logger.debug("redacted_text: %s", redacted_text) - for item in redacted_text["items"]: - start = item["start"] - end = item["end"] - replacement = item["text"] # replacement token - if item["operator"] == "replace" and output_parse_pii is True: - # check if token in dict - # if exists, add a uuid to the replacement token for swapping back to the original text in llm response output parsing - if replacement in self.pii_tokens: - replacement = replacement + str(uuid.uuid4()) - - self.pii_tokens[replacement] = new_text[ - start:end - ] # get text it'll replace - - new_text = new_text[:start] + replacement + new_text[end:] - return redacted_text["text"] - else: - raise Exception(f"Invalid anonymizer response: {redacted_text}") - except Exception as e: - verbose_proxy_logger.error( - "litellm.proxy.hooks.presidio_pii_masking.py::async_pre_call_hook(): Exception occured - {}".format( - str(e) - ) - ) - verbose_proxy_logger.debug(traceback.format_exc()) - raise e - - async def async_pre_call_hook( - self, - user_api_key_dict: UserAPIKeyAuth, - cache: DualCache, - data: dict, - call_type: str, - ): - """ - - Check if request turned off pii - - Check if user allowed to turn off pii (key permissions -> 'allow_pii_controls') - - - Take the request data - - Call /analyze -> get the results - - Call /anonymize w/ the analyze results -> get the redacted text - - For multiple messages in /chat/completions, we'll need to call them in parallel. - """ - try: - if ( - self.logging_only is True - ): # only modify the logging obj data (done by async_logging_hook) - return data - permissions = user_api_key_dict.permissions - output_parse_pii = permissions.get( - "output_parse_pii", litellm.output_parse_pii - ) # allow key to turn on/off output parsing for pii - no_pii = permissions.get( - "no-pii", None - ) # allow key to turn on/off pii masking (if user is allowed to set pii controls, then they can override the key defaults) - - if no_pii is None: - # check older way of turning on/off pii - no_pii = not permissions.get("pii", True) - - content_safety = data.get("content_safety", None) - verbose_proxy_logger.debug("content_safety: %s", content_safety) - ## Request-level turn on/off PII controls ## - if content_safety is not None and isinstance(content_safety, dict): - # pii masking ## - if ( - content_safety.get("no-pii", None) is not None - and content_safety.get("no-pii") is True - ): - # check if user allowed to turn this off - if permissions.get("allow_pii_controls", False) is False: - raise HTTPException( - status_code=400, - detail={ - "error": "Not allowed to set PII controls per request" - }, - ) - else: # user allowed to turn off pii masking - no_pii = content_safety.get("no-pii") - if not isinstance(no_pii, bool): - raise HTTPException( - status_code=400, - detail={"error": "no_pii needs to be a boolean value"}, - ) - ## pii output parsing ## - if content_safety.get("output_parse_pii", None) is not None: - # check if user allowed to turn this off - if permissions.get("allow_pii_controls", False) is False: - raise HTTPException( - status_code=400, - detail={ - "error": "Not allowed to set PII controls per request" - }, - ) - else: # user allowed to turn on/off pii output parsing - output_parse_pii = content_safety.get("output_parse_pii") - if not isinstance(output_parse_pii, bool): - raise HTTPException( - status_code=400, - detail={ - "error": "output_parse_pii needs to be a boolean value" - }, - ) - - if no_pii is True: # turn off pii masking - return data - - if call_type == "completion": # /chat/completions requests - messages = data["messages"] - tasks = [] - - for m in messages: - if isinstance(m["content"], str): - tasks.append( - self.check_pii( - text=m["content"], output_parse_pii=output_parse_pii - ) - ) - responses = await asyncio.gather(*tasks) - for index, r in enumerate(responses): - if isinstance(messages[index]["content"], str): - messages[index][ - "content" - ] = r # replace content with redacted string - verbose_proxy_logger.info( - f"Presidio PII Masking: Redacted pii message: {data['messages']}" - ) - return data - except Exception as e: - verbose_proxy_logger.info( - "An error occurred -", - ) - raise e - - async def async_logging_hook( - self, kwargs: dict, result: Any, call_type: str - ) -> Tuple[dict, Any]: - """ - Masks the input before logging to langfuse, datadog, etc. - """ - if ( - call_type == "completion" or call_type == "acompletion" - ): # /chat/completions requests - messages: Optional[List] = kwargs.get("messages", None) - tasks = [] - - if messages is None: - return kwargs, result - - for m in messages: - text_str = "" - if m["content"] is None: - continue - if isinstance(m["content"], str): - text_str = m["content"] - tasks.append( - self.check_pii(text=text_str, output_parse_pii=False) - ) # need to pass separately b/c presidio has context window limits - responses = await asyncio.gather(*tasks) - for index, r in enumerate(responses): - if isinstance(messages[index]["content"], str): - messages[index][ - "content" - ] = r # replace content with redacted string - verbose_proxy_logger.info( - f"Presidio PII Masking: Redacted pii message: {messages}" - ) - kwargs["messages"] = messages - - return kwargs, result - - async def async_post_call_success_hook( - self, - data: dict, - user_api_key_dict: UserAPIKeyAuth, - response: Union[ModelResponse, EmbeddingResponse, ImageResponse], - ): - """ - Output parse the response object to replace the masked tokens with user sent values - """ - verbose_proxy_logger.debug( - f"PII Masking Args: litellm.output_parse_pii={litellm.output_parse_pii}; type of response={type(response)}" - ) - if litellm.output_parse_pii is False: - return response - - if isinstance(response, ModelResponse) and not isinstance( - response.choices[0], StreamingChoices - ): # /chat/completions requests - if isinstance(response.choices[0].message.content, str): - verbose_proxy_logger.debug( - f"self.pii_tokens: {self.pii_tokens}; initial response: {response.choices[0].message.content}" - ) - for key, value in self.pii_tokens.items(): - response.choices[0].message.content = response.choices[ - 0 - ].message.content.replace(key, value) - return response diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index a72297ac3e..93df33d757 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -5987,6 +5987,39 @@ async def new_budget( return response +@router.post( + "/budget/update", + tags=["budget management"], + dependencies=[Depends(user_api_key_auth)], +) +async def update_budget( + budget_obj: BudgetNew, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Create a new budget object. Can apply this to teams, orgs, end-users, keys. + """ + global prisma_client + + if prisma_client is None: + raise HTTPException( + status_code=500, + detail={"error": CommonProxyErrors.db_not_connected_error.value}, + ) + if budget_obj.budget_id is None: + raise HTTPException(status_code=400, detail={"error": "budget_id is required"}) + + response = await prisma_client.db.litellm_budgettable.update( + where={"budget_id": budget_obj.budget_id}, + data={ + **budget_obj.model_dump(exclude_none=True), # type: ignore + "updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name, + }, # type: ignore + ) + + return response + + @router.post( "/budget/info", tags=["budget management"], diff --git a/litellm/router.py b/litellm/router.py index 3f21597945..34143bd500 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -83,6 +83,9 @@ from litellm.router_utils.fallback_event_handlers import ( run_async_fallback, run_sync_fallback, ) +from litellm.router_utils.get_retry_from_policy import ( + get_num_retries_from_retry_policy as _get_num_retries_from_retry_policy, +) from litellm.router_utils.handle_error import ( async_raise_no_deployment_exception, send_llm_exception_alert, @@ -5609,53 +5612,12 @@ class Router: def get_num_retries_from_retry_policy( self, exception: Exception, model_group: Optional[str] = None ): - """ - BadRequestErrorRetries: Optional[int] = None - AuthenticationErrorRetries: Optional[int] = None - TimeoutErrorRetries: Optional[int] = None - RateLimitErrorRetries: Optional[int] = None - ContentPolicyViolationErrorRetries: Optional[int] = None - """ - # if we can find the exception then in the retry policy -> return the number of retries - retry_policy: Optional[RetryPolicy] = self.retry_policy - - if ( - self.model_group_retry_policy is not None - and model_group is not None - and model_group in self.model_group_retry_policy - ): - retry_policy = self.model_group_retry_policy.get(model_group, None) # type: ignore - - if retry_policy is None: - return None - if isinstance(retry_policy, dict): - retry_policy = RetryPolicy(**retry_policy) - - if ( - isinstance(exception, litellm.BadRequestError) - and retry_policy.BadRequestErrorRetries is not None - ): - return retry_policy.BadRequestErrorRetries - if ( - isinstance(exception, litellm.AuthenticationError) - and retry_policy.AuthenticationErrorRetries is not None - ): - return retry_policy.AuthenticationErrorRetries - if ( - isinstance(exception, litellm.Timeout) - and retry_policy.TimeoutErrorRetries is not None - ): - return retry_policy.TimeoutErrorRetries - if ( - isinstance(exception, litellm.RateLimitError) - and retry_policy.RateLimitErrorRetries is not None - ): - return retry_policy.RateLimitErrorRetries - if ( - isinstance(exception, litellm.ContentPolicyViolationError) - and retry_policy.ContentPolicyViolationErrorRetries is not None - ): - return retry_policy.ContentPolicyViolationErrorRetries + return _get_num_retries_from_retry_policy( + exception=exception, + model_group=model_group, + model_group_retry_policy=self.model_group_retry_policy, + retry_policy=self.retry_policy, + ) def get_allowed_fails_from_policy(self, exception: Exception): """ diff --git a/litellm/router_utils/get_retry_from_policy.py b/litellm/router_utils/get_retry_from_policy.py new file mode 100644 index 0000000000..48df43ef81 --- /dev/null +++ b/litellm/router_utils/get_retry_from_policy.py @@ -0,0 +1,71 @@ +""" +Get num retries for an exception. + +- Account for retry policy by exception type. +""" + +from typing import Dict, Optional, Union + +from litellm.exceptions import ( + AuthenticationError, + BadRequestError, + ContentPolicyViolationError, + RateLimitError, + Timeout, +) +from litellm.types.router import RetryPolicy + + +def get_num_retries_from_retry_policy( + exception: Exception, + retry_policy: Optional[Union[RetryPolicy, dict]] = None, + model_group: Optional[str] = None, + model_group_retry_policy: Optional[Dict[str, RetryPolicy]] = None, +): + """ + BadRequestErrorRetries: Optional[int] = None + AuthenticationErrorRetries: Optional[int] = None + TimeoutErrorRetries: Optional[int] = None + RateLimitErrorRetries: Optional[int] = None + ContentPolicyViolationErrorRetries: Optional[int] = None + """ + # if we can find the exception then in the retry policy -> return the number of retries + + if ( + model_group_retry_policy is not None + and model_group is not None + and model_group in model_group_retry_policy + ): + retry_policy = model_group_retry_policy.get(model_group, None) # type: ignore + + if retry_policy is None: + return None + if isinstance(retry_policy, dict): + retry_policy = RetryPolicy(**retry_policy) + + if ( + isinstance(exception, BadRequestError) + and retry_policy.BadRequestErrorRetries is not None + ): + return retry_policy.BadRequestErrorRetries + if ( + isinstance(exception, AuthenticationError) + and retry_policy.AuthenticationErrorRetries is not None + ): + return retry_policy.AuthenticationErrorRetries + if isinstance(exception, Timeout) and retry_policy.TimeoutErrorRetries is not None: + return retry_policy.TimeoutErrorRetries + if ( + isinstance(exception, RateLimitError) + and retry_policy.RateLimitErrorRetries is not None + ): + return retry_policy.RateLimitErrorRetries + if ( + isinstance(exception, ContentPolicyViolationError) + and retry_policy.ContentPolicyViolationErrorRetries is not None + ): + return retry_policy.ContentPolicyViolationErrorRetries + + +def reset_retry_policy() -> RetryPolicy: + return RetryPolicy() diff --git a/litellm/utils.py b/litellm/utils.py index fe43421732..f11368fc41 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -97,6 +97,10 @@ from litellm.litellm_core_utils.rules import Rules from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper from litellm.litellm_core_utils.token_counter import get_modified_max_tokens from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler +from litellm.router_utils.get_retry_from_policy import ( + get_num_retries_from_retry_policy, + reset_retry_policy, +) from litellm.secret_managers.main import get_secret from litellm.types.llms.openai import ( AllMessageValues, @@ -918,6 +922,14 @@ def client(original_function): # noqa: PLR0915 num_retries = ( kwargs.get("num_retries", None) or litellm.num_retries or None ) + if kwargs.get("retry_policy", None): + num_retries = get_num_retries_from_retry_policy( + exception=e, + retry_policy=kwargs.get("retry_policy"), + ) + kwargs["retry_policy"] = ( + reset_retry_policy() + ) # prevent infinite loops litellm.num_retries = ( None # set retries to None to prevent infinite loops ) @@ -1137,6 +1149,13 @@ def client(original_function): # noqa: PLR0915 num_retries = ( kwargs.get("num_retries", None) or litellm.num_retries or None ) + if kwargs.get("retry_policy", None): + num_retries = get_num_retries_from_retry_policy( + exception=e, + retry_policy=kwargs.get("retry_policy"), + ) + kwargs["retry_policy"] = reset_retry_policy() + litellm.num_retries = ( None # set retries to None to prevent infinite loops ) diff --git a/tests/local_testing/test_completion_with_retries.py b/tests/local_testing/test_completion_with_retries.py index e59d1d6e13..efb66c40c6 100644 --- a/tests/local_testing/test_completion_with_retries.py +++ b/tests/local_testing/test_completion_with_retries.py @@ -61,3 +61,68 @@ def test_completion_with_0_num_retries(): except Exception as e: print("exception", e) pass + + +@pytest.mark.asyncio +@pytest.mark.parametrize("sync_mode", [True, False]) +async def test_completion_with_retry_policy(sync_mode): + from unittest.mock import patch, MagicMock, AsyncMock + from litellm.types.router import RetryPolicy + + retry_number = 1 + retry_policy = RetryPolicy( + ContentPolicyViolationErrorRetries=retry_number, # run 3 retries for ContentPolicyViolationErrors + AuthenticationErrorRetries=0, # run 0 retries for AuthenticationErrorRetries + ) + + target_function = "completion_with_retries" + + with patch.object(litellm, target_function) as mock_completion_with_retries: + data = { + "model": "azure/gpt-3.5-turbo", + "messages": [{"gm": "vibe", "role": "user"}], + "retry_policy": retry_policy, + "mock_response": "Exception: content_filter_policy", + } + try: + if sync_mode: + completion(**data) + else: + await completion(**data) + except Exception as e: + print(e) + + mock_completion_with_retries.assert_called_once() + assert ( + mock_completion_with_retries.call_args.kwargs["num_retries"] == retry_number + ) + assert retry_policy.ContentPolicyViolationErrorRetries == retry_number + + +@pytest.mark.asyncio +@pytest.mark.parametrize("sync_mode", [True, False]) +async def test_completion_with_retry_policy_no_error(sync_mode): + """ + Test that the completion function does not throw an error when the retry policy is set + """ + from unittest.mock import patch, MagicMock, AsyncMock + from litellm.types.router import RetryPolicy + + retry_number = 1 + retry_policy = RetryPolicy( + ContentPolicyViolationErrorRetries=retry_number, # run 3 retries for ContentPolicyViolationErrors + AuthenticationErrorRetries=0, # run 0 retries for AuthenticationErrorRetries + ) + + data = { + "model": "gpt-3.5-turbo", + "messages": [{"gm": "vibe", "role": "user"}], + "retry_policy": retry_policy, + } + try: + if sync_mode: + completion(**data) + else: + await completion(**data) + except Exception as e: + print(e) diff --git a/tests/local_testing/test_presidio_masking.py b/tests/local_testing/test_presidio_masking.py index 0f96da3348..c9d1adc9b2 100644 --- a/tests/local_testing/test_presidio_masking.py +++ b/tests/local_testing/test_presidio_masking.py @@ -19,12 +19,13 @@ sys.path.insert( from unittest.mock import AsyncMock, MagicMock, patch import pytest - import litellm from litellm import Router, mock_completion from litellm.caching.caching import DualCache from litellm.proxy._types import UserAPIKeyAuth -from litellm.proxy.hooks.presidio_pii_masking import _OPTIONAL_PresidioPIIMasking +from litellm.proxy.guardrails.guardrail_hooks.presidio import ( + _OPTIONAL_PresidioPIIMasking, +) from litellm.proxy.utils import ProxyLogging @@ -63,6 +64,7 @@ async def test_output_parsing(): - have presidio pii masking - output parse message - assert that no masked tokens are in the input message """ + litellm.set_verbose = True litellm.output_parse_pii = True pii_masking = _OPTIONAL_PresidioPIIMasking(mock_testing=True) @@ -206,6 +208,8 @@ async def test_presidio_pii_masking_input_b(): @pytest.mark.asyncio async def test_presidio_pii_masking_logging_output_only_no_pre_api_hook(): + from litellm.types.guardrails import GuardrailEventHooks + pii_masking = _OPTIONAL_PresidioPIIMasking( logging_only=True, mock_testing=True, @@ -223,22 +227,29 @@ async def test_presidio_pii_masking_logging_output_only_no_pre_api_hook(): } ] - new_data = await pii_masking.async_pre_call_hook( - user_api_key_dict=user_api_key_dict, - cache=local_cache, - data={"messages": test_messages}, - call_type="completion", + assert ( + pii_masking.should_run_guardrail( + data={"messages": test_messages}, + event_type=GuardrailEventHooks.pre_call, + ) + is False ) - assert "Jane Doe" in new_data["messages"][0]["content"] - +@pytest.mark.parametrize( + "sync_mode", + [True, False], +) @pytest.mark.asyncio -async def test_presidio_pii_masking_logging_output_only_logged_response(): +async def test_presidio_pii_masking_logging_output_only_logged_response(sync_mode): + import litellm + + litellm.set_verbose = True pii_masking = _OPTIONAL_PresidioPIIMasking( logging_only=True, mock_testing=True, mock_redacted_text=input_b_anonymizer_results, + guardrail_name="presidio", ) test_messages = [ @@ -247,15 +258,26 @@ async def test_presidio_pii_masking_logging_output_only_logged_response(): "content": "My name is Jane Doe, who are you? Say my name in your response", } ] - with patch.object( - pii_masking, "async_log_success_event", new=AsyncMock() - ) as mock_call: - litellm.callbacks = [pii_masking] - response = await litellm.acompletion( - model="gpt-3.5-turbo", messages=test_messages, mock_response="Hi Peter!" - ) - await asyncio.sleep(3) + if sync_mode: + target_function = "log_success_event" + mock_call = MagicMock() + else: + target_function = "async_log_success_event" + mock_call = AsyncMock() + + with patch.object(pii_masking, target_function, new=mock_call) as mock_call: + litellm.callbacks = [pii_masking] + if sync_mode: + response = litellm.completion( + model="gpt-3.5-turbo", messages=test_messages, mock_response="Hi Peter!" + ) + time.sleep(3) + else: + response = await litellm.acompletion( + model="gpt-3.5-turbo", messages=test_messages, mock_response="Hi Peter!" + ) + await asyncio.sleep(3) assert response.choices[0].message.content == "Hi Peter!" # type: ignore @@ -275,8 +297,13 @@ async def test_presidio_pii_masking_logging_output_only_logged_response_guardrai import litellm from litellm.proxy.guardrails.init_guardrails import initialize_guardrails - from litellm.types.guardrails import GuardrailItem, GuardrailItemSpec + from litellm.types.guardrails import ( + GuardrailItem, + GuardrailItemSpec, + GuardrailEventHooks, + ) + litellm.set_verbose = True os.environ["PRESIDIO_ANALYZER_API_BASE"] = "http://localhost:5002" os.environ["PRESIDIO_ANONYMIZER_API_BASE"] = "http://localhost:5001" @@ -303,10 +330,15 @@ async def test_presidio_pii_masking_logging_output_only_logged_response_guardrai pii_masking_obj: Optional[_OPTIONAL_PresidioPIIMasking] = None for callback in litellm.callbacks: + print(f"CALLBACK: {callback}") if isinstance(callback, _OPTIONAL_PresidioPIIMasking): pii_masking_obj = callback assert pii_masking_obj is not None assert hasattr(pii_masking_obj, "logging_only") - assert pii_masking_obj.logging_only is True + assert pii_masking_obj.event_hook == GuardrailEventHooks.logging_only + + assert pii_masking_obj.should_run_guardrail( + data={}, event_type=GuardrailEventHooks.logging_only + ) diff --git a/tests/proxy_unit_tests/test_jwt.py b/tests/proxy_unit_tests/test_jwt.py index c07394962f..c68878dcbe 100644 --- a/tests/proxy_unit_tests/test_jwt.py +++ b/tests/proxy_unit_tests/test_jwt.py @@ -1026,3 +1026,89 @@ def test_get_public_key_from_jwk_url(): assert public_key is not None assert public_key == jwk_response[0] + + +@pytest.mark.asyncio +async def test_end_user_jwt_auth(monkeypatch): + import litellm + from litellm.proxy.auth.handle_jwt import JWTHandler + from litellm.caching import DualCache + from litellm.proxy._types import LiteLLM_JWTAuth + from litellm.proxy.proxy_server import user_api_key_auth + + monkeypatch.delenv("JWT_AUDIENCE", None) + jwt_handler = JWTHandler() + + litellm_jwtauth = LiteLLM_JWTAuth( + end_user_id_jwt_field="sub", + ) + + cache = DualCache() + + keys = [ + { + "kid": "d-1733370597545", + "alg": "RS256", + "kty": "RSA", + "use": "sig", + "n": "j5Ik60FJSUIPMVdMSU8vhYvyPPH7rUsUllNI0BfBlIkgEYFk2mg4KX1XDQc6mcKFjbq9k_7TSkHWKnfPhNkkb0MdmZLKbwTrmte2k8xWDxp-rSmZpIJwC1zuPDx5joLHBgIb09-K2cPL2rvwzP75WtOr_QLXBerHAbXx8cOdI7mrSRWJ9iXbKv_pLDnZHnGNld75tztb8nCtgrywkF010jGi1xxaT8UKsTvK-QkIBkYI6m6WR9LMnG2OZm-ExuvNPUenfYUsqnynPF4SMNZwyQqJfavSLKI8uMzB2s9pcbx5HfQwIOYhMlgBHjhdDn2IUSnXSJqSsN6RQO18M2rgPQ", + "e": "AQAB", + }, + { + "kid": "s-f836dd32-ef71-426a-8804-946a7f230bc9", + "alg": "RS256", + "kty": "RSA", + "use": "sig", + "n": "2A5-ZA18YKn7M4OtxsfXBc3Z7n2WyHTxbK4GEBlmD9T9TDr4sbJaI4oHfTvzsAC3H2r2YkASzrCISXMXQJjLHoeLgDVcKs8qTdLj7K5FNT9fA0kU9ayUjSGrqkz57SG7oNf9Wp__Qa-H-bs6Z8_CEfBy0JA9QSHUfrdOXp4vCB_qLn6DE0DJH9ELAq_0nktVQk_oxlvXlGtVZSZe31mNNgiD__RJMogf-SIFcYOkMLVGTTEBYiCk1mHxXS6oJZaVSWiBgHzu5wkra5AfQLUVelQaupT5H81hFPmiceEApf_2DacnqqRV4-Nl8sjhJtuTXiprVS2Z5r2pOMz_kVGNgw", + "e": "AQAB", + }, + ] + + cache.set_cache( + key="litellm_jwt_auth_keys", + value=keys, + ) + + jwt_handler.update_environment( + prisma_client=None, + user_api_key_cache=cache, + litellm_jwtauth=litellm_jwtauth, + leeway=100000000000000, + ) + + token = "eyJraWQiOiJkLTE3MzMzNzA1OTc1NDUiLCJ0eXAiOiJKV1QiLCJ2ZXJzaW9uIjoiNCIsImFsZyI6IlJTMjU2In0.eyJpYXQiOjE3MzM1MDcyNzcsImV4cCI6MTczMzUwNzg3Nywic3ViIjoiODFiM2U1MmEtNjdhNi00ZWZiLTk2NDUtNzA1MjdlMTAxNDc5IiwidElkIjoicHVibGljIiwic2Vzc2lvbkhhbmRsZSI6Ijg4M2Y4YWFmLWUwOTEtNGE1Ny04YTJhLTRiMjcwMmZhZjMzYyIsInJlZnJlc2hUb2tlbkhhc2gxIjoiNDVhNDRhYjlmOTMwMGQyOTY4ZjkxODZkYWQ0YTcwY2QwNjk2YzBiNTBmZmUxZmQ4ZTM2YzU1NGU0MWE4ODU0YiIsInBhcmVudFJlZnJlc2hUb2tlbkhhc2gxIjpudWxsLCJhbnRpQ3NyZlRva2VuIjpudWxsLCJpc3MiOiJodHRwOi8vbG9jYWxob3N0OjMwMDEvYXV0aC9zdCIsImxlZ2FjeV9jb21wYW55X2lkIjoxNTI0OTM5LCJsZWdhY3lfaWQiOjM5NzAyNzk1LCJzY29wZSI6WyJza2lsbF91cCJdLCJzdC1ldiI6eyJ0IjoxNzMzNTA3Mjc4NzAwLCJ2IjpmYWxzZX19.XlYrT6dRIjaZKkJtdr7C_UuxajFRbNpA9BnIsny3rxiPVyS8rhIBwxW12tZwgttRywmXrXK-msowFhWU4XdL5Qfe4lwZb2HTbDeGiQPvQTlOjWWYMhgCoKdPtjCQsAcW45rg7aQ0p42JFQPoAQa8AnGfxXpgx2vSR7njiZ3ZZyHerDdKQHyIGSFVOxoK0TgR-hxBVY__Wjg8UTKgKSz9KU_uwnPgpe2DeYmP-LTK2oeoygsVRmbldY_GrrcRe3nqYcUfFkxSs0FSsoSv35jIxiptXfCjhEB1Y5eaJhHEjlYlP2rw98JysYxjO2rZbAdUpL3itPeo3T2uh1NZr_lArw" + + response = await jwt_handler.auth_jwt(token=token) + + assert response is not None + + end_user_id = jwt_handler.get_end_user_id( + token=response, + default_value=None, + ) + + assert end_user_id is not None + + ## CHECK USER API KEY AUTH ## + from starlette.datastructures import URL + + bearer_token = "Bearer " + token + + request = Request(scope={"type": "http"}) + request._url = URL(url="/chat/completions") + + ## 1. INITIAL TEAM CALL - should fail + # use generated key to auth in + setattr( + litellm.proxy.proxy_server, + "general_settings", + { + "enable_jwt_auth": True, + }, + ) + setattr(litellm.proxy.proxy_server, "prisma_client", {}) + setattr(litellm.proxy.proxy_server, "jwt_handler", jwt_handler) + result = await user_api_key_auth(request=request, api_key=bearer_token) + assert ( + result.end_user_id == "81b3e52a-67a6-4efb-9645-70527e101479" + ) # jwt token decoded sub value diff --git a/ui/litellm-dashboard/src/components/budgets/budget_panel.tsx b/ui/litellm-dashboard/src/components/budgets/budget_panel.tsx index edad680b28..604134575d 100644 --- a/ui/litellm-dashboard/src/components/budgets/budget_panel.tsx +++ b/ui/litellm-dashboard/src/components/budgets/budget_panel.tsx @@ -49,6 +49,7 @@ export interface budgetItem { max_budget: string | null; rpm_limit: number | null; tpm_limit: number | null; + updated_at: string; } const BudgetPanel: React.FC = ({ accessToken }) => { @@ -65,13 +66,17 @@ const BudgetPanel: React.FC = ({ accessToken }) => { }); }, [accessToken]); - const handleEditCall = async (budget_id: string, index: number) => { + console.log("budget_id", budget_id) if (accessToken == null) { return; } - setSelectedBudget(budgetList[index]) - setIsEditModalVisible(true) + // Find the budget first + const budget = budgetList.find(budget => budget.budget_id === budget_id) || null; + + // Update state and show modal after state is updated + setSelectedBudget(budget); + setIsEditModalVisible(true); }; const handleDeleteCall = async (budget_id: string, index: number) => { @@ -90,6 +95,15 @@ const BudgetPanel: React.FC = ({ accessToken }) => { message.success("Budget Deleted."); }; + const handleUpdateCall = async () => { + if (accessToken == null) { + return; + } + getBudgetList(accessToken).then((data) => { + setBudgetList(data); + }); + } + return (