mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
Litellm dev 12 23 2024 p1 (#7383)
* feat(guardrails_endpoint.py): new `/guardrails/list` endpoint Allow users to view what the available guardrails are * docs: document new `/guardrails/list` endpoint * docs(enterprise.md): update docs * fix(openai/transcription/handler.py): support cost tracking on vtt + srt formats * fix(openai/transcriptions/handler.py): default to 'verbose_json' response format if 'text' or 'json' response_format received. ensures 'duration' param is received for all audio transcription requests * fix: fix linting errors * fix: remove unused import
This commit is contained in:
parent
564ecc728d
commit
db59e08958
11 changed files with 169 additions and 51 deletions
|
@ -9,9 +9,9 @@ Deploy managed LiteLLM Proxy within your VPC.
|
||||||
|
|
||||||
Includes all enterprise features.
|
Includes all enterprise features.
|
||||||
|
|
||||||
[**View AWS Marketplace Listing**](https://aws.amazon.com/marketplace/pp/prodview-gdm3gswgjhgjo?sr=0-1&ref_=beagle&applicationId=AWSMPContessa)
|
[**Procurement available via AWS / Azure Marketplace**](./data_security.md#legalcompliance-faqs)
|
||||||
|
|
||||||
[**Get early access**](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat)
|
[**Get 7 day trial key**](https://www.litellm.ai/#trial)
|
||||||
|
|
||||||
|
|
||||||
This covers:
|
This covers:
|
||||||
|
@ -44,6 +44,9 @@ This covers:
|
||||||
- ✅ [Custom Branding + Routes on Swagger Docs](./proxy/enterprise#swagger-docs---custom-routes--branding)
|
- ✅ [Custom Branding + Routes on Swagger Docs](./proxy/enterprise#swagger-docs---custom-routes--branding)
|
||||||
- ✅ [Public Model Hub](../docs/proxy/enterprise.md#public-model-hub)
|
- ✅ [Public Model Hub](../docs/proxy/enterprise.md#public-model-hub)
|
||||||
- ✅ [Custom Email Branding](../docs/proxy/email.md#customizing-email-branding)
|
- ✅ [Custom Email Branding](../docs/proxy/email.md#customizing-email-branding)
|
||||||
|
- **Guardrails**
|
||||||
|
- ✅ [Setting team/key based guardrails](./proxy/guardrails/quick_start.md#-control-guardrails-per-project-api-key)
|
||||||
|
- ✅ [API endpoint listing available guardrails](./proxy/guardrails/bedrock.md#list-guardrails)
|
||||||
- ✅ **Feature Prioritization**
|
- ✅ **Feature Prioritization**
|
||||||
- ✅ **Custom Integrations**
|
- ✅ **Custom Integrations**
|
||||||
- ✅ **Professional Support - Dedicated discord + slack**
|
- ✅ **Professional Support - Dedicated discord + slack**
|
||||||
|
|
|
@ -4,6 +4,8 @@ import TabItem from '@theme/TabItem';
|
||||||
|
|
||||||
# Bedrock
|
# Bedrock
|
||||||
|
|
||||||
|
LiteLLM supports Bedrock guardrails via the [Bedrock ApplyGuardrail API](https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ApplyGuardrail.html).
|
||||||
|
|
||||||
## Quick Start
|
## Quick Start
|
||||||
### 1. Define Guardrails on your LiteLLM config.yaml
|
### 1. Define Guardrails on your LiteLLM config.yaml
|
||||||
|
|
||||||
|
@ -56,7 +58,7 @@ curl -i http://localhost:4000/v1/chat/completions \
|
||||||
"messages": [
|
"messages": [
|
||||||
{"role": "user", "content": "hi my email is ishaan@berri.ai"}
|
{"role": "user", "content": "hi my email is ishaan@berri.ai"}
|
||||||
],
|
],
|
||||||
"guardrails": ["bedrock-guard"]
|
"guardrails": ["bedrock-pre-guard"]
|
||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -124,7 +126,7 @@ curl -i http://localhost:4000/v1/chat/completions \
|
||||||
"messages": [
|
"messages": [
|
||||||
{"role": "user", "content": "hi what is the weather"}
|
{"role": "user", "content": "hi what is the weather"}
|
||||||
],
|
],
|
||||||
"guardrails": ["bedrock-guard"]
|
"guardrails": ["bedrock-pre-guard"]
|
||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
@ -236,3 +236,20 @@ Expect to NOT see `+1 412-612-9992` in your server logs on your callback.
|
||||||
The `pii_masking` guardrail ran on this request because api key=sk-jNm1Zar7XfNdZXp49Z1kSQ has `"permissions": {"pii_masking": true}`
|
The `pii_masking` guardrail ran on this request because api key=sk-jNm1Zar7XfNdZXp49Z1kSQ has `"permissions": {"pii_masking": true}`
|
||||||
:::
|
:::
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
### ✨ List guardrails
|
||||||
|
|
||||||
|
Show available guardrails on the proxy server. This makes it easier for developers to know what guardrails are available / can be used.
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl -X GET 'http://0.0.0.0:4000/guardrails/list'
|
||||||
|
```
|
||||||
|
|
||||||
|
Expected response
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"guardrails": ["aporia-pre-guard", "aporia-post-guard"]
|
||||||
|
}
|
||||||
|
```
|
|
@ -512,6 +512,7 @@ def completion_cost( # noqa: PLR0915
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
call_type = _infer_call_type(call_type, completion_response) or "completion"
|
call_type = _infer_call_type(call_type, completion_response) or "completion"
|
||||||
|
|
||||||
if (
|
if (
|
||||||
(call_type == "aimage_generation" or call_type == "image_generation")
|
(call_type == "aimage_generation" or call_type == "image_generation")
|
||||||
and model is not None
|
and model is not None
|
||||||
|
|
|
@ -789,11 +789,15 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
}
|
}
|
||||||
except Exception as e: # error creating kwargs for cost calculation
|
except Exception as e: # error creating kwargs for cost calculation
|
||||||
|
debug_info = StandardLoggingModelCostFailureDebugInformation(
|
||||||
|
error_str=str(e),
|
||||||
|
traceback_str=traceback.format_exc(),
|
||||||
|
)
|
||||||
|
verbose_logger.debug(
|
||||||
|
f"response_cost_failure_debug_information: {debug_info}"
|
||||||
|
)
|
||||||
self.model_call_details["response_cost_failure_debug_information"] = (
|
self.model_call_details["response_cost_failure_debug_information"] = (
|
||||||
StandardLoggingModelCostFailureDebugInformation(
|
debug_info
|
||||||
error_str=str(e),
|
|
||||||
traceback_str=traceback.format_exc(),
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -803,19 +807,23 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
)
|
)
|
||||||
return response_cost
|
return response_cost
|
||||||
except Exception as e: # error calculating cost
|
except Exception as e: # error calculating cost
|
||||||
|
debug_info = StandardLoggingModelCostFailureDebugInformation(
|
||||||
|
error_str=str(e),
|
||||||
|
traceback_str=traceback.format_exc(),
|
||||||
|
model=response_cost_calculator_kwargs["model"],
|
||||||
|
cache_hit=response_cost_calculator_kwargs["cache_hit"],
|
||||||
|
custom_llm_provider=response_cost_calculator_kwargs[
|
||||||
|
"custom_llm_provider"
|
||||||
|
],
|
||||||
|
base_model=response_cost_calculator_kwargs["base_model"],
|
||||||
|
call_type=response_cost_calculator_kwargs["call_type"],
|
||||||
|
custom_pricing=response_cost_calculator_kwargs["custom_pricing"],
|
||||||
|
)
|
||||||
|
verbose_logger.debug(
|
||||||
|
f"response_cost_failure_debug_information: {debug_info}"
|
||||||
|
)
|
||||||
self.model_call_details["response_cost_failure_debug_information"] = (
|
self.model_call_details["response_cost_failure_debug_information"] = (
|
||||||
StandardLoggingModelCostFailureDebugInformation(
|
debug_info
|
||||||
error_str=str(e),
|
|
||||||
traceback_str=traceback.format_exc(),
|
|
||||||
model=response_cost_calculator_kwargs["model"],
|
|
||||||
cache_hit=response_cost_calculator_kwargs["cache_hit"],
|
|
||||||
custom_llm_provider=response_cost_calculator_kwargs[
|
|
||||||
"custom_llm_provider"
|
|
||||||
],
|
|
||||||
base_model=response_cost_calculator_kwargs["base_model"],
|
|
||||||
call_type=response_cost_calculator_kwargs["call_type"],
|
|
||||||
custom_pricing=response_cost_calculator_kwargs["custom_pricing"],
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
|
@ -8,7 +8,11 @@ import litellm
|
||||||
from litellm.litellm_core_utils.audio_utils.utils import get_audio_file_name
|
from litellm.litellm_core_utils.audio_utils.utils import get_audio_file_name
|
||||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||||
from litellm.types.utils import FileTypes
|
from litellm.types.utils import FileTypes
|
||||||
from litellm.utils import TranscriptionResponse, convert_to_model_response_object
|
from litellm.utils import (
|
||||||
|
TranscriptionResponse,
|
||||||
|
convert_to_model_response_object,
|
||||||
|
extract_duration_from_srt_or_vtt,
|
||||||
|
)
|
||||||
|
|
||||||
from ..openai import OpenAIChatCompletion
|
from ..openai import OpenAIChatCompletion
|
||||||
|
|
||||||
|
@ -27,18 +31,15 @@ class OpenAIAudioTranscription(OpenAIChatCompletion):
|
||||||
- call openai_aclient.audio.transcriptions.create by default
|
- call openai_aclient.audio.transcriptions.create by default
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
if litellm.return_response_headers is True:
|
raw_response = (
|
||||||
raw_response = (
|
await openai_aclient.audio.transcriptions.with_raw_response.create(
|
||||||
await openai_aclient.audio.transcriptions.with_raw_response.create(
|
**data, timeout=timeout
|
||||||
**data, timeout=timeout
|
)
|
||||||
)
|
) # type: ignore
|
||||||
) # type: ignore
|
headers = dict(raw_response.headers)
|
||||||
headers = dict(raw_response.headers)
|
response = raw_response.parse()
|
||||||
response = raw_response.parse()
|
|
||||||
return headers, response
|
return headers, response
|
||||||
else:
|
|
||||||
response = await openai_aclient.audio.transcriptions.create(**data, timeout=timeout) # type: ignore
|
|
||||||
return None, response
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
@ -84,6 +85,14 @@ class OpenAIAudioTranscription(OpenAIChatCompletion):
|
||||||
atranscription: bool = False,
|
atranscription: bool = False,
|
||||||
) -> TranscriptionResponse:
|
) -> TranscriptionResponse:
|
||||||
data = {"model": model, "file": audio_file, **optional_params}
|
data = {"model": model, "file": audio_file, **optional_params}
|
||||||
|
|
||||||
|
if "response_format" not in data or (
|
||||||
|
data["response_format"] == "text" or data["response_format"] == "json"
|
||||||
|
):
|
||||||
|
data["response_format"] = (
|
||||||
|
"verbose_json" # ensures 'duration' is received - used for cost calculation
|
||||||
|
)
|
||||||
|
|
||||||
if atranscription is True:
|
if atranscription is True:
|
||||||
return self.async_audio_transcriptions( # type: ignore
|
return self.async_audio_transcriptions( # type: ignore
|
||||||
audio_file=audio_file,
|
audio_file=audio_file,
|
||||||
|
@ -178,7 +187,9 @@ class OpenAIAudioTranscription(OpenAIChatCompletion):
|
||||||
if isinstance(response, BaseModel):
|
if isinstance(response, BaseModel):
|
||||||
stringified_response = response.model_dump()
|
stringified_response = response.model_dump()
|
||||||
else:
|
else:
|
||||||
|
duration = extract_duration_from_srt_or_vtt(response)
|
||||||
stringified_response = TranscriptionResponse(text=response).model_dump()
|
stringified_response = TranscriptionResponse(text=response).model_dump()
|
||||||
|
stringified_response["duration"] = duration
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
input=get_audio_file_name(audio_file),
|
input=get_audio_file_name(audio_file),
|
||||||
|
|
|
@ -1,17 +1,8 @@
|
||||||
model_list:
|
model_list:
|
||||||
- model_name: gpt-3.5-turbo
|
- model_name: whisper
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: azure/chatgpt-v-2
|
model: whisper-1
|
||||||
api_key: os.environ/AZURE_API_KEY
|
|
||||||
api_base: os.environ/AZURE_API_BASE
|
|
||||||
temperature: 0.2
|
|
||||||
model_info:
|
|
||||||
access_groups: ["default"]
|
|
||||||
- model_name: gpt-4o
|
|
||||||
litellm_params:
|
|
||||||
model: openai/gpt-4o
|
|
||||||
api_key: os.environ/OPENAI_API_KEY
|
api_key: os.environ/OPENAI_API_KEY
|
||||||
num_retries: 3
|
model_info:
|
||||||
|
mode: audio_transcription
|
||||||
litellm_settings:
|
|
||||||
success_callback: ["langfuse"]
|
|
50
litellm/proxy/guardrails/guardrail_endpoints.py
Normal file
50
litellm/proxy/guardrails/guardrail_endpoints.py
Normal file
|
@ -0,0 +1,50 @@
|
||||||
|
"""
|
||||||
|
CRUD ENDPOINTS FOR GUARDRAILS
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Dict, List, Optional, cast
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
|
||||||
|
from litellm.proxy._types import CommonProxyErrors
|
||||||
|
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||||
|
|
||||||
|
#### GUARDRAILS ENDPOINTS ####
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
def _get_guardrail_names_from_config(guardrails_config: List[Dict]) -> List[str]:
|
||||||
|
return [guardrail["guardrail_name"] for guardrail in guardrails_config]
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/guardrails/list",
|
||||||
|
tags=["Guardrails"],
|
||||||
|
dependencies=[Depends(user_api_key_auth)],
|
||||||
|
)
|
||||||
|
async def list_guardrails():
|
||||||
|
"""
|
||||||
|
List the guardrails that are available on the proxy server
|
||||||
|
"""
|
||||||
|
from litellm.proxy.proxy_server import premium_user, proxy_config
|
||||||
|
|
||||||
|
if not premium_user:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail={
|
||||||
|
"error": CommonProxyErrors.not_premium_user.value,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
config = proxy_config.config
|
||||||
|
|
||||||
|
_guardrails_config = cast(Optional[list[dict]], config.get("guardrails"))
|
||||||
|
|
||||||
|
if _guardrails_config is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail={"error": "No guardrails found in config"},
|
||||||
|
)
|
||||||
|
|
||||||
|
return _get_guardrail_names_from_config(config["guardrails"])
|
|
@ -167,6 +167,7 @@ from litellm.proxy.common_utils.proxy_state import ProxyState
|
||||||
from litellm.proxy.common_utils.swagger_utils import ERROR_RESPONSES
|
from litellm.proxy.common_utils.swagger_utils import ERROR_RESPONSES
|
||||||
from litellm.proxy.fine_tuning_endpoints.endpoints import router as fine_tuning_router
|
from litellm.proxy.fine_tuning_endpoints.endpoints import router as fine_tuning_router
|
||||||
from litellm.proxy.fine_tuning_endpoints.endpoints import set_fine_tuning_config
|
from litellm.proxy.fine_tuning_endpoints.endpoints import set_fine_tuning_config
|
||||||
|
from litellm.proxy.guardrails.guardrail_endpoints import router as guardrails_router
|
||||||
from litellm.proxy.guardrails.init_guardrails import (
|
from litellm.proxy.guardrails.init_guardrails import (
|
||||||
init_guardrails_v2,
|
init_guardrails_v2,
|
||||||
initialize_guardrails,
|
initialize_guardrails,
|
||||||
|
@ -4241,12 +4242,11 @@ async def audio_transcriptions(
|
||||||
await proxy_logging_obj.post_call_failure_hook(
|
await proxy_logging_obj.post_call_failure_hook(
|
||||||
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
||||||
)
|
)
|
||||||
verbose_proxy_logger.error(
|
verbose_proxy_logger.exception(
|
||||||
"litellm.proxy.proxy_server.audio_transcription(): Exception occured - {}".format(
|
"litellm.proxy.proxy_server.audio_transcription(): Exception occured - {}".format(
|
||||||
str(e)
|
str(e)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
verbose_proxy_logger.debug(traceback.format_exc())
|
|
||||||
if isinstance(e, HTTPException):
|
if isinstance(e, HTTPException):
|
||||||
raise ProxyException(
|
raise ProxyException(
|
||||||
message=getattr(e, "message", str(e.detail)),
|
message=getattr(e, "message", str(e.detail)),
|
||||||
|
@ -9219,6 +9219,7 @@ app.include_router(customer_router)
|
||||||
app.include_router(spend_management_router)
|
app.include_router(spend_management_router)
|
||||||
app.include_router(caching_router)
|
app.include_router(caching_router)
|
||||||
app.include_router(analytics_router)
|
app.include_router(analytics_router)
|
||||||
|
app.include_router(guardrails_router)
|
||||||
app.include_router(debugging_endpoints_router)
|
app.include_router(debugging_endpoints_router)
|
||||||
app.include_router(ui_crud_endpoints_router)
|
app.include_router(ui_crud_endpoints_router)
|
||||||
app.include_router(openai_files_router)
|
app.include_router(openai_files_router)
|
||||||
|
|
|
@ -1,9 +1,10 @@
|
||||||
#### What this does ####
|
#### What this does ####
|
||||||
# identifies lowest tpm deployment
|
# identifies lowest tpm deployment
|
||||||
import random
|
import random
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import token_counter
|
from litellm import token_counter
|
||||||
|
@ -278,13 +279,18 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
|
||||||
"model_group", None
|
"model_group", None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if isinstance(response_obj, BaseModel) and not hasattr(
|
||||||
|
response_obj, "usage"
|
||||||
|
):
|
||||||
|
return
|
||||||
|
|
||||||
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
|
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
|
||||||
if model_group is None or id is None:
|
if model_group is None or id is None:
|
||||||
return
|
return
|
||||||
elif isinstance(id, int):
|
elif isinstance(id, int):
|
||||||
id = str(id)
|
id = str(id)
|
||||||
|
|
||||||
total_tokens = response_obj["usage"]["total_tokens"]
|
total_tokens = cast(dict, response_obj)["usage"]["total_tokens"]
|
||||||
|
|
||||||
# ------------
|
# ------------
|
||||||
# Setup values
|
# Setup values
|
||||||
|
|
|
@ -6316,3 +6316,31 @@ def is_prompt_caching_valid_prompt(
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_logger.error(f"Error in is_prompt_caching_valid_prompt: {e}")
|
verbose_logger.error(f"Error in is_prompt_caching_valid_prompt: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def extract_duration_from_srt_or_vtt(srt_or_vtt_content: str) -> Optional[float]:
|
||||||
|
"""
|
||||||
|
Extracts the total duration (in seconds) from SRT or VTT content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
srt_or_vtt_content (str): The content of an SRT or VTT file as a string.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[float]: The total duration in seconds, or None if no timestamps are found.
|
||||||
|
"""
|
||||||
|
# Regular expression to match timestamps in the format "hh:mm:ss,ms" or "hh:mm:ss.ms"
|
||||||
|
timestamp_pattern = r"(\d{2}):(\d{2}):(\d{2})[.,](\d{3})"
|
||||||
|
|
||||||
|
timestamps = re.findall(timestamp_pattern, srt_or_vtt_content)
|
||||||
|
|
||||||
|
if not timestamps:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Convert timestamps to seconds and find the max (end time)
|
||||||
|
durations = []
|
||||||
|
for match in timestamps:
|
||||||
|
hours, minutes, seconds, milliseconds = map(int, match)
|
||||||
|
total_seconds = hours * 3600 + minutes * 60 + seconds + milliseconds / 1000.0
|
||||||
|
durations.append(total_seconds)
|
||||||
|
|
||||||
|
return max(durations) if durations else None
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue