mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
Merge pull request #5288 from BerriAI/litellm_aporia_refactor
[Feat] V2 aporia guardrails litellm
This commit is contained in:
commit
c82714757a
33 changed files with 1078 additions and 337 deletions
|
@ -149,6 +149,7 @@ from litellm.proxy.common_utils.admin_ui_utils import (
|
|||
show_missing_vars_in_env,
|
||||
)
|
||||
from litellm.proxy.common_utils.callback_utils import (
|
||||
get_applied_guardrails_header,
|
||||
get_remaining_tokens_and_requests_from_request_data,
|
||||
initialize_callbacks_on_proxy,
|
||||
)
|
||||
|
@ -168,7 +169,10 @@ from litellm.proxy.common_utils.openai_endpoint_utils import (
|
|||
)
|
||||
from litellm.proxy.fine_tuning_endpoints.endpoints import router as fine_tuning_router
|
||||
from litellm.proxy.fine_tuning_endpoints.endpoints import set_fine_tuning_config
|
||||
from litellm.proxy.guardrails.init_guardrails import initialize_guardrails
|
||||
from litellm.proxy.guardrails.init_guardrails import (
|
||||
init_guardrails_v2,
|
||||
initialize_guardrails,
|
||||
)
|
||||
from litellm.proxy.health_check import perform_health_check
|
||||
from litellm.proxy.health_endpoints._health_endpoints import router as health_router
|
||||
from litellm.proxy.hooks.prompt_injection_detection import (
|
||||
|
@ -539,6 +543,10 @@ def get_custom_headers(
|
|||
)
|
||||
headers.update(remaining_tokens_header)
|
||||
|
||||
applied_guardrails = get_applied_guardrails_header(request_data)
|
||||
if applied_guardrails:
|
||||
headers.update(applied_guardrails)
|
||||
|
||||
try:
|
||||
return {
|
||||
key: value for key, value in headers.items() if value not in exclude_values
|
||||
|
@ -1937,6 +1945,11 @@ class ProxyConfig:
|
|||
async_only_mode=True # only init async clients
|
||||
),
|
||||
) # type:ignore
|
||||
|
||||
# Guardrail settings
|
||||
guardrails_v2 = config.get("guardrails", None)
|
||||
if guardrails_v2:
|
||||
init_guardrails_v2(all_guardrails=guardrails_v2)
|
||||
return router, router.get_model_list(), general_settings
|
||||
|
||||
def get_model_info_with_id(self, model, db_model=False) -> RouterModelInfo:
|
||||
|
@ -3139,7 +3152,7 @@ async def chat_completion(
|
|||
|
||||
### CALL HOOKS ### - modify outgoing data
|
||||
response = await proxy_logging_obj.post_call_success_hook(
|
||||
user_api_key_dict=user_api_key_dict, response=response
|
||||
data=data, user_api_key_dict=user_api_key_dict, response=response
|
||||
)
|
||||
|
||||
hidden_params = (
|
||||
|
@ -3353,6 +3366,11 @@ async def completion(
|
|||
media_type="text/event-stream",
|
||||
headers=custom_headers,
|
||||
)
|
||||
### CALL HOOKS ### - modify outgoing data
|
||||
response = await proxy_logging_obj.post_call_success_hook(
|
||||
data=data, user_api_key_dict=user_api_key_dict, response=response
|
||||
)
|
||||
|
||||
fastapi_response.headers.update(
|
||||
get_custom_headers(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue