Litellm dev 12 06 2024 (#7067)

* fix(edit_budget_modal.tsx): call `/budget/update` endpoint instead of `/budget/new`

allows updating existing budget on ui

* fix(user_api_key_auth.py): support cost tracking for end user via jwt field

* fix(presidio.py): support pii masking on sync logging callbacks

enables masking before logging to langfuse

* feat(utils.py): support retry policy logic inside '.completion()'

Fixes https://github.com/BerriAI/litellm/issues/6623

* fix(utils.py): support retry by retry policy on async logic as well

* fix(handle_jwt.py): set leeway default leeway value

* test: fix test to handle jwt audience claim
This commit is contained in:
Krish Dholakia 2024-12-06 22:44:18 -08:00 committed by GitHub
parent ba1e6fe7b7
commit e4493248ae
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 498 additions and 498 deletions

View file

@ -29,6 +29,7 @@ class CustomGuardrail(CustomLogger):
def should_run_guardrail(self, data, event_type: GuardrailEventHooks) -> bool:
metadata = data.get("metadata") or {}
requested_guardrails = metadata.get("guardrails") or []
verbose_logger.debug(
"inside should_run_guardrail for guardrail=%s event_type= %s guardrail_supported_event_hooks= %s requested_guardrails= %s",
self.guardrail_name,

View file

@ -969,10 +969,10 @@ class Logging:
),
result=result,
)
## LOGGING HOOK ##
for callback in callbacks:
if isinstance(callback, CustomLogger):
self.model_call_details, result = callback.logging_hook(
kwargs=self.model_call_details,
result=result,

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -1,48 +1,30 @@
model_list:
# GPT-4 Turbo Models
- model_name: gpt-4
litellm_params:
model: gpt-4
rpm: 1
- model_name: gpt-4
- model_name: gpt-3.5-turbo
litellm_params:
model: azure/chatgpt-v-2
api_key: os.environ/AZURE_API_KEY
api_base: os.environ/AZURE_API_BASE
- model_name: rerank-model
temperature: 0.2
guardrails:
- guardrail_name: "presidio-log-guard"
litellm_params:
model: jina_ai/jina-reranker-v2-base-multilingual
- model_name: anthropic-vertex
litellm_params:
model: vertex_ai/claude-3-5-sonnet-v2
vertex_ai_project: "adroit-crow-413218"
vertex_ai_location: "us-east5"
- model_name: openai-gpt-4o-realtime-audio
litellm_params:
model: openai/gpt-4o-realtime-preview-2024-10-01
api_key: os.environ/OPENAI_API_KEY
- model_name: openai/*
litellm_params:
model: openai/*
api_key: os.environ/OPENAI_API_KEY
- model_name: openai/*
litellm_params:
model: openai/*
api_key: os.environ/OPENAI_API_KEY
model_info:
access_groups: ["public-openai-models"]
- model_name: openai/gpt-4o
litellm_params:
model: openai/gpt-4o
api_key: os.environ/OPENAI_API_KEY
model_info:
access_groups: ["private-openai-models"]
router_settings:
# routing_strategy: usage-based-routing-v2
#redis_url: "os.environ/REDIS_URL"
redis_host: "os.environ/REDIS_HOST"
redis_port: "os.environ/REDIS_PORT"
guardrail: presidio
mode: "logging_only"
mock_redacted_text:
text: "hello world, my name is <PERSON>. My number is: <PHONE_NUMBER>"
items:
- start: 48
end: 62
entity_type: PHONE_NUMBER
text: "<PHONE_NUMBER>"
operator: replace
- start: 24
end: 32
entity_type: PERSON
text: "<PERSON>"
operator: replace
litellm_settings:
callbacks: ["datadog"]
set_verbose: true
success_callback: ["langfuse"]

View file

@ -36,16 +36,19 @@ class JWTHandler:
self,
) -> None:
self.http_handler = HTTPHandler()
self.leeway = 0
def update_environment(
self,
prisma_client: Optional[PrismaClient],
user_api_key_cache: DualCache,
litellm_jwtauth: LiteLLM_JWTAuth,
leeway: int = 0,
) -> None:
self.prisma_client = prisma_client
self.user_api_key_cache = user_api_key_cache
self.litellm_jwtauth = litellm_jwtauth
self.leeway = leeway
def is_jwt(self, token: str):
parts = token.split(".")
@ -271,6 +274,7 @@ class JWTHandler:
algorithms=algorithms,
options=decode_options,
audience=audience,
leeway=self.leeway, # allow testing of expired tokens
)
return payload

View file

@ -562,6 +562,7 @@ async def user_api_key_auth( # noqa: PLR0915
user_id=user_id,
org_id=org_id,
parent_otel_span=parent_otel_span,
end_user_id=end_user_id,
)
#### ELSE ####
## CHECK PASS-THROUGH ENDPOINTS ##

View file

@ -48,7 +48,7 @@ def initialize_callbacks_on_proxy( # noqa: PLR0915
imported_list.append(open_telemetry_logger)
setattr(proxy_server, "open_telemetry_logger", open_telemetry_logger)
elif isinstance(callback, str) and callback == "presidio":
from litellm.proxy.hooks.presidio_pii_masking import (
from litellm.proxy.guardrails.guardrail_hooks.presidio import (
_OPTIONAL_PresidioPIIMasking,
)

View file

@ -24,6 +24,7 @@ from litellm._logging import verbose_proxy_logger
from litellm.caching.caching import DualCache
from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.proxy._types import UserAPIKeyAuth
from litellm.types.guardrails import GuardrailEventHooks
from litellm.utils import (
EmbeddingResponse,
ImageResponse,
@ -54,8 +55,13 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
presidio_anonymizer_api_base: Optional[str] = None,
output_parse_pii: Optional[bool] = False,
presidio_ad_hoc_recognizers: Optional[str] = None,
logging_only: Optional[bool] = None,
**kwargs,
):
if logging_only is True:
self.logging_only = True
kwargs["event_hook"] = GuardrailEventHooks.logging_only
super().__init__(**kwargs)
self.pii_tokens: dict = (
{}
) # mapping of PII token to original text - only used with Presidio `replace` operation
@ -84,8 +90,6 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
presidio_anonymizer_api_base=presidio_anonymizer_api_base,
)
super().__init__(**kwargs)
def validate_environment(
self,
presidio_analyzer_api_base: Optional[str] = None,
@ -245,10 +249,44 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
verbose_proxy_logger.info(
f"Presidio PII Masking: Redacted pii message: {data['messages']}"
)
data["messages"] = messages
return data
except Exception as e:
raise e
def logging_hook(
self, kwargs: dict, result: Any, call_type: str
) -> Tuple[dict, Any]:
import threading
from concurrent.futures import ThreadPoolExecutor
def run_in_new_loop():
"""Run the coroutine in a new event loop within this thread."""
new_loop = asyncio.new_event_loop()
try:
asyncio.set_event_loop(new_loop)
return new_loop.run_until_complete(
self.async_logging_hook(
kwargs=kwargs, result=result, call_type=call_type
)
)
finally:
new_loop.close()
asyncio.set_event_loop(None)
try:
# First, try to get the current event loop
_ = asyncio.get_running_loop()
# If we're already in an event loop, run in a separate thread
# to avoid nested event loop issues
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(run_in_new_loop)
return future.result()
except RuntimeError:
# No running event loop, we can safely run in this thread
return run_in_new_loop()
async def async_logging_hook(
self, kwargs: dict, result: Any, call_type: str
) -> Tuple[dict, Any]:
@ -304,7 +342,8 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
verbose_proxy_logger.debug(
f"PII Masking Args: self.output_parse_pii={self.output_parse_pii}; type of response={type(response)}"
)
if self.output_parse_pii is False:
if self.output_parse_pii is False and litellm.output_parse_pii is False:
return response
if isinstance(response, ModelResponse) and not isinstance(

View file

@ -1,349 +0,0 @@
# +-----------------------------------------------+
# | |
# | PII Masking |
# | with Microsoft Presidio |
# | https://github.com/BerriAI/litellm/issues/ |
# +-----------------------------------------------+
#
# Tell us how we can improve! - Krrish & Ishaan
import asyncio
import json
import traceback
import uuid
from typing import Any, List, Optional, Tuple, Union
import aiohttp
from fastapi import HTTPException
import litellm # noqa: E401
from litellm._logging import verbose_proxy_logger
from litellm.caching.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy._types import UserAPIKeyAuth
from litellm.utils import (
EmbeddingResponse,
ImageResponse,
ModelResponse,
StreamingChoices,
get_formatted_prompt,
)
class _OPTIONAL_PresidioPIIMasking(CustomLogger):
user_api_key_cache = None
ad_hoc_recognizers = None
# Class variables or attributes
def __init__(
self,
logging_only: Optional[bool] = None,
mock_testing: bool = False,
mock_redacted_text: Optional[dict] = None,
):
self.pii_tokens: dict = (
{}
) # mapping of PII token to original text - only used with Presidio `replace` operation
self.mock_redacted_text = mock_redacted_text
self.logging_only = logging_only
if mock_testing is True: # for testing purposes only
return
ad_hoc_recognizers = litellm.presidio_ad_hoc_recognizers
if ad_hoc_recognizers is not None:
try:
with open(ad_hoc_recognizers, "r") as file:
self.ad_hoc_recognizers = json.load(file)
except FileNotFoundError:
raise Exception(f"File not found. file_path={ad_hoc_recognizers}")
except json.JSONDecodeError as e:
raise Exception(
f"Error decoding JSON file: {str(e)}, file_path={ad_hoc_recognizers}"
)
except Exception as e:
raise Exception(
f"An error occurred: {str(e)}, file_path={ad_hoc_recognizers}"
)
self.validate_environment()
def validate_environment(self):
self.presidio_analyzer_api_base: Optional[str] = litellm.get_secret(
"PRESIDIO_ANALYZER_API_BASE", None
) # type: ignore
self.presidio_anonymizer_api_base: Optional[str] = litellm.get_secret(
"PRESIDIO_ANONYMIZER_API_BASE", None
) # type: ignore
if self.presidio_analyzer_api_base is None:
raise Exception("Missing `PRESIDIO_ANALYZER_API_BASE` from environment")
if not self.presidio_analyzer_api_base.endswith("/"):
self.presidio_analyzer_api_base += "/"
if not (
self.presidio_analyzer_api_base.startswith("http://")
or self.presidio_analyzer_api_base.startswith("https://")
):
# add http:// if unset, assume communicating over private network - e.g. render
self.presidio_analyzer_api_base = (
"http://" + self.presidio_analyzer_api_base
)
if self.presidio_anonymizer_api_base is None:
raise Exception("Missing `PRESIDIO_ANONYMIZER_API_BASE` from environment")
if not self.presidio_anonymizer_api_base.endswith("/"):
self.presidio_anonymizer_api_base += "/"
if not (
self.presidio_anonymizer_api_base.startswith("http://")
or self.presidio_anonymizer_api_base.startswith("https://")
):
# add http:// if unset, assume communicating over private network - e.g. render
self.presidio_anonymizer_api_base = (
"http://" + self.presidio_anonymizer_api_base
)
def print_verbose(self, print_statement):
try:
verbose_proxy_logger.debug(print_statement)
if litellm.set_verbose:
print(print_statement) # noqa
except Exception:
pass
async def check_pii(self, text: str, output_parse_pii: bool) -> str:
"""
[TODO] make this more performant for high-throughput scenario
"""
try:
async with aiohttp.ClientSession() as session:
if self.mock_redacted_text is not None:
redacted_text = self.mock_redacted_text
else:
# Make the first request to /analyze
analyze_url = f"{self.presidio_analyzer_api_base}analyze"
verbose_proxy_logger.debug("Making request to: %s", analyze_url)
analyze_payload = {"text": text, "language": "en"}
if self.ad_hoc_recognizers is not None:
analyze_payload["ad_hoc_recognizers"] = self.ad_hoc_recognizers
redacted_text = None
async with session.post(
analyze_url, json=analyze_payload
) as response:
analyze_results = await response.json()
# Make the second request to /anonymize
anonymize_url = f"{self.presidio_anonymizer_api_base}anonymize"
verbose_proxy_logger.debug("Making request to: %s", anonymize_url)
anonymize_payload = {
"text": text,
"analyzer_results": analyze_results,
}
async with session.post(
anonymize_url, json=anonymize_payload
) as response:
redacted_text = await response.json()
new_text = text
if redacted_text is not None:
verbose_proxy_logger.debug("redacted_text: %s", redacted_text)
for item in redacted_text["items"]:
start = item["start"]
end = item["end"]
replacement = item["text"] # replacement token
if item["operator"] == "replace" and output_parse_pii is True:
# check if token in dict
# if exists, add a uuid to the replacement token for swapping back to the original text in llm response output parsing
if replacement in self.pii_tokens:
replacement = replacement + str(uuid.uuid4())
self.pii_tokens[replacement] = new_text[
start:end
] # get text it'll replace
new_text = new_text[:start] + replacement + new_text[end:]
return redacted_text["text"]
else:
raise Exception(f"Invalid anonymizer response: {redacted_text}")
except Exception as e:
verbose_proxy_logger.error(
"litellm.proxy.hooks.presidio_pii_masking.py::async_pre_call_hook(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
raise e
async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: str,
):
"""
- Check if request turned off pii
- Check if user allowed to turn off pii (key permissions -> 'allow_pii_controls')
- Take the request data
- Call /analyze -> get the results
- Call /anonymize w/ the analyze results -> get the redacted text
For multiple messages in /chat/completions, we'll need to call them in parallel.
"""
try:
if (
self.logging_only is True
): # only modify the logging obj data (done by async_logging_hook)
return data
permissions = user_api_key_dict.permissions
output_parse_pii = permissions.get(
"output_parse_pii", litellm.output_parse_pii
) # allow key to turn on/off output parsing for pii
no_pii = permissions.get(
"no-pii", None
) # allow key to turn on/off pii masking (if user is allowed to set pii controls, then they can override the key defaults)
if no_pii is None:
# check older way of turning on/off pii
no_pii = not permissions.get("pii", True)
content_safety = data.get("content_safety", None)
verbose_proxy_logger.debug("content_safety: %s", content_safety)
## Request-level turn on/off PII controls ##
if content_safety is not None and isinstance(content_safety, dict):
# pii masking ##
if (
content_safety.get("no-pii", None) is not None
and content_safety.get("no-pii") is True
):
# check if user allowed to turn this off
if permissions.get("allow_pii_controls", False) is False:
raise HTTPException(
status_code=400,
detail={
"error": "Not allowed to set PII controls per request"
},
)
else: # user allowed to turn off pii masking
no_pii = content_safety.get("no-pii")
if not isinstance(no_pii, bool):
raise HTTPException(
status_code=400,
detail={"error": "no_pii needs to be a boolean value"},
)
## pii output parsing ##
if content_safety.get("output_parse_pii", None) is not None:
# check if user allowed to turn this off
if permissions.get("allow_pii_controls", False) is False:
raise HTTPException(
status_code=400,
detail={
"error": "Not allowed to set PII controls per request"
},
)
else: # user allowed to turn on/off pii output parsing
output_parse_pii = content_safety.get("output_parse_pii")
if not isinstance(output_parse_pii, bool):
raise HTTPException(
status_code=400,
detail={
"error": "output_parse_pii needs to be a boolean value"
},
)
if no_pii is True: # turn off pii masking
return data
if call_type == "completion": # /chat/completions requests
messages = data["messages"]
tasks = []
for m in messages:
if isinstance(m["content"], str):
tasks.append(
self.check_pii(
text=m["content"], output_parse_pii=output_parse_pii
)
)
responses = await asyncio.gather(*tasks)
for index, r in enumerate(responses):
if isinstance(messages[index]["content"], str):
messages[index][
"content"
] = r # replace content with redacted string
verbose_proxy_logger.info(
f"Presidio PII Masking: Redacted pii message: {data['messages']}"
)
return data
except Exception as e:
verbose_proxy_logger.info(
"An error occurred -",
)
raise e
async def async_logging_hook(
self, kwargs: dict, result: Any, call_type: str
) -> Tuple[dict, Any]:
"""
Masks the input before logging to langfuse, datadog, etc.
"""
if (
call_type == "completion" or call_type == "acompletion"
): # /chat/completions requests
messages: Optional[List] = kwargs.get("messages", None)
tasks = []
if messages is None:
return kwargs, result
for m in messages:
text_str = ""
if m["content"] is None:
continue
if isinstance(m["content"], str):
text_str = m["content"]
tasks.append(
self.check_pii(text=text_str, output_parse_pii=False)
) # need to pass separately b/c presidio has context window limits
responses = await asyncio.gather(*tasks)
for index, r in enumerate(responses):
if isinstance(messages[index]["content"], str):
messages[index][
"content"
] = r # replace content with redacted string
verbose_proxy_logger.info(
f"Presidio PII Masking: Redacted pii message: {messages}"
)
kwargs["messages"] = messages
return kwargs, result
async def async_post_call_success_hook(
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
response: Union[ModelResponse, EmbeddingResponse, ImageResponse],
):
"""
Output parse the response object to replace the masked tokens with user sent values
"""
verbose_proxy_logger.debug(
f"PII Masking Args: litellm.output_parse_pii={litellm.output_parse_pii}; type of response={type(response)}"
)
if litellm.output_parse_pii is False:
return response
if isinstance(response, ModelResponse) and not isinstance(
response.choices[0], StreamingChoices
): # /chat/completions requests
if isinstance(response.choices[0].message.content, str):
verbose_proxy_logger.debug(
f"self.pii_tokens: {self.pii_tokens}; initial response: {response.choices[0].message.content}"
)
for key, value in self.pii_tokens.items():
response.choices[0].message.content = response.choices[
0
].message.content.replace(key, value)
return response

View file

@ -5987,6 +5987,39 @@ async def new_budget(
return response
@router.post(
"/budget/update",
tags=["budget management"],
dependencies=[Depends(user_api_key_auth)],
)
async def update_budget(
budget_obj: BudgetNew,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Create a new budget object. Can apply this to teams, orgs, end-users, keys.
"""
global prisma_client
if prisma_client is None:
raise HTTPException(
status_code=500,
detail={"error": CommonProxyErrors.db_not_connected_error.value},
)
if budget_obj.budget_id is None:
raise HTTPException(status_code=400, detail={"error": "budget_id is required"})
response = await prisma_client.db.litellm_budgettable.update(
where={"budget_id": budget_obj.budget_id},
data={
**budget_obj.model_dump(exclude_none=True), # type: ignore
"updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
}, # type: ignore
)
return response
@router.post(
"/budget/info",
tags=["budget management"],

View file

@ -83,6 +83,9 @@ from litellm.router_utils.fallback_event_handlers import (
run_async_fallback,
run_sync_fallback,
)
from litellm.router_utils.get_retry_from_policy import (
get_num_retries_from_retry_policy as _get_num_retries_from_retry_policy,
)
from litellm.router_utils.handle_error import (
async_raise_no_deployment_exception,
send_llm_exception_alert,
@ -5609,53 +5612,12 @@ class Router:
def get_num_retries_from_retry_policy(
self, exception: Exception, model_group: Optional[str] = None
):
"""
BadRequestErrorRetries: Optional[int] = None
AuthenticationErrorRetries: Optional[int] = None
TimeoutErrorRetries: Optional[int] = None
RateLimitErrorRetries: Optional[int] = None
ContentPolicyViolationErrorRetries: Optional[int] = None
"""
# if we can find the exception then in the retry policy -> return the number of retries
retry_policy: Optional[RetryPolicy] = self.retry_policy
if (
self.model_group_retry_policy is not None
and model_group is not None
and model_group in self.model_group_retry_policy
):
retry_policy = self.model_group_retry_policy.get(model_group, None) # type: ignore
if retry_policy is None:
return None
if isinstance(retry_policy, dict):
retry_policy = RetryPolicy(**retry_policy)
if (
isinstance(exception, litellm.BadRequestError)
and retry_policy.BadRequestErrorRetries is not None
):
return retry_policy.BadRequestErrorRetries
if (
isinstance(exception, litellm.AuthenticationError)
and retry_policy.AuthenticationErrorRetries is not None
):
return retry_policy.AuthenticationErrorRetries
if (
isinstance(exception, litellm.Timeout)
and retry_policy.TimeoutErrorRetries is not None
):
return retry_policy.TimeoutErrorRetries
if (
isinstance(exception, litellm.RateLimitError)
and retry_policy.RateLimitErrorRetries is not None
):
return retry_policy.RateLimitErrorRetries
if (
isinstance(exception, litellm.ContentPolicyViolationError)
and retry_policy.ContentPolicyViolationErrorRetries is not None
):
return retry_policy.ContentPolicyViolationErrorRetries
return _get_num_retries_from_retry_policy(
exception=exception,
model_group=model_group,
model_group_retry_policy=self.model_group_retry_policy,
retry_policy=self.retry_policy,
)
def get_allowed_fails_from_policy(self, exception: Exception):
"""

View file

@ -0,0 +1,71 @@
"""
Get num retries for an exception.
- Account for retry policy by exception type.
"""
from typing import Dict, Optional, Union
from litellm.exceptions import (
AuthenticationError,
BadRequestError,
ContentPolicyViolationError,
RateLimitError,
Timeout,
)
from litellm.types.router import RetryPolicy
def get_num_retries_from_retry_policy(
exception: Exception,
retry_policy: Optional[Union[RetryPolicy, dict]] = None,
model_group: Optional[str] = None,
model_group_retry_policy: Optional[Dict[str, RetryPolicy]] = None,
):
"""
BadRequestErrorRetries: Optional[int] = None
AuthenticationErrorRetries: Optional[int] = None
TimeoutErrorRetries: Optional[int] = None
RateLimitErrorRetries: Optional[int] = None
ContentPolicyViolationErrorRetries: Optional[int] = None
"""
# if we can find the exception then in the retry policy -> return the number of retries
if (
model_group_retry_policy is not None
and model_group is not None
and model_group in model_group_retry_policy
):
retry_policy = model_group_retry_policy.get(model_group, None) # type: ignore
if retry_policy is None:
return None
if isinstance(retry_policy, dict):
retry_policy = RetryPolicy(**retry_policy)
if (
isinstance(exception, BadRequestError)
and retry_policy.BadRequestErrorRetries is not None
):
return retry_policy.BadRequestErrorRetries
if (
isinstance(exception, AuthenticationError)
and retry_policy.AuthenticationErrorRetries is not None
):
return retry_policy.AuthenticationErrorRetries
if isinstance(exception, Timeout) and retry_policy.TimeoutErrorRetries is not None:
return retry_policy.TimeoutErrorRetries
if (
isinstance(exception, RateLimitError)
and retry_policy.RateLimitErrorRetries is not None
):
return retry_policy.RateLimitErrorRetries
if (
isinstance(exception, ContentPolicyViolationError)
and retry_policy.ContentPolicyViolationErrorRetries is not None
):
return retry_policy.ContentPolicyViolationErrorRetries
def reset_retry_policy() -> RetryPolicy:
return RetryPolicy()

View file

@ -97,6 +97,10 @@ from litellm.litellm_core_utils.rules import Rules
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
from litellm.litellm_core_utils.token_counter import get_modified_max_tokens
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.router_utils.get_retry_from_policy import (
get_num_retries_from_retry_policy,
reset_retry_policy,
)
from litellm.secret_managers.main import get_secret
from litellm.types.llms.openai import (
AllMessageValues,
@ -918,6 +922,14 @@ def client(original_function): # noqa: PLR0915
num_retries = (
kwargs.get("num_retries", None) or litellm.num_retries or None
)
if kwargs.get("retry_policy", None):
num_retries = get_num_retries_from_retry_policy(
exception=e,
retry_policy=kwargs.get("retry_policy"),
)
kwargs["retry_policy"] = (
reset_retry_policy()
) # prevent infinite loops
litellm.num_retries = (
None # set retries to None to prevent infinite loops
)
@ -1137,6 +1149,13 @@ def client(original_function): # noqa: PLR0915
num_retries = (
kwargs.get("num_retries", None) or litellm.num_retries or None
)
if kwargs.get("retry_policy", None):
num_retries = get_num_retries_from_retry_policy(
exception=e,
retry_policy=kwargs.get("retry_policy"),
)
kwargs["retry_policy"] = reset_retry_policy()
litellm.num_retries = (
None # set retries to None to prevent infinite loops
)

View file

@ -61,3 +61,68 @@ def test_completion_with_0_num_retries():
except Exception as e:
print("exception", e)
pass
@pytest.mark.asyncio
@pytest.mark.parametrize("sync_mode", [True, False])
async def test_completion_with_retry_policy(sync_mode):
from unittest.mock import patch, MagicMock, AsyncMock
from litellm.types.router import RetryPolicy
retry_number = 1
retry_policy = RetryPolicy(
ContentPolicyViolationErrorRetries=retry_number, # run 3 retries for ContentPolicyViolationErrors
AuthenticationErrorRetries=0, # run 0 retries for AuthenticationErrorRetries
)
target_function = "completion_with_retries"
with patch.object(litellm, target_function) as mock_completion_with_retries:
data = {
"model": "azure/gpt-3.5-turbo",
"messages": [{"gm": "vibe", "role": "user"}],
"retry_policy": retry_policy,
"mock_response": "Exception: content_filter_policy",
}
try:
if sync_mode:
completion(**data)
else:
await completion(**data)
except Exception as e:
print(e)
mock_completion_with_retries.assert_called_once()
assert (
mock_completion_with_retries.call_args.kwargs["num_retries"] == retry_number
)
assert retry_policy.ContentPolicyViolationErrorRetries == retry_number
@pytest.mark.asyncio
@pytest.mark.parametrize("sync_mode", [True, False])
async def test_completion_with_retry_policy_no_error(sync_mode):
"""
Test that the completion function does not throw an error when the retry policy is set
"""
from unittest.mock import patch, MagicMock, AsyncMock
from litellm.types.router import RetryPolicy
retry_number = 1
retry_policy = RetryPolicy(
ContentPolicyViolationErrorRetries=retry_number, # run 3 retries for ContentPolicyViolationErrors
AuthenticationErrorRetries=0, # run 0 retries for AuthenticationErrorRetries
)
data = {
"model": "gpt-3.5-turbo",
"messages": [{"gm": "vibe", "role": "user"}],
"retry_policy": retry_policy,
}
try:
if sync_mode:
completion(**data)
else:
await completion(**data)
except Exception as e:
print(e)

View file

@ -19,12 +19,13 @@ sys.path.insert(
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import litellm
from litellm import Router, mock_completion
from litellm.caching.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.hooks.presidio_pii_masking import _OPTIONAL_PresidioPIIMasking
from litellm.proxy.guardrails.guardrail_hooks.presidio import (
_OPTIONAL_PresidioPIIMasking,
)
from litellm.proxy.utils import ProxyLogging
@ -63,6 +64,7 @@ async def test_output_parsing():
- have presidio pii masking - output parse message
- assert that no masked tokens are in the input message
"""
litellm.set_verbose = True
litellm.output_parse_pii = True
pii_masking = _OPTIONAL_PresidioPIIMasking(mock_testing=True)
@ -206,6 +208,8 @@ async def test_presidio_pii_masking_input_b():
@pytest.mark.asyncio
async def test_presidio_pii_masking_logging_output_only_no_pre_api_hook():
from litellm.types.guardrails import GuardrailEventHooks
pii_masking = _OPTIONAL_PresidioPIIMasking(
logging_only=True,
mock_testing=True,
@ -223,22 +227,29 @@ async def test_presidio_pii_masking_logging_output_only_no_pre_api_hook():
}
]
new_data = await pii_masking.async_pre_call_hook(
user_api_key_dict=user_api_key_dict,
cache=local_cache,
data={"messages": test_messages},
call_type="completion",
assert (
pii_masking.should_run_guardrail(
data={"messages": test_messages},
event_type=GuardrailEventHooks.pre_call,
)
is False
)
assert "Jane Doe" in new_data["messages"][0]["content"]
@pytest.mark.parametrize(
"sync_mode",
[True, False],
)
@pytest.mark.asyncio
async def test_presidio_pii_masking_logging_output_only_logged_response():
async def test_presidio_pii_masking_logging_output_only_logged_response(sync_mode):
import litellm
litellm.set_verbose = True
pii_masking = _OPTIONAL_PresidioPIIMasking(
logging_only=True,
mock_testing=True,
mock_redacted_text=input_b_anonymizer_results,
guardrail_name="presidio",
)
test_messages = [
@ -247,15 +258,26 @@ async def test_presidio_pii_masking_logging_output_only_logged_response():
"content": "My name is Jane Doe, who are you? Say my name in your response",
}
]
with patch.object(
pii_masking, "async_log_success_event", new=AsyncMock()
) as mock_call:
litellm.callbacks = [pii_masking]
response = await litellm.acompletion(
model="gpt-3.5-turbo", messages=test_messages, mock_response="Hi Peter!"
)
await asyncio.sleep(3)
if sync_mode:
target_function = "log_success_event"
mock_call = MagicMock()
else:
target_function = "async_log_success_event"
mock_call = AsyncMock()
with patch.object(pii_masking, target_function, new=mock_call) as mock_call:
litellm.callbacks = [pii_masking]
if sync_mode:
response = litellm.completion(
model="gpt-3.5-turbo", messages=test_messages, mock_response="Hi Peter!"
)
time.sleep(3)
else:
response = await litellm.acompletion(
model="gpt-3.5-turbo", messages=test_messages, mock_response="Hi Peter!"
)
await asyncio.sleep(3)
assert response.choices[0].message.content == "Hi Peter!" # type: ignore
@ -275,8 +297,13 @@ async def test_presidio_pii_masking_logging_output_only_logged_response_guardrai
import litellm
from litellm.proxy.guardrails.init_guardrails import initialize_guardrails
from litellm.types.guardrails import GuardrailItem, GuardrailItemSpec
from litellm.types.guardrails import (
GuardrailItem,
GuardrailItemSpec,
GuardrailEventHooks,
)
litellm.set_verbose = True
os.environ["PRESIDIO_ANALYZER_API_BASE"] = "http://localhost:5002"
os.environ["PRESIDIO_ANONYMIZER_API_BASE"] = "http://localhost:5001"
@ -303,10 +330,15 @@ async def test_presidio_pii_masking_logging_output_only_logged_response_guardrai
pii_masking_obj: Optional[_OPTIONAL_PresidioPIIMasking] = None
for callback in litellm.callbacks:
print(f"CALLBACK: {callback}")
if isinstance(callback, _OPTIONAL_PresidioPIIMasking):
pii_masking_obj = callback
assert pii_masking_obj is not None
assert hasattr(pii_masking_obj, "logging_only")
assert pii_masking_obj.logging_only is True
assert pii_masking_obj.event_hook == GuardrailEventHooks.logging_only
assert pii_masking_obj.should_run_guardrail(
data={}, event_type=GuardrailEventHooks.logging_only
)

View file

@ -1026,3 +1026,89 @@ def test_get_public_key_from_jwk_url():
assert public_key is not None
assert public_key == jwk_response[0]
@pytest.mark.asyncio
async def test_end_user_jwt_auth(monkeypatch):
import litellm
from litellm.proxy.auth.handle_jwt import JWTHandler
from litellm.caching import DualCache
from litellm.proxy._types import LiteLLM_JWTAuth
from litellm.proxy.proxy_server import user_api_key_auth
monkeypatch.delenv("JWT_AUDIENCE", None)
jwt_handler = JWTHandler()
litellm_jwtauth = LiteLLM_JWTAuth(
end_user_id_jwt_field="sub",
)
cache = DualCache()
keys = [
{
"kid": "d-1733370597545",
"alg": "RS256",
"kty": "RSA",
"use": "sig",
"n": "j5Ik60FJSUIPMVdMSU8vhYvyPPH7rUsUllNI0BfBlIkgEYFk2mg4KX1XDQc6mcKFjbq9k_7TSkHWKnfPhNkkb0MdmZLKbwTrmte2k8xWDxp-rSmZpIJwC1zuPDx5joLHBgIb09-K2cPL2rvwzP75WtOr_QLXBerHAbXx8cOdI7mrSRWJ9iXbKv_pLDnZHnGNld75tztb8nCtgrywkF010jGi1xxaT8UKsTvK-QkIBkYI6m6WR9LMnG2OZm-ExuvNPUenfYUsqnynPF4SMNZwyQqJfavSLKI8uMzB2s9pcbx5HfQwIOYhMlgBHjhdDn2IUSnXSJqSsN6RQO18M2rgPQ",
"e": "AQAB",
},
{
"kid": "s-f836dd32-ef71-426a-8804-946a7f230bc9",
"alg": "RS256",
"kty": "RSA",
"use": "sig",
"n": "2A5-ZA18YKn7M4OtxsfXBc3Z7n2WyHTxbK4GEBlmD9T9TDr4sbJaI4oHfTvzsAC3H2r2YkASzrCISXMXQJjLHoeLgDVcKs8qTdLj7K5FNT9fA0kU9ayUjSGrqkz57SG7oNf9Wp__Qa-H-bs6Z8_CEfBy0JA9QSHUfrdOXp4vCB_qLn6DE0DJH9ELAq_0nktVQk_oxlvXlGtVZSZe31mNNgiD__RJMogf-SIFcYOkMLVGTTEBYiCk1mHxXS6oJZaVSWiBgHzu5wkra5AfQLUVelQaupT5H81hFPmiceEApf_2DacnqqRV4-Nl8sjhJtuTXiprVS2Z5r2pOMz_kVGNgw",
"e": "AQAB",
},
]
cache.set_cache(
key="litellm_jwt_auth_keys",
value=keys,
)
jwt_handler.update_environment(
prisma_client=None,
user_api_key_cache=cache,
litellm_jwtauth=litellm_jwtauth,
leeway=100000000000000,
)
token = "eyJraWQiOiJkLTE3MzMzNzA1OTc1NDUiLCJ0eXAiOiJKV1QiLCJ2ZXJzaW9uIjoiNCIsImFsZyI6IlJTMjU2In0.eyJpYXQiOjE3MzM1MDcyNzcsImV4cCI6MTczMzUwNzg3Nywic3ViIjoiODFiM2U1MmEtNjdhNi00ZWZiLTk2NDUtNzA1MjdlMTAxNDc5IiwidElkIjoicHVibGljIiwic2Vzc2lvbkhhbmRsZSI6Ijg4M2Y4YWFmLWUwOTEtNGE1Ny04YTJhLTRiMjcwMmZhZjMzYyIsInJlZnJlc2hUb2tlbkhhc2gxIjoiNDVhNDRhYjlmOTMwMGQyOTY4ZjkxODZkYWQ0YTcwY2QwNjk2YzBiNTBmZmUxZmQ4ZTM2YzU1NGU0MWE4ODU0YiIsInBhcmVudFJlZnJlc2hUb2tlbkhhc2gxIjpudWxsLCJhbnRpQ3NyZlRva2VuIjpudWxsLCJpc3MiOiJodHRwOi8vbG9jYWxob3N0OjMwMDEvYXV0aC9zdCIsImxlZ2FjeV9jb21wYW55X2lkIjoxNTI0OTM5LCJsZWdhY3lfaWQiOjM5NzAyNzk1LCJzY29wZSI6WyJza2lsbF91cCJdLCJzdC1ldiI6eyJ0IjoxNzMzNTA3Mjc4NzAwLCJ2IjpmYWxzZX19.XlYrT6dRIjaZKkJtdr7C_UuxajFRbNpA9BnIsny3rxiPVyS8rhIBwxW12tZwgttRywmXrXK-msowFhWU4XdL5Qfe4lwZb2HTbDeGiQPvQTlOjWWYMhgCoKdPtjCQsAcW45rg7aQ0p42JFQPoAQa8AnGfxXpgx2vSR7njiZ3ZZyHerDdKQHyIGSFVOxoK0TgR-hxBVY__Wjg8UTKgKSz9KU_uwnPgpe2DeYmP-LTK2oeoygsVRmbldY_GrrcRe3nqYcUfFkxSs0FSsoSv35jIxiptXfCjhEB1Y5eaJhHEjlYlP2rw98JysYxjO2rZbAdUpL3itPeo3T2uh1NZr_lArw"
response = await jwt_handler.auth_jwt(token=token)
assert response is not None
end_user_id = jwt_handler.get_end_user_id(
token=response,
default_value=None,
)
assert end_user_id is not None
## CHECK USER API KEY AUTH ##
from starlette.datastructures import URL
bearer_token = "Bearer " + token
request = Request(scope={"type": "http"})
request._url = URL(url="/chat/completions")
## 1. INITIAL TEAM CALL - should fail
# use generated key to auth in
setattr(
litellm.proxy.proxy_server,
"general_settings",
{
"enable_jwt_auth": True,
},
)
setattr(litellm.proxy.proxy_server, "prisma_client", {})
setattr(litellm.proxy.proxy_server, "jwt_handler", jwt_handler)
result = await user_api_key_auth(request=request, api_key=bearer_token)
assert (
result.end_user_id == "81b3e52a-67a6-4efb-9645-70527e101479"
) # jwt token decoded sub value

View file

@ -49,6 +49,7 @@ export interface budgetItem {
max_budget: string | null;
rpm_limit: number | null;
tpm_limit: number | null;
updated_at: string;
}
const BudgetPanel: React.FC<BudgetSettingsPageProps> = ({ accessToken }) => {
@ -65,13 +66,17 @@ const BudgetPanel: React.FC<BudgetSettingsPageProps> = ({ accessToken }) => {
});
}, [accessToken]);
const handleEditCall = async (budget_id: string, index: number) => {
console.log("budget_id", budget_id)
if (accessToken == null) {
return;
}
setSelectedBudget(budgetList[index])
setIsEditModalVisible(true)
// Find the budget first
const budget = budgetList.find(budget => budget.budget_id === budget_id) || null;
// Update state and show modal after state is updated
setSelectedBudget(budget);
setIsEditModalVisible(true);
};
const handleDeleteCall = async (budget_id: string, index: number) => {
@ -90,6 +95,15 @@ const BudgetPanel: React.FC<BudgetSettingsPageProps> = ({ accessToken }) => {
message.success("Budget Deleted.");
};
const handleUpdateCall = async () => {
if (accessToken == null) {
return;
}
getBudgetList(accessToken).then((data) => {
setBudgetList(data);
});
}
return (
<div className="w-full mx-auto flex-auto overflow-y-auto m-8 p-2">
<Button
@ -113,6 +127,7 @@ const BudgetPanel: React.FC<BudgetSettingsPageProps> = ({ accessToken }) => {
setIsModalVisible={setIsEditModalVisible}
setBudgetList={setBudgetList}
existingBudget={selectedBudget}
handleUpdateCall={handleUpdateCall}
/>
}
<Card>
@ -128,30 +143,27 @@ const BudgetPanel: React.FC<BudgetSettingsPageProps> = ({ accessToken }) => {
</TableHead>
<TableBody>
{budgetList.map((value: budgetItem, index: number) => (
<TableRow key={index}>
<TableCell>{value.budget_id}</TableCell>
<TableCell>
{value.max_budget ? value.max_budget : "n/a"}
</TableCell>
<TableCell>
{value.tpm_limit ? value.tpm_limit : "n/a"}
</TableCell>
<TableCell>
{value.rpm_limit ? value.rpm_limit : "n/a"}
</TableCell>
<Icon
icon={PencilAltIcon}
size="sm"
onClick={() => handleEditCall(value.budget_id, index)}
/>
<Icon
icon={TrashIcon}
size="sm"
onClick={() => handleDeleteCall(value.budget_id, index)}
/>
</TableRow>
))}
{budgetList
.slice() // Creates a shallow copy to avoid mutating the original array
.sort((a, b) => new Date(b.updated_at).getTime() - new Date(a.updated_at).getTime()) // Sort by updated_at in descending order
.map((value: budgetItem, index: number) => (
<TableRow key={index}>
<TableCell>{value.budget_id}</TableCell>
<TableCell>{value.max_budget ? value.max_budget : "n/a"}</TableCell>
<TableCell>{value.tpm_limit ? value.tpm_limit : "n/a"}</TableCell>
<TableCell>{value.rpm_limit ? value.rpm_limit : "n/a"}</TableCell>
<Icon
icon={PencilAltIcon}
size="sm"
onClick={() => handleEditCall(value.budget_id, index)}
/>
<Icon
icon={TrashIcon}
size="sm"
onClick={() => handleDeleteCall(value.budget_id, index)}
/>
</TableRow>
))}
</TableBody>
</Table>
</Card>

View file

@ -1,4 +1,4 @@
import React from "react";
import React, { useEffect } from "react";
import {
Button,
TextInput,
@ -17,7 +17,7 @@ import {
Select,
message,
} from "antd";
import { budgetCreateCall } from "../networking";
import { budgetUpdateCall } from "../networking";
import { budgetItem } from "./budget_panel";
interface BudgetModalProps {
@ -26,15 +26,23 @@ interface BudgetModalProps {
setIsModalVisible: React.Dispatch<React.SetStateAction<boolean>>;
setBudgetList: React.Dispatch<React.SetStateAction<any[]>>;
existingBudget: budgetItem
handleUpdateCall: () => void
}
const EditBudgetModal: React.FC<BudgetModalProps> = ({
isModalVisible,
accessToken,
setIsModalVisible,
setBudgetList,
existingBudget
existingBudget,
handleUpdateCall
}) => {
console.log("existingBudget", existingBudget)
const [form] = Form.useForm();
useEffect(() => {
form.setFieldsValue(existingBudget);
}, [existingBudget, form]);
const handleOk = () => {
setIsModalVisible(false);
form.resetFields();
@ -51,14 +59,14 @@ const EditBudgetModal: React.FC<BudgetModalProps> = ({
}
try {
message.info("Making API Call");
// setIsModalVisible(true);
const response = await budgetCreateCall(accessToken, formValues);
console.log("key create Response:", response);
setIsModalVisible(true);
const response = await budgetUpdateCall(accessToken, formValues);
setBudgetList((prevData) =>
prevData ? [...prevData, response] : [response]
); // Check if prevData is null
message.success("API Key Created");
message.success("Budget Updated");
form.resetFields();
handleUpdateCall();
} catch (error) {
console.error("Error creating the key:", error);
message.error(`Error creating the key: ${error}`, 20);

View file

@ -255,6 +255,43 @@ export const budgetCreateCall = async (
}
};
export const budgetUpdateCall = async (
accessToken: string,
formValues: Record<string, any> // Assuming formValues is an object
) => {
try {
console.log("Form Values in budgetUpdateCall:", formValues); // Log the form values before making the API call
console.log("Form Values after check:", formValues);
const url = proxyBaseUrl ? `${proxyBaseUrl}/budget/update` : `/budget/update`;
const response = await fetch(url, {
method: "POST",
headers: {
[globalLitellmHeaderName]: `Bearer ${accessToken}`,
"Content-Type": "application/json",
},
body: JSON.stringify({
...formValues, // Include formValues in the request body
}),
});
if (!response.ok) {
const errorData = await response.text();
handleError(errorData);
console.error("Error response from the server:", errorData);
throw new Error("Network response was not ok");
}
const data = await response.json();
console.log("API Response:", data);
return data;
// Handle success - you might want to update some state or UI based on the created key
} catch (error) {
console.error("Failed to create key:", error);
throw error;
}
};
export const invitationCreateCall = async (
accessToken: string,
userID: string // Assuming formValues is an object