diff --git a/README.md b/README.md
index bfdba2fa3..5d3efe355 100644
--- a/README.md
+++ b/README.md
@@ -305,6 +305,36 @@ Step 4: Submit a PR with your changes! 🚀
- push your fork to your GitHub repo
- submit a PR from there
+### Building LiteLLM Docker Image
+
+Follow these instructions if you want to build / run the LiteLLM Docker Image yourself.
+
+Step 1: Clone the repo
+
+```
+git clone https://github.com/BerriAI/litellm.git
+```
+
+Step 2: Build the Docker Image
+
+Build using Dockerfile.non_root
+```
+docker build -f docker/Dockerfile.non_root -t litellm_test_image .
+```
+
+Step 3: Run the Docker Image
+
+Make sure config.yaml is present in the root directory. This is your litellm proxy config file.
+```
+docker run \
+ -v $(pwd)/proxy_config.yaml:/app/config.yaml \
+ -e DATABASE_URL="postgresql://xxxxxxxx" \
+ -e LITELLM_MASTER_KEY="sk-1234" \
+ -p 4000:4000 \
+ litellm_test_image \
+ --config /app/config.yaml --detailed_debug
+```
+
# Enterprise
For companies that need better security, user management and professional support
diff --git a/deploy/charts/litellm-helm/templates/migrations-job.yaml b/deploy/charts/litellm-helm/templates/migrations-job.yaml
index fc1aacf16..010d2d1b5 100644
--- a/deploy/charts/litellm-helm/templates/migrations-job.yaml
+++ b/deploy/charts/litellm-helm/templates/migrations-job.yaml
@@ -13,18 +13,18 @@ spec:
spec:
containers:
- name: prisma-migrations
- image: "ghcr.io/berriai/litellm:main-stable"
+ image: ghcr.io/berriai/litellm-database:main-latest
command: ["python", "litellm/proxy/prisma_migration.py"]
workingDir: "/app"
env:
- {{- if .Values.db.deployStandalone }}
- - name: DATABASE_URL
- value: postgresql://{{ .Values.postgresql.auth.username }}:{{ .Values.postgresql.auth.password }}@{{ .Release.Name }}-postgresql/{{ .Values.postgresql.auth.database }}
- {{- else if .Values.db.useExisting }}
+ {{- if .Values.db.useExisting }}
- name: DATABASE_URL
value: {{ .Values.db.url | quote }}
+ {{- else }}
+ - name: DATABASE_URL
+ value: postgresql://{{ .Values.postgresql.auth.username }}:{{ .Values.postgresql.auth.password }}@{{ .Release.Name }}-postgresql/{{ .Values.postgresql.auth.database }}
{{- end }}
- name: DISABLE_SCHEMA_UPDATE
- value: "{{ .Values.migrationJob.disableSchemaUpdate }}"
+ value: "false" # always run the migration from the Helm PreSync hook, override the value set
restartPolicy: OnFailure
backoffLimit: {{ .Values.migrationJob.backoffLimit }}
diff --git a/docs/my-website/docs/completion/json_mode.md b/docs/my-website/docs/completion/json_mode.md
index a782bfb0a..51f76b7a6 100644
--- a/docs/my-website/docs/completion/json_mode.md
+++ b/docs/my-website/docs/completion/json_mode.md
@@ -75,6 +75,7 @@ Works for:
- Google AI Studio - Gemini models
- Vertex AI models (Gemini + Anthropic)
- Bedrock Models
+- Anthropic API Models
diff --git a/docs/my-website/docs/completion/prefix.md b/docs/my-website/docs/completion/prefix.md
index e3619a2a0..d413ad989 100644
--- a/docs/my-website/docs/completion/prefix.md
+++ b/docs/my-website/docs/completion/prefix.md
@@ -93,7 +93,7 @@ curl http://0.0.0.0:4000/v1/chat/completions \
## Check Model Support
-Call `litellm.get_model_info` to check if a model/provider supports `response_format`.
+Call `litellm.get_model_info` to check if a model/provider supports `prefix`.
@@ -116,4 +116,4 @@ curl -X GET 'http://0.0.0.0:4000/v1/model/info' \
-H 'Authorization: Bearer $LITELLM_KEY' \
```
-
\ No newline at end of file
+
diff --git a/docs/my-website/docs/providers/anthropic.md b/docs/my-website/docs/providers/anthropic.md
index 290e094d0..d4660b807 100644
--- a/docs/my-website/docs/providers/anthropic.md
+++ b/docs/my-website/docs/providers/anthropic.md
@@ -957,3 +957,69 @@ curl http://0.0.0.0:4000/v1/chat/completions \
```
+
+## Usage - passing 'user_id' to Anthropic
+
+LiteLLM translates the OpenAI `user` param to Anthropic's `metadata[user_id]` param.
+
+
+
+
+```python
+response = completion(
+ model="claude-3-5-sonnet-20240620",
+ messages=messages,
+ user="user_123",
+)
+```
+
+
+
+1. Setup config.yaml
+
+```yaml
+model_list:
+ - model_name: claude-3-5-sonnet-20240620
+ litellm_params:
+ model: anthropic/claude-3-5-sonnet-20240620
+ api_key: os.environ/ANTHROPIC_API_KEY
+```
+
+2. Start Proxy
+
+```
+litellm --config /path/to/config.yaml
+```
+
+3. Test it!
+
+```bash
+curl http://0.0.0.0:4000/v1/chat/completions \
+ -H "Content-Type: application/json" \
+ -H "Authorization: Bearer " \
+ -d '{
+ "model": "claude-3-5-sonnet-20240620",
+ "messages": [{"role": "user", "content": "What is Anthropic?"}],
+ "user": "user_123"
+ }'
+```
+
+
+
+
+## All Supported OpenAI Params
+
+```
+"stream",
+"stop",
+"temperature",
+"top_p",
+"max_tokens",
+"max_completion_tokens",
+"tools",
+"tool_choice",
+"extra_headers",
+"parallel_tool_calls",
+"response_format",
+"user"
+```
\ No newline at end of file
diff --git a/docs/my-website/docs/providers/huggingface.md b/docs/my-website/docs/providers/huggingface.md
index 4620a6c5d..5297a688b 100644
--- a/docs/my-website/docs/providers/huggingface.md
+++ b/docs/my-website/docs/providers/huggingface.md
@@ -37,7 +37,7 @@ os.environ["HUGGINGFACE_API_KEY"] = "huggingface_api_key"
messages = [{ "content": "There's a llama in my garden 😱 What should I do?","role": "user"}]
# e.g. Call 'https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct' from Serverless Inference API
-response = litellm.completion(
+response = completion(
model="huggingface/meta-llama/Meta-Llama-3.1-8B-Instruct",
messages=[{ "content": "Hello, how are you?","role": "user"}],
stream=True
@@ -165,14 +165,14 @@ Steps to use
```python
import os
-import litellm
+from litellm import completion
os.environ["HUGGINGFACE_API_KEY"] = ""
# TGI model: Call https://huggingface.co/glaiveai/glaive-coder-7b
# add the 'huggingface/' prefix to the model to set huggingface as the provider
# set api base to your deployed api endpoint from hugging face
-response = litellm.completion(
+response = completion(
model="huggingface/glaiveai/glaive-coder-7b",
messages=[{ "content": "Hello, how are you?","role": "user"}],
api_base="https://wjiegasee9bmqke2.us-east-1.aws.endpoints.huggingface.cloud"
@@ -383,6 +383,8 @@ def default_pt(messages):
#### Custom prompt templates
```python
+import litellm
+
# Create your own custom prompt template works
litellm.register_prompt_template(
model="togethercomputer/LLaMA-2-7B-32K",
diff --git a/docs/my-website/docs/providers/jina_ai.md b/docs/my-website/docs/providers/jina_ai.md
index 499cf6709..6c13dbf1a 100644
--- a/docs/my-website/docs/providers/jina_ai.md
+++ b/docs/my-website/docs/providers/jina_ai.md
@@ -1,6 +1,13 @@
+import Tabs from '@theme/Tabs';
+import TabItem from '@theme/TabItem';
+
# Jina AI
https://jina.ai/embeddings/
+Supported endpoints:
+- /embeddings
+- /rerank
+
## API Key
```python
# env variable
@@ -8,6 +15,10 @@ os.environ['JINA_AI_API_KEY']
```
## Sample Usage - Embedding
+
+
+
+
```python
from litellm import embedding
import os
@@ -19,6 +30,142 @@ response = embedding(
)
print(response)
```
+
+
+
+1. Add to config.yaml
+```yaml
+model_list:
+ - model_name: embedding-model
+ litellm_params:
+ model: jina_ai/jina-embeddings-v3
+ api_key: os.environ/JINA_AI_API_KEY
+```
+
+2. Start proxy
+
+```bash
+litellm --config /path/to/config.yaml
+
+# RUNNING on http://0.0.0.0:4000/
+```
+
+3. Test it!
+
+```bash
+curl -L -X POST 'http://0.0.0.0:4000/embeddings' \
+-H 'Authorization: Bearer sk-1234' \
+-H 'Content-Type: application/json' \
+-d '{"input": ["hello world"], "model": "embedding-model"}'
+```
+
+
+
+
+## Sample Usage - Rerank
+
+
+
+
+```python
+from litellm import rerank
+import os
+
+os.environ["JINA_AI_API_KEY"] = "sk-..."
+
+query = "What is the capital of the United States?"
+documents = [
+ "Carson City is the capital city of the American state of Nevada.",
+ "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.",
+ "Washington, D.C. is the capital of the United States.",
+ "Capital punishment has existed in the United States since before it was a country.",
+]
+
+response = rerank(
+ model="jina_ai/jina-reranker-v2-base-multilingual",
+ query=query,
+ documents=documents,
+ top_n=3,
+)
+print(response)
+```
+
+
+
+1. Add to config.yaml
+```yaml
+model_list:
+ - model_name: rerank-model
+ litellm_params:
+ model: jina_ai/jina-reranker-v2-base-multilingual
+ api_key: os.environ/JINA_AI_API_KEY
+```
+
+2. Start proxy
+
+```bash
+litellm --config /path/to/config.yaml
+```
+
+3. Test it!
+
+```bash
+curl -L -X POST 'http://0.0.0.0:4000/rerank' \
+-H 'Authorization: Bearer sk-1234' \
+-H 'Content-Type: application/json' \
+-d '{
+ "model": "rerank-model",
+ "query": "What is the capital of the United States?",
+ "documents": [
+ "Carson City is the capital city of the American state of Nevada.",
+ "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.",
+ "Washington, D.C. is the capital of the United States.",
+ "Capital punishment has existed in the United States since before it was a country."
+ ],
+ "top_n": 3
+}'
+```
+
+
+
## Supported Models
All models listed here https://jina.ai/embeddings/ are supported
+
+## Supported Optional Rerank Parameters
+
+All cohere rerank parameters are supported.
+
+## Supported Optional Embeddings Parameters
+
+```
+dimensions
+```
+
+## Provider-specific parameters
+
+Pass any jina ai specific parameters as a keyword argument to the `embedding` or `rerank` function, e.g.
+
+
+
+
+```python
+response = embedding(
+ model="jina_ai/jina-embeddings-v3",
+ input=["good morning from litellm"],
+ dimensions=1536,
+ my_custom_param="my_custom_value", # any other jina ai specific parameters
+)
+```
+
+
+
+```bash
+curl -L -X POST 'http://0.0.0.0:4000/embeddings' \
+-H 'Authorization: Bearer sk-1234' \
+-H 'Content-Type: application/json' \
+-d '{"input": ["good morning from litellm"], "model": "jina_ai/jina-embeddings-v3", "dimensions": 1536, "my_custom_param": "my_custom_value"}'
+```
+
+
+
diff --git a/docs/my-website/docs/providers/vertex.md b/docs/my-website/docs/providers/vertex.md
index b69e8ee56..921db9e73 100644
--- a/docs/my-website/docs/providers/vertex.md
+++ b/docs/my-website/docs/providers/vertex.md
@@ -1562,6 +1562,10 @@ curl http://0.0.0.0:4000/v1/chat/completions \
## **Embedding Models**
#### Usage - Embedding
+
+
+
+
```python
import litellm
from litellm import embedding
@@ -1574,6 +1578,49 @@ response = embedding(
)
print(response)
```
+
+
+
+
+
+1. Add model to config.yaml
+```yaml
+model_list:
+ - model_name: snowflake-arctic-embed-m-long-1731622468876
+ litellm_params:
+ model: vertex_ai/
+ vertex_project: "adroit-crow-413218"
+ vertex_location: "us-central1"
+ vertex_credentials: adroit-crow-413218-a956eef1a2a8.json
+
+litellm_settings:
+ drop_params: True
+```
+
+2. Start Proxy
+
+```
+$ litellm --config /path/to/config.yaml
+```
+
+3. Make Request using OpenAI Python SDK, Langchain Python SDK
+
+```python
+import openai
+
+client = openai.OpenAI(api_key="sk-1234", base_url="http://0.0.0.0:4000")
+
+response = client.embeddings.create(
+ model="snowflake-arctic-embed-m-long-1731622468876",
+ input = ["good morning from litellm", "this is another item"],
+)
+
+print(response)
+```
+
+
+
+
#### Supported Embedding Models
All models listed [here](https://github.com/BerriAI/litellm/blob/57f37f743886a0249f630a6792d49dffc2c5d9b7/model_prices_and_context_window.json#L835) are supported
@@ -1589,6 +1636,7 @@ All models listed [here](https://github.com/BerriAI/litellm/blob/57f37f743886a02
| textembedding-gecko@003 | `embedding(model="vertex_ai/textembedding-gecko@003", input)` |
| text-embedding-preview-0409 | `embedding(model="vertex_ai/text-embedding-preview-0409", input)` |
| text-multilingual-embedding-preview-0409 | `embedding(model="vertex_ai/text-multilingual-embedding-preview-0409", input)` |
+| Fine-tuned OR Custom Embedding models | `embedding(model="vertex_ai/", input)` |
### Supported OpenAI (Unified) Params
diff --git a/docs/my-website/docs/proxy/configs.md b/docs/my-website/docs/proxy/configs.md
index c6b9f2d45..888f424b4 100644
--- a/docs/my-website/docs/proxy/configs.md
+++ b/docs/my-website/docs/proxy/configs.md
@@ -791,9 +791,9 @@ general_settings:
| store_model_in_db | boolean | If true, allows `/model/new` endpoint to store model information in db. Endpoint disabled by default. [Doc on `/model/new` endpoint](./model_management.md#create-a-new-model) |
| max_request_size_mb | int | The maximum size for requests in MB. Requests above this size will be rejected. |
| max_response_size_mb | int | The maximum size for responses in MB. LLM Responses above this size will not be sent. |
-| proxy_budget_rescheduler_min_time | int | The minimum time (in seconds) to wait before checking db for budget resets. |
-| proxy_budget_rescheduler_max_time | int | The maximum time (in seconds) to wait before checking db for budget resets. |
-| proxy_batch_write_at | int | Time (in seconds) to wait before batch writing spend logs to the db. |
+| proxy_budget_rescheduler_min_time | int | The minimum time (in seconds) to wait before checking db for budget resets. **Default is 597 seconds** |
+| proxy_budget_rescheduler_max_time | int | The maximum time (in seconds) to wait before checking db for budget resets. **Default is 605 seconds** |
+| proxy_batch_write_at | int | Time (in seconds) to wait before batch writing spend logs to the db. **Default is 10 seconds** |
| alerting_args | dict | Args for Slack Alerting [Doc on Slack Alerting](./alerting.md) |
| custom_key_generate | str | Custom function for key generation [Doc on custom key generation](./virtual_keys.md#custom--key-generate) |
| allowed_ips | List[str] | List of IPs allowed to access the proxy. If not set, all IPs are allowed. |
diff --git a/docs/my-website/docs/proxy/logging.md b/docs/my-website/docs/proxy/logging.md
index 5867a8f23..1bd1b6c4b 100644
--- a/docs/my-website/docs/proxy/logging.md
+++ b/docs/my-website/docs/proxy/logging.md
@@ -66,10 +66,16 @@ Removes any field with `user_api_key_*` from metadata.
Found under `kwargs["standard_logging_object"]`. This is a standard payload, logged for every response.
```python
+
class StandardLoggingPayload(TypedDict):
id: str
+ trace_id: str # Trace multiple LLM calls belonging to same overall request (e.g. fallbacks/retries)
call_type: str
response_cost: float
+ response_cost_failure_debug_info: Optional[
+ StandardLoggingModelCostFailureDebugInformation
+ ]
+ status: StandardLoggingPayloadStatus
total_tokens: int
prompt_tokens: int
completion_tokens: int
@@ -84,13 +90,13 @@ class StandardLoggingPayload(TypedDict):
metadata: StandardLoggingMetadata
cache_hit: Optional[bool]
cache_key: Optional[str]
- saved_cache_cost: Optional[float]
- request_tags: list
+ saved_cache_cost: float
+ request_tags: list
end_user: Optional[str]
- requester_ip_address: Optional[str] # IP address of requester
- requester_metadata: Optional[dict] # metadata passed in request in the "metadata" field
+ requester_ip_address: Optional[str]
messages: Optional[Union[str, list, dict]]
response: Optional[Union[str, list, dict]]
+ error_str: Optional[str]
model_parameters: dict
hidden_params: StandardLoggingHiddenParams
@@ -99,12 +105,47 @@ class StandardLoggingHiddenParams(TypedDict):
cache_key: Optional[str]
api_base: Optional[str]
response_cost: Optional[str]
- additional_headers: Optional[dict]
+ additional_headers: Optional[StandardLoggingAdditionalHeaders]
+class StandardLoggingAdditionalHeaders(TypedDict, total=False):
+ x_ratelimit_limit_requests: int
+ x_ratelimit_limit_tokens: int
+ x_ratelimit_remaining_requests: int
+ x_ratelimit_remaining_tokens: int
+
+class StandardLoggingMetadata(StandardLoggingUserAPIKeyMetadata):
+ """
+ Specific metadata k,v pairs logged to integration for easier cost tracking
+ """
+
+ spend_logs_metadata: Optional[
+ dict
+ ] # special param to log k,v pairs to spendlogs for a call
+ requester_ip_address: Optional[str]
+ requester_metadata: Optional[dict]
class StandardLoggingModelInformation(TypedDict):
model_map_key: str
model_map_value: Optional[ModelInfo]
+
+
+StandardLoggingPayloadStatus = Literal["success", "failure"]
+
+class StandardLoggingModelCostFailureDebugInformation(TypedDict, total=False):
+ """
+ Debug information, if cost tracking fails.
+
+ Avoid logging sensitive information like response or optional params
+ """
+
+ error_str: Required[str]
+ traceback_str: Required[str]
+ model: str
+ cache_hit: Optional[bool]
+ custom_llm_provider: Optional[str]
+ base_model: Optional[str]
+ call_type: str
+ custom_pricing: Optional[bool]
```
## Langfuse
diff --git a/docs/my-website/docs/proxy/prod.md b/docs/my-website/docs/proxy/prod.md
index 66c719e5d..32a6fceee 100644
--- a/docs/my-website/docs/proxy/prod.md
+++ b/docs/my-website/docs/proxy/prod.md
@@ -1,5 +1,6 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
+import Image from '@theme/IdealImage';
# âš¡ Best Practices for Production
@@ -112,7 +113,35 @@ general_settings:
disable_spend_logs: True
```
-## 7. Set LiteLLM Salt Key
+## 7. Use Helm PreSync Hook for Database Migrations [BETA]
+
+To ensure only one service manages database migrations, use our [Helm PreSync hook for Database Migrations](https://github.com/BerriAI/litellm/blob/main/deploy/charts/litellm-helm/templates/migrations-job.yaml). This ensures migrations are handled during `helm upgrade` or `helm install`, while LiteLLM pods explicitly disable migrations.
+
+
+1. **Helm PreSync Hook**:
+ - The Helm PreSync hook is configured in the chart to run database migrations during deployments.
+ - The hook always sets `DISABLE_SCHEMA_UPDATE=false`, ensuring migrations are executed reliably.
+
+ Reference Settings to set on ArgoCD for `values.yaml`
+
+ ```yaml
+ db:
+ useExisting: true # use existing Postgres DB
+ url: postgresql://ishaanjaffer0324:3rnwpOBau6hT@ep-withered-mud-a5dkdpke.us-east-2.aws.neon.tech/test-argo-cd?sslmode=require # url of existing Postgres DB
+ ```
+
+2. **LiteLLM Pods**:
+ - Set `DISABLE_SCHEMA_UPDATE=true` in LiteLLM pod configurations to prevent them from running migrations.
+
+ Example configuration for LiteLLM pod:
+ ```yaml
+ env:
+ - name: DISABLE_SCHEMA_UPDATE
+ value: "true"
+ ```
+
+
+## 8. Set LiteLLM Salt Key
If you plan on using the DB, set a salt key for encrypting/decrypting variables in the DB.
diff --git a/docs/my-website/docs/proxy/reliability.md b/docs/my-website/docs/proxy/reliability.md
index 9a3ba4ec6..73f25f817 100644
--- a/docs/my-website/docs/proxy/reliability.md
+++ b/docs/my-website/docs/proxy/reliability.md
@@ -748,4 +748,19 @@ curl -L -X POST 'http://0.0.0.0:4000/v1/chat/completions' \
"max_tokens": 300,
"mock_testing_fallbacks": true
}'
+```
+
+### Disable Fallbacks per key
+
+You can disable fallbacks per key by setting `disable_fallbacks: true` in your key metadata.
+
+```bash
+curl -L -X POST 'http://0.0.0.0:4000/key/generate' \
+-H 'Authorization: Bearer sk-1234' \
+-H 'Content-Type: application/json' \
+-d '{
+ "metadata": {
+ "disable_fallbacks": true
+ }
+}'
```
\ No newline at end of file
diff --git a/docs/my-website/docs/rerank.md b/docs/my-website/docs/rerank.md
index 8179e6b81..d25b552fb 100644
--- a/docs/my-website/docs/rerank.md
+++ b/docs/my-website/docs/rerank.md
@@ -113,4 +113,5 @@ curl http://0.0.0.0:4000/rerank \
|-------------|--------------------|
| Cohere | [Usage](#quick-start) |
| Together AI| [Usage](../docs/providers/togetherai) |
-| Azure AI| [Usage](../docs/providers/azure_ai) |
\ No newline at end of file
+| Azure AI| [Usage](../docs/providers/azure_ai) |
+| Jina AI| [Usage](../docs/providers/jina_ai) |
\ No newline at end of file
diff --git a/docs/my-website/docs/secret.md b/docs/my-website/docs/secret.md
index db5ec6910..15480ea3d 100644
--- a/docs/my-website/docs/secret.md
+++ b/docs/my-website/docs/secret.md
@@ -1,3 +1,6 @@
+import Tabs from '@theme/Tabs';
+import TabItem from '@theme/TabItem';
+
# Secret Manager
LiteLLM supports reading secrets from Azure Key Vault, Google Secret Manager
@@ -59,14 +62,35 @@ os.environ["AWS_REGION_NAME"] = "" # us-east-1, us-east-2, us-west-1, us-west-2
```
2. Enable AWS Secret Manager in config.
+
+
+
+
```yaml
general_settings:
master_key: os.environ/litellm_master_key
key_management_system: "aws_secret_manager" # 👈 KEY CHANGE
key_management_settings:
hosted_keys: ["litellm_master_key"] # 👈 Specify which env keys you stored on AWS
+
```
+
+
+
+
+This will only store virtual keys in AWS Secret Manager. No keys will be read from AWS Secret Manager.
+
+```yaml
+general_settings:
+ key_management_system: "aws_secret_manager" # 👈 KEY CHANGE
+ key_management_settings:
+ store_virtual_keys: true
+ access_mode: "write_only" # Literal["read_only", "write_only", "read_and_write"]
+```
+
+
+
3. Run proxy
```bash
@@ -181,16 +205,14 @@ litellm --config /path/to/config.yaml
Use encrypted keys from Google KMS on the proxy
-### Usage with LiteLLM Proxy Server
-
-## Step 1. Add keys to env
+Step 1. Add keys to env
```
export GOOGLE_APPLICATION_CREDENTIALS="/path/to/credentials.json"
export GOOGLE_KMS_RESOURCE_NAME="projects/*/locations/*/keyRings/*/cryptoKeys/*"
export PROXY_DATABASE_URL_ENCRYPTED=b'\n$\x00D\xac\xb4/\x8e\xc...'
```
-## Step 2: Update Config
+Step 2: Update Config
```yaml
general_settings:
@@ -199,7 +221,7 @@ general_settings:
master_key: sk-1234
```
-## Step 3: Start + test proxy
+Step 3: Start + test proxy
```
$ litellm --config /path/to/config.yaml
@@ -215,3 +237,17 @@ $ litellm --test
+
+
+## All Secret Manager Settings
+
+All settings related to secret management
+
+```yaml
+general_settings:
+ key_management_system: "aws_secret_manager" # REQUIRED
+ key_management_settings:
+ store_virtual_keys: true # OPTIONAL. Defaults to False, when True will store virtual keys in secret manager
+ access_mode: "write_only" # OPTIONAL. Literal["read_only", "write_only", "read_and_write"]. Defaults to "read_only"
+ hosted_keys: ["litellm_master_key"] # OPTIONAL. Specify which env keys you stored on AWS
+```
\ No newline at end of file
diff --git a/litellm/__init__.py b/litellm/__init__.py
index e54117e11..edfe1a336 100644
--- a/litellm/__init__.py
+++ b/litellm/__init__.py
@@ -305,7 +305,7 @@ secret_manager_client: Optional[Any] = (
)
_google_kms_resource_name: Optional[str] = None
_key_management_system: Optional[KeyManagementSystem] = None
-_key_management_settings: Optional[KeyManagementSettings] = None
+_key_management_settings: KeyManagementSettings = KeyManagementSettings()
#### PII MASKING ####
output_parse_pii: bool = False
#############################################
@@ -962,6 +962,8 @@ from .utils import (
supports_response_schema,
supports_parallel_function_calling,
supports_vision,
+ supports_audio_input,
+ supports_audio_output,
supports_system_messages,
get_litellm_params,
acreate,
diff --git a/litellm/cost_calculator.py b/litellm/cost_calculator.py
index 0aa8a8e36..50bed6fe9 100644
--- a/litellm/cost_calculator.py
+++ b/litellm/cost_calculator.py
@@ -46,6 +46,9 @@ from litellm.llms.OpenAI.cost_calculation import (
from litellm.llms.OpenAI.cost_calculation import cost_per_token as openai_cost_per_token
from litellm.llms.OpenAI.cost_calculation import cost_router as openai_cost_router
from litellm.llms.together_ai.cost_calculator import get_model_params_and_category
+from litellm.llms.vertex_ai_and_google_ai_studio.image_generation.cost_calculator import (
+ cost_calculator as vertex_ai_image_cost_calculator,
+)
from litellm.types.llms.openai import HttpxBinaryResponseContent
from litellm.types.rerank import RerankResponse
from litellm.types.router import SPECIAL_MODEL_INFO_PARAMS
@@ -667,9 +670,11 @@ def completion_cost( # noqa: PLR0915
):
### IMAGE GENERATION COST CALCULATION ###
if custom_llm_provider == "vertex_ai":
- # https://cloud.google.com/vertex-ai/generative-ai/pricing
- # Vertex Charges Flat $0.20 per image
- return 0.020
+ if isinstance(completion_response, ImageResponse):
+ return vertex_ai_image_cost_calculator(
+ model=model,
+ image_response=completion_response,
+ )
elif custom_llm_provider == "bedrock":
if isinstance(completion_response, ImageResponse):
return bedrock_image_cost_calculator(
diff --git a/litellm/litellm_core_utils/exception_mapping_utils.py b/litellm/litellm_core_utils/exception_mapping_utils.py
index ca1de75be..3fb276611 100644
--- a/litellm/litellm_core_utils/exception_mapping_utils.py
+++ b/litellm/litellm_core_utils/exception_mapping_utils.py
@@ -239,7 +239,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"ContextWindowExceededError: {exception_provider} - {message}",
llm_provider=custom_llm_provider,
model=model,
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information,
)
elif (
@@ -251,7 +251,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"{exception_provider} - {message}",
llm_provider=custom_llm_provider,
model=model,
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information,
)
elif "A timeout occurred" in error_str:
@@ -271,7 +271,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"ContentPolicyViolationError: {exception_provider} - {message}",
llm_provider=custom_llm_provider,
model=model,
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information,
)
elif (
@@ -283,7 +283,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"{exception_provider} - {message}",
llm_provider=custom_llm_provider,
model=model,
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information,
)
elif "Web server is returning an unknown error" in error_str:
@@ -299,7 +299,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"RateLimitError: {exception_provider} - {message}",
model=model,
llm_provider=custom_llm_provider,
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information,
)
elif (
@@ -311,7 +311,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"AuthenticationError: {exception_provider} - {message}",
llm_provider=custom_llm_provider,
model=model,
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information,
)
elif "Mistral API raised a streaming error" in error_str:
@@ -335,7 +335,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"{exception_provider} - {message}",
llm_provider=custom_llm_provider,
model=model,
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information,
)
elif original_exception.status_code == 401:
@@ -344,7 +344,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"AuthenticationError: {exception_provider} - {message}",
llm_provider=custom_llm_provider,
model=model,
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information,
)
elif original_exception.status_code == 404:
@@ -353,7 +353,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"NotFoundError: {exception_provider} - {message}",
model=model,
llm_provider=custom_llm_provider,
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information,
)
elif original_exception.status_code == 408:
@@ -516,7 +516,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"ReplicateException - {error_str}",
llm_provider="replicate",
model=model,
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
)
elif "input is too long" in error_str:
exception_mapping_worked = True
@@ -524,7 +524,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"ReplicateException - {error_str}",
model=model,
llm_provider="replicate",
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
)
elif exception_type == "ModelError":
exception_mapping_worked = True
@@ -532,7 +532,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"ReplicateException - {error_str}",
model=model,
llm_provider="replicate",
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
)
elif "Request was throttled" in error_str:
exception_mapping_worked = True
@@ -540,7 +540,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"ReplicateException - {error_str}",
llm_provider="replicate",
model=model,
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
)
elif hasattr(original_exception, "status_code"):
if original_exception.status_code == 401:
@@ -549,7 +549,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"ReplicateException - {original_exception.message}",
llm_provider="replicate",
model=model,
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
)
elif (
original_exception.status_code == 400
@@ -560,7 +560,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"ReplicateException - {original_exception.message}",
model=model,
llm_provider="replicate",
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
)
elif original_exception.status_code == 422:
exception_mapping_worked = True
@@ -568,7 +568,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"ReplicateException - {original_exception.message}",
model=model,
llm_provider="replicate",
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
)
elif original_exception.status_code == 408:
exception_mapping_worked = True
@@ -583,7 +583,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"ReplicateException - {original_exception.message}",
llm_provider="replicate",
model=model,
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
)
elif original_exception.status_code == 429:
exception_mapping_worked = True
@@ -591,7 +591,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"ReplicateException - {original_exception.message}",
llm_provider="replicate",
model=model,
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
)
elif original_exception.status_code == 500:
exception_mapping_worked = True
@@ -599,7 +599,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"ReplicateException - {original_exception.message}",
llm_provider="replicate",
model=model,
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
)
exception_mapping_worked = True
raise APIError(
@@ -631,7 +631,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"{custom_llm_provider}Exception: Authentication Error - {error_str}",
llm_provider=custom_llm_provider,
model=model,
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information,
)
elif "token_quota_reached" in error_str:
@@ -640,7 +640,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"{custom_llm_provider}Exception: Rate Limit Errror - {error_str}",
llm_provider=custom_llm_provider,
model=model,
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
)
elif (
"The server received an invalid response from an upstream server."
@@ -750,7 +750,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"BedrockException - {error_str}\n. Enable 'litellm.modify_params=True' (for PROXY do: `litellm_settings::modify_params: True`) to insert a dummy assistant message and fix this error.",
model=model,
llm_provider="bedrock",
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
)
elif "Malformed input request" in error_str:
exception_mapping_worked = True
@@ -758,7 +758,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"BedrockException - {error_str}",
model=model,
llm_provider="bedrock",
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
)
elif "A conversation must start with a user message." in error_str:
exception_mapping_worked = True
@@ -766,7 +766,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"BedrockException - {error_str}\n. Pass in default user message via `completion(..,user_continue_message=)` or enable `litellm.modify_params=True`.\nFor Proxy: do via `litellm_settings::modify_params: True` or user_continue_message under `litellm_params`",
model=model,
llm_provider="bedrock",
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
)
elif (
"Unable to locate credentials" in error_str
@@ -778,7 +778,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"BedrockException Invalid Authentication - {error_str}",
model=model,
llm_provider="bedrock",
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
)
elif "AccessDeniedException" in error_str:
exception_mapping_worked = True
@@ -786,7 +786,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"BedrockException PermissionDeniedError - {error_str}",
model=model,
llm_provider="bedrock",
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
)
elif (
"throttlingException" in error_str
@@ -797,7 +797,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"BedrockException: Rate Limit Error - {error_str}",
model=model,
llm_provider="bedrock",
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
)
elif (
"Connect timeout on endpoint URL" in error_str
@@ -836,7 +836,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"BedrockException - {original_exception.message}",
llm_provider="bedrock",
model=model,
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
)
elif original_exception.status_code == 400:
exception_mapping_worked = True
@@ -844,7 +844,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"BedrockException - {original_exception.message}",
llm_provider="bedrock",
model=model,
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
)
elif original_exception.status_code == 404:
exception_mapping_worked = True
@@ -852,7 +852,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"BedrockException - {original_exception.message}",
llm_provider="bedrock",
model=model,
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
)
elif original_exception.status_code == 408:
exception_mapping_worked = True
@@ -868,7 +868,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"BedrockException - {original_exception.message}",
model=model,
llm_provider=custom_llm_provider,
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information,
)
elif original_exception.status_code == 429:
@@ -877,7 +877,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"BedrockException - {original_exception.message}",
model=model,
llm_provider=custom_llm_provider,
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information,
)
elif original_exception.status_code == 503:
@@ -886,7 +886,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"BedrockException - {original_exception.message}",
model=model,
llm_provider=custom_llm_provider,
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information,
)
elif original_exception.status_code == 504: # gateway timeout error
@@ -907,7 +907,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"litellm.BadRequestError: SagemakerException - {error_str}",
model=model,
llm_provider="sagemaker",
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
)
elif (
"Input validation error: `best_of` must be > 0 and <= 2"
@@ -918,7 +918,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message="SagemakerException - the value of 'n' must be > 0 and <= 2 for sagemaker endpoints",
model=model,
llm_provider="sagemaker",
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
)
elif (
"`inputs` tokens + `max_new_tokens` must be <=" in error_str
@@ -929,7 +929,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"SagemakerException - {error_str}",
model=model,
llm_provider="sagemaker",
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
)
elif hasattr(original_exception, "status_code"):
if original_exception.status_code == 500:
@@ -951,7 +951,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"SagemakerException - {original_exception.message}",
llm_provider=custom_llm_provider,
model=model,
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
)
elif original_exception.status_code == 400:
exception_mapping_worked = True
@@ -959,7 +959,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"SagemakerException - {original_exception.message}",
llm_provider=custom_llm_provider,
model=model,
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
)
elif original_exception.status_code == 404:
exception_mapping_worked = True
@@ -967,7 +967,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"SagemakerException - {original_exception.message}",
llm_provider=custom_llm_provider,
model=model,
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
)
elif original_exception.status_code == 408:
exception_mapping_worked = True
@@ -986,7 +986,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"SagemakerException - {original_exception.message}",
model=model,
llm_provider=custom_llm_provider,
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information,
)
elif original_exception.status_code == 429:
@@ -995,7 +995,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"SagemakerException - {original_exception.message}",
model=model,
llm_provider=custom_llm_provider,
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information,
)
elif original_exception.status_code == 503:
@@ -1004,7 +1004,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"SagemakerException - {original_exception.message}",
model=model,
llm_provider=custom_llm_provider,
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information,
)
elif original_exception.status_code == 504: # gateway timeout error
@@ -1217,7 +1217,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message="GeminiException - Invalid api key",
model=model,
llm_provider="palm",
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
)
if (
"504 Deadline expired before operation could complete." in error_str
@@ -1235,7 +1235,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"GeminiException - {error_str}",
model=model,
llm_provider="palm",
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
)
if (
"500 An internal error has occurred." in error_str
@@ -1262,7 +1262,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"GeminiException - {error_str}",
model=model,
llm_provider="palm",
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
)
# Dailed: Error occurred: 400 Request payload size exceeds the limit: 20000 bytes
elif custom_llm_provider == "cloudflare":
@@ -1272,7 +1272,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"Cloudflare Exception - {original_exception.message}",
llm_provider="cloudflare",
model=model,
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
)
if "must have required property" in error_str:
exception_mapping_worked = True
@@ -1280,7 +1280,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"Cloudflare Exception - {original_exception.message}",
llm_provider="cloudflare",
model=model,
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
)
elif (
custom_llm_provider == "cohere" or custom_llm_provider == "cohere_chat"
@@ -1294,7 +1294,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"CohereException - {original_exception.message}",
llm_provider="cohere",
model=model,
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
)
elif "too many tokens" in error_str:
exception_mapping_worked = True
@@ -1302,7 +1302,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"CohereException - {original_exception.message}",
model=model,
llm_provider="cohere",
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
)
elif hasattr(original_exception, "status_code"):
if (
@@ -1314,7 +1314,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"CohereException - {original_exception.message}",
llm_provider="cohere",
model=model,
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
)
elif original_exception.status_code == 408:
exception_mapping_worked = True
@@ -1329,7 +1329,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"CohereException - {original_exception.message}",
llm_provider="cohere",
model=model,
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
)
elif (
"CohereConnectionError" in exception_type
@@ -1339,7 +1339,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"CohereException - {original_exception.message}",
llm_provider="cohere",
model=model,
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
)
elif "invalid type:" in error_str:
exception_mapping_worked = True
@@ -1347,7 +1347,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"CohereException - {original_exception.message}",
llm_provider="cohere",
model=model,
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
)
elif "Unexpected server error" in error_str:
exception_mapping_worked = True
@@ -1355,7 +1355,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"CohereException - {original_exception.message}",
llm_provider="cohere",
model=model,
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
)
else:
if hasattr(original_exception, "status_code"):
@@ -1375,7 +1375,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=error_str,
model=model,
llm_provider="huggingface",
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
)
elif "A valid user token is required" in error_str:
exception_mapping_worked = True
@@ -1383,7 +1383,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=error_str,
llm_provider="huggingface",
model=model,
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
)
elif "Rate limit reached" in error_str:
exception_mapping_worked = True
@@ -1391,7 +1391,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=error_str,
llm_provider="huggingface",
model=model,
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
)
if hasattr(original_exception, "status_code"):
if original_exception.status_code == 401:
@@ -1400,7 +1400,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"HuggingfaceException - {original_exception.message}",
llm_provider="huggingface",
model=model,
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
)
elif original_exception.status_code == 400:
exception_mapping_worked = True
@@ -1408,7 +1408,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"HuggingfaceException - {original_exception.message}",
model=model,
llm_provider="huggingface",
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
)
elif original_exception.status_code == 408:
exception_mapping_worked = True
@@ -1423,7 +1423,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"HuggingfaceException - {original_exception.message}",
llm_provider="huggingface",
model=model,
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
)
elif original_exception.status_code == 503:
exception_mapping_worked = True
@@ -1431,7 +1431,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"HuggingfaceException - {original_exception.message}",
llm_provider="huggingface",
model=model,
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
)
else:
exception_mapping_worked = True
@@ -1450,7 +1450,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"AI21Exception - {original_exception.message}",
model=model,
llm_provider="ai21",
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
)
if "Bad or missing API token." in original_exception.message:
exception_mapping_worked = True
@@ -1458,7 +1458,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"AI21Exception - {original_exception.message}",
model=model,
llm_provider="ai21",
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
)
if hasattr(original_exception, "status_code"):
if original_exception.status_code == 401:
@@ -1467,7 +1467,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"AI21Exception - {original_exception.message}",
llm_provider="ai21",
model=model,
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
)
elif original_exception.status_code == 408:
exception_mapping_worked = True
@@ -1482,7 +1482,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"AI21Exception - {original_exception.message}",
model=model,
llm_provider="ai21",
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
)
elif original_exception.status_code == 429:
exception_mapping_worked = True
@@ -1490,7 +1490,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"AI21Exception - {original_exception.message}",
llm_provider="ai21",
model=model,
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
)
else:
exception_mapping_worked = True
@@ -1509,7 +1509,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"NLPCloudException - {error_str}",
model=model,
llm_provider="nlp_cloud",
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
)
elif "value is not a valid" in error_str:
exception_mapping_worked = True
@@ -1517,7 +1517,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"NLPCloudException - {error_str}",
model=model,
llm_provider="nlp_cloud",
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
)
else:
exception_mapping_worked = True
@@ -1542,7 +1542,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"NLPCloudException - {original_exception.message}",
llm_provider="nlp_cloud",
model=model,
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
)
elif (
original_exception.status_code == 401
@@ -1553,7 +1553,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"NLPCloudException - {original_exception.message}",
llm_provider="nlp_cloud",
model=model,
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
)
elif (
original_exception.status_code == 522
@@ -1574,7 +1574,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"NLPCloudException - {original_exception.message}",
llm_provider="nlp_cloud",
model=model,
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
)
elif (
original_exception.status_code == 500
@@ -1597,7 +1597,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"NLPCloudException - {original_exception.message}",
model=model,
llm_provider="nlp_cloud",
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
)
else:
exception_mapping_worked = True
@@ -1623,7 +1623,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"TogetherAIException - {error_response['error']}",
model=model,
llm_provider="together_ai",
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
)
elif (
"error" in error_response
@@ -1634,7 +1634,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"TogetherAIException - {error_response['error']}",
llm_provider="together_ai",
model=model,
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
)
elif (
"error" in error_response
@@ -1645,7 +1645,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"TogetherAIException - {error_response['error']}",
model=model,
llm_provider="together_ai",
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
)
elif "A timeout occurred" in error_str:
exception_mapping_worked = True
@@ -1664,7 +1664,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"TogetherAIException - {error_response['error']}",
model=model,
llm_provider="together_ai",
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
)
elif (
"error_type" in error_response
@@ -1675,7 +1675,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"TogetherAIException - {error_response['error']}",
model=model,
llm_provider="together_ai",
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
)
if hasattr(original_exception, "status_code"):
if original_exception.status_code == 408:
@@ -1691,7 +1691,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"TogetherAIException - {error_response['error']}",
model=model,
llm_provider="together_ai",
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
)
elif original_exception.status_code == 429:
exception_mapping_worked = True
@@ -1699,7 +1699,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"TogetherAIException - {original_exception.message}",
llm_provider="together_ai",
model=model,
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
)
elif original_exception.status_code == 524:
exception_mapping_worked = True
@@ -1727,7 +1727,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"AlephAlphaException - {original_exception.message}",
llm_provider="aleph_alpha",
model=model,
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
)
elif "InvalidToken" in error_str or "No token provided" in error_str:
exception_mapping_worked = True
@@ -1735,7 +1735,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"AlephAlphaException - {original_exception.message}",
llm_provider="aleph_alpha",
model=model,
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
)
elif hasattr(original_exception, "status_code"):
verbose_logger.debug(
@@ -1754,7 +1754,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"AlephAlphaException - {original_exception.message}",
llm_provider="aleph_alpha",
model=model,
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
)
elif original_exception.status_code == 429:
exception_mapping_worked = True
@@ -1762,7 +1762,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"AlephAlphaException - {original_exception.message}",
llm_provider="aleph_alpha",
model=model,
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
)
elif original_exception.status_code == 500:
exception_mapping_worked = True
@@ -1770,7 +1770,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"AlephAlphaException - {original_exception.message}",
llm_provider="aleph_alpha",
model=model,
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
)
raise original_exception
raise original_exception
@@ -1787,7 +1787,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"OllamaException: Invalid Model/Model not loaded - {original_exception}",
model=model,
llm_provider="ollama",
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
)
elif "Failed to establish a new connection" in error_str:
exception_mapping_worked = True
@@ -1795,7 +1795,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"OllamaException: {original_exception}",
llm_provider="ollama",
model=model,
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
)
elif "Invalid response object from API" in error_str:
exception_mapping_worked = True
@@ -1803,7 +1803,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"OllamaException: {original_exception}",
llm_provider="ollama",
model=model,
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
)
elif "Read timed out" in error_str:
exception_mapping_worked = True
@@ -1837,6 +1837,7 @@ def exception_type( # type: ignore # noqa: PLR0915
llm_provider="azure",
model=model,
litellm_debug_info=extra_information,
+ response=getattr(original_exception, "response", None),
)
elif "This model's maximum context length is" in error_str:
exception_mapping_worked = True
@@ -1845,6 +1846,7 @@ def exception_type( # type: ignore # noqa: PLR0915
llm_provider="azure",
model=model,
litellm_debug_info=extra_information,
+ response=getattr(original_exception, "response", None),
)
elif "DeploymentNotFound" in error_str:
exception_mapping_worked = True
@@ -1853,6 +1855,7 @@ def exception_type( # type: ignore # noqa: PLR0915
llm_provider="azure",
model=model,
litellm_debug_info=extra_information,
+ response=getattr(original_exception, "response", None),
)
elif (
(
@@ -1873,6 +1876,7 @@ def exception_type( # type: ignore # noqa: PLR0915
llm_provider="azure",
model=model,
litellm_debug_info=extra_information,
+ response=getattr(original_exception, "response", None),
)
elif "invalid_request_error" in error_str:
exception_mapping_worked = True
@@ -1881,6 +1885,7 @@ def exception_type( # type: ignore # noqa: PLR0915
llm_provider="azure",
model=model,
litellm_debug_info=extra_information,
+ response=getattr(original_exception, "response", None),
)
elif (
"The api_key client option must be set either by passing api_key to the client or by setting"
@@ -1892,6 +1897,7 @@ def exception_type( # type: ignore # noqa: PLR0915
llm_provider=custom_llm_provider,
model=model,
litellm_debug_info=extra_information,
+ response=getattr(original_exception, "response", None),
)
elif "Connection error" in error_str:
exception_mapping_worked = True
@@ -1910,6 +1916,7 @@ def exception_type( # type: ignore # noqa: PLR0915
llm_provider="azure",
model=model,
litellm_debug_info=extra_information,
+ response=getattr(original_exception, "response", None),
)
elif original_exception.status_code == 401:
exception_mapping_worked = True
@@ -1918,6 +1925,7 @@ def exception_type( # type: ignore # noqa: PLR0915
llm_provider="azure",
model=model,
litellm_debug_info=extra_information,
+ response=getattr(original_exception, "response", None),
)
elif original_exception.status_code == 408:
exception_mapping_worked = True
@@ -1934,6 +1942,7 @@ def exception_type( # type: ignore # noqa: PLR0915
model=model,
llm_provider="azure",
litellm_debug_info=extra_information,
+ response=getattr(original_exception, "response", None),
)
elif original_exception.status_code == 429:
exception_mapping_worked = True
@@ -1942,6 +1951,7 @@ def exception_type( # type: ignore # noqa: PLR0915
model=model,
llm_provider="azure",
litellm_debug_info=extra_information,
+ response=getattr(original_exception, "response", None),
)
elif original_exception.status_code == 503:
exception_mapping_worked = True
@@ -1950,6 +1960,7 @@ def exception_type( # type: ignore # noqa: PLR0915
model=model,
llm_provider="azure",
litellm_debug_info=extra_information,
+ response=getattr(original_exception, "response", None),
)
elif original_exception.status_code == 504: # gateway timeout error
exception_mapping_worked = True
@@ -1989,7 +2000,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"{exception_provider} - {error_str}",
llm_provider=custom_llm_provider,
model=model,
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information,
)
elif original_exception.status_code == 401:
@@ -1998,7 +2009,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"AuthenticationError: {exception_provider} - {error_str}",
llm_provider=custom_llm_provider,
model=model,
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information,
)
elif original_exception.status_code == 404:
@@ -2007,7 +2018,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"NotFoundError: {exception_provider} - {error_str}",
model=model,
llm_provider=custom_llm_provider,
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information,
)
elif original_exception.status_code == 408:
@@ -2024,7 +2035,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"BadRequestError: {exception_provider} - {error_str}",
model=model,
llm_provider=custom_llm_provider,
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information,
)
elif original_exception.status_code == 429:
@@ -2033,7 +2044,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"RateLimitError: {exception_provider} - {error_str}",
model=model,
llm_provider=custom_llm_provider,
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information,
)
elif original_exception.status_code == 503:
@@ -2042,7 +2053,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"ServiceUnavailableError: {exception_provider} - {error_str}",
model=model,
llm_provider=custom_llm_provider,
- response=original_exception.response,
+ response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information,
)
elif original_exception.status_code == 504: # gateway timeout error
diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py
index 66f91abf1..69d6adca4 100644
--- a/litellm/litellm_core_utils/litellm_logging.py
+++ b/litellm/litellm_core_utils/litellm_logging.py
@@ -202,6 +202,7 @@ class Logging:
start_time,
litellm_call_id: str,
function_id: str,
+ litellm_trace_id: Optional[str] = None,
dynamic_input_callbacks: Optional[
List[Union[str, Callable, CustomLogger]]
] = None,
@@ -239,6 +240,7 @@ class Logging:
self.start_time = start_time # log the call start time
self.call_type = call_type
self.litellm_call_id = litellm_call_id
+ self.litellm_trace_id = litellm_trace_id
self.function_id = function_id
self.streaming_chunks: List[Any] = [] # for generating complete stream response
self.sync_streaming_chunks: List[Any] = (
@@ -275,6 +277,11 @@ class Logging:
self.completion_start_time: Optional[datetime.datetime] = None
self._llm_caching_handler: Optional[LLMCachingHandler] = None
+ self.model_call_details = {
+ "litellm_trace_id": litellm_trace_id,
+ "litellm_call_id": litellm_call_id,
+ }
+
def process_dynamic_callbacks(self):
"""
Initializes CustomLogger compatible callbacks in self.dynamic_* callbacks
@@ -382,21 +389,23 @@ class Logging:
self.logger_fn = litellm_params.get("logger_fn", None)
verbose_logger.debug(f"self.optional_params: {self.optional_params}")
- self.model_call_details = {
- "model": self.model,
- "messages": self.messages,
- "optional_params": self.optional_params,
- "litellm_params": self.litellm_params,
- "start_time": self.start_time,
- "stream": self.stream,
- "user": user,
- "call_type": str(self.call_type),
- "litellm_call_id": self.litellm_call_id,
- "completion_start_time": self.completion_start_time,
- "standard_callback_dynamic_params": self.standard_callback_dynamic_params,
- **self.optional_params,
- **additional_params,
- }
+ self.model_call_details.update(
+ {
+ "model": self.model,
+ "messages": self.messages,
+ "optional_params": self.optional_params,
+ "litellm_params": self.litellm_params,
+ "start_time": self.start_time,
+ "stream": self.stream,
+ "user": user,
+ "call_type": str(self.call_type),
+ "litellm_call_id": self.litellm_call_id,
+ "completion_start_time": self.completion_start_time,
+ "standard_callback_dynamic_params": self.standard_callback_dynamic_params,
+ **self.optional_params,
+ **additional_params,
+ }
+ )
## check if stream options is set ## - used by CustomStreamWrapper for easy instrumentation
if "stream_options" in additional_params:
@@ -2823,6 +2832,7 @@ def get_standard_logging_object_payload(
payload: StandardLoggingPayload = StandardLoggingPayload(
id=str(id),
+ trace_id=kwargs.get("litellm_trace_id"), # type: ignore
call_type=call_type or "",
cache_hit=cache_hit,
status=status,
diff --git a/litellm/llms/anthropic/chat/handler.py b/litellm/llms/anthropic/chat/handler.py
index 2d119a28f..2952d54d5 100644
--- a/litellm/llms/anthropic/chat/handler.py
+++ b/litellm/llms/anthropic/chat/handler.py
@@ -44,7 +44,9 @@ from litellm.types.llms.openai import (
ChatCompletionToolCallFunctionChunk,
ChatCompletionUsageBlock,
)
-from litellm.types.utils import GenericStreamingChunk, PromptTokensDetailsWrapper
+from litellm.types.utils import GenericStreamingChunk
+from litellm.types.utils import Message as LitellmMessage
+from litellm.types.utils import PromptTokensDetailsWrapper
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
from ...base import BaseLLM
@@ -94,6 +96,7 @@ async def make_call(
messages: list,
logging_obj,
timeout: Optional[Union[float, httpx.Timeout]],
+ json_mode: bool,
) -> Tuple[Any, httpx.Headers]:
if client is None:
client = litellm.module_level_aclient
@@ -119,7 +122,9 @@ async def make_call(
raise AnthropicError(status_code=500, message=str(e))
completion_stream = ModelResponseIterator(
- streaming_response=response.aiter_lines(), sync_stream=False
+ streaming_response=response.aiter_lines(),
+ sync_stream=False,
+ json_mode=json_mode,
)
# LOGGING
@@ -142,6 +147,7 @@ def make_sync_call(
messages: list,
logging_obj,
timeout: Optional[Union[float, httpx.Timeout]],
+ json_mode: bool,
) -> Tuple[Any, httpx.Headers]:
if client is None:
client = litellm.module_level_client # re-use a module level client
@@ -175,7 +181,7 @@ def make_sync_call(
)
completion_stream = ModelResponseIterator(
- streaming_response=response.iter_lines(), sync_stream=True
+ streaming_response=response.iter_lines(), sync_stream=True, json_mode=json_mode
)
# LOGGING
@@ -270,11 +276,12 @@ class AnthropicChatCompletion(BaseLLM):
"arguments"
)
if json_mode_content_str is not None:
- args = json.loads(json_mode_content_str)
- values: Optional[dict] = args.get("values")
- if values is not None:
- _message = litellm.Message(content=json.dumps(values))
+ _converted_message = self._convert_tool_response_to_message(
+ tool_calls=tool_calls,
+ )
+ if _converted_message is not None:
completion_response["stop_reason"] = "stop"
+ _message = _converted_message
model_response.choices[0].message = _message # type: ignore
model_response._hidden_params["original_response"] = completion_response[
"content"
@@ -318,6 +325,37 @@ class AnthropicChatCompletion(BaseLLM):
model_response._hidden_params = _hidden_params
return model_response
+ @staticmethod
+ def _convert_tool_response_to_message(
+ tool_calls: List[ChatCompletionToolCallChunk],
+ ) -> Optional[LitellmMessage]:
+ """
+ In JSON mode, Anthropic API returns JSON schema as a tool call, we need to convert it to a message to follow the OpenAI format
+
+ """
+ ## HANDLE JSON MODE - anthropic returns single function call
+ json_mode_content_str: Optional[str] = tool_calls[0]["function"].get(
+ "arguments"
+ )
+ try:
+ if json_mode_content_str is not None:
+ args = json.loads(json_mode_content_str)
+ if (
+ isinstance(args, dict)
+ and (values := args.get("values")) is not None
+ ):
+ _message = litellm.Message(content=json.dumps(values))
+ return _message
+ else:
+ # a lot of the times the `values` key is not present in the tool response
+ # relevant issue: https://github.com/BerriAI/litellm/issues/6741
+ _message = litellm.Message(content=json.dumps(args))
+ return _message
+ except json.JSONDecodeError:
+ # json decode error does occur, return the original tool response str
+ return litellm.Message(content=json_mode_content_str)
+ return None
+
async def acompletion_stream_function(
self,
model: str,
@@ -334,6 +372,7 @@ class AnthropicChatCompletion(BaseLLM):
stream,
_is_function_call,
data: dict,
+ json_mode: bool,
optional_params=None,
litellm_params=None,
logger_fn=None,
@@ -350,6 +389,7 @@ class AnthropicChatCompletion(BaseLLM):
messages=messages,
logging_obj=logging_obj,
timeout=timeout,
+ json_mode=json_mode,
)
streamwrapper = CustomStreamWrapper(
completion_stream=completion_stream,
@@ -440,8 +480,8 @@ class AnthropicChatCompletion(BaseLLM):
logging_obj,
optional_params: dict,
timeout: Union[float, httpx.Timeout],
+ litellm_params: dict,
acompletion=None,
- litellm_params=None,
logger_fn=None,
headers={},
client=None,
@@ -464,6 +504,7 @@ class AnthropicChatCompletion(BaseLLM):
model=model,
messages=messages,
optional_params=optional_params,
+ litellm_params=litellm_params,
headers=headers,
_is_function_call=_is_function_call,
is_vertex_request=is_vertex_request,
@@ -500,6 +541,7 @@ class AnthropicChatCompletion(BaseLLM):
optional_params=optional_params,
stream=stream,
_is_function_call=_is_function_call,
+ json_mode=json_mode,
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=headers,
@@ -547,6 +589,7 @@ class AnthropicChatCompletion(BaseLLM):
messages=messages,
logging_obj=logging_obj,
timeout=timeout,
+ json_mode=json_mode,
)
return CustomStreamWrapper(
completion_stream=completion_stream,
@@ -605,11 +648,14 @@ class AnthropicChatCompletion(BaseLLM):
class ModelResponseIterator:
- def __init__(self, streaming_response, sync_stream: bool):
+ def __init__(
+ self, streaming_response, sync_stream: bool, json_mode: Optional[bool] = False
+ ):
self.streaming_response = streaming_response
self.response_iterator = self.streaming_response
self.content_blocks: List[ContentBlockDelta] = []
self.tool_index = -1
+ self.json_mode = json_mode
def check_empty_tool_call_args(self) -> bool:
"""
@@ -771,6 +817,8 @@ class ModelResponseIterator:
status_code=500, # it looks like Anthropic API does not return a status code in the chunk error - default to 500
)
+ text, tool_use = self._handle_json_mode_chunk(text=text, tool_use=tool_use)
+
returned_chunk = GenericStreamingChunk(
text=text,
tool_use=tool_use,
@@ -785,6 +833,34 @@ class ModelResponseIterator:
except json.JSONDecodeError:
raise ValueError(f"Failed to decode JSON from chunk: {chunk}")
+ def _handle_json_mode_chunk(
+ self, text: str, tool_use: Optional[ChatCompletionToolCallChunk]
+ ) -> Tuple[str, Optional[ChatCompletionToolCallChunk]]:
+ """
+ If JSON mode is enabled, convert the tool call to a message.
+
+ Anthropic returns the JSON schema as part of the tool call
+ OpenAI returns the JSON schema as part of the content, this handles placing it in the content
+
+ Args:
+ text: str
+ tool_use: Optional[ChatCompletionToolCallChunk]
+ Returns:
+ Tuple[str, Optional[ChatCompletionToolCallChunk]]
+
+ text: The text to use in the content
+ tool_use: The ChatCompletionToolCallChunk to use in the chunk response
+ """
+ if self.json_mode is True and tool_use is not None:
+ message = AnthropicChatCompletion._convert_tool_response_to_message(
+ tool_calls=[tool_use]
+ )
+ if message is not None:
+ text = message.content or ""
+ tool_use = None
+
+ return text, tool_use
+
# Sync iterator
def __iter__(self):
return self
diff --git a/litellm/llms/anthropic/chat/transformation.py b/litellm/llms/anthropic/chat/transformation.py
index e222d8721..28bd8d86f 100644
--- a/litellm/llms/anthropic/chat/transformation.py
+++ b/litellm/llms/anthropic/chat/transformation.py
@@ -91,6 +91,7 @@ class AnthropicConfig:
"extra_headers",
"parallel_tool_calls",
"response_format",
+ "user",
]
def get_cache_control_headers(self) -> dict:
@@ -246,6 +247,28 @@ class AnthropicConfig:
anthropic_tools.append(new_tool)
return anthropic_tools
+ def _map_stop_sequences(
+ self, stop: Optional[Union[str, List[str]]]
+ ) -> Optional[List[str]]:
+ new_stop: Optional[List[str]] = None
+ if isinstance(stop, str):
+ if (
+ stop == "\n"
+ ) and litellm.drop_params is True: # anthropic doesn't allow whitespace characters as stop-sequences
+ return new_stop
+ new_stop = [stop]
+ elif isinstance(stop, list):
+ new_v = []
+ for v in stop:
+ if (
+ v == "\n"
+ ) and litellm.drop_params is True: # anthropic doesn't allow whitespace characters as stop-sequences
+ continue
+ new_v.append(v)
+ if len(new_v) > 0:
+ new_stop = new_v
+ return new_stop
+
def map_openai_params(
self,
non_default_params: dict,
@@ -271,26 +294,10 @@ class AnthropicConfig:
optional_params["tool_choice"] = _tool_choice
if param == "stream" and value is True:
optional_params["stream"] = value
- if param == "stop":
- if isinstance(value, str):
- if (
- value == "\n"
- ) and litellm.drop_params is True: # anthropic doesn't allow whitespace characters as stop-sequences
- continue
- value = [value]
- elif isinstance(value, list):
- new_v = []
- for v in value:
- if (
- v == "\n"
- ) and litellm.drop_params is True: # anthropic doesn't allow whitespace characters as stop-sequences
- continue
- new_v.append(v)
- if len(new_v) > 0:
- value = new_v
- else:
- continue
- optional_params["stop_sequences"] = value
+ if param == "stop" and (isinstance(value, str) or isinstance(value, list)):
+ _value = self._map_stop_sequences(value)
+ if _value is not None:
+ optional_params["stop_sequences"] = _value
if param == "temperature":
optional_params["temperature"] = value
if param == "top_p":
@@ -314,7 +321,8 @@ class AnthropicConfig:
optional_params["tools"] = [_tool]
optional_params["tool_choice"] = _tool_choice
optional_params["json_mode"] = True
-
+ if param == "user":
+ optional_params["metadata"] = {"user_id": value}
## VALIDATE REQUEST
"""
Anthropic doesn't support tool calling without `tools=` param specified.
@@ -465,6 +473,7 @@ class AnthropicConfig:
model: str,
messages: List[AllMessageValues],
optional_params: dict,
+ litellm_params: dict,
headers: dict,
_is_function_call: bool,
is_vertex_request: bool,
@@ -502,6 +511,15 @@ class AnthropicConfig:
if "tools" in optional_params:
_is_function_call = True
+ ## Handle user_id in metadata
+ _litellm_metadata = litellm_params.get("metadata", None)
+ if (
+ _litellm_metadata
+ and isinstance(_litellm_metadata, dict)
+ and "user_id" in _litellm_metadata
+ ):
+ optional_params["metadata"] = {"user_id": _litellm_metadata["user_id"]}
+
data = {
"messages": anthropic_messages,
**optional_params,
diff --git a/litellm/llms/bedrock/image/amazon_stability3_transformation.py b/litellm/llms/bedrock/image/amazon_stability3_transformation.py
index 784e86b04..2c90b3a12 100644
--- a/litellm/llms/bedrock/image/amazon_stability3_transformation.py
+++ b/litellm/llms/bedrock/image/amazon_stability3_transformation.py
@@ -53,9 +53,15 @@ class AmazonStability3Config:
sd3-medium
sd3.5-large
sd3.5-large-turbo
+
+ Stability ultra models
+ stable-image-ultra-v1
"""
- if model and ("sd3" in model or "sd3.5" in model):
- return True
+ if model:
+ if "sd3" in model or "sd3.5" in model:
+ return True
+ if "stable-image-ultra-v1" in model:
+ return True
return False
@classmethod
diff --git a/litellm/llms/custom_httpx/types.py b/litellm/llms/custom_httpx/types.py
index dc0958118..8e6ad0eda 100644
--- a/litellm/llms/custom_httpx/types.py
+++ b/litellm/llms/custom_httpx/types.py
@@ -8,3 +8,4 @@ class httpxSpecialProvider(str, Enum):
GuardrailCallback = "guardrail_callback"
Caching = "caching"
Oauth2Check = "oauth2_check"
+ SecretManager = "secret_manager"
diff --git a/litellm/llms/jina_ai/embedding/transformation.py b/litellm/llms/jina_ai/embedding/transformation.py
index 26ff58878..97b7b2cfa 100644
--- a/litellm/llms/jina_ai/embedding/transformation.py
+++ b/litellm/llms/jina_ai/embedding/transformation.py
@@ -76,4 +76,4 @@ class JinaAIEmbeddingConfig:
or get_secret_str("JINA_AI_API_KEY")
or get_secret_str("JINA_AI_TOKEN")
)
- return LlmProviders.OPENAI_LIKE.value, api_base, dynamic_api_key
+ return LlmProviders.JINA_AI.value, api_base, dynamic_api_key
diff --git a/litellm/llms/jina_ai/rerank/handler.py b/litellm/llms/jina_ai/rerank/handler.py
new file mode 100644
index 000000000..a2cfdd49e
--- /dev/null
+++ b/litellm/llms/jina_ai/rerank/handler.py
@@ -0,0 +1,96 @@
+"""
+Re rank api
+
+LiteLLM supports the re rank API format, no paramter transformation occurs
+"""
+
+import uuid
+from typing import Any, Dict, List, Optional, Union
+
+import httpx
+from pydantic import BaseModel
+
+import litellm
+from litellm.llms.base import BaseLLM
+from litellm.llms.custom_httpx.http_handler import (
+ _get_httpx_client,
+ get_async_httpx_client,
+)
+from litellm.llms.jina_ai.rerank.transformation import JinaAIRerankConfig
+from litellm.types.rerank import RerankRequest, RerankResponse
+
+
+class JinaAIRerank(BaseLLM):
+ def rerank(
+ self,
+ model: str,
+ api_key: str,
+ query: str,
+ documents: List[Union[str, Dict[str, Any]]],
+ top_n: Optional[int] = None,
+ rank_fields: Optional[List[str]] = None,
+ return_documents: Optional[bool] = True,
+ max_chunks_per_doc: Optional[int] = None,
+ _is_async: Optional[bool] = False,
+ ) -> RerankResponse:
+ client = _get_httpx_client()
+
+ request_data = RerankRequest(
+ model=model,
+ query=query,
+ top_n=top_n,
+ documents=documents,
+ rank_fields=rank_fields,
+ return_documents=return_documents,
+ )
+
+ # exclude None values from request_data
+ request_data_dict = request_data.dict(exclude_none=True)
+
+ if _is_async:
+ return self.async_rerank(request_data_dict, api_key) # type: ignore # Call async method
+
+ response = client.post(
+ "https://api.jina.ai/v1/rerank",
+ headers={
+ "accept": "application/json",
+ "content-type": "application/json",
+ "authorization": f"Bearer {api_key}",
+ },
+ json=request_data_dict,
+ )
+
+ if response.status_code != 200:
+ raise Exception(response.text)
+
+ _json_response = response.json()
+
+ return JinaAIRerankConfig()._transform_response(_json_response)
+
+ async def async_rerank( # New async method
+ self,
+ request_data_dict: Dict[str, Any],
+ api_key: str,
+ ) -> RerankResponse:
+ client = get_async_httpx_client(
+ llm_provider=litellm.LlmProviders.JINA_AI
+ ) # Use async client
+
+ response = await client.post(
+ "https://api.jina.ai/v1/rerank",
+ headers={
+ "accept": "application/json",
+ "content-type": "application/json",
+ "authorization": f"Bearer {api_key}",
+ },
+ json=request_data_dict,
+ )
+
+ if response.status_code != 200:
+ raise Exception(response.text)
+
+ _json_response = response.json()
+
+ return JinaAIRerankConfig()._transform_response(_json_response)
+
+ pass
diff --git a/litellm/llms/jina_ai/rerank/transformation.py b/litellm/llms/jina_ai/rerank/transformation.py
new file mode 100644
index 000000000..82039a15b
--- /dev/null
+++ b/litellm/llms/jina_ai/rerank/transformation.py
@@ -0,0 +1,36 @@
+"""
+Transformation logic from Cohere's /v1/rerank format to Jina AI's `/v1/rerank` format.
+
+Why separate file? Make it easy to see how transformation works
+
+Docs - https://jina.ai/reranker
+"""
+
+import uuid
+from typing import List, Optional
+
+from litellm.types.rerank import (
+ RerankBilledUnits,
+ RerankResponse,
+ RerankResponseMeta,
+ RerankTokens,
+)
+
+
+class JinaAIRerankConfig:
+ def _transform_response(self, response: dict) -> RerankResponse:
+
+ _billed_units = RerankBilledUnits(**response.get("usage", {}))
+ _tokens = RerankTokens(**response.get("usage", {}))
+ rerank_meta = RerankResponseMeta(billed_units=_billed_units, tokens=_tokens)
+
+ _results: Optional[List[dict]] = response.get("results")
+
+ if _results is None:
+ raise ValueError(f"No results found in the response={response}")
+
+ return RerankResponse(
+ id=response.get("id") or str(uuid.uuid4()),
+ results=_results,
+ meta=rerank_meta,
+ ) # Return response
diff --git a/litellm/llms/ollama.py b/litellm/llms/ollama.py
index 845d0e2dd..842d946c6 100644
--- a/litellm/llms/ollama.py
+++ b/litellm/llms/ollama.py
@@ -185,6 +185,8 @@ class OllamaConfig:
"name": "mistral"
}'
"""
+ if model.startswith("ollama/") or model.startswith("ollama_chat/"):
+ model = model.split("/", 1)[1]
api_base = get_secret_str("OLLAMA_API_BASE") or "http://localhost:11434"
try:
diff --git a/litellm/llms/together_ai/rerank.py b/litellm/llms/together_ai/rerank/handler.py
similarity index 84%
rename from litellm/llms/together_ai/rerank.py
rename to litellm/llms/together_ai/rerank/handler.py
index 1be73af2d..3e6d5d667 100644
--- a/litellm/llms/together_ai/rerank.py
+++ b/litellm/llms/together_ai/rerank/handler.py
@@ -15,7 +15,14 @@ from litellm.llms.custom_httpx.http_handler import (
_get_httpx_client,
get_async_httpx_client,
)
-from litellm.types.rerank import RerankRequest, RerankResponse
+from litellm.llms.together_ai.rerank.transformation import TogetherAIRerankConfig
+from litellm.types.rerank import (
+ RerankBilledUnits,
+ RerankRequest,
+ RerankResponse,
+ RerankResponseMeta,
+ RerankTokens,
+)
class TogetherAIRerank(BaseLLM):
@@ -65,13 +72,7 @@ class TogetherAIRerank(BaseLLM):
_json_response = response.json()
- response = RerankResponse(
- id=_json_response.get("id"),
- results=_json_response.get("results"),
- meta=_json_response.get("meta") or {},
- )
-
- return response
+ return TogetherAIRerankConfig()._transform_response(_json_response)
async def async_rerank( # New async method
self,
@@ -97,10 +98,4 @@ class TogetherAIRerank(BaseLLM):
_json_response = response.json()
- return RerankResponse(
- id=_json_response.get("id"),
- results=_json_response.get("results"),
- meta=_json_response.get("meta") or {},
- ) # Return response
-
- pass
+ return TogetherAIRerankConfig()._transform_response(_json_response)
diff --git a/litellm/llms/together_ai/rerank/transformation.py b/litellm/llms/together_ai/rerank/transformation.py
new file mode 100644
index 000000000..b2024b5cd
--- /dev/null
+++ b/litellm/llms/together_ai/rerank/transformation.py
@@ -0,0 +1,34 @@
+"""
+Transformation logic from Cohere's /v1/rerank format to Together AI's `/v1/rerank` format.
+
+Why separate file? Make it easy to see how transformation works
+"""
+
+import uuid
+from typing import List, Optional
+
+from litellm.types.rerank import (
+ RerankBilledUnits,
+ RerankResponse,
+ RerankResponseMeta,
+ RerankTokens,
+)
+
+
+class TogetherAIRerankConfig:
+ def _transform_response(self, response: dict) -> RerankResponse:
+
+ _billed_units = RerankBilledUnits(**response.get("usage", {}))
+ _tokens = RerankTokens(**response.get("usage", {}))
+ rerank_meta = RerankResponseMeta(billed_units=_billed_units, tokens=_tokens)
+
+ _results: Optional[List[dict]] = response.get("results")
+
+ if _results is None:
+ raise ValueError(f"No results found in the response={response}")
+
+ return RerankResponse(
+ id=response.get("id") or str(uuid.uuid4()),
+ results=_results,
+ meta=rerank_meta,
+ ) # Return response
diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/common_utils.py b/litellm/llms/vertex_ai_and_google_ai_studio/common_utils.py
index 0f95b222c..74bab0b26 100644
--- a/litellm/llms/vertex_ai_and_google_ai_studio/common_utils.py
+++ b/litellm/llms/vertex_ai_and_google_ai_studio/common_utils.py
@@ -89,6 +89,9 @@ def _get_vertex_url(
elif mode == "embedding":
endpoint = "predict"
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}"
+ if model.isdigit():
+ # https://us-central1-aiplatform.googleapis.com/v1/projects/$PROJECT_ID/locations/us-central1/endpoints/$ENDPOINT_ID:predict
+ url = f"https://{vertex_location}-aiplatform.googleapis.com/{vertex_api_version}/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}:{endpoint}"
if not url or not endpoint:
raise ValueError(f"Unable to get vertex url/endpoint for mode: {mode}")
diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/image_generation/cost_calculator.py b/litellm/llms/vertex_ai_and_google_ai_studio/image_generation/cost_calculator.py
new file mode 100644
index 000000000..2d7fa37f7
--- /dev/null
+++ b/litellm/llms/vertex_ai_and_google_ai_studio/image_generation/cost_calculator.py
@@ -0,0 +1,25 @@
+"""
+Vertex AI Image Generation Cost Calculator
+"""
+
+from typing import Optional
+
+import litellm
+from litellm.types.utils import ImageResponse
+
+
+def cost_calculator(
+ model: str,
+ image_response: ImageResponse,
+) -> float:
+ """
+ Vertex AI Image Generation Cost Calculator
+ """
+ _model_info = litellm.get_model_info(
+ model=model,
+ custom_llm_provider="vertex_ai",
+ )
+
+ output_cost_per_image: float = _model_info.get("output_cost_per_image") or 0.0
+ num_images: int = len(image_response.data)
+ return output_cost_per_image * num_images
diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_embeddings/embedding_handler.py b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_embeddings/embedding_handler.py
index 0cde5c3b5..26741ff4f 100644
--- a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_embeddings/embedding_handler.py
+++ b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_embeddings/embedding_handler.py
@@ -96,7 +96,7 @@ class VertexEmbedding(VertexBase):
headers = self.set_headers(auth_header=auth_header, extra_headers=extra_headers)
vertex_request: VertexEmbeddingRequest = (
litellm.vertexAITextEmbeddingConfig.transform_openai_request_to_vertex_embedding_request(
- input=input, optional_params=optional_params
+ input=input, optional_params=optional_params, model=model
)
)
@@ -188,7 +188,7 @@ class VertexEmbedding(VertexBase):
headers = self.set_headers(auth_header=auth_header, extra_headers=extra_headers)
vertex_request: VertexEmbeddingRequest = (
litellm.vertexAITextEmbeddingConfig.transform_openai_request_to_vertex_embedding_request(
- input=input, optional_params=optional_params
+ input=input, optional_params=optional_params, model=model
)
)
diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_embeddings/transformation.py b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_embeddings/transformation.py
index 1ca405392..6f4b25cef 100644
--- a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_embeddings/transformation.py
+++ b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_embeddings/transformation.py
@@ -101,11 +101,16 @@ class VertexAITextEmbeddingConfig(BaseModel):
return optional_params
def transform_openai_request_to_vertex_embedding_request(
- self, input: Union[list, str], optional_params: dict
+ self, input: Union[list, str], optional_params: dict, model: str
) -> VertexEmbeddingRequest:
"""
Transforms an openai request to a vertex embedding request.
"""
+ if model.isdigit():
+ return self._transform_openai_request_to_fine_tuned_embedding_request(
+ input, optional_params, model
+ )
+
vertex_request: VertexEmbeddingRequest = VertexEmbeddingRequest()
vertex_text_embedding_input_list: List[TextEmbeddingInput] = []
task_type: Optional[TaskType] = optional_params.get("task_type")
@@ -125,6 +130,47 @@ class VertexAITextEmbeddingConfig(BaseModel):
return vertex_request
+ def _transform_openai_request_to_fine_tuned_embedding_request(
+ self, input: Union[list, str], optional_params: dict, model: str
+ ) -> VertexEmbeddingRequest:
+ """
+ Transforms an openai request to a vertex fine-tuned embedding request.
+
+ Vertex Doc: https://console.cloud.google.com/vertex-ai/model-garden?hl=en&project=adroit-crow-413218&pageState=(%22galleryStateKey%22:(%22f%22:(%22g%22:%5B%5D,%22o%22:%5B%5D),%22s%22:%22%22))
+ Sample Request:
+
+ ```json
+ {
+ "instances" : [
+ {
+ "inputs": "How would the Future of AI in 10 Years look?",
+ "parameters": {
+ "max_new_tokens": 128,
+ "temperature": 1.0,
+ "top_p": 0.9,
+ "top_k": 10
+ }
+ }
+ ]
+ }
+ ```
+ """
+ vertex_request: VertexEmbeddingRequest = VertexEmbeddingRequest()
+ vertex_text_embedding_input_list: List[TextEmbeddingFineTunedInput] = []
+ if isinstance(input, str):
+ input = [input] # Convert single string to list for uniform processing
+
+ for text in input:
+ embedding_input = TextEmbeddingFineTunedInput(inputs=text)
+ vertex_text_embedding_input_list.append(embedding_input)
+
+ vertex_request["instances"] = vertex_text_embedding_input_list
+ vertex_request["parameters"] = TextEmbeddingFineTunedParameters(
+ **optional_params
+ )
+
+ return vertex_request
+
def create_embedding_input(
self,
content: str,
@@ -157,6 +203,11 @@ class VertexAITextEmbeddingConfig(BaseModel):
"""
Transforms a vertex embedding response to an openai response.
"""
+ if model.isdigit():
+ return self._transform_vertex_response_to_openai_for_fine_tuned_models(
+ response, model, model_response
+ )
+
_predictions = response["predictions"]
embedding_response = []
@@ -181,3 +232,35 @@ class VertexAITextEmbeddingConfig(BaseModel):
)
setattr(model_response, "usage", usage)
return model_response
+
+ def _transform_vertex_response_to_openai_for_fine_tuned_models(
+ self, response: dict, model: str, model_response: litellm.EmbeddingResponse
+ ) -> litellm.EmbeddingResponse:
+ """
+ Transforms a vertex fine-tuned model embedding response to an openai response format.
+ """
+ _predictions = response["predictions"]
+
+ embedding_response = []
+ # For fine-tuned models, we don't get token counts in the response
+ input_tokens = 0
+
+ for idx, embedding_values in enumerate(_predictions):
+ embedding_response.append(
+ {
+ "object": "embedding",
+ "index": idx,
+ "embedding": embedding_values[
+ 0
+ ], # The embedding values are nested one level deeper
+ }
+ )
+
+ model_response.object = "list"
+ model_response.data = embedding_response
+ model_response.model = model
+ usage = Usage(
+ prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
+ )
+ setattr(model_response, "usage", usage)
+ return model_response
diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_embeddings/types.py b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_embeddings/types.py
index 311809c82..433305516 100644
--- a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_embeddings/types.py
+++ b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_embeddings/types.py
@@ -23,14 +23,27 @@ class TextEmbeddingInput(TypedDict, total=False):
title: Optional[str]
+# Fine-tuned models require a different input format
+# Ref: https://console.cloud.google.com/vertex-ai/model-garden?hl=en&project=adroit-crow-413218&pageState=(%22galleryStateKey%22:(%22f%22:(%22g%22:%5B%5D,%22o%22:%5B%5D),%22s%22:%22%22))
+class TextEmbeddingFineTunedInput(TypedDict, total=False):
+ inputs: str
+
+
+class TextEmbeddingFineTunedParameters(TypedDict, total=False):
+ max_new_tokens: Optional[int]
+ temperature: Optional[float]
+ top_p: Optional[float]
+ top_k: Optional[int]
+
+
class EmbeddingParameters(TypedDict, total=False):
auto_truncate: Optional[bool]
output_dimensionality: Optional[int]
class VertexEmbeddingRequest(TypedDict, total=False):
- instances: List[TextEmbeddingInput]
- parameters: Optional[EmbeddingParameters]
+ instances: Union[List[TextEmbeddingInput], List[TextEmbeddingFineTunedInput]]
+ parameters: Optional[Union[EmbeddingParameters, TextEmbeddingFineTunedParameters]]
# Example usage:
diff --git a/litellm/main.py b/litellm/main.py
index afb46c698..543a93eea 100644
--- a/litellm/main.py
+++ b/litellm/main.py
@@ -1066,6 +1066,7 @@ def completion( # type: ignore # noqa: PLR0915
azure_ad_token_provider=kwargs.get("azure_ad_token_provider"),
user_continue_message=kwargs.get("user_continue_message"),
base_model=base_model,
+ litellm_trace_id=kwargs.get("litellm_trace_id"),
)
logging.update_environment_variables(
model=model,
@@ -3455,7 +3456,7 @@ def embedding( # noqa: PLR0915
client=client,
aembedding=aembedding,
)
- elif custom_llm_provider == "openai_like":
+ elif custom_llm_provider == "openai_like" or custom_llm_provider == "jina_ai":
api_base = (
api_base or litellm.api_base or get_secret_str("OPENAI_LIKE_API_BASE")
)
diff --git a/litellm/model_prices_and_context_window_backup.json b/litellm/model_prices_and_context_window_backup.json
index fb8fb105c..cae3bee12 100644
--- a/litellm/model_prices_and_context_window_backup.json
+++ b/litellm/model_prices_and_context_window_backup.json
@@ -2986,19 +2986,19 @@
"supports_function_calling": true
},
"vertex_ai/imagegeneration@006": {
- "cost_per_image": 0.020,
+ "output_cost_per_image": 0.020,
"litellm_provider": "vertex_ai-image-models",
"mode": "image_generation",
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing"
},
"vertex_ai/imagen-3.0-generate-001": {
- "cost_per_image": 0.04,
+ "output_cost_per_image": 0.04,
"litellm_provider": "vertex_ai-image-models",
"mode": "image_generation",
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing"
},
"vertex_ai/imagen-3.0-fast-generate-001": {
- "cost_per_image": 0.02,
+ "output_cost_per_image": 0.02,
"litellm_provider": "vertex_ai-image-models",
"mode": "image_generation",
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing"
@@ -5620,6 +5620,13 @@
"litellm_provider": "bedrock",
"mode": "image_generation"
},
+ "stability.stable-image-ultra-v1:0": {
+ "max_tokens": 77,
+ "max_input_tokens": 77,
+ "output_cost_per_image": 0.14,
+ "litellm_provider": "bedrock",
+ "mode": "image_generation"
+ },
"sagemaker/meta-textgeneration-llama-2-7b": {
"max_tokens": 4096,
"max_input_tokens": 4096,
diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml
index 911f15b86..b06a9e667 100644
--- a/litellm/proxy/_new_secret_config.yaml
+++ b/litellm/proxy/_new_secret_config.yaml
@@ -1,122 +1,15 @@
model_list:
- - model_name: "*"
- litellm_params:
- model: claude-3-5-sonnet-20240620
- api_key: os.environ/ANTHROPIC_API_KEY
- - model_name: claude-3-5-sonnet-aihubmix
- litellm_params:
- model: openai/claude-3-5-sonnet-20240620
- input_cost_per_token: 0.000003 # 3$/M
- output_cost_per_token: 0.000015 # 15$/M
- api_base: "https://exampleopenaiendpoint-production.up.railway.app"
- api_key: my-fake-key
- - model_name: fake-openai-endpoint-2
- litellm_params:
- model: openai/my-fake-model
- api_key: my-fake-key
- api_base: https://exampleopenaiendpoint-production.up.railway.app/
- stream_timeout: 0.001
- timeout: 1
- rpm: 1
- - model_name: fake-openai-endpoint
- litellm_params:
- model: openai/my-fake-model
- api_key: my-fake-key
- api_base: https://exampleopenaiendpoint-production.up.railway.app/
- ## bedrock chat completions
- - model_name: "*anthropic.claude*"
- litellm_params:
- model: bedrock/*anthropic.claude*
- aws_access_key_id: os.environ/BEDROCK_AWS_ACCESS_KEY_ID
- aws_secret_access_key: os.environ/BEDROCK_AWS_SECRET_ACCESS_KEY
- aws_region_name: os.environ/AWS_REGION_NAME
- guardrailConfig:
- "guardrailIdentifier": "h4dsqwhp6j66"
- "guardrailVersion": "2"
- "trace": "enabled"
-
-## bedrock embeddings
- - model_name: "*amazon.titan-embed-*"
- litellm_params:
- model: bedrock/amazon.titan-embed-*
- aws_access_key_id: os.environ/BEDROCK_AWS_ACCESS_KEY_ID
- aws_secret_access_key: os.environ/BEDROCK_AWS_SECRET_ACCESS_KEY
- aws_region_name: os.environ/AWS_REGION_NAME
- - model_name: "*cohere.embed-*"
- litellm_params:
- model: bedrock/cohere.embed-*
- aws_access_key_id: os.environ/BEDROCK_AWS_ACCESS_KEY_ID
- aws_secret_access_key: os.environ/BEDROCK_AWS_SECRET_ACCESS_KEY
- aws_region_name: os.environ/AWS_REGION_NAME
-
- - model_name: "bedrock/*"
- litellm_params:
- model: bedrock/*
- aws_access_key_id: os.environ/BEDROCK_AWS_ACCESS_KEY_ID
- aws_secret_access_key: os.environ/BEDROCK_AWS_SECRET_ACCESS_KEY
- aws_region_name: os.environ/AWS_REGION_NAME
-
+ # GPT-4 Turbo Models
- model_name: gpt-4
litellm_params:
- model: azure/chatgpt-v-2
- api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
- api_version: "2023-05-15"
- api_key: os.environ/AZURE_API_KEY # The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault
- rpm: 480
- timeout: 300
- stream_timeout: 60
-
-litellm_settings:
- fallbacks: [{ "claude-3-5-sonnet-20240620": ["claude-3-5-sonnet-aihubmix"] }]
- # callbacks: ["otel", "prometheus"]
- default_redis_batch_cache_expiry: 10
- # default_team_settings:
- # - team_id: "dbe2f686-a686-4896-864a-4c3924458709"
- # success_callback: ["langfuse"]
- # langfuse_public_key: os.environ/LANGFUSE_PUB_KEY_1 # Project 1
- # langfuse_secret: os.environ/LANGFUSE_PRIVATE_KEY_1 # Project 1
-
-# litellm_settings:
-# cache: True
-# cache_params:
-# type: redis
-
-# # disable caching on the actual API call
-# supported_call_types: []
-
-# # see https://docs.litellm.ai/docs/proxy/prod#3-use-redis-porthost-password-not-redis_url
-# host: os.environ/REDIS_HOST
-# port: os.environ/REDIS_PORT
-# password: os.environ/REDIS_PASSWORD
-
-# # see https://docs.litellm.ai/docs/proxy/caching#turn-on-batch_redis_requests
-# # see https://docs.litellm.ai/docs/proxy/prometheus
-# callbacks: ['otel']
+ model: gpt-4
+ - model_name: rerank-model
+ litellm_params:
+ model: jina_ai/jina-reranker-v2-base-multilingual
-# # router_settings:
-# # routing_strategy: latency-based-routing
-# # routing_strategy_args:
-# # # only assign 40% of traffic to the fastest deployment to avoid overloading it
-# # lowest_latency_buffer: 0.4
-
-# # # consider last five minutes of calls for latency calculation
-# # ttl: 300
-# # redis_host: os.environ/REDIS_HOST
-# # redis_port: os.environ/REDIS_PORT
-# # redis_password: os.environ/REDIS_PASSWORD
-
-# # # see https://docs.litellm.ai/docs/proxy/prod#1-use-this-configyaml
-# # general_settings:
-# # master_key: os.environ/LITELLM_MASTER_KEY
-# # database_url: os.environ/DATABASE_URL
-# # disable_master_key_return: true
-# # # alerting: ['slack', 'email']
-# # alerting: ['email']
-
-# # # Batch write spend updates every 60s
-# # proxy_batch_write_at: 60
-
-# # # see https://docs.litellm.ai/docs/proxy/caching#advanced---user-api-key-cache-ttl
-# # # our api keys rarely change
-# # user_api_key_cache_ttl: 3600
+router_settings:
+ model_group_alias:
+ "gpt-4-turbo": # Aliased model name
+ model: "gpt-4" # Actual model name in 'model_list'
+ hidden: true
\ No newline at end of file
diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py
index 2d869af85..4baf13b61 100644
--- a/litellm/proxy/_types.py
+++ b/litellm/proxy/_types.py
@@ -1128,7 +1128,16 @@ class KeyManagementSystem(enum.Enum):
class KeyManagementSettings(LiteLLMBase):
- hosted_keys: List
+ hosted_keys: Optional[List] = None
+ store_virtual_keys: Optional[bool] = False
+ """
+ If True, virtual keys created by litellm will be stored in the secret manager
+ """
+
+ access_mode: Literal["read_only", "write_only", "read_and_write"] = "read_only"
+ """
+ Access mode for the secret manager, when write_only will only use for writing secrets
+ """
class TeamDefaultSettings(LiteLLMBase):
diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py
index 12b6ec372..8d3afa33f 100644
--- a/litellm/proxy/auth/auth_checks.py
+++ b/litellm/proxy/auth/auth_checks.py
@@ -8,6 +8,7 @@ Run checks for:
2. If user is in budget
3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
"""
+
import time
import traceback
from datetime import datetime
diff --git a/litellm/proxy/hooks/key_management_event_hooks.py b/litellm/proxy/hooks/key_management_event_hooks.py
new file mode 100644
index 000000000..08645a468
--- /dev/null
+++ b/litellm/proxy/hooks/key_management_event_hooks.py
@@ -0,0 +1,267 @@
+import asyncio
+import json
+import uuid
+from datetime import datetime, timezone
+from re import A
+from typing import Any, List, Optional
+
+from fastapi import status
+
+import litellm
+from litellm._logging import verbose_proxy_logger
+from litellm.proxy._types import (
+ GenerateKeyRequest,
+ KeyManagementSystem,
+ KeyRequest,
+ LiteLLM_AuditLogs,
+ LiteLLM_VerificationToken,
+ LitellmTableNames,
+ ProxyErrorTypes,
+ ProxyException,
+ UpdateKeyRequest,
+ UserAPIKeyAuth,
+ WebhookEvent,
+)
+
+
+class KeyManagementEventHooks:
+
+ @staticmethod
+ async def async_key_generated_hook(
+ data: GenerateKeyRequest,
+ response: dict,
+ user_api_key_dict: UserAPIKeyAuth,
+ litellm_changed_by: Optional[str] = None,
+ ):
+ """
+ Hook that runs after a successful /key/generate request
+
+ Handles the following:
+ - Sending Email with Key Details
+ - Storing Audit Logs for key generation
+ - Storing Generated Key in DB
+ """
+ from litellm.proxy.management_helpers.audit_logs import (
+ create_audit_log_for_update,
+ )
+ from litellm.proxy.proxy_server import (
+ general_settings,
+ litellm_proxy_admin_name,
+ proxy_logging_obj,
+ )
+
+ if data.send_invite_email is True:
+ await KeyManagementEventHooks._send_key_created_email(response)
+
+ # Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
+ if litellm.store_audit_logs is True:
+ _updated_values = json.dumps(response, default=str)
+ asyncio.create_task(
+ create_audit_log_for_update(
+ request_data=LiteLLM_AuditLogs(
+ id=str(uuid.uuid4()),
+ updated_at=datetime.now(timezone.utc),
+ changed_by=litellm_changed_by
+ or user_api_key_dict.user_id
+ or litellm_proxy_admin_name,
+ changed_by_api_key=user_api_key_dict.api_key,
+ table_name=LitellmTableNames.KEY_TABLE_NAME,
+ object_id=response.get("token_id", ""),
+ action="created",
+ updated_values=_updated_values,
+ before_value=None,
+ )
+ )
+ )
+ # store the generated key in the secret manager
+ await KeyManagementEventHooks._store_virtual_key_in_secret_manager(
+ secret_name=data.key_alias or f"virtual-key-{uuid.uuid4()}",
+ secret_token=response.get("token", ""),
+ )
+
+ @staticmethod
+ async def async_key_updated_hook(
+ data: UpdateKeyRequest,
+ existing_key_row: Any,
+ response: Any,
+ user_api_key_dict: UserAPIKeyAuth,
+ litellm_changed_by: Optional[str] = None,
+ ):
+ """
+ Post /key/update processing hook
+
+ Handles the following:
+ - Storing Audit Logs for key update
+ """
+ from litellm.proxy.management_helpers.audit_logs import (
+ create_audit_log_for_update,
+ )
+ from litellm.proxy.proxy_server import litellm_proxy_admin_name
+
+ # Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
+ if litellm.store_audit_logs is True:
+ _updated_values = json.dumps(data.json(exclude_none=True), default=str)
+
+ _before_value = existing_key_row.json(exclude_none=True)
+ _before_value = json.dumps(_before_value, default=str)
+
+ asyncio.create_task(
+ create_audit_log_for_update(
+ request_data=LiteLLM_AuditLogs(
+ id=str(uuid.uuid4()),
+ updated_at=datetime.now(timezone.utc),
+ changed_by=litellm_changed_by
+ or user_api_key_dict.user_id
+ or litellm_proxy_admin_name,
+ changed_by_api_key=user_api_key_dict.api_key,
+ table_name=LitellmTableNames.KEY_TABLE_NAME,
+ object_id=data.key,
+ action="updated",
+ updated_values=_updated_values,
+ before_value=_before_value,
+ )
+ )
+ )
+ pass
+
+ @staticmethod
+ async def async_key_deleted_hook(
+ data: KeyRequest,
+ keys_being_deleted: List[LiteLLM_VerificationToken],
+ response: dict,
+ user_api_key_dict: UserAPIKeyAuth,
+ litellm_changed_by: Optional[str] = None,
+ ):
+ """
+ Post /key/delete processing hook
+
+ Handles the following:
+ - Storing Audit Logs for key deletion
+ """
+ from litellm.proxy.management_helpers.audit_logs import (
+ create_audit_log_for_update,
+ )
+ from litellm.proxy.proxy_server import litellm_proxy_admin_name, prisma_client
+
+ # Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
+ # we do this after the first for loop, since first for loop is for validation. we only want this inserted after validation passes
+ if litellm.store_audit_logs is True:
+ # make an audit log for each team deleted
+ for key in data.keys:
+ key_row = await prisma_client.get_data( # type: ignore
+ token=key, table_name="key", query_type="find_unique"
+ )
+
+ if key_row is None:
+ raise ProxyException(
+ message=f"Key {key} not found",
+ type=ProxyErrorTypes.bad_request_error,
+ param="key",
+ code=status.HTTP_404_NOT_FOUND,
+ )
+
+ key_row = key_row.json(exclude_none=True)
+ _key_row = json.dumps(key_row, default=str)
+
+ asyncio.create_task(
+ create_audit_log_for_update(
+ request_data=LiteLLM_AuditLogs(
+ id=str(uuid.uuid4()),
+ updated_at=datetime.now(timezone.utc),
+ changed_by=litellm_changed_by
+ or user_api_key_dict.user_id
+ or litellm_proxy_admin_name,
+ changed_by_api_key=user_api_key_dict.api_key,
+ table_name=LitellmTableNames.KEY_TABLE_NAME,
+ object_id=key,
+ action="deleted",
+ updated_values="{}",
+ before_value=_key_row,
+ )
+ )
+ )
+ # delete the keys from the secret manager
+ await KeyManagementEventHooks._delete_virtual_keys_from_secret_manager(
+ keys_being_deleted=keys_being_deleted
+ )
+ pass
+
+ @staticmethod
+ async def _store_virtual_key_in_secret_manager(secret_name: str, secret_token: str):
+ """
+ Store a virtual key in the secret manager
+
+ Args:
+ secret_name: Name of the virtual key
+ secret_token: Value of the virtual key (example: sk-1234)
+ """
+ if litellm._key_management_settings is not None:
+ if litellm._key_management_settings.store_virtual_keys is True:
+ from litellm.secret_managers.aws_secret_manager_v2 import (
+ AWSSecretsManagerV2,
+ )
+
+ # store the key in the secret manager
+ if (
+ litellm._key_management_system
+ == KeyManagementSystem.AWS_SECRET_MANAGER
+ and isinstance(litellm.secret_manager_client, AWSSecretsManagerV2)
+ ):
+ await litellm.secret_manager_client.async_write_secret(
+ secret_name=secret_name,
+ secret_value=secret_token,
+ )
+
+ @staticmethod
+ async def _delete_virtual_keys_from_secret_manager(
+ keys_being_deleted: List[LiteLLM_VerificationToken],
+ ):
+ """
+ Deletes virtual keys from the secret manager
+
+ Args:
+ keys_being_deleted: List of keys being deleted, this is passed down from the /key/delete operation
+ """
+ if litellm._key_management_settings is not None:
+ if litellm._key_management_settings.store_virtual_keys is True:
+ from litellm.secret_managers.aws_secret_manager_v2 import (
+ AWSSecretsManagerV2,
+ )
+
+ if isinstance(litellm.secret_manager_client, AWSSecretsManagerV2):
+ for key in keys_being_deleted:
+ if key.key_alias is not None:
+ await litellm.secret_manager_client.async_delete_secret(
+ secret_name=key.key_alias
+ )
+ else:
+ verbose_proxy_logger.warning(
+ f"KeyManagementEventHooks._delete_virtual_key_from_secret_manager: Key alias not found for key {key.token}. Skipping deletion from secret manager."
+ )
+
+ @staticmethod
+ async def _send_key_created_email(response: dict):
+ from litellm.proxy.proxy_server import general_settings, proxy_logging_obj
+
+ if "email" not in general_settings.get("alerting", []):
+ raise ValueError(
+ "Email alerting not setup on config.yaml. Please set `alerting=['email']. \nDocs: https://docs.litellm.ai/docs/proxy/email`"
+ )
+ event = WebhookEvent(
+ event="key_created",
+ event_group="key",
+ event_message="API Key Created",
+ token=response.get("token", ""),
+ spend=response.get("spend", 0.0),
+ max_budget=response.get("max_budget", 0.0),
+ user_id=response.get("user_id", None),
+ team_id=response.get("team_id", "Default Team"),
+ key_alias=response.get("key_alias", None),
+ )
+
+ # If user configured email alerting - send an Email letting their end-user know the key was created
+ asyncio.create_task(
+ proxy_logging_obj.slack_alerting_instance.send_key_created_or_user_invited_email(
+ webhook_event=event,
+ )
+ )
diff --git a/litellm/proxy/litellm_pre_call_utils.py b/litellm/proxy/litellm_pre_call_utils.py
index 789e79f37..3d1d3b491 100644
--- a/litellm/proxy/litellm_pre_call_utils.py
+++ b/litellm/proxy/litellm_pre_call_utils.py
@@ -274,6 +274,51 @@ class LiteLLMProxyRequestSetup:
)
return user_api_key_logged_metadata
+ @staticmethod
+ def add_key_level_controls(
+ key_metadata: dict, data: dict, _metadata_variable_name: str
+ ):
+ data = data.copy()
+ if "cache" in key_metadata:
+ data["cache"] = {}
+ if isinstance(key_metadata["cache"], dict):
+ for k, v in key_metadata["cache"].items():
+ if k in SupportedCacheControls:
+ data["cache"][k] = v
+
+ ## KEY-LEVEL SPEND LOGS / TAGS
+ if "tags" in key_metadata and key_metadata["tags"] is not None:
+ if "tags" in data[_metadata_variable_name] and isinstance(
+ data[_metadata_variable_name]["tags"], list
+ ):
+ data[_metadata_variable_name]["tags"].extend(key_metadata["tags"])
+ else:
+ data[_metadata_variable_name]["tags"] = key_metadata["tags"]
+ if "spend_logs_metadata" in key_metadata and isinstance(
+ key_metadata["spend_logs_metadata"], dict
+ ):
+ if "spend_logs_metadata" in data[_metadata_variable_name] and isinstance(
+ data[_metadata_variable_name]["spend_logs_metadata"], dict
+ ):
+ for key, value in key_metadata["spend_logs_metadata"].items():
+ if (
+ key not in data[_metadata_variable_name]["spend_logs_metadata"]
+ ): # don't override k-v pair sent by request (user request)
+ data[_metadata_variable_name]["spend_logs_metadata"][
+ key
+ ] = value
+ else:
+ data[_metadata_variable_name]["spend_logs_metadata"] = key_metadata[
+ "spend_logs_metadata"
+ ]
+
+ ## KEY-LEVEL DISABLE FALLBACKS
+ if "disable_fallbacks" in key_metadata and isinstance(
+ key_metadata["disable_fallbacks"], bool
+ ):
+ data["disable_fallbacks"] = key_metadata["disable_fallbacks"]
+ return data
+
async def add_litellm_data_to_request( # noqa: PLR0915
data: dict,
@@ -389,37 +434,11 @@ async def add_litellm_data_to_request( # noqa: PLR0915
### KEY-LEVEL Controls
key_metadata = user_api_key_dict.metadata
- if "cache" in key_metadata:
- data["cache"] = {}
- if isinstance(key_metadata["cache"], dict):
- for k, v in key_metadata["cache"].items():
- if k in SupportedCacheControls:
- data["cache"][k] = v
-
- ## KEY-LEVEL SPEND LOGS / TAGS
- if "tags" in key_metadata and key_metadata["tags"] is not None:
- if "tags" in data[_metadata_variable_name] and isinstance(
- data[_metadata_variable_name]["tags"], list
- ):
- data[_metadata_variable_name]["tags"].extend(key_metadata["tags"])
- else:
- data[_metadata_variable_name]["tags"] = key_metadata["tags"]
- if "spend_logs_metadata" in key_metadata and isinstance(
- key_metadata["spend_logs_metadata"], dict
- ):
- if "spend_logs_metadata" in data[_metadata_variable_name] and isinstance(
- data[_metadata_variable_name]["spend_logs_metadata"], dict
- ):
- for key, value in key_metadata["spend_logs_metadata"].items():
- if (
- key not in data[_metadata_variable_name]["spend_logs_metadata"]
- ): # don't override k-v pair sent by request (user request)
- data[_metadata_variable_name]["spend_logs_metadata"][key] = value
- else:
- data[_metadata_variable_name]["spend_logs_metadata"] = key_metadata[
- "spend_logs_metadata"
- ]
-
+ data = LiteLLMProxyRequestSetup.add_key_level_controls(
+ key_metadata=key_metadata,
+ data=data,
+ _metadata_variable_name=_metadata_variable_name,
+ )
## TEAM-LEVEL SPEND LOGS/TAGS
team_metadata = user_api_key_dict.team_metadata or {}
if "tags" in team_metadata and team_metadata["tags"] is not None:
diff --git a/litellm/proxy/management_endpoints/key_management_endpoints.py b/litellm/proxy/management_endpoints/key_management_endpoints.py
index 01baa5a43..e38236e9b 100644
--- a/litellm/proxy/management_endpoints/key_management_endpoints.py
+++ b/litellm/proxy/management_endpoints/key_management_endpoints.py
@@ -17,7 +17,7 @@ import secrets
import traceback
import uuid
from datetime import datetime, timedelta, timezone
-from typing import List, Optional
+from typing import List, Optional, Tuple
import fastapi
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Request, status
@@ -31,6 +31,7 @@ from litellm.proxy.auth.auth_checks import (
get_key_object,
)
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
+from litellm.proxy.hooks.key_management_event_hooks import KeyManagementEventHooks
from litellm.proxy.management_helpers.utils import management_endpoint_wrapper
from litellm.proxy.utils import _duration_in_seconds, _hash_token_if_needed
from litellm.secret_managers.main import get_secret
@@ -234,50 +235,14 @@ async def generate_key_fn( # noqa: PLR0915
data.soft_budget
) # include the user-input soft budget in the response
- if data.send_invite_email is True:
- if "email" not in general_settings.get("alerting", []):
- raise ValueError(
- "Email alerting not setup on config.yaml. Please set `alerting=['email']. \nDocs: https://docs.litellm.ai/docs/proxy/email`"
- )
- event = WebhookEvent(
- event="key_created",
- event_group="key",
- event_message="API Key Created",
- token=response.get("token", ""),
- spend=response.get("spend", 0.0),
- max_budget=response.get("max_budget", 0.0),
- user_id=response.get("user_id", None),
- team_id=response.get("team_id", "Default Team"),
- key_alias=response.get("key_alias", None),
- )
-
- # If user configured email alerting - send an Email letting their end-user know the key was created
- asyncio.create_task(
- proxy_logging_obj.slack_alerting_instance.send_key_created_or_user_invited_email(
- webhook_event=event,
- )
- )
-
- # Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
- if litellm.store_audit_logs is True:
- _updated_values = json.dumps(response, default=str)
- asyncio.create_task(
- create_audit_log_for_update(
- request_data=LiteLLM_AuditLogs(
- id=str(uuid.uuid4()),
- updated_at=datetime.now(timezone.utc),
- changed_by=litellm_changed_by
- or user_api_key_dict.user_id
- or litellm_proxy_admin_name,
- changed_by_api_key=user_api_key_dict.api_key,
- table_name=LitellmTableNames.KEY_TABLE_NAME,
- object_id=response.get("token_id", ""),
- action="created",
- updated_values=_updated_values,
- before_value=None,
- )
- )
+ asyncio.create_task(
+ KeyManagementEventHooks.async_key_generated_hook(
+ data=data,
+ response=response,
+ user_api_key_dict=user_api_key_dict,
+ litellm_changed_by=litellm_changed_by,
)
+ )
return GenerateKeyResponse(**response)
except Exception as e:
@@ -407,30 +372,15 @@ async def update_key_fn(
proxy_logging_obj=proxy_logging_obj,
)
- # Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
- if litellm.store_audit_logs is True:
- _updated_values = json.dumps(data_json, default=str)
-
- _before_value = existing_key_row.json(exclude_none=True)
- _before_value = json.dumps(_before_value, default=str)
-
- asyncio.create_task(
- create_audit_log_for_update(
- request_data=LiteLLM_AuditLogs(
- id=str(uuid.uuid4()),
- updated_at=datetime.now(timezone.utc),
- changed_by=litellm_changed_by
- or user_api_key_dict.user_id
- or litellm_proxy_admin_name,
- changed_by_api_key=user_api_key_dict.api_key,
- table_name=LitellmTableNames.KEY_TABLE_NAME,
- object_id=data.key,
- action="updated",
- updated_values=_updated_values,
- before_value=_before_value,
- )
- )
+ asyncio.create_task(
+ KeyManagementEventHooks.async_key_updated_hook(
+ data=data,
+ existing_key_row=existing_key_row,
+ response=response,
+ user_api_key_dict=user_api_key_dict,
+ litellm_changed_by=litellm_changed_by,
)
+ )
if response is None:
raise ValueError("Failed to update key got response = None")
@@ -496,6 +446,9 @@ async def delete_key_fn(
user_custom_key_generate,
)
+ if prisma_client is None:
+ raise Exception("Not connected to DB!")
+
keys = data.keys
if len(keys) == 0:
raise ProxyException(
@@ -516,45 +469,7 @@ async def delete_key_fn(
):
user_id = None # unless they're admin
- # Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
- # we do this after the first for loop, since first for loop is for validation. we only want this inserted after validation passes
- if litellm.store_audit_logs is True:
- # make an audit log for each team deleted
- for key in data.keys:
- key_row = await prisma_client.get_data( # type: ignore
- token=key, table_name="key", query_type="find_unique"
- )
-
- if key_row is None:
- raise ProxyException(
- message=f"Key {key} not found",
- type=ProxyErrorTypes.bad_request_error,
- param="key",
- code=status.HTTP_404_NOT_FOUND,
- )
-
- key_row = key_row.json(exclude_none=True)
- _key_row = json.dumps(key_row, default=str)
-
- asyncio.create_task(
- create_audit_log_for_update(
- request_data=LiteLLM_AuditLogs(
- id=str(uuid.uuid4()),
- updated_at=datetime.now(timezone.utc),
- changed_by=litellm_changed_by
- or user_api_key_dict.user_id
- or litellm_proxy_admin_name,
- changed_by_api_key=user_api_key_dict.api_key,
- table_name=LitellmTableNames.KEY_TABLE_NAME,
- object_id=key,
- action="deleted",
- updated_values="{}",
- before_value=_key_row,
- )
- )
- )
-
- number_deleted_keys = await delete_verification_token(
+ number_deleted_keys, _keys_being_deleted = await delete_verification_token(
tokens=keys, user_id=user_id
)
if number_deleted_keys is None:
@@ -588,6 +503,16 @@ async def delete_key_fn(
f"/keys/delete - cache after delete: {user_api_key_cache.in_memory_cache.cache_dict}"
)
+ asyncio.create_task(
+ KeyManagementEventHooks.async_key_deleted_hook(
+ data=data,
+ keys_being_deleted=_keys_being_deleted,
+ user_api_key_dict=user_api_key_dict,
+ litellm_changed_by=litellm_changed_by,
+ response=number_deleted_keys,
+ )
+ )
+
return {"deleted_keys": keys}
except Exception as e:
if isinstance(e, HTTPException):
@@ -1026,11 +951,35 @@ async def generate_key_helper_fn( # noqa: PLR0915
return key_data
-async def delete_verification_token(tokens: List, user_id: Optional[str] = None):
+async def delete_verification_token(
+ tokens: List, user_id: Optional[str] = None
+) -> Tuple[Optional[Dict], List[LiteLLM_VerificationToken]]:
+ """
+ Helper that deletes the list of tokens from the database
+
+ Args:
+ tokens: List of tokens to delete
+ user_id: Optional user_id to filter by
+
+ Returns:
+ Tuple[Optional[Dict], List[LiteLLM_VerificationToken]]:
+ Optional[Dict]:
+ - Number of deleted tokens
+ List[LiteLLM_VerificationToken]:
+ - List of keys being deleted, this contains information about the key_alias, token, and user_id being deleted,
+ this is passed down to the KeyManagementEventHooks to delete the keys from the secret manager and handle audit logs
+ """
from litellm.proxy.proxy_server import litellm_proxy_admin_name, prisma_client
try:
if prisma_client:
+ tokens = [_hash_token_if_needed(token=key) for key in tokens]
+ _keys_being_deleted = (
+ await prisma_client.db.litellm_verificationtoken.find_many(
+ where={"token": {"in": tokens}}
+ )
+ )
+
# Assuming 'db' is your Prisma Client instance
# check if admin making request - don't filter by user-id
if user_id == litellm_proxy_admin_name:
@@ -1060,7 +1009,7 @@ async def delete_verification_token(tokens: List, user_id: Optional[str] = None)
)
verbose_proxy_logger.debug(traceback.format_exc())
raise e
- return deleted_tokens
+ return deleted_tokens, _keys_being_deleted
@router.post(
diff --git a/litellm/proxy/proxy_cli.py b/litellm/proxy/proxy_cli.py
index f9f8276c7..094828de1 100644
--- a/litellm/proxy/proxy_cli.py
+++ b/litellm/proxy/proxy_cli.py
@@ -265,7 +265,6 @@ def run_server( # noqa: PLR0915
ProxyConfig,
app,
load_aws_kms,
- load_aws_secret_manager,
load_from_azure_key_vault,
load_google_kms,
save_worker_config,
@@ -278,7 +277,6 @@ def run_server( # noqa: PLR0915
ProxyConfig,
app,
load_aws_kms,
- load_aws_secret_manager,
load_from_azure_key_vault,
load_google_kms,
save_worker_config,
@@ -295,7 +293,6 @@ def run_server( # noqa: PLR0915
ProxyConfig,
app,
load_aws_kms,
- load_aws_secret_manager,
load_from_azure_key_vault,
load_google_kms,
save_worker_config,
@@ -559,8 +556,14 @@ def run_server( # noqa: PLR0915
key_management_system
== KeyManagementSystem.AWS_SECRET_MANAGER.value # noqa: F405
):
+ from litellm.secret_managers.aws_secret_manager_v2 import (
+ AWSSecretsManagerV2,
+ )
+
### LOAD FROM AWS SECRET MANAGER ###
- load_aws_secret_manager(use_aws_secret_manager=True)
+ AWSSecretsManagerV2.load_aws_secret_manager(
+ use_aws_secret_manager=True
+ )
elif key_management_system == KeyManagementSystem.AWS_KMS.value:
load_aws_kms(use_aws_kms=True)
elif (
diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml
index 29d14c910..71e3dee0e 100644
--- a/litellm/proxy/proxy_config.yaml
+++ b/litellm/proxy/proxy_config.yaml
@@ -7,6 +7,8 @@ model_list:
-litellm_settings:
- callbacks: ["gcs_bucket"]
-
+general_settings:
+ key_management_system: "aws_secret_manager"
+ key_management_settings:
+ store_virtual_keys: true
+ access_mode: "write_only"
diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py
index bbf4b0b93..92ca32e52 100644
--- a/litellm/proxy/proxy_server.py
+++ b/litellm/proxy/proxy_server.py
@@ -245,10 +245,7 @@ from litellm.router import (
from litellm.router import ModelInfo as RouterModelInfo
from litellm.router import updateDeployment
from litellm.scheduler import DefaultPriorities, FlowItem, Scheduler
-from litellm.secret_managers.aws_secret_manager import (
- load_aws_kms,
- load_aws_secret_manager,
-)
+from litellm.secret_managers.aws_secret_manager import load_aws_kms
from litellm.secret_managers.google_kms import load_google_kms
from litellm.secret_managers.main import (
get_secret,
@@ -1825,8 +1822,13 @@ class ProxyConfig:
key_management_system
== KeyManagementSystem.AWS_SECRET_MANAGER.value # noqa: F405
):
- ### LOAD FROM AWS SECRET MANAGER ###
- load_aws_secret_manager(use_aws_secret_manager=True)
+ from litellm.secret_managers.aws_secret_manager_v2 import (
+ AWSSecretsManagerV2,
+ )
+
+ AWSSecretsManagerV2.load_aws_secret_manager(
+ use_aws_secret_manager=True
+ )
elif key_management_system == KeyManagementSystem.AWS_KMS.value:
load_aws_kms(use_aws_kms=True)
elif (
diff --git a/litellm/rerank_api/main.py b/litellm/rerank_api/main.py
index a06aff135..9cc8a8c1d 100644
--- a/litellm/rerank_api/main.py
+++ b/litellm/rerank_api/main.py
@@ -8,7 +8,8 @@ from litellm._logging import verbose_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.azure_ai.rerank import AzureAIRerank
from litellm.llms.cohere.rerank import CohereRerank
-from litellm.llms.together_ai.rerank import TogetherAIRerank
+from litellm.llms.jina_ai.rerank.handler import JinaAIRerank
+from litellm.llms.together_ai.rerank.handler import TogetherAIRerank
from litellm.secret_managers.main import get_secret
from litellm.types.rerank import RerankRequest, RerankResponse
from litellm.types.router import *
@@ -19,6 +20,7 @@ from litellm.utils import client, exception_type, supports_httpx_timeout
cohere_rerank = CohereRerank()
together_rerank = TogetherAIRerank()
azure_ai_rerank = AzureAIRerank()
+jina_ai_rerank = JinaAIRerank()
#################################################
@@ -247,7 +249,23 @@ def rerank(
api_key=api_key,
_is_async=_is_async,
)
+ elif _custom_llm_provider == "jina_ai":
+ if dynamic_api_key is None:
+ raise ValueError(
+ "Jina AI API key is required, please set 'JINA_AI_API_KEY' in your environment"
+ )
+ response = jina_ai_rerank.rerank(
+ model=model,
+ api_key=dynamic_api_key,
+ query=query,
+ documents=documents,
+ top_n=top_n,
+ rank_fields=rank_fields,
+ return_documents=return_documents,
+ max_chunks_per_doc=max_chunks_per_doc,
+ _is_async=_is_async,
+ )
else:
raise ValueError(f"Unsupported provider: {_custom_llm_provider}")
diff --git a/litellm/router.py b/litellm/router.py
index 4735d422b..97065bc85 100644
--- a/litellm/router.py
+++ b/litellm/router.py
@@ -679,9 +679,8 @@ class Router:
kwargs["model"] = model
kwargs["messages"] = messages
kwargs["original_function"] = self._completion
- kwargs.get("request_timeout", self.timeout)
- kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
- kwargs.setdefault("metadata", {}).update({"model_group": model})
+ self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
+
response = self.function_with_fallbacks(**kwargs)
return response
except Exception as e:
@@ -783,8 +782,7 @@ class Router:
kwargs["stream"] = stream
kwargs["original_function"] = self._acompletion
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
-
- kwargs.setdefault("metadata", {}).update({"model_group": model})
+ self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
request_priority = kwargs.get("priority") or self.default_priority
@@ -948,6 +946,17 @@ class Router:
self.fail_calls[model_name] += 1
raise e
+ def _update_kwargs_before_fallbacks(self, model: str, kwargs: dict) -> None:
+ """
+ Adds/updates to kwargs:
+ - num_retries
+ - litellm_trace_id
+ - metadata
+ """
+ kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
+ kwargs.setdefault("litellm_trace_id", str(uuid.uuid4()))
+ kwargs.setdefault("metadata", {}).update({"model_group": model})
+
def _update_kwargs_with_default_litellm_params(self, kwargs: dict) -> None:
"""
Adds default litellm params to kwargs, if set.
@@ -1511,9 +1520,7 @@ class Router:
kwargs["model"] = model
kwargs["file"] = file
kwargs["original_function"] = self._atranscription
- kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
- kwargs.get("request_timeout", self.timeout)
- kwargs.setdefault("metadata", {}).update({"model_group": model})
+ self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
response = await self.async_function_with_fallbacks(**kwargs)
return response
@@ -1688,9 +1695,7 @@ class Router:
kwargs["model"] = model
kwargs["input"] = input
kwargs["original_function"] = self._arerank
- kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
- kwargs.get("request_timeout", self.timeout)
- kwargs.setdefault("metadata", {}).update({"model_group": model})
+ self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
response = await self.async_function_with_fallbacks(**kwargs)
@@ -1839,9 +1844,7 @@ class Router:
kwargs["model"] = model
kwargs["prompt"] = prompt
kwargs["original_function"] = self._atext_completion
- kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
- kwargs.get("request_timeout", self.timeout)
- kwargs.setdefault("metadata", {}).update({"model_group": model})
+ self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
response = await self.async_function_with_fallbacks(**kwargs)
return response
@@ -2112,9 +2115,7 @@ class Router:
kwargs["model"] = model
kwargs["input"] = input
kwargs["original_function"] = self._aembedding
- kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
- kwargs.get("request_timeout", self.timeout)
- kwargs.setdefault("metadata", {}).update({"model_group": model})
+ self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
response = await self.async_function_with_fallbacks(**kwargs)
return response
except Exception as e:
@@ -2609,6 +2610,7 @@ class Router:
If it fails after num_retries, fall back to another model group
"""
model_group: Optional[str] = kwargs.get("model")
+ disable_fallbacks: Optional[bool] = kwargs.pop("disable_fallbacks", False)
fallbacks: Optional[List] = kwargs.get("fallbacks", self.fallbacks)
context_window_fallbacks: Optional[List] = kwargs.get(
"context_window_fallbacks", self.context_window_fallbacks
@@ -2616,6 +2618,7 @@ class Router:
content_policy_fallbacks: Optional[List] = kwargs.get(
"content_policy_fallbacks", self.content_policy_fallbacks
)
+
try:
self._handle_mock_testing_fallbacks(
kwargs=kwargs,
@@ -2635,7 +2638,7 @@ class Router:
original_model_group: Optional[str] = kwargs.get("model") # type: ignore
fallback_failure_exception_str = ""
- if original_model_group is None:
+ if disable_fallbacks is True or original_model_group is None:
raise e
input_kwargs = {
diff --git a/litellm/secret_managers/aws_secret_manager.py b/litellm/secret_managers/aws_secret_manager.py
index f0e510fa8..fbe951e64 100644
--- a/litellm/secret_managers/aws_secret_manager.py
+++ b/litellm/secret_managers/aws_secret_manager.py
@@ -23,28 +23,6 @@ def validate_environment():
raise ValueError("Missing required environment variable - AWS_REGION_NAME")
-def load_aws_secret_manager(use_aws_secret_manager: Optional[bool]):
- if use_aws_secret_manager is None or use_aws_secret_manager is False:
- return
- try:
- import boto3
- from botocore.exceptions import ClientError
-
- validate_environment()
-
- # Create a Secrets Manager client
- session = boto3.session.Session() # type: ignore
- client = session.client(
- service_name="secretsmanager", region_name=os.getenv("AWS_REGION_NAME")
- )
-
- litellm.secret_manager_client = client
- litellm._key_management_system = KeyManagementSystem.AWS_SECRET_MANAGER
-
- except Exception as e:
- raise e
-
-
def load_aws_kms(use_aws_kms: Optional[bool]):
if use_aws_kms is None or use_aws_kms is False:
return
diff --git a/litellm/secret_managers/aws_secret_manager_v2.py b/litellm/secret_managers/aws_secret_manager_v2.py
new file mode 100644
index 000000000..69add6f23
--- /dev/null
+++ b/litellm/secret_managers/aws_secret_manager_v2.py
@@ -0,0 +1,310 @@
+"""
+This is a file for the AWS Secret Manager Integration
+
+Handles Async Operations for:
+- Read Secret
+- Write Secret
+- Delete Secret
+
+Relevant issue: https://github.com/BerriAI/litellm/issues/1883
+
+Requires:
+* `os.environ["AWS_REGION_NAME"],
+* `pip install boto3>=1.28.57`
+"""
+
+import ast
+import asyncio
+import base64
+import json
+import os
+import re
+import sys
+from typing import Any, Dict, Optional, Union
+
+import httpx
+
+import litellm
+from litellm._logging import verbose_logger
+from litellm.llms.base_aws_llm import BaseAWSLLM
+from litellm.llms.custom_httpx.http_handler import (
+ _get_httpx_client,
+ get_async_httpx_client,
+)
+from litellm.llms.custom_httpx.types import httpxSpecialProvider
+from litellm.proxy._types import KeyManagementSystem
+
+
+class AWSSecretsManagerV2(BaseAWSLLM):
+ @classmethod
+ def validate_environment(cls):
+ if "AWS_REGION_NAME" not in os.environ:
+ raise ValueError("Missing required environment variable - AWS_REGION_NAME")
+
+ @classmethod
+ def load_aws_secret_manager(cls, use_aws_secret_manager: Optional[bool]):
+ """
+ Initialize AWSSecretsManagerV2 and sets litellm.secret_manager_client = AWSSecretsManagerV2() and litellm._key_management_system = KeyManagementSystem.AWS_SECRET_MANAGER
+ """
+ if use_aws_secret_manager is None or use_aws_secret_manager is False:
+ return
+ try:
+ import boto3
+
+ cls.validate_environment()
+ litellm.secret_manager_client = cls()
+ litellm._key_management_system = KeyManagementSystem.AWS_SECRET_MANAGER
+
+ except Exception as e:
+ raise e
+
+ async def async_read_secret(
+ self,
+ secret_name: str,
+ optional_params: Optional[dict] = None,
+ timeout: Optional[Union[float, httpx.Timeout]] = None,
+ ) -> Optional[str]:
+ """
+ Async function to read a secret from AWS Secrets Manager
+
+ Returns:
+ str: Secret value
+ Raises:
+ ValueError: If the secret is not found or an HTTP error occurs
+ """
+ endpoint_url, headers, body = self._prepare_request(
+ action="GetSecretValue",
+ secret_name=secret_name,
+ optional_params=optional_params,
+ )
+
+ async_client = get_async_httpx_client(
+ llm_provider=httpxSpecialProvider.SecretManager,
+ params={"timeout": timeout},
+ )
+
+ try:
+ response = await async_client.post(
+ url=endpoint_url, headers=headers, data=body.decode("utf-8")
+ )
+ response.raise_for_status()
+ return response.json()["SecretString"]
+ except httpx.TimeoutException:
+ raise ValueError("Timeout error occurred")
+ except Exception as e:
+ verbose_logger.exception(
+ "Error reading secret from AWS Secrets Manager: %s", str(e)
+ )
+ return None
+
+ def sync_read_secret(
+ self,
+ secret_name: str,
+ optional_params: Optional[dict] = None,
+ timeout: Optional[Union[float, httpx.Timeout]] = None,
+ ) -> Optional[str]:
+ """
+ Sync function to read a secret from AWS Secrets Manager
+
+ Done for backwards compatibility with existing codebase, since get_secret is a sync function
+ """
+
+ # self._prepare_request uses these env vars, we cannot read them from AWS Secrets Manager. If we do we'd get stuck in an infinite loop
+ if secret_name in [
+ "AWS_ACCESS_KEY_ID",
+ "AWS_SECRET_ACCESS_KEY",
+ "AWS_REGION_NAME",
+ "AWS_REGION",
+ "AWS_BEDROCK_RUNTIME_ENDPOINT",
+ ]:
+ return os.getenv(secret_name)
+
+ endpoint_url, headers, body = self._prepare_request(
+ action="GetSecretValue",
+ secret_name=secret_name,
+ optional_params=optional_params,
+ )
+
+ sync_client = _get_httpx_client(
+ params={"timeout": timeout},
+ )
+
+ try:
+ response = sync_client.post(
+ url=endpoint_url, headers=headers, data=body.decode("utf-8")
+ )
+ response.raise_for_status()
+ return response.json()["SecretString"]
+ except httpx.TimeoutException:
+ raise ValueError("Timeout error occurred")
+ except Exception as e:
+ verbose_logger.exception(
+ "Error reading secret from AWS Secrets Manager: %s", str(e)
+ )
+ return None
+
+ async def async_write_secret(
+ self,
+ secret_name: str,
+ secret_value: str,
+ description: Optional[str] = None,
+ client_request_token: Optional[str] = None,
+ optional_params: Optional[dict] = None,
+ timeout: Optional[Union[float, httpx.Timeout]] = None,
+ ) -> dict:
+ """
+ Async function to write a secret to AWS Secrets Manager
+
+ Args:
+ secret_name: Name of the secret
+ secret_value: Value to store (can be a JSON string)
+ description: Optional description for the secret
+ client_request_token: Optional unique identifier to ensure idempotency
+ optional_params: Additional AWS parameters
+ timeout: Request timeout
+ """
+ import uuid
+
+ # Prepare the request data
+ data = {"Name": secret_name, "SecretString": secret_value}
+ if description:
+ data["Description"] = description
+
+ data["ClientRequestToken"] = str(uuid.uuid4())
+
+ endpoint_url, headers, body = self._prepare_request(
+ action="CreateSecret",
+ secret_name=secret_name,
+ secret_value=secret_value,
+ optional_params=optional_params,
+ request_data=data, # Pass the complete request data
+ )
+
+ async_client = get_async_httpx_client(
+ llm_provider=httpxSpecialProvider.SecretManager,
+ params={"timeout": timeout},
+ )
+
+ try:
+ response = await async_client.post(
+ url=endpoint_url, headers=headers, data=body.decode("utf-8")
+ )
+ response.raise_for_status()
+ return response.json()
+ except httpx.HTTPStatusError as err:
+ raise ValueError(f"HTTP error occurred: {err.response.text}")
+ except httpx.TimeoutException:
+ raise ValueError("Timeout error occurred")
+
+ async def async_delete_secret(
+ self,
+ secret_name: str,
+ recovery_window_in_days: Optional[int] = 7,
+ optional_params: Optional[dict] = None,
+ timeout: Optional[Union[float, httpx.Timeout]] = None,
+ ) -> dict:
+ """
+ Async function to delete a secret from AWS Secrets Manager
+
+ Args:
+ secret_name: Name of the secret to delete
+ recovery_window_in_days: Number of days before permanent deletion (default: 7)
+ optional_params: Additional AWS parameters
+ timeout: Request timeout
+
+ Returns:
+ dict: Response from AWS Secrets Manager containing deletion details
+ """
+ # Prepare the request data
+ data = {
+ "SecretId": secret_name,
+ "RecoveryWindowInDays": recovery_window_in_days,
+ }
+
+ endpoint_url, headers, body = self._prepare_request(
+ action="DeleteSecret",
+ secret_name=secret_name,
+ optional_params=optional_params,
+ request_data=data,
+ )
+
+ async_client = get_async_httpx_client(
+ llm_provider=httpxSpecialProvider.SecretManager,
+ params={"timeout": timeout},
+ )
+
+ try:
+ response = await async_client.post(
+ url=endpoint_url, headers=headers, data=body.decode("utf-8")
+ )
+ response.raise_for_status()
+ return response.json()
+ except httpx.HTTPStatusError as err:
+ raise ValueError(f"HTTP error occurred: {err.response.text}")
+ except httpx.TimeoutException:
+ raise ValueError("Timeout error occurred")
+
+ def _prepare_request(
+ self,
+ action: str, # "GetSecretValue" or "PutSecretValue"
+ secret_name: str,
+ secret_value: Optional[str] = None,
+ optional_params: Optional[dict] = None,
+ request_data: Optional[dict] = None,
+ ) -> tuple[str, Any, bytes]:
+ """Prepare the AWS Secrets Manager request"""
+ try:
+ import boto3
+ from botocore.auth import SigV4Auth
+ from botocore.awsrequest import AWSRequest
+ from botocore.credentials import Credentials
+ except ImportError:
+ raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
+ optional_params = optional_params or {}
+ boto3_credentials_info = self._get_boto_credentials_from_optional_params(
+ optional_params
+ )
+
+ # Get endpoint
+ _, endpoint_url = self.get_runtime_endpoint(
+ api_base=None,
+ aws_bedrock_runtime_endpoint=boto3_credentials_info.aws_bedrock_runtime_endpoint,
+ aws_region_name=boto3_credentials_info.aws_region_name,
+ )
+ endpoint_url = endpoint_url.replace("bedrock-runtime", "secretsmanager")
+
+ # Use provided request_data if available, otherwise build default data
+ if request_data:
+ data = request_data
+ else:
+ data = {"SecretId": secret_name}
+ if secret_value and action == "PutSecretValue":
+ data["SecretString"] = secret_value
+
+ body = json.dumps(data).encode("utf-8")
+ headers = {
+ "Content-Type": "application/x-amz-json-1.1",
+ "X-Amz-Target": f"secretsmanager.{action}",
+ }
+
+ # Sign request
+ request = AWSRequest(
+ method="POST", url=endpoint_url, data=body, headers=headers
+ )
+ SigV4Auth(
+ boto3_credentials_info.credentials,
+ "secretsmanager",
+ boto3_credentials_info.aws_region_name,
+ ).add_auth(request)
+ prepped = request.prepare()
+
+ return endpoint_url, prepped.headers, body
+
+
+# if __name__ == "__main__":
+# print("loading aws secret manager v2")
+# aws_secret_manager_v2 = AWSSecretsManagerV2()
+
+# print("writing secret to aws secret manager v2")
+# asyncio.run(aws_secret_manager_v2.async_write_secret(secret_name="test_secret_3", secret_value="test_value_2"))
+# print("reading secret from aws secret manager v2")
diff --git a/litellm/secret_managers/main.py b/litellm/secret_managers/main.py
index f3d6d420a..ce6d30755 100644
--- a/litellm/secret_managers/main.py
+++ b/litellm/secret_managers/main.py
@@ -5,7 +5,7 @@ import json
import os
import sys
import traceback
-from typing import Any, Optional, Union
+from typing import TYPE_CHECKING, Any, Optional, Union
import httpx
from dotenv import load_dotenv
@@ -198,7 +198,10 @@ def get_secret( # noqa: PLR0915
raise ValueError("Unsupported OIDC provider")
try:
- if litellm.secret_manager_client is not None:
+ if (
+ _should_read_secret_from_secret_manager()
+ and litellm.secret_manager_client is not None
+ ):
try:
client = litellm.secret_manager_client
key_manager = "local"
@@ -207,7 +210,8 @@ def get_secret( # noqa: PLR0915
if key_management_settings is not None:
if (
- secret_name not in key_management_settings.hosted_keys
+ key_management_settings.hosted_keys is not None
+ and secret_name not in key_management_settings.hosted_keys
): # allow user to specify which keys to check in hosted key manager
key_manager = "local"
@@ -268,25 +272,13 @@ def get_secret( # noqa: PLR0915
if isinstance(secret, str):
secret = secret.strip()
elif key_manager == KeyManagementSystem.AWS_SECRET_MANAGER.value:
- try:
- get_secret_value_response = client.get_secret_value(
- SecretId=secret_name
- )
- print_verbose(
- f"get_secret_value_response: {get_secret_value_response}"
- )
- except Exception as e:
- print_verbose(f"An error occurred - {str(e)}")
- # For a list of exceptions thrown, see
- # https://docs.aws.amazon.com/secretsmanager/latest/apireference/API_GetSecretValue.html
- raise e
+ from litellm.secret_managers.aws_secret_manager_v2 import (
+ AWSSecretsManagerV2,
+ )
- # assume there is 1 secret per secret_name
- secret_dict = json.loads(get_secret_value_response["SecretString"])
- print_verbose(f"secret_dict: {secret_dict}")
- for k, v in secret_dict.items():
- secret = v
- print_verbose(f"secret: {secret}")
+ if isinstance(client, AWSSecretsManagerV2):
+ secret = client.sync_read_secret(secret_name=secret_name)
+ print_verbose(f"get_secret_value_response: {secret}")
elif key_manager == KeyManagementSystem.GOOGLE_SECRET_MANAGER.value:
try:
secret = client.get_secret_from_google_secret_manager(
@@ -332,3 +324,21 @@ def get_secret( # noqa: PLR0915
return default_value
else:
raise e
+
+
+def _should_read_secret_from_secret_manager() -> bool:
+ """
+ Returns True if the secret manager should be used to read the secret, False otherwise
+
+ - If the secret manager client is not set, return False
+ - If the `_key_management_settings` access mode is "read_only" or "read_and_write", return True
+ - Otherwise, return False
+ """
+ if litellm.secret_manager_client is not None:
+ if litellm._key_management_settings is not None:
+ if (
+ litellm._key_management_settings.access_mode == "read_only"
+ or litellm._key_management_settings.access_mode == "read_and_write"
+ ):
+ return True
+ return False
diff --git a/litellm/types/rerank.py b/litellm/types/rerank.py
index d016021fb..00b07ba13 100644
--- a/litellm/types/rerank.py
+++ b/litellm/types/rerank.py
@@ -7,6 +7,7 @@ https://docs.cohere.com/reference/rerank
from typing import List, Optional, Union
from pydantic import BaseModel, PrivateAttr
+from typing_extensions import TypedDict
class RerankRequest(BaseModel):
@@ -19,10 +20,26 @@ class RerankRequest(BaseModel):
max_chunks_per_doc: Optional[int] = None
+class RerankBilledUnits(TypedDict, total=False):
+ search_units: int
+ total_tokens: int
+
+
+class RerankTokens(TypedDict, total=False):
+ input_tokens: int
+ output_tokens: int
+
+
+class RerankResponseMeta(TypedDict, total=False):
+ api_version: dict
+ billed_units: RerankBilledUnits
+ tokens: RerankTokens
+
+
class RerankResponse(BaseModel):
id: str
results: List[dict] # Contains index and relevance_score
- meta: Optional[dict] = None # Contains api_version and billed_units
+ meta: Optional[RerankResponseMeta] = None # Contains api_version and billed_units
# Define private attributes using PrivateAttr
_hidden_params: dict = PrivateAttr(default_factory=dict)
diff --git a/litellm/types/router.py b/litellm/types/router.py
index 6119ca4b7..bb93aaa63 100644
--- a/litellm/types/router.py
+++ b/litellm/types/router.py
@@ -150,6 +150,8 @@ class GenericLiteLLMParams(BaseModel):
max_retries: Optional[int] = None
organization: Optional[str] = None # for openai orgs
configurable_clientside_auth_params: CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS = None
+ ## LOGGING PARAMS ##
+ litellm_trace_id: Optional[str] = None
## UNIFIED PROJECT/REGION ##
region_name: Optional[str] = None
## VERTEX AI ##
@@ -186,6 +188,8 @@ class GenericLiteLLMParams(BaseModel):
None # timeout when making stream=True calls, if str, pass in as os.environ/
),
organization: Optional[str] = None, # for openai orgs
+ ## LOGGING PARAMS ##
+ litellm_trace_id: Optional[str] = None,
## UNIFIED PROJECT/REGION ##
region_name: Optional[str] = None,
## VERTEX AI ##
diff --git a/litellm/types/utils.py b/litellm/types/utils.py
index e3df357be..d02129681 100644
--- a/litellm/types/utils.py
+++ b/litellm/types/utils.py
@@ -1334,6 +1334,7 @@ class ResponseFormatChunk(TypedDict, total=False):
all_litellm_params = [
"metadata",
+ "litellm_trace_id",
"tags",
"acompletion",
"aimg_generation",
@@ -1523,6 +1524,7 @@ StandardLoggingPayloadStatus = Literal["success", "failure"]
class StandardLoggingPayload(TypedDict):
id: str
+ trace_id: str # Trace multiple LLM calls belonging to same overall request (e.g. fallbacks/retries)
call_type: str
response_cost: float
response_cost_failure_debug_info: Optional[
diff --git a/litellm/utils.py b/litellm/utils.py
index a0f544312..f4f31e6cf 100644
--- a/litellm/utils.py
+++ b/litellm/utils.py
@@ -527,6 +527,7 @@ def function_setup( # noqa: PLR0915
messages=messages,
stream=stream,
litellm_call_id=kwargs["litellm_call_id"],
+ litellm_trace_id=kwargs.get("litellm_trace_id"),
function_id=function_id or "",
call_type=call_type,
start_time=start_time,
@@ -2056,6 +2057,7 @@ def get_litellm_params(
azure_ad_token_provider=None,
user_continue_message=None,
base_model=None,
+ litellm_trace_id=None,
):
litellm_params = {
"acompletion": acompletion,
@@ -2084,6 +2086,7 @@ def get_litellm_params(
"user_continue_message": user_continue_message,
"base_model": base_model
or _get_base_model_from_litellm_call_metadata(metadata=metadata),
+ "litellm_trace_id": litellm_trace_id,
}
return litellm_params
diff --git a/model_prices_and_context_window.json b/model_prices_and_context_window.json
index fb8fb105c..cae3bee12 100644
--- a/model_prices_and_context_window.json
+++ b/model_prices_and_context_window.json
@@ -2986,19 +2986,19 @@
"supports_function_calling": true
},
"vertex_ai/imagegeneration@006": {
- "cost_per_image": 0.020,
+ "output_cost_per_image": 0.020,
"litellm_provider": "vertex_ai-image-models",
"mode": "image_generation",
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing"
},
"vertex_ai/imagen-3.0-generate-001": {
- "cost_per_image": 0.04,
+ "output_cost_per_image": 0.04,
"litellm_provider": "vertex_ai-image-models",
"mode": "image_generation",
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing"
},
"vertex_ai/imagen-3.0-fast-generate-001": {
- "cost_per_image": 0.02,
+ "output_cost_per_image": 0.02,
"litellm_provider": "vertex_ai-image-models",
"mode": "image_generation",
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing"
@@ -5620,6 +5620,13 @@
"litellm_provider": "bedrock",
"mode": "image_generation"
},
+ "stability.stable-image-ultra-v1:0": {
+ "max_tokens": 77,
+ "max_input_tokens": 77,
+ "output_cost_per_image": 0.14,
+ "litellm_provider": "bedrock",
+ "mode": "image_generation"
+ },
"sagemaker/meta-textgeneration-llama-2-7b": {
"max_tokens": 4096,
"max_input_tokens": 4096,
diff --git a/pyproject.toml b/pyproject.toml
index 17d37c0ce..fedfebc4c 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[tool.poetry]
name = "litellm"
-version = "1.52.6"
+version = "1.52.9"
description = "Library to easily interface with LLM API providers"
authors = ["BerriAI"]
license = "MIT"
@@ -91,7 +91,7 @@ requires = ["poetry-core", "wheel"]
build-backend = "poetry.core.masonry.api"
[tool.commitizen]
-version = "1.52.6"
+version = "1.52.9"
version_files = [
"pyproject.toml:^version"
]
diff --git a/tests/llm_translation/base_llm_unit_tests.py b/tests/llm_translation/base_llm_unit_tests.py
index acb764ba1..955eed957 100644
--- a/tests/llm_translation/base_llm_unit_tests.py
+++ b/tests/llm_translation/base_llm_unit_tests.py
@@ -13,8 +13,11 @@ sys.path.insert(
import litellm
from litellm.exceptions import BadRequestError
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
-from litellm.utils import CustomStreamWrapper
-
+from litellm.utils import (
+ CustomStreamWrapper,
+ get_supported_openai_params,
+ get_optional_params,
+)
# test_example.py
from abc import ABC, abstractmethod
@@ -45,6 +48,9 @@ class BaseLLMChatTest(ABC):
)
assert response is not None
+ # for OpenAI the content contains the JSON schema, so we need to assert that the content is not None
+ assert response.choices[0].message.content is not None
+
def test_message_with_name(self):
base_completion_call_args = self.get_base_completion_call_args()
messages = [
@@ -79,6 +85,49 @@ class BaseLLMChatTest(ABC):
print(response)
+ # OpenAI guarantees that the JSON schema is returned in the content
+ # relevant issue: https://github.com/BerriAI/litellm/issues/6741
+ assert response.choices[0].message.content is not None
+
+ def test_json_response_format_stream(self):
+ """
+ Test that the JSON response format with streaming is supported by the LLM API
+ """
+ base_completion_call_args = self.get_base_completion_call_args()
+ litellm.set_verbose = True
+
+ messages = [
+ {
+ "role": "system",
+ "content": "Your output should be a JSON object with no additional properties. ",
+ },
+ {
+ "role": "user",
+ "content": "Respond with this in json. city=San Francisco, state=CA, weather=sunny, temp=60",
+ },
+ ]
+
+ response = litellm.completion(
+ **base_completion_call_args,
+ messages=messages,
+ response_format={"type": "json_object"},
+ stream=True,
+ )
+
+ print(response)
+
+ content = ""
+ for chunk in response:
+ content += chunk.choices[0].delta.content or ""
+
+ print("content=", content)
+
+ # OpenAI guarantees that the JSON schema is returned in the content
+ # relevant issue: https://github.com/BerriAI/litellm/issues/6741
+ # we need to assert that the JSON schema was returned in the content, (for Anthropic we were returning it as part of the tool call)
+ assert content is not None
+ assert len(content) > 0
+
@pytest.fixture
def pdf_messages(self):
import base64
diff --git a/tests/llm_translation/base_rerank_unit_tests.py b/tests/llm_translation/base_rerank_unit_tests.py
new file mode 100644
index 000000000..2a8b80194
--- /dev/null
+++ b/tests/llm_translation/base_rerank_unit_tests.py
@@ -0,0 +1,115 @@
+import asyncio
+import httpx
+import json
+import pytest
+import sys
+from typing import Any, Dict, List
+from unittest.mock import MagicMock, Mock, patch
+import os
+
+sys.path.insert(
+ 0, os.path.abspath("../..")
+) # Adds the parent directory to the system path
+import litellm
+from litellm.exceptions import BadRequestError
+from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
+from litellm.utils import (
+ CustomStreamWrapper,
+ get_supported_openai_params,
+ get_optional_params,
+)
+
+# test_example.py
+from abc import ABC, abstractmethod
+
+
+def assert_response_shape(response, custom_llm_provider):
+ expected_response_shape = {"id": str, "results": list, "meta": dict}
+
+ expected_results_shape = {"index": int, "relevance_score": float}
+
+ expected_meta_shape = {"api_version": dict, "billed_units": dict}
+
+ expected_api_version_shape = {"version": str}
+
+ expected_billed_units_shape = {"search_units": int}
+
+ assert isinstance(response.id, expected_response_shape["id"])
+ assert isinstance(response.results, expected_response_shape["results"])
+ for result in response.results:
+ assert isinstance(result["index"], expected_results_shape["index"])
+ assert isinstance(
+ result["relevance_score"], expected_results_shape["relevance_score"]
+ )
+ assert isinstance(response.meta, expected_response_shape["meta"])
+
+ if custom_llm_provider == "cohere":
+
+ assert isinstance(
+ response.meta["api_version"], expected_meta_shape["api_version"]
+ )
+ assert isinstance(
+ response.meta["api_version"]["version"],
+ expected_api_version_shape["version"],
+ )
+ assert isinstance(
+ response.meta["billed_units"], expected_meta_shape["billed_units"]
+ )
+ assert isinstance(
+ response.meta["billed_units"]["search_units"],
+ expected_billed_units_shape["search_units"],
+ )
+
+
+class BaseLLMRerankTest(ABC):
+ """
+ Abstract base test class that enforces a common test across all test classes.
+ """
+
+ @abstractmethod
+ def get_base_rerank_call_args(self) -> dict:
+ """Must return the base rerank call args"""
+ pass
+
+ @abstractmethod
+ def get_custom_llm_provider(self) -> litellm.LlmProviders:
+ """Must return the custom llm provider"""
+ pass
+
+ @pytest.mark.asyncio()
+ @pytest.mark.parametrize("sync_mode", [True, False])
+ async def test_basic_rerank(self, sync_mode):
+ rerank_call_args = self.get_base_rerank_call_args()
+ custom_llm_provider = self.get_custom_llm_provider()
+ if sync_mode is True:
+ response = litellm.rerank(
+ **rerank_call_args,
+ query="hello",
+ documents=["hello", "world"],
+ top_n=3,
+ )
+
+ print("re rank response: ", response)
+
+ assert response.id is not None
+ assert response.results is not None
+
+ assert_response_shape(
+ response=response, custom_llm_provider=custom_llm_provider.value
+ )
+ else:
+ response = await litellm.arerank(
+ **rerank_call_args,
+ query="hello",
+ documents=["hello", "world"],
+ top_n=3,
+ )
+
+ print("async re rank response: ", response)
+
+ assert response.id is not None
+ assert response.results is not None
+
+ assert_response_shape(
+ response=response, custom_llm_provider=custom_llm_provider.value
+ )
diff --git a/tests/llm_translation/test_anthropic_completion.py b/tests/llm_translation/test_anthropic_completion.py
index c399c3a47..8a788e0fb 100644
--- a/tests/llm_translation/test_anthropic_completion.py
+++ b/tests/llm_translation/test_anthropic_completion.py
@@ -33,8 +33,10 @@ from litellm import (
)
from litellm.adapters.anthropic_adapter import anthropic_adapter
from litellm.types.llms.anthropic import AnthropicResponse
-
+from litellm.types.utils import GenericStreamingChunk, ChatCompletionToolCallChunk
+from litellm.types.llms.openai import ChatCompletionToolCallFunctionChunk
from litellm.llms.anthropic.common_utils import process_anthropic_headers
+from litellm.llms.anthropic.chat.handler import AnthropicChatCompletion
from httpx import Headers
from base_llm_unit_tests import BaseLLMChatTest
@@ -694,3 +696,91 @@ class TestAnthropicCompletion(BaseLLMChatTest):
assert _document_validation["type"] == "document"
assert _document_validation["source"]["media_type"] == "application/pdf"
assert _document_validation["source"]["type"] == "base64"
+
+
+def test_convert_tool_response_to_message_with_values():
+ """Test converting a tool response with 'values' key to a message"""
+ tool_calls = [
+ ChatCompletionToolCallChunk(
+ id="test_id",
+ type="function",
+ function=ChatCompletionToolCallFunctionChunk(
+ name="json_tool_call",
+ arguments='{"values": {"name": "John", "age": 30}}',
+ ),
+ index=0,
+ )
+ ]
+
+ message = AnthropicChatCompletion._convert_tool_response_to_message(
+ tool_calls=tool_calls
+ )
+
+ assert message is not None
+ assert message.content == '{"name": "John", "age": 30}'
+
+
+def test_convert_tool_response_to_message_without_values():
+ """
+ Test converting a tool response without 'values' key to a message
+
+ Anthropic API returns the JSON schema in the tool call, OpenAI Spec expects it in the message. This test ensures that the tool call is converted to a message correctly.
+
+ Relevant issue: https://github.com/BerriAI/litellm/issues/6741
+ """
+ tool_calls = [
+ ChatCompletionToolCallChunk(
+ id="test_id",
+ type="function",
+ function=ChatCompletionToolCallFunctionChunk(
+ name="json_tool_call", arguments='{"name": "John", "age": 30}'
+ ),
+ index=0,
+ )
+ ]
+
+ message = AnthropicChatCompletion._convert_tool_response_to_message(
+ tool_calls=tool_calls
+ )
+
+ assert message is not None
+ assert message.content == '{"name": "John", "age": 30}'
+
+
+def test_convert_tool_response_to_message_invalid_json():
+ """Test converting a tool response with invalid JSON"""
+ tool_calls = [
+ ChatCompletionToolCallChunk(
+ id="test_id",
+ type="function",
+ function=ChatCompletionToolCallFunctionChunk(
+ name="json_tool_call", arguments="invalid json"
+ ),
+ index=0,
+ )
+ ]
+
+ message = AnthropicChatCompletion._convert_tool_response_to_message(
+ tool_calls=tool_calls
+ )
+
+ assert message is not None
+ assert message.content == "invalid json"
+
+
+def test_convert_tool_response_to_message_no_arguments():
+ """Test converting a tool response with no arguments"""
+ tool_calls = [
+ ChatCompletionToolCallChunk(
+ id="test_id",
+ type="function",
+ function=ChatCompletionToolCallFunctionChunk(name="json_tool_call"),
+ index=0,
+ )
+ ]
+
+ message = AnthropicChatCompletion._convert_tool_response_to_message(
+ tool_calls=tool_calls
+ )
+
+ assert message is None
diff --git a/tests/llm_translation/test_jina_ai.py b/tests/llm_translation/test_jina_ai.py
new file mode 100644
index 000000000..c169b5587
--- /dev/null
+++ b/tests/llm_translation/test_jina_ai.py
@@ -0,0 +1,23 @@
+import json
+import os
+import sys
+from datetime import datetime
+from unittest.mock import AsyncMock
+
+sys.path.insert(
+ 0, os.path.abspath("../..")
+) # Adds the parent directory to the system path
+
+
+from base_rerank_unit_tests import BaseLLMRerankTest
+import litellm
+
+
+class TestJinaAI(BaseLLMRerankTest):
+ def get_custom_llm_provider(self) -> litellm.LlmProviders:
+ return litellm.LlmProviders.JINA_AI
+
+ def get_base_rerank_call_args(self) -> dict:
+ return {
+ "model": "jina_ai/jina-reranker-v2-base-multilingual",
+ }
diff --git a/tests/llm_translation/test_optional_params.py b/tests/llm_translation/test_optional_params.py
index 8677d6b73..c9527c830 100644
--- a/tests/llm_translation/test_optional_params.py
+++ b/tests/llm_translation/test_optional_params.py
@@ -923,9 +923,22 @@ def test_watsonx_text_top_k():
assert optional_params["top_k"] == 10
+
def test_together_ai_model_params():
optional_params = get_optional_params(
model="together_ai", custom_llm_provider="together_ai", logprobs=1
)
print(optional_params)
assert optional_params["logprobs"] == 1
+
+def test_forward_user_param():
+ from litellm.utils import get_supported_openai_params, get_optional_params
+
+ model = "claude-3-5-sonnet-20240620"
+ optional_params = get_optional_params(
+ model=model,
+ user="test_user",
+ custom_llm_provider="anthropic",
+ )
+
+ assert optional_params["metadata"]["user_id"] == "test_user"
diff --git a/tests/llm_translation/test_vertex.py b/tests/llm_translation/test_vertex.py
index a06179a49..73960020d 100644
--- a/tests/llm_translation/test_vertex.py
+++ b/tests/llm_translation/test_vertex.py
@@ -16,6 +16,7 @@ import pytest
import litellm
from litellm import get_optional_params
from litellm.llms.custom_httpx.http_handler import HTTPHandler
+import httpx
def test_completion_pydantic_obj_2():
@@ -1317,3 +1318,39 @@ def test_image_completion_request(image_url):
mock_post.assert_called_once()
print("mock_post.call_args.kwargs['json']", mock_post.call_args.kwargs["json"])
assert mock_post.call_args.kwargs["json"] == expected_request_body
+
+
+@pytest.mark.parametrize(
+ "model, expected_url",
+ [
+ (
+ "textembedding-gecko@001",
+ "https://us-central1-aiplatform.googleapis.com/v1/projects/project-id/locations/us-central1/publishers/google/models/textembedding-gecko@001:predict",
+ ),
+ (
+ "123456789",
+ "https://us-central1-aiplatform.googleapis.com/v1/projects/project-id/locations/us-central1/endpoints/123456789:predict",
+ ),
+ ],
+)
+def test_vertex_embedding_url(model, expected_url):
+ """
+ Test URL generation for embedding models, including numeric model IDs (fine-tuned models
+
+ Relevant issue: https://github.com/BerriAI/litellm/issues/6482
+
+ When a fine-tuned embedding model is used, the URL is different from the standard one.
+ """
+ from litellm.llms.vertex_ai_and_google_ai_studio.common_utils import _get_vertex_url
+
+ url, endpoint = _get_vertex_url(
+ mode="embedding",
+ model=model,
+ stream=False,
+ vertex_project="project-id",
+ vertex_location="us-central1",
+ vertex_api_version="v1",
+ )
+
+ assert url == expected_url
+ assert endpoint == "predict"
diff --git a/tests/local_testing/test_amazing_vertex_completion.py b/tests/local_testing/test_amazing_vertex_completion.py
index 2de53696f..5a07d17b7 100644
--- a/tests/local_testing/test_amazing_vertex_completion.py
+++ b/tests/local_testing/test_amazing_vertex_completion.py
@@ -18,6 +18,8 @@ import json
import os
import tempfile
from unittest.mock import AsyncMock, MagicMock, patch
+from respx import MockRouter
+import httpx
import pytest
@@ -973,6 +975,7 @@ async def test_partner_models_httpx(model, sync_mode):
data = {
"model": model,
"messages": messages,
+ "timeout": 10,
}
if sync_mode:
response = litellm.completion(**data)
@@ -986,6 +989,8 @@ async def test_partner_models_httpx(model, sync_mode):
assert isinstance(response._hidden_params["response_cost"], float)
except litellm.RateLimitError as e:
pass
+ except litellm.Timeout as e:
+ pass
except litellm.InternalServerError as e:
pass
except Exception as e:
@@ -3051,3 +3056,70 @@ def test_custom_api_base(api_base):
assert url == api_base + ":"
else:
assert url == test_endpoint
+
+
+@pytest.mark.asyncio
+@pytest.mark.respx
+async def test_vertexai_embedding_finetuned(respx_mock: MockRouter):
+ """
+ Tests that:
+ - Request URL and body are correctly formatted for Vertex AI embeddings
+ - Response is properly parsed into litellm's embedding response format
+ """
+ load_vertex_ai_credentials()
+ litellm.set_verbose = True
+
+ # Test input
+ input_text = ["good morning from litellm", "this is another item"]
+
+ # Expected request/response
+ expected_url = "https://us-central1-aiplatform.googleapis.com/v1/projects/633608382793/locations/us-central1/endpoints/1004708436694269952:predict"
+ expected_request = {
+ "instances": [
+ {"inputs": "good morning from litellm"},
+ {"inputs": "this is another item"},
+ ],
+ "parameters": {},
+ }
+
+ mock_response = {
+ "predictions": [
+ [[-0.000431762, -0.04416759, -0.03443353]], # Truncated embedding vector
+ [[-0.000431762, -0.04416759, -0.03443353]], # Truncated embedding vector
+ ],
+ "deployedModelId": "2275167734310371328",
+ "model": "projects/633608382793/locations/us-central1/models/snowflake-arctic-embed-m-long-1731622468876",
+ "modelDisplayName": "snowflake-arctic-embed-m-long-1731622468876",
+ "modelVersionId": "1",
+ }
+
+ # Setup mock request
+ mock_request = respx_mock.post(expected_url).mock(
+ return_value=httpx.Response(200, json=mock_response)
+ )
+
+ # Make request
+ response = await litellm.aembedding(
+ vertex_project="633608382793",
+ model="vertex_ai/1004708436694269952",
+ input=input_text,
+ )
+
+ # Assert request was made correctly
+ assert mock_request.called
+ request_body = json.loads(mock_request.calls[0].request.content)
+ print("\n\nrequest_body", request_body)
+ print("\n\nexpected_request", expected_request)
+ assert request_body == expected_request
+
+ # Assert response structure
+ assert response is not None
+ assert hasattr(response, "data")
+ assert len(response.data) == len(input_text)
+
+ # Assert embedding structure
+ for embedding in response.data:
+ assert "embedding" in embedding
+ assert isinstance(embedding["embedding"], list)
+ assert len(embedding["embedding"]) > 0
+ assert all(isinstance(x, float) for x in embedding["embedding"])
diff --git a/tests/local_testing/test_aws_secret_manager.py b/tests/local_testing/test_aws_secret_manager.py
new file mode 100644
index 000000000..f2e2319cc
--- /dev/null
+++ b/tests/local_testing/test_aws_secret_manager.py
@@ -0,0 +1,139 @@
+# What is this?
+
+import asyncio
+import os
+import sys
+import traceback
+
+from dotenv import load_dotenv
+
+import litellm.types
+import litellm.types.utils
+
+
+load_dotenv()
+import io
+
+import sys
+import os
+
+# Ensure the project root is in the Python path
+sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..")))
+
+print("Python Path:", sys.path)
+print("Current Working Directory:", os.getcwd())
+
+
+from typing import Optional
+from unittest.mock import MagicMock, patch
+
+import pytest
+import uuid
+import json
+from litellm.secret_managers.aws_secret_manager_v2 import AWSSecretsManagerV2
+
+
+def check_aws_credentials():
+ """Helper function to check if AWS credentials are set"""
+ required_vars = ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_REGION_NAME"]
+ missing_vars = [var for var in required_vars if not os.getenv(var)]
+ if missing_vars:
+ pytest.skip(f"Missing required AWS credentials: {', '.join(missing_vars)}")
+
+
+@pytest.mark.asyncio
+async def test_write_and_read_simple_secret():
+ """Test writing and reading a simple string secret"""
+ check_aws_credentials()
+
+ secret_manager = AWSSecretsManagerV2()
+ test_secret_name = f"litellm_test_{uuid.uuid4().hex[:8]}"
+ test_secret_value = "test_value_123"
+
+ try:
+ # Write secret
+ write_response = await secret_manager.async_write_secret(
+ secret_name=test_secret_name,
+ secret_value=test_secret_value,
+ description="LiteLLM Test Secret",
+ )
+
+ print("Write Response:", write_response)
+
+ assert write_response is not None
+ assert "ARN" in write_response
+ assert "Name" in write_response
+ assert write_response["Name"] == test_secret_name
+
+ # Read secret back
+ read_value = await secret_manager.async_read_secret(
+ secret_name=test_secret_name
+ )
+
+ print("Read Value:", read_value)
+
+ assert read_value == test_secret_value
+ finally:
+ # Cleanup: Delete the secret
+ delete_response = await secret_manager.async_delete_secret(
+ secret_name=test_secret_name
+ )
+ print("Delete Response:", delete_response)
+ assert delete_response is not None
+
+
+@pytest.mark.asyncio
+async def test_write_and_read_json_secret():
+ """Test writing and reading a JSON structured secret"""
+ check_aws_credentials()
+
+ secret_manager = AWSSecretsManagerV2()
+ test_secret_name = f"litellm_test_{uuid.uuid4().hex[:8]}_json"
+ test_secret_value = {
+ "api_key": "test_key",
+ "model": "gpt-4",
+ "temperature": 0.7,
+ "metadata": {"team": "ml", "project": "litellm"},
+ }
+
+ try:
+ # Write JSON secret
+ write_response = await secret_manager.async_write_secret(
+ secret_name=test_secret_name,
+ secret_value=json.dumps(test_secret_value),
+ description="LiteLLM JSON Test Secret",
+ )
+
+ print("Write Response:", write_response)
+
+ # Read and parse JSON secret
+ read_value = await secret_manager.async_read_secret(
+ secret_name=test_secret_name
+ )
+ parsed_value = json.loads(read_value)
+
+ print("Read Value:", read_value)
+
+ assert parsed_value == test_secret_value
+ assert parsed_value["api_key"] == "test_key"
+ assert parsed_value["metadata"]["team"] == "ml"
+ finally:
+ # Cleanup: Delete the secret
+ delete_response = await secret_manager.async_delete_secret(
+ secret_name=test_secret_name
+ )
+ print("Delete Response:", delete_response)
+ assert delete_response is not None
+
+
+@pytest.mark.asyncio
+async def test_read_nonexistent_secret():
+ """Test reading a secret that doesn't exist"""
+ check_aws_credentials()
+
+ secret_manager = AWSSecretsManagerV2()
+ nonexistent_secret = f"litellm_nonexistent_{uuid.uuid4().hex}"
+
+ response = await secret_manager.async_read_secret(secret_name=nonexistent_secret)
+
+ assert response is None
diff --git a/tests/local_testing/test_completion.py b/tests/local_testing/test_completion.py
index 43bcfc882..3ce4cb7d7 100644
--- a/tests/local_testing/test_completion.py
+++ b/tests/local_testing/test_completion.py
@@ -24,7 +24,7 @@ from litellm import RateLimitError, Timeout, completion, completion_cost, embedd
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.llms.prompt_templates.factory import anthropic_messages_pt
-# litellm.num_retries = 3
+# litellm.num_retries=3
litellm.cache = None
litellm.success_callback = []
diff --git a/tests/local_testing/test_cost_calc.py b/tests/local_testing/test_cost_calc.py
index ecead0679..1831c2a45 100644
--- a/tests/local_testing/test_cost_calc.py
+++ b/tests/local_testing/test_cost_calc.py
@@ -10,7 +10,7 @@ import os
sys.path.insert(
0, os.path.abspath("../..")
-) # Adds the parent directory to the system path
+) # Adds the parent directory to the system-path
from typing import Literal
import pytest
diff --git a/tests/local_testing/test_custom_callback_input.py b/tests/local_testing/test_custom_callback_input.py
index 1744d3891..9b7b6d532 100644
--- a/tests/local_testing/test_custom_callback_input.py
+++ b/tests/local_testing/test_custom_callback_input.py
@@ -1624,3 +1624,55 @@ async def test_standard_logging_payload_stream_usage(sync_mode):
print(f"standard_logging_object usage: {built_response.usage}")
except litellm.InternalServerError:
pass
+
+
+def test_standard_logging_retries():
+ """
+ know if a request was retried.
+ """
+ from litellm.types.utils import StandardLoggingPayload
+ from litellm.router import Router
+
+ customHandler = CompletionCustomHandler()
+ litellm.callbacks = [customHandler]
+
+ router = Router(
+ model_list=[
+ {
+ "model_name": "gpt-3.5-turbo",
+ "litellm_params": {
+ "model": "openai/gpt-3.5-turbo",
+ "api_key": "test-api-key",
+ },
+ }
+ ]
+ )
+
+ with patch.object(
+ customHandler, "log_failure_event", new=MagicMock()
+ ) as mock_client:
+ try:
+ router.completion(
+ model="gpt-3.5-turbo",
+ messages=[{"role": "user", "content": "Hey, how's it going?"}],
+ num_retries=1,
+ mock_response="litellm.RateLimitError",
+ )
+ except litellm.RateLimitError:
+ pass
+
+ assert mock_client.call_count == 2
+ assert (
+ mock_client.call_args_list[0].kwargs["kwargs"]["standard_logging_object"][
+ "trace_id"
+ ]
+ is not None
+ )
+ assert (
+ mock_client.call_args_list[0].kwargs["kwargs"]["standard_logging_object"][
+ "trace_id"
+ ]
+ == mock_client.call_args_list[1].kwargs["kwargs"][
+ "standard_logging_object"
+ ]["trace_id"]
+ )
diff --git a/tests/local_testing/test_exceptions.py b/tests/local_testing/test_exceptions.py
index d5f67cecf..67c36928f 100644
--- a/tests/local_testing/test_exceptions.py
+++ b/tests/local_testing/test_exceptions.py
@@ -58,6 +58,7 @@ async def test_content_policy_exception_azure():
except litellm.ContentPolicyViolationError as e:
print("caught a content policy violation error! Passed")
print("exception", e)
+ assert e.response is not None
assert e.litellm_debug_info is not None
assert isinstance(e.litellm_debug_info, str)
assert len(e.litellm_debug_info) > 0
@@ -1152,3 +1153,24 @@ async def test_exception_with_headers_httpx(
if exception_raised is False:
print(resp)
assert exception_raised
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("model", ["azure/chatgpt-v-2", "openai/gpt-3.5-turbo"])
+async def test_bad_request_error_contains_httpx_response(model):
+ """
+ Test that the BadRequestError contains the httpx response
+
+ Relevant issue: https://github.com/BerriAI/litellm/issues/6732
+ """
+ try:
+ await litellm.acompletion(
+ model=model,
+ messages=[{"role": "user", "content": "Hello world"}],
+ bad_arg="bad_arg",
+ )
+ pytest.fail("Expected to raise BadRequestError")
+ except litellm.BadRequestError as e:
+ print("e.response", e.response)
+ print("vars(e.response)", vars(e.response))
+ assert e.response is not None
diff --git a/tests/local_testing/test_get_llm_provider.py b/tests/local_testing/test_get_llm_provider.py
index 6654c10c2..423ffe2fd 100644
--- a/tests/local_testing/test_get_llm_provider.py
+++ b/tests/local_testing/test_get_llm_provider.py
@@ -157,7 +157,7 @@ def test_get_llm_provider_jina_ai():
model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider(
model="jina_ai/jina-embeddings-v3",
)
- assert custom_llm_provider == "openai_like"
+ assert custom_llm_provider == "jina_ai"
assert api_base == "https://api.jina.ai/v1"
assert model == "jina-embeddings-v3"
diff --git a/tests/local_testing/test_get_model_info.py b/tests/local_testing/test_get_model_info.py
index 82ce9c465..11506ed3d 100644
--- a/tests/local_testing/test_get_model_info.py
+++ b/tests/local_testing/test_get_model_info.py
@@ -89,11 +89,16 @@ def test_get_model_info_ollama_chat():
"template": "tools",
}
),
- ):
+ ) as mock_client:
info = OllamaConfig().get_model_info("mistral")
- print("info", info)
assert info["supports_function_calling"] is True
info = get_model_info("ollama/mistral")
- print("info", info)
+
assert info["supports_function_calling"] is True
+
+ mock_client.assert_called()
+
+ print(mock_client.call_args.kwargs)
+
+ assert mock_client.call_args.kwargs["json"]["name"] == "mistral"
diff --git a/tests/local_testing/test_router_fallbacks.py b/tests/local_testing/test_router_fallbacks.py
index cad640a54..3c9750691 100644
--- a/tests/local_testing/test_router_fallbacks.py
+++ b/tests/local_testing/test_router_fallbacks.py
@@ -1138,9 +1138,9 @@ async def test_router_content_policy_fallbacks(
router = Router(
model_list=[
{
- "model_name": "claude-2",
+ "model_name": "claude-2.1",
"litellm_params": {
- "model": "claude-2",
+ "model": "claude-2.1",
"api_key": "",
"mock_response": mock_response,
},
@@ -1164,7 +1164,7 @@ async def test_router_content_policy_fallbacks(
{
"model_name": "my-general-model",
"litellm_params": {
- "model": "claude-2",
+ "model": "claude-2.1",
"api_key": "",
"mock_response": Exception("Should not have called this."),
},
@@ -1172,14 +1172,14 @@ async def test_router_content_policy_fallbacks(
{
"model_name": "my-context-window-model",
"litellm_params": {
- "model": "claude-2",
+ "model": "claude-2.1",
"api_key": "",
"mock_response": Exception("Should not have called this."),
},
},
],
content_policy_fallbacks=(
- [{"claude-2": ["my-fallback-model"]}]
+ [{"claude-2.1": ["my-fallback-model"]}]
if fallback_type == "model-specific"
else None
),
@@ -1190,12 +1190,12 @@ async def test_router_content_policy_fallbacks(
if sync_mode is True:
response = router.completion(
- model="claude-2",
+ model="claude-2.1",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
)
else:
response = await router.acompletion(
- model="claude-2",
+ model="claude-2.1",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
)
@@ -1455,3 +1455,46 @@ async def test_router_fallbacks_default_and_model_specific_fallbacks(sync_mode):
assert isinstance(
exc_info.value, litellm.AuthenticationError
), f"Expected AuthenticationError, but got {type(exc_info.value).__name__}"
+
+
+@pytest.mark.asyncio
+async def test_router_disable_fallbacks_dynamically():
+ from litellm.router import run_async_fallback
+
+ router = Router(
+ model_list=[
+ {
+ "model_name": "bad-model",
+ "litellm_params": {
+ "model": "openai/my-bad-model",
+ "api_key": "my-bad-api-key",
+ },
+ },
+ {
+ "model_name": "good-model",
+ "litellm_params": {
+ "model": "gpt-4o",
+ "api_key": os.getenv("OPENAI_API_KEY"),
+ },
+ },
+ ],
+ fallbacks=[{"bad-model": ["good-model"]}],
+ default_fallbacks=["good-model"],
+ )
+
+ with patch.object(
+ router,
+ "log_retry",
+ new=MagicMock(return_value=None),
+ ) as mock_client:
+ try:
+ resp = await router.acompletion(
+ model="bad-model",
+ messages=[{"role": "user", "content": "Hey, how's it going?"}],
+ disable_fallbacks=True,
+ )
+ print(resp)
+ except Exception as e:
+ print(e)
+
+ mock_client.assert_not_called()
diff --git a/tests/local_testing/test_router_utils.py b/tests/local_testing/test_router_utils.py
index 538ab4d0b..d266cfbd9 100644
--- a/tests/local_testing/test_router_utils.py
+++ b/tests/local_testing/test_router_utils.py
@@ -14,6 +14,7 @@ from litellm.router import Deployment, LiteLLM_Params, ModelInfo
from concurrent.futures import ThreadPoolExecutor
from collections import defaultdict
from dotenv import load_dotenv
+from unittest.mock import patch, MagicMock, AsyncMock
load_dotenv()
@@ -83,3 +84,93 @@ def test_returned_settings():
except Exception:
print(traceback.format_exc())
pytest.fail("An error occurred - " + traceback.format_exc())
+
+
+from litellm.types.utils import CallTypes
+
+
+def test_update_kwargs_before_fallbacks_unit_test():
+ router = Router(
+ model_list=[
+ {
+ "model_name": "gpt-3.5-turbo",
+ "litellm_params": {
+ "model": "azure/chatgpt-v-2",
+ "api_key": "bad-key",
+ "api_version": os.getenv("AZURE_API_VERSION"),
+ "api_base": os.getenv("AZURE_API_BASE"),
+ },
+ }
+ ],
+ )
+
+ kwargs = {"messages": [{"role": "user", "content": "write 1 sentence poem"}]}
+
+ router._update_kwargs_before_fallbacks(
+ model="gpt-3.5-turbo",
+ kwargs=kwargs,
+ )
+
+ assert kwargs["litellm_trace_id"] is not None
+
+
+@pytest.mark.parametrize(
+ "call_type",
+ [
+ CallTypes.acompletion,
+ CallTypes.atext_completion,
+ CallTypes.aembedding,
+ CallTypes.arerank,
+ CallTypes.atranscription,
+ ],
+)
+@pytest.mark.asyncio
+async def test_update_kwargs_before_fallbacks(call_type):
+
+ router = Router(
+ model_list=[
+ {
+ "model_name": "gpt-3.5-turbo",
+ "litellm_params": {
+ "model": "azure/chatgpt-v-2",
+ "api_key": "bad-key",
+ "api_version": os.getenv("AZURE_API_VERSION"),
+ "api_base": os.getenv("AZURE_API_BASE"),
+ },
+ }
+ ],
+ )
+
+ if call_type.value.startswith("a"):
+ with patch.object(router, "async_function_with_fallbacks") as mock_client:
+ if call_type.value == "acompletion":
+ input_kwarg = {
+ "messages": [{"role": "user", "content": "Hello, how are you?"}],
+ }
+ elif (
+ call_type.value == "atext_completion"
+ or call_type.value == "aimage_generation"
+ ):
+ input_kwarg = {
+ "prompt": "Hello, how are you?",
+ }
+ elif call_type.value == "aembedding" or call_type.value == "arerank":
+ input_kwarg = {
+ "input": "Hello, how are you?",
+ }
+ elif call_type.value == "atranscription":
+ input_kwarg = {
+ "file": "path/to/file",
+ }
+ else:
+ input_kwarg = {}
+
+ await getattr(router, call_type.value)(
+ model="gpt-3.5-turbo",
+ **input_kwarg,
+ )
+
+ mock_client.assert_called_once()
+
+ print(mock_client.call_args.kwargs)
+ assert mock_client.call_args.kwargs["litellm_trace_id"] is not None
diff --git a/tests/local_testing/test_secret_manager.py b/tests/local_testing/test_secret_manager.py
index 397128ecb..1b95119a3 100644
--- a/tests/local_testing/test_secret_manager.py
+++ b/tests/local_testing/test_secret_manager.py
@@ -15,22 +15,29 @@ sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest
-
+import litellm
from litellm.llms.AzureOpenAI.azure import get_azure_ad_token_from_oidc
from litellm.llms.bedrock.chat import BedrockConverseLLM, BedrockLLM
-from litellm.secret_managers.aws_secret_manager import load_aws_secret_manager
-from litellm.secret_managers.main import get_secret
+from litellm.secret_managers.aws_secret_manager_v2 import AWSSecretsManagerV2
+from litellm.secret_managers.main import (
+ get_secret,
+ _should_read_secret_from_secret_manager,
+)
-@pytest.mark.skip(reason="AWS Suspended Account")
def test_aws_secret_manager():
- load_aws_secret_manager(use_aws_secret_manager=True)
+ import json
+
+ AWSSecretsManagerV2.load_aws_secret_manager(use_aws_secret_manager=True)
secret_val = get_secret("litellm_master_key")
print(f"secret_val: {secret_val}")
- assert secret_val == "sk-1234"
+ # cast json to dict
+ secret_val = json.loads(secret_val)
+
+ assert secret_val["litellm_master_key"] == "sk-1234"
def redact_oidc_signature(secret_val):
@@ -240,3 +247,71 @@ def test_google_secret_manager_read_in_memory():
)
print("secret_val: {}".format(secret_val))
assert secret_val == "lite-llm"
+
+
+def test_should_read_secret_from_secret_manager():
+ """
+ Test that _should_read_secret_from_secret_manager returns correct values based on access mode
+ """
+ from litellm.proxy._types import KeyManagementSettings
+
+ # Test when secret manager client is None
+ litellm.secret_manager_client = None
+ litellm._key_management_settings = KeyManagementSettings()
+ assert _should_read_secret_from_secret_manager() is False
+
+ # Test with secret manager client and read_only access
+ litellm.secret_manager_client = "dummy_client"
+ litellm._key_management_settings = KeyManagementSettings(access_mode="read_only")
+ assert _should_read_secret_from_secret_manager() is True
+
+ # Test with secret manager client and read_and_write access
+ litellm._key_management_settings = KeyManagementSettings(
+ access_mode="read_and_write"
+ )
+ assert _should_read_secret_from_secret_manager() is True
+
+ # Test with secret manager client and write_only access
+ litellm._key_management_settings = KeyManagementSettings(access_mode="write_only")
+ assert _should_read_secret_from_secret_manager() is False
+
+ # Reset global variables
+ litellm.secret_manager_client = None
+ litellm._key_management_settings = KeyManagementSettings()
+
+
+def test_get_secret_with_access_mode():
+ """
+ Test that get_secret respects access mode settings
+ """
+ from litellm.proxy._types import KeyManagementSettings
+
+ # Set up test environment
+ test_secret_name = "TEST_SECRET_KEY"
+ test_secret_value = "test_secret_value"
+ os.environ[test_secret_name] = test_secret_value
+
+ # Test with write_only access (should read from os.environ)
+ litellm.secret_manager_client = "dummy_client"
+ litellm._key_management_settings = KeyManagementSettings(access_mode="write_only")
+ assert get_secret(test_secret_name) == test_secret_value
+
+ # Test with no KeyManagementSettings but secret_manager_client set
+ litellm.secret_manager_client = "dummy_client"
+ litellm._key_management_settings = KeyManagementSettings()
+ assert _should_read_secret_from_secret_manager() is True
+
+ # Test with read_only access
+ litellm._key_management_settings = KeyManagementSettings(access_mode="read_only")
+ assert _should_read_secret_from_secret_manager() is True
+
+ # Test with read_and_write access
+ litellm._key_management_settings = KeyManagementSettings(
+ access_mode="read_and_write"
+ )
+ assert _should_read_secret_from_secret_manager() is True
+
+ # Reset global variables
+ litellm.secret_manager_client = None
+ litellm._key_management_settings = KeyManagementSettings()
+ del os.environ[test_secret_name]
diff --git a/tests/local_testing/test_stream_chunk_builder.py b/tests/local_testing/test_stream_chunk_builder.py
index 2548abdb7..4fb44299d 100644
--- a/tests/local_testing/test_stream_chunk_builder.py
+++ b/tests/local_testing/test_stream_chunk_builder.py
@@ -184,12 +184,11 @@ def test_stream_chunk_builder_litellm_usage_chunks():
{"role": "assistant", "content": "uhhhh\n\n\nhmmmm.....\nthinking....\n"},
{"role": "user", "content": "\nI am waiting...\n\n...\n"},
]
- # make a regular gemini call
usage: litellm.Usage = Usage(
- completion_tokens=64,
+ completion_tokens=27,
prompt_tokens=55,
- total_tokens=119,
+ total_tokens=82,
completion_tokens_details=None,
prompt_tokens_details=None,
)
diff --git a/tests/local_testing/test_streaming.py b/tests/local_testing/test_streaming.py
index bc4827d92..0bc6953f9 100644
--- a/tests/local_testing/test_streaming.py
+++ b/tests/local_testing/test_streaming.py
@@ -718,7 +718,7 @@ async def test_acompletion_claude_2_stream():
try:
litellm.set_verbose = True
response = await litellm.acompletion(
- model="claude-2",
+ model="claude-2.1",
messages=[{"role": "user", "content": "hello from litellm"}],
stream=True,
)
@@ -3274,7 +3274,7 @@ def test_completion_claude_3_function_call_with_streaming():
], # "claude-3-opus-20240229"
) #
@pytest.mark.asyncio
-async def test_acompletion_claude_3_function_call_with_streaming(model):
+async def test_acompletion_function_call_with_streaming(model):
litellm.set_verbose = True
tools = [
{
@@ -3335,6 +3335,8 @@ async def test_acompletion_claude_3_function_call_with_streaming(model):
# raise Exception("it worked! ")
except litellm.InternalServerError as e:
pytest.skip(f"InternalServerError - {str(e)}")
+ except litellm.ServiceUnavailableError:
+ pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
diff --git a/tests/proxy_unit_tests/test_key_generate_prisma.py b/tests/proxy_unit_tests/test_key_generate_prisma.py
index 78b558cd2..b97ab3514 100644
--- a/tests/proxy_unit_tests/test_key_generate_prisma.py
+++ b/tests/proxy_unit_tests/test_key_generate_prisma.py
@@ -3451,3 +3451,90 @@ async def test_user_api_key_auth_db_unavailable_not_allowed():
request=request,
api_key="Bearer sk-123456789",
)
+
+
+## E2E Virtual Key + Secret Manager Tests #########################################
+
+
+@pytest.mark.asyncio
+async def test_key_generate_with_secret_manager_call(prisma_client):
+ """
+ Generate a key
+ assert it exists in the secret manager
+
+ delete the key
+ assert it is deleted from the secret manager
+ """
+ from litellm.secret_managers.aws_secret_manager_v2 import AWSSecretsManagerV2
+ from litellm.proxy._types import KeyManagementSystem, KeyManagementSettings
+
+ litellm.set_verbose = True
+
+ #### Test Setup ############################################################
+ aws_secret_manager_client = AWSSecretsManagerV2()
+ litellm.secret_manager_client = aws_secret_manager_client
+ litellm._key_management_system = KeyManagementSystem.AWS_SECRET_MANAGER
+ litellm._key_management_settings = KeyManagementSettings(
+ store_virtual_keys=True,
+ )
+ general_settings = {
+ "key_management_system": "aws_secret_manager",
+ "key_management_settings": {
+ "store_virtual_keys": True,
+ },
+ }
+
+ setattr(litellm.proxy.proxy_server, "general_settings", general_settings)
+ setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
+ setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
+ await litellm.proxy.proxy_server.prisma_client.connect()
+ ############################################################################
+
+ # generate new key
+ key_alias = f"test_alias_secret_manager_key-{uuid.uuid4()}"
+ spend = 100
+ max_budget = 400
+ models = ["fake-openai-endpoint"]
+ new_key = await generate_key_fn(
+ data=GenerateKeyRequest(
+ key_alias=key_alias, spend=spend, max_budget=max_budget, models=models
+ ),
+ user_api_key_dict=UserAPIKeyAuth(
+ user_role=LitellmUserRoles.PROXY_ADMIN,
+ api_key="sk-1234",
+ user_id="1234",
+ ),
+ )
+
+ generated_key = new_key.key
+ print(generated_key)
+
+ await asyncio.sleep(2)
+
+ # read from the secret manager
+ result = await aws_secret_manager_client.async_read_secret(secret_name=key_alias)
+
+ # Assert the correct key is stored in the secret manager
+ print("response from AWS Secret Manager")
+ print(result)
+ assert result == generated_key
+
+ # delete the key
+ await delete_key_fn(
+ data=KeyRequest(keys=[generated_key]),
+ user_api_key_dict=UserAPIKeyAuth(
+ user_role=LitellmUserRoles.PROXY_ADMIN, api_key="sk-1234", user_id="1234"
+ ),
+ )
+
+ await asyncio.sleep(2)
+
+ # Assert the key is deleted from the secret manager
+ result = await aws_secret_manager_client.async_read_secret(secret_name=key_alias)
+ assert result is None
+
+ # cleanup
+ setattr(litellm.proxy.proxy_server, "general_settings", {})
+
+
+################################################################################
diff --git a/tests/proxy_unit_tests/test_proxy_server.py b/tests/proxy_unit_tests/test_proxy_server.py
index 5588d0414..b1c00ce75 100644
--- a/tests/proxy_unit_tests/test_proxy_server.py
+++ b/tests/proxy_unit_tests/test_proxy_server.py
@@ -1500,6 +1500,31 @@ async def test_add_callback_via_key_litellm_pre_call_utils(
assert new_data["failure_callback"] == expected_failure_callbacks
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ "disable_fallbacks_set",
+ [
+ True,
+ False,
+ ],
+)
+async def test_disable_fallbacks_by_key(disable_fallbacks_set):
+ from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup
+
+ key_metadata = {"disable_fallbacks": disable_fallbacks_set}
+ existing_data = {
+ "model": "azure/chatgpt-v-2",
+ "messages": [{"role": "user", "content": "write 1 sentence poem"}],
+ }
+ data = LiteLLMProxyRequestSetup.add_key_level_controls(
+ key_metadata=key_metadata,
+ data=existing_data,
+ _metadata_variable_name="metadata",
+ )
+
+ assert data["disable_fallbacks"] == disable_fallbacks_set
+
+
@pytest.mark.asyncio
@pytest.mark.parametrize(
"callback_type, expected_success_callbacks, expected_failure_callbacks",