Merge branch 'main' into litellm_allow_using_azure_ad_token_auth

This commit is contained in:
Ishaan Jaff 2024-08-22 18:21:24 -07:00 committed by GitHub
commit 228252b92d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
33 changed files with 802 additions and 84 deletions

View file

@ -321,6 +321,9 @@ jobs:
-e APORIA_API_BASE_2=$APORIA_API_BASE_2 \
-e APORIA_API_KEY_2=$APORIA_API_KEY_2 \
-e APORIA_API_BASE_1=$APORIA_API_BASE_1 \
-e AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID \
-e AWS_SECRET_ACCESS_KEY=$AWS_SECRET_ACCESS_KEY \
-e AWS_REGION_NAME=$AWS_REGION_NAME \
-e APORIA_API_KEY_1=$APORIA_API_KEY_1 \
--name my-app \
-v $(pwd)/litellm/proxy/example_config_yaml/otel_test_config.yaml:/app/config.yaml \

View file

@ -0,0 +1,3 @@
Efficient, consistent and secure library for querying structured data with natural language. Query any database with over 100 LLMs ❤️ 🚅.
🔗 [GitHub](https://github.com/deepsense-ai/db-ally)

View file

@ -307,8 +307,9 @@ LiteLLM supports **ALL** azure ai models. Here's a few examples:
| Model Name | Function Call |
|--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| Cohere command-r-plus | `completion(model="azure/command-r-plus", messages)` |
| Cohere command-r | `completion(model="azure/command-r", messages)` |
| mistral-large-latest | `completion(model="azure/mistral-large-latest", messages)` |
| Cohere command-r-plus | `completion(model="azure_ai/command-r-plus", messages)` |
| Cohere command-r | `completion(model="azure_ai/command-r", messages)` |
| mistral-large-latest | `completion(model="azure_ai/mistral-large-latest", messages)` |
| AI21-Jamba-Instruct | `completion(model="azure_ai/ai21-jamba-instruct", messages)` |

View file

@ -727,6 +727,7 @@ general_settings:
"completion_model": "string",
"disable_spend_logs": "boolean", # turn off writing each transaction to the db
"disable_master_key_return": "boolean", # turn off returning master key on UI (checked on '/user/info' endpoint)
"disable_retry_on_max_parallel_request_limit_error": "boolean", # turn off retries when max parallel request limit is reached
"disable_reset_budget": "boolean", # turn off reset budget scheduled task
"disable_adding_master_key_hash_to_db": "boolean", # turn off storing master key hash in db, for spend tracking
"enable_jwt_auth": "boolean", # allow proxy admin to auth in via jwt tokens with 'litellm_proxy_admin' in claims
@ -751,7 +752,8 @@ general_settings:
},
"otel": true,
"custom_auth": "string",
"max_parallel_requests": 0,
"max_parallel_requests": 0, # the max parallel requests allowed per deployment
"global_max_parallel_requests": 0, # the max parallel requests allowed on the proxy all up
"infer_model_from_keys": true,
"background_health_checks": true,
"health_check_interval": 300,

View file

@ -0,0 +1,135 @@
import Image from '@theme/IdealImage';
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# Bedrock
## Quick Start
### 1. Define Guardrails on your LiteLLM config.yaml
Define your guardrails under the `guardrails` section
```yaml
model_list:
- model_name: gpt-3.5-turbo
litellm_params:
model: openai/gpt-3.5-turbo
api_key: os.environ/OPENAI_API_KEY
guardrails:
- guardrail_name: "bedrock-pre-guard"
litellm_params:
guardrail: bedrock # supported values: "aporia", "bedrock", "lakera"
mode: "during_call"
guardrailIdentifier: ff6ujrregl1q # your guardrail ID on bedrock
guardrailVersion: "DRAFT" # your guardrail version on bedrock
```
#### Supported values for `mode`
- `pre_call` Run **before** LLM call, on **input**
- `post_call` Run **after** LLM call, on **input & output**
- `during_call` Run **during** LLM call, on **input** Same as `pre_call` but runs in parallel as LLM call. Response not returned until guardrail check completes
### 2. Start LiteLLM Gateway
```shell
litellm --config config.yaml --detailed_debug
```
### 3. Test request
**[Langchain, OpenAI SDK Usage Examples](../proxy/user_keys##request-format)**
<Tabs>
<TabItem label="Unsuccessful call" value = "not-allowed">
Expect this to fail since since `ishaan@berri.ai` in the request is PII
```shell
curl -i http://localhost:4000/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer sk-npnwjPQciVRok5yNZgKmFQ" \
-d '{
"model": "gpt-3.5-turbo",
"messages": [
{"role": "user", "content": "hi my email is ishaan@berri.ai"}
],
"guardrails": ["bedrock-guard"]
}'
```
Expected response on failure
```shell
{
"error": {
"message": {
"error": "Violated guardrail policy",
"bedrock_guardrail_response": {
"action": "GUARDRAIL_INTERVENED",
"assessments": [
{
"topicPolicy": {
"topics": [
{
"action": "BLOCKED",
"name": "Coffee",
"type": "DENY"
}
]
}
}
],
"blockedResponse": "Sorry, the model cannot answer this question. coffee guardrail applied ",
"output": [
{
"text": "Sorry, the model cannot answer this question. coffee guardrail applied "
}
],
"outputs": [
{
"text": "Sorry, the model cannot answer this question. coffee guardrail applied "
}
],
"usage": {
"contentPolicyUnits": 0,
"contextualGroundingPolicyUnits": 0,
"sensitiveInformationPolicyFreeUnits": 0,
"sensitiveInformationPolicyUnits": 0,
"topicPolicyUnits": 1,
"wordPolicyUnits": 0
}
}
},
"type": "None",
"param": "None",
"code": "400"
}
}
```
</TabItem>
<TabItem label="Successful Call " value = "allowed">
```shell
curl -i http://localhost:4000/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer sk-npnwjPQciVRok5yNZgKmFQ" \
-d '{
"model": "gpt-3.5-turbo",
"messages": [
{"role": "user", "content": "hi what is the weather"}
],
"guardrails": ["bedrock-guard"]
}'
```
</TabItem>
</Tabs>

View file

@ -68,6 +68,15 @@ http://localhost:4000/metrics
| `litellm_total_tokens` | input + output tokens per `"user", "key", "model", "team", "end-user"` |
| `litellm_llm_api_failed_requests_metric` | Number of failed LLM API requests per `"user", "key", "model", "team", "end-user"` |
### Request Latency Metrics
| Metric Name | Description |
|----------------------|--------------------------------------|
| `litellm_request_total_latency_metric` | Total latency (seconds) for a request to LiteLLM Proxy Server - tracked for labels `litellm_call_id`, `model` |
| `litellm_llm_api_latency_metric` | latency (seconds) for just the LLM API call - tracked for labels `litellm_call_id`, `model` |
### LLM API / Provider Metrics
| Metric Name | Description |

View file

@ -54,7 +54,7 @@ const sidebars = {
{
type: "category",
label: "🛡️ [Beta] Guardrails",
items: ["proxy/guardrails/quick_start", "proxy/guardrails/aporia_api", "proxy/guardrails/lakera_ai"],
items: ["proxy/guardrails/quick_start", "proxy/guardrails/aporia_api", "proxy/guardrails/lakera_ai", "proxy/guardrails/bedrock"],
},
{
type: "category",
@ -292,6 +292,7 @@ const sidebars = {
items: [
"projects/Docq.AI",
"projects/OpenInterpreter",
"projects/dbally",
"projects/FastREPL",
"projects/PROMPTMETHEUS",
"projects/Codium PR Agent",

View file

@ -6,7 +6,7 @@
"": {
"dependencies": {
"@hono/node-server": "^1.10.1",
"hono": "^4.2.7"
"hono": "^4.5.8"
},
"devDependencies": {
"@types/node": "^20.11.17",
@ -463,9 +463,9 @@
}
},
"node_modules/hono": {
"version": "4.2.7",
"resolved": "https://registry.npmjs.org/hono/-/hono-4.2.7.tgz",
"integrity": "sha512-k1xHi86tJnRIVvqhFMBDGFKJ8r5O+bEsT4P59ZK59r0F300Xd910/r237inVfuT/VmE86RQQffX4OYNda6dLXw==",
"version": "4.5.8",
"resolved": "https://registry.npmjs.org/hono/-/hono-4.5.8.tgz",
"integrity": "sha512-pqpSlcdqGkpTTRpLYU1PnCz52gVr0zVR9H5GzMyJWuKQLLEBQxh96q45QizJ2PPX8NATtz2mu31/PKW/Jt+90Q==",
"engines": {
"node": ">=16.0.0"
}

View file

@ -4,7 +4,7 @@
},
"dependencies": {
"@hono/node-server": "^1.10.1",
"hono": "^4.2.7"
"hono": "^4.5.8"
},
"devDependencies": {
"@types/node": "^20.11.17",

View file

@ -7,13 +7,17 @@
#
# Thank you users! We ❤️ you! - Krrish & Ishaan
import inspect
# s/o [@Frank Colson](https://www.linkedin.com/in/frank-colson-422b9b183/) for this redis implementation
import os
import inspect
import redis, litellm # type: ignore
import redis.asyncio as async_redis # type: ignore
from typing import List, Optional
import redis # type: ignore
import redis.asyncio as async_redis # type: ignore
import litellm
def _get_redis_kwargs():
arg_spec = inspect.getfullargspec(redis.Redis)
@ -51,6 +55,19 @@ def _get_redis_url_kwargs(client=None):
return available_args
def _get_redis_cluster_kwargs(client=None):
if client is None:
client = redis.Redis.from_url
arg_spec = inspect.getfullargspec(redis.RedisCluster)
# Only allow primitive arguments
exclude_args = {"self", "connection_pool", "retry", "host", "port", "startup_nodes"}
available_args = [x for x in arg_spec.args if x not in exclude_args]
return available_args
def _get_redis_env_kwarg_mapping():
PREFIX = "REDIS_"
@ -124,6 +141,22 @@ def get_redis_client(**env_overrides):
url_kwargs[arg] = redis_kwargs[arg]
return redis.Redis.from_url(**url_kwargs)
if "startup_nodes" in redis_kwargs:
from redis.cluster import ClusterNode
args = _get_redis_cluster_kwargs()
cluster_kwargs = {}
for arg in redis_kwargs:
if arg in args:
cluster_kwargs[arg] = redis_kwargs[arg]
new_startup_nodes: List[ClusterNode] = []
for item in redis_kwargs["startup_nodes"]:
new_startup_nodes.append(ClusterNode(**item))
redis_kwargs.pop("startup_nodes")
return redis.RedisCluster(startup_nodes=new_startup_nodes, **cluster_kwargs)
return redis.Redis(**redis_kwargs)
@ -143,6 +176,24 @@ def get_redis_async_client(**env_overrides):
)
return async_redis.Redis.from_url(**url_kwargs)
if "startup_nodes" in redis_kwargs:
from redis.cluster import ClusterNode
args = _get_redis_cluster_kwargs()
cluster_kwargs = {}
for arg in redis_kwargs:
if arg in args:
cluster_kwargs[arg] = redis_kwargs[arg]
new_startup_nodes: List[ClusterNode] = []
for item in redis_kwargs["startup_nodes"]:
new_startup_nodes.append(ClusterNode(**item))
redis_kwargs.pop("startup_nodes")
return async_redis.RedisCluster(
startup_nodes=new_startup_nodes, **cluster_kwargs
)
return async_redis.Redis(
socket_timeout=5,
**redis_kwargs,
@ -160,4 +211,5 @@ def get_redis_connection_pool(**env_overrides):
connection_class = async_redis.SSLConnection
redis_kwargs.pop("ssl", None)
redis_kwargs["connection_class"] = connection_class
redis_kwargs.pop("startup_nodes", None)
return async_redis.BlockingConnectionPool(timeout=5, **redis_kwargs)

View file

@ -203,6 +203,7 @@ class RedisCache(BaseCache):
password=None,
redis_flush_size=100,
namespace: Optional[str] = None,
startup_nodes: Optional[List] = None, # for redis-cluster
**kwargs,
):
import redis
@ -218,7 +219,8 @@ class RedisCache(BaseCache):
redis_kwargs["port"] = port
if password is not None:
redis_kwargs["password"] = password
if startup_nodes is not None:
redis_kwargs["startup_nodes"] = startup_nodes
### HEALTH MONITORING OBJECT ###
if kwargs.get("service_logger_obj", None) is not None and isinstance(
kwargs["service_logger_obj"], ServiceLogging
@ -246,7 +248,7 @@ class RedisCache(BaseCache):
### ASYNC HEALTH PING ###
try:
# asyncio.get_running_loop().create_task(self.ping())
result = asyncio.get_running_loop().create_task(self.ping())
_ = asyncio.get_running_loop().create_task(self.ping())
except Exception as e:
if "no running event loop" in str(e):
verbose_logger.debug(
@ -2123,6 +2125,7 @@ class Cache:
redis_semantic_cache_use_async=False,
redis_semantic_cache_embedding_model="text-embedding-ada-002",
redis_flush_size=None,
redis_startup_nodes: Optional[List] = None,
disk_cache_dir=None,
qdrant_api_base: Optional[str] = None,
qdrant_api_key: Optional[str] = None,
@ -2155,7 +2158,12 @@ class Cache:
"""
if type == "redis":
self.cache: BaseCache = RedisCache(
host, port, password, redis_flush_size, **kwargs
host,
port,
password,
redis_flush_size,
startup_nodes=redis_startup_nodes,
**kwargs,
)
elif type == "redis-semantic":
self.cache = RedisSemanticCache(

View file

@ -60,6 +60,25 @@ class PrometheusLogger(CustomLogger):
],
)
# request latency metrics
self.litellm_request_total_latency_metric = Histogram(
"litellm_request_total_latency_metric",
"Total latency (seconds) for a request to LiteLLM",
labelnames=[
"model",
"litellm_call_id",
],
)
self.litellm_llm_api_latency_metric = Histogram(
"litellm_llm_api_latency_metric",
"Total latency (seconds) for a models LLM API call",
labelnames=[
"model",
"litellm_call_id",
],
)
# Counter for spend
self.litellm_spend_metric = Counter(
"litellm_spend_metric",
@ -103,26 +122,27 @@ class PrometheusLogger(CustomLogger):
"Remaining budget for api key",
labelnames=["hashed_api_key", "api_key_alias"],
)
########################################
# LiteLLM Virtual API KEY metrics
########################################
# Remaining MODEL RPM limit for API Key
self.litellm_remaining_api_key_requests_for_model = Gauge(
"litellm_remaining_api_key_requests_for_model",
"Remaining Requests API Key can make for model (model based rpm limit on key)",
labelnames=["hashed_api_key", "api_key_alias", "model"],
)
# Remaining MODEL TPM limit for API Key
self.litellm_remaining_api_key_tokens_for_model = Gauge(
"litellm_remaining_api_key_tokens_for_model",
"Remaining Tokens API Key can make for model (model based tpm limit on key)",
labelnames=["hashed_api_key", "api_key_alias", "model"],
)
# Litellm-Enterprise Metrics
if premium_user is True:
########################################
# LiteLLM Virtual API KEY metrics
########################################
# Remaining MODEL RPM limit for API Key
self.litellm_remaining_api_key_requests_for_model = Gauge(
"litellm_remaining_api_key_requests_for_model",
"Remaining Requests API Key can make for model (model based rpm limit on key)",
labelnames=["hashed_api_key", "api_key_alias", "model"],
)
# Remaining MODEL TPM limit for API Key
self.litellm_remaining_api_key_tokens_for_model = Gauge(
"litellm_remaining_api_key_tokens_for_model",
"Remaining Tokens API Key can make for model (model based tpm limit on key)",
labelnames=["hashed_api_key", "api_key_alias", "model"],
)
########################################
# LLM API Deployment Metrics / analytics
########################################
@ -328,6 +348,25 @@ class PrometheusLogger(CustomLogger):
user_api_key, user_api_key_alias, model_group
).set(remaining_tokens)
# latency metrics
total_time: timedelta = kwargs.get("end_time") - kwargs.get("start_time")
total_time_seconds = total_time.total_seconds()
api_call_total_time: timedelta = kwargs.get("end_time") - kwargs.get(
"api_call_start_time"
)
api_call_total_time_seconds = api_call_total_time.total_seconds()
litellm_call_id = kwargs.get("litellm_call_id")
self.litellm_request_total_latency_metric.labels(
model, litellm_call_id
).observe(total_time_seconds)
self.litellm_llm_api_latency_metric.labels(model, litellm_call_id).observe(
api_call_total_time_seconds
)
# set x-ratelimit headers
if premium_user is True:
self.set_llm_deployment_success_metrics(

View file

@ -354,6 +354,8 @@ class Logging:
str(e)
)
)
self.model_call_details["api_call_start_time"] = datetime.datetime.now()
# Input Integration Logging -> If you want to log the fact that an attempt to call the model was made
callbacks = litellm.input_callback + self.dynamic_input_callbacks
for callback in callbacks:

View file

@ -84,6 +84,7 @@ BEDROCK_CONVERSE_MODELS = [
"meta.llama3-1-8b-instruct-v1:0",
"meta.llama3-1-70b-instruct-v1:0",
"meta.llama3-1-405b-instruct-v1:0",
"meta.llama3-70b-instruct-v1:0",
"mistral.mistral-large-2407-v1:0",
]
@ -1480,7 +1481,7 @@ class BedrockConverseLLM(BaseAWSLLM):
optional_params: dict,
acompletion: bool,
timeout: Optional[Union[float, httpx.Timeout]],
litellm_params=None,
litellm_params: dict,
logger_fn=None,
extra_headers: Optional[dict] = None,
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
@ -1596,6 +1597,14 @@ class BedrockConverseLLM(BaseAWSLLM):
supported_tool_call_params = ["tools", "tool_choice"]
supported_guardrail_params = ["guardrailConfig"]
## TRANSFORMATION ##
bedrock_messages: List[MessageBlock] = _bedrock_converse_messages_pt(
messages=messages,
model=model,
llm_provider="bedrock_converse",
user_continue_message=litellm_params.pop("user_continue_message", None),
)
# send all model-specific params in 'additional_request_params'
for k, v in inference_params.items():
if (
@ -1608,11 +1617,6 @@ class BedrockConverseLLM(BaseAWSLLM):
for key in additional_request_keys:
inference_params.pop(key, None)
bedrock_messages: List[MessageBlock] = _bedrock_converse_messages_pt(
messages=messages,
model=model,
llm_provider="bedrock_converse",
)
bedrock_tools: List[ToolBlock] = _bedrock_tools_pt(
inference_params.pop("tools", [])
)

View file

@ -124,12 +124,14 @@ class CohereConfig:
}
def validate_environment(api_key):
headers = {
"Request-Source": "unspecified:litellm",
"accept": "application/json",
"content-type": "application/json",
}
def validate_environment(api_key, headers: dict):
headers.update(
{
"Request-Source": "unspecified:litellm",
"accept": "application/json",
"content-type": "application/json",
}
)
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
return headers
@ -144,11 +146,12 @@ def completion(
encoding,
api_key,
logging_obj,
headers: dict,
optional_params=None,
litellm_params=None,
logger_fn=None,
):
headers = validate_environment(api_key)
headers = validate_environment(api_key, headers=headers)
completion_url = api_base
model = model
prompt = " ".join(message["content"] for message in messages)
@ -338,13 +341,14 @@ def embedding(
model_response: litellm.EmbeddingResponse,
logging_obj: LiteLLMLoggingObj,
optional_params: dict,
headers: dict,
encoding: Any,
api_key: Optional[str] = None,
aembedding: Optional[bool] = None,
timeout: Union[float, httpx.Timeout] = httpx.Timeout(None),
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
):
headers = validate_environment(api_key)
headers = validate_environment(api_key, headers=headers)
embed_url = "https://api.cohere.ai/v1/embed"
model = model
data = {"model": model, "texts": input, **optional_params}

View file

@ -116,12 +116,14 @@ class CohereChatConfig:
}
def validate_environment(api_key):
headers = {
"Request-Source": "unspecified:litellm",
"accept": "application/json",
"content-type": "application/json",
}
def validate_environment(api_key, headers: dict):
headers.update(
{
"Request-Source": "unspecified:litellm",
"accept": "application/json",
"content-type": "application/json",
}
)
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
return headers
@ -203,13 +205,14 @@ def completion(
model_response: ModelResponse,
print_verbose: Callable,
optional_params: dict,
headers: dict,
encoding,
api_key,
logging_obj,
litellm_params=None,
logger_fn=None,
):
headers = validate_environment(api_key)
headers = validate_environment(api_key, headers=headers)
completion_url = api_base
model = model
most_recent_message, chat_history = cohere_messages_pt_v2(

View file

@ -38,6 +38,18 @@ def prompt_injection_detection_default_pt():
BAD_MESSAGE_ERROR_STR = "Invalid Message "
# used to interweave user messages, to ensure user/assistant alternating
DEFAULT_USER_CONTINUE_MESSAGE = {
"role": "user",
"content": "Please continue.",
} # similar to autogen. Only used if `litellm.modify_params=True`.
# used to interweave assistant messages, to ensure user/assistant alternating
DEFAULT_ASSISTANT_CONTINUE_MESSAGE = {
"role": "assistant",
"content": "Please continue.",
} # similar to autogen. Only used if `litellm.modify_params=True`.
def map_system_message_pt(messages: list) -> list:
"""
@ -2254,6 +2266,7 @@ def _bedrock_converse_messages_pt(
messages: List,
model: str,
llm_provider: str,
user_continue_message: Optional[dict] = None,
) -> List[BedrockMessageBlock]:
"""
Converts given messages from OpenAI format to Bedrock format
@ -2264,6 +2277,21 @@ def _bedrock_converse_messages_pt(
contents: List[BedrockMessageBlock] = []
msg_i = 0
# if initial message is assistant message
if messages[0].get("role") is not None and messages[0]["role"] == "assistant":
if user_continue_message is not None:
messages.insert(0, user_continue_message)
elif litellm.modify_params:
messages.insert(0, DEFAULT_USER_CONTINUE_MESSAGE)
# if final message is assistant message
if messages[-1].get("role") is not None and messages[-1]["role"] == "assistant":
if user_continue_message is not None:
messages.append(user_continue_message)
elif litellm.modify_params:
messages.append(DEFAULT_USER_CONTINUE_MESSAGE)
while msg_i < len(messages):
user_content: List[BedrockContentBlock] = []
init_msg_i = msg_i
@ -2344,6 +2372,7 @@ def _bedrock_converse_messages_pt(
model=model,
llm_provider=llm_provider,
)
return contents

View file

@ -944,6 +944,8 @@ def completion(
cooldown_time=cooldown_time,
text_completion=kwargs.get("text_completion"),
azure_ad_token_provider=kwargs.get("azure_ad_token_provider"),
user_continue_message=kwargs.get("user_continue_message"),
)
logging.update_environment_variables(
model=model,
@ -1635,6 +1637,13 @@ def completion(
or "https://api.cohere.ai/v1/generate"
)
headers = headers or litellm.headers or {}
if headers is None:
headers = {}
if extra_headers is not None:
headers.update(extra_headers)
model_response = cohere.completion(
model=model,
messages=messages,
@ -1645,6 +1654,7 @@ def completion(
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
headers=headers,
api_key=cohere_key,
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
)
@ -1675,6 +1685,13 @@ def completion(
or "https://api.cohere.ai/v1/chat"
)
headers = headers or litellm.headers or {}
if headers is None:
headers = {}
if extra_headers is not None:
headers.update(extra_headers)
model_response = cohere_chat.completion(
model=model,
messages=messages,
@ -1683,6 +1700,7 @@ def completion(
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
headers=headers,
logger_fn=logger_fn,
encoding=encoding,
api_key=cohere_key,
@ -2289,7 +2307,7 @@ def completion(
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
litellm_params=litellm_params, # type: ignore
logger_fn=logger_fn,
encoding=encoding,
logging_obj=logging,
@ -3159,6 +3177,7 @@ def embedding(
encoding_format = kwargs.get("encoding_format", None)
proxy_server_request = kwargs.get("proxy_server_request", None)
aembedding = kwargs.get("aembedding", None)
extra_headers = kwargs.get("extra_headers", None)
### CUSTOM MODEL COST ###
input_cost_per_token = kwargs.get("input_cost_per_token", None)
output_cost_per_token = kwargs.get("output_cost_per_token", None)
@ -3234,6 +3253,7 @@ def embedding(
"tenant_id",
"client_id",
"client_secret",
"extra_headers",
]
default_params = openai_params + litellm_params
non_default_params = {
@ -3297,7 +3317,7 @@ def embedding(
"cooldown_time": cooldown_time,
},
)
if azure == True or custom_llm_provider == "azure":
if azure is True or custom_llm_provider == "azure":
# azure configs
api_type = get_secret("AZURE_API_TYPE") or "azure"
@ -3403,12 +3423,18 @@ def embedding(
or get_secret("CO_API_KEY")
or litellm.api_key
)
if extra_headers is not None and isinstance(extra_headers, dict):
headers = extra_headers
else:
headers = {}
response = cohere.embedding(
model=model,
input=input,
optional_params=optional_params,
encoding=encoding,
api_key=cohere_key, # type: ignore
headers=headers,
logging_obj=logging,
model_response=EmbeddingResponse(),
aembedding=aembedding,

View file

@ -2,12 +2,3 @@ model_list:
- model_name: "*"
litellm_params:
model: "*"
litellm_settings:
success_callback: ["s3"]
cache: true
s3_callback_params:
s3_bucket_name: mytestbucketlitellm # AWS Bucket Name for S3
s3_region_name: us-west-2 # AWS Region Name for S3
s3_aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID # us os.environ/<variable name> to pass environment variables. This is AWS Access Key ID for S3
s3_aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY # AWS Secret Access Key for S3

View file

@ -66,7 +66,7 @@ def common_checks(
raise Exception(
f"Team={team_object.team_id} is blocked. Update via `/team/unblock` if your admin."
)
# 2. If user can call model
# 2. If team can call model
if (
_model is not None
and team_object is not None
@ -74,7 +74,11 @@ def common_checks(
and _model not in team_object.models
):
# this means the team has access to all models on the proxy
if "all-proxy-models" in team_object.models:
if (
"all-proxy-models" in team_object.models
or "*" in team_object.models
or "openai/*" in team_object.models
):
# this means the team has access to all models on the proxy
pass
# check if the team model is an access_group

View file

@ -22,3 +22,9 @@ guardrails:
mode: "post_call"
api_key: os.environ/APORIA_API_KEY_2
api_base: os.environ/APORIA_API_BASE_2
- guardrail_name: "bedrock-pre-guard"
litellm_params:
guardrail: bedrock # supported values: "aporia", "bedrock", "lakera"
mode: "pre_call"
guardrailIdentifier: ff6ujrregl1q
guardrailVersion: "DRAFT"

View file

@ -0,0 +1,289 @@
# +-------------------------------------------------------------+
#
# Use Bedrock Guardrails for your LLM calls
#
# +-------------------------------------------------------------+
# Thank you users! We ❤️ you! - Krrish & Ishaan
import os
import sys
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import asyncio
import json
import sys
import traceback
import uuid
from datetime import datetime
from typing import Any, Dict, List, Literal, Optional, Union
import aiohttp
import httpx
from fastapi import HTTPException
import litellm
from litellm import get_secret
from litellm._logging import verbose_proxy_logger
from litellm.caching import DualCache
from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.litellm_core_utils.logging_utils import (
convert_litellm_response_object_to_str,
)
from litellm.llms.base_aws_llm import BaseAWSLLM
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
_get_async_httpx_client,
)
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata
from litellm.types.guardrails import (
BedrockContentItem,
BedrockRequest,
BedrockTextContent,
GuardrailEventHooks,
)
GUARDRAIL_NAME = "bedrock"
class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
def __init__(
self,
guardrailIdentifier: Optional[str] = None,
guardrailVersion: Optional[str] = None,
**kwargs,
):
self.async_handler = _get_async_httpx_client()
self.guardrailIdentifier = guardrailIdentifier
self.guardrailVersion = guardrailVersion
# store kwargs as optional_params
self.optional_params = kwargs
super().__init__(**kwargs)
def convert_to_bedrock_format(
self,
messages: Optional[List[Dict[str, str]]] = None,
response: Optional[Union[Any, litellm.ModelResponse]] = None,
) -> BedrockRequest:
bedrock_request: BedrockRequest = BedrockRequest(source="INPUT")
bedrock_request_content: List[BedrockContentItem] = []
if messages:
for message in messages:
content = message.get("content")
if isinstance(content, str):
bedrock_content_item = BedrockContentItem(
text=BedrockTextContent(text=content)
)
bedrock_request_content.append(bedrock_content_item)
bedrock_request["content"] = bedrock_request_content
if response:
bedrock_request["source"] = "OUTPUT"
if isinstance(response, litellm.ModelResponse):
for choice in response.choices:
if isinstance(choice, litellm.Choices):
if choice.message.content and isinstance(
choice.message.content, str
):
bedrock_content_item = BedrockContentItem(
text=BedrockTextContent(text=choice.message.content)
)
bedrock_request_content.append(bedrock_content_item)
bedrock_request["content"] = bedrock_request_content
return bedrock_request
#### CALL HOOKS - proxy only ####
def _load_credentials(
self,
):
try:
from botocore.credentials import Credentials
except ImportError as e:
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
## CREDENTIALS ##
# pop aws_secret_access_key, aws_access_key_id, aws_session_token, aws_region_name from kwargs, since completion calls fail with them
aws_secret_access_key = self.optional_params.pop("aws_secret_access_key", None)
aws_access_key_id = self.optional_params.pop("aws_access_key_id", None)
aws_session_token = self.optional_params.pop("aws_session_token", None)
aws_region_name = self.optional_params.pop("aws_region_name", None)
aws_role_name = self.optional_params.pop("aws_role_name", None)
aws_session_name = self.optional_params.pop("aws_session_name", None)
aws_profile_name = self.optional_params.pop("aws_profile_name", None)
aws_bedrock_runtime_endpoint = self.optional_params.pop(
"aws_bedrock_runtime_endpoint", None
) # https://bedrock-runtime.{region_name}.amazonaws.com
aws_web_identity_token = self.optional_params.pop(
"aws_web_identity_token", None
)
aws_sts_endpoint = self.optional_params.pop("aws_sts_endpoint", None)
### SET REGION NAME ###
if aws_region_name is None:
# check env #
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
if litellm_aws_region_name is not None and isinstance(
litellm_aws_region_name, str
):
aws_region_name = litellm_aws_region_name
standard_aws_region_name = get_secret("AWS_REGION", None)
if standard_aws_region_name is not None and isinstance(
standard_aws_region_name, str
):
aws_region_name = standard_aws_region_name
if aws_region_name is None:
aws_region_name = "us-west-2"
credentials: Credentials = self.get_credentials(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
aws_region_name=aws_region_name,
aws_session_name=aws_session_name,
aws_profile_name=aws_profile_name,
aws_role_name=aws_role_name,
aws_web_identity_token=aws_web_identity_token,
aws_sts_endpoint=aws_sts_endpoint,
)
return credentials, aws_region_name
def _prepare_request(
self,
credentials,
data: BedrockRequest,
optional_params: dict,
aws_region_name: str,
extra_headers: Optional[dict] = None,
):
try:
import boto3
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest
from botocore.credentials import Credentials
except ImportError as e:
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
api_base = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com/guardrail/{self.guardrailIdentifier}/version/{self.guardrailVersion}/apply"
encoded_data = json.dumps(data).encode("utf-8")
headers = {"Content-Type": "application/json"}
if extra_headers is not None:
headers = {"Content-Type": "application/json", **extra_headers}
request = AWSRequest(
method="POST", url=api_base, data=encoded_data, headers=headers
)
sigv4.add_auth(request)
prepped_request = request.prepare()
return prepped_request
async def make_bedrock_api_request(
self, kwargs: dict, response: Optional[Union[Any, litellm.ModelResponse]] = None
):
credentials, aws_region_name = self._load_credentials()
request_data: BedrockRequest = self.convert_to_bedrock_format(
messages=kwargs.get("messages"), response=response
)
prepared_request = self._prepare_request(
credentials=credentials,
data=request_data,
optional_params=self.optional_params,
aws_region_name=aws_region_name,
)
verbose_proxy_logger.debug(
"Bedrock AI request body: %s, url %s, headers: %s",
request_data,
prepared_request.url,
prepared_request.headers,
)
_json_data = json.dumps(request_data) # type: ignore
response = await self.async_handler.post(
url=prepared_request.url,
json=request_data, # type: ignore
headers=prepared_request.headers,
)
verbose_proxy_logger.debug("Bedrock AI response: %s", response.text)
if response.status_code == 200:
# check if the response was flagged
_json_response = response.json()
if _json_response.get("action") == "GUARDRAIL_INTERVENED":
raise HTTPException(
status_code=400,
detail={
"error": "Violated guardrail policy",
"bedrock_guardrail_response": _json_response,
},
)
else:
verbose_proxy_logger.error(
"Bedrock AI: error in response. Status code: %s, response: %s",
response.status_code,
response.text,
)
async def async_moderation_hook( ### 👈 KEY CHANGE ###
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
call_type: Literal["completion", "embeddings", "image_generation"],
):
from litellm.proxy.common_utils.callback_utils import (
add_guardrail_to_applied_guardrails_header,
)
from litellm.types.guardrails import GuardrailEventHooks
event_type: GuardrailEventHooks = GuardrailEventHooks.during_call
if self.should_run_guardrail(data=data, event_type=event_type) is not True:
return
new_messages: Optional[List[dict]] = data.get("messages")
if new_messages is not None:
await self.make_bedrock_api_request(kwargs=data)
add_guardrail_to_applied_guardrails_header(
request_data=data, guardrail_name=self.guardrail_name
)
else:
verbose_proxy_logger.warning(
"Bedrock AI: not running guardrail. No messages in data"
)
pass
async def async_post_call_success_hook(
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
response,
):
from litellm.proxy.common_utils.callback_utils import (
add_guardrail_to_applied_guardrails_header,
)
from litellm.types.guardrails import GuardrailEventHooks
if (
self.should_run_guardrail(
data=data, event_type=GuardrailEventHooks.post_call
)
is not True
):
return
new_messages: Optional[List[dict]] = data.get("messages")
if new_messages is not None:
await self.make_bedrock_api_request(kwargs=data, response=response)
add_guardrail_to_applied_guardrails_header(
request_data=data, guardrail_name=self.guardrail_name
)
else:
verbose_proxy_logger.warning(
"Bedrock AI: not running guardrail. No messages in data"
)

View file

@ -96,8 +96,10 @@ def init_guardrails_v2(all_guardrails: dict):
litellm_params = LitellmParams(
guardrail=litellm_params_data["guardrail"],
mode=litellm_params_data["mode"],
api_key=litellm_params_data["api_key"],
api_base=litellm_params_data["api_base"],
api_key=litellm_params_data.get("api_key"),
api_base=litellm_params_data.get("api_base"),
guardrailIdentifier=litellm_params_data.get("guardrailIdentifier"),
guardrailVersion=litellm_params_data.get("guardrailVersion"),
)
if (
@ -134,6 +136,18 @@ def init_guardrails_v2(all_guardrails: dict):
event_hook=litellm_params["mode"],
)
litellm.callbacks.append(_aporia_callback) # type: ignore
if litellm_params["guardrail"] == "bedrock":
from litellm.proxy.guardrails.guardrail_hooks.bedrock_guardrails import (
BedrockGuardrail,
)
_bedrock_callback = BedrockGuardrail(
guardrail_name=guardrail["guardrail_name"],
event_hook=litellm_params["mode"],
guardrailIdentifier=litellm_params["guardrailIdentifier"],
guardrailVersion=litellm_params["guardrailVersion"],
)
litellm.callbacks.append(_bedrock_callback) # type: ignore
elif litellm_params["guardrail"] == "lakera":
from litellm.proxy.guardrails.guardrail_hooks.lakera_ai import (
lakeraAI_Moderation,

View file

@ -1,5 +1,5 @@
model_list:
- model_name: gpt-4
- model_name: fake-openai-endpoint
litellm_params:
model: azure/chatgpt-v-2
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
@ -9,13 +9,9 @@ model_list:
client_secret: os.environ/AZURE_CLIENT_SECRET
guardrails:
- guardrail_name: "lakera-pre-guard"
- guardrail_name: "bedrock-pre-guard"
litellm_params:
guardrail: lakera # supported values: "aporia", "bedrock", "lakera"
mode: "during_call"
api_key: os.environ/LAKERA_API_KEY
api_base: os.environ/LAKERA_API_BASE
category_thresholds:
prompt_injection: 0.1
jailbreak: 0.1
guardrail: bedrock # supported values: "aporia", "bedrock", "lakera"
mode: "post_call"
guardrailIdentifier: ff6ujrregl1q
guardrailVersion: "DRAFT"

View file

@ -1588,7 +1588,7 @@ class ProxyConfig:
verbose_proxy_logger.debug( # noqa
f"{blue_color_code}Set Cache on LiteLLM Proxy: {vars(litellm.cache.cache)}{reset_color_code}"
)
elif key == "cache" and value == False:
elif key == "cache" and value is False:
pass
elif key == "guardrails":
if premium_user is not True:
@ -2672,6 +2672,13 @@ def giveup(e):
and isinstance(e.message, str)
and "Max parallel request limit reached" in e.message
)
if (
general_settings.get("disable_retry_on_max_parallel_request_limit_error")
is True
):
return True # giveup if queuing max parallel request limits is disabled
if result:
verbose_proxy_logger.info(json.dumps({"event": "giveup", "exception": str(e)}))
return result

View file

@ -277,7 +277,8 @@ class Router:
"local" # default to an in-memory cache
)
redis_cache = None
cache_config = {}
cache_config: Dict[str, Any] = {}
self.client_ttl = client_ttl
if redis_url is not None or (
redis_host is not None

View file

@ -738,8 +738,9 @@ def test_bedrock_system_prompt(system, model):
"temperature": 0.3,
"messages": [
{"role": "system", "content": system},
{"role": "user", "content": "hey, how's it going?"},
{"role": "assistant", "content": "hey, how's it going?"},
],
"user_continue_message": {"role": "user", "content": "Be a good bot!"},
}
response: ModelResponse = completion(
model="bedrock/{}".format(model),

View file

@ -804,6 +804,38 @@ def test_redis_cache_completion_stream():
# test_redis_cache_completion_stream()
@pytest.mark.skip(reason="Local test. Requires running redis cluster locally.")
@pytest.mark.asyncio
async def test_redis_cache_cluster_init_unit_test():
try:
from redis.asyncio import RedisCluster as AsyncRedisCluster
from redis.cluster import RedisCluster
from litellm.caching import RedisCache
litellm.set_verbose = True
# List of startup nodes
startup_nodes = [
{"host": "127.0.0.1", "port": "7001"},
]
resp = RedisCache(startup_nodes=startup_nodes)
assert isinstance(resp.redis_client, RedisCluster)
assert isinstance(resp.init_async_client(), AsyncRedisCluster)
resp = litellm.Cache(type="redis", redis_startup_nodes=startup_nodes)
assert isinstance(resp.cache, RedisCache)
assert isinstance(resp.cache.redis_client, RedisCluster)
assert isinstance(resp.cache.init_async_client(), AsyncRedisCluster)
except Exception as e:
print(f"{str(e)}\n\n{traceback.format_exc()}")
raise e
@pytest.mark.asyncio
async def test_redis_cache_acompletion_stream():
try:

View file

@ -3653,6 +3653,7 @@ def test_completion_cohere():
response = completion(
model="command-r",
messages=messages,
extra_headers={"Helicone-Property-Locale": "ko"},
)
print(response)
except Exception as e:

View file

@ -1,5 +1,5 @@
from enum import Enum
from typing import Dict, List, Optional, TypedDict
from typing import Dict, List, Literal, Optional, TypedDict
from pydantic import BaseModel, ConfigDict
from typing_extensions import Required, TypedDict
@ -76,8 +76,14 @@ class LitellmParams(TypedDict, total=False):
mode: str
api_key: str
api_base: Optional[str]
# Lakera specific params
category_thresholds: Optional[LakeraCategoryThresholds]
# Bedrock specific params
guardrailIdentifier: Optional[str]
guardrailVersion: Optional[str]
class Guardrail(TypedDict):
guardrail_name: str
@ -92,3 +98,16 @@ class GuardrailEventHooks(str, Enum):
pre_call = "pre_call"
post_call = "post_call"
during_call = "during_call"
class BedrockTextContent(TypedDict, total=False):
text: str
class BedrockContentItem(TypedDict, total=False):
text: BedrockTextContent
class BedrockRequest(TypedDict, total=False):
source: Literal["INPUT", "OUTPUT"]
content: List[BedrockContentItem]

View file

@ -1120,6 +1120,7 @@ all_litellm_params = [
"tenant_id",
"client_id",
"client_secret",
"user_continue_message",
]

View file

@ -2324,6 +2324,7 @@ def get_litellm_params(
cooldown_time=None,
text_completion=None,
azure_ad_token_provider=None,
user_continue_message=None,
):
litellm_params = {
"acompletion": acompletion,
@ -2349,6 +2350,7 @@ def get_litellm_params(
"cooldown_time": cooldown_time,
"text_completion": text_completion,
"azure_ad_token_provider": azure_ad_token_provider,
"user_continue_message": user_continue_message,
}
return litellm_params
@ -4221,6 +4223,7 @@ def get_supported_openai_params(
"presence_penalty",
"stop",
"n",
"extra_headers",
]
elif custom_llm_provider == "cohere_chat":
return [
@ -4235,6 +4238,7 @@ def get_supported_openai_params(
"tools",
"tool_choice",
"seed",
"extra_headers",
]
elif custom_llm_provider == "maritalk":
return [
@ -7123,6 +7127,14 @@ def exception_type(
llm_provider="bedrock",
response=original_exception.response,
)
elif "A conversation must start with a user message." in error_str:
exception_mapping_worked = True
raise BadRequestError(
message=f"BedrockException - {error_str}\n. Pass in default user message via `completion(..,user_continue_message=)` or enable `litellm.modify_params=True`.\nFor Proxy: do via `litellm_settings::modify_params: True` or user_continue_message under `litellm_params`",
model=model,
llm_provider="bedrock",
response=original_exception.response,
)
elif (
"Unable to locate credentials" in error_str
or "The security token included in the request is invalid"

View file

@ -144,6 +144,7 @@ async def test_no_llm_guard_triggered():
assert "x-litellm-applied-guardrails" not in headers
@pytest.mark.asyncio
async def test_guardrails_with_api_key_controls():
"""
@ -194,3 +195,25 @@ async def test_guardrails_with_api_key_controls():
except Exception as e:
print(e)
assert "Aporia detected and blocked PII" in str(e)
@pytest.mark.asyncio
async def test_bedrock_guardrail_triggered():
"""
- Tests a request where our bedrock guardrail should be triggered
- Assert that the guardrails applied are returned in the response headers
"""
async with aiohttp.ClientSession() as session:
try:
response, headers = await chat_completion(
session,
"sk-1234",
model="fake-openai-endpoint",
messages=[{"role": "user", "content": f"Hello do you like coffee?"}],
guardrails=["bedrock-pre-guard"],
)
pytest.fail("Should have thrown an exception")
except Exception as e:
print(e)
assert "GUARDRAIL_INTERVENED" in str(e)
assert "Violated guardrail policy" in str(e)