Merge branch 'main' into litellm_dev_11_13_2024

This commit is contained in:
Krish Dholakia 2024-11-15 11:18:02 +05:30 committed by GitHub
commit 1dcbfda202
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
76 changed files with 2836 additions and 560 deletions

View file

@ -305,6 +305,36 @@ Step 4: Submit a PR with your changes! 🚀
- push your fork to your GitHub repo
- submit a PR from there
### Building LiteLLM Docker Image
Follow these instructions if you want to build / run the LiteLLM Docker Image yourself.
Step 1: Clone the repo
```
git clone https://github.com/BerriAI/litellm.git
```
Step 2: Build the Docker Image
Build using Dockerfile.non_root
```
docker build -f docker/Dockerfile.non_root -t litellm_test_image .
```
Step 3: Run the Docker Image
Make sure config.yaml is present in the root directory. This is your litellm proxy config file.
```
docker run \
-v $(pwd)/proxy_config.yaml:/app/config.yaml \
-e DATABASE_URL="postgresql://xxxxxxxx" \
-e LITELLM_MASTER_KEY="sk-1234" \
-p 4000:4000 \
litellm_test_image \
--config /app/config.yaml --detailed_debug
```
# Enterprise
For companies that need better security, user management and professional support

View file

@ -13,18 +13,18 @@ spec:
spec:
containers:
- name: prisma-migrations
image: "ghcr.io/berriai/litellm:main-stable"
image: ghcr.io/berriai/litellm-database:main-latest
command: ["python", "litellm/proxy/prisma_migration.py"]
workingDir: "/app"
env:
{{- if .Values.db.deployStandalone }}
- name: DATABASE_URL
value: postgresql://{{ .Values.postgresql.auth.username }}:{{ .Values.postgresql.auth.password }}@{{ .Release.Name }}-postgresql/{{ .Values.postgresql.auth.database }}
{{- else if .Values.db.useExisting }}
{{- if .Values.db.useExisting }}
- name: DATABASE_URL
value: {{ .Values.db.url | quote }}
{{- else }}
- name: DATABASE_URL
value: postgresql://{{ .Values.postgresql.auth.username }}:{{ .Values.postgresql.auth.password }}@{{ .Release.Name }}-postgresql/{{ .Values.postgresql.auth.database }}
{{- end }}
- name: DISABLE_SCHEMA_UPDATE
value: "{{ .Values.migrationJob.disableSchemaUpdate }}"
value: "false" # always run the migration from the Helm PreSync hook, override the value set
restartPolicy: OnFailure
backoffLimit: {{ .Values.migrationJob.backoffLimit }}

View file

@ -75,6 +75,7 @@ Works for:
- Google AI Studio - Gemini models
- Vertex AI models (Gemini + Anthropic)
- Bedrock Models
- Anthropic API Models
<Tabs>
<TabItem value="sdk" label="SDK">

View file

@ -93,7 +93,7 @@ curl http://0.0.0.0:4000/v1/chat/completions \
## Check Model Support
Call `litellm.get_model_info` to check if a model/provider supports `response_format`.
Call `litellm.get_model_info` to check if a model/provider supports `prefix`.
<Tabs>
<TabItem value="sdk" label="SDK">

View file

@ -957,3 +957,69 @@ curl http://0.0.0.0:4000/v1/chat/completions \
```
</TabItem>
</Tabs>
## Usage - passing 'user_id' to Anthropic
LiteLLM translates the OpenAI `user` param to Anthropic's `metadata[user_id]` param.
<Tabs>
<TabItem value="sdk" label="SDK">
```python
response = completion(
model="claude-3-5-sonnet-20240620",
messages=messages,
user="user_123",
)
```
</TabItem>
<TabItem value="proxy" label="PROXY">
1. Setup config.yaml
```yaml
model_list:
- model_name: claude-3-5-sonnet-20240620
litellm_params:
model: anthropic/claude-3-5-sonnet-20240620
api_key: os.environ/ANTHROPIC_API_KEY
```
2. Start Proxy
```
litellm --config /path/to/config.yaml
```
3. Test it!
```bash
curl http://0.0.0.0:4000/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer <YOUR-LITELLM-KEY>" \
-d '{
"model": "claude-3-5-sonnet-20240620",
"messages": [{"role": "user", "content": "What is Anthropic?"}],
"user": "user_123"
}'
```
</TabItem>
</Tabs>
## All Supported OpenAI Params
```
"stream",
"stop",
"temperature",
"top_p",
"max_tokens",
"max_completion_tokens",
"tools",
"tool_choice",
"extra_headers",
"parallel_tool_calls",
"response_format",
"user"
```

View file

@ -37,7 +37,7 @@ os.environ["HUGGINGFACE_API_KEY"] = "huggingface_api_key"
messages = [{ "content": "There's a llama in my garden 😱 What should I do?","role": "user"}]
# e.g. Call 'https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct' from Serverless Inference API
response = litellm.completion(
response = completion(
model="huggingface/meta-llama/Meta-Llama-3.1-8B-Instruct",
messages=[{ "content": "Hello, how are you?","role": "user"}],
stream=True
@ -165,14 +165,14 @@ Steps to use
```python
import os
import litellm
from litellm import completion
os.environ["HUGGINGFACE_API_KEY"] = ""
# TGI model: Call https://huggingface.co/glaiveai/glaive-coder-7b
# add the 'huggingface/' prefix to the model to set huggingface as the provider
# set api base to your deployed api endpoint from hugging face
response = litellm.completion(
response = completion(
model="huggingface/glaiveai/glaive-coder-7b",
messages=[{ "content": "Hello, how are you?","role": "user"}],
api_base="https://wjiegasee9bmqke2.us-east-1.aws.endpoints.huggingface.cloud"
@ -383,6 +383,8 @@ def default_pt(messages):
#### Custom prompt templates
```python
import litellm
# Create your own custom prompt template works
litellm.register_prompt_template(
model="togethercomputer/LLaMA-2-7B-32K",

View file

@ -1,6 +1,13 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# Jina AI
https://jina.ai/embeddings/
Supported endpoints:
- /embeddings
- /rerank
## API Key
```python
# env variable
@ -8,6 +15,10 @@ os.environ['JINA_AI_API_KEY']
```
## Sample Usage - Embedding
<Tabs>
<TabItem value="sdk" label="SDK">
```python
from litellm import embedding
import os
@ -19,6 +30,142 @@ response = embedding(
)
print(response)
```
</TabItem>
<TabItem value="proxy" label="PROXY">
1. Add to config.yaml
```yaml
model_list:
- model_name: embedding-model
litellm_params:
model: jina_ai/jina-embeddings-v3
api_key: os.environ/JINA_AI_API_KEY
```
2. Start proxy
```bash
litellm --config /path/to/config.yaml
# RUNNING on http://0.0.0.0:4000/
```
3. Test it!
```bash
curl -L -X POST 'http://0.0.0.0:4000/embeddings' \
-H 'Authorization: Bearer sk-1234' \
-H 'Content-Type: application/json' \
-d '{"input": ["hello world"], "model": "embedding-model"}'
```
</TabItem>
</Tabs>
## Sample Usage - Rerank
<Tabs>
<TabItem value="sdk" label="SDK">
```python
from litellm import rerank
import os
os.environ["JINA_AI_API_KEY"] = "sk-..."
query = "What is the capital of the United States?"
documents = [
"Carson City is the capital city of the American state of Nevada.",
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.",
"Washington, D.C. is the capital of the United States.",
"Capital punishment has existed in the United States since before it was a country.",
]
response = rerank(
model="jina_ai/jina-reranker-v2-base-multilingual",
query=query,
documents=documents,
top_n=3,
)
print(response)
```
</TabItem>
<TabItem value="proxy" label="PROXY">
1. Add to config.yaml
```yaml
model_list:
- model_name: rerank-model
litellm_params:
model: jina_ai/jina-reranker-v2-base-multilingual
api_key: os.environ/JINA_AI_API_KEY
```
2. Start proxy
```bash
litellm --config /path/to/config.yaml
```
3. Test it!
```bash
curl -L -X POST 'http://0.0.0.0:4000/rerank' \
-H 'Authorization: Bearer sk-1234' \
-H 'Content-Type: application/json' \
-d '{
"model": "rerank-model",
"query": "What is the capital of the United States?",
"documents": [
"Carson City is the capital city of the American state of Nevada.",
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.",
"Washington, D.C. is the capital of the United States.",
"Capital punishment has existed in the United States since before it was a country."
],
"top_n": 3
}'
```
</TabItem>
</Tabs>
## Supported Models
All models listed here https://jina.ai/embeddings/ are supported
## Supported Optional Rerank Parameters
All cohere rerank parameters are supported.
## Supported Optional Embeddings Parameters
```
dimensions
```
## Provider-specific parameters
Pass any jina ai specific parameters as a keyword argument to the `embedding` or `rerank` function, e.g.
<Tabs>
<TabItem value="sdk" label="SDK">
```python
response = embedding(
model="jina_ai/jina-embeddings-v3",
input=["good morning from litellm"],
dimensions=1536,
my_custom_param="my_custom_value", # any other jina ai specific parameters
)
```
</TabItem>
<TabItem value="proxy" label="PROXY">
```bash
curl -L -X POST 'http://0.0.0.0:4000/embeddings' \
-H 'Authorization: Bearer sk-1234' \
-H 'Content-Type: application/json' \
-d '{"input": ["good morning from litellm"], "model": "jina_ai/jina-embeddings-v3", "dimensions": 1536, "my_custom_param": "my_custom_value"}'
```
</TabItem>
</Tabs>

View file

@ -1562,6 +1562,10 @@ curl http://0.0.0.0:4000/v1/chat/completions \
## **Embedding Models**
#### Usage - Embedding
<Tabs>
<TabItem value="sdk" label="SDK">
```python
import litellm
from litellm import embedding
@ -1574,6 +1578,49 @@ response = embedding(
)
print(response)
```
</TabItem>
<TabItem value="proxy" label="LiteLLM PROXY">
1. Add model to config.yaml
```yaml
model_list:
- model_name: snowflake-arctic-embed-m-long-1731622468876
litellm_params:
model: vertex_ai/<your-model-id>
vertex_project: "adroit-crow-413218"
vertex_location: "us-central1"
vertex_credentials: adroit-crow-413218-a956eef1a2a8.json
litellm_settings:
drop_params: True
```
2. Start Proxy
```
$ litellm --config /path/to/config.yaml
```
3. Make Request using OpenAI Python SDK, Langchain Python SDK
```python
import openai
client = openai.OpenAI(api_key="sk-1234", base_url="http://0.0.0.0:4000")
response = client.embeddings.create(
model="snowflake-arctic-embed-m-long-1731622468876",
input = ["good morning from litellm", "this is another item"],
)
print(response)
```
</TabItem>
</Tabs>
#### Supported Embedding Models
All models listed [here](https://github.com/BerriAI/litellm/blob/57f37f743886a0249f630a6792d49dffc2c5d9b7/model_prices_and_context_window.json#L835) are supported
@ -1589,6 +1636,7 @@ All models listed [here](https://github.com/BerriAI/litellm/blob/57f37f743886a02
| textembedding-gecko@003 | `embedding(model="vertex_ai/textembedding-gecko@003", input)` |
| text-embedding-preview-0409 | `embedding(model="vertex_ai/text-embedding-preview-0409", input)` |
| text-multilingual-embedding-preview-0409 | `embedding(model="vertex_ai/text-multilingual-embedding-preview-0409", input)` |
| Fine-tuned OR Custom Embedding models | `embedding(model="vertex_ai/<your-model-id>", input)` |
### Supported OpenAI (Unified) Params

View file

@ -791,9 +791,9 @@ general_settings:
| store_model_in_db | boolean | If true, allows `/model/new` endpoint to store model information in db. Endpoint disabled by default. [Doc on `/model/new` endpoint](./model_management.md#create-a-new-model) |
| max_request_size_mb | int | The maximum size for requests in MB. Requests above this size will be rejected. |
| max_response_size_mb | int | The maximum size for responses in MB. LLM Responses above this size will not be sent. |
| proxy_budget_rescheduler_min_time | int | The minimum time (in seconds) to wait before checking db for budget resets. |
| proxy_budget_rescheduler_max_time | int | The maximum time (in seconds) to wait before checking db for budget resets. |
| proxy_batch_write_at | int | Time (in seconds) to wait before batch writing spend logs to the db. |
| proxy_budget_rescheduler_min_time | int | The minimum time (in seconds) to wait before checking db for budget resets. **Default is 597 seconds** |
| proxy_budget_rescheduler_max_time | int | The maximum time (in seconds) to wait before checking db for budget resets. **Default is 605 seconds** |
| proxy_batch_write_at | int | Time (in seconds) to wait before batch writing spend logs to the db. **Default is 10 seconds** |
| alerting_args | dict | Args for Slack Alerting [Doc on Slack Alerting](./alerting.md) |
| custom_key_generate | str | Custom function for key generation [Doc on custom key generation](./virtual_keys.md#custom--key-generate) |
| allowed_ips | List[str] | List of IPs allowed to access the proxy. If not set, all IPs are allowed. |

View file

@ -66,10 +66,16 @@ Removes any field with `user_api_key_*` from metadata.
Found under `kwargs["standard_logging_object"]`. This is a standard payload, logged for every response.
```python
class StandardLoggingPayload(TypedDict):
id: str
trace_id: str # Trace multiple LLM calls belonging to same overall request (e.g. fallbacks/retries)
call_type: str
response_cost: float
response_cost_failure_debug_info: Optional[
StandardLoggingModelCostFailureDebugInformation
]
status: StandardLoggingPayloadStatus
total_tokens: int
prompt_tokens: int
completion_tokens: int
@ -84,13 +90,13 @@ class StandardLoggingPayload(TypedDict):
metadata: StandardLoggingMetadata
cache_hit: Optional[bool]
cache_key: Optional[str]
saved_cache_cost: Optional[float]
saved_cache_cost: float
request_tags: list
end_user: Optional[str]
requester_ip_address: Optional[str] # IP address of requester
requester_metadata: Optional[dict] # metadata passed in request in the "metadata" field
requester_ip_address: Optional[str]
messages: Optional[Union[str, list, dict]]
response: Optional[Union[str, list, dict]]
error_str: Optional[str]
model_parameters: dict
hidden_params: StandardLoggingHiddenParams
@ -99,12 +105,47 @@ class StandardLoggingHiddenParams(TypedDict):
cache_key: Optional[str]
api_base: Optional[str]
response_cost: Optional[str]
additional_headers: Optional[dict]
additional_headers: Optional[StandardLoggingAdditionalHeaders]
class StandardLoggingAdditionalHeaders(TypedDict, total=False):
x_ratelimit_limit_requests: int
x_ratelimit_limit_tokens: int
x_ratelimit_remaining_requests: int
x_ratelimit_remaining_tokens: int
class StandardLoggingMetadata(StandardLoggingUserAPIKeyMetadata):
"""
Specific metadata k,v pairs logged to integration for easier cost tracking
"""
spend_logs_metadata: Optional[
dict
] # special param to log k,v pairs to spendlogs for a call
requester_ip_address: Optional[str]
requester_metadata: Optional[dict]
class StandardLoggingModelInformation(TypedDict):
model_map_key: str
model_map_value: Optional[ModelInfo]
StandardLoggingPayloadStatus = Literal["success", "failure"]
class StandardLoggingModelCostFailureDebugInformation(TypedDict, total=False):
"""
Debug information, if cost tracking fails.
Avoid logging sensitive information like response or optional params
"""
error_str: Required[str]
traceback_str: Required[str]
model: str
cache_hit: Optional[bool]
custom_llm_provider: Optional[str]
base_model: Optional[str]
call_type: str
custom_pricing: Optional[bool]
```
## Langfuse

View file

@ -1,5 +1,6 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
import Image from '@theme/IdealImage';
# ⚡ Best Practices for Production
@ -112,7 +113,35 @@ general_settings:
disable_spend_logs: True
```
## 7. Set LiteLLM Salt Key
## 7. Use Helm PreSync Hook for Database Migrations [BETA]
To ensure only one service manages database migrations, use our [Helm PreSync hook for Database Migrations](https://github.com/BerriAI/litellm/blob/main/deploy/charts/litellm-helm/templates/migrations-job.yaml). This ensures migrations are handled during `helm upgrade` or `helm install`, while LiteLLM pods explicitly disable migrations.
1. **Helm PreSync Hook**:
- The Helm PreSync hook is configured in the chart to run database migrations during deployments.
- The hook always sets `DISABLE_SCHEMA_UPDATE=false`, ensuring migrations are executed reliably.
Reference Settings to set on ArgoCD for `values.yaml`
```yaml
db:
useExisting: true # use existing Postgres DB
url: postgresql://ishaanjaffer0324:3rnwpOBau6hT@ep-withered-mud-a5dkdpke.us-east-2.aws.neon.tech/test-argo-cd?sslmode=require # url of existing Postgres DB
```
2. **LiteLLM Pods**:
- Set `DISABLE_SCHEMA_UPDATE=true` in LiteLLM pod configurations to prevent them from running migrations.
Example configuration for LiteLLM pod:
```yaml
env:
- name: DISABLE_SCHEMA_UPDATE
value: "true"
```
## 8. Set LiteLLM Salt Key
If you plan on using the DB, set a salt key for encrypting/decrypting variables in the DB.

View file

@ -749,3 +749,18 @@ curl -L -X POST 'http://0.0.0.0:4000/v1/chat/completions' \
"mock_testing_fallbacks": true
}'
```
### Disable Fallbacks per key
You can disable fallbacks per key by setting `disable_fallbacks: true` in your key metadata.
```bash
curl -L -X POST 'http://0.0.0.0:4000/key/generate' \
-H 'Authorization: Bearer sk-1234' \
-H 'Content-Type: application/json' \
-d '{
"metadata": {
"disable_fallbacks": true
}
}'
```

View file

@ -114,3 +114,4 @@ curl http://0.0.0.0:4000/rerank \
| Cohere | [Usage](#quick-start) |
| Together AI| [Usage](../docs/providers/togetherai) |
| Azure AI| [Usage](../docs/providers/azure_ai) |
| Jina AI| [Usage](../docs/providers/jina_ai) |

View file

@ -1,3 +1,6 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# Secret Manager
LiteLLM supports reading secrets from Azure Key Vault, Google Secret Manager
@ -59,14 +62,35 @@ os.environ["AWS_REGION_NAME"] = "" # us-east-1, us-east-2, us-west-1, us-west-2
```
2. Enable AWS Secret Manager in config.
<Tabs>
<TabItem value="read_only" label="Read Keys from AWS Secret Manager">
```yaml
general_settings:
master_key: os.environ/litellm_master_key
key_management_system: "aws_secret_manager" # 👈 KEY CHANGE
key_management_settings:
hosted_keys: ["litellm_master_key"] # 👈 Specify which env keys you stored on AWS
```
</TabItem>
<TabItem value="write_only" label="Write Virtual Keys to AWS Secret Manager">
This will only store virtual keys in AWS Secret Manager. No keys will be read from AWS Secret Manager.
```yaml
general_settings:
key_management_system: "aws_secret_manager" # 👈 KEY CHANGE
key_management_settings:
store_virtual_keys: true
access_mode: "write_only" # Literal["read_only", "write_only", "read_and_write"]
```
</TabItem>
</Tabs>
3. Run proxy
```bash
@ -181,16 +205,14 @@ litellm --config /path/to/config.yaml
Use encrypted keys from Google KMS on the proxy
### Usage with LiteLLM Proxy Server
## Step 1. Add keys to env
Step 1. Add keys to env
```
export GOOGLE_APPLICATION_CREDENTIALS="/path/to/credentials.json"
export GOOGLE_KMS_RESOURCE_NAME="projects/*/locations/*/keyRings/*/cryptoKeys/*"
export PROXY_DATABASE_URL_ENCRYPTED=b'\n$\x00D\xac\xb4/\x8e\xc...'
```
## Step 2: Update Config
Step 2: Update Config
```yaml
general_settings:
@ -199,7 +221,7 @@ general_settings:
master_key: sk-1234
```
## Step 3: Start + test proxy
Step 3: Start + test proxy
```
$ litellm --config /path/to/config.yaml
@ -215,3 +237,17 @@ $ litellm --test
<!--
## .env Files
If no secret manager client is specified, Litellm automatically uses the `.env` file to manage sensitive data. -->
## All Secret Manager Settings
All settings related to secret management
```yaml
general_settings:
key_management_system: "aws_secret_manager" # REQUIRED
key_management_settings:
store_virtual_keys: true # OPTIONAL. Defaults to False, when True will store virtual keys in secret manager
access_mode: "write_only" # OPTIONAL. Literal["read_only", "write_only", "read_and_write"]. Defaults to "read_only"
hosted_keys: ["litellm_master_key"] # OPTIONAL. Specify which env keys you stored on AWS
```

View file

@ -305,7 +305,7 @@ secret_manager_client: Optional[Any] = (
)
_google_kms_resource_name: Optional[str] = None
_key_management_system: Optional[KeyManagementSystem] = None
_key_management_settings: Optional[KeyManagementSettings] = None
_key_management_settings: KeyManagementSettings = KeyManagementSettings()
#### PII MASKING ####
output_parse_pii: bool = False
#############################################
@ -962,6 +962,8 @@ from .utils import (
supports_response_schema,
supports_parallel_function_calling,
supports_vision,
supports_audio_input,
supports_audio_output,
supports_system_messages,
get_litellm_params,
acreate,

View file

@ -46,6 +46,9 @@ from litellm.llms.OpenAI.cost_calculation import (
from litellm.llms.OpenAI.cost_calculation import cost_per_token as openai_cost_per_token
from litellm.llms.OpenAI.cost_calculation import cost_router as openai_cost_router
from litellm.llms.together_ai.cost_calculator import get_model_params_and_category
from litellm.llms.vertex_ai_and_google_ai_studio.image_generation.cost_calculator import (
cost_calculator as vertex_ai_image_cost_calculator,
)
from litellm.types.llms.openai import HttpxBinaryResponseContent
from litellm.types.rerank import RerankResponse
from litellm.types.router import SPECIAL_MODEL_INFO_PARAMS
@ -667,9 +670,11 @@ def completion_cost( # noqa: PLR0915
):
### IMAGE GENERATION COST CALCULATION ###
if custom_llm_provider == "vertex_ai":
# https://cloud.google.com/vertex-ai/generative-ai/pricing
# Vertex Charges Flat $0.20 per image
return 0.020
if isinstance(completion_response, ImageResponse):
return vertex_ai_image_cost_calculator(
model=model,
image_response=completion_response,
)
elif custom_llm_provider == "bedrock":
if isinstance(completion_response, ImageResponse):
return bedrock_image_cost_calculator(

View file

@ -239,7 +239,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"ContextWindowExceededError: {exception_provider} - {message}",
llm_provider=custom_llm_provider,
model=model,
response=original_exception.response,
response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information,
)
elif (
@ -251,7 +251,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"{exception_provider} - {message}",
llm_provider=custom_llm_provider,
model=model,
response=original_exception.response,
response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information,
)
elif "A timeout occurred" in error_str:
@ -271,7 +271,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"ContentPolicyViolationError: {exception_provider} - {message}",
llm_provider=custom_llm_provider,
model=model,
response=original_exception.response,
response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information,
)
elif (
@ -283,7 +283,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"{exception_provider} - {message}",
llm_provider=custom_llm_provider,
model=model,
response=original_exception.response,
response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information,
)
elif "Web server is returning an unknown error" in error_str:
@ -299,7 +299,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"RateLimitError: {exception_provider} - {message}",
model=model,
llm_provider=custom_llm_provider,
response=original_exception.response,
response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information,
)
elif (
@ -311,7 +311,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"AuthenticationError: {exception_provider} - {message}",
llm_provider=custom_llm_provider,
model=model,
response=original_exception.response,
response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information,
)
elif "Mistral API raised a streaming error" in error_str:
@ -335,7 +335,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"{exception_provider} - {message}",
llm_provider=custom_llm_provider,
model=model,
response=original_exception.response,
response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information,
)
elif original_exception.status_code == 401:
@ -344,7 +344,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"AuthenticationError: {exception_provider} - {message}",
llm_provider=custom_llm_provider,
model=model,
response=original_exception.response,
response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information,
)
elif original_exception.status_code == 404:
@ -353,7 +353,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"NotFoundError: {exception_provider} - {message}",
model=model,
llm_provider=custom_llm_provider,
response=original_exception.response,
response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information,
)
elif original_exception.status_code == 408:
@ -516,7 +516,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"ReplicateException - {error_str}",
llm_provider="replicate",
model=model,
response=original_exception.response,
response=getattr(original_exception, "response", None),
)
elif "input is too long" in error_str:
exception_mapping_worked = True
@ -524,7 +524,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"ReplicateException - {error_str}",
model=model,
llm_provider="replicate",
response=original_exception.response,
response=getattr(original_exception, "response", None),
)
elif exception_type == "ModelError":
exception_mapping_worked = True
@ -532,7 +532,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"ReplicateException - {error_str}",
model=model,
llm_provider="replicate",
response=original_exception.response,
response=getattr(original_exception, "response", None),
)
elif "Request was throttled" in error_str:
exception_mapping_worked = True
@ -540,7 +540,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"ReplicateException - {error_str}",
llm_provider="replicate",
model=model,
response=original_exception.response,
response=getattr(original_exception, "response", None),
)
elif hasattr(original_exception, "status_code"):
if original_exception.status_code == 401:
@ -549,7 +549,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"ReplicateException - {original_exception.message}",
llm_provider="replicate",
model=model,
response=original_exception.response,
response=getattr(original_exception, "response", None),
)
elif (
original_exception.status_code == 400
@ -560,7 +560,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"ReplicateException - {original_exception.message}",
model=model,
llm_provider="replicate",
response=original_exception.response,
response=getattr(original_exception, "response", None),
)
elif original_exception.status_code == 422:
exception_mapping_worked = True
@ -568,7 +568,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"ReplicateException - {original_exception.message}",
model=model,
llm_provider="replicate",
response=original_exception.response,
response=getattr(original_exception, "response", None),
)
elif original_exception.status_code == 408:
exception_mapping_worked = True
@ -583,7 +583,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"ReplicateException - {original_exception.message}",
llm_provider="replicate",
model=model,
response=original_exception.response,
response=getattr(original_exception, "response", None),
)
elif original_exception.status_code == 429:
exception_mapping_worked = True
@ -591,7 +591,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"ReplicateException - {original_exception.message}",
llm_provider="replicate",
model=model,
response=original_exception.response,
response=getattr(original_exception, "response", None),
)
elif original_exception.status_code == 500:
exception_mapping_worked = True
@ -599,7 +599,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"ReplicateException - {original_exception.message}",
llm_provider="replicate",
model=model,
response=original_exception.response,
response=getattr(original_exception, "response", None),
)
exception_mapping_worked = True
raise APIError(
@ -631,7 +631,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"{custom_llm_provider}Exception: Authentication Error - {error_str}",
llm_provider=custom_llm_provider,
model=model,
response=original_exception.response,
response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information,
)
elif "token_quota_reached" in error_str:
@ -640,7 +640,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"{custom_llm_provider}Exception: Rate Limit Errror - {error_str}",
llm_provider=custom_llm_provider,
model=model,
response=original_exception.response,
response=getattr(original_exception, "response", None),
)
elif (
"The server received an invalid response from an upstream server."
@ -750,7 +750,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"BedrockException - {error_str}\n. Enable 'litellm.modify_params=True' (for PROXY do: `litellm_settings::modify_params: True`) to insert a dummy assistant message and fix this error.",
model=model,
llm_provider="bedrock",
response=original_exception.response,
response=getattr(original_exception, "response", None),
)
elif "Malformed input request" in error_str:
exception_mapping_worked = True
@ -758,7 +758,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"BedrockException - {error_str}",
model=model,
llm_provider="bedrock",
response=original_exception.response,
response=getattr(original_exception, "response", None),
)
elif "A conversation must start with a user message." in error_str:
exception_mapping_worked = True
@ -766,7 +766,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"BedrockException - {error_str}\n. Pass in default user message via `completion(..,user_continue_message=)` or enable `litellm.modify_params=True`.\nFor Proxy: do via `litellm_settings::modify_params: True` or user_continue_message under `litellm_params`",
model=model,
llm_provider="bedrock",
response=original_exception.response,
response=getattr(original_exception, "response", None),
)
elif (
"Unable to locate credentials" in error_str
@ -778,7 +778,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"BedrockException Invalid Authentication - {error_str}",
model=model,
llm_provider="bedrock",
response=original_exception.response,
response=getattr(original_exception, "response", None),
)
elif "AccessDeniedException" in error_str:
exception_mapping_worked = True
@ -786,7 +786,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"BedrockException PermissionDeniedError - {error_str}",
model=model,
llm_provider="bedrock",
response=original_exception.response,
response=getattr(original_exception, "response", None),
)
elif (
"throttlingException" in error_str
@ -797,7 +797,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"BedrockException: Rate Limit Error - {error_str}",
model=model,
llm_provider="bedrock",
response=original_exception.response,
response=getattr(original_exception, "response", None),
)
elif (
"Connect timeout on endpoint URL" in error_str
@ -836,7 +836,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"BedrockException - {original_exception.message}",
llm_provider="bedrock",
model=model,
response=original_exception.response,
response=getattr(original_exception, "response", None),
)
elif original_exception.status_code == 400:
exception_mapping_worked = True
@ -844,7 +844,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"BedrockException - {original_exception.message}",
llm_provider="bedrock",
model=model,
response=original_exception.response,
response=getattr(original_exception, "response", None),
)
elif original_exception.status_code == 404:
exception_mapping_worked = True
@ -852,7 +852,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"BedrockException - {original_exception.message}",
llm_provider="bedrock",
model=model,
response=original_exception.response,
response=getattr(original_exception, "response", None),
)
elif original_exception.status_code == 408:
exception_mapping_worked = True
@ -868,7 +868,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"BedrockException - {original_exception.message}",
model=model,
llm_provider=custom_llm_provider,
response=original_exception.response,
response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information,
)
elif original_exception.status_code == 429:
@ -877,7 +877,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"BedrockException - {original_exception.message}",
model=model,
llm_provider=custom_llm_provider,
response=original_exception.response,
response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information,
)
elif original_exception.status_code == 503:
@ -886,7 +886,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"BedrockException - {original_exception.message}",
model=model,
llm_provider=custom_llm_provider,
response=original_exception.response,
response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information,
)
elif original_exception.status_code == 504: # gateway timeout error
@ -907,7 +907,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"litellm.BadRequestError: SagemakerException - {error_str}",
model=model,
llm_provider="sagemaker",
response=original_exception.response,
response=getattr(original_exception, "response", None),
)
elif (
"Input validation error: `best_of` must be > 0 and <= 2"
@ -918,7 +918,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message="SagemakerException - the value of 'n' must be > 0 and <= 2 for sagemaker endpoints",
model=model,
llm_provider="sagemaker",
response=original_exception.response,
response=getattr(original_exception, "response", None),
)
elif (
"`inputs` tokens + `max_new_tokens` must be <=" in error_str
@ -929,7 +929,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"SagemakerException - {error_str}",
model=model,
llm_provider="sagemaker",
response=original_exception.response,
response=getattr(original_exception, "response", None),
)
elif hasattr(original_exception, "status_code"):
if original_exception.status_code == 500:
@ -951,7 +951,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"SagemakerException - {original_exception.message}",
llm_provider=custom_llm_provider,
model=model,
response=original_exception.response,
response=getattr(original_exception, "response", None),
)
elif original_exception.status_code == 400:
exception_mapping_worked = True
@ -959,7 +959,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"SagemakerException - {original_exception.message}",
llm_provider=custom_llm_provider,
model=model,
response=original_exception.response,
response=getattr(original_exception, "response", None),
)
elif original_exception.status_code == 404:
exception_mapping_worked = True
@ -967,7 +967,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"SagemakerException - {original_exception.message}",
llm_provider=custom_llm_provider,
model=model,
response=original_exception.response,
response=getattr(original_exception, "response", None),
)
elif original_exception.status_code == 408:
exception_mapping_worked = True
@ -986,7 +986,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"SagemakerException - {original_exception.message}",
model=model,
llm_provider=custom_llm_provider,
response=original_exception.response,
response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information,
)
elif original_exception.status_code == 429:
@ -995,7 +995,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"SagemakerException - {original_exception.message}",
model=model,
llm_provider=custom_llm_provider,
response=original_exception.response,
response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information,
)
elif original_exception.status_code == 503:
@ -1004,7 +1004,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"SagemakerException - {original_exception.message}",
model=model,
llm_provider=custom_llm_provider,
response=original_exception.response,
response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information,
)
elif original_exception.status_code == 504: # gateway timeout error
@ -1217,7 +1217,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message="GeminiException - Invalid api key",
model=model,
llm_provider="palm",
response=original_exception.response,
response=getattr(original_exception, "response", None),
)
if (
"504 Deadline expired before operation could complete." in error_str
@ -1235,7 +1235,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"GeminiException - {error_str}",
model=model,
llm_provider="palm",
response=original_exception.response,
response=getattr(original_exception, "response", None),
)
if (
"500 An internal error has occurred." in error_str
@ -1262,7 +1262,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"GeminiException - {error_str}",
model=model,
llm_provider="palm",
response=original_exception.response,
response=getattr(original_exception, "response", None),
)
# Dailed: Error occurred: 400 Request payload size exceeds the limit: 20000 bytes
elif custom_llm_provider == "cloudflare":
@ -1272,7 +1272,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"Cloudflare Exception - {original_exception.message}",
llm_provider="cloudflare",
model=model,
response=original_exception.response,
response=getattr(original_exception, "response", None),
)
if "must have required property" in error_str:
exception_mapping_worked = True
@ -1280,7 +1280,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"Cloudflare Exception - {original_exception.message}",
llm_provider="cloudflare",
model=model,
response=original_exception.response,
response=getattr(original_exception, "response", None),
)
elif (
custom_llm_provider == "cohere" or custom_llm_provider == "cohere_chat"
@ -1294,7 +1294,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"CohereException - {original_exception.message}",
llm_provider="cohere",
model=model,
response=original_exception.response,
response=getattr(original_exception, "response", None),
)
elif "too many tokens" in error_str:
exception_mapping_worked = True
@ -1302,7 +1302,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"CohereException - {original_exception.message}",
model=model,
llm_provider="cohere",
response=original_exception.response,
response=getattr(original_exception, "response", None),
)
elif hasattr(original_exception, "status_code"):
if (
@ -1314,7 +1314,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"CohereException - {original_exception.message}",
llm_provider="cohere",
model=model,
response=original_exception.response,
response=getattr(original_exception, "response", None),
)
elif original_exception.status_code == 408:
exception_mapping_worked = True
@ -1329,7 +1329,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"CohereException - {original_exception.message}",
llm_provider="cohere",
model=model,
response=original_exception.response,
response=getattr(original_exception, "response", None),
)
elif (
"CohereConnectionError" in exception_type
@ -1339,7 +1339,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"CohereException - {original_exception.message}",
llm_provider="cohere",
model=model,
response=original_exception.response,
response=getattr(original_exception, "response", None),
)
elif "invalid type:" in error_str:
exception_mapping_worked = True
@ -1347,7 +1347,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"CohereException - {original_exception.message}",
llm_provider="cohere",
model=model,
response=original_exception.response,
response=getattr(original_exception, "response", None),
)
elif "Unexpected server error" in error_str:
exception_mapping_worked = True
@ -1355,7 +1355,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"CohereException - {original_exception.message}",
llm_provider="cohere",
model=model,
response=original_exception.response,
response=getattr(original_exception, "response", None),
)
else:
if hasattr(original_exception, "status_code"):
@ -1375,7 +1375,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=error_str,
model=model,
llm_provider="huggingface",
response=original_exception.response,
response=getattr(original_exception, "response", None),
)
elif "A valid user token is required" in error_str:
exception_mapping_worked = True
@ -1383,7 +1383,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=error_str,
llm_provider="huggingface",
model=model,
response=original_exception.response,
response=getattr(original_exception, "response", None),
)
elif "Rate limit reached" in error_str:
exception_mapping_worked = True
@ -1391,7 +1391,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=error_str,
llm_provider="huggingface",
model=model,
response=original_exception.response,
response=getattr(original_exception, "response", None),
)
if hasattr(original_exception, "status_code"):
if original_exception.status_code == 401:
@ -1400,7 +1400,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"HuggingfaceException - {original_exception.message}",
llm_provider="huggingface",
model=model,
response=original_exception.response,
response=getattr(original_exception, "response", None),
)
elif original_exception.status_code == 400:
exception_mapping_worked = True
@ -1408,7 +1408,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"HuggingfaceException - {original_exception.message}",
model=model,
llm_provider="huggingface",
response=original_exception.response,
response=getattr(original_exception, "response", None),
)
elif original_exception.status_code == 408:
exception_mapping_worked = True
@ -1423,7 +1423,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"HuggingfaceException - {original_exception.message}",
llm_provider="huggingface",
model=model,
response=original_exception.response,
response=getattr(original_exception, "response", None),
)
elif original_exception.status_code == 503:
exception_mapping_worked = True
@ -1431,7 +1431,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"HuggingfaceException - {original_exception.message}",
llm_provider="huggingface",
model=model,
response=original_exception.response,
response=getattr(original_exception, "response", None),
)
else:
exception_mapping_worked = True
@ -1450,7 +1450,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"AI21Exception - {original_exception.message}",
model=model,
llm_provider="ai21",
response=original_exception.response,
response=getattr(original_exception, "response", None),
)
if "Bad or missing API token." in original_exception.message:
exception_mapping_worked = True
@ -1458,7 +1458,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"AI21Exception - {original_exception.message}",
model=model,
llm_provider="ai21",
response=original_exception.response,
response=getattr(original_exception, "response", None),
)
if hasattr(original_exception, "status_code"):
if original_exception.status_code == 401:
@ -1467,7 +1467,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"AI21Exception - {original_exception.message}",
llm_provider="ai21",
model=model,
response=original_exception.response,
response=getattr(original_exception, "response", None),
)
elif original_exception.status_code == 408:
exception_mapping_worked = True
@ -1482,7 +1482,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"AI21Exception - {original_exception.message}",
model=model,
llm_provider="ai21",
response=original_exception.response,
response=getattr(original_exception, "response", None),
)
elif original_exception.status_code == 429:
exception_mapping_worked = True
@ -1490,7 +1490,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"AI21Exception - {original_exception.message}",
llm_provider="ai21",
model=model,
response=original_exception.response,
response=getattr(original_exception, "response", None),
)
else:
exception_mapping_worked = True
@ -1509,7 +1509,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"NLPCloudException - {error_str}",
model=model,
llm_provider="nlp_cloud",
response=original_exception.response,
response=getattr(original_exception, "response", None),
)
elif "value is not a valid" in error_str:
exception_mapping_worked = True
@ -1517,7 +1517,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"NLPCloudException - {error_str}",
model=model,
llm_provider="nlp_cloud",
response=original_exception.response,
response=getattr(original_exception, "response", None),
)
else:
exception_mapping_worked = True
@ -1542,7 +1542,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"NLPCloudException - {original_exception.message}",
llm_provider="nlp_cloud",
model=model,
response=original_exception.response,
response=getattr(original_exception, "response", None),
)
elif (
original_exception.status_code == 401
@ -1553,7 +1553,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"NLPCloudException - {original_exception.message}",
llm_provider="nlp_cloud",
model=model,
response=original_exception.response,
response=getattr(original_exception, "response", None),
)
elif (
original_exception.status_code == 522
@ -1574,7 +1574,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"NLPCloudException - {original_exception.message}",
llm_provider="nlp_cloud",
model=model,
response=original_exception.response,
response=getattr(original_exception, "response", None),
)
elif (
original_exception.status_code == 500
@ -1597,7 +1597,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"NLPCloudException - {original_exception.message}",
model=model,
llm_provider="nlp_cloud",
response=original_exception.response,
response=getattr(original_exception, "response", None),
)
else:
exception_mapping_worked = True
@ -1623,7 +1623,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"TogetherAIException - {error_response['error']}",
model=model,
llm_provider="together_ai",
response=original_exception.response,
response=getattr(original_exception, "response", None),
)
elif (
"error" in error_response
@ -1634,7 +1634,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"TogetherAIException - {error_response['error']}",
llm_provider="together_ai",
model=model,
response=original_exception.response,
response=getattr(original_exception, "response", None),
)
elif (
"error" in error_response
@ -1645,7 +1645,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"TogetherAIException - {error_response['error']}",
model=model,
llm_provider="together_ai",
response=original_exception.response,
response=getattr(original_exception, "response", None),
)
elif "A timeout occurred" in error_str:
exception_mapping_worked = True
@ -1664,7 +1664,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"TogetherAIException - {error_response['error']}",
model=model,
llm_provider="together_ai",
response=original_exception.response,
response=getattr(original_exception, "response", None),
)
elif (
"error_type" in error_response
@ -1675,7 +1675,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"TogetherAIException - {error_response['error']}",
model=model,
llm_provider="together_ai",
response=original_exception.response,
response=getattr(original_exception, "response", None),
)
if hasattr(original_exception, "status_code"):
if original_exception.status_code == 408:
@ -1691,7 +1691,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"TogetherAIException - {error_response['error']}",
model=model,
llm_provider="together_ai",
response=original_exception.response,
response=getattr(original_exception, "response", None),
)
elif original_exception.status_code == 429:
exception_mapping_worked = True
@ -1699,7 +1699,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"TogetherAIException - {original_exception.message}",
llm_provider="together_ai",
model=model,
response=original_exception.response,
response=getattr(original_exception, "response", None),
)
elif original_exception.status_code == 524:
exception_mapping_worked = True
@ -1727,7 +1727,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"AlephAlphaException - {original_exception.message}",
llm_provider="aleph_alpha",
model=model,
response=original_exception.response,
response=getattr(original_exception, "response", None),
)
elif "InvalidToken" in error_str or "No token provided" in error_str:
exception_mapping_worked = True
@ -1735,7 +1735,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"AlephAlphaException - {original_exception.message}",
llm_provider="aleph_alpha",
model=model,
response=original_exception.response,
response=getattr(original_exception, "response", None),
)
elif hasattr(original_exception, "status_code"):
verbose_logger.debug(
@ -1754,7 +1754,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"AlephAlphaException - {original_exception.message}",
llm_provider="aleph_alpha",
model=model,
response=original_exception.response,
response=getattr(original_exception, "response", None),
)
elif original_exception.status_code == 429:
exception_mapping_worked = True
@ -1762,7 +1762,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"AlephAlphaException - {original_exception.message}",
llm_provider="aleph_alpha",
model=model,
response=original_exception.response,
response=getattr(original_exception, "response", None),
)
elif original_exception.status_code == 500:
exception_mapping_worked = True
@ -1770,7 +1770,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"AlephAlphaException - {original_exception.message}",
llm_provider="aleph_alpha",
model=model,
response=original_exception.response,
response=getattr(original_exception, "response", None),
)
raise original_exception
raise original_exception
@ -1787,7 +1787,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"OllamaException: Invalid Model/Model not loaded - {original_exception}",
model=model,
llm_provider="ollama",
response=original_exception.response,
response=getattr(original_exception, "response", None),
)
elif "Failed to establish a new connection" in error_str:
exception_mapping_worked = True
@ -1795,7 +1795,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"OllamaException: {original_exception}",
llm_provider="ollama",
model=model,
response=original_exception.response,
response=getattr(original_exception, "response", None),
)
elif "Invalid response object from API" in error_str:
exception_mapping_worked = True
@ -1803,7 +1803,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"OllamaException: {original_exception}",
llm_provider="ollama",
model=model,
response=original_exception.response,
response=getattr(original_exception, "response", None),
)
elif "Read timed out" in error_str:
exception_mapping_worked = True
@ -1837,6 +1837,7 @@ def exception_type( # type: ignore # noqa: PLR0915
llm_provider="azure",
model=model,
litellm_debug_info=extra_information,
response=getattr(original_exception, "response", None),
)
elif "This model's maximum context length is" in error_str:
exception_mapping_worked = True
@ -1845,6 +1846,7 @@ def exception_type( # type: ignore # noqa: PLR0915
llm_provider="azure",
model=model,
litellm_debug_info=extra_information,
response=getattr(original_exception, "response", None),
)
elif "DeploymentNotFound" in error_str:
exception_mapping_worked = True
@ -1853,6 +1855,7 @@ def exception_type( # type: ignore # noqa: PLR0915
llm_provider="azure",
model=model,
litellm_debug_info=extra_information,
response=getattr(original_exception, "response", None),
)
elif (
(
@ -1873,6 +1876,7 @@ def exception_type( # type: ignore # noqa: PLR0915
llm_provider="azure",
model=model,
litellm_debug_info=extra_information,
response=getattr(original_exception, "response", None),
)
elif "invalid_request_error" in error_str:
exception_mapping_worked = True
@ -1881,6 +1885,7 @@ def exception_type( # type: ignore # noqa: PLR0915
llm_provider="azure",
model=model,
litellm_debug_info=extra_information,
response=getattr(original_exception, "response", None),
)
elif (
"The api_key client option must be set either by passing api_key to the client or by setting"
@ -1892,6 +1897,7 @@ def exception_type( # type: ignore # noqa: PLR0915
llm_provider=custom_llm_provider,
model=model,
litellm_debug_info=extra_information,
response=getattr(original_exception, "response", None),
)
elif "Connection error" in error_str:
exception_mapping_worked = True
@ -1910,6 +1916,7 @@ def exception_type( # type: ignore # noqa: PLR0915
llm_provider="azure",
model=model,
litellm_debug_info=extra_information,
response=getattr(original_exception, "response", None),
)
elif original_exception.status_code == 401:
exception_mapping_worked = True
@ -1918,6 +1925,7 @@ def exception_type( # type: ignore # noqa: PLR0915
llm_provider="azure",
model=model,
litellm_debug_info=extra_information,
response=getattr(original_exception, "response", None),
)
elif original_exception.status_code == 408:
exception_mapping_worked = True
@ -1934,6 +1942,7 @@ def exception_type( # type: ignore # noqa: PLR0915
model=model,
llm_provider="azure",
litellm_debug_info=extra_information,
response=getattr(original_exception, "response", None),
)
elif original_exception.status_code == 429:
exception_mapping_worked = True
@ -1942,6 +1951,7 @@ def exception_type( # type: ignore # noqa: PLR0915
model=model,
llm_provider="azure",
litellm_debug_info=extra_information,
response=getattr(original_exception, "response", None),
)
elif original_exception.status_code == 503:
exception_mapping_worked = True
@ -1950,6 +1960,7 @@ def exception_type( # type: ignore # noqa: PLR0915
model=model,
llm_provider="azure",
litellm_debug_info=extra_information,
response=getattr(original_exception, "response", None),
)
elif original_exception.status_code == 504: # gateway timeout error
exception_mapping_worked = True
@ -1989,7 +2000,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"{exception_provider} - {error_str}",
llm_provider=custom_llm_provider,
model=model,
response=original_exception.response,
response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information,
)
elif original_exception.status_code == 401:
@ -1998,7 +2009,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"AuthenticationError: {exception_provider} - {error_str}",
llm_provider=custom_llm_provider,
model=model,
response=original_exception.response,
response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information,
)
elif original_exception.status_code == 404:
@ -2007,7 +2018,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"NotFoundError: {exception_provider} - {error_str}",
model=model,
llm_provider=custom_llm_provider,
response=original_exception.response,
response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information,
)
elif original_exception.status_code == 408:
@ -2024,7 +2035,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"BadRequestError: {exception_provider} - {error_str}",
model=model,
llm_provider=custom_llm_provider,
response=original_exception.response,
response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information,
)
elif original_exception.status_code == 429:
@ -2033,7 +2044,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"RateLimitError: {exception_provider} - {error_str}",
model=model,
llm_provider=custom_llm_provider,
response=original_exception.response,
response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information,
)
elif original_exception.status_code == 503:
@ -2042,7 +2053,7 @@ def exception_type( # type: ignore # noqa: PLR0915
message=f"ServiceUnavailableError: {exception_provider} - {error_str}",
model=model,
llm_provider=custom_llm_provider,
response=original_exception.response,
response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information,
)
elif original_exception.status_code == 504: # gateway timeout error

View file

@ -202,6 +202,7 @@ class Logging:
start_time,
litellm_call_id: str,
function_id: str,
litellm_trace_id: Optional[str] = None,
dynamic_input_callbacks: Optional[
List[Union[str, Callable, CustomLogger]]
] = None,
@ -239,6 +240,7 @@ class Logging:
self.start_time = start_time # log the call start time
self.call_type = call_type
self.litellm_call_id = litellm_call_id
self.litellm_trace_id = litellm_trace_id
self.function_id = function_id
self.streaming_chunks: List[Any] = [] # for generating complete stream response
self.sync_streaming_chunks: List[Any] = (
@ -275,6 +277,11 @@ class Logging:
self.completion_start_time: Optional[datetime.datetime] = None
self._llm_caching_handler: Optional[LLMCachingHandler] = None
self.model_call_details = {
"litellm_trace_id": litellm_trace_id,
"litellm_call_id": litellm_call_id,
}
def process_dynamic_callbacks(self):
"""
Initializes CustomLogger compatible callbacks in self.dynamic_* callbacks
@ -382,7 +389,8 @@ class Logging:
self.logger_fn = litellm_params.get("logger_fn", None)
verbose_logger.debug(f"self.optional_params: {self.optional_params}")
self.model_call_details = {
self.model_call_details.update(
{
"model": self.model,
"messages": self.messages,
"optional_params": self.optional_params,
@ -397,6 +405,7 @@ class Logging:
**self.optional_params,
**additional_params,
}
)
## check if stream options is set ## - used by CustomStreamWrapper for easy instrumentation
if "stream_options" in additional_params:
@ -2823,6 +2832,7 @@ def get_standard_logging_object_payload(
payload: StandardLoggingPayload = StandardLoggingPayload(
id=str(id),
trace_id=kwargs.get("litellm_trace_id"), # type: ignore
call_type=call_type or "",
cache_hit=cache_hit,
status=status,

View file

@ -44,7 +44,9 @@ from litellm.types.llms.openai import (
ChatCompletionToolCallFunctionChunk,
ChatCompletionUsageBlock,
)
from litellm.types.utils import GenericStreamingChunk, PromptTokensDetailsWrapper
from litellm.types.utils import GenericStreamingChunk
from litellm.types.utils import Message as LitellmMessage
from litellm.types.utils import PromptTokensDetailsWrapper
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
from ...base import BaseLLM
@ -94,6 +96,7 @@ async def make_call(
messages: list,
logging_obj,
timeout: Optional[Union[float, httpx.Timeout]],
json_mode: bool,
) -> Tuple[Any, httpx.Headers]:
if client is None:
client = litellm.module_level_aclient
@ -119,7 +122,9 @@ async def make_call(
raise AnthropicError(status_code=500, message=str(e))
completion_stream = ModelResponseIterator(
streaming_response=response.aiter_lines(), sync_stream=False
streaming_response=response.aiter_lines(),
sync_stream=False,
json_mode=json_mode,
)
# LOGGING
@ -142,6 +147,7 @@ def make_sync_call(
messages: list,
logging_obj,
timeout: Optional[Union[float, httpx.Timeout]],
json_mode: bool,
) -> Tuple[Any, httpx.Headers]:
if client is None:
client = litellm.module_level_client # re-use a module level client
@ -175,7 +181,7 @@ def make_sync_call(
)
completion_stream = ModelResponseIterator(
streaming_response=response.iter_lines(), sync_stream=True
streaming_response=response.iter_lines(), sync_stream=True, json_mode=json_mode
)
# LOGGING
@ -270,11 +276,12 @@ class AnthropicChatCompletion(BaseLLM):
"arguments"
)
if json_mode_content_str is not None:
args = json.loads(json_mode_content_str)
values: Optional[dict] = args.get("values")
if values is not None:
_message = litellm.Message(content=json.dumps(values))
_converted_message = self._convert_tool_response_to_message(
tool_calls=tool_calls,
)
if _converted_message is not None:
completion_response["stop_reason"] = "stop"
_message = _converted_message
model_response.choices[0].message = _message # type: ignore
model_response._hidden_params["original_response"] = completion_response[
"content"
@ -318,6 +325,37 @@ class AnthropicChatCompletion(BaseLLM):
model_response._hidden_params = _hidden_params
return model_response
@staticmethod
def _convert_tool_response_to_message(
tool_calls: List[ChatCompletionToolCallChunk],
) -> Optional[LitellmMessage]:
"""
In JSON mode, Anthropic API returns JSON schema as a tool call, we need to convert it to a message to follow the OpenAI format
"""
## HANDLE JSON MODE - anthropic returns single function call
json_mode_content_str: Optional[str] = tool_calls[0]["function"].get(
"arguments"
)
try:
if json_mode_content_str is not None:
args = json.loads(json_mode_content_str)
if (
isinstance(args, dict)
and (values := args.get("values")) is not None
):
_message = litellm.Message(content=json.dumps(values))
return _message
else:
# a lot of the times the `values` key is not present in the tool response
# relevant issue: https://github.com/BerriAI/litellm/issues/6741
_message = litellm.Message(content=json.dumps(args))
return _message
except json.JSONDecodeError:
# json decode error does occur, return the original tool response str
return litellm.Message(content=json_mode_content_str)
return None
async def acompletion_stream_function(
self,
model: str,
@ -334,6 +372,7 @@ class AnthropicChatCompletion(BaseLLM):
stream,
_is_function_call,
data: dict,
json_mode: bool,
optional_params=None,
litellm_params=None,
logger_fn=None,
@ -350,6 +389,7 @@ class AnthropicChatCompletion(BaseLLM):
messages=messages,
logging_obj=logging_obj,
timeout=timeout,
json_mode=json_mode,
)
streamwrapper = CustomStreamWrapper(
completion_stream=completion_stream,
@ -440,8 +480,8 @@ class AnthropicChatCompletion(BaseLLM):
logging_obj,
optional_params: dict,
timeout: Union[float, httpx.Timeout],
litellm_params: dict,
acompletion=None,
litellm_params=None,
logger_fn=None,
headers={},
client=None,
@ -464,6 +504,7 @@ class AnthropicChatCompletion(BaseLLM):
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
headers=headers,
_is_function_call=_is_function_call,
is_vertex_request=is_vertex_request,
@ -500,6 +541,7 @@ class AnthropicChatCompletion(BaseLLM):
optional_params=optional_params,
stream=stream,
_is_function_call=_is_function_call,
json_mode=json_mode,
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=headers,
@ -547,6 +589,7 @@ class AnthropicChatCompletion(BaseLLM):
messages=messages,
logging_obj=logging_obj,
timeout=timeout,
json_mode=json_mode,
)
return CustomStreamWrapper(
completion_stream=completion_stream,
@ -605,11 +648,14 @@ class AnthropicChatCompletion(BaseLLM):
class ModelResponseIterator:
def __init__(self, streaming_response, sync_stream: bool):
def __init__(
self, streaming_response, sync_stream: bool, json_mode: Optional[bool] = False
):
self.streaming_response = streaming_response
self.response_iterator = self.streaming_response
self.content_blocks: List[ContentBlockDelta] = []
self.tool_index = -1
self.json_mode = json_mode
def check_empty_tool_call_args(self) -> bool:
"""
@ -771,6 +817,8 @@ class ModelResponseIterator:
status_code=500, # it looks like Anthropic API does not return a status code in the chunk error - default to 500
)
text, tool_use = self._handle_json_mode_chunk(text=text, tool_use=tool_use)
returned_chunk = GenericStreamingChunk(
text=text,
tool_use=tool_use,
@ -785,6 +833,34 @@ class ModelResponseIterator:
except json.JSONDecodeError:
raise ValueError(f"Failed to decode JSON from chunk: {chunk}")
def _handle_json_mode_chunk(
self, text: str, tool_use: Optional[ChatCompletionToolCallChunk]
) -> Tuple[str, Optional[ChatCompletionToolCallChunk]]:
"""
If JSON mode is enabled, convert the tool call to a message.
Anthropic returns the JSON schema as part of the tool call
OpenAI returns the JSON schema as part of the content, this handles placing it in the content
Args:
text: str
tool_use: Optional[ChatCompletionToolCallChunk]
Returns:
Tuple[str, Optional[ChatCompletionToolCallChunk]]
text: The text to use in the content
tool_use: The ChatCompletionToolCallChunk to use in the chunk response
"""
if self.json_mode is True and tool_use is not None:
message = AnthropicChatCompletion._convert_tool_response_to_message(
tool_calls=[tool_use]
)
if message is not None:
text = message.content or ""
tool_use = None
return text, tool_use
# Sync iterator
def __iter__(self):
return self

View file

@ -91,6 +91,7 @@ class AnthropicConfig:
"extra_headers",
"parallel_tool_calls",
"response_format",
"user",
]
def get_cache_control_headers(self) -> dict:
@ -246,6 +247,28 @@ class AnthropicConfig:
anthropic_tools.append(new_tool)
return anthropic_tools
def _map_stop_sequences(
self, stop: Optional[Union[str, List[str]]]
) -> Optional[List[str]]:
new_stop: Optional[List[str]] = None
if isinstance(stop, str):
if (
stop == "\n"
) and litellm.drop_params is True: # anthropic doesn't allow whitespace characters as stop-sequences
return new_stop
new_stop = [stop]
elif isinstance(stop, list):
new_v = []
for v in stop:
if (
v == "\n"
) and litellm.drop_params is True: # anthropic doesn't allow whitespace characters as stop-sequences
continue
new_v.append(v)
if len(new_v) > 0:
new_stop = new_v
return new_stop
def map_openai_params(
self,
non_default_params: dict,
@ -271,26 +294,10 @@ class AnthropicConfig:
optional_params["tool_choice"] = _tool_choice
if param == "stream" and value is True:
optional_params["stream"] = value
if param == "stop":
if isinstance(value, str):
if (
value == "\n"
) and litellm.drop_params is True: # anthropic doesn't allow whitespace characters as stop-sequences
continue
value = [value]
elif isinstance(value, list):
new_v = []
for v in value:
if (
v == "\n"
) and litellm.drop_params is True: # anthropic doesn't allow whitespace characters as stop-sequences
continue
new_v.append(v)
if len(new_v) > 0:
value = new_v
else:
continue
optional_params["stop_sequences"] = value
if param == "stop" and (isinstance(value, str) or isinstance(value, list)):
_value = self._map_stop_sequences(value)
if _value is not None:
optional_params["stop_sequences"] = _value
if param == "temperature":
optional_params["temperature"] = value
if param == "top_p":
@ -314,7 +321,8 @@ class AnthropicConfig:
optional_params["tools"] = [_tool]
optional_params["tool_choice"] = _tool_choice
optional_params["json_mode"] = True
if param == "user":
optional_params["metadata"] = {"user_id": value}
## VALIDATE REQUEST
"""
Anthropic doesn't support tool calling without `tools=` param specified.
@ -465,6 +473,7 @@ class AnthropicConfig:
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
_is_function_call: bool,
is_vertex_request: bool,
@ -502,6 +511,15 @@ class AnthropicConfig:
if "tools" in optional_params:
_is_function_call = True
## Handle user_id in metadata
_litellm_metadata = litellm_params.get("metadata", None)
if (
_litellm_metadata
and isinstance(_litellm_metadata, dict)
and "user_id" in _litellm_metadata
):
optional_params["metadata"] = {"user_id": _litellm_metadata["user_id"]}
data = {
"messages": anthropic_messages,
**optional_params,

View file

@ -53,8 +53,14 @@ class AmazonStability3Config:
sd3-medium
sd3.5-large
sd3.5-large-turbo
Stability ultra models
stable-image-ultra-v1
"""
if model and ("sd3" in model or "sd3.5" in model):
if model:
if "sd3" in model or "sd3.5" in model:
return True
if "stable-image-ultra-v1" in model:
return True
return False

View file

@ -8,3 +8,4 @@ class httpxSpecialProvider(str, Enum):
GuardrailCallback = "guardrail_callback"
Caching = "caching"
Oauth2Check = "oauth2_check"
SecretManager = "secret_manager"

View file

@ -76,4 +76,4 @@ class JinaAIEmbeddingConfig:
or get_secret_str("JINA_AI_API_KEY")
or get_secret_str("JINA_AI_TOKEN")
)
return LlmProviders.OPENAI_LIKE.value, api_base, dynamic_api_key
return LlmProviders.JINA_AI.value, api_base, dynamic_api_key

View file

@ -0,0 +1,96 @@
"""
Re rank api
LiteLLM supports the re rank API format, no paramter transformation occurs
"""
import uuid
from typing import Any, Dict, List, Optional, Union
import httpx
from pydantic import BaseModel
import litellm
from litellm.llms.base import BaseLLM
from litellm.llms.custom_httpx.http_handler import (
_get_httpx_client,
get_async_httpx_client,
)
from litellm.llms.jina_ai.rerank.transformation import JinaAIRerankConfig
from litellm.types.rerank import RerankRequest, RerankResponse
class JinaAIRerank(BaseLLM):
def rerank(
self,
model: str,
api_key: str,
query: str,
documents: List[Union[str, Dict[str, Any]]],
top_n: Optional[int] = None,
rank_fields: Optional[List[str]] = None,
return_documents: Optional[bool] = True,
max_chunks_per_doc: Optional[int] = None,
_is_async: Optional[bool] = False,
) -> RerankResponse:
client = _get_httpx_client()
request_data = RerankRequest(
model=model,
query=query,
top_n=top_n,
documents=documents,
rank_fields=rank_fields,
return_documents=return_documents,
)
# exclude None values from request_data
request_data_dict = request_data.dict(exclude_none=True)
if _is_async:
return self.async_rerank(request_data_dict, api_key) # type: ignore # Call async method
response = client.post(
"https://api.jina.ai/v1/rerank",
headers={
"accept": "application/json",
"content-type": "application/json",
"authorization": f"Bearer {api_key}",
},
json=request_data_dict,
)
if response.status_code != 200:
raise Exception(response.text)
_json_response = response.json()
return JinaAIRerankConfig()._transform_response(_json_response)
async def async_rerank( # New async method
self,
request_data_dict: Dict[str, Any],
api_key: str,
) -> RerankResponse:
client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.JINA_AI
) # Use async client
response = await client.post(
"https://api.jina.ai/v1/rerank",
headers={
"accept": "application/json",
"content-type": "application/json",
"authorization": f"Bearer {api_key}",
},
json=request_data_dict,
)
if response.status_code != 200:
raise Exception(response.text)
_json_response = response.json()
return JinaAIRerankConfig()._transform_response(_json_response)
pass

View file

@ -0,0 +1,36 @@
"""
Transformation logic from Cohere's /v1/rerank format to Jina AI's `/v1/rerank` format.
Why separate file? Make it easy to see how transformation works
Docs - https://jina.ai/reranker
"""
import uuid
from typing import List, Optional
from litellm.types.rerank import (
RerankBilledUnits,
RerankResponse,
RerankResponseMeta,
RerankTokens,
)
class JinaAIRerankConfig:
def _transform_response(self, response: dict) -> RerankResponse:
_billed_units = RerankBilledUnits(**response.get("usage", {}))
_tokens = RerankTokens(**response.get("usage", {}))
rerank_meta = RerankResponseMeta(billed_units=_billed_units, tokens=_tokens)
_results: Optional[List[dict]] = response.get("results")
if _results is None:
raise ValueError(f"No results found in the response={response}")
return RerankResponse(
id=response.get("id") or str(uuid.uuid4()),
results=_results,
meta=rerank_meta,
) # Return response

View file

@ -185,6 +185,8 @@ class OllamaConfig:
"name": "mistral"
}'
"""
if model.startswith("ollama/") or model.startswith("ollama_chat/"):
model = model.split("/", 1)[1]
api_base = get_secret_str("OLLAMA_API_BASE") or "http://localhost:11434"
try:

View file

@ -15,7 +15,14 @@ from litellm.llms.custom_httpx.http_handler import (
_get_httpx_client,
get_async_httpx_client,
)
from litellm.types.rerank import RerankRequest, RerankResponse
from litellm.llms.together_ai.rerank.transformation import TogetherAIRerankConfig
from litellm.types.rerank import (
RerankBilledUnits,
RerankRequest,
RerankResponse,
RerankResponseMeta,
RerankTokens,
)
class TogetherAIRerank(BaseLLM):
@ -65,13 +72,7 @@ class TogetherAIRerank(BaseLLM):
_json_response = response.json()
response = RerankResponse(
id=_json_response.get("id"),
results=_json_response.get("results"),
meta=_json_response.get("meta") or {},
)
return response
return TogetherAIRerankConfig()._transform_response(_json_response)
async def async_rerank( # New async method
self,
@ -97,10 +98,4 @@ class TogetherAIRerank(BaseLLM):
_json_response = response.json()
return RerankResponse(
id=_json_response.get("id"),
results=_json_response.get("results"),
meta=_json_response.get("meta") or {},
) # Return response
pass
return TogetherAIRerankConfig()._transform_response(_json_response)

View file

@ -0,0 +1,34 @@
"""
Transformation logic from Cohere's /v1/rerank format to Together AI's `/v1/rerank` format.
Why separate file? Make it easy to see how transformation works
"""
import uuid
from typing import List, Optional
from litellm.types.rerank import (
RerankBilledUnits,
RerankResponse,
RerankResponseMeta,
RerankTokens,
)
class TogetherAIRerankConfig:
def _transform_response(self, response: dict) -> RerankResponse:
_billed_units = RerankBilledUnits(**response.get("usage", {}))
_tokens = RerankTokens(**response.get("usage", {}))
rerank_meta = RerankResponseMeta(billed_units=_billed_units, tokens=_tokens)
_results: Optional[List[dict]] = response.get("results")
if _results is None:
raise ValueError(f"No results found in the response={response}")
return RerankResponse(
id=response.get("id") or str(uuid.uuid4()),
results=_results,
meta=rerank_meta,
) # Return response

View file

@ -89,6 +89,9 @@ def _get_vertex_url(
elif mode == "embedding":
endpoint = "predict"
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}"
if model.isdigit():
# https://us-central1-aiplatform.googleapis.com/v1/projects/$PROJECT_ID/locations/us-central1/endpoints/$ENDPOINT_ID:predict
url = f"https://{vertex_location}-aiplatform.googleapis.com/{vertex_api_version}/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}:{endpoint}"
if not url or not endpoint:
raise ValueError(f"Unable to get vertex url/endpoint for mode: {mode}")

View file

@ -0,0 +1,25 @@
"""
Vertex AI Image Generation Cost Calculator
"""
from typing import Optional
import litellm
from litellm.types.utils import ImageResponse
def cost_calculator(
model: str,
image_response: ImageResponse,
) -> float:
"""
Vertex AI Image Generation Cost Calculator
"""
_model_info = litellm.get_model_info(
model=model,
custom_llm_provider="vertex_ai",
)
output_cost_per_image: float = _model_info.get("output_cost_per_image") or 0.0
num_images: int = len(image_response.data)
return output_cost_per_image * num_images

View file

@ -96,7 +96,7 @@ class VertexEmbedding(VertexBase):
headers = self.set_headers(auth_header=auth_header, extra_headers=extra_headers)
vertex_request: VertexEmbeddingRequest = (
litellm.vertexAITextEmbeddingConfig.transform_openai_request_to_vertex_embedding_request(
input=input, optional_params=optional_params
input=input, optional_params=optional_params, model=model
)
)
@ -188,7 +188,7 @@ class VertexEmbedding(VertexBase):
headers = self.set_headers(auth_header=auth_header, extra_headers=extra_headers)
vertex_request: VertexEmbeddingRequest = (
litellm.vertexAITextEmbeddingConfig.transform_openai_request_to_vertex_embedding_request(
input=input, optional_params=optional_params
input=input, optional_params=optional_params, model=model
)
)

View file

@ -101,11 +101,16 @@ class VertexAITextEmbeddingConfig(BaseModel):
return optional_params
def transform_openai_request_to_vertex_embedding_request(
self, input: Union[list, str], optional_params: dict
self, input: Union[list, str], optional_params: dict, model: str
) -> VertexEmbeddingRequest:
"""
Transforms an openai request to a vertex embedding request.
"""
if model.isdigit():
return self._transform_openai_request_to_fine_tuned_embedding_request(
input, optional_params, model
)
vertex_request: VertexEmbeddingRequest = VertexEmbeddingRequest()
vertex_text_embedding_input_list: List[TextEmbeddingInput] = []
task_type: Optional[TaskType] = optional_params.get("task_type")
@ -125,6 +130,47 @@ class VertexAITextEmbeddingConfig(BaseModel):
return vertex_request
def _transform_openai_request_to_fine_tuned_embedding_request(
self, input: Union[list, str], optional_params: dict, model: str
) -> VertexEmbeddingRequest:
"""
Transforms an openai request to a vertex fine-tuned embedding request.
Vertex Doc: https://console.cloud.google.com/vertex-ai/model-garden?hl=en&project=adroit-crow-413218&pageState=(%22galleryStateKey%22:(%22f%22:(%22g%22:%5B%5D,%22o%22:%5B%5D),%22s%22:%22%22))
Sample Request:
```json
{
"instances" : [
{
"inputs": "How would the Future of AI in 10 Years look?",
"parameters": {
"max_new_tokens": 128,
"temperature": 1.0,
"top_p": 0.9,
"top_k": 10
}
}
]
}
```
"""
vertex_request: VertexEmbeddingRequest = VertexEmbeddingRequest()
vertex_text_embedding_input_list: List[TextEmbeddingFineTunedInput] = []
if isinstance(input, str):
input = [input] # Convert single string to list for uniform processing
for text in input:
embedding_input = TextEmbeddingFineTunedInput(inputs=text)
vertex_text_embedding_input_list.append(embedding_input)
vertex_request["instances"] = vertex_text_embedding_input_list
vertex_request["parameters"] = TextEmbeddingFineTunedParameters(
**optional_params
)
return vertex_request
def create_embedding_input(
self,
content: str,
@ -157,6 +203,11 @@ class VertexAITextEmbeddingConfig(BaseModel):
"""
Transforms a vertex embedding response to an openai response.
"""
if model.isdigit():
return self._transform_vertex_response_to_openai_for_fine_tuned_models(
response, model, model_response
)
_predictions = response["predictions"]
embedding_response = []
@ -181,3 +232,35 @@ class VertexAITextEmbeddingConfig(BaseModel):
)
setattr(model_response, "usage", usage)
return model_response
def _transform_vertex_response_to_openai_for_fine_tuned_models(
self, response: dict, model: str, model_response: litellm.EmbeddingResponse
) -> litellm.EmbeddingResponse:
"""
Transforms a vertex fine-tuned model embedding response to an openai response format.
"""
_predictions = response["predictions"]
embedding_response = []
# For fine-tuned models, we don't get token counts in the response
input_tokens = 0
for idx, embedding_values in enumerate(_predictions):
embedding_response.append(
{
"object": "embedding",
"index": idx,
"embedding": embedding_values[
0
], # The embedding values are nested one level deeper
}
)
model_response.object = "list"
model_response.data = embedding_response
model_response.model = model
usage = Usage(
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
)
setattr(model_response, "usage", usage)
return model_response

View file

@ -23,14 +23,27 @@ class TextEmbeddingInput(TypedDict, total=False):
title: Optional[str]
# Fine-tuned models require a different input format
# Ref: https://console.cloud.google.com/vertex-ai/model-garden?hl=en&project=adroit-crow-413218&pageState=(%22galleryStateKey%22:(%22f%22:(%22g%22:%5B%5D,%22o%22:%5B%5D),%22s%22:%22%22))
class TextEmbeddingFineTunedInput(TypedDict, total=False):
inputs: str
class TextEmbeddingFineTunedParameters(TypedDict, total=False):
max_new_tokens: Optional[int]
temperature: Optional[float]
top_p: Optional[float]
top_k: Optional[int]
class EmbeddingParameters(TypedDict, total=False):
auto_truncate: Optional[bool]
output_dimensionality: Optional[int]
class VertexEmbeddingRequest(TypedDict, total=False):
instances: List[TextEmbeddingInput]
parameters: Optional[EmbeddingParameters]
instances: Union[List[TextEmbeddingInput], List[TextEmbeddingFineTunedInput]]
parameters: Optional[Union[EmbeddingParameters, TextEmbeddingFineTunedParameters]]
# Example usage:

View file

@ -1066,6 +1066,7 @@ def completion( # type: ignore # noqa: PLR0915
azure_ad_token_provider=kwargs.get("azure_ad_token_provider"),
user_continue_message=kwargs.get("user_continue_message"),
base_model=base_model,
litellm_trace_id=kwargs.get("litellm_trace_id"),
)
logging.update_environment_variables(
model=model,
@ -3455,7 +3456,7 @@ def embedding( # noqa: PLR0915
client=client,
aembedding=aembedding,
)
elif custom_llm_provider == "openai_like":
elif custom_llm_provider == "openai_like" or custom_llm_provider == "jina_ai":
api_base = (
api_base or litellm.api_base or get_secret_str("OPENAI_LIKE_API_BASE")
)

View file

@ -2986,19 +2986,19 @@
"supports_function_calling": true
},
"vertex_ai/imagegeneration@006": {
"cost_per_image": 0.020,
"output_cost_per_image": 0.020,
"litellm_provider": "vertex_ai-image-models",
"mode": "image_generation",
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing"
},
"vertex_ai/imagen-3.0-generate-001": {
"cost_per_image": 0.04,
"output_cost_per_image": 0.04,
"litellm_provider": "vertex_ai-image-models",
"mode": "image_generation",
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing"
},
"vertex_ai/imagen-3.0-fast-generate-001": {
"cost_per_image": 0.02,
"output_cost_per_image": 0.02,
"litellm_provider": "vertex_ai-image-models",
"mode": "image_generation",
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing"
@ -5620,6 +5620,13 @@
"litellm_provider": "bedrock",
"mode": "image_generation"
},
"stability.stable-image-ultra-v1:0": {
"max_tokens": 77,
"max_input_tokens": 77,
"output_cost_per_image": 0.14,
"litellm_provider": "bedrock",
"mode": "image_generation"
},
"sagemaker/meta-textgeneration-llama-2-7b": {
"max_tokens": 4096,
"max_input_tokens": 4096,

View file

@ -1,122 +1,15 @@
model_list:
- model_name: "*"
litellm_params:
model: claude-3-5-sonnet-20240620
api_key: os.environ/ANTHROPIC_API_KEY
- model_name: claude-3-5-sonnet-aihubmix
litellm_params:
model: openai/claude-3-5-sonnet-20240620
input_cost_per_token: 0.000003 # 3$/M
output_cost_per_token: 0.000015 # 15$/M
api_base: "https://exampleopenaiendpoint-production.up.railway.app"
api_key: my-fake-key
- model_name: fake-openai-endpoint-2
litellm_params:
model: openai/my-fake-model
api_key: my-fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app/
stream_timeout: 0.001
timeout: 1
rpm: 1
- model_name: fake-openai-endpoint
litellm_params:
model: openai/my-fake-model
api_key: my-fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app/
## bedrock chat completions
- model_name: "*anthropic.claude*"
litellm_params:
model: bedrock/*anthropic.claude*
aws_access_key_id: os.environ/BEDROCK_AWS_ACCESS_KEY_ID
aws_secret_access_key: os.environ/BEDROCK_AWS_SECRET_ACCESS_KEY
aws_region_name: os.environ/AWS_REGION_NAME
guardrailConfig:
"guardrailIdentifier": "h4dsqwhp6j66"
"guardrailVersion": "2"
"trace": "enabled"
## bedrock embeddings
- model_name: "*amazon.titan-embed-*"
litellm_params:
model: bedrock/amazon.titan-embed-*
aws_access_key_id: os.environ/BEDROCK_AWS_ACCESS_KEY_ID
aws_secret_access_key: os.environ/BEDROCK_AWS_SECRET_ACCESS_KEY
aws_region_name: os.environ/AWS_REGION_NAME
- model_name: "*cohere.embed-*"
litellm_params:
model: bedrock/cohere.embed-*
aws_access_key_id: os.environ/BEDROCK_AWS_ACCESS_KEY_ID
aws_secret_access_key: os.environ/BEDROCK_AWS_SECRET_ACCESS_KEY
aws_region_name: os.environ/AWS_REGION_NAME
- model_name: "bedrock/*"
litellm_params:
model: bedrock/*
aws_access_key_id: os.environ/BEDROCK_AWS_ACCESS_KEY_ID
aws_secret_access_key: os.environ/BEDROCK_AWS_SECRET_ACCESS_KEY
aws_region_name: os.environ/AWS_REGION_NAME
# GPT-4 Turbo Models
- model_name: gpt-4
litellm_params:
model: azure/chatgpt-v-2
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
api_version: "2023-05-15"
api_key: os.environ/AZURE_API_KEY # The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault
rpm: 480
timeout: 300
stream_timeout: 60
litellm_settings:
fallbacks: [{ "claude-3-5-sonnet-20240620": ["claude-3-5-sonnet-aihubmix"] }]
# callbacks: ["otel", "prometheus"]
default_redis_batch_cache_expiry: 10
# default_team_settings:
# - team_id: "dbe2f686-a686-4896-864a-4c3924458709"
# success_callback: ["langfuse"]
# langfuse_public_key: os.environ/LANGFUSE_PUB_KEY_1 # Project 1
# langfuse_secret: os.environ/LANGFUSE_PRIVATE_KEY_1 # Project 1
# litellm_settings:
# cache: True
# cache_params:
# type: redis
# # disable caching on the actual API call
# supported_call_types: []
# # see https://docs.litellm.ai/docs/proxy/prod#3-use-redis-porthost-password-not-redis_url
# host: os.environ/REDIS_HOST
# port: os.environ/REDIS_PORT
# password: os.environ/REDIS_PASSWORD
# # see https://docs.litellm.ai/docs/proxy/caching#turn-on-batch_redis_requests
# # see https://docs.litellm.ai/docs/proxy/prometheus
# callbacks: ['otel']
model: gpt-4
- model_name: rerank-model
litellm_params:
model: jina_ai/jina-reranker-v2-base-multilingual
# # router_settings:
# # routing_strategy: latency-based-routing
# # routing_strategy_args:
# # # only assign 40% of traffic to the fastest deployment to avoid overloading it
# # lowest_latency_buffer: 0.4
# # # consider last five minutes of calls for latency calculation
# # ttl: 300
# # redis_host: os.environ/REDIS_HOST
# # redis_port: os.environ/REDIS_PORT
# # redis_password: os.environ/REDIS_PASSWORD
# # # see https://docs.litellm.ai/docs/proxy/prod#1-use-this-configyaml
# # general_settings:
# # master_key: os.environ/LITELLM_MASTER_KEY
# # database_url: os.environ/DATABASE_URL
# # disable_master_key_return: true
# # # alerting: ['slack', 'email']
# # alerting: ['email']
# # # Batch write spend updates every 60s
# # proxy_batch_write_at: 60
# # # see https://docs.litellm.ai/docs/proxy/caching#advanced---user-api-key-cache-ttl
# # # our api keys rarely change
# # user_api_key_cache_ttl: 3600
router_settings:
model_group_alias:
"gpt-4-turbo": # Aliased model name
model: "gpt-4" # Actual model name in 'model_list'
hidden: true

View file

@ -1128,7 +1128,16 @@ class KeyManagementSystem(enum.Enum):
class KeyManagementSettings(LiteLLMBase):
hosted_keys: List
hosted_keys: Optional[List] = None
store_virtual_keys: Optional[bool] = False
"""
If True, virtual keys created by litellm will be stored in the secret manager
"""
access_mode: Literal["read_only", "write_only", "read_and_write"] = "read_only"
"""
Access mode for the secret manager, when write_only will only use for writing secrets
"""
class TeamDefaultSettings(LiteLLMBase):

View file

@ -8,6 +8,7 @@ Run checks for:
2. If user is in budget
3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
"""
import time
import traceback
from datetime import datetime

View file

@ -0,0 +1,267 @@
import asyncio
import json
import uuid
from datetime import datetime, timezone
from re import A
from typing import Any, List, Optional
from fastapi import status
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import (
GenerateKeyRequest,
KeyManagementSystem,
KeyRequest,
LiteLLM_AuditLogs,
LiteLLM_VerificationToken,
LitellmTableNames,
ProxyErrorTypes,
ProxyException,
UpdateKeyRequest,
UserAPIKeyAuth,
WebhookEvent,
)
class KeyManagementEventHooks:
@staticmethod
async def async_key_generated_hook(
data: GenerateKeyRequest,
response: dict,
user_api_key_dict: UserAPIKeyAuth,
litellm_changed_by: Optional[str] = None,
):
"""
Hook that runs after a successful /key/generate request
Handles the following:
- Sending Email with Key Details
- Storing Audit Logs for key generation
- Storing Generated Key in DB
"""
from litellm.proxy.management_helpers.audit_logs import (
create_audit_log_for_update,
)
from litellm.proxy.proxy_server import (
general_settings,
litellm_proxy_admin_name,
proxy_logging_obj,
)
if data.send_invite_email is True:
await KeyManagementEventHooks._send_key_created_email(response)
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
if litellm.store_audit_logs is True:
_updated_values = json.dumps(response, default=str)
asyncio.create_task(
create_audit_log_for_update(
request_data=LiteLLM_AuditLogs(
id=str(uuid.uuid4()),
updated_at=datetime.now(timezone.utc),
changed_by=litellm_changed_by
or user_api_key_dict.user_id
or litellm_proxy_admin_name,
changed_by_api_key=user_api_key_dict.api_key,
table_name=LitellmTableNames.KEY_TABLE_NAME,
object_id=response.get("token_id", ""),
action="created",
updated_values=_updated_values,
before_value=None,
)
)
)
# store the generated key in the secret manager
await KeyManagementEventHooks._store_virtual_key_in_secret_manager(
secret_name=data.key_alias or f"virtual-key-{uuid.uuid4()}",
secret_token=response.get("token", ""),
)
@staticmethod
async def async_key_updated_hook(
data: UpdateKeyRequest,
existing_key_row: Any,
response: Any,
user_api_key_dict: UserAPIKeyAuth,
litellm_changed_by: Optional[str] = None,
):
"""
Post /key/update processing hook
Handles the following:
- Storing Audit Logs for key update
"""
from litellm.proxy.management_helpers.audit_logs import (
create_audit_log_for_update,
)
from litellm.proxy.proxy_server import litellm_proxy_admin_name
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
if litellm.store_audit_logs is True:
_updated_values = json.dumps(data.json(exclude_none=True), default=str)
_before_value = existing_key_row.json(exclude_none=True)
_before_value = json.dumps(_before_value, default=str)
asyncio.create_task(
create_audit_log_for_update(
request_data=LiteLLM_AuditLogs(
id=str(uuid.uuid4()),
updated_at=datetime.now(timezone.utc),
changed_by=litellm_changed_by
or user_api_key_dict.user_id
or litellm_proxy_admin_name,
changed_by_api_key=user_api_key_dict.api_key,
table_name=LitellmTableNames.KEY_TABLE_NAME,
object_id=data.key,
action="updated",
updated_values=_updated_values,
before_value=_before_value,
)
)
)
pass
@staticmethod
async def async_key_deleted_hook(
data: KeyRequest,
keys_being_deleted: List[LiteLLM_VerificationToken],
response: dict,
user_api_key_dict: UserAPIKeyAuth,
litellm_changed_by: Optional[str] = None,
):
"""
Post /key/delete processing hook
Handles the following:
- Storing Audit Logs for key deletion
"""
from litellm.proxy.management_helpers.audit_logs import (
create_audit_log_for_update,
)
from litellm.proxy.proxy_server import litellm_proxy_admin_name, prisma_client
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
# we do this after the first for loop, since first for loop is for validation. we only want this inserted after validation passes
if litellm.store_audit_logs is True:
# make an audit log for each team deleted
for key in data.keys:
key_row = await prisma_client.get_data( # type: ignore
token=key, table_name="key", query_type="find_unique"
)
if key_row is None:
raise ProxyException(
message=f"Key {key} not found",
type=ProxyErrorTypes.bad_request_error,
param="key",
code=status.HTTP_404_NOT_FOUND,
)
key_row = key_row.json(exclude_none=True)
_key_row = json.dumps(key_row, default=str)
asyncio.create_task(
create_audit_log_for_update(
request_data=LiteLLM_AuditLogs(
id=str(uuid.uuid4()),
updated_at=datetime.now(timezone.utc),
changed_by=litellm_changed_by
or user_api_key_dict.user_id
or litellm_proxy_admin_name,
changed_by_api_key=user_api_key_dict.api_key,
table_name=LitellmTableNames.KEY_TABLE_NAME,
object_id=key,
action="deleted",
updated_values="{}",
before_value=_key_row,
)
)
)
# delete the keys from the secret manager
await KeyManagementEventHooks._delete_virtual_keys_from_secret_manager(
keys_being_deleted=keys_being_deleted
)
pass
@staticmethod
async def _store_virtual_key_in_secret_manager(secret_name: str, secret_token: str):
"""
Store a virtual key in the secret manager
Args:
secret_name: Name of the virtual key
secret_token: Value of the virtual key (example: sk-1234)
"""
if litellm._key_management_settings is not None:
if litellm._key_management_settings.store_virtual_keys is True:
from litellm.secret_managers.aws_secret_manager_v2 import (
AWSSecretsManagerV2,
)
# store the key in the secret manager
if (
litellm._key_management_system
== KeyManagementSystem.AWS_SECRET_MANAGER
and isinstance(litellm.secret_manager_client, AWSSecretsManagerV2)
):
await litellm.secret_manager_client.async_write_secret(
secret_name=secret_name,
secret_value=secret_token,
)
@staticmethod
async def _delete_virtual_keys_from_secret_manager(
keys_being_deleted: List[LiteLLM_VerificationToken],
):
"""
Deletes virtual keys from the secret manager
Args:
keys_being_deleted: List of keys being deleted, this is passed down from the /key/delete operation
"""
if litellm._key_management_settings is not None:
if litellm._key_management_settings.store_virtual_keys is True:
from litellm.secret_managers.aws_secret_manager_v2 import (
AWSSecretsManagerV2,
)
if isinstance(litellm.secret_manager_client, AWSSecretsManagerV2):
for key in keys_being_deleted:
if key.key_alias is not None:
await litellm.secret_manager_client.async_delete_secret(
secret_name=key.key_alias
)
else:
verbose_proxy_logger.warning(
f"KeyManagementEventHooks._delete_virtual_key_from_secret_manager: Key alias not found for key {key.token}. Skipping deletion from secret manager."
)
@staticmethod
async def _send_key_created_email(response: dict):
from litellm.proxy.proxy_server import general_settings, proxy_logging_obj
if "email" not in general_settings.get("alerting", []):
raise ValueError(
"Email alerting not setup on config.yaml. Please set `alerting=['email']. \nDocs: https://docs.litellm.ai/docs/proxy/email`"
)
event = WebhookEvent(
event="key_created",
event_group="key",
event_message="API Key Created",
token=response.get("token", ""),
spend=response.get("spend", 0.0),
max_budget=response.get("max_budget", 0.0),
user_id=response.get("user_id", None),
team_id=response.get("team_id", "Default Team"),
key_alias=response.get("key_alias", None),
)
# If user configured email alerting - send an Email letting their end-user know the key was created
asyncio.create_task(
proxy_logging_obj.slack_alerting_instance.send_key_created_or_user_invited_email(
webhook_event=event,
)
)

View file

@ -274,6 +274,51 @@ class LiteLLMProxyRequestSetup:
)
return user_api_key_logged_metadata
@staticmethod
def add_key_level_controls(
key_metadata: dict, data: dict, _metadata_variable_name: str
):
data = data.copy()
if "cache" in key_metadata:
data["cache"] = {}
if isinstance(key_metadata["cache"], dict):
for k, v in key_metadata["cache"].items():
if k in SupportedCacheControls:
data["cache"][k] = v
## KEY-LEVEL SPEND LOGS / TAGS
if "tags" in key_metadata and key_metadata["tags"] is not None:
if "tags" in data[_metadata_variable_name] and isinstance(
data[_metadata_variable_name]["tags"], list
):
data[_metadata_variable_name]["tags"].extend(key_metadata["tags"])
else:
data[_metadata_variable_name]["tags"] = key_metadata["tags"]
if "spend_logs_metadata" in key_metadata and isinstance(
key_metadata["spend_logs_metadata"], dict
):
if "spend_logs_metadata" in data[_metadata_variable_name] and isinstance(
data[_metadata_variable_name]["spend_logs_metadata"], dict
):
for key, value in key_metadata["spend_logs_metadata"].items():
if (
key not in data[_metadata_variable_name]["spend_logs_metadata"]
): # don't override k-v pair sent by request (user request)
data[_metadata_variable_name]["spend_logs_metadata"][
key
] = value
else:
data[_metadata_variable_name]["spend_logs_metadata"] = key_metadata[
"spend_logs_metadata"
]
## KEY-LEVEL DISABLE FALLBACKS
if "disable_fallbacks" in key_metadata and isinstance(
key_metadata["disable_fallbacks"], bool
):
data["disable_fallbacks"] = key_metadata["disable_fallbacks"]
return data
async def add_litellm_data_to_request( # noqa: PLR0915
data: dict,
@ -389,37 +434,11 @@ async def add_litellm_data_to_request( # noqa: PLR0915
### KEY-LEVEL Controls
key_metadata = user_api_key_dict.metadata
if "cache" in key_metadata:
data["cache"] = {}
if isinstance(key_metadata["cache"], dict):
for k, v in key_metadata["cache"].items():
if k in SupportedCacheControls:
data["cache"][k] = v
## KEY-LEVEL SPEND LOGS / TAGS
if "tags" in key_metadata and key_metadata["tags"] is not None:
if "tags" in data[_metadata_variable_name] and isinstance(
data[_metadata_variable_name]["tags"], list
):
data[_metadata_variable_name]["tags"].extend(key_metadata["tags"])
else:
data[_metadata_variable_name]["tags"] = key_metadata["tags"]
if "spend_logs_metadata" in key_metadata and isinstance(
key_metadata["spend_logs_metadata"], dict
):
if "spend_logs_metadata" in data[_metadata_variable_name] and isinstance(
data[_metadata_variable_name]["spend_logs_metadata"], dict
):
for key, value in key_metadata["spend_logs_metadata"].items():
if (
key not in data[_metadata_variable_name]["spend_logs_metadata"]
): # don't override k-v pair sent by request (user request)
data[_metadata_variable_name]["spend_logs_metadata"][key] = value
else:
data[_metadata_variable_name]["spend_logs_metadata"] = key_metadata[
"spend_logs_metadata"
]
data = LiteLLMProxyRequestSetup.add_key_level_controls(
key_metadata=key_metadata,
data=data,
_metadata_variable_name=_metadata_variable_name,
)
## TEAM-LEVEL SPEND LOGS/TAGS
team_metadata = user_api_key_dict.team_metadata or {}
if "tags" in team_metadata and team_metadata["tags"] is not None:

View file

@ -17,7 +17,7 @@ import secrets
import traceback
import uuid
from datetime import datetime, timedelta, timezone
from typing import List, Optional
from typing import List, Optional, Tuple
import fastapi
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Request, status
@ -31,6 +31,7 @@ from litellm.proxy.auth.auth_checks import (
get_key_object,
)
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.hooks.key_management_event_hooks import KeyManagementEventHooks
from litellm.proxy.management_helpers.utils import management_endpoint_wrapper
from litellm.proxy.utils import _duration_in_seconds, _hash_token_if_needed
from litellm.secret_managers.main import get_secret
@ -234,48 +235,12 @@ async def generate_key_fn( # noqa: PLR0915
data.soft_budget
) # include the user-input soft budget in the response
if data.send_invite_email is True:
if "email" not in general_settings.get("alerting", []):
raise ValueError(
"Email alerting not setup on config.yaml. Please set `alerting=['email']. \nDocs: https://docs.litellm.ai/docs/proxy/email`"
)
event = WebhookEvent(
event="key_created",
event_group="key",
event_message="API Key Created",
token=response.get("token", ""),
spend=response.get("spend", 0.0),
max_budget=response.get("max_budget", 0.0),
user_id=response.get("user_id", None),
team_id=response.get("team_id", "Default Team"),
key_alias=response.get("key_alias", None),
)
# If user configured email alerting - send an Email letting their end-user know the key was created
asyncio.create_task(
proxy_logging_obj.slack_alerting_instance.send_key_created_or_user_invited_email(
webhook_event=event,
)
)
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
if litellm.store_audit_logs is True:
_updated_values = json.dumps(response, default=str)
asyncio.create_task(
create_audit_log_for_update(
request_data=LiteLLM_AuditLogs(
id=str(uuid.uuid4()),
updated_at=datetime.now(timezone.utc),
changed_by=litellm_changed_by
or user_api_key_dict.user_id
or litellm_proxy_admin_name,
changed_by_api_key=user_api_key_dict.api_key,
table_name=LitellmTableNames.KEY_TABLE_NAME,
object_id=response.get("token_id", ""),
action="created",
updated_values=_updated_values,
before_value=None,
)
KeyManagementEventHooks.async_key_generated_hook(
data=data,
response=response,
user_api_key_dict=user_api_key_dict,
litellm_changed_by=litellm_changed_by,
)
)
@ -407,28 +372,13 @@ async def update_key_fn(
proxy_logging_obj=proxy_logging_obj,
)
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
if litellm.store_audit_logs is True:
_updated_values = json.dumps(data_json, default=str)
_before_value = existing_key_row.json(exclude_none=True)
_before_value = json.dumps(_before_value, default=str)
asyncio.create_task(
create_audit_log_for_update(
request_data=LiteLLM_AuditLogs(
id=str(uuid.uuid4()),
updated_at=datetime.now(timezone.utc),
changed_by=litellm_changed_by
or user_api_key_dict.user_id
or litellm_proxy_admin_name,
changed_by_api_key=user_api_key_dict.api_key,
table_name=LitellmTableNames.KEY_TABLE_NAME,
object_id=data.key,
action="updated",
updated_values=_updated_values,
before_value=_before_value,
)
KeyManagementEventHooks.async_key_updated_hook(
data=data,
existing_key_row=existing_key_row,
response=response,
user_api_key_dict=user_api_key_dict,
litellm_changed_by=litellm_changed_by,
)
)
@ -496,6 +446,9 @@ async def delete_key_fn(
user_custom_key_generate,
)
if prisma_client is None:
raise Exception("Not connected to DB!")
keys = data.keys
if len(keys) == 0:
raise ProxyException(
@ -516,45 +469,7 @@ async def delete_key_fn(
):
user_id = None # unless they're admin
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
# we do this after the first for loop, since first for loop is for validation. we only want this inserted after validation passes
if litellm.store_audit_logs is True:
# make an audit log for each team deleted
for key in data.keys:
key_row = await prisma_client.get_data( # type: ignore
token=key, table_name="key", query_type="find_unique"
)
if key_row is None:
raise ProxyException(
message=f"Key {key} not found",
type=ProxyErrorTypes.bad_request_error,
param="key",
code=status.HTTP_404_NOT_FOUND,
)
key_row = key_row.json(exclude_none=True)
_key_row = json.dumps(key_row, default=str)
asyncio.create_task(
create_audit_log_for_update(
request_data=LiteLLM_AuditLogs(
id=str(uuid.uuid4()),
updated_at=datetime.now(timezone.utc),
changed_by=litellm_changed_by
or user_api_key_dict.user_id
or litellm_proxy_admin_name,
changed_by_api_key=user_api_key_dict.api_key,
table_name=LitellmTableNames.KEY_TABLE_NAME,
object_id=key,
action="deleted",
updated_values="{}",
before_value=_key_row,
)
)
)
number_deleted_keys = await delete_verification_token(
number_deleted_keys, _keys_being_deleted = await delete_verification_token(
tokens=keys, user_id=user_id
)
if number_deleted_keys is None:
@ -588,6 +503,16 @@ async def delete_key_fn(
f"/keys/delete - cache after delete: {user_api_key_cache.in_memory_cache.cache_dict}"
)
asyncio.create_task(
KeyManagementEventHooks.async_key_deleted_hook(
data=data,
keys_being_deleted=_keys_being_deleted,
user_api_key_dict=user_api_key_dict,
litellm_changed_by=litellm_changed_by,
response=number_deleted_keys,
)
)
return {"deleted_keys": keys}
except Exception as e:
if isinstance(e, HTTPException):
@ -1026,11 +951,35 @@ async def generate_key_helper_fn( # noqa: PLR0915
return key_data
async def delete_verification_token(tokens: List, user_id: Optional[str] = None):
async def delete_verification_token(
tokens: List, user_id: Optional[str] = None
) -> Tuple[Optional[Dict], List[LiteLLM_VerificationToken]]:
"""
Helper that deletes the list of tokens from the database
Args:
tokens: List of tokens to delete
user_id: Optional user_id to filter by
Returns:
Tuple[Optional[Dict], List[LiteLLM_VerificationToken]]:
Optional[Dict]:
- Number of deleted tokens
List[LiteLLM_VerificationToken]:
- List of keys being deleted, this contains information about the key_alias, token, and user_id being deleted,
this is passed down to the KeyManagementEventHooks to delete the keys from the secret manager and handle audit logs
"""
from litellm.proxy.proxy_server import litellm_proxy_admin_name, prisma_client
try:
if prisma_client:
tokens = [_hash_token_if_needed(token=key) for key in tokens]
_keys_being_deleted = (
await prisma_client.db.litellm_verificationtoken.find_many(
where={"token": {"in": tokens}}
)
)
# Assuming 'db' is your Prisma Client instance
# check if admin making request - don't filter by user-id
if user_id == litellm_proxy_admin_name:
@ -1060,7 +1009,7 @@ async def delete_verification_token(tokens: List, user_id: Optional[str] = None)
)
verbose_proxy_logger.debug(traceback.format_exc())
raise e
return deleted_tokens
return deleted_tokens, _keys_being_deleted
@router.post(

View file

@ -265,7 +265,6 @@ def run_server( # noqa: PLR0915
ProxyConfig,
app,
load_aws_kms,
load_aws_secret_manager,
load_from_azure_key_vault,
load_google_kms,
save_worker_config,
@ -278,7 +277,6 @@ def run_server( # noqa: PLR0915
ProxyConfig,
app,
load_aws_kms,
load_aws_secret_manager,
load_from_azure_key_vault,
load_google_kms,
save_worker_config,
@ -295,7 +293,6 @@ def run_server( # noqa: PLR0915
ProxyConfig,
app,
load_aws_kms,
load_aws_secret_manager,
load_from_azure_key_vault,
load_google_kms,
save_worker_config,
@ -559,8 +556,14 @@ def run_server( # noqa: PLR0915
key_management_system
== KeyManagementSystem.AWS_SECRET_MANAGER.value # noqa: F405
):
from litellm.secret_managers.aws_secret_manager_v2 import (
AWSSecretsManagerV2,
)
### LOAD FROM AWS SECRET MANAGER ###
load_aws_secret_manager(use_aws_secret_manager=True)
AWSSecretsManagerV2.load_aws_secret_manager(
use_aws_secret_manager=True
)
elif key_management_system == KeyManagementSystem.AWS_KMS.value:
load_aws_kms(use_aws_kms=True)
elif (

View file

@ -7,6 +7,8 @@ model_list:
litellm_settings:
callbacks: ["gcs_bucket"]
general_settings:
key_management_system: "aws_secret_manager"
key_management_settings:
store_virtual_keys: true
access_mode: "write_only"

View file

@ -245,10 +245,7 @@ from litellm.router import (
from litellm.router import ModelInfo as RouterModelInfo
from litellm.router import updateDeployment
from litellm.scheduler import DefaultPriorities, FlowItem, Scheduler
from litellm.secret_managers.aws_secret_manager import (
load_aws_kms,
load_aws_secret_manager,
)
from litellm.secret_managers.aws_secret_manager import load_aws_kms
from litellm.secret_managers.google_kms import load_google_kms
from litellm.secret_managers.main import (
get_secret,
@ -1825,8 +1822,13 @@ class ProxyConfig:
key_management_system
== KeyManagementSystem.AWS_SECRET_MANAGER.value # noqa: F405
):
### LOAD FROM AWS SECRET MANAGER ###
load_aws_secret_manager(use_aws_secret_manager=True)
from litellm.secret_managers.aws_secret_manager_v2 import (
AWSSecretsManagerV2,
)
AWSSecretsManagerV2.load_aws_secret_manager(
use_aws_secret_manager=True
)
elif key_management_system == KeyManagementSystem.AWS_KMS.value:
load_aws_kms(use_aws_kms=True)
elif (

View file

@ -8,7 +8,8 @@ from litellm._logging import verbose_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.azure_ai.rerank import AzureAIRerank
from litellm.llms.cohere.rerank import CohereRerank
from litellm.llms.together_ai.rerank import TogetherAIRerank
from litellm.llms.jina_ai.rerank.handler import JinaAIRerank
from litellm.llms.together_ai.rerank.handler import TogetherAIRerank
from litellm.secret_managers.main import get_secret
from litellm.types.rerank import RerankRequest, RerankResponse
from litellm.types.router import *
@ -19,6 +20,7 @@ from litellm.utils import client, exception_type, supports_httpx_timeout
cohere_rerank = CohereRerank()
together_rerank = TogetherAIRerank()
azure_ai_rerank = AzureAIRerank()
jina_ai_rerank = JinaAIRerank()
#################################################
@ -247,7 +249,23 @@ def rerank(
api_key=api_key,
_is_async=_is_async,
)
elif _custom_llm_provider == "jina_ai":
if dynamic_api_key is None:
raise ValueError(
"Jina AI API key is required, please set 'JINA_AI_API_KEY' in your environment"
)
response = jina_ai_rerank.rerank(
model=model,
api_key=dynamic_api_key,
query=query,
documents=documents,
top_n=top_n,
rank_fields=rank_fields,
return_documents=return_documents,
max_chunks_per_doc=max_chunks_per_doc,
_is_async=_is_async,
)
else:
raise ValueError(f"Unsupported provider: {_custom_llm_provider}")

View file

@ -679,9 +679,8 @@ class Router:
kwargs["model"] = model
kwargs["messages"] = messages
kwargs["original_function"] = self._completion
kwargs.get("request_timeout", self.timeout)
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
kwargs.setdefault("metadata", {}).update({"model_group": model})
self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
response = self.function_with_fallbacks(**kwargs)
return response
except Exception as e:
@ -783,8 +782,7 @@ class Router:
kwargs["stream"] = stream
kwargs["original_function"] = self._acompletion
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
kwargs.setdefault("metadata", {}).update({"model_group": model})
self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
request_priority = kwargs.get("priority") or self.default_priority
@ -948,6 +946,17 @@ class Router:
self.fail_calls[model_name] += 1
raise e
def _update_kwargs_before_fallbacks(self, model: str, kwargs: dict) -> None:
"""
Adds/updates to kwargs:
- num_retries
- litellm_trace_id
- metadata
"""
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
kwargs.setdefault("litellm_trace_id", str(uuid.uuid4()))
kwargs.setdefault("metadata", {}).update({"model_group": model})
def _update_kwargs_with_default_litellm_params(self, kwargs: dict) -> None:
"""
Adds default litellm params to kwargs, if set.
@ -1511,9 +1520,7 @@ class Router:
kwargs["model"] = model
kwargs["file"] = file
kwargs["original_function"] = self._atranscription
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
response = await self.async_function_with_fallbacks(**kwargs)
return response
@ -1688,9 +1695,7 @@ class Router:
kwargs["model"] = model
kwargs["input"] = input
kwargs["original_function"] = self._arerank
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
response = await self.async_function_with_fallbacks(**kwargs)
@ -1839,9 +1844,7 @@ class Router:
kwargs["model"] = model
kwargs["prompt"] = prompt
kwargs["original_function"] = self._atext_completion
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
response = await self.async_function_with_fallbacks(**kwargs)
return response
@ -2112,9 +2115,7 @@ class Router:
kwargs["model"] = model
kwargs["input"] = input
kwargs["original_function"] = self._aembedding
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
response = await self.async_function_with_fallbacks(**kwargs)
return response
except Exception as e:
@ -2609,6 +2610,7 @@ class Router:
If it fails after num_retries, fall back to another model group
"""
model_group: Optional[str] = kwargs.get("model")
disable_fallbacks: Optional[bool] = kwargs.pop("disable_fallbacks", False)
fallbacks: Optional[List] = kwargs.get("fallbacks", self.fallbacks)
context_window_fallbacks: Optional[List] = kwargs.get(
"context_window_fallbacks", self.context_window_fallbacks
@ -2616,6 +2618,7 @@ class Router:
content_policy_fallbacks: Optional[List] = kwargs.get(
"content_policy_fallbacks", self.content_policy_fallbacks
)
try:
self._handle_mock_testing_fallbacks(
kwargs=kwargs,
@ -2635,7 +2638,7 @@ class Router:
original_model_group: Optional[str] = kwargs.get("model") # type: ignore
fallback_failure_exception_str = ""
if original_model_group is None:
if disable_fallbacks is True or original_model_group is None:
raise e
input_kwargs = {

View file

@ -23,28 +23,6 @@ def validate_environment():
raise ValueError("Missing required environment variable - AWS_REGION_NAME")
def load_aws_secret_manager(use_aws_secret_manager: Optional[bool]):
if use_aws_secret_manager is None or use_aws_secret_manager is False:
return
try:
import boto3
from botocore.exceptions import ClientError
validate_environment()
# Create a Secrets Manager client
session = boto3.session.Session() # type: ignore
client = session.client(
service_name="secretsmanager", region_name=os.getenv("AWS_REGION_NAME")
)
litellm.secret_manager_client = client
litellm._key_management_system = KeyManagementSystem.AWS_SECRET_MANAGER
except Exception as e:
raise e
def load_aws_kms(use_aws_kms: Optional[bool]):
if use_aws_kms is None or use_aws_kms is False:
return

View file

@ -0,0 +1,310 @@
"""
This is a file for the AWS Secret Manager Integration
Handles Async Operations for:
- Read Secret
- Write Secret
- Delete Secret
Relevant issue: https://github.com/BerriAI/litellm/issues/1883
Requires:
* `os.environ["AWS_REGION_NAME"],
* `pip install boto3>=1.28.57`
"""
import ast
import asyncio
import base64
import json
import os
import re
import sys
from typing import Any, Dict, Optional, Union
import httpx
import litellm
from litellm._logging import verbose_logger
from litellm.llms.base_aws_llm import BaseAWSLLM
from litellm.llms.custom_httpx.http_handler import (
_get_httpx_client,
get_async_httpx_client,
)
from litellm.llms.custom_httpx.types import httpxSpecialProvider
from litellm.proxy._types import KeyManagementSystem
class AWSSecretsManagerV2(BaseAWSLLM):
@classmethod
def validate_environment(cls):
if "AWS_REGION_NAME" not in os.environ:
raise ValueError("Missing required environment variable - AWS_REGION_NAME")
@classmethod
def load_aws_secret_manager(cls, use_aws_secret_manager: Optional[bool]):
"""
Initialize AWSSecretsManagerV2 and sets litellm.secret_manager_client = AWSSecretsManagerV2() and litellm._key_management_system = KeyManagementSystem.AWS_SECRET_MANAGER
"""
if use_aws_secret_manager is None or use_aws_secret_manager is False:
return
try:
import boto3
cls.validate_environment()
litellm.secret_manager_client = cls()
litellm._key_management_system = KeyManagementSystem.AWS_SECRET_MANAGER
except Exception as e:
raise e
async def async_read_secret(
self,
secret_name: str,
optional_params: Optional[dict] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
) -> Optional[str]:
"""
Async function to read a secret from AWS Secrets Manager
Returns:
str: Secret value
Raises:
ValueError: If the secret is not found or an HTTP error occurs
"""
endpoint_url, headers, body = self._prepare_request(
action="GetSecretValue",
secret_name=secret_name,
optional_params=optional_params,
)
async_client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.SecretManager,
params={"timeout": timeout},
)
try:
response = await async_client.post(
url=endpoint_url, headers=headers, data=body.decode("utf-8")
)
response.raise_for_status()
return response.json()["SecretString"]
except httpx.TimeoutException:
raise ValueError("Timeout error occurred")
except Exception as e:
verbose_logger.exception(
"Error reading secret from AWS Secrets Manager: %s", str(e)
)
return None
def sync_read_secret(
self,
secret_name: str,
optional_params: Optional[dict] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
) -> Optional[str]:
"""
Sync function to read a secret from AWS Secrets Manager
Done for backwards compatibility with existing codebase, since get_secret is a sync function
"""
# self._prepare_request uses these env vars, we cannot read them from AWS Secrets Manager. If we do we'd get stuck in an infinite loop
if secret_name in [
"AWS_ACCESS_KEY_ID",
"AWS_SECRET_ACCESS_KEY",
"AWS_REGION_NAME",
"AWS_REGION",
"AWS_BEDROCK_RUNTIME_ENDPOINT",
]:
return os.getenv(secret_name)
endpoint_url, headers, body = self._prepare_request(
action="GetSecretValue",
secret_name=secret_name,
optional_params=optional_params,
)
sync_client = _get_httpx_client(
params={"timeout": timeout},
)
try:
response = sync_client.post(
url=endpoint_url, headers=headers, data=body.decode("utf-8")
)
response.raise_for_status()
return response.json()["SecretString"]
except httpx.TimeoutException:
raise ValueError("Timeout error occurred")
except Exception as e:
verbose_logger.exception(
"Error reading secret from AWS Secrets Manager: %s", str(e)
)
return None
async def async_write_secret(
self,
secret_name: str,
secret_value: str,
description: Optional[str] = None,
client_request_token: Optional[str] = None,
optional_params: Optional[dict] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
) -> dict:
"""
Async function to write a secret to AWS Secrets Manager
Args:
secret_name: Name of the secret
secret_value: Value to store (can be a JSON string)
description: Optional description for the secret
client_request_token: Optional unique identifier to ensure idempotency
optional_params: Additional AWS parameters
timeout: Request timeout
"""
import uuid
# Prepare the request data
data = {"Name": secret_name, "SecretString": secret_value}
if description:
data["Description"] = description
data["ClientRequestToken"] = str(uuid.uuid4())
endpoint_url, headers, body = self._prepare_request(
action="CreateSecret",
secret_name=secret_name,
secret_value=secret_value,
optional_params=optional_params,
request_data=data, # Pass the complete request data
)
async_client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.SecretManager,
params={"timeout": timeout},
)
try:
response = await async_client.post(
url=endpoint_url, headers=headers, data=body.decode("utf-8")
)
response.raise_for_status()
return response.json()
except httpx.HTTPStatusError as err:
raise ValueError(f"HTTP error occurred: {err.response.text}")
except httpx.TimeoutException:
raise ValueError("Timeout error occurred")
async def async_delete_secret(
self,
secret_name: str,
recovery_window_in_days: Optional[int] = 7,
optional_params: Optional[dict] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
) -> dict:
"""
Async function to delete a secret from AWS Secrets Manager
Args:
secret_name: Name of the secret to delete
recovery_window_in_days: Number of days before permanent deletion (default: 7)
optional_params: Additional AWS parameters
timeout: Request timeout
Returns:
dict: Response from AWS Secrets Manager containing deletion details
"""
# Prepare the request data
data = {
"SecretId": secret_name,
"RecoveryWindowInDays": recovery_window_in_days,
}
endpoint_url, headers, body = self._prepare_request(
action="DeleteSecret",
secret_name=secret_name,
optional_params=optional_params,
request_data=data,
)
async_client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.SecretManager,
params={"timeout": timeout},
)
try:
response = await async_client.post(
url=endpoint_url, headers=headers, data=body.decode("utf-8")
)
response.raise_for_status()
return response.json()
except httpx.HTTPStatusError as err:
raise ValueError(f"HTTP error occurred: {err.response.text}")
except httpx.TimeoutException:
raise ValueError("Timeout error occurred")
def _prepare_request(
self,
action: str, # "GetSecretValue" or "PutSecretValue"
secret_name: str,
secret_value: Optional[str] = None,
optional_params: Optional[dict] = None,
request_data: Optional[dict] = None,
) -> tuple[str, Any, bytes]:
"""Prepare the AWS Secrets Manager request"""
try:
import boto3
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest
from botocore.credentials import Credentials
except ImportError:
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
optional_params = optional_params or {}
boto3_credentials_info = self._get_boto_credentials_from_optional_params(
optional_params
)
# Get endpoint
_, endpoint_url = self.get_runtime_endpoint(
api_base=None,
aws_bedrock_runtime_endpoint=boto3_credentials_info.aws_bedrock_runtime_endpoint,
aws_region_name=boto3_credentials_info.aws_region_name,
)
endpoint_url = endpoint_url.replace("bedrock-runtime", "secretsmanager")
# Use provided request_data if available, otherwise build default data
if request_data:
data = request_data
else:
data = {"SecretId": secret_name}
if secret_value and action == "PutSecretValue":
data["SecretString"] = secret_value
body = json.dumps(data).encode("utf-8")
headers = {
"Content-Type": "application/x-amz-json-1.1",
"X-Amz-Target": f"secretsmanager.{action}",
}
# Sign request
request = AWSRequest(
method="POST", url=endpoint_url, data=body, headers=headers
)
SigV4Auth(
boto3_credentials_info.credentials,
"secretsmanager",
boto3_credentials_info.aws_region_name,
).add_auth(request)
prepped = request.prepare()
return endpoint_url, prepped.headers, body
# if __name__ == "__main__":
# print("loading aws secret manager v2")
# aws_secret_manager_v2 = AWSSecretsManagerV2()
# print("writing secret to aws secret manager v2")
# asyncio.run(aws_secret_manager_v2.async_write_secret(secret_name="test_secret_3", secret_value="test_value_2"))
# print("reading secret from aws secret manager v2")

View file

@ -5,7 +5,7 @@ import json
import os
import sys
import traceback
from typing import Any, Optional, Union
from typing import TYPE_CHECKING, Any, Optional, Union
import httpx
from dotenv import load_dotenv
@ -198,7 +198,10 @@ def get_secret( # noqa: PLR0915
raise ValueError("Unsupported OIDC provider")
try:
if litellm.secret_manager_client is not None:
if (
_should_read_secret_from_secret_manager()
and litellm.secret_manager_client is not None
):
try:
client = litellm.secret_manager_client
key_manager = "local"
@ -207,7 +210,8 @@ def get_secret( # noqa: PLR0915
if key_management_settings is not None:
if (
secret_name not in key_management_settings.hosted_keys
key_management_settings.hosted_keys is not None
and secret_name not in key_management_settings.hosted_keys
): # allow user to specify which keys to check in hosted key manager
key_manager = "local"
@ -268,25 +272,13 @@ def get_secret( # noqa: PLR0915
if isinstance(secret, str):
secret = secret.strip()
elif key_manager == KeyManagementSystem.AWS_SECRET_MANAGER.value:
try:
get_secret_value_response = client.get_secret_value(
SecretId=secret_name
from litellm.secret_managers.aws_secret_manager_v2 import (
AWSSecretsManagerV2,
)
print_verbose(
f"get_secret_value_response: {get_secret_value_response}"
)
except Exception as e:
print_verbose(f"An error occurred - {str(e)}")
# For a list of exceptions thrown, see
# https://docs.aws.amazon.com/secretsmanager/latest/apireference/API_GetSecretValue.html
raise e
# assume there is 1 secret per secret_name
secret_dict = json.loads(get_secret_value_response["SecretString"])
print_verbose(f"secret_dict: {secret_dict}")
for k, v in secret_dict.items():
secret = v
print_verbose(f"secret: {secret}")
if isinstance(client, AWSSecretsManagerV2):
secret = client.sync_read_secret(secret_name=secret_name)
print_verbose(f"get_secret_value_response: {secret}")
elif key_manager == KeyManagementSystem.GOOGLE_SECRET_MANAGER.value:
try:
secret = client.get_secret_from_google_secret_manager(
@ -332,3 +324,21 @@ def get_secret( # noqa: PLR0915
return default_value
else:
raise e
def _should_read_secret_from_secret_manager() -> bool:
"""
Returns True if the secret manager should be used to read the secret, False otherwise
- If the secret manager client is not set, return False
- If the `_key_management_settings` access mode is "read_only" or "read_and_write", return True
- Otherwise, return False
"""
if litellm.secret_manager_client is not None:
if litellm._key_management_settings is not None:
if (
litellm._key_management_settings.access_mode == "read_only"
or litellm._key_management_settings.access_mode == "read_and_write"
):
return True
return False

View file

@ -7,6 +7,7 @@ https://docs.cohere.com/reference/rerank
from typing import List, Optional, Union
from pydantic import BaseModel, PrivateAttr
from typing_extensions import TypedDict
class RerankRequest(BaseModel):
@ -19,10 +20,26 @@ class RerankRequest(BaseModel):
max_chunks_per_doc: Optional[int] = None
class RerankBilledUnits(TypedDict, total=False):
search_units: int
total_tokens: int
class RerankTokens(TypedDict, total=False):
input_tokens: int
output_tokens: int
class RerankResponseMeta(TypedDict, total=False):
api_version: dict
billed_units: RerankBilledUnits
tokens: RerankTokens
class RerankResponse(BaseModel):
id: str
results: List[dict] # Contains index and relevance_score
meta: Optional[dict] = None # Contains api_version and billed_units
meta: Optional[RerankResponseMeta] = None # Contains api_version and billed_units
# Define private attributes using PrivateAttr
_hidden_params: dict = PrivateAttr(default_factory=dict)

View file

@ -150,6 +150,8 @@ class GenericLiteLLMParams(BaseModel):
max_retries: Optional[int] = None
organization: Optional[str] = None # for openai orgs
configurable_clientside_auth_params: CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS = None
## LOGGING PARAMS ##
litellm_trace_id: Optional[str] = None
## UNIFIED PROJECT/REGION ##
region_name: Optional[str] = None
## VERTEX AI ##
@ -186,6 +188,8 @@ class GenericLiteLLMParams(BaseModel):
None # timeout when making stream=True calls, if str, pass in as os.environ/
),
organization: Optional[str] = None, # for openai orgs
## LOGGING PARAMS ##
litellm_trace_id: Optional[str] = None,
## UNIFIED PROJECT/REGION ##
region_name: Optional[str] = None,
## VERTEX AI ##

View file

@ -1334,6 +1334,7 @@ class ResponseFormatChunk(TypedDict, total=False):
all_litellm_params = [
"metadata",
"litellm_trace_id",
"tags",
"acompletion",
"aimg_generation",
@ -1523,6 +1524,7 @@ StandardLoggingPayloadStatus = Literal["success", "failure"]
class StandardLoggingPayload(TypedDict):
id: str
trace_id: str # Trace multiple LLM calls belonging to same overall request (e.g. fallbacks/retries)
call_type: str
response_cost: float
response_cost_failure_debug_info: Optional[

View file

@ -527,6 +527,7 @@ def function_setup( # noqa: PLR0915
messages=messages,
stream=stream,
litellm_call_id=kwargs["litellm_call_id"],
litellm_trace_id=kwargs.get("litellm_trace_id"),
function_id=function_id or "",
call_type=call_type,
start_time=start_time,
@ -2056,6 +2057,7 @@ def get_litellm_params(
azure_ad_token_provider=None,
user_continue_message=None,
base_model=None,
litellm_trace_id=None,
):
litellm_params = {
"acompletion": acompletion,
@ -2084,6 +2086,7 @@ def get_litellm_params(
"user_continue_message": user_continue_message,
"base_model": base_model
or _get_base_model_from_litellm_call_metadata(metadata=metadata),
"litellm_trace_id": litellm_trace_id,
}
return litellm_params

View file

@ -2986,19 +2986,19 @@
"supports_function_calling": true
},
"vertex_ai/imagegeneration@006": {
"cost_per_image": 0.020,
"output_cost_per_image": 0.020,
"litellm_provider": "vertex_ai-image-models",
"mode": "image_generation",
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing"
},
"vertex_ai/imagen-3.0-generate-001": {
"cost_per_image": 0.04,
"output_cost_per_image": 0.04,
"litellm_provider": "vertex_ai-image-models",
"mode": "image_generation",
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing"
},
"vertex_ai/imagen-3.0-fast-generate-001": {
"cost_per_image": 0.02,
"output_cost_per_image": 0.02,
"litellm_provider": "vertex_ai-image-models",
"mode": "image_generation",
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing"
@ -5620,6 +5620,13 @@
"litellm_provider": "bedrock",
"mode": "image_generation"
},
"stability.stable-image-ultra-v1:0": {
"max_tokens": 77,
"max_input_tokens": 77,
"output_cost_per_image": 0.14,
"litellm_provider": "bedrock",
"mode": "image_generation"
},
"sagemaker/meta-textgeneration-llama-2-7b": {
"max_tokens": 4096,
"max_input_tokens": 4096,

View file

@ -1,6 +1,6 @@
[tool.poetry]
name = "litellm"
version = "1.52.6"
version = "1.52.9"
description = "Library to easily interface with LLM API providers"
authors = ["BerriAI"]
license = "MIT"
@ -91,7 +91,7 @@ requires = ["poetry-core", "wheel"]
build-backend = "poetry.core.masonry.api"
[tool.commitizen]
version = "1.52.6"
version = "1.52.9"
version_files = [
"pyproject.toml:^version"
]

View file

@ -13,8 +13,11 @@ sys.path.insert(
import litellm
from litellm.exceptions import BadRequestError
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.utils import CustomStreamWrapper
from litellm.utils import (
CustomStreamWrapper,
get_supported_openai_params,
get_optional_params,
)
# test_example.py
from abc import ABC, abstractmethod
@ -45,6 +48,9 @@ class BaseLLMChatTest(ABC):
)
assert response is not None
# for OpenAI the content contains the JSON schema, so we need to assert that the content is not None
assert response.choices[0].message.content is not None
def test_message_with_name(self):
base_completion_call_args = self.get_base_completion_call_args()
messages = [
@ -79,6 +85,49 @@ class BaseLLMChatTest(ABC):
print(response)
# OpenAI guarantees that the JSON schema is returned in the content
# relevant issue: https://github.com/BerriAI/litellm/issues/6741
assert response.choices[0].message.content is not None
def test_json_response_format_stream(self):
"""
Test that the JSON response format with streaming is supported by the LLM API
"""
base_completion_call_args = self.get_base_completion_call_args()
litellm.set_verbose = True
messages = [
{
"role": "system",
"content": "Your output should be a JSON object with no additional properties. ",
},
{
"role": "user",
"content": "Respond with this in json. city=San Francisco, state=CA, weather=sunny, temp=60",
},
]
response = litellm.completion(
**base_completion_call_args,
messages=messages,
response_format={"type": "json_object"},
stream=True,
)
print(response)
content = ""
for chunk in response:
content += chunk.choices[0].delta.content or ""
print("content=", content)
# OpenAI guarantees that the JSON schema is returned in the content
# relevant issue: https://github.com/BerriAI/litellm/issues/6741
# we need to assert that the JSON schema was returned in the content, (for Anthropic we were returning it as part of the tool call)
assert content is not None
assert len(content) > 0
@pytest.fixture
def pdf_messages(self):
import base64

View file

@ -0,0 +1,115 @@
import asyncio
import httpx
import json
import pytest
import sys
from typing import Any, Dict, List
from unittest.mock import MagicMock, Mock, patch
import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import litellm
from litellm.exceptions import BadRequestError
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.utils import (
CustomStreamWrapper,
get_supported_openai_params,
get_optional_params,
)
# test_example.py
from abc import ABC, abstractmethod
def assert_response_shape(response, custom_llm_provider):
expected_response_shape = {"id": str, "results": list, "meta": dict}
expected_results_shape = {"index": int, "relevance_score": float}
expected_meta_shape = {"api_version": dict, "billed_units": dict}
expected_api_version_shape = {"version": str}
expected_billed_units_shape = {"search_units": int}
assert isinstance(response.id, expected_response_shape["id"])
assert isinstance(response.results, expected_response_shape["results"])
for result in response.results:
assert isinstance(result["index"], expected_results_shape["index"])
assert isinstance(
result["relevance_score"], expected_results_shape["relevance_score"]
)
assert isinstance(response.meta, expected_response_shape["meta"])
if custom_llm_provider == "cohere":
assert isinstance(
response.meta["api_version"], expected_meta_shape["api_version"]
)
assert isinstance(
response.meta["api_version"]["version"],
expected_api_version_shape["version"],
)
assert isinstance(
response.meta["billed_units"], expected_meta_shape["billed_units"]
)
assert isinstance(
response.meta["billed_units"]["search_units"],
expected_billed_units_shape["search_units"],
)
class BaseLLMRerankTest(ABC):
"""
Abstract base test class that enforces a common test across all test classes.
"""
@abstractmethod
def get_base_rerank_call_args(self) -> dict:
"""Must return the base rerank call args"""
pass
@abstractmethod
def get_custom_llm_provider(self) -> litellm.LlmProviders:
"""Must return the custom llm provider"""
pass
@pytest.mark.asyncio()
@pytest.mark.parametrize("sync_mode", [True, False])
async def test_basic_rerank(self, sync_mode):
rerank_call_args = self.get_base_rerank_call_args()
custom_llm_provider = self.get_custom_llm_provider()
if sync_mode is True:
response = litellm.rerank(
**rerank_call_args,
query="hello",
documents=["hello", "world"],
top_n=3,
)
print("re rank response: ", response)
assert response.id is not None
assert response.results is not None
assert_response_shape(
response=response, custom_llm_provider=custom_llm_provider.value
)
else:
response = await litellm.arerank(
**rerank_call_args,
query="hello",
documents=["hello", "world"],
top_n=3,
)
print("async re rank response: ", response)
assert response.id is not None
assert response.results is not None
assert_response_shape(
response=response, custom_llm_provider=custom_llm_provider.value
)

View file

@ -33,8 +33,10 @@ from litellm import (
)
from litellm.adapters.anthropic_adapter import anthropic_adapter
from litellm.types.llms.anthropic import AnthropicResponse
from litellm.types.utils import GenericStreamingChunk, ChatCompletionToolCallChunk
from litellm.types.llms.openai import ChatCompletionToolCallFunctionChunk
from litellm.llms.anthropic.common_utils import process_anthropic_headers
from litellm.llms.anthropic.chat.handler import AnthropicChatCompletion
from httpx import Headers
from base_llm_unit_tests import BaseLLMChatTest
@ -694,3 +696,91 @@ class TestAnthropicCompletion(BaseLLMChatTest):
assert _document_validation["type"] == "document"
assert _document_validation["source"]["media_type"] == "application/pdf"
assert _document_validation["source"]["type"] == "base64"
def test_convert_tool_response_to_message_with_values():
"""Test converting a tool response with 'values' key to a message"""
tool_calls = [
ChatCompletionToolCallChunk(
id="test_id",
type="function",
function=ChatCompletionToolCallFunctionChunk(
name="json_tool_call",
arguments='{"values": {"name": "John", "age": 30}}',
),
index=0,
)
]
message = AnthropicChatCompletion._convert_tool_response_to_message(
tool_calls=tool_calls
)
assert message is not None
assert message.content == '{"name": "John", "age": 30}'
def test_convert_tool_response_to_message_without_values():
"""
Test converting a tool response without 'values' key to a message
Anthropic API returns the JSON schema in the tool call, OpenAI Spec expects it in the message. This test ensures that the tool call is converted to a message correctly.
Relevant issue: https://github.com/BerriAI/litellm/issues/6741
"""
tool_calls = [
ChatCompletionToolCallChunk(
id="test_id",
type="function",
function=ChatCompletionToolCallFunctionChunk(
name="json_tool_call", arguments='{"name": "John", "age": 30}'
),
index=0,
)
]
message = AnthropicChatCompletion._convert_tool_response_to_message(
tool_calls=tool_calls
)
assert message is not None
assert message.content == '{"name": "John", "age": 30}'
def test_convert_tool_response_to_message_invalid_json():
"""Test converting a tool response with invalid JSON"""
tool_calls = [
ChatCompletionToolCallChunk(
id="test_id",
type="function",
function=ChatCompletionToolCallFunctionChunk(
name="json_tool_call", arguments="invalid json"
),
index=0,
)
]
message = AnthropicChatCompletion._convert_tool_response_to_message(
tool_calls=tool_calls
)
assert message is not None
assert message.content == "invalid json"
def test_convert_tool_response_to_message_no_arguments():
"""Test converting a tool response with no arguments"""
tool_calls = [
ChatCompletionToolCallChunk(
id="test_id",
type="function",
function=ChatCompletionToolCallFunctionChunk(name="json_tool_call"),
index=0,
)
]
message = AnthropicChatCompletion._convert_tool_response_to_message(
tool_calls=tool_calls
)
assert message is None

View file

@ -0,0 +1,23 @@
import json
import os
import sys
from datetime import datetime
from unittest.mock import AsyncMock
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
from base_rerank_unit_tests import BaseLLMRerankTest
import litellm
class TestJinaAI(BaseLLMRerankTest):
def get_custom_llm_provider(self) -> litellm.LlmProviders:
return litellm.LlmProviders.JINA_AI
def get_base_rerank_call_args(self) -> dict:
return {
"model": "jina_ai/jina-reranker-v2-base-multilingual",
}

View file

@ -923,9 +923,22 @@ def test_watsonx_text_top_k():
assert optional_params["top_k"] == 10
def test_together_ai_model_params():
optional_params = get_optional_params(
model="together_ai", custom_llm_provider="together_ai", logprobs=1
)
print(optional_params)
assert optional_params["logprobs"] == 1
def test_forward_user_param():
from litellm.utils import get_supported_openai_params, get_optional_params
model = "claude-3-5-sonnet-20240620"
optional_params = get_optional_params(
model=model,
user="test_user",
custom_llm_provider="anthropic",
)
assert optional_params["metadata"]["user_id"] == "test_user"

View file

@ -16,6 +16,7 @@ import pytest
import litellm
from litellm import get_optional_params
from litellm.llms.custom_httpx.http_handler import HTTPHandler
import httpx
def test_completion_pydantic_obj_2():
@ -1317,3 +1318,39 @@ def test_image_completion_request(image_url):
mock_post.assert_called_once()
print("mock_post.call_args.kwargs['json']", mock_post.call_args.kwargs["json"])
assert mock_post.call_args.kwargs["json"] == expected_request_body
@pytest.mark.parametrize(
"model, expected_url",
[
(
"textembedding-gecko@001",
"https://us-central1-aiplatform.googleapis.com/v1/projects/project-id/locations/us-central1/publishers/google/models/textembedding-gecko@001:predict",
),
(
"123456789",
"https://us-central1-aiplatform.googleapis.com/v1/projects/project-id/locations/us-central1/endpoints/123456789:predict",
),
],
)
def test_vertex_embedding_url(model, expected_url):
"""
Test URL generation for embedding models, including numeric model IDs (fine-tuned models
Relevant issue: https://github.com/BerriAI/litellm/issues/6482
When a fine-tuned embedding model is used, the URL is different from the standard one.
"""
from litellm.llms.vertex_ai_and_google_ai_studio.common_utils import _get_vertex_url
url, endpoint = _get_vertex_url(
mode="embedding",
model=model,
stream=False,
vertex_project="project-id",
vertex_location="us-central1",
vertex_api_version="v1",
)
assert url == expected_url
assert endpoint == "predict"

View file

@ -18,6 +18,8 @@ import json
import os
import tempfile
from unittest.mock import AsyncMock, MagicMock, patch
from respx import MockRouter
import httpx
import pytest
@ -973,6 +975,7 @@ async def test_partner_models_httpx(model, sync_mode):
data = {
"model": model,
"messages": messages,
"timeout": 10,
}
if sync_mode:
response = litellm.completion(**data)
@ -986,6 +989,8 @@ async def test_partner_models_httpx(model, sync_mode):
assert isinstance(response._hidden_params["response_cost"], float)
except litellm.RateLimitError as e:
pass
except litellm.Timeout as e:
pass
except litellm.InternalServerError as e:
pass
except Exception as e:
@ -3051,3 +3056,70 @@ def test_custom_api_base(api_base):
assert url == api_base + ":"
else:
assert url == test_endpoint
@pytest.mark.asyncio
@pytest.mark.respx
async def test_vertexai_embedding_finetuned(respx_mock: MockRouter):
"""
Tests that:
- Request URL and body are correctly formatted for Vertex AI embeddings
- Response is properly parsed into litellm's embedding response format
"""
load_vertex_ai_credentials()
litellm.set_verbose = True
# Test input
input_text = ["good morning from litellm", "this is another item"]
# Expected request/response
expected_url = "https://us-central1-aiplatform.googleapis.com/v1/projects/633608382793/locations/us-central1/endpoints/1004708436694269952:predict"
expected_request = {
"instances": [
{"inputs": "good morning from litellm"},
{"inputs": "this is another item"},
],
"parameters": {},
}
mock_response = {
"predictions": [
[[-0.000431762, -0.04416759, -0.03443353]], # Truncated embedding vector
[[-0.000431762, -0.04416759, -0.03443353]], # Truncated embedding vector
],
"deployedModelId": "2275167734310371328",
"model": "projects/633608382793/locations/us-central1/models/snowflake-arctic-embed-m-long-1731622468876",
"modelDisplayName": "snowflake-arctic-embed-m-long-1731622468876",
"modelVersionId": "1",
}
# Setup mock request
mock_request = respx_mock.post(expected_url).mock(
return_value=httpx.Response(200, json=mock_response)
)
# Make request
response = await litellm.aembedding(
vertex_project="633608382793",
model="vertex_ai/1004708436694269952",
input=input_text,
)
# Assert request was made correctly
assert mock_request.called
request_body = json.loads(mock_request.calls[0].request.content)
print("\n\nrequest_body", request_body)
print("\n\nexpected_request", expected_request)
assert request_body == expected_request
# Assert response structure
assert response is not None
assert hasattr(response, "data")
assert len(response.data) == len(input_text)
# Assert embedding structure
for embedding in response.data:
assert "embedding" in embedding
assert isinstance(embedding["embedding"], list)
assert len(embedding["embedding"]) > 0
assert all(isinstance(x, float) for x in embedding["embedding"])

View file

@ -0,0 +1,139 @@
# What is this?
import asyncio
import os
import sys
import traceback
from dotenv import load_dotenv
import litellm.types
import litellm.types.utils
load_dotenv()
import io
import sys
import os
# Ensure the project root is in the Python path
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..")))
print("Python Path:", sys.path)
print("Current Working Directory:", os.getcwd())
from typing import Optional
from unittest.mock import MagicMock, patch
import pytest
import uuid
import json
from litellm.secret_managers.aws_secret_manager_v2 import AWSSecretsManagerV2
def check_aws_credentials():
"""Helper function to check if AWS credentials are set"""
required_vars = ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_REGION_NAME"]
missing_vars = [var for var in required_vars if not os.getenv(var)]
if missing_vars:
pytest.skip(f"Missing required AWS credentials: {', '.join(missing_vars)}")
@pytest.mark.asyncio
async def test_write_and_read_simple_secret():
"""Test writing and reading a simple string secret"""
check_aws_credentials()
secret_manager = AWSSecretsManagerV2()
test_secret_name = f"litellm_test_{uuid.uuid4().hex[:8]}"
test_secret_value = "test_value_123"
try:
# Write secret
write_response = await secret_manager.async_write_secret(
secret_name=test_secret_name,
secret_value=test_secret_value,
description="LiteLLM Test Secret",
)
print("Write Response:", write_response)
assert write_response is not None
assert "ARN" in write_response
assert "Name" in write_response
assert write_response["Name"] == test_secret_name
# Read secret back
read_value = await secret_manager.async_read_secret(
secret_name=test_secret_name
)
print("Read Value:", read_value)
assert read_value == test_secret_value
finally:
# Cleanup: Delete the secret
delete_response = await secret_manager.async_delete_secret(
secret_name=test_secret_name
)
print("Delete Response:", delete_response)
assert delete_response is not None
@pytest.mark.asyncio
async def test_write_and_read_json_secret():
"""Test writing and reading a JSON structured secret"""
check_aws_credentials()
secret_manager = AWSSecretsManagerV2()
test_secret_name = f"litellm_test_{uuid.uuid4().hex[:8]}_json"
test_secret_value = {
"api_key": "test_key",
"model": "gpt-4",
"temperature": 0.7,
"metadata": {"team": "ml", "project": "litellm"},
}
try:
# Write JSON secret
write_response = await secret_manager.async_write_secret(
secret_name=test_secret_name,
secret_value=json.dumps(test_secret_value),
description="LiteLLM JSON Test Secret",
)
print("Write Response:", write_response)
# Read and parse JSON secret
read_value = await secret_manager.async_read_secret(
secret_name=test_secret_name
)
parsed_value = json.loads(read_value)
print("Read Value:", read_value)
assert parsed_value == test_secret_value
assert parsed_value["api_key"] == "test_key"
assert parsed_value["metadata"]["team"] == "ml"
finally:
# Cleanup: Delete the secret
delete_response = await secret_manager.async_delete_secret(
secret_name=test_secret_name
)
print("Delete Response:", delete_response)
assert delete_response is not None
@pytest.mark.asyncio
async def test_read_nonexistent_secret():
"""Test reading a secret that doesn't exist"""
check_aws_credentials()
secret_manager = AWSSecretsManagerV2()
nonexistent_secret = f"litellm_nonexistent_{uuid.uuid4().hex}"
response = await secret_manager.async_read_secret(secret_name=nonexistent_secret)
assert response is None

View file

@ -10,7 +10,7 @@ import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
) # Adds the parent directory to the system-path
from typing import Literal
import pytest

View file

@ -1624,3 +1624,55 @@ async def test_standard_logging_payload_stream_usage(sync_mode):
print(f"standard_logging_object usage: {built_response.usage}")
except litellm.InternalServerError:
pass
def test_standard_logging_retries():
"""
know if a request was retried.
"""
from litellm.types.utils import StandardLoggingPayload
from litellm.router import Router
customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler]
router = Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "openai/gpt-3.5-turbo",
"api_key": "test-api-key",
},
}
]
)
with patch.object(
customHandler, "log_failure_event", new=MagicMock()
) as mock_client:
try:
router.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
num_retries=1,
mock_response="litellm.RateLimitError",
)
except litellm.RateLimitError:
pass
assert mock_client.call_count == 2
assert (
mock_client.call_args_list[0].kwargs["kwargs"]["standard_logging_object"][
"trace_id"
]
is not None
)
assert (
mock_client.call_args_list[0].kwargs["kwargs"]["standard_logging_object"][
"trace_id"
]
== mock_client.call_args_list[1].kwargs["kwargs"][
"standard_logging_object"
]["trace_id"]
)

View file

@ -58,6 +58,7 @@ async def test_content_policy_exception_azure():
except litellm.ContentPolicyViolationError as e:
print("caught a content policy violation error! Passed")
print("exception", e)
assert e.response is not None
assert e.litellm_debug_info is not None
assert isinstance(e.litellm_debug_info, str)
assert len(e.litellm_debug_info) > 0
@ -1152,3 +1153,24 @@ async def test_exception_with_headers_httpx(
if exception_raised is False:
print(resp)
assert exception_raised
@pytest.mark.asyncio
@pytest.mark.parametrize("model", ["azure/chatgpt-v-2", "openai/gpt-3.5-turbo"])
async def test_bad_request_error_contains_httpx_response(model):
"""
Test that the BadRequestError contains the httpx response
Relevant issue: https://github.com/BerriAI/litellm/issues/6732
"""
try:
await litellm.acompletion(
model=model,
messages=[{"role": "user", "content": "Hello world"}],
bad_arg="bad_arg",
)
pytest.fail("Expected to raise BadRequestError")
except litellm.BadRequestError as e:
print("e.response", e.response)
print("vars(e.response)", vars(e.response))
assert e.response is not None

View file

@ -157,7 +157,7 @@ def test_get_llm_provider_jina_ai():
model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider(
model="jina_ai/jina-embeddings-v3",
)
assert custom_llm_provider == "openai_like"
assert custom_llm_provider == "jina_ai"
assert api_base == "https://api.jina.ai/v1"
assert model == "jina-embeddings-v3"

View file

@ -89,11 +89,16 @@ def test_get_model_info_ollama_chat():
"template": "tools",
}
),
):
) as mock_client:
info = OllamaConfig().get_model_info("mistral")
print("info", info)
assert info["supports_function_calling"] is True
info = get_model_info("ollama/mistral")
print("info", info)
assert info["supports_function_calling"] is True
mock_client.assert_called()
print(mock_client.call_args.kwargs)
assert mock_client.call_args.kwargs["json"]["name"] == "mistral"

View file

@ -1138,9 +1138,9 @@ async def test_router_content_policy_fallbacks(
router = Router(
model_list=[
{
"model_name": "claude-2",
"model_name": "claude-2.1",
"litellm_params": {
"model": "claude-2",
"model": "claude-2.1",
"api_key": "",
"mock_response": mock_response,
},
@ -1164,7 +1164,7 @@ async def test_router_content_policy_fallbacks(
{
"model_name": "my-general-model",
"litellm_params": {
"model": "claude-2",
"model": "claude-2.1",
"api_key": "",
"mock_response": Exception("Should not have called this."),
},
@ -1172,14 +1172,14 @@ async def test_router_content_policy_fallbacks(
{
"model_name": "my-context-window-model",
"litellm_params": {
"model": "claude-2",
"model": "claude-2.1",
"api_key": "",
"mock_response": Exception("Should not have called this."),
},
},
],
content_policy_fallbacks=(
[{"claude-2": ["my-fallback-model"]}]
[{"claude-2.1": ["my-fallback-model"]}]
if fallback_type == "model-specific"
else None
),
@ -1190,12 +1190,12 @@ async def test_router_content_policy_fallbacks(
if sync_mode is True:
response = router.completion(
model="claude-2",
model="claude-2.1",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
)
else:
response = await router.acompletion(
model="claude-2",
model="claude-2.1",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
)
@ -1455,3 +1455,46 @@ async def test_router_fallbacks_default_and_model_specific_fallbacks(sync_mode):
assert isinstance(
exc_info.value, litellm.AuthenticationError
), f"Expected AuthenticationError, but got {type(exc_info.value).__name__}"
@pytest.mark.asyncio
async def test_router_disable_fallbacks_dynamically():
from litellm.router import run_async_fallback
router = Router(
model_list=[
{
"model_name": "bad-model",
"litellm_params": {
"model": "openai/my-bad-model",
"api_key": "my-bad-api-key",
},
},
{
"model_name": "good-model",
"litellm_params": {
"model": "gpt-4o",
"api_key": os.getenv("OPENAI_API_KEY"),
},
},
],
fallbacks=[{"bad-model": ["good-model"]}],
default_fallbacks=["good-model"],
)
with patch.object(
router,
"log_retry",
new=MagicMock(return_value=None),
) as mock_client:
try:
resp = await router.acompletion(
model="bad-model",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
disable_fallbacks=True,
)
print(resp)
except Exception as e:
print(e)
mock_client.assert_not_called()

View file

@ -14,6 +14,7 @@ from litellm.router import Deployment, LiteLLM_Params, ModelInfo
from concurrent.futures import ThreadPoolExecutor
from collections import defaultdict
from dotenv import load_dotenv
from unittest.mock import patch, MagicMock, AsyncMock
load_dotenv()
@ -83,3 +84,93 @@ def test_returned_settings():
except Exception:
print(traceback.format_exc())
pytest.fail("An error occurred - " + traceback.format_exc())
from litellm.types.utils import CallTypes
def test_update_kwargs_before_fallbacks_unit_test():
router = Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "azure/chatgpt-v-2",
"api_key": "bad-key",
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"),
},
}
],
)
kwargs = {"messages": [{"role": "user", "content": "write 1 sentence poem"}]}
router._update_kwargs_before_fallbacks(
model="gpt-3.5-turbo",
kwargs=kwargs,
)
assert kwargs["litellm_trace_id"] is not None
@pytest.mark.parametrize(
"call_type",
[
CallTypes.acompletion,
CallTypes.atext_completion,
CallTypes.aembedding,
CallTypes.arerank,
CallTypes.atranscription,
],
)
@pytest.mark.asyncio
async def test_update_kwargs_before_fallbacks(call_type):
router = Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "azure/chatgpt-v-2",
"api_key": "bad-key",
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"),
},
}
],
)
if call_type.value.startswith("a"):
with patch.object(router, "async_function_with_fallbacks") as mock_client:
if call_type.value == "acompletion":
input_kwarg = {
"messages": [{"role": "user", "content": "Hello, how are you?"}],
}
elif (
call_type.value == "atext_completion"
or call_type.value == "aimage_generation"
):
input_kwarg = {
"prompt": "Hello, how are you?",
}
elif call_type.value == "aembedding" or call_type.value == "arerank":
input_kwarg = {
"input": "Hello, how are you?",
}
elif call_type.value == "atranscription":
input_kwarg = {
"file": "path/to/file",
}
else:
input_kwarg = {}
await getattr(router, call_type.value)(
model="gpt-3.5-turbo",
**input_kwarg,
)
mock_client.assert_called_once()
print(mock_client.call_args.kwargs)
assert mock_client.call_args.kwargs["litellm_trace_id"] is not None

View file

@ -15,22 +15,29 @@ sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest
import litellm
from litellm.llms.AzureOpenAI.azure import get_azure_ad_token_from_oidc
from litellm.llms.bedrock.chat import BedrockConverseLLM, BedrockLLM
from litellm.secret_managers.aws_secret_manager import load_aws_secret_manager
from litellm.secret_managers.main import get_secret
from litellm.secret_managers.aws_secret_manager_v2 import AWSSecretsManagerV2
from litellm.secret_managers.main import (
get_secret,
_should_read_secret_from_secret_manager,
)
@pytest.mark.skip(reason="AWS Suspended Account")
def test_aws_secret_manager():
load_aws_secret_manager(use_aws_secret_manager=True)
import json
AWSSecretsManagerV2.load_aws_secret_manager(use_aws_secret_manager=True)
secret_val = get_secret("litellm_master_key")
print(f"secret_val: {secret_val}")
assert secret_val == "sk-1234"
# cast json to dict
secret_val = json.loads(secret_val)
assert secret_val["litellm_master_key"] == "sk-1234"
def redact_oidc_signature(secret_val):
@ -240,3 +247,71 @@ def test_google_secret_manager_read_in_memory():
)
print("secret_val: {}".format(secret_val))
assert secret_val == "lite-llm"
def test_should_read_secret_from_secret_manager():
"""
Test that _should_read_secret_from_secret_manager returns correct values based on access mode
"""
from litellm.proxy._types import KeyManagementSettings
# Test when secret manager client is None
litellm.secret_manager_client = None
litellm._key_management_settings = KeyManagementSettings()
assert _should_read_secret_from_secret_manager() is False
# Test with secret manager client and read_only access
litellm.secret_manager_client = "dummy_client"
litellm._key_management_settings = KeyManagementSettings(access_mode="read_only")
assert _should_read_secret_from_secret_manager() is True
# Test with secret manager client and read_and_write access
litellm._key_management_settings = KeyManagementSettings(
access_mode="read_and_write"
)
assert _should_read_secret_from_secret_manager() is True
# Test with secret manager client and write_only access
litellm._key_management_settings = KeyManagementSettings(access_mode="write_only")
assert _should_read_secret_from_secret_manager() is False
# Reset global variables
litellm.secret_manager_client = None
litellm._key_management_settings = KeyManagementSettings()
def test_get_secret_with_access_mode():
"""
Test that get_secret respects access mode settings
"""
from litellm.proxy._types import KeyManagementSettings
# Set up test environment
test_secret_name = "TEST_SECRET_KEY"
test_secret_value = "test_secret_value"
os.environ[test_secret_name] = test_secret_value
# Test with write_only access (should read from os.environ)
litellm.secret_manager_client = "dummy_client"
litellm._key_management_settings = KeyManagementSettings(access_mode="write_only")
assert get_secret(test_secret_name) == test_secret_value
# Test with no KeyManagementSettings but secret_manager_client set
litellm.secret_manager_client = "dummy_client"
litellm._key_management_settings = KeyManagementSettings()
assert _should_read_secret_from_secret_manager() is True
# Test with read_only access
litellm._key_management_settings = KeyManagementSettings(access_mode="read_only")
assert _should_read_secret_from_secret_manager() is True
# Test with read_and_write access
litellm._key_management_settings = KeyManagementSettings(
access_mode="read_and_write"
)
assert _should_read_secret_from_secret_manager() is True
# Reset global variables
litellm.secret_manager_client = None
litellm._key_management_settings = KeyManagementSettings()
del os.environ[test_secret_name]

View file

@ -184,12 +184,11 @@ def test_stream_chunk_builder_litellm_usage_chunks():
{"role": "assistant", "content": "uhhhh\n\n\nhmmmm.....\nthinking....\n"},
{"role": "user", "content": "\nI am waiting...\n\n...\n"},
]
# make a regular gemini call
usage: litellm.Usage = Usage(
completion_tokens=64,
completion_tokens=27,
prompt_tokens=55,
total_tokens=119,
total_tokens=82,
completion_tokens_details=None,
prompt_tokens_details=None,
)

View file

@ -718,7 +718,7 @@ async def test_acompletion_claude_2_stream():
try:
litellm.set_verbose = True
response = await litellm.acompletion(
model="claude-2",
model="claude-2.1",
messages=[{"role": "user", "content": "hello from litellm"}],
stream=True,
)
@ -3274,7 +3274,7 @@ def test_completion_claude_3_function_call_with_streaming():
], # "claude-3-opus-20240229"
) #
@pytest.mark.asyncio
async def test_acompletion_claude_3_function_call_with_streaming(model):
async def test_acompletion_function_call_with_streaming(model):
litellm.set_verbose = True
tools = [
{
@ -3335,6 +3335,8 @@ async def test_acompletion_claude_3_function_call_with_streaming(model):
# raise Exception("it worked! ")
except litellm.InternalServerError as e:
pytest.skip(f"InternalServerError - {str(e)}")
except litellm.ServiceUnavailableError:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")

View file

@ -3451,3 +3451,90 @@ async def test_user_api_key_auth_db_unavailable_not_allowed():
request=request,
api_key="Bearer sk-123456789",
)
## E2E Virtual Key + Secret Manager Tests #########################################
@pytest.mark.asyncio
async def test_key_generate_with_secret_manager_call(prisma_client):
"""
Generate a key
assert it exists in the secret manager
delete the key
assert it is deleted from the secret manager
"""
from litellm.secret_managers.aws_secret_manager_v2 import AWSSecretsManagerV2
from litellm.proxy._types import KeyManagementSystem, KeyManagementSettings
litellm.set_verbose = True
#### Test Setup ############################################################
aws_secret_manager_client = AWSSecretsManagerV2()
litellm.secret_manager_client = aws_secret_manager_client
litellm._key_management_system = KeyManagementSystem.AWS_SECRET_MANAGER
litellm._key_management_settings = KeyManagementSettings(
store_virtual_keys=True,
)
general_settings = {
"key_management_system": "aws_secret_manager",
"key_management_settings": {
"store_virtual_keys": True,
},
}
setattr(litellm.proxy.proxy_server, "general_settings", general_settings)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
await litellm.proxy.proxy_server.prisma_client.connect()
############################################################################
# generate new key
key_alias = f"test_alias_secret_manager_key-{uuid.uuid4()}"
spend = 100
max_budget = 400
models = ["fake-openai-endpoint"]
new_key = await generate_key_fn(
data=GenerateKeyRequest(
key_alias=key_alias, spend=spend, max_budget=max_budget, models=models
),
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="1234",
),
)
generated_key = new_key.key
print(generated_key)
await asyncio.sleep(2)
# read from the secret manager
result = await aws_secret_manager_client.async_read_secret(secret_name=key_alias)
# Assert the correct key is stored in the secret manager
print("response from AWS Secret Manager")
print(result)
assert result == generated_key
# delete the key
await delete_key_fn(
data=KeyRequest(keys=[generated_key]),
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN, api_key="sk-1234", user_id="1234"
),
)
await asyncio.sleep(2)
# Assert the key is deleted from the secret manager
result = await aws_secret_manager_client.async_read_secret(secret_name=key_alias)
assert result is None
# cleanup
setattr(litellm.proxy.proxy_server, "general_settings", {})
################################################################################

View file

@ -1500,6 +1500,31 @@ async def test_add_callback_via_key_litellm_pre_call_utils(
assert new_data["failure_callback"] == expected_failure_callbacks
@pytest.mark.asyncio
@pytest.mark.parametrize(
"disable_fallbacks_set",
[
True,
False,
],
)
async def test_disable_fallbacks_by_key(disable_fallbacks_set):
from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup
key_metadata = {"disable_fallbacks": disable_fallbacks_set}
existing_data = {
"model": "azure/chatgpt-v-2",
"messages": [{"role": "user", "content": "write 1 sentence poem"}],
}
data = LiteLLMProxyRequestSetup.add_key_level_controls(
key_metadata=key_metadata,
data=existing_data,
_metadata_variable_name="metadata",
)
assert data["disable_fallbacks"] == disable_fallbacks_set
@pytest.mark.asyncio
@pytest.mark.parametrize(
"callback_type, expected_success_callbacks, expected_failure_callbacks",