mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
Litellm dev 12 06 2024 (#7067)
* fix(edit_budget_modal.tsx): call `/budget/update` endpoint instead of `/budget/new` allows updating existing budget on ui * fix(user_api_key_auth.py): support cost tracking for end user via jwt field * fix(presidio.py): support pii masking on sync logging callbacks enables masking before logging to langfuse * feat(utils.py): support retry policy logic inside '.completion()' Fixes https://github.com/BerriAI/litellm/issues/6623 * fix(utils.py): support retry by retry policy on async logic as well * fix(handle_jwt.py): set leeway default leeway value * test: fix test to handle jwt audience claim
This commit is contained in:
parent
ba1e6fe7b7
commit
e4493248ae
21 changed files with 498 additions and 498 deletions
|
@ -29,6 +29,7 @@ class CustomGuardrail(CustomLogger):
|
|||
def should_run_guardrail(self, data, event_type: GuardrailEventHooks) -> bool:
|
||||
metadata = data.get("metadata") or {}
|
||||
requested_guardrails = metadata.get("guardrails") or []
|
||||
|
||||
verbose_logger.debug(
|
||||
"inside should_run_guardrail for guardrail=%s event_type= %s guardrail_supported_event_hooks= %s requested_guardrails= %s",
|
||||
self.guardrail_name,
|
||||
|
|
|
@ -969,10 +969,10 @@ class Logging:
|
|||
),
|
||||
result=result,
|
||||
)
|
||||
|
||||
## LOGGING HOOK ##
|
||||
for callback in callbacks:
|
||||
if isinstance(callback, CustomLogger):
|
||||
|
||||
self.model_call_details, result = callback.logging_hook(
|
||||
kwargs=self.model_call_details,
|
||||
result=result,
|
||||
|
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
@ -1,48 +1,30 @@
|
|||
model_list:
|
||||
# GPT-4 Turbo Models
|
||||
- model_name: gpt-4
|
||||
litellm_params:
|
||||
model: gpt-4
|
||||
rpm: 1
|
||||
- model_name: gpt-4
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: azure/chatgpt-v-2
|
||||
api_key: os.environ/AZURE_API_KEY
|
||||
api_base: os.environ/AZURE_API_BASE
|
||||
- model_name: rerank-model
|
||||
temperature: 0.2
|
||||
|
||||
guardrails:
|
||||
- guardrail_name: "presidio-log-guard"
|
||||
litellm_params:
|
||||
model: jina_ai/jina-reranker-v2-base-multilingual
|
||||
- model_name: anthropic-vertex
|
||||
litellm_params:
|
||||
model: vertex_ai/claude-3-5-sonnet-v2
|
||||
vertex_ai_project: "adroit-crow-413218"
|
||||
vertex_ai_location: "us-east5"
|
||||
- model_name: openai-gpt-4o-realtime-audio
|
||||
litellm_params:
|
||||
model: openai/gpt-4o-realtime-preview-2024-10-01
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
- model_name: openai/*
|
||||
litellm_params:
|
||||
model: openai/*
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
- model_name: openai/*
|
||||
litellm_params:
|
||||
model: openai/*
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
model_info:
|
||||
access_groups: ["public-openai-models"]
|
||||
- model_name: openai/gpt-4o
|
||||
litellm_params:
|
||||
model: openai/gpt-4o
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
model_info:
|
||||
access_groups: ["private-openai-models"]
|
||||
|
||||
router_settings:
|
||||
# routing_strategy: usage-based-routing-v2
|
||||
#redis_url: "os.environ/REDIS_URL"
|
||||
redis_host: "os.environ/REDIS_HOST"
|
||||
redis_port: "os.environ/REDIS_PORT"
|
||||
guardrail: presidio
|
||||
mode: "logging_only"
|
||||
mock_redacted_text:
|
||||
text: "hello world, my name is <PERSON>. My number is: <PHONE_NUMBER>"
|
||||
items:
|
||||
- start: 48
|
||||
end: 62
|
||||
entity_type: PHONE_NUMBER
|
||||
text: "<PHONE_NUMBER>"
|
||||
operator: replace
|
||||
- start: 24
|
||||
end: 32
|
||||
entity_type: PERSON
|
||||
text: "<PERSON>"
|
||||
operator: replace
|
||||
|
||||
litellm_settings:
|
||||
callbacks: ["datadog"]
|
||||
set_verbose: true
|
||||
success_callback: ["langfuse"]
|
|
@ -36,16 +36,19 @@ class JWTHandler:
|
|||
self,
|
||||
) -> None:
|
||||
self.http_handler = HTTPHandler()
|
||||
self.leeway = 0
|
||||
|
||||
def update_environment(
|
||||
self,
|
||||
prisma_client: Optional[PrismaClient],
|
||||
user_api_key_cache: DualCache,
|
||||
litellm_jwtauth: LiteLLM_JWTAuth,
|
||||
leeway: int = 0,
|
||||
) -> None:
|
||||
self.prisma_client = prisma_client
|
||||
self.user_api_key_cache = user_api_key_cache
|
||||
self.litellm_jwtauth = litellm_jwtauth
|
||||
self.leeway = leeway
|
||||
|
||||
def is_jwt(self, token: str):
|
||||
parts = token.split(".")
|
||||
|
@ -271,6 +274,7 @@ class JWTHandler:
|
|||
algorithms=algorithms,
|
||||
options=decode_options,
|
||||
audience=audience,
|
||||
leeway=self.leeway, # allow testing of expired tokens
|
||||
)
|
||||
return payload
|
||||
|
||||
|
|
|
@ -562,6 +562,7 @@ async def user_api_key_auth( # noqa: PLR0915
|
|||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
parent_otel_span=parent_otel_span,
|
||||
end_user_id=end_user_id,
|
||||
)
|
||||
#### ELSE ####
|
||||
## CHECK PASS-THROUGH ENDPOINTS ##
|
||||
|
|
|
@ -48,7 +48,7 @@ def initialize_callbacks_on_proxy( # noqa: PLR0915
|
|||
imported_list.append(open_telemetry_logger)
|
||||
setattr(proxy_server, "open_telemetry_logger", open_telemetry_logger)
|
||||
elif isinstance(callback, str) and callback == "presidio":
|
||||
from litellm.proxy.hooks.presidio_pii_masking import (
|
||||
from litellm.proxy.guardrails.guardrail_hooks.presidio import (
|
||||
_OPTIONAL_PresidioPIIMasking,
|
||||
)
|
||||
|
||||
|
|
|
@ -24,6 +24,7 @@ from litellm._logging import verbose_proxy_logger
|
|||
from litellm.caching.caching import DualCache
|
||||
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.types.guardrails import GuardrailEventHooks
|
||||
from litellm.utils import (
|
||||
EmbeddingResponse,
|
||||
ImageResponse,
|
||||
|
@ -54,8 +55,13 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
|
|||
presidio_anonymizer_api_base: Optional[str] = None,
|
||||
output_parse_pii: Optional[bool] = False,
|
||||
presidio_ad_hoc_recognizers: Optional[str] = None,
|
||||
logging_only: Optional[bool] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if logging_only is True:
|
||||
self.logging_only = True
|
||||
kwargs["event_hook"] = GuardrailEventHooks.logging_only
|
||||
super().__init__(**kwargs)
|
||||
self.pii_tokens: dict = (
|
||||
{}
|
||||
) # mapping of PII token to original text - only used with Presidio `replace` operation
|
||||
|
@ -84,8 +90,6 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
|
|||
presidio_anonymizer_api_base=presidio_anonymizer_api_base,
|
||||
)
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
presidio_analyzer_api_base: Optional[str] = None,
|
||||
|
@ -245,10 +249,44 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
|
|||
verbose_proxy_logger.info(
|
||||
f"Presidio PII Masking: Redacted pii message: {data['messages']}"
|
||||
)
|
||||
data["messages"] = messages
|
||||
return data
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def logging_hook(
|
||||
self, kwargs: dict, result: Any, call_type: str
|
||||
) -> Tuple[dict, Any]:
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
def run_in_new_loop():
|
||||
"""Run the coroutine in a new event loop within this thread."""
|
||||
new_loop = asyncio.new_event_loop()
|
||||
try:
|
||||
asyncio.set_event_loop(new_loop)
|
||||
return new_loop.run_until_complete(
|
||||
self.async_logging_hook(
|
||||
kwargs=kwargs, result=result, call_type=call_type
|
||||
)
|
||||
)
|
||||
finally:
|
||||
new_loop.close()
|
||||
asyncio.set_event_loop(None)
|
||||
|
||||
try:
|
||||
# First, try to get the current event loop
|
||||
_ = asyncio.get_running_loop()
|
||||
# If we're already in an event loop, run in a separate thread
|
||||
# to avoid nested event loop issues
|
||||
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||
future = executor.submit(run_in_new_loop)
|
||||
return future.result()
|
||||
|
||||
except RuntimeError:
|
||||
# No running event loop, we can safely run in this thread
|
||||
return run_in_new_loop()
|
||||
|
||||
async def async_logging_hook(
|
||||
self, kwargs: dict, result: Any, call_type: str
|
||||
) -> Tuple[dict, Any]:
|
||||
|
@ -304,7 +342,8 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
|
|||
verbose_proxy_logger.debug(
|
||||
f"PII Masking Args: self.output_parse_pii={self.output_parse_pii}; type of response={type(response)}"
|
||||
)
|
||||
if self.output_parse_pii is False:
|
||||
|
||||
if self.output_parse_pii is False and litellm.output_parse_pii is False:
|
||||
return response
|
||||
|
||||
if isinstance(response, ModelResponse) and not isinstance(
|
||||
|
|
|
@ -1,349 +0,0 @@
|
|||
# +-----------------------------------------------+
|
||||
# | |
|
||||
# | PII Masking |
|
||||
# | with Microsoft Presidio |
|
||||
# | https://github.com/BerriAI/litellm/issues/ |
|
||||
# +-----------------------------------------------+
|
||||
#
|
||||
# Tell us how we can improve! - Krrish & Ishaan
|
||||
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import traceback
|
||||
import uuid
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
import aiohttp
|
||||
from fastapi import HTTPException
|
||||
|
||||
import litellm # noqa: E401
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.utils import (
|
||||
EmbeddingResponse,
|
||||
ImageResponse,
|
||||
ModelResponse,
|
||||
StreamingChoices,
|
||||
get_formatted_prompt,
|
||||
)
|
||||
|
||||
|
||||
class _OPTIONAL_PresidioPIIMasking(CustomLogger):
|
||||
user_api_key_cache = None
|
||||
ad_hoc_recognizers = None
|
||||
|
||||
# Class variables or attributes
|
||||
def __init__(
|
||||
self,
|
||||
logging_only: Optional[bool] = None,
|
||||
mock_testing: bool = False,
|
||||
mock_redacted_text: Optional[dict] = None,
|
||||
):
|
||||
self.pii_tokens: dict = (
|
||||
{}
|
||||
) # mapping of PII token to original text - only used with Presidio `replace` operation
|
||||
|
||||
self.mock_redacted_text = mock_redacted_text
|
||||
self.logging_only = logging_only
|
||||
if mock_testing is True: # for testing purposes only
|
||||
return
|
||||
|
||||
ad_hoc_recognizers = litellm.presidio_ad_hoc_recognizers
|
||||
if ad_hoc_recognizers is not None:
|
||||
try:
|
||||
with open(ad_hoc_recognizers, "r") as file:
|
||||
self.ad_hoc_recognizers = json.load(file)
|
||||
except FileNotFoundError:
|
||||
raise Exception(f"File not found. file_path={ad_hoc_recognizers}")
|
||||
except json.JSONDecodeError as e:
|
||||
raise Exception(
|
||||
f"Error decoding JSON file: {str(e)}, file_path={ad_hoc_recognizers}"
|
||||
)
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"An error occurred: {str(e)}, file_path={ad_hoc_recognizers}"
|
||||
)
|
||||
|
||||
self.validate_environment()
|
||||
|
||||
def validate_environment(self):
|
||||
self.presidio_analyzer_api_base: Optional[str] = litellm.get_secret(
|
||||
"PRESIDIO_ANALYZER_API_BASE", None
|
||||
) # type: ignore
|
||||
self.presidio_anonymizer_api_base: Optional[str] = litellm.get_secret(
|
||||
"PRESIDIO_ANONYMIZER_API_BASE", None
|
||||
) # type: ignore
|
||||
|
||||
if self.presidio_analyzer_api_base is None:
|
||||
raise Exception("Missing `PRESIDIO_ANALYZER_API_BASE` from environment")
|
||||
if not self.presidio_analyzer_api_base.endswith("/"):
|
||||
self.presidio_analyzer_api_base += "/"
|
||||
if not (
|
||||
self.presidio_analyzer_api_base.startswith("http://")
|
||||
or self.presidio_analyzer_api_base.startswith("https://")
|
||||
):
|
||||
# add http:// if unset, assume communicating over private network - e.g. render
|
||||
self.presidio_analyzer_api_base = (
|
||||
"http://" + self.presidio_analyzer_api_base
|
||||
)
|
||||
|
||||
if self.presidio_anonymizer_api_base is None:
|
||||
raise Exception("Missing `PRESIDIO_ANONYMIZER_API_BASE` from environment")
|
||||
if not self.presidio_anonymizer_api_base.endswith("/"):
|
||||
self.presidio_anonymizer_api_base += "/"
|
||||
if not (
|
||||
self.presidio_anonymizer_api_base.startswith("http://")
|
||||
or self.presidio_anonymizer_api_base.startswith("https://")
|
||||
):
|
||||
# add http:// if unset, assume communicating over private network - e.g. render
|
||||
self.presidio_anonymizer_api_base = (
|
||||
"http://" + self.presidio_anonymizer_api_base
|
||||
)
|
||||
|
||||
def print_verbose(self, print_statement):
|
||||
try:
|
||||
verbose_proxy_logger.debug(print_statement)
|
||||
if litellm.set_verbose:
|
||||
print(print_statement) # noqa
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def check_pii(self, text: str, output_parse_pii: bool) -> str:
|
||||
"""
|
||||
[TODO] make this more performant for high-throughput scenario
|
||||
"""
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
if self.mock_redacted_text is not None:
|
||||
redacted_text = self.mock_redacted_text
|
||||
else:
|
||||
# Make the first request to /analyze
|
||||
analyze_url = f"{self.presidio_analyzer_api_base}analyze"
|
||||
verbose_proxy_logger.debug("Making request to: %s", analyze_url)
|
||||
analyze_payload = {"text": text, "language": "en"}
|
||||
if self.ad_hoc_recognizers is not None:
|
||||
analyze_payload["ad_hoc_recognizers"] = self.ad_hoc_recognizers
|
||||
redacted_text = None
|
||||
async with session.post(
|
||||
analyze_url, json=analyze_payload
|
||||
) as response:
|
||||
analyze_results = await response.json()
|
||||
|
||||
# Make the second request to /anonymize
|
||||
anonymize_url = f"{self.presidio_anonymizer_api_base}anonymize"
|
||||
verbose_proxy_logger.debug("Making request to: %s", anonymize_url)
|
||||
anonymize_payload = {
|
||||
"text": text,
|
||||
"analyzer_results": analyze_results,
|
||||
}
|
||||
|
||||
async with session.post(
|
||||
anonymize_url, json=anonymize_payload
|
||||
) as response:
|
||||
redacted_text = await response.json()
|
||||
|
||||
new_text = text
|
||||
if redacted_text is not None:
|
||||
verbose_proxy_logger.debug("redacted_text: %s", redacted_text)
|
||||
for item in redacted_text["items"]:
|
||||
start = item["start"]
|
||||
end = item["end"]
|
||||
replacement = item["text"] # replacement token
|
||||
if item["operator"] == "replace" and output_parse_pii is True:
|
||||
# check if token in dict
|
||||
# if exists, add a uuid to the replacement token for swapping back to the original text in llm response output parsing
|
||||
if replacement in self.pii_tokens:
|
||||
replacement = replacement + str(uuid.uuid4())
|
||||
|
||||
self.pii_tokens[replacement] = new_text[
|
||||
start:end
|
||||
] # get text it'll replace
|
||||
|
||||
new_text = new_text[:start] + replacement + new_text[end:]
|
||||
return redacted_text["text"]
|
||||
else:
|
||||
raise Exception(f"Invalid anonymizer response: {redacted_text}")
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.hooks.presidio_pii_masking.py::async_pre_call_hook(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
raise e
|
||||
|
||||
async def async_pre_call_hook(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
cache: DualCache,
|
||||
data: dict,
|
||||
call_type: str,
|
||||
):
|
||||
"""
|
||||
- Check if request turned off pii
|
||||
- Check if user allowed to turn off pii (key permissions -> 'allow_pii_controls')
|
||||
|
||||
- Take the request data
|
||||
- Call /analyze -> get the results
|
||||
- Call /anonymize w/ the analyze results -> get the redacted text
|
||||
|
||||
For multiple messages in /chat/completions, we'll need to call them in parallel.
|
||||
"""
|
||||
try:
|
||||
if (
|
||||
self.logging_only is True
|
||||
): # only modify the logging obj data (done by async_logging_hook)
|
||||
return data
|
||||
permissions = user_api_key_dict.permissions
|
||||
output_parse_pii = permissions.get(
|
||||
"output_parse_pii", litellm.output_parse_pii
|
||||
) # allow key to turn on/off output parsing for pii
|
||||
no_pii = permissions.get(
|
||||
"no-pii", None
|
||||
) # allow key to turn on/off pii masking (if user is allowed to set pii controls, then they can override the key defaults)
|
||||
|
||||
if no_pii is None:
|
||||
# check older way of turning on/off pii
|
||||
no_pii = not permissions.get("pii", True)
|
||||
|
||||
content_safety = data.get("content_safety", None)
|
||||
verbose_proxy_logger.debug("content_safety: %s", content_safety)
|
||||
## Request-level turn on/off PII controls ##
|
||||
if content_safety is not None and isinstance(content_safety, dict):
|
||||
# pii masking ##
|
||||
if (
|
||||
content_safety.get("no-pii", None) is not None
|
||||
and content_safety.get("no-pii") is True
|
||||
):
|
||||
# check if user allowed to turn this off
|
||||
if permissions.get("allow_pii_controls", False) is False:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "Not allowed to set PII controls per request"
|
||||
},
|
||||
)
|
||||
else: # user allowed to turn off pii masking
|
||||
no_pii = content_safety.get("no-pii")
|
||||
if not isinstance(no_pii, bool):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"error": "no_pii needs to be a boolean value"},
|
||||
)
|
||||
## pii output parsing ##
|
||||
if content_safety.get("output_parse_pii", None) is not None:
|
||||
# check if user allowed to turn this off
|
||||
if permissions.get("allow_pii_controls", False) is False:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "Not allowed to set PII controls per request"
|
||||
},
|
||||
)
|
||||
else: # user allowed to turn on/off pii output parsing
|
||||
output_parse_pii = content_safety.get("output_parse_pii")
|
||||
if not isinstance(output_parse_pii, bool):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "output_parse_pii needs to be a boolean value"
|
||||
},
|
||||
)
|
||||
|
||||
if no_pii is True: # turn off pii masking
|
||||
return data
|
||||
|
||||
if call_type == "completion": # /chat/completions requests
|
||||
messages = data["messages"]
|
||||
tasks = []
|
||||
|
||||
for m in messages:
|
||||
if isinstance(m["content"], str):
|
||||
tasks.append(
|
||||
self.check_pii(
|
||||
text=m["content"], output_parse_pii=output_parse_pii
|
||||
)
|
||||
)
|
||||
responses = await asyncio.gather(*tasks)
|
||||
for index, r in enumerate(responses):
|
||||
if isinstance(messages[index]["content"], str):
|
||||
messages[index][
|
||||
"content"
|
||||
] = r # replace content with redacted string
|
||||
verbose_proxy_logger.info(
|
||||
f"Presidio PII Masking: Redacted pii message: {data['messages']}"
|
||||
)
|
||||
return data
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.info(
|
||||
"An error occurred -",
|
||||
)
|
||||
raise e
|
||||
|
||||
async def async_logging_hook(
|
||||
self, kwargs: dict, result: Any, call_type: str
|
||||
) -> Tuple[dict, Any]:
|
||||
"""
|
||||
Masks the input before logging to langfuse, datadog, etc.
|
||||
"""
|
||||
if (
|
||||
call_type == "completion" or call_type == "acompletion"
|
||||
): # /chat/completions requests
|
||||
messages: Optional[List] = kwargs.get("messages", None)
|
||||
tasks = []
|
||||
|
||||
if messages is None:
|
||||
return kwargs, result
|
||||
|
||||
for m in messages:
|
||||
text_str = ""
|
||||
if m["content"] is None:
|
||||
continue
|
||||
if isinstance(m["content"], str):
|
||||
text_str = m["content"]
|
||||
tasks.append(
|
||||
self.check_pii(text=text_str, output_parse_pii=False)
|
||||
) # need to pass separately b/c presidio has context window limits
|
||||
responses = await asyncio.gather(*tasks)
|
||||
for index, r in enumerate(responses):
|
||||
if isinstance(messages[index]["content"], str):
|
||||
messages[index][
|
||||
"content"
|
||||
] = r # replace content with redacted string
|
||||
verbose_proxy_logger.info(
|
||||
f"Presidio PII Masking: Redacted pii message: {messages}"
|
||||
)
|
||||
kwargs["messages"] = messages
|
||||
|
||||
return kwargs, result
|
||||
|
||||
async def async_post_call_success_hook(
|
||||
self,
|
||||
data: dict,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
response: Union[ModelResponse, EmbeddingResponse, ImageResponse],
|
||||
):
|
||||
"""
|
||||
Output parse the response object to replace the masked tokens with user sent values
|
||||
"""
|
||||
verbose_proxy_logger.debug(
|
||||
f"PII Masking Args: litellm.output_parse_pii={litellm.output_parse_pii}; type of response={type(response)}"
|
||||
)
|
||||
if litellm.output_parse_pii is False:
|
||||
return response
|
||||
|
||||
if isinstance(response, ModelResponse) and not isinstance(
|
||||
response.choices[0], StreamingChoices
|
||||
): # /chat/completions requests
|
||||
if isinstance(response.choices[0].message.content, str):
|
||||
verbose_proxy_logger.debug(
|
||||
f"self.pii_tokens: {self.pii_tokens}; initial response: {response.choices[0].message.content}"
|
||||
)
|
||||
for key, value in self.pii_tokens.items():
|
||||
response.choices[0].message.content = response.choices[
|
||||
0
|
||||
].message.content.replace(key, value)
|
||||
return response
|
|
@ -5987,6 +5987,39 @@ async def new_budget(
|
|||
return response
|
||||
|
||||
|
||||
@router.post(
|
||||
"/budget/update",
|
||||
tags=["budget management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def update_budget(
|
||||
budget_obj: BudgetNew,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Create a new budget object. Can apply this to teams, orgs, end-users, keys.
|
||||
"""
|
||||
global prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||
)
|
||||
if budget_obj.budget_id is None:
|
||||
raise HTTPException(status_code=400, detail={"error": "budget_id is required"})
|
||||
|
||||
response = await prisma_client.db.litellm_budgettable.update(
|
||||
where={"budget_id": budget_obj.budget_id},
|
||||
data={
|
||||
**budget_obj.model_dump(exclude_none=True), # type: ignore
|
||||
"updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||
}, # type: ignore
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@router.post(
|
||||
"/budget/info",
|
||||
tags=["budget management"],
|
||||
|
|
|
@ -83,6 +83,9 @@ from litellm.router_utils.fallback_event_handlers import (
|
|||
run_async_fallback,
|
||||
run_sync_fallback,
|
||||
)
|
||||
from litellm.router_utils.get_retry_from_policy import (
|
||||
get_num_retries_from_retry_policy as _get_num_retries_from_retry_policy,
|
||||
)
|
||||
from litellm.router_utils.handle_error import (
|
||||
async_raise_no_deployment_exception,
|
||||
send_llm_exception_alert,
|
||||
|
@ -5609,53 +5612,12 @@ class Router:
|
|||
def get_num_retries_from_retry_policy(
|
||||
self, exception: Exception, model_group: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
BadRequestErrorRetries: Optional[int] = None
|
||||
AuthenticationErrorRetries: Optional[int] = None
|
||||
TimeoutErrorRetries: Optional[int] = None
|
||||
RateLimitErrorRetries: Optional[int] = None
|
||||
ContentPolicyViolationErrorRetries: Optional[int] = None
|
||||
"""
|
||||
# if we can find the exception then in the retry policy -> return the number of retries
|
||||
retry_policy: Optional[RetryPolicy] = self.retry_policy
|
||||
|
||||
if (
|
||||
self.model_group_retry_policy is not None
|
||||
and model_group is not None
|
||||
and model_group in self.model_group_retry_policy
|
||||
):
|
||||
retry_policy = self.model_group_retry_policy.get(model_group, None) # type: ignore
|
||||
|
||||
if retry_policy is None:
|
||||
return None
|
||||
if isinstance(retry_policy, dict):
|
||||
retry_policy = RetryPolicy(**retry_policy)
|
||||
|
||||
if (
|
||||
isinstance(exception, litellm.BadRequestError)
|
||||
and retry_policy.BadRequestErrorRetries is not None
|
||||
):
|
||||
return retry_policy.BadRequestErrorRetries
|
||||
if (
|
||||
isinstance(exception, litellm.AuthenticationError)
|
||||
and retry_policy.AuthenticationErrorRetries is not None
|
||||
):
|
||||
return retry_policy.AuthenticationErrorRetries
|
||||
if (
|
||||
isinstance(exception, litellm.Timeout)
|
||||
and retry_policy.TimeoutErrorRetries is not None
|
||||
):
|
||||
return retry_policy.TimeoutErrorRetries
|
||||
if (
|
||||
isinstance(exception, litellm.RateLimitError)
|
||||
and retry_policy.RateLimitErrorRetries is not None
|
||||
):
|
||||
return retry_policy.RateLimitErrorRetries
|
||||
if (
|
||||
isinstance(exception, litellm.ContentPolicyViolationError)
|
||||
and retry_policy.ContentPolicyViolationErrorRetries is not None
|
||||
):
|
||||
return retry_policy.ContentPolicyViolationErrorRetries
|
||||
return _get_num_retries_from_retry_policy(
|
||||
exception=exception,
|
||||
model_group=model_group,
|
||||
model_group_retry_policy=self.model_group_retry_policy,
|
||||
retry_policy=self.retry_policy,
|
||||
)
|
||||
|
||||
def get_allowed_fails_from_policy(self, exception: Exception):
|
||||
"""
|
||||
|
|
71
litellm/router_utils/get_retry_from_policy.py
Normal file
71
litellm/router_utils/get_retry_from_policy.py
Normal file
|
@ -0,0 +1,71 @@
|
|||
"""
|
||||
Get num retries for an exception.
|
||||
|
||||
- Account for retry policy by exception type.
|
||||
"""
|
||||
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
from litellm.exceptions import (
|
||||
AuthenticationError,
|
||||
BadRequestError,
|
||||
ContentPolicyViolationError,
|
||||
RateLimitError,
|
||||
Timeout,
|
||||
)
|
||||
from litellm.types.router import RetryPolicy
|
||||
|
||||
|
||||
def get_num_retries_from_retry_policy(
|
||||
exception: Exception,
|
||||
retry_policy: Optional[Union[RetryPolicy, dict]] = None,
|
||||
model_group: Optional[str] = None,
|
||||
model_group_retry_policy: Optional[Dict[str, RetryPolicy]] = None,
|
||||
):
|
||||
"""
|
||||
BadRequestErrorRetries: Optional[int] = None
|
||||
AuthenticationErrorRetries: Optional[int] = None
|
||||
TimeoutErrorRetries: Optional[int] = None
|
||||
RateLimitErrorRetries: Optional[int] = None
|
||||
ContentPolicyViolationErrorRetries: Optional[int] = None
|
||||
"""
|
||||
# if we can find the exception then in the retry policy -> return the number of retries
|
||||
|
||||
if (
|
||||
model_group_retry_policy is not None
|
||||
and model_group is not None
|
||||
and model_group in model_group_retry_policy
|
||||
):
|
||||
retry_policy = model_group_retry_policy.get(model_group, None) # type: ignore
|
||||
|
||||
if retry_policy is None:
|
||||
return None
|
||||
if isinstance(retry_policy, dict):
|
||||
retry_policy = RetryPolicy(**retry_policy)
|
||||
|
||||
if (
|
||||
isinstance(exception, BadRequestError)
|
||||
and retry_policy.BadRequestErrorRetries is not None
|
||||
):
|
||||
return retry_policy.BadRequestErrorRetries
|
||||
if (
|
||||
isinstance(exception, AuthenticationError)
|
||||
and retry_policy.AuthenticationErrorRetries is not None
|
||||
):
|
||||
return retry_policy.AuthenticationErrorRetries
|
||||
if isinstance(exception, Timeout) and retry_policy.TimeoutErrorRetries is not None:
|
||||
return retry_policy.TimeoutErrorRetries
|
||||
if (
|
||||
isinstance(exception, RateLimitError)
|
||||
and retry_policy.RateLimitErrorRetries is not None
|
||||
):
|
||||
return retry_policy.RateLimitErrorRetries
|
||||
if (
|
||||
isinstance(exception, ContentPolicyViolationError)
|
||||
and retry_policy.ContentPolicyViolationErrorRetries is not None
|
||||
):
|
||||
return retry_policy.ContentPolicyViolationErrorRetries
|
||||
|
||||
|
||||
def reset_retry_policy() -> RetryPolicy:
|
||||
return RetryPolicy()
|
|
@ -97,6 +97,10 @@ from litellm.litellm_core_utils.rules import Rules
|
|||
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
|
||||
from litellm.litellm_core_utils.token_counter import get_modified_max_tokens
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.router_utils.get_retry_from_policy import (
|
||||
get_num_retries_from_retry_policy,
|
||||
reset_retry_policy,
|
||||
)
|
||||
from litellm.secret_managers.main import get_secret
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
|
@ -918,6 +922,14 @@ def client(original_function): # noqa: PLR0915
|
|||
num_retries = (
|
||||
kwargs.get("num_retries", None) or litellm.num_retries or None
|
||||
)
|
||||
if kwargs.get("retry_policy", None):
|
||||
num_retries = get_num_retries_from_retry_policy(
|
||||
exception=e,
|
||||
retry_policy=kwargs.get("retry_policy"),
|
||||
)
|
||||
kwargs["retry_policy"] = (
|
||||
reset_retry_policy()
|
||||
) # prevent infinite loops
|
||||
litellm.num_retries = (
|
||||
None # set retries to None to prevent infinite loops
|
||||
)
|
||||
|
@ -1137,6 +1149,13 @@ def client(original_function): # noqa: PLR0915
|
|||
num_retries = (
|
||||
kwargs.get("num_retries", None) or litellm.num_retries or None
|
||||
)
|
||||
if kwargs.get("retry_policy", None):
|
||||
num_retries = get_num_retries_from_retry_policy(
|
||||
exception=e,
|
||||
retry_policy=kwargs.get("retry_policy"),
|
||||
)
|
||||
kwargs["retry_policy"] = reset_retry_policy()
|
||||
|
||||
litellm.num_retries = (
|
||||
None # set retries to None to prevent infinite loops
|
||||
)
|
||||
|
|
|
@ -61,3 +61,68 @@ def test_completion_with_0_num_retries():
|
|||
except Exception as e:
|
||||
print("exception", e)
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
async def test_completion_with_retry_policy(sync_mode):
|
||||
from unittest.mock import patch, MagicMock, AsyncMock
|
||||
from litellm.types.router import RetryPolicy
|
||||
|
||||
retry_number = 1
|
||||
retry_policy = RetryPolicy(
|
||||
ContentPolicyViolationErrorRetries=retry_number, # run 3 retries for ContentPolicyViolationErrors
|
||||
AuthenticationErrorRetries=0, # run 0 retries for AuthenticationErrorRetries
|
||||
)
|
||||
|
||||
target_function = "completion_with_retries"
|
||||
|
||||
with patch.object(litellm, target_function) as mock_completion_with_retries:
|
||||
data = {
|
||||
"model": "azure/gpt-3.5-turbo",
|
||||
"messages": [{"gm": "vibe", "role": "user"}],
|
||||
"retry_policy": retry_policy,
|
||||
"mock_response": "Exception: content_filter_policy",
|
||||
}
|
||||
try:
|
||||
if sync_mode:
|
||||
completion(**data)
|
||||
else:
|
||||
await completion(**data)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
mock_completion_with_retries.assert_called_once()
|
||||
assert (
|
||||
mock_completion_with_retries.call_args.kwargs["num_retries"] == retry_number
|
||||
)
|
||||
assert retry_policy.ContentPolicyViolationErrorRetries == retry_number
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
async def test_completion_with_retry_policy_no_error(sync_mode):
|
||||
"""
|
||||
Test that the completion function does not throw an error when the retry policy is set
|
||||
"""
|
||||
from unittest.mock import patch, MagicMock, AsyncMock
|
||||
from litellm.types.router import RetryPolicy
|
||||
|
||||
retry_number = 1
|
||||
retry_policy = RetryPolicy(
|
||||
ContentPolicyViolationErrorRetries=retry_number, # run 3 retries for ContentPolicyViolationErrors
|
||||
AuthenticationErrorRetries=0, # run 0 retries for AuthenticationErrorRetries
|
||||
)
|
||||
|
||||
data = {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [{"gm": "vibe", "role": "user"}],
|
||||
"retry_policy": retry_policy,
|
||||
}
|
||||
try:
|
||||
if sync_mode:
|
||||
completion(**data)
|
||||
else:
|
||||
await completion(**data)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
|
|
@ -19,12 +19,13 @@ sys.path.insert(
|
|||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
import litellm
|
||||
from litellm import Router, mock_completion
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy.hooks.presidio_pii_masking import _OPTIONAL_PresidioPIIMasking
|
||||
from litellm.proxy.guardrails.guardrail_hooks.presidio import (
|
||||
_OPTIONAL_PresidioPIIMasking,
|
||||
)
|
||||
from litellm.proxy.utils import ProxyLogging
|
||||
|
||||
|
||||
|
@ -63,6 +64,7 @@ async def test_output_parsing():
|
|||
- have presidio pii masking - output parse message
|
||||
- assert that no masked tokens are in the input message
|
||||
"""
|
||||
litellm.set_verbose = True
|
||||
litellm.output_parse_pii = True
|
||||
pii_masking = _OPTIONAL_PresidioPIIMasking(mock_testing=True)
|
||||
|
||||
|
@ -206,6 +208,8 @@ async def test_presidio_pii_masking_input_b():
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_presidio_pii_masking_logging_output_only_no_pre_api_hook():
|
||||
from litellm.types.guardrails import GuardrailEventHooks
|
||||
|
||||
pii_masking = _OPTIONAL_PresidioPIIMasking(
|
||||
logging_only=True,
|
||||
mock_testing=True,
|
||||
|
@ -223,22 +227,29 @@ async def test_presidio_pii_masking_logging_output_only_no_pre_api_hook():
|
|||
}
|
||||
]
|
||||
|
||||
new_data = await pii_masking.async_pre_call_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
cache=local_cache,
|
||||
data={"messages": test_messages},
|
||||
call_type="completion",
|
||||
assert (
|
||||
pii_masking.should_run_guardrail(
|
||||
data={"messages": test_messages},
|
||||
event_type=GuardrailEventHooks.pre_call,
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
assert "Jane Doe" in new_data["messages"][0]["content"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sync_mode",
|
||||
[True, False],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_presidio_pii_masking_logging_output_only_logged_response():
|
||||
async def test_presidio_pii_masking_logging_output_only_logged_response(sync_mode):
|
||||
import litellm
|
||||
|
||||
litellm.set_verbose = True
|
||||
pii_masking = _OPTIONAL_PresidioPIIMasking(
|
||||
logging_only=True,
|
||||
mock_testing=True,
|
||||
mock_redacted_text=input_b_anonymizer_results,
|
||||
guardrail_name="presidio",
|
||||
)
|
||||
|
||||
test_messages = [
|
||||
|
@ -247,15 +258,26 @@ async def test_presidio_pii_masking_logging_output_only_logged_response():
|
|||
"content": "My name is Jane Doe, who are you? Say my name in your response",
|
||||
}
|
||||
]
|
||||
with patch.object(
|
||||
pii_masking, "async_log_success_event", new=AsyncMock()
|
||||
) as mock_call:
|
||||
litellm.callbacks = [pii_masking]
|
||||
response = await litellm.acompletion(
|
||||
model="gpt-3.5-turbo", messages=test_messages, mock_response="Hi Peter!"
|
||||
)
|
||||
|
||||
await asyncio.sleep(3)
|
||||
if sync_mode:
|
||||
target_function = "log_success_event"
|
||||
mock_call = MagicMock()
|
||||
else:
|
||||
target_function = "async_log_success_event"
|
||||
mock_call = AsyncMock()
|
||||
|
||||
with patch.object(pii_masking, target_function, new=mock_call) as mock_call:
|
||||
litellm.callbacks = [pii_masking]
|
||||
if sync_mode:
|
||||
response = litellm.completion(
|
||||
model="gpt-3.5-turbo", messages=test_messages, mock_response="Hi Peter!"
|
||||
)
|
||||
time.sleep(3)
|
||||
else:
|
||||
response = await litellm.acompletion(
|
||||
model="gpt-3.5-turbo", messages=test_messages, mock_response="Hi Peter!"
|
||||
)
|
||||
await asyncio.sleep(3)
|
||||
|
||||
assert response.choices[0].message.content == "Hi Peter!" # type: ignore
|
||||
|
||||
|
@ -275,8 +297,13 @@ async def test_presidio_pii_masking_logging_output_only_logged_response_guardrai
|
|||
|
||||
import litellm
|
||||
from litellm.proxy.guardrails.init_guardrails import initialize_guardrails
|
||||
from litellm.types.guardrails import GuardrailItem, GuardrailItemSpec
|
||||
from litellm.types.guardrails import (
|
||||
GuardrailItem,
|
||||
GuardrailItemSpec,
|
||||
GuardrailEventHooks,
|
||||
)
|
||||
|
||||
litellm.set_verbose = True
|
||||
os.environ["PRESIDIO_ANALYZER_API_BASE"] = "http://localhost:5002"
|
||||
os.environ["PRESIDIO_ANONYMIZER_API_BASE"] = "http://localhost:5001"
|
||||
|
||||
|
@ -303,10 +330,15 @@ async def test_presidio_pii_masking_logging_output_only_logged_response_guardrai
|
|||
|
||||
pii_masking_obj: Optional[_OPTIONAL_PresidioPIIMasking] = None
|
||||
for callback in litellm.callbacks:
|
||||
print(f"CALLBACK: {callback}")
|
||||
if isinstance(callback, _OPTIONAL_PresidioPIIMasking):
|
||||
pii_masking_obj = callback
|
||||
|
||||
assert pii_masking_obj is not None
|
||||
|
||||
assert hasattr(pii_masking_obj, "logging_only")
|
||||
assert pii_masking_obj.logging_only is True
|
||||
assert pii_masking_obj.event_hook == GuardrailEventHooks.logging_only
|
||||
|
||||
assert pii_masking_obj.should_run_guardrail(
|
||||
data={}, event_type=GuardrailEventHooks.logging_only
|
||||
)
|
||||
|
|
|
@ -1026,3 +1026,89 @@ def test_get_public_key_from_jwk_url():
|
|||
|
||||
assert public_key is not None
|
||||
assert public_key == jwk_response[0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_end_user_jwt_auth(monkeypatch):
|
||||
import litellm
|
||||
from litellm.proxy.auth.handle_jwt import JWTHandler
|
||||
from litellm.caching import DualCache
|
||||
from litellm.proxy._types import LiteLLM_JWTAuth
|
||||
from litellm.proxy.proxy_server import user_api_key_auth
|
||||
|
||||
monkeypatch.delenv("JWT_AUDIENCE", None)
|
||||
jwt_handler = JWTHandler()
|
||||
|
||||
litellm_jwtauth = LiteLLM_JWTAuth(
|
||||
end_user_id_jwt_field="sub",
|
||||
)
|
||||
|
||||
cache = DualCache()
|
||||
|
||||
keys = [
|
||||
{
|
||||
"kid": "d-1733370597545",
|
||||
"alg": "RS256",
|
||||
"kty": "RSA",
|
||||
"use": "sig",
|
||||
"n": "j5Ik60FJSUIPMVdMSU8vhYvyPPH7rUsUllNI0BfBlIkgEYFk2mg4KX1XDQc6mcKFjbq9k_7TSkHWKnfPhNkkb0MdmZLKbwTrmte2k8xWDxp-rSmZpIJwC1zuPDx5joLHBgIb09-K2cPL2rvwzP75WtOr_QLXBerHAbXx8cOdI7mrSRWJ9iXbKv_pLDnZHnGNld75tztb8nCtgrywkF010jGi1xxaT8UKsTvK-QkIBkYI6m6WR9LMnG2OZm-ExuvNPUenfYUsqnynPF4SMNZwyQqJfavSLKI8uMzB2s9pcbx5HfQwIOYhMlgBHjhdDn2IUSnXSJqSsN6RQO18M2rgPQ",
|
||||
"e": "AQAB",
|
||||
},
|
||||
{
|
||||
"kid": "s-f836dd32-ef71-426a-8804-946a7f230bc9",
|
||||
"alg": "RS256",
|
||||
"kty": "RSA",
|
||||
"use": "sig",
|
||||
"n": "2A5-ZA18YKn7M4OtxsfXBc3Z7n2WyHTxbK4GEBlmD9T9TDr4sbJaI4oHfTvzsAC3H2r2YkASzrCISXMXQJjLHoeLgDVcKs8qTdLj7K5FNT9fA0kU9ayUjSGrqkz57SG7oNf9Wp__Qa-H-bs6Z8_CEfBy0JA9QSHUfrdOXp4vCB_qLn6DE0DJH9ELAq_0nktVQk_oxlvXlGtVZSZe31mNNgiD__RJMogf-SIFcYOkMLVGTTEBYiCk1mHxXS6oJZaVSWiBgHzu5wkra5AfQLUVelQaupT5H81hFPmiceEApf_2DacnqqRV4-Nl8sjhJtuTXiprVS2Z5r2pOMz_kVGNgw",
|
||||
"e": "AQAB",
|
||||
},
|
||||
]
|
||||
|
||||
cache.set_cache(
|
||||
key="litellm_jwt_auth_keys",
|
||||
value=keys,
|
||||
)
|
||||
|
||||
jwt_handler.update_environment(
|
||||
prisma_client=None,
|
||||
user_api_key_cache=cache,
|
||||
litellm_jwtauth=litellm_jwtauth,
|
||||
leeway=100000000000000,
|
||||
)
|
||||
|
||||
token = "eyJraWQiOiJkLTE3MzMzNzA1OTc1NDUiLCJ0eXAiOiJKV1QiLCJ2ZXJzaW9uIjoiNCIsImFsZyI6IlJTMjU2In0.eyJpYXQiOjE3MzM1MDcyNzcsImV4cCI6MTczMzUwNzg3Nywic3ViIjoiODFiM2U1MmEtNjdhNi00ZWZiLTk2NDUtNzA1MjdlMTAxNDc5IiwidElkIjoicHVibGljIiwic2Vzc2lvbkhhbmRsZSI6Ijg4M2Y4YWFmLWUwOTEtNGE1Ny04YTJhLTRiMjcwMmZhZjMzYyIsInJlZnJlc2hUb2tlbkhhc2gxIjoiNDVhNDRhYjlmOTMwMGQyOTY4ZjkxODZkYWQ0YTcwY2QwNjk2YzBiNTBmZmUxZmQ4ZTM2YzU1NGU0MWE4ODU0YiIsInBhcmVudFJlZnJlc2hUb2tlbkhhc2gxIjpudWxsLCJhbnRpQ3NyZlRva2VuIjpudWxsLCJpc3MiOiJodHRwOi8vbG9jYWxob3N0OjMwMDEvYXV0aC9zdCIsImxlZ2FjeV9jb21wYW55X2lkIjoxNTI0OTM5LCJsZWdhY3lfaWQiOjM5NzAyNzk1LCJzY29wZSI6WyJza2lsbF91cCJdLCJzdC1ldiI6eyJ0IjoxNzMzNTA3Mjc4NzAwLCJ2IjpmYWxzZX19.XlYrT6dRIjaZKkJtdr7C_UuxajFRbNpA9BnIsny3rxiPVyS8rhIBwxW12tZwgttRywmXrXK-msowFhWU4XdL5Qfe4lwZb2HTbDeGiQPvQTlOjWWYMhgCoKdPtjCQsAcW45rg7aQ0p42JFQPoAQa8AnGfxXpgx2vSR7njiZ3ZZyHerDdKQHyIGSFVOxoK0TgR-hxBVY__Wjg8UTKgKSz9KU_uwnPgpe2DeYmP-LTK2oeoygsVRmbldY_GrrcRe3nqYcUfFkxSs0FSsoSv35jIxiptXfCjhEB1Y5eaJhHEjlYlP2rw98JysYxjO2rZbAdUpL3itPeo3T2uh1NZr_lArw"
|
||||
|
||||
response = await jwt_handler.auth_jwt(token=token)
|
||||
|
||||
assert response is not None
|
||||
|
||||
end_user_id = jwt_handler.get_end_user_id(
|
||||
token=response,
|
||||
default_value=None,
|
||||
)
|
||||
|
||||
assert end_user_id is not None
|
||||
|
||||
## CHECK USER API KEY AUTH ##
|
||||
from starlette.datastructures import URL
|
||||
|
||||
bearer_token = "Bearer " + token
|
||||
|
||||
request = Request(scope={"type": "http"})
|
||||
request._url = URL(url="/chat/completions")
|
||||
|
||||
## 1. INITIAL TEAM CALL - should fail
|
||||
# use generated key to auth in
|
||||
setattr(
|
||||
litellm.proxy.proxy_server,
|
||||
"general_settings",
|
||||
{
|
||||
"enable_jwt_auth": True,
|
||||
},
|
||||
)
|
||||
setattr(litellm.proxy.proxy_server, "prisma_client", {})
|
||||
setattr(litellm.proxy.proxy_server, "jwt_handler", jwt_handler)
|
||||
result = await user_api_key_auth(request=request, api_key=bearer_token)
|
||||
assert (
|
||||
result.end_user_id == "81b3e52a-67a6-4efb-9645-70527e101479"
|
||||
) # jwt token decoded sub value
|
||||
|
|
|
@ -49,6 +49,7 @@ export interface budgetItem {
|
|||
max_budget: string | null;
|
||||
rpm_limit: number | null;
|
||||
tpm_limit: number | null;
|
||||
updated_at: string;
|
||||
}
|
||||
|
||||
const BudgetPanel: React.FC<BudgetSettingsPageProps> = ({ accessToken }) => {
|
||||
|
@ -65,13 +66,17 @@ const BudgetPanel: React.FC<BudgetSettingsPageProps> = ({ accessToken }) => {
|
|||
});
|
||||
}, [accessToken]);
|
||||
|
||||
|
||||
const handleEditCall = async (budget_id: string, index: number) => {
|
||||
console.log("budget_id", budget_id)
|
||||
if (accessToken == null) {
|
||||
return;
|
||||
}
|
||||
setSelectedBudget(budgetList[index])
|
||||
setIsEditModalVisible(true)
|
||||
// Find the budget first
|
||||
const budget = budgetList.find(budget => budget.budget_id === budget_id) || null;
|
||||
|
||||
// Update state and show modal after state is updated
|
||||
setSelectedBudget(budget);
|
||||
setIsEditModalVisible(true);
|
||||
};
|
||||
|
||||
const handleDeleteCall = async (budget_id: string, index: number) => {
|
||||
|
@ -90,6 +95,15 @@ const BudgetPanel: React.FC<BudgetSettingsPageProps> = ({ accessToken }) => {
|
|||
message.success("Budget Deleted.");
|
||||
};
|
||||
|
||||
const handleUpdateCall = async () => {
|
||||
if (accessToken == null) {
|
||||
return;
|
||||
}
|
||||
getBudgetList(accessToken).then((data) => {
|
||||
setBudgetList(data);
|
||||
});
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="w-full mx-auto flex-auto overflow-y-auto m-8 p-2">
|
||||
<Button
|
||||
|
@ -113,6 +127,7 @@ const BudgetPanel: React.FC<BudgetSettingsPageProps> = ({ accessToken }) => {
|
|||
setIsModalVisible={setIsEditModalVisible}
|
||||
setBudgetList={setBudgetList}
|
||||
existingBudget={selectedBudget}
|
||||
handleUpdateCall={handleUpdateCall}
|
||||
/>
|
||||
}
|
||||
<Card>
|
||||
|
@ -128,30 +143,27 @@ const BudgetPanel: React.FC<BudgetSettingsPageProps> = ({ accessToken }) => {
|
|||
</TableHead>
|
||||
|
||||
<TableBody>
|
||||
{budgetList.map((value: budgetItem, index: number) => (
|
||||
<TableRow key={index}>
|
||||
<TableCell>{value.budget_id}</TableCell>
|
||||
<TableCell>
|
||||
{value.max_budget ? value.max_budget : "n/a"}
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
{value.tpm_limit ? value.tpm_limit : "n/a"}
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
{value.rpm_limit ? value.rpm_limit : "n/a"}
|
||||
</TableCell>
|
||||
<Icon
|
||||
icon={PencilAltIcon}
|
||||
size="sm"
|
||||
onClick={() => handleEditCall(value.budget_id, index)}
|
||||
/>
|
||||
<Icon
|
||||
icon={TrashIcon}
|
||||
size="sm"
|
||||
onClick={() => handleDeleteCall(value.budget_id, index)}
|
||||
/>
|
||||
</TableRow>
|
||||
))}
|
||||
{budgetList
|
||||
.slice() // Creates a shallow copy to avoid mutating the original array
|
||||
.sort((a, b) => new Date(b.updated_at).getTime() - new Date(a.updated_at).getTime()) // Sort by updated_at in descending order
|
||||
.map((value: budgetItem, index: number) => (
|
||||
<TableRow key={index}>
|
||||
<TableCell>{value.budget_id}</TableCell>
|
||||
<TableCell>{value.max_budget ? value.max_budget : "n/a"}</TableCell>
|
||||
<TableCell>{value.tpm_limit ? value.tpm_limit : "n/a"}</TableCell>
|
||||
<TableCell>{value.rpm_limit ? value.rpm_limit : "n/a"}</TableCell>
|
||||
<Icon
|
||||
icon={PencilAltIcon}
|
||||
size="sm"
|
||||
onClick={() => handleEditCall(value.budget_id, index)}
|
||||
/>
|
||||
<Icon
|
||||
icon={TrashIcon}
|
||||
size="sm"
|
||||
onClick={() => handleDeleteCall(value.budget_id, index)}
|
||||
/>
|
||||
</TableRow>
|
||||
))}
|
||||
</TableBody>
|
||||
</Table>
|
||||
</Card>
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
import React from "react";
|
||||
import React, { useEffect } from "react";
|
||||
import {
|
||||
Button,
|
||||
TextInput,
|
||||
|
@ -17,7 +17,7 @@ import {
|
|||
Select,
|
||||
message,
|
||||
} from "antd";
|
||||
import { budgetCreateCall } from "../networking";
|
||||
import { budgetUpdateCall } from "../networking";
|
||||
import { budgetItem } from "./budget_panel";
|
||||
|
||||
interface BudgetModalProps {
|
||||
|
@ -26,15 +26,23 @@ interface BudgetModalProps {
|
|||
setIsModalVisible: React.Dispatch<React.SetStateAction<boolean>>;
|
||||
setBudgetList: React.Dispatch<React.SetStateAction<any[]>>;
|
||||
existingBudget: budgetItem
|
||||
handleUpdateCall: () => void
|
||||
}
|
||||
const EditBudgetModal: React.FC<BudgetModalProps> = ({
|
||||
isModalVisible,
|
||||
accessToken,
|
||||
setIsModalVisible,
|
||||
setBudgetList,
|
||||
existingBudget
|
||||
existingBudget,
|
||||
handleUpdateCall
|
||||
}) => {
|
||||
console.log("existingBudget", existingBudget)
|
||||
const [form] = Form.useForm();
|
||||
|
||||
useEffect(() => {
|
||||
form.setFieldsValue(existingBudget);
|
||||
}, [existingBudget, form]);
|
||||
|
||||
const handleOk = () => {
|
||||
setIsModalVisible(false);
|
||||
form.resetFields();
|
||||
|
@ -51,14 +59,14 @@ const EditBudgetModal: React.FC<BudgetModalProps> = ({
|
|||
}
|
||||
try {
|
||||
message.info("Making API Call");
|
||||
// setIsModalVisible(true);
|
||||
const response = await budgetCreateCall(accessToken, formValues);
|
||||
console.log("key create Response:", response);
|
||||
setIsModalVisible(true);
|
||||
const response = await budgetUpdateCall(accessToken, formValues);
|
||||
setBudgetList((prevData) =>
|
||||
prevData ? [...prevData, response] : [response]
|
||||
); // Check if prevData is null
|
||||
message.success("API Key Created");
|
||||
message.success("Budget Updated");
|
||||
form.resetFields();
|
||||
handleUpdateCall();
|
||||
} catch (error) {
|
||||
console.error("Error creating the key:", error);
|
||||
message.error(`Error creating the key: ${error}`, 20);
|
||||
|
|
|
@ -255,6 +255,43 @@ export const budgetCreateCall = async (
|
|||
}
|
||||
};
|
||||
|
||||
export const budgetUpdateCall = async (
|
||||
accessToken: string,
|
||||
formValues: Record<string, any> // Assuming formValues is an object
|
||||
) => {
|
||||
try {
|
||||
console.log("Form Values in budgetUpdateCall:", formValues); // Log the form values before making the API call
|
||||
|
||||
console.log("Form Values after check:", formValues);
|
||||
const url = proxyBaseUrl ? `${proxyBaseUrl}/budget/update` : `/budget/update`;
|
||||
const response = await fetch(url, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
[globalLitellmHeaderName]: `Bearer ${accessToken}`,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
...formValues, // Include formValues in the request body
|
||||
}),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorData = await response.text();
|
||||
handleError(errorData);
|
||||
console.error("Error response from the server:", errorData);
|
||||
throw new Error("Network response was not ok");
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
console.log("API Response:", data);
|
||||
return data;
|
||||
// Handle success - you might want to update some state or UI based on the created key
|
||||
} catch (error) {
|
||||
console.error("Failed to create key:", error);
|
||||
throw error;
|
||||
}
|
||||
};
|
||||
|
||||
export const invitationCreateCall = async (
|
||||
accessToken: string,
|
||||
userID: string // Assuming formValues is an object
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue