Merge pull request #5288 from BerriAI/litellm_aporia_refactor

[Feat] V2 aporia guardrails litellm
This commit is contained in:
Ishaan Jaff 2024-08-19 20:41:45 -07:00 committed by GitHub
commit c7b3978655
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
33 changed files with 1078 additions and 337 deletions

View file

@ -149,6 +149,7 @@ from litellm.proxy.common_utils.admin_ui_utils import (
show_missing_vars_in_env,
)
from litellm.proxy.common_utils.callback_utils import (
get_applied_guardrails_header,
get_remaining_tokens_and_requests_from_request_data,
initialize_callbacks_on_proxy,
)
@ -168,7 +169,10 @@ from litellm.proxy.common_utils.openai_endpoint_utils import (
)
from litellm.proxy.fine_tuning_endpoints.endpoints import router as fine_tuning_router
from litellm.proxy.fine_tuning_endpoints.endpoints import set_fine_tuning_config
from litellm.proxy.guardrails.init_guardrails import initialize_guardrails
from litellm.proxy.guardrails.init_guardrails import (
init_guardrails_v2,
initialize_guardrails,
)
from litellm.proxy.health_check import perform_health_check
from litellm.proxy.health_endpoints._health_endpoints import router as health_router
from litellm.proxy.hooks.prompt_injection_detection import (
@ -539,6 +543,10 @@ def get_custom_headers(
)
headers.update(remaining_tokens_header)
applied_guardrails = get_applied_guardrails_header(request_data)
if applied_guardrails:
headers.update(applied_guardrails)
try:
return {
key: value for key, value in headers.items() if value not in exclude_values
@ -1937,6 +1945,11 @@ class ProxyConfig:
async_only_mode=True # only init async clients
),
) # type:ignore
# Guardrail settings
guardrails_v2 = config.get("guardrails", None)
if guardrails_v2:
init_guardrails_v2(all_guardrails=guardrails_v2)
return router, router.get_model_list(), general_settings
def get_model_info_with_id(self, model, db_model=False) -> RouterModelInfo:
@ -3139,7 +3152,7 @@ async def chat_completion(
### CALL HOOKS ### - modify outgoing data
response = await proxy_logging_obj.post_call_success_hook(
user_api_key_dict=user_api_key_dict, response=response
data=data, user_api_key_dict=user_api_key_dict, response=response
)
hidden_params = (
@ -3353,6 +3366,11 @@ async def completion(
media_type="text/event-stream",
headers=custom_headers,
)
### CALL HOOKS ### - modify outgoing data
response = await proxy_logging_obj.post_call_success_hook(
data=data, user_api_key_dict=user_api_key_dict, response=response
)
fastapi_response.headers.update(
get_custom_headers(
user_api_key_dict=user_api_key_dict,