Merge branch 'main' into litellm-fix-vertexaibeta

This commit is contained in:
Tiger Yu 2024-07-02 09:49:44 -07:00 committed by GitHub
commit 26630cd263
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
73 changed files with 3482 additions and 782 deletions

View file

@ -66,7 +66,7 @@ jobs:
pip install "pydantic==2.7.1"
pip install "diskcache==5.6.1"
pip install "Pillow==10.3.0"
pip install "ijson==3.2.3"
pip install "jsonschema==4.22.0"
- save_cache:
paths:
- ./venv
@ -128,7 +128,7 @@ jobs:
pip install jinja2
pip install tokenizers
pip install openai
pip install ijson
pip install jsonschema
- run:
name: Run tests
command: |
@ -183,7 +183,7 @@ jobs:
pip install numpydoc
pip install prisma
pip install fastapi
pip install ijson
pip install jsonschema
pip install "httpx==0.24.1"
pip install "gunicorn==21.2.0"
pip install "anyio==3.7.1"
@ -212,6 +212,7 @@ jobs:
-e AWS_REGION_NAME=$AWS_REGION_NAME \
-e AUTO_INFER_REGION=True \
-e OPENAI_API_KEY=$OPENAI_API_KEY \
-e LITELLM_LICENSE=$LITELLM_LICENSE \
-e LANGFUSE_PROJECT1_PUBLIC=$LANGFUSE_PROJECT1_PUBLIC \
-e LANGFUSE_PROJECT2_PUBLIC=$LANGFUSE_PROJECT2_PUBLIC \
-e LANGFUSE_PROJECT1_SECRET=$LANGFUSE_PROJECT1_SECRET \

View file

@ -50,7 +50,7 @@ Use `litellm.get_supported_openai_params()` for an updated list of params for ea
|Huggingface| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | |
|Openrouter| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | | | ✅ | | | | |
|AI21| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | |
|VertexAI| ✅ | ✅ | | ✅ | ✅ | | | | | | | | | | ✅ | | |
|VertexAI| ✅ | ✅ | | ✅ | ✅ | | | | | | | | | | ✅ | | |
|Bedrock| ✅ | ✅ | ✅ | ✅ | ✅ | | | | | | | | | | ✅ (for anthropic) | |
|Sagemaker| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | |
|TogetherAI| ✅ | ✅ | ✅ | ✅ | ✅ | | | | | | ✅ |

View file

@ -2,26 +2,39 @@
For companies that need SSO, user management and professional support for LiteLLM Proxy
:::info
Interested in Enterprise? Schedule a meeting with us here 👉
[Talk to founders](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat)
:::
This covers:
- ✅ **Features under the [LiteLLM Commercial License (Content Mod, Custom Tags, etc.)](https://docs.litellm.ai/docs/proxy/enterprise)**
- ✅ [**Secure UI access with Single Sign-On**](../docs/proxy/ui.md#setup-ssoauth-for-ui)
- ✅ [**Audit Logs with retention policy**](../docs/proxy/enterprise.md#audit-logs)
- ✅ [**JWT-Auth**](../docs/proxy/token_auth.md)
- ✅ [**Control available public, private routes**](../docs/proxy/enterprise.md#control-available-public-private-routes)
- ✅ [**Guardrails, Content Moderation, PII Masking, Secret/API Key Masking**](../docs/proxy/enterprise.md#prompt-injection-detection---lakeraai)
- ✅ [**Prompt Injection Detection**](../docs/proxy/enterprise.md#prompt-injection-detection---lakeraai)
- ✅ [**Invite Team Members to access `/spend` Routes**](../docs/proxy/cost_tracking#allowing-non-proxy-admins-to-access-spend-endpoints)
- **Enterprise Features**
- **Security**
- ✅ [SSO for Admin UI](./proxy/ui#✨-enterprise-features)
- ✅ [Audit Logs with retention policy](./proxy/enterprise#audit-logs)
- ✅ [JWT-Auth](../docs/proxy/token_auth.md)
- ✅ [Control available public, private routes](./proxy/enterprise#control-available-public-private-routes)
- ✅ [[BETA] AWS Key Manager v2 - Key Decryption](./proxy/enterprise#beta-aws-key-manager---key-decryption)
- ✅ [Use LiteLLM keys/authentication on Pass Through Endpoints](./proxy/pass_through#✨-enterprise---use-litellm-keysauthentication-on-pass-through-endpoints)
- ✅ [Enforce Required Params for LLM Requests (ex. Reject requests missing ["metadata"]["generation_name"])](./proxy/enterprise#enforce-required-params-for-llm-requests)
- **Spend Tracking**
- ✅ [Tracking Spend for Custom Tags](./proxy/enterprise#tracking-spend-for-custom-tags)
- ✅ [API Endpoints to get Spend Reports per Team, API Key, Customer](./proxy/cost_tracking.md#✨-enterprise-api-endpoints-to-get-spend)
- **Advanced Metrics**
- ✅ [`x-ratelimit-remaining-requests`, `x-ratelimit-remaining-tokens` for LLM APIs on Prometheus](./proxy/prometheus#✨-enterprise-llm-remaining-requests-and-remaining-tokens)
- **Guardrails, PII Masking, Content Moderation**
- ✅ [Content Moderation with LLM Guard, LlamaGuard, Secret Detection, Google Text Moderations](./proxy/enterprise#content-moderation)
- ✅ [Prompt Injection Detection (with LakeraAI API)](./proxy/enterprise#prompt-injection-detection---lakeraai)
- ✅ Reject calls from Blocked User list
- ✅ Reject calls (incoming / outgoing) with Banned Keywords (e.g. competitors)
- **Custom Branding**
- ✅ [Custom Branding + Routes on Swagger Docs](./proxy/enterprise#swagger-docs---custom-routes--branding)
- ✅ [Public Model Hub](../docs/proxy/enterprise.md#public-model-hub)
- ✅ [Custom Email Branding](../docs/proxy/email.md#customizing-email-branding)
- ✅ **Feature Prioritization**
- ✅ **Custom Integrations**
- ✅ **Professional Support - Dedicated discord + slack**
- ✅ [**Custom Swagger**](../docs/proxy/enterprise.md#swagger-docs---custom-routes--branding)
- ✅ [**Public Model Hub**](../docs/proxy/enterprise.md#public-model-hub)
- ✅ [**Custom Email Branding**](../docs/proxy/email.md#customizing-email-branding)

View file

@ -168,8 +168,12 @@ print(response)
## Supported Models
`Model Name` 👉 Human-friendly name.
`Function Call` 👉 How to call the model in LiteLLM.
| Model Name | Function Call |
|------------------|--------------------------------------------|
| claude-3-5-sonnet | `completion('claude-3-5-sonnet-20240620', messages)` | `os.environ['ANTHROPIC_API_KEY']` |
| claude-3-haiku | `completion('claude-3-haiku-20240307', messages)` | `os.environ['ANTHROPIC_API_KEY']` |
| claude-3-opus | `completion('claude-3-opus-20240229', messages)` | `os.environ['ANTHROPIC_API_KEY']` |
| claude-3-5-sonnet-20240620 | `completion('claude-3-5-sonnet-20240620', messages)` | `os.environ['ANTHROPIC_API_KEY']` |

View file

@ -14,7 +14,7 @@ LiteLLM supports all models on Azure AI Studio
### ENV VAR
```python
import os
os.environ["AZURE_API_API_KEY"] = ""
os.environ["AZURE_AI_API_KEY"] = ""
os.environ["AZURE_AI_API_BASE"] = ""
```
@ -24,7 +24,7 @@ os.environ["AZURE_AI_API_BASE"] = ""
from litellm import completion
import os
## set ENV variables
os.environ["AZURE_API_API_KEY"] = "azure ai key"
os.environ["AZURE_AI_API_KEY"] = "azure ai key"
os.environ["AZURE_AI_API_BASE"] = "azure ai base url" # e.g.: https://Mistral-large-dfgfj-serverless.eastus2.inference.ai.azure.com/
# predibase llama-3 call

View file

@ -549,6 +549,10 @@ response = completion(
This is a deprecated flow. Boto3 is not async. And boto3.client does not let us make the http call through httpx. Pass in your aws params through the method above 👆. [See Auth Code](https://github.com/BerriAI/litellm/blob/55a20c7cce99a93d36a82bf3ae90ba3baf9a7f89/litellm/llms/bedrock_httpx.py#L284) [Add new auth flow](https://github.com/BerriAI/litellm/issues)
Experimental - 2024-Jun-23:
`aws_access_key_id`, `aws_secret_access_key`, and `aws_session_token` will be extracted from boto3.client and be passed into the httpx client
:::
Pass an external BedrockRuntime.Client object as a parameter to litellm.completion. Useful when using an AWS credentials profile, SSO session, assumed role session, or if environment variables are not available for auth.

View file

@ -18,7 +18,7 @@ import litellm
import os
response = litellm.completion(
model="openai/mistral, # add `openai/` prefix to model so litellm knows to route to OpenAI
model="openai/mistral", # add `openai/` prefix to model so litellm knows to route to OpenAI
api_key="sk-1234", # api key to your openai compatible endpoint
api_base="http://0.0.0.0:4000", # set API Base of your Custom OpenAI Endpoint
messages=[

View file

@ -123,6 +123,182 @@ print(completion(**data))
### **JSON Schema**
From v`1.40.1+` LiteLLM supports sending `response_schema` as a param for Gemini-1.5-Pro on Vertex AI. For other models (e.g. `gemini-1.5-flash` or `claude-3-5-sonnet`), LiteLLM adds the schema to the message list with a user-controlled prompt.
**Response Schema**
<Tabs>
<TabItem value="sdk" label="SDK">
```python
from litellm import completion
import json
## SETUP ENVIRONMENT
# !gcloud auth application-default login - run this to add vertex credentials to your env
messages = [
{
"role": "user",
"content": "List 5 popular cookie recipes."
}
]
response_schema = {
"type": "array",
"items": {
"type": "object",
"properties": {
"recipe_name": {
"type": "string",
},
},
"required": ["recipe_name"],
},
}
completion(
model="vertex_ai_beta/gemini-1.5-pro",
messages=messages,
response_format={"type": "json_object", "response_schema": response_schema} # 👈 KEY CHANGE
)
print(json.loads(completion.choices[0].message.content))
```
</TabItem>
<TabItem value="proxy" label="PROXY">
1. Add model to config.yaml
```yaml
model_list:
- model_name: gemini-pro
litellm_params:
model: vertex_ai_beta/gemini-1.5-pro
vertex_project: "project-id"
vertex_location: "us-central1"
vertex_credentials: "/path/to/service_account.json" # [OPTIONAL] Do this OR `!gcloud auth application-default login` - run this to add vertex credentials to your env
```
2. Start Proxy
```
$ litellm --config /path/to/config.yaml
```
3. Make Request!
```bash
curl -X POST 'http://0.0.0.0:4000/chat/completions' \
-H 'Content-Type: application/json' \
-H 'Authorization: Bearer sk-1234' \
-D '{
"model": "gemini-pro",
"messages": [
{"role": "user", "content": "List 5 popular cookie recipes."}
],
"response_format": {"type": "json_object", "response_schema": {
"type": "array",
"items": {
"type": "object",
"properties": {
"recipe_name": {
"type": "string",
},
},
"required": ["recipe_name"],
},
}}
}
'
```
</TabItem>
</Tabs>
**Validate Schema**
To validate the response_schema, set `enforce_validation: true`.
<Tabs>
<TabItem value="sdk" label="SDK">
```python
from litellm import completion, JSONSchemaValidationError
try:
completion(
model="vertex_ai_beta/gemini-1.5-pro",
messages=messages,
response_format={
"type": "json_object",
"response_schema": response_schema,
"enforce_validation": true # 👈 KEY CHANGE
}
)
except JSONSchemaValidationError as e:
print("Raw Response: {}".format(e.raw_response))
raise e
```
</TabItem>
<TabItem value="proxy" label="PROXY">
1. Add model to config.yaml
```yaml
model_list:
- model_name: gemini-pro
litellm_params:
model: vertex_ai_beta/gemini-1.5-pro
vertex_project: "project-id"
vertex_location: "us-central1"
vertex_credentials: "/path/to/service_account.json" # [OPTIONAL] Do this OR `!gcloud auth application-default login` - run this to add vertex credentials to your env
```
2. Start Proxy
```
$ litellm --config /path/to/config.yaml
```
3. Make Request!
```bash
curl -X POST 'http://0.0.0.0:4000/chat/completions' \
-H 'Content-Type: application/json' \
-H 'Authorization: Bearer sk-1234' \
-D '{
"model": "gemini-pro",
"messages": [
{"role": "user", "content": "List 5 popular cookie recipes."}
],
"response_format": {"type": "json_object", "response_schema": {
"type": "array",
"items": {
"type": "object",
"properties": {
"recipe_name": {
"type": "string",
},
},
"required": ["recipe_name"],
},
},
"enforce_validation": true
}
}
'
```
</TabItem>
</Tabs>
LiteLLM will validate the response against the schema, and raise a `JSONSchemaValidationError` if the response does not match the schema.
JSONSchemaValidationError inherits from `openai.APIError`
Access the raw response with `e.raw_response`
**Add to prompt yourself**
```python
from litellm import completion
@ -645,6 +821,86 @@ assert isinstance(
```
## Usage - PDF / Videos / etc. Files
Pass any file supported by Vertex AI, through LiteLLM.
<Tabs>
<TabItem value="sdk" label="SDK">
```python
from litellm import completion
response = completion(
model="vertex_ai/gemini-1.5-flash",
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "You are a very professional document summarization specialist. Please summarize the given document."},
{
"type": "image_url",
"image_url": "gs://cloud-samples-data/generative-ai/pdf/2403.05530.pdf",
},
],
}
],
max_tokens=300,
)
print(response.choices[0])
```
</TabItem>
<TabItem value="proxy" lable="PROXY">
1. Add model to config
```yaml
- model_name: gemini-1.5-flash
litellm_params:
model: vertex_ai/gemini-1.5-flash
vertex_credentials: "/path/to/service_account.json"
```
2. Start Proxy
```
litellm --config /path/to/config.yaml
```
3. Test it!
```bash
curl http://0.0.0.0:4000/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer <YOUR-LITELLM-KEY>" \
-d '{
"model": "gemini-1.5-flash",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": "You are a very professional document summarization specialist. Please summarize the given document"
},
{
"type": "image_url",
"image_url": "gs://cloud-samples-data/generative-ai/pdf/2403.05530.pdf",
},
}
]
}
],
"max_tokens": 300
}'
```
</TabItem>
</Tabs>
## Chat Models
| Model Name | Function Call |
|------------------|--------------------------------------|

View file

@ -277,6 +277,54 @@ curl --location 'http://0.0.0.0:4000/v1/model/info' \
--data ''
```
## Wildcard Model Name (Add ALL MODELS from env)
Dynamically call any model from any given provider without the need to predefine it in the config YAML file. As long as the relevant keys are in the environment (see [providers list](../providers/)), LiteLLM will make the call correctly.
1. Setup config.yaml
```
model_list:
- model_name: "*" # all requests where model not in your config go to this deployment
litellm_params:
model: "openai/*" # passes our validation check that a real provider is given
```
2. Start LiteLLM proxy
```
litellm --config /path/to/config.yaml
```
3. Try claude 3-5 sonnet from anthropic
```bash
curl -X POST 'http://0.0.0.0:4000/chat/completions' \
-H 'Content-Type: application/json' \
-H 'Authorization: Bearer sk-1234' \
-D '{
"model": "claude-3-5-sonnet-20240620",
"messages": [
{"role": "user", "content": "Hey, how'\''s it going?"},
{
"role": "assistant",
"content": "I'\''m doing well. Would like to hear the rest of the story?"
},
{"role": "user", "content": "Na"},
{
"role": "assistant",
"content": "No problem, is there anything else i can help you with today?"
},
{
"role": "user",
"content": "I think you'\''re getting cut off sometimes"
}
]
}
'
```
## Load Balancing
:::info

View file

@ -117,6 +117,8 @@ That's IT. Now Verify your spend was tracked
<Tabs>
<TabItem value="curl" label="Response Headers">
Expect to see `x-litellm-response-cost` in the response headers with calculated cost
<Image img={require('../../img/response_cost_img.png')} />
</TabItem>
@ -145,16 +147,16 @@ Navigate to the Usage Tab on the LiteLLM UI (found on https://your-proxy-endpoin
<Image img={require('../../img/admin_ui_spend.png')} />
## API Endpoints to get Spend
#### Getting Spend Reports - To Charge Other Teams, Customers
Use the `/global/spend/report` endpoint to get daily spend report per
- team
- customer [this is `user` passed to `/chat/completions` request](#how-to-track-spend-with-litellm)
</TabItem>
</Tabs>
## ✨ (Enterprise) API Endpoints to get Spend
#### Getting Spend Reports - To Charge Other Teams, Customers
Use the `/global/spend/report` endpoint to get daily spend report per
- Team
- Customer [this is `user` passed to `/chat/completions` request](#how-to-track-spend-with-litellm)
- [LiteLLM API key](virtual_keys.md)
<Tabs>
@ -337,6 +339,61 @@ curl -X GET 'http://localhost:4000/global/spend/report?start_date=2024-04-01&end
```
</TabItem>
<TabItem value="per key" label="Spend Per API Key">
👉 Key Change: Specify `group_by=api_key`
```shell
curl -X GET 'http://localhost:4000/global/spend/report?start_date=2024-04-01&end_date=2024-06-30&group_by=api_key' \
-H 'Authorization: Bearer sk-1234'
```
##### Example Response
```shell
[
{
"api_key": "ad64768847d05d978d62f623d872bff0f9616cc14b9c1e651c84d14fe3b9f539",
"total_cost": 0.0002157,
"total_input_tokens": 45.0,
"total_output_tokens": 1375.0,
"model_details": [
{
"model": "gpt-3.5-turbo",
"total_cost": 0.0001095,
"total_input_tokens": 9,
"total_output_tokens": 70
},
{
"model": "llama3-8b-8192",
"total_cost": 0.0001062,
"total_input_tokens": 36,
"total_output_tokens": 1305
}
]
},
{
"api_key": "88dc28d0f030c55ed4ab77ed8faf098196cb1c05df778539800c9f1243fe6b4b",
"total_cost": 0.00012924,
"total_input_tokens": 36.0,
"total_output_tokens": 1593.0,
"model_details": [
{
"model": "llama3-8b-8192",
"total_cost": 0.00012924,
"total_input_tokens": 36,
"total_output_tokens": 1593
}
]
}
]
```
</TabItem>
</Tabs>

View file

@ -89,3 +89,30 @@ Expected Output:
```bash
# no info statements
```
## Common Errors
1. "No available deployments..."
```
No deployments available for selected model, Try again in 60 seconds. Passed model=claude-3-5-sonnet. pre-call-checks=False, allowed_model_region=n/a.
```
This can be caused due to all your models hitting rate limit errors, causing the cooldown to kick in.
How to control this?
- Adjust the cooldown time
```yaml
router_settings:
cooldown_time: 0 # 👈 KEY CHANGE
```
- Disable Cooldowns [NOT RECOMMENDED]
```yaml
router_settings:
disable_cooldowns: True
```
This is not recommended, as it will lead to requests being routed to deployments over their tpm/rpm limit.

View file

@ -6,21 +6,34 @@ import TabItem from '@theme/TabItem';
:::tip
Get in touch with us [here](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat)
To get a license, get in touch with us [here](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat)
:::
Features:
- ✅ [SSO for Admin UI](./ui.md#✨-enterprise-features)
- ✅ [Audit Logs](#audit-logs)
- ✅ [Tracking Spend for Custom Tags](#tracking-spend-for-custom-tags)
- ✅ [Control available public, private routes](#control-available-public-private-routes)
- ✅ [Content Moderation with LLM Guard, LlamaGuard, Secret Detection, Google Text Moderations](#content-moderation)
- ✅ [Prompt Injection Detection (with LakeraAI API)](#prompt-injection-detection---lakeraai)
- ✅ [Custom Branding + Routes on Swagger Docs](#swagger-docs---custom-routes--branding)
- ✅ [Enforce Required Params for LLM Requests (ex. Reject requests missing ["metadata"]["generation_name"])](#enforce-required-params-for-llm-requests)
- ✅ Reject calls from Blocked User list
- ✅ Reject calls (incoming / outgoing) with Banned Keywords (e.g. competitors)
- **Security**
- ✅ [SSO for Admin UI](./ui.md#✨-enterprise-features)
- ✅ [Audit Logs with retention policy](#audit-logs)
- ✅ [JWT-Auth](../docs/proxy/token_auth.md)
- ✅ [Control available public, private routes](#control-available-public-private-routes)
- ✅ [[BETA] AWS Key Manager v2 - Key Decryption](#beta-aws-key-manager---key-decryption)
- ✅ [Use LiteLLM keys/authentication on Pass Through Endpoints](pass_through#✨-enterprise---use-litellm-keysauthentication-on-pass-through-endpoints)
- ✅ [Enforce Required Params for LLM Requests (ex. Reject requests missing ["metadata"]["generation_name"])](#enforce-required-params-for-llm-requests)
- **Spend Tracking**
- ✅ [Tracking Spend for Custom Tags](#tracking-spend-for-custom-tags)
- ✅ [API Endpoints to get Spend Reports per Team, API Key, Customer](cost_tracking.md#✨-enterprise-api-endpoints-to-get-spend)
- **Advanced Metrics**
- ✅ [`x-ratelimit-remaining-requests`, `x-ratelimit-remaining-tokens` for LLM APIs on Prometheus](prometheus#✨-enterprise-llm-remaining-requests-and-remaining-tokens)
- **Guardrails, PII Masking, Content Moderation**
- ✅ [Content Moderation with LLM Guard, LlamaGuard, Secret Detection, Google Text Moderations](#content-moderation)
- ✅ [Prompt Injection Detection (with LakeraAI API)](#prompt-injection-detection---lakeraai)
- ✅ Reject calls from Blocked User list
- ✅ Reject calls (incoming / outgoing) with Banned Keywords (e.g. competitors)
- **Custom Branding**
- ✅ [Custom Branding + Routes on Swagger Docs](#swagger-docs---custom-routes--branding)
- ✅ [Public Model Hub](../docs/proxy/enterprise.md#public-model-hub)
- ✅ [Custom Email Branding](../docs/proxy/email.md#customizing-email-branding)
## Audit Logs
@ -1020,3 +1033,34 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \
Share a public page of available models for users
<Image img={require('../../img/model_hub.png')} style={{ width: '900px', height: 'auto' }}/>
## [BETA] AWS Key Manager - Key Decryption
This is a beta feature, and subject to changes.
**Step 1.** Add `USE_AWS_KMS` to env
```env
USE_AWS_KMS="True"
```
**Step 2.** Add `aws_kms/` to encrypted keys in env
```env
DATABASE_URL="aws_kms/AQICAH.."
```
**Step 3.** Start proxy
```
$ litellm
```
How it works?
- Key Decryption runs before server starts up. [**Code**](https://github.com/BerriAI/litellm/blob/8571cb45e80cc561dc34bc6aa89611eb96b9fe3e/litellm/proxy/proxy_cli.py#L445)
- It adds the decrypted value to the `os.environ` for the python process.
**Note:** Setting an environment variable within a Python script using os.environ will not make that variable accessible via SSH sessions or any other new processes that are started independently of the Python script. Environment variables set this way only affect the current process and its child processes.

View file

@ -1188,6 +1188,7 @@ litellm_settings:
s3_region_name: us-west-2 # AWS Region Name for S3
s3_aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID # us os.environ/<variable name> to pass environment variables. This is AWS Access Key ID for S3
s3_aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY # AWS Secret Access Key for S3
s3_path: my-test-path # [OPTIONAL] set path in bucket you want to write logs to
s3_endpoint_url: https://s3.amazonaws.com # [OPTIONAL] S3 endpoint URL, if you want to use Backblaze/cloudflare s3 buckets
```

View file

@ -0,0 +1,220 @@
import Image from '@theme/IdealImage';
# ➡️ Create Pass Through Endpoints
Add pass through routes to LiteLLM Proxy
**Example:** Add a route `/v1/rerank` that forwards requests to `https://api.cohere.com/v1/rerank` through LiteLLM Proxy
💡 This allows making the following Request to LiteLLM Proxy
```shell
curl --request POST \
--url http://localhost:4000/v1/rerank \
--header 'accept: application/json' \
--header 'content-type: application/json' \
--data '{
"model": "rerank-english-v3.0",
"query": "What is the capital of the United States?",
"top_n": 3,
"documents": ["Carson City is the capital city of the American state of Nevada."]
}'
```
## Tutorial - Pass through Cohere Re-Rank Endpoint
**Step 1** Define pass through routes on [litellm config.yaml](configs.md)
```yaml
general_settings:
master_key: sk-1234
pass_through_endpoints:
- path: "/v1/rerank" # route you want to add to LiteLLM Proxy Server
target: "https://api.cohere.com/v1/rerank" # URL this route should forward requests to
headers: # headers to forward to this URL
Authorization: "bearer os.environ/COHERE_API_KEY" # (Optional) Auth Header to forward to your Endpoint
content-type: application/json # (Optional) Extra Headers to pass to this endpoint
accept: application/json
```
**Step 2** Start Proxy Server in detailed_debug mode
```shell
litellm --config config.yaml --detailed_debug
```
**Step 3** Make Request to pass through endpoint
Here `http://localhost:4000` is your litellm proxy endpoint
```shell
curl --request POST \
--url http://localhost:4000/v1/rerank \
--header 'accept: application/json' \
--header 'content-type: application/json' \
--data '{
"model": "rerank-english-v3.0",
"query": "What is the capital of the United States?",
"top_n": 3,
"documents": ["Carson City is the capital city of the American state of Nevada.",
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.",
"Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.",
"Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages.",
"Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states."]
}'
```
🎉 **Expected Response**
This request got forwarded from LiteLLM Proxy -> Defined Target URL (with headers)
```shell
{
"id": "37103a5b-8cfb-48d3-87c7-da288bedd429",
"results": [
{
"index": 2,
"relevance_score": 0.999071
},
{
"index": 4,
"relevance_score": 0.7867867
},
{
"index": 0,
"relevance_score": 0.32713068
}
],
"meta": {
"api_version": {
"version": "1"
},
"billed_units": {
"search_units": 1
}
}
}
```
## Tutorial - Pass Through Langfuse Requests
**Step 1** Define pass through routes on [litellm config.yaml](configs.md)
```yaml
general_settings:
master_key: sk-1234
pass_through_endpoints:
- path: "/api/public/ingestion" # route you want to add to LiteLLM Proxy Server
target: "https://us.cloud.langfuse.com/api/public/ingestion" # URL this route should forward
headers:
LANGFUSE_PUBLIC_KEY: "os.environ/LANGFUSE_DEV_PUBLIC_KEY" # your langfuse account public key
LANGFUSE_SECRET_KEY: "os.environ/LANGFUSE_DEV_SK_KEY" # your langfuse account secret key
```
**Step 2** Start Proxy Server in detailed_debug mode
```shell
litellm --config config.yaml --detailed_debug
```
**Step 3** Make Request to pass through endpoint
Run this code to make a sample trace
```python
from langfuse import Langfuse
langfuse = Langfuse(
host="http://localhost:4000", # your litellm proxy endpoint
public_key="anything", # no key required since this is a pass through
secret_key="anything", # no key required since this is a pass through
)
print("sending langfuse trace request")
trace = langfuse.trace(name="test-trace-litellm-proxy-passthrough")
print("flushing langfuse request")
langfuse.flush()
print("flushed langfuse request")
```
🎉 **Expected Response**
On success
Expect to see the following Trace Generated on your Langfuse Dashboard
<Image img={require('../../img/proxy_langfuse.png')} />
You will see the following endpoint called on your litellm proxy server logs
```shell
POST /api/public/ingestion HTTP/1.1" 207 Multi-Status
```
## ✨ [Enterprise] - Use LiteLLM keys/authentication on Pass Through Endpoints
Use this if you want the pass through endpoint to honour LiteLLM keys/authentication
Usage - set `auth: true` on the config
```yaml
general_settings:
master_key: sk-1234
pass_through_endpoints:
- path: "/v1/rerank"
target: "https://api.cohere.com/v1/rerank"
auth: true # 👈 Key change to use LiteLLM Auth / Keys
headers:
Authorization: "bearer os.environ/COHERE_API_KEY"
content-type: application/json
accept: application/json
```
Test Request with LiteLLM Key
```shell
curl --request POST \
--url http://localhost:4000/v1/rerank \
--header 'accept: application/json' \
--header 'Authorization: Bearer sk-1234'\
--header 'content-type: application/json' \
--data '{
"model": "rerank-english-v3.0",
"query": "What is the capital of the United States?",
"top_n": 3,
"documents": ["Carson City is the capital city of the American state of Nevada.",
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.",
"Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.",
"Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages.",
"Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states."]
}'
```
## `pass_through_endpoints` Spec on config.yaml
All possible values for `pass_through_endpoints` and what they mean
**Example config**
```yaml
general_settings:
pass_through_endpoints:
- path: "/v1/rerank" # route you want to add to LiteLLM Proxy Server
target: "https://api.cohere.com/v1/rerank" # URL this route should forward requests to
headers: # headers to forward to this URL
Authorization: "bearer os.environ/COHERE_API_KEY" # (Optional) Auth Header to forward to your Endpoint
content-type: application/json # (Optional) Extra Headers to pass to this endpoint
accept: application/json
```
**Spec**
* `pass_through_endpoints` *list*: A collection of endpoint configurations for request forwarding.
* `path` *string*: The route to be added to the LiteLLM Proxy Server.
* `target` *string*: The URL to which requests for this path should be forwarded.
* `headers` *object*: Key-value pairs of headers to be forwarded with the request. You can set any key value pair here and it will be forwarded to your target endpoint
* `Authorization` *string*: The authentication header for the target API.
* `content-type` *string*: The format specification for the request body.
* `accept` *string*: The expected response format from the server.
* `LANGFUSE_PUBLIC_KEY` *string*: Your Langfuse account public key - only set this when forwarding to Langfuse.
* `LANGFUSE_SECRET_KEY` *string*: Your Langfuse account secret key - only set this when forwarding to Langfuse.
* `<your-custom-header>` *string*: Pass any custom header key/value pair

View file

@ -1,3 +1,6 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# 📈 Prometheus metrics [BETA]
LiteLLM Exposes a `/metrics` endpoint for Prometheus to Poll
@ -61,6 +64,56 @@ http://localhost:4000/metrics
| `litellm_remaining_api_key_budget_metric` | Remaining Budget for API Key (A key Created on LiteLLM)|
### ✨ (Enterprise) LLM Remaining Requests and Remaining Tokens
Set this on your config.yaml to allow you to track how close you are to hitting your TPM / RPM limits on each model group
```yaml
litellm_settings:
success_callback: ["prometheus"]
failure_callback: ["prometheus"]
return_response_headers: true # ensures the LLM API calls track the response headers
```
| Metric Name | Description |
|----------------------|--------------------------------------|
| `litellm_remaining_requests_metric` | Track `x-ratelimit-remaining-requests` returned from LLM API Deployment |
| `litellm_remaining_tokens` | Track `x-ratelimit-remaining-tokens` return from LLM API Deployment |
Example Metric
<Tabs>
<TabItem value="Remaining Requests" label="Remaining Requests">
```shell
litellm_remaining_requests
{
api_base="https://api.openai.com/v1",
api_provider="openai",
litellm_model_name="gpt-3.5-turbo",
model_group="gpt-3.5-turbo"
}
8998.0
```
</TabItem>
<TabItem value="Requests" label="Remaining Tokens">
```shell
litellm_remaining_tokens
{
api_base="https://api.openai.com/v1",
api_provider="openai",
litellm_model_name="gpt-3.5-turbo",
model_group="gpt-3.5-turbo"
}
999981.0
```
</TabItem>
</Tabs>
## Monitor System Health
To monitor the health of litellm adjacent services (redis / postgres), do:

View file

@ -815,6 +815,35 @@ model_list:
</TabItem>
</Tabs>
**Expected Response**
```
No deployments available for selected model, Try again in 60 seconds. Passed model=claude-3-5-sonnet. pre-call-checks=False, allowed_model_region=n/a.
```
#### **Disable cooldowns**
<Tabs>
<TabItem value="sdk" label="SDK">
```python
from litellm import Router
router = Router(..., disable_cooldowns=True)
```
</TabItem>
<TabItem value="proxy" label="PROXY">
```yaml
router_settings:
disable_cooldowns: True
```
</TabItem>
</Tabs>
### Retries
For both async + sync functions, we support retrying failed requests.

View file

@ -8,7 +8,13 @@ LiteLLM supports reading secrets from Azure Key Vault and Infisical
- [Infisical Secret Manager](#infisical-secret-manager)
- [.env Files](#env-files)
## AWS Key Management Service
## AWS Key Management V1
:::tip
[BETA] AWS Key Management v2 is on the enterprise tier. Go [here for docs](./proxy/enterprise.md#beta-aws-key-manager---key-decryption)
:::
Use AWS KMS to storing a hashed copy of your Proxy Master Key in the environment.

Binary file not shown.

After

Width:  |  Height:  |  Size: 212 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 142 KiB

After

Width:  |  Height:  |  Size: 206 KiB

Before After
Before After

View file

@ -48,6 +48,7 @@ const sidebars = {
"proxy/billing",
"proxy/user_keys",
"proxy/virtual_keys",
"proxy/token_auth",
"proxy/alerting",
{
type: "category",
@ -56,11 +57,11 @@ const sidebars = {
},
"proxy/ui",
"proxy/prometheus",
"proxy/pass_through",
"proxy/email",
"proxy/multiple_admins",
"proxy/team_based_routing",
"proxy/customer_routing",
"proxy/token_auth",
{
type: "category",
label: "Extra Load Balancing",

View file

@ -114,7 +114,11 @@ class _ENTERPRISE_lakeraAI_Moderation(CustomLogger):
if flagged == True:
raise HTTPException(
status_code=400, detail={"error": "Violated content safety policy"}
status_code=400,
detail={
"error": "Violated content safety policy",
"lakera_ai_response": _json_response,
},
)
pass

View file

@ -1,48 +1,13 @@
#!/bin/sh
#!/bin/bash
echo $(pwd)
# Check if DATABASE_URL is not set
if [ -z "$DATABASE_URL" ]; then
# Check if all required variables are provided
if [ -n "$DATABASE_HOST" ] && [ -n "$DATABASE_USERNAME" ] && [ -n "$DATABASE_PASSWORD" ] && [ -n "$DATABASE_NAME" ]; then
# Construct DATABASE_URL from the provided variables
DATABASE_URL="postgresql://${DATABASE_USERNAME}:${DATABASE_PASSWORD}@${DATABASE_HOST}/${DATABASE_NAME}"
export DATABASE_URL
else
echo "Error: Required database environment variables are not set. Provide a postgres url for DATABASE_URL."
exit 1
fi
fi
# Run the Python migration script
python3 litellm/proxy/prisma_migration.py
# Set DIRECT_URL to the value of DATABASE_URL if it is not set, required for migrations
if [ -z "$DIRECT_URL" ]; then
export DIRECT_URL=$DATABASE_URL
fi
# Apply migrations
retry_count=0
max_retries=3
exit_code=1
until [ $retry_count -ge $max_retries ] || [ $exit_code -eq 0 ]
do
retry_count=$((retry_count+1))
echo "Attempt $retry_count..."
# Run the Prisma db push command
prisma db push --accept-data-loss
exit_code=$?
if [ $exit_code -ne 0 ] && [ $retry_count -lt $max_retries ]; then
echo "Retrying in 10 seconds..."
sleep 10
fi
done
if [ $exit_code -ne 0 ]; then
echo "Unable to push database changes after $max_retries retries."
# Check if the Python script executed successfully
if [ $? -eq 0 ]; then
echo "Migration script ran successfully!"
else
echo "Migration script failed!"
exit 1
fi
echo "Database push successful!"

View file

@ -125,6 +125,9 @@ llm_guard_mode: Literal["all", "key-specific", "request-specific"] = "all"
##################
### PREVIEW FEATURES ###
enable_preview_features: bool = False
return_response_headers: bool = (
False # get response headers from LLM Api providers - example x-remaining-requests,
)
##################
logging: bool = True
caching: bool = (
@ -749,6 +752,7 @@ from .utils import (
create_pretrained_tokenizer,
create_tokenizer,
supports_function_calling,
supports_response_schema,
supports_parallel_function_calling,
supports_vision,
supports_system_messages,
@ -799,7 +803,11 @@ from .llms.sagemaker import SagemakerConfig
from .llms.ollama import OllamaConfig
from .llms.ollama_chat import OllamaChatConfig
from .llms.maritalk import MaritTalkConfig
from .llms.bedrock_httpx import AmazonCohereChatConfig, AmazonConverseConfig
from .llms.bedrock_httpx import (
AmazonCohereChatConfig,
AmazonConverseConfig,
BEDROCK_CONVERSE_MODELS,
)
from .llms.bedrock import (
AmazonTitanConfig,
AmazonAI21Config,
@ -848,6 +856,7 @@ from .exceptions import (
APIResponseValidationError,
UnprocessableEntityError,
InternalServerError,
JSONSchemaValidationError,
LITELLM_EXCEPTION_TYPES,
)
from .budget_manager import BudgetManager

View file

@ -1,6 +1,7 @@
# What is this?
## File for 'response_cost' calculation in Logging
import time
import traceback
from typing import List, Literal, Optional, Tuple, Union
import litellm
@ -668,3 +669,10 @@ def response_cost_calculator(
f"Model={model} for LLM Provider={custom_llm_provider} not found in completion cost map."
)
return None
except Exception as e:
verbose_logger.error(
"litellm.cost_calculator.py::response_cost_calculator - Exception occurred - {}/n{}".format(
str(e), traceback.format_exc()
)
)
return None

View file

@ -551,7 +551,7 @@ class APIError(openai.APIError): # type: ignore
message,
llm_provider,
model,
request: httpx.Request,
request: Optional[httpx.Request] = None,
litellm_debug_info: Optional[str] = None,
max_retries: Optional[int] = None,
num_retries: Optional[int] = None,
@ -563,6 +563,8 @@ class APIError(openai.APIError): # type: ignore
self.litellm_debug_info = litellm_debug_info
self.max_retries = max_retries
self.num_retries = num_retries
if request is None:
request = httpx.Request(method="POST", url="https://api.openai.com/v1")
super().__init__(self.message, request=request, body=None) # type: ignore
def __str__(self):
@ -664,6 +666,22 @@ class OpenAIError(openai.OpenAIError): # type: ignore
self.llm_provider = "openai"
class JSONSchemaValidationError(APIError):
def __init__(
self, model: str, llm_provider: str, raw_response: str, schema: str
) -> None:
self.raw_response = raw_response
self.schema = schema
self.model = model
message = "litellm.JSONSchemaValidationError: model={}, returned an invalid response={}, for schema={}.\nAccess raw response with `e.raw_response`".format(
model, raw_response, schema
)
self.message = message
super().__init__(
model=model, message=message, llm_provider=llm_provider, status_code=500
)
LITELLM_EXCEPTION_TYPES = [
AuthenticationError,
NotFoundError,
@ -682,6 +700,7 @@ LITELLM_EXCEPTION_TYPES = [
APIResponseValidationError,
OpenAIError,
InternalServerError,
JSONSchemaValidationError,
]

View file

@ -311,11 +311,6 @@ class LangFuseLogger:
try:
tags = []
try:
metadata = copy.deepcopy(
metadata
) # Avoid modifying the original metadata
except:
new_metadata = {}
for key, value in metadata.items():
if (

View file

@ -2,14 +2,20 @@
#### What this does ####
# On success, log events to Prometheus
import dotenv, os
import requests # type: ignore
import datetime
import os
import subprocess
import sys
import traceback
import datetime, subprocess, sys
import litellm, uuid
from litellm._logging import print_verbose, verbose_logger
import uuid
from typing import Optional, Union
import dotenv
import requests # type: ignore
import litellm
from litellm._logging import print_verbose, verbose_logger
class PrometheusLogger:
# Class variables or attributes
@ -20,6 +26,8 @@ class PrometheusLogger:
try:
from prometheus_client import Counter, Gauge
from litellm.proxy.proxy_server import premium_user
self.litellm_llm_api_failed_requests_metric = Counter(
name="litellm_llm_api_failed_requests_metric",
documentation="Total number of failed LLM API calls via litellm",
@ -88,6 +96,31 @@ class PrometheusLogger:
labelnames=["hashed_api_key", "api_key_alias"],
)
# Litellm-Enterprise Metrics
if premium_user is True:
# Remaining Rate Limit for model
self.litellm_remaining_requests_metric = Gauge(
"litellm_remaining_requests",
"remaining requests for model, returned from LLM API Provider",
labelnames=[
"model_group",
"api_provider",
"api_base",
"litellm_model_name",
],
)
self.litellm_remaining_tokens_metric = Gauge(
"litellm_remaining_tokens",
"remaining tokens for model, returned from LLM API Provider",
labelnames=[
"model_group",
"api_provider",
"api_base",
"litellm_model_name",
],
)
except Exception as e:
print_verbose(f"Got exception on init prometheus client {str(e)}")
raise e
@ -104,6 +137,8 @@ class PrometheusLogger:
):
try:
# Define prometheus client
from litellm.proxy.proxy_server import premium_user
verbose_logger.debug(
f"prometheus Logging - Enters logging function for model {kwargs}"
)
@ -199,6 +234,10 @@ class PrometheusLogger:
user_api_key, user_api_key_alias
).set(_remaining_api_key_budget)
# set x-ratelimit headers
if premium_user is True:
self.set_remaining_tokens_requests_metric(kwargs)
### FAILURE INCREMENT ###
if "exception" in kwargs:
self.litellm_llm_api_failed_requests_metric.labels(
@ -216,6 +255,58 @@ class PrometheusLogger:
verbose_logger.debug(traceback.format_exc())
pass
def set_remaining_tokens_requests_metric(self, request_kwargs: dict):
try:
verbose_logger.debug("setting remaining tokens requests metric")
_response_headers = request_kwargs.get("response_headers")
_litellm_params = request_kwargs.get("litellm_params", {}) or {}
_metadata = _litellm_params.get("metadata", {})
litellm_model_name = request_kwargs.get("model", None)
model_group = _metadata.get("model_group", None)
api_base = _metadata.get("api_base", None)
llm_provider = _litellm_params.get("custom_llm_provider", None)
remaining_requests = None
remaining_tokens = None
# OpenAI / OpenAI Compatible headers
if (
_response_headers
and "x-ratelimit-remaining-requests" in _response_headers
):
remaining_requests = _response_headers["x-ratelimit-remaining-requests"]
if (
_response_headers
and "x-ratelimit-remaining-tokens" in _response_headers
):
remaining_tokens = _response_headers["x-ratelimit-remaining-tokens"]
verbose_logger.debug(
f"remaining requests: {remaining_requests}, remaining tokens: {remaining_tokens}"
)
if remaining_requests:
"""
"model_group",
"api_provider",
"api_base",
"litellm_model_name"
"""
self.litellm_remaining_requests_metric.labels(
model_group, llm_provider, api_base, litellm_model_name
).set(remaining_requests)
if remaining_tokens:
self.litellm_remaining_tokens_metric.labels(
model_group, llm_provider, api_base, litellm_model_name
).set(remaining_tokens)
except Exception as e:
verbose_logger.error(
"Prometheus Error: set_remaining_tokens_requests_metric. Exception occured - {}".format(
str(e)
)
)
return
def safe_get_remaining_budget(
max_budget: Optional[float], spend: Optional[float]

View file

@ -1,10 +1,14 @@
#### What this does ####
# On success + failure, log events to Supabase
import datetime
import os
import subprocess
import sys
import traceback
import datetime, subprocess, sys
import litellm, uuid
import uuid
import litellm
from litellm._logging import print_verbose, verbose_logger
@ -54,6 +58,7 @@ class S3Logger:
"s3_aws_session_token"
)
s3_config = litellm.s3_callback_params.get("s3_config")
s3_path = litellm.s3_callback_params.get("s3_path")
# done reading litellm.s3_callback_params
self.bucket_name = s3_bucket_name

View file

@ -26,7 +26,7 @@ def map_finish_reason(
finish_reason == "FINISH_REASON_UNSPECIFIED" or finish_reason == "STOP"
): # vertex ai - got from running `print(dir(response_obj.candidates[0].finish_reason))`: ['FINISH_REASON_UNSPECIFIED', 'MAX_TOKENS', 'OTHER', 'RECITATION', 'SAFETY', 'STOP',]
return "stop"
elif finish_reason == "SAFETY": # vertex ai
elif finish_reason == "SAFETY" or finish_reason == "RECITATION": # vertex ai
return "content_filter"
elif finish_reason == "STOP": # vertex ai
return "stop"

View file

@ -0,0 +1,40 @@
import json
from typing import Optional
def get_error_message(error_obj) -> Optional[str]:
"""
OpenAI Returns Error message that is nested, this extract the message
Example:
{
'request': "<Request('POST', 'https://api.openai.com/v1/chat/completions')>",
'message': "Error code: 400 - {\'error\': {\'message\': \"Invalid 'temperature': decimal above maximum value. Expected a value <= 2, but got 200 instead.\", 'type': 'invalid_request_error', 'param': 'temperature', 'code': 'decimal_above_max_value'}}",
'body': {
'message': "Invalid 'temperature': decimal above maximum value. Expected a value <= 2, but got 200 instead.",
'type': 'invalid_request_error',
'param': 'temperature',
'code': 'decimal_above_max_value'
},
'code': 'decimal_above_max_value',
'param': 'temperature',
'type': 'invalid_request_error',
'response': "<Response [400 Bad Request]>",
'status_code': 400,
'request_id': 'req_f287898caa6364cd42bc01355f74dd2a'
}
"""
try:
# First, try to access the message directly from the 'body' key
if error_obj is None:
return None
if hasattr(error_obj, "body"):
_error_obj_body = getattr(error_obj, "body")
if isinstance(_error_obj_body, dict):
return _error_obj_body.get("message")
# If all else fails, return None
return None
except Exception as e:
return None

View file

@ -0,0 +1,23 @@
import json
def validate_schema(schema: dict, response: str):
"""
Validate if the returned json response follows the schema.
Params:
- schema - dict: JSON schema
- response - str: Received json response as string.
"""
from jsonschema import ValidationError, validate
from litellm import JSONSchemaValidationError
response_dict = json.loads(response)
try:
validate(response_dict, schema=schema)
except ValidationError:
raise JSONSchemaValidationError(
model="", llm_provider="", raw_response=response, schema=json.dumps(schema)
)

View file

@ -1,23 +1,28 @@
import os, types
import copy
import json
from enum import Enum
import requests, copy # type: ignore
import os
import time
import types
from enum import Enum
from functools import partial
from typing import Callable, Optional, List, Union
import litellm.litellm_core_utils
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from typing import Callable, List, Optional, Union
import httpx # type: ignore
import requests # type: ignore
import litellm
from .prompt_templates.factory import prompt_factory, custom_prompt
import litellm.litellm_core_utils
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
_get_async_httpx_client,
_get_httpx_client,
)
from .base import BaseLLM
import httpx # type: ignore
from litellm.types.llms.anthropic import AnthropicMessagesToolChoice
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
from .base import BaseLLM
from .prompt_templates.factory import custom_prompt, prompt_factory
class AnthropicConstants(Enum):
@ -179,10 +184,19 @@ async def make_call(
if client is None:
client = _get_async_httpx_client() # Create a new client if none provided
try:
response = await client.post(api_base, headers=headers, data=data, stream=True)
except httpx.HTTPStatusError as e:
raise AnthropicError(
status_code=e.response.status_code, message=await e.response.aread()
)
except Exception as e:
raise AnthropicError(status_code=500, message=str(e))
if response.status_code != 200:
raise AnthropicError(status_code=response.status_code, message=response.text)
raise AnthropicError(
status_code=response.status_code, message=await response.aread()
)
completion_stream = response.aiter_lines()

View file

@ -23,6 +23,7 @@ from typing_extensions import overload
import litellm
from litellm import OpenAIConfig
from litellm.caching import DualCache
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.utils import (
Choices,
CustomStreamWrapper,
@ -458,6 +459,36 @@ class AzureChatCompletion(BaseLLM):
return azure_client
async def make_azure_openai_chat_completion_request(
self,
azure_client: AsyncAzureOpenAI,
data: dict,
timeout: Union[float, httpx.Timeout],
):
"""
Helper to:
- call chat.completions.create.with_raw_response when litellm.return_response_headers is True
- call chat.completions.create by default
"""
try:
if litellm.return_response_headers is True:
raw_response = (
await azure_client.chat.completions.with_raw_response.create(
**data, timeout=timeout
)
)
headers = dict(raw_response.headers)
response = raw_response.parse()
return headers, response
else:
response = await azure_client.chat.completions.create(
**data, timeout=timeout
)
return None, response
except Exception as e:
raise e
def completion(
self,
model: str,
@ -470,7 +501,7 @@ class AzureChatCompletion(BaseLLM):
azure_ad_token: str,
print_verbose: Callable,
timeout: Union[float, httpx.Timeout],
logging_obj,
logging_obj: LiteLLMLoggingObj,
optional_params,
litellm_params,
logger_fn,
@ -649,9 +680,9 @@ class AzureChatCompletion(BaseLLM):
data: dict,
timeout: Any,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
azure_ad_token: Optional[str] = None,
client=None, # this is the AsyncAzureOpenAI
logging_obj=None,
):
response = None
try:
@ -701,9 +732,13 @@ class AzureChatCompletion(BaseLLM):
"complete_input_dict": data,
},
)
response = await azure_client.chat.completions.create(
**data, timeout=timeout
headers, response = await self.make_azure_openai_chat_completion_request(
azure_client=azure_client,
data=data,
timeout=timeout,
)
logging_obj.model_call_details["response_headers"] = headers
stringified_response = response.model_dump()
logging_obj.post_call(
@ -717,11 +752,32 @@ class AzureChatCompletion(BaseLLM):
model_response_object=model_response,
)
except AzureOpenAIError as e:
## LOGGING
logging_obj.post_call(
input=data["messages"],
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=str(e),
)
exception_mapping_worked = True
raise e
except asyncio.CancelledError as e:
## LOGGING
logging_obj.post_call(
input=data["messages"],
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=str(e),
)
raise AzureOpenAIError(status_code=500, message=str(e))
except Exception as e:
## LOGGING
logging_obj.post_call(
input=data["messages"],
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=str(e),
)
if hasattr(e, "status_code"):
raise e
else:
@ -791,7 +847,7 @@ class AzureChatCompletion(BaseLLM):
async def async_streaming(
self,
logging_obj,
logging_obj: LiteLLMLoggingObj,
api_base: str,
api_key: str,
api_version: str,
@ -840,9 +896,14 @@ class AzureChatCompletion(BaseLLM):
"complete_input_dict": data,
},
)
response = await azure_client.chat.completions.create(
**data, timeout=timeout
headers, response = await self.make_azure_openai_chat_completion_request(
azure_client=azure_client,
data=data,
timeout=timeout,
)
logging_obj.model_call_details["response_headers"] = headers
# return response
streamwrapper = CustomStreamWrapper(
completion_stream=response,

View file

@ -60,6 +60,17 @@ from .prompt_templates.factory import (
prompt_factory,
)
BEDROCK_CONVERSE_MODELS = [
"anthropic.claude-3-5-sonnet-20240620-v1:0",
"anthropic.claude-3-opus-20240229-v1:0",
"anthropic.claude-3-sonnet-20240229-v1:0",
"anthropic.claude-3-haiku-20240307-v1:0",
"anthropic.claude-v2",
"anthropic.claude-v2:1",
"anthropic.claude-v1",
"anthropic.claude-instant-v1",
]
iam_cache = DualCache()
@ -305,6 +316,7 @@ class BedrockLLM(BaseLLM):
self,
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,
aws_session_token: Optional[str] = None,
aws_region_name: Optional[str] = None,
aws_session_name: Optional[str] = None,
aws_profile_name: Optional[str] = None,
@ -320,6 +332,7 @@ class BedrockLLM(BaseLLM):
params_to_check: List[Optional[str]] = [
aws_access_key_id,
aws_secret_access_key,
aws_session_token,
aws_region_name,
aws_session_name,
aws_profile_name,
@ -337,6 +350,7 @@ class BedrockLLM(BaseLLM):
(
aws_access_key_id,
aws_secret_access_key,
aws_session_token,
aws_region_name,
aws_session_name,
aws_profile_name,
@ -430,6 +444,19 @@ class BedrockLLM(BaseLLM):
client = boto3.Session(profile_name=aws_profile_name)
return client.get_credentials()
elif (
aws_access_key_id is not None
and aws_secret_access_key is not None
and aws_session_token is not None
): ### CHECK FOR AWS SESSION TOKEN ###
from botocore.credentials import Credentials
credentials = Credentials(
access_key=aws_access_key_id,
secret_key=aws_secret_access_key,
token=aws_session_token,
)
return credentials
else:
session = boto3.Session(
aws_access_key_id=aws_access_key_id,
@ -734,9 +761,10 @@ class BedrockLLM(BaseLLM):
provider = model.split(".")[0]
## CREDENTIALS ##
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
# pop aws_secret_access_key, aws_access_key_id, aws_session_token, aws_region_name from kwargs, since completion calls fail with them
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
aws_session_token = optional_params.pop("aws_session_token", None)
aws_region_name = optional_params.pop("aws_region_name", None)
aws_role_name = optional_params.pop("aws_role_name", None)
aws_session_name = optional_params.pop("aws_session_name", None)
@ -768,6 +796,7 @@ class BedrockLLM(BaseLLM):
credentials: Credentials = self.get_credentials(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
aws_region_name=aws_region_name,
aws_session_name=aws_session_name,
aws_profile_name=aws_profile_name,
@ -1422,6 +1451,7 @@ class BedrockConverseLLM(BaseLLM):
self,
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,
aws_session_token: Optional[str] = None,
aws_region_name: Optional[str] = None,
aws_session_name: Optional[str] = None,
aws_profile_name: Optional[str] = None,
@ -1437,6 +1467,7 @@ class BedrockConverseLLM(BaseLLM):
params_to_check: List[Optional[str]] = [
aws_access_key_id,
aws_secret_access_key,
aws_session_token,
aws_region_name,
aws_session_name,
aws_profile_name,
@ -1454,6 +1485,7 @@ class BedrockConverseLLM(BaseLLM):
(
aws_access_key_id,
aws_secret_access_key,
aws_session_token,
aws_region_name,
aws_session_name,
aws_profile_name,
@ -1547,6 +1579,19 @@ class BedrockConverseLLM(BaseLLM):
client = boto3.Session(profile_name=aws_profile_name)
return client.get_credentials()
elif (
aws_access_key_id is not None
and aws_secret_access_key is not None
and aws_session_token is not None
): ### CHECK FOR AWS SESSION TOKEN ###
from botocore.credentials import Credentials
credentials = Credentials(
access_key=aws_access_key_id,
secret_key=aws_secret_access_key,
token=aws_session_token,
)
return credentials
else:
session = boto3.Session(
aws_access_key_id=aws_access_key_id,
@ -1682,6 +1727,7 @@ class BedrockConverseLLM(BaseLLM):
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
aws_session_token = optional_params.pop("aws_session_token", None)
aws_region_name = optional_params.pop("aws_region_name", None)
aws_role_name = optional_params.pop("aws_role_name", None)
aws_session_name = optional_params.pop("aws_session_name", None)
@ -1713,6 +1759,7 @@ class BedrockConverseLLM(BaseLLM):
credentials: Credentials = self.get_credentials(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
aws_region_name=aws_region_name,
aws_session_name=aws_session_name,
aws_profile_name=aws_profile_name,

View file

@ -1,4 +1,8 @@
import time, json, httpx, asyncio
import asyncio
import json
import time
import httpx
class AsyncCustomHTTPTransport(httpx.AsyncHTTPTransport):
@ -7,15 +11,18 @@ class AsyncCustomHTTPTransport(httpx.AsyncHTTPTransport):
"""
async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
if "images/generations" in request.url.path and request.url.params[
"api-version"
] in [ # dall-e-3 starts from `2023-12-01-preview` so we should be able to avoid conflict
_api_version = request.url.params.get("api-version", "")
if (
"images/generations" in request.url.path
and _api_version
in [ # dall-e-3 starts from `2023-12-01-preview` so we should be able to avoid conflict
"2023-06-01-preview",
"2023-07-01-preview",
"2023-08-01-preview",
"2023-09-01-preview",
"2023-10-01-preview",
]:
]
):
request.url = request.url.copy_with(
path="/openai/images/generations:submit"
)
@ -77,15 +84,18 @@ class CustomHTTPTransport(httpx.HTTPTransport):
self,
request: httpx.Request,
) -> httpx.Response:
if "images/generations" in request.url.path and request.url.params[
"api-version"
] in [ # dall-e-3 starts from `2023-12-01-preview` so we should be able to avoid conflict
_api_version = request.url.params.get("api-version", "")
if (
"images/generations" in request.url.path
and _api_version
in [ # dall-e-3 starts from `2023-12-01-preview` so we should be able to avoid conflict
"2023-06-01-preview",
"2023-07-01-preview",
"2023-08-01-preview",
"2023-09-01-preview",
"2023-10-01-preview",
]:
]
):
request.url = request.url.copy_with(
path="/openai/images/generations:submit"
)

View file

@ -1,6 +1,11 @@
import asyncio
import os
import traceback
from typing import Any, Mapping, Optional, Union
import httpx
import litellm
import httpx, asyncio, traceback, os
from typing import Optional, Union, Mapping, Any
# https://www.python-httpx.org/advanced/timeouts
_DEFAULT_TIMEOUT = httpx.Timeout(timeout=5.0, connect=5.0)
@ -93,7 +98,7 @@ class AsyncHTTPHandler:
response = await self.client.send(req, stream=stream)
response.raise_for_status()
return response
except httpx.RemoteProtocolError:
except (httpx.RemoteProtocolError, httpx.ConnectError):
# Retry the request with a new session if there is a connection error
new_client = self.create_client(timeout=self.timeout, concurrent_limit=1)
try:
@ -109,6 +114,11 @@ class AsyncHTTPHandler:
finally:
await new_client.aclose()
except httpx.HTTPStatusError as e:
setattr(e, "status_code", e.response.status_code)
if stream is True:
setattr(e, "message", await e.response.aread())
else:
setattr(e, "message", e.response.text)
raise e
except Exception as e:
raise e
@ -208,6 +218,7 @@ class HTTPHandler:
headers: Optional[dict] = None,
stream: bool = False,
):
req = self.client.build_request(
"POST", url, data=data, json=json, params=params, headers=headers # type: ignore
)

View file

@ -21,6 +21,7 @@ from pydantic import BaseModel
from typing_extensions import overload, override
import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.types.utils import ProviderField
from litellm.utils import (
Choices,
@ -652,6 +653,36 @@ class OpenAIChatCompletion(BaseLLM):
else:
return client
async def make_openai_chat_completion_request(
self,
openai_aclient: AsyncOpenAI,
data: dict,
timeout: Union[float, httpx.Timeout],
):
"""
Helper to:
- call chat.completions.create.with_raw_response when litellm.return_response_headers is True
- call chat.completions.create by default
"""
try:
if litellm.return_response_headers is True:
raw_response = (
await openai_aclient.chat.completions.with_raw_response.create(
**data, timeout=timeout
)
)
headers = dict(raw_response.headers)
response = raw_response.parse()
return headers, response
else:
response = await openai_aclient.chat.completions.create(
**data, timeout=timeout
)
return None, response
except Exception as e:
raise e
def completion(
self,
model_response: ModelResponse,
@ -678,17 +709,17 @@ class OpenAIChatCompletion(BaseLLM):
if headers:
optional_params["extra_headers"] = headers
if model is None or messages is None:
raise OpenAIError(status_code=422, message=f"Missing model or messages")
raise OpenAIError(status_code=422, message="Missing model or messages")
if not isinstance(timeout, float) and not isinstance(
timeout, httpx.Timeout
):
raise OpenAIError(
status_code=422,
message=f"Timeout needs to be a float or httpx.Timeout",
message="Timeout needs to be a float or httpx.Timeout",
)
if custom_llm_provider != "openai":
if custom_llm_provider is not None and custom_llm_provider != "openai":
model_response.model = f"{custom_llm_provider}/{model}"
# process all OpenAI compatible provider logic here
if custom_llm_provider == "mistral":
@ -836,13 +867,13 @@ class OpenAIChatCompletion(BaseLLM):
self,
data: dict,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
timeout: Union[float, httpx.Timeout],
api_key: Optional[str] = None,
api_base: Optional[str] = None,
organization: Optional[str] = None,
client=None,
max_retries=None,
logging_obj=None,
headers=None,
):
response = None
@ -869,8 +900,8 @@ class OpenAIChatCompletion(BaseLLM):
},
)
response = await openai_aclient.chat.completions.create(
**data, timeout=timeout
headers, response = await self.make_openai_chat_completion_request(
openai_aclient=openai_aclient, data=data, timeout=timeout
)
stringified_response = response.model_dump()
logging_obj.post_call(
@ -879,9 +910,11 @@ class OpenAIChatCompletion(BaseLLM):
original_response=stringified_response,
additional_args={"complete_input_dict": data},
)
logging_obj.model_call_details["response_headers"] = headers
return convert_to_model_response_object(
response_object=stringified_response,
model_response_object=model_response,
hidden_params={"headers": headers},
)
except Exception as e:
raise e
@ -931,10 +964,10 @@ class OpenAIChatCompletion(BaseLLM):
async def async_streaming(
self,
logging_obj,
timeout: Union[float, httpx.Timeout],
data: dict,
model: str,
logging_obj: LiteLLMLoggingObj,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
organization: Optional[str] = None,
@ -965,9 +998,10 @@ class OpenAIChatCompletion(BaseLLM):
},
)
response = await openai_aclient.chat.completions.create(
**data, timeout=timeout
headers, response = await self.make_openai_chat_completion_request(
openai_aclient=openai_aclient, data=data, timeout=timeout
)
logging_obj.model_call_details["response_headers"] = headers
streamwrapper = CustomStreamWrapper(
completion_stream=response,
model=model,
@ -992,17 +1026,43 @@ class OpenAIChatCompletion(BaseLLM):
else:
raise OpenAIError(status_code=500, message=f"{str(e)}")
# Embedding
async def make_openai_embedding_request(
self,
openai_aclient: AsyncOpenAI,
data: dict,
timeout: Union[float, httpx.Timeout],
):
"""
Helper to:
- call embeddings.create.with_raw_response when litellm.return_response_headers is True
- call embeddings.create by default
"""
try:
if litellm.return_response_headers is True:
raw_response = await openai_aclient.embeddings.with_raw_response.create(
**data, timeout=timeout
) # type: ignore
headers = dict(raw_response.headers)
response = raw_response.parse()
return headers, response
else:
response = await openai_aclient.embeddings.create(**data, timeout=timeout) # type: ignore
return None, response
except Exception as e:
raise e
async def aembedding(
self,
input: list,
data: dict,
model_response: litellm.utils.EmbeddingResponse,
timeout: float,
logging_obj: LiteLLMLoggingObj,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
client: Optional[AsyncOpenAI] = None,
max_retries=None,
logging_obj=None,
):
response = None
try:
@ -1014,7 +1074,10 @@ class OpenAIChatCompletion(BaseLLM):
max_retries=max_retries,
client=client,
)
response = await openai_aclient.embeddings.create(**data, timeout=timeout) # type: ignore
headers, response = await self.make_openai_embedding_request(
openai_aclient=openai_aclient, data=data, timeout=timeout
)
logging_obj.model_call_details["response_headers"] = headers
stringified_response = response.model_dump()
## LOGGING
logging_obj.post_call(
@ -1229,6 +1292,34 @@ class OpenAIChatCompletion(BaseLLM):
else:
raise OpenAIError(status_code=500, message=str(e))
# Audio Transcriptions
async def make_openai_audio_transcriptions_request(
self,
openai_aclient: AsyncOpenAI,
data: dict,
timeout: Union[float, httpx.Timeout],
):
"""
Helper to:
- call openai_aclient.audio.transcriptions.with_raw_response when litellm.return_response_headers is True
- call openai_aclient.audio.transcriptions.create by default
"""
try:
if litellm.return_response_headers is True:
raw_response = (
await openai_aclient.audio.transcriptions.with_raw_response.create(
**data, timeout=timeout
)
) # type: ignore
headers = dict(raw_response.headers)
response = raw_response.parse()
return headers, response
else:
response = await openai_aclient.audio.transcriptions.create(**data, timeout=timeout) # type: ignore
return None, response
except Exception as e:
raise e
def audio_transcriptions(
self,
model: str,
@ -1286,11 +1377,11 @@ class OpenAIChatCompletion(BaseLLM):
data: dict,
model_response: TranscriptionResponse,
timeout: float,
logging_obj: LiteLLMLoggingObj,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
client=None,
max_retries=None,
logging_obj=None,
):
try:
openai_aclient = self._get_openai_client(
@ -1302,9 +1393,12 @@ class OpenAIChatCompletion(BaseLLM):
client=client,
)
response = await openai_aclient.audio.transcriptions.create(
**data, timeout=timeout
) # type: ignore
headers, response = await self.make_openai_audio_transcriptions_request(
openai_aclient=openai_aclient,
data=data,
timeout=timeout,
)
logging_obj.model_call_details["response_headers"] = headers
stringified_response = response.model_dump()
## LOGGING
logging_obj.post_call(
@ -1497,9 +1591,9 @@ class OpenAITextCompletion(BaseLLM):
model: str,
messages: list,
timeout: float,
logging_obj: LiteLLMLoggingObj,
print_verbose: Optional[Callable] = None,
api_base: Optional[str] = None,
logging_obj=None,
acompletion: bool = False,
optional_params=None,
litellm_params=None,

View file

@ -2033,6 +2033,50 @@ def function_call_prompt(messages: list, functions: list):
return messages
def response_schema_prompt(model: str, response_schema: dict) -> str:
"""
Decides if a user-defined custom prompt or default needs to be used
Returns the prompt str that's passed to the model as a user message
"""
custom_prompt_details: Optional[dict] = None
response_schema_as_message = [
{"role": "user", "content": "{}".format(response_schema)}
]
if f"{model}/response_schema_prompt" in litellm.custom_prompt_dict:
custom_prompt_details = litellm.custom_prompt_dict[
f"{model}/response_schema_prompt"
] # allow user to define custom response schema prompt by model
elif "response_schema_prompt" in litellm.custom_prompt_dict:
custom_prompt_details = litellm.custom_prompt_dict["response_schema_prompt"]
if custom_prompt_details is not None:
return custom_prompt(
role_dict=custom_prompt_details["roles"],
initial_prompt_value=custom_prompt_details["initial_prompt_value"],
final_prompt_value=custom_prompt_details["final_prompt_value"],
messages=response_schema_as_message,
)
else:
return default_response_schema_prompt(response_schema=response_schema)
def default_response_schema_prompt(response_schema: dict) -> str:
"""
Used if provider/model doesn't support 'response_schema' param.
This is the default prompt. Allow user to override this with a custom_prompt.
"""
prompt_str = """Use this JSON schema:
```json
{}
```""".format(
response_schema
)
return prompt_str
# Custom prompt template
def custom_prompt(
role_dict: dict,

View file

@ -12,6 +12,7 @@ import requests # type: ignore
from pydantic import BaseModel
import litellm
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.llms.prompt_templates.factory import (
convert_to_anthropic_image_obj,
@ -328,11 +329,14 @@ def _gemini_convert_messages_with_history(messages: list) -> List[ContentType]:
contents: List[ContentType] = []
msg_i = 0
try:
while msg_i < len(messages):
user_content: List[PartType] = []
init_msg_i = msg_i
## MERGE CONSECUTIVE USER CONTENT ##
while msg_i < len(messages) and messages[msg_i]["role"] in user_message_types:
while (
msg_i < len(messages) and messages[msg_i]["role"] in user_message_types
):
if isinstance(messages[msg_i]["content"], list):
_parts: List[PartType] = []
for element in messages[msg_i]["content"]:
@ -375,7 +379,9 @@ def _gemini_convert_messages_with_history(messages: list) -> List[ContentType]:
"tool_calls", []
): # support assistant tool invoke conversion
assistant_content.extend(
convert_to_gemini_tool_call_invoke(messages[msg_i]["tool_calls"])
convert_to_gemini_tool_call_invoke(
messages[msg_i]["tool_calls"]
)
)
else:
assistant_text = (
@ -400,8 +406,9 @@ def _gemini_convert_messages_with_history(messages: list) -> List[ContentType]:
messages[msg_i]
)
)
return contents
except Exception as e:
raise e
def _get_client_cache_key(model: str, vertex_project: str, vertex_location: str):

View file

@ -1,24 +1,32 @@
# What is this?
## Handler file for calling claude-3 on vertex ai
import os, types
import copy
import json
import os
import time
import types
import uuid
from enum import Enum
import requests, copy # type: ignore
import time, uuid
from typing import Callable, Optional, List
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from typing import Any, Callable, List, Optional, Tuple
import httpx # type: ignore
import requests # type: ignore
import litellm
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.types.utils import ResponseFormatChunk
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
from .prompt_templates.factory import (
contains_tag,
prompt_factory,
custom_prompt,
construct_tool_use_system_prompt,
contains_tag,
custom_prompt,
extract_between_tags,
parse_xml_params,
prompt_factory,
response_schema_prompt,
)
import httpx # type: ignore
class VertexAIError(Exception):
@ -104,6 +112,7 @@ class VertexAIAnthropicConfig:
"stop",
"temperature",
"top_p",
"response_format",
]
def map_openai_params(self, non_default_params: dict, optional_params: dict):
@ -120,6 +129,8 @@ class VertexAIAnthropicConfig:
optional_params["temperature"] = value
if param == "top_p":
optional_params["top_p"] = value
if param == "response_format" and "response_schema" in value:
optional_params["response_format"] = ResponseFormatChunk(**value) # type: ignore
return optional_params
@ -129,7 +140,6 @@ class VertexAIAnthropicConfig:
"""
# makes headers for API call
def refresh_auth(
credentials,
) -> str: # used when user passes in credentials as json string
@ -144,6 +154,40 @@ def refresh_auth(
return credentials.token
def get_vertex_client(
client: Any,
vertex_project: Optional[str],
vertex_location: Optional[str],
vertex_credentials: Optional[str],
) -> Tuple[Any, Optional[str]]:
args = locals()
from litellm.llms.vertex_httpx import VertexLLM
try:
from anthropic import AnthropicVertex
except Exception:
raise VertexAIError(
status_code=400,
message="""vertexai import failed please run `pip install -U google-cloud-aiplatform "anthropic[vertex]"`""",
)
access_token: Optional[str] = None
if client is None:
_credentials, cred_project_id = VertexLLM().load_auth(
credentials=vertex_credentials, project_id=vertex_project
)
vertex_ai_client = AnthropicVertex(
project_id=vertex_project or cred_project_id,
region=vertex_location or "us-central1",
access_token=_credentials.token,
)
else:
vertex_ai_client = client
return vertex_ai_client, access_token
def completion(
model: str,
messages: list,
@ -151,10 +195,10 @@ def completion(
print_verbose: Callable,
encoding,
logging_obj,
optional_params: dict,
vertex_project=None,
vertex_location=None,
vertex_credentials=None,
optional_params=None,
litellm_params=None,
logger_fn=None,
acompletion: bool = False,
@ -178,6 +222,13 @@ def completion(
)
try:
vertex_ai_client, access_token = get_vertex_client(
client=client,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_credentials=vertex_credentials,
)
## Load Config
config = litellm.VertexAIAnthropicConfig.get_config()
for k, v in config.items():
@ -186,6 +237,7 @@ def completion(
## Format Prompt
_is_function_call = False
_is_json_schema = False
messages = copy.deepcopy(messages)
optional_params = copy.deepcopy(optional_params)
# Separate system prompt from rest of message
@ -200,6 +252,29 @@ def completion(
messages.pop(idx)
if len(system_prompt) > 0:
optional_params["system"] = system_prompt
# Checks for 'response_schema' support - if passed in
if "response_format" in optional_params:
response_format_chunk = ResponseFormatChunk(
**optional_params["response_format"] # type: ignore
)
supports_response_schema = litellm.supports_response_schema(
model=model, custom_llm_provider="vertex_ai"
)
if (
supports_response_schema is False
and response_format_chunk["type"] == "json_object"
and "response_schema" in response_format_chunk
):
_is_json_schema = True
user_response_schema_message = response_schema_prompt(
model=model,
response_schema=response_format_chunk["response_schema"],
)
messages.append(
{"role": "user", "content": user_response_schema_message}
)
messages.append({"role": "assistant", "content": "{"})
optional_params.pop("response_format")
# Format rest of message according to anthropic guidelines
try:
messages = prompt_factory(
@ -233,32 +308,6 @@ def completion(
print_verbose(
f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}; vertex_credentials={vertex_credentials}"
)
access_token = None
if client is None:
if vertex_credentials is not None and isinstance(vertex_credentials, str):
import google.oauth2.service_account
try:
json_obj = json.loads(vertex_credentials)
except json.JSONDecodeError:
json_obj = json.load(open(vertex_credentials))
creds = (
google.oauth2.service_account.Credentials.from_service_account_info(
json_obj,
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
)
### CHECK IF ACCESS
access_token = refresh_auth(credentials=creds)
vertex_ai_client = AnthropicVertex(
project_id=vertex_project,
region=vertex_location,
access_token=access_token,
)
else:
vertex_ai_client = client
if acompletion == True:
"""
@ -315,7 +364,16 @@ def completion(
)
message = vertex_ai_client.messages.create(**data) # type: ignore
text_content = message.content[0].text
## LOGGING
logging_obj.post_call(
input=messages,
api_key="",
original_response=message,
additional_args={"complete_input_dict": data},
)
text_content: str = message.content[0].text
## TOOL CALLING - OUTPUT PARSE
if text_content is not None and contains_tag("invoke", text_content):
function_name = extract_between_tags("tool_name", text_content)[0]
@ -338,6 +396,12 @@ def completion(
content=None,
)
model_response.choices[0].message = _message # type: ignore
else:
if (
_is_json_schema
): # follows https://github.com/anthropics/anthropic-cookbook/blob/main/misc/how_to_enable_json_mode.ipynb
json_response = "{" + text_content[: text_content.rfind("}") + 1]
model_response.choices[0].message.content = json_response # type: ignore
else:
model_response.choices[0].message.content = text_content # type: ignore
model_response.choices[0].finish_reason = map_finish_reason(message.stop_reason)

View file

@ -12,7 +12,6 @@ from functools import partial
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
import httpx # type: ignore
import ijson
import requests # type: ignore
import litellm
@ -21,7 +20,10 @@ import litellm.litellm_core_utils.litellm_logging
from litellm import verbose_logger
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.llms.prompt_templates.factory import convert_url_to_base64
from litellm.llms.prompt_templates.factory import (
convert_url_to_base64,
response_schema_prompt,
)
from litellm.llms.vertex_ai import _gemini_convert_messages_with_history
from litellm.types.llms.openai import (
ChatCompletionResponseMessage,
@ -183,10 +185,17 @@ class GoogleAIStudioGeminiConfig: # key diff from VertexAI - 'frequency_penalty
if param == "tools" and isinstance(value, list):
gtool_func_declarations = []
for tool in value:
_parameters = tool.get("function", {}).get("parameters", {})
_properties = _parameters.get("properties", {})
if isinstance(_properties, dict):
for _, _property in _properties.items():
if "enum" in _property and "format" not in _property:
_property["format"] = "enum"
gtool_func_declaration = FunctionDeclaration(
name=tool["function"]["name"],
description=tool["function"].get("description", ""),
parameters=tool["function"].get("parameters", {}),
parameters=_parameters,
)
gtool_func_declarations.append(gtool_func_declaration)
optional_params["tools"] = [
@ -349,6 +358,7 @@ class VertexGeminiConfig:
model: str,
non_default_params: dict,
optional_params: dict,
drop_params: bool,
):
for param, value in non_default_params.items():
if param == "temperature":
@ -368,8 +378,13 @@ class VertexGeminiConfig:
optional_params["stop_sequences"] = value
if param == "max_tokens":
optional_params["max_output_tokens"] = value
if param == "response_format" and value["type"] == "json_object": # type: ignore
if param == "response_format" and isinstance(value, dict): # type: ignore
if value["type"] == "json_object":
optional_params["response_mime_type"] = "application/json"
elif value["type"] == "text":
optional_params["response_mime_type"] = "text/plain"
if "response_schema" in value:
optional_params["response_schema"] = value["response_schema"]
if param == "frequency_penalty":
optional_params["frequency_penalty"] = value
if param == "presence_penalty":
@ -460,7 +475,7 @@ async def make_call(
raise VertexAIError(status_code=response.status_code, message=response.text)
completion_stream = ModelResponseIterator(
streaming_response=response.aiter_bytes(), sync_stream=False
streaming_response=response.aiter_lines(), sync_stream=False
)
# LOGGING
logging_obj.post_call(
@ -491,7 +506,7 @@ def make_sync_call(
raise VertexAIError(status_code=response.status_code, message=response.read())
completion_stream = ModelResponseIterator(
streaming_response=response.iter_bytes(chunk_size=2056), sync_stream=True
streaming_response=response.iter_lines(), sync_stream=True
)
# LOGGING
@ -813,11 +828,12 @@ class VertexLLM(BaseLLM):
endpoint = "generateContent"
if stream is True:
endpoint = "streamGenerateContent"
url = (
"https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}".format(
url = "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}&alt=sse".format(
_gemini_model_name, endpoint, gemini_api_key
)
else:
url = "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}".format(
_gemini_model_name, endpoint, gemini_api_key
)
else:
auth_header, vertex_project = self._ensure_access_token(
@ -829,6 +845,8 @@ class VertexLLM(BaseLLM):
endpoint = "generateContent"
if stream is True:
endpoint = "streamGenerateContent"
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}?alt=sse"
else:
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}"
if (
@ -842,6 +860,9 @@ class VertexLLM(BaseLLM):
else:
url = "{}:{}".format(api_base, endpoint)
if stream is True:
url = url + "?alt=sse"
return auth_header, url
async def async_streaming(
@ -994,6 +1015,22 @@ class VertexLLM(BaseLLM):
if len(system_prompt_indices) > 0:
for idx in reversed(system_prompt_indices):
messages.pop(idx)
# Checks for 'response_schema' support - if passed in
if "response_schema" in optional_params:
supports_response_schema = litellm.supports_response_schema(
model=model, custom_llm_provider="vertex_ai"
)
if supports_response_schema is False:
user_response_schema_message = response_schema_prompt(
model=model, response_schema=optional_params.get("response_schema") # type: ignore
)
messages.append(
{"role": "user", "content": user_response_schema_message}
)
optional_params.pop("response_schema")
try:
content = _gemini_convert_messages_with_history(messages=messages)
tools: Optional[Tools] = optional_params.pop("tools", None)
tool_choice: Optional[ToolConfig] = optional_params.pop("tool_choice", None)
@ -1017,12 +1054,14 @@ class VertexLLM(BaseLLM):
data["generationConfig"] = generation_config
headers = {
"Content-Type": "application/json; charset=utf-8",
"Content-Type": "application/json",
}
if auth_header is not None:
headers["Authorization"] = f"Bearer {auth_header}"
if extra_headers is not None:
headers.update(extra_headers)
except Exception as e:
raise e
## LOGGING
logging_obj.pre_call(
@ -1270,11 +1309,6 @@ class VertexLLM(BaseLLM):
class ModelResponseIterator:
def __init__(self, streaming_response, sync_stream: bool):
self.streaming_response = streaming_response
if sync_stream:
self.response_iterator = iter(self.streaming_response)
self.events = ijson.sendable_list()
self.coro = ijson.items_coro(self.events, "item")
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
try:
@ -1304,9 +1338,9 @@ class ModelResponseIterator:
if "usageMetadata" in processed_chunk:
usage = ChatCompletionUsageBlock(
prompt_tokens=processed_chunk["usageMetadata"]["promptTokenCount"],
completion_tokens=processed_chunk["usageMetadata"][
"candidatesTokenCount"
],
completion_tokens=processed_chunk["usageMetadata"].get(
"candidatesTokenCount", 0
),
total_tokens=processed_chunk["usageMetadata"]["totalTokenCount"],
)
@ -1324,16 +1358,24 @@ class ModelResponseIterator:
# Sync iterator
def __iter__(self):
self.response_iterator = self.streaming_response
return self
def __next__(self):
try:
chunk = self.response_iterator.__next__()
self.coro.send(chunk)
if self.events:
event = self.events.pop(0)
json_chunk = event
except StopIteration:
raise StopIteration
except ValueError as e:
raise RuntimeError(f"Error receiving chunk from stream: {e}")
try:
chunk = chunk.replace("data:", "")
chunk = chunk.strip()
if len(chunk) > 0:
json_chunk = json.loads(chunk)
return self.chunk_parser(chunk=json_chunk)
else:
return GenericStreamingChunk(
text="",
is_finished=False,
@ -1343,12 +1385,9 @@ class ModelResponseIterator:
tool_use=None,
)
except StopIteration:
if self.events: # flush the events
event = self.events.pop(0) # Remove the first event
return self.chunk_parser(chunk=event)
raise StopIteration
except ValueError as e:
raise RuntimeError(f"Error parsing chunk: {e}")
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
# Async iterator
def __aiter__(self):
@ -1358,11 +1397,18 @@ class ModelResponseIterator:
async def __anext__(self):
try:
chunk = await self.async_response_iterator.__anext__()
self.coro.send(chunk)
if self.events:
event = self.events.pop(0)
json_chunk = event
except StopAsyncIteration:
raise StopAsyncIteration
except ValueError as e:
raise RuntimeError(f"Error receiving chunk from stream: {e}")
try:
chunk = chunk.replace("data:", "")
chunk = chunk.strip()
if len(chunk) > 0:
json_chunk = json.loads(chunk)
return self.chunk_parser(chunk=json_chunk)
else:
return GenericStreamingChunk(
text="",
is_finished=False,
@ -1372,9 +1418,6 @@ class ModelResponseIterator:
tool_use=None,
)
except StopAsyncIteration:
if self.events: # flush the events
event = self.events.pop(0) # Remove the first event
return self.chunk_parser(chunk=event)
raise StopAsyncIteration
except ValueError as e:
raise RuntimeError(f"Error parsing chunk: {e}")
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")

View file

@ -476,6 +476,15 @@ def mock_completion(
model=model, # type: ignore
request=httpx.Request(method="POST", url="https://api.openai.com/v1/"),
)
elif (
isinstance(mock_response, str) and mock_response == "litellm.RateLimitError"
):
raise litellm.RateLimitError(
message="this is a mock rate limit error",
status_code=getattr(mock_response, "status_code", 429), # type: ignore
llm_provider=getattr(mock_response, "llm_provider", custom_llm_provider or "openai"), # type: ignore
model=model,
)
time_delay = kwargs.get("mock_delay", None)
if time_delay is not None:
time.sleep(time_delay)
@ -676,6 +685,8 @@ def completion(
client = kwargs.get("client", None)
### Admin Controls ###
no_log = kwargs.get("no-log", False)
### COPY MESSAGES ### - related issue https://github.com/BerriAI/litellm/discussions/4489
messages = deepcopy(messages)
######## end of unpacking kwargs ###########
openai_params = [
"functions",
@ -1828,6 +1839,7 @@ def completion(
logging_obj=logging,
acompletion=acompletion,
timeout=timeout, # type: ignore
custom_llm_provider="openrouter",
)
## LOGGING
logging.post_call(
@ -2199,46 +2211,29 @@ def completion(
# boto3 reads keys from .env
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
if "aws_bedrock_client" in optional_params:
verbose_logger.warning(
"'aws_bedrock_client' is a deprecated param. Please move to another auth method - https://docs.litellm.ai/docs/providers/bedrock#boto3---authentication."
)
# Extract credentials for legacy boto3 client and pass thru to httpx
aws_bedrock_client = optional_params.pop("aws_bedrock_client")
creds = aws_bedrock_client._get_credentials().get_frozen_credentials()
if creds.access_key:
optional_params["aws_access_key_id"] = creds.access_key
if creds.secret_key:
optional_params["aws_secret_access_key"] = creds.secret_key
if creds.token:
optional_params["aws_session_token"] = creds.token
if (
"aws_bedrock_client" in optional_params
): # use old bedrock flow for aws_bedrock_client users.
response = bedrock.completion(
model=model,
messages=messages,
custom_prompt_dict=litellm.custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
logging_obj=logging,
extra_headers=extra_headers,
timeout=timeout,
"aws_region_name" not in optional_params
or optional_params["aws_region_name"] is None
):
optional_params["aws_region_name"] = (
aws_bedrock_client.meta.region_name
)
if (
"stream" in optional_params
and optional_params["stream"] == True
and not isinstance(response, CustomStreamWrapper)
):
# don't try to access stream object,
if "ai21" in model:
response = CustomStreamWrapper(
response,
model,
custom_llm_provider="bedrock",
logging_obj=logging,
)
else:
response = CustomStreamWrapper(
iter(response),
model,
custom_llm_provider="bedrock",
logging_obj=logging,
)
else:
if model.startswith("anthropic"):
if model in litellm.BEDROCK_CONVERSE_MODELS:
response = bedrock_converse_chat_completion.completion(
model=model,
messages=messages,
@ -2272,6 +2267,7 @@ def completion(
acompletion=acompletion,
client=client,
)
if optional_params.get("stream", False):
## LOGGING
logging.post_call(

View file

@ -1486,6 +1486,7 @@
"supports_system_messages": true,
"supports_function_calling": true,
"supports_tool_choice": true,
"supports_response_schema": true,
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
},
"gemini-1.5-pro-001": {
@ -1511,6 +1512,7 @@
"supports_system_messages": true,
"supports_function_calling": true,
"supports_tool_choice": true,
"supports_response_schema": true,
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
},
"gemini-1.5-pro-preview-0514": {
@ -1536,6 +1538,7 @@
"supports_system_messages": true,
"supports_function_calling": true,
"supports_tool_choice": true,
"supports_response_schema": true,
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
},
"gemini-1.5-pro-preview-0215": {
@ -1561,6 +1564,7 @@
"supports_system_messages": true,
"supports_function_calling": true,
"supports_tool_choice": true,
"supports_response_schema": true,
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
},
"gemini-1.5-pro-preview-0409": {
@ -1585,6 +1589,7 @@
"mode": "chat",
"supports_function_calling": true,
"supports_tool_choice": true,
"supports_response_schema": true,
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
},
"gemini-1.5-flash": {
@ -2007,6 +2012,7 @@
"supports_function_calling": true,
"supports_vision": true,
"supports_tool_choice": true,
"supports_response_schema": true,
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
},
"gemini/gemini-1.5-pro-latest": {
@ -2023,6 +2029,7 @@
"supports_function_calling": true,
"supports_vision": true,
"supports_tool_choice": true,
"supports_response_schema": true,
"source": "https://ai.google.dev/models/gemini"
},
"gemini/gemini-pro-vision": {

View file

@ -1,54 +1,5 @@
# model_list:
# - model_name: my-fake-model
# litellm_params:
# model: bedrock/anthropic.claude-3-sonnet-20240229-v1:0
# api_key: my-fake-key
# aws_bedrock_runtime_endpoint: http://127.0.0.1:8000
# mock_response: "Hello world 1"
# model_info:
# max_input_tokens: 0 # trigger context window fallback
# - model_name: my-fake-model
# litellm_params:
# model: bedrock/anthropic.claude-3-sonnet-20240229-v1:0
# api_key: my-fake-key
# aws_bedrock_runtime_endpoint: http://127.0.0.1:8000
# mock_response: "Hello world 2"
# model_info:
# max_input_tokens: 0
# router_settings:
# enable_pre_call_checks: True
# litellm_settings:
# failure_callback: ["langfuse"]
model_list:
- model_name: summarize
- model_name: claude-3-5-sonnet # all requests where model not in your config go to this deployment
litellm_params:
model: openai/gpt-4o
rpm: 10000
tpm: 12000000
api_key: os.environ/OPENAI_API_KEY
mock_response: Hello world 1
- model_name: summarize-l
litellm_params:
model: claude-3-5-sonnet-20240620
rpm: 4000
tpm: 400000
api_key: os.environ/ANTHROPIC_API_KEY
mock_response: Hello world 2
litellm_settings:
num_retries: 3
request_timeout: 120
allowed_fails: 3
# fallbacks: [{"summarize": ["summarize-l", "summarize-xl"]}, {"summarize-l": ["summarize-xl"]}]
# context_window_fallbacks: [{"summarize": ["summarize-l", "summarize-xl"]}, {"summarize-l": ["summarize-xl"]}]
router_settings:
routing_strategy: simple-shuffle
enable_pre_call_checks: true.
model: "openai/*"
mock_response: "litellm.RateLimitError"

View file

@ -1,10 +1,11 @@
model_list:
- model_name: claude-3-5-sonnet
litellm_params:
model: anthropic/claude-3-5-sonnet
- model_name: gemini-1.5-flash-gemini
litellm_params:
model: gemini/gemini-1.5-flash
- model_name: gemini-1.5-flash-gemini
litellm_params:
model: gemini/gemini-1.5-flash
model: vertex_ai_beta/gemini-1.5-flash
api_base: https://gateway.ai.cloudflare.com/v1/fa4cdcab1f32b95ca3b53fd36043d691/test/google-vertex-ai/v1/projects/adroit-crow-413218/locations/us-central1/publishers/google/models/gemini-1.5-flash
- litellm_params:
api_base: http://0.0.0.0:8080
api_key: ''

View file

@ -1622,7 +1622,7 @@ class ProxyException(Exception):
}
class CommonProxyErrors(enum.Enum):
class CommonProxyErrors(str, enum.Enum):
db_not_connected_error = "DB not connected"
no_llm_router = "No models configured on proxy"
not_allowed_access = "Admin-only endpoint. Not allowed to access this."

View file

@ -0,0 +1,173 @@
import ast
import traceback
from base64 import b64encode
import httpx
from fastapi import (
APIRouter,
Depends,
FastAPI,
HTTPException,
Request,
Response,
status,
)
from fastapi.responses import StreamingResponse
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import ProxyException
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
async_client = httpx.AsyncClient()
async def set_env_variables_in_header(custom_headers: dict):
"""
checks if nay headers on config.yaml are defined as os.environ/COHERE_API_KEY etc
only runs for headers defined on config.yaml
example header can be
{"Authorization": "bearer os.environ/COHERE_API_KEY"}
"""
headers = {}
for key, value in custom_headers.items():
# langfuse Api requires base64 encoded headers - it's simpleer to just ask litellm users to set their langfuse public and secret keys
# we can then get the b64 encoded keys here
if key == "LANGFUSE_PUBLIC_KEY" or key == "LANGFUSE_SECRET_KEY":
# langfuse requires b64 encoded headers - we construct that here
_langfuse_public_key = custom_headers["LANGFUSE_PUBLIC_KEY"]
_langfuse_secret_key = custom_headers["LANGFUSE_SECRET_KEY"]
if isinstance(
_langfuse_public_key, str
) and _langfuse_public_key.startswith("os.environ/"):
_langfuse_public_key = litellm.get_secret(_langfuse_public_key)
if isinstance(
_langfuse_secret_key, str
) and _langfuse_secret_key.startswith("os.environ/"):
_langfuse_secret_key = litellm.get_secret(_langfuse_secret_key)
headers["Authorization"] = "Basic " + b64encode(
f"{_langfuse_public_key}:{_langfuse_secret_key}".encode("utf-8")
).decode("ascii")
else:
# for all other headers
headers[key] = value
if isinstance(value, str) and "os.environ/" in value:
verbose_proxy_logger.debug(
"pass through endpoint - looking up 'os.environ/' variable"
)
# get string section that is os.environ/
start_index = value.find("os.environ/")
_variable_name = value[start_index:]
verbose_proxy_logger.debug(
"pass through endpoint - getting secret for variable name: %s",
_variable_name,
)
_secret_value = litellm.get_secret(_variable_name)
new_value = value.replace(_variable_name, _secret_value)
headers[key] = new_value
return headers
async def pass_through_request(request: Request, target: str, custom_headers: dict):
try:
url = httpx.URL(target)
headers = custom_headers
request_body = await request.body()
_parsed_body = ast.literal_eval(request_body.decode("utf-8"))
verbose_proxy_logger.debug(
"Pass through endpoint sending request to \nURL {}\nheaders: {}\nbody: {}\n".format(
url, headers, _parsed_body
)
)
response = await async_client.request(
method=request.method,
url=url,
headers=headers,
params=request.query_params,
json=_parsed_body,
)
if response.status_code != 200:
raise HTTPException(status_code=response.status_code, detail=response.text)
content = await response.aread()
return Response(
content=content,
status_code=response.status_code,
headers=dict(response.headers),
)
except Exception as e:
verbose_proxy_logger.error(
"litellm.proxy.proxy_server.pass through endpoint(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "message", str(e.detail)),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
)
else:
error_msg = f"{str(e)}"
raise ProxyException(
message=getattr(e, "message", error_msg),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", 500),
)
def create_pass_through_route(endpoint, target, custom_headers=None):
async def endpoint_func(request: Request):
return await pass_through_request(request, target, custom_headers)
return endpoint_func
async def initialize_pass_through_endpoints(pass_through_endpoints: list):
verbose_proxy_logger.debug("initializing pass through endpoints")
from litellm.proxy._types import CommonProxyErrors, LiteLLMRoutes
from litellm.proxy.proxy_server import app, premium_user
for endpoint in pass_through_endpoints:
_target = endpoint.get("target", None)
_path = endpoint.get("path", None)
_custom_headers = endpoint.get("headers", None)
_custom_headers = await set_env_variables_in_header(
custom_headers=_custom_headers
)
_auth = endpoint.get("auth", None)
_dependencies = None
if _auth is not None and str(_auth).lower() == "true":
if premium_user is not True:
raise ValueError(
f"Error Setting Authentication on Pass Through Endpoint: {CommonProxyErrors.not_premium_user}"
)
_dependencies = [Depends(user_api_key_auth)]
LiteLLMRoutes.openai_routes.value.append(_path)
if _target is None:
continue
verbose_proxy_logger.debug("adding pass through endpoint: %s", _path)
app.add_api_route(
path=_path,
endpoint=create_pass_through_route(_path, _target, _custom_headers),
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
dependencies=_dependencies,
)
verbose_proxy_logger.debug("Added new pass through endpoint: %s", _path)

View file

@ -0,0 +1,68 @@
# What is this?
## Script to apply initial prisma migration on Docker setup
import os
import subprocess
import sys
import time
sys.path.insert(
0, os.path.abspath("./")
) # Adds the parent directory to the system path
from litellm.proxy.secret_managers.aws_secret_manager import decrypt_env_var
if os.getenv("USE_AWS_KMS", None) is not None and os.getenv("USE_AWS_KMS") == "True":
## V2 IMPLEMENTATION OF AWS KMS - USER WANTS TO DECRYPT MULTIPLE KEYS IN THEIR ENV
new_env_var = decrypt_env_var()
for k, v in new_env_var.items():
os.environ[k] = v
# Check if DATABASE_URL is not set
database_url = os.getenv("DATABASE_URL")
if not database_url:
# Check if all required variables are provided
database_host = os.getenv("DATABASE_HOST")
database_username = os.getenv("DATABASE_USERNAME")
database_password = os.getenv("DATABASE_PASSWORD")
database_name = os.getenv("DATABASE_NAME")
if database_host and database_username and database_password and database_name:
# Construct DATABASE_URL from the provided variables
database_url = f"postgresql://{database_username}:{database_password}@{database_host}/{database_name}"
os.environ["DATABASE_URL"] = database_url
else:
print( # noqa
"Error: Required database environment variables are not set. Provide a postgres url for DATABASE_URL." # noqa
)
exit(1)
# Set DIRECT_URL to the value of DATABASE_URL if it is not set, required for migrations
direct_url = os.getenv("DIRECT_URL")
if not direct_url:
os.environ["DIRECT_URL"] = database_url
# Apply migrations
retry_count = 0
max_retries = 3
exit_code = 1
while retry_count < max_retries and exit_code != 0:
retry_count += 1
print(f"Attempt {retry_count}...") # noqa
# Run the Prisma db push command
result = subprocess.run(
["prisma", "db", "push", "--accept-data-loss"], capture_output=True
)
exit_code = result.returncode
if exit_code != 0 and retry_count < max_retries:
print("Retrying in 10 seconds...") # noqa
time.sleep(10)
if exit_code != 0:
print(f"Unable to push database changes after {max_retries} retries.") # noqa
exit(1)
print("Database push successful!") # noqa

View file

@ -442,6 +442,20 @@ def run_server(
db_connection_pool_limit = 100
db_connection_timeout = 60
### DECRYPT ENV VAR ###
from litellm.proxy.secret_managers.aws_secret_manager import decrypt_env_var
if (
os.getenv("USE_AWS_KMS", None) is not None
and os.getenv("USE_AWS_KMS") == "True"
):
## V2 IMPLEMENTATION OF AWS KMS - USER WANTS TO DECRYPT MULTIPLE KEYS IN THEIR ENV
new_env_var = decrypt_env_var()
for k, v in new_env_var.items():
os.environ[k] = v
if config is not None:
"""
Allow user to pass in db url via config
@ -459,6 +473,7 @@ def run_server(
proxy_config = ProxyConfig()
_config = asyncio.run(proxy_config.get_config(config_file_path=config))
### LITELLM SETTINGS ###
litellm_settings = _config.get("litellm_settings", None)
if (

View file

@ -20,11 +20,23 @@ model_list:
general_settings:
master_key: sk-1234
alerting: ["slack", "email"]
public_routes: ["LiteLLMRoutes.public_routes", "/spend/calculate"]
pass_through_endpoints:
- path: "/v1/rerank"
target: "https://api.cohere.com/v1/rerank"
auth: true # 👈 Key change to use LiteLLM Auth / Keys
headers:
Authorization: "bearer os.environ/COHERE_API_KEY"
content-type: application/json
accept: application/json
- path: "/api/public/ingestion"
target: "https://us.cloud.langfuse.com/api/public/ingestion"
auth: true
headers:
LANGFUSE_PUBLIC_KEY: "os.environ/LANGFUSE_DEV_PUBLIC_KEY"
LANGFUSE_SECRET_KEY: "os.environ/LANGFUSE_DEV_SK_KEY"
litellm_settings:
return_response_headers: true
success_callback: ["prometheus"]
callbacks: ["otel", "hide_secrets"]
failure_callback: ["prometheus"]
@ -34,6 +46,5 @@ litellm_settings:
- user
- metadata
- metadata.generation_name
cache: True

View file

@ -161,6 +161,9 @@ from litellm.proxy.management_endpoints.key_management_endpoints import (
router as key_management_router,
)
from litellm.proxy.management_endpoints.team_endpoints import router as team_router
from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
initialize_pass_through_endpoints,
)
from litellm.proxy.secret_managers.aws_secret_manager import (
load_aws_kms,
load_aws_secret_manager,
@ -590,7 +593,7 @@ async def _PROXY_failure_handler(
_model_id = _metadata.get("model_info", {}).get("id", "")
_model_group = _metadata.get("model_group", "")
api_base = litellm.get_api_base(model=_model, optional_params=_litellm_params)
_exception_string = str(_exception)[:500]
_exception_string = str(_exception)
error_log = LiteLLM_ErrorLogs(
request_id=str(uuid.uuid4()),
@ -1856,6 +1859,11 @@ class ProxyConfig:
user_custom_key_generate = get_instance_fn(
value=custom_key_generate, config_file_path=config_file_path
)
## pass through endpoints
if general_settings.get("pass_through_endpoints", None) is not None:
await initialize_pass_through_endpoints(
pass_through_endpoints=general_settings["pass_through_endpoints"]
)
## dynamodb
database_type = general_settings.get("database_type", None)
if database_type is not None and (
@ -7503,7 +7511,9 @@ async def login(request: Request):
# Non SSO -> If user is using UI_USERNAME and UI_PASSWORD they are Proxy admin
user_role = LitellmUserRoles.PROXY_ADMIN
user_id = username
key_user_id = user_id
# we want the key created to have PROXY_ADMIN_PERMISSIONS
key_user_id = litellm_proxy_admin_name
if (
os.getenv("PROXY_ADMIN_ID", None) is not None
and os.environ["PROXY_ADMIN_ID"] == user_id
@ -7523,7 +7533,17 @@ async def login(request: Request):
if os.getenv("DATABASE_URL") is not None:
response = await generate_key_helper_fn(
request_type="key",
**{"user_role": LitellmUserRoles.PROXY_ADMIN, "duration": "2hr", "key_max_budget": 5, "models": [], "aliases": {}, "config": {}, "spend": 0, "user_id": key_user_id, "team_id": "litellm-dashboard"}, # type: ignore
**{
"user_role": LitellmUserRoles.PROXY_ADMIN,
"duration": "2hr",
"key_max_budget": 5,
"models": [],
"aliases": {},
"config": {},
"spend": 0,
"user_id": key_user_id,
"team_id": "litellm-dashboard",
}, # type: ignore
)
else:
raise ProxyException(

View file

@ -8,9 +8,13 @@ Requires:
* `pip install boto3>=1.28.57`
"""
import litellm
import ast
import base64
import os
from typing import Optional
import re
from typing import Any, Dict, Optional
import litellm
from litellm.proxy._types import KeyManagementSystem
@ -57,3 +61,99 @@ def load_aws_kms(use_aws_kms: Optional[bool]):
except Exception as e:
raise e
class AWSKeyManagementService_V2:
"""
V2 Clean Class for decrypting keys from AWS KeyManagementService
"""
def __init__(self) -> None:
self.validate_environment()
self.kms_client = self.load_aws_kms(use_aws_kms=True)
def validate_environment(
self,
):
if "AWS_REGION_NAME" not in os.environ:
raise ValueError("Missing required environment variable - AWS_REGION_NAME")
## CHECK IF LICENSE IN ENV ## - premium feature
if os.getenv("LITELLM_LICENSE", None) is None:
raise ValueError(
"AWSKeyManagementService V2 is an Enterprise Feature. Please add a valid LITELLM_LICENSE to your envionment."
)
def load_aws_kms(self, use_aws_kms: Optional[bool]):
if use_aws_kms is None or use_aws_kms is False:
return
try:
import boto3
validate_environment()
# Create a Secrets Manager client
kms_client = boto3.client("kms", region_name=os.getenv("AWS_REGION_NAME"))
return kms_client
except Exception as e:
raise e
def decrypt_value(self, secret_name: str) -> Any:
if self.kms_client is None:
raise ValueError("kms_client is None")
encrypted_value = os.getenv(secret_name, None)
if encrypted_value is None:
raise Exception(
"AWS KMS - Encrypted Value of Key={} is None".format(secret_name)
)
if isinstance(encrypted_value, str) and encrypted_value.startswith("aws_kms/"):
encrypted_value = encrypted_value.replace("aws_kms/", "")
# Decode the base64 encoded ciphertext
ciphertext_blob = base64.b64decode(encrypted_value)
# Set up the parameters for the decrypt call
params = {"CiphertextBlob": ciphertext_blob}
# Perform the decryption
response = self.kms_client.decrypt(**params)
# Extract and decode the plaintext
plaintext = response["Plaintext"]
secret = plaintext.decode("utf-8")
if isinstance(secret, str):
secret = secret.strip()
try:
secret_value_as_bool = ast.literal_eval(secret)
if isinstance(secret_value_as_bool, bool):
return secret_value_as_bool
except Exception:
pass
return secret
"""
- look for all values in the env with `aws_kms/<hashed_key>`
- decrypt keys
- rewrite env var with decrypted key (). Note: this environment variable will only be available to the current process and any child processes spawned from it. Once the Python script ends, the environment variable will not persist.
"""
def decrypt_env_var() -> Dict[str, Any]:
# setup client class
aws_kms = AWSKeyManagementService_V2()
# iterate through env - for `aws_kms/`
new_values = {}
for k, v in os.environ.items():
if (
k is not None
and isinstance(k, str)
and k.lower().startswith("litellm_secret_aws_kms")
) or (v is not None and isinstance(v, str) and v.startswith("aws_kms/")):
decrypted_value = aws_kms.decrypt_value(secret_name=k)
# reset env var
k = re.sub("litellm_secret_aws_kms_", "", k, flags=re.IGNORECASE)
new_values[k] = decrypted_value
return new_values

View file

@ -817,9 +817,9 @@ async def get_global_spend_report(
default=None,
description="Time till which to view spend",
),
group_by: Optional[Literal["team", "customer"]] = fastapi.Query(
group_by: Optional[Literal["team", "customer", "api_key"]] = fastapi.Query(
default="team",
description="Group spend by internal team or customer",
description="Group spend by internal team or customer or api_key",
),
):
"""
@ -860,7 +860,7 @@ async def get_global_spend_report(
start_date_obj = datetime.strptime(start_date, "%Y-%m-%d")
end_date_obj = datetime.strptime(end_date, "%Y-%m-%d")
from litellm.proxy.proxy_server import prisma_client
from litellm.proxy.proxy_server import premium_user, prisma_client
try:
if prisma_client is None:
@ -868,6 +868,12 @@ async def get_global_spend_report(
f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys"
)
if premium_user is not True:
verbose_proxy_logger.debug("accessing /spend/report but not a premium user")
raise ValueError(
"/spend/report endpoint " + CommonProxyErrors.not_premium_user.value
)
if group_by == "team":
# first get data from spend logs -> SpendByModelApiKey
# then read data from "SpendByModelApiKey" to format the response obj
@ -992,6 +998,48 @@ async def get_global_spend_report(
return []
return db_response
elif group_by == "api_key":
sql_query = """
WITH SpendByModelApiKey AS (
SELECT
sl.api_key,
sl.model,
SUM(sl.spend) AS model_cost,
SUM(sl.prompt_tokens) AS model_input_tokens,
SUM(sl.completion_tokens) AS model_output_tokens
FROM
"LiteLLM_SpendLogs" sl
WHERE
sl."startTime" BETWEEN $1::date AND $2::date
GROUP BY
sl.api_key,
sl.model
)
SELECT
api_key,
SUM(model_cost) AS total_cost,
SUM(model_input_tokens) AS total_input_tokens,
SUM(model_output_tokens) AS total_output_tokens,
jsonb_agg(jsonb_build_object(
'model', model,
'total_cost', model_cost,
'total_input_tokens', model_input_tokens,
'total_output_tokens', model_output_tokens
)) AS model_details
FROM
SpendByModelApiKey
GROUP BY
api_key
ORDER BY
total_cost DESC;
"""
db_response = await prisma_client.db.query_raw(
sql_query, start_date_obj, end_date_obj
)
if db_response is None:
return []
return db_response
except Exception as e:
raise HTTPException(

View file

@ -0,0 +1,14 @@
from langfuse import Langfuse
langfuse = Langfuse(
host="http://localhost:4000",
public_key="anything",
secret_key="anything",
)
print("sending langfuse trace request")
trace = langfuse.trace(name="test-trace-litellm-proxy-passthrough")
print("flushing langfuse request")
langfuse.flush()
print("flushed langfuse request")

View file

@ -105,7 +105,9 @@ class Router:
def __init__(
self,
model_list: Optional[List[Union[DeploymentTypedDict, Dict]]] = None,
model_list: Optional[
Union[List[DeploymentTypedDict], List[Dict[str, Any]]]
] = None,
## ASSISTANTS API ##
assistants_config: Optional[AssistantsTypedDict] = None,
## CACHING ##
@ -154,6 +156,7 @@ class Router:
cooldown_time: Optional[
float
] = None, # (seconds) time to cooldown a deployment after failure
disable_cooldowns: Optional[bool] = None,
routing_strategy: Literal[
"simple-shuffle",
"least-busy",
@ -305,6 +308,7 @@ class Router:
self.allowed_fails = allowed_fails or litellm.allowed_fails
self.cooldown_time = cooldown_time or 60
self.disable_cooldowns = disable_cooldowns
self.failed_calls = (
InMemoryCache()
) # cache to track failed call per deployment, if num failed calls within 1 minute > allowed fails, then add it to cooldown
@ -2988,6 +2992,8 @@ class Router:
the exception is not one that should be immediately retried (e.g. 401)
"""
if self.disable_cooldowns is True:
return
if deployment is None:
return
@ -3028,24 +3034,50 @@ class Router:
exception_status = 500
_should_retry = litellm._should_retry(status_code=exception_status)
if updated_fails > allowed_fails or _should_retry == False:
if updated_fails > allowed_fails or _should_retry is False:
# get the current cooldown list for that minute
cooldown_key = f"{current_minute}:cooldown_models" # group cooldown models by minute to reduce number of redis calls
cached_value = self.cache.get_cache(key=cooldown_key)
cached_value = self.cache.get_cache(
key=cooldown_key
) # [(deployment_id, {last_error_str, last_error_status_code})]
cached_value_deployment_ids = []
if (
cached_value is not None
and isinstance(cached_value, list)
and len(cached_value) > 0
and isinstance(cached_value[0], tuple)
):
cached_value_deployment_ids = [cv[0] for cv in cached_value]
verbose_router_logger.debug(f"adding {deployment} to cooldown models")
# update value
try:
if deployment in cached_value:
if cached_value is not None and len(cached_value_deployment_ids) > 0:
if deployment in cached_value_deployment_ids:
pass
else:
cached_value = cached_value + [deployment]
cached_value = cached_value + [
(
deployment,
{
"Exception Received": str(original_exception),
"Status Code": str(exception_status),
},
)
]
# save updated value
self.cache.set_cache(
value=cached_value, key=cooldown_key, ttl=cooldown_time
)
except:
cached_value = [deployment]
else:
cached_value = [
(
deployment,
{
"Exception Received": str(original_exception),
"Status Code": str(exception_status),
},
)
]
# save updated value
self.cache.set_cache(
value=cached_value, key=cooldown_key, ttl=cooldown_time
@ -3061,7 +3093,33 @@ class Router:
key=deployment, value=updated_fails, ttl=cooldown_time
)
async def _async_get_cooldown_deployments(self):
async def _async_get_cooldown_deployments(self) -> List[str]:
"""
Async implementation of '_get_cooldown_deployments'
"""
dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M")
# get the current cooldown list for that minute
cooldown_key = f"{current_minute}:cooldown_models"
# ----------------------
# Return cooldown models
# ----------------------
cooldown_models = await self.cache.async_get_cache(key=cooldown_key) or []
cached_value_deployment_ids = []
if (
cooldown_models is not None
and isinstance(cooldown_models, list)
and len(cooldown_models) > 0
and isinstance(cooldown_models[0], tuple)
):
cached_value_deployment_ids = [cv[0] for cv in cooldown_models]
verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}")
return cached_value_deployment_ids
async def _async_get_cooldown_deployments_with_debug_info(self) -> List[tuple]:
"""
Async implementation of '_get_cooldown_deployments'
"""
@ -3078,7 +3136,7 @@ class Router:
verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}")
return cooldown_models
def _get_cooldown_deployments(self):
def _get_cooldown_deployments(self) -> List[str]:
"""
Get the list of models being cooled down for this minute
"""
@ -3092,8 +3150,17 @@ class Router:
# ----------------------
cooldown_models = self.cache.get_cache(key=cooldown_key) or []
cached_value_deployment_ids = []
if (
cooldown_models is not None
and isinstance(cooldown_models, list)
and len(cooldown_models) > 0
and isinstance(cooldown_models[0], tuple)
):
cached_value_deployment_ids = [cv[0] for cv in cooldown_models]
verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}")
return cooldown_models
return cached_value_deployment_ids
def _get_healthy_deployments(self, model: str):
_all_deployments: list = []
@ -3970,16 +4037,36 @@ class Router:
Augment litellm info with additional params set in `model_info`.
For azure models, ignore the `model:`. Only set max tokens, cost values if base_model is set.
Returns
- ModelInfo - If found -> typed dict with max tokens, input cost, etc.
Raises:
- ValueError -> If model is not mapped yet
"""
## SET MODEL NAME
## GET BASE MODEL
base_model = deployment.get("model_info", {}).get("base_model", None)
if base_model is None:
base_model = deployment.get("litellm_params", {}).get("base_model", None)
model = base_model or deployment.get("litellm_params", {}).get("model", None)
## GET LITELLM MODEL INFO
model = base_model
## GET PROVIDER
_model, custom_llm_provider, _, _ = litellm.get_llm_provider(
model=deployment.get("litellm_params", {}).get("model", ""),
litellm_params=LiteLLM_Params(**deployment.get("litellm_params", {})),
)
## SET MODEL TO 'model=' - if base_model is None + not azure
if custom_llm_provider == "azure" and base_model is None:
verbose_router_logger.error(
"Could not identify azure model. Set azure 'base_model' for accurate max tokens, cost tracking, etc.- https://docs.litellm.ai/docs/proxy/cost_tracking#spend-tracking-for-azure-openai-models"
)
elif custom_llm_provider != "azure":
model = _model
## GET LITELLM MODEL INFO - raises exception, if model is not mapped
model_info = litellm.get_model_info(model=model)
## CHECK USER SET MODEL INFO
@ -4365,7 +4452,7 @@ class Router:
"""
Filter out model in model group, if:
- model context window < message length
- model context window < message length. For azure openai models, requires 'base_model' is set. - https://docs.litellm.ai/docs/proxy/cost_tracking#spend-tracking-for-azure-openai-models
- filter models above rpm limits
- if region given, filter out models not in that region / unknown region
- [TODO] function call and model doesn't support function calling
@ -4382,6 +4469,11 @@ class Router:
try:
input_tokens = litellm.token_counter(messages=messages)
except Exception as e:
verbose_router_logger.error(
"litellm.router.py::_pre_call_checks: failed to count tokens. Returning initial list of deployments. Got - {}".format(
str(e)
)
)
return _returned_deployments
_context_window_error = False
@ -4425,7 +4517,7 @@ class Router:
)
continue
except Exception as e:
verbose_router_logger.debug("An error occurs - {}".format(str(e)))
verbose_router_logger.error("An error occurs - {}".format(str(e)))
_litellm_params = deployment.get("litellm_params", {})
model_id = deployment.get("model_info", {}).get("id", "")
@ -4686,7 +4778,7 @@ class Router:
if _allowed_model_region is None:
_allowed_model_region = "n/a"
raise ValueError(
f"{RouterErrors.no_deployments_available.value}, Try again in {self.cooldown_time} seconds. Passed model={model}. pre-call-checks={self.enable_pre_call_checks}, allowed_model_region={_allowed_model_region}"
f"{RouterErrors.no_deployments_available.value}, Try again in {self.cooldown_time} seconds. Passed model={model}. pre-call-checks={self.enable_pre_call_checks}, allowed_model_region={_allowed_model_region}, cooldown_list={await self._async_get_cooldown_deployments_with_debug_info()}"
)
if (

View file

@ -880,6 +880,208 @@ Using this JSON schema:
mock_call.assert_called_once()
def vertex_httpx_mock_post_valid_response(*args, **kwargs):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.headers = {"Content-Type": "application/json"}
mock_response.json.return_value = {
"candidates": [
{
"content": {
"role": "model",
"parts": [
{
"text": '[{"recipe_name": "Chocolate Chip Cookies"}, {"recipe_name": "Oatmeal Raisin Cookies"}, {"recipe_name": "Peanut Butter Cookies"}, {"recipe_name": "Sugar Cookies"}, {"recipe_name": "Snickerdoodles"}]\n'
}
],
},
"finishReason": "STOP",
"safetyRatings": [
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"probability": "NEGLIGIBLE",
"probabilityScore": 0.09790669,
"severity": "HARM_SEVERITY_NEGLIGIBLE",
"severityScore": 0.11736965,
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"probability": "NEGLIGIBLE",
"probabilityScore": 0.1261379,
"severity": "HARM_SEVERITY_NEGLIGIBLE",
"severityScore": 0.08601588,
},
{
"category": "HARM_CATEGORY_HARASSMENT",
"probability": "NEGLIGIBLE",
"probabilityScore": 0.083441176,
"severity": "HARM_SEVERITY_NEGLIGIBLE",
"severityScore": 0.0355444,
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"probability": "NEGLIGIBLE",
"probabilityScore": 0.071981624,
"severity": "HARM_SEVERITY_NEGLIGIBLE",
"severityScore": 0.08108212,
},
],
}
],
"usageMetadata": {
"promptTokenCount": 60,
"candidatesTokenCount": 55,
"totalTokenCount": 115,
},
}
return mock_response
def vertex_httpx_mock_post_invalid_schema_response(*args, **kwargs):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.headers = {"Content-Type": "application/json"}
mock_response.json.return_value = {
"candidates": [
{
"content": {
"role": "model",
"parts": [
{"text": '[{"recipe_world": "Chocolate Chip Cookies"}]\n'}
],
},
"finishReason": "STOP",
"safetyRatings": [
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"probability": "NEGLIGIBLE",
"probabilityScore": 0.09790669,
"severity": "HARM_SEVERITY_NEGLIGIBLE",
"severityScore": 0.11736965,
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"probability": "NEGLIGIBLE",
"probabilityScore": 0.1261379,
"severity": "HARM_SEVERITY_NEGLIGIBLE",
"severityScore": 0.08601588,
},
{
"category": "HARM_CATEGORY_HARASSMENT",
"probability": "NEGLIGIBLE",
"probabilityScore": 0.083441176,
"severity": "HARM_SEVERITY_NEGLIGIBLE",
"severityScore": 0.0355444,
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"probability": "NEGLIGIBLE",
"probabilityScore": 0.071981624,
"severity": "HARM_SEVERITY_NEGLIGIBLE",
"severityScore": 0.08108212,
},
],
}
],
"usageMetadata": {
"promptTokenCount": 60,
"candidatesTokenCount": 55,
"totalTokenCount": 115,
},
}
return mock_response
@pytest.mark.parametrize(
"model, vertex_location, supports_response_schema",
[
("vertex_ai_beta/gemini-1.5-pro-001", "us-central1", True),
("vertex_ai_beta/gemini-1.5-flash", "us-central1", False),
],
)
@pytest.mark.parametrize(
"invalid_response",
[True, False],
)
@pytest.mark.parametrize(
"enforce_validation",
[True, False],
)
@pytest.mark.asyncio
async def test_gemini_pro_json_schema_args_sent_httpx(
model,
supports_response_schema,
vertex_location,
invalid_response,
enforce_validation,
):
load_vertex_ai_credentials()
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
litellm.model_cost = litellm.get_model_cost_map(url="")
litellm.set_verbose = True
messages = [{"role": "user", "content": "List 5 cookie recipes"}]
from litellm.llms.custom_httpx.http_handler import HTTPHandler
response_schema = {
"type": "array",
"items": {
"type": "object",
"properties": {
"recipe_name": {
"type": "string",
},
},
"required": ["recipe_name"],
},
}
client = HTTPHandler()
httpx_response = MagicMock()
if invalid_response is True:
httpx_response.side_effect = vertex_httpx_mock_post_invalid_schema_response
else:
httpx_response.side_effect = vertex_httpx_mock_post_valid_response
with patch.object(client, "post", new=httpx_response) as mock_call:
try:
_ = completion(
model=model,
messages=messages,
response_format={
"type": "json_object",
"response_schema": response_schema,
"enforce_validation": enforce_validation,
},
vertex_location=vertex_location,
client=client,
)
if invalid_response is True and enforce_validation is True:
pytest.fail("Expected this to fail")
except litellm.JSONSchemaValidationError as e:
if invalid_response is False and "claude-3" not in model:
pytest.fail("Expected this to pass. Got={}".format(e))
mock_call.assert_called_once()
print(mock_call.call_args.kwargs)
print(mock_call.call_args.kwargs["json"]["generationConfig"])
if supports_response_schema:
assert (
"response_schema"
in mock_call.call_args.kwargs["json"]["generationConfig"]
)
else:
assert (
"response_schema"
not in mock_call.call_args.kwargs["json"]["generationConfig"]
)
assert (
"Use this JSON schema:"
in mock_call.call_args.kwargs["json"]["contents"][0]["parts"][1]["text"]
)
@pytest.mark.parametrize("provider", ["vertex_ai_beta"]) # "vertex_ai",
@pytest.mark.asyncio
async def test_gemini_pro_httpx_custom_api_base(provider):

View file

@ -25,6 +25,7 @@ from litellm import (
completion_cost,
embedding,
)
from litellm.llms.bedrock_httpx import BedrockLLM
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
# litellm.num_retries = 3
@ -217,6 +218,234 @@ def test_completion_bedrock_claude_sts_client_auth():
pytest.fail(f"Error occurred: {e}")
@pytest.fixture()
def bedrock_session_token_creds():
print("\ncalling oidc auto to get aws_session_token credentials")
import os
aws_region_name = os.environ["AWS_REGION_NAME"]
aws_session_token = os.environ.get("AWS_SESSION_TOKEN")
bllm = BedrockLLM()
if aws_session_token is not None:
# For local testing
creds = bllm.get_credentials(
aws_region_name=aws_region_name,
aws_access_key_id=os.environ["AWS_ACCESS_KEY_ID"],
aws_secret_access_key=os.environ["AWS_SECRET_ACCESS_KEY"],
aws_session_token=aws_session_token,
)
else:
# For circle-ci testing
# aws_role_name = os.environ["AWS_TEMP_ROLE_NAME"]
# TODO: This is using ai.moda's IAM role, we should use LiteLLM's IAM role eventually
aws_role_name = (
"arn:aws:iam::335785316107:role/litellm-github-unit-tests-circleci"
)
aws_web_identity_token = "oidc/circleci_v2/"
creds = bllm.get_credentials(
aws_region_name=aws_region_name,
aws_web_identity_token=aws_web_identity_token,
aws_role_name=aws_role_name,
aws_session_name="my-test-session",
)
return creds
def process_stream_response(res, messages):
import types
if isinstance(res, litellm.utils.CustomStreamWrapper):
chunks = []
for part in res:
chunks.append(part)
text = part.choices[0].delta.content or ""
print(text, end="")
res = litellm.stream_chunk_builder(chunks, messages=messages)
else:
raise ValueError("Response object is not a streaming response")
return res
@pytest.mark.skipif(
os.environ.get("CIRCLE_OIDC_TOKEN_V2") is None,
reason="Cannot run without being in CircleCI Runner",
)
def test_completion_bedrock_claude_aws_session_token(bedrock_session_token_creds):
print("\ncalling bedrock claude with aws_session_token auth")
import os
aws_region_name = os.environ["AWS_REGION_NAME"]
aws_access_key_id = bedrock_session_token_creds.access_key
aws_secret_access_key = bedrock_session_token_creds.secret_key
aws_session_token = bedrock_session_token_creds.token
try:
litellm.set_verbose = True
response_1 = completion(
model="bedrock/anthropic.claude-3-haiku-20240307-v1:0",
messages=messages,
max_tokens=10,
temperature=0.1,
aws_region_name=aws_region_name,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
)
print(response_1)
assert len(response_1.choices) > 0
assert len(response_1.choices[0].message.content) > 0
# This second call is to verify that the cache isn't breaking anything
response_2 = completion(
model="bedrock/anthropic.claude-3-haiku-20240307-v1:0",
messages=messages,
max_tokens=5,
temperature=0.2,
aws_region_name=aws_region_name,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
)
print(response_2)
assert len(response_2.choices) > 0
assert len(response_2.choices[0].message.content) > 0
# This third call is to verify that the cache isn't used for a different region
response_3 = completion(
model="bedrock/anthropic.claude-3-haiku-20240307-v1:0",
messages=messages,
max_tokens=6,
temperature=0.3,
aws_region_name="us-east-1",
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
)
print(response_3)
assert len(response_3.choices) > 0
assert len(response_3.choices[0].message.content) > 0
# This fourth call is to verify streaming api works
response_4 = completion(
model="bedrock/anthropic.claude-3-haiku-20240307-v1:0",
messages=messages,
max_tokens=6,
temperature=0.3,
aws_region_name="us-east-1",
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
stream=True,
)
response_4 = process_stream_response(response_4, messages)
print(response_4)
assert len(response_4.choices) > 0
assert len(response_4.choices[0].message.content) > 0
except RateLimitError:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@pytest.mark.skipif(
os.environ.get("CIRCLE_OIDC_TOKEN_V2") is None,
reason="Cannot run without being in CircleCI Runner",
)
def test_completion_bedrock_claude_aws_bedrock_client(bedrock_session_token_creds):
print("\ncalling bedrock claude with aws_session_token auth")
import os
import boto3
from botocore.client import Config
aws_region_name = os.environ["AWS_REGION_NAME"]
aws_access_key_id = bedrock_session_token_creds.access_key
aws_secret_access_key = bedrock_session_token_creds.secret_key
aws_session_token = bedrock_session_token_creds.token
aws_bedrock_client_west = boto3.client(
service_name="bedrock-runtime",
region_name=aws_region_name,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
config=Config(read_timeout=600),
)
try:
litellm.set_verbose = True
response_1 = completion(
model="bedrock/anthropic.claude-3-haiku-20240307-v1:0",
messages=messages,
max_tokens=10,
temperature=0.1,
aws_bedrock_client=aws_bedrock_client_west,
)
print(response_1)
assert len(response_1.choices) > 0
assert len(response_1.choices[0].message.content) > 0
# This second call is to verify that the cache isn't breaking anything
response_2 = completion(
model="bedrock/anthropic.claude-3-haiku-20240307-v1:0",
messages=messages,
max_tokens=5,
temperature=0.2,
aws_bedrock_client=aws_bedrock_client_west,
)
print(response_2)
assert len(response_2.choices) > 0
assert len(response_2.choices[0].message.content) > 0
# This third call is to verify that the cache isn't used for a different region
aws_bedrock_client_east = boto3.client(
service_name="bedrock-runtime",
region_name="us-east-1",
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
config=Config(read_timeout=600),
)
response_3 = completion(
model="bedrock/anthropic.claude-3-haiku-20240307-v1:0",
messages=messages,
max_tokens=6,
temperature=0.3,
aws_bedrock_client=aws_bedrock_client_east,
)
print(response_3)
assert len(response_3.choices) > 0
assert len(response_3.choices[0].message.content) > 0
# This fourth call is to verify streaming api works
response_4 = completion(
model="bedrock/anthropic.claude-3-haiku-20240307-v1:0",
messages=messages,
max_tokens=6,
temperature=0.3,
aws_bedrock_client=aws_bedrock_client_east,
stream=True,
)
response_4 = process_stream_response(response_4, messages)
print(response_4)
assert len(response_4.choices) > 0
assert len(response_4.choices[0].message.content) > 0
except RateLimitError:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_completion_bedrock_claude_sts_client_auth()
@ -489,61 +718,6 @@ def test_completion_claude_3_base64():
pytest.fail(f"An exception occurred - {str(e)}")
def test_provisioned_throughput():
try:
litellm.set_verbose = True
import io
import json
import botocore
import botocore.session
from botocore.stub import Stubber
bedrock_client = botocore.session.get_session().create_client(
"bedrock-runtime", region_name="us-east-1"
)
expected_params = {
"accept": "application/json",
"body": '{"prompt": "\\n\\nHuman: Hello, how are you?\\n\\nAssistant: ", '
'"max_tokens_to_sample": 256}',
"contentType": "application/json",
"modelId": "provisioned-model-arn",
}
response_from_bedrock = {
"body": io.StringIO(
json.dumps(
{
"completion": " Here is a short poem about the sky:",
"stop_reason": "max_tokens",
"stop": None,
}
)
),
"contentType": "contentType",
"ResponseMetadata": {"HTTPStatusCode": 200},
}
with Stubber(bedrock_client) as stubber:
stubber.add_response(
"invoke_model",
service_response=response_from_bedrock,
expected_params=expected_params,
)
response = litellm.completion(
model="bedrock/anthropic.claude-instant-v1",
model_id="provisioned-model-arn",
messages=[{"content": "Hello, how are you?", "role": "user"}],
aws_bedrock_client=bedrock_client,
)
print("response stubbed", response)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_provisioned_throughput()
def test_completion_bedrock_mistral_completion_auth():
print("calling bedrock mistral completion params auth")
import os
@ -682,3 +856,56 @@ async def test_bedrock_custom_prompt_template():
prompt = json.loads(mock_client_post.call_args.kwargs["data"])["prompt"]
assert prompt == "<|im_start|>user\nWhat's AWS?<|im_end|>"
mock_client_post.assert_called_once()
def test_completion_bedrock_external_client_region():
print("\ncalling bedrock claude external client auth")
import os
aws_access_key_id = os.environ["AWS_ACCESS_KEY_ID"]
aws_secret_access_key = os.environ["AWS_SECRET_ACCESS_KEY"]
aws_region_name = "us-east-1"
os.environ.pop("AWS_ACCESS_KEY_ID", None)
os.environ.pop("AWS_SECRET_ACCESS_KEY", None)
client = HTTPHandler()
try:
import boto3
litellm.set_verbose = True
bedrock = boto3.client(
service_name="bedrock-runtime",
region_name=aws_region_name,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
endpoint_url=f"https://bedrock-runtime.{aws_region_name}.amazonaws.com",
)
with patch.object(client, "post", new=Mock()) as mock_client_post:
try:
response = completion(
model="bedrock/anthropic.claude-instant-v1",
messages=messages,
max_tokens=10,
temperature=0.1,
aws_bedrock_client=bedrock,
client=client,
)
# Add any assertions here to check the response
print(response)
except Exception as e:
pass
print(f"mock_client_post.call_args: {mock_client_post.call_args}")
assert "us-east-1" in mock_client_post.call_args.kwargs["url"]
mock_client_post.assert_called_once()
os.environ["AWS_ACCESS_KEY_ID"] = aws_access_key_id
os.environ["AWS_SECRET_ACCESS_KEY"] = aws_secret_access_key
except RateLimitError:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")

View file

@ -23,7 +23,7 @@ from litellm import RateLimitError, Timeout, completion, completion_cost, embedd
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.llms.prompt_templates.factory import anthropic_messages_pt
# litellm.num_retries = 3
# litellm.num_retries=3
litellm.cache = None
litellm.success_callback = []
user_message = "Write a short poem about the sky"

View file

@ -249,6 +249,25 @@ def test_completion_azure_exception():
# test_completion_azure_exception()
def test_azure_embedding_exceptions():
try:
response = litellm.embedding(
model="azure/azure-embedding-model",
input="hello",
messages="hello",
)
pytest.fail(f"Bad request this should have failed but got {response}")
except Exception as e:
print(vars(e))
# CRUCIAL Test - Ensures our exceptions are readable and not overly complicated. some users have complained exceptions will randomly have another exception raised in our exception mapping
assert (
e.message
== "litellm.APIError: AzureException APIError - Embeddings.create() got an unexpected keyword argument 'messages'"
)
async def asynctest_completion_azure_exception():
try:
import openai

View file

@ -61,7 +61,6 @@ async def test_token_single_public_key():
import jwt
jwt_handler = JWTHandler()
backend_keys = {
"keys": [
{

View file

@ -1,10 +1,16 @@
# What is this?
## This tests the Lakera AI integration
import sys, os, asyncio, time, random
from datetime import datetime
import asyncio
import os
import random
import sys
import time
import traceback
from datetime import datetime
from dotenv import load_dotenv
from fastapi import HTTPException
load_dotenv()
import os
@ -12,17 +18,19 @@ import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import logging
import pytest
import litellm
from litellm import Router, mock_completion
from litellm._logging import verbose_proxy_logger
from litellm.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.enterprise.enterprise_hooks.lakera_ai import (
_ENTERPRISE_lakeraAI_Moderation,
)
from litellm import Router, mock_completion
from litellm.proxy.utils import ProxyLogging, hash_token
from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache
from litellm._logging import verbose_proxy_logger
import logging
verbose_proxy_logger.setLevel(logging.DEBUG)
@ -55,10 +63,12 @@ async def test_lakera_prompt_injection_detection():
call_type="completion",
)
pytest.fail(f"Should have failed")
except Exception as e:
print("Got exception: ", e)
assert "Violated content safety policy" in str(e)
pass
except HTTPException as http_exception:
print("http exception details=", http_exception.detail)
# Assert that the laker ai response is in the exception raise
assert "lakera_ai_response" in http_exception.detail
assert "Violated content safety policy" in str(http_exception)
@pytest.mark.asyncio

View file

@ -0,0 +1,85 @@
import os
import sys
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
sys.path.insert(
0, os.path.abspath("../..")
) # Adds-the parent directory to the system path
import asyncio
import httpx
from litellm.proxy.proxy_server import app, initialize_pass_through_endpoints
# Mock the async_client used in the pass_through_request function
async def mock_request(*args, **kwargs):
return httpx.Response(200, json={"message": "Mocked response"})
@pytest.fixture
def client():
return TestClient(app)
@pytest.mark.asyncio
async def test_pass_through_endpoint(client, monkeypatch):
# Mock the httpx.AsyncClient.request method
monkeypatch.setattr("httpx.AsyncClient.request", mock_request)
# Define a pass-through endpoint
pass_through_endpoints = [
{
"path": "/test-endpoint",
"target": "https://api.example.com/v1/chat/completions",
"headers": {"Authorization": "Bearer test-token"},
}
]
# Initialize the pass-through endpoint
await initialize_pass_through_endpoints(pass_through_endpoints)
# Make a request to the pass-through endpoint
response = client.post("/test-endpoint", json={"prompt": "Hello, world!"})
# Assert the response
assert response.status_code == 200
assert response.json() == {"message": "Mocked response"}
@pytest.mark.asyncio
async def test_pass_through_endpoint_rerank(client):
_cohere_api_key = os.environ.get("COHERE_API_KEY")
# Define a pass-through endpoint
pass_through_endpoints = [
{
"path": "/v1/rerank",
"target": "https://api.cohere.com/v1/rerank",
"headers": {"Authorization": f"bearer {_cohere_api_key}"},
}
]
# Initialize the pass-through endpoint
await initialize_pass_through_endpoints(pass_through_endpoints)
_json_data = {
"model": "rerank-english-v3.0",
"query": "What is the capital of the United States?",
"top_n": 3,
"documents": [
"Carson City is the capital city of the American state of Nevada."
],
}
# Make a request to the pass-through endpoint
response = client.post("/v1/rerank", json=_json_data)
print("JSON response: ", _json_data)
# Assert the response
assert response.status_code == 200

View file

@ -1,25 +1,31 @@
# test that the proxy actually does exception mapping to the OpenAI format
import sys, os
from unittest import mock
import json
import os
import sys
from unittest import mock
from dotenv import load_dotenv
load_dotenv()
import os, io, asyncio
import asyncio
import io
import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import openai
import pytest
import litellm, openai
from fastapi.testclient import TestClient
from fastapi import Response
from litellm.proxy.proxy_server import (
from fastapi.testclient import TestClient
import litellm
from litellm.proxy.proxy_server import ( # Replace with the actual module where your FastAPI router is defined
initialize,
router,
save_worker_config,
initialize,
) # Replace with the actual module where your FastAPI router is defined
)
invalid_authentication_error_response = Response(
status_code=401,
@ -66,6 +72,12 @@ def test_chat_completion_exception(client):
json_response = response.json()
print("keys in json response", json_response.keys())
assert json_response.keys() == {"error"}
print("ERROR=", json_response["error"])
assert isinstance(json_response["error"]["message"], str)
assert (
json_response["error"]["message"]
== "litellm.AuthenticationError: AuthenticationError: OpenAIException - Incorrect API key provided: bad-key. You can find your API key at https://platform.openai.com/account/api-keys."
)
# make an openai client to call _make_status_error_from_response
openai_client = openai.OpenAI(api_key="anything")

View file

@ -16,6 +16,7 @@ sys.path.insert(
import os
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
from dotenv import load_dotenv
@ -811,6 +812,7 @@ def test_router_context_window_check_pre_call_check():
"base_model": "azure/gpt-35-turbo",
"mock_response": "Hello world 1!",
},
"model_info": {"base_model": "azure/gpt-35-turbo"},
},
{
"model_name": "gpt-3.5-turbo", # openai model name
@ -1884,3 +1886,106 @@ async def test_router_model_usage(mock_response):
else:
print(f"allowed_fails: {allowed_fails}")
raise e
@pytest.mark.parametrize(
"model, base_model, llm_provider",
[
("azure/gpt-4", None, "azure"),
("azure/gpt-4", "azure/gpt-4-0125-preview", "azure"),
("gpt-4", None, "openai"),
],
)
def test_router_get_model_info(model, base_model, llm_provider):
"""
Test if router get model info works based on provider
For azure -> only if base model set
For openai -> use model=
"""
router = Router(
model_list=[
{
"model_name": "gpt-4",
"litellm_params": {
"model": model,
"api_key": "my-fake-key",
"api_base": "my-fake-base",
},
"model_info": {"base_model": base_model, "id": "1"},
}
]
)
deployment = router.get_deployment(model_id="1")
assert deployment is not None
if llm_provider == "openai" or (base_model is not None and llm_provider == "azure"):
router.get_router_model_info(deployment=deployment.to_json())
else:
try:
router.get_router_model_info(deployment=deployment.to_json())
pytest.fail("Expected this to raise model not mapped error")
except Exception as e:
if "This model isn't mapped yet" in str(e):
pass
@pytest.mark.parametrize(
"model, base_model, llm_provider",
[
("azure/gpt-4", None, "azure"),
("azure/gpt-4", "azure/gpt-4-0125-preview", "azure"),
("gpt-4", None, "openai"),
],
)
def test_router_context_window_pre_call_check(model, base_model, llm_provider):
"""
- For an azure model
- if no base model set
- don't enforce context window limits
"""
try:
model_list = [
{
"model_name": "gpt-4",
"litellm_params": {
"model": model,
"api_key": "my-fake-key",
"api_base": "my-fake-base",
},
"model_info": {"base_model": base_model, "id": "1"},
}
]
router = Router(
model_list=model_list,
set_verbose=True,
enable_pre_call_checks=True,
num_retries=0,
)
litellm.token_counter = MagicMock()
def token_counter_side_effect(*args, **kwargs):
# Process args and kwargs if needed
return 1000000
litellm.token_counter.side_effect = token_counter_side_effect
try:
updated_list = router._pre_call_checks(
model="gpt-4",
healthy_deployments=model_list,
messages=[{"role": "user", "content": "Hey, how's it going?"}],
)
if llm_provider == "azure" and base_model is None:
assert len(updated_list) == 1
else:
pytest.fail("Expected to raise an error. Got={}".format(updated_list))
except Exception as e:
if (
llm_provider == "azure" and base_model is not None
) or llm_provider == "openai":
pass
except Exception as e:
pytest.fail(f"Got unexpected exception on router! - {str(e)}")

View file

@ -1,16 +1,23 @@
import sys, os, time
import traceback, asyncio
import asyncio
import os
import sys
import time
import traceback
import pytest
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import litellm, asyncio, logging
import asyncio
import logging
import litellm
from litellm import Router
# this tests debug logs from litellm router and litellm proxy server
from litellm._logging import verbose_router_logger, verbose_logger, verbose_proxy_logger
from litellm._logging import verbose_logger, verbose_proxy_logger, verbose_router_logger
# this tests debug logs from litellm router and litellm proxy server
@ -81,7 +88,7 @@ def test_async_fallbacks(caplog):
# Define the expected log messages
# - error request, falling back notice, success notice
expected_logs = [
"litellm.acompletion(model=gpt-3.5-turbo)\x1b[31m Exception litellm.AuthenticationError: AuthenticationError: OpenAIException - Error code: 401 - {'error': {'message': 'Incorrect API key provided: bad-key. You can find your API key at https://platform.openai.com/account/api-keys.', 'type': 'invalid_request_error', 'param': None, 'code': 'invalid_api_key'}}\x1b[0m",
"litellm.acompletion(model=gpt-3.5-turbo)\x1b[31m Exception litellm.AuthenticationError: AuthenticationError: OpenAIException - Incorrect API key provided: bad-key. You can find your API key at https://platform.openai.com/account/api-keys.\x1b[0m",
"Falling back to model_group = azure/gpt-3.5-turbo",
"litellm.acompletion(model=azure/chatgpt-v-2)\x1b[32m 200 OK\x1b[0m",
"Successful fallback b/w models.",

View file

@ -742,7 +742,10 @@ def test_completion_palm_stream():
# test_completion_palm_stream()
@pytest.mark.parametrize("sync_mode", [False]) # True,
@pytest.mark.parametrize(
"sync_mode",
[True, False],
) # ,
@pytest.mark.asyncio
async def test_completion_gemini_stream(sync_mode):
try:
@ -807,49 +810,6 @@ async def test_completion_gemini_stream(sync_mode):
pytest.fail(f"Error occurred: {e}")
@pytest.mark.asyncio
async def test_acompletion_gemini_stream():
try:
litellm.set_verbose = True
print("Streaming gemini response")
messages = [
# {"role": "system", "content": "You are a helpful assistant."},
{
"role": "user",
"content": "What do you know?",
},
]
print("testing gemini streaming")
response = await acompletion(
model="gemini/gemini-pro", messages=messages, max_tokens=50, stream=True
)
print(f"type of response at the top: {response}")
complete_response = ""
idx = 0
# Add any assertions here to check, the response
async for chunk in response:
print(f"chunk in acompletion gemini: {chunk}")
print(chunk.choices[0].delta)
chunk, finished = streaming_format_tests(idx, chunk)
if finished:
break
print(f"chunk: {chunk}")
complete_response += chunk
idx += 1
print(f"completion_response: {complete_response}")
if complete_response.strip() == "":
raise Exception("Empty response received")
except litellm.APIError as e:
pass
except litellm.RateLimitError as e:
pass
except Exception as e:
if "429 Resource has been exhausted" in str(e):
pass
else:
pytest.fail(f"Error occurred: {e}")
# asyncio.run(test_acompletion_gemini_stream())
@ -1071,7 +1031,7 @@ def test_completion_claude_stream_bad_key():
# test_completion_replicate_stream()
@pytest.mark.parametrize("provider", ["vertex_ai"]) # "vertex_ai_beta"
@pytest.mark.parametrize("provider", ["vertex_ai_beta"]) # ""
def test_vertex_ai_stream(provider):
from litellm.tests.test_amazing_vertex_completion import load_vertex_ai_credentials
@ -1080,14 +1040,27 @@ def test_vertex_ai_stream(provider):
litellm.vertex_project = "adroit-crow-413218"
import random
test_models = ["gemini-1.0-pro"]
test_models = ["gemini-1.5-pro"]
for model in test_models:
try:
print("making request", model)
response = completion(
model="{}/{}".format(provider, model),
messages=[
{"role": "user", "content": "write 10 line code code for saying hi"}
{"role": "user", "content": "Hey, how's it going?"},
{
"role": "assistant",
"content": "I'm doing well. Would like to hear the rest of the story?",
},
{"role": "user", "content": "Na"},
{
"role": "assistant",
"content": "No problem, is there anything else i can help you with today?",
},
{
"role": "user",
"content": "I think you're getting cut off sometimes",
},
],
stream=True,
)
@ -1104,6 +1077,8 @@ def test_vertex_ai_stream(provider):
raise Exception("Empty response received")
print(f"completion_response: {complete_response}")
assert is_finished == True
assert False
except litellm.RateLimitError as e:
pass
except Exception as e:
@ -1251,6 +1226,7 @@ async def test_completion_replicate_llama3_streaming(sync_mode):
messages=messages,
max_tokens=10, # type: ignore
stream=True,
num_retries=3,
)
complete_response = ""
# Add any assertions here to check the response
@ -1272,6 +1248,7 @@ async def test_completion_replicate_llama3_streaming(sync_mode):
messages=messages,
max_tokens=100, # type: ignore
stream=True,
num_retries=3,
)
complete_response = ""
# Add any assertions here to check the response
@ -1290,6 +1267,8 @@ async def test_completion_replicate_llama3_streaming(sync_mode):
raise Exception("finish reason not set")
if complete_response.strip() == "":
raise Exception("Empty response received")
except litellm.UnprocessableEntityError as e:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")

View file

@ -609,3 +609,83 @@ def test_logging_trace_id(langfuse_trace_id, langfuse_existing_trace_id):
litellm_logging_obj._get_trace_id(service_name="langfuse")
== litellm_call_id
)
def test_convert_model_response_object():
"""
Unit test to ensure model response object correctly handles openrouter errors.
"""
args = {
"response_object": {
"id": None,
"choices": None,
"created": None,
"model": None,
"object": None,
"service_tier": None,
"system_fingerprint": None,
"usage": None,
"error": {
"message": '{"type":"error","error":{"type":"invalid_request_error","message":"Output blocked by content filtering policy"}}',
"code": 400,
},
},
"model_response_object": litellm.ModelResponse(
id="chatcmpl-b88ce43a-7bfc-437c-b8cc-e90d59372cfb",
choices=[
litellm.Choices(
finish_reason="stop",
index=0,
message=litellm.Message(content="default", role="assistant"),
)
],
created=1719376241,
model="openrouter/anthropic/claude-3.5-sonnet",
object="chat.completion",
system_fingerprint=None,
usage=litellm.Usage(),
),
"response_type": "completion",
"stream": False,
"start_time": None,
"end_time": None,
"hidden_params": None,
}
try:
litellm.convert_to_model_response_object(**args)
pytest.fail("Expected this to fail")
except Exception as e:
assert hasattr(e, "status_code")
assert e.status_code == 400
assert hasattr(e, "message")
assert (
e.message
== '{"type":"error","error":{"type":"invalid_request_error","message":"Output blocked by content filtering policy"}}'
)
@pytest.mark.parametrize(
"model, expected_bool",
[
("vertex_ai/gemini-1.5-pro", True),
("gemini/gemini-1.5-pro", True),
("predibase/llama3-8b-instruct", True),
("gpt-4o", False),
],
)
def test_supports_response_schema(model, expected_bool):
"""
Unit tests for 'supports_response_schema' helper function.
Should be true for gemini-1.5-pro on google ai studio / vertex ai AND predibase models
Should be false otherwise
"""
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
litellm.model_cost = litellm.get_model_cost_map(url="")
from litellm.utils import supports_response_schema
response = supports_response_schema(model=model, custom_llm_provider=None)
assert expected_bool == response

View file

@ -71,6 +71,7 @@ class ModelInfo(TypedDict, total=False):
]
supported_openai_params: Required[Optional[List[str]]]
supports_system_messages: Optional[bool]
supports_response_schema: Optional[bool]
class GenericStreamingChunk(TypedDict):
@ -994,3 +995,8 @@ class GenericImageParsingChunk(TypedDict):
type: str
media_type: str
data: str
class ResponseFormatChunk(TypedDict, total=False):
type: Required[Literal["json_object", "text"]]
response_schema: dict

View file

@ -48,8 +48,10 @@ from tokenizers import Tokenizer
import litellm
import litellm._service_logger # for storing API inputs, outputs, and metadata
import litellm.litellm_core_utils
import litellm.litellm_core_utils.json_validation_rule
from litellm.caching import DualCache
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.litellm_core_utils.exception_mapping_utils import get_error_message
from litellm.litellm_core_utils.llm_request_utils import _ensure_extra_body_is_safe
from litellm.litellm_core_utils.redact_messages import (
redact_message_input_output_from_logging,
@ -579,7 +581,7 @@ def client(original_function):
else:
return False
def post_call_processing(original_response, model):
def post_call_processing(original_response, model, optional_params: Optional[dict]):
try:
if original_response is None:
pass
@ -594,11 +596,47 @@ def client(original_function):
pass
else:
if isinstance(original_response, ModelResponse):
model_response = original_response.choices[
model_response: Optional[str] = original_response.choices[
0
].message.content
].message.content # type: ignore
if model_response is not None:
### POST-CALL RULES ###
rules_obj.post_call_rules(input=model_response, model=model)
rules_obj.post_call_rules(
input=model_response, model=model
)
### JSON SCHEMA VALIDATION ###
if (
optional_params is not None
and "response_format" in optional_params
and isinstance(
optional_params["response_format"], dict
)
and "type" in optional_params["response_format"]
and optional_params["response_format"]["type"]
== "json_object"
and "response_schema"
in optional_params["response_format"]
and isinstance(
optional_params["response_format"][
"response_schema"
],
dict,
)
and "enforce_validation"
in optional_params["response_format"]
and optional_params["response_format"][
"enforce_validation"
]
is True
):
# schema given, json response expected, and validation enforced
litellm.litellm_core_utils.json_validation_rule.validate_schema(
schema=optional_params["response_format"][
"response_schema"
],
response=model_response,
)
except Exception as e:
raise e
@ -867,7 +905,11 @@ def client(original_function):
return result
### POST-CALL RULES ###
post_call_processing(original_response=result, model=model or None)
post_call_processing(
original_response=result,
model=model or None,
optional_params=kwargs,
)
# [OPTIONAL] ADD TO CACHE
if (
@ -1316,7 +1358,9 @@ def client(original_function):
).total_seconds() * 1000 # return response latency in ms like openai
### POST-CALL RULES ###
post_call_processing(original_response=result, model=model)
post_call_processing(
original_response=result, model=model, optional_params=kwargs
)
# [OPTIONAL] ADD TO CACHE
if (
@ -1847,9 +1891,10 @@ def supports_system_messages(model: str, custom_llm_provider: Optional[str]) ->
Parameters:
model (str): The model name to be checked.
custom_llm_provider (str): The provider to be checked.
Returns:
bool: True if the model supports function calling, False otherwise.
bool: True if the model supports system messages, False otherwise.
Raises:
Exception: If the given model is not found in model_prices_and_context_window.json.
@ -1867,6 +1912,43 @@ def supports_system_messages(model: str, custom_llm_provider: Optional[str]) ->
)
def supports_response_schema(model: str, custom_llm_provider: Optional[str]) -> bool:
"""
Check if the given model + provider supports 'response_schema' as a param.
Parameters:
model (str): The model name to be checked.
custom_llm_provider (str): The provider to be checked.
Returns:
bool: True if the model supports response_schema, False otherwise.
Does not raise error. Defaults to 'False'. Outputs logging.error.
"""
try:
## GET LLM PROVIDER ##
model, custom_llm_provider, _, _ = get_llm_provider(
model=model, custom_llm_provider=custom_llm_provider
)
if custom_llm_provider == "predibase": # predibase supports this globally
return True
## GET MODEL INFO
model_info = litellm.get_model_info(
model=model, custom_llm_provider=custom_llm_provider
)
if model_info.get("supports_response_schema", False) is True:
return True
return False
except Exception:
verbose_logger.error(
f"Model not in model_prices_and_context_window.json. You passed model={model}, custom_llm_provider={custom_llm_provider}."
)
return False
def supports_function_calling(model: str) -> bool:
"""
Check if the given model supports function calling and return a boolean value.
@ -2324,7 +2406,9 @@ def get_optional_params(
elif k == "hf_model_name" and custom_llm_provider != "sagemaker":
continue
elif (
k.startswith("vertex_") and custom_llm_provider != "vertex_ai" and custom_llm_provider != "vertex_ai_beta"
k.startswith("vertex_")
and custom_llm_provider != "vertex_ai"
and custom_llm_provider != "vertex_ai_beta"
): # allow dynamically setting vertex ai init logic
continue
passed_params[k] = v
@ -2756,6 +2840,11 @@ def get_optional_params(
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
)
elif (
custom_llm_provider == "vertex_ai" and model in litellm.vertex_anthropic_models
@ -2824,12 +2913,7 @@ def get_optional_params(
optional_params=optional_params,
)
)
else:
optional_params = litellm.AmazonAnthropicConfig().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
)
else: # bedrock httpx route
elif model in litellm.BEDROCK_CONVERSE_MODELS:
optional_params = litellm.AmazonConverseConfig().map_openai_params(
model=model,
non_default_params=non_default_params,
@ -2840,6 +2924,11 @@ def get_optional_params(
else False
),
)
else:
optional_params = litellm.AmazonAnthropicConfig().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
)
elif "amazon" in model: # amazon titan llms
_check_valid_arg(supported_params=supported_params)
# see https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-large
@ -3755,23 +3844,18 @@ def get_supported_openai_params(
return litellm.AzureOpenAIConfig().get_supported_openai_params()
elif custom_llm_provider == "openrouter":
return [
"functions",
"function_call",
"temperature",
"top_p",
"n",
"stream",
"stop",
"max_tokens",
"presence_penalty",
"frequency_penalty",
"logit_bias",
"user",
"response_format",
"presence_penalty",
"repetition_penalty",
"seed",
"tools",
"tool_choice",
"max_retries",
"max_tokens",
"logit_bias",
"logprobs",
"top_logprobs",
"response_format",
"stop",
]
elif custom_llm_provider == "mistral" or custom_llm_provider == "codestral":
# mistal and codestral api have the exact same params
@ -3789,6 +3873,10 @@ def get_supported_openai_params(
"top_p",
"stop",
"seed",
"tools",
"tool_choice",
"functions",
"function_call",
]
elif custom_llm_provider == "huggingface":
return litellm.HuggingfaceConfig().get_supported_openai_params()
@ -4434,8 +4522,7 @@ def get_max_tokens(model: str) -> Optional[int]:
def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> ModelInfo:
"""
Get a dict for the maximum tokens (context window),
input_cost_per_token, output_cost_per_token for a given model.
Get a dict for the maximum tokens (context window), input_cost_per_token, output_cost_per_token for a given model.
Parameters:
- model (str): The name of the model.
@ -4520,6 +4607,7 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
mode="chat",
supported_openai_params=supported_openai_params,
supports_system_messages=None,
supports_response_schema=None,
)
else:
"""
@ -4541,36 +4629,6 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
pass
else:
raise Exception
return ModelInfo(
max_tokens=_model_info.get("max_tokens", None),
max_input_tokens=_model_info.get("max_input_tokens", None),
max_output_tokens=_model_info.get("max_output_tokens", None),
input_cost_per_token=_model_info.get("input_cost_per_token", 0),
input_cost_per_character=_model_info.get(
"input_cost_per_character", None
),
input_cost_per_token_above_128k_tokens=_model_info.get(
"input_cost_per_token_above_128k_tokens", None
),
output_cost_per_token=_model_info.get("output_cost_per_token", 0),
output_cost_per_character=_model_info.get(
"output_cost_per_character", None
),
output_cost_per_token_above_128k_tokens=_model_info.get(
"output_cost_per_token_above_128k_tokens", None
),
output_cost_per_character_above_128k_tokens=_model_info.get(
"output_cost_per_character_above_128k_tokens", None
),
litellm_provider=_model_info.get(
"litellm_provider", custom_llm_provider
),
mode=_model_info.get("mode"),
supported_openai_params=supported_openai_params,
supports_system_messages=_model_info.get(
"supports_system_messages", None
),
)
elif model in litellm.model_cost:
_model_info = litellm.model_cost[model]
_model_info["supported_openai_params"] = supported_openai_params
@ -4584,36 +4642,6 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
pass
else:
raise Exception
return ModelInfo(
max_tokens=_model_info.get("max_tokens", None),
max_input_tokens=_model_info.get("max_input_tokens", None),
max_output_tokens=_model_info.get("max_output_tokens", None),
input_cost_per_token=_model_info.get("input_cost_per_token", 0),
input_cost_per_character=_model_info.get(
"input_cost_per_character", None
),
input_cost_per_token_above_128k_tokens=_model_info.get(
"input_cost_per_token_above_128k_tokens", None
),
output_cost_per_token=_model_info.get("output_cost_per_token", 0),
output_cost_per_character=_model_info.get(
"output_cost_per_character", None
),
output_cost_per_token_above_128k_tokens=_model_info.get(
"output_cost_per_token_above_128k_tokens", None
),
output_cost_per_character_above_128k_tokens=_model_info.get(
"output_cost_per_character_above_128k_tokens", None
),
litellm_provider=_model_info.get(
"litellm_provider", custom_llm_provider
),
mode=_model_info.get("mode"),
supported_openai_params=supported_openai_params,
supports_system_messages=_model_info.get(
"supports_system_messages", None
),
)
elif split_model in litellm.model_cost:
_model_info = litellm.model_cost[split_model]
_model_info["supported_openai_params"] = supported_openai_params
@ -4627,6 +4655,15 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
pass
else:
raise Exception
else:
raise ValueError(
"This model isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json"
)
## PROVIDER-SPECIFIC INFORMATION
if custom_llm_provider == "predibase":
_model_info["supports_response_schema"] = True
return ModelInfo(
max_tokens=_model_info.get("max_tokens", None),
max_input_tokens=_model_info.get("max_input_tokens", None),
@ -4656,10 +4693,9 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
supports_system_messages=_model_info.get(
"supports_system_messages", None
),
)
else:
raise ValueError(
"This model isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json"
supports_response_schema=_model_info.get(
"supports_response_schema", None
),
)
except Exception:
raise Exception(
@ -5278,6 +5314,27 @@ def convert_to_model_response_object(
hidden_params: Optional[dict] = None,
):
received_args = locals()
### CHECK IF ERROR IN RESPONSE ### - openrouter returns these in the dictionary
if (
response_object is not None
and "error" in response_object
and response_object["error"] is not None
):
error_args = {"status_code": 422, "message": "Error in response object"}
if isinstance(response_object["error"], dict):
if "code" in response_object["error"]:
error_args["status_code"] = response_object["error"]["code"]
if "message" in response_object["error"]:
if isinstance(response_object["error"]["message"], dict):
message_str = json.dumps(response_object["error"]["message"])
else:
message_str = str(response_object["error"]["message"])
error_args["message"] = message_str
raised_exception = Exception()
setattr(raised_exception, "status_code", error_args["status_code"])
setattr(raised_exception, "message", error_args["message"])
raise raised_exception
try:
if response_type == "completion" and (
model_response_object is None
@ -5733,6 +5790,9 @@ def exception_type(
print() # noqa
try:
if model:
if hasattr(original_exception, "message"):
error_str = str(original_exception.message)
else:
error_str = str(original_exception)
if isinstance(original_exception, BaseException):
exception_type = type(original_exception).__name__
@ -5755,6 +5815,18 @@ def exception_type(
_model_group = _metadata.get("model_group")
_deployment = _metadata.get("deployment")
extra_information = f"\nModel: {model}"
exception_provider = "Unknown"
if (
isinstance(custom_llm_provider, str)
and len(custom_llm_provider) > 0
):
exception_provider = (
custom_llm_provider[0].upper()
+ custom_llm_provider[1:]
+ "Exception"
)
if _api_base:
extra_information += f"\nAPI Base: `{_api_base}`"
if (
@ -5805,10 +5877,13 @@ def exception_type(
or custom_llm_provider in litellm.openai_compatible_providers
):
# custom_llm_provider is openai, make it OpenAI
message = get_error_message(error_obj=original_exception)
if message is None:
if hasattr(original_exception, "message"):
message = original_exception.message
else:
message = str(original_exception)
if message is not None and isinstance(message, str):
message = message.replace("OPENAI", custom_llm_provider.upper())
message = message.replace("openai", custom_llm_provider)
@ -6141,7 +6216,6 @@ def exception_type(
)
elif (
original_exception.status_code == 400
or original_exception.status_code == 422
or original_exception.status_code == 413
):
exception_mapping_worked = True
@ -6151,6 +6225,14 @@ def exception_type(
llm_provider="replicate",
response=original_exception.response,
)
elif original_exception.status_code == 422:
exception_mapping_worked = True
raise UnprocessableEntityError(
message=f"ReplicateException - {original_exception.message}",
model=model,
llm_provider="replicate",
response=original_exception.response,
)
elif original_exception.status_code == 408:
exception_mapping_worked = True
raise Timeout(
@ -7254,10 +7336,17 @@ def exception_type(
request=original_exception.request,
)
elif custom_llm_provider == "azure":
message = get_error_message(error_obj=original_exception)
if message is None:
if hasattr(original_exception, "message"):
message = original_exception.message
else:
message = str(original_exception)
if "Internal server error" in error_str:
exception_mapping_worked = True
raise litellm.InternalServerError(
message=f"AzureException Internal server error - {original_exception.message}",
message=f"AzureException Internal server error - {message}",
llm_provider="azure",
model=model,
litellm_debug_info=extra_information,
@ -7270,7 +7359,7 @@ def exception_type(
elif "This model's maximum context length is" in error_str:
exception_mapping_worked = True
raise ContextWindowExceededError(
message=f"AzureException ContextWindowExceededError - {original_exception.message}",
message=f"AzureException ContextWindowExceededError - {message}",
llm_provider="azure",
model=model,
litellm_debug_info=extra_information,
@ -7279,7 +7368,7 @@ def exception_type(
elif "DeploymentNotFound" in error_str:
exception_mapping_worked = True
raise NotFoundError(
message=f"AzureException NotFoundError - {original_exception.message}",
message=f"AzureException NotFoundError - {message}",
llm_provider="azure",
model=model,
litellm_debug_info=extra_information,
@ -7299,7 +7388,7 @@ def exception_type(
):
exception_mapping_worked = True
raise ContentPolicyViolationError(
message=f"litellm.ContentPolicyViolationError: AzureException - {original_exception.message}",
message=f"litellm.ContentPolicyViolationError: AzureException - {message}",
llm_provider="azure",
model=model,
litellm_debug_info=extra_information,
@ -7308,7 +7397,7 @@ def exception_type(
elif "invalid_request_error" in error_str:
exception_mapping_worked = True
raise BadRequestError(
message=f"AzureException BadRequestError - {original_exception.message}",
message=f"AzureException BadRequestError - {message}",
llm_provider="azure",
model=model,
litellm_debug_info=extra_information,
@ -7320,7 +7409,7 @@ def exception_type(
):
exception_mapping_worked = True
raise AuthenticationError(
message=f"{exception_provider} AuthenticationError - {original_exception.message}",
message=f"{exception_provider} AuthenticationError - {message}",
llm_provider=custom_llm_provider,
model=model,
litellm_debug_info=extra_information,
@ -7331,7 +7420,7 @@ def exception_type(
if original_exception.status_code == 400:
exception_mapping_worked = True
raise BadRequestError(
message=f"AzureException - {original_exception.message}",
message=f"AzureException - {message}",
llm_provider="azure",
model=model,
litellm_debug_info=extra_information,
@ -7340,7 +7429,7 @@ def exception_type(
elif original_exception.status_code == 401:
exception_mapping_worked = True
raise AuthenticationError(
message=f"AzureException AuthenticationError - {original_exception.message}",
message=f"AzureException AuthenticationError - {message}",
llm_provider="azure",
model=model,
litellm_debug_info=extra_information,
@ -7349,7 +7438,7 @@ def exception_type(
elif original_exception.status_code == 408:
exception_mapping_worked = True
raise Timeout(
message=f"AzureException Timeout - {original_exception.message}",
message=f"AzureException Timeout - {message}",
model=model,
litellm_debug_info=extra_information,
llm_provider="azure",
@ -7357,7 +7446,7 @@ def exception_type(
elif original_exception.status_code == 422:
exception_mapping_worked = True
raise BadRequestError(
message=f"AzureException BadRequestError - {original_exception.message}",
message=f"AzureException BadRequestError - {message}",
model=model,
llm_provider="azure",
litellm_debug_info=extra_information,
@ -7366,7 +7455,7 @@ def exception_type(
elif original_exception.status_code == 429:
exception_mapping_worked = True
raise RateLimitError(
message=f"AzureException RateLimitError - {original_exception.message}",
message=f"AzureException RateLimitError - {message}",
model=model,
llm_provider="azure",
litellm_debug_info=extra_information,
@ -7375,7 +7464,7 @@ def exception_type(
elif original_exception.status_code == 503:
exception_mapping_worked = True
raise ServiceUnavailableError(
message=f"AzureException ServiceUnavailableError - {original_exception.message}",
message=f"AzureException ServiceUnavailableError - {message}",
model=model,
llm_provider="azure",
litellm_debug_info=extra_information,
@ -7384,7 +7473,7 @@ def exception_type(
elif original_exception.status_code == 504: # gateway timeout error
exception_mapping_worked = True
raise Timeout(
message=f"AzureException Timeout - {original_exception.message}",
message=f"AzureException Timeout - {message}",
model=model,
litellm_debug_info=extra_information,
llm_provider="azure",
@ -7393,7 +7482,7 @@ def exception_type(
exception_mapping_worked = True
raise APIError(
status_code=original_exception.status_code,
message=f"AzureException APIError - {original_exception.message}",
message=f"AzureException APIError - {message}",
llm_provider="azure",
litellm_debug_info=extra_information,
model=model,

View file

@ -1486,6 +1486,7 @@
"supports_system_messages": true,
"supports_function_calling": true,
"supports_tool_choice": true,
"supports_response_schema": true,
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
},
"gemini-1.5-pro-001": {
@ -1511,6 +1512,7 @@
"supports_system_messages": true,
"supports_function_calling": true,
"supports_tool_choice": true,
"supports_response_schema": true,
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
},
"gemini-1.5-pro-preview-0514": {
@ -1536,6 +1538,7 @@
"supports_system_messages": true,
"supports_function_calling": true,
"supports_tool_choice": true,
"supports_response_schema": true,
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
},
"gemini-1.5-pro-preview-0215": {
@ -1561,6 +1564,7 @@
"supports_system_messages": true,
"supports_function_calling": true,
"supports_tool_choice": true,
"supports_response_schema": true,
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
},
"gemini-1.5-pro-preview-0409": {
@ -1585,6 +1589,7 @@
"mode": "chat",
"supports_function_calling": true,
"supports_tool_choice": true,
"supports_response_schema": true,
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
},
"gemini-1.5-flash": {
@ -2007,6 +2012,7 @@
"supports_function_calling": true,
"supports_vision": true,
"supports_tool_choice": true,
"supports_response_schema": true,
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
},
"gemini/gemini-1.5-pro-latest": {
@ -2023,6 +2029,7 @@
"supports_function_calling": true,
"supports_vision": true,
"supports_tool_choice": true,
"supports_response_schema": true,
"source": "https://ai.google.dev/models/gemini"
},
"gemini/gemini-pro-vision": {

View file

@ -1,6 +1,6 @@
[tool.poetry]
name = "litellm"
version = "1.40.31"
version = "1.41.3"
description = "Library to easily interface with LLM API providers"
authors = ["BerriAI"]
license = "MIT"
@ -27,7 +27,7 @@ jinja2 = "^3.1.2"
aiohttp = "*"
requests = "^2.31.0"
pydantic = "^2.0.0"
ijson = "*"
jsonschema = "^4.22.0"
uvicorn = {version = "^0.22.0", optional = true}
gunicorn = {version = "^22.0.0", optional = true}
@ -90,7 +90,7 @@ requires = ["poetry-core", "wheel"]
build-backend = "poetry.core.masonry.api"
[tool.commitizen]
version = "1.40.31"
version = "1.41.3"
version_files = [
"pyproject.toml:^version"
]

View file

@ -46,5 +46,5 @@ aiohttp==3.9.0 # for network calls
aioboto3==12.3.0 # for async sagemaker calls
tenacity==8.2.3 # for retrying requests, when litellm.num_retries set
pydantic==2.7.1 # proxy + openai req.
ijson==3.2.3 # for google ai studio streaming
jsonschema==4.22.0 # validating json schema
####

59
tests/test_entrypoint.py Normal file
View file

@ -0,0 +1,59 @@
# What is this?
## Unit tests for 'entrypoint.sh'
import pytest
import sys
import os
sys.path.insert(
0, os.path.abspath("../")
) # Adds the parent directory to the system path
import litellm
import subprocess
@pytest.mark.skip(reason="local test")
def test_decrypt_and_reset_env():
os.environ["DATABASE_URL"] = (
"aws_kms/AQICAHgwddjZ9xjVaZ9CNCG8smFU6FiQvfdrjL12DIqi9vUAQwHwF6U7caMgHQa6tK+TzaoMAAAAzjCBywYJKoZIhvcNAQcGoIG9MIG6AgEAMIG0BgkqhkiG9w0BBwEwHgYJYIZIAWUDBAEuMBEEDCmu+DVeKTm5tFZu6AIBEICBhnOFQYviL8JsciGk0bZsn9pfzeYWtNkVXEsl01AdgHBqT9UOZOI4ZC+T3wO/fXA7wdNF4o8ASPDbVZ34ZFdBs8xt4LKp9niufL30WYBkuuzz89ztly0jvE9pZ8L6BMw0ATTaMgIweVtVSDCeCzEb5PUPyxt4QayrlYHBGrNH5Aq/axFTe0La"
)
from litellm.proxy.secret_managers.aws_secret_manager import (
decrypt_and_reset_env_var,
)
decrypt_and_reset_env_var()
assert os.environ["DATABASE_URL"] is not None
assert isinstance(os.environ["DATABASE_URL"], str)
assert not os.environ["DATABASE_URL"].startswith("aws_kms/")
print("DATABASE_URL={}".format(os.environ["DATABASE_URL"]))
@pytest.mark.skip(reason="local test")
def test_entrypoint_decrypt_and_reset():
os.environ["DATABASE_URL"] = (
"aws_kms/AQICAHgwddjZ9xjVaZ9CNCG8smFU6FiQvfdrjL12DIqi9vUAQwHwF6U7caMgHQa6tK+TzaoMAAAAzjCBywYJKoZIhvcNAQcGoIG9MIG6AgEAMIG0BgkqhkiG9w0BBwEwHgYJYIZIAWUDBAEuMBEEDCmu+DVeKTm5tFZu6AIBEICBhnOFQYviL8JsciGk0bZsn9pfzeYWtNkVXEsl01AdgHBqT9UOZOI4ZC+T3wO/fXA7wdNF4o8ASPDbVZ34ZFdBs8xt4LKp9niufL30WYBkuuzz89ztly0jvE9pZ8L6BMw0ATTaMgIweVtVSDCeCzEb5PUPyxt4QayrlYHBGrNH5Aq/axFTe0La"
)
command = "./entrypoint.sh"
directory = ".." # Relative to the current directory
# Run the command using subprocess
result = subprocess.run(
command, shell=True, cwd=directory, capture_output=True, text=True
)
# Print the output for debugging purposes
print("STDOUT:", result.stdout)
print("STDERR:", result.stderr)
# Assert the script ran successfully
assert result.returncode == 0, "The shell script did not execute successfully"
assert (
"DECRYPTS VALUE" in result.stdout
), "Expected output not found in script output"
assert (
"Database push successful!" in result.stdout
), "Expected output not found in script output"
assert False