Merge branch 'main' into litellm_dev_11_13_2024

This commit is contained in:
Krish Dholakia 2024-11-15 11:18:02 +05:30 committed by GitHub
commit 1dcbfda202
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
76 changed files with 2836 additions and 560 deletions

View file

@ -305,6 +305,36 @@ Step 4: Submit a PR with your changes! 🚀
- push your fork to your GitHub repo - push your fork to your GitHub repo
- submit a PR from there - submit a PR from there
### Building LiteLLM Docker Image
Follow these instructions if you want to build / run the LiteLLM Docker Image yourself.
Step 1: Clone the repo
```
git clone https://github.com/BerriAI/litellm.git
```
Step 2: Build the Docker Image
Build using Dockerfile.non_root
```
docker build -f docker/Dockerfile.non_root -t litellm_test_image .
```
Step 3: Run the Docker Image
Make sure config.yaml is present in the root directory. This is your litellm proxy config file.
```
docker run \
-v $(pwd)/proxy_config.yaml:/app/config.yaml \
-e DATABASE_URL="postgresql://xxxxxxxx" \
-e LITELLM_MASTER_KEY="sk-1234" \
-p 4000:4000 \
litellm_test_image \
--config /app/config.yaml --detailed_debug
```
# Enterprise # Enterprise
For companies that need better security, user management and professional support For companies that need better security, user management and professional support

View file

@ -13,18 +13,18 @@ spec:
spec: spec:
containers: containers:
- name: prisma-migrations - name: prisma-migrations
image: "ghcr.io/berriai/litellm:main-stable" image: ghcr.io/berriai/litellm-database:main-latest
command: ["python", "litellm/proxy/prisma_migration.py"] command: ["python", "litellm/proxy/prisma_migration.py"]
workingDir: "/app" workingDir: "/app"
env: env:
{{- if .Values.db.deployStandalone }} {{- if .Values.db.useExisting }}
- name: DATABASE_URL
value: postgresql://{{ .Values.postgresql.auth.username }}:{{ .Values.postgresql.auth.password }}@{{ .Release.Name }}-postgresql/{{ .Values.postgresql.auth.database }}
{{- else if .Values.db.useExisting }}
- name: DATABASE_URL - name: DATABASE_URL
value: {{ .Values.db.url | quote }} value: {{ .Values.db.url | quote }}
{{- else }}
- name: DATABASE_URL
value: postgresql://{{ .Values.postgresql.auth.username }}:{{ .Values.postgresql.auth.password }}@{{ .Release.Name }}-postgresql/{{ .Values.postgresql.auth.database }}
{{- end }} {{- end }}
- name: DISABLE_SCHEMA_UPDATE - name: DISABLE_SCHEMA_UPDATE
value: "{{ .Values.migrationJob.disableSchemaUpdate }}" value: "false" # always run the migration from the Helm PreSync hook, override the value set
restartPolicy: OnFailure restartPolicy: OnFailure
backoffLimit: {{ .Values.migrationJob.backoffLimit }} backoffLimit: {{ .Values.migrationJob.backoffLimit }}

View file

@ -75,6 +75,7 @@ Works for:
- Google AI Studio - Gemini models - Google AI Studio - Gemini models
- Vertex AI models (Gemini + Anthropic) - Vertex AI models (Gemini + Anthropic)
- Bedrock Models - Bedrock Models
- Anthropic API Models
<Tabs> <Tabs>
<TabItem value="sdk" label="SDK"> <TabItem value="sdk" label="SDK">

View file

@ -93,7 +93,7 @@ curl http://0.0.0.0:4000/v1/chat/completions \
## Check Model Support ## Check Model Support
Call `litellm.get_model_info` to check if a model/provider supports `response_format`. Call `litellm.get_model_info` to check if a model/provider supports `prefix`.
<Tabs> <Tabs>
<TabItem value="sdk" label="SDK"> <TabItem value="sdk" label="SDK">
@ -116,4 +116,4 @@ curl -X GET 'http://0.0.0.0:4000/v1/model/info' \
-H 'Authorization: Bearer $LITELLM_KEY' \ -H 'Authorization: Bearer $LITELLM_KEY' \
``` ```
</TabItem> </TabItem>
</Tabs> </Tabs>

View file

@ -957,3 +957,69 @@ curl http://0.0.0.0:4000/v1/chat/completions \
``` ```
</TabItem> </TabItem>
</Tabs> </Tabs>
## Usage - passing 'user_id' to Anthropic
LiteLLM translates the OpenAI `user` param to Anthropic's `metadata[user_id]` param.
<Tabs>
<TabItem value="sdk" label="SDK">
```python
response = completion(
model="claude-3-5-sonnet-20240620",
messages=messages,
user="user_123",
)
```
</TabItem>
<TabItem value="proxy" label="PROXY">
1. Setup config.yaml
```yaml
model_list:
- model_name: claude-3-5-sonnet-20240620
litellm_params:
model: anthropic/claude-3-5-sonnet-20240620
api_key: os.environ/ANTHROPIC_API_KEY
```
2. Start Proxy
```
litellm --config /path/to/config.yaml
```
3. Test it!
```bash
curl http://0.0.0.0:4000/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer <YOUR-LITELLM-KEY>" \
-d '{
"model": "claude-3-5-sonnet-20240620",
"messages": [{"role": "user", "content": "What is Anthropic?"}],
"user": "user_123"
}'
```
</TabItem>
</Tabs>
## All Supported OpenAI Params
```
"stream",
"stop",
"temperature",
"top_p",
"max_tokens",
"max_completion_tokens",
"tools",
"tool_choice",
"extra_headers",
"parallel_tool_calls",
"response_format",
"user"
```

View file

@ -37,7 +37,7 @@ os.environ["HUGGINGFACE_API_KEY"] = "huggingface_api_key"
messages = [{ "content": "There's a llama in my garden 😱 What should I do?","role": "user"}] messages = [{ "content": "There's a llama in my garden 😱 What should I do?","role": "user"}]
# e.g. Call 'https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct' from Serverless Inference API # e.g. Call 'https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct' from Serverless Inference API
response = litellm.completion( response = completion(
model="huggingface/meta-llama/Meta-Llama-3.1-8B-Instruct", model="huggingface/meta-llama/Meta-Llama-3.1-8B-Instruct",
messages=[{ "content": "Hello, how are you?","role": "user"}], messages=[{ "content": "Hello, how are you?","role": "user"}],
stream=True stream=True
@ -165,14 +165,14 @@ Steps to use
```python ```python
import os import os
import litellm from litellm import completion
os.environ["HUGGINGFACE_API_KEY"] = "" os.environ["HUGGINGFACE_API_KEY"] = ""
# TGI model: Call https://huggingface.co/glaiveai/glaive-coder-7b # TGI model: Call https://huggingface.co/glaiveai/glaive-coder-7b
# add the 'huggingface/' prefix to the model to set huggingface as the provider # add the 'huggingface/' prefix to the model to set huggingface as the provider
# set api base to your deployed api endpoint from hugging face # set api base to your deployed api endpoint from hugging face
response = litellm.completion( response = completion(
model="huggingface/glaiveai/glaive-coder-7b", model="huggingface/glaiveai/glaive-coder-7b",
messages=[{ "content": "Hello, how are you?","role": "user"}], messages=[{ "content": "Hello, how are you?","role": "user"}],
api_base="https://wjiegasee9bmqke2.us-east-1.aws.endpoints.huggingface.cloud" api_base="https://wjiegasee9bmqke2.us-east-1.aws.endpoints.huggingface.cloud"
@ -383,6 +383,8 @@ def default_pt(messages):
#### Custom prompt templates #### Custom prompt templates
```python ```python
import litellm
# Create your own custom prompt template works # Create your own custom prompt template works
litellm.register_prompt_template( litellm.register_prompt_template(
model="togethercomputer/LLaMA-2-7B-32K", model="togethercomputer/LLaMA-2-7B-32K",

View file

@ -1,6 +1,13 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# Jina AI # Jina AI
https://jina.ai/embeddings/ https://jina.ai/embeddings/
Supported endpoints:
- /embeddings
- /rerank
## API Key ## API Key
```python ```python
# env variable # env variable
@ -8,6 +15,10 @@ os.environ['JINA_AI_API_KEY']
``` ```
## Sample Usage - Embedding ## Sample Usage - Embedding
<Tabs>
<TabItem value="sdk" label="SDK">
```python ```python
from litellm import embedding from litellm import embedding
import os import os
@ -19,6 +30,142 @@ response = embedding(
) )
print(response) print(response)
``` ```
</TabItem>
<TabItem value="proxy" label="PROXY">
1. Add to config.yaml
```yaml
model_list:
- model_name: embedding-model
litellm_params:
model: jina_ai/jina-embeddings-v3
api_key: os.environ/JINA_AI_API_KEY
```
2. Start proxy
```bash
litellm --config /path/to/config.yaml
# RUNNING on http://0.0.0.0:4000/
```
3. Test it!
```bash
curl -L -X POST 'http://0.0.0.0:4000/embeddings' \
-H 'Authorization: Bearer sk-1234' \
-H 'Content-Type: application/json' \
-d '{"input": ["hello world"], "model": "embedding-model"}'
```
</TabItem>
</Tabs>
## Sample Usage - Rerank
<Tabs>
<TabItem value="sdk" label="SDK">
```python
from litellm import rerank
import os
os.environ["JINA_AI_API_KEY"] = "sk-..."
query = "What is the capital of the United States?"
documents = [
"Carson City is the capital city of the American state of Nevada.",
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.",
"Washington, D.C. is the capital of the United States.",
"Capital punishment has existed in the United States since before it was a country.",
]
response = rerank(
model="jina_ai/jina-reranker-v2-base-multilingual",
query=query,
documents=documents,
top_n=3,
)
print(response)
```
</TabItem>
<TabItem value="proxy" label="PROXY">
1. Add to config.yaml
```yaml
model_list:
- model_name: rerank-model
litellm_params:
model: jina_ai/jina-reranker-v2-base-multilingual
api_key: os.environ/JINA_AI_API_KEY
```
2. Start proxy
```bash
litellm --config /path/to/config.yaml
```
3. Test it!
```bash
curl -L -X POST 'http://0.0.0.0:4000/rerank' \
-H 'Authorization: Bearer sk-1234' \
-H 'Content-Type: application/json' \
-d '{
"model": "rerank-model",
"query": "What is the capital of the United States?",
"documents": [
"Carson City is the capital city of the American state of Nevada.",
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.",
"Washington, D.C. is the capital of the United States.",
"Capital punishment has existed in the United States since before it was a country."
],
"top_n": 3
}'
```
</TabItem>
</Tabs>
## Supported Models ## Supported Models
All models listed here https://jina.ai/embeddings/ are supported All models listed here https://jina.ai/embeddings/ are supported
## Supported Optional Rerank Parameters
All cohere rerank parameters are supported.
## Supported Optional Embeddings Parameters
```
dimensions
```
## Provider-specific parameters
Pass any jina ai specific parameters as a keyword argument to the `embedding` or `rerank` function, e.g.
<Tabs>
<TabItem value="sdk" label="SDK">
```python
response = embedding(
model="jina_ai/jina-embeddings-v3",
input=["good morning from litellm"],
dimensions=1536,
my_custom_param="my_custom_value", # any other jina ai specific parameters
)
```
</TabItem>
<TabItem value="proxy" label="PROXY">
```bash
curl -L -X POST 'http://0.0.0.0:4000/embeddings' \
-H 'Authorization: Bearer sk-1234' \
-H 'Content-Type: application/json' \
-d '{"input": ["good morning from litellm"], "model": "jina_ai/jina-embeddings-v3", "dimensions": 1536, "my_custom_param": "my_custom_value"}'
```
</TabItem>
</Tabs>

View file

@ -1562,6 +1562,10 @@ curl http://0.0.0.0:4000/v1/chat/completions \
## **Embedding Models** ## **Embedding Models**
#### Usage - Embedding #### Usage - Embedding
<Tabs>
<TabItem value="sdk" label="SDK">
```python ```python
import litellm import litellm
from litellm import embedding from litellm import embedding
@ -1574,6 +1578,49 @@ response = embedding(
) )
print(response) print(response)
``` ```
</TabItem>
<TabItem value="proxy" label="LiteLLM PROXY">
1. Add model to config.yaml
```yaml
model_list:
- model_name: snowflake-arctic-embed-m-long-1731622468876
litellm_params:
model: vertex_ai/<your-model-id>
vertex_project: "adroit-crow-413218"
vertex_location: "us-central1"
vertex_credentials: adroit-crow-413218-a956eef1a2a8.json
litellm_settings:
drop_params: True
```
2. Start Proxy
```
$ litellm --config /path/to/config.yaml
```
3. Make Request using OpenAI Python SDK, Langchain Python SDK
```python
import openai
client = openai.OpenAI(api_key="sk-1234", base_url="http://0.0.0.0:4000")
response = client.embeddings.create(
model="snowflake-arctic-embed-m-long-1731622468876",
input = ["good morning from litellm", "this is another item"],
)
print(response)
```
</TabItem>
</Tabs>
#### Supported Embedding Models #### Supported Embedding Models
All models listed [here](https://github.com/BerriAI/litellm/blob/57f37f743886a0249f630a6792d49dffc2c5d9b7/model_prices_and_context_window.json#L835) are supported All models listed [here](https://github.com/BerriAI/litellm/blob/57f37f743886a0249f630a6792d49dffc2c5d9b7/model_prices_and_context_window.json#L835) are supported
@ -1589,6 +1636,7 @@ All models listed [here](https://github.com/BerriAI/litellm/blob/57f37f743886a02
| textembedding-gecko@003 | `embedding(model="vertex_ai/textembedding-gecko@003", input)` | | textembedding-gecko@003 | `embedding(model="vertex_ai/textembedding-gecko@003", input)` |
| text-embedding-preview-0409 | `embedding(model="vertex_ai/text-embedding-preview-0409", input)` | | text-embedding-preview-0409 | `embedding(model="vertex_ai/text-embedding-preview-0409", input)` |
| text-multilingual-embedding-preview-0409 | `embedding(model="vertex_ai/text-multilingual-embedding-preview-0409", input)` | | text-multilingual-embedding-preview-0409 | `embedding(model="vertex_ai/text-multilingual-embedding-preview-0409", input)` |
| Fine-tuned OR Custom Embedding models | `embedding(model="vertex_ai/<your-model-id>", input)` |
### Supported OpenAI (Unified) Params ### Supported OpenAI (Unified) Params

View file

@ -791,9 +791,9 @@ general_settings:
| store_model_in_db | boolean | If true, allows `/model/new` endpoint to store model information in db. Endpoint disabled by default. [Doc on `/model/new` endpoint](./model_management.md#create-a-new-model) | | store_model_in_db | boolean | If true, allows `/model/new` endpoint to store model information in db. Endpoint disabled by default. [Doc on `/model/new` endpoint](./model_management.md#create-a-new-model) |
| max_request_size_mb | int | The maximum size for requests in MB. Requests above this size will be rejected. | | max_request_size_mb | int | The maximum size for requests in MB. Requests above this size will be rejected. |
| max_response_size_mb | int | The maximum size for responses in MB. LLM Responses above this size will not be sent. | | max_response_size_mb | int | The maximum size for responses in MB. LLM Responses above this size will not be sent. |
| proxy_budget_rescheduler_min_time | int | The minimum time (in seconds) to wait before checking db for budget resets. | | proxy_budget_rescheduler_min_time | int | The minimum time (in seconds) to wait before checking db for budget resets. **Default is 597 seconds** |
| proxy_budget_rescheduler_max_time | int | The maximum time (in seconds) to wait before checking db for budget resets. | | proxy_budget_rescheduler_max_time | int | The maximum time (in seconds) to wait before checking db for budget resets. **Default is 605 seconds** |
| proxy_batch_write_at | int | Time (in seconds) to wait before batch writing spend logs to the db. | | proxy_batch_write_at | int | Time (in seconds) to wait before batch writing spend logs to the db. **Default is 10 seconds** |
| alerting_args | dict | Args for Slack Alerting [Doc on Slack Alerting](./alerting.md) | | alerting_args | dict | Args for Slack Alerting [Doc on Slack Alerting](./alerting.md) |
| custom_key_generate | str | Custom function for key generation [Doc on custom key generation](./virtual_keys.md#custom--key-generate) | | custom_key_generate | str | Custom function for key generation [Doc on custom key generation](./virtual_keys.md#custom--key-generate) |
| allowed_ips | List[str] | List of IPs allowed to access the proxy. If not set, all IPs are allowed. | | allowed_ips | List[str] | List of IPs allowed to access the proxy. If not set, all IPs are allowed. |

View file

@ -66,10 +66,16 @@ Removes any field with `user_api_key_*` from metadata.
Found under `kwargs["standard_logging_object"]`. This is a standard payload, logged for every response. Found under `kwargs["standard_logging_object"]`. This is a standard payload, logged for every response.
```python ```python
class StandardLoggingPayload(TypedDict): class StandardLoggingPayload(TypedDict):
id: str id: str
trace_id: str # Trace multiple LLM calls belonging to same overall request (e.g. fallbacks/retries)
call_type: str call_type: str
response_cost: float response_cost: float
response_cost_failure_debug_info: Optional[
StandardLoggingModelCostFailureDebugInformation
]
status: StandardLoggingPayloadStatus
total_tokens: int total_tokens: int
prompt_tokens: int prompt_tokens: int
completion_tokens: int completion_tokens: int
@ -84,13 +90,13 @@ class StandardLoggingPayload(TypedDict):
metadata: StandardLoggingMetadata metadata: StandardLoggingMetadata
cache_hit: Optional[bool] cache_hit: Optional[bool]
cache_key: Optional[str] cache_key: Optional[str]
saved_cache_cost: Optional[float] saved_cache_cost: float
request_tags: list request_tags: list
end_user: Optional[str] end_user: Optional[str]
requester_ip_address: Optional[str] # IP address of requester requester_ip_address: Optional[str]
requester_metadata: Optional[dict] # metadata passed in request in the "metadata" field
messages: Optional[Union[str, list, dict]] messages: Optional[Union[str, list, dict]]
response: Optional[Union[str, list, dict]] response: Optional[Union[str, list, dict]]
error_str: Optional[str]
model_parameters: dict model_parameters: dict
hidden_params: StandardLoggingHiddenParams hidden_params: StandardLoggingHiddenParams
@ -99,12 +105,47 @@ class StandardLoggingHiddenParams(TypedDict):
cache_key: Optional[str] cache_key: Optional[str]
api_base: Optional[str] api_base: Optional[str]
response_cost: Optional[str] response_cost: Optional[str]
additional_headers: Optional[dict] additional_headers: Optional[StandardLoggingAdditionalHeaders]
class StandardLoggingAdditionalHeaders(TypedDict, total=False):
x_ratelimit_limit_requests: int
x_ratelimit_limit_tokens: int
x_ratelimit_remaining_requests: int
x_ratelimit_remaining_tokens: int
class StandardLoggingMetadata(StandardLoggingUserAPIKeyMetadata):
"""
Specific metadata k,v pairs logged to integration for easier cost tracking
"""
spend_logs_metadata: Optional[
dict
] # special param to log k,v pairs to spendlogs for a call
requester_ip_address: Optional[str]
requester_metadata: Optional[dict]
class StandardLoggingModelInformation(TypedDict): class StandardLoggingModelInformation(TypedDict):
model_map_key: str model_map_key: str
model_map_value: Optional[ModelInfo] model_map_value: Optional[ModelInfo]
StandardLoggingPayloadStatus = Literal["success", "failure"]
class StandardLoggingModelCostFailureDebugInformation(TypedDict, total=False):
"""
Debug information, if cost tracking fails.
Avoid logging sensitive information like response or optional params
"""
error_str: Required[str]
traceback_str: Required[str]
model: str
cache_hit: Optional[bool]
custom_llm_provider: Optional[str]
base_model: Optional[str]
call_type: str
custom_pricing: Optional[bool]
``` ```
## Langfuse ## Langfuse

View file

@ -1,5 +1,6 @@
import Tabs from '@theme/Tabs'; import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem'; import TabItem from '@theme/TabItem';
import Image from '@theme/IdealImage';
# ⚡ Best Practices for Production # ⚡ Best Practices for Production
@ -112,7 +113,35 @@ general_settings:
disable_spend_logs: True disable_spend_logs: True
``` ```
## 7. Set LiteLLM Salt Key ## 7. Use Helm PreSync Hook for Database Migrations [BETA]
To ensure only one service manages database migrations, use our [Helm PreSync hook for Database Migrations](https://github.com/BerriAI/litellm/blob/main/deploy/charts/litellm-helm/templates/migrations-job.yaml). This ensures migrations are handled during `helm upgrade` or `helm install`, while LiteLLM pods explicitly disable migrations.
1. **Helm PreSync Hook**:
- The Helm PreSync hook is configured in the chart to run database migrations during deployments.
- The hook always sets `DISABLE_SCHEMA_UPDATE=false`, ensuring migrations are executed reliably.
Reference Settings to set on ArgoCD for `values.yaml`
```yaml
db:
useExisting: true # use existing Postgres DB
url: postgresql://ishaanjaffer0324:3rnwpOBau6hT@ep-withered-mud-a5dkdpke.us-east-2.aws.neon.tech/test-argo-cd?sslmode=require # url of existing Postgres DB
```
2. **LiteLLM Pods**:
- Set `DISABLE_SCHEMA_UPDATE=true` in LiteLLM pod configurations to prevent them from running migrations.
Example configuration for LiteLLM pod:
```yaml
env:
- name: DISABLE_SCHEMA_UPDATE
value: "true"
```
## 8. Set LiteLLM Salt Key
If you plan on using the DB, set a salt key for encrypting/decrypting variables in the DB. If you plan on using the DB, set a salt key for encrypting/decrypting variables in the DB.

View file

@ -748,4 +748,19 @@ curl -L -X POST 'http://0.0.0.0:4000/v1/chat/completions' \
"max_tokens": 300, "max_tokens": 300,
"mock_testing_fallbacks": true "mock_testing_fallbacks": true
}' }'
```
### Disable Fallbacks per key
You can disable fallbacks per key by setting `disable_fallbacks: true` in your key metadata.
```bash
curl -L -X POST 'http://0.0.0.0:4000/key/generate' \
-H 'Authorization: Bearer sk-1234' \
-H 'Content-Type: application/json' \
-d '{
"metadata": {
"disable_fallbacks": true
}
}'
``` ```

View file

@ -113,4 +113,5 @@ curl http://0.0.0.0:4000/rerank \
|-------------|--------------------| |-------------|--------------------|
| Cohere | [Usage](#quick-start) | | Cohere | [Usage](#quick-start) |
| Together AI| [Usage](../docs/providers/togetherai) | | Together AI| [Usage](../docs/providers/togetherai) |
| Azure AI| [Usage](../docs/providers/azure_ai) | | Azure AI| [Usage](../docs/providers/azure_ai) |
| Jina AI| [Usage](../docs/providers/jina_ai) |

View file

@ -1,3 +1,6 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# Secret Manager # Secret Manager
LiteLLM supports reading secrets from Azure Key Vault, Google Secret Manager LiteLLM supports reading secrets from Azure Key Vault, Google Secret Manager
@ -59,14 +62,35 @@ os.environ["AWS_REGION_NAME"] = "" # us-east-1, us-east-2, us-west-1, us-west-2
``` ```
2. Enable AWS Secret Manager in config. 2. Enable AWS Secret Manager in config.
<Tabs>
<TabItem value="read_only" label="Read Keys from AWS Secret Manager">
```yaml ```yaml
general_settings: general_settings:
master_key: os.environ/litellm_master_key master_key: os.environ/litellm_master_key
key_management_system: "aws_secret_manager" # 👈 KEY CHANGE key_management_system: "aws_secret_manager" # 👈 KEY CHANGE
key_management_settings: key_management_settings:
hosted_keys: ["litellm_master_key"] # 👈 Specify which env keys you stored on AWS hosted_keys: ["litellm_master_key"] # 👈 Specify which env keys you stored on AWS
``` ```
</TabItem>
<TabItem value="write_only" label="Write Virtual Keys to AWS Secret Manager">
This will only store virtual keys in AWS Secret Manager. No keys will be read from AWS Secret Manager.
```yaml
general_settings:
key_management_system: "aws_secret_manager" # 👈 KEY CHANGE
key_management_settings:
store_virtual_keys: true
access_mode: "write_only" # Literal["read_only", "write_only", "read_and_write"]
```
</TabItem>
</Tabs>
3. Run proxy 3. Run proxy
```bash ```bash
@ -181,16 +205,14 @@ litellm --config /path/to/config.yaml
Use encrypted keys from Google KMS on the proxy Use encrypted keys from Google KMS on the proxy
### Usage with LiteLLM Proxy Server Step 1. Add keys to env
## Step 1. Add keys to env
``` ```
export GOOGLE_APPLICATION_CREDENTIALS="/path/to/credentials.json" export GOOGLE_APPLICATION_CREDENTIALS="/path/to/credentials.json"
export GOOGLE_KMS_RESOURCE_NAME="projects/*/locations/*/keyRings/*/cryptoKeys/*" export GOOGLE_KMS_RESOURCE_NAME="projects/*/locations/*/keyRings/*/cryptoKeys/*"
export PROXY_DATABASE_URL_ENCRYPTED=b'\n$\x00D\xac\xb4/\x8e\xc...' export PROXY_DATABASE_URL_ENCRYPTED=b'\n$\x00D\xac\xb4/\x8e\xc...'
``` ```
## Step 2: Update Config Step 2: Update Config
```yaml ```yaml
general_settings: general_settings:
@ -199,7 +221,7 @@ general_settings:
master_key: sk-1234 master_key: sk-1234
``` ```
## Step 3: Start + test proxy Step 3: Start + test proxy
``` ```
$ litellm --config /path/to/config.yaml $ litellm --config /path/to/config.yaml
@ -215,3 +237,17 @@ $ litellm --test
<!-- <!--
## .env Files ## .env Files
If no secret manager client is specified, Litellm automatically uses the `.env` file to manage sensitive data. --> If no secret manager client is specified, Litellm automatically uses the `.env` file to manage sensitive data. -->
## All Secret Manager Settings
All settings related to secret management
```yaml
general_settings:
key_management_system: "aws_secret_manager" # REQUIRED
key_management_settings:
store_virtual_keys: true # OPTIONAL. Defaults to False, when True will store virtual keys in secret manager
access_mode: "write_only" # OPTIONAL. Literal["read_only", "write_only", "read_and_write"]. Defaults to "read_only"
hosted_keys: ["litellm_master_key"] # OPTIONAL. Specify which env keys you stored on AWS
```

View file

@ -305,7 +305,7 @@ secret_manager_client: Optional[Any] = (
) )
_google_kms_resource_name: Optional[str] = None _google_kms_resource_name: Optional[str] = None
_key_management_system: Optional[KeyManagementSystem] = None _key_management_system: Optional[KeyManagementSystem] = None
_key_management_settings: Optional[KeyManagementSettings] = None _key_management_settings: KeyManagementSettings = KeyManagementSettings()
#### PII MASKING #### #### PII MASKING ####
output_parse_pii: bool = False output_parse_pii: bool = False
############################################# #############################################
@ -962,6 +962,8 @@ from .utils import (
supports_response_schema, supports_response_schema,
supports_parallel_function_calling, supports_parallel_function_calling,
supports_vision, supports_vision,
supports_audio_input,
supports_audio_output,
supports_system_messages, supports_system_messages,
get_litellm_params, get_litellm_params,
acreate, acreate,

View file

@ -46,6 +46,9 @@ from litellm.llms.OpenAI.cost_calculation import (
from litellm.llms.OpenAI.cost_calculation import cost_per_token as openai_cost_per_token from litellm.llms.OpenAI.cost_calculation import cost_per_token as openai_cost_per_token
from litellm.llms.OpenAI.cost_calculation import cost_router as openai_cost_router from litellm.llms.OpenAI.cost_calculation import cost_router as openai_cost_router
from litellm.llms.together_ai.cost_calculator import get_model_params_and_category from litellm.llms.together_ai.cost_calculator import get_model_params_and_category
from litellm.llms.vertex_ai_and_google_ai_studio.image_generation.cost_calculator import (
cost_calculator as vertex_ai_image_cost_calculator,
)
from litellm.types.llms.openai import HttpxBinaryResponseContent from litellm.types.llms.openai import HttpxBinaryResponseContent
from litellm.types.rerank import RerankResponse from litellm.types.rerank import RerankResponse
from litellm.types.router import SPECIAL_MODEL_INFO_PARAMS from litellm.types.router import SPECIAL_MODEL_INFO_PARAMS
@ -667,9 +670,11 @@ def completion_cost( # noqa: PLR0915
): ):
### IMAGE GENERATION COST CALCULATION ### ### IMAGE GENERATION COST CALCULATION ###
if custom_llm_provider == "vertex_ai": if custom_llm_provider == "vertex_ai":
# https://cloud.google.com/vertex-ai/generative-ai/pricing if isinstance(completion_response, ImageResponse):
# Vertex Charges Flat $0.20 per image return vertex_ai_image_cost_calculator(
return 0.020 model=model,
image_response=completion_response,
)
elif custom_llm_provider == "bedrock": elif custom_llm_provider == "bedrock":
if isinstance(completion_response, ImageResponse): if isinstance(completion_response, ImageResponse):
return bedrock_image_cost_calculator( return bedrock_image_cost_calculator(

View file

@ -239,7 +239,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"ContextWindowExceededError: {exception_provider} - {message}", message=f"ContextWindowExceededError: {exception_provider} - {message}",
llm_provider=custom_llm_provider, llm_provider=custom_llm_provider,
model=model, model=model,
response=original_exception.response, response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information, litellm_debug_info=extra_information,
) )
elif ( elif (
@ -251,7 +251,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"{exception_provider} - {message}", message=f"{exception_provider} - {message}",
llm_provider=custom_llm_provider, llm_provider=custom_llm_provider,
model=model, model=model,
response=original_exception.response, response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information, litellm_debug_info=extra_information,
) )
elif "A timeout occurred" in error_str: elif "A timeout occurred" in error_str:
@ -271,7 +271,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"ContentPolicyViolationError: {exception_provider} - {message}", message=f"ContentPolicyViolationError: {exception_provider} - {message}",
llm_provider=custom_llm_provider, llm_provider=custom_llm_provider,
model=model, model=model,
response=original_exception.response, response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information, litellm_debug_info=extra_information,
) )
elif ( elif (
@ -283,7 +283,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"{exception_provider} - {message}", message=f"{exception_provider} - {message}",
llm_provider=custom_llm_provider, llm_provider=custom_llm_provider,
model=model, model=model,
response=original_exception.response, response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information, litellm_debug_info=extra_information,
) )
elif "Web server is returning an unknown error" in error_str: elif "Web server is returning an unknown error" in error_str:
@ -299,7 +299,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"RateLimitError: {exception_provider} - {message}", message=f"RateLimitError: {exception_provider} - {message}",
model=model, model=model,
llm_provider=custom_llm_provider, llm_provider=custom_llm_provider,
response=original_exception.response, response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information, litellm_debug_info=extra_information,
) )
elif ( elif (
@ -311,7 +311,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"AuthenticationError: {exception_provider} - {message}", message=f"AuthenticationError: {exception_provider} - {message}",
llm_provider=custom_llm_provider, llm_provider=custom_llm_provider,
model=model, model=model,
response=original_exception.response, response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information, litellm_debug_info=extra_information,
) )
elif "Mistral API raised a streaming error" in error_str: elif "Mistral API raised a streaming error" in error_str:
@ -335,7 +335,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"{exception_provider} - {message}", message=f"{exception_provider} - {message}",
llm_provider=custom_llm_provider, llm_provider=custom_llm_provider,
model=model, model=model,
response=original_exception.response, response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information, litellm_debug_info=extra_information,
) )
elif original_exception.status_code == 401: elif original_exception.status_code == 401:
@ -344,7 +344,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"AuthenticationError: {exception_provider} - {message}", message=f"AuthenticationError: {exception_provider} - {message}",
llm_provider=custom_llm_provider, llm_provider=custom_llm_provider,
model=model, model=model,
response=original_exception.response, response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information, litellm_debug_info=extra_information,
) )
elif original_exception.status_code == 404: elif original_exception.status_code == 404:
@ -353,7 +353,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"NotFoundError: {exception_provider} - {message}", message=f"NotFoundError: {exception_provider} - {message}",
model=model, model=model,
llm_provider=custom_llm_provider, llm_provider=custom_llm_provider,
response=original_exception.response, response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information, litellm_debug_info=extra_information,
) )
elif original_exception.status_code == 408: elif original_exception.status_code == 408:
@ -516,7 +516,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"ReplicateException - {error_str}", message=f"ReplicateException - {error_str}",
llm_provider="replicate", llm_provider="replicate",
model=model, model=model,
response=original_exception.response, response=getattr(original_exception, "response", None),
) )
elif "input is too long" in error_str: elif "input is too long" in error_str:
exception_mapping_worked = True exception_mapping_worked = True
@ -524,7 +524,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"ReplicateException - {error_str}", message=f"ReplicateException - {error_str}",
model=model, model=model,
llm_provider="replicate", llm_provider="replicate",
response=original_exception.response, response=getattr(original_exception, "response", None),
) )
elif exception_type == "ModelError": elif exception_type == "ModelError":
exception_mapping_worked = True exception_mapping_worked = True
@ -532,7 +532,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"ReplicateException - {error_str}", message=f"ReplicateException - {error_str}",
model=model, model=model,
llm_provider="replicate", llm_provider="replicate",
response=original_exception.response, response=getattr(original_exception, "response", None),
) )
elif "Request was throttled" in error_str: elif "Request was throttled" in error_str:
exception_mapping_worked = True exception_mapping_worked = True
@ -540,7 +540,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"ReplicateException - {error_str}", message=f"ReplicateException - {error_str}",
llm_provider="replicate", llm_provider="replicate",
model=model, model=model,
response=original_exception.response, response=getattr(original_exception, "response", None),
) )
elif hasattr(original_exception, "status_code"): elif hasattr(original_exception, "status_code"):
if original_exception.status_code == 401: if original_exception.status_code == 401:
@ -549,7 +549,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"ReplicateException - {original_exception.message}", message=f"ReplicateException - {original_exception.message}",
llm_provider="replicate", llm_provider="replicate",
model=model, model=model,
response=original_exception.response, response=getattr(original_exception, "response", None),
) )
elif ( elif (
original_exception.status_code == 400 original_exception.status_code == 400
@ -560,7 +560,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"ReplicateException - {original_exception.message}", message=f"ReplicateException - {original_exception.message}",
model=model, model=model,
llm_provider="replicate", llm_provider="replicate",
response=original_exception.response, response=getattr(original_exception, "response", None),
) )
elif original_exception.status_code == 422: elif original_exception.status_code == 422:
exception_mapping_worked = True exception_mapping_worked = True
@ -568,7 +568,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"ReplicateException - {original_exception.message}", message=f"ReplicateException - {original_exception.message}",
model=model, model=model,
llm_provider="replicate", llm_provider="replicate",
response=original_exception.response, response=getattr(original_exception, "response", None),
) )
elif original_exception.status_code == 408: elif original_exception.status_code == 408:
exception_mapping_worked = True exception_mapping_worked = True
@ -583,7 +583,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"ReplicateException - {original_exception.message}", message=f"ReplicateException - {original_exception.message}",
llm_provider="replicate", llm_provider="replicate",
model=model, model=model,
response=original_exception.response, response=getattr(original_exception, "response", None),
) )
elif original_exception.status_code == 429: elif original_exception.status_code == 429:
exception_mapping_worked = True exception_mapping_worked = True
@ -591,7 +591,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"ReplicateException - {original_exception.message}", message=f"ReplicateException - {original_exception.message}",
llm_provider="replicate", llm_provider="replicate",
model=model, model=model,
response=original_exception.response, response=getattr(original_exception, "response", None),
) )
elif original_exception.status_code == 500: elif original_exception.status_code == 500:
exception_mapping_worked = True exception_mapping_worked = True
@ -599,7 +599,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"ReplicateException - {original_exception.message}", message=f"ReplicateException - {original_exception.message}",
llm_provider="replicate", llm_provider="replicate",
model=model, model=model,
response=original_exception.response, response=getattr(original_exception, "response", None),
) )
exception_mapping_worked = True exception_mapping_worked = True
raise APIError( raise APIError(
@ -631,7 +631,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"{custom_llm_provider}Exception: Authentication Error - {error_str}", message=f"{custom_llm_provider}Exception: Authentication Error - {error_str}",
llm_provider=custom_llm_provider, llm_provider=custom_llm_provider,
model=model, model=model,
response=original_exception.response, response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information, litellm_debug_info=extra_information,
) )
elif "token_quota_reached" in error_str: elif "token_quota_reached" in error_str:
@ -640,7 +640,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"{custom_llm_provider}Exception: Rate Limit Errror - {error_str}", message=f"{custom_llm_provider}Exception: Rate Limit Errror - {error_str}",
llm_provider=custom_llm_provider, llm_provider=custom_llm_provider,
model=model, model=model,
response=original_exception.response, response=getattr(original_exception, "response", None),
) )
elif ( elif (
"The server received an invalid response from an upstream server." "The server received an invalid response from an upstream server."
@ -750,7 +750,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"BedrockException - {error_str}\n. Enable 'litellm.modify_params=True' (for PROXY do: `litellm_settings::modify_params: True`) to insert a dummy assistant message and fix this error.", message=f"BedrockException - {error_str}\n. Enable 'litellm.modify_params=True' (for PROXY do: `litellm_settings::modify_params: True`) to insert a dummy assistant message and fix this error.",
model=model, model=model,
llm_provider="bedrock", llm_provider="bedrock",
response=original_exception.response, response=getattr(original_exception, "response", None),
) )
elif "Malformed input request" in error_str: elif "Malformed input request" in error_str:
exception_mapping_worked = True exception_mapping_worked = True
@ -758,7 +758,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"BedrockException - {error_str}", message=f"BedrockException - {error_str}",
model=model, model=model,
llm_provider="bedrock", llm_provider="bedrock",
response=original_exception.response, response=getattr(original_exception, "response", None),
) )
elif "A conversation must start with a user message." in error_str: elif "A conversation must start with a user message." in error_str:
exception_mapping_worked = True exception_mapping_worked = True
@ -766,7 +766,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"BedrockException - {error_str}\n. Pass in default user message via `completion(..,user_continue_message=)` or enable `litellm.modify_params=True`.\nFor Proxy: do via `litellm_settings::modify_params: True` or user_continue_message under `litellm_params`", message=f"BedrockException - {error_str}\n. Pass in default user message via `completion(..,user_continue_message=)` or enable `litellm.modify_params=True`.\nFor Proxy: do via `litellm_settings::modify_params: True` or user_continue_message under `litellm_params`",
model=model, model=model,
llm_provider="bedrock", llm_provider="bedrock",
response=original_exception.response, response=getattr(original_exception, "response", None),
) )
elif ( elif (
"Unable to locate credentials" in error_str "Unable to locate credentials" in error_str
@ -778,7 +778,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"BedrockException Invalid Authentication - {error_str}", message=f"BedrockException Invalid Authentication - {error_str}",
model=model, model=model,
llm_provider="bedrock", llm_provider="bedrock",
response=original_exception.response, response=getattr(original_exception, "response", None),
) )
elif "AccessDeniedException" in error_str: elif "AccessDeniedException" in error_str:
exception_mapping_worked = True exception_mapping_worked = True
@ -786,7 +786,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"BedrockException PermissionDeniedError - {error_str}", message=f"BedrockException PermissionDeniedError - {error_str}",
model=model, model=model,
llm_provider="bedrock", llm_provider="bedrock",
response=original_exception.response, response=getattr(original_exception, "response", None),
) )
elif ( elif (
"throttlingException" in error_str "throttlingException" in error_str
@ -797,7 +797,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"BedrockException: Rate Limit Error - {error_str}", message=f"BedrockException: Rate Limit Error - {error_str}",
model=model, model=model,
llm_provider="bedrock", llm_provider="bedrock",
response=original_exception.response, response=getattr(original_exception, "response", None),
) )
elif ( elif (
"Connect timeout on endpoint URL" in error_str "Connect timeout on endpoint URL" in error_str
@ -836,7 +836,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"BedrockException - {original_exception.message}", message=f"BedrockException - {original_exception.message}",
llm_provider="bedrock", llm_provider="bedrock",
model=model, model=model,
response=original_exception.response, response=getattr(original_exception, "response", None),
) )
elif original_exception.status_code == 400: elif original_exception.status_code == 400:
exception_mapping_worked = True exception_mapping_worked = True
@ -844,7 +844,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"BedrockException - {original_exception.message}", message=f"BedrockException - {original_exception.message}",
llm_provider="bedrock", llm_provider="bedrock",
model=model, model=model,
response=original_exception.response, response=getattr(original_exception, "response", None),
) )
elif original_exception.status_code == 404: elif original_exception.status_code == 404:
exception_mapping_worked = True exception_mapping_worked = True
@ -852,7 +852,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"BedrockException - {original_exception.message}", message=f"BedrockException - {original_exception.message}",
llm_provider="bedrock", llm_provider="bedrock",
model=model, model=model,
response=original_exception.response, response=getattr(original_exception, "response", None),
) )
elif original_exception.status_code == 408: elif original_exception.status_code == 408:
exception_mapping_worked = True exception_mapping_worked = True
@ -868,7 +868,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"BedrockException - {original_exception.message}", message=f"BedrockException - {original_exception.message}",
model=model, model=model,
llm_provider=custom_llm_provider, llm_provider=custom_llm_provider,
response=original_exception.response, response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information, litellm_debug_info=extra_information,
) )
elif original_exception.status_code == 429: elif original_exception.status_code == 429:
@ -877,7 +877,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"BedrockException - {original_exception.message}", message=f"BedrockException - {original_exception.message}",
model=model, model=model,
llm_provider=custom_llm_provider, llm_provider=custom_llm_provider,
response=original_exception.response, response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information, litellm_debug_info=extra_information,
) )
elif original_exception.status_code == 503: elif original_exception.status_code == 503:
@ -886,7 +886,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"BedrockException - {original_exception.message}", message=f"BedrockException - {original_exception.message}",
model=model, model=model,
llm_provider=custom_llm_provider, llm_provider=custom_llm_provider,
response=original_exception.response, response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information, litellm_debug_info=extra_information,
) )
elif original_exception.status_code == 504: # gateway timeout error elif original_exception.status_code == 504: # gateway timeout error
@ -907,7 +907,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"litellm.BadRequestError: SagemakerException - {error_str}", message=f"litellm.BadRequestError: SagemakerException - {error_str}",
model=model, model=model,
llm_provider="sagemaker", llm_provider="sagemaker",
response=original_exception.response, response=getattr(original_exception, "response", None),
) )
elif ( elif (
"Input validation error: `best_of` must be > 0 and <= 2" "Input validation error: `best_of` must be > 0 and <= 2"
@ -918,7 +918,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message="SagemakerException - the value of 'n' must be > 0 and <= 2 for sagemaker endpoints", message="SagemakerException - the value of 'n' must be > 0 and <= 2 for sagemaker endpoints",
model=model, model=model,
llm_provider="sagemaker", llm_provider="sagemaker",
response=original_exception.response, response=getattr(original_exception, "response", None),
) )
elif ( elif (
"`inputs` tokens + `max_new_tokens` must be <=" in error_str "`inputs` tokens + `max_new_tokens` must be <=" in error_str
@ -929,7 +929,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"SagemakerException - {error_str}", message=f"SagemakerException - {error_str}",
model=model, model=model,
llm_provider="sagemaker", llm_provider="sagemaker",
response=original_exception.response, response=getattr(original_exception, "response", None),
) )
elif hasattr(original_exception, "status_code"): elif hasattr(original_exception, "status_code"):
if original_exception.status_code == 500: if original_exception.status_code == 500:
@ -951,7 +951,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"SagemakerException - {original_exception.message}", message=f"SagemakerException - {original_exception.message}",
llm_provider=custom_llm_provider, llm_provider=custom_llm_provider,
model=model, model=model,
response=original_exception.response, response=getattr(original_exception, "response", None),
) )
elif original_exception.status_code == 400: elif original_exception.status_code == 400:
exception_mapping_worked = True exception_mapping_worked = True
@ -959,7 +959,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"SagemakerException - {original_exception.message}", message=f"SagemakerException - {original_exception.message}",
llm_provider=custom_llm_provider, llm_provider=custom_llm_provider,
model=model, model=model,
response=original_exception.response, response=getattr(original_exception, "response", None),
) )
elif original_exception.status_code == 404: elif original_exception.status_code == 404:
exception_mapping_worked = True exception_mapping_worked = True
@ -967,7 +967,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"SagemakerException - {original_exception.message}", message=f"SagemakerException - {original_exception.message}",
llm_provider=custom_llm_provider, llm_provider=custom_llm_provider,
model=model, model=model,
response=original_exception.response, response=getattr(original_exception, "response", None),
) )
elif original_exception.status_code == 408: elif original_exception.status_code == 408:
exception_mapping_worked = True exception_mapping_worked = True
@ -986,7 +986,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"SagemakerException - {original_exception.message}", message=f"SagemakerException - {original_exception.message}",
model=model, model=model,
llm_provider=custom_llm_provider, llm_provider=custom_llm_provider,
response=original_exception.response, response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information, litellm_debug_info=extra_information,
) )
elif original_exception.status_code == 429: elif original_exception.status_code == 429:
@ -995,7 +995,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"SagemakerException - {original_exception.message}", message=f"SagemakerException - {original_exception.message}",
model=model, model=model,
llm_provider=custom_llm_provider, llm_provider=custom_llm_provider,
response=original_exception.response, response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information, litellm_debug_info=extra_information,
) )
elif original_exception.status_code == 503: elif original_exception.status_code == 503:
@ -1004,7 +1004,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"SagemakerException - {original_exception.message}", message=f"SagemakerException - {original_exception.message}",
model=model, model=model,
llm_provider=custom_llm_provider, llm_provider=custom_llm_provider,
response=original_exception.response, response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information, litellm_debug_info=extra_information,
) )
elif original_exception.status_code == 504: # gateway timeout error elif original_exception.status_code == 504: # gateway timeout error
@ -1217,7 +1217,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message="GeminiException - Invalid api key", message="GeminiException - Invalid api key",
model=model, model=model,
llm_provider="palm", llm_provider="palm",
response=original_exception.response, response=getattr(original_exception, "response", None),
) )
if ( if (
"504 Deadline expired before operation could complete." in error_str "504 Deadline expired before operation could complete." in error_str
@ -1235,7 +1235,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"GeminiException - {error_str}", message=f"GeminiException - {error_str}",
model=model, model=model,
llm_provider="palm", llm_provider="palm",
response=original_exception.response, response=getattr(original_exception, "response", None),
) )
if ( if (
"500 An internal error has occurred." in error_str "500 An internal error has occurred." in error_str
@ -1262,7 +1262,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"GeminiException - {error_str}", message=f"GeminiException - {error_str}",
model=model, model=model,
llm_provider="palm", llm_provider="palm",
response=original_exception.response, response=getattr(original_exception, "response", None),
) )
# Dailed: Error occurred: 400 Request payload size exceeds the limit: 20000 bytes # Dailed: Error occurred: 400 Request payload size exceeds the limit: 20000 bytes
elif custom_llm_provider == "cloudflare": elif custom_llm_provider == "cloudflare":
@ -1272,7 +1272,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"Cloudflare Exception - {original_exception.message}", message=f"Cloudflare Exception - {original_exception.message}",
llm_provider="cloudflare", llm_provider="cloudflare",
model=model, model=model,
response=original_exception.response, response=getattr(original_exception, "response", None),
) )
if "must have required property" in error_str: if "must have required property" in error_str:
exception_mapping_worked = True exception_mapping_worked = True
@ -1280,7 +1280,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"Cloudflare Exception - {original_exception.message}", message=f"Cloudflare Exception - {original_exception.message}",
llm_provider="cloudflare", llm_provider="cloudflare",
model=model, model=model,
response=original_exception.response, response=getattr(original_exception, "response", None),
) )
elif ( elif (
custom_llm_provider == "cohere" or custom_llm_provider == "cohere_chat" custom_llm_provider == "cohere" or custom_llm_provider == "cohere_chat"
@ -1294,7 +1294,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"CohereException - {original_exception.message}", message=f"CohereException - {original_exception.message}",
llm_provider="cohere", llm_provider="cohere",
model=model, model=model,
response=original_exception.response, response=getattr(original_exception, "response", None),
) )
elif "too many tokens" in error_str: elif "too many tokens" in error_str:
exception_mapping_worked = True exception_mapping_worked = True
@ -1302,7 +1302,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"CohereException - {original_exception.message}", message=f"CohereException - {original_exception.message}",
model=model, model=model,
llm_provider="cohere", llm_provider="cohere",
response=original_exception.response, response=getattr(original_exception, "response", None),
) )
elif hasattr(original_exception, "status_code"): elif hasattr(original_exception, "status_code"):
if ( if (
@ -1314,7 +1314,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"CohereException - {original_exception.message}", message=f"CohereException - {original_exception.message}",
llm_provider="cohere", llm_provider="cohere",
model=model, model=model,
response=original_exception.response, response=getattr(original_exception, "response", None),
) )
elif original_exception.status_code == 408: elif original_exception.status_code == 408:
exception_mapping_worked = True exception_mapping_worked = True
@ -1329,7 +1329,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"CohereException - {original_exception.message}", message=f"CohereException - {original_exception.message}",
llm_provider="cohere", llm_provider="cohere",
model=model, model=model,
response=original_exception.response, response=getattr(original_exception, "response", None),
) )
elif ( elif (
"CohereConnectionError" in exception_type "CohereConnectionError" in exception_type
@ -1339,7 +1339,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"CohereException - {original_exception.message}", message=f"CohereException - {original_exception.message}",
llm_provider="cohere", llm_provider="cohere",
model=model, model=model,
response=original_exception.response, response=getattr(original_exception, "response", None),
) )
elif "invalid type:" in error_str: elif "invalid type:" in error_str:
exception_mapping_worked = True exception_mapping_worked = True
@ -1347,7 +1347,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"CohereException - {original_exception.message}", message=f"CohereException - {original_exception.message}",
llm_provider="cohere", llm_provider="cohere",
model=model, model=model,
response=original_exception.response, response=getattr(original_exception, "response", None),
) )
elif "Unexpected server error" in error_str: elif "Unexpected server error" in error_str:
exception_mapping_worked = True exception_mapping_worked = True
@ -1355,7 +1355,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"CohereException - {original_exception.message}", message=f"CohereException - {original_exception.message}",
llm_provider="cohere", llm_provider="cohere",
model=model, model=model,
response=original_exception.response, response=getattr(original_exception, "response", None),
) )
else: else:
if hasattr(original_exception, "status_code"): if hasattr(original_exception, "status_code"):
@ -1375,7 +1375,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=error_str, message=error_str,
model=model, model=model,
llm_provider="huggingface", llm_provider="huggingface",
response=original_exception.response, response=getattr(original_exception, "response", None),
) )
elif "A valid user token is required" in error_str: elif "A valid user token is required" in error_str:
exception_mapping_worked = True exception_mapping_worked = True
@ -1383,7 +1383,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=error_str, message=error_str,
llm_provider="huggingface", llm_provider="huggingface",
model=model, model=model,
response=original_exception.response, response=getattr(original_exception, "response", None),
) )
elif "Rate limit reached" in error_str: elif "Rate limit reached" in error_str:
exception_mapping_worked = True exception_mapping_worked = True
@ -1391,7 +1391,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=error_str, message=error_str,
llm_provider="huggingface", llm_provider="huggingface",
model=model, model=model,
response=original_exception.response, response=getattr(original_exception, "response", None),
) )
if hasattr(original_exception, "status_code"): if hasattr(original_exception, "status_code"):
if original_exception.status_code == 401: if original_exception.status_code == 401:
@ -1400,7 +1400,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"HuggingfaceException - {original_exception.message}", message=f"HuggingfaceException - {original_exception.message}",
llm_provider="huggingface", llm_provider="huggingface",
model=model, model=model,
response=original_exception.response, response=getattr(original_exception, "response", None),
) )
elif original_exception.status_code == 400: elif original_exception.status_code == 400:
exception_mapping_worked = True exception_mapping_worked = True
@ -1408,7 +1408,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"HuggingfaceException - {original_exception.message}", message=f"HuggingfaceException - {original_exception.message}",
model=model, model=model,
llm_provider="huggingface", llm_provider="huggingface",
response=original_exception.response, response=getattr(original_exception, "response", None),
) )
elif original_exception.status_code == 408: elif original_exception.status_code == 408:
exception_mapping_worked = True exception_mapping_worked = True
@ -1423,7 +1423,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"HuggingfaceException - {original_exception.message}", message=f"HuggingfaceException - {original_exception.message}",
llm_provider="huggingface", llm_provider="huggingface",
model=model, model=model,
response=original_exception.response, response=getattr(original_exception, "response", None),
) )
elif original_exception.status_code == 503: elif original_exception.status_code == 503:
exception_mapping_worked = True exception_mapping_worked = True
@ -1431,7 +1431,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"HuggingfaceException - {original_exception.message}", message=f"HuggingfaceException - {original_exception.message}",
llm_provider="huggingface", llm_provider="huggingface",
model=model, model=model,
response=original_exception.response, response=getattr(original_exception, "response", None),
) )
else: else:
exception_mapping_worked = True exception_mapping_worked = True
@ -1450,7 +1450,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"AI21Exception - {original_exception.message}", message=f"AI21Exception - {original_exception.message}",
model=model, model=model,
llm_provider="ai21", llm_provider="ai21",
response=original_exception.response, response=getattr(original_exception, "response", None),
) )
if "Bad or missing API token." in original_exception.message: if "Bad or missing API token." in original_exception.message:
exception_mapping_worked = True exception_mapping_worked = True
@ -1458,7 +1458,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"AI21Exception - {original_exception.message}", message=f"AI21Exception - {original_exception.message}",
model=model, model=model,
llm_provider="ai21", llm_provider="ai21",
response=original_exception.response, response=getattr(original_exception, "response", None),
) )
if hasattr(original_exception, "status_code"): if hasattr(original_exception, "status_code"):
if original_exception.status_code == 401: if original_exception.status_code == 401:
@ -1467,7 +1467,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"AI21Exception - {original_exception.message}", message=f"AI21Exception - {original_exception.message}",
llm_provider="ai21", llm_provider="ai21",
model=model, model=model,
response=original_exception.response, response=getattr(original_exception, "response", None),
) )
elif original_exception.status_code == 408: elif original_exception.status_code == 408:
exception_mapping_worked = True exception_mapping_worked = True
@ -1482,7 +1482,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"AI21Exception - {original_exception.message}", message=f"AI21Exception - {original_exception.message}",
model=model, model=model,
llm_provider="ai21", llm_provider="ai21",
response=original_exception.response, response=getattr(original_exception, "response", None),
) )
elif original_exception.status_code == 429: elif original_exception.status_code == 429:
exception_mapping_worked = True exception_mapping_worked = True
@ -1490,7 +1490,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"AI21Exception - {original_exception.message}", message=f"AI21Exception - {original_exception.message}",
llm_provider="ai21", llm_provider="ai21",
model=model, model=model,
response=original_exception.response, response=getattr(original_exception, "response", None),
) )
else: else:
exception_mapping_worked = True exception_mapping_worked = True
@ -1509,7 +1509,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"NLPCloudException - {error_str}", message=f"NLPCloudException - {error_str}",
model=model, model=model,
llm_provider="nlp_cloud", llm_provider="nlp_cloud",
response=original_exception.response, response=getattr(original_exception, "response", None),
) )
elif "value is not a valid" in error_str: elif "value is not a valid" in error_str:
exception_mapping_worked = True exception_mapping_worked = True
@ -1517,7 +1517,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"NLPCloudException - {error_str}", message=f"NLPCloudException - {error_str}",
model=model, model=model,
llm_provider="nlp_cloud", llm_provider="nlp_cloud",
response=original_exception.response, response=getattr(original_exception, "response", None),
) )
else: else:
exception_mapping_worked = True exception_mapping_worked = True
@ -1542,7 +1542,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"NLPCloudException - {original_exception.message}", message=f"NLPCloudException - {original_exception.message}",
llm_provider="nlp_cloud", llm_provider="nlp_cloud",
model=model, model=model,
response=original_exception.response, response=getattr(original_exception, "response", None),
) )
elif ( elif (
original_exception.status_code == 401 original_exception.status_code == 401
@ -1553,7 +1553,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"NLPCloudException - {original_exception.message}", message=f"NLPCloudException - {original_exception.message}",
llm_provider="nlp_cloud", llm_provider="nlp_cloud",
model=model, model=model,
response=original_exception.response, response=getattr(original_exception, "response", None),
) )
elif ( elif (
original_exception.status_code == 522 original_exception.status_code == 522
@ -1574,7 +1574,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"NLPCloudException - {original_exception.message}", message=f"NLPCloudException - {original_exception.message}",
llm_provider="nlp_cloud", llm_provider="nlp_cloud",
model=model, model=model,
response=original_exception.response, response=getattr(original_exception, "response", None),
) )
elif ( elif (
original_exception.status_code == 500 original_exception.status_code == 500
@ -1597,7 +1597,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"NLPCloudException - {original_exception.message}", message=f"NLPCloudException - {original_exception.message}",
model=model, model=model,
llm_provider="nlp_cloud", llm_provider="nlp_cloud",
response=original_exception.response, response=getattr(original_exception, "response", None),
) )
else: else:
exception_mapping_worked = True exception_mapping_worked = True
@ -1623,7 +1623,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"TogetherAIException - {error_response['error']}", message=f"TogetherAIException - {error_response['error']}",
model=model, model=model,
llm_provider="together_ai", llm_provider="together_ai",
response=original_exception.response, response=getattr(original_exception, "response", None),
) )
elif ( elif (
"error" in error_response "error" in error_response
@ -1634,7 +1634,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"TogetherAIException - {error_response['error']}", message=f"TogetherAIException - {error_response['error']}",
llm_provider="together_ai", llm_provider="together_ai",
model=model, model=model,
response=original_exception.response, response=getattr(original_exception, "response", None),
) )
elif ( elif (
"error" in error_response "error" in error_response
@ -1645,7 +1645,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"TogetherAIException - {error_response['error']}", message=f"TogetherAIException - {error_response['error']}",
model=model, model=model,
llm_provider="together_ai", llm_provider="together_ai",
response=original_exception.response, response=getattr(original_exception, "response", None),
) )
elif "A timeout occurred" in error_str: elif "A timeout occurred" in error_str:
exception_mapping_worked = True exception_mapping_worked = True
@ -1664,7 +1664,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"TogetherAIException - {error_response['error']}", message=f"TogetherAIException - {error_response['error']}",
model=model, model=model,
llm_provider="together_ai", llm_provider="together_ai",
response=original_exception.response, response=getattr(original_exception, "response", None),
) )
elif ( elif (
"error_type" in error_response "error_type" in error_response
@ -1675,7 +1675,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"TogetherAIException - {error_response['error']}", message=f"TogetherAIException - {error_response['error']}",
model=model, model=model,
llm_provider="together_ai", llm_provider="together_ai",
response=original_exception.response, response=getattr(original_exception, "response", None),
) )
if hasattr(original_exception, "status_code"): if hasattr(original_exception, "status_code"):
if original_exception.status_code == 408: if original_exception.status_code == 408:
@ -1691,7 +1691,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"TogetherAIException - {error_response['error']}", message=f"TogetherAIException - {error_response['error']}",
model=model, model=model,
llm_provider="together_ai", llm_provider="together_ai",
response=original_exception.response, response=getattr(original_exception, "response", None),
) )
elif original_exception.status_code == 429: elif original_exception.status_code == 429:
exception_mapping_worked = True exception_mapping_worked = True
@ -1699,7 +1699,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"TogetherAIException - {original_exception.message}", message=f"TogetherAIException - {original_exception.message}",
llm_provider="together_ai", llm_provider="together_ai",
model=model, model=model,
response=original_exception.response, response=getattr(original_exception, "response", None),
) )
elif original_exception.status_code == 524: elif original_exception.status_code == 524:
exception_mapping_worked = True exception_mapping_worked = True
@ -1727,7 +1727,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"AlephAlphaException - {original_exception.message}", message=f"AlephAlphaException - {original_exception.message}",
llm_provider="aleph_alpha", llm_provider="aleph_alpha",
model=model, model=model,
response=original_exception.response, response=getattr(original_exception, "response", None),
) )
elif "InvalidToken" in error_str or "No token provided" in error_str: elif "InvalidToken" in error_str or "No token provided" in error_str:
exception_mapping_worked = True exception_mapping_worked = True
@ -1735,7 +1735,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"AlephAlphaException - {original_exception.message}", message=f"AlephAlphaException - {original_exception.message}",
llm_provider="aleph_alpha", llm_provider="aleph_alpha",
model=model, model=model,
response=original_exception.response, response=getattr(original_exception, "response", None),
) )
elif hasattr(original_exception, "status_code"): elif hasattr(original_exception, "status_code"):
verbose_logger.debug( verbose_logger.debug(
@ -1754,7 +1754,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"AlephAlphaException - {original_exception.message}", message=f"AlephAlphaException - {original_exception.message}",
llm_provider="aleph_alpha", llm_provider="aleph_alpha",
model=model, model=model,
response=original_exception.response, response=getattr(original_exception, "response", None),
) )
elif original_exception.status_code == 429: elif original_exception.status_code == 429:
exception_mapping_worked = True exception_mapping_worked = True
@ -1762,7 +1762,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"AlephAlphaException - {original_exception.message}", message=f"AlephAlphaException - {original_exception.message}",
llm_provider="aleph_alpha", llm_provider="aleph_alpha",
model=model, model=model,
response=original_exception.response, response=getattr(original_exception, "response", None),
) )
elif original_exception.status_code == 500: elif original_exception.status_code == 500:
exception_mapping_worked = True exception_mapping_worked = True
@ -1770,7 +1770,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"AlephAlphaException - {original_exception.message}", message=f"AlephAlphaException - {original_exception.message}",
llm_provider="aleph_alpha", llm_provider="aleph_alpha",
model=model, model=model,
response=original_exception.response, response=getattr(original_exception, "response", None),
) )
raise original_exception raise original_exception
raise original_exception raise original_exception
@ -1787,7 +1787,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"OllamaException: Invalid Model/Model not loaded - {original_exception}", message=f"OllamaException: Invalid Model/Model not loaded - {original_exception}",
model=model, model=model,
llm_provider="ollama", llm_provider="ollama",
response=original_exception.response, response=getattr(original_exception, "response", None),
) )
elif "Failed to establish a new connection" in error_str: elif "Failed to establish a new connection" in error_str:
exception_mapping_worked = True exception_mapping_worked = True
@ -1795,7 +1795,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"OllamaException: {original_exception}", message=f"OllamaException: {original_exception}",
llm_provider="ollama", llm_provider="ollama",
model=model, model=model,
response=original_exception.response, response=getattr(original_exception, "response", None),
) )
elif "Invalid response object from API" in error_str: elif "Invalid response object from API" in error_str:
exception_mapping_worked = True exception_mapping_worked = True
@ -1803,7 +1803,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"OllamaException: {original_exception}", message=f"OllamaException: {original_exception}",
llm_provider="ollama", llm_provider="ollama",
model=model, model=model,
response=original_exception.response, response=getattr(original_exception, "response", None),
) )
elif "Read timed out" in error_str: elif "Read timed out" in error_str:
exception_mapping_worked = True exception_mapping_worked = True
@ -1837,6 +1837,7 @@ def exception_type( # type: ignore # noqa: PLR0915
llm_provider="azure", llm_provider="azure",
model=model, model=model,
litellm_debug_info=extra_information, litellm_debug_info=extra_information,
response=getattr(original_exception, "response", None),
) )
elif "This model's maximum context length is" in error_str: elif "This model's maximum context length is" in error_str:
exception_mapping_worked = True exception_mapping_worked = True
@ -1845,6 +1846,7 @@ def exception_type( # type: ignore # noqa: PLR0915
llm_provider="azure", llm_provider="azure",
model=model, model=model,
litellm_debug_info=extra_information, litellm_debug_info=extra_information,
response=getattr(original_exception, "response", None),
) )
elif "DeploymentNotFound" in error_str: elif "DeploymentNotFound" in error_str:
exception_mapping_worked = True exception_mapping_worked = True
@ -1853,6 +1855,7 @@ def exception_type( # type: ignore # noqa: PLR0915
llm_provider="azure", llm_provider="azure",
model=model, model=model,
litellm_debug_info=extra_information, litellm_debug_info=extra_information,
response=getattr(original_exception, "response", None),
) )
elif ( elif (
( (
@ -1873,6 +1876,7 @@ def exception_type( # type: ignore # noqa: PLR0915
llm_provider="azure", llm_provider="azure",
model=model, model=model,
litellm_debug_info=extra_information, litellm_debug_info=extra_information,
response=getattr(original_exception, "response", None),
) )
elif "invalid_request_error" in error_str: elif "invalid_request_error" in error_str:
exception_mapping_worked = True exception_mapping_worked = True
@ -1881,6 +1885,7 @@ def exception_type( # type: ignore # noqa: PLR0915
llm_provider="azure", llm_provider="azure",
model=model, model=model,
litellm_debug_info=extra_information, litellm_debug_info=extra_information,
response=getattr(original_exception, "response", None),
) )
elif ( elif (
"The api_key client option must be set either by passing api_key to the client or by setting" "The api_key client option must be set either by passing api_key to the client or by setting"
@ -1892,6 +1897,7 @@ def exception_type( # type: ignore # noqa: PLR0915
llm_provider=custom_llm_provider, llm_provider=custom_llm_provider,
model=model, model=model,
litellm_debug_info=extra_information, litellm_debug_info=extra_information,
response=getattr(original_exception, "response", None),
) )
elif "Connection error" in error_str: elif "Connection error" in error_str:
exception_mapping_worked = True exception_mapping_worked = True
@ -1910,6 +1916,7 @@ def exception_type( # type: ignore # noqa: PLR0915
llm_provider="azure", llm_provider="azure",
model=model, model=model,
litellm_debug_info=extra_information, litellm_debug_info=extra_information,
response=getattr(original_exception, "response", None),
) )
elif original_exception.status_code == 401: elif original_exception.status_code == 401:
exception_mapping_worked = True exception_mapping_worked = True
@ -1918,6 +1925,7 @@ def exception_type( # type: ignore # noqa: PLR0915
llm_provider="azure", llm_provider="azure",
model=model, model=model,
litellm_debug_info=extra_information, litellm_debug_info=extra_information,
response=getattr(original_exception, "response", None),
) )
elif original_exception.status_code == 408: elif original_exception.status_code == 408:
exception_mapping_worked = True exception_mapping_worked = True
@ -1934,6 +1942,7 @@ def exception_type( # type: ignore # noqa: PLR0915
model=model, model=model,
llm_provider="azure", llm_provider="azure",
litellm_debug_info=extra_information, litellm_debug_info=extra_information,
response=getattr(original_exception, "response", None),
) )
elif original_exception.status_code == 429: elif original_exception.status_code == 429:
exception_mapping_worked = True exception_mapping_worked = True
@ -1942,6 +1951,7 @@ def exception_type( # type: ignore # noqa: PLR0915
model=model, model=model,
llm_provider="azure", llm_provider="azure",
litellm_debug_info=extra_information, litellm_debug_info=extra_information,
response=getattr(original_exception, "response", None),
) )
elif original_exception.status_code == 503: elif original_exception.status_code == 503:
exception_mapping_worked = True exception_mapping_worked = True
@ -1950,6 +1960,7 @@ def exception_type( # type: ignore # noqa: PLR0915
model=model, model=model,
llm_provider="azure", llm_provider="azure",
litellm_debug_info=extra_information, litellm_debug_info=extra_information,
response=getattr(original_exception, "response", None),
) )
elif original_exception.status_code == 504: # gateway timeout error elif original_exception.status_code == 504: # gateway timeout error
exception_mapping_worked = True exception_mapping_worked = True
@ -1989,7 +2000,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"{exception_provider} - {error_str}", message=f"{exception_provider} - {error_str}",
llm_provider=custom_llm_provider, llm_provider=custom_llm_provider,
model=model, model=model,
response=original_exception.response, response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information, litellm_debug_info=extra_information,
) )
elif original_exception.status_code == 401: elif original_exception.status_code == 401:
@ -1998,7 +2009,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"AuthenticationError: {exception_provider} - {error_str}", message=f"AuthenticationError: {exception_provider} - {error_str}",
llm_provider=custom_llm_provider, llm_provider=custom_llm_provider,
model=model, model=model,
response=original_exception.response, response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information, litellm_debug_info=extra_information,
) )
elif original_exception.status_code == 404: elif original_exception.status_code == 404:
@ -2007,7 +2018,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"NotFoundError: {exception_provider} - {error_str}", message=f"NotFoundError: {exception_provider} - {error_str}",
model=model, model=model,
llm_provider=custom_llm_provider, llm_provider=custom_llm_provider,
response=original_exception.response, response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information, litellm_debug_info=extra_information,
) )
elif original_exception.status_code == 408: elif original_exception.status_code == 408:
@ -2024,7 +2035,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"BadRequestError: {exception_provider} - {error_str}", message=f"BadRequestError: {exception_provider} - {error_str}",
model=model, model=model,
llm_provider=custom_llm_provider, llm_provider=custom_llm_provider,
response=original_exception.response, response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information, litellm_debug_info=extra_information,
) )
elif original_exception.status_code == 429: elif original_exception.status_code == 429:
@ -2033,7 +2044,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"RateLimitError: {exception_provider} - {error_str}", message=f"RateLimitError: {exception_provider} - {error_str}",
model=model, model=model,
llm_provider=custom_llm_provider, llm_provider=custom_llm_provider,
response=original_exception.response, response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information, litellm_debug_info=extra_information,
) )
elif original_exception.status_code == 503: elif original_exception.status_code == 503:
@ -2042,7 +2053,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"ServiceUnavailableError: {exception_provider} - {error_str}", message=f"ServiceUnavailableError: {exception_provider} - {error_str}",
model=model, model=model,
llm_provider=custom_llm_provider, llm_provider=custom_llm_provider,
response=original_exception.response, response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information, litellm_debug_info=extra_information,
) )
elif original_exception.status_code == 504: # gateway timeout error elif original_exception.status_code == 504: # gateway timeout error

View file

@ -202,6 +202,7 @@ class Logging:
start_time, start_time,
litellm_call_id: str, litellm_call_id: str,
function_id: str, function_id: str,
litellm_trace_id: Optional[str] = None,
dynamic_input_callbacks: Optional[ dynamic_input_callbacks: Optional[
List[Union[str, Callable, CustomLogger]] List[Union[str, Callable, CustomLogger]]
] = None, ] = None,
@ -239,6 +240,7 @@ class Logging:
self.start_time = start_time # log the call start time self.start_time = start_time # log the call start time
self.call_type = call_type self.call_type = call_type
self.litellm_call_id = litellm_call_id self.litellm_call_id = litellm_call_id
self.litellm_trace_id = litellm_trace_id
self.function_id = function_id self.function_id = function_id
self.streaming_chunks: List[Any] = [] # for generating complete stream response self.streaming_chunks: List[Any] = [] # for generating complete stream response
self.sync_streaming_chunks: List[Any] = ( self.sync_streaming_chunks: List[Any] = (
@ -275,6 +277,11 @@ class Logging:
self.completion_start_time: Optional[datetime.datetime] = None self.completion_start_time: Optional[datetime.datetime] = None
self._llm_caching_handler: Optional[LLMCachingHandler] = None self._llm_caching_handler: Optional[LLMCachingHandler] = None
self.model_call_details = {
"litellm_trace_id": litellm_trace_id,
"litellm_call_id": litellm_call_id,
}
def process_dynamic_callbacks(self): def process_dynamic_callbacks(self):
""" """
Initializes CustomLogger compatible callbacks in self.dynamic_* callbacks Initializes CustomLogger compatible callbacks in self.dynamic_* callbacks
@ -382,21 +389,23 @@ class Logging:
self.logger_fn = litellm_params.get("logger_fn", None) self.logger_fn = litellm_params.get("logger_fn", None)
verbose_logger.debug(f"self.optional_params: {self.optional_params}") verbose_logger.debug(f"self.optional_params: {self.optional_params}")
self.model_call_details = { self.model_call_details.update(
"model": self.model, {
"messages": self.messages, "model": self.model,
"optional_params": self.optional_params, "messages": self.messages,
"litellm_params": self.litellm_params, "optional_params": self.optional_params,
"start_time": self.start_time, "litellm_params": self.litellm_params,
"stream": self.stream, "start_time": self.start_time,
"user": user, "stream": self.stream,
"call_type": str(self.call_type), "user": user,
"litellm_call_id": self.litellm_call_id, "call_type": str(self.call_type),
"completion_start_time": self.completion_start_time, "litellm_call_id": self.litellm_call_id,
"standard_callback_dynamic_params": self.standard_callback_dynamic_params, "completion_start_time": self.completion_start_time,
**self.optional_params, "standard_callback_dynamic_params": self.standard_callback_dynamic_params,
**additional_params, **self.optional_params,
} **additional_params,
}
)
## check if stream options is set ## - used by CustomStreamWrapper for easy instrumentation ## check if stream options is set ## - used by CustomStreamWrapper for easy instrumentation
if "stream_options" in additional_params: if "stream_options" in additional_params:
@ -2823,6 +2832,7 @@ def get_standard_logging_object_payload(
payload: StandardLoggingPayload = StandardLoggingPayload( payload: StandardLoggingPayload = StandardLoggingPayload(
id=str(id), id=str(id),
trace_id=kwargs.get("litellm_trace_id"), # type: ignore
call_type=call_type or "", call_type=call_type or "",
cache_hit=cache_hit, cache_hit=cache_hit,
status=status, status=status,

View file

@ -44,7 +44,9 @@ from litellm.types.llms.openai import (
ChatCompletionToolCallFunctionChunk, ChatCompletionToolCallFunctionChunk,
ChatCompletionUsageBlock, ChatCompletionUsageBlock,
) )
from litellm.types.utils import GenericStreamingChunk, PromptTokensDetailsWrapper from litellm.types.utils import GenericStreamingChunk
from litellm.types.utils import Message as LitellmMessage
from litellm.types.utils import PromptTokensDetailsWrapper
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
from ...base import BaseLLM from ...base import BaseLLM
@ -94,6 +96,7 @@ async def make_call(
messages: list, messages: list,
logging_obj, logging_obj,
timeout: Optional[Union[float, httpx.Timeout]], timeout: Optional[Union[float, httpx.Timeout]],
json_mode: bool,
) -> Tuple[Any, httpx.Headers]: ) -> Tuple[Any, httpx.Headers]:
if client is None: if client is None:
client = litellm.module_level_aclient client = litellm.module_level_aclient
@ -119,7 +122,9 @@ async def make_call(
raise AnthropicError(status_code=500, message=str(e)) raise AnthropicError(status_code=500, message=str(e))
completion_stream = ModelResponseIterator( completion_stream = ModelResponseIterator(
streaming_response=response.aiter_lines(), sync_stream=False streaming_response=response.aiter_lines(),
sync_stream=False,
json_mode=json_mode,
) )
# LOGGING # LOGGING
@ -142,6 +147,7 @@ def make_sync_call(
messages: list, messages: list,
logging_obj, logging_obj,
timeout: Optional[Union[float, httpx.Timeout]], timeout: Optional[Union[float, httpx.Timeout]],
json_mode: bool,
) -> Tuple[Any, httpx.Headers]: ) -> Tuple[Any, httpx.Headers]:
if client is None: if client is None:
client = litellm.module_level_client # re-use a module level client client = litellm.module_level_client # re-use a module level client
@ -175,7 +181,7 @@ def make_sync_call(
) )
completion_stream = ModelResponseIterator( completion_stream = ModelResponseIterator(
streaming_response=response.iter_lines(), sync_stream=True streaming_response=response.iter_lines(), sync_stream=True, json_mode=json_mode
) )
# LOGGING # LOGGING
@ -270,11 +276,12 @@ class AnthropicChatCompletion(BaseLLM):
"arguments" "arguments"
) )
if json_mode_content_str is not None: if json_mode_content_str is not None:
args = json.loads(json_mode_content_str) _converted_message = self._convert_tool_response_to_message(
values: Optional[dict] = args.get("values") tool_calls=tool_calls,
if values is not None: )
_message = litellm.Message(content=json.dumps(values)) if _converted_message is not None:
completion_response["stop_reason"] = "stop" completion_response["stop_reason"] = "stop"
_message = _converted_message
model_response.choices[0].message = _message # type: ignore model_response.choices[0].message = _message # type: ignore
model_response._hidden_params["original_response"] = completion_response[ model_response._hidden_params["original_response"] = completion_response[
"content" "content"
@ -318,6 +325,37 @@ class AnthropicChatCompletion(BaseLLM):
model_response._hidden_params = _hidden_params model_response._hidden_params = _hidden_params
return model_response return model_response
@staticmethod
def _convert_tool_response_to_message(
tool_calls: List[ChatCompletionToolCallChunk],
) -> Optional[LitellmMessage]:
"""
In JSON mode, Anthropic API returns JSON schema as a tool call, we need to convert it to a message to follow the OpenAI format
"""
## HANDLE JSON MODE - anthropic returns single function call
json_mode_content_str: Optional[str] = tool_calls[0]["function"].get(
"arguments"
)
try:
if json_mode_content_str is not None:
args = json.loads(json_mode_content_str)
if (
isinstance(args, dict)
and (values := args.get("values")) is not None
):
_message = litellm.Message(content=json.dumps(values))
return _message
else:
# a lot of the times the `values` key is not present in the tool response
# relevant issue: https://github.com/BerriAI/litellm/issues/6741
_message = litellm.Message(content=json.dumps(args))
return _message
except json.JSONDecodeError:
# json decode error does occur, return the original tool response str
return litellm.Message(content=json_mode_content_str)
return None
async def acompletion_stream_function( async def acompletion_stream_function(
self, self,
model: str, model: str,
@ -334,6 +372,7 @@ class AnthropicChatCompletion(BaseLLM):
stream, stream,
_is_function_call, _is_function_call,
data: dict, data: dict,
json_mode: bool,
optional_params=None, optional_params=None,
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
@ -350,6 +389,7 @@ class AnthropicChatCompletion(BaseLLM):
messages=messages, messages=messages,
logging_obj=logging_obj, logging_obj=logging_obj,
timeout=timeout, timeout=timeout,
json_mode=json_mode,
) )
streamwrapper = CustomStreamWrapper( streamwrapper = CustomStreamWrapper(
completion_stream=completion_stream, completion_stream=completion_stream,
@ -440,8 +480,8 @@ class AnthropicChatCompletion(BaseLLM):
logging_obj, logging_obj,
optional_params: dict, optional_params: dict,
timeout: Union[float, httpx.Timeout], timeout: Union[float, httpx.Timeout],
litellm_params: dict,
acompletion=None, acompletion=None,
litellm_params=None,
logger_fn=None, logger_fn=None,
headers={}, headers={},
client=None, client=None,
@ -464,6 +504,7 @@ class AnthropicChatCompletion(BaseLLM):
model=model, model=model,
messages=messages, messages=messages,
optional_params=optional_params, optional_params=optional_params,
litellm_params=litellm_params,
headers=headers, headers=headers,
_is_function_call=_is_function_call, _is_function_call=_is_function_call,
is_vertex_request=is_vertex_request, is_vertex_request=is_vertex_request,
@ -500,6 +541,7 @@ class AnthropicChatCompletion(BaseLLM):
optional_params=optional_params, optional_params=optional_params,
stream=stream, stream=stream,
_is_function_call=_is_function_call, _is_function_call=_is_function_call,
json_mode=json_mode,
litellm_params=litellm_params, litellm_params=litellm_params,
logger_fn=logger_fn, logger_fn=logger_fn,
headers=headers, headers=headers,
@ -547,6 +589,7 @@ class AnthropicChatCompletion(BaseLLM):
messages=messages, messages=messages,
logging_obj=logging_obj, logging_obj=logging_obj,
timeout=timeout, timeout=timeout,
json_mode=json_mode,
) )
return CustomStreamWrapper( return CustomStreamWrapper(
completion_stream=completion_stream, completion_stream=completion_stream,
@ -605,11 +648,14 @@ class AnthropicChatCompletion(BaseLLM):
class ModelResponseIterator: class ModelResponseIterator:
def __init__(self, streaming_response, sync_stream: bool): def __init__(
self, streaming_response, sync_stream: bool, json_mode: Optional[bool] = False
):
self.streaming_response = streaming_response self.streaming_response = streaming_response
self.response_iterator = self.streaming_response self.response_iterator = self.streaming_response
self.content_blocks: List[ContentBlockDelta] = [] self.content_blocks: List[ContentBlockDelta] = []
self.tool_index = -1 self.tool_index = -1
self.json_mode = json_mode
def check_empty_tool_call_args(self) -> bool: def check_empty_tool_call_args(self) -> bool:
""" """
@ -771,6 +817,8 @@ class ModelResponseIterator:
status_code=500, # it looks like Anthropic API does not return a status code in the chunk error - default to 500 status_code=500, # it looks like Anthropic API does not return a status code in the chunk error - default to 500
) )
text, tool_use = self._handle_json_mode_chunk(text=text, tool_use=tool_use)
returned_chunk = GenericStreamingChunk( returned_chunk = GenericStreamingChunk(
text=text, text=text,
tool_use=tool_use, tool_use=tool_use,
@ -785,6 +833,34 @@ class ModelResponseIterator:
except json.JSONDecodeError: except json.JSONDecodeError:
raise ValueError(f"Failed to decode JSON from chunk: {chunk}") raise ValueError(f"Failed to decode JSON from chunk: {chunk}")
def _handle_json_mode_chunk(
self, text: str, tool_use: Optional[ChatCompletionToolCallChunk]
) -> Tuple[str, Optional[ChatCompletionToolCallChunk]]:
"""
If JSON mode is enabled, convert the tool call to a message.
Anthropic returns the JSON schema as part of the tool call
OpenAI returns the JSON schema as part of the content, this handles placing it in the content
Args:
text: str
tool_use: Optional[ChatCompletionToolCallChunk]
Returns:
Tuple[str, Optional[ChatCompletionToolCallChunk]]
text: The text to use in the content
tool_use: The ChatCompletionToolCallChunk to use in the chunk response
"""
if self.json_mode is True and tool_use is not None:
message = AnthropicChatCompletion._convert_tool_response_to_message(
tool_calls=[tool_use]
)
if message is not None:
text = message.content or ""
tool_use = None
return text, tool_use
# Sync iterator # Sync iterator
def __iter__(self): def __iter__(self):
return self return self

View file

@ -91,6 +91,7 @@ class AnthropicConfig:
"extra_headers", "extra_headers",
"parallel_tool_calls", "parallel_tool_calls",
"response_format", "response_format",
"user",
] ]
def get_cache_control_headers(self) -> dict: def get_cache_control_headers(self) -> dict:
@ -246,6 +247,28 @@ class AnthropicConfig:
anthropic_tools.append(new_tool) anthropic_tools.append(new_tool)
return anthropic_tools return anthropic_tools
def _map_stop_sequences(
self, stop: Optional[Union[str, List[str]]]
) -> Optional[List[str]]:
new_stop: Optional[List[str]] = None
if isinstance(stop, str):
if (
stop == "\n"
) and litellm.drop_params is True: # anthropic doesn't allow whitespace characters as stop-sequences
return new_stop
new_stop = [stop]
elif isinstance(stop, list):
new_v = []
for v in stop:
if (
v == "\n"
) and litellm.drop_params is True: # anthropic doesn't allow whitespace characters as stop-sequences
continue
new_v.append(v)
if len(new_v) > 0:
new_stop = new_v
return new_stop
def map_openai_params( def map_openai_params(
self, self,
non_default_params: dict, non_default_params: dict,
@ -271,26 +294,10 @@ class AnthropicConfig:
optional_params["tool_choice"] = _tool_choice optional_params["tool_choice"] = _tool_choice
if param == "stream" and value is True: if param == "stream" and value is True:
optional_params["stream"] = value optional_params["stream"] = value
if param == "stop": if param == "stop" and (isinstance(value, str) or isinstance(value, list)):
if isinstance(value, str): _value = self._map_stop_sequences(value)
if ( if _value is not None:
value == "\n" optional_params["stop_sequences"] = _value
) and litellm.drop_params is True: # anthropic doesn't allow whitespace characters as stop-sequences
continue
value = [value]
elif isinstance(value, list):
new_v = []
for v in value:
if (
v == "\n"
) and litellm.drop_params is True: # anthropic doesn't allow whitespace characters as stop-sequences
continue
new_v.append(v)
if len(new_v) > 0:
value = new_v
else:
continue
optional_params["stop_sequences"] = value
if param == "temperature": if param == "temperature":
optional_params["temperature"] = value optional_params["temperature"] = value
if param == "top_p": if param == "top_p":
@ -314,7 +321,8 @@ class AnthropicConfig:
optional_params["tools"] = [_tool] optional_params["tools"] = [_tool]
optional_params["tool_choice"] = _tool_choice optional_params["tool_choice"] = _tool_choice
optional_params["json_mode"] = True optional_params["json_mode"] = True
if param == "user":
optional_params["metadata"] = {"user_id": value}
## VALIDATE REQUEST ## VALIDATE REQUEST
""" """
Anthropic doesn't support tool calling without `tools=` param specified. Anthropic doesn't support tool calling without `tools=` param specified.
@ -465,6 +473,7 @@ class AnthropicConfig:
model: str, model: str,
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: dict, optional_params: dict,
litellm_params: dict,
headers: dict, headers: dict,
_is_function_call: bool, _is_function_call: bool,
is_vertex_request: bool, is_vertex_request: bool,
@ -502,6 +511,15 @@ class AnthropicConfig:
if "tools" in optional_params: if "tools" in optional_params:
_is_function_call = True _is_function_call = True
## Handle user_id in metadata
_litellm_metadata = litellm_params.get("metadata", None)
if (
_litellm_metadata
and isinstance(_litellm_metadata, dict)
and "user_id" in _litellm_metadata
):
optional_params["metadata"] = {"user_id": _litellm_metadata["user_id"]}
data = { data = {
"messages": anthropic_messages, "messages": anthropic_messages,
**optional_params, **optional_params,

View file

@ -53,9 +53,15 @@ class AmazonStability3Config:
sd3-medium sd3-medium
sd3.5-large sd3.5-large
sd3.5-large-turbo sd3.5-large-turbo
Stability ultra models
stable-image-ultra-v1
""" """
if model and ("sd3" in model or "sd3.5" in model): if model:
return True if "sd3" in model or "sd3.5" in model:
return True
if "stable-image-ultra-v1" in model:
return True
return False return False
@classmethod @classmethod

View file

@ -8,3 +8,4 @@ class httpxSpecialProvider(str, Enum):
GuardrailCallback = "guardrail_callback" GuardrailCallback = "guardrail_callback"
Caching = "caching" Caching = "caching"
Oauth2Check = "oauth2_check" Oauth2Check = "oauth2_check"
SecretManager = "secret_manager"

View file

@ -76,4 +76,4 @@ class JinaAIEmbeddingConfig:
or get_secret_str("JINA_AI_API_KEY") or get_secret_str("JINA_AI_API_KEY")
or get_secret_str("JINA_AI_TOKEN") or get_secret_str("JINA_AI_TOKEN")
) )
return LlmProviders.OPENAI_LIKE.value, api_base, dynamic_api_key return LlmProviders.JINA_AI.value, api_base, dynamic_api_key

View file

@ -0,0 +1,96 @@
"""
Re rank api
LiteLLM supports the re rank API format, no paramter transformation occurs
"""
import uuid
from typing import Any, Dict, List, Optional, Union
import httpx
from pydantic import BaseModel
import litellm
from litellm.llms.base import BaseLLM
from litellm.llms.custom_httpx.http_handler import (
_get_httpx_client,
get_async_httpx_client,
)
from litellm.llms.jina_ai.rerank.transformation import JinaAIRerankConfig
from litellm.types.rerank import RerankRequest, RerankResponse
class JinaAIRerank(BaseLLM):
def rerank(
self,
model: str,
api_key: str,
query: str,
documents: List[Union[str, Dict[str, Any]]],
top_n: Optional[int] = None,
rank_fields: Optional[List[str]] = None,
return_documents: Optional[bool] = True,
max_chunks_per_doc: Optional[int] = None,
_is_async: Optional[bool] = False,
) -> RerankResponse:
client = _get_httpx_client()
request_data = RerankRequest(
model=model,
query=query,
top_n=top_n,
documents=documents,
rank_fields=rank_fields,
return_documents=return_documents,
)
# exclude None values from request_data
request_data_dict = request_data.dict(exclude_none=True)
if _is_async:
return self.async_rerank(request_data_dict, api_key) # type: ignore # Call async method
response = client.post(
"https://api.jina.ai/v1/rerank",
headers={
"accept": "application/json",
"content-type": "application/json",
"authorization": f"Bearer {api_key}",
},
json=request_data_dict,
)
if response.status_code != 200:
raise Exception(response.text)
_json_response = response.json()
return JinaAIRerankConfig()._transform_response(_json_response)
async def async_rerank( # New async method
self,
request_data_dict: Dict[str, Any],
api_key: str,
) -> RerankResponse:
client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.JINA_AI
) # Use async client
response = await client.post(
"https://api.jina.ai/v1/rerank",
headers={
"accept": "application/json",
"content-type": "application/json",
"authorization": f"Bearer {api_key}",
},
json=request_data_dict,
)
if response.status_code != 200:
raise Exception(response.text)
_json_response = response.json()
return JinaAIRerankConfig()._transform_response(_json_response)
pass

View file

@ -0,0 +1,36 @@
"""
Transformation logic from Cohere's /v1/rerank format to Jina AI's `/v1/rerank` format.
Why separate file? Make it easy to see how transformation works
Docs - https://jina.ai/reranker
"""
import uuid
from typing import List, Optional
from litellm.types.rerank import (
RerankBilledUnits,
RerankResponse,
RerankResponseMeta,
RerankTokens,
)
class JinaAIRerankConfig:
def _transform_response(self, response: dict) -> RerankResponse:
_billed_units = RerankBilledUnits(**response.get("usage", {}))
_tokens = RerankTokens(**response.get("usage", {}))
rerank_meta = RerankResponseMeta(billed_units=_billed_units, tokens=_tokens)
_results: Optional[List[dict]] = response.get("results")
if _results is None:
raise ValueError(f"No results found in the response={response}")
return RerankResponse(
id=response.get("id") or str(uuid.uuid4()),
results=_results,
meta=rerank_meta,
) # Return response

View file

@ -185,6 +185,8 @@ class OllamaConfig:
"name": "mistral" "name": "mistral"
}' }'
""" """
if model.startswith("ollama/") or model.startswith("ollama_chat/"):
model = model.split("/", 1)[1]
api_base = get_secret_str("OLLAMA_API_BASE") or "http://localhost:11434" api_base = get_secret_str("OLLAMA_API_BASE") or "http://localhost:11434"
try: try:

View file

@ -15,7 +15,14 @@ from litellm.llms.custom_httpx.http_handler import (
_get_httpx_client, _get_httpx_client,
get_async_httpx_client, get_async_httpx_client,
) )
from litellm.types.rerank import RerankRequest, RerankResponse from litellm.llms.together_ai.rerank.transformation import TogetherAIRerankConfig
from litellm.types.rerank import (
RerankBilledUnits,
RerankRequest,
RerankResponse,
RerankResponseMeta,
RerankTokens,
)
class TogetherAIRerank(BaseLLM): class TogetherAIRerank(BaseLLM):
@ -65,13 +72,7 @@ class TogetherAIRerank(BaseLLM):
_json_response = response.json() _json_response = response.json()
response = RerankResponse( return TogetherAIRerankConfig()._transform_response(_json_response)
id=_json_response.get("id"),
results=_json_response.get("results"),
meta=_json_response.get("meta") or {},
)
return response
async def async_rerank( # New async method async def async_rerank( # New async method
self, self,
@ -97,10 +98,4 @@ class TogetherAIRerank(BaseLLM):
_json_response = response.json() _json_response = response.json()
return RerankResponse( return TogetherAIRerankConfig()._transform_response(_json_response)
id=_json_response.get("id"),
results=_json_response.get("results"),
meta=_json_response.get("meta") or {},
) # Return response
pass

View file

@ -0,0 +1,34 @@
"""
Transformation logic from Cohere's /v1/rerank format to Together AI's `/v1/rerank` format.
Why separate file? Make it easy to see how transformation works
"""
import uuid
from typing import List, Optional
from litellm.types.rerank import (
RerankBilledUnits,
RerankResponse,
RerankResponseMeta,
RerankTokens,
)
class TogetherAIRerankConfig:
def _transform_response(self, response: dict) -> RerankResponse:
_billed_units = RerankBilledUnits(**response.get("usage", {}))
_tokens = RerankTokens(**response.get("usage", {}))
rerank_meta = RerankResponseMeta(billed_units=_billed_units, tokens=_tokens)
_results: Optional[List[dict]] = response.get("results")
if _results is None:
raise ValueError(f"No results found in the response={response}")
return RerankResponse(
id=response.get("id") or str(uuid.uuid4()),
results=_results,
meta=rerank_meta,
) # Return response

View file

@ -89,6 +89,9 @@ def _get_vertex_url(
elif mode == "embedding": elif mode == "embedding":
endpoint = "predict" endpoint = "predict"
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}" url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}"
if model.isdigit():
# https://us-central1-aiplatform.googleapis.com/v1/projects/$PROJECT_ID/locations/us-central1/endpoints/$ENDPOINT_ID:predict
url = f"https://{vertex_location}-aiplatform.googleapis.com/{vertex_api_version}/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}:{endpoint}"
if not url or not endpoint: if not url or not endpoint:
raise ValueError(f"Unable to get vertex url/endpoint for mode: {mode}") raise ValueError(f"Unable to get vertex url/endpoint for mode: {mode}")

View file

@ -0,0 +1,25 @@
"""
Vertex AI Image Generation Cost Calculator
"""
from typing import Optional
import litellm
from litellm.types.utils import ImageResponse
def cost_calculator(
model: str,
image_response: ImageResponse,
) -> float:
"""
Vertex AI Image Generation Cost Calculator
"""
_model_info = litellm.get_model_info(
model=model,
custom_llm_provider="vertex_ai",
)
output_cost_per_image: float = _model_info.get("output_cost_per_image") or 0.0
num_images: int = len(image_response.data)
return output_cost_per_image * num_images

View file

@ -96,7 +96,7 @@ class VertexEmbedding(VertexBase):
headers = self.set_headers(auth_header=auth_header, extra_headers=extra_headers) headers = self.set_headers(auth_header=auth_header, extra_headers=extra_headers)
vertex_request: VertexEmbeddingRequest = ( vertex_request: VertexEmbeddingRequest = (
litellm.vertexAITextEmbeddingConfig.transform_openai_request_to_vertex_embedding_request( litellm.vertexAITextEmbeddingConfig.transform_openai_request_to_vertex_embedding_request(
input=input, optional_params=optional_params input=input, optional_params=optional_params, model=model
) )
) )
@ -188,7 +188,7 @@ class VertexEmbedding(VertexBase):
headers = self.set_headers(auth_header=auth_header, extra_headers=extra_headers) headers = self.set_headers(auth_header=auth_header, extra_headers=extra_headers)
vertex_request: VertexEmbeddingRequest = ( vertex_request: VertexEmbeddingRequest = (
litellm.vertexAITextEmbeddingConfig.transform_openai_request_to_vertex_embedding_request( litellm.vertexAITextEmbeddingConfig.transform_openai_request_to_vertex_embedding_request(
input=input, optional_params=optional_params input=input, optional_params=optional_params, model=model
) )
) )

View file

@ -101,11 +101,16 @@ class VertexAITextEmbeddingConfig(BaseModel):
return optional_params return optional_params
def transform_openai_request_to_vertex_embedding_request( def transform_openai_request_to_vertex_embedding_request(
self, input: Union[list, str], optional_params: dict self, input: Union[list, str], optional_params: dict, model: str
) -> VertexEmbeddingRequest: ) -> VertexEmbeddingRequest:
""" """
Transforms an openai request to a vertex embedding request. Transforms an openai request to a vertex embedding request.
""" """
if model.isdigit():
return self._transform_openai_request_to_fine_tuned_embedding_request(
input, optional_params, model
)
vertex_request: VertexEmbeddingRequest = VertexEmbeddingRequest() vertex_request: VertexEmbeddingRequest = VertexEmbeddingRequest()
vertex_text_embedding_input_list: List[TextEmbeddingInput] = [] vertex_text_embedding_input_list: List[TextEmbeddingInput] = []
task_type: Optional[TaskType] = optional_params.get("task_type") task_type: Optional[TaskType] = optional_params.get("task_type")
@ -125,6 +130,47 @@ class VertexAITextEmbeddingConfig(BaseModel):
return vertex_request return vertex_request
def _transform_openai_request_to_fine_tuned_embedding_request(
self, input: Union[list, str], optional_params: dict, model: str
) -> VertexEmbeddingRequest:
"""
Transforms an openai request to a vertex fine-tuned embedding request.
Vertex Doc: https://console.cloud.google.com/vertex-ai/model-garden?hl=en&project=adroit-crow-413218&pageState=(%22galleryStateKey%22:(%22f%22:(%22g%22:%5B%5D,%22o%22:%5B%5D),%22s%22:%22%22))
Sample Request:
```json
{
"instances" : [
{
"inputs": "How would the Future of AI in 10 Years look?",
"parameters": {
"max_new_tokens": 128,
"temperature": 1.0,
"top_p": 0.9,
"top_k": 10
}
}
]
}
```
"""
vertex_request: VertexEmbeddingRequest = VertexEmbeddingRequest()
vertex_text_embedding_input_list: List[TextEmbeddingFineTunedInput] = []
if isinstance(input, str):
input = [input] # Convert single string to list for uniform processing
for text in input:
embedding_input = TextEmbeddingFineTunedInput(inputs=text)
vertex_text_embedding_input_list.append(embedding_input)
vertex_request["instances"] = vertex_text_embedding_input_list
vertex_request["parameters"] = TextEmbeddingFineTunedParameters(
**optional_params
)
return vertex_request
def create_embedding_input( def create_embedding_input(
self, self,
content: str, content: str,
@ -157,6 +203,11 @@ class VertexAITextEmbeddingConfig(BaseModel):
""" """
Transforms a vertex embedding response to an openai response. Transforms a vertex embedding response to an openai response.
""" """
if model.isdigit():
return self._transform_vertex_response_to_openai_for_fine_tuned_models(
response, model, model_response
)
_predictions = response["predictions"] _predictions = response["predictions"]
embedding_response = [] embedding_response = []
@ -181,3 +232,35 @@ class VertexAITextEmbeddingConfig(BaseModel):
) )
setattr(model_response, "usage", usage) setattr(model_response, "usage", usage)
return model_response return model_response
def _transform_vertex_response_to_openai_for_fine_tuned_models(
self, response: dict, model: str, model_response: litellm.EmbeddingResponse
) -> litellm.EmbeddingResponse:
"""
Transforms a vertex fine-tuned model embedding response to an openai response format.
"""
_predictions = response["predictions"]
embedding_response = []
# For fine-tuned models, we don't get token counts in the response
input_tokens = 0
for idx, embedding_values in enumerate(_predictions):
embedding_response.append(
{
"object": "embedding",
"index": idx,
"embedding": embedding_values[
0
], # The embedding values are nested one level deeper
}
)
model_response.object = "list"
model_response.data = embedding_response
model_response.model = model
usage = Usage(
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
)
setattr(model_response, "usage", usage)
return model_response

View file

@ -23,14 +23,27 @@ class TextEmbeddingInput(TypedDict, total=False):
title: Optional[str] title: Optional[str]
# Fine-tuned models require a different input format
# Ref: https://console.cloud.google.com/vertex-ai/model-garden?hl=en&project=adroit-crow-413218&pageState=(%22galleryStateKey%22:(%22f%22:(%22g%22:%5B%5D,%22o%22:%5B%5D),%22s%22:%22%22))
class TextEmbeddingFineTunedInput(TypedDict, total=False):
inputs: str
class TextEmbeddingFineTunedParameters(TypedDict, total=False):
max_new_tokens: Optional[int]
temperature: Optional[float]
top_p: Optional[float]
top_k: Optional[int]
class EmbeddingParameters(TypedDict, total=False): class EmbeddingParameters(TypedDict, total=False):
auto_truncate: Optional[bool] auto_truncate: Optional[bool]
output_dimensionality: Optional[int] output_dimensionality: Optional[int]
class VertexEmbeddingRequest(TypedDict, total=False): class VertexEmbeddingRequest(TypedDict, total=False):
instances: List[TextEmbeddingInput] instances: Union[List[TextEmbeddingInput], List[TextEmbeddingFineTunedInput]]
parameters: Optional[EmbeddingParameters] parameters: Optional[Union[EmbeddingParameters, TextEmbeddingFineTunedParameters]]
# Example usage: # Example usage:

View file

@ -1066,6 +1066,7 @@ def completion( # type: ignore # noqa: PLR0915
azure_ad_token_provider=kwargs.get("azure_ad_token_provider"), azure_ad_token_provider=kwargs.get("azure_ad_token_provider"),
user_continue_message=kwargs.get("user_continue_message"), user_continue_message=kwargs.get("user_continue_message"),
base_model=base_model, base_model=base_model,
litellm_trace_id=kwargs.get("litellm_trace_id"),
) )
logging.update_environment_variables( logging.update_environment_variables(
model=model, model=model,
@ -3455,7 +3456,7 @@ def embedding( # noqa: PLR0915
client=client, client=client,
aembedding=aembedding, aembedding=aembedding,
) )
elif custom_llm_provider == "openai_like": elif custom_llm_provider == "openai_like" or custom_llm_provider == "jina_ai":
api_base = ( api_base = (
api_base or litellm.api_base or get_secret_str("OPENAI_LIKE_API_BASE") api_base or litellm.api_base or get_secret_str("OPENAI_LIKE_API_BASE")
) )

View file

@ -2986,19 +2986,19 @@
"supports_function_calling": true "supports_function_calling": true
}, },
"vertex_ai/imagegeneration@006": { "vertex_ai/imagegeneration@006": {
"cost_per_image": 0.020, "output_cost_per_image": 0.020,
"litellm_provider": "vertex_ai-image-models", "litellm_provider": "vertex_ai-image-models",
"mode": "image_generation", "mode": "image_generation",
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing" "source": "https://cloud.google.com/vertex-ai/generative-ai/pricing"
}, },
"vertex_ai/imagen-3.0-generate-001": { "vertex_ai/imagen-3.0-generate-001": {
"cost_per_image": 0.04, "output_cost_per_image": 0.04,
"litellm_provider": "vertex_ai-image-models", "litellm_provider": "vertex_ai-image-models",
"mode": "image_generation", "mode": "image_generation",
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing" "source": "https://cloud.google.com/vertex-ai/generative-ai/pricing"
}, },
"vertex_ai/imagen-3.0-fast-generate-001": { "vertex_ai/imagen-3.0-fast-generate-001": {
"cost_per_image": 0.02, "output_cost_per_image": 0.02,
"litellm_provider": "vertex_ai-image-models", "litellm_provider": "vertex_ai-image-models",
"mode": "image_generation", "mode": "image_generation",
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing" "source": "https://cloud.google.com/vertex-ai/generative-ai/pricing"
@ -5620,6 +5620,13 @@
"litellm_provider": "bedrock", "litellm_provider": "bedrock",
"mode": "image_generation" "mode": "image_generation"
}, },
"stability.stable-image-ultra-v1:0": {
"max_tokens": 77,
"max_input_tokens": 77,
"output_cost_per_image": 0.14,
"litellm_provider": "bedrock",
"mode": "image_generation"
},
"sagemaker/meta-textgeneration-llama-2-7b": { "sagemaker/meta-textgeneration-llama-2-7b": {
"max_tokens": 4096, "max_tokens": 4096,
"max_input_tokens": 4096, "max_input_tokens": 4096,

View file

@ -1,122 +1,15 @@
model_list: model_list:
- model_name: "*" # GPT-4 Turbo Models
litellm_params:
model: claude-3-5-sonnet-20240620
api_key: os.environ/ANTHROPIC_API_KEY
- model_name: claude-3-5-sonnet-aihubmix
litellm_params:
model: openai/claude-3-5-sonnet-20240620
input_cost_per_token: 0.000003 # 3$/M
output_cost_per_token: 0.000015 # 15$/M
api_base: "https://exampleopenaiendpoint-production.up.railway.app"
api_key: my-fake-key
- model_name: fake-openai-endpoint-2
litellm_params:
model: openai/my-fake-model
api_key: my-fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app/
stream_timeout: 0.001
timeout: 1
rpm: 1
- model_name: fake-openai-endpoint
litellm_params:
model: openai/my-fake-model
api_key: my-fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app/
## bedrock chat completions
- model_name: "*anthropic.claude*"
litellm_params:
model: bedrock/*anthropic.claude*
aws_access_key_id: os.environ/BEDROCK_AWS_ACCESS_KEY_ID
aws_secret_access_key: os.environ/BEDROCK_AWS_SECRET_ACCESS_KEY
aws_region_name: os.environ/AWS_REGION_NAME
guardrailConfig:
"guardrailIdentifier": "h4dsqwhp6j66"
"guardrailVersion": "2"
"trace": "enabled"
## bedrock embeddings
- model_name: "*amazon.titan-embed-*"
litellm_params:
model: bedrock/amazon.titan-embed-*
aws_access_key_id: os.environ/BEDROCK_AWS_ACCESS_KEY_ID
aws_secret_access_key: os.environ/BEDROCK_AWS_SECRET_ACCESS_KEY
aws_region_name: os.environ/AWS_REGION_NAME
- model_name: "*cohere.embed-*"
litellm_params:
model: bedrock/cohere.embed-*
aws_access_key_id: os.environ/BEDROCK_AWS_ACCESS_KEY_ID
aws_secret_access_key: os.environ/BEDROCK_AWS_SECRET_ACCESS_KEY
aws_region_name: os.environ/AWS_REGION_NAME
- model_name: "bedrock/*"
litellm_params:
model: bedrock/*
aws_access_key_id: os.environ/BEDROCK_AWS_ACCESS_KEY_ID
aws_secret_access_key: os.environ/BEDROCK_AWS_SECRET_ACCESS_KEY
aws_region_name: os.environ/AWS_REGION_NAME
- model_name: gpt-4 - model_name: gpt-4
litellm_params: litellm_params:
model: azure/chatgpt-v-2 model: gpt-4
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/ - model_name: rerank-model
api_version: "2023-05-15" litellm_params:
api_key: os.environ/AZURE_API_KEY # The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault model: jina_ai/jina-reranker-v2-base-multilingual
rpm: 480
timeout: 300
stream_timeout: 60
litellm_settings:
fallbacks: [{ "claude-3-5-sonnet-20240620": ["claude-3-5-sonnet-aihubmix"] }]
# callbacks: ["otel", "prometheus"]
default_redis_batch_cache_expiry: 10
# default_team_settings:
# - team_id: "dbe2f686-a686-4896-864a-4c3924458709"
# success_callback: ["langfuse"]
# langfuse_public_key: os.environ/LANGFUSE_PUB_KEY_1 # Project 1
# langfuse_secret: os.environ/LANGFUSE_PRIVATE_KEY_1 # Project 1
# litellm_settings:
# cache: True
# cache_params:
# type: redis
# # disable caching on the actual API call
# supported_call_types: []
# # see https://docs.litellm.ai/docs/proxy/prod#3-use-redis-porthost-password-not-redis_url
# host: os.environ/REDIS_HOST
# port: os.environ/REDIS_PORT
# password: os.environ/REDIS_PASSWORD
# # see https://docs.litellm.ai/docs/proxy/caching#turn-on-batch_redis_requests
# # see https://docs.litellm.ai/docs/proxy/prometheus
# callbacks: ['otel']
# # router_settings: router_settings:
# # routing_strategy: latency-based-routing model_group_alias:
# # routing_strategy_args: "gpt-4-turbo": # Aliased model name
# # # only assign 40% of traffic to the fastest deployment to avoid overloading it model: "gpt-4" # Actual model name in 'model_list'
# # lowest_latency_buffer: 0.4 hidden: true
# # # consider last five minutes of calls for latency calculation
# # ttl: 300
# # redis_host: os.environ/REDIS_HOST
# # redis_port: os.environ/REDIS_PORT
# # redis_password: os.environ/REDIS_PASSWORD
# # # see https://docs.litellm.ai/docs/proxy/prod#1-use-this-configyaml
# # general_settings:
# # master_key: os.environ/LITELLM_MASTER_KEY
# # database_url: os.environ/DATABASE_URL
# # disable_master_key_return: true
# # # alerting: ['slack', 'email']
# # alerting: ['email']
# # # Batch write spend updates every 60s
# # proxy_batch_write_at: 60
# # # see https://docs.litellm.ai/docs/proxy/caching#advanced---user-api-key-cache-ttl
# # # our api keys rarely change
# # user_api_key_cache_ttl: 3600

View file

@ -1128,7 +1128,16 @@ class KeyManagementSystem(enum.Enum):
class KeyManagementSettings(LiteLLMBase): class KeyManagementSettings(LiteLLMBase):
hosted_keys: List hosted_keys: Optional[List] = None
store_virtual_keys: Optional[bool] = False
"""
If True, virtual keys created by litellm will be stored in the secret manager
"""
access_mode: Literal["read_only", "write_only", "read_and_write"] = "read_only"
"""
Access mode for the secret manager, when write_only will only use for writing secrets
"""
class TeamDefaultSettings(LiteLLMBase): class TeamDefaultSettings(LiteLLMBase):

View file

@ -8,6 +8,7 @@ Run checks for:
2. If user is in budget 2. If user is in budget
3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget 3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
""" """
import time import time
import traceback import traceback
from datetime import datetime from datetime import datetime

View file

@ -0,0 +1,267 @@
import asyncio
import json
import uuid
from datetime import datetime, timezone
from re import A
from typing import Any, List, Optional
from fastapi import status
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import (
GenerateKeyRequest,
KeyManagementSystem,
KeyRequest,
LiteLLM_AuditLogs,
LiteLLM_VerificationToken,
LitellmTableNames,
ProxyErrorTypes,
ProxyException,
UpdateKeyRequest,
UserAPIKeyAuth,
WebhookEvent,
)
class KeyManagementEventHooks:
@staticmethod
async def async_key_generated_hook(
data: GenerateKeyRequest,
response: dict,
user_api_key_dict: UserAPIKeyAuth,
litellm_changed_by: Optional[str] = None,
):
"""
Hook that runs after a successful /key/generate request
Handles the following:
- Sending Email with Key Details
- Storing Audit Logs for key generation
- Storing Generated Key in DB
"""
from litellm.proxy.management_helpers.audit_logs import (
create_audit_log_for_update,
)
from litellm.proxy.proxy_server import (
general_settings,
litellm_proxy_admin_name,
proxy_logging_obj,
)
if data.send_invite_email is True:
await KeyManagementEventHooks._send_key_created_email(response)
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
if litellm.store_audit_logs is True:
_updated_values = json.dumps(response, default=str)
asyncio.create_task(
create_audit_log_for_update(
request_data=LiteLLM_AuditLogs(
id=str(uuid.uuid4()),
updated_at=datetime.now(timezone.utc),
changed_by=litellm_changed_by
or user_api_key_dict.user_id
or litellm_proxy_admin_name,
changed_by_api_key=user_api_key_dict.api_key,
table_name=LitellmTableNames.KEY_TABLE_NAME,
object_id=response.get("token_id", ""),
action="created",
updated_values=_updated_values,
before_value=None,
)
)
)
# store the generated key in the secret manager
await KeyManagementEventHooks._store_virtual_key_in_secret_manager(
secret_name=data.key_alias or f"virtual-key-{uuid.uuid4()}",
secret_token=response.get("token", ""),
)
@staticmethod
async def async_key_updated_hook(
data: UpdateKeyRequest,
existing_key_row: Any,
response: Any,
user_api_key_dict: UserAPIKeyAuth,
litellm_changed_by: Optional[str] = None,
):
"""
Post /key/update processing hook
Handles the following:
- Storing Audit Logs for key update
"""
from litellm.proxy.management_helpers.audit_logs import (
create_audit_log_for_update,
)
from litellm.proxy.proxy_server import litellm_proxy_admin_name
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
if litellm.store_audit_logs is True:
_updated_values = json.dumps(data.json(exclude_none=True), default=str)
_before_value = existing_key_row.json(exclude_none=True)
_before_value = json.dumps(_before_value, default=str)
asyncio.create_task(
create_audit_log_for_update(
request_data=LiteLLM_AuditLogs(
id=str(uuid.uuid4()),
updated_at=datetime.now(timezone.utc),
changed_by=litellm_changed_by
or user_api_key_dict.user_id
or litellm_proxy_admin_name,
changed_by_api_key=user_api_key_dict.api_key,
table_name=LitellmTableNames.KEY_TABLE_NAME,
object_id=data.key,
action="updated",
updated_values=_updated_values,
before_value=_before_value,
)
)
)
pass
@staticmethod
async def async_key_deleted_hook(
data: KeyRequest,
keys_being_deleted: List[LiteLLM_VerificationToken],
response: dict,
user_api_key_dict: UserAPIKeyAuth,
litellm_changed_by: Optional[str] = None,
):
"""
Post /key/delete processing hook
Handles the following:
- Storing Audit Logs for key deletion
"""
from litellm.proxy.management_helpers.audit_logs import (
create_audit_log_for_update,
)
from litellm.proxy.proxy_server import litellm_proxy_admin_name, prisma_client
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
# we do this after the first for loop, since first for loop is for validation. we only want this inserted after validation passes
if litellm.store_audit_logs is True:
# make an audit log for each team deleted
for key in data.keys:
key_row = await prisma_client.get_data( # type: ignore
token=key, table_name="key", query_type="find_unique"
)
if key_row is None:
raise ProxyException(
message=f"Key {key} not found",
type=ProxyErrorTypes.bad_request_error,
param="key",
code=status.HTTP_404_NOT_FOUND,
)
key_row = key_row.json(exclude_none=True)
_key_row = json.dumps(key_row, default=str)
asyncio.create_task(
create_audit_log_for_update(
request_data=LiteLLM_AuditLogs(
id=str(uuid.uuid4()),
updated_at=datetime.now(timezone.utc),
changed_by=litellm_changed_by
or user_api_key_dict.user_id
or litellm_proxy_admin_name,
changed_by_api_key=user_api_key_dict.api_key,
table_name=LitellmTableNames.KEY_TABLE_NAME,
object_id=key,
action="deleted",
updated_values="{}",
before_value=_key_row,
)
)
)
# delete the keys from the secret manager
await KeyManagementEventHooks._delete_virtual_keys_from_secret_manager(
keys_being_deleted=keys_being_deleted
)
pass
@staticmethod
async def _store_virtual_key_in_secret_manager(secret_name: str, secret_token: str):
"""
Store a virtual key in the secret manager
Args:
secret_name: Name of the virtual key
secret_token: Value of the virtual key (example: sk-1234)
"""
if litellm._key_management_settings is not None:
if litellm._key_management_settings.store_virtual_keys is True:
from litellm.secret_managers.aws_secret_manager_v2 import (
AWSSecretsManagerV2,
)
# store the key in the secret manager
if (
litellm._key_management_system
== KeyManagementSystem.AWS_SECRET_MANAGER
and isinstance(litellm.secret_manager_client, AWSSecretsManagerV2)
):
await litellm.secret_manager_client.async_write_secret(
secret_name=secret_name,
secret_value=secret_token,
)
@staticmethod
async def _delete_virtual_keys_from_secret_manager(
keys_being_deleted: List[LiteLLM_VerificationToken],
):
"""
Deletes virtual keys from the secret manager
Args:
keys_being_deleted: List of keys being deleted, this is passed down from the /key/delete operation
"""
if litellm._key_management_settings is not None:
if litellm._key_management_settings.store_virtual_keys is True:
from litellm.secret_managers.aws_secret_manager_v2 import (
AWSSecretsManagerV2,
)
if isinstance(litellm.secret_manager_client, AWSSecretsManagerV2):
for key in keys_being_deleted:
if key.key_alias is not None:
await litellm.secret_manager_client.async_delete_secret(
secret_name=key.key_alias
)
else:
verbose_proxy_logger.warning(
f"KeyManagementEventHooks._delete_virtual_key_from_secret_manager: Key alias not found for key {key.token}. Skipping deletion from secret manager."
)
@staticmethod
async def _send_key_created_email(response: dict):
from litellm.proxy.proxy_server import general_settings, proxy_logging_obj
if "email" not in general_settings.get("alerting", []):
raise ValueError(
"Email alerting not setup on config.yaml. Please set `alerting=['email']. \nDocs: https://docs.litellm.ai/docs/proxy/email`"
)
event = WebhookEvent(
event="key_created",
event_group="key",
event_message="API Key Created",
token=response.get("token", ""),
spend=response.get("spend", 0.0),
max_budget=response.get("max_budget", 0.0),
user_id=response.get("user_id", None),
team_id=response.get("team_id", "Default Team"),
key_alias=response.get("key_alias", None),
)
# If user configured email alerting - send an Email letting their end-user know the key was created
asyncio.create_task(
proxy_logging_obj.slack_alerting_instance.send_key_created_or_user_invited_email(
webhook_event=event,
)
)

View file

@ -274,6 +274,51 @@ class LiteLLMProxyRequestSetup:
) )
return user_api_key_logged_metadata return user_api_key_logged_metadata
@staticmethod
def add_key_level_controls(
key_metadata: dict, data: dict, _metadata_variable_name: str
):
data = data.copy()
if "cache" in key_metadata:
data["cache"] = {}
if isinstance(key_metadata["cache"], dict):
for k, v in key_metadata["cache"].items():
if k in SupportedCacheControls:
data["cache"][k] = v
## KEY-LEVEL SPEND LOGS / TAGS
if "tags" in key_metadata and key_metadata["tags"] is not None:
if "tags" in data[_metadata_variable_name] and isinstance(
data[_metadata_variable_name]["tags"], list
):
data[_metadata_variable_name]["tags"].extend(key_metadata["tags"])
else:
data[_metadata_variable_name]["tags"] = key_metadata["tags"]
if "spend_logs_metadata" in key_metadata and isinstance(
key_metadata["spend_logs_metadata"], dict
):
if "spend_logs_metadata" in data[_metadata_variable_name] and isinstance(
data[_metadata_variable_name]["spend_logs_metadata"], dict
):
for key, value in key_metadata["spend_logs_metadata"].items():
if (
key not in data[_metadata_variable_name]["spend_logs_metadata"]
): # don't override k-v pair sent by request (user request)
data[_metadata_variable_name]["spend_logs_metadata"][
key
] = value
else:
data[_metadata_variable_name]["spend_logs_metadata"] = key_metadata[
"spend_logs_metadata"
]
## KEY-LEVEL DISABLE FALLBACKS
if "disable_fallbacks" in key_metadata and isinstance(
key_metadata["disable_fallbacks"], bool
):
data["disable_fallbacks"] = key_metadata["disable_fallbacks"]
return data
async def add_litellm_data_to_request( # noqa: PLR0915 async def add_litellm_data_to_request( # noqa: PLR0915
data: dict, data: dict,
@ -389,37 +434,11 @@ async def add_litellm_data_to_request( # noqa: PLR0915
### KEY-LEVEL Controls ### KEY-LEVEL Controls
key_metadata = user_api_key_dict.metadata key_metadata = user_api_key_dict.metadata
if "cache" in key_metadata: data = LiteLLMProxyRequestSetup.add_key_level_controls(
data["cache"] = {} key_metadata=key_metadata,
if isinstance(key_metadata["cache"], dict): data=data,
for k, v in key_metadata["cache"].items(): _metadata_variable_name=_metadata_variable_name,
if k in SupportedCacheControls: )
data["cache"][k] = v
## KEY-LEVEL SPEND LOGS / TAGS
if "tags" in key_metadata and key_metadata["tags"] is not None:
if "tags" in data[_metadata_variable_name] and isinstance(
data[_metadata_variable_name]["tags"], list
):
data[_metadata_variable_name]["tags"].extend(key_metadata["tags"])
else:
data[_metadata_variable_name]["tags"] = key_metadata["tags"]
if "spend_logs_metadata" in key_metadata and isinstance(
key_metadata["spend_logs_metadata"], dict
):
if "spend_logs_metadata" in data[_metadata_variable_name] and isinstance(
data[_metadata_variable_name]["spend_logs_metadata"], dict
):
for key, value in key_metadata["spend_logs_metadata"].items():
if (
key not in data[_metadata_variable_name]["spend_logs_metadata"]
): # don't override k-v pair sent by request (user request)
data[_metadata_variable_name]["spend_logs_metadata"][key] = value
else:
data[_metadata_variable_name]["spend_logs_metadata"] = key_metadata[
"spend_logs_metadata"
]
## TEAM-LEVEL SPEND LOGS/TAGS ## TEAM-LEVEL SPEND LOGS/TAGS
team_metadata = user_api_key_dict.team_metadata or {} team_metadata = user_api_key_dict.team_metadata or {}
if "tags" in team_metadata and team_metadata["tags"] is not None: if "tags" in team_metadata and team_metadata["tags"] is not None:

View file

@ -17,7 +17,7 @@ import secrets
import traceback import traceback
import uuid import uuid
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from typing import List, Optional from typing import List, Optional, Tuple
import fastapi import fastapi
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Request, status from fastapi import APIRouter, Depends, Header, HTTPException, Query, Request, status
@ -31,6 +31,7 @@ from litellm.proxy.auth.auth_checks import (
get_key_object, get_key_object,
) )
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.hooks.key_management_event_hooks import KeyManagementEventHooks
from litellm.proxy.management_helpers.utils import management_endpoint_wrapper from litellm.proxy.management_helpers.utils import management_endpoint_wrapper
from litellm.proxy.utils import _duration_in_seconds, _hash_token_if_needed from litellm.proxy.utils import _duration_in_seconds, _hash_token_if_needed
from litellm.secret_managers.main import get_secret from litellm.secret_managers.main import get_secret
@ -234,50 +235,14 @@ async def generate_key_fn( # noqa: PLR0915
data.soft_budget data.soft_budget
) # include the user-input soft budget in the response ) # include the user-input soft budget in the response
if data.send_invite_email is True: asyncio.create_task(
if "email" not in general_settings.get("alerting", []): KeyManagementEventHooks.async_key_generated_hook(
raise ValueError( data=data,
"Email alerting not setup on config.yaml. Please set `alerting=['email']. \nDocs: https://docs.litellm.ai/docs/proxy/email`" response=response,
) user_api_key_dict=user_api_key_dict,
event = WebhookEvent( litellm_changed_by=litellm_changed_by,
event="key_created",
event_group="key",
event_message="API Key Created",
token=response.get("token", ""),
spend=response.get("spend", 0.0),
max_budget=response.get("max_budget", 0.0),
user_id=response.get("user_id", None),
team_id=response.get("team_id", "Default Team"),
key_alias=response.get("key_alias", None),
)
# If user configured email alerting - send an Email letting their end-user know the key was created
asyncio.create_task(
proxy_logging_obj.slack_alerting_instance.send_key_created_or_user_invited_email(
webhook_event=event,
)
)
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
if litellm.store_audit_logs is True:
_updated_values = json.dumps(response, default=str)
asyncio.create_task(
create_audit_log_for_update(
request_data=LiteLLM_AuditLogs(
id=str(uuid.uuid4()),
updated_at=datetime.now(timezone.utc),
changed_by=litellm_changed_by
or user_api_key_dict.user_id
or litellm_proxy_admin_name,
changed_by_api_key=user_api_key_dict.api_key,
table_name=LitellmTableNames.KEY_TABLE_NAME,
object_id=response.get("token_id", ""),
action="created",
updated_values=_updated_values,
before_value=None,
)
)
) )
)
return GenerateKeyResponse(**response) return GenerateKeyResponse(**response)
except Exception as e: except Exception as e:
@ -407,30 +372,15 @@ async def update_key_fn(
proxy_logging_obj=proxy_logging_obj, proxy_logging_obj=proxy_logging_obj,
) )
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True asyncio.create_task(
if litellm.store_audit_logs is True: KeyManagementEventHooks.async_key_updated_hook(
_updated_values = json.dumps(data_json, default=str) data=data,
existing_key_row=existing_key_row,
_before_value = existing_key_row.json(exclude_none=True) response=response,
_before_value = json.dumps(_before_value, default=str) user_api_key_dict=user_api_key_dict,
litellm_changed_by=litellm_changed_by,
asyncio.create_task(
create_audit_log_for_update(
request_data=LiteLLM_AuditLogs(
id=str(uuid.uuid4()),
updated_at=datetime.now(timezone.utc),
changed_by=litellm_changed_by
or user_api_key_dict.user_id
or litellm_proxy_admin_name,
changed_by_api_key=user_api_key_dict.api_key,
table_name=LitellmTableNames.KEY_TABLE_NAME,
object_id=data.key,
action="updated",
updated_values=_updated_values,
before_value=_before_value,
)
)
) )
)
if response is None: if response is None:
raise ValueError("Failed to update key got response = None") raise ValueError("Failed to update key got response = None")
@ -496,6 +446,9 @@ async def delete_key_fn(
user_custom_key_generate, user_custom_key_generate,
) )
if prisma_client is None:
raise Exception("Not connected to DB!")
keys = data.keys keys = data.keys
if len(keys) == 0: if len(keys) == 0:
raise ProxyException( raise ProxyException(
@ -516,45 +469,7 @@ async def delete_key_fn(
): ):
user_id = None # unless they're admin user_id = None # unless they're admin
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True number_deleted_keys, _keys_being_deleted = await delete_verification_token(
# we do this after the first for loop, since first for loop is for validation. we only want this inserted after validation passes
if litellm.store_audit_logs is True:
# make an audit log for each team deleted
for key in data.keys:
key_row = await prisma_client.get_data( # type: ignore
token=key, table_name="key", query_type="find_unique"
)
if key_row is None:
raise ProxyException(
message=f"Key {key} not found",
type=ProxyErrorTypes.bad_request_error,
param="key",
code=status.HTTP_404_NOT_FOUND,
)
key_row = key_row.json(exclude_none=True)
_key_row = json.dumps(key_row, default=str)
asyncio.create_task(
create_audit_log_for_update(
request_data=LiteLLM_AuditLogs(
id=str(uuid.uuid4()),
updated_at=datetime.now(timezone.utc),
changed_by=litellm_changed_by
or user_api_key_dict.user_id
or litellm_proxy_admin_name,
changed_by_api_key=user_api_key_dict.api_key,
table_name=LitellmTableNames.KEY_TABLE_NAME,
object_id=key,
action="deleted",
updated_values="{}",
before_value=_key_row,
)
)
)
number_deleted_keys = await delete_verification_token(
tokens=keys, user_id=user_id tokens=keys, user_id=user_id
) )
if number_deleted_keys is None: if number_deleted_keys is None:
@ -588,6 +503,16 @@ async def delete_key_fn(
f"/keys/delete - cache after delete: {user_api_key_cache.in_memory_cache.cache_dict}" f"/keys/delete - cache after delete: {user_api_key_cache.in_memory_cache.cache_dict}"
) )
asyncio.create_task(
KeyManagementEventHooks.async_key_deleted_hook(
data=data,
keys_being_deleted=_keys_being_deleted,
user_api_key_dict=user_api_key_dict,
litellm_changed_by=litellm_changed_by,
response=number_deleted_keys,
)
)
return {"deleted_keys": keys} return {"deleted_keys": keys}
except Exception as e: except Exception as e:
if isinstance(e, HTTPException): if isinstance(e, HTTPException):
@ -1026,11 +951,35 @@ async def generate_key_helper_fn( # noqa: PLR0915
return key_data return key_data
async def delete_verification_token(tokens: List, user_id: Optional[str] = None): async def delete_verification_token(
tokens: List, user_id: Optional[str] = None
) -> Tuple[Optional[Dict], List[LiteLLM_VerificationToken]]:
"""
Helper that deletes the list of tokens from the database
Args:
tokens: List of tokens to delete
user_id: Optional user_id to filter by
Returns:
Tuple[Optional[Dict], List[LiteLLM_VerificationToken]]:
Optional[Dict]:
- Number of deleted tokens
List[LiteLLM_VerificationToken]:
- List of keys being deleted, this contains information about the key_alias, token, and user_id being deleted,
this is passed down to the KeyManagementEventHooks to delete the keys from the secret manager and handle audit logs
"""
from litellm.proxy.proxy_server import litellm_proxy_admin_name, prisma_client from litellm.proxy.proxy_server import litellm_proxy_admin_name, prisma_client
try: try:
if prisma_client: if prisma_client:
tokens = [_hash_token_if_needed(token=key) for key in tokens]
_keys_being_deleted = (
await prisma_client.db.litellm_verificationtoken.find_many(
where={"token": {"in": tokens}}
)
)
# Assuming 'db' is your Prisma Client instance # Assuming 'db' is your Prisma Client instance
# check if admin making request - don't filter by user-id # check if admin making request - don't filter by user-id
if user_id == litellm_proxy_admin_name: if user_id == litellm_proxy_admin_name:
@ -1060,7 +1009,7 @@ async def delete_verification_token(tokens: List, user_id: Optional[str] = None)
) )
verbose_proxy_logger.debug(traceback.format_exc()) verbose_proxy_logger.debug(traceback.format_exc())
raise e raise e
return deleted_tokens return deleted_tokens, _keys_being_deleted
@router.post( @router.post(

View file

@ -265,7 +265,6 @@ def run_server( # noqa: PLR0915
ProxyConfig, ProxyConfig,
app, app,
load_aws_kms, load_aws_kms,
load_aws_secret_manager,
load_from_azure_key_vault, load_from_azure_key_vault,
load_google_kms, load_google_kms,
save_worker_config, save_worker_config,
@ -278,7 +277,6 @@ def run_server( # noqa: PLR0915
ProxyConfig, ProxyConfig,
app, app,
load_aws_kms, load_aws_kms,
load_aws_secret_manager,
load_from_azure_key_vault, load_from_azure_key_vault,
load_google_kms, load_google_kms,
save_worker_config, save_worker_config,
@ -295,7 +293,6 @@ def run_server( # noqa: PLR0915
ProxyConfig, ProxyConfig,
app, app,
load_aws_kms, load_aws_kms,
load_aws_secret_manager,
load_from_azure_key_vault, load_from_azure_key_vault,
load_google_kms, load_google_kms,
save_worker_config, save_worker_config,
@ -559,8 +556,14 @@ def run_server( # noqa: PLR0915
key_management_system key_management_system
== KeyManagementSystem.AWS_SECRET_MANAGER.value # noqa: F405 == KeyManagementSystem.AWS_SECRET_MANAGER.value # noqa: F405
): ):
from litellm.secret_managers.aws_secret_manager_v2 import (
AWSSecretsManagerV2,
)
### LOAD FROM AWS SECRET MANAGER ### ### LOAD FROM AWS SECRET MANAGER ###
load_aws_secret_manager(use_aws_secret_manager=True) AWSSecretsManagerV2.load_aws_secret_manager(
use_aws_secret_manager=True
)
elif key_management_system == KeyManagementSystem.AWS_KMS.value: elif key_management_system == KeyManagementSystem.AWS_KMS.value:
load_aws_kms(use_aws_kms=True) load_aws_kms(use_aws_kms=True)
elif ( elif (

View file

@ -7,6 +7,8 @@ model_list:
litellm_settings: general_settings:
callbacks: ["gcs_bucket"] key_management_system: "aws_secret_manager"
key_management_settings:
store_virtual_keys: true
access_mode: "write_only"

View file

@ -245,10 +245,7 @@ from litellm.router import (
from litellm.router import ModelInfo as RouterModelInfo from litellm.router import ModelInfo as RouterModelInfo
from litellm.router import updateDeployment from litellm.router import updateDeployment
from litellm.scheduler import DefaultPriorities, FlowItem, Scheduler from litellm.scheduler import DefaultPriorities, FlowItem, Scheduler
from litellm.secret_managers.aws_secret_manager import ( from litellm.secret_managers.aws_secret_manager import load_aws_kms
load_aws_kms,
load_aws_secret_manager,
)
from litellm.secret_managers.google_kms import load_google_kms from litellm.secret_managers.google_kms import load_google_kms
from litellm.secret_managers.main import ( from litellm.secret_managers.main import (
get_secret, get_secret,
@ -1825,8 +1822,13 @@ class ProxyConfig:
key_management_system key_management_system
== KeyManagementSystem.AWS_SECRET_MANAGER.value # noqa: F405 == KeyManagementSystem.AWS_SECRET_MANAGER.value # noqa: F405
): ):
### LOAD FROM AWS SECRET MANAGER ### from litellm.secret_managers.aws_secret_manager_v2 import (
load_aws_secret_manager(use_aws_secret_manager=True) AWSSecretsManagerV2,
)
AWSSecretsManagerV2.load_aws_secret_manager(
use_aws_secret_manager=True
)
elif key_management_system == KeyManagementSystem.AWS_KMS.value: elif key_management_system == KeyManagementSystem.AWS_KMS.value:
load_aws_kms(use_aws_kms=True) load_aws_kms(use_aws_kms=True)
elif ( elif (

View file

@ -8,7 +8,8 @@ from litellm._logging import verbose_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.azure_ai.rerank import AzureAIRerank from litellm.llms.azure_ai.rerank import AzureAIRerank
from litellm.llms.cohere.rerank import CohereRerank from litellm.llms.cohere.rerank import CohereRerank
from litellm.llms.together_ai.rerank import TogetherAIRerank from litellm.llms.jina_ai.rerank.handler import JinaAIRerank
from litellm.llms.together_ai.rerank.handler import TogetherAIRerank
from litellm.secret_managers.main import get_secret from litellm.secret_managers.main import get_secret
from litellm.types.rerank import RerankRequest, RerankResponse from litellm.types.rerank import RerankRequest, RerankResponse
from litellm.types.router import * from litellm.types.router import *
@ -19,6 +20,7 @@ from litellm.utils import client, exception_type, supports_httpx_timeout
cohere_rerank = CohereRerank() cohere_rerank = CohereRerank()
together_rerank = TogetherAIRerank() together_rerank = TogetherAIRerank()
azure_ai_rerank = AzureAIRerank() azure_ai_rerank = AzureAIRerank()
jina_ai_rerank = JinaAIRerank()
################################################# #################################################
@ -247,7 +249,23 @@ def rerank(
api_key=api_key, api_key=api_key,
_is_async=_is_async, _is_async=_is_async,
) )
elif _custom_llm_provider == "jina_ai":
if dynamic_api_key is None:
raise ValueError(
"Jina AI API key is required, please set 'JINA_AI_API_KEY' in your environment"
)
response = jina_ai_rerank.rerank(
model=model,
api_key=dynamic_api_key,
query=query,
documents=documents,
top_n=top_n,
rank_fields=rank_fields,
return_documents=return_documents,
max_chunks_per_doc=max_chunks_per_doc,
_is_async=_is_async,
)
else: else:
raise ValueError(f"Unsupported provider: {_custom_llm_provider}") raise ValueError(f"Unsupported provider: {_custom_llm_provider}")

View file

@ -679,9 +679,8 @@ class Router:
kwargs["model"] = model kwargs["model"] = model
kwargs["messages"] = messages kwargs["messages"] = messages
kwargs["original_function"] = self._completion kwargs["original_function"] = self._completion
kwargs.get("request_timeout", self.timeout) self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
kwargs.setdefault("metadata", {}).update({"model_group": model})
response = self.function_with_fallbacks(**kwargs) response = self.function_with_fallbacks(**kwargs)
return response return response
except Exception as e: except Exception as e:
@ -783,8 +782,7 @@ class Router:
kwargs["stream"] = stream kwargs["stream"] = stream
kwargs["original_function"] = self._acompletion kwargs["original_function"] = self._acompletion
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
kwargs.setdefault("metadata", {}).update({"model_group": model})
request_priority = kwargs.get("priority") or self.default_priority request_priority = kwargs.get("priority") or self.default_priority
@ -948,6 +946,17 @@ class Router:
self.fail_calls[model_name] += 1 self.fail_calls[model_name] += 1
raise e raise e
def _update_kwargs_before_fallbacks(self, model: str, kwargs: dict) -> None:
"""
Adds/updates to kwargs:
- num_retries
- litellm_trace_id
- metadata
"""
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
kwargs.setdefault("litellm_trace_id", str(uuid.uuid4()))
kwargs.setdefault("metadata", {}).update({"model_group": model})
def _update_kwargs_with_default_litellm_params(self, kwargs: dict) -> None: def _update_kwargs_with_default_litellm_params(self, kwargs: dict) -> None:
""" """
Adds default litellm params to kwargs, if set. Adds default litellm params to kwargs, if set.
@ -1511,9 +1520,7 @@ class Router:
kwargs["model"] = model kwargs["model"] = model
kwargs["file"] = file kwargs["file"] = file
kwargs["original_function"] = self._atranscription kwargs["original_function"] = self._atranscription
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
response = await self.async_function_with_fallbacks(**kwargs) response = await self.async_function_with_fallbacks(**kwargs)
return response return response
@ -1688,9 +1695,7 @@ class Router:
kwargs["model"] = model kwargs["model"] = model
kwargs["input"] = input kwargs["input"] = input
kwargs["original_function"] = self._arerank kwargs["original_function"] = self._arerank
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
response = await self.async_function_with_fallbacks(**kwargs) response = await self.async_function_with_fallbacks(**kwargs)
@ -1839,9 +1844,7 @@ class Router:
kwargs["model"] = model kwargs["model"] = model
kwargs["prompt"] = prompt kwargs["prompt"] = prompt
kwargs["original_function"] = self._atext_completion kwargs["original_function"] = self._atext_completion
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
response = await self.async_function_with_fallbacks(**kwargs) response = await self.async_function_with_fallbacks(**kwargs)
return response return response
@ -2112,9 +2115,7 @@ class Router:
kwargs["model"] = model kwargs["model"] = model
kwargs["input"] = input kwargs["input"] = input
kwargs["original_function"] = self._aembedding kwargs["original_function"] = self._aembedding
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
response = await self.async_function_with_fallbacks(**kwargs) response = await self.async_function_with_fallbacks(**kwargs)
return response return response
except Exception as e: except Exception as e:
@ -2609,6 +2610,7 @@ class Router:
If it fails after num_retries, fall back to another model group If it fails after num_retries, fall back to another model group
""" """
model_group: Optional[str] = kwargs.get("model") model_group: Optional[str] = kwargs.get("model")
disable_fallbacks: Optional[bool] = kwargs.pop("disable_fallbacks", False)
fallbacks: Optional[List] = kwargs.get("fallbacks", self.fallbacks) fallbacks: Optional[List] = kwargs.get("fallbacks", self.fallbacks)
context_window_fallbacks: Optional[List] = kwargs.get( context_window_fallbacks: Optional[List] = kwargs.get(
"context_window_fallbacks", self.context_window_fallbacks "context_window_fallbacks", self.context_window_fallbacks
@ -2616,6 +2618,7 @@ class Router:
content_policy_fallbacks: Optional[List] = kwargs.get( content_policy_fallbacks: Optional[List] = kwargs.get(
"content_policy_fallbacks", self.content_policy_fallbacks "content_policy_fallbacks", self.content_policy_fallbacks
) )
try: try:
self._handle_mock_testing_fallbacks( self._handle_mock_testing_fallbacks(
kwargs=kwargs, kwargs=kwargs,
@ -2635,7 +2638,7 @@ class Router:
original_model_group: Optional[str] = kwargs.get("model") # type: ignore original_model_group: Optional[str] = kwargs.get("model") # type: ignore
fallback_failure_exception_str = "" fallback_failure_exception_str = ""
if original_model_group is None: if disable_fallbacks is True or original_model_group is None:
raise e raise e
input_kwargs = { input_kwargs = {

View file

@ -23,28 +23,6 @@ def validate_environment():
raise ValueError("Missing required environment variable - AWS_REGION_NAME") raise ValueError("Missing required environment variable - AWS_REGION_NAME")
def load_aws_secret_manager(use_aws_secret_manager: Optional[bool]):
if use_aws_secret_manager is None or use_aws_secret_manager is False:
return
try:
import boto3
from botocore.exceptions import ClientError
validate_environment()
# Create a Secrets Manager client
session = boto3.session.Session() # type: ignore
client = session.client(
service_name="secretsmanager", region_name=os.getenv("AWS_REGION_NAME")
)
litellm.secret_manager_client = client
litellm._key_management_system = KeyManagementSystem.AWS_SECRET_MANAGER
except Exception as e:
raise e
def load_aws_kms(use_aws_kms: Optional[bool]): def load_aws_kms(use_aws_kms: Optional[bool]):
if use_aws_kms is None or use_aws_kms is False: if use_aws_kms is None or use_aws_kms is False:
return return

View file

@ -0,0 +1,310 @@
"""
This is a file for the AWS Secret Manager Integration
Handles Async Operations for:
- Read Secret
- Write Secret
- Delete Secret
Relevant issue: https://github.com/BerriAI/litellm/issues/1883
Requires:
* `os.environ["AWS_REGION_NAME"],
* `pip install boto3>=1.28.57`
"""
import ast
import asyncio
import base64
import json
import os
import re
import sys
from typing import Any, Dict, Optional, Union
import httpx
import litellm
from litellm._logging import verbose_logger
from litellm.llms.base_aws_llm import BaseAWSLLM
from litellm.llms.custom_httpx.http_handler import (
_get_httpx_client,
get_async_httpx_client,
)
from litellm.llms.custom_httpx.types import httpxSpecialProvider
from litellm.proxy._types import KeyManagementSystem
class AWSSecretsManagerV2(BaseAWSLLM):
@classmethod
def validate_environment(cls):
if "AWS_REGION_NAME" not in os.environ:
raise ValueError("Missing required environment variable - AWS_REGION_NAME")
@classmethod
def load_aws_secret_manager(cls, use_aws_secret_manager: Optional[bool]):
"""
Initialize AWSSecretsManagerV2 and sets litellm.secret_manager_client = AWSSecretsManagerV2() and litellm._key_management_system = KeyManagementSystem.AWS_SECRET_MANAGER
"""
if use_aws_secret_manager is None or use_aws_secret_manager is False:
return
try:
import boto3
cls.validate_environment()
litellm.secret_manager_client = cls()
litellm._key_management_system = KeyManagementSystem.AWS_SECRET_MANAGER
except Exception as e:
raise e
async def async_read_secret(
self,
secret_name: str,
optional_params: Optional[dict] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
) -> Optional[str]:
"""
Async function to read a secret from AWS Secrets Manager
Returns:
str: Secret value
Raises:
ValueError: If the secret is not found or an HTTP error occurs
"""
endpoint_url, headers, body = self._prepare_request(
action="GetSecretValue",
secret_name=secret_name,
optional_params=optional_params,
)
async_client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.SecretManager,
params={"timeout": timeout},
)
try:
response = await async_client.post(
url=endpoint_url, headers=headers, data=body.decode("utf-8")
)
response.raise_for_status()
return response.json()["SecretString"]
except httpx.TimeoutException:
raise ValueError("Timeout error occurred")
except Exception as e:
verbose_logger.exception(
"Error reading secret from AWS Secrets Manager: %s", str(e)
)
return None
def sync_read_secret(
self,
secret_name: str,
optional_params: Optional[dict] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
) -> Optional[str]:
"""
Sync function to read a secret from AWS Secrets Manager
Done for backwards compatibility with existing codebase, since get_secret is a sync function
"""
# self._prepare_request uses these env vars, we cannot read them from AWS Secrets Manager. If we do we'd get stuck in an infinite loop
if secret_name in [
"AWS_ACCESS_KEY_ID",
"AWS_SECRET_ACCESS_KEY",
"AWS_REGION_NAME",
"AWS_REGION",
"AWS_BEDROCK_RUNTIME_ENDPOINT",
]:
return os.getenv(secret_name)
endpoint_url, headers, body = self._prepare_request(
action="GetSecretValue",
secret_name=secret_name,
optional_params=optional_params,
)
sync_client = _get_httpx_client(
params={"timeout": timeout},
)
try:
response = sync_client.post(
url=endpoint_url, headers=headers, data=body.decode("utf-8")
)
response.raise_for_status()
return response.json()["SecretString"]
except httpx.TimeoutException:
raise ValueError("Timeout error occurred")
except Exception as e:
verbose_logger.exception(
"Error reading secret from AWS Secrets Manager: %s", str(e)
)
return None
async def async_write_secret(
self,
secret_name: str,
secret_value: str,
description: Optional[str] = None,
client_request_token: Optional[str] = None,
optional_params: Optional[dict] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
) -> dict:
"""
Async function to write a secret to AWS Secrets Manager
Args:
secret_name: Name of the secret
secret_value: Value to store (can be a JSON string)
description: Optional description for the secret
client_request_token: Optional unique identifier to ensure idempotency
optional_params: Additional AWS parameters
timeout: Request timeout
"""
import uuid
# Prepare the request data
data = {"Name": secret_name, "SecretString": secret_value}
if description:
data["Description"] = description
data["ClientRequestToken"] = str(uuid.uuid4())
endpoint_url, headers, body = self._prepare_request(
action="CreateSecret",
secret_name=secret_name,
secret_value=secret_value,
optional_params=optional_params,
request_data=data, # Pass the complete request data
)
async_client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.SecretManager,
params={"timeout": timeout},
)
try:
response = await async_client.post(
url=endpoint_url, headers=headers, data=body.decode("utf-8")
)
response.raise_for_status()
return response.json()
except httpx.HTTPStatusError as err:
raise ValueError(f"HTTP error occurred: {err.response.text}")
except httpx.TimeoutException:
raise ValueError("Timeout error occurred")
async def async_delete_secret(
self,
secret_name: str,
recovery_window_in_days: Optional[int] = 7,
optional_params: Optional[dict] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
) -> dict:
"""
Async function to delete a secret from AWS Secrets Manager
Args:
secret_name: Name of the secret to delete
recovery_window_in_days: Number of days before permanent deletion (default: 7)
optional_params: Additional AWS parameters
timeout: Request timeout
Returns:
dict: Response from AWS Secrets Manager containing deletion details
"""
# Prepare the request data
data = {
"SecretId": secret_name,
"RecoveryWindowInDays": recovery_window_in_days,
}
endpoint_url, headers, body = self._prepare_request(
action="DeleteSecret",
secret_name=secret_name,
optional_params=optional_params,
request_data=data,
)
async_client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.SecretManager,
params={"timeout": timeout},
)
try:
response = await async_client.post(
url=endpoint_url, headers=headers, data=body.decode("utf-8")
)
response.raise_for_status()
return response.json()
except httpx.HTTPStatusError as err:
raise ValueError(f"HTTP error occurred: {err.response.text}")
except httpx.TimeoutException:
raise ValueError("Timeout error occurred")
def _prepare_request(
self,
action: str, # "GetSecretValue" or "PutSecretValue"
secret_name: str,
secret_value: Optional[str] = None,
optional_params: Optional[dict] = None,
request_data: Optional[dict] = None,
) -> tuple[str, Any, bytes]:
"""Prepare the AWS Secrets Manager request"""
try:
import boto3
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest
from botocore.credentials import Credentials
except ImportError:
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
optional_params = optional_params or {}
boto3_credentials_info = self._get_boto_credentials_from_optional_params(
optional_params
)
# Get endpoint
_, endpoint_url = self.get_runtime_endpoint(
api_base=None,
aws_bedrock_runtime_endpoint=boto3_credentials_info.aws_bedrock_runtime_endpoint,
aws_region_name=boto3_credentials_info.aws_region_name,
)
endpoint_url = endpoint_url.replace("bedrock-runtime", "secretsmanager")
# Use provided request_data if available, otherwise build default data
if request_data:
data = request_data
else:
data = {"SecretId": secret_name}
if secret_value and action == "PutSecretValue":
data["SecretString"] = secret_value
body = json.dumps(data).encode("utf-8")
headers = {
"Content-Type": "application/x-amz-json-1.1",
"X-Amz-Target": f"secretsmanager.{action}",
}
# Sign request
request = AWSRequest(
method="POST", url=endpoint_url, data=body, headers=headers
)
SigV4Auth(
boto3_credentials_info.credentials,
"secretsmanager",
boto3_credentials_info.aws_region_name,
).add_auth(request)
prepped = request.prepare()
return endpoint_url, prepped.headers, body
# if __name__ == "__main__":
# print("loading aws secret manager v2")
# aws_secret_manager_v2 = AWSSecretsManagerV2()
# print("writing secret to aws secret manager v2")
# asyncio.run(aws_secret_manager_v2.async_write_secret(secret_name="test_secret_3", secret_value="test_value_2"))
# print("reading secret from aws secret manager v2")

View file

@ -5,7 +5,7 @@ import json
import os import os
import sys import sys
import traceback import traceback
from typing import Any, Optional, Union from typing import TYPE_CHECKING, Any, Optional, Union
import httpx import httpx
from dotenv import load_dotenv from dotenv import load_dotenv
@ -198,7 +198,10 @@ def get_secret( # noqa: PLR0915
raise ValueError("Unsupported OIDC provider") raise ValueError("Unsupported OIDC provider")
try: try:
if litellm.secret_manager_client is not None: if (
_should_read_secret_from_secret_manager()
and litellm.secret_manager_client is not None
):
try: try:
client = litellm.secret_manager_client client = litellm.secret_manager_client
key_manager = "local" key_manager = "local"
@ -207,7 +210,8 @@ def get_secret( # noqa: PLR0915
if key_management_settings is not None: if key_management_settings is not None:
if ( if (
secret_name not in key_management_settings.hosted_keys key_management_settings.hosted_keys is not None
and secret_name not in key_management_settings.hosted_keys
): # allow user to specify which keys to check in hosted key manager ): # allow user to specify which keys to check in hosted key manager
key_manager = "local" key_manager = "local"
@ -268,25 +272,13 @@ def get_secret( # noqa: PLR0915
if isinstance(secret, str): if isinstance(secret, str):
secret = secret.strip() secret = secret.strip()
elif key_manager == KeyManagementSystem.AWS_SECRET_MANAGER.value: elif key_manager == KeyManagementSystem.AWS_SECRET_MANAGER.value:
try: from litellm.secret_managers.aws_secret_manager_v2 import (
get_secret_value_response = client.get_secret_value( AWSSecretsManagerV2,
SecretId=secret_name )
)
print_verbose(
f"get_secret_value_response: {get_secret_value_response}"
)
except Exception as e:
print_verbose(f"An error occurred - {str(e)}")
# For a list of exceptions thrown, see
# https://docs.aws.amazon.com/secretsmanager/latest/apireference/API_GetSecretValue.html
raise e
# assume there is 1 secret per secret_name if isinstance(client, AWSSecretsManagerV2):
secret_dict = json.loads(get_secret_value_response["SecretString"]) secret = client.sync_read_secret(secret_name=secret_name)
print_verbose(f"secret_dict: {secret_dict}") print_verbose(f"get_secret_value_response: {secret}")
for k, v in secret_dict.items():
secret = v
print_verbose(f"secret: {secret}")
elif key_manager == KeyManagementSystem.GOOGLE_SECRET_MANAGER.value: elif key_manager == KeyManagementSystem.GOOGLE_SECRET_MANAGER.value:
try: try:
secret = client.get_secret_from_google_secret_manager( secret = client.get_secret_from_google_secret_manager(
@ -332,3 +324,21 @@ def get_secret( # noqa: PLR0915
return default_value return default_value
else: else:
raise e raise e
def _should_read_secret_from_secret_manager() -> bool:
"""
Returns True if the secret manager should be used to read the secret, False otherwise
- If the secret manager client is not set, return False
- If the `_key_management_settings` access mode is "read_only" or "read_and_write", return True
- Otherwise, return False
"""
if litellm.secret_manager_client is not None:
if litellm._key_management_settings is not None:
if (
litellm._key_management_settings.access_mode == "read_only"
or litellm._key_management_settings.access_mode == "read_and_write"
):
return True
return False

View file

@ -7,6 +7,7 @@ https://docs.cohere.com/reference/rerank
from typing import List, Optional, Union from typing import List, Optional, Union
from pydantic import BaseModel, PrivateAttr from pydantic import BaseModel, PrivateAttr
from typing_extensions import TypedDict
class RerankRequest(BaseModel): class RerankRequest(BaseModel):
@ -19,10 +20,26 @@ class RerankRequest(BaseModel):
max_chunks_per_doc: Optional[int] = None max_chunks_per_doc: Optional[int] = None
class RerankBilledUnits(TypedDict, total=False):
search_units: int
total_tokens: int
class RerankTokens(TypedDict, total=False):
input_tokens: int
output_tokens: int
class RerankResponseMeta(TypedDict, total=False):
api_version: dict
billed_units: RerankBilledUnits
tokens: RerankTokens
class RerankResponse(BaseModel): class RerankResponse(BaseModel):
id: str id: str
results: List[dict] # Contains index and relevance_score results: List[dict] # Contains index and relevance_score
meta: Optional[dict] = None # Contains api_version and billed_units meta: Optional[RerankResponseMeta] = None # Contains api_version and billed_units
# Define private attributes using PrivateAttr # Define private attributes using PrivateAttr
_hidden_params: dict = PrivateAttr(default_factory=dict) _hidden_params: dict = PrivateAttr(default_factory=dict)

View file

@ -150,6 +150,8 @@ class GenericLiteLLMParams(BaseModel):
max_retries: Optional[int] = None max_retries: Optional[int] = None
organization: Optional[str] = None # for openai orgs organization: Optional[str] = None # for openai orgs
configurable_clientside_auth_params: CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS = None configurable_clientside_auth_params: CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS = None
## LOGGING PARAMS ##
litellm_trace_id: Optional[str] = None
## UNIFIED PROJECT/REGION ## ## UNIFIED PROJECT/REGION ##
region_name: Optional[str] = None region_name: Optional[str] = None
## VERTEX AI ## ## VERTEX AI ##
@ -186,6 +188,8 @@ class GenericLiteLLMParams(BaseModel):
None # timeout when making stream=True calls, if str, pass in as os.environ/ None # timeout when making stream=True calls, if str, pass in as os.environ/
), ),
organization: Optional[str] = None, # for openai orgs organization: Optional[str] = None, # for openai orgs
## LOGGING PARAMS ##
litellm_trace_id: Optional[str] = None,
## UNIFIED PROJECT/REGION ## ## UNIFIED PROJECT/REGION ##
region_name: Optional[str] = None, region_name: Optional[str] = None,
## VERTEX AI ## ## VERTEX AI ##

View file

@ -1334,6 +1334,7 @@ class ResponseFormatChunk(TypedDict, total=False):
all_litellm_params = [ all_litellm_params = [
"metadata", "metadata",
"litellm_trace_id",
"tags", "tags",
"acompletion", "acompletion",
"aimg_generation", "aimg_generation",
@ -1523,6 +1524,7 @@ StandardLoggingPayloadStatus = Literal["success", "failure"]
class StandardLoggingPayload(TypedDict): class StandardLoggingPayload(TypedDict):
id: str id: str
trace_id: str # Trace multiple LLM calls belonging to same overall request (e.g. fallbacks/retries)
call_type: str call_type: str
response_cost: float response_cost: float
response_cost_failure_debug_info: Optional[ response_cost_failure_debug_info: Optional[

View file

@ -527,6 +527,7 @@ def function_setup( # noqa: PLR0915
messages=messages, messages=messages,
stream=stream, stream=stream,
litellm_call_id=kwargs["litellm_call_id"], litellm_call_id=kwargs["litellm_call_id"],
litellm_trace_id=kwargs.get("litellm_trace_id"),
function_id=function_id or "", function_id=function_id or "",
call_type=call_type, call_type=call_type,
start_time=start_time, start_time=start_time,
@ -2056,6 +2057,7 @@ def get_litellm_params(
azure_ad_token_provider=None, azure_ad_token_provider=None,
user_continue_message=None, user_continue_message=None,
base_model=None, base_model=None,
litellm_trace_id=None,
): ):
litellm_params = { litellm_params = {
"acompletion": acompletion, "acompletion": acompletion,
@ -2084,6 +2086,7 @@ def get_litellm_params(
"user_continue_message": user_continue_message, "user_continue_message": user_continue_message,
"base_model": base_model "base_model": base_model
or _get_base_model_from_litellm_call_metadata(metadata=metadata), or _get_base_model_from_litellm_call_metadata(metadata=metadata),
"litellm_trace_id": litellm_trace_id,
} }
return litellm_params return litellm_params

View file

@ -2986,19 +2986,19 @@
"supports_function_calling": true "supports_function_calling": true
}, },
"vertex_ai/imagegeneration@006": { "vertex_ai/imagegeneration@006": {
"cost_per_image": 0.020, "output_cost_per_image": 0.020,
"litellm_provider": "vertex_ai-image-models", "litellm_provider": "vertex_ai-image-models",
"mode": "image_generation", "mode": "image_generation",
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing" "source": "https://cloud.google.com/vertex-ai/generative-ai/pricing"
}, },
"vertex_ai/imagen-3.0-generate-001": { "vertex_ai/imagen-3.0-generate-001": {
"cost_per_image": 0.04, "output_cost_per_image": 0.04,
"litellm_provider": "vertex_ai-image-models", "litellm_provider": "vertex_ai-image-models",
"mode": "image_generation", "mode": "image_generation",
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing" "source": "https://cloud.google.com/vertex-ai/generative-ai/pricing"
}, },
"vertex_ai/imagen-3.0-fast-generate-001": { "vertex_ai/imagen-3.0-fast-generate-001": {
"cost_per_image": 0.02, "output_cost_per_image": 0.02,
"litellm_provider": "vertex_ai-image-models", "litellm_provider": "vertex_ai-image-models",
"mode": "image_generation", "mode": "image_generation",
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing" "source": "https://cloud.google.com/vertex-ai/generative-ai/pricing"
@ -5620,6 +5620,13 @@
"litellm_provider": "bedrock", "litellm_provider": "bedrock",
"mode": "image_generation" "mode": "image_generation"
}, },
"stability.stable-image-ultra-v1:0": {
"max_tokens": 77,
"max_input_tokens": 77,
"output_cost_per_image": 0.14,
"litellm_provider": "bedrock",
"mode": "image_generation"
},
"sagemaker/meta-textgeneration-llama-2-7b": { "sagemaker/meta-textgeneration-llama-2-7b": {
"max_tokens": 4096, "max_tokens": 4096,
"max_input_tokens": 4096, "max_input_tokens": 4096,

View file

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "litellm" name = "litellm"
version = "1.52.6" version = "1.52.9"
description = "Library to easily interface with LLM API providers" description = "Library to easily interface with LLM API providers"
authors = ["BerriAI"] authors = ["BerriAI"]
license = "MIT" license = "MIT"
@ -91,7 +91,7 @@ requires = ["poetry-core", "wheel"]
build-backend = "poetry.core.masonry.api" build-backend = "poetry.core.masonry.api"
[tool.commitizen] [tool.commitizen]
version = "1.52.6" version = "1.52.9"
version_files = [ version_files = [
"pyproject.toml:^version" "pyproject.toml:^version"
] ]

View file

@ -13,8 +13,11 @@ sys.path.insert(
import litellm import litellm
from litellm.exceptions import BadRequestError from litellm.exceptions import BadRequestError
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.utils import CustomStreamWrapper from litellm.utils import (
CustomStreamWrapper,
get_supported_openai_params,
get_optional_params,
)
# test_example.py # test_example.py
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
@ -45,6 +48,9 @@ class BaseLLMChatTest(ABC):
) )
assert response is not None assert response is not None
# for OpenAI the content contains the JSON schema, so we need to assert that the content is not None
assert response.choices[0].message.content is not None
def test_message_with_name(self): def test_message_with_name(self):
base_completion_call_args = self.get_base_completion_call_args() base_completion_call_args = self.get_base_completion_call_args()
messages = [ messages = [
@ -79,6 +85,49 @@ class BaseLLMChatTest(ABC):
print(response) print(response)
# OpenAI guarantees that the JSON schema is returned in the content
# relevant issue: https://github.com/BerriAI/litellm/issues/6741
assert response.choices[0].message.content is not None
def test_json_response_format_stream(self):
"""
Test that the JSON response format with streaming is supported by the LLM API
"""
base_completion_call_args = self.get_base_completion_call_args()
litellm.set_verbose = True
messages = [
{
"role": "system",
"content": "Your output should be a JSON object with no additional properties. ",
},
{
"role": "user",
"content": "Respond with this in json. city=San Francisco, state=CA, weather=sunny, temp=60",
},
]
response = litellm.completion(
**base_completion_call_args,
messages=messages,
response_format={"type": "json_object"},
stream=True,
)
print(response)
content = ""
for chunk in response:
content += chunk.choices[0].delta.content or ""
print("content=", content)
# OpenAI guarantees that the JSON schema is returned in the content
# relevant issue: https://github.com/BerriAI/litellm/issues/6741
# we need to assert that the JSON schema was returned in the content, (for Anthropic we were returning it as part of the tool call)
assert content is not None
assert len(content) > 0
@pytest.fixture @pytest.fixture
def pdf_messages(self): def pdf_messages(self):
import base64 import base64

View file

@ -0,0 +1,115 @@
import asyncio
import httpx
import json
import pytest
import sys
from typing import Any, Dict, List
from unittest.mock import MagicMock, Mock, patch
import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import litellm
from litellm.exceptions import BadRequestError
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.utils import (
CustomStreamWrapper,
get_supported_openai_params,
get_optional_params,
)
# test_example.py
from abc import ABC, abstractmethod
def assert_response_shape(response, custom_llm_provider):
expected_response_shape = {"id": str, "results": list, "meta": dict}
expected_results_shape = {"index": int, "relevance_score": float}
expected_meta_shape = {"api_version": dict, "billed_units": dict}
expected_api_version_shape = {"version": str}
expected_billed_units_shape = {"search_units": int}
assert isinstance(response.id, expected_response_shape["id"])
assert isinstance(response.results, expected_response_shape["results"])
for result in response.results:
assert isinstance(result["index"], expected_results_shape["index"])
assert isinstance(
result["relevance_score"], expected_results_shape["relevance_score"]
)
assert isinstance(response.meta, expected_response_shape["meta"])
if custom_llm_provider == "cohere":
assert isinstance(
response.meta["api_version"], expected_meta_shape["api_version"]
)
assert isinstance(
response.meta["api_version"]["version"],
expected_api_version_shape["version"],
)
assert isinstance(
response.meta["billed_units"], expected_meta_shape["billed_units"]
)
assert isinstance(
response.meta["billed_units"]["search_units"],
expected_billed_units_shape["search_units"],
)
class BaseLLMRerankTest(ABC):
"""
Abstract base test class that enforces a common test across all test classes.
"""
@abstractmethod
def get_base_rerank_call_args(self) -> dict:
"""Must return the base rerank call args"""
pass
@abstractmethod
def get_custom_llm_provider(self) -> litellm.LlmProviders:
"""Must return the custom llm provider"""
pass
@pytest.mark.asyncio()
@pytest.mark.parametrize("sync_mode", [True, False])
async def test_basic_rerank(self, sync_mode):
rerank_call_args = self.get_base_rerank_call_args()
custom_llm_provider = self.get_custom_llm_provider()
if sync_mode is True:
response = litellm.rerank(
**rerank_call_args,
query="hello",
documents=["hello", "world"],
top_n=3,
)
print("re rank response: ", response)
assert response.id is not None
assert response.results is not None
assert_response_shape(
response=response, custom_llm_provider=custom_llm_provider.value
)
else:
response = await litellm.arerank(
**rerank_call_args,
query="hello",
documents=["hello", "world"],
top_n=3,
)
print("async re rank response: ", response)
assert response.id is not None
assert response.results is not None
assert_response_shape(
response=response, custom_llm_provider=custom_llm_provider.value
)

View file

@ -33,8 +33,10 @@ from litellm import (
) )
from litellm.adapters.anthropic_adapter import anthropic_adapter from litellm.adapters.anthropic_adapter import anthropic_adapter
from litellm.types.llms.anthropic import AnthropicResponse from litellm.types.llms.anthropic import AnthropicResponse
from litellm.types.utils import GenericStreamingChunk, ChatCompletionToolCallChunk
from litellm.types.llms.openai import ChatCompletionToolCallFunctionChunk
from litellm.llms.anthropic.common_utils import process_anthropic_headers from litellm.llms.anthropic.common_utils import process_anthropic_headers
from litellm.llms.anthropic.chat.handler import AnthropicChatCompletion
from httpx import Headers from httpx import Headers
from base_llm_unit_tests import BaseLLMChatTest from base_llm_unit_tests import BaseLLMChatTest
@ -694,3 +696,91 @@ class TestAnthropicCompletion(BaseLLMChatTest):
assert _document_validation["type"] == "document" assert _document_validation["type"] == "document"
assert _document_validation["source"]["media_type"] == "application/pdf" assert _document_validation["source"]["media_type"] == "application/pdf"
assert _document_validation["source"]["type"] == "base64" assert _document_validation["source"]["type"] == "base64"
def test_convert_tool_response_to_message_with_values():
"""Test converting a tool response with 'values' key to a message"""
tool_calls = [
ChatCompletionToolCallChunk(
id="test_id",
type="function",
function=ChatCompletionToolCallFunctionChunk(
name="json_tool_call",
arguments='{"values": {"name": "John", "age": 30}}',
),
index=0,
)
]
message = AnthropicChatCompletion._convert_tool_response_to_message(
tool_calls=tool_calls
)
assert message is not None
assert message.content == '{"name": "John", "age": 30}'
def test_convert_tool_response_to_message_without_values():
"""
Test converting a tool response without 'values' key to a message
Anthropic API returns the JSON schema in the tool call, OpenAI Spec expects it in the message. This test ensures that the tool call is converted to a message correctly.
Relevant issue: https://github.com/BerriAI/litellm/issues/6741
"""
tool_calls = [
ChatCompletionToolCallChunk(
id="test_id",
type="function",
function=ChatCompletionToolCallFunctionChunk(
name="json_tool_call", arguments='{"name": "John", "age": 30}'
),
index=0,
)
]
message = AnthropicChatCompletion._convert_tool_response_to_message(
tool_calls=tool_calls
)
assert message is not None
assert message.content == '{"name": "John", "age": 30}'
def test_convert_tool_response_to_message_invalid_json():
"""Test converting a tool response with invalid JSON"""
tool_calls = [
ChatCompletionToolCallChunk(
id="test_id",
type="function",
function=ChatCompletionToolCallFunctionChunk(
name="json_tool_call", arguments="invalid json"
),
index=0,
)
]
message = AnthropicChatCompletion._convert_tool_response_to_message(
tool_calls=tool_calls
)
assert message is not None
assert message.content == "invalid json"
def test_convert_tool_response_to_message_no_arguments():
"""Test converting a tool response with no arguments"""
tool_calls = [
ChatCompletionToolCallChunk(
id="test_id",
type="function",
function=ChatCompletionToolCallFunctionChunk(name="json_tool_call"),
index=0,
)
]
message = AnthropicChatCompletion._convert_tool_response_to_message(
tool_calls=tool_calls
)
assert message is None

View file

@ -0,0 +1,23 @@
import json
import os
import sys
from datetime import datetime
from unittest.mock import AsyncMock
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
from base_rerank_unit_tests import BaseLLMRerankTest
import litellm
class TestJinaAI(BaseLLMRerankTest):
def get_custom_llm_provider(self) -> litellm.LlmProviders:
return litellm.LlmProviders.JINA_AI
def get_base_rerank_call_args(self) -> dict:
return {
"model": "jina_ai/jina-reranker-v2-base-multilingual",
}

View file

@ -923,9 +923,22 @@ def test_watsonx_text_top_k():
assert optional_params["top_k"] == 10 assert optional_params["top_k"] == 10
def test_together_ai_model_params(): def test_together_ai_model_params():
optional_params = get_optional_params( optional_params = get_optional_params(
model="together_ai", custom_llm_provider="together_ai", logprobs=1 model="together_ai", custom_llm_provider="together_ai", logprobs=1
) )
print(optional_params) print(optional_params)
assert optional_params["logprobs"] == 1 assert optional_params["logprobs"] == 1
def test_forward_user_param():
from litellm.utils import get_supported_openai_params, get_optional_params
model = "claude-3-5-sonnet-20240620"
optional_params = get_optional_params(
model=model,
user="test_user",
custom_llm_provider="anthropic",
)
assert optional_params["metadata"]["user_id"] == "test_user"

View file

@ -16,6 +16,7 @@ import pytest
import litellm import litellm
from litellm import get_optional_params from litellm import get_optional_params
from litellm.llms.custom_httpx.http_handler import HTTPHandler from litellm.llms.custom_httpx.http_handler import HTTPHandler
import httpx
def test_completion_pydantic_obj_2(): def test_completion_pydantic_obj_2():
@ -1317,3 +1318,39 @@ def test_image_completion_request(image_url):
mock_post.assert_called_once() mock_post.assert_called_once()
print("mock_post.call_args.kwargs['json']", mock_post.call_args.kwargs["json"]) print("mock_post.call_args.kwargs['json']", mock_post.call_args.kwargs["json"])
assert mock_post.call_args.kwargs["json"] == expected_request_body assert mock_post.call_args.kwargs["json"] == expected_request_body
@pytest.mark.parametrize(
"model, expected_url",
[
(
"textembedding-gecko@001",
"https://us-central1-aiplatform.googleapis.com/v1/projects/project-id/locations/us-central1/publishers/google/models/textembedding-gecko@001:predict",
),
(
"123456789",
"https://us-central1-aiplatform.googleapis.com/v1/projects/project-id/locations/us-central1/endpoints/123456789:predict",
),
],
)
def test_vertex_embedding_url(model, expected_url):
"""
Test URL generation for embedding models, including numeric model IDs (fine-tuned models
Relevant issue: https://github.com/BerriAI/litellm/issues/6482
When a fine-tuned embedding model is used, the URL is different from the standard one.
"""
from litellm.llms.vertex_ai_and_google_ai_studio.common_utils import _get_vertex_url
url, endpoint = _get_vertex_url(
mode="embedding",
model=model,
stream=False,
vertex_project="project-id",
vertex_location="us-central1",
vertex_api_version="v1",
)
assert url == expected_url
assert endpoint == "predict"

View file

@ -18,6 +18,8 @@ import json
import os import os
import tempfile import tempfile
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
from respx import MockRouter
import httpx
import pytest import pytest
@ -973,6 +975,7 @@ async def test_partner_models_httpx(model, sync_mode):
data = { data = {
"model": model, "model": model,
"messages": messages, "messages": messages,
"timeout": 10,
} }
if sync_mode: if sync_mode:
response = litellm.completion(**data) response = litellm.completion(**data)
@ -986,6 +989,8 @@ async def test_partner_models_httpx(model, sync_mode):
assert isinstance(response._hidden_params["response_cost"], float) assert isinstance(response._hidden_params["response_cost"], float)
except litellm.RateLimitError as e: except litellm.RateLimitError as e:
pass pass
except litellm.Timeout as e:
pass
except litellm.InternalServerError as e: except litellm.InternalServerError as e:
pass pass
except Exception as e: except Exception as e:
@ -3051,3 +3056,70 @@ def test_custom_api_base(api_base):
assert url == api_base + ":" assert url == api_base + ":"
else: else:
assert url == test_endpoint assert url == test_endpoint
@pytest.mark.asyncio
@pytest.mark.respx
async def test_vertexai_embedding_finetuned(respx_mock: MockRouter):
"""
Tests that:
- Request URL and body are correctly formatted for Vertex AI embeddings
- Response is properly parsed into litellm's embedding response format
"""
load_vertex_ai_credentials()
litellm.set_verbose = True
# Test input
input_text = ["good morning from litellm", "this is another item"]
# Expected request/response
expected_url = "https://us-central1-aiplatform.googleapis.com/v1/projects/633608382793/locations/us-central1/endpoints/1004708436694269952:predict"
expected_request = {
"instances": [
{"inputs": "good morning from litellm"},
{"inputs": "this is another item"},
],
"parameters": {},
}
mock_response = {
"predictions": [
[[-0.000431762, -0.04416759, -0.03443353]], # Truncated embedding vector
[[-0.000431762, -0.04416759, -0.03443353]], # Truncated embedding vector
],
"deployedModelId": "2275167734310371328",
"model": "projects/633608382793/locations/us-central1/models/snowflake-arctic-embed-m-long-1731622468876",
"modelDisplayName": "snowflake-arctic-embed-m-long-1731622468876",
"modelVersionId": "1",
}
# Setup mock request
mock_request = respx_mock.post(expected_url).mock(
return_value=httpx.Response(200, json=mock_response)
)
# Make request
response = await litellm.aembedding(
vertex_project="633608382793",
model="vertex_ai/1004708436694269952",
input=input_text,
)
# Assert request was made correctly
assert mock_request.called
request_body = json.loads(mock_request.calls[0].request.content)
print("\n\nrequest_body", request_body)
print("\n\nexpected_request", expected_request)
assert request_body == expected_request
# Assert response structure
assert response is not None
assert hasattr(response, "data")
assert len(response.data) == len(input_text)
# Assert embedding structure
for embedding in response.data:
assert "embedding" in embedding
assert isinstance(embedding["embedding"], list)
assert len(embedding["embedding"]) > 0
assert all(isinstance(x, float) for x in embedding["embedding"])

View file

@ -0,0 +1,139 @@
# What is this?
import asyncio
import os
import sys
import traceback
from dotenv import load_dotenv
import litellm.types
import litellm.types.utils
load_dotenv()
import io
import sys
import os
# Ensure the project root is in the Python path
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..")))
print("Python Path:", sys.path)
print("Current Working Directory:", os.getcwd())
from typing import Optional
from unittest.mock import MagicMock, patch
import pytest
import uuid
import json
from litellm.secret_managers.aws_secret_manager_v2 import AWSSecretsManagerV2
def check_aws_credentials():
"""Helper function to check if AWS credentials are set"""
required_vars = ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_REGION_NAME"]
missing_vars = [var for var in required_vars if not os.getenv(var)]
if missing_vars:
pytest.skip(f"Missing required AWS credentials: {', '.join(missing_vars)}")
@pytest.mark.asyncio
async def test_write_and_read_simple_secret():
"""Test writing and reading a simple string secret"""
check_aws_credentials()
secret_manager = AWSSecretsManagerV2()
test_secret_name = f"litellm_test_{uuid.uuid4().hex[:8]}"
test_secret_value = "test_value_123"
try:
# Write secret
write_response = await secret_manager.async_write_secret(
secret_name=test_secret_name,
secret_value=test_secret_value,
description="LiteLLM Test Secret",
)
print("Write Response:", write_response)
assert write_response is not None
assert "ARN" in write_response
assert "Name" in write_response
assert write_response["Name"] == test_secret_name
# Read secret back
read_value = await secret_manager.async_read_secret(
secret_name=test_secret_name
)
print("Read Value:", read_value)
assert read_value == test_secret_value
finally:
# Cleanup: Delete the secret
delete_response = await secret_manager.async_delete_secret(
secret_name=test_secret_name
)
print("Delete Response:", delete_response)
assert delete_response is not None
@pytest.mark.asyncio
async def test_write_and_read_json_secret():
"""Test writing and reading a JSON structured secret"""
check_aws_credentials()
secret_manager = AWSSecretsManagerV2()
test_secret_name = f"litellm_test_{uuid.uuid4().hex[:8]}_json"
test_secret_value = {
"api_key": "test_key",
"model": "gpt-4",
"temperature": 0.7,
"metadata": {"team": "ml", "project": "litellm"},
}
try:
# Write JSON secret
write_response = await secret_manager.async_write_secret(
secret_name=test_secret_name,
secret_value=json.dumps(test_secret_value),
description="LiteLLM JSON Test Secret",
)
print("Write Response:", write_response)
# Read and parse JSON secret
read_value = await secret_manager.async_read_secret(
secret_name=test_secret_name
)
parsed_value = json.loads(read_value)
print("Read Value:", read_value)
assert parsed_value == test_secret_value
assert parsed_value["api_key"] == "test_key"
assert parsed_value["metadata"]["team"] == "ml"
finally:
# Cleanup: Delete the secret
delete_response = await secret_manager.async_delete_secret(
secret_name=test_secret_name
)
print("Delete Response:", delete_response)
assert delete_response is not None
@pytest.mark.asyncio
async def test_read_nonexistent_secret():
"""Test reading a secret that doesn't exist"""
check_aws_credentials()
secret_manager = AWSSecretsManagerV2()
nonexistent_secret = f"litellm_nonexistent_{uuid.uuid4().hex}"
response = await secret_manager.async_read_secret(secret_name=nonexistent_secret)
assert response is None

View file

@ -24,7 +24,7 @@ from litellm import RateLimitError, Timeout, completion, completion_cost, embedd
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.llms.prompt_templates.factory import anthropic_messages_pt from litellm.llms.prompt_templates.factory import anthropic_messages_pt
# litellm.num_retries = 3 # litellm.num_retries=3
litellm.cache = None litellm.cache = None
litellm.success_callback = [] litellm.success_callback = []

View file

@ -10,7 +10,7 @@ import os
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system-path
from typing import Literal from typing import Literal
import pytest import pytest

View file

@ -1624,3 +1624,55 @@ async def test_standard_logging_payload_stream_usage(sync_mode):
print(f"standard_logging_object usage: {built_response.usage}") print(f"standard_logging_object usage: {built_response.usage}")
except litellm.InternalServerError: except litellm.InternalServerError:
pass pass
def test_standard_logging_retries():
"""
know if a request was retried.
"""
from litellm.types.utils import StandardLoggingPayload
from litellm.router import Router
customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler]
router = Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "openai/gpt-3.5-turbo",
"api_key": "test-api-key",
},
}
]
)
with patch.object(
customHandler, "log_failure_event", new=MagicMock()
) as mock_client:
try:
router.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
num_retries=1,
mock_response="litellm.RateLimitError",
)
except litellm.RateLimitError:
pass
assert mock_client.call_count == 2
assert (
mock_client.call_args_list[0].kwargs["kwargs"]["standard_logging_object"][
"trace_id"
]
is not None
)
assert (
mock_client.call_args_list[0].kwargs["kwargs"]["standard_logging_object"][
"trace_id"
]
== mock_client.call_args_list[1].kwargs["kwargs"][
"standard_logging_object"
]["trace_id"]
)

View file

@ -58,6 +58,7 @@ async def test_content_policy_exception_azure():
except litellm.ContentPolicyViolationError as e: except litellm.ContentPolicyViolationError as e:
print("caught a content policy violation error! Passed") print("caught a content policy violation error! Passed")
print("exception", e) print("exception", e)
assert e.response is not None
assert e.litellm_debug_info is not None assert e.litellm_debug_info is not None
assert isinstance(e.litellm_debug_info, str) assert isinstance(e.litellm_debug_info, str)
assert len(e.litellm_debug_info) > 0 assert len(e.litellm_debug_info) > 0
@ -1152,3 +1153,24 @@ async def test_exception_with_headers_httpx(
if exception_raised is False: if exception_raised is False:
print(resp) print(resp)
assert exception_raised assert exception_raised
@pytest.mark.asyncio
@pytest.mark.parametrize("model", ["azure/chatgpt-v-2", "openai/gpt-3.5-turbo"])
async def test_bad_request_error_contains_httpx_response(model):
"""
Test that the BadRequestError contains the httpx response
Relevant issue: https://github.com/BerriAI/litellm/issues/6732
"""
try:
await litellm.acompletion(
model=model,
messages=[{"role": "user", "content": "Hello world"}],
bad_arg="bad_arg",
)
pytest.fail("Expected to raise BadRequestError")
except litellm.BadRequestError as e:
print("e.response", e.response)
print("vars(e.response)", vars(e.response))
assert e.response is not None

View file

@ -157,7 +157,7 @@ def test_get_llm_provider_jina_ai():
model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider( model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider(
model="jina_ai/jina-embeddings-v3", model="jina_ai/jina-embeddings-v3",
) )
assert custom_llm_provider == "openai_like" assert custom_llm_provider == "jina_ai"
assert api_base == "https://api.jina.ai/v1" assert api_base == "https://api.jina.ai/v1"
assert model == "jina-embeddings-v3" assert model == "jina-embeddings-v3"

View file

@ -89,11 +89,16 @@ def test_get_model_info_ollama_chat():
"template": "tools", "template": "tools",
} }
), ),
): ) as mock_client:
info = OllamaConfig().get_model_info("mistral") info = OllamaConfig().get_model_info("mistral")
print("info", info)
assert info["supports_function_calling"] is True assert info["supports_function_calling"] is True
info = get_model_info("ollama/mistral") info = get_model_info("ollama/mistral")
print("info", info)
assert info["supports_function_calling"] is True assert info["supports_function_calling"] is True
mock_client.assert_called()
print(mock_client.call_args.kwargs)
assert mock_client.call_args.kwargs["json"]["name"] == "mistral"

View file

@ -1138,9 +1138,9 @@ async def test_router_content_policy_fallbacks(
router = Router( router = Router(
model_list=[ model_list=[
{ {
"model_name": "claude-2", "model_name": "claude-2.1",
"litellm_params": { "litellm_params": {
"model": "claude-2", "model": "claude-2.1",
"api_key": "", "api_key": "",
"mock_response": mock_response, "mock_response": mock_response,
}, },
@ -1164,7 +1164,7 @@ async def test_router_content_policy_fallbacks(
{ {
"model_name": "my-general-model", "model_name": "my-general-model",
"litellm_params": { "litellm_params": {
"model": "claude-2", "model": "claude-2.1",
"api_key": "", "api_key": "",
"mock_response": Exception("Should not have called this."), "mock_response": Exception("Should not have called this."),
}, },
@ -1172,14 +1172,14 @@ async def test_router_content_policy_fallbacks(
{ {
"model_name": "my-context-window-model", "model_name": "my-context-window-model",
"litellm_params": { "litellm_params": {
"model": "claude-2", "model": "claude-2.1",
"api_key": "", "api_key": "",
"mock_response": Exception("Should not have called this."), "mock_response": Exception("Should not have called this."),
}, },
}, },
], ],
content_policy_fallbacks=( content_policy_fallbacks=(
[{"claude-2": ["my-fallback-model"]}] [{"claude-2.1": ["my-fallback-model"]}]
if fallback_type == "model-specific" if fallback_type == "model-specific"
else None else None
), ),
@ -1190,12 +1190,12 @@ async def test_router_content_policy_fallbacks(
if sync_mode is True: if sync_mode is True:
response = router.completion( response = router.completion(
model="claude-2", model="claude-2.1",
messages=[{"role": "user", "content": "Hey, how's it going?"}], messages=[{"role": "user", "content": "Hey, how's it going?"}],
) )
else: else:
response = await router.acompletion( response = await router.acompletion(
model="claude-2", model="claude-2.1",
messages=[{"role": "user", "content": "Hey, how's it going?"}], messages=[{"role": "user", "content": "Hey, how's it going?"}],
) )
@ -1455,3 +1455,46 @@ async def test_router_fallbacks_default_and_model_specific_fallbacks(sync_mode):
assert isinstance( assert isinstance(
exc_info.value, litellm.AuthenticationError exc_info.value, litellm.AuthenticationError
), f"Expected AuthenticationError, but got {type(exc_info.value).__name__}" ), f"Expected AuthenticationError, but got {type(exc_info.value).__name__}"
@pytest.mark.asyncio
async def test_router_disable_fallbacks_dynamically():
from litellm.router import run_async_fallback
router = Router(
model_list=[
{
"model_name": "bad-model",
"litellm_params": {
"model": "openai/my-bad-model",
"api_key": "my-bad-api-key",
},
},
{
"model_name": "good-model",
"litellm_params": {
"model": "gpt-4o",
"api_key": os.getenv("OPENAI_API_KEY"),
},
},
],
fallbacks=[{"bad-model": ["good-model"]}],
default_fallbacks=["good-model"],
)
with patch.object(
router,
"log_retry",
new=MagicMock(return_value=None),
) as mock_client:
try:
resp = await router.acompletion(
model="bad-model",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
disable_fallbacks=True,
)
print(resp)
except Exception as e:
print(e)
mock_client.assert_not_called()

View file

@ -14,6 +14,7 @@ from litellm.router import Deployment, LiteLLM_Params, ModelInfo
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from collections import defaultdict from collections import defaultdict
from dotenv import load_dotenv from dotenv import load_dotenv
from unittest.mock import patch, MagicMock, AsyncMock
load_dotenv() load_dotenv()
@ -83,3 +84,93 @@ def test_returned_settings():
except Exception: except Exception:
print(traceback.format_exc()) print(traceback.format_exc())
pytest.fail("An error occurred - " + traceback.format_exc()) pytest.fail("An error occurred - " + traceback.format_exc())
from litellm.types.utils import CallTypes
def test_update_kwargs_before_fallbacks_unit_test():
router = Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "azure/chatgpt-v-2",
"api_key": "bad-key",
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"),
},
}
],
)
kwargs = {"messages": [{"role": "user", "content": "write 1 sentence poem"}]}
router._update_kwargs_before_fallbacks(
model="gpt-3.5-turbo",
kwargs=kwargs,
)
assert kwargs["litellm_trace_id"] is not None
@pytest.mark.parametrize(
"call_type",
[
CallTypes.acompletion,
CallTypes.atext_completion,
CallTypes.aembedding,
CallTypes.arerank,
CallTypes.atranscription,
],
)
@pytest.mark.asyncio
async def test_update_kwargs_before_fallbacks(call_type):
router = Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "azure/chatgpt-v-2",
"api_key": "bad-key",
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"),
},
}
],
)
if call_type.value.startswith("a"):
with patch.object(router, "async_function_with_fallbacks") as mock_client:
if call_type.value == "acompletion":
input_kwarg = {
"messages": [{"role": "user", "content": "Hello, how are you?"}],
}
elif (
call_type.value == "atext_completion"
or call_type.value == "aimage_generation"
):
input_kwarg = {
"prompt": "Hello, how are you?",
}
elif call_type.value == "aembedding" or call_type.value == "arerank":
input_kwarg = {
"input": "Hello, how are you?",
}
elif call_type.value == "atranscription":
input_kwarg = {
"file": "path/to/file",
}
else:
input_kwarg = {}
await getattr(router, call_type.value)(
model="gpt-3.5-turbo",
**input_kwarg,
)
mock_client.assert_called_once()
print(mock_client.call_args.kwargs)
assert mock_client.call_args.kwargs["litellm_trace_id"] is not None

View file

@ -15,22 +15,29 @@ sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
import pytest import pytest
import litellm
from litellm.llms.AzureOpenAI.azure import get_azure_ad_token_from_oidc from litellm.llms.AzureOpenAI.azure import get_azure_ad_token_from_oidc
from litellm.llms.bedrock.chat import BedrockConverseLLM, BedrockLLM from litellm.llms.bedrock.chat import BedrockConverseLLM, BedrockLLM
from litellm.secret_managers.aws_secret_manager import load_aws_secret_manager from litellm.secret_managers.aws_secret_manager_v2 import AWSSecretsManagerV2
from litellm.secret_managers.main import get_secret from litellm.secret_managers.main import (
get_secret,
_should_read_secret_from_secret_manager,
)
@pytest.mark.skip(reason="AWS Suspended Account")
def test_aws_secret_manager(): def test_aws_secret_manager():
load_aws_secret_manager(use_aws_secret_manager=True) import json
AWSSecretsManagerV2.load_aws_secret_manager(use_aws_secret_manager=True)
secret_val = get_secret("litellm_master_key") secret_val = get_secret("litellm_master_key")
print(f"secret_val: {secret_val}") print(f"secret_val: {secret_val}")
assert secret_val == "sk-1234" # cast json to dict
secret_val = json.loads(secret_val)
assert secret_val["litellm_master_key"] == "sk-1234"
def redact_oidc_signature(secret_val): def redact_oidc_signature(secret_val):
@ -240,3 +247,71 @@ def test_google_secret_manager_read_in_memory():
) )
print("secret_val: {}".format(secret_val)) print("secret_val: {}".format(secret_val))
assert secret_val == "lite-llm" assert secret_val == "lite-llm"
def test_should_read_secret_from_secret_manager():
"""
Test that _should_read_secret_from_secret_manager returns correct values based on access mode
"""
from litellm.proxy._types import KeyManagementSettings
# Test when secret manager client is None
litellm.secret_manager_client = None
litellm._key_management_settings = KeyManagementSettings()
assert _should_read_secret_from_secret_manager() is False
# Test with secret manager client and read_only access
litellm.secret_manager_client = "dummy_client"
litellm._key_management_settings = KeyManagementSettings(access_mode="read_only")
assert _should_read_secret_from_secret_manager() is True
# Test with secret manager client and read_and_write access
litellm._key_management_settings = KeyManagementSettings(
access_mode="read_and_write"
)
assert _should_read_secret_from_secret_manager() is True
# Test with secret manager client and write_only access
litellm._key_management_settings = KeyManagementSettings(access_mode="write_only")
assert _should_read_secret_from_secret_manager() is False
# Reset global variables
litellm.secret_manager_client = None
litellm._key_management_settings = KeyManagementSettings()
def test_get_secret_with_access_mode():
"""
Test that get_secret respects access mode settings
"""
from litellm.proxy._types import KeyManagementSettings
# Set up test environment
test_secret_name = "TEST_SECRET_KEY"
test_secret_value = "test_secret_value"
os.environ[test_secret_name] = test_secret_value
# Test with write_only access (should read from os.environ)
litellm.secret_manager_client = "dummy_client"
litellm._key_management_settings = KeyManagementSettings(access_mode="write_only")
assert get_secret(test_secret_name) == test_secret_value
# Test with no KeyManagementSettings but secret_manager_client set
litellm.secret_manager_client = "dummy_client"
litellm._key_management_settings = KeyManagementSettings()
assert _should_read_secret_from_secret_manager() is True
# Test with read_only access
litellm._key_management_settings = KeyManagementSettings(access_mode="read_only")
assert _should_read_secret_from_secret_manager() is True
# Test with read_and_write access
litellm._key_management_settings = KeyManagementSettings(
access_mode="read_and_write"
)
assert _should_read_secret_from_secret_manager() is True
# Reset global variables
litellm.secret_manager_client = None
litellm._key_management_settings = KeyManagementSettings()
del os.environ[test_secret_name]

View file

@ -184,12 +184,11 @@ def test_stream_chunk_builder_litellm_usage_chunks():
{"role": "assistant", "content": "uhhhh\n\n\nhmmmm.....\nthinking....\n"}, {"role": "assistant", "content": "uhhhh\n\n\nhmmmm.....\nthinking....\n"},
{"role": "user", "content": "\nI am waiting...\n\n...\n"}, {"role": "user", "content": "\nI am waiting...\n\n...\n"},
] ]
# make a regular gemini call
usage: litellm.Usage = Usage( usage: litellm.Usage = Usage(
completion_tokens=64, completion_tokens=27,
prompt_tokens=55, prompt_tokens=55,
total_tokens=119, total_tokens=82,
completion_tokens_details=None, completion_tokens_details=None,
prompt_tokens_details=None, prompt_tokens_details=None,
) )

View file

@ -718,7 +718,7 @@ async def test_acompletion_claude_2_stream():
try: try:
litellm.set_verbose = True litellm.set_verbose = True
response = await litellm.acompletion( response = await litellm.acompletion(
model="claude-2", model="claude-2.1",
messages=[{"role": "user", "content": "hello from litellm"}], messages=[{"role": "user", "content": "hello from litellm"}],
stream=True, stream=True,
) )
@ -3274,7 +3274,7 @@ def test_completion_claude_3_function_call_with_streaming():
], # "claude-3-opus-20240229" ], # "claude-3-opus-20240229"
) # ) #
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_acompletion_claude_3_function_call_with_streaming(model): async def test_acompletion_function_call_with_streaming(model):
litellm.set_verbose = True litellm.set_verbose = True
tools = [ tools = [
{ {
@ -3335,6 +3335,8 @@ async def test_acompletion_claude_3_function_call_with_streaming(model):
# raise Exception("it worked! ") # raise Exception("it worked! ")
except litellm.InternalServerError as e: except litellm.InternalServerError as e:
pytest.skip(f"InternalServerError - {str(e)}") pytest.skip(f"InternalServerError - {str(e)}")
except litellm.ServiceUnavailableError:
pass
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")

View file

@ -3451,3 +3451,90 @@ async def test_user_api_key_auth_db_unavailable_not_allowed():
request=request, request=request,
api_key="Bearer sk-123456789", api_key="Bearer sk-123456789",
) )
## E2E Virtual Key + Secret Manager Tests #########################################
@pytest.mark.asyncio
async def test_key_generate_with_secret_manager_call(prisma_client):
"""
Generate a key
assert it exists in the secret manager
delete the key
assert it is deleted from the secret manager
"""
from litellm.secret_managers.aws_secret_manager_v2 import AWSSecretsManagerV2
from litellm.proxy._types import KeyManagementSystem, KeyManagementSettings
litellm.set_verbose = True
#### Test Setup ############################################################
aws_secret_manager_client = AWSSecretsManagerV2()
litellm.secret_manager_client = aws_secret_manager_client
litellm._key_management_system = KeyManagementSystem.AWS_SECRET_MANAGER
litellm._key_management_settings = KeyManagementSettings(
store_virtual_keys=True,
)
general_settings = {
"key_management_system": "aws_secret_manager",
"key_management_settings": {
"store_virtual_keys": True,
},
}
setattr(litellm.proxy.proxy_server, "general_settings", general_settings)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
await litellm.proxy.proxy_server.prisma_client.connect()
############################################################################
# generate new key
key_alias = f"test_alias_secret_manager_key-{uuid.uuid4()}"
spend = 100
max_budget = 400
models = ["fake-openai-endpoint"]
new_key = await generate_key_fn(
data=GenerateKeyRequest(
key_alias=key_alias, spend=spend, max_budget=max_budget, models=models
),
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="1234",
),
)
generated_key = new_key.key
print(generated_key)
await asyncio.sleep(2)
# read from the secret manager
result = await aws_secret_manager_client.async_read_secret(secret_name=key_alias)
# Assert the correct key is stored in the secret manager
print("response from AWS Secret Manager")
print(result)
assert result == generated_key
# delete the key
await delete_key_fn(
data=KeyRequest(keys=[generated_key]),
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN, api_key="sk-1234", user_id="1234"
),
)
await asyncio.sleep(2)
# Assert the key is deleted from the secret manager
result = await aws_secret_manager_client.async_read_secret(secret_name=key_alias)
assert result is None
# cleanup
setattr(litellm.proxy.proxy_server, "general_settings", {})
################################################################################

View file

@ -1500,6 +1500,31 @@ async def test_add_callback_via_key_litellm_pre_call_utils(
assert new_data["failure_callback"] == expected_failure_callbacks assert new_data["failure_callback"] == expected_failure_callbacks
@pytest.mark.asyncio
@pytest.mark.parametrize(
"disable_fallbacks_set",
[
True,
False,
],
)
async def test_disable_fallbacks_by_key(disable_fallbacks_set):
from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup
key_metadata = {"disable_fallbacks": disable_fallbacks_set}
existing_data = {
"model": "azure/chatgpt-v-2",
"messages": [{"role": "user", "content": "write 1 sentence poem"}],
}
data = LiteLLMProxyRequestSetup.add_key_level_controls(
key_metadata=key_metadata,
data=existing_data,
_metadata_variable_name="metadata",
)
assert data["disable_fallbacks"] == disable_fallbacks_set
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"callback_type, expected_success_callbacks, expected_failure_callbacks", "callback_type, expected_success_callbacks, expected_failure_callbacks",