forked from phoenix/litellm-mirror
Merge branch 'main' into litellm_allow_using_azure_ad_token_auth
This commit is contained in:
commit
228252b92d
33 changed files with 802 additions and 84 deletions
|
@ -321,6 +321,9 @@ jobs:
|
|||
-e APORIA_API_BASE_2=$APORIA_API_BASE_2 \
|
||||
-e APORIA_API_KEY_2=$APORIA_API_KEY_2 \
|
||||
-e APORIA_API_BASE_1=$APORIA_API_BASE_1 \
|
||||
-e AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID \
|
||||
-e AWS_SECRET_ACCESS_KEY=$AWS_SECRET_ACCESS_KEY \
|
||||
-e AWS_REGION_NAME=$AWS_REGION_NAME \
|
||||
-e APORIA_API_KEY_1=$APORIA_API_KEY_1 \
|
||||
--name my-app \
|
||||
-v $(pwd)/litellm/proxy/example_config_yaml/otel_test_config.yaml:/app/config.yaml \
|
||||
|
|
3
docs/my-website/docs/projects/dbally.md
Normal file
3
docs/my-website/docs/projects/dbally.md
Normal file
|
@ -0,0 +1,3 @@
|
|||
Efficient, consistent and secure library for querying structured data with natural language. Query any database with over 100 LLMs ❤️ 🚅.
|
||||
|
||||
🔗 [GitHub](https://github.com/deepsense-ai/db-ally)
|
|
@ -307,8 +307,9 @@ LiteLLM supports **ALL** azure ai models. Here's a few examples:
|
|||
|
||||
| Model Name | Function Call |
|
||||
|--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| Cohere command-r-plus | `completion(model="azure/command-r-plus", messages)` |
|
||||
| Cohere command-r | `completion(model="azure/command-r", messages)` |
|
||||
| mistral-large-latest | `completion(model="azure/mistral-large-latest", messages)` |
|
||||
| Cohere command-r-plus | `completion(model="azure_ai/command-r-plus", messages)` |
|
||||
| Cohere command-r | `completion(model="azure_ai/command-r", messages)` |
|
||||
| mistral-large-latest | `completion(model="azure_ai/mistral-large-latest", messages)` |
|
||||
| AI21-Jamba-Instruct | `completion(model="azure_ai/ai21-jamba-instruct", messages)` |
|
||||
|
||||
|
||||
|
|
|
@ -727,6 +727,7 @@ general_settings:
|
|||
"completion_model": "string",
|
||||
"disable_spend_logs": "boolean", # turn off writing each transaction to the db
|
||||
"disable_master_key_return": "boolean", # turn off returning master key on UI (checked on '/user/info' endpoint)
|
||||
"disable_retry_on_max_parallel_request_limit_error": "boolean", # turn off retries when max parallel request limit is reached
|
||||
"disable_reset_budget": "boolean", # turn off reset budget scheduled task
|
||||
"disable_adding_master_key_hash_to_db": "boolean", # turn off storing master key hash in db, for spend tracking
|
||||
"enable_jwt_auth": "boolean", # allow proxy admin to auth in via jwt tokens with 'litellm_proxy_admin' in claims
|
||||
|
@ -751,7 +752,8 @@ general_settings:
|
|||
},
|
||||
"otel": true,
|
||||
"custom_auth": "string",
|
||||
"max_parallel_requests": 0,
|
||||
"max_parallel_requests": 0, # the max parallel requests allowed per deployment
|
||||
"global_max_parallel_requests": 0, # the max parallel requests allowed on the proxy all up
|
||||
"infer_model_from_keys": true,
|
||||
"background_health_checks": true,
|
||||
"health_check_interval": 300,
|
||||
|
|
135
docs/my-website/docs/proxy/guardrails/bedrock.md
Normal file
135
docs/my-website/docs/proxy/guardrails/bedrock.md
Normal file
|
@ -0,0 +1,135 @@
|
|||
import Image from '@theme/IdealImage';
|
||||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# Bedrock
|
||||
|
||||
## Quick Start
|
||||
### 1. Define Guardrails on your LiteLLM config.yaml
|
||||
|
||||
Define your guardrails under the `guardrails` section
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: openai/gpt-3.5-turbo
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
|
||||
guardrails:
|
||||
- guardrail_name: "bedrock-pre-guard"
|
||||
litellm_params:
|
||||
guardrail: bedrock # supported values: "aporia", "bedrock", "lakera"
|
||||
mode: "during_call"
|
||||
guardrailIdentifier: ff6ujrregl1q # your guardrail ID on bedrock
|
||||
guardrailVersion: "DRAFT" # your guardrail version on bedrock
|
||||
|
||||
```
|
||||
|
||||
#### Supported values for `mode`
|
||||
|
||||
- `pre_call` Run **before** LLM call, on **input**
|
||||
- `post_call` Run **after** LLM call, on **input & output**
|
||||
- `during_call` Run **during** LLM call, on **input** Same as `pre_call` but runs in parallel as LLM call. Response not returned until guardrail check completes
|
||||
|
||||
### 2. Start LiteLLM Gateway
|
||||
|
||||
|
||||
```shell
|
||||
litellm --config config.yaml --detailed_debug
|
||||
```
|
||||
|
||||
### 3. Test request
|
||||
|
||||
**[Langchain, OpenAI SDK Usage Examples](../proxy/user_keys##request-format)**
|
||||
|
||||
<Tabs>
|
||||
<TabItem label="Unsuccessful call" value = "not-allowed">
|
||||
|
||||
Expect this to fail since since `ishaan@berri.ai` in the request is PII
|
||||
|
||||
```shell
|
||||
curl -i http://localhost:4000/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer sk-npnwjPQciVRok5yNZgKmFQ" \
|
||||
-d '{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{"role": "user", "content": "hi my email is ishaan@berri.ai"}
|
||||
],
|
||||
"guardrails": ["bedrock-guard"]
|
||||
}'
|
||||
```
|
||||
|
||||
Expected response on failure
|
||||
|
||||
```shell
|
||||
{
|
||||
"error": {
|
||||
"message": {
|
||||
"error": "Violated guardrail policy",
|
||||
"bedrock_guardrail_response": {
|
||||
"action": "GUARDRAIL_INTERVENED",
|
||||
"assessments": [
|
||||
{
|
||||
"topicPolicy": {
|
||||
"topics": [
|
||||
{
|
||||
"action": "BLOCKED",
|
||||
"name": "Coffee",
|
||||
"type": "DENY"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
],
|
||||
"blockedResponse": "Sorry, the model cannot answer this question. coffee guardrail applied ",
|
||||
"output": [
|
||||
{
|
||||
"text": "Sorry, the model cannot answer this question. coffee guardrail applied "
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"text": "Sorry, the model cannot answer this question. coffee guardrail applied "
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"contentPolicyUnits": 0,
|
||||
"contextualGroundingPolicyUnits": 0,
|
||||
"sensitiveInformationPolicyFreeUnits": 0,
|
||||
"sensitiveInformationPolicyUnits": 0,
|
||||
"topicPolicyUnits": 1,
|
||||
"wordPolicyUnits": 0
|
||||
}
|
||||
}
|
||||
},
|
||||
"type": "None",
|
||||
"param": "None",
|
||||
"code": "400"
|
||||
}
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
<TabItem label="Successful Call " value = "allowed">
|
||||
|
||||
```shell
|
||||
curl -i http://localhost:4000/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer sk-npnwjPQciVRok5yNZgKmFQ" \
|
||||
-d '{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{"role": "user", "content": "hi what is the weather"}
|
||||
],
|
||||
"guardrails": ["bedrock-guard"]
|
||||
}'
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
|
||||
</Tabs>
|
||||
|
|
@ -68,6 +68,15 @@ http://localhost:4000/metrics
|
|||
| `litellm_total_tokens` | input + output tokens per `"user", "key", "model", "team", "end-user"` |
|
||||
| `litellm_llm_api_failed_requests_metric` | Number of failed LLM API requests per `"user", "key", "model", "team", "end-user"` |
|
||||
|
||||
### Request Latency Metrics
|
||||
|
||||
| Metric Name | Description |
|
||||
|----------------------|--------------------------------------|
|
||||
| `litellm_request_total_latency_metric` | Total latency (seconds) for a request to LiteLLM Proxy Server - tracked for labels `litellm_call_id`, `model` |
|
||||
| `litellm_llm_api_latency_metric` | latency (seconds) for just the LLM API call - tracked for labels `litellm_call_id`, `model` |
|
||||
|
||||
|
||||
|
||||
### LLM API / Provider Metrics
|
||||
|
||||
| Metric Name | Description |
|
||||
|
|
|
@ -54,7 +54,7 @@ const sidebars = {
|
|||
{
|
||||
type: "category",
|
||||
label: "🛡️ [Beta] Guardrails",
|
||||
items: ["proxy/guardrails/quick_start", "proxy/guardrails/aporia_api", "proxy/guardrails/lakera_ai"],
|
||||
items: ["proxy/guardrails/quick_start", "proxy/guardrails/aporia_api", "proxy/guardrails/lakera_ai", "proxy/guardrails/bedrock"],
|
||||
},
|
||||
{
|
||||
type: "category",
|
||||
|
@ -292,6 +292,7 @@ const sidebars = {
|
|||
items: [
|
||||
"projects/Docq.AI",
|
||||
"projects/OpenInterpreter",
|
||||
"projects/dbally",
|
||||
"projects/FastREPL",
|
||||
"projects/PROMPTMETHEUS",
|
||||
"projects/Codium PR Agent",
|
||||
|
|
8
litellm-js/spend-logs/package-lock.json
generated
8
litellm-js/spend-logs/package-lock.json
generated
|
@ -6,7 +6,7 @@
|
|||
"": {
|
||||
"dependencies": {
|
||||
"@hono/node-server": "^1.10.1",
|
||||
"hono": "^4.2.7"
|
||||
"hono": "^4.5.8"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/node": "^20.11.17",
|
||||
|
@ -463,9 +463,9 @@
|
|||
}
|
||||
},
|
||||
"node_modules/hono": {
|
||||
"version": "4.2.7",
|
||||
"resolved": "https://registry.npmjs.org/hono/-/hono-4.2.7.tgz",
|
||||
"integrity": "sha512-k1xHi86tJnRIVvqhFMBDGFKJ8r5O+bEsT4P59ZK59r0F300Xd910/r237inVfuT/VmE86RQQffX4OYNda6dLXw==",
|
||||
"version": "4.5.8",
|
||||
"resolved": "https://registry.npmjs.org/hono/-/hono-4.5.8.tgz",
|
||||
"integrity": "sha512-pqpSlcdqGkpTTRpLYU1PnCz52gVr0zVR9H5GzMyJWuKQLLEBQxh96q45QizJ2PPX8NATtz2mu31/PKW/Jt+90Q==",
|
||||
"engines": {
|
||||
"node": ">=16.0.0"
|
||||
}
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
},
|
||||
"dependencies": {
|
||||
"@hono/node-server": "^1.10.1",
|
||||
"hono": "^4.2.7"
|
||||
"hono": "^4.5.8"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/node": "^20.11.17",
|
||||
|
|
|
@ -7,13 +7,17 @@
|
|||
#
|
||||
# Thank you users! We ❤️ you! - Krrish & Ishaan
|
||||
|
||||
import inspect
|
||||
|
||||
# s/o [@Frank Colson](https://www.linkedin.com/in/frank-colson-422b9b183/) for this redis implementation
|
||||
import os
|
||||
import inspect
|
||||
import redis, litellm # type: ignore
|
||||
import redis.asyncio as async_redis # type: ignore
|
||||
from typing import List, Optional
|
||||
|
||||
import redis # type: ignore
|
||||
import redis.asyncio as async_redis # type: ignore
|
||||
|
||||
import litellm
|
||||
|
||||
|
||||
def _get_redis_kwargs():
|
||||
arg_spec = inspect.getfullargspec(redis.Redis)
|
||||
|
@ -51,6 +55,19 @@ def _get_redis_url_kwargs(client=None):
|
|||
return available_args
|
||||
|
||||
|
||||
def _get_redis_cluster_kwargs(client=None):
|
||||
if client is None:
|
||||
client = redis.Redis.from_url
|
||||
arg_spec = inspect.getfullargspec(redis.RedisCluster)
|
||||
|
||||
# Only allow primitive arguments
|
||||
exclude_args = {"self", "connection_pool", "retry", "host", "port", "startup_nodes"}
|
||||
|
||||
available_args = [x for x in arg_spec.args if x not in exclude_args]
|
||||
|
||||
return available_args
|
||||
|
||||
|
||||
def _get_redis_env_kwarg_mapping():
|
||||
PREFIX = "REDIS_"
|
||||
|
||||
|
@ -124,6 +141,22 @@ def get_redis_client(**env_overrides):
|
|||
url_kwargs[arg] = redis_kwargs[arg]
|
||||
|
||||
return redis.Redis.from_url(**url_kwargs)
|
||||
|
||||
if "startup_nodes" in redis_kwargs:
|
||||
from redis.cluster import ClusterNode
|
||||
|
||||
args = _get_redis_cluster_kwargs()
|
||||
cluster_kwargs = {}
|
||||
for arg in redis_kwargs:
|
||||
if arg in args:
|
||||
cluster_kwargs[arg] = redis_kwargs[arg]
|
||||
|
||||
new_startup_nodes: List[ClusterNode] = []
|
||||
|
||||
for item in redis_kwargs["startup_nodes"]:
|
||||
new_startup_nodes.append(ClusterNode(**item))
|
||||
redis_kwargs.pop("startup_nodes")
|
||||
return redis.RedisCluster(startup_nodes=new_startup_nodes, **cluster_kwargs)
|
||||
return redis.Redis(**redis_kwargs)
|
||||
|
||||
|
||||
|
@ -143,6 +176,24 @@ def get_redis_async_client(**env_overrides):
|
|||
)
|
||||
return async_redis.Redis.from_url(**url_kwargs)
|
||||
|
||||
if "startup_nodes" in redis_kwargs:
|
||||
from redis.cluster import ClusterNode
|
||||
|
||||
args = _get_redis_cluster_kwargs()
|
||||
cluster_kwargs = {}
|
||||
for arg in redis_kwargs:
|
||||
if arg in args:
|
||||
cluster_kwargs[arg] = redis_kwargs[arg]
|
||||
|
||||
new_startup_nodes: List[ClusterNode] = []
|
||||
|
||||
for item in redis_kwargs["startup_nodes"]:
|
||||
new_startup_nodes.append(ClusterNode(**item))
|
||||
redis_kwargs.pop("startup_nodes")
|
||||
return async_redis.RedisCluster(
|
||||
startup_nodes=new_startup_nodes, **cluster_kwargs
|
||||
)
|
||||
|
||||
return async_redis.Redis(
|
||||
socket_timeout=5,
|
||||
**redis_kwargs,
|
||||
|
@ -160,4 +211,5 @@ def get_redis_connection_pool(**env_overrides):
|
|||
connection_class = async_redis.SSLConnection
|
||||
redis_kwargs.pop("ssl", None)
|
||||
redis_kwargs["connection_class"] = connection_class
|
||||
redis_kwargs.pop("startup_nodes", None)
|
||||
return async_redis.BlockingConnectionPool(timeout=5, **redis_kwargs)
|
||||
|
|
|
@ -203,6 +203,7 @@ class RedisCache(BaseCache):
|
|||
password=None,
|
||||
redis_flush_size=100,
|
||||
namespace: Optional[str] = None,
|
||||
startup_nodes: Optional[List] = None, # for redis-cluster
|
||||
**kwargs,
|
||||
):
|
||||
import redis
|
||||
|
@ -218,7 +219,8 @@ class RedisCache(BaseCache):
|
|||
redis_kwargs["port"] = port
|
||||
if password is not None:
|
||||
redis_kwargs["password"] = password
|
||||
|
||||
if startup_nodes is not None:
|
||||
redis_kwargs["startup_nodes"] = startup_nodes
|
||||
### HEALTH MONITORING OBJECT ###
|
||||
if kwargs.get("service_logger_obj", None) is not None and isinstance(
|
||||
kwargs["service_logger_obj"], ServiceLogging
|
||||
|
@ -246,7 +248,7 @@ class RedisCache(BaseCache):
|
|||
### ASYNC HEALTH PING ###
|
||||
try:
|
||||
# asyncio.get_running_loop().create_task(self.ping())
|
||||
result = asyncio.get_running_loop().create_task(self.ping())
|
||||
_ = asyncio.get_running_loop().create_task(self.ping())
|
||||
except Exception as e:
|
||||
if "no running event loop" in str(e):
|
||||
verbose_logger.debug(
|
||||
|
@ -2123,6 +2125,7 @@ class Cache:
|
|||
redis_semantic_cache_use_async=False,
|
||||
redis_semantic_cache_embedding_model="text-embedding-ada-002",
|
||||
redis_flush_size=None,
|
||||
redis_startup_nodes: Optional[List] = None,
|
||||
disk_cache_dir=None,
|
||||
qdrant_api_base: Optional[str] = None,
|
||||
qdrant_api_key: Optional[str] = None,
|
||||
|
@ -2155,7 +2158,12 @@ class Cache:
|
|||
"""
|
||||
if type == "redis":
|
||||
self.cache: BaseCache = RedisCache(
|
||||
host, port, password, redis_flush_size, **kwargs
|
||||
host,
|
||||
port,
|
||||
password,
|
||||
redis_flush_size,
|
||||
startup_nodes=redis_startup_nodes,
|
||||
**kwargs,
|
||||
)
|
||||
elif type == "redis-semantic":
|
||||
self.cache = RedisSemanticCache(
|
||||
|
|
|
@ -60,6 +60,25 @@ class PrometheusLogger(CustomLogger):
|
|||
],
|
||||
)
|
||||
|
||||
# request latency metrics
|
||||
self.litellm_request_total_latency_metric = Histogram(
|
||||
"litellm_request_total_latency_metric",
|
||||
"Total latency (seconds) for a request to LiteLLM",
|
||||
labelnames=[
|
||||
"model",
|
||||
"litellm_call_id",
|
||||
],
|
||||
)
|
||||
|
||||
self.litellm_llm_api_latency_metric = Histogram(
|
||||
"litellm_llm_api_latency_metric",
|
||||
"Total latency (seconds) for a models LLM API call",
|
||||
labelnames=[
|
||||
"model",
|
||||
"litellm_call_id",
|
||||
],
|
||||
)
|
||||
|
||||
# Counter for spend
|
||||
self.litellm_spend_metric = Counter(
|
||||
"litellm_spend_metric",
|
||||
|
@ -103,8 +122,6 @@ class PrometheusLogger(CustomLogger):
|
|||
"Remaining budget for api key",
|
||||
labelnames=["hashed_api_key", "api_key_alias"],
|
||||
)
|
||||
# Litellm-Enterprise Metrics
|
||||
if premium_user is True:
|
||||
|
||||
########################################
|
||||
# LiteLLM Virtual API KEY metrics
|
||||
|
@ -123,6 +140,9 @@ class PrometheusLogger(CustomLogger):
|
|||
labelnames=["hashed_api_key", "api_key_alias", "model"],
|
||||
)
|
||||
|
||||
# Litellm-Enterprise Metrics
|
||||
if premium_user is True:
|
||||
|
||||
########################################
|
||||
# LLM API Deployment Metrics / analytics
|
||||
########################################
|
||||
|
@ -328,6 +348,25 @@ class PrometheusLogger(CustomLogger):
|
|||
user_api_key, user_api_key_alias, model_group
|
||||
).set(remaining_tokens)
|
||||
|
||||
# latency metrics
|
||||
total_time: timedelta = kwargs.get("end_time") - kwargs.get("start_time")
|
||||
total_time_seconds = total_time.total_seconds()
|
||||
api_call_total_time: timedelta = kwargs.get("end_time") - kwargs.get(
|
||||
"api_call_start_time"
|
||||
)
|
||||
|
||||
api_call_total_time_seconds = api_call_total_time.total_seconds()
|
||||
|
||||
litellm_call_id = kwargs.get("litellm_call_id")
|
||||
|
||||
self.litellm_request_total_latency_metric.labels(
|
||||
model, litellm_call_id
|
||||
).observe(total_time_seconds)
|
||||
|
||||
self.litellm_llm_api_latency_metric.labels(model, litellm_call_id).observe(
|
||||
api_call_total_time_seconds
|
||||
)
|
||||
|
||||
# set x-ratelimit headers
|
||||
if premium_user is True:
|
||||
self.set_llm_deployment_success_metrics(
|
||||
|
|
|
@ -354,6 +354,8 @@ class Logging:
|
|||
str(e)
|
||||
)
|
||||
)
|
||||
|
||||
self.model_call_details["api_call_start_time"] = datetime.datetime.now()
|
||||
# Input Integration Logging -> If you want to log the fact that an attempt to call the model was made
|
||||
callbacks = litellm.input_callback + self.dynamic_input_callbacks
|
||||
for callback in callbacks:
|
||||
|
|
|
@ -84,6 +84,7 @@ BEDROCK_CONVERSE_MODELS = [
|
|||
"meta.llama3-1-8b-instruct-v1:0",
|
||||
"meta.llama3-1-70b-instruct-v1:0",
|
||||
"meta.llama3-1-405b-instruct-v1:0",
|
||||
"meta.llama3-70b-instruct-v1:0",
|
||||
"mistral.mistral-large-2407-v1:0",
|
||||
]
|
||||
|
||||
|
@ -1480,7 +1481,7 @@ class BedrockConverseLLM(BaseAWSLLM):
|
|||
optional_params: dict,
|
||||
acompletion: bool,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
litellm_params=None,
|
||||
litellm_params: dict,
|
||||
logger_fn=None,
|
||||
extra_headers: Optional[dict] = None,
|
||||
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
|
||||
|
@ -1596,6 +1597,14 @@ class BedrockConverseLLM(BaseAWSLLM):
|
|||
supported_tool_call_params = ["tools", "tool_choice"]
|
||||
supported_guardrail_params = ["guardrailConfig"]
|
||||
## TRANSFORMATION ##
|
||||
|
||||
bedrock_messages: List[MessageBlock] = _bedrock_converse_messages_pt(
|
||||
messages=messages,
|
||||
model=model,
|
||||
llm_provider="bedrock_converse",
|
||||
user_continue_message=litellm_params.pop("user_continue_message", None),
|
||||
)
|
||||
|
||||
# send all model-specific params in 'additional_request_params'
|
||||
for k, v in inference_params.items():
|
||||
if (
|
||||
|
@ -1608,11 +1617,6 @@ class BedrockConverseLLM(BaseAWSLLM):
|
|||
for key in additional_request_keys:
|
||||
inference_params.pop(key, None)
|
||||
|
||||
bedrock_messages: List[MessageBlock] = _bedrock_converse_messages_pt(
|
||||
messages=messages,
|
||||
model=model,
|
||||
llm_provider="bedrock_converse",
|
||||
)
|
||||
bedrock_tools: List[ToolBlock] = _bedrock_tools_pt(
|
||||
inference_params.pop("tools", [])
|
||||
)
|
||||
|
|
|
@ -124,12 +124,14 @@ class CohereConfig:
|
|||
}
|
||||
|
||||
|
||||
def validate_environment(api_key):
|
||||
headers = {
|
||||
def validate_environment(api_key, headers: dict):
|
||||
headers.update(
|
||||
{
|
||||
"Request-Source": "unspecified:litellm",
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
}
|
||||
)
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
return headers
|
||||
|
@ -144,11 +146,12 @@ def completion(
|
|||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
headers: dict,
|
||||
optional_params=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
):
|
||||
headers = validate_environment(api_key)
|
||||
headers = validate_environment(api_key, headers=headers)
|
||||
completion_url = api_base
|
||||
model = model
|
||||
prompt = " ".join(message["content"] for message in messages)
|
||||
|
@ -338,13 +341,14 @@ def embedding(
|
|||
model_response: litellm.EmbeddingResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
optional_params: dict,
|
||||
headers: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
aembedding: Optional[bool] = None,
|
||||
timeout: Union[float, httpx.Timeout] = httpx.Timeout(None),
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
):
|
||||
headers = validate_environment(api_key)
|
||||
headers = validate_environment(api_key, headers=headers)
|
||||
embed_url = "https://api.cohere.ai/v1/embed"
|
||||
model = model
|
||||
data = {"model": model, "texts": input, **optional_params}
|
||||
|
|
|
@ -116,12 +116,14 @@ class CohereChatConfig:
|
|||
}
|
||||
|
||||
|
||||
def validate_environment(api_key):
|
||||
headers = {
|
||||
def validate_environment(api_key, headers: dict):
|
||||
headers.update(
|
||||
{
|
||||
"Request-Source": "unspecified:litellm",
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
}
|
||||
)
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
return headers
|
||||
|
@ -203,13 +205,14 @@ def completion(
|
|||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
optional_params: dict,
|
||||
headers: dict,
|
||||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
):
|
||||
headers = validate_environment(api_key)
|
||||
headers = validate_environment(api_key, headers=headers)
|
||||
completion_url = api_base
|
||||
model = model
|
||||
most_recent_message, chat_history = cohere_messages_pt_v2(
|
||||
|
|
|
@ -38,6 +38,18 @@ def prompt_injection_detection_default_pt():
|
|||
|
||||
BAD_MESSAGE_ERROR_STR = "Invalid Message "
|
||||
|
||||
# used to interweave user messages, to ensure user/assistant alternating
|
||||
DEFAULT_USER_CONTINUE_MESSAGE = {
|
||||
"role": "user",
|
||||
"content": "Please continue.",
|
||||
} # similar to autogen. Only used if `litellm.modify_params=True`.
|
||||
|
||||
# used to interweave assistant messages, to ensure user/assistant alternating
|
||||
DEFAULT_ASSISTANT_CONTINUE_MESSAGE = {
|
||||
"role": "assistant",
|
||||
"content": "Please continue.",
|
||||
} # similar to autogen. Only used if `litellm.modify_params=True`.
|
||||
|
||||
|
||||
def map_system_message_pt(messages: list) -> list:
|
||||
"""
|
||||
|
@ -2254,6 +2266,7 @@ def _bedrock_converse_messages_pt(
|
|||
messages: List,
|
||||
model: str,
|
||||
llm_provider: str,
|
||||
user_continue_message: Optional[dict] = None,
|
||||
) -> List[BedrockMessageBlock]:
|
||||
"""
|
||||
Converts given messages from OpenAI format to Bedrock format
|
||||
|
@ -2264,6 +2277,21 @@ def _bedrock_converse_messages_pt(
|
|||
|
||||
contents: List[BedrockMessageBlock] = []
|
||||
msg_i = 0
|
||||
|
||||
# if initial message is assistant message
|
||||
if messages[0].get("role") is not None and messages[0]["role"] == "assistant":
|
||||
if user_continue_message is not None:
|
||||
messages.insert(0, user_continue_message)
|
||||
elif litellm.modify_params:
|
||||
messages.insert(0, DEFAULT_USER_CONTINUE_MESSAGE)
|
||||
|
||||
# if final message is assistant message
|
||||
if messages[-1].get("role") is not None and messages[-1]["role"] == "assistant":
|
||||
if user_continue_message is not None:
|
||||
messages.append(user_continue_message)
|
||||
elif litellm.modify_params:
|
||||
messages.append(DEFAULT_USER_CONTINUE_MESSAGE)
|
||||
|
||||
while msg_i < len(messages):
|
||||
user_content: List[BedrockContentBlock] = []
|
||||
init_msg_i = msg_i
|
||||
|
@ -2344,6 +2372,7 @@ def _bedrock_converse_messages_pt(
|
|||
model=model,
|
||||
llm_provider=llm_provider,
|
||||
)
|
||||
|
||||
return contents
|
||||
|
||||
|
||||
|
|
|
@ -944,6 +944,8 @@ def completion(
|
|||
cooldown_time=cooldown_time,
|
||||
text_completion=kwargs.get("text_completion"),
|
||||
azure_ad_token_provider=kwargs.get("azure_ad_token_provider"),
|
||||
user_continue_message=kwargs.get("user_continue_message"),
|
||||
|
||||
)
|
||||
logging.update_environment_variables(
|
||||
model=model,
|
||||
|
@ -1635,6 +1637,13 @@ def completion(
|
|||
or "https://api.cohere.ai/v1/generate"
|
||||
)
|
||||
|
||||
headers = headers or litellm.headers or {}
|
||||
if headers is None:
|
||||
headers = {}
|
||||
|
||||
if extra_headers is not None:
|
||||
headers.update(extra_headers)
|
||||
|
||||
model_response = cohere.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
|
@ -1645,6 +1654,7 @@ def completion(
|
|||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
encoding=encoding,
|
||||
headers=headers,
|
||||
api_key=cohere_key,
|
||||
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
|
||||
)
|
||||
|
@ -1675,6 +1685,13 @@ def completion(
|
|||
or "https://api.cohere.ai/v1/chat"
|
||||
)
|
||||
|
||||
headers = headers or litellm.headers or {}
|
||||
if headers is None:
|
||||
headers = {}
|
||||
|
||||
if extra_headers is not None:
|
||||
headers.update(extra_headers)
|
||||
|
||||
model_response = cohere_chat.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
|
@ -1683,6 +1700,7 @@ def completion(
|
|||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
logger_fn=logger_fn,
|
||||
encoding=encoding,
|
||||
api_key=cohere_key,
|
||||
|
@ -2289,7 +2307,7 @@ def completion(
|
|||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
litellm_params=litellm_params, # type: ignore
|
||||
logger_fn=logger_fn,
|
||||
encoding=encoding,
|
||||
logging_obj=logging,
|
||||
|
@ -3159,6 +3177,7 @@ def embedding(
|
|||
encoding_format = kwargs.get("encoding_format", None)
|
||||
proxy_server_request = kwargs.get("proxy_server_request", None)
|
||||
aembedding = kwargs.get("aembedding", None)
|
||||
extra_headers = kwargs.get("extra_headers", None)
|
||||
### CUSTOM MODEL COST ###
|
||||
input_cost_per_token = kwargs.get("input_cost_per_token", None)
|
||||
output_cost_per_token = kwargs.get("output_cost_per_token", None)
|
||||
|
@ -3234,6 +3253,7 @@ def embedding(
|
|||
"tenant_id",
|
||||
"client_id",
|
||||
"client_secret",
|
||||
"extra_headers",
|
||||
]
|
||||
default_params = openai_params + litellm_params
|
||||
non_default_params = {
|
||||
|
@ -3297,7 +3317,7 @@ def embedding(
|
|||
"cooldown_time": cooldown_time,
|
||||
},
|
||||
)
|
||||
if azure == True or custom_llm_provider == "azure":
|
||||
if azure is True or custom_llm_provider == "azure":
|
||||
# azure configs
|
||||
api_type = get_secret("AZURE_API_TYPE") or "azure"
|
||||
|
||||
|
@ -3403,12 +3423,18 @@ def embedding(
|
|||
or get_secret("CO_API_KEY")
|
||||
or litellm.api_key
|
||||
)
|
||||
|
||||
if extra_headers is not None and isinstance(extra_headers, dict):
|
||||
headers = extra_headers
|
||||
else:
|
||||
headers = {}
|
||||
response = cohere.embedding(
|
||||
model=model,
|
||||
input=input,
|
||||
optional_params=optional_params,
|
||||
encoding=encoding,
|
||||
api_key=cohere_key, # type: ignore
|
||||
headers=headers,
|
||||
logging_obj=logging,
|
||||
model_response=EmbeddingResponse(),
|
||||
aembedding=aembedding,
|
||||
|
|
|
@ -2,12 +2,3 @@ model_list:
|
|||
- model_name: "*"
|
||||
litellm_params:
|
||||
model: "*"
|
||||
|
||||
litellm_settings:
|
||||
success_callback: ["s3"]
|
||||
cache: true
|
||||
s3_callback_params:
|
||||
s3_bucket_name: mytestbucketlitellm # AWS Bucket Name for S3
|
||||
s3_region_name: us-west-2 # AWS Region Name for S3
|
||||
s3_aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID # us os.environ/<variable name> to pass environment variables. This is AWS Access Key ID for S3
|
||||
s3_aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY # AWS Secret Access Key for S3
|
||||
|
|
|
@ -66,7 +66,7 @@ def common_checks(
|
|||
raise Exception(
|
||||
f"Team={team_object.team_id} is blocked. Update via `/team/unblock` if your admin."
|
||||
)
|
||||
# 2. If user can call model
|
||||
# 2. If team can call model
|
||||
if (
|
||||
_model is not None
|
||||
and team_object is not None
|
||||
|
@ -74,7 +74,11 @@ def common_checks(
|
|||
and _model not in team_object.models
|
||||
):
|
||||
# this means the team has access to all models on the proxy
|
||||
if "all-proxy-models" in team_object.models:
|
||||
if (
|
||||
"all-proxy-models" in team_object.models
|
||||
or "*" in team_object.models
|
||||
or "openai/*" in team_object.models
|
||||
):
|
||||
# this means the team has access to all models on the proxy
|
||||
pass
|
||||
# check if the team model is an access_group
|
||||
|
|
|
@ -22,3 +22,9 @@ guardrails:
|
|||
mode: "post_call"
|
||||
api_key: os.environ/APORIA_API_KEY_2
|
||||
api_base: os.environ/APORIA_API_BASE_2
|
||||
- guardrail_name: "bedrock-pre-guard"
|
||||
litellm_params:
|
||||
guardrail: bedrock # supported values: "aporia", "bedrock", "lakera"
|
||||
mode: "pre_call"
|
||||
guardrailIdentifier: ff6ujrregl1q
|
||||
guardrailVersion: "DRAFT"
|
289
litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py
Normal file
289
litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py
Normal file
|
@ -0,0 +1,289 @@
|
|||
# +-------------------------------------------------------------+
|
||||
#
|
||||
# Use Bedrock Guardrails for your LLM calls
|
||||
#
|
||||
# +-------------------------------------------------------------+
|
||||
# Thank you users! We ❤️ you! - Krrish & Ishaan
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
import traceback
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
import aiohttp
|
||||
import httpx
|
||||
from fastapi import HTTPException
|
||||
|
||||
import litellm
|
||||
from litellm import get_secret
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.caching import DualCache
|
||||
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||
from litellm.litellm_core_utils.logging_utils import (
|
||||
convert_litellm_response_object_to_str,
|
||||
)
|
||||
from litellm.llms.base_aws_llm import BaseAWSLLM
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
_get_async_httpx_client,
|
||||
)
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata
|
||||
from litellm.types.guardrails import (
|
||||
BedrockContentItem,
|
||||
BedrockRequest,
|
||||
BedrockTextContent,
|
||||
GuardrailEventHooks,
|
||||
)
|
||||
|
||||
GUARDRAIL_NAME = "bedrock"
|
||||
|
||||
|
||||
class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
|
||||
def __init__(
|
||||
self,
|
||||
guardrailIdentifier: Optional[str] = None,
|
||||
guardrailVersion: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.async_handler = _get_async_httpx_client()
|
||||
self.guardrailIdentifier = guardrailIdentifier
|
||||
self.guardrailVersion = guardrailVersion
|
||||
|
||||
# store kwargs as optional_params
|
||||
self.optional_params = kwargs
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def convert_to_bedrock_format(
|
||||
self,
|
||||
messages: Optional[List[Dict[str, str]]] = None,
|
||||
response: Optional[Union[Any, litellm.ModelResponse]] = None,
|
||||
) -> BedrockRequest:
|
||||
bedrock_request: BedrockRequest = BedrockRequest(source="INPUT")
|
||||
bedrock_request_content: List[BedrockContentItem] = []
|
||||
|
||||
if messages:
|
||||
for message in messages:
|
||||
content = message.get("content")
|
||||
if isinstance(content, str):
|
||||
bedrock_content_item = BedrockContentItem(
|
||||
text=BedrockTextContent(text=content)
|
||||
)
|
||||
bedrock_request_content.append(bedrock_content_item)
|
||||
|
||||
bedrock_request["content"] = bedrock_request_content
|
||||
if response:
|
||||
bedrock_request["source"] = "OUTPUT"
|
||||
if isinstance(response, litellm.ModelResponse):
|
||||
for choice in response.choices:
|
||||
if isinstance(choice, litellm.Choices):
|
||||
if choice.message.content and isinstance(
|
||||
choice.message.content, str
|
||||
):
|
||||
bedrock_content_item = BedrockContentItem(
|
||||
text=BedrockTextContent(text=choice.message.content)
|
||||
)
|
||||
bedrock_request_content.append(bedrock_content_item)
|
||||
bedrock_request["content"] = bedrock_request_content
|
||||
return bedrock_request
|
||||
|
||||
#### CALL HOOKS - proxy only ####
|
||||
def _load_credentials(
|
||||
self,
|
||||
):
|
||||
try:
|
||||
from botocore.credentials import Credentials
|
||||
except ImportError as e:
|
||||
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||
## CREDENTIALS ##
|
||||
# pop aws_secret_access_key, aws_access_key_id, aws_session_token, aws_region_name from kwargs, since completion calls fail with them
|
||||
aws_secret_access_key = self.optional_params.pop("aws_secret_access_key", None)
|
||||
aws_access_key_id = self.optional_params.pop("aws_access_key_id", None)
|
||||
aws_session_token = self.optional_params.pop("aws_session_token", None)
|
||||
aws_region_name = self.optional_params.pop("aws_region_name", None)
|
||||
aws_role_name = self.optional_params.pop("aws_role_name", None)
|
||||
aws_session_name = self.optional_params.pop("aws_session_name", None)
|
||||
aws_profile_name = self.optional_params.pop("aws_profile_name", None)
|
||||
aws_bedrock_runtime_endpoint = self.optional_params.pop(
|
||||
"aws_bedrock_runtime_endpoint", None
|
||||
) # https://bedrock-runtime.{region_name}.amazonaws.com
|
||||
aws_web_identity_token = self.optional_params.pop(
|
||||
"aws_web_identity_token", None
|
||||
)
|
||||
aws_sts_endpoint = self.optional_params.pop("aws_sts_endpoint", None)
|
||||
|
||||
### SET REGION NAME ###
|
||||
if aws_region_name is None:
|
||||
# check env #
|
||||
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
|
||||
|
||||
if litellm_aws_region_name is not None and isinstance(
|
||||
litellm_aws_region_name, str
|
||||
):
|
||||
aws_region_name = litellm_aws_region_name
|
||||
|
||||
standard_aws_region_name = get_secret("AWS_REGION", None)
|
||||
if standard_aws_region_name is not None and isinstance(
|
||||
standard_aws_region_name, str
|
||||
):
|
||||
aws_region_name = standard_aws_region_name
|
||||
|
||||
if aws_region_name is None:
|
||||
aws_region_name = "us-west-2"
|
||||
|
||||
credentials: Credentials = self.get_credentials(
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
aws_session_token=aws_session_token,
|
||||
aws_region_name=aws_region_name,
|
||||
aws_session_name=aws_session_name,
|
||||
aws_profile_name=aws_profile_name,
|
||||
aws_role_name=aws_role_name,
|
||||
aws_web_identity_token=aws_web_identity_token,
|
||||
aws_sts_endpoint=aws_sts_endpoint,
|
||||
)
|
||||
return credentials, aws_region_name
|
||||
|
||||
def _prepare_request(
|
||||
self,
|
||||
credentials,
|
||||
data: BedrockRequest,
|
||||
optional_params: dict,
|
||||
aws_region_name: str,
|
||||
extra_headers: Optional[dict] = None,
|
||||
):
|
||||
try:
|
||||
import boto3
|
||||
from botocore.auth import SigV4Auth
|
||||
from botocore.awsrequest import AWSRequest
|
||||
from botocore.credentials import Credentials
|
||||
except ImportError as e:
|
||||
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||
|
||||
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
|
||||
api_base = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com/guardrail/{self.guardrailIdentifier}/version/{self.guardrailVersion}/apply"
|
||||
|
||||
encoded_data = json.dumps(data).encode("utf-8")
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if extra_headers is not None:
|
||||
headers = {"Content-Type": "application/json", **extra_headers}
|
||||
|
||||
request = AWSRequest(
|
||||
method="POST", url=api_base, data=encoded_data, headers=headers
|
||||
)
|
||||
sigv4.add_auth(request)
|
||||
prepped_request = request.prepare()
|
||||
|
||||
return prepped_request
|
||||
|
||||
async def make_bedrock_api_request(
|
||||
self, kwargs: dict, response: Optional[Union[Any, litellm.ModelResponse]] = None
|
||||
):
|
||||
|
||||
credentials, aws_region_name = self._load_credentials()
|
||||
request_data: BedrockRequest = self.convert_to_bedrock_format(
|
||||
messages=kwargs.get("messages"), response=response
|
||||
)
|
||||
prepared_request = self._prepare_request(
|
||||
credentials=credentials,
|
||||
data=request_data,
|
||||
optional_params=self.optional_params,
|
||||
aws_region_name=aws_region_name,
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
"Bedrock AI request body: %s, url %s, headers: %s",
|
||||
request_data,
|
||||
prepared_request.url,
|
||||
prepared_request.headers,
|
||||
)
|
||||
_json_data = json.dumps(request_data) # type: ignore
|
||||
response = await self.async_handler.post(
|
||||
url=prepared_request.url,
|
||||
json=request_data, # type: ignore
|
||||
headers=prepared_request.headers,
|
||||
)
|
||||
verbose_proxy_logger.debug("Bedrock AI response: %s", response.text)
|
||||
if response.status_code == 200:
|
||||
# check if the response was flagged
|
||||
_json_response = response.json()
|
||||
if _json_response.get("action") == "GUARDRAIL_INTERVENED":
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "Violated guardrail policy",
|
||||
"bedrock_guardrail_response": _json_response,
|
||||
},
|
||||
)
|
||||
else:
|
||||
verbose_proxy_logger.error(
|
||||
"Bedrock AI: error in response. Status code: %s, response: %s",
|
||||
response.status_code,
|
||||
response.text,
|
||||
)
|
||||
|
||||
async def async_moderation_hook( ### 👈 KEY CHANGE ###
|
||||
self,
|
||||
data: dict,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
call_type: Literal["completion", "embeddings", "image_generation"],
|
||||
):
|
||||
from litellm.proxy.common_utils.callback_utils import (
|
||||
add_guardrail_to_applied_guardrails_header,
|
||||
)
|
||||
from litellm.types.guardrails import GuardrailEventHooks
|
||||
|
||||
event_type: GuardrailEventHooks = GuardrailEventHooks.during_call
|
||||
if self.should_run_guardrail(data=data, event_type=event_type) is not True:
|
||||
return
|
||||
|
||||
new_messages: Optional[List[dict]] = data.get("messages")
|
||||
if new_messages is not None:
|
||||
await self.make_bedrock_api_request(kwargs=data)
|
||||
add_guardrail_to_applied_guardrails_header(
|
||||
request_data=data, guardrail_name=self.guardrail_name
|
||||
)
|
||||
else:
|
||||
verbose_proxy_logger.warning(
|
||||
"Bedrock AI: not running guardrail. No messages in data"
|
||||
)
|
||||
pass
|
||||
|
||||
async def async_post_call_success_hook(
|
||||
self,
|
||||
data: dict,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
response,
|
||||
):
|
||||
from litellm.proxy.common_utils.callback_utils import (
|
||||
add_guardrail_to_applied_guardrails_header,
|
||||
)
|
||||
from litellm.types.guardrails import GuardrailEventHooks
|
||||
|
||||
if (
|
||||
self.should_run_guardrail(
|
||||
data=data, event_type=GuardrailEventHooks.post_call
|
||||
)
|
||||
is not True
|
||||
):
|
||||
return
|
||||
|
||||
new_messages: Optional[List[dict]] = data.get("messages")
|
||||
if new_messages is not None:
|
||||
await self.make_bedrock_api_request(kwargs=data, response=response)
|
||||
add_guardrail_to_applied_guardrails_header(
|
||||
request_data=data, guardrail_name=self.guardrail_name
|
||||
)
|
||||
else:
|
||||
verbose_proxy_logger.warning(
|
||||
"Bedrock AI: not running guardrail. No messages in data"
|
||||
)
|
|
@ -96,8 +96,10 @@ def init_guardrails_v2(all_guardrails: dict):
|
|||
litellm_params = LitellmParams(
|
||||
guardrail=litellm_params_data["guardrail"],
|
||||
mode=litellm_params_data["mode"],
|
||||
api_key=litellm_params_data["api_key"],
|
||||
api_base=litellm_params_data["api_base"],
|
||||
api_key=litellm_params_data.get("api_key"),
|
||||
api_base=litellm_params_data.get("api_base"),
|
||||
guardrailIdentifier=litellm_params_data.get("guardrailIdentifier"),
|
||||
guardrailVersion=litellm_params_data.get("guardrailVersion"),
|
||||
)
|
||||
|
||||
if (
|
||||
|
@ -134,6 +136,18 @@ def init_guardrails_v2(all_guardrails: dict):
|
|||
event_hook=litellm_params["mode"],
|
||||
)
|
||||
litellm.callbacks.append(_aporia_callback) # type: ignore
|
||||
if litellm_params["guardrail"] == "bedrock":
|
||||
from litellm.proxy.guardrails.guardrail_hooks.bedrock_guardrails import (
|
||||
BedrockGuardrail,
|
||||
)
|
||||
|
||||
_bedrock_callback = BedrockGuardrail(
|
||||
guardrail_name=guardrail["guardrail_name"],
|
||||
event_hook=litellm_params["mode"],
|
||||
guardrailIdentifier=litellm_params["guardrailIdentifier"],
|
||||
guardrailVersion=litellm_params["guardrailVersion"],
|
||||
)
|
||||
litellm.callbacks.append(_bedrock_callback) # type: ignore
|
||||
elif litellm_params["guardrail"] == "lakera":
|
||||
from litellm.proxy.guardrails.guardrail_hooks.lakera_ai import (
|
||||
lakeraAI_Moderation,
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
model_list:
|
||||
- model_name: gpt-4
|
||||
- model_name: fake-openai-endpoint
|
||||
litellm_params:
|
||||
model: azure/chatgpt-v-2
|
||||
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
|
||||
|
@ -9,13 +9,9 @@ model_list:
|
|||
client_secret: os.environ/AZURE_CLIENT_SECRET
|
||||
|
||||
guardrails:
|
||||
- guardrail_name: "lakera-pre-guard"
|
||||
- guardrail_name: "bedrock-pre-guard"
|
||||
litellm_params:
|
||||
guardrail: lakera # supported values: "aporia", "bedrock", "lakera"
|
||||
mode: "during_call"
|
||||
api_key: os.environ/LAKERA_API_KEY
|
||||
api_base: os.environ/LAKERA_API_BASE
|
||||
category_thresholds:
|
||||
prompt_injection: 0.1
|
||||
jailbreak: 0.1
|
||||
|
||||
guardrail: bedrock # supported values: "aporia", "bedrock", "lakera"
|
||||
mode: "post_call"
|
||||
guardrailIdentifier: ff6ujrregl1q
|
||||
guardrailVersion: "DRAFT"
|
|
@ -1588,7 +1588,7 @@ class ProxyConfig:
|
|||
verbose_proxy_logger.debug( # noqa
|
||||
f"{blue_color_code}Set Cache on LiteLLM Proxy: {vars(litellm.cache.cache)}{reset_color_code}"
|
||||
)
|
||||
elif key == "cache" and value == False:
|
||||
elif key == "cache" and value is False:
|
||||
pass
|
||||
elif key == "guardrails":
|
||||
if premium_user is not True:
|
||||
|
@ -2672,6 +2672,13 @@ def giveup(e):
|
|||
and isinstance(e.message, str)
|
||||
and "Max parallel request limit reached" in e.message
|
||||
)
|
||||
|
||||
if (
|
||||
general_settings.get("disable_retry_on_max_parallel_request_limit_error")
|
||||
is True
|
||||
):
|
||||
return True # giveup if queuing max parallel request limits is disabled
|
||||
|
||||
if result:
|
||||
verbose_proxy_logger.info(json.dumps({"event": "giveup", "exception": str(e)}))
|
||||
return result
|
||||
|
|
|
@ -277,7 +277,8 @@ class Router:
|
|||
"local" # default to an in-memory cache
|
||||
)
|
||||
redis_cache = None
|
||||
cache_config = {}
|
||||
cache_config: Dict[str, Any] = {}
|
||||
|
||||
self.client_ttl = client_ttl
|
||||
if redis_url is not None or (
|
||||
redis_host is not None
|
||||
|
|
|
@ -738,8 +738,9 @@ def test_bedrock_system_prompt(system, model):
|
|||
"temperature": 0.3,
|
||||
"messages": [
|
||||
{"role": "system", "content": system},
|
||||
{"role": "user", "content": "hey, how's it going?"},
|
||||
{"role": "assistant", "content": "hey, how's it going?"},
|
||||
],
|
||||
"user_continue_message": {"role": "user", "content": "Be a good bot!"},
|
||||
}
|
||||
response: ModelResponse = completion(
|
||||
model="bedrock/{}".format(model),
|
||||
|
|
|
@ -804,6 +804,38 @@ def test_redis_cache_completion_stream():
|
|||
# test_redis_cache_completion_stream()
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Local test. Requires running redis cluster locally.")
|
||||
@pytest.mark.asyncio
|
||||
async def test_redis_cache_cluster_init_unit_test():
|
||||
try:
|
||||
from redis.asyncio import RedisCluster as AsyncRedisCluster
|
||||
from redis.cluster import RedisCluster
|
||||
|
||||
from litellm.caching import RedisCache
|
||||
|
||||
litellm.set_verbose = True
|
||||
|
||||
# List of startup nodes
|
||||
startup_nodes = [
|
||||
{"host": "127.0.0.1", "port": "7001"},
|
||||
]
|
||||
|
||||
resp = RedisCache(startup_nodes=startup_nodes)
|
||||
|
||||
assert isinstance(resp.redis_client, RedisCluster)
|
||||
assert isinstance(resp.init_async_client(), AsyncRedisCluster)
|
||||
|
||||
resp = litellm.Cache(type="redis", redis_startup_nodes=startup_nodes)
|
||||
|
||||
assert isinstance(resp.cache, RedisCache)
|
||||
assert isinstance(resp.cache.redis_client, RedisCluster)
|
||||
assert isinstance(resp.cache.init_async_client(), AsyncRedisCluster)
|
||||
|
||||
except Exception as e:
|
||||
print(f"{str(e)}\n\n{traceback.format_exc()}")
|
||||
raise e
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_redis_cache_acompletion_stream():
|
||||
try:
|
||||
|
|
|
@ -3653,6 +3653,7 @@ def test_completion_cohere():
|
|||
response = completion(
|
||||
model="command-r",
|
||||
messages=messages,
|
||||
extra_headers={"Helicone-Property-Locale": "ko"},
|
||||
)
|
||||
print(response)
|
||||
except Exception as e:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from enum import Enum
|
||||
from typing import Dict, List, Optional, TypedDict
|
||||
from typing import Dict, List, Literal, Optional, TypedDict
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
@ -76,8 +76,14 @@ class LitellmParams(TypedDict, total=False):
|
|||
mode: str
|
||||
api_key: str
|
||||
api_base: Optional[str]
|
||||
|
||||
# Lakera specific params
|
||||
category_thresholds: Optional[LakeraCategoryThresholds]
|
||||
|
||||
# Bedrock specific params
|
||||
guardrailIdentifier: Optional[str]
|
||||
guardrailVersion: Optional[str]
|
||||
|
||||
|
||||
class Guardrail(TypedDict):
|
||||
guardrail_name: str
|
||||
|
@ -92,3 +98,16 @@ class GuardrailEventHooks(str, Enum):
|
|||
pre_call = "pre_call"
|
||||
post_call = "post_call"
|
||||
during_call = "during_call"
|
||||
|
||||
|
||||
class BedrockTextContent(TypedDict, total=False):
|
||||
text: str
|
||||
|
||||
|
||||
class BedrockContentItem(TypedDict, total=False):
|
||||
text: BedrockTextContent
|
||||
|
||||
|
||||
class BedrockRequest(TypedDict, total=False):
|
||||
source: Literal["INPUT", "OUTPUT"]
|
||||
content: List[BedrockContentItem]
|
||||
|
|
|
@ -1120,6 +1120,7 @@ all_litellm_params = [
|
|||
"tenant_id",
|
||||
"client_id",
|
||||
"client_secret",
|
||||
"user_continue_message",
|
||||
]
|
||||
|
||||
|
||||
|
|
|
@ -2324,6 +2324,7 @@ def get_litellm_params(
|
|||
cooldown_time=None,
|
||||
text_completion=None,
|
||||
azure_ad_token_provider=None,
|
||||
user_continue_message=None,
|
||||
):
|
||||
litellm_params = {
|
||||
"acompletion": acompletion,
|
||||
|
@ -2349,6 +2350,7 @@ def get_litellm_params(
|
|||
"cooldown_time": cooldown_time,
|
||||
"text_completion": text_completion,
|
||||
"azure_ad_token_provider": azure_ad_token_provider,
|
||||
"user_continue_message": user_continue_message,
|
||||
}
|
||||
|
||||
return litellm_params
|
||||
|
@ -4221,6 +4223,7 @@ def get_supported_openai_params(
|
|||
"presence_penalty",
|
||||
"stop",
|
||||
"n",
|
||||
"extra_headers",
|
||||
]
|
||||
elif custom_llm_provider == "cohere_chat":
|
||||
return [
|
||||
|
@ -4235,6 +4238,7 @@ def get_supported_openai_params(
|
|||
"tools",
|
||||
"tool_choice",
|
||||
"seed",
|
||||
"extra_headers",
|
||||
]
|
||||
elif custom_llm_provider == "maritalk":
|
||||
return [
|
||||
|
@ -7123,6 +7127,14 @@ def exception_type(
|
|||
llm_provider="bedrock",
|
||||
response=original_exception.response,
|
||||
)
|
||||
elif "A conversation must start with a user message." in error_str:
|
||||
exception_mapping_worked = True
|
||||
raise BadRequestError(
|
||||
message=f"BedrockException - {error_str}\n. Pass in default user message via `completion(..,user_continue_message=)` or enable `litellm.modify_params=True`.\nFor Proxy: do via `litellm_settings::modify_params: True` or user_continue_message under `litellm_params`",
|
||||
model=model,
|
||||
llm_provider="bedrock",
|
||||
response=original_exception.response,
|
||||
)
|
||||
elif (
|
||||
"Unable to locate credentials" in error_str
|
||||
or "The security token included in the request is invalid"
|
||||
|
|
|
@ -144,6 +144,7 @@ async def test_no_llm_guard_triggered():
|
|||
|
||||
assert "x-litellm-applied-guardrails" not in headers
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_guardrails_with_api_key_controls():
|
||||
"""
|
||||
|
@ -194,3 +195,25 @@ async def test_guardrails_with_api_key_controls():
|
|||
except Exception as e:
|
||||
print(e)
|
||||
assert "Aporia detected and blocked PII" in str(e)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bedrock_guardrail_triggered():
|
||||
"""
|
||||
- Tests a request where our bedrock guardrail should be triggered
|
||||
- Assert that the guardrails applied are returned in the response headers
|
||||
"""
|
||||
async with aiohttp.ClientSession() as session:
|
||||
try:
|
||||
response, headers = await chat_completion(
|
||||
session,
|
||||
"sk-1234",
|
||||
model="fake-openai-endpoint",
|
||||
messages=[{"role": "user", "content": f"Hello do you like coffee?"}],
|
||||
guardrails=["bedrock-pre-guard"],
|
||||
)
|
||||
pytest.fail("Should have thrown an exception")
|
||||
except Exception as e:
|
||||
print(e)
|
||||
assert "GUARDRAIL_INTERVENED" in str(e)
|
||||
assert "Violated guardrail policy" in str(e)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue