diff --git a/litellm/integrations/custom_guardrail.py b/litellm/integrations/custom_guardrail.py
index 3053a4ad1f..816b024c72 100644
--- a/litellm/integrations/custom_guardrail.py
+++ b/litellm/integrations/custom_guardrail.py
@@ -29,6 +29,7 @@ class CustomGuardrail(CustomLogger):
def should_run_guardrail(self, data, event_type: GuardrailEventHooks) -> bool:
metadata = data.get("metadata") or {}
requested_guardrails = metadata.get("guardrails") or []
+
verbose_logger.debug(
"inside should_run_guardrail for guardrail=%s event_type= %s guardrail_supported_event_hooks= %s requested_guardrails= %s",
self.guardrail_name,
diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py
index fed1cc2863..e2fad98387 100644
--- a/litellm/litellm_core_utils/litellm_logging.py
+++ b/litellm/litellm_core_utils/litellm_logging.py
@@ -969,10 +969,10 @@ class Logging:
),
result=result,
)
-
## LOGGING HOOK ##
for callback in callbacks:
if isinstance(callback, CustomLogger):
+
self.model_call_details, result = callback.logging_hook(
kwargs=self.model_call_details,
result=result,
diff --git a/litellm/proxy/_experimental/out/404.html b/litellm/proxy/_experimental/out/404.html
deleted file mode 100644
index 6656e5452d..0000000000
--- a/litellm/proxy/_experimental/out/404.html
+++ /dev/null
@@ -1 +0,0 @@
-
404: This page could not be found.LiteLLM Dashboard404
This page could not be found.
\ No newline at end of file
diff --git a/litellm/proxy/_experimental/out/model_hub.html b/litellm/proxy/_experimental/out/model_hub.html
deleted file mode 100644
index ffebd4c6e2..0000000000
--- a/litellm/proxy/_experimental/out/model_hub.html
+++ /dev/null
@@ -1 +0,0 @@
-LiteLLM Dashboard
\ No newline at end of file
diff --git a/litellm/proxy/_experimental/out/onboarding.html b/litellm/proxy/_experimental/out/onboarding.html
deleted file mode 100644
index f59bdc9d5d..0000000000
--- a/litellm/proxy/_experimental/out/onboarding.html
+++ /dev/null
@@ -1 +0,0 @@
-LiteLLM Dashboard
\ No newline at end of file
diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml
index 599a1bf230..84075f53e0 100644
--- a/litellm/proxy/_new_secret_config.yaml
+++ b/litellm/proxy/_new_secret_config.yaml
@@ -1,48 +1,30 @@
model_list:
- # GPT-4 Turbo Models
- - model_name: gpt-4
- litellm_params:
- model: gpt-4
- rpm: 1
- - model_name: gpt-4
+ - model_name: gpt-3.5-turbo
litellm_params:
model: azure/chatgpt-v-2
api_key: os.environ/AZURE_API_KEY
api_base: os.environ/AZURE_API_BASE
- - model_name: rerank-model
+ temperature: 0.2
+
+guardrails:
+ - guardrail_name: "presidio-log-guard"
litellm_params:
- model: jina_ai/jina-reranker-v2-base-multilingual
- - model_name: anthropic-vertex
- litellm_params:
- model: vertex_ai/claude-3-5-sonnet-v2
- vertex_ai_project: "adroit-crow-413218"
- vertex_ai_location: "us-east5"
- - model_name: openai-gpt-4o-realtime-audio
- litellm_params:
- model: openai/gpt-4o-realtime-preview-2024-10-01
- api_key: os.environ/OPENAI_API_KEY
- - model_name: openai/*
- litellm_params:
- model: openai/*
- api_key: os.environ/OPENAI_API_KEY
- - model_name: openai/*
- litellm_params:
- model: openai/*
- api_key: os.environ/OPENAI_API_KEY
- model_info:
- access_groups: ["public-openai-models"]
- - model_name: openai/gpt-4o
- litellm_params:
- model: openai/gpt-4o
- api_key: os.environ/OPENAI_API_KEY
- model_info:
- access_groups: ["private-openai-models"]
-
-router_settings:
- # routing_strategy: usage-based-routing-v2
- #redis_url: "os.environ/REDIS_URL"
- redis_host: "os.environ/REDIS_HOST"
- redis_port: "os.environ/REDIS_PORT"
+ guardrail: presidio
+ mode: "logging_only"
+ mock_redacted_text:
+ text: "hello world, my name is . My number is: "
+ items:
+ - start: 48
+ end: 62
+ entity_type: PHONE_NUMBER
+ text: ""
+ operator: replace
+ - start: 24
+ end: 32
+ entity_type: PERSON
+ text: ""
+ operator: replace
litellm_settings:
- callbacks: ["datadog"]
\ No newline at end of file
+ set_verbose: true
+ success_callback: ["langfuse"]
\ No newline at end of file
diff --git a/litellm/proxy/auth/handle_jwt.py b/litellm/proxy/auth/handle_jwt.py
index 49d6c3520e..bfcde8fb1e 100644
--- a/litellm/proxy/auth/handle_jwt.py
+++ b/litellm/proxy/auth/handle_jwt.py
@@ -36,16 +36,19 @@ class JWTHandler:
self,
) -> None:
self.http_handler = HTTPHandler()
+ self.leeway = 0
def update_environment(
self,
prisma_client: Optional[PrismaClient],
user_api_key_cache: DualCache,
litellm_jwtauth: LiteLLM_JWTAuth,
+ leeway: int = 0,
) -> None:
self.prisma_client = prisma_client
self.user_api_key_cache = user_api_key_cache
self.litellm_jwtauth = litellm_jwtauth
+ self.leeway = leeway
def is_jwt(self, token: str):
parts = token.split(".")
@@ -271,6 +274,7 @@ class JWTHandler:
algorithms=algorithms,
options=decode_options,
audience=audience,
+ leeway=self.leeway, # allow testing of expired tokens
)
return payload
diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py
index e9f2fc6d0c..ae6d4774be 100644
--- a/litellm/proxy/auth/user_api_key_auth.py
+++ b/litellm/proxy/auth/user_api_key_auth.py
@@ -562,6 +562,7 @@ async def user_api_key_auth( # noqa: PLR0915
user_id=user_id,
org_id=org_id,
parent_otel_span=parent_otel_span,
+ end_user_id=end_user_id,
)
#### ELSE ####
## CHECK PASS-THROUGH ENDPOINTS ##
diff --git a/litellm/proxy/common_utils/callback_utils.py b/litellm/proxy/common_utils/callback_utils.py
index 40f66e90b6..fa7208d3c0 100644
--- a/litellm/proxy/common_utils/callback_utils.py
+++ b/litellm/proxy/common_utils/callback_utils.py
@@ -48,7 +48,7 @@ def initialize_callbacks_on_proxy( # noqa: PLR0915
imported_list.append(open_telemetry_logger)
setattr(proxy_server, "open_telemetry_logger", open_telemetry_logger)
elif isinstance(callback, str) and callback == "presidio":
- from litellm.proxy.hooks.presidio_pii_masking import (
+ from litellm.proxy.guardrails.guardrail_hooks.presidio import (
_OPTIONAL_PresidioPIIMasking,
)
diff --git a/litellm/proxy/guardrails/guardrail_hooks/presidio.py b/litellm/proxy/guardrails/guardrail_hooks/presidio.py
index da53e4a8ae..384b2cb999 100644
--- a/litellm/proxy/guardrails/guardrail_hooks/presidio.py
+++ b/litellm/proxy/guardrails/guardrail_hooks/presidio.py
@@ -24,6 +24,7 @@ from litellm._logging import verbose_proxy_logger
from litellm.caching.caching import DualCache
from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.proxy._types import UserAPIKeyAuth
+from litellm.types.guardrails import GuardrailEventHooks
from litellm.utils import (
EmbeddingResponse,
ImageResponse,
@@ -54,8 +55,13 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
presidio_anonymizer_api_base: Optional[str] = None,
output_parse_pii: Optional[bool] = False,
presidio_ad_hoc_recognizers: Optional[str] = None,
+ logging_only: Optional[bool] = None,
**kwargs,
):
+ if logging_only is True:
+ self.logging_only = True
+ kwargs["event_hook"] = GuardrailEventHooks.logging_only
+ super().__init__(**kwargs)
self.pii_tokens: dict = (
{}
) # mapping of PII token to original text - only used with Presidio `replace` operation
@@ -84,8 +90,6 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
presidio_anonymizer_api_base=presidio_anonymizer_api_base,
)
- super().__init__(**kwargs)
-
def validate_environment(
self,
presidio_analyzer_api_base: Optional[str] = None,
@@ -245,10 +249,44 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
verbose_proxy_logger.info(
f"Presidio PII Masking: Redacted pii message: {data['messages']}"
)
+ data["messages"] = messages
return data
except Exception as e:
raise e
+ def logging_hook(
+ self, kwargs: dict, result: Any, call_type: str
+ ) -> Tuple[dict, Any]:
+ import threading
+ from concurrent.futures import ThreadPoolExecutor
+
+ def run_in_new_loop():
+ """Run the coroutine in a new event loop within this thread."""
+ new_loop = asyncio.new_event_loop()
+ try:
+ asyncio.set_event_loop(new_loop)
+ return new_loop.run_until_complete(
+ self.async_logging_hook(
+ kwargs=kwargs, result=result, call_type=call_type
+ )
+ )
+ finally:
+ new_loop.close()
+ asyncio.set_event_loop(None)
+
+ try:
+ # First, try to get the current event loop
+ _ = asyncio.get_running_loop()
+ # If we're already in an event loop, run in a separate thread
+ # to avoid nested event loop issues
+ with ThreadPoolExecutor(max_workers=1) as executor:
+ future = executor.submit(run_in_new_loop)
+ return future.result()
+
+ except RuntimeError:
+ # No running event loop, we can safely run in this thread
+ return run_in_new_loop()
+
async def async_logging_hook(
self, kwargs: dict, result: Any, call_type: str
) -> Tuple[dict, Any]:
@@ -304,7 +342,8 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
verbose_proxy_logger.debug(
f"PII Masking Args: self.output_parse_pii={self.output_parse_pii}; type of response={type(response)}"
)
- if self.output_parse_pii is False:
+
+ if self.output_parse_pii is False and litellm.output_parse_pii is False:
return response
if isinstance(response, ModelResponse) and not isinstance(
diff --git a/litellm/proxy/hooks/presidio_pii_masking.py b/litellm/proxy/hooks/presidio_pii_masking.py
deleted file mode 100644
index 603e075620..0000000000
--- a/litellm/proxy/hooks/presidio_pii_masking.py
+++ /dev/null
@@ -1,349 +0,0 @@
-# +-----------------------------------------------+
-# | |
-# | PII Masking |
-# | with Microsoft Presidio |
-# | https://github.com/BerriAI/litellm/issues/ |
-# +-----------------------------------------------+
-#
-# Tell us how we can improve! - Krrish & Ishaan
-
-
-import asyncio
-import json
-import traceback
-import uuid
-from typing import Any, List, Optional, Tuple, Union
-
-import aiohttp
-from fastapi import HTTPException
-
-import litellm # noqa: E401
-from litellm._logging import verbose_proxy_logger
-from litellm.caching.caching import DualCache
-from litellm.integrations.custom_logger import CustomLogger
-from litellm.proxy._types import UserAPIKeyAuth
-from litellm.utils import (
- EmbeddingResponse,
- ImageResponse,
- ModelResponse,
- StreamingChoices,
- get_formatted_prompt,
-)
-
-
-class _OPTIONAL_PresidioPIIMasking(CustomLogger):
- user_api_key_cache = None
- ad_hoc_recognizers = None
-
- # Class variables or attributes
- def __init__(
- self,
- logging_only: Optional[bool] = None,
- mock_testing: bool = False,
- mock_redacted_text: Optional[dict] = None,
- ):
- self.pii_tokens: dict = (
- {}
- ) # mapping of PII token to original text - only used with Presidio `replace` operation
-
- self.mock_redacted_text = mock_redacted_text
- self.logging_only = logging_only
- if mock_testing is True: # for testing purposes only
- return
-
- ad_hoc_recognizers = litellm.presidio_ad_hoc_recognizers
- if ad_hoc_recognizers is not None:
- try:
- with open(ad_hoc_recognizers, "r") as file:
- self.ad_hoc_recognizers = json.load(file)
- except FileNotFoundError:
- raise Exception(f"File not found. file_path={ad_hoc_recognizers}")
- except json.JSONDecodeError as e:
- raise Exception(
- f"Error decoding JSON file: {str(e)}, file_path={ad_hoc_recognizers}"
- )
- except Exception as e:
- raise Exception(
- f"An error occurred: {str(e)}, file_path={ad_hoc_recognizers}"
- )
-
- self.validate_environment()
-
- def validate_environment(self):
- self.presidio_analyzer_api_base: Optional[str] = litellm.get_secret(
- "PRESIDIO_ANALYZER_API_BASE", None
- ) # type: ignore
- self.presidio_anonymizer_api_base: Optional[str] = litellm.get_secret(
- "PRESIDIO_ANONYMIZER_API_BASE", None
- ) # type: ignore
-
- if self.presidio_analyzer_api_base is None:
- raise Exception("Missing `PRESIDIO_ANALYZER_API_BASE` from environment")
- if not self.presidio_analyzer_api_base.endswith("/"):
- self.presidio_analyzer_api_base += "/"
- if not (
- self.presidio_analyzer_api_base.startswith("http://")
- or self.presidio_analyzer_api_base.startswith("https://")
- ):
- # add http:// if unset, assume communicating over private network - e.g. render
- self.presidio_analyzer_api_base = (
- "http://" + self.presidio_analyzer_api_base
- )
-
- if self.presidio_anonymizer_api_base is None:
- raise Exception("Missing `PRESIDIO_ANONYMIZER_API_BASE` from environment")
- if not self.presidio_anonymizer_api_base.endswith("/"):
- self.presidio_anonymizer_api_base += "/"
- if not (
- self.presidio_anonymizer_api_base.startswith("http://")
- or self.presidio_anonymizer_api_base.startswith("https://")
- ):
- # add http:// if unset, assume communicating over private network - e.g. render
- self.presidio_anonymizer_api_base = (
- "http://" + self.presidio_anonymizer_api_base
- )
-
- def print_verbose(self, print_statement):
- try:
- verbose_proxy_logger.debug(print_statement)
- if litellm.set_verbose:
- print(print_statement) # noqa
- except Exception:
- pass
-
- async def check_pii(self, text: str, output_parse_pii: bool) -> str:
- """
- [TODO] make this more performant for high-throughput scenario
- """
- try:
- async with aiohttp.ClientSession() as session:
- if self.mock_redacted_text is not None:
- redacted_text = self.mock_redacted_text
- else:
- # Make the first request to /analyze
- analyze_url = f"{self.presidio_analyzer_api_base}analyze"
- verbose_proxy_logger.debug("Making request to: %s", analyze_url)
- analyze_payload = {"text": text, "language": "en"}
- if self.ad_hoc_recognizers is not None:
- analyze_payload["ad_hoc_recognizers"] = self.ad_hoc_recognizers
- redacted_text = None
- async with session.post(
- analyze_url, json=analyze_payload
- ) as response:
- analyze_results = await response.json()
-
- # Make the second request to /anonymize
- anonymize_url = f"{self.presidio_anonymizer_api_base}anonymize"
- verbose_proxy_logger.debug("Making request to: %s", anonymize_url)
- anonymize_payload = {
- "text": text,
- "analyzer_results": analyze_results,
- }
-
- async with session.post(
- anonymize_url, json=anonymize_payload
- ) as response:
- redacted_text = await response.json()
-
- new_text = text
- if redacted_text is not None:
- verbose_proxy_logger.debug("redacted_text: %s", redacted_text)
- for item in redacted_text["items"]:
- start = item["start"]
- end = item["end"]
- replacement = item["text"] # replacement token
- if item["operator"] == "replace" and output_parse_pii is True:
- # check if token in dict
- # if exists, add a uuid to the replacement token for swapping back to the original text in llm response output parsing
- if replacement in self.pii_tokens:
- replacement = replacement + str(uuid.uuid4())
-
- self.pii_tokens[replacement] = new_text[
- start:end
- ] # get text it'll replace
-
- new_text = new_text[:start] + replacement + new_text[end:]
- return redacted_text["text"]
- else:
- raise Exception(f"Invalid anonymizer response: {redacted_text}")
- except Exception as e:
- verbose_proxy_logger.error(
- "litellm.proxy.hooks.presidio_pii_masking.py::async_pre_call_hook(): Exception occured - {}".format(
- str(e)
- )
- )
- verbose_proxy_logger.debug(traceback.format_exc())
- raise e
-
- async def async_pre_call_hook(
- self,
- user_api_key_dict: UserAPIKeyAuth,
- cache: DualCache,
- data: dict,
- call_type: str,
- ):
- """
- - Check if request turned off pii
- - Check if user allowed to turn off pii (key permissions -> 'allow_pii_controls')
-
- - Take the request data
- - Call /analyze -> get the results
- - Call /anonymize w/ the analyze results -> get the redacted text
-
- For multiple messages in /chat/completions, we'll need to call them in parallel.
- """
- try:
- if (
- self.logging_only is True
- ): # only modify the logging obj data (done by async_logging_hook)
- return data
- permissions = user_api_key_dict.permissions
- output_parse_pii = permissions.get(
- "output_parse_pii", litellm.output_parse_pii
- ) # allow key to turn on/off output parsing for pii
- no_pii = permissions.get(
- "no-pii", None
- ) # allow key to turn on/off pii masking (if user is allowed to set pii controls, then they can override the key defaults)
-
- if no_pii is None:
- # check older way of turning on/off pii
- no_pii = not permissions.get("pii", True)
-
- content_safety = data.get("content_safety", None)
- verbose_proxy_logger.debug("content_safety: %s", content_safety)
- ## Request-level turn on/off PII controls ##
- if content_safety is not None and isinstance(content_safety, dict):
- # pii masking ##
- if (
- content_safety.get("no-pii", None) is not None
- and content_safety.get("no-pii") is True
- ):
- # check if user allowed to turn this off
- if permissions.get("allow_pii_controls", False) is False:
- raise HTTPException(
- status_code=400,
- detail={
- "error": "Not allowed to set PII controls per request"
- },
- )
- else: # user allowed to turn off pii masking
- no_pii = content_safety.get("no-pii")
- if not isinstance(no_pii, bool):
- raise HTTPException(
- status_code=400,
- detail={"error": "no_pii needs to be a boolean value"},
- )
- ## pii output parsing ##
- if content_safety.get("output_parse_pii", None) is not None:
- # check if user allowed to turn this off
- if permissions.get("allow_pii_controls", False) is False:
- raise HTTPException(
- status_code=400,
- detail={
- "error": "Not allowed to set PII controls per request"
- },
- )
- else: # user allowed to turn on/off pii output parsing
- output_parse_pii = content_safety.get("output_parse_pii")
- if not isinstance(output_parse_pii, bool):
- raise HTTPException(
- status_code=400,
- detail={
- "error": "output_parse_pii needs to be a boolean value"
- },
- )
-
- if no_pii is True: # turn off pii masking
- return data
-
- if call_type == "completion": # /chat/completions requests
- messages = data["messages"]
- tasks = []
-
- for m in messages:
- if isinstance(m["content"], str):
- tasks.append(
- self.check_pii(
- text=m["content"], output_parse_pii=output_parse_pii
- )
- )
- responses = await asyncio.gather(*tasks)
- for index, r in enumerate(responses):
- if isinstance(messages[index]["content"], str):
- messages[index][
- "content"
- ] = r # replace content with redacted string
- verbose_proxy_logger.info(
- f"Presidio PII Masking: Redacted pii message: {data['messages']}"
- )
- return data
- except Exception as e:
- verbose_proxy_logger.info(
- "An error occurred -",
- )
- raise e
-
- async def async_logging_hook(
- self, kwargs: dict, result: Any, call_type: str
- ) -> Tuple[dict, Any]:
- """
- Masks the input before logging to langfuse, datadog, etc.
- """
- if (
- call_type == "completion" or call_type == "acompletion"
- ): # /chat/completions requests
- messages: Optional[List] = kwargs.get("messages", None)
- tasks = []
-
- if messages is None:
- return kwargs, result
-
- for m in messages:
- text_str = ""
- if m["content"] is None:
- continue
- if isinstance(m["content"], str):
- text_str = m["content"]
- tasks.append(
- self.check_pii(text=text_str, output_parse_pii=False)
- ) # need to pass separately b/c presidio has context window limits
- responses = await asyncio.gather(*tasks)
- for index, r in enumerate(responses):
- if isinstance(messages[index]["content"], str):
- messages[index][
- "content"
- ] = r # replace content with redacted string
- verbose_proxy_logger.info(
- f"Presidio PII Masking: Redacted pii message: {messages}"
- )
- kwargs["messages"] = messages
-
- return kwargs, result
-
- async def async_post_call_success_hook(
- self,
- data: dict,
- user_api_key_dict: UserAPIKeyAuth,
- response: Union[ModelResponse, EmbeddingResponse, ImageResponse],
- ):
- """
- Output parse the response object to replace the masked tokens with user sent values
- """
- verbose_proxy_logger.debug(
- f"PII Masking Args: litellm.output_parse_pii={litellm.output_parse_pii}; type of response={type(response)}"
- )
- if litellm.output_parse_pii is False:
- return response
-
- if isinstance(response, ModelResponse) and not isinstance(
- response.choices[0], StreamingChoices
- ): # /chat/completions requests
- if isinstance(response.choices[0].message.content, str):
- verbose_proxy_logger.debug(
- f"self.pii_tokens: {self.pii_tokens}; initial response: {response.choices[0].message.content}"
- )
- for key, value in self.pii_tokens.items():
- response.choices[0].message.content = response.choices[
- 0
- ].message.content.replace(key, value)
- return response
diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py
index a72297ac3e..93df33d757 100644
--- a/litellm/proxy/proxy_server.py
+++ b/litellm/proxy/proxy_server.py
@@ -5987,6 +5987,39 @@ async def new_budget(
return response
+@router.post(
+ "/budget/update",
+ tags=["budget management"],
+ dependencies=[Depends(user_api_key_auth)],
+)
+async def update_budget(
+ budget_obj: BudgetNew,
+ user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
+):
+ """
+ Create a new budget object. Can apply this to teams, orgs, end-users, keys.
+ """
+ global prisma_client
+
+ if prisma_client is None:
+ raise HTTPException(
+ status_code=500,
+ detail={"error": CommonProxyErrors.db_not_connected_error.value},
+ )
+ if budget_obj.budget_id is None:
+ raise HTTPException(status_code=400, detail={"error": "budget_id is required"})
+
+ response = await prisma_client.db.litellm_budgettable.update(
+ where={"budget_id": budget_obj.budget_id},
+ data={
+ **budget_obj.model_dump(exclude_none=True), # type: ignore
+ "updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
+ }, # type: ignore
+ )
+
+ return response
+
+
@router.post(
"/budget/info",
tags=["budget management"],
diff --git a/litellm/router.py b/litellm/router.py
index 3f21597945..34143bd500 100644
--- a/litellm/router.py
+++ b/litellm/router.py
@@ -83,6 +83,9 @@ from litellm.router_utils.fallback_event_handlers import (
run_async_fallback,
run_sync_fallback,
)
+from litellm.router_utils.get_retry_from_policy import (
+ get_num_retries_from_retry_policy as _get_num_retries_from_retry_policy,
+)
from litellm.router_utils.handle_error import (
async_raise_no_deployment_exception,
send_llm_exception_alert,
@@ -5609,53 +5612,12 @@ class Router:
def get_num_retries_from_retry_policy(
self, exception: Exception, model_group: Optional[str] = None
):
- """
- BadRequestErrorRetries: Optional[int] = None
- AuthenticationErrorRetries: Optional[int] = None
- TimeoutErrorRetries: Optional[int] = None
- RateLimitErrorRetries: Optional[int] = None
- ContentPolicyViolationErrorRetries: Optional[int] = None
- """
- # if we can find the exception then in the retry policy -> return the number of retries
- retry_policy: Optional[RetryPolicy] = self.retry_policy
-
- if (
- self.model_group_retry_policy is not None
- and model_group is not None
- and model_group in self.model_group_retry_policy
- ):
- retry_policy = self.model_group_retry_policy.get(model_group, None) # type: ignore
-
- if retry_policy is None:
- return None
- if isinstance(retry_policy, dict):
- retry_policy = RetryPolicy(**retry_policy)
-
- if (
- isinstance(exception, litellm.BadRequestError)
- and retry_policy.BadRequestErrorRetries is not None
- ):
- return retry_policy.BadRequestErrorRetries
- if (
- isinstance(exception, litellm.AuthenticationError)
- and retry_policy.AuthenticationErrorRetries is not None
- ):
- return retry_policy.AuthenticationErrorRetries
- if (
- isinstance(exception, litellm.Timeout)
- and retry_policy.TimeoutErrorRetries is not None
- ):
- return retry_policy.TimeoutErrorRetries
- if (
- isinstance(exception, litellm.RateLimitError)
- and retry_policy.RateLimitErrorRetries is not None
- ):
- return retry_policy.RateLimitErrorRetries
- if (
- isinstance(exception, litellm.ContentPolicyViolationError)
- and retry_policy.ContentPolicyViolationErrorRetries is not None
- ):
- return retry_policy.ContentPolicyViolationErrorRetries
+ return _get_num_retries_from_retry_policy(
+ exception=exception,
+ model_group=model_group,
+ model_group_retry_policy=self.model_group_retry_policy,
+ retry_policy=self.retry_policy,
+ )
def get_allowed_fails_from_policy(self, exception: Exception):
"""
diff --git a/litellm/router_utils/get_retry_from_policy.py b/litellm/router_utils/get_retry_from_policy.py
new file mode 100644
index 0000000000..48df43ef81
--- /dev/null
+++ b/litellm/router_utils/get_retry_from_policy.py
@@ -0,0 +1,71 @@
+"""
+Get num retries for an exception.
+
+- Account for retry policy by exception type.
+"""
+
+from typing import Dict, Optional, Union
+
+from litellm.exceptions import (
+ AuthenticationError,
+ BadRequestError,
+ ContentPolicyViolationError,
+ RateLimitError,
+ Timeout,
+)
+from litellm.types.router import RetryPolicy
+
+
+def get_num_retries_from_retry_policy(
+ exception: Exception,
+ retry_policy: Optional[Union[RetryPolicy, dict]] = None,
+ model_group: Optional[str] = None,
+ model_group_retry_policy: Optional[Dict[str, RetryPolicy]] = None,
+):
+ """
+ BadRequestErrorRetries: Optional[int] = None
+ AuthenticationErrorRetries: Optional[int] = None
+ TimeoutErrorRetries: Optional[int] = None
+ RateLimitErrorRetries: Optional[int] = None
+ ContentPolicyViolationErrorRetries: Optional[int] = None
+ """
+ # if we can find the exception then in the retry policy -> return the number of retries
+
+ if (
+ model_group_retry_policy is not None
+ and model_group is not None
+ and model_group in model_group_retry_policy
+ ):
+ retry_policy = model_group_retry_policy.get(model_group, None) # type: ignore
+
+ if retry_policy is None:
+ return None
+ if isinstance(retry_policy, dict):
+ retry_policy = RetryPolicy(**retry_policy)
+
+ if (
+ isinstance(exception, BadRequestError)
+ and retry_policy.BadRequestErrorRetries is not None
+ ):
+ return retry_policy.BadRequestErrorRetries
+ if (
+ isinstance(exception, AuthenticationError)
+ and retry_policy.AuthenticationErrorRetries is not None
+ ):
+ return retry_policy.AuthenticationErrorRetries
+ if isinstance(exception, Timeout) and retry_policy.TimeoutErrorRetries is not None:
+ return retry_policy.TimeoutErrorRetries
+ if (
+ isinstance(exception, RateLimitError)
+ and retry_policy.RateLimitErrorRetries is not None
+ ):
+ return retry_policy.RateLimitErrorRetries
+ if (
+ isinstance(exception, ContentPolicyViolationError)
+ and retry_policy.ContentPolicyViolationErrorRetries is not None
+ ):
+ return retry_policy.ContentPolicyViolationErrorRetries
+
+
+def reset_retry_policy() -> RetryPolicy:
+ return RetryPolicy()
diff --git a/litellm/utils.py b/litellm/utils.py
index fe43421732..f11368fc41 100644
--- a/litellm/utils.py
+++ b/litellm/utils.py
@@ -97,6 +97,10 @@ from litellm.litellm_core_utils.rules import Rules
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
from litellm.litellm_core_utils.token_counter import get_modified_max_tokens
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
+from litellm.router_utils.get_retry_from_policy import (
+ get_num_retries_from_retry_policy,
+ reset_retry_policy,
+)
from litellm.secret_managers.main import get_secret
from litellm.types.llms.openai import (
AllMessageValues,
@@ -918,6 +922,14 @@ def client(original_function): # noqa: PLR0915
num_retries = (
kwargs.get("num_retries", None) or litellm.num_retries or None
)
+ if kwargs.get("retry_policy", None):
+ num_retries = get_num_retries_from_retry_policy(
+ exception=e,
+ retry_policy=kwargs.get("retry_policy"),
+ )
+ kwargs["retry_policy"] = (
+ reset_retry_policy()
+ ) # prevent infinite loops
litellm.num_retries = (
None # set retries to None to prevent infinite loops
)
@@ -1137,6 +1149,13 @@ def client(original_function): # noqa: PLR0915
num_retries = (
kwargs.get("num_retries", None) or litellm.num_retries or None
)
+ if kwargs.get("retry_policy", None):
+ num_retries = get_num_retries_from_retry_policy(
+ exception=e,
+ retry_policy=kwargs.get("retry_policy"),
+ )
+ kwargs["retry_policy"] = reset_retry_policy()
+
litellm.num_retries = (
None # set retries to None to prevent infinite loops
)
diff --git a/tests/local_testing/test_completion_with_retries.py b/tests/local_testing/test_completion_with_retries.py
index e59d1d6e13..efb66c40c6 100644
--- a/tests/local_testing/test_completion_with_retries.py
+++ b/tests/local_testing/test_completion_with_retries.py
@@ -61,3 +61,68 @@ def test_completion_with_0_num_retries():
except Exception as e:
print("exception", e)
pass
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("sync_mode", [True, False])
+async def test_completion_with_retry_policy(sync_mode):
+ from unittest.mock import patch, MagicMock, AsyncMock
+ from litellm.types.router import RetryPolicy
+
+ retry_number = 1
+ retry_policy = RetryPolicy(
+ ContentPolicyViolationErrorRetries=retry_number, # run 3 retries for ContentPolicyViolationErrors
+ AuthenticationErrorRetries=0, # run 0 retries for AuthenticationErrorRetries
+ )
+
+ target_function = "completion_with_retries"
+
+ with patch.object(litellm, target_function) as mock_completion_with_retries:
+ data = {
+ "model": "azure/gpt-3.5-turbo",
+ "messages": [{"gm": "vibe", "role": "user"}],
+ "retry_policy": retry_policy,
+ "mock_response": "Exception: content_filter_policy",
+ }
+ try:
+ if sync_mode:
+ completion(**data)
+ else:
+ await completion(**data)
+ except Exception as e:
+ print(e)
+
+ mock_completion_with_retries.assert_called_once()
+ assert (
+ mock_completion_with_retries.call_args.kwargs["num_retries"] == retry_number
+ )
+ assert retry_policy.ContentPolicyViolationErrorRetries == retry_number
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("sync_mode", [True, False])
+async def test_completion_with_retry_policy_no_error(sync_mode):
+ """
+ Test that the completion function does not throw an error when the retry policy is set
+ """
+ from unittest.mock import patch, MagicMock, AsyncMock
+ from litellm.types.router import RetryPolicy
+
+ retry_number = 1
+ retry_policy = RetryPolicy(
+ ContentPolicyViolationErrorRetries=retry_number, # run 3 retries for ContentPolicyViolationErrors
+ AuthenticationErrorRetries=0, # run 0 retries for AuthenticationErrorRetries
+ )
+
+ data = {
+ "model": "gpt-3.5-turbo",
+ "messages": [{"gm": "vibe", "role": "user"}],
+ "retry_policy": retry_policy,
+ }
+ try:
+ if sync_mode:
+ completion(**data)
+ else:
+ await completion(**data)
+ except Exception as e:
+ print(e)
diff --git a/tests/local_testing/test_presidio_masking.py b/tests/local_testing/test_presidio_masking.py
index 0f96da3348..c9d1adc9b2 100644
--- a/tests/local_testing/test_presidio_masking.py
+++ b/tests/local_testing/test_presidio_masking.py
@@ -19,12 +19,13 @@ sys.path.insert(
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
-
import litellm
from litellm import Router, mock_completion
from litellm.caching.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth
-from litellm.proxy.hooks.presidio_pii_masking import _OPTIONAL_PresidioPIIMasking
+from litellm.proxy.guardrails.guardrail_hooks.presidio import (
+ _OPTIONAL_PresidioPIIMasking,
+)
from litellm.proxy.utils import ProxyLogging
@@ -63,6 +64,7 @@ async def test_output_parsing():
- have presidio pii masking - output parse message
- assert that no masked tokens are in the input message
"""
+ litellm.set_verbose = True
litellm.output_parse_pii = True
pii_masking = _OPTIONAL_PresidioPIIMasking(mock_testing=True)
@@ -206,6 +208,8 @@ async def test_presidio_pii_masking_input_b():
@pytest.mark.asyncio
async def test_presidio_pii_masking_logging_output_only_no_pre_api_hook():
+ from litellm.types.guardrails import GuardrailEventHooks
+
pii_masking = _OPTIONAL_PresidioPIIMasking(
logging_only=True,
mock_testing=True,
@@ -223,22 +227,29 @@ async def test_presidio_pii_masking_logging_output_only_no_pre_api_hook():
}
]
- new_data = await pii_masking.async_pre_call_hook(
- user_api_key_dict=user_api_key_dict,
- cache=local_cache,
- data={"messages": test_messages},
- call_type="completion",
+ assert (
+ pii_masking.should_run_guardrail(
+ data={"messages": test_messages},
+ event_type=GuardrailEventHooks.pre_call,
+ )
+ is False
)
- assert "Jane Doe" in new_data["messages"][0]["content"]
-
+@pytest.mark.parametrize(
+ "sync_mode",
+ [True, False],
+)
@pytest.mark.asyncio
-async def test_presidio_pii_masking_logging_output_only_logged_response():
+async def test_presidio_pii_masking_logging_output_only_logged_response(sync_mode):
+ import litellm
+
+ litellm.set_verbose = True
pii_masking = _OPTIONAL_PresidioPIIMasking(
logging_only=True,
mock_testing=True,
mock_redacted_text=input_b_anonymizer_results,
+ guardrail_name="presidio",
)
test_messages = [
@@ -247,15 +258,26 @@ async def test_presidio_pii_masking_logging_output_only_logged_response():
"content": "My name is Jane Doe, who are you? Say my name in your response",
}
]
- with patch.object(
- pii_masking, "async_log_success_event", new=AsyncMock()
- ) as mock_call:
- litellm.callbacks = [pii_masking]
- response = await litellm.acompletion(
- model="gpt-3.5-turbo", messages=test_messages, mock_response="Hi Peter!"
- )
- await asyncio.sleep(3)
+ if sync_mode:
+ target_function = "log_success_event"
+ mock_call = MagicMock()
+ else:
+ target_function = "async_log_success_event"
+ mock_call = AsyncMock()
+
+ with patch.object(pii_masking, target_function, new=mock_call) as mock_call:
+ litellm.callbacks = [pii_masking]
+ if sync_mode:
+ response = litellm.completion(
+ model="gpt-3.5-turbo", messages=test_messages, mock_response="Hi Peter!"
+ )
+ time.sleep(3)
+ else:
+ response = await litellm.acompletion(
+ model="gpt-3.5-turbo", messages=test_messages, mock_response="Hi Peter!"
+ )
+ await asyncio.sleep(3)
assert response.choices[0].message.content == "Hi Peter!" # type: ignore
@@ -275,8 +297,13 @@ async def test_presidio_pii_masking_logging_output_only_logged_response_guardrai
import litellm
from litellm.proxy.guardrails.init_guardrails import initialize_guardrails
- from litellm.types.guardrails import GuardrailItem, GuardrailItemSpec
+ from litellm.types.guardrails import (
+ GuardrailItem,
+ GuardrailItemSpec,
+ GuardrailEventHooks,
+ )
+ litellm.set_verbose = True
os.environ["PRESIDIO_ANALYZER_API_BASE"] = "http://localhost:5002"
os.environ["PRESIDIO_ANONYMIZER_API_BASE"] = "http://localhost:5001"
@@ -303,10 +330,15 @@ async def test_presidio_pii_masking_logging_output_only_logged_response_guardrai
pii_masking_obj: Optional[_OPTIONAL_PresidioPIIMasking] = None
for callback in litellm.callbacks:
+ print(f"CALLBACK: {callback}")
if isinstance(callback, _OPTIONAL_PresidioPIIMasking):
pii_masking_obj = callback
assert pii_masking_obj is not None
assert hasattr(pii_masking_obj, "logging_only")
- assert pii_masking_obj.logging_only is True
+ assert pii_masking_obj.event_hook == GuardrailEventHooks.logging_only
+
+ assert pii_masking_obj.should_run_guardrail(
+ data={}, event_type=GuardrailEventHooks.logging_only
+ )
diff --git a/tests/proxy_unit_tests/test_jwt.py b/tests/proxy_unit_tests/test_jwt.py
index c07394962f..c68878dcbe 100644
--- a/tests/proxy_unit_tests/test_jwt.py
+++ b/tests/proxy_unit_tests/test_jwt.py
@@ -1026,3 +1026,89 @@ def test_get_public_key_from_jwk_url():
assert public_key is not None
assert public_key == jwk_response[0]
+
+
+@pytest.mark.asyncio
+async def test_end_user_jwt_auth(monkeypatch):
+ import litellm
+ from litellm.proxy.auth.handle_jwt import JWTHandler
+ from litellm.caching import DualCache
+ from litellm.proxy._types import LiteLLM_JWTAuth
+ from litellm.proxy.proxy_server import user_api_key_auth
+
+ monkeypatch.delenv("JWT_AUDIENCE", None)
+ jwt_handler = JWTHandler()
+
+ litellm_jwtauth = LiteLLM_JWTAuth(
+ end_user_id_jwt_field="sub",
+ )
+
+ cache = DualCache()
+
+ keys = [
+ {
+ "kid": "d-1733370597545",
+ "alg": "RS256",
+ "kty": "RSA",
+ "use": "sig",
+ "n": "j5Ik60FJSUIPMVdMSU8vhYvyPPH7rUsUllNI0BfBlIkgEYFk2mg4KX1XDQc6mcKFjbq9k_7TSkHWKnfPhNkkb0MdmZLKbwTrmte2k8xWDxp-rSmZpIJwC1zuPDx5joLHBgIb09-K2cPL2rvwzP75WtOr_QLXBerHAbXx8cOdI7mrSRWJ9iXbKv_pLDnZHnGNld75tztb8nCtgrywkF010jGi1xxaT8UKsTvK-QkIBkYI6m6WR9LMnG2OZm-ExuvNPUenfYUsqnynPF4SMNZwyQqJfavSLKI8uMzB2s9pcbx5HfQwIOYhMlgBHjhdDn2IUSnXSJqSsN6RQO18M2rgPQ",
+ "e": "AQAB",
+ },
+ {
+ "kid": "s-f836dd32-ef71-426a-8804-946a7f230bc9",
+ "alg": "RS256",
+ "kty": "RSA",
+ "use": "sig",
+ "n": "2A5-ZA18YKn7M4OtxsfXBc3Z7n2WyHTxbK4GEBlmD9T9TDr4sbJaI4oHfTvzsAC3H2r2YkASzrCISXMXQJjLHoeLgDVcKs8qTdLj7K5FNT9fA0kU9ayUjSGrqkz57SG7oNf9Wp__Qa-H-bs6Z8_CEfBy0JA9QSHUfrdOXp4vCB_qLn6DE0DJH9ELAq_0nktVQk_oxlvXlGtVZSZe31mNNgiD__RJMogf-SIFcYOkMLVGTTEBYiCk1mHxXS6oJZaVSWiBgHzu5wkra5AfQLUVelQaupT5H81hFPmiceEApf_2DacnqqRV4-Nl8sjhJtuTXiprVS2Z5r2pOMz_kVGNgw",
+ "e": "AQAB",
+ },
+ ]
+
+ cache.set_cache(
+ key="litellm_jwt_auth_keys",
+ value=keys,
+ )
+
+ jwt_handler.update_environment(
+ prisma_client=None,
+ user_api_key_cache=cache,
+ litellm_jwtauth=litellm_jwtauth,
+ leeway=100000000000000,
+ )
+
+ token = "eyJraWQiOiJkLTE3MzMzNzA1OTc1NDUiLCJ0eXAiOiJKV1QiLCJ2ZXJzaW9uIjoiNCIsImFsZyI6IlJTMjU2In0.eyJpYXQiOjE3MzM1MDcyNzcsImV4cCI6MTczMzUwNzg3Nywic3ViIjoiODFiM2U1MmEtNjdhNi00ZWZiLTk2NDUtNzA1MjdlMTAxNDc5IiwidElkIjoicHVibGljIiwic2Vzc2lvbkhhbmRsZSI6Ijg4M2Y4YWFmLWUwOTEtNGE1Ny04YTJhLTRiMjcwMmZhZjMzYyIsInJlZnJlc2hUb2tlbkhhc2gxIjoiNDVhNDRhYjlmOTMwMGQyOTY4ZjkxODZkYWQ0YTcwY2QwNjk2YzBiNTBmZmUxZmQ4ZTM2YzU1NGU0MWE4ODU0YiIsInBhcmVudFJlZnJlc2hUb2tlbkhhc2gxIjpudWxsLCJhbnRpQ3NyZlRva2VuIjpudWxsLCJpc3MiOiJodHRwOi8vbG9jYWxob3N0OjMwMDEvYXV0aC9zdCIsImxlZ2FjeV9jb21wYW55X2lkIjoxNTI0OTM5LCJsZWdhY3lfaWQiOjM5NzAyNzk1LCJzY29wZSI6WyJza2lsbF91cCJdLCJzdC1ldiI6eyJ0IjoxNzMzNTA3Mjc4NzAwLCJ2IjpmYWxzZX19.XlYrT6dRIjaZKkJtdr7C_UuxajFRbNpA9BnIsny3rxiPVyS8rhIBwxW12tZwgttRywmXrXK-msowFhWU4XdL5Qfe4lwZb2HTbDeGiQPvQTlOjWWYMhgCoKdPtjCQsAcW45rg7aQ0p42JFQPoAQa8AnGfxXpgx2vSR7njiZ3ZZyHerDdKQHyIGSFVOxoK0TgR-hxBVY__Wjg8UTKgKSz9KU_uwnPgpe2DeYmP-LTK2oeoygsVRmbldY_GrrcRe3nqYcUfFkxSs0FSsoSv35jIxiptXfCjhEB1Y5eaJhHEjlYlP2rw98JysYxjO2rZbAdUpL3itPeo3T2uh1NZr_lArw"
+
+ response = await jwt_handler.auth_jwt(token=token)
+
+ assert response is not None
+
+ end_user_id = jwt_handler.get_end_user_id(
+ token=response,
+ default_value=None,
+ )
+
+ assert end_user_id is not None
+
+ ## CHECK USER API KEY AUTH ##
+ from starlette.datastructures import URL
+
+ bearer_token = "Bearer " + token
+
+ request = Request(scope={"type": "http"})
+ request._url = URL(url="/chat/completions")
+
+ ## 1. INITIAL TEAM CALL - should fail
+ # use generated key to auth in
+ setattr(
+ litellm.proxy.proxy_server,
+ "general_settings",
+ {
+ "enable_jwt_auth": True,
+ },
+ )
+ setattr(litellm.proxy.proxy_server, "prisma_client", {})
+ setattr(litellm.proxy.proxy_server, "jwt_handler", jwt_handler)
+ result = await user_api_key_auth(request=request, api_key=bearer_token)
+ assert (
+ result.end_user_id == "81b3e52a-67a6-4efb-9645-70527e101479"
+ ) # jwt token decoded sub value
diff --git a/ui/litellm-dashboard/src/components/budgets/budget_panel.tsx b/ui/litellm-dashboard/src/components/budgets/budget_panel.tsx
index edad680b28..604134575d 100644
--- a/ui/litellm-dashboard/src/components/budgets/budget_panel.tsx
+++ b/ui/litellm-dashboard/src/components/budgets/budget_panel.tsx
@@ -49,6 +49,7 @@ export interface budgetItem {
max_budget: string | null;
rpm_limit: number | null;
tpm_limit: number | null;
+ updated_at: string;
}
const BudgetPanel: React.FC = ({ accessToken }) => {
@@ -65,13 +66,17 @@ const BudgetPanel: React.FC = ({ accessToken }) => {
});
}, [accessToken]);
-
const handleEditCall = async (budget_id: string, index: number) => {
+ console.log("budget_id", budget_id)
if (accessToken == null) {
return;
}
- setSelectedBudget(budgetList[index])
- setIsEditModalVisible(true)
+ // Find the budget first
+ const budget = budgetList.find(budget => budget.budget_id === budget_id) || null;
+
+ // Update state and show modal after state is updated
+ setSelectedBudget(budget);
+ setIsEditModalVisible(true);
};
const handleDeleteCall = async (budget_id: string, index: number) => {
@@ -90,6 +95,15 @@ const BudgetPanel: React.FC = ({ accessToken }) => {
message.success("Budget Deleted.");
};
+ const handleUpdateCall = async () => {
+ if (accessToken == null) {
+ return;
+ }
+ getBudgetList(accessToken).then((data) => {
+ setBudgetList(data);
+ });
+ }
+
return (