forked from phoenix/litellm-mirror
Merge branch 'main' into litellm_dev_11_13_2024
This commit is contained in:
commit
1dcbfda202
76 changed files with 2836 additions and 560 deletions
30
README.md
30
README.md
|
@ -305,6 +305,36 @@ Step 4: Submit a PR with your changes! 🚀
|
||||||
- push your fork to your GitHub repo
|
- push your fork to your GitHub repo
|
||||||
- submit a PR from there
|
- submit a PR from there
|
||||||
|
|
||||||
|
### Building LiteLLM Docker Image
|
||||||
|
|
||||||
|
Follow these instructions if you want to build / run the LiteLLM Docker Image yourself.
|
||||||
|
|
||||||
|
Step 1: Clone the repo
|
||||||
|
|
||||||
|
```
|
||||||
|
git clone https://github.com/BerriAI/litellm.git
|
||||||
|
```
|
||||||
|
|
||||||
|
Step 2: Build the Docker Image
|
||||||
|
|
||||||
|
Build using Dockerfile.non_root
|
||||||
|
```
|
||||||
|
docker build -f docker/Dockerfile.non_root -t litellm_test_image .
|
||||||
|
```
|
||||||
|
|
||||||
|
Step 3: Run the Docker Image
|
||||||
|
|
||||||
|
Make sure config.yaml is present in the root directory. This is your litellm proxy config file.
|
||||||
|
```
|
||||||
|
docker run \
|
||||||
|
-v $(pwd)/proxy_config.yaml:/app/config.yaml \
|
||||||
|
-e DATABASE_URL="postgresql://xxxxxxxx" \
|
||||||
|
-e LITELLM_MASTER_KEY="sk-1234" \
|
||||||
|
-p 4000:4000 \
|
||||||
|
litellm_test_image \
|
||||||
|
--config /app/config.yaml --detailed_debug
|
||||||
|
```
|
||||||
|
|
||||||
# Enterprise
|
# Enterprise
|
||||||
For companies that need better security, user management and professional support
|
For companies that need better security, user management and professional support
|
||||||
|
|
||||||
|
|
|
@ -13,18 +13,18 @@ spec:
|
||||||
spec:
|
spec:
|
||||||
containers:
|
containers:
|
||||||
- name: prisma-migrations
|
- name: prisma-migrations
|
||||||
image: "ghcr.io/berriai/litellm:main-stable"
|
image: ghcr.io/berriai/litellm-database:main-latest
|
||||||
command: ["python", "litellm/proxy/prisma_migration.py"]
|
command: ["python", "litellm/proxy/prisma_migration.py"]
|
||||||
workingDir: "/app"
|
workingDir: "/app"
|
||||||
env:
|
env:
|
||||||
{{- if .Values.db.deployStandalone }}
|
{{- if .Values.db.useExisting }}
|
||||||
- name: DATABASE_URL
|
|
||||||
value: postgresql://{{ .Values.postgresql.auth.username }}:{{ .Values.postgresql.auth.password }}@{{ .Release.Name }}-postgresql/{{ .Values.postgresql.auth.database }}
|
|
||||||
{{- else if .Values.db.useExisting }}
|
|
||||||
- name: DATABASE_URL
|
- name: DATABASE_URL
|
||||||
value: {{ .Values.db.url | quote }}
|
value: {{ .Values.db.url | quote }}
|
||||||
|
{{- else }}
|
||||||
|
- name: DATABASE_URL
|
||||||
|
value: postgresql://{{ .Values.postgresql.auth.username }}:{{ .Values.postgresql.auth.password }}@{{ .Release.Name }}-postgresql/{{ .Values.postgresql.auth.database }}
|
||||||
{{- end }}
|
{{- end }}
|
||||||
- name: DISABLE_SCHEMA_UPDATE
|
- name: DISABLE_SCHEMA_UPDATE
|
||||||
value: "{{ .Values.migrationJob.disableSchemaUpdate }}"
|
value: "false" # always run the migration from the Helm PreSync hook, override the value set
|
||||||
restartPolicy: OnFailure
|
restartPolicy: OnFailure
|
||||||
backoffLimit: {{ .Values.migrationJob.backoffLimit }}
|
backoffLimit: {{ .Values.migrationJob.backoffLimit }}
|
||||||
|
|
|
@ -75,6 +75,7 @@ Works for:
|
||||||
- Google AI Studio - Gemini models
|
- Google AI Studio - Gemini models
|
||||||
- Vertex AI models (Gemini + Anthropic)
|
- Vertex AI models (Gemini + Anthropic)
|
||||||
- Bedrock Models
|
- Bedrock Models
|
||||||
|
- Anthropic API Models
|
||||||
|
|
||||||
<Tabs>
|
<Tabs>
|
||||||
<TabItem value="sdk" label="SDK">
|
<TabItem value="sdk" label="SDK">
|
||||||
|
|
|
@ -93,7 +93,7 @@ curl http://0.0.0.0:4000/v1/chat/completions \
|
||||||
|
|
||||||
## Check Model Support
|
## Check Model Support
|
||||||
|
|
||||||
Call `litellm.get_model_info` to check if a model/provider supports `response_format`.
|
Call `litellm.get_model_info` to check if a model/provider supports `prefix`.
|
||||||
|
|
||||||
<Tabs>
|
<Tabs>
|
||||||
<TabItem value="sdk" label="SDK">
|
<TabItem value="sdk" label="SDK">
|
||||||
|
|
|
@ -957,3 +957,69 @@ curl http://0.0.0.0:4000/v1/chat/completions \
|
||||||
```
|
```
|
||||||
</TabItem>
|
</TabItem>
|
||||||
</Tabs>
|
</Tabs>
|
||||||
|
|
||||||
|
## Usage - passing 'user_id' to Anthropic
|
||||||
|
|
||||||
|
LiteLLM translates the OpenAI `user` param to Anthropic's `metadata[user_id]` param.
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="sdk" label="SDK">
|
||||||
|
|
||||||
|
```python
|
||||||
|
response = completion(
|
||||||
|
model="claude-3-5-sonnet-20240620",
|
||||||
|
messages=messages,
|
||||||
|
user="user_123",
|
||||||
|
)
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="proxy" label="PROXY">
|
||||||
|
|
||||||
|
1. Setup config.yaml
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: claude-3-5-sonnet-20240620
|
||||||
|
litellm_params:
|
||||||
|
model: anthropic/claude-3-5-sonnet-20240620
|
||||||
|
api_key: os.environ/ANTHROPIC_API_KEY
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Start Proxy
|
||||||
|
|
||||||
|
```
|
||||||
|
litellm --config /path/to/config.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Test it!
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl http://0.0.0.0:4000/v1/chat/completions \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-H "Authorization: Bearer <YOUR-LITELLM-KEY>" \
|
||||||
|
-d '{
|
||||||
|
"model": "claude-3-5-sonnet-20240620",
|
||||||
|
"messages": [{"role": "user", "content": "What is Anthropic?"}],
|
||||||
|
"user": "user_123"
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
## All Supported OpenAI Params
|
||||||
|
|
||||||
|
```
|
||||||
|
"stream",
|
||||||
|
"stop",
|
||||||
|
"temperature",
|
||||||
|
"top_p",
|
||||||
|
"max_tokens",
|
||||||
|
"max_completion_tokens",
|
||||||
|
"tools",
|
||||||
|
"tool_choice",
|
||||||
|
"extra_headers",
|
||||||
|
"parallel_tool_calls",
|
||||||
|
"response_format",
|
||||||
|
"user"
|
||||||
|
```
|
|
@ -37,7 +37,7 @@ os.environ["HUGGINGFACE_API_KEY"] = "huggingface_api_key"
|
||||||
messages = [{ "content": "There's a llama in my garden 😱 What should I do?","role": "user"}]
|
messages = [{ "content": "There's a llama in my garden 😱 What should I do?","role": "user"}]
|
||||||
|
|
||||||
# e.g. Call 'https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct' from Serverless Inference API
|
# e.g. Call 'https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct' from Serverless Inference API
|
||||||
response = litellm.completion(
|
response = completion(
|
||||||
model="huggingface/meta-llama/Meta-Llama-3.1-8B-Instruct",
|
model="huggingface/meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||||
messages=[{ "content": "Hello, how are you?","role": "user"}],
|
messages=[{ "content": "Hello, how are you?","role": "user"}],
|
||||||
stream=True
|
stream=True
|
||||||
|
@ -165,14 +165,14 @@ Steps to use
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import os
|
import os
|
||||||
import litellm
|
from litellm import completion
|
||||||
|
|
||||||
os.environ["HUGGINGFACE_API_KEY"] = ""
|
os.environ["HUGGINGFACE_API_KEY"] = ""
|
||||||
|
|
||||||
# TGI model: Call https://huggingface.co/glaiveai/glaive-coder-7b
|
# TGI model: Call https://huggingface.co/glaiveai/glaive-coder-7b
|
||||||
# add the 'huggingface/' prefix to the model to set huggingface as the provider
|
# add the 'huggingface/' prefix to the model to set huggingface as the provider
|
||||||
# set api base to your deployed api endpoint from hugging face
|
# set api base to your deployed api endpoint from hugging face
|
||||||
response = litellm.completion(
|
response = completion(
|
||||||
model="huggingface/glaiveai/glaive-coder-7b",
|
model="huggingface/glaiveai/glaive-coder-7b",
|
||||||
messages=[{ "content": "Hello, how are you?","role": "user"}],
|
messages=[{ "content": "Hello, how are you?","role": "user"}],
|
||||||
api_base="https://wjiegasee9bmqke2.us-east-1.aws.endpoints.huggingface.cloud"
|
api_base="https://wjiegasee9bmqke2.us-east-1.aws.endpoints.huggingface.cloud"
|
||||||
|
@ -383,6 +383,8 @@ def default_pt(messages):
|
||||||
#### Custom prompt templates
|
#### Custom prompt templates
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
import litellm
|
||||||
|
|
||||||
# Create your own custom prompt template works
|
# Create your own custom prompt template works
|
||||||
litellm.register_prompt_template(
|
litellm.register_prompt_template(
|
||||||
model="togethercomputer/LLaMA-2-7B-32K",
|
model="togethercomputer/LLaMA-2-7B-32K",
|
||||||
|
|
|
@ -1,6 +1,13 @@
|
||||||
|
import Tabs from '@theme/Tabs';
|
||||||
|
import TabItem from '@theme/TabItem';
|
||||||
|
|
||||||
# Jina AI
|
# Jina AI
|
||||||
https://jina.ai/embeddings/
|
https://jina.ai/embeddings/
|
||||||
|
|
||||||
|
Supported endpoints:
|
||||||
|
- /embeddings
|
||||||
|
- /rerank
|
||||||
|
|
||||||
## API Key
|
## API Key
|
||||||
```python
|
```python
|
||||||
# env variable
|
# env variable
|
||||||
|
@ -8,6 +15,10 @@ os.environ['JINA_AI_API_KEY']
|
||||||
```
|
```
|
||||||
|
|
||||||
## Sample Usage - Embedding
|
## Sample Usage - Embedding
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="sdk" label="SDK">
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from litellm import embedding
|
from litellm import embedding
|
||||||
import os
|
import os
|
||||||
|
@ -19,6 +30,142 @@ response = embedding(
|
||||||
)
|
)
|
||||||
print(response)
|
print(response)
|
||||||
```
|
```
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="proxy" label="PROXY">
|
||||||
|
|
||||||
|
1. Add to config.yaml
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: embedding-model
|
||||||
|
litellm_params:
|
||||||
|
model: jina_ai/jina-embeddings-v3
|
||||||
|
api_key: os.environ/JINA_AI_API_KEY
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Start proxy
|
||||||
|
|
||||||
|
```bash
|
||||||
|
litellm --config /path/to/config.yaml
|
||||||
|
|
||||||
|
# RUNNING on http://0.0.0.0:4000/
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Test it!
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -L -X POST 'http://0.0.0.0:4000/embeddings' \
|
||||||
|
-H 'Authorization: Bearer sk-1234' \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-d '{"input": ["hello world"], "model": "embedding-model"}'
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
## Sample Usage - Rerank
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="sdk" label="SDK">
|
||||||
|
|
||||||
|
```python
|
||||||
|
from litellm import rerank
|
||||||
|
import os
|
||||||
|
|
||||||
|
os.environ["JINA_AI_API_KEY"] = "sk-..."
|
||||||
|
|
||||||
|
query = "What is the capital of the United States?"
|
||||||
|
documents = [
|
||||||
|
"Carson City is the capital city of the American state of Nevada.",
|
||||||
|
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.",
|
||||||
|
"Washington, D.C. is the capital of the United States.",
|
||||||
|
"Capital punishment has existed in the United States since before it was a country.",
|
||||||
|
]
|
||||||
|
|
||||||
|
response = rerank(
|
||||||
|
model="jina_ai/jina-reranker-v2-base-multilingual",
|
||||||
|
query=query,
|
||||||
|
documents=documents,
|
||||||
|
top_n=3,
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="proxy" label="PROXY">
|
||||||
|
|
||||||
|
1. Add to config.yaml
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: rerank-model
|
||||||
|
litellm_params:
|
||||||
|
model: jina_ai/jina-reranker-v2-base-multilingual
|
||||||
|
api_key: os.environ/JINA_AI_API_KEY
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Start proxy
|
||||||
|
|
||||||
|
```bash
|
||||||
|
litellm --config /path/to/config.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Test it!
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -L -X POST 'http://0.0.0.0:4000/rerank' \
|
||||||
|
-H 'Authorization: Bearer sk-1234' \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-d '{
|
||||||
|
"model": "rerank-model",
|
||||||
|
"query": "What is the capital of the United States?",
|
||||||
|
"documents": [
|
||||||
|
"Carson City is the capital city of the American state of Nevada.",
|
||||||
|
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.",
|
||||||
|
"Washington, D.C. is the capital of the United States.",
|
||||||
|
"Capital punishment has existed in the United States since before it was a country."
|
||||||
|
],
|
||||||
|
"top_n": 3
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
## Supported Models
|
## Supported Models
|
||||||
All models listed here https://jina.ai/embeddings/ are supported
|
All models listed here https://jina.ai/embeddings/ are supported
|
||||||
|
|
||||||
|
## Supported Optional Rerank Parameters
|
||||||
|
|
||||||
|
All cohere rerank parameters are supported.
|
||||||
|
|
||||||
|
## Supported Optional Embeddings Parameters
|
||||||
|
|
||||||
|
```
|
||||||
|
dimensions
|
||||||
|
```
|
||||||
|
|
||||||
|
## Provider-specific parameters
|
||||||
|
|
||||||
|
Pass any jina ai specific parameters as a keyword argument to the `embedding` or `rerank` function, e.g.
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="sdk" label="SDK">
|
||||||
|
|
||||||
|
```python
|
||||||
|
response = embedding(
|
||||||
|
model="jina_ai/jina-embeddings-v3",
|
||||||
|
input=["good morning from litellm"],
|
||||||
|
dimensions=1536,
|
||||||
|
my_custom_param="my_custom_value", # any other jina ai specific parameters
|
||||||
|
)
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="proxy" label="PROXY">
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -L -X POST 'http://0.0.0.0:4000/embeddings' \
|
||||||
|
-H 'Authorization: Bearer sk-1234' \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-d '{"input": ["good morning from litellm"], "model": "jina_ai/jina-embeddings-v3", "dimensions": 1536, "my_custom_param": "my_custom_value"}'
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
|
|
@ -1562,6 +1562,10 @@ curl http://0.0.0.0:4000/v1/chat/completions \
|
||||||
## **Embedding Models**
|
## **Embedding Models**
|
||||||
|
|
||||||
#### Usage - Embedding
|
#### Usage - Embedding
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="sdk" label="SDK">
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import embedding
|
from litellm import embedding
|
||||||
|
@ -1574,6 +1578,49 @@ response = embedding(
|
||||||
)
|
)
|
||||||
print(response)
|
print(response)
|
||||||
```
|
```
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
<TabItem value="proxy" label="LiteLLM PROXY">
|
||||||
|
|
||||||
|
|
||||||
|
1. Add model to config.yaml
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: snowflake-arctic-embed-m-long-1731622468876
|
||||||
|
litellm_params:
|
||||||
|
model: vertex_ai/<your-model-id>
|
||||||
|
vertex_project: "adroit-crow-413218"
|
||||||
|
vertex_location: "us-central1"
|
||||||
|
vertex_credentials: adroit-crow-413218-a956eef1a2a8.json
|
||||||
|
|
||||||
|
litellm_settings:
|
||||||
|
drop_params: True
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Start Proxy
|
||||||
|
|
||||||
|
```
|
||||||
|
$ litellm --config /path/to/config.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Make Request using OpenAI Python SDK, Langchain Python SDK
|
||||||
|
|
||||||
|
```python
|
||||||
|
import openai
|
||||||
|
|
||||||
|
client = openai.OpenAI(api_key="sk-1234", base_url="http://0.0.0.0:4000")
|
||||||
|
|
||||||
|
response = client.embeddings.create(
|
||||||
|
model="snowflake-arctic-embed-m-long-1731622468876",
|
||||||
|
input = ["good morning from litellm", "this is another item"],
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
#### Supported Embedding Models
|
#### Supported Embedding Models
|
||||||
All models listed [here](https://github.com/BerriAI/litellm/blob/57f37f743886a0249f630a6792d49dffc2c5d9b7/model_prices_and_context_window.json#L835) are supported
|
All models listed [here](https://github.com/BerriAI/litellm/blob/57f37f743886a0249f630a6792d49dffc2c5d9b7/model_prices_and_context_window.json#L835) are supported
|
||||||
|
@ -1589,6 +1636,7 @@ All models listed [here](https://github.com/BerriAI/litellm/blob/57f37f743886a02
|
||||||
| textembedding-gecko@003 | `embedding(model="vertex_ai/textembedding-gecko@003", input)` |
|
| textembedding-gecko@003 | `embedding(model="vertex_ai/textembedding-gecko@003", input)` |
|
||||||
| text-embedding-preview-0409 | `embedding(model="vertex_ai/text-embedding-preview-0409", input)` |
|
| text-embedding-preview-0409 | `embedding(model="vertex_ai/text-embedding-preview-0409", input)` |
|
||||||
| text-multilingual-embedding-preview-0409 | `embedding(model="vertex_ai/text-multilingual-embedding-preview-0409", input)` |
|
| text-multilingual-embedding-preview-0409 | `embedding(model="vertex_ai/text-multilingual-embedding-preview-0409", input)` |
|
||||||
|
| Fine-tuned OR Custom Embedding models | `embedding(model="vertex_ai/<your-model-id>", input)` |
|
||||||
|
|
||||||
### Supported OpenAI (Unified) Params
|
### Supported OpenAI (Unified) Params
|
||||||
|
|
||||||
|
|
|
@ -791,9 +791,9 @@ general_settings:
|
||||||
| store_model_in_db | boolean | If true, allows `/model/new` endpoint to store model information in db. Endpoint disabled by default. [Doc on `/model/new` endpoint](./model_management.md#create-a-new-model) |
|
| store_model_in_db | boolean | If true, allows `/model/new` endpoint to store model information in db. Endpoint disabled by default. [Doc on `/model/new` endpoint](./model_management.md#create-a-new-model) |
|
||||||
| max_request_size_mb | int | The maximum size for requests in MB. Requests above this size will be rejected. |
|
| max_request_size_mb | int | The maximum size for requests in MB. Requests above this size will be rejected. |
|
||||||
| max_response_size_mb | int | The maximum size for responses in MB. LLM Responses above this size will not be sent. |
|
| max_response_size_mb | int | The maximum size for responses in MB. LLM Responses above this size will not be sent. |
|
||||||
| proxy_budget_rescheduler_min_time | int | The minimum time (in seconds) to wait before checking db for budget resets. |
|
| proxy_budget_rescheduler_min_time | int | The minimum time (in seconds) to wait before checking db for budget resets. **Default is 597 seconds** |
|
||||||
| proxy_budget_rescheduler_max_time | int | The maximum time (in seconds) to wait before checking db for budget resets. |
|
| proxy_budget_rescheduler_max_time | int | The maximum time (in seconds) to wait before checking db for budget resets. **Default is 605 seconds** |
|
||||||
| proxy_batch_write_at | int | Time (in seconds) to wait before batch writing spend logs to the db. |
|
| proxy_batch_write_at | int | Time (in seconds) to wait before batch writing spend logs to the db. **Default is 10 seconds** |
|
||||||
| alerting_args | dict | Args for Slack Alerting [Doc on Slack Alerting](./alerting.md) |
|
| alerting_args | dict | Args for Slack Alerting [Doc on Slack Alerting](./alerting.md) |
|
||||||
| custom_key_generate | str | Custom function for key generation [Doc on custom key generation](./virtual_keys.md#custom--key-generate) |
|
| custom_key_generate | str | Custom function for key generation [Doc on custom key generation](./virtual_keys.md#custom--key-generate) |
|
||||||
| allowed_ips | List[str] | List of IPs allowed to access the proxy. If not set, all IPs are allowed. |
|
| allowed_ips | List[str] | List of IPs allowed to access the proxy. If not set, all IPs are allowed. |
|
||||||
|
|
|
@ -66,10 +66,16 @@ Removes any field with `user_api_key_*` from metadata.
|
||||||
Found under `kwargs["standard_logging_object"]`. This is a standard payload, logged for every response.
|
Found under `kwargs["standard_logging_object"]`. This is a standard payload, logged for every response.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
|
||||||
class StandardLoggingPayload(TypedDict):
|
class StandardLoggingPayload(TypedDict):
|
||||||
id: str
|
id: str
|
||||||
|
trace_id: str # Trace multiple LLM calls belonging to same overall request (e.g. fallbacks/retries)
|
||||||
call_type: str
|
call_type: str
|
||||||
response_cost: float
|
response_cost: float
|
||||||
|
response_cost_failure_debug_info: Optional[
|
||||||
|
StandardLoggingModelCostFailureDebugInformation
|
||||||
|
]
|
||||||
|
status: StandardLoggingPayloadStatus
|
||||||
total_tokens: int
|
total_tokens: int
|
||||||
prompt_tokens: int
|
prompt_tokens: int
|
||||||
completion_tokens: int
|
completion_tokens: int
|
||||||
|
@ -84,13 +90,13 @@ class StandardLoggingPayload(TypedDict):
|
||||||
metadata: StandardLoggingMetadata
|
metadata: StandardLoggingMetadata
|
||||||
cache_hit: Optional[bool]
|
cache_hit: Optional[bool]
|
||||||
cache_key: Optional[str]
|
cache_key: Optional[str]
|
||||||
saved_cache_cost: Optional[float]
|
saved_cache_cost: float
|
||||||
request_tags: list
|
request_tags: list
|
||||||
end_user: Optional[str]
|
end_user: Optional[str]
|
||||||
requester_ip_address: Optional[str] # IP address of requester
|
requester_ip_address: Optional[str]
|
||||||
requester_metadata: Optional[dict] # metadata passed in request in the "metadata" field
|
|
||||||
messages: Optional[Union[str, list, dict]]
|
messages: Optional[Union[str, list, dict]]
|
||||||
response: Optional[Union[str, list, dict]]
|
response: Optional[Union[str, list, dict]]
|
||||||
|
error_str: Optional[str]
|
||||||
model_parameters: dict
|
model_parameters: dict
|
||||||
hidden_params: StandardLoggingHiddenParams
|
hidden_params: StandardLoggingHiddenParams
|
||||||
|
|
||||||
|
@ -99,12 +105,47 @@ class StandardLoggingHiddenParams(TypedDict):
|
||||||
cache_key: Optional[str]
|
cache_key: Optional[str]
|
||||||
api_base: Optional[str]
|
api_base: Optional[str]
|
||||||
response_cost: Optional[str]
|
response_cost: Optional[str]
|
||||||
additional_headers: Optional[dict]
|
additional_headers: Optional[StandardLoggingAdditionalHeaders]
|
||||||
|
|
||||||
|
class StandardLoggingAdditionalHeaders(TypedDict, total=False):
|
||||||
|
x_ratelimit_limit_requests: int
|
||||||
|
x_ratelimit_limit_tokens: int
|
||||||
|
x_ratelimit_remaining_requests: int
|
||||||
|
x_ratelimit_remaining_tokens: int
|
||||||
|
|
||||||
|
class StandardLoggingMetadata(StandardLoggingUserAPIKeyMetadata):
|
||||||
|
"""
|
||||||
|
Specific metadata k,v pairs logged to integration for easier cost tracking
|
||||||
|
"""
|
||||||
|
|
||||||
|
spend_logs_metadata: Optional[
|
||||||
|
dict
|
||||||
|
] # special param to log k,v pairs to spendlogs for a call
|
||||||
|
requester_ip_address: Optional[str]
|
||||||
|
requester_metadata: Optional[dict]
|
||||||
|
|
||||||
class StandardLoggingModelInformation(TypedDict):
|
class StandardLoggingModelInformation(TypedDict):
|
||||||
model_map_key: str
|
model_map_key: str
|
||||||
model_map_value: Optional[ModelInfo]
|
model_map_value: Optional[ModelInfo]
|
||||||
|
|
||||||
|
|
||||||
|
StandardLoggingPayloadStatus = Literal["success", "failure"]
|
||||||
|
|
||||||
|
class StandardLoggingModelCostFailureDebugInformation(TypedDict, total=False):
|
||||||
|
"""
|
||||||
|
Debug information, if cost tracking fails.
|
||||||
|
|
||||||
|
Avoid logging sensitive information like response or optional params
|
||||||
|
"""
|
||||||
|
|
||||||
|
error_str: Required[str]
|
||||||
|
traceback_str: Required[str]
|
||||||
|
model: str
|
||||||
|
cache_hit: Optional[bool]
|
||||||
|
custom_llm_provider: Optional[str]
|
||||||
|
base_model: Optional[str]
|
||||||
|
call_type: str
|
||||||
|
custom_pricing: Optional[bool]
|
||||||
```
|
```
|
||||||
|
|
||||||
## Langfuse
|
## Langfuse
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import Tabs from '@theme/Tabs';
|
import Tabs from '@theme/Tabs';
|
||||||
import TabItem from '@theme/TabItem';
|
import TabItem from '@theme/TabItem';
|
||||||
|
import Image from '@theme/IdealImage';
|
||||||
|
|
||||||
# ⚡ Best Practices for Production
|
# ⚡ Best Practices for Production
|
||||||
|
|
||||||
|
@ -112,7 +113,35 @@ general_settings:
|
||||||
disable_spend_logs: True
|
disable_spend_logs: True
|
||||||
```
|
```
|
||||||
|
|
||||||
## 7. Set LiteLLM Salt Key
|
## 7. Use Helm PreSync Hook for Database Migrations [BETA]
|
||||||
|
|
||||||
|
To ensure only one service manages database migrations, use our [Helm PreSync hook for Database Migrations](https://github.com/BerriAI/litellm/blob/main/deploy/charts/litellm-helm/templates/migrations-job.yaml). This ensures migrations are handled during `helm upgrade` or `helm install`, while LiteLLM pods explicitly disable migrations.
|
||||||
|
|
||||||
|
|
||||||
|
1. **Helm PreSync Hook**:
|
||||||
|
- The Helm PreSync hook is configured in the chart to run database migrations during deployments.
|
||||||
|
- The hook always sets `DISABLE_SCHEMA_UPDATE=false`, ensuring migrations are executed reliably.
|
||||||
|
|
||||||
|
Reference Settings to set on ArgoCD for `values.yaml`
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
db:
|
||||||
|
useExisting: true # use existing Postgres DB
|
||||||
|
url: postgresql://ishaanjaffer0324:3rnwpOBau6hT@ep-withered-mud-a5dkdpke.us-east-2.aws.neon.tech/test-argo-cd?sslmode=require # url of existing Postgres DB
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **LiteLLM Pods**:
|
||||||
|
- Set `DISABLE_SCHEMA_UPDATE=true` in LiteLLM pod configurations to prevent them from running migrations.
|
||||||
|
|
||||||
|
Example configuration for LiteLLM pod:
|
||||||
|
```yaml
|
||||||
|
env:
|
||||||
|
- name: DISABLE_SCHEMA_UPDATE
|
||||||
|
value: "true"
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## 8. Set LiteLLM Salt Key
|
||||||
|
|
||||||
If you plan on using the DB, set a salt key for encrypting/decrypting variables in the DB.
|
If you plan on using the DB, set a salt key for encrypting/decrypting variables in the DB.
|
||||||
|
|
||||||
|
|
|
@ -749,3 +749,18 @@ curl -L -X POST 'http://0.0.0.0:4000/v1/chat/completions' \
|
||||||
"mock_testing_fallbacks": true
|
"mock_testing_fallbacks": true
|
||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Disable Fallbacks per key
|
||||||
|
|
||||||
|
You can disable fallbacks per key by setting `disable_fallbacks: true` in your key metadata.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -L -X POST 'http://0.0.0.0:4000/key/generate' \
|
||||||
|
-H 'Authorization: Bearer sk-1234' \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-d '{
|
||||||
|
"metadata": {
|
||||||
|
"disable_fallbacks": true
|
||||||
|
}
|
||||||
|
}'
|
||||||
|
```
|
|
@ -114,3 +114,4 @@ curl http://0.0.0.0:4000/rerank \
|
||||||
| Cohere | [Usage](#quick-start) |
|
| Cohere | [Usage](#quick-start) |
|
||||||
| Together AI| [Usage](../docs/providers/togetherai) |
|
| Together AI| [Usage](../docs/providers/togetherai) |
|
||||||
| Azure AI| [Usage](../docs/providers/azure_ai) |
|
| Azure AI| [Usage](../docs/providers/azure_ai) |
|
||||||
|
| Jina AI| [Usage](../docs/providers/jina_ai) |
|
|
@ -1,3 +1,6 @@
|
||||||
|
import Tabs from '@theme/Tabs';
|
||||||
|
import TabItem from '@theme/TabItem';
|
||||||
|
|
||||||
# Secret Manager
|
# Secret Manager
|
||||||
LiteLLM supports reading secrets from Azure Key Vault, Google Secret Manager
|
LiteLLM supports reading secrets from Azure Key Vault, Google Secret Manager
|
||||||
|
|
||||||
|
@ -59,14 +62,35 @@ os.environ["AWS_REGION_NAME"] = "" # us-east-1, us-east-2, us-west-1, us-west-2
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Enable AWS Secret Manager in config.
|
2. Enable AWS Secret Manager in config.
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="read_only" label="Read Keys from AWS Secret Manager">
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
general_settings:
|
general_settings:
|
||||||
master_key: os.environ/litellm_master_key
|
master_key: os.environ/litellm_master_key
|
||||||
key_management_system: "aws_secret_manager" # 👈 KEY CHANGE
|
key_management_system: "aws_secret_manager" # 👈 KEY CHANGE
|
||||||
key_management_settings:
|
key_management_settings:
|
||||||
hosted_keys: ["litellm_master_key"] # 👈 Specify which env keys you stored on AWS
|
hosted_keys: ["litellm_master_key"] # 👈 Specify which env keys you stored on AWS
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
<TabItem value="write_only" label="Write Virtual Keys to AWS Secret Manager">
|
||||||
|
|
||||||
|
This will only store virtual keys in AWS Secret Manager. No keys will be read from AWS Secret Manager.
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
general_settings:
|
||||||
|
key_management_system: "aws_secret_manager" # 👈 KEY CHANGE
|
||||||
|
key_management_settings:
|
||||||
|
store_virtual_keys: true
|
||||||
|
access_mode: "write_only" # Literal["read_only", "write_only", "read_and_write"]
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
3. Run proxy
|
3. Run proxy
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
@ -181,16 +205,14 @@ litellm --config /path/to/config.yaml
|
||||||
|
|
||||||
Use encrypted keys from Google KMS on the proxy
|
Use encrypted keys from Google KMS on the proxy
|
||||||
|
|
||||||
### Usage with LiteLLM Proxy Server
|
Step 1. Add keys to env
|
||||||
|
|
||||||
## Step 1. Add keys to env
|
|
||||||
```
|
```
|
||||||
export GOOGLE_APPLICATION_CREDENTIALS="/path/to/credentials.json"
|
export GOOGLE_APPLICATION_CREDENTIALS="/path/to/credentials.json"
|
||||||
export GOOGLE_KMS_RESOURCE_NAME="projects/*/locations/*/keyRings/*/cryptoKeys/*"
|
export GOOGLE_KMS_RESOURCE_NAME="projects/*/locations/*/keyRings/*/cryptoKeys/*"
|
||||||
export PROXY_DATABASE_URL_ENCRYPTED=b'\n$\x00D\xac\xb4/\x8e\xc...'
|
export PROXY_DATABASE_URL_ENCRYPTED=b'\n$\x00D\xac\xb4/\x8e\xc...'
|
||||||
```
|
```
|
||||||
|
|
||||||
## Step 2: Update Config
|
Step 2: Update Config
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
general_settings:
|
general_settings:
|
||||||
|
@ -199,7 +221,7 @@ general_settings:
|
||||||
master_key: sk-1234
|
master_key: sk-1234
|
||||||
```
|
```
|
||||||
|
|
||||||
## Step 3: Start + test proxy
|
Step 3: Start + test proxy
|
||||||
|
|
||||||
```
|
```
|
||||||
$ litellm --config /path/to/config.yaml
|
$ litellm --config /path/to/config.yaml
|
||||||
|
@ -215,3 +237,17 @@ $ litellm --test
|
||||||
<!--
|
<!--
|
||||||
## .env Files
|
## .env Files
|
||||||
If no secret manager client is specified, Litellm automatically uses the `.env` file to manage sensitive data. -->
|
If no secret manager client is specified, Litellm automatically uses the `.env` file to manage sensitive data. -->
|
||||||
|
|
||||||
|
|
||||||
|
## All Secret Manager Settings
|
||||||
|
|
||||||
|
All settings related to secret management
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
general_settings:
|
||||||
|
key_management_system: "aws_secret_manager" # REQUIRED
|
||||||
|
key_management_settings:
|
||||||
|
store_virtual_keys: true # OPTIONAL. Defaults to False, when True will store virtual keys in secret manager
|
||||||
|
access_mode: "write_only" # OPTIONAL. Literal["read_only", "write_only", "read_and_write"]. Defaults to "read_only"
|
||||||
|
hosted_keys: ["litellm_master_key"] # OPTIONAL. Specify which env keys you stored on AWS
|
||||||
|
```
|
|
@ -305,7 +305,7 @@ secret_manager_client: Optional[Any] = (
|
||||||
)
|
)
|
||||||
_google_kms_resource_name: Optional[str] = None
|
_google_kms_resource_name: Optional[str] = None
|
||||||
_key_management_system: Optional[KeyManagementSystem] = None
|
_key_management_system: Optional[KeyManagementSystem] = None
|
||||||
_key_management_settings: Optional[KeyManagementSettings] = None
|
_key_management_settings: KeyManagementSettings = KeyManagementSettings()
|
||||||
#### PII MASKING ####
|
#### PII MASKING ####
|
||||||
output_parse_pii: bool = False
|
output_parse_pii: bool = False
|
||||||
#############################################
|
#############################################
|
||||||
|
@ -962,6 +962,8 @@ from .utils import (
|
||||||
supports_response_schema,
|
supports_response_schema,
|
||||||
supports_parallel_function_calling,
|
supports_parallel_function_calling,
|
||||||
supports_vision,
|
supports_vision,
|
||||||
|
supports_audio_input,
|
||||||
|
supports_audio_output,
|
||||||
supports_system_messages,
|
supports_system_messages,
|
||||||
get_litellm_params,
|
get_litellm_params,
|
||||||
acreate,
|
acreate,
|
||||||
|
|
|
@ -46,6 +46,9 @@ from litellm.llms.OpenAI.cost_calculation import (
|
||||||
from litellm.llms.OpenAI.cost_calculation import cost_per_token as openai_cost_per_token
|
from litellm.llms.OpenAI.cost_calculation import cost_per_token as openai_cost_per_token
|
||||||
from litellm.llms.OpenAI.cost_calculation import cost_router as openai_cost_router
|
from litellm.llms.OpenAI.cost_calculation import cost_router as openai_cost_router
|
||||||
from litellm.llms.together_ai.cost_calculator import get_model_params_and_category
|
from litellm.llms.together_ai.cost_calculator import get_model_params_and_category
|
||||||
|
from litellm.llms.vertex_ai_and_google_ai_studio.image_generation.cost_calculator import (
|
||||||
|
cost_calculator as vertex_ai_image_cost_calculator,
|
||||||
|
)
|
||||||
from litellm.types.llms.openai import HttpxBinaryResponseContent
|
from litellm.types.llms.openai import HttpxBinaryResponseContent
|
||||||
from litellm.types.rerank import RerankResponse
|
from litellm.types.rerank import RerankResponse
|
||||||
from litellm.types.router import SPECIAL_MODEL_INFO_PARAMS
|
from litellm.types.router import SPECIAL_MODEL_INFO_PARAMS
|
||||||
|
@ -667,9 +670,11 @@ def completion_cost( # noqa: PLR0915
|
||||||
):
|
):
|
||||||
### IMAGE GENERATION COST CALCULATION ###
|
### IMAGE GENERATION COST CALCULATION ###
|
||||||
if custom_llm_provider == "vertex_ai":
|
if custom_llm_provider == "vertex_ai":
|
||||||
# https://cloud.google.com/vertex-ai/generative-ai/pricing
|
if isinstance(completion_response, ImageResponse):
|
||||||
# Vertex Charges Flat $0.20 per image
|
return vertex_ai_image_cost_calculator(
|
||||||
return 0.020
|
model=model,
|
||||||
|
image_response=completion_response,
|
||||||
|
)
|
||||||
elif custom_llm_provider == "bedrock":
|
elif custom_llm_provider == "bedrock":
|
||||||
if isinstance(completion_response, ImageResponse):
|
if isinstance(completion_response, ImageResponse):
|
||||||
return bedrock_image_cost_calculator(
|
return bedrock_image_cost_calculator(
|
||||||
|
|
|
@ -239,7 +239,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"ContextWindowExceededError: {exception_provider} - {message}",
|
message=f"ContextWindowExceededError: {exception_provider} - {message}",
|
||||||
llm_provider=custom_llm_provider,
|
llm_provider=custom_llm_provider,
|
||||||
model=model,
|
model=model,
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
litellm_debug_info=extra_information,
|
litellm_debug_info=extra_information,
|
||||||
)
|
)
|
||||||
elif (
|
elif (
|
||||||
|
@ -251,7 +251,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"{exception_provider} - {message}",
|
message=f"{exception_provider} - {message}",
|
||||||
llm_provider=custom_llm_provider,
|
llm_provider=custom_llm_provider,
|
||||||
model=model,
|
model=model,
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
litellm_debug_info=extra_information,
|
litellm_debug_info=extra_information,
|
||||||
)
|
)
|
||||||
elif "A timeout occurred" in error_str:
|
elif "A timeout occurred" in error_str:
|
||||||
|
@ -271,7 +271,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"ContentPolicyViolationError: {exception_provider} - {message}",
|
message=f"ContentPolicyViolationError: {exception_provider} - {message}",
|
||||||
llm_provider=custom_llm_provider,
|
llm_provider=custom_llm_provider,
|
||||||
model=model,
|
model=model,
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
litellm_debug_info=extra_information,
|
litellm_debug_info=extra_information,
|
||||||
)
|
)
|
||||||
elif (
|
elif (
|
||||||
|
@ -283,7 +283,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"{exception_provider} - {message}",
|
message=f"{exception_provider} - {message}",
|
||||||
llm_provider=custom_llm_provider,
|
llm_provider=custom_llm_provider,
|
||||||
model=model,
|
model=model,
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
litellm_debug_info=extra_information,
|
litellm_debug_info=extra_information,
|
||||||
)
|
)
|
||||||
elif "Web server is returning an unknown error" in error_str:
|
elif "Web server is returning an unknown error" in error_str:
|
||||||
|
@ -299,7 +299,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"RateLimitError: {exception_provider} - {message}",
|
message=f"RateLimitError: {exception_provider} - {message}",
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider=custom_llm_provider,
|
llm_provider=custom_llm_provider,
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
litellm_debug_info=extra_information,
|
litellm_debug_info=extra_information,
|
||||||
)
|
)
|
||||||
elif (
|
elif (
|
||||||
|
@ -311,7 +311,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"AuthenticationError: {exception_provider} - {message}",
|
message=f"AuthenticationError: {exception_provider} - {message}",
|
||||||
llm_provider=custom_llm_provider,
|
llm_provider=custom_llm_provider,
|
||||||
model=model,
|
model=model,
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
litellm_debug_info=extra_information,
|
litellm_debug_info=extra_information,
|
||||||
)
|
)
|
||||||
elif "Mistral API raised a streaming error" in error_str:
|
elif "Mistral API raised a streaming error" in error_str:
|
||||||
|
@ -335,7 +335,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"{exception_provider} - {message}",
|
message=f"{exception_provider} - {message}",
|
||||||
llm_provider=custom_llm_provider,
|
llm_provider=custom_llm_provider,
|
||||||
model=model,
|
model=model,
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
litellm_debug_info=extra_information,
|
litellm_debug_info=extra_information,
|
||||||
)
|
)
|
||||||
elif original_exception.status_code == 401:
|
elif original_exception.status_code == 401:
|
||||||
|
@ -344,7 +344,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"AuthenticationError: {exception_provider} - {message}",
|
message=f"AuthenticationError: {exception_provider} - {message}",
|
||||||
llm_provider=custom_llm_provider,
|
llm_provider=custom_llm_provider,
|
||||||
model=model,
|
model=model,
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
litellm_debug_info=extra_information,
|
litellm_debug_info=extra_information,
|
||||||
)
|
)
|
||||||
elif original_exception.status_code == 404:
|
elif original_exception.status_code == 404:
|
||||||
|
@ -353,7 +353,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"NotFoundError: {exception_provider} - {message}",
|
message=f"NotFoundError: {exception_provider} - {message}",
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider=custom_llm_provider,
|
llm_provider=custom_llm_provider,
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
litellm_debug_info=extra_information,
|
litellm_debug_info=extra_information,
|
||||||
)
|
)
|
||||||
elif original_exception.status_code == 408:
|
elif original_exception.status_code == 408:
|
||||||
|
@ -516,7 +516,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"ReplicateException - {error_str}",
|
message=f"ReplicateException - {error_str}",
|
||||||
llm_provider="replicate",
|
llm_provider="replicate",
|
||||||
model=model,
|
model=model,
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
elif "input is too long" in error_str:
|
elif "input is too long" in error_str:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
|
@ -524,7 +524,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"ReplicateException - {error_str}",
|
message=f"ReplicateException - {error_str}",
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider="replicate",
|
llm_provider="replicate",
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
elif exception_type == "ModelError":
|
elif exception_type == "ModelError":
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
|
@ -532,7 +532,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"ReplicateException - {error_str}",
|
message=f"ReplicateException - {error_str}",
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider="replicate",
|
llm_provider="replicate",
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
elif "Request was throttled" in error_str:
|
elif "Request was throttled" in error_str:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
|
@ -540,7 +540,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"ReplicateException - {error_str}",
|
message=f"ReplicateException - {error_str}",
|
||||||
llm_provider="replicate",
|
llm_provider="replicate",
|
||||||
model=model,
|
model=model,
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
elif hasattr(original_exception, "status_code"):
|
elif hasattr(original_exception, "status_code"):
|
||||||
if original_exception.status_code == 401:
|
if original_exception.status_code == 401:
|
||||||
|
@ -549,7 +549,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"ReplicateException - {original_exception.message}",
|
message=f"ReplicateException - {original_exception.message}",
|
||||||
llm_provider="replicate",
|
llm_provider="replicate",
|
||||||
model=model,
|
model=model,
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
elif (
|
elif (
|
||||||
original_exception.status_code == 400
|
original_exception.status_code == 400
|
||||||
|
@ -560,7 +560,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"ReplicateException - {original_exception.message}",
|
message=f"ReplicateException - {original_exception.message}",
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider="replicate",
|
llm_provider="replicate",
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
elif original_exception.status_code == 422:
|
elif original_exception.status_code == 422:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
|
@ -568,7 +568,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"ReplicateException - {original_exception.message}",
|
message=f"ReplicateException - {original_exception.message}",
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider="replicate",
|
llm_provider="replicate",
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
elif original_exception.status_code == 408:
|
elif original_exception.status_code == 408:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
|
@ -583,7 +583,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"ReplicateException - {original_exception.message}",
|
message=f"ReplicateException - {original_exception.message}",
|
||||||
llm_provider="replicate",
|
llm_provider="replicate",
|
||||||
model=model,
|
model=model,
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
elif original_exception.status_code == 429:
|
elif original_exception.status_code == 429:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
|
@ -591,7 +591,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"ReplicateException - {original_exception.message}",
|
message=f"ReplicateException - {original_exception.message}",
|
||||||
llm_provider="replicate",
|
llm_provider="replicate",
|
||||||
model=model,
|
model=model,
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
elif original_exception.status_code == 500:
|
elif original_exception.status_code == 500:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
|
@ -599,7 +599,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"ReplicateException - {original_exception.message}",
|
message=f"ReplicateException - {original_exception.message}",
|
||||||
llm_provider="replicate",
|
llm_provider="replicate",
|
||||||
model=model,
|
model=model,
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
raise APIError(
|
raise APIError(
|
||||||
|
@ -631,7 +631,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"{custom_llm_provider}Exception: Authentication Error - {error_str}",
|
message=f"{custom_llm_provider}Exception: Authentication Error - {error_str}",
|
||||||
llm_provider=custom_llm_provider,
|
llm_provider=custom_llm_provider,
|
||||||
model=model,
|
model=model,
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
litellm_debug_info=extra_information,
|
litellm_debug_info=extra_information,
|
||||||
)
|
)
|
||||||
elif "token_quota_reached" in error_str:
|
elif "token_quota_reached" in error_str:
|
||||||
|
@ -640,7 +640,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"{custom_llm_provider}Exception: Rate Limit Errror - {error_str}",
|
message=f"{custom_llm_provider}Exception: Rate Limit Errror - {error_str}",
|
||||||
llm_provider=custom_llm_provider,
|
llm_provider=custom_llm_provider,
|
||||||
model=model,
|
model=model,
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
elif (
|
elif (
|
||||||
"The server received an invalid response from an upstream server."
|
"The server received an invalid response from an upstream server."
|
||||||
|
@ -750,7 +750,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"BedrockException - {error_str}\n. Enable 'litellm.modify_params=True' (for PROXY do: `litellm_settings::modify_params: True`) to insert a dummy assistant message and fix this error.",
|
message=f"BedrockException - {error_str}\n. Enable 'litellm.modify_params=True' (for PROXY do: `litellm_settings::modify_params: True`) to insert a dummy assistant message and fix this error.",
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider="bedrock",
|
llm_provider="bedrock",
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
elif "Malformed input request" in error_str:
|
elif "Malformed input request" in error_str:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
|
@ -758,7 +758,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"BedrockException - {error_str}",
|
message=f"BedrockException - {error_str}",
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider="bedrock",
|
llm_provider="bedrock",
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
elif "A conversation must start with a user message." in error_str:
|
elif "A conversation must start with a user message." in error_str:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
|
@ -766,7 +766,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"BedrockException - {error_str}\n. Pass in default user message via `completion(..,user_continue_message=)` or enable `litellm.modify_params=True`.\nFor Proxy: do via `litellm_settings::modify_params: True` or user_continue_message under `litellm_params`",
|
message=f"BedrockException - {error_str}\n. Pass in default user message via `completion(..,user_continue_message=)` or enable `litellm.modify_params=True`.\nFor Proxy: do via `litellm_settings::modify_params: True` or user_continue_message under `litellm_params`",
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider="bedrock",
|
llm_provider="bedrock",
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
elif (
|
elif (
|
||||||
"Unable to locate credentials" in error_str
|
"Unable to locate credentials" in error_str
|
||||||
|
@ -778,7 +778,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"BedrockException Invalid Authentication - {error_str}",
|
message=f"BedrockException Invalid Authentication - {error_str}",
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider="bedrock",
|
llm_provider="bedrock",
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
elif "AccessDeniedException" in error_str:
|
elif "AccessDeniedException" in error_str:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
|
@ -786,7 +786,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"BedrockException PermissionDeniedError - {error_str}",
|
message=f"BedrockException PermissionDeniedError - {error_str}",
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider="bedrock",
|
llm_provider="bedrock",
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
elif (
|
elif (
|
||||||
"throttlingException" in error_str
|
"throttlingException" in error_str
|
||||||
|
@ -797,7 +797,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"BedrockException: Rate Limit Error - {error_str}",
|
message=f"BedrockException: Rate Limit Error - {error_str}",
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider="bedrock",
|
llm_provider="bedrock",
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
elif (
|
elif (
|
||||||
"Connect timeout on endpoint URL" in error_str
|
"Connect timeout on endpoint URL" in error_str
|
||||||
|
@ -836,7 +836,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"BedrockException - {original_exception.message}",
|
message=f"BedrockException - {original_exception.message}",
|
||||||
llm_provider="bedrock",
|
llm_provider="bedrock",
|
||||||
model=model,
|
model=model,
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
elif original_exception.status_code == 400:
|
elif original_exception.status_code == 400:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
|
@ -844,7 +844,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"BedrockException - {original_exception.message}",
|
message=f"BedrockException - {original_exception.message}",
|
||||||
llm_provider="bedrock",
|
llm_provider="bedrock",
|
||||||
model=model,
|
model=model,
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
elif original_exception.status_code == 404:
|
elif original_exception.status_code == 404:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
|
@ -852,7 +852,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"BedrockException - {original_exception.message}",
|
message=f"BedrockException - {original_exception.message}",
|
||||||
llm_provider="bedrock",
|
llm_provider="bedrock",
|
||||||
model=model,
|
model=model,
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
elif original_exception.status_code == 408:
|
elif original_exception.status_code == 408:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
|
@ -868,7 +868,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"BedrockException - {original_exception.message}",
|
message=f"BedrockException - {original_exception.message}",
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider=custom_llm_provider,
|
llm_provider=custom_llm_provider,
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
litellm_debug_info=extra_information,
|
litellm_debug_info=extra_information,
|
||||||
)
|
)
|
||||||
elif original_exception.status_code == 429:
|
elif original_exception.status_code == 429:
|
||||||
|
@ -877,7 +877,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"BedrockException - {original_exception.message}",
|
message=f"BedrockException - {original_exception.message}",
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider=custom_llm_provider,
|
llm_provider=custom_llm_provider,
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
litellm_debug_info=extra_information,
|
litellm_debug_info=extra_information,
|
||||||
)
|
)
|
||||||
elif original_exception.status_code == 503:
|
elif original_exception.status_code == 503:
|
||||||
|
@ -886,7 +886,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"BedrockException - {original_exception.message}",
|
message=f"BedrockException - {original_exception.message}",
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider=custom_llm_provider,
|
llm_provider=custom_llm_provider,
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
litellm_debug_info=extra_information,
|
litellm_debug_info=extra_information,
|
||||||
)
|
)
|
||||||
elif original_exception.status_code == 504: # gateway timeout error
|
elif original_exception.status_code == 504: # gateway timeout error
|
||||||
|
@ -907,7 +907,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"litellm.BadRequestError: SagemakerException - {error_str}",
|
message=f"litellm.BadRequestError: SagemakerException - {error_str}",
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider="sagemaker",
|
llm_provider="sagemaker",
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
elif (
|
elif (
|
||||||
"Input validation error: `best_of` must be > 0 and <= 2"
|
"Input validation error: `best_of` must be > 0 and <= 2"
|
||||||
|
@ -918,7 +918,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message="SagemakerException - the value of 'n' must be > 0 and <= 2 for sagemaker endpoints",
|
message="SagemakerException - the value of 'n' must be > 0 and <= 2 for sagemaker endpoints",
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider="sagemaker",
|
llm_provider="sagemaker",
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
elif (
|
elif (
|
||||||
"`inputs` tokens + `max_new_tokens` must be <=" in error_str
|
"`inputs` tokens + `max_new_tokens` must be <=" in error_str
|
||||||
|
@ -929,7 +929,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"SagemakerException - {error_str}",
|
message=f"SagemakerException - {error_str}",
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider="sagemaker",
|
llm_provider="sagemaker",
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
elif hasattr(original_exception, "status_code"):
|
elif hasattr(original_exception, "status_code"):
|
||||||
if original_exception.status_code == 500:
|
if original_exception.status_code == 500:
|
||||||
|
@ -951,7 +951,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"SagemakerException - {original_exception.message}",
|
message=f"SagemakerException - {original_exception.message}",
|
||||||
llm_provider=custom_llm_provider,
|
llm_provider=custom_llm_provider,
|
||||||
model=model,
|
model=model,
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
elif original_exception.status_code == 400:
|
elif original_exception.status_code == 400:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
|
@ -959,7 +959,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"SagemakerException - {original_exception.message}",
|
message=f"SagemakerException - {original_exception.message}",
|
||||||
llm_provider=custom_llm_provider,
|
llm_provider=custom_llm_provider,
|
||||||
model=model,
|
model=model,
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
elif original_exception.status_code == 404:
|
elif original_exception.status_code == 404:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
|
@ -967,7 +967,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"SagemakerException - {original_exception.message}",
|
message=f"SagemakerException - {original_exception.message}",
|
||||||
llm_provider=custom_llm_provider,
|
llm_provider=custom_llm_provider,
|
||||||
model=model,
|
model=model,
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
elif original_exception.status_code == 408:
|
elif original_exception.status_code == 408:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
|
@ -986,7 +986,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"SagemakerException - {original_exception.message}",
|
message=f"SagemakerException - {original_exception.message}",
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider=custom_llm_provider,
|
llm_provider=custom_llm_provider,
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
litellm_debug_info=extra_information,
|
litellm_debug_info=extra_information,
|
||||||
)
|
)
|
||||||
elif original_exception.status_code == 429:
|
elif original_exception.status_code == 429:
|
||||||
|
@ -995,7 +995,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"SagemakerException - {original_exception.message}",
|
message=f"SagemakerException - {original_exception.message}",
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider=custom_llm_provider,
|
llm_provider=custom_llm_provider,
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
litellm_debug_info=extra_information,
|
litellm_debug_info=extra_information,
|
||||||
)
|
)
|
||||||
elif original_exception.status_code == 503:
|
elif original_exception.status_code == 503:
|
||||||
|
@ -1004,7 +1004,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"SagemakerException - {original_exception.message}",
|
message=f"SagemakerException - {original_exception.message}",
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider=custom_llm_provider,
|
llm_provider=custom_llm_provider,
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
litellm_debug_info=extra_information,
|
litellm_debug_info=extra_information,
|
||||||
)
|
)
|
||||||
elif original_exception.status_code == 504: # gateway timeout error
|
elif original_exception.status_code == 504: # gateway timeout error
|
||||||
|
@ -1217,7 +1217,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message="GeminiException - Invalid api key",
|
message="GeminiException - Invalid api key",
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider="palm",
|
llm_provider="palm",
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
"504 Deadline expired before operation could complete." in error_str
|
"504 Deadline expired before operation could complete." in error_str
|
||||||
|
@ -1235,7 +1235,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"GeminiException - {error_str}",
|
message=f"GeminiException - {error_str}",
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider="palm",
|
llm_provider="palm",
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
"500 An internal error has occurred." in error_str
|
"500 An internal error has occurred." in error_str
|
||||||
|
@ -1262,7 +1262,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"GeminiException - {error_str}",
|
message=f"GeminiException - {error_str}",
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider="palm",
|
llm_provider="palm",
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
# Dailed: Error occurred: 400 Request payload size exceeds the limit: 20000 bytes
|
# Dailed: Error occurred: 400 Request payload size exceeds the limit: 20000 bytes
|
||||||
elif custom_llm_provider == "cloudflare":
|
elif custom_llm_provider == "cloudflare":
|
||||||
|
@ -1272,7 +1272,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"Cloudflare Exception - {original_exception.message}",
|
message=f"Cloudflare Exception - {original_exception.message}",
|
||||||
llm_provider="cloudflare",
|
llm_provider="cloudflare",
|
||||||
model=model,
|
model=model,
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
if "must have required property" in error_str:
|
if "must have required property" in error_str:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
|
@ -1280,7 +1280,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"Cloudflare Exception - {original_exception.message}",
|
message=f"Cloudflare Exception - {original_exception.message}",
|
||||||
llm_provider="cloudflare",
|
llm_provider="cloudflare",
|
||||||
model=model,
|
model=model,
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
elif (
|
elif (
|
||||||
custom_llm_provider == "cohere" or custom_llm_provider == "cohere_chat"
|
custom_llm_provider == "cohere" or custom_llm_provider == "cohere_chat"
|
||||||
|
@ -1294,7 +1294,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"CohereException - {original_exception.message}",
|
message=f"CohereException - {original_exception.message}",
|
||||||
llm_provider="cohere",
|
llm_provider="cohere",
|
||||||
model=model,
|
model=model,
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
elif "too many tokens" in error_str:
|
elif "too many tokens" in error_str:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
|
@ -1302,7 +1302,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"CohereException - {original_exception.message}",
|
message=f"CohereException - {original_exception.message}",
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider="cohere",
|
llm_provider="cohere",
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
elif hasattr(original_exception, "status_code"):
|
elif hasattr(original_exception, "status_code"):
|
||||||
if (
|
if (
|
||||||
|
@ -1314,7 +1314,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"CohereException - {original_exception.message}",
|
message=f"CohereException - {original_exception.message}",
|
||||||
llm_provider="cohere",
|
llm_provider="cohere",
|
||||||
model=model,
|
model=model,
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
elif original_exception.status_code == 408:
|
elif original_exception.status_code == 408:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
|
@ -1329,7 +1329,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"CohereException - {original_exception.message}",
|
message=f"CohereException - {original_exception.message}",
|
||||||
llm_provider="cohere",
|
llm_provider="cohere",
|
||||||
model=model,
|
model=model,
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
elif (
|
elif (
|
||||||
"CohereConnectionError" in exception_type
|
"CohereConnectionError" in exception_type
|
||||||
|
@ -1339,7 +1339,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"CohereException - {original_exception.message}",
|
message=f"CohereException - {original_exception.message}",
|
||||||
llm_provider="cohere",
|
llm_provider="cohere",
|
||||||
model=model,
|
model=model,
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
elif "invalid type:" in error_str:
|
elif "invalid type:" in error_str:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
|
@ -1347,7 +1347,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"CohereException - {original_exception.message}",
|
message=f"CohereException - {original_exception.message}",
|
||||||
llm_provider="cohere",
|
llm_provider="cohere",
|
||||||
model=model,
|
model=model,
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
elif "Unexpected server error" in error_str:
|
elif "Unexpected server error" in error_str:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
|
@ -1355,7 +1355,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"CohereException - {original_exception.message}",
|
message=f"CohereException - {original_exception.message}",
|
||||||
llm_provider="cohere",
|
llm_provider="cohere",
|
||||||
model=model,
|
model=model,
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if hasattr(original_exception, "status_code"):
|
if hasattr(original_exception, "status_code"):
|
||||||
|
@ -1375,7 +1375,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=error_str,
|
message=error_str,
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider="huggingface",
|
llm_provider="huggingface",
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
elif "A valid user token is required" in error_str:
|
elif "A valid user token is required" in error_str:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
|
@ -1383,7 +1383,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=error_str,
|
message=error_str,
|
||||||
llm_provider="huggingface",
|
llm_provider="huggingface",
|
||||||
model=model,
|
model=model,
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
elif "Rate limit reached" in error_str:
|
elif "Rate limit reached" in error_str:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
|
@ -1391,7 +1391,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=error_str,
|
message=error_str,
|
||||||
llm_provider="huggingface",
|
llm_provider="huggingface",
|
||||||
model=model,
|
model=model,
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
if hasattr(original_exception, "status_code"):
|
if hasattr(original_exception, "status_code"):
|
||||||
if original_exception.status_code == 401:
|
if original_exception.status_code == 401:
|
||||||
|
@ -1400,7 +1400,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"HuggingfaceException - {original_exception.message}",
|
message=f"HuggingfaceException - {original_exception.message}",
|
||||||
llm_provider="huggingface",
|
llm_provider="huggingface",
|
||||||
model=model,
|
model=model,
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
elif original_exception.status_code == 400:
|
elif original_exception.status_code == 400:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
|
@ -1408,7 +1408,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"HuggingfaceException - {original_exception.message}",
|
message=f"HuggingfaceException - {original_exception.message}",
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider="huggingface",
|
llm_provider="huggingface",
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
elif original_exception.status_code == 408:
|
elif original_exception.status_code == 408:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
|
@ -1423,7 +1423,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"HuggingfaceException - {original_exception.message}",
|
message=f"HuggingfaceException - {original_exception.message}",
|
||||||
llm_provider="huggingface",
|
llm_provider="huggingface",
|
||||||
model=model,
|
model=model,
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
elif original_exception.status_code == 503:
|
elif original_exception.status_code == 503:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
|
@ -1431,7 +1431,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"HuggingfaceException - {original_exception.message}",
|
message=f"HuggingfaceException - {original_exception.message}",
|
||||||
llm_provider="huggingface",
|
llm_provider="huggingface",
|
||||||
model=model,
|
model=model,
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
|
@ -1450,7 +1450,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"AI21Exception - {original_exception.message}",
|
message=f"AI21Exception - {original_exception.message}",
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider="ai21",
|
llm_provider="ai21",
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
if "Bad or missing API token." in original_exception.message:
|
if "Bad or missing API token." in original_exception.message:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
|
@ -1458,7 +1458,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"AI21Exception - {original_exception.message}",
|
message=f"AI21Exception - {original_exception.message}",
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider="ai21",
|
llm_provider="ai21",
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
if hasattr(original_exception, "status_code"):
|
if hasattr(original_exception, "status_code"):
|
||||||
if original_exception.status_code == 401:
|
if original_exception.status_code == 401:
|
||||||
|
@ -1467,7 +1467,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"AI21Exception - {original_exception.message}",
|
message=f"AI21Exception - {original_exception.message}",
|
||||||
llm_provider="ai21",
|
llm_provider="ai21",
|
||||||
model=model,
|
model=model,
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
elif original_exception.status_code == 408:
|
elif original_exception.status_code == 408:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
|
@ -1482,7 +1482,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"AI21Exception - {original_exception.message}",
|
message=f"AI21Exception - {original_exception.message}",
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider="ai21",
|
llm_provider="ai21",
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
elif original_exception.status_code == 429:
|
elif original_exception.status_code == 429:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
|
@ -1490,7 +1490,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"AI21Exception - {original_exception.message}",
|
message=f"AI21Exception - {original_exception.message}",
|
||||||
llm_provider="ai21",
|
llm_provider="ai21",
|
||||||
model=model,
|
model=model,
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
|
@ -1509,7 +1509,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"NLPCloudException - {error_str}",
|
message=f"NLPCloudException - {error_str}",
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider="nlp_cloud",
|
llm_provider="nlp_cloud",
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
elif "value is not a valid" in error_str:
|
elif "value is not a valid" in error_str:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
|
@ -1517,7 +1517,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"NLPCloudException - {error_str}",
|
message=f"NLPCloudException - {error_str}",
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider="nlp_cloud",
|
llm_provider="nlp_cloud",
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
|
@ -1542,7 +1542,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"NLPCloudException - {original_exception.message}",
|
message=f"NLPCloudException - {original_exception.message}",
|
||||||
llm_provider="nlp_cloud",
|
llm_provider="nlp_cloud",
|
||||||
model=model,
|
model=model,
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
elif (
|
elif (
|
||||||
original_exception.status_code == 401
|
original_exception.status_code == 401
|
||||||
|
@ -1553,7 +1553,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"NLPCloudException - {original_exception.message}",
|
message=f"NLPCloudException - {original_exception.message}",
|
||||||
llm_provider="nlp_cloud",
|
llm_provider="nlp_cloud",
|
||||||
model=model,
|
model=model,
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
elif (
|
elif (
|
||||||
original_exception.status_code == 522
|
original_exception.status_code == 522
|
||||||
|
@ -1574,7 +1574,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"NLPCloudException - {original_exception.message}",
|
message=f"NLPCloudException - {original_exception.message}",
|
||||||
llm_provider="nlp_cloud",
|
llm_provider="nlp_cloud",
|
||||||
model=model,
|
model=model,
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
elif (
|
elif (
|
||||||
original_exception.status_code == 500
|
original_exception.status_code == 500
|
||||||
|
@ -1597,7 +1597,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"NLPCloudException - {original_exception.message}",
|
message=f"NLPCloudException - {original_exception.message}",
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider="nlp_cloud",
|
llm_provider="nlp_cloud",
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
|
@ -1623,7 +1623,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"TogetherAIException - {error_response['error']}",
|
message=f"TogetherAIException - {error_response['error']}",
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider="together_ai",
|
llm_provider="together_ai",
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
elif (
|
elif (
|
||||||
"error" in error_response
|
"error" in error_response
|
||||||
|
@ -1634,7 +1634,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"TogetherAIException - {error_response['error']}",
|
message=f"TogetherAIException - {error_response['error']}",
|
||||||
llm_provider="together_ai",
|
llm_provider="together_ai",
|
||||||
model=model,
|
model=model,
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
elif (
|
elif (
|
||||||
"error" in error_response
|
"error" in error_response
|
||||||
|
@ -1645,7 +1645,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"TogetherAIException - {error_response['error']}",
|
message=f"TogetherAIException - {error_response['error']}",
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider="together_ai",
|
llm_provider="together_ai",
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
elif "A timeout occurred" in error_str:
|
elif "A timeout occurred" in error_str:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
|
@ -1664,7 +1664,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"TogetherAIException - {error_response['error']}",
|
message=f"TogetherAIException - {error_response['error']}",
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider="together_ai",
|
llm_provider="together_ai",
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
elif (
|
elif (
|
||||||
"error_type" in error_response
|
"error_type" in error_response
|
||||||
|
@ -1675,7 +1675,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"TogetherAIException - {error_response['error']}",
|
message=f"TogetherAIException - {error_response['error']}",
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider="together_ai",
|
llm_provider="together_ai",
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
if hasattr(original_exception, "status_code"):
|
if hasattr(original_exception, "status_code"):
|
||||||
if original_exception.status_code == 408:
|
if original_exception.status_code == 408:
|
||||||
|
@ -1691,7 +1691,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"TogetherAIException - {error_response['error']}",
|
message=f"TogetherAIException - {error_response['error']}",
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider="together_ai",
|
llm_provider="together_ai",
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
elif original_exception.status_code == 429:
|
elif original_exception.status_code == 429:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
|
@ -1699,7 +1699,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"TogetherAIException - {original_exception.message}",
|
message=f"TogetherAIException - {original_exception.message}",
|
||||||
llm_provider="together_ai",
|
llm_provider="together_ai",
|
||||||
model=model,
|
model=model,
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
elif original_exception.status_code == 524:
|
elif original_exception.status_code == 524:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
|
@ -1727,7 +1727,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"AlephAlphaException - {original_exception.message}",
|
message=f"AlephAlphaException - {original_exception.message}",
|
||||||
llm_provider="aleph_alpha",
|
llm_provider="aleph_alpha",
|
||||||
model=model,
|
model=model,
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
elif "InvalidToken" in error_str or "No token provided" in error_str:
|
elif "InvalidToken" in error_str or "No token provided" in error_str:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
|
@ -1735,7 +1735,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"AlephAlphaException - {original_exception.message}",
|
message=f"AlephAlphaException - {original_exception.message}",
|
||||||
llm_provider="aleph_alpha",
|
llm_provider="aleph_alpha",
|
||||||
model=model,
|
model=model,
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
elif hasattr(original_exception, "status_code"):
|
elif hasattr(original_exception, "status_code"):
|
||||||
verbose_logger.debug(
|
verbose_logger.debug(
|
||||||
|
@ -1754,7 +1754,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"AlephAlphaException - {original_exception.message}",
|
message=f"AlephAlphaException - {original_exception.message}",
|
||||||
llm_provider="aleph_alpha",
|
llm_provider="aleph_alpha",
|
||||||
model=model,
|
model=model,
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
elif original_exception.status_code == 429:
|
elif original_exception.status_code == 429:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
|
@ -1762,7 +1762,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"AlephAlphaException - {original_exception.message}",
|
message=f"AlephAlphaException - {original_exception.message}",
|
||||||
llm_provider="aleph_alpha",
|
llm_provider="aleph_alpha",
|
||||||
model=model,
|
model=model,
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
elif original_exception.status_code == 500:
|
elif original_exception.status_code == 500:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
|
@ -1770,7 +1770,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"AlephAlphaException - {original_exception.message}",
|
message=f"AlephAlphaException - {original_exception.message}",
|
||||||
llm_provider="aleph_alpha",
|
llm_provider="aleph_alpha",
|
||||||
model=model,
|
model=model,
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
raise original_exception
|
raise original_exception
|
||||||
raise original_exception
|
raise original_exception
|
||||||
|
@ -1787,7 +1787,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"OllamaException: Invalid Model/Model not loaded - {original_exception}",
|
message=f"OllamaException: Invalid Model/Model not loaded - {original_exception}",
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider="ollama",
|
llm_provider="ollama",
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
elif "Failed to establish a new connection" in error_str:
|
elif "Failed to establish a new connection" in error_str:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
|
@ -1795,7 +1795,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"OllamaException: {original_exception}",
|
message=f"OllamaException: {original_exception}",
|
||||||
llm_provider="ollama",
|
llm_provider="ollama",
|
||||||
model=model,
|
model=model,
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
elif "Invalid response object from API" in error_str:
|
elif "Invalid response object from API" in error_str:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
|
@ -1803,7 +1803,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"OllamaException: {original_exception}",
|
message=f"OllamaException: {original_exception}",
|
||||||
llm_provider="ollama",
|
llm_provider="ollama",
|
||||||
model=model,
|
model=model,
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
elif "Read timed out" in error_str:
|
elif "Read timed out" in error_str:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
|
@ -1837,6 +1837,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
llm_provider="azure",
|
llm_provider="azure",
|
||||||
model=model,
|
model=model,
|
||||||
litellm_debug_info=extra_information,
|
litellm_debug_info=extra_information,
|
||||||
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
elif "This model's maximum context length is" in error_str:
|
elif "This model's maximum context length is" in error_str:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
|
@ -1845,6 +1846,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
llm_provider="azure",
|
llm_provider="azure",
|
||||||
model=model,
|
model=model,
|
||||||
litellm_debug_info=extra_information,
|
litellm_debug_info=extra_information,
|
||||||
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
elif "DeploymentNotFound" in error_str:
|
elif "DeploymentNotFound" in error_str:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
|
@ -1853,6 +1855,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
llm_provider="azure",
|
llm_provider="azure",
|
||||||
model=model,
|
model=model,
|
||||||
litellm_debug_info=extra_information,
|
litellm_debug_info=extra_information,
|
||||||
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
elif (
|
elif (
|
||||||
(
|
(
|
||||||
|
@ -1873,6 +1876,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
llm_provider="azure",
|
llm_provider="azure",
|
||||||
model=model,
|
model=model,
|
||||||
litellm_debug_info=extra_information,
|
litellm_debug_info=extra_information,
|
||||||
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
elif "invalid_request_error" in error_str:
|
elif "invalid_request_error" in error_str:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
|
@ -1881,6 +1885,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
llm_provider="azure",
|
llm_provider="azure",
|
||||||
model=model,
|
model=model,
|
||||||
litellm_debug_info=extra_information,
|
litellm_debug_info=extra_information,
|
||||||
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
elif (
|
elif (
|
||||||
"The api_key client option must be set either by passing api_key to the client or by setting"
|
"The api_key client option must be set either by passing api_key to the client or by setting"
|
||||||
|
@ -1892,6 +1897,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
llm_provider=custom_llm_provider,
|
llm_provider=custom_llm_provider,
|
||||||
model=model,
|
model=model,
|
||||||
litellm_debug_info=extra_information,
|
litellm_debug_info=extra_information,
|
||||||
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
elif "Connection error" in error_str:
|
elif "Connection error" in error_str:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
|
@ -1910,6 +1916,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
llm_provider="azure",
|
llm_provider="azure",
|
||||||
model=model,
|
model=model,
|
||||||
litellm_debug_info=extra_information,
|
litellm_debug_info=extra_information,
|
||||||
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
elif original_exception.status_code == 401:
|
elif original_exception.status_code == 401:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
|
@ -1918,6 +1925,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
llm_provider="azure",
|
llm_provider="azure",
|
||||||
model=model,
|
model=model,
|
||||||
litellm_debug_info=extra_information,
|
litellm_debug_info=extra_information,
|
||||||
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
elif original_exception.status_code == 408:
|
elif original_exception.status_code == 408:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
|
@ -1934,6 +1942,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider="azure",
|
llm_provider="azure",
|
||||||
litellm_debug_info=extra_information,
|
litellm_debug_info=extra_information,
|
||||||
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
elif original_exception.status_code == 429:
|
elif original_exception.status_code == 429:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
|
@ -1942,6 +1951,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider="azure",
|
llm_provider="azure",
|
||||||
litellm_debug_info=extra_information,
|
litellm_debug_info=extra_information,
|
||||||
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
elif original_exception.status_code == 503:
|
elif original_exception.status_code == 503:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
|
@ -1950,6 +1960,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider="azure",
|
llm_provider="azure",
|
||||||
litellm_debug_info=extra_information,
|
litellm_debug_info=extra_information,
|
||||||
|
response=getattr(original_exception, "response", None),
|
||||||
)
|
)
|
||||||
elif original_exception.status_code == 504: # gateway timeout error
|
elif original_exception.status_code == 504: # gateway timeout error
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
|
@ -1989,7 +2000,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"{exception_provider} - {error_str}",
|
message=f"{exception_provider} - {error_str}",
|
||||||
llm_provider=custom_llm_provider,
|
llm_provider=custom_llm_provider,
|
||||||
model=model,
|
model=model,
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
litellm_debug_info=extra_information,
|
litellm_debug_info=extra_information,
|
||||||
)
|
)
|
||||||
elif original_exception.status_code == 401:
|
elif original_exception.status_code == 401:
|
||||||
|
@ -1998,7 +2009,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"AuthenticationError: {exception_provider} - {error_str}",
|
message=f"AuthenticationError: {exception_provider} - {error_str}",
|
||||||
llm_provider=custom_llm_provider,
|
llm_provider=custom_llm_provider,
|
||||||
model=model,
|
model=model,
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
litellm_debug_info=extra_information,
|
litellm_debug_info=extra_information,
|
||||||
)
|
)
|
||||||
elif original_exception.status_code == 404:
|
elif original_exception.status_code == 404:
|
||||||
|
@ -2007,7 +2018,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"NotFoundError: {exception_provider} - {error_str}",
|
message=f"NotFoundError: {exception_provider} - {error_str}",
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider=custom_llm_provider,
|
llm_provider=custom_llm_provider,
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
litellm_debug_info=extra_information,
|
litellm_debug_info=extra_information,
|
||||||
)
|
)
|
||||||
elif original_exception.status_code == 408:
|
elif original_exception.status_code == 408:
|
||||||
|
@ -2024,7 +2035,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"BadRequestError: {exception_provider} - {error_str}",
|
message=f"BadRequestError: {exception_provider} - {error_str}",
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider=custom_llm_provider,
|
llm_provider=custom_llm_provider,
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
litellm_debug_info=extra_information,
|
litellm_debug_info=extra_information,
|
||||||
)
|
)
|
||||||
elif original_exception.status_code == 429:
|
elif original_exception.status_code == 429:
|
||||||
|
@ -2033,7 +2044,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"RateLimitError: {exception_provider} - {error_str}",
|
message=f"RateLimitError: {exception_provider} - {error_str}",
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider=custom_llm_provider,
|
llm_provider=custom_llm_provider,
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
litellm_debug_info=extra_information,
|
litellm_debug_info=extra_information,
|
||||||
)
|
)
|
||||||
elif original_exception.status_code == 503:
|
elif original_exception.status_code == 503:
|
||||||
|
@ -2042,7 +2053,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
message=f"ServiceUnavailableError: {exception_provider} - {error_str}",
|
message=f"ServiceUnavailableError: {exception_provider} - {error_str}",
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider=custom_llm_provider,
|
llm_provider=custom_llm_provider,
|
||||||
response=original_exception.response,
|
response=getattr(original_exception, "response", None),
|
||||||
litellm_debug_info=extra_information,
|
litellm_debug_info=extra_information,
|
||||||
)
|
)
|
||||||
elif original_exception.status_code == 504: # gateway timeout error
|
elif original_exception.status_code == 504: # gateway timeout error
|
||||||
|
|
|
@ -202,6 +202,7 @@ class Logging:
|
||||||
start_time,
|
start_time,
|
||||||
litellm_call_id: str,
|
litellm_call_id: str,
|
||||||
function_id: str,
|
function_id: str,
|
||||||
|
litellm_trace_id: Optional[str] = None,
|
||||||
dynamic_input_callbacks: Optional[
|
dynamic_input_callbacks: Optional[
|
||||||
List[Union[str, Callable, CustomLogger]]
|
List[Union[str, Callable, CustomLogger]]
|
||||||
] = None,
|
] = None,
|
||||||
|
@ -239,6 +240,7 @@ class Logging:
|
||||||
self.start_time = start_time # log the call start time
|
self.start_time = start_time # log the call start time
|
||||||
self.call_type = call_type
|
self.call_type = call_type
|
||||||
self.litellm_call_id = litellm_call_id
|
self.litellm_call_id = litellm_call_id
|
||||||
|
self.litellm_trace_id = litellm_trace_id
|
||||||
self.function_id = function_id
|
self.function_id = function_id
|
||||||
self.streaming_chunks: List[Any] = [] # for generating complete stream response
|
self.streaming_chunks: List[Any] = [] # for generating complete stream response
|
||||||
self.sync_streaming_chunks: List[Any] = (
|
self.sync_streaming_chunks: List[Any] = (
|
||||||
|
@ -275,6 +277,11 @@ class Logging:
|
||||||
self.completion_start_time: Optional[datetime.datetime] = None
|
self.completion_start_time: Optional[datetime.datetime] = None
|
||||||
self._llm_caching_handler: Optional[LLMCachingHandler] = None
|
self._llm_caching_handler: Optional[LLMCachingHandler] = None
|
||||||
|
|
||||||
|
self.model_call_details = {
|
||||||
|
"litellm_trace_id": litellm_trace_id,
|
||||||
|
"litellm_call_id": litellm_call_id,
|
||||||
|
}
|
||||||
|
|
||||||
def process_dynamic_callbacks(self):
|
def process_dynamic_callbacks(self):
|
||||||
"""
|
"""
|
||||||
Initializes CustomLogger compatible callbacks in self.dynamic_* callbacks
|
Initializes CustomLogger compatible callbacks in self.dynamic_* callbacks
|
||||||
|
@ -382,7 +389,8 @@ class Logging:
|
||||||
self.logger_fn = litellm_params.get("logger_fn", None)
|
self.logger_fn = litellm_params.get("logger_fn", None)
|
||||||
verbose_logger.debug(f"self.optional_params: {self.optional_params}")
|
verbose_logger.debug(f"self.optional_params: {self.optional_params}")
|
||||||
|
|
||||||
self.model_call_details = {
|
self.model_call_details.update(
|
||||||
|
{
|
||||||
"model": self.model,
|
"model": self.model,
|
||||||
"messages": self.messages,
|
"messages": self.messages,
|
||||||
"optional_params": self.optional_params,
|
"optional_params": self.optional_params,
|
||||||
|
@ -397,6 +405,7 @@ class Logging:
|
||||||
**self.optional_params,
|
**self.optional_params,
|
||||||
**additional_params,
|
**additional_params,
|
||||||
}
|
}
|
||||||
|
)
|
||||||
|
|
||||||
## check if stream options is set ## - used by CustomStreamWrapper for easy instrumentation
|
## check if stream options is set ## - used by CustomStreamWrapper for easy instrumentation
|
||||||
if "stream_options" in additional_params:
|
if "stream_options" in additional_params:
|
||||||
|
@ -2823,6 +2832,7 @@ def get_standard_logging_object_payload(
|
||||||
|
|
||||||
payload: StandardLoggingPayload = StandardLoggingPayload(
|
payload: StandardLoggingPayload = StandardLoggingPayload(
|
||||||
id=str(id),
|
id=str(id),
|
||||||
|
trace_id=kwargs.get("litellm_trace_id"), # type: ignore
|
||||||
call_type=call_type or "",
|
call_type=call_type or "",
|
||||||
cache_hit=cache_hit,
|
cache_hit=cache_hit,
|
||||||
status=status,
|
status=status,
|
||||||
|
|
|
@ -44,7 +44,9 @@ from litellm.types.llms.openai import (
|
||||||
ChatCompletionToolCallFunctionChunk,
|
ChatCompletionToolCallFunctionChunk,
|
||||||
ChatCompletionUsageBlock,
|
ChatCompletionUsageBlock,
|
||||||
)
|
)
|
||||||
from litellm.types.utils import GenericStreamingChunk, PromptTokensDetailsWrapper
|
from litellm.types.utils import GenericStreamingChunk
|
||||||
|
from litellm.types.utils import Message as LitellmMessage
|
||||||
|
from litellm.types.utils import PromptTokensDetailsWrapper
|
||||||
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
|
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
|
||||||
|
|
||||||
from ...base import BaseLLM
|
from ...base import BaseLLM
|
||||||
|
@ -94,6 +96,7 @@ async def make_call(
|
||||||
messages: list,
|
messages: list,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
timeout: Optional[Union[float, httpx.Timeout]],
|
timeout: Optional[Union[float, httpx.Timeout]],
|
||||||
|
json_mode: bool,
|
||||||
) -> Tuple[Any, httpx.Headers]:
|
) -> Tuple[Any, httpx.Headers]:
|
||||||
if client is None:
|
if client is None:
|
||||||
client = litellm.module_level_aclient
|
client = litellm.module_level_aclient
|
||||||
|
@ -119,7 +122,9 @@ async def make_call(
|
||||||
raise AnthropicError(status_code=500, message=str(e))
|
raise AnthropicError(status_code=500, message=str(e))
|
||||||
|
|
||||||
completion_stream = ModelResponseIterator(
|
completion_stream = ModelResponseIterator(
|
||||||
streaming_response=response.aiter_lines(), sync_stream=False
|
streaming_response=response.aiter_lines(),
|
||||||
|
sync_stream=False,
|
||||||
|
json_mode=json_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
# LOGGING
|
# LOGGING
|
||||||
|
@ -142,6 +147,7 @@ def make_sync_call(
|
||||||
messages: list,
|
messages: list,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
timeout: Optional[Union[float, httpx.Timeout]],
|
timeout: Optional[Union[float, httpx.Timeout]],
|
||||||
|
json_mode: bool,
|
||||||
) -> Tuple[Any, httpx.Headers]:
|
) -> Tuple[Any, httpx.Headers]:
|
||||||
if client is None:
|
if client is None:
|
||||||
client = litellm.module_level_client # re-use a module level client
|
client = litellm.module_level_client # re-use a module level client
|
||||||
|
@ -175,7 +181,7 @@ def make_sync_call(
|
||||||
)
|
)
|
||||||
|
|
||||||
completion_stream = ModelResponseIterator(
|
completion_stream = ModelResponseIterator(
|
||||||
streaming_response=response.iter_lines(), sync_stream=True
|
streaming_response=response.iter_lines(), sync_stream=True, json_mode=json_mode
|
||||||
)
|
)
|
||||||
|
|
||||||
# LOGGING
|
# LOGGING
|
||||||
|
@ -270,11 +276,12 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
"arguments"
|
"arguments"
|
||||||
)
|
)
|
||||||
if json_mode_content_str is not None:
|
if json_mode_content_str is not None:
|
||||||
args = json.loads(json_mode_content_str)
|
_converted_message = self._convert_tool_response_to_message(
|
||||||
values: Optional[dict] = args.get("values")
|
tool_calls=tool_calls,
|
||||||
if values is not None:
|
)
|
||||||
_message = litellm.Message(content=json.dumps(values))
|
if _converted_message is not None:
|
||||||
completion_response["stop_reason"] = "stop"
|
completion_response["stop_reason"] = "stop"
|
||||||
|
_message = _converted_message
|
||||||
model_response.choices[0].message = _message # type: ignore
|
model_response.choices[0].message = _message # type: ignore
|
||||||
model_response._hidden_params["original_response"] = completion_response[
|
model_response._hidden_params["original_response"] = completion_response[
|
||||||
"content"
|
"content"
|
||||||
|
@ -318,6 +325,37 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
model_response._hidden_params = _hidden_params
|
model_response._hidden_params = _hidden_params
|
||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _convert_tool_response_to_message(
|
||||||
|
tool_calls: List[ChatCompletionToolCallChunk],
|
||||||
|
) -> Optional[LitellmMessage]:
|
||||||
|
"""
|
||||||
|
In JSON mode, Anthropic API returns JSON schema as a tool call, we need to convert it to a message to follow the OpenAI format
|
||||||
|
|
||||||
|
"""
|
||||||
|
## HANDLE JSON MODE - anthropic returns single function call
|
||||||
|
json_mode_content_str: Optional[str] = tool_calls[0]["function"].get(
|
||||||
|
"arguments"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
if json_mode_content_str is not None:
|
||||||
|
args = json.loads(json_mode_content_str)
|
||||||
|
if (
|
||||||
|
isinstance(args, dict)
|
||||||
|
and (values := args.get("values")) is not None
|
||||||
|
):
|
||||||
|
_message = litellm.Message(content=json.dumps(values))
|
||||||
|
return _message
|
||||||
|
else:
|
||||||
|
# a lot of the times the `values` key is not present in the tool response
|
||||||
|
# relevant issue: https://github.com/BerriAI/litellm/issues/6741
|
||||||
|
_message = litellm.Message(content=json.dumps(args))
|
||||||
|
return _message
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# json decode error does occur, return the original tool response str
|
||||||
|
return litellm.Message(content=json_mode_content_str)
|
||||||
|
return None
|
||||||
|
|
||||||
async def acompletion_stream_function(
|
async def acompletion_stream_function(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
@ -334,6 +372,7 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
stream,
|
stream,
|
||||||
_is_function_call,
|
_is_function_call,
|
||||||
data: dict,
|
data: dict,
|
||||||
|
json_mode: bool,
|
||||||
optional_params=None,
|
optional_params=None,
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
|
@ -350,6 +389,7 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
messages=messages,
|
messages=messages,
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
|
json_mode=json_mode,
|
||||||
)
|
)
|
||||||
streamwrapper = CustomStreamWrapper(
|
streamwrapper = CustomStreamWrapper(
|
||||||
completion_stream=completion_stream,
|
completion_stream=completion_stream,
|
||||||
|
@ -440,8 +480,8 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
logging_obj,
|
logging_obj,
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
timeout: Union[float, httpx.Timeout],
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
litellm_params: dict,
|
||||||
acompletion=None,
|
acompletion=None,
|
||||||
litellm_params=None,
|
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
headers={},
|
headers={},
|
||||||
client=None,
|
client=None,
|
||||||
|
@ -464,6 +504,7 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
_is_function_call=_is_function_call,
|
_is_function_call=_is_function_call,
|
||||||
is_vertex_request=is_vertex_request,
|
is_vertex_request=is_vertex_request,
|
||||||
|
@ -500,6 +541,7 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
_is_function_call=_is_function_call,
|
_is_function_call=_is_function_call,
|
||||||
|
json_mode=json_mode,
|
||||||
litellm_params=litellm_params,
|
litellm_params=litellm_params,
|
||||||
logger_fn=logger_fn,
|
logger_fn=logger_fn,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
|
@ -547,6 +589,7 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
messages=messages,
|
messages=messages,
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
|
json_mode=json_mode,
|
||||||
)
|
)
|
||||||
return CustomStreamWrapper(
|
return CustomStreamWrapper(
|
||||||
completion_stream=completion_stream,
|
completion_stream=completion_stream,
|
||||||
|
@ -605,11 +648,14 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
|
|
||||||
|
|
||||||
class ModelResponseIterator:
|
class ModelResponseIterator:
|
||||||
def __init__(self, streaming_response, sync_stream: bool):
|
def __init__(
|
||||||
|
self, streaming_response, sync_stream: bool, json_mode: Optional[bool] = False
|
||||||
|
):
|
||||||
self.streaming_response = streaming_response
|
self.streaming_response = streaming_response
|
||||||
self.response_iterator = self.streaming_response
|
self.response_iterator = self.streaming_response
|
||||||
self.content_blocks: List[ContentBlockDelta] = []
|
self.content_blocks: List[ContentBlockDelta] = []
|
||||||
self.tool_index = -1
|
self.tool_index = -1
|
||||||
|
self.json_mode = json_mode
|
||||||
|
|
||||||
def check_empty_tool_call_args(self) -> bool:
|
def check_empty_tool_call_args(self) -> bool:
|
||||||
"""
|
"""
|
||||||
|
@ -771,6 +817,8 @@ class ModelResponseIterator:
|
||||||
status_code=500, # it looks like Anthropic API does not return a status code in the chunk error - default to 500
|
status_code=500, # it looks like Anthropic API does not return a status code in the chunk error - default to 500
|
||||||
)
|
)
|
||||||
|
|
||||||
|
text, tool_use = self._handle_json_mode_chunk(text=text, tool_use=tool_use)
|
||||||
|
|
||||||
returned_chunk = GenericStreamingChunk(
|
returned_chunk = GenericStreamingChunk(
|
||||||
text=text,
|
text=text,
|
||||||
tool_use=tool_use,
|
tool_use=tool_use,
|
||||||
|
@ -785,6 +833,34 @@ class ModelResponseIterator:
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
raise ValueError(f"Failed to decode JSON from chunk: {chunk}")
|
raise ValueError(f"Failed to decode JSON from chunk: {chunk}")
|
||||||
|
|
||||||
|
def _handle_json_mode_chunk(
|
||||||
|
self, text: str, tool_use: Optional[ChatCompletionToolCallChunk]
|
||||||
|
) -> Tuple[str, Optional[ChatCompletionToolCallChunk]]:
|
||||||
|
"""
|
||||||
|
If JSON mode is enabled, convert the tool call to a message.
|
||||||
|
|
||||||
|
Anthropic returns the JSON schema as part of the tool call
|
||||||
|
OpenAI returns the JSON schema as part of the content, this handles placing it in the content
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: str
|
||||||
|
tool_use: Optional[ChatCompletionToolCallChunk]
|
||||||
|
Returns:
|
||||||
|
Tuple[str, Optional[ChatCompletionToolCallChunk]]
|
||||||
|
|
||||||
|
text: The text to use in the content
|
||||||
|
tool_use: The ChatCompletionToolCallChunk to use in the chunk response
|
||||||
|
"""
|
||||||
|
if self.json_mode is True and tool_use is not None:
|
||||||
|
message = AnthropicChatCompletion._convert_tool_response_to_message(
|
||||||
|
tool_calls=[tool_use]
|
||||||
|
)
|
||||||
|
if message is not None:
|
||||||
|
text = message.content or ""
|
||||||
|
tool_use = None
|
||||||
|
|
||||||
|
return text, tool_use
|
||||||
|
|
||||||
# Sync iterator
|
# Sync iterator
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
return self
|
return self
|
||||||
|
|
|
@ -91,6 +91,7 @@ class AnthropicConfig:
|
||||||
"extra_headers",
|
"extra_headers",
|
||||||
"parallel_tool_calls",
|
"parallel_tool_calls",
|
||||||
"response_format",
|
"response_format",
|
||||||
|
"user",
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_cache_control_headers(self) -> dict:
|
def get_cache_control_headers(self) -> dict:
|
||||||
|
@ -246,6 +247,28 @@ class AnthropicConfig:
|
||||||
anthropic_tools.append(new_tool)
|
anthropic_tools.append(new_tool)
|
||||||
return anthropic_tools
|
return anthropic_tools
|
||||||
|
|
||||||
|
def _map_stop_sequences(
|
||||||
|
self, stop: Optional[Union[str, List[str]]]
|
||||||
|
) -> Optional[List[str]]:
|
||||||
|
new_stop: Optional[List[str]] = None
|
||||||
|
if isinstance(stop, str):
|
||||||
|
if (
|
||||||
|
stop == "\n"
|
||||||
|
) and litellm.drop_params is True: # anthropic doesn't allow whitespace characters as stop-sequences
|
||||||
|
return new_stop
|
||||||
|
new_stop = [stop]
|
||||||
|
elif isinstance(stop, list):
|
||||||
|
new_v = []
|
||||||
|
for v in stop:
|
||||||
|
if (
|
||||||
|
v == "\n"
|
||||||
|
) and litellm.drop_params is True: # anthropic doesn't allow whitespace characters as stop-sequences
|
||||||
|
continue
|
||||||
|
new_v.append(v)
|
||||||
|
if len(new_v) > 0:
|
||||||
|
new_stop = new_v
|
||||||
|
return new_stop
|
||||||
|
|
||||||
def map_openai_params(
|
def map_openai_params(
|
||||||
self,
|
self,
|
||||||
non_default_params: dict,
|
non_default_params: dict,
|
||||||
|
@ -271,26 +294,10 @@ class AnthropicConfig:
|
||||||
optional_params["tool_choice"] = _tool_choice
|
optional_params["tool_choice"] = _tool_choice
|
||||||
if param == "stream" and value is True:
|
if param == "stream" and value is True:
|
||||||
optional_params["stream"] = value
|
optional_params["stream"] = value
|
||||||
if param == "stop":
|
if param == "stop" and (isinstance(value, str) or isinstance(value, list)):
|
||||||
if isinstance(value, str):
|
_value = self._map_stop_sequences(value)
|
||||||
if (
|
if _value is not None:
|
||||||
value == "\n"
|
optional_params["stop_sequences"] = _value
|
||||||
) and litellm.drop_params is True: # anthropic doesn't allow whitespace characters as stop-sequences
|
|
||||||
continue
|
|
||||||
value = [value]
|
|
||||||
elif isinstance(value, list):
|
|
||||||
new_v = []
|
|
||||||
for v in value:
|
|
||||||
if (
|
|
||||||
v == "\n"
|
|
||||||
) and litellm.drop_params is True: # anthropic doesn't allow whitespace characters as stop-sequences
|
|
||||||
continue
|
|
||||||
new_v.append(v)
|
|
||||||
if len(new_v) > 0:
|
|
||||||
value = new_v
|
|
||||||
else:
|
|
||||||
continue
|
|
||||||
optional_params["stop_sequences"] = value
|
|
||||||
if param == "temperature":
|
if param == "temperature":
|
||||||
optional_params["temperature"] = value
|
optional_params["temperature"] = value
|
||||||
if param == "top_p":
|
if param == "top_p":
|
||||||
|
@ -314,7 +321,8 @@ class AnthropicConfig:
|
||||||
optional_params["tools"] = [_tool]
|
optional_params["tools"] = [_tool]
|
||||||
optional_params["tool_choice"] = _tool_choice
|
optional_params["tool_choice"] = _tool_choice
|
||||||
optional_params["json_mode"] = True
|
optional_params["json_mode"] = True
|
||||||
|
if param == "user":
|
||||||
|
optional_params["metadata"] = {"user_id": value}
|
||||||
## VALIDATE REQUEST
|
## VALIDATE REQUEST
|
||||||
"""
|
"""
|
||||||
Anthropic doesn't support tool calling without `tools=` param specified.
|
Anthropic doesn't support tool calling without `tools=` param specified.
|
||||||
|
@ -465,6 +473,7 @@ class AnthropicConfig:
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
headers: dict,
|
headers: dict,
|
||||||
_is_function_call: bool,
|
_is_function_call: bool,
|
||||||
is_vertex_request: bool,
|
is_vertex_request: bool,
|
||||||
|
@ -502,6 +511,15 @@ class AnthropicConfig:
|
||||||
if "tools" in optional_params:
|
if "tools" in optional_params:
|
||||||
_is_function_call = True
|
_is_function_call = True
|
||||||
|
|
||||||
|
## Handle user_id in metadata
|
||||||
|
_litellm_metadata = litellm_params.get("metadata", None)
|
||||||
|
if (
|
||||||
|
_litellm_metadata
|
||||||
|
and isinstance(_litellm_metadata, dict)
|
||||||
|
and "user_id" in _litellm_metadata
|
||||||
|
):
|
||||||
|
optional_params["metadata"] = {"user_id": _litellm_metadata["user_id"]}
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"messages": anthropic_messages,
|
"messages": anthropic_messages,
|
||||||
**optional_params,
|
**optional_params,
|
||||||
|
|
|
@ -53,8 +53,14 @@ class AmazonStability3Config:
|
||||||
sd3-medium
|
sd3-medium
|
||||||
sd3.5-large
|
sd3.5-large
|
||||||
sd3.5-large-turbo
|
sd3.5-large-turbo
|
||||||
|
|
||||||
|
Stability ultra models
|
||||||
|
stable-image-ultra-v1
|
||||||
"""
|
"""
|
||||||
if model and ("sd3" in model or "sd3.5" in model):
|
if model:
|
||||||
|
if "sd3" in model or "sd3.5" in model:
|
||||||
|
return True
|
||||||
|
if "stable-image-ultra-v1" in model:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
|
@ -8,3 +8,4 @@ class httpxSpecialProvider(str, Enum):
|
||||||
GuardrailCallback = "guardrail_callback"
|
GuardrailCallback = "guardrail_callback"
|
||||||
Caching = "caching"
|
Caching = "caching"
|
||||||
Oauth2Check = "oauth2_check"
|
Oauth2Check = "oauth2_check"
|
||||||
|
SecretManager = "secret_manager"
|
||||||
|
|
|
@ -76,4 +76,4 @@ class JinaAIEmbeddingConfig:
|
||||||
or get_secret_str("JINA_AI_API_KEY")
|
or get_secret_str("JINA_AI_API_KEY")
|
||||||
or get_secret_str("JINA_AI_TOKEN")
|
or get_secret_str("JINA_AI_TOKEN")
|
||||||
)
|
)
|
||||||
return LlmProviders.OPENAI_LIKE.value, api_base, dynamic_api_key
|
return LlmProviders.JINA_AI.value, api_base, dynamic_api_key
|
||||||
|
|
96
litellm/llms/jina_ai/rerank/handler.py
Normal file
96
litellm/llms/jina_ai/rerank/handler.py
Normal file
|
@ -0,0 +1,96 @@
|
||||||
|
"""
|
||||||
|
Re rank api
|
||||||
|
|
||||||
|
LiteLLM supports the re rank API format, no paramter transformation occurs
|
||||||
|
"""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm.llms.base import BaseLLM
|
||||||
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
|
_get_httpx_client,
|
||||||
|
get_async_httpx_client,
|
||||||
|
)
|
||||||
|
from litellm.llms.jina_ai.rerank.transformation import JinaAIRerankConfig
|
||||||
|
from litellm.types.rerank import RerankRequest, RerankResponse
|
||||||
|
|
||||||
|
|
||||||
|
class JinaAIRerank(BaseLLM):
|
||||||
|
def rerank(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
api_key: str,
|
||||||
|
query: str,
|
||||||
|
documents: List[Union[str, Dict[str, Any]]],
|
||||||
|
top_n: Optional[int] = None,
|
||||||
|
rank_fields: Optional[List[str]] = None,
|
||||||
|
return_documents: Optional[bool] = True,
|
||||||
|
max_chunks_per_doc: Optional[int] = None,
|
||||||
|
_is_async: Optional[bool] = False,
|
||||||
|
) -> RerankResponse:
|
||||||
|
client = _get_httpx_client()
|
||||||
|
|
||||||
|
request_data = RerankRequest(
|
||||||
|
model=model,
|
||||||
|
query=query,
|
||||||
|
top_n=top_n,
|
||||||
|
documents=documents,
|
||||||
|
rank_fields=rank_fields,
|
||||||
|
return_documents=return_documents,
|
||||||
|
)
|
||||||
|
|
||||||
|
# exclude None values from request_data
|
||||||
|
request_data_dict = request_data.dict(exclude_none=True)
|
||||||
|
|
||||||
|
if _is_async:
|
||||||
|
return self.async_rerank(request_data_dict, api_key) # type: ignore # Call async method
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"https://api.jina.ai/v1/rerank",
|
||||||
|
headers={
|
||||||
|
"accept": "application/json",
|
||||||
|
"content-type": "application/json",
|
||||||
|
"authorization": f"Bearer {api_key}",
|
||||||
|
},
|
||||||
|
json=request_data_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise Exception(response.text)
|
||||||
|
|
||||||
|
_json_response = response.json()
|
||||||
|
|
||||||
|
return JinaAIRerankConfig()._transform_response(_json_response)
|
||||||
|
|
||||||
|
async def async_rerank( # New async method
|
||||||
|
self,
|
||||||
|
request_data_dict: Dict[str, Any],
|
||||||
|
api_key: str,
|
||||||
|
) -> RerankResponse:
|
||||||
|
client = get_async_httpx_client(
|
||||||
|
llm_provider=litellm.LlmProviders.JINA_AI
|
||||||
|
) # Use async client
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
"https://api.jina.ai/v1/rerank",
|
||||||
|
headers={
|
||||||
|
"accept": "application/json",
|
||||||
|
"content-type": "application/json",
|
||||||
|
"authorization": f"Bearer {api_key}",
|
||||||
|
},
|
||||||
|
json=request_data_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise Exception(response.text)
|
||||||
|
|
||||||
|
_json_response = response.json()
|
||||||
|
|
||||||
|
return JinaAIRerankConfig()._transform_response(_json_response)
|
||||||
|
|
||||||
|
pass
|
36
litellm/llms/jina_ai/rerank/transformation.py
Normal file
36
litellm/llms/jina_ai/rerank/transformation.py
Normal file
|
@ -0,0 +1,36 @@
|
||||||
|
"""
|
||||||
|
Transformation logic from Cohere's /v1/rerank format to Jina AI's `/v1/rerank` format.
|
||||||
|
|
||||||
|
Why separate file? Make it easy to see how transformation works
|
||||||
|
|
||||||
|
Docs - https://jina.ai/reranker
|
||||||
|
"""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from litellm.types.rerank import (
|
||||||
|
RerankBilledUnits,
|
||||||
|
RerankResponse,
|
||||||
|
RerankResponseMeta,
|
||||||
|
RerankTokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class JinaAIRerankConfig:
|
||||||
|
def _transform_response(self, response: dict) -> RerankResponse:
|
||||||
|
|
||||||
|
_billed_units = RerankBilledUnits(**response.get("usage", {}))
|
||||||
|
_tokens = RerankTokens(**response.get("usage", {}))
|
||||||
|
rerank_meta = RerankResponseMeta(billed_units=_billed_units, tokens=_tokens)
|
||||||
|
|
||||||
|
_results: Optional[List[dict]] = response.get("results")
|
||||||
|
|
||||||
|
if _results is None:
|
||||||
|
raise ValueError(f"No results found in the response={response}")
|
||||||
|
|
||||||
|
return RerankResponse(
|
||||||
|
id=response.get("id") or str(uuid.uuid4()),
|
||||||
|
results=_results,
|
||||||
|
meta=rerank_meta,
|
||||||
|
) # Return response
|
|
@ -185,6 +185,8 @@ class OllamaConfig:
|
||||||
"name": "mistral"
|
"name": "mistral"
|
||||||
}'
|
}'
|
||||||
"""
|
"""
|
||||||
|
if model.startswith("ollama/") or model.startswith("ollama_chat/"):
|
||||||
|
model = model.split("/", 1)[1]
|
||||||
api_base = get_secret_str("OLLAMA_API_BASE") or "http://localhost:11434"
|
api_base = get_secret_str("OLLAMA_API_BASE") or "http://localhost:11434"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -15,7 +15,14 @@ from litellm.llms.custom_httpx.http_handler import (
|
||||||
_get_httpx_client,
|
_get_httpx_client,
|
||||||
get_async_httpx_client,
|
get_async_httpx_client,
|
||||||
)
|
)
|
||||||
from litellm.types.rerank import RerankRequest, RerankResponse
|
from litellm.llms.together_ai.rerank.transformation import TogetherAIRerankConfig
|
||||||
|
from litellm.types.rerank import (
|
||||||
|
RerankBilledUnits,
|
||||||
|
RerankRequest,
|
||||||
|
RerankResponse,
|
||||||
|
RerankResponseMeta,
|
||||||
|
RerankTokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TogetherAIRerank(BaseLLM):
|
class TogetherAIRerank(BaseLLM):
|
||||||
|
@ -65,13 +72,7 @@ class TogetherAIRerank(BaseLLM):
|
||||||
|
|
||||||
_json_response = response.json()
|
_json_response = response.json()
|
||||||
|
|
||||||
response = RerankResponse(
|
return TogetherAIRerankConfig()._transform_response(_json_response)
|
||||||
id=_json_response.get("id"),
|
|
||||||
results=_json_response.get("results"),
|
|
||||||
meta=_json_response.get("meta") or {},
|
|
||||||
)
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
async def async_rerank( # New async method
|
async def async_rerank( # New async method
|
||||||
self,
|
self,
|
||||||
|
@ -97,10 +98,4 @@ class TogetherAIRerank(BaseLLM):
|
||||||
|
|
||||||
_json_response = response.json()
|
_json_response = response.json()
|
||||||
|
|
||||||
return RerankResponse(
|
return TogetherAIRerankConfig()._transform_response(_json_response)
|
||||||
id=_json_response.get("id"),
|
|
||||||
results=_json_response.get("results"),
|
|
||||||
meta=_json_response.get("meta") or {},
|
|
||||||
) # Return response
|
|
||||||
|
|
||||||
pass
|
|
34
litellm/llms/together_ai/rerank/transformation.py
Normal file
34
litellm/llms/together_ai/rerank/transformation.py
Normal file
|
@ -0,0 +1,34 @@
|
||||||
|
"""
|
||||||
|
Transformation logic from Cohere's /v1/rerank format to Together AI's `/v1/rerank` format.
|
||||||
|
|
||||||
|
Why separate file? Make it easy to see how transformation works
|
||||||
|
"""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from litellm.types.rerank import (
|
||||||
|
RerankBilledUnits,
|
||||||
|
RerankResponse,
|
||||||
|
RerankResponseMeta,
|
||||||
|
RerankTokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TogetherAIRerankConfig:
|
||||||
|
def _transform_response(self, response: dict) -> RerankResponse:
|
||||||
|
|
||||||
|
_billed_units = RerankBilledUnits(**response.get("usage", {}))
|
||||||
|
_tokens = RerankTokens(**response.get("usage", {}))
|
||||||
|
rerank_meta = RerankResponseMeta(billed_units=_billed_units, tokens=_tokens)
|
||||||
|
|
||||||
|
_results: Optional[List[dict]] = response.get("results")
|
||||||
|
|
||||||
|
if _results is None:
|
||||||
|
raise ValueError(f"No results found in the response={response}")
|
||||||
|
|
||||||
|
return RerankResponse(
|
||||||
|
id=response.get("id") or str(uuid.uuid4()),
|
||||||
|
results=_results,
|
||||||
|
meta=rerank_meta,
|
||||||
|
) # Return response
|
|
@ -89,6 +89,9 @@ def _get_vertex_url(
|
||||||
elif mode == "embedding":
|
elif mode == "embedding":
|
||||||
endpoint = "predict"
|
endpoint = "predict"
|
||||||
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}"
|
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}"
|
||||||
|
if model.isdigit():
|
||||||
|
# https://us-central1-aiplatform.googleapis.com/v1/projects/$PROJECT_ID/locations/us-central1/endpoints/$ENDPOINT_ID:predict
|
||||||
|
url = f"https://{vertex_location}-aiplatform.googleapis.com/{vertex_api_version}/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}:{endpoint}"
|
||||||
|
|
||||||
if not url or not endpoint:
|
if not url or not endpoint:
|
||||||
raise ValueError(f"Unable to get vertex url/endpoint for mode: {mode}")
|
raise ValueError(f"Unable to get vertex url/endpoint for mode: {mode}")
|
||||||
|
|
|
@ -0,0 +1,25 @@
|
||||||
|
"""
|
||||||
|
Vertex AI Image Generation Cost Calculator
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm.types.utils import ImageResponse
|
||||||
|
|
||||||
|
|
||||||
|
def cost_calculator(
|
||||||
|
model: str,
|
||||||
|
image_response: ImageResponse,
|
||||||
|
) -> float:
|
||||||
|
"""
|
||||||
|
Vertex AI Image Generation Cost Calculator
|
||||||
|
"""
|
||||||
|
_model_info = litellm.get_model_info(
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider="vertex_ai",
|
||||||
|
)
|
||||||
|
|
||||||
|
output_cost_per_image: float = _model_info.get("output_cost_per_image") or 0.0
|
||||||
|
num_images: int = len(image_response.data)
|
||||||
|
return output_cost_per_image * num_images
|
|
@ -96,7 +96,7 @@ class VertexEmbedding(VertexBase):
|
||||||
headers = self.set_headers(auth_header=auth_header, extra_headers=extra_headers)
|
headers = self.set_headers(auth_header=auth_header, extra_headers=extra_headers)
|
||||||
vertex_request: VertexEmbeddingRequest = (
|
vertex_request: VertexEmbeddingRequest = (
|
||||||
litellm.vertexAITextEmbeddingConfig.transform_openai_request_to_vertex_embedding_request(
|
litellm.vertexAITextEmbeddingConfig.transform_openai_request_to_vertex_embedding_request(
|
||||||
input=input, optional_params=optional_params
|
input=input, optional_params=optional_params, model=model
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -188,7 +188,7 @@ class VertexEmbedding(VertexBase):
|
||||||
headers = self.set_headers(auth_header=auth_header, extra_headers=extra_headers)
|
headers = self.set_headers(auth_header=auth_header, extra_headers=extra_headers)
|
||||||
vertex_request: VertexEmbeddingRequest = (
|
vertex_request: VertexEmbeddingRequest = (
|
||||||
litellm.vertexAITextEmbeddingConfig.transform_openai_request_to_vertex_embedding_request(
|
litellm.vertexAITextEmbeddingConfig.transform_openai_request_to_vertex_embedding_request(
|
||||||
input=input, optional_params=optional_params
|
input=input, optional_params=optional_params, model=model
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -101,11 +101,16 @@ class VertexAITextEmbeddingConfig(BaseModel):
|
||||||
return optional_params
|
return optional_params
|
||||||
|
|
||||||
def transform_openai_request_to_vertex_embedding_request(
|
def transform_openai_request_to_vertex_embedding_request(
|
||||||
self, input: Union[list, str], optional_params: dict
|
self, input: Union[list, str], optional_params: dict, model: str
|
||||||
) -> VertexEmbeddingRequest:
|
) -> VertexEmbeddingRequest:
|
||||||
"""
|
"""
|
||||||
Transforms an openai request to a vertex embedding request.
|
Transforms an openai request to a vertex embedding request.
|
||||||
"""
|
"""
|
||||||
|
if model.isdigit():
|
||||||
|
return self._transform_openai_request_to_fine_tuned_embedding_request(
|
||||||
|
input, optional_params, model
|
||||||
|
)
|
||||||
|
|
||||||
vertex_request: VertexEmbeddingRequest = VertexEmbeddingRequest()
|
vertex_request: VertexEmbeddingRequest = VertexEmbeddingRequest()
|
||||||
vertex_text_embedding_input_list: List[TextEmbeddingInput] = []
|
vertex_text_embedding_input_list: List[TextEmbeddingInput] = []
|
||||||
task_type: Optional[TaskType] = optional_params.get("task_type")
|
task_type: Optional[TaskType] = optional_params.get("task_type")
|
||||||
|
@ -125,6 +130,47 @@ class VertexAITextEmbeddingConfig(BaseModel):
|
||||||
|
|
||||||
return vertex_request
|
return vertex_request
|
||||||
|
|
||||||
|
def _transform_openai_request_to_fine_tuned_embedding_request(
|
||||||
|
self, input: Union[list, str], optional_params: dict, model: str
|
||||||
|
) -> VertexEmbeddingRequest:
|
||||||
|
"""
|
||||||
|
Transforms an openai request to a vertex fine-tuned embedding request.
|
||||||
|
|
||||||
|
Vertex Doc: https://console.cloud.google.com/vertex-ai/model-garden?hl=en&project=adroit-crow-413218&pageState=(%22galleryStateKey%22:(%22f%22:(%22g%22:%5B%5D,%22o%22:%5B%5D),%22s%22:%22%22))
|
||||||
|
Sample Request:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"instances" : [
|
||||||
|
{
|
||||||
|
"inputs": "How would the Future of AI in 10 Years look?",
|
||||||
|
"parameters": {
|
||||||
|
"max_new_tokens": 128,
|
||||||
|
"temperature": 1.0,
|
||||||
|
"top_p": 0.9,
|
||||||
|
"top_k": 10
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
vertex_request: VertexEmbeddingRequest = VertexEmbeddingRequest()
|
||||||
|
vertex_text_embedding_input_list: List[TextEmbeddingFineTunedInput] = []
|
||||||
|
if isinstance(input, str):
|
||||||
|
input = [input] # Convert single string to list for uniform processing
|
||||||
|
|
||||||
|
for text in input:
|
||||||
|
embedding_input = TextEmbeddingFineTunedInput(inputs=text)
|
||||||
|
vertex_text_embedding_input_list.append(embedding_input)
|
||||||
|
|
||||||
|
vertex_request["instances"] = vertex_text_embedding_input_list
|
||||||
|
vertex_request["parameters"] = TextEmbeddingFineTunedParameters(
|
||||||
|
**optional_params
|
||||||
|
)
|
||||||
|
|
||||||
|
return vertex_request
|
||||||
|
|
||||||
def create_embedding_input(
|
def create_embedding_input(
|
||||||
self,
|
self,
|
||||||
content: str,
|
content: str,
|
||||||
|
@ -157,6 +203,11 @@ class VertexAITextEmbeddingConfig(BaseModel):
|
||||||
"""
|
"""
|
||||||
Transforms a vertex embedding response to an openai response.
|
Transforms a vertex embedding response to an openai response.
|
||||||
"""
|
"""
|
||||||
|
if model.isdigit():
|
||||||
|
return self._transform_vertex_response_to_openai_for_fine_tuned_models(
|
||||||
|
response, model, model_response
|
||||||
|
)
|
||||||
|
|
||||||
_predictions = response["predictions"]
|
_predictions = response["predictions"]
|
||||||
|
|
||||||
embedding_response = []
|
embedding_response = []
|
||||||
|
@ -181,3 +232,35 @@ class VertexAITextEmbeddingConfig(BaseModel):
|
||||||
)
|
)
|
||||||
setattr(model_response, "usage", usage)
|
setattr(model_response, "usage", usage)
|
||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
|
def _transform_vertex_response_to_openai_for_fine_tuned_models(
|
||||||
|
self, response: dict, model: str, model_response: litellm.EmbeddingResponse
|
||||||
|
) -> litellm.EmbeddingResponse:
|
||||||
|
"""
|
||||||
|
Transforms a vertex fine-tuned model embedding response to an openai response format.
|
||||||
|
"""
|
||||||
|
_predictions = response["predictions"]
|
||||||
|
|
||||||
|
embedding_response = []
|
||||||
|
# For fine-tuned models, we don't get token counts in the response
|
||||||
|
input_tokens = 0
|
||||||
|
|
||||||
|
for idx, embedding_values in enumerate(_predictions):
|
||||||
|
embedding_response.append(
|
||||||
|
{
|
||||||
|
"object": "embedding",
|
||||||
|
"index": idx,
|
||||||
|
"embedding": embedding_values[
|
||||||
|
0
|
||||||
|
], # The embedding values are nested one level deeper
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
model_response.object = "list"
|
||||||
|
model_response.data = embedding_response
|
||||||
|
model_response.model = model
|
||||||
|
usage = Usage(
|
||||||
|
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
|
||||||
|
)
|
||||||
|
setattr(model_response, "usage", usage)
|
||||||
|
return model_response
|
||||||
|
|
|
@ -23,14 +23,27 @@ class TextEmbeddingInput(TypedDict, total=False):
|
||||||
title: Optional[str]
|
title: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
|
# Fine-tuned models require a different input format
|
||||||
|
# Ref: https://console.cloud.google.com/vertex-ai/model-garden?hl=en&project=adroit-crow-413218&pageState=(%22galleryStateKey%22:(%22f%22:(%22g%22:%5B%5D,%22o%22:%5B%5D),%22s%22:%22%22))
|
||||||
|
class TextEmbeddingFineTunedInput(TypedDict, total=False):
|
||||||
|
inputs: str
|
||||||
|
|
||||||
|
|
||||||
|
class TextEmbeddingFineTunedParameters(TypedDict, total=False):
|
||||||
|
max_new_tokens: Optional[int]
|
||||||
|
temperature: Optional[float]
|
||||||
|
top_p: Optional[float]
|
||||||
|
top_k: Optional[int]
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingParameters(TypedDict, total=False):
|
class EmbeddingParameters(TypedDict, total=False):
|
||||||
auto_truncate: Optional[bool]
|
auto_truncate: Optional[bool]
|
||||||
output_dimensionality: Optional[int]
|
output_dimensionality: Optional[int]
|
||||||
|
|
||||||
|
|
||||||
class VertexEmbeddingRequest(TypedDict, total=False):
|
class VertexEmbeddingRequest(TypedDict, total=False):
|
||||||
instances: List[TextEmbeddingInput]
|
instances: Union[List[TextEmbeddingInput], List[TextEmbeddingFineTunedInput]]
|
||||||
parameters: Optional[EmbeddingParameters]
|
parameters: Optional[Union[EmbeddingParameters, TextEmbeddingFineTunedParameters]]
|
||||||
|
|
||||||
|
|
||||||
# Example usage:
|
# Example usage:
|
||||||
|
|
|
@ -1066,6 +1066,7 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
azure_ad_token_provider=kwargs.get("azure_ad_token_provider"),
|
azure_ad_token_provider=kwargs.get("azure_ad_token_provider"),
|
||||||
user_continue_message=kwargs.get("user_continue_message"),
|
user_continue_message=kwargs.get("user_continue_message"),
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
|
litellm_trace_id=kwargs.get("litellm_trace_id"),
|
||||||
)
|
)
|
||||||
logging.update_environment_variables(
|
logging.update_environment_variables(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -3455,7 +3456,7 @@ def embedding( # noqa: PLR0915
|
||||||
client=client,
|
client=client,
|
||||||
aembedding=aembedding,
|
aembedding=aembedding,
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "openai_like":
|
elif custom_llm_provider == "openai_like" or custom_llm_provider == "jina_ai":
|
||||||
api_base = (
|
api_base = (
|
||||||
api_base or litellm.api_base or get_secret_str("OPENAI_LIKE_API_BASE")
|
api_base or litellm.api_base or get_secret_str("OPENAI_LIKE_API_BASE")
|
||||||
)
|
)
|
||||||
|
|
|
@ -2986,19 +2986,19 @@
|
||||||
"supports_function_calling": true
|
"supports_function_calling": true
|
||||||
},
|
},
|
||||||
"vertex_ai/imagegeneration@006": {
|
"vertex_ai/imagegeneration@006": {
|
||||||
"cost_per_image": 0.020,
|
"output_cost_per_image": 0.020,
|
||||||
"litellm_provider": "vertex_ai-image-models",
|
"litellm_provider": "vertex_ai-image-models",
|
||||||
"mode": "image_generation",
|
"mode": "image_generation",
|
||||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing"
|
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing"
|
||||||
},
|
},
|
||||||
"vertex_ai/imagen-3.0-generate-001": {
|
"vertex_ai/imagen-3.0-generate-001": {
|
||||||
"cost_per_image": 0.04,
|
"output_cost_per_image": 0.04,
|
||||||
"litellm_provider": "vertex_ai-image-models",
|
"litellm_provider": "vertex_ai-image-models",
|
||||||
"mode": "image_generation",
|
"mode": "image_generation",
|
||||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing"
|
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing"
|
||||||
},
|
},
|
||||||
"vertex_ai/imagen-3.0-fast-generate-001": {
|
"vertex_ai/imagen-3.0-fast-generate-001": {
|
||||||
"cost_per_image": 0.02,
|
"output_cost_per_image": 0.02,
|
||||||
"litellm_provider": "vertex_ai-image-models",
|
"litellm_provider": "vertex_ai-image-models",
|
||||||
"mode": "image_generation",
|
"mode": "image_generation",
|
||||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing"
|
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing"
|
||||||
|
@ -5620,6 +5620,13 @@
|
||||||
"litellm_provider": "bedrock",
|
"litellm_provider": "bedrock",
|
||||||
"mode": "image_generation"
|
"mode": "image_generation"
|
||||||
},
|
},
|
||||||
|
"stability.stable-image-ultra-v1:0": {
|
||||||
|
"max_tokens": 77,
|
||||||
|
"max_input_tokens": 77,
|
||||||
|
"output_cost_per_image": 0.14,
|
||||||
|
"litellm_provider": "bedrock",
|
||||||
|
"mode": "image_generation"
|
||||||
|
},
|
||||||
"sagemaker/meta-textgeneration-llama-2-7b": {
|
"sagemaker/meta-textgeneration-llama-2-7b": {
|
||||||
"max_tokens": 4096,
|
"max_tokens": 4096,
|
||||||
"max_input_tokens": 4096,
|
"max_input_tokens": 4096,
|
||||||
|
|
|
@ -1,122 +1,15 @@
|
||||||
model_list:
|
model_list:
|
||||||
- model_name: "*"
|
# GPT-4 Turbo Models
|
||||||
litellm_params:
|
|
||||||
model: claude-3-5-sonnet-20240620
|
|
||||||
api_key: os.environ/ANTHROPIC_API_KEY
|
|
||||||
- model_name: claude-3-5-sonnet-aihubmix
|
|
||||||
litellm_params:
|
|
||||||
model: openai/claude-3-5-sonnet-20240620
|
|
||||||
input_cost_per_token: 0.000003 # 3$/M
|
|
||||||
output_cost_per_token: 0.000015 # 15$/M
|
|
||||||
api_base: "https://exampleopenaiendpoint-production.up.railway.app"
|
|
||||||
api_key: my-fake-key
|
|
||||||
- model_name: fake-openai-endpoint-2
|
|
||||||
litellm_params:
|
|
||||||
model: openai/my-fake-model
|
|
||||||
api_key: my-fake-key
|
|
||||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
|
||||||
stream_timeout: 0.001
|
|
||||||
timeout: 1
|
|
||||||
rpm: 1
|
|
||||||
- model_name: fake-openai-endpoint
|
|
||||||
litellm_params:
|
|
||||||
model: openai/my-fake-model
|
|
||||||
api_key: my-fake-key
|
|
||||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
|
||||||
## bedrock chat completions
|
|
||||||
- model_name: "*anthropic.claude*"
|
|
||||||
litellm_params:
|
|
||||||
model: bedrock/*anthropic.claude*
|
|
||||||
aws_access_key_id: os.environ/BEDROCK_AWS_ACCESS_KEY_ID
|
|
||||||
aws_secret_access_key: os.environ/BEDROCK_AWS_SECRET_ACCESS_KEY
|
|
||||||
aws_region_name: os.environ/AWS_REGION_NAME
|
|
||||||
guardrailConfig:
|
|
||||||
"guardrailIdentifier": "h4dsqwhp6j66"
|
|
||||||
"guardrailVersion": "2"
|
|
||||||
"trace": "enabled"
|
|
||||||
|
|
||||||
## bedrock embeddings
|
|
||||||
- model_name: "*amazon.titan-embed-*"
|
|
||||||
litellm_params:
|
|
||||||
model: bedrock/amazon.titan-embed-*
|
|
||||||
aws_access_key_id: os.environ/BEDROCK_AWS_ACCESS_KEY_ID
|
|
||||||
aws_secret_access_key: os.environ/BEDROCK_AWS_SECRET_ACCESS_KEY
|
|
||||||
aws_region_name: os.environ/AWS_REGION_NAME
|
|
||||||
- model_name: "*cohere.embed-*"
|
|
||||||
litellm_params:
|
|
||||||
model: bedrock/cohere.embed-*
|
|
||||||
aws_access_key_id: os.environ/BEDROCK_AWS_ACCESS_KEY_ID
|
|
||||||
aws_secret_access_key: os.environ/BEDROCK_AWS_SECRET_ACCESS_KEY
|
|
||||||
aws_region_name: os.environ/AWS_REGION_NAME
|
|
||||||
|
|
||||||
- model_name: "bedrock/*"
|
|
||||||
litellm_params:
|
|
||||||
model: bedrock/*
|
|
||||||
aws_access_key_id: os.environ/BEDROCK_AWS_ACCESS_KEY_ID
|
|
||||||
aws_secret_access_key: os.environ/BEDROCK_AWS_SECRET_ACCESS_KEY
|
|
||||||
aws_region_name: os.environ/AWS_REGION_NAME
|
|
||||||
|
|
||||||
- model_name: gpt-4
|
- model_name: gpt-4
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: azure/chatgpt-v-2
|
model: gpt-4
|
||||||
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
|
- model_name: rerank-model
|
||||||
api_version: "2023-05-15"
|
litellm_params:
|
||||||
api_key: os.environ/AZURE_API_KEY # The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault
|
model: jina_ai/jina-reranker-v2-base-multilingual
|
||||||
rpm: 480
|
|
||||||
timeout: 300
|
|
||||||
stream_timeout: 60
|
|
||||||
|
|
||||||
litellm_settings:
|
|
||||||
fallbacks: [{ "claude-3-5-sonnet-20240620": ["claude-3-5-sonnet-aihubmix"] }]
|
|
||||||
# callbacks: ["otel", "prometheus"]
|
|
||||||
default_redis_batch_cache_expiry: 10
|
|
||||||
# default_team_settings:
|
|
||||||
# - team_id: "dbe2f686-a686-4896-864a-4c3924458709"
|
|
||||||
# success_callback: ["langfuse"]
|
|
||||||
# langfuse_public_key: os.environ/LANGFUSE_PUB_KEY_1 # Project 1
|
|
||||||
# langfuse_secret: os.environ/LANGFUSE_PRIVATE_KEY_1 # Project 1
|
|
||||||
|
|
||||||
# litellm_settings:
|
|
||||||
# cache: True
|
|
||||||
# cache_params:
|
|
||||||
# type: redis
|
|
||||||
|
|
||||||
# # disable caching on the actual API call
|
|
||||||
# supported_call_types: []
|
|
||||||
|
|
||||||
# # see https://docs.litellm.ai/docs/proxy/prod#3-use-redis-porthost-password-not-redis_url
|
|
||||||
# host: os.environ/REDIS_HOST
|
|
||||||
# port: os.environ/REDIS_PORT
|
|
||||||
# password: os.environ/REDIS_PASSWORD
|
|
||||||
|
|
||||||
# # see https://docs.litellm.ai/docs/proxy/caching#turn-on-batch_redis_requests
|
|
||||||
# # see https://docs.litellm.ai/docs/proxy/prometheus
|
|
||||||
# callbacks: ['otel']
|
|
||||||
|
|
||||||
|
|
||||||
# # router_settings:
|
router_settings:
|
||||||
# # routing_strategy: latency-based-routing
|
model_group_alias:
|
||||||
# # routing_strategy_args:
|
"gpt-4-turbo": # Aliased model name
|
||||||
# # # only assign 40% of traffic to the fastest deployment to avoid overloading it
|
model: "gpt-4" # Actual model name in 'model_list'
|
||||||
# # lowest_latency_buffer: 0.4
|
hidden: true
|
||||||
|
|
||||||
# # # consider last five minutes of calls for latency calculation
|
|
||||||
# # ttl: 300
|
|
||||||
# # redis_host: os.environ/REDIS_HOST
|
|
||||||
# # redis_port: os.environ/REDIS_PORT
|
|
||||||
# # redis_password: os.environ/REDIS_PASSWORD
|
|
||||||
|
|
||||||
# # # see https://docs.litellm.ai/docs/proxy/prod#1-use-this-configyaml
|
|
||||||
# # general_settings:
|
|
||||||
# # master_key: os.environ/LITELLM_MASTER_KEY
|
|
||||||
# # database_url: os.environ/DATABASE_URL
|
|
||||||
# # disable_master_key_return: true
|
|
||||||
# # # alerting: ['slack', 'email']
|
|
||||||
# # alerting: ['email']
|
|
||||||
|
|
||||||
# # # Batch write spend updates every 60s
|
|
||||||
# # proxy_batch_write_at: 60
|
|
||||||
|
|
||||||
# # # see https://docs.litellm.ai/docs/proxy/caching#advanced---user-api-key-cache-ttl
|
|
||||||
# # # our api keys rarely change
|
|
||||||
# # user_api_key_cache_ttl: 3600
|
|
|
@ -1128,7 +1128,16 @@ class KeyManagementSystem(enum.Enum):
|
||||||
|
|
||||||
|
|
||||||
class KeyManagementSettings(LiteLLMBase):
|
class KeyManagementSettings(LiteLLMBase):
|
||||||
hosted_keys: List
|
hosted_keys: Optional[List] = None
|
||||||
|
store_virtual_keys: Optional[bool] = False
|
||||||
|
"""
|
||||||
|
If True, virtual keys created by litellm will be stored in the secret manager
|
||||||
|
"""
|
||||||
|
|
||||||
|
access_mode: Literal["read_only", "write_only", "read_and_write"] = "read_only"
|
||||||
|
"""
|
||||||
|
Access mode for the secret manager, when write_only will only use for writing secrets
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class TeamDefaultSettings(LiteLLMBase):
|
class TeamDefaultSettings(LiteLLMBase):
|
||||||
|
|
|
@ -8,6 +8,7 @@ Run checks for:
|
||||||
2. If user is in budget
|
2. If user is in budget
|
||||||
3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
|
3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
267
litellm/proxy/hooks/key_management_event_hooks.py
Normal file
267
litellm/proxy/hooks/key_management_event_hooks.py
Normal file
|
@ -0,0 +1,267 @@
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from re import A
|
||||||
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
|
from fastapi import status
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm._logging import verbose_proxy_logger
|
||||||
|
from litellm.proxy._types import (
|
||||||
|
GenerateKeyRequest,
|
||||||
|
KeyManagementSystem,
|
||||||
|
KeyRequest,
|
||||||
|
LiteLLM_AuditLogs,
|
||||||
|
LiteLLM_VerificationToken,
|
||||||
|
LitellmTableNames,
|
||||||
|
ProxyErrorTypes,
|
||||||
|
ProxyException,
|
||||||
|
UpdateKeyRequest,
|
||||||
|
UserAPIKeyAuth,
|
||||||
|
WebhookEvent,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class KeyManagementEventHooks:
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def async_key_generated_hook(
|
||||||
|
data: GenerateKeyRequest,
|
||||||
|
response: dict,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
litellm_changed_by: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Hook that runs after a successful /key/generate request
|
||||||
|
|
||||||
|
Handles the following:
|
||||||
|
- Sending Email with Key Details
|
||||||
|
- Storing Audit Logs for key generation
|
||||||
|
- Storing Generated Key in DB
|
||||||
|
"""
|
||||||
|
from litellm.proxy.management_helpers.audit_logs import (
|
||||||
|
create_audit_log_for_update,
|
||||||
|
)
|
||||||
|
from litellm.proxy.proxy_server import (
|
||||||
|
general_settings,
|
||||||
|
litellm_proxy_admin_name,
|
||||||
|
proxy_logging_obj,
|
||||||
|
)
|
||||||
|
|
||||||
|
if data.send_invite_email is True:
|
||||||
|
await KeyManagementEventHooks._send_key_created_email(response)
|
||||||
|
|
||||||
|
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
|
||||||
|
if litellm.store_audit_logs is True:
|
||||||
|
_updated_values = json.dumps(response, default=str)
|
||||||
|
asyncio.create_task(
|
||||||
|
create_audit_log_for_update(
|
||||||
|
request_data=LiteLLM_AuditLogs(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
updated_at=datetime.now(timezone.utc),
|
||||||
|
changed_by=litellm_changed_by
|
||||||
|
or user_api_key_dict.user_id
|
||||||
|
or litellm_proxy_admin_name,
|
||||||
|
changed_by_api_key=user_api_key_dict.api_key,
|
||||||
|
table_name=LitellmTableNames.KEY_TABLE_NAME,
|
||||||
|
object_id=response.get("token_id", ""),
|
||||||
|
action="created",
|
||||||
|
updated_values=_updated_values,
|
||||||
|
before_value=None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# store the generated key in the secret manager
|
||||||
|
await KeyManagementEventHooks._store_virtual_key_in_secret_manager(
|
||||||
|
secret_name=data.key_alias or f"virtual-key-{uuid.uuid4()}",
|
||||||
|
secret_token=response.get("token", ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def async_key_updated_hook(
|
||||||
|
data: UpdateKeyRequest,
|
||||||
|
existing_key_row: Any,
|
||||||
|
response: Any,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
litellm_changed_by: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Post /key/update processing hook
|
||||||
|
|
||||||
|
Handles the following:
|
||||||
|
- Storing Audit Logs for key update
|
||||||
|
"""
|
||||||
|
from litellm.proxy.management_helpers.audit_logs import (
|
||||||
|
create_audit_log_for_update,
|
||||||
|
)
|
||||||
|
from litellm.proxy.proxy_server import litellm_proxy_admin_name
|
||||||
|
|
||||||
|
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
|
||||||
|
if litellm.store_audit_logs is True:
|
||||||
|
_updated_values = json.dumps(data.json(exclude_none=True), default=str)
|
||||||
|
|
||||||
|
_before_value = existing_key_row.json(exclude_none=True)
|
||||||
|
_before_value = json.dumps(_before_value, default=str)
|
||||||
|
|
||||||
|
asyncio.create_task(
|
||||||
|
create_audit_log_for_update(
|
||||||
|
request_data=LiteLLM_AuditLogs(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
updated_at=datetime.now(timezone.utc),
|
||||||
|
changed_by=litellm_changed_by
|
||||||
|
or user_api_key_dict.user_id
|
||||||
|
or litellm_proxy_admin_name,
|
||||||
|
changed_by_api_key=user_api_key_dict.api_key,
|
||||||
|
table_name=LitellmTableNames.KEY_TABLE_NAME,
|
||||||
|
object_id=data.key,
|
||||||
|
action="updated",
|
||||||
|
updated_values=_updated_values,
|
||||||
|
before_value=_before_value,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def async_key_deleted_hook(
|
||||||
|
data: KeyRequest,
|
||||||
|
keys_being_deleted: List[LiteLLM_VerificationToken],
|
||||||
|
response: dict,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
litellm_changed_by: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Post /key/delete processing hook
|
||||||
|
|
||||||
|
Handles the following:
|
||||||
|
- Storing Audit Logs for key deletion
|
||||||
|
"""
|
||||||
|
from litellm.proxy.management_helpers.audit_logs import (
|
||||||
|
create_audit_log_for_update,
|
||||||
|
)
|
||||||
|
from litellm.proxy.proxy_server import litellm_proxy_admin_name, prisma_client
|
||||||
|
|
||||||
|
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
|
||||||
|
# we do this after the first for loop, since first for loop is for validation. we only want this inserted after validation passes
|
||||||
|
if litellm.store_audit_logs is True:
|
||||||
|
# make an audit log for each team deleted
|
||||||
|
for key in data.keys:
|
||||||
|
key_row = await prisma_client.get_data( # type: ignore
|
||||||
|
token=key, table_name="key", query_type="find_unique"
|
||||||
|
)
|
||||||
|
|
||||||
|
if key_row is None:
|
||||||
|
raise ProxyException(
|
||||||
|
message=f"Key {key} not found",
|
||||||
|
type=ProxyErrorTypes.bad_request_error,
|
||||||
|
param="key",
|
||||||
|
code=status.HTTP_404_NOT_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
key_row = key_row.json(exclude_none=True)
|
||||||
|
_key_row = json.dumps(key_row, default=str)
|
||||||
|
|
||||||
|
asyncio.create_task(
|
||||||
|
create_audit_log_for_update(
|
||||||
|
request_data=LiteLLM_AuditLogs(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
updated_at=datetime.now(timezone.utc),
|
||||||
|
changed_by=litellm_changed_by
|
||||||
|
or user_api_key_dict.user_id
|
||||||
|
or litellm_proxy_admin_name,
|
||||||
|
changed_by_api_key=user_api_key_dict.api_key,
|
||||||
|
table_name=LitellmTableNames.KEY_TABLE_NAME,
|
||||||
|
object_id=key,
|
||||||
|
action="deleted",
|
||||||
|
updated_values="{}",
|
||||||
|
before_value=_key_row,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# delete the keys from the secret manager
|
||||||
|
await KeyManagementEventHooks._delete_virtual_keys_from_secret_manager(
|
||||||
|
keys_being_deleted=keys_being_deleted
|
||||||
|
)
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _store_virtual_key_in_secret_manager(secret_name: str, secret_token: str):
|
||||||
|
"""
|
||||||
|
Store a virtual key in the secret manager
|
||||||
|
|
||||||
|
Args:
|
||||||
|
secret_name: Name of the virtual key
|
||||||
|
secret_token: Value of the virtual key (example: sk-1234)
|
||||||
|
"""
|
||||||
|
if litellm._key_management_settings is not None:
|
||||||
|
if litellm._key_management_settings.store_virtual_keys is True:
|
||||||
|
from litellm.secret_managers.aws_secret_manager_v2 import (
|
||||||
|
AWSSecretsManagerV2,
|
||||||
|
)
|
||||||
|
|
||||||
|
# store the key in the secret manager
|
||||||
|
if (
|
||||||
|
litellm._key_management_system
|
||||||
|
== KeyManagementSystem.AWS_SECRET_MANAGER
|
||||||
|
and isinstance(litellm.secret_manager_client, AWSSecretsManagerV2)
|
||||||
|
):
|
||||||
|
await litellm.secret_manager_client.async_write_secret(
|
||||||
|
secret_name=secret_name,
|
||||||
|
secret_value=secret_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _delete_virtual_keys_from_secret_manager(
|
||||||
|
keys_being_deleted: List[LiteLLM_VerificationToken],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Deletes virtual keys from the secret manager
|
||||||
|
|
||||||
|
Args:
|
||||||
|
keys_being_deleted: List of keys being deleted, this is passed down from the /key/delete operation
|
||||||
|
"""
|
||||||
|
if litellm._key_management_settings is not None:
|
||||||
|
if litellm._key_management_settings.store_virtual_keys is True:
|
||||||
|
from litellm.secret_managers.aws_secret_manager_v2 import (
|
||||||
|
AWSSecretsManagerV2,
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(litellm.secret_manager_client, AWSSecretsManagerV2):
|
||||||
|
for key in keys_being_deleted:
|
||||||
|
if key.key_alias is not None:
|
||||||
|
await litellm.secret_manager_client.async_delete_secret(
|
||||||
|
secret_name=key.key_alias
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
verbose_proxy_logger.warning(
|
||||||
|
f"KeyManagementEventHooks._delete_virtual_key_from_secret_manager: Key alias not found for key {key.token}. Skipping deletion from secret manager."
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _send_key_created_email(response: dict):
|
||||||
|
from litellm.proxy.proxy_server import general_settings, proxy_logging_obj
|
||||||
|
|
||||||
|
if "email" not in general_settings.get("alerting", []):
|
||||||
|
raise ValueError(
|
||||||
|
"Email alerting not setup on config.yaml. Please set `alerting=['email']. \nDocs: https://docs.litellm.ai/docs/proxy/email`"
|
||||||
|
)
|
||||||
|
event = WebhookEvent(
|
||||||
|
event="key_created",
|
||||||
|
event_group="key",
|
||||||
|
event_message="API Key Created",
|
||||||
|
token=response.get("token", ""),
|
||||||
|
spend=response.get("spend", 0.0),
|
||||||
|
max_budget=response.get("max_budget", 0.0),
|
||||||
|
user_id=response.get("user_id", None),
|
||||||
|
team_id=response.get("team_id", "Default Team"),
|
||||||
|
key_alias=response.get("key_alias", None),
|
||||||
|
)
|
||||||
|
|
||||||
|
# If user configured email alerting - send an Email letting their end-user know the key was created
|
||||||
|
asyncio.create_task(
|
||||||
|
proxy_logging_obj.slack_alerting_instance.send_key_created_or_user_invited_email(
|
||||||
|
webhook_event=event,
|
||||||
|
)
|
||||||
|
)
|
|
@ -274,6 +274,51 @@ class LiteLLMProxyRequestSetup:
|
||||||
)
|
)
|
||||||
return user_api_key_logged_metadata
|
return user_api_key_logged_metadata
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def add_key_level_controls(
|
||||||
|
key_metadata: dict, data: dict, _metadata_variable_name: str
|
||||||
|
):
|
||||||
|
data = data.copy()
|
||||||
|
if "cache" in key_metadata:
|
||||||
|
data["cache"] = {}
|
||||||
|
if isinstance(key_metadata["cache"], dict):
|
||||||
|
for k, v in key_metadata["cache"].items():
|
||||||
|
if k in SupportedCacheControls:
|
||||||
|
data["cache"][k] = v
|
||||||
|
|
||||||
|
## KEY-LEVEL SPEND LOGS / TAGS
|
||||||
|
if "tags" in key_metadata and key_metadata["tags"] is not None:
|
||||||
|
if "tags" in data[_metadata_variable_name] and isinstance(
|
||||||
|
data[_metadata_variable_name]["tags"], list
|
||||||
|
):
|
||||||
|
data[_metadata_variable_name]["tags"].extend(key_metadata["tags"])
|
||||||
|
else:
|
||||||
|
data[_metadata_variable_name]["tags"] = key_metadata["tags"]
|
||||||
|
if "spend_logs_metadata" in key_metadata and isinstance(
|
||||||
|
key_metadata["spend_logs_metadata"], dict
|
||||||
|
):
|
||||||
|
if "spend_logs_metadata" in data[_metadata_variable_name] and isinstance(
|
||||||
|
data[_metadata_variable_name]["spend_logs_metadata"], dict
|
||||||
|
):
|
||||||
|
for key, value in key_metadata["spend_logs_metadata"].items():
|
||||||
|
if (
|
||||||
|
key not in data[_metadata_variable_name]["spend_logs_metadata"]
|
||||||
|
): # don't override k-v pair sent by request (user request)
|
||||||
|
data[_metadata_variable_name]["spend_logs_metadata"][
|
||||||
|
key
|
||||||
|
] = value
|
||||||
|
else:
|
||||||
|
data[_metadata_variable_name]["spend_logs_metadata"] = key_metadata[
|
||||||
|
"spend_logs_metadata"
|
||||||
|
]
|
||||||
|
|
||||||
|
## KEY-LEVEL DISABLE FALLBACKS
|
||||||
|
if "disable_fallbacks" in key_metadata and isinstance(
|
||||||
|
key_metadata["disable_fallbacks"], bool
|
||||||
|
):
|
||||||
|
data["disable_fallbacks"] = key_metadata["disable_fallbacks"]
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
async def add_litellm_data_to_request( # noqa: PLR0915
|
async def add_litellm_data_to_request( # noqa: PLR0915
|
||||||
data: dict,
|
data: dict,
|
||||||
|
@ -389,37 +434,11 @@ async def add_litellm_data_to_request( # noqa: PLR0915
|
||||||
|
|
||||||
### KEY-LEVEL Controls
|
### KEY-LEVEL Controls
|
||||||
key_metadata = user_api_key_dict.metadata
|
key_metadata = user_api_key_dict.metadata
|
||||||
if "cache" in key_metadata:
|
data = LiteLLMProxyRequestSetup.add_key_level_controls(
|
||||||
data["cache"] = {}
|
key_metadata=key_metadata,
|
||||||
if isinstance(key_metadata["cache"], dict):
|
data=data,
|
||||||
for k, v in key_metadata["cache"].items():
|
_metadata_variable_name=_metadata_variable_name,
|
||||||
if k in SupportedCacheControls:
|
)
|
||||||
data["cache"][k] = v
|
|
||||||
|
|
||||||
## KEY-LEVEL SPEND LOGS / TAGS
|
|
||||||
if "tags" in key_metadata and key_metadata["tags"] is not None:
|
|
||||||
if "tags" in data[_metadata_variable_name] and isinstance(
|
|
||||||
data[_metadata_variable_name]["tags"], list
|
|
||||||
):
|
|
||||||
data[_metadata_variable_name]["tags"].extend(key_metadata["tags"])
|
|
||||||
else:
|
|
||||||
data[_metadata_variable_name]["tags"] = key_metadata["tags"]
|
|
||||||
if "spend_logs_metadata" in key_metadata and isinstance(
|
|
||||||
key_metadata["spend_logs_metadata"], dict
|
|
||||||
):
|
|
||||||
if "spend_logs_metadata" in data[_metadata_variable_name] and isinstance(
|
|
||||||
data[_metadata_variable_name]["spend_logs_metadata"], dict
|
|
||||||
):
|
|
||||||
for key, value in key_metadata["spend_logs_metadata"].items():
|
|
||||||
if (
|
|
||||||
key not in data[_metadata_variable_name]["spend_logs_metadata"]
|
|
||||||
): # don't override k-v pair sent by request (user request)
|
|
||||||
data[_metadata_variable_name]["spend_logs_metadata"][key] = value
|
|
||||||
else:
|
|
||||||
data[_metadata_variable_name]["spend_logs_metadata"] = key_metadata[
|
|
||||||
"spend_logs_metadata"
|
|
||||||
]
|
|
||||||
|
|
||||||
## TEAM-LEVEL SPEND LOGS/TAGS
|
## TEAM-LEVEL SPEND LOGS/TAGS
|
||||||
team_metadata = user_api_key_dict.team_metadata or {}
|
team_metadata = user_api_key_dict.team_metadata or {}
|
||||||
if "tags" in team_metadata and team_metadata["tags"] is not None:
|
if "tags" in team_metadata and team_metadata["tags"] is not None:
|
||||||
|
|
|
@ -17,7 +17,7 @@ import secrets
|
||||||
import traceback
|
import traceback
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from typing import List, Optional
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import fastapi
|
import fastapi
|
||||||
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Request, status
|
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Request, status
|
||||||
|
@ -31,6 +31,7 @@ from litellm.proxy.auth.auth_checks import (
|
||||||
get_key_object,
|
get_key_object,
|
||||||
)
|
)
|
||||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||||
|
from litellm.proxy.hooks.key_management_event_hooks import KeyManagementEventHooks
|
||||||
from litellm.proxy.management_helpers.utils import management_endpoint_wrapper
|
from litellm.proxy.management_helpers.utils import management_endpoint_wrapper
|
||||||
from litellm.proxy.utils import _duration_in_seconds, _hash_token_if_needed
|
from litellm.proxy.utils import _duration_in_seconds, _hash_token_if_needed
|
||||||
from litellm.secret_managers.main import get_secret
|
from litellm.secret_managers.main import get_secret
|
||||||
|
@ -234,48 +235,12 @@ async def generate_key_fn( # noqa: PLR0915
|
||||||
data.soft_budget
|
data.soft_budget
|
||||||
) # include the user-input soft budget in the response
|
) # include the user-input soft budget in the response
|
||||||
|
|
||||||
if data.send_invite_email is True:
|
|
||||||
if "email" not in general_settings.get("alerting", []):
|
|
||||||
raise ValueError(
|
|
||||||
"Email alerting not setup on config.yaml. Please set `alerting=['email']. \nDocs: https://docs.litellm.ai/docs/proxy/email`"
|
|
||||||
)
|
|
||||||
event = WebhookEvent(
|
|
||||||
event="key_created",
|
|
||||||
event_group="key",
|
|
||||||
event_message="API Key Created",
|
|
||||||
token=response.get("token", ""),
|
|
||||||
spend=response.get("spend", 0.0),
|
|
||||||
max_budget=response.get("max_budget", 0.0),
|
|
||||||
user_id=response.get("user_id", None),
|
|
||||||
team_id=response.get("team_id", "Default Team"),
|
|
||||||
key_alias=response.get("key_alias", None),
|
|
||||||
)
|
|
||||||
|
|
||||||
# If user configured email alerting - send an Email letting their end-user know the key was created
|
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
proxy_logging_obj.slack_alerting_instance.send_key_created_or_user_invited_email(
|
KeyManagementEventHooks.async_key_generated_hook(
|
||||||
webhook_event=event,
|
data=data,
|
||||||
)
|
response=response,
|
||||||
)
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
litellm_changed_by=litellm_changed_by,
|
||||||
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
|
|
||||||
if litellm.store_audit_logs is True:
|
|
||||||
_updated_values = json.dumps(response, default=str)
|
|
||||||
asyncio.create_task(
|
|
||||||
create_audit_log_for_update(
|
|
||||||
request_data=LiteLLM_AuditLogs(
|
|
||||||
id=str(uuid.uuid4()),
|
|
||||||
updated_at=datetime.now(timezone.utc),
|
|
||||||
changed_by=litellm_changed_by
|
|
||||||
or user_api_key_dict.user_id
|
|
||||||
or litellm_proxy_admin_name,
|
|
||||||
changed_by_api_key=user_api_key_dict.api_key,
|
|
||||||
table_name=LitellmTableNames.KEY_TABLE_NAME,
|
|
||||||
object_id=response.get("token_id", ""),
|
|
||||||
action="created",
|
|
||||||
updated_values=_updated_values,
|
|
||||||
before_value=None,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -407,28 +372,13 @@ async def update_key_fn(
|
||||||
proxy_logging_obj=proxy_logging_obj,
|
proxy_logging_obj=proxy_logging_obj,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
|
|
||||||
if litellm.store_audit_logs is True:
|
|
||||||
_updated_values = json.dumps(data_json, default=str)
|
|
||||||
|
|
||||||
_before_value = existing_key_row.json(exclude_none=True)
|
|
||||||
_before_value = json.dumps(_before_value, default=str)
|
|
||||||
|
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
create_audit_log_for_update(
|
KeyManagementEventHooks.async_key_updated_hook(
|
||||||
request_data=LiteLLM_AuditLogs(
|
data=data,
|
||||||
id=str(uuid.uuid4()),
|
existing_key_row=existing_key_row,
|
||||||
updated_at=datetime.now(timezone.utc),
|
response=response,
|
||||||
changed_by=litellm_changed_by
|
user_api_key_dict=user_api_key_dict,
|
||||||
or user_api_key_dict.user_id
|
litellm_changed_by=litellm_changed_by,
|
||||||
or litellm_proxy_admin_name,
|
|
||||||
changed_by_api_key=user_api_key_dict.api_key,
|
|
||||||
table_name=LitellmTableNames.KEY_TABLE_NAME,
|
|
||||||
object_id=data.key,
|
|
||||||
action="updated",
|
|
||||||
updated_values=_updated_values,
|
|
||||||
before_value=_before_value,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -496,6 +446,9 @@ async def delete_key_fn(
|
||||||
user_custom_key_generate,
|
user_custom_key_generate,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if prisma_client is None:
|
||||||
|
raise Exception("Not connected to DB!")
|
||||||
|
|
||||||
keys = data.keys
|
keys = data.keys
|
||||||
if len(keys) == 0:
|
if len(keys) == 0:
|
||||||
raise ProxyException(
|
raise ProxyException(
|
||||||
|
@ -516,45 +469,7 @@ async def delete_key_fn(
|
||||||
):
|
):
|
||||||
user_id = None # unless they're admin
|
user_id = None # unless they're admin
|
||||||
|
|
||||||
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
|
number_deleted_keys, _keys_being_deleted = await delete_verification_token(
|
||||||
# we do this after the first for loop, since first for loop is for validation. we only want this inserted after validation passes
|
|
||||||
if litellm.store_audit_logs is True:
|
|
||||||
# make an audit log for each team deleted
|
|
||||||
for key in data.keys:
|
|
||||||
key_row = await prisma_client.get_data( # type: ignore
|
|
||||||
token=key, table_name="key", query_type="find_unique"
|
|
||||||
)
|
|
||||||
|
|
||||||
if key_row is None:
|
|
||||||
raise ProxyException(
|
|
||||||
message=f"Key {key} not found",
|
|
||||||
type=ProxyErrorTypes.bad_request_error,
|
|
||||||
param="key",
|
|
||||||
code=status.HTTP_404_NOT_FOUND,
|
|
||||||
)
|
|
||||||
|
|
||||||
key_row = key_row.json(exclude_none=True)
|
|
||||||
_key_row = json.dumps(key_row, default=str)
|
|
||||||
|
|
||||||
asyncio.create_task(
|
|
||||||
create_audit_log_for_update(
|
|
||||||
request_data=LiteLLM_AuditLogs(
|
|
||||||
id=str(uuid.uuid4()),
|
|
||||||
updated_at=datetime.now(timezone.utc),
|
|
||||||
changed_by=litellm_changed_by
|
|
||||||
or user_api_key_dict.user_id
|
|
||||||
or litellm_proxy_admin_name,
|
|
||||||
changed_by_api_key=user_api_key_dict.api_key,
|
|
||||||
table_name=LitellmTableNames.KEY_TABLE_NAME,
|
|
||||||
object_id=key,
|
|
||||||
action="deleted",
|
|
||||||
updated_values="{}",
|
|
||||||
before_value=_key_row,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
number_deleted_keys = await delete_verification_token(
|
|
||||||
tokens=keys, user_id=user_id
|
tokens=keys, user_id=user_id
|
||||||
)
|
)
|
||||||
if number_deleted_keys is None:
|
if number_deleted_keys is None:
|
||||||
|
@ -588,6 +503,16 @@ async def delete_key_fn(
|
||||||
f"/keys/delete - cache after delete: {user_api_key_cache.in_memory_cache.cache_dict}"
|
f"/keys/delete - cache after delete: {user_api_key_cache.in_memory_cache.cache_dict}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
asyncio.create_task(
|
||||||
|
KeyManagementEventHooks.async_key_deleted_hook(
|
||||||
|
data=data,
|
||||||
|
keys_being_deleted=_keys_being_deleted,
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
litellm_changed_by=litellm_changed_by,
|
||||||
|
response=number_deleted_keys,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return {"deleted_keys": keys}
|
return {"deleted_keys": keys}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if isinstance(e, HTTPException):
|
if isinstance(e, HTTPException):
|
||||||
|
@ -1026,11 +951,35 @@ async def generate_key_helper_fn( # noqa: PLR0915
|
||||||
return key_data
|
return key_data
|
||||||
|
|
||||||
|
|
||||||
async def delete_verification_token(tokens: List, user_id: Optional[str] = None):
|
async def delete_verification_token(
|
||||||
|
tokens: List, user_id: Optional[str] = None
|
||||||
|
) -> Tuple[Optional[Dict], List[LiteLLM_VerificationToken]]:
|
||||||
|
"""
|
||||||
|
Helper that deletes the list of tokens from the database
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokens: List of tokens to delete
|
||||||
|
user_id: Optional user_id to filter by
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[Optional[Dict], List[LiteLLM_VerificationToken]]:
|
||||||
|
Optional[Dict]:
|
||||||
|
- Number of deleted tokens
|
||||||
|
List[LiteLLM_VerificationToken]:
|
||||||
|
- List of keys being deleted, this contains information about the key_alias, token, and user_id being deleted,
|
||||||
|
this is passed down to the KeyManagementEventHooks to delete the keys from the secret manager and handle audit logs
|
||||||
|
"""
|
||||||
from litellm.proxy.proxy_server import litellm_proxy_admin_name, prisma_client
|
from litellm.proxy.proxy_server import litellm_proxy_admin_name, prisma_client
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if prisma_client:
|
if prisma_client:
|
||||||
|
tokens = [_hash_token_if_needed(token=key) for key in tokens]
|
||||||
|
_keys_being_deleted = (
|
||||||
|
await prisma_client.db.litellm_verificationtoken.find_many(
|
||||||
|
where={"token": {"in": tokens}}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Assuming 'db' is your Prisma Client instance
|
# Assuming 'db' is your Prisma Client instance
|
||||||
# check if admin making request - don't filter by user-id
|
# check if admin making request - don't filter by user-id
|
||||||
if user_id == litellm_proxy_admin_name:
|
if user_id == litellm_proxy_admin_name:
|
||||||
|
@ -1060,7 +1009,7 @@ async def delete_verification_token(tokens: List, user_id: Optional[str] = None)
|
||||||
)
|
)
|
||||||
verbose_proxy_logger.debug(traceback.format_exc())
|
verbose_proxy_logger.debug(traceback.format_exc())
|
||||||
raise e
|
raise e
|
||||||
return deleted_tokens
|
return deleted_tokens, _keys_being_deleted
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
|
|
|
@ -265,7 +265,6 @@ def run_server( # noqa: PLR0915
|
||||||
ProxyConfig,
|
ProxyConfig,
|
||||||
app,
|
app,
|
||||||
load_aws_kms,
|
load_aws_kms,
|
||||||
load_aws_secret_manager,
|
|
||||||
load_from_azure_key_vault,
|
load_from_azure_key_vault,
|
||||||
load_google_kms,
|
load_google_kms,
|
||||||
save_worker_config,
|
save_worker_config,
|
||||||
|
@ -278,7 +277,6 @@ def run_server( # noqa: PLR0915
|
||||||
ProxyConfig,
|
ProxyConfig,
|
||||||
app,
|
app,
|
||||||
load_aws_kms,
|
load_aws_kms,
|
||||||
load_aws_secret_manager,
|
|
||||||
load_from_azure_key_vault,
|
load_from_azure_key_vault,
|
||||||
load_google_kms,
|
load_google_kms,
|
||||||
save_worker_config,
|
save_worker_config,
|
||||||
|
@ -295,7 +293,6 @@ def run_server( # noqa: PLR0915
|
||||||
ProxyConfig,
|
ProxyConfig,
|
||||||
app,
|
app,
|
||||||
load_aws_kms,
|
load_aws_kms,
|
||||||
load_aws_secret_manager,
|
|
||||||
load_from_azure_key_vault,
|
load_from_azure_key_vault,
|
||||||
load_google_kms,
|
load_google_kms,
|
||||||
save_worker_config,
|
save_worker_config,
|
||||||
|
@ -559,8 +556,14 @@ def run_server( # noqa: PLR0915
|
||||||
key_management_system
|
key_management_system
|
||||||
== KeyManagementSystem.AWS_SECRET_MANAGER.value # noqa: F405
|
== KeyManagementSystem.AWS_SECRET_MANAGER.value # noqa: F405
|
||||||
):
|
):
|
||||||
|
from litellm.secret_managers.aws_secret_manager_v2 import (
|
||||||
|
AWSSecretsManagerV2,
|
||||||
|
)
|
||||||
|
|
||||||
### LOAD FROM AWS SECRET MANAGER ###
|
### LOAD FROM AWS SECRET MANAGER ###
|
||||||
load_aws_secret_manager(use_aws_secret_manager=True)
|
AWSSecretsManagerV2.load_aws_secret_manager(
|
||||||
|
use_aws_secret_manager=True
|
||||||
|
)
|
||||||
elif key_management_system == KeyManagementSystem.AWS_KMS.value:
|
elif key_management_system == KeyManagementSystem.AWS_KMS.value:
|
||||||
load_aws_kms(use_aws_kms=True)
|
load_aws_kms(use_aws_kms=True)
|
||||||
elif (
|
elif (
|
||||||
|
|
|
@ -7,6 +7,8 @@ model_list:
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
litellm_settings:
|
general_settings:
|
||||||
callbacks: ["gcs_bucket"]
|
key_management_system: "aws_secret_manager"
|
||||||
|
key_management_settings:
|
||||||
|
store_virtual_keys: true
|
||||||
|
access_mode: "write_only"
|
||||||
|
|
|
@ -245,10 +245,7 @@ from litellm.router import (
|
||||||
from litellm.router import ModelInfo as RouterModelInfo
|
from litellm.router import ModelInfo as RouterModelInfo
|
||||||
from litellm.router import updateDeployment
|
from litellm.router import updateDeployment
|
||||||
from litellm.scheduler import DefaultPriorities, FlowItem, Scheduler
|
from litellm.scheduler import DefaultPriorities, FlowItem, Scheduler
|
||||||
from litellm.secret_managers.aws_secret_manager import (
|
from litellm.secret_managers.aws_secret_manager import load_aws_kms
|
||||||
load_aws_kms,
|
|
||||||
load_aws_secret_manager,
|
|
||||||
)
|
|
||||||
from litellm.secret_managers.google_kms import load_google_kms
|
from litellm.secret_managers.google_kms import load_google_kms
|
||||||
from litellm.secret_managers.main import (
|
from litellm.secret_managers.main import (
|
||||||
get_secret,
|
get_secret,
|
||||||
|
@ -1825,8 +1822,13 @@ class ProxyConfig:
|
||||||
key_management_system
|
key_management_system
|
||||||
== KeyManagementSystem.AWS_SECRET_MANAGER.value # noqa: F405
|
== KeyManagementSystem.AWS_SECRET_MANAGER.value # noqa: F405
|
||||||
):
|
):
|
||||||
### LOAD FROM AWS SECRET MANAGER ###
|
from litellm.secret_managers.aws_secret_manager_v2 import (
|
||||||
load_aws_secret_manager(use_aws_secret_manager=True)
|
AWSSecretsManagerV2,
|
||||||
|
)
|
||||||
|
|
||||||
|
AWSSecretsManagerV2.load_aws_secret_manager(
|
||||||
|
use_aws_secret_manager=True
|
||||||
|
)
|
||||||
elif key_management_system == KeyManagementSystem.AWS_KMS.value:
|
elif key_management_system == KeyManagementSystem.AWS_KMS.value:
|
||||||
load_aws_kms(use_aws_kms=True)
|
load_aws_kms(use_aws_kms=True)
|
||||||
elif (
|
elif (
|
||||||
|
|
|
@ -8,7 +8,8 @@ from litellm._logging import verbose_logger
|
||||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||||
from litellm.llms.azure_ai.rerank import AzureAIRerank
|
from litellm.llms.azure_ai.rerank import AzureAIRerank
|
||||||
from litellm.llms.cohere.rerank import CohereRerank
|
from litellm.llms.cohere.rerank import CohereRerank
|
||||||
from litellm.llms.together_ai.rerank import TogetherAIRerank
|
from litellm.llms.jina_ai.rerank.handler import JinaAIRerank
|
||||||
|
from litellm.llms.together_ai.rerank.handler import TogetherAIRerank
|
||||||
from litellm.secret_managers.main import get_secret
|
from litellm.secret_managers.main import get_secret
|
||||||
from litellm.types.rerank import RerankRequest, RerankResponse
|
from litellm.types.rerank import RerankRequest, RerankResponse
|
||||||
from litellm.types.router import *
|
from litellm.types.router import *
|
||||||
|
@ -19,6 +20,7 @@ from litellm.utils import client, exception_type, supports_httpx_timeout
|
||||||
cohere_rerank = CohereRerank()
|
cohere_rerank = CohereRerank()
|
||||||
together_rerank = TogetherAIRerank()
|
together_rerank = TogetherAIRerank()
|
||||||
azure_ai_rerank = AzureAIRerank()
|
azure_ai_rerank = AzureAIRerank()
|
||||||
|
jina_ai_rerank = JinaAIRerank()
|
||||||
#################################################
|
#################################################
|
||||||
|
|
||||||
|
|
||||||
|
@ -247,7 +249,23 @@ def rerank(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
_is_async=_is_async,
|
_is_async=_is_async,
|
||||||
)
|
)
|
||||||
|
elif _custom_llm_provider == "jina_ai":
|
||||||
|
|
||||||
|
if dynamic_api_key is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Jina AI API key is required, please set 'JINA_AI_API_KEY' in your environment"
|
||||||
|
)
|
||||||
|
response = jina_ai_rerank.rerank(
|
||||||
|
model=model,
|
||||||
|
api_key=dynamic_api_key,
|
||||||
|
query=query,
|
||||||
|
documents=documents,
|
||||||
|
top_n=top_n,
|
||||||
|
rank_fields=rank_fields,
|
||||||
|
return_documents=return_documents,
|
||||||
|
max_chunks_per_doc=max_chunks_per_doc,
|
||||||
|
_is_async=_is_async,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported provider: {_custom_llm_provider}")
|
raise ValueError(f"Unsupported provider: {_custom_llm_provider}")
|
||||||
|
|
||||||
|
|
|
@ -679,9 +679,8 @@ class Router:
|
||||||
kwargs["model"] = model
|
kwargs["model"] = model
|
||||||
kwargs["messages"] = messages
|
kwargs["messages"] = messages
|
||||||
kwargs["original_function"] = self._completion
|
kwargs["original_function"] = self._completion
|
||||||
kwargs.get("request_timeout", self.timeout)
|
self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
|
||||||
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
|
||||||
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
|
||||||
response = self.function_with_fallbacks(**kwargs)
|
response = self.function_with_fallbacks(**kwargs)
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -783,8 +782,7 @@ class Router:
|
||||||
kwargs["stream"] = stream
|
kwargs["stream"] = stream
|
||||||
kwargs["original_function"] = self._acompletion
|
kwargs["original_function"] = self._acompletion
|
||||||
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
||||||
|
self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
|
||||||
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
|
||||||
|
|
||||||
request_priority = kwargs.get("priority") or self.default_priority
|
request_priority = kwargs.get("priority") or self.default_priority
|
||||||
|
|
||||||
|
@ -948,6 +946,17 @@ class Router:
|
||||||
self.fail_calls[model_name] += 1
|
self.fail_calls[model_name] += 1
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
def _update_kwargs_before_fallbacks(self, model: str, kwargs: dict) -> None:
|
||||||
|
"""
|
||||||
|
Adds/updates to kwargs:
|
||||||
|
- num_retries
|
||||||
|
- litellm_trace_id
|
||||||
|
- metadata
|
||||||
|
"""
|
||||||
|
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
||||||
|
kwargs.setdefault("litellm_trace_id", str(uuid.uuid4()))
|
||||||
|
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
||||||
|
|
||||||
def _update_kwargs_with_default_litellm_params(self, kwargs: dict) -> None:
|
def _update_kwargs_with_default_litellm_params(self, kwargs: dict) -> None:
|
||||||
"""
|
"""
|
||||||
Adds default litellm params to kwargs, if set.
|
Adds default litellm params to kwargs, if set.
|
||||||
|
@ -1511,9 +1520,7 @@ class Router:
|
||||||
kwargs["model"] = model
|
kwargs["model"] = model
|
||||||
kwargs["file"] = file
|
kwargs["file"] = file
|
||||||
kwargs["original_function"] = self._atranscription
|
kwargs["original_function"] = self._atranscription
|
||||||
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
|
||||||
kwargs.get("request_timeout", self.timeout)
|
|
||||||
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
|
||||||
response = await self.async_function_with_fallbacks(**kwargs)
|
response = await self.async_function_with_fallbacks(**kwargs)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
@ -1688,9 +1695,7 @@ class Router:
|
||||||
kwargs["model"] = model
|
kwargs["model"] = model
|
||||||
kwargs["input"] = input
|
kwargs["input"] = input
|
||||||
kwargs["original_function"] = self._arerank
|
kwargs["original_function"] = self._arerank
|
||||||
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
|
||||||
kwargs.get("request_timeout", self.timeout)
|
|
||||||
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
|
||||||
|
|
||||||
response = await self.async_function_with_fallbacks(**kwargs)
|
response = await self.async_function_with_fallbacks(**kwargs)
|
||||||
|
|
||||||
|
@ -1839,9 +1844,7 @@ class Router:
|
||||||
kwargs["model"] = model
|
kwargs["model"] = model
|
||||||
kwargs["prompt"] = prompt
|
kwargs["prompt"] = prompt
|
||||||
kwargs["original_function"] = self._atext_completion
|
kwargs["original_function"] = self._atext_completion
|
||||||
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
|
||||||
kwargs.get("request_timeout", self.timeout)
|
|
||||||
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
|
||||||
response = await self.async_function_with_fallbacks(**kwargs)
|
response = await self.async_function_with_fallbacks(**kwargs)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
@ -2112,9 +2115,7 @@ class Router:
|
||||||
kwargs["model"] = model
|
kwargs["model"] = model
|
||||||
kwargs["input"] = input
|
kwargs["input"] = input
|
||||||
kwargs["original_function"] = self._aembedding
|
kwargs["original_function"] = self._aembedding
|
||||||
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
|
||||||
kwargs.get("request_timeout", self.timeout)
|
|
||||||
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
|
||||||
response = await self.async_function_with_fallbacks(**kwargs)
|
response = await self.async_function_with_fallbacks(**kwargs)
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -2609,6 +2610,7 @@ class Router:
|
||||||
If it fails after num_retries, fall back to another model group
|
If it fails after num_retries, fall back to another model group
|
||||||
"""
|
"""
|
||||||
model_group: Optional[str] = kwargs.get("model")
|
model_group: Optional[str] = kwargs.get("model")
|
||||||
|
disable_fallbacks: Optional[bool] = kwargs.pop("disable_fallbacks", False)
|
||||||
fallbacks: Optional[List] = kwargs.get("fallbacks", self.fallbacks)
|
fallbacks: Optional[List] = kwargs.get("fallbacks", self.fallbacks)
|
||||||
context_window_fallbacks: Optional[List] = kwargs.get(
|
context_window_fallbacks: Optional[List] = kwargs.get(
|
||||||
"context_window_fallbacks", self.context_window_fallbacks
|
"context_window_fallbacks", self.context_window_fallbacks
|
||||||
|
@ -2616,6 +2618,7 @@ class Router:
|
||||||
content_policy_fallbacks: Optional[List] = kwargs.get(
|
content_policy_fallbacks: Optional[List] = kwargs.get(
|
||||||
"content_policy_fallbacks", self.content_policy_fallbacks
|
"content_policy_fallbacks", self.content_policy_fallbacks
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self._handle_mock_testing_fallbacks(
|
self._handle_mock_testing_fallbacks(
|
||||||
kwargs=kwargs,
|
kwargs=kwargs,
|
||||||
|
@ -2635,7 +2638,7 @@ class Router:
|
||||||
original_model_group: Optional[str] = kwargs.get("model") # type: ignore
|
original_model_group: Optional[str] = kwargs.get("model") # type: ignore
|
||||||
fallback_failure_exception_str = ""
|
fallback_failure_exception_str = ""
|
||||||
|
|
||||||
if original_model_group is None:
|
if disable_fallbacks is True or original_model_group is None:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
input_kwargs = {
|
input_kwargs = {
|
||||||
|
|
|
@ -23,28 +23,6 @@ def validate_environment():
|
||||||
raise ValueError("Missing required environment variable - AWS_REGION_NAME")
|
raise ValueError("Missing required environment variable - AWS_REGION_NAME")
|
||||||
|
|
||||||
|
|
||||||
def load_aws_secret_manager(use_aws_secret_manager: Optional[bool]):
|
|
||||||
if use_aws_secret_manager is None or use_aws_secret_manager is False:
|
|
||||||
return
|
|
||||||
try:
|
|
||||||
import boto3
|
|
||||||
from botocore.exceptions import ClientError
|
|
||||||
|
|
||||||
validate_environment()
|
|
||||||
|
|
||||||
# Create a Secrets Manager client
|
|
||||||
session = boto3.session.Session() # type: ignore
|
|
||||||
client = session.client(
|
|
||||||
service_name="secretsmanager", region_name=os.getenv("AWS_REGION_NAME")
|
|
||||||
)
|
|
||||||
|
|
||||||
litellm.secret_manager_client = client
|
|
||||||
litellm._key_management_system = KeyManagementSystem.AWS_SECRET_MANAGER
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
raise e
|
|
||||||
|
|
||||||
|
|
||||||
def load_aws_kms(use_aws_kms: Optional[bool]):
|
def load_aws_kms(use_aws_kms: Optional[bool]):
|
||||||
if use_aws_kms is None or use_aws_kms is False:
|
if use_aws_kms is None or use_aws_kms is False:
|
||||||
return
|
return
|
||||||
|
|
310
litellm/secret_managers/aws_secret_manager_v2.py
Normal file
310
litellm/secret_managers/aws_secret_manager_v2.py
Normal file
|
@ -0,0 +1,310 @@
|
||||||
|
"""
|
||||||
|
This is a file for the AWS Secret Manager Integration
|
||||||
|
|
||||||
|
Handles Async Operations for:
|
||||||
|
- Read Secret
|
||||||
|
- Write Secret
|
||||||
|
- Delete Secret
|
||||||
|
|
||||||
|
Relevant issue: https://github.com/BerriAI/litellm/issues/1883
|
||||||
|
|
||||||
|
Requires:
|
||||||
|
* `os.environ["AWS_REGION_NAME"],
|
||||||
|
* `pip install boto3>=1.28.57`
|
||||||
|
"""
|
||||||
|
|
||||||
|
import ast
|
||||||
|
import asyncio
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
from typing import Any, Dict, Optional, Union
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm._logging import verbose_logger
|
||||||
|
from litellm.llms.base_aws_llm import BaseAWSLLM
|
||||||
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
|
_get_httpx_client,
|
||||||
|
get_async_httpx_client,
|
||||||
|
)
|
||||||
|
from litellm.llms.custom_httpx.types import httpxSpecialProvider
|
||||||
|
from litellm.proxy._types import KeyManagementSystem
|
||||||
|
|
||||||
|
|
||||||
|
class AWSSecretsManagerV2(BaseAWSLLM):
|
||||||
|
@classmethod
|
||||||
|
def validate_environment(cls):
|
||||||
|
if "AWS_REGION_NAME" not in os.environ:
|
||||||
|
raise ValueError("Missing required environment variable - AWS_REGION_NAME")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load_aws_secret_manager(cls, use_aws_secret_manager: Optional[bool]):
|
||||||
|
"""
|
||||||
|
Initialize AWSSecretsManagerV2 and sets litellm.secret_manager_client = AWSSecretsManagerV2() and litellm._key_management_system = KeyManagementSystem.AWS_SECRET_MANAGER
|
||||||
|
"""
|
||||||
|
if use_aws_secret_manager is None or use_aws_secret_manager is False:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
import boto3
|
||||||
|
|
||||||
|
cls.validate_environment()
|
||||||
|
litellm.secret_manager_client = cls()
|
||||||
|
litellm._key_management_system = KeyManagementSystem.AWS_SECRET_MANAGER
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
async def async_read_secret(
|
||||||
|
self,
|
||||||
|
secret_name: str,
|
||||||
|
optional_params: Optional[dict] = None,
|
||||||
|
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Async function to read a secret from AWS Secrets Manager
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Secret value
|
||||||
|
Raises:
|
||||||
|
ValueError: If the secret is not found or an HTTP error occurs
|
||||||
|
"""
|
||||||
|
endpoint_url, headers, body = self._prepare_request(
|
||||||
|
action="GetSecretValue",
|
||||||
|
secret_name=secret_name,
|
||||||
|
optional_params=optional_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
async_client = get_async_httpx_client(
|
||||||
|
llm_provider=httpxSpecialProvider.SecretManager,
|
||||||
|
params={"timeout": timeout},
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await async_client.post(
|
||||||
|
url=endpoint_url, headers=headers, data=body.decode("utf-8")
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()["SecretString"]
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
raise ValueError("Timeout error occurred")
|
||||||
|
except Exception as e:
|
||||||
|
verbose_logger.exception(
|
||||||
|
"Error reading secret from AWS Secrets Manager: %s", str(e)
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def sync_read_secret(
|
||||||
|
self,
|
||||||
|
secret_name: str,
|
||||||
|
optional_params: Optional[dict] = None,
|
||||||
|
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Sync function to read a secret from AWS Secrets Manager
|
||||||
|
|
||||||
|
Done for backwards compatibility with existing codebase, since get_secret is a sync function
|
||||||
|
"""
|
||||||
|
|
||||||
|
# self._prepare_request uses these env vars, we cannot read them from AWS Secrets Manager. If we do we'd get stuck in an infinite loop
|
||||||
|
if secret_name in [
|
||||||
|
"AWS_ACCESS_KEY_ID",
|
||||||
|
"AWS_SECRET_ACCESS_KEY",
|
||||||
|
"AWS_REGION_NAME",
|
||||||
|
"AWS_REGION",
|
||||||
|
"AWS_BEDROCK_RUNTIME_ENDPOINT",
|
||||||
|
]:
|
||||||
|
return os.getenv(secret_name)
|
||||||
|
|
||||||
|
endpoint_url, headers, body = self._prepare_request(
|
||||||
|
action="GetSecretValue",
|
||||||
|
secret_name=secret_name,
|
||||||
|
optional_params=optional_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
sync_client = _get_httpx_client(
|
||||||
|
params={"timeout": timeout},
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = sync_client.post(
|
||||||
|
url=endpoint_url, headers=headers, data=body.decode("utf-8")
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()["SecretString"]
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
raise ValueError("Timeout error occurred")
|
||||||
|
except Exception as e:
|
||||||
|
verbose_logger.exception(
|
||||||
|
"Error reading secret from AWS Secrets Manager: %s", str(e)
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def async_write_secret(
|
||||||
|
self,
|
||||||
|
secret_name: str,
|
||||||
|
secret_value: str,
|
||||||
|
description: Optional[str] = None,
|
||||||
|
client_request_token: Optional[str] = None,
|
||||||
|
optional_params: Optional[dict] = None,
|
||||||
|
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Async function to write a secret to AWS Secrets Manager
|
||||||
|
|
||||||
|
Args:
|
||||||
|
secret_name: Name of the secret
|
||||||
|
secret_value: Value to store (can be a JSON string)
|
||||||
|
description: Optional description for the secret
|
||||||
|
client_request_token: Optional unique identifier to ensure idempotency
|
||||||
|
optional_params: Additional AWS parameters
|
||||||
|
timeout: Request timeout
|
||||||
|
"""
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
# Prepare the request data
|
||||||
|
data = {"Name": secret_name, "SecretString": secret_value}
|
||||||
|
if description:
|
||||||
|
data["Description"] = description
|
||||||
|
|
||||||
|
data["ClientRequestToken"] = str(uuid.uuid4())
|
||||||
|
|
||||||
|
endpoint_url, headers, body = self._prepare_request(
|
||||||
|
action="CreateSecret",
|
||||||
|
secret_name=secret_name,
|
||||||
|
secret_value=secret_value,
|
||||||
|
optional_params=optional_params,
|
||||||
|
request_data=data, # Pass the complete request data
|
||||||
|
)
|
||||||
|
|
||||||
|
async_client = get_async_httpx_client(
|
||||||
|
llm_provider=httpxSpecialProvider.SecretManager,
|
||||||
|
params={"timeout": timeout},
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await async_client.post(
|
||||||
|
url=endpoint_url, headers=headers, data=body.decode("utf-8")
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
except httpx.HTTPStatusError as err:
|
||||||
|
raise ValueError(f"HTTP error occurred: {err.response.text}")
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
raise ValueError("Timeout error occurred")
|
||||||
|
|
||||||
|
async def async_delete_secret(
|
||||||
|
self,
|
||||||
|
secret_name: str,
|
||||||
|
recovery_window_in_days: Optional[int] = 7,
|
||||||
|
optional_params: Optional[dict] = None,
|
||||||
|
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Async function to delete a secret from AWS Secrets Manager
|
||||||
|
|
||||||
|
Args:
|
||||||
|
secret_name: Name of the secret to delete
|
||||||
|
recovery_window_in_days: Number of days before permanent deletion (default: 7)
|
||||||
|
optional_params: Additional AWS parameters
|
||||||
|
timeout: Request timeout
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Response from AWS Secrets Manager containing deletion details
|
||||||
|
"""
|
||||||
|
# Prepare the request data
|
||||||
|
data = {
|
||||||
|
"SecretId": secret_name,
|
||||||
|
"RecoveryWindowInDays": recovery_window_in_days,
|
||||||
|
}
|
||||||
|
|
||||||
|
endpoint_url, headers, body = self._prepare_request(
|
||||||
|
action="DeleteSecret",
|
||||||
|
secret_name=secret_name,
|
||||||
|
optional_params=optional_params,
|
||||||
|
request_data=data,
|
||||||
|
)
|
||||||
|
|
||||||
|
async_client = get_async_httpx_client(
|
||||||
|
llm_provider=httpxSpecialProvider.SecretManager,
|
||||||
|
params={"timeout": timeout},
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await async_client.post(
|
||||||
|
url=endpoint_url, headers=headers, data=body.decode("utf-8")
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
except httpx.HTTPStatusError as err:
|
||||||
|
raise ValueError(f"HTTP error occurred: {err.response.text}")
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
raise ValueError("Timeout error occurred")
|
||||||
|
|
||||||
|
def _prepare_request(
|
||||||
|
self,
|
||||||
|
action: str, # "GetSecretValue" or "PutSecretValue"
|
||||||
|
secret_name: str,
|
||||||
|
secret_value: Optional[str] = None,
|
||||||
|
optional_params: Optional[dict] = None,
|
||||||
|
request_data: Optional[dict] = None,
|
||||||
|
) -> tuple[str, Any, bytes]:
|
||||||
|
"""Prepare the AWS Secrets Manager request"""
|
||||||
|
try:
|
||||||
|
import boto3
|
||||||
|
from botocore.auth import SigV4Auth
|
||||||
|
from botocore.awsrequest import AWSRequest
|
||||||
|
from botocore.credentials import Credentials
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||||
|
optional_params = optional_params or {}
|
||||||
|
boto3_credentials_info = self._get_boto_credentials_from_optional_params(
|
||||||
|
optional_params
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get endpoint
|
||||||
|
_, endpoint_url = self.get_runtime_endpoint(
|
||||||
|
api_base=None,
|
||||||
|
aws_bedrock_runtime_endpoint=boto3_credentials_info.aws_bedrock_runtime_endpoint,
|
||||||
|
aws_region_name=boto3_credentials_info.aws_region_name,
|
||||||
|
)
|
||||||
|
endpoint_url = endpoint_url.replace("bedrock-runtime", "secretsmanager")
|
||||||
|
|
||||||
|
# Use provided request_data if available, otherwise build default data
|
||||||
|
if request_data:
|
||||||
|
data = request_data
|
||||||
|
else:
|
||||||
|
data = {"SecretId": secret_name}
|
||||||
|
if secret_value and action == "PutSecretValue":
|
||||||
|
data["SecretString"] = secret_value
|
||||||
|
|
||||||
|
body = json.dumps(data).encode("utf-8")
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/x-amz-json-1.1",
|
||||||
|
"X-Amz-Target": f"secretsmanager.{action}",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Sign request
|
||||||
|
request = AWSRequest(
|
||||||
|
method="POST", url=endpoint_url, data=body, headers=headers
|
||||||
|
)
|
||||||
|
SigV4Auth(
|
||||||
|
boto3_credentials_info.credentials,
|
||||||
|
"secretsmanager",
|
||||||
|
boto3_credentials_info.aws_region_name,
|
||||||
|
).add_auth(request)
|
||||||
|
prepped = request.prepare()
|
||||||
|
|
||||||
|
return endpoint_url, prepped.headers, body
|
||||||
|
|
||||||
|
|
||||||
|
# if __name__ == "__main__":
|
||||||
|
# print("loading aws secret manager v2")
|
||||||
|
# aws_secret_manager_v2 = AWSSecretsManagerV2()
|
||||||
|
|
||||||
|
# print("writing secret to aws secret manager v2")
|
||||||
|
# asyncio.run(aws_secret_manager_v2.async_write_secret(secret_name="test_secret_3", secret_value="test_value_2"))
|
||||||
|
# print("reading secret from aws secret manager v2")
|
|
@ -5,7 +5,7 @@ import json
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Any, Optional, Union
|
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
@ -198,7 +198,10 @@ def get_secret( # noqa: PLR0915
|
||||||
raise ValueError("Unsupported OIDC provider")
|
raise ValueError("Unsupported OIDC provider")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if litellm.secret_manager_client is not None:
|
if (
|
||||||
|
_should_read_secret_from_secret_manager()
|
||||||
|
and litellm.secret_manager_client is not None
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
client = litellm.secret_manager_client
|
client = litellm.secret_manager_client
|
||||||
key_manager = "local"
|
key_manager = "local"
|
||||||
|
@ -207,7 +210,8 @@ def get_secret( # noqa: PLR0915
|
||||||
|
|
||||||
if key_management_settings is not None:
|
if key_management_settings is not None:
|
||||||
if (
|
if (
|
||||||
secret_name not in key_management_settings.hosted_keys
|
key_management_settings.hosted_keys is not None
|
||||||
|
and secret_name not in key_management_settings.hosted_keys
|
||||||
): # allow user to specify which keys to check in hosted key manager
|
): # allow user to specify which keys to check in hosted key manager
|
||||||
key_manager = "local"
|
key_manager = "local"
|
||||||
|
|
||||||
|
@ -268,25 +272,13 @@ def get_secret( # noqa: PLR0915
|
||||||
if isinstance(secret, str):
|
if isinstance(secret, str):
|
||||||
secret = secret.strip()
|
secret = secret.strip()
|
||||||
elif key_manager == KeyManagementSystem.AWS_SECRET_MANAGER.value:
|
elif key_manager == KeyManagementSystem.AWS_SECRET_MANAGER.value:
|
||||||
try:
|
from litellm.secret_managers.aws_secret_manager_v2 import (
|
||||||
get_secret_value_response = client.get_secret_value(
|
AWSSecretsManagerV2,
|
||||||
SecretId=secret_name
|
|
||||||
)
|
)
|
||||||
print_verbose(
|
|
||||||
f"get_secret_value_response: {get_secret_value_response}"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
print_verbose(f"An error occurred - {str(e)}")
|
|
||||||
# For a list of exceptions thrown, see
|
|
||||||
# https://docs.aws.amazon.com/secretsmanager/latest/apireference/API_GetSecretValue.html
|
|
||||||
raise e
|
|
||||||
|
|
||||||
# assume there is 1 secret per secret_name
|
if isinstance(client, AWSSecretsManagerV2):
|
||||||
secret_dict = json.loads(get_secret_value_response["SecretString"])
|
secret = client.sync_read_secret(secret_name=secret_name)
|
||||||
print_verbose(f"secret_dict: {secret_dict}")
|
print_verbose(f"get_secret_value_response: {secret}")
|
||||||
for k, v in secret_dict.items():
|
|
||||||
secret = v
|
|
||||||
print_verbose(f"secret: {secret}")
|
|
||||||
elif key_manager == KeyManagementSystem.GOOGLE_SECRET_MANAGER.value:
|
elif key_manager == KeyManagementSystem.GOOGLE_SECRET_MANAGER.value:
|
||||||
try:
|
try:
|
||||||
secret = client.get_secret_from_google_secret_manager(
|
secret = client.get_secret_from_google_secret_manager(
|
||||||
|
@ -332,3 +324,21 @@ def get_secret( # noqa: PLR0915
|
||||||
return default_value
|
return default_value
|
||||||
else:
|
else:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
def _should_read_secret_from_secret_manager() -> bool:
|
||||||
|
"""
|
||||||
|
Returns True if the secret manager should be used to read the secret, False otherwise
|
||||||
|
|
||||||
|
- If the secret manager client is not set, return False
|
||||||
|
- If the `_key_management_settings` access mode is "read_only" or "read_and_write", return True
|
||||||
|
- Otherwise, return False
|
||||||
|
"""
|
||||||
|
if litellm.secret_manager_client is not None:
|
||||||
|
if litellm._key_management_settings is not None:
|
||||||
|
if (
|
||||||
|
litellm._key_management_settings.access_mode == "read_only"
|
||||||
|
or litellm._key_management_settings.access_mode == "read_and_write"
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
|
@ -7,6 +7,7 @@ https://docs.cohere.com/reference/rerank
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, PrivateAttr
|
from pydantic import BaseModel, PrivateAttr
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
|
|
||||||
class RerankRequest(BaseModel):
|
class RerankRequest(BaseModel):
|
||||||
|
@ -19,10 +20,26 @@ class RerankRequest(BaseModel):
|
||||||
max_chunks_per_doc: Optional[int] = None
|
max_chunks_per_doc: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
|
class RerankBilledUnits(TypedDict, total=False):
|
||||||
|
search_units: int
|
||||||
|
total_tokens: int
|
||||||
|
|
||||||
|
|
||||||
|
class RerankTokens(TypedDict, total=False):
|
||||||
|
input_tokens: int
|
||||||
|
output_tokens: int
|
||||||
|
|
||||||
|
|
||||||
|
class RerankResponseMeta(TypedDict, total=False):
|
||||||
|
api_version: dict
|
||||||
|
billed_units: RerankBilledUnits
|
||||||
|
tokens: RerankTokens
|
||||||
|
|
||||||
|
|
||||||
class RerankResponse(BaseModel):
|
class RerankResponse(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
results: List[dict] # Contains index and relevance_score
|
results: List[dict] # Contains index and relevance_score
|
||||||
meta: Optional[dict] = None # Contains api_version and billed_units
|
meta: Optional[RerankResponseMeta] = None # Contains api_version and billed_units
|
||||||
|
|
||||||
# Define private attributes using PrivateAttr
|
# Define private attributes using PrivateAttr
|
||||||
_hidden_params: dict = PrivateAttr(default_factory=dict)
|
_hidden_params: dict = PrivateAttr(default_factory=dict)
|
||||||
|
|
|
@ -150,6 +150,8 @@ class GenericLiteLLMParams(BaseModel):
|
||||||
max_retries: Optional[int] = None
|
max_retries: Optional[int] = None
|
||||||
organization: Optional[str] = None # for openai orgs
|
organization: Optional[str] = None # for openai orgs
|
||||||
configurable_clientside_auth_params: CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS = None
|
configurable_clientside_auth_params: CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS = None
|
||||||
|
## LOGGING PARAMS ##
|
||||||
|
litellm_trace_id: Optional[str] = None
|
||||||
## UNIFIED PROJECT/REGION ##
|
## UNIFIED PROJECT/REGION ##
|
||||||
region_name: Optional[str] = None
|
region_name: Optional[str] = None
|
||||||
## VERTEX AI ##
|
## VERTEX AI ##
|
||||||
|
@ -186,6 +188,8 @@ class GenericLiteLLMParams(BaseModel):
|
||||||
None # timeout when making stream=True calls, if str, pass in as os.environ/
|
None # timeout when making stream=True calls, if str, pass in as os.environ/
|
||||||
),
|
),
|
||||||
organization: Optional[str] = None, # for openai orgs
|
organization: Optional[str] = None, # for openai orgs
|
||||||
|
## LOGGING PARAMS ##
|
||||||
|
litellm_trace_id: Optional[str] = None,
|
||||||
## UNIFIED PROJECT/REGION ##
|
## UNIFIED PROJECT/REGION ##
|
||||||
region_name: Optional[str] = None,
|
region_name: Optional[str] = None,
|
||||||
## VERTEX AI ##
|
## VERTEX AI ##
|
||||||
|
|
|
@ -1334,6 +1334,7 @@ class ResponseFormatChunk(TypedDict, total=False):
|
||||||
|
|
||||||
all_litellm_params = [
|
all_litellm_params = [
|
||||||
"metadata",
|
"metadata",
|
||||||
|
"litellm_trace_id",
|
||||||
"tags",
|
"tags",
|
||||||
"acompletion",
|
"acompletion",
|
||||||
"aimg_generation",
|
"aimg_generation",
|
||||||
|
@ -1523,6 +1524,7 @@ StandardLoggingPayloadStatus = Literal["success", "failure"]
|
||||||
|
|
||||||
class StandardLoggingPayload(TypedDict):
|
class StandardLoggingPayload(TypedDict):
|
||||||
id: str
|
id: str
|
||||||
|
trace_id: str # Trace multiple LLM calls belonging to same overall request (e.g. fallbacks/retries)
|
||||||
call_type: str
|
call_type: str
|
||||||
response_cost: float
|
response_cost: float
|
||||||
response_cost_failure_debug_info: Optional[
|
response_cost_failure_debug_info: Optional[
|
||||||
|
|
|
@ -527,6 +527,7 @@ def function_setup( # noqa: PLR0915
|
||||||
messages=messages,
|
messages=messages,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
litellm_call_id=kwargs["litellm_call_id"],
|
litellm_call_id=kwargs["litellm_call_id"],
|
||||||
|
litellm_trace_id=kwargs.get("litellm_trace_id"),
|
||||||
function_id=function_id or "",
|
function_id=function_id or "",
|
||||||
call_type=call_type,
|
call_type=call_type,
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
|
@ -2056,6 +2057,7 @@ def get_litellm_params(
|
||||||
azure_ad_token_provider=None,
|
azure_ad_token_provider=None,
|
||||||
user_continue_message=None,
|
user_continue_message=None,
|
||||||
base_model=None,
|
base_model=None,
|
||||||
|
litellm_trace_id=None,
|
||||||
):
|
):
|
||||||
litellm_params = {
|
litellm_params = {
|
||||||
"acompletion": acompletion,
|
"acompletion": acompletion,
|
||||||
|
@ -2084,6 +2086,7 @@ def get_litellm_params(
|
||||||
"user_continue_message": user_continue_message,
|
"user_continue_message": user_continue_message,
|
||||||
"base_model": base_model
|
"base_model": base_model
|
||||||
or _get_base_model_from_litellm_call_metadata(metadata=metadata),
|
or _get_base_model_from_litellm_call_metadata(metadata=metadata),
|
||||||
|
"litellm_trace_id": litellm_trace_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
return litellm_params
|
return litellm_params
|
||||||
|
|
|
@ -2986,19 +2986,19 @@
|
||||||
"supports_function_calling": true
|
"supports_function_calling": true
|
||||||
},
|
},
|
||||||
"vertex_ai/imagegeneration@006": {
|
"vertex_ai/imagegeneration@006": {
|
||||||
"cost_per_image": 0.020,
|
"output_cost_per_image": 0.020,
|
||||||
"litellm_provider": "vertex_ai-image-models",
|
"litellm_provider": "vertex_ai-image-models",
|
||||||
"mode": "image_generation",
|
"mode": "image_generation",
|
||||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing"
|
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing"
|
||||||
},
|
},
|
||||||
"vertex_ai/imagen-3.0-generate-001": {
|
"vertex_ai/imagen-3.0-generate-001": {
|
||||||
"cost_per_image": 0.04,
|
"output_cost_per_image": 0.04,
|
||||||
"litellm_provider": "vertex_ai-image-models",
|
"litellm_provider": "vertex_ai-image-models",
|
||||||
"mode": "image_generation",
|
"mode": "image_generation",
|
||||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing"
|
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing"
|
||||||
},
|
},
|
||||||
"vertex_ai/imagen-3.0-fast-generate-001": {
|
"vertex_ai/imagen-3.0-fast-generate-001": {
|
||||||
"cost_per_image": 0.02,
|
"output_cost_per_image": 0.02,
|
||||||
"litellm_provider": "vertex_ai-image-models",
|
"litellm_provider": "vertex_ai-image-models",
|
||||||
"mode": "image_generation",
|
"mode": "image_generation",
|
||||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing"
|
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing"
|
||||||
|
@ -5620,6 +5620,13 @@
|
||||||
"litellm_provider": "bedrock",
|
"litellm_provider": "bedrock",
|
||||||
"mode": "image_generation"
|
"mode": "image_generation"
|
||||||
},
|
},
|
||||||
|
"stability.stable-image-ultra-v1:0": {
|
||||||
|
"max_tokens": 77,
|
||||||
|
"max_input_tokens": 77,
|
||||||
|
"output_cost_per_image": 0.14,
|
||||||
|
"litellm_provider": "bedrock",
|
||||||
|
"mode": "image_generation"
|
||||||
|
},
|
||||||
"sagemaker/meta-textgeneration-llama-2-7b": {
|
"sagemaker/meta-textgeneration-llama-2-7b": {
|
||||||
"max_tokens": 4096,
|
"max_tokens": 4096,
|
||||||
"max_input_tokens": 4096,
|
"max_input_tokens": 4096,
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "litellm"
|
name = "litellm"
|
||||||
version = "1.52.6"
|
version = "1.52.9"
|
||||||
description = "Library to easily interface with LLM API providers"
|
description = "Library to easily interface with LLM API providers"
|
||||||
authors = ["BerriAI"]
|
authors = ["BerriAI"]
|
||||||
license = "MIT"
|
license = "MIT"
|
||||||
|
@ -91,7 +91,7 @@ requires = ["poetry-core", "wheel"]
|
||||||
build-backend = "poetry.core.masonry.api"
|
build-backend = "poetry.core.masonry.api"
|
||||||
|
|
||||||
[tool.commitizen]
|
[tool.commitizen]
|
||||||
version = "1.52.6"
|
version = "1.52.9"
|
||||||
version_files = [
|
version_files = [
|
||||||
"pyproject.toml:^version"
|
"pyproject.toml:^version"
|
||||||
]
|
]
|
||||||
|
|
|
@ -13,8 +13,11 @@ sys.path.insert(
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.exceptions import BadRequestError
|
from litellm.exceptions import BadRequestError
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
from litellm.utils import CustomStreamWrapper
|
from litellm.utils import (
|
||||||
|
CustomStreamWrapper,
|
||||||
|
get_supported_openai_params,
|
||||||
|
get_optional_params,
|
||||||
|
)
|
||||||
|
|
||||||
# test_example.py
|
# test_example.py
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
@ -45,6 +48,9 @@ class BaseLLMChatTest(ABC):
|
||||||
)
|
)
|
||||||
assert response is not None
|
assert response is not None
|
||||||
|
|
||||||
|
# for OpenAI the content contains the JSON schema, so we need to assert that the content is not None
|
||||||
|
assert response.choices[0].message.content is not None
|
||||||
|
|
||||||
def test_message_with_name(self):
|
def test_message_with_name(self):
|
||||||
base_completion_call_args = self.get_base_completion_call_args()
|
base_completion_call_args = self.get_base_completion_call_args()
|
||||||
messages = [
|
messages = [
|
||||||
|
@ -79,6 +85,49 @@ class BaseLLMChatTest(ABC):
|
||||||
|
|
||||||
print(response)
|
print(response)
|
||||||
|
|
||||||
|
# OpenAI guarantees that the JSON schema is returned in the content
|
||||||
|
# relevant issue: https://github.com/BerriAI/litellm/issues/6741
|
||||||
|
assert response.choices[0].message.content is not None
|
||||||
|
|
||||||
|
def test_json_response_format_stream(self):
|
||||||
|
"""
|
||||||
|
Test that the JSON response format with streaming is supported by the LLM API
|
||||||
|
"""
|
||||||
|
base_completion_call_args = self.get_base_completion_call_args()
|
||||||
|
litellm.set_verbose = True
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "Your output should be a JSON object with no additional properties. ",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Respond with this in json. city=San Francisco, state=CA, weather=sunny, temp=60",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
response = litellm.completion(
|
||||||
|
**base_completion_call_args,
|
||||||
|
messages=messages,
|
||||||
|
response_format={"type": "json_object"},
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
|
||||||
|
content = ""
|
||||||
|
for chunk in response:
|
||||||
|
content += chunk.choices[0].delta.content or ""
|
||||||
|
|
||||||
|
print("content=", content)
|
||||||
|
|
||||||
|
# OpenAI guarantees that the JSON schema is returned in the content
|
||||||
|
# relevant issue: https://github.com/BerriAI/litellm/issues/6741
|
||||||
|
# we need to assert that the JSON schema was returned in the content, (for Anthropic we were returning it as part of the tool call)
|
||||||
|
assert content is not None
|
||||||
|
assert len(content) > 0
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def pdf_messages(self):
|
def pdf_messages(self):
|
||||||
import base64
|
import base64
|
||||||
|
|
115
tests/llm_translation/base_rerank_unit_tests.py
Normal file
115
tests/llm_translation/base_rerank_unit_tests.py
Normal file
|
@ -0,0 +1,115 @@
|
||||||
|
import asyncio
|
||||||
|
import httpx
|
||||||
|
import json
|
||||||
|
import pytest
|
||||||
|
import sys
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
from unittest.mock import MagicMock, Mock, patch
|
||||||
|
import os
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
import litellm
|
||||||
|
from litellm.exceptions import BadRequestError
|
||||||
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
|
from litellm.utils import (
|
||||||
|
CustomStreamWrapper,
|
||||||
|
get_supported_openai_params,
|
||||||
|
get_optional_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
# test_example.py
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
|
||||||
|
def assert_response_shape(response, custom_llm_provider):
|
||||||
|
expected_response_shape = {"id": str, "results": list, "meta": dict}
|
||||||
|
|
||||||
|
expected_results_shape = {"index": int, "relevance_score": float}
|
||||||
|
|
||||||
|
expected_meta_shape = {"api_version": dict, "billed_units": dict}
|
||||||
|
|
||||||
|
expected_api_version_shape = {"version": str}
|
||||||
|
|
||||||
|
expected_billed_units_shape = {"search_units": int}
|
||||||
|
|
||||||
|
assert isinstance(response.id, expected_response_shape["id"])
|
||||||
|
assert isinstance(response.results, expected_response_shape["results"])
|
||||||
|
for result in response.results:
|
||||||
|
assert isinstance(result["index"], expected_results_shape["index"])
|
||||||
|
assert isinstance(
|
||||||
|
result["relevance_score"], expected_results_shape["relevance_score"]
|
||||||
|
)
|
||||||
|
assert isinstance(response.meta, expected_response_shape["meta"])
|
||||||
|
|
||||||
|
if custom_llm_provider == "cohere":
|
||||||
|
|
||||||
|
assert isinstance(
|
||||||
|
response.meta["api_version"], expected_meta_shape["api_version"]
|
||||||
|
)
|
||||||
|
assert isinstance(
|
||||||
|
response.meta["api_version"]["version"],
|
||||||
|
expected_api_version_shape["version"],
|
||||||
|
)
|
||||||
|
assert isinstance(
|
||||||
|
response.meta["billed_units"], expected_meta_shape["billed_units"]
|
||||||
|
)
|
||||||
|
assert isinstance(
|
||||||
|
response.meta["billed_units"]["search_units"],
|
||||||
|
expected_billed_units_shape["search_units"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseLLMRerankTest(ABC):
|
||||||
|
"""
|
||||||
|
Abstract base test class that enforces a common test across all test classes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_base_rerank_call_args(self) -> dict:
|
||||||
|
"""Must return the base rerank call args"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_custom_llm_provider(self) -> litellm.LlmProviders:
|
||||||
|
"""Must return the custom llm provider"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@pytest.mark.asyncio()
|
||||||
|
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||||
|
async def test_basic_rerank(self, sync_mode):
|
||||||
|
rerank_call_args = self.get_base_rerank_call_args()
|
||||||
|
custom_llm_provider = self.get_custom_llm_provider()
|
||||||
|
if sync_mode is True:
|
||||||
|
response = litellm.rerank(
|
||||||
|
**rerank_call_args,
|
||||||
|
query="hello",
|
||||||
|
documents=["hello", "world"],
|
||||||
|
top_n=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("re rank response: ", response)
|
||||||
|
|
||||||
|
assert response.id is not None
|
||||||
|
assert response.results is not None
|
||||||
|
|
||||||
|
assert_response_shape(
|
||||||
|
response=response, custom_llm_provider=custom_llm_provider.value
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response = await litellm.arerank(
|
||||||
|
**rerank_call_args,
|
||||||
|
query="hello",
|
||||||
|
documents=["hello", "world"],
|
||||||
|
top_n=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("async re rank response: ", response)
|
||||||
|
|
||||||
|
assert response.id is not None
|
||||||
|
assert response.results is not None
|
||||||
|
|
||||||
|
assert_response_shape(
|
||||||
|
response=response, custom_llm_provider=custom_llm_provider.value
|
||||||
|
)
|
|
@ -33,8 +33,10 @@ from litellm import (
|
||||||
)
|
)
|
||||||
from litellm.adapters.anthropic_adapter import anthropic_adapter
|
from litellm.adapters.anthropic_adapter import anthropic_adapter
|
||||||
from litellm.types.llms.anthropic import AnthropicResponse
|
from litellm.types.llms.anthropic import AnthropicResponse
|
||||||
|
from litellm.types.utils import GenericStreamingChunk, ChatCompletionToolCallChunk
|
||||||
|
from litellm.types.llms.openai import ChatCompletionToolCallFunctionChunk
|
||||||
from litellm.llms.anthropic.common_utils import process_anthropic_headers
|
from litellm.llms.anthropic.common_utils import process_anthropic_headers
|
||||||
|
from litellm.llms.anthropic.chat.handler import AnthropicChatCompletion
|
||||||
from httpx import Headers
|
from httpx import Headers
|
||||||
from base_llm_unit_tests import BaseLLMChatTest
|
from base_llm_unit_tests import BaseLLMChatTest
|
||||||
|
|
||||||
|
@ -694,3 +696,91 @@ class TestAnthropicCompletion(BaseLLMChatTest):
|
||||||
assert _document_validation["type"] == "document"
|
assert _document_validation["type"] == "document"
|
||||||
assert _document_validation["source"]["media_type"] == "application/pdf"
|
assert _document_validation["source"]["media_type"] == "application/pdf"
|
||||||
assert _document_validation["source"]["type"] == "base64"
|
assert _document_validation["source"]["type"] == "base64"
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_tool_response_to_message_with_values():
|
||||||
|
"""Test converting a tool response with 'values' key to a message"""
|
||||||
|
tool_calls = [
|
||||||
|
ChatCompletionToolCallChunk(
|
||||||
|
id="test_id",
|
||||||
|
type="function",
|
||||||
|
function=ChatCompletionToolCallFunctionChunk(
|
||||||
|
name="json_tool_call",
|
||||||
|
arguments='{"values": {"name": "John", "age": 30}}',
|
||||||
|
),
|
||||||
|
index=0,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
message = AnthropicChatCompletion._convert_tool_response_to_message(
|
||||||
|
tool_calls=tool_calls
|
||||||
|
)
|
||||||
|
|
||||||
|
assert message is not None
|
||||||
|
assert message.content == '{"name": "John", "age": 30}'
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_tool_response_to_message_without_values():
|
||||||
|
"""
|
||||||
|
Test converting a tool response without 'values' key to a message
|
||||||
|
|
||||||
|
Anthropic API returns the JSON schema in the tool call, OpenAI Spec expects it in the message. This test ensures that the tool call is converted to a message correctly.
|
||||||
|
|
||||||
|
Relevant issue: https://github.com/BerriAI/litellm/issues/6741
|
||||||
|
"""
|
||||||
|
tool_calls = [
|
||||||
|
ChatCompletionToolCallChunk(
|
||||||
|
id="test_id",
|
||||||
|
type="function",
|
||||||
|
function=ChatCompletionToolCallFunctionChunk(
|
||||||
|
name="json_tool_call", arguments='{"name": "John", "age": 30}'
|
||||||
|
),
|
||||||
|
index=0,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
message = AnthropicChatCompletion._convert_tool_response_to_message(
|
||||||
|
tool_calls=tool_calls
|
||||||
|
)
|
||||||
|
|
||||||
|
assert message is not None
|
||||||
|
assert message.content == '{"name": "John", "age": 30}'
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_tool_response_to_message_invalid_json():
|
||||||
|
"""Test converting a tool response with invalid JSON"""
|
||||||
|
tool_calls = [
|
||||||
|
ChatCompletionToolCallChunk(
|
||||||
|
id="test_id",
|
||||||
|
type="function",
|
||||||
|
function=ChatCompletionToolCallFunctionChunk(
|
||||||
|
name="json_tool_call", arguments="invalid json"
|
||||||
|
),
|
||||||
|
index=0,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
message = AnthropicChatCompletion._convert_tool_response_to_message(
|
||||||
|
tool_calls=tool_calls
|
||||||
|
)
|
||||||
|
|
||||||
|
assert message is not None
|
||||||
|
assert message.content == "invalid json"
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_tool_response_to_message_no_arguments():
|
||||||
|
"""Test converting a tool response with no arguments"""
|
||||||
|
tool_calls = [
|
||||||
|
ChatCompletionToolCallChunk(
|
||||||
|
id="test_id",
|
||||||
|
type="function",
|
||||||
|
function=ChatCompletionToolCallFunctionChunk(name="json_tool_call"),
|
||||||
|
index=0,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
message = AnthropicChatCompletion._convert_tool_response_to_message(
|
||||||
|
tool_calls=tool_calls
|
||||||
|
)
|
||||||
|
|
||||||
|
assert message is None
|
||||||
|
|
23
tests/llm_translation/test_jina_ai.py
Normal file
23
tests/llm_translation/test_jina_ai.py
Normal file
|
@ -0,0 +1,23 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from datetime import datetime
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
|
||||||
|
|
||||||
|
from base_rerank_unit_tests import BaseLLMRerankTest
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
|
||||||
|
class TestJinaAI(BaseLLMRerankTest):
|
||||||
|
def get_custom_llm_provider(self) -> litellm.LlmProviders:
|
||||||
|
return litellm.LlmProviders.JINA_AI
|
||||||
|
|
||||||
|
def get_base_rerank_call_args(self) -> dict:
|
||||||
|
return {
|
||||||
|
"model": "jina_ai/jina-reranker-v2-base-multilingual",
|
||||||
|
}
|
|
@ -923,9 +923,22 @@ def test_watsonx_text_top_k():
|
||||||
assert optional_params["top_k"] == 10
|
assert optional_params["top_k"] == 10
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def test_together_ai_model_params():
|
def test_together_ai_model_params():
|
||||||
optional_params = get_optional_params(
|
optional_params = get_optional_params(
|
||||||
model="together_ai", custom_llm_provider="together_ai", logprobs=1
|
model="together_ai", custom_llm_provider="together_ai", logprobs=1
|
||||||
)
|
)
|
||||||
print(optional_params)
|
print(optional_params)
|
||||||
assert optional_params["logprobs"] == 1
|
assert optional_params["logprobs"] == 1
|
||||||
|
|
||||||
|
def test_forward_user_param():
|
||||||
|
from litellm.utils import get_supported_openai_params, get_optional_params
|
||||||
|
|
||||||
|
model = "claude-3-5-sonnet-20240620"
|
||||||
|
optional_params = get_optional_params(
|
||||||
|
model=model,
|
||||||
|
user="test_user",
|
||||||
|
custom_llm_provider="anthropic",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert optional_params["metadata"]["user_id"] == "test_user"
|
||||||
|
|
|
@ -16,6 +16,7 @@ import pytest
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import get_optional_params
|
from litellm import get_optional_params
|
||||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
|
||||||
def test_completion_pydantic_obj_2():
|
def test_completion_pydantic_obj_2():
|
||||||
|
@ -1317,3 +1318,39 @@ def test_image_completion_request(image_url):
|
||||||
mock_post.assert_called_once()
|
mock_post.assert_called_once()
|
||||||
print("mock_post.call_args.kwargs['json']", mock_post.call_args.kwargs["json"])
|
print("mock_post.call_args.kwargs['json']", mock_post.call_args.kwargs["json"])
|
||||||
assert mock_post.call_args.kwargs["json"] == expected_request_body
|
assert mock_post.call_args.kwargs["json"] == expected_request_body
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model, expected_url",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
"textembedding-gecko@001",
|
||||||
|
"https://us-central1-aiplatform.googleapis.com/v1/projects/project-id/locations/us-central1/publishers/google/models/textembedding-gecko@001:predict",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"123456789",
|
||||||
|
"https://us-central1-aiplatform.googleapis.com/v1/projects/project-id/locations/us-central1/endpoints/123456789:predict",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_vertex_embedding_url(model, expected_url):
|
||||||
|
"""
|
||||||
|
Test URL generation for embedding models, including numeric model IDs (fine-tuned models
|
||||||
|
|
||||||
|
Relevant issue: https://github.com/BerriAI/litellm/issues/6482
|
||||||
|
|
||||||
|
When a fine-tuned embedding model is used, the URL is different from the standard one.
|
||||||
|
"""
|
||||||
|
from litellm.llms.vertex_ai_and_google_ai_studio.common_utils import _get_vertex_url
|
||||||
|
|
||||||
|
url, endpoint = _get_vertex_url(
|
||||||
|
mode="embedding",
|
||||||
|
model=model,
|
||||||
|
stream=False,
|
||||||
|
vertex_project="project-id",
|
||||||
|
vertex_location="us-central1",
|
||||||
|
vertex_api_version="v1",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert url == expected_url
|
||||||
|
assert endpoint == "predict"
|
||||||
|
|
|
@ -18,6 +18,8 @@ import json
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
from respx import MockRouter
|
||||||
|
import httpx
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
@ -973,6 +975,7 @@ async def test_partner_models_httpx(model, sync_mode):
|
||||||
data = {
|
data = {
|
||||||
"model": model,
|
"model": model,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
|
"timeout": 10,
|
||||||
}
|
}
|
||||||
if sync_mode:
|
if sync_mode:
|
||||||
response = litellm.completion(**data)
|
response = litellm.completion(**data)
|
||||||
|
@ -986,6 +989,8 @@ async def test_partner_models_httpx(model, sync_mode):
|
||||||
assert isinstance(response._hidden_params["response_cost"], float)
|
assert isinstance(response._hidden_params["response_cost"], float)
|
||||||
except litellm.RateLimitError as e:
|
except litellm.RateLimitError as e:
|
||||||
pass
|
pass
|
||||||
|
except litellm.Timeout as e:
|
||||||
|
pass
|
||||||
except litellm.InternalServerError as e:
|
except litellm.InternalServerError as e:
|
||||||
pass
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -3051,3 +3056,70 @@ def test_custom_api_base(api_base):
|
||||||
assert url == api_base + ":"
|
assert url == api_base + ":"
|
||||||
else:
|
else:
|
||||||
assert url == test_endpoint
|
assert url == test_endpoint
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.respx
|
||||||
|
async def test_vertexai_embedding_finetuned(respx_mock: MockRouter):
|
||||||
|
"""
|
||||||
|
Tests that:
|
||||||
|
- Request URL and body are correctly formatted for Vertex AI embeddings
|
||||||
|
- Response is properly parsed into litellm's embedding response format
|
||||||
|
"""
|
||||||
|
load_vertex_ai_credentials()
|
||||||
|
litellm.set_verbose = True
|
||||||
|
|
||||||
|
# Test input
|
||||||
|
input_text = ["good morning from litellm", "this is another item"]
|
||||||
|
|
||||||
|
# Expected request/response
|
||||||
|
expected_url = "https://us-central1-aiplatform.googleapis.com/v1/projects/633608382793/locations/us-central1/endpoints/1004708436694269952:predict"
|
||||||
|
expected_request = {
|
||||||
|
"instances": [
|
||||||
|
{"inputs": "good morning from litellm"},
|
||||||
|
{"inputs": "this is another item"},
|
||||||
|
],
|
||||||
|
"parameters": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_response = {
|
||||||
|
"predictions": [
|
||||||
|
[[-0.000431762, -0.04416759, -0.03443353]], # Truncated embedding vector
|
||||||
|
[[-0.000431762, -0.04416759, -0.03443353]], # Truncated embedding vector
|
||||||
|
],
|
||||||
|
"deployedModelId": "2275167734310371328",
|
||||||
|
"model": "projects/633608382793/locations/us-central1/models/snowflake-arctic-embed-m-long-1731622468876",
|
||||||
|
"modelDisplayName": "snowflake-arctic-embed-m-long-1731622468876",
|
||||||
|
"modelVersionId": "1",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Setup mock request
|
||||||
|
mock_request = respx_mock.post(expected_url).mock(
|
||||||
|
return_value=httpx.Response(200, json=mock_response)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Make request
|
||||||
|
response = await litellm.aembedding(
|
||||||
|
vertex_project="633608382793",
|
||||||
|
model="vertex_ai/1004708436694269952",
|
||||||
|
input=input_text,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert request was made correctly
|
||||||
|
assert mock_request.called
|
||||||
|
request_body = json.loads(mock_request.calls[0].request.content)
|
||||||
|
print("\n\nrequest_body", request_body)
|
||||||
|
print("\n\nexpected_request", expected_request)
|
||||||
|
assert request_body == expected_request
|
||||||
|
|
||||||
|
# Assert response structure
|
||||||
|
assert response is not None
|
||||||
|
assert hasattr(response, "data")
|
||||||
|
assert len(response.data) == len(input_text)
|
||||||
|
|
||||||
|
# Assert embedding structure
|
||||||
|
for embedding in response.data:
|
||||||
|
assert "embedding" in embedding
|
||||||
|
assert isinstance(embedding["embedding"], list)
|
||||||
|
assert len(embedding["embedding"]) > 0
|
||||||
|
assert all(isinstance(x, float) for x in embedding["embedding"])
|
||||||
|
|
139
tests/local_testing/test_aws_secret_manager.py
Normal file
139
tests/local_testing/test_aws_secret_manager.py
Normal file
|
@ -0,0 +1,139 @@
|
||||||
|
# What is this?
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
import litellm.types
|
||||||
|
import litellm.types.utils
|
||||||
|
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
import io
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Ensure the project root is in the Python path
|
||||||
|
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..")))
|
||||||
|
|
||||||
|
print("Python Path:", sys.path)
|
||||||
|
print("Current Working Directory:", os.getcwd())
|
||||||
|
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import uuid
|
||||||
|
import json
|
||||||
|
from litellm.secret_managers.aws_secret_manager_v2 import AWSSecretsManagerV2
|
||||||
|
|
||||||
|
|
||||||
|
def check_aws_credentials():
|
||||||
|
"""Helper function to check if AWS credentials are set"""
|
||||||
|
required_vars = ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_REGION_NAME"]
|
||||||
|
missing_vars = [var for var in required_vars if not os.getenv(var)]
|
||||||
|
if missing_vars:
|
||||||
|
pytest.skip(f"Missing required AWS credentials: {', '.join(missing_vars)}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_write_and_read_simple_secret():
|
||||||
|
"""Test writing and reading a simple string secret"""
|
||||||
|
check_aws_credentials()
|
||||||
|
|
||||||
|
secret_manager = AWSSecretsManagerV2()
|
||||||
|
test_secret_name = f"litellm_test_{uuid.uuid4().hex[:8]}"
|
||||||
|
test_secret_value = "test_value_123"
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Write secret
|
||||||
|
write_response = await secret_manager.async_write_secret(
|
||||||
|
secret_name=test_secret_name,
|
||||||
|
secret_value=test_secret_value,
|
||||||
|
description="LiteLLM Test Secret",
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Write Response:", write_response)
|
||||||
|
|
||||||
|
assert write_response is not None
|
||||||
|
assert "ARN" in write_response
|
||||||
|
assert "Name" in write_response
|
||||||
|
assert write_response["Name"] == test_secret_name
|
||||||
|
|
||||||
|
# Read secret back
|
||||||
|
read_value = await secret_manager.async_read_secret(
|
||||||
|
secret_name=test_secret_name
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Read Value:", read_value)
|
||||||
|
|
||||||
|
assert read_value == test_secret_value
|
||||||
|
finally:
|
||||||
|
# Cleanup: Delete the secret
|
||||||
|
delete_response = await secret_manager.async_delete_secret(
|
||||||
|
secret_name=test_secret_name
|
||||||
|
)
|
||||||
|
print("Delete Response:", delete_response)
|
||||||
|
assert delete_response is not None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_write_and_read_json_secret():
|
||||||
|
"""Test writing and reading a JSON structured secret"""
|
||||||
|
check_aws_credentials()
|
||||||
|
|
||||||
|
secret_manager = AWSSecretsManagerV2()
|
||||||
|
test_secret_name = f"litellm_test_{uuid.uuid4().hex[:8]}_json"
|
||||||
|
test_secret_value = {
|
||||||
|
"api_key": "test_key",
|
||||||
|
"model": "gpt-4",
|
||||||
|
"temperature": 0.7,
|
||||||
|
"metadata": {"team": "ml", "project": "litellm"},
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Write JSON secret
|
||||||
|
write_response = await secret_manager.async_write_secret(
|
||||||
|
secret_name=test_secret_name,
|
||||||
|
secret_value=json.dumps(test_secret_value),
|
||||||
|
description="LiteLLM JSON Test Secret",
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Write Response:", write_response)
|
||||||
|
|
||||||
|
# Read and parse JSON secret
|
||||||
|
read_value = await secret_manager.async_read_secret(
|
||||||
|
secret_name=test_secret_name
|
||||||
|
)
|
||||||
|
parsed_value = json.loads(read_value)
|
||||||
|
|
||||||
|
print("Read Value:", read_value)
|
||||||
|
|
||||||
|
assert parsed_value == test_secret_value
|
||||||
|
assert parsed_value["api_key"] == "test_key"
|
||||||
|
assert parsed_value["metadata"]["team"] == "ml"
|
||||||
|
finally:
|
||||||
|
# Cleanup: Delete the secret
|
||||||
|
delete_response = await secret_manager.async_delete_secret(
|
||||||
|
secret_name=test_secret_name
|
||||||
|
)
|
||||||
|
print("Delete Response:", delete_response)
|
||||||
|
assert delete_response is not None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_read_nonexistent_secret():
|
||||||
|
"""Test reading a secret that doesn't exist"""
|
||||||
|
check_aws_credentials()
|
||||||
|
|
||||||
|
secret_manager = AWSSecretsManagerV2()
|
||||||
|
nonexistent_secret = f"litellm_nonexistent_{uuid.uuid4().hex}"
|
||||||
|
|
||||||
|
response = await secret_manager.async_read_secret(secret_name=nonexistent_secret)
|
||||||
|
|
||||||
|
assert response is None
|
|
@ -24,7 +24,7 @@ from litellm import RateLimitError, Timeout, completion, completion_cost, embedd
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
from litellm.llms.prompt_templates.factory import anthropic_messages_pt
|
from litellm.llms.prompt_templates.factory import anthropic_messages_pt
|
||||||
|
|
||||||
# litellm.num_retries = 3
|
# litellm.num_retries=3
|
||||||
|
|
||||||
litellm.cache = None
|
litellm.cache = None
|
||||||
litellm.success_callback = []
|
litellm.success_callback = []
|
||||||
|
|
|
@ -10,7 +10,7 @@ import os
|
||||||
|
|
||||||
sys.path.insert(
|
sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system-path
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
|
@ -1624,3 +1624,55 @@ async def test_standard_logging_payload_stream_usage(sync_mode):
|
||||||
print(f"standard_logging_object usage: {built_response.usage}")
|
print(f"standard_logging_object usage: {built_response.usage}")
|
||||||
except litellm.InternalServerError:
|
except litellm.InternalServerError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_standard_logging_retries():
|
||||||
|
"""
|
||||||
|
know if a request was retried.
|
||||||
|
"""
|
||||||
|
from litellm.types.utils import StandardLoggingPayload
|
||||||
|
from litellm.router import Router
|
||||||
|
|
||||||
|
customHandler = CompletionCustomHandler()
|
||||||
|
litellm.callbacks = [customHandler]
|
||||||
|
|
||||||
|
router = Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "openai/gpt-3.5-turbo",
|
||||||
|
"api_key": "test-api-key",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
customHandler, "log_failure_event", new=MagicMock()
|
||||||
|
) as mock_client:
|
||||||
|
try:
|
||||||
|
router.completion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||||
|
num_retries=1,
|
||||||
|
mock_response="litellm.RateLimitError",
|
||||||
|
)
|
||||||
|
except litellm.RateLimitError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert mock_client.call_count == 2
|
||||||
|
assert (
|
||||||
|
mock_client.call_args_list[0].kwargs["kwargs"]["standard_logging_object"][
|
||||||
|
"trace_id"
|
||||||
|
]
|
||||||
|
is not None
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
mock_client.call_args_list[0].kwargs["kwargs"]["standard_logging_object"][
|
||||||
|
"trace_id"
|
||||||
|
]
|
||||||
|
== mock_client.call_args_list[1].kwargs["kwargs"][
|
||||||
|
"standard_logging_object"
|
||||||
|
]["trace_id"]
|
||||||
|
)
|
||||||
|
|
|
@ -58,6 +58,7 @@ async def test_content_policy_exception_azure():
|
||||||
except litellm.ContentPolicyViolationError as e:
|
except litellm.ContentPolicyViolationError as e:
|
||||||
print("caught a content policy violation error! Passed")
|
print("caught a content policy violation error! Passed")
|
||||||
print("exception", e)
|
print("exception", e)
|
||||||
|
assert e.response is not None
|
||||||
assert e.litellm_debug_info is not None
|
assert e.litellm_debug_info is not None
|
||||||
assert isinstance(e.litellm_debug_info, str)
|
assert isinstance(e.litellm_debug_info, str)
|
||||||
assert len(e.litellm_debug_info) > 0
|
assert len(e.litellm_debug_info) > 0
|
||||||
|
@ -1152,3 +1153,24 @@ async def test_exception_with_headers_httpx(
|
||||||
if exception_raised is False:
|
if exception_raised is False:
|
||||||
print(resp)
|
print(resp)
|
||||||
assert exception_raised
|
assert exception_raised
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("model", ["azure/chatgpt-v-2", "openai/gpt-3.5-turbo"])
|
||||||
|
async def test_bad_request_error_contains_httpx_response(model):
|
||||||
|
"""
|
||||||
|
Test that the BadRequestError contains the httpx response
|
||||||
|
|
||||||
|
Relevant issue: https://github.com/BerriAI/litellm/issues/6732
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
await litellm.acompletion(
|
||||||
|
model=model,
|
||||||
|
messages=[{"role": "user", "content": "Hello world"}],
|
||||||
|
bad_arg="bad_arg",
|
||||||
|
)
|
||||||
|
pytest.fail("Expected to raise BadRequestError")
|
||||||
|
except litellm.BadRequestError as e:
|
||||||
|
print("e.response", e.response)
|
||||||
|
print("vars(e.response)", vars(e.response))
|
||||||
|
assert e.response is not None
|
||||||
|
|
|
@ -157,7 +157,7 @@ def test_get_llm_provider_jina_ai():
|
||||||
model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider(
|
model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider(
|
||||||
model="jina_ai/jina-embeddings-v3",
|
model="jina_ai/jina-embeddings-v3",
|
||||||
)
|
)
|
||||||
assert custom_llm_provider == "openai_like"
|
assert custom_llm_provider == "jina_ai"
|
||||||
assert api_base == "https://api.jina.ai/v1"
|
assert api_base == "https://api.jina.ai/v1"
|
||||||
assert model == "jina-embeddings-v3"
|
assert model == "jina-embeddings-v3"
|
||||||
|
|
||||||
|
|
|
@ -89,11 +89,16 @@ def test_get_model_info_ollama_chat():
|
||||||
"template": "tools",
|
"template": "tools",
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
):
|
) as mock_client:
|
||||||
info = OllamaConfig().get_model_info("mistral")
|
info = OllamaConfig().get_model_info("mistral")
|
||||||
print("info", info)
|
|
||||||
assert info["supports_function_calling"] is True
|
assert info["supports_function_calling"] is True
|
||||||
|
|
||||||
info = get_model_info("ollama/mistral")
|
info = get_model_info("ollama/mistral")
|
||||||
print("info", info)
|
|
||||||
assert info["supports_function_calling"] is True
|
assert info["supports_function_calling"] is True
|
||||||
|
|
||||||
|
mock_client.assert_called()
|
||||||
|
|
||||||
|
print(mock_client.call_args.kwargs)
|
||||||
|
|
||||||
|
assert mock_client.call_args.kwargs["json"]["name"] == "mistral"
|
||||||
|
|
|
@ -1138,9 +1138,9 @@ async def test_router_content_policy_fallbacks(
|
||||||
router = Router(
|
router = Router(
|
||||||
model_list=[
|
model_list=[
|
||||||
{
|
{
|
||||||
"model_name": "claude-2",
|
"model_name": "claude-2.1",
|
||||||
"litellm_params": {
|
"litellm_params": {
|
||||||
"model": "claude-2",
|
"model": "claude-2.1",
|
||||||
"api_key": "",
|
"api_key": "",
|
||||||
"mock_response": mock_response,
|
"mock_response": mock_response,
|
||||||
},
|
},
|
||||||
|
@ -1164,7 +1164,7 @@ async def test_router_content_policy_fallbacks(
|
||||||
{
|
{
|
||||||
"model_name": "my-general-model",
|
"model_name": "my-general-model",
|
||||||
"litellm_params": {
|
"litellm_params": {
|
||||||
"model": "claude-2",
|
"model": "claude-2.1",
|
||||||
"api_key": "",
|
"api_key": "",
|
||||||
"mock_response": Exception("Should not have called this."),
|
"mock_response": Exception("Should not have called this."),
|
||||||
},
|
},
|
||||||
|
@ -1172,14 +1172,14 @@ async def test_router_content_policy_fallbacks(
|
||||||
{
|
{
|
||||||
"model_name": "my-context-window-model",
|
"model_name": "my-context-window-model",
|
||||||
"litellm_params": {
|
"litellm_params": {
|
||||||
"model": "claude-2",
|
"model": "claude-2.1",
|
||||||
"api_key": "",
|
"api_key": "",
|
||||||
"mock_response": Exception("Should not have called this."),
|
"mock_response": Exception("Should not have called this."),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
content_policy_fallbacks=(
|
content_policy_fallbacks=(
|
||||||
[{"claude-2": ["my-fallback-model"]}]
|
[{"claude-2.1": ["my-fallback-model"]}]
|
||||||
if fallback_type == "model-specific"
|
if fallback_type == "model-specific"
|
||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
|
@ -1190,12 +1190,12 @@ async def test_router_content_policy_fallbacks(
|
||||||
|
|
||||||
if sync_mode is True:
|
if sync_mode is True:
|
||||||
response = router.completion(
|
response = router.completion(
|
||||||
model="claude-2",
|
model="claude-2.1",
|
||||||
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
response = await router.acompletion(
|
response = await router.acompletion(
|
||||||
model="claude-2",
|
model="claude-2.1",
|
||||||
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1455,3 +1455,46 @@ async def test_router_fallbacks_default_and_model_specific_fallbacks(sync_mode):
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
exc_info.value, litellm.AuthenticationError
|
exc_info.value, litellm.AuthenticationError
|
||||||
), f"Expected AuthenticationError, but got {type(exc_info.value).__name__}"
|
), f"Expected AuthenticationError, but got {type(exc_info.value).__name__}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_router_disable_fallbacks_dynamically():
|
||||||
|
from litellm.router import run_async_fallback
|
||||||
|
|
||||||
|
router = Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "bad-model",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "openai/my-bad-model",
|
||||||
|
"api_key": "my-bad-api-key",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "good-model",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "gpt-4o",
|
||||||
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
fallbacks=[{"bad-model": ["good-model"]}],
|
||||||
|
default_fallbacks=["good-model"],
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
router,
|
||||||
|
"log_retry",
|
||||||
|
new=MagicMock(return_value=None),
|
||||||
|
) as mock_client:
|
||||||
|
try:
|
||||||
|
resp = await router.acompletion(
|
||||||
|
model="bad-model",
|
||||||
|
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||||
|
disable_fallbacks=True,
|
||||||
|
)
|
||||||
|
print(resp)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
mock_client.assert_not_called()
|
||||||
|
|
|
@ -14,6 +14,7 @@ from litellm.router import Deployment, LiteLLM_Params, ModelInfo
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
from unittest.mock import patch, MagicMock, AsyncMock
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
@ -83,3 +84,93 @@ def test_returned_settings():
|
||||||
except Exception:
|
except Exception:
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
pytest.fail("An error occurred - " + traceback.format_exc())
|
pytest.fail("An error occurred - " + traceback.format_exc())
|
||||||
|
|
||||||
|
|
||||||
|
from litellm.types.utils import CallTypes
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_kwargs_before_fallbacks_unit_test():
|
||||||
|
router = Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "azure/chatgpt-v-2",
|
||||||
|
"api_key": "bad-key",
|
||||||
|
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||||
|
"api_base": os.getenv("AZURE_API_BASE"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
kwargs = {"messages": [{"role": "user", "content": "write 1 sentence poem"}]}
|
||||||
|
|
||||||
|
router._update_kwargs_before_fallbacks(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
kwargs=kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert kwargs["litellm_trace_id"] is not None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"call_type",
|
||||||
|
[
|
||||||
|
CallTypes.acompletion,
|
||||||
|
CallTypes.atext_completion,
|
||||||
|
CallTypes.aembedding,
|
||||||
|
CallTypes.arerank,
|
||||||
|
CallTypes.atranscription,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_kwargs_before_fallbacks(call_type):
|
||||||
|
|
||||||
|
router = Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "azure/chatgpt-v-2",
|
||||||
|
"api_key": "bad-key",
|
||||||
|
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||||
|
"api_base": os.getenv("AZURE_API_BASE"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
if call_type.value.startswith("a"):
|
||||||
|
with patch.object(router, "async_function_with_fallbacks") as mock_client:
|
||||||
|
if call_type.value == "acompletion":
|
||||||
|
input_kwarg = {
|
||||||
|
"messages": [{"role": "user", "content": "Hello, how are you?"}],
|
||||||
|
}
|
||||||
|
elif (
|
||||||
|
call_type.value == "atext_completion"
|
||||||
|
or call_type.value == "aimage_generation"
|
||||||
|
):
|
||||||
|
input_kwarg = {
|
||||||
|
"prompt": "Hello, how are you?",
|
||||||
|
}
|
||||||
|
elif call_type.value == "aembedding" or call_type.value == "arerank":
|
||||||
|
input_kwarg = {
|
||||||
|
"input": "Hello, how are you?",
|
||||||
|
}
|
||||||
|
elif call_type.value == "atranscription":
|
||||||
|
input_kwarg = {
|
||||||
|
"file": "path/to/file",
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
input_kwarg = {}
|
||||||
|
|
||||||
|
await getattr(router, call_type.value)(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
**input_kwarg,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_client.assert_called_once()
|
||||||
|
|
||||||
|
print(mock_client.call_args.kwargs)
|
||||||
|
assert mock_client.call_args.kwargs["litellm_trace_id"] is not None
|
||||||
|
|
|
@ -15,22 +15,29 @@ sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system path
|
||||||
import pytest
|
import pytest
|
||||||
|
import litellm
|
||||||
from litellm.llms.AzureOpenAI.azure import get_azure_ad_token_from_oidc
|
from litellm.llms.AzureOpenAI.azure import get_azure_ad_token_from_oidc
|
||||||
from litellm.llms.bedrock.chat import BedrockConverseLLM, BedrockLLM
|
from litellm.llms.bedrock.chat import BedrockConverseLLM, BedrockLLM
|
||||||
from litellm.secret_managers.aws_secret_manager import load_aws_secret_manager
|
from litellm.secret_managers.aws_secret_manager_v2 import AWSSecretsManagerV2
|
||||||
from litellm.secret_managers.main import get_secret
|
from litellm.secret_managers.main import (
|
||||||
|
get_secret,
|
||||||
|
_should_read_secret_from_secret_manager,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="AWS Suspended Account")
|
|
||||||
def test_aws_secret_manager():
|
def test_aws_secret_manager():
|
||||||
load_aws_secret_manager(use_aws_secret_manager=True)
|
import json
|
||||||
|
|
||||||
|
AWSSecretsManagerV2.load_aws_secret_manager(use_aws_secret_manager=True)
|
||||||
|
|
||||||
secret_val = get_secret("litellm_master_key")
|
secret_val = get_secret("litellm_master_key")
|
||||||
|
|
||||||
print(f"secret_val: {secret_val}")
|
print(f"secret_val: {secret_val}")
|
||||||
|
|
||||||
assert secret_val == "sk-1234"
|
# cast json to dict
|
||||||
|
secret_val = json.loads(secret_val)
|
||||||
|
|
||||||
|
assert secret_val["litellm_master_key"] == "sk-1234"
|
||||||
|
|
||||||
|
|
||||||
def redact_oidc_signature(secret_val):
|
def redact_oidc_signature(secret_val):
|
||||||
|
@ -240,3 +247,71 @@ def test_google_secret_manager_read_in_memory():
|
||||||
)
|
)
|
||||||
print("secret_val: {}".format(secret_val))
|
print("secret_val: {}".format(secret_val))
|
||||||
assert secret_val == "lite-llm"
|
assert secret_val == "lite-llm"
|
||||||
|
|
||||||
|
|
||||||
|
def test_should_read_secret_from_secret_manager():
|
||||||
|
"""
|
||||||
|
Test that _should_read_secret_from_secret_manager returns correct values based on access mode
|
||||||
|
"""
|
||||||
|
from litellm.proxy._types import KeyManagementSettings
|
||||||
|
|
||||||
|
# Test when secret manager client is None
|
||||||
|
litellm.secret_manager_client = None
|
||||||
|
litellm._key_management_settings = KeyManagementSettings()
|
||||||
|
assert _should_read_secret_from_secret_manager() is False
|
||||||
|
|
||||||
|
# Test with secret manager client and read_only access
|
||||||
|
litellm.secret_manager_client = "dummy_client"
|
||||||
|
litellm._key_management_settings = KeyManagementSettings(access_mode="read_only")
|
||||||
|
assert _should_read_secret_from_secret_manager() is True
|
||||||
|
|
||||||
|
# Test with secret manager client and read_and_write access
|
||||||
|
litellm._key_management_settings = KeyManagementSettings(
|
||||||
|
access_mode="read_and_write"
|
||||||
|
)
|
||||||
|
assert _should_read_secret_from_secret_manager() is True
|
||||||
|
|
||||||
|
# Test with secret manager client and write_only access
|
||||||
|
litellm._key_management_settings = KeyManagementSettings(access_mode="write_only")
|
||||||
|
assert _should_read_secret_from_secret_manager() is False
|
||||||
|
|
||||||
|
# Reset global variables
|
||||||
|
litellm.secret_manager_client = None
|
||||||
|
litellm._key_management_settings = KeyManagementSettings()
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_secret_with_access_mode():
|
||||||
|
"""
|
||||||
|
Test that get_secret respects access mode settings
|
||||||
|
"""
|
||||||
|
from litellm.proxy._types import KeyManagementSettings
|
||||||
|
|
||||||
|
# Set up test environment
|
||||||
|
test_secret_name = "TEST_SECRET_KEY"
|
||||||
|
test_secret_value = "test_secret_value"
|
||||||
|
os.environ[test_secret_name] = test_secret_value
|
||||||
|
|
||||||
|
# Test with write_only access (should read from os.environ)
|
||||||
|
litellm.secret_manager_client = "dummy_client"
|
||||||
|
litellm._key_management_settings = KeyManagementSettings(access_mode="write_only")
|
||||||
|
assert get_secret(test_secret_name) == test_secret_value
|
||||||
|
|
||||||
|
# Test with no KeyManagementSettings but secret_manager_client set
|
||||||
|
litellm.secret_manager_client = "dummy_client"
|
||||||
|
litellm._key_management_settings = KeyManagementSettings()
|
||||||
|
assert _should_read_secret_from_secret_manager() is True
|
||||||
|
|
||||||
|
# Test with read_only access
|
||||||
|
litellm._key_management_settings = KeyManagementSettings(access_mode="read_only")
|
||||||
|
assert _should_read_secret_from_secret_manager() is True
|
||||||
|
|
||||||
|
# Test with read_and_write access
|
||||||
|
litellm._key_management_settings = KeyManagementSettings(
|
||||||
|
access_mode="read_and_write"
|
||||||
|
)
|
||||||
|
assert _should_read_secret_from_secret_manager() is True
|
||||||
|
|
||||||
|
# Reset global variables
|
||||||
|
litellm.secret_manager_client = None
|
||||||
|
litellm._key_management_settings = KeyManagementSettings()
|
||||||
|
del os.environ[test_secret_name]
|
||||||
|
|
|
@ -184,12 +184,11 @@ def test_stream_chunk_builder_litellm_usage_chunks():
|
||||||
{"role": "assistant", "content": "uhhhh\n\n\nhmmmm.....\nthinking....\n"},
|
{"role": "assistant", "content": "uhhhh\n\n\nhmmmm.....\nthinking....\n"},
|
||||||
{"role": "user", "content": "\nI am waiting...\n\n...\n"},
|
{"role": "user", "content": "\nI am waiting...\n\n...\n"},
|
||||||
]
|
]
|
||||||
# make a regular gemini call
|
|
||||||
|
|
||||||
usage: litellm.Usage = Usage(
|
usage: litellm.Usage = Usage(
|
||||||
completion_tokens=64,
|
completion_tokens=27,
|
||||||
prompt_tokens=55,
|
prompt_tokens=55,
|
||||||
total_tokens=119,
|
total_tokens=82,
|
||||||
completion_tokens_details=None,
|
completion_tokens_details=None,
|
||||||
prompt_tokens_details=None,
|
prompt_tokens_details=None,
|
||||||
)
|
)
|
||||||
|
|
|
@ -718,7 +718,7 @@ async def test_acompletion_claude_2_stream():
|
||||||
try:
|
try:
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
response = await litellm.acompletion(
|
response = await litellm.acompletion(
|
||||||
model="claude-2",
|
model="claude-2.1",
|
||||||
messages=[{"role": "user", "content": "hello from litellm"}],
|
messages=[{"role": "user", "content": "hello from litellm"}],
|
||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
|
@ -3274,7 +3274,7 @@ def test_completion_claude_3_function_call_with_streaming():
|
||||||
], # "claude-3-opus-20240229"
|
], # "claude-3-opus-20240229"
|
||||||
) #
|
) #
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_acompletion_claude_3_function_call_with_streaming(model):
|
async def test_acompletion_function_call_with_streaming(model):
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
tools = [
|
tools = [
|
||||||
{
|
{
|
||||||
|
@ -3335,6 +3335,8 @@ async def test_acompletion_claude_3_function_call_with_streaming(model):
|
||||||
# raise Exception("it worked! ")
|
# raise Exception("it worked! ")
|
||||||
except litellm.InternalServerError as e:
|
except litellm.InternalServerError as e:
|
||||||
pytest.skip(f"InternalServerError - {str(e)}")
|
pytest.skip(f"InternalServerError - {str(e)}")
|
||||||
|
except litellm.ServiceUnavailableError:
|
||||||
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
|
@ -3451,3 +3451,90 @@ async def test_user_api_key_auth_db_unavailable_not_allowed():
|
||||||
request=request,
|
request=request,
|
||||||
api_key="Bearer sk-123456789",
|
api_key="Bearer sk-123456789",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
## E2E Virtual Key + Secret Manager Tests #########################################
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_key_generate_with_secret_manager_call(prisma_client):
|
||||||
|
"""
|
||||||
|
Generate a key
|
||||||
|
assert it exists in the secret manager
|
||||||
|
|
||||||
|
delete the key
|
||||||
|
assert it is deleted from the secret manager
|
||||||
|
"""
|
||||||
|
from litellm.secret_managers.aws_secret_manager_v2 import AWSSecretsManagerV2
|
||||||
|
from litellm.proxy._types import KeyManagementSystem, KeyManagementSettings
|
||||||
|
|
||||||
|
litellm.set_verbose = True
|
||||||
|
|
||||||
|
#### Test Setup ############################################################
|
||||||
|
aws_secret_manager_client = AWSSecretsManagerV2()
|
||||||
|
litellm.secret_manager_client = aws_secret_manager_client
|
||||||
|
litellm._key_management_system = KeyManagementSystem.AWS_SECRET_MANAGER
|
||||||
|
litellm._key_management_settings = KeyManagementSettings(
|
||||||
|
store_virtual_keys=True,
|
||||||
|
)
|
||||||
|
general_settings = {
|
||||||
|
"key_management_system": "aws_secret_manager",
|
||||||
|
"key_management_settings": {
|
||||||
|
"store_virtual_keys": True,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
setattr(litellm.proxy.proxy_server, "general_settings", general_settings)
|
||||||
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||||
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||||
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||||
|
############################################################################
|
||||||
|
|
||||||
|
# generate new key
|
||||||
|
key_alias = f"test_alias_secret_manager_key-{uuid.uuid4()}"
|
||||||
|
spend = 100
|
||||||
|
max_budget = 400
|
||||||
|
models = ["fake-openai-endpoint"]
|
||||||
|
new_key = await generate_key_fn(
|
||||||
|
data=GenerateKeyRequest(
|
||||||
|
key_alias=key_alias, spend=spend, max_budget=max_budget, models=models
|
||||||
|
),
|
||||||
|
user_api_key_dict=UserAPIKeyAuth(
|
||||||
|
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||||
|
api_key="sk-1234",
|
||||||
|
user_id="1234",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
generated_key = new_key.key
|
||||||
|
print(generated_key)
|
||||||
|
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
|
||||||
|
# read from the secret manager
|
||||||
|
result = await aws_secret_manager_client.async_read_secret(secret_name=key_alias)
|
||||||
|
|
||||||
|
# Assert the correct key is stored in the secret manager
|
||||||
|
print("response from AWS Secret Manager")
|
||||||
|
print(result)
|
||||||
|
assert result == generated_key
|
||||||
|
|
||||||
|
# delete the key
|
||||||
|
await delete_key_fn(
|
||||||
|
data=KeyRequest(keys=[generated_key]),
|
||||||
|
user_api_key_dict=UserAPIKeyAuth(
|
||||||
|
user_role=LitellmUserRoles.PROXY_ADMIN, api_key="sk-1234", user_id="1234"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
|
||||||
|
# Assert the key is deleted from the secret manager
|
||||||
|
result = await aws_secret_manager_client.async_read_secret(secret_name=key_alias)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
# cleanup
|
||||||
|
setattr(litellm.proxy.proxy_server, "general_settings", {})
|
||||||
|
|
||||||
|
|
||||||
|
################################################################################
|
||||||
|
|
|
@ -1500,6 +1500,31 @@ async def test_add_callback_via_key_litellm_pre_call_utils(
|
||||||
assert new_data["failure_callback"] == expected_failure_callbacks
|
assert new_data["failure_callback"] == expected_failure_callbacks
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"disable_fallbacks_set",
|
||||||
|
[
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_disable_fallbacks_by_key(disable_fallbacks_set):
|
||||||
|
from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup
|
||||||
|
|
||||||
|
key_metadata = {"disable_fallbacks": disable_fallbacks_set}
|
||||||
|
existing_data = {
|
||||||
|
"model": "azure/chatgpt-v-2",
|
||||||
|
"messages": [{"role": "user", "content": "write 1 sentence poem"}],
|
||||||
|
}
|
||||||
|
data = LiteLLMProxyRequestSetup.add_key_level_controls(
|
||||||
|
key_metadata=key_metadata,
|
||||||
|
data=existing_data,
|
||||||
|
_metadata_variable_name="metadata",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert data["disable_fallbacks"] == disable_fallbacks_set
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"callback_type, expected_success_callbacks, expected_failure_callbacks",
|
"callback_type, expected_success_callbacks, expected_failure_callbacks",
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue