mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
Merge branch 'main' into litellm-fix-vertexaibeta
This commit is contained in:
commit
26630cd263
73 changed files with 3482 additions and 782 deletions
|
@ -66,7 +66,7 @@ jobs:
|
||||||
pip install "pydantic==2.7.1"
|
pip install "pydantic==2.7.1"
|
||||||
pip install "diskcache==5.6.1"
|
pip install "diskcache==5.6.1"
|
||||||
pip install "Pillow==10.3.0"
|
pip install "Pillow==10.3.0"
|
||||||
pip install "ijson==3.2.3"
|
pip install "jsonschema==4.22.0"
|
||||||
- save_cache:
|
- save_cache:
|
||||||
paths:
|
paths:
|
||||||
- ./venv
|
- ./venv
|
||||||
|
@ -128,7 +128,7 @@ jobs:
|
||||||
pip install jinja2
|
pip install jinja2
|
||||||
pip install tokenizers
|
pip install tokenizers
|
||||||
pip install openai
|
pip install openai
|
||||||
pip install ijson
|
pip install jsonschema
|
||||||
- run:
|
- run:
|
||||||
name: Run tests
|
name: Run tests
|
||||||
command: |
|
command: |
|
||||||
|
@ -183,7 +183,7 @@ jobs:
|
||||||
pip install numpydoc
|
pip install numpydoc
|
||||||
pip install prisma
|
pip install prisma
|
||||||
pip install fastapi
|
pip install fastapi
|
||||||
pip install ijson
|
pip install jsonschema
|
||||||
pip install "httpx==0.24.1"
|
pip install "httpx==0.24.1"
|
||||||
pip install "gunicorn==21.2.0"
|
pip install "gunicorn==21.2.0"
|
||||||
pip install "anyio==3.7.1"
|
pip install "anyio==3.7.1"
|
||||||
|
@ -212,6 +212,7 @@ jobs:
|
||||||
-e AWS_REGION_NAME=$AWS_REGION_NAME \
|
-e AWS_REGION_NAME=$AWS_REGION_NAME \
|
||||||
-e AUTO_INFER_REGION=True \
|
-e AUTO_INFER_REGION=True \
|
||||||
-e OPENAI_API_KEY=$OPENAI_API_KEY \
|
-e OPENAI_API_KEY=$OPENAI_API_KEY \
|
||||||
|
-e LITELLM_LICENSE=$LITELLM_LICENSE \
|
||||||
-e LANGFUSE_PROJECT1_PUBLIC=$LANGFUSE_PROJECT1_PUBLIC \
|
-e LANGFUSE_PROJECT1_PUBLIC=$LANGFUSE_PROJECT1_PUBLIC \
|
||||||
-e LANGFUSE_PROJECT2_PUBLIC=$LANGFUSE_PROJECT2_PUBLIC \
|
-e LANGFUSE_PROJECT2_PUBLIC=$LANGFUSE_PROJECT2_PUBLIC \
|
||||||
-e LANGFUSE_PROJECT1_SECRET=$LANGFUSE_PROJECT1_SECRET \
|
-e LANGFUSE_PROJECT1_SECRET=$LANGFUSE_PROJECT1_SECRET \
|
||||||
|
|
|
@ -50,7 +50,7 @@ Use `litellm.get_supported_openai_params()` for an updated list of params for ea
|
||||||
|Huggingface| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | |
|
|Huggingface| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | |
|
||||||
|Openrouter| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | | | ✅ | | | | |
|
|Openrouter| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | | | ✅ | | | | |
|
||||||
|AI21| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | |
|
|AI21| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | |
|
||||||
|VertexAI| ✅ | ✅ | | ✅ | ✅ | | | | | | | | | | ✅ | | |
|
|VertexAI| ✅ | ✅ | | ✅ | ✅ | | | | | | | | | ✅ | ✅ | | |
|
||||||
|Bedrock| ✅ | ✅ | ✅ | ✅ | ✅ | | | | | | | | | | ✅ (for anthropic) | |
|
|Bedrock| ✅ | ✅ | ✅ | ✅ | ✅ | | | | | | | | | | ✅ (for anthropic) | |
|
||||||
|Sagemaker| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | |
|
|Sagemaker| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | |
|
||||||
|TogetherAI| ✅ | ✅ | ✅ | ✅ | ✅ | | | | | | ✅ |
|
|TogetherAI| ✅ | ✅ | ✅ | ✅ | ✅ | | | | | | ✅ |
|
||||||
|
|
|
@ -2,26 +2,39 @@
|
||||||
For companies that need SSO, user management and professional support for LiteLLM Proxy
|
For companies that need SSO, user management and professional support for LiteLLM Proxy
|
||||||
|
|
||||||
:::info
|
:::info
|
||||||
|
Interested in Enterprise? Schedule a meeting with us here 👉
|
||||||
[Talk to founders](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat)
|
[Talk to founders](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat)
|
||||||
|
|
||||||
:::
|
:::
|
||||||
|
|
||||||
This covers:
|
This covers:
|
||||||
- ✅ **Features under the [LiteLLM Commercial License (Content Mod, Custom Tags, etc.)](https://docs.litellm.ai/docs/proxy/enterprise)**
|
- **Enterprise Features**
|
||||||
- ✅ [**Secure UI access with Single Sign-On**](../docs/proxy/ui.md#setup-ssoauth-for-ui)
|
- **Security**
|
||||||
- ✅ [**Audit Logs with retention policy**](../docs/proxy/enterprise.md#audit-logs)
|
- ✅ [SSO for Admin UI](./proxy/ui#✨-enterprise-features)
|
||||||
- ✅ [**JWT-Auth**](../docs/proxy/token_auth.md)
|
- ✅ [Audit Logs with retention policy](./proxy/enterprise#audit-logs)
|
||||||
- ✅ [**Control available public, private routes**](../docs/proxy/enterprise.md#control-available-public-private-routes)
|
- ✅ [JWT-Auth](../docs/proxy/token_auth.md)
|
||||||
- ✅ [**Guardrails, Content Moderation, PII Masking, Secret/API Key Masking**](../docs/proxy/enterprise.md#prompt-injection-detection---lakeraai)
|
- ✅ [Control available public, private routes](./proxy/enterprise#control-available-public-private-routes)
|
||||||
- ✅ [**Prompt Injection Detection**](../docs/proxy/enterprise.md#prompt-injection-detection---lakeraai)
|
- ✅ [[BETA] AWS Key Manager v2 - Key Decryption](./proxy/enterprise#beta-aws-key-manager---key-decryption)
|
||||||
- ✅ [**Invite Team Members to access `/spend` Routes**](../docs/proxy/cost_tracking#allowing-non-proxy-admins-to-access-spend-endpoints)
|
- ✅ [Use LiteLLM keys/authentication on Pass Through Endpoints](./proxy/pass_through#✨-enterprise---use-litellm-keysauthentication-on-pass-through-endpoints)
|
||||||
|
- ✅ [Enforce Required Params for LLM Requests (ex. Reject requests missing ["metadata"]["generation_name"])](./proxy/enterprise#enforce-required-params-for-llm-requests)
|
||||||
|
- **Spend Tracking**
|
||||||
|
- ✅ [Tracking Spend for Custom Tags](./proxy/enterprise#tracking-spend-for-custom-tags)
|
||||||
|
- ✅ [API Endpoints to get Spend Reports per Team, API Key, Customer](./proxy/cost_tracking.md#✨-enterprise-api-endpoints-to-get-spend)
|
||||||
|
- **Advanced Metrics**
|
||||||
|
- ✅ [`x-ratelimit-remaining-requests`, `x-ratelimit-remaining-tokens` for LLM APIs on Prometheus](./proxy/prometheus#✨-enterprise-llm-remaining-requests-and-remaining-tokens)
|
||||||
|
- **Guardrails, PII Masking, Content Moderation**
|
||||||
|
- ✅ [Content Moderation with LLM Guard, LlamaGuard, Secret Detection, Google Text Moderations](./proxy/enterprise#content-moderation)
|
||||||
|
- ✅ [Prompt Injection Detection (with LakeraAI API)](./proxy/enterprise#prompt-injection-detection---lakeraai)
|
||||||
|
- ✅ Reject calls from Blocked User list
|
||||||
|
- ✅ Reject calls (incoming / outgoing) with Banned Keywords (e.g. competitors)
|
||||||
|
- **Custom Branding**
|
||||||
|
- ✅ [Custom Branding + Routes on Swagger Docs](./proxy/enterprise#swagger-docs---custom-routes--branding)
|
||||||
|
- ✅ [Public Model Hub](../docs/proxy/enterprise.md#public-model-hub)
|
||||||
|
- ✅ [Custom Email Branding](../docs/proxy/email.md#customizing-email-branding)
|
||||||
- ✅ **Feature Prioritization**
|
- ✅ **Feature Prioritization**
|
||||||
- ✅ **Custom Integrations**
|
- ✅ **Custom Integrations**
|
||||||
- ✅ **Professional Support - Dedicated discord + slack**
|
- ✅ **Professional Support - Dedicated discord + slack**
|
||||||
- ✅ [**Custom Swagger**](../docs/proxy/enterprise.md#swagger-docs---custom-routes--branding)
|
|
||||||
- ✅ [**Public Model Hub**](../docs/proxy/enterprise.md#public-model-hub)
|
|
||||||
- ✅ [**Custom Email Branding**](../docs/proxy/email.md#customizing-email-branding)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -168,8 +168,12 @@ print(response)
|
||||||
|
|
||||||
## Supported Models
|
## Supported Models
|
||||||
|
|
||||||
|
`Model Name` 👉 Human-friendly name.
|
||||||
|
`Function Call` 👉 How to call the model in LiteLLM.
|
||||||
|
|
||||||
| Model Name | Function Call |
|
| Model Name | Function Call |
|
||||||
|------------------|--------------------------------------------|
|
|------------------|--------------------------------------------|
|
||||||
|
| claude-3-5-sonnet | `completion('claude-3-5-sonnet-20240620', messages)` | `os.environ['ANTHROPIC_API_KEY']` |
|
||||||
| claude-3-haiku | `completion('claude-3-haiku-20240307', messages)` | `os.environ['ANTHROPIC_API_KEY']` |
|
| claude-3-haiku | `completion('claude-3-haiku-20240307', messages)` | `os.environ['ANTHROPIC_API_KEY']` |
|
||||||
| claude-3-opus | `completion('claude-3-opus-20240229', messages)` | `os.environ['ANTHROPIC_API_KEY']` |
|
| claude-3-opus | `completion('claude-3-opus-20240229', messages)` | `os.environ['ANTHROPIC_API_KEY']` |
|
||||||
| claude-3-5-sonnet-20240620 | `completion('claude-3-5-sonnet-20240620', messages)` | `os.environ['ANTHROPIC_API_KEY']` |
|
| claude-3-5-sonnet-20240620 | `completion('claude-3-5-sonnet-20240620', messages)` | `os.environ['ANTHROPIC_API_KEY']` |
|
||||||
|
|
|
@ -14,7 +14,7 @@ LiteLLM supports all models on Azure AI Studio
|
||||||
### ENV VAR
|
### ENV VAR
|
||||||
```python
|
```python
|
||||||
import os
|
import os
|
||||||
os.environ["AZURE_API_API_KEY"] = ""
|
os.environ["AZURE_AI_API_KEY"] = ""
|
||||||
os.environ["AZURE_AI_API_BASE"] = ""
|
os.environ["AZURE_AI_API_BASE"] = ""
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -24,7 +24,7 @@ os.environ["AZURE_AI_API_BASE"] = ""
|
||||||
from litellm import completion
|
from litellm import completion
|
||||||
import os
|
import os
|
||||||
## set ENV variables
|
## set ENV variables
|
||||||
os.environ["AZURE_API_API_KEY"] = "azure ai key"
|
os.environ["AZURE_AI_API_KEY"] = "azure ai key"
|
||||||
os.environ["AZURE_AI_API_BASE"] = "azure ai base url" # e.g.: https://Mistral-large-dfgfj-serverless.eastus2.inference.ai.azure.com/
|
os.environ["AZURE_AI_API_BASE"] = "azure ai base url" # e.g.: https://Mistral-large-dfgfj-serverless.eastus2.inference.ai.azure.com/
|
||||||
|
|
||||||
# predibase llama-3 call
|
# predibase llama-3 call
|
||||||
|
|
|
@ -549,6 +549,10 @@ response = completion(
|
||||||
|
|
||||||
This is a deprecated flow. Boto3 is not async. And boto3.client does not let us make the http call through httpx. Pass in your aws params through the method above 👆. [See Auth Code](https://github.com/BerriAI/litellm/blob/55a20c7cce99a93d36a82bf3ae90ba3baf9a7f89/litellm/llms/bedrock_httpx.py#L284) [Add new auth flow](https://github.com/BerriAI/litellm/issues)
|
This is a deprecated flow. Boto3 is not async. And boto3.client does not let us make the http call through httpx. Pass in your aws params through the method above 👆. [See Auth Code](https://github.com/BerriAI/litellm/blob/55a20c7cce99a93d36a82bf3ae90ba3baf9a7f89/litellm/llms/bedrock_httpx.py#L284) [Add new auth flow](https://github.com/BerriAI/litellm/issues)
|
||||||
|
|
||||||
|
|
||||||
|
Experimental - 2024-Jun-23:
|
||||||
|
`aws_access_key_id`, `aws_secret_access_key`, and `aws_session_token` will be extracted from boto3.client and be passed into the httpx client
|
||||||
|
|
||||||
:::
|
:::
|
||||||
|
|
||||||
Pass an external BedrockRuntime.Client object as a parameter to litellm.completion. Useful when using an AWS credentials profile, SSO session, assumed role session, or if environment variables are not available for auth.
|
Pass an external BedrockRuntime.Client object as a parameter to litellm.completion. Useful when using an AWS credentials profile, SSO session, assumed role session, or if environment variables are not available for auth.
|
||||||
|
|
|
@ -18,7 +18,7 @@ import litellm
|
||||||
import os
|
import os
|
||||||
|
|
||||||
response = litellm.completion(
|
response = litellm.completion(
|
||||||
model="openai/mistral, # add `openai/` prefix to model so litellm knows to route to OpenAI
|
model="openai/mistral", # add `openai/` prefix to model so litellm knows to route to OpenAI
|
||||||
api_key="sk-1234", # api key to your openai compatible endpoint
|
api_key="sk-1234", # api key to your openai compatible endpoint
|
||||||
api_base="http://0.0.0.0:4000", # set API Base of your Custom OpenAI Endpoint
|
api_base="http://0.0.0.0:4000", # set API Base of your Custom OpenAI Endpoint
|
||||||
messages=[
|
messages=[
|
||||||
|
|
|
@ -123,6 +123,182 @@ print(completion(**data))
|
||||||
|
|
||||||
### **JSON Schema**
|
### **JSON Schema**
|
||||||
|
|
||||||
|
From v`1.40.1+` LiteLLM supports sending `response_schema` as a param for Gemini-1.5-Pro on Vertex AI. For other models (e.g. `gemini-1.5-flash` or `claude-3-5-sonnet`), LiteLLM adds the schema to the message list with a user-controlled prompt.
|
||||||
|
|
||||||
|
**Response Schema**
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="sdk" label="SDK">
|
||||||
|
|
||||||
|
```python
|
||||||
|
from litellm import completion
|
||||||
|
import json
|
||||||
|
|
||||||
|
## SETUP ENVIRONMENT
|
||||||
|
# !gcloud auth application-default login - run this to add vertex credentials to your env
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "List 5 popular cookie recipes."
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
response_schema = {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"recipe_name": {
|
||||||
|
"type": "string",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["recipe_name"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
completion(
|
||||||
|
model="vertex_ai_beta/gemini-1.5-pro",
|
||||||
|
messages=messages,
|
||||||
|
response_format={"type": "json_object", "response_schema": response_schema} # 👈 KEY CHANGE
|
||||||
|
)
|
||||||
|
|
||||||
|
print(json.loads(completion.choices[0].message.content))
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="proxy" label="PROXY">
|
||||||
|
|
||||||
|
1. Add model to config.yaml
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: gemini-pro
|
||||||
|
litellm_params:
|
||||||
|
model: vertex_ai_beta/gemini-1.5-pro
|
||||||
|
vertex_project: "project-id"
|
||||||
|
vertex_location: "us-central1"
|
||||||
|
vertex_credentials: "/path/to/service_account.json" # [OPTIONAL] Do this OR `!gcloud auth application-default login` - run this to add vertex credentials to your env
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Start Proxy
|
||||||
|
|
||||||
|
```
|
||||||
|
$ litellm --config /path/to/config.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Make Request!
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST 'http://0.0.0.0:4000/chat/completions' \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-H 'Authorization: Bearer sk-1234' \
|
||||||
|
-D '{
|
||||||
|
"model": "gemini-pro",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "List 5 popular cookie recipes."}
|
||||||
|
],
|
||||||
|
"response_format": {"type": "json_object", "response_schema": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"recipe_name": {
|
||||||
|
"type": "string",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["recipe_name"],
|
||||||
|
},
|
||||||
|
}}
|
||||||
|
}
|
||||||
|
'
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
**Validate Schema**
|
||||||
|
|
||||||
|
To validate the response_schema, set `enforce_validation: true`.
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="sdk" label="SDK">
|
||||||
|
|
||||||
|
```python
|
||||||
|
from litellm import completion, JSONSchemaValidationError
|
||||||
|
try:
|
||||||
|
completion(
|
||||||
|
model="vertex_ai_beta/gemini-1.5-pro",
|
||||||
|
messages=messages,
|
||||||
|
response_format={
|
||||||
|
"type": "json_object",
|
||||||
|
"response_schema": response_schema,
|
||||||
|
"enforce_validation": true # 👈 KEY CHANGE
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except JSONSchemaValidationError as e:
|
||||||
|
print("Raw Response: {}".format(e.raw_response))
|
||||||
|
raise e
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="proxy" label="PROXY">
|
||||||
|
|
||||||
|
1. Add model to config.yaml
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: gemini-pro
|
||||||
|
litellm_params:
|
||||||
|
model: vertex_ai_beta/gemini-1.5-pro
|
||||||
|
vertex_project: "project-id"
|
||||||
|
vertex_location: "us-central1"
|
||||||
|
vertex_credentials: "/path/to/service_account.json" # [OPTIONAL] Do this OR `!gcloud auth application-default login` - run this to add vertex credentials to your env
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Start Proxy
|
||||||
|
|
||||||
|
```
|
||||||
|
$ litellm --config /path/to/config.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Make Request!
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST 'http://0.0.0.0:4000/chat/completions' \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-H 'Authorization: Bearer sk-1234' \
|
||||||
|
-D '{
|
||||||
|
"model": "gemini-pro",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "List 5 popular cookie recipes."}
|
||||||
|
],
|
||||||
|
"response_format": {"type": "json_object", "response_schema": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"recipe_name": {
|
||||||
|
"type": "string",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["recipe_name"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"enforce_validation": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
'
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
LiteLLM will validate the response against the schema, and raise a `JSONSchemaValidationError` if the response does not match the schema.
|
||||||
|
|
||||||
|
JSONSchemaValidationError inherits from `openai.APIError`
|
||||||
|
|
||||||
|
Access the raw response with `e.raw_response`
|
||||||
|
|
||||||
|
**Add to prompt yourself**
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from litellm import completion
|
from litellm import completion
|
||||||
|
|
||||||
|
@ -645,6 +821,86 @@ assert isinstance(
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Usage - PDF / Videos / etc. Files
|
||||||
|
|
||||||
|
Pass any file supported by Vertex AI, through LiteLLM.
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="sdk" label="SDK">
|
||||||
|
|
||||||
|
```python
|
||||||
|
from litellm import completion
|
||||||
|
|
||||||
|
response = completion(
|
||||||
|
model="vertex_ai/gemini-1.5-flash",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "You are a very professional document summarization specialist. Please summarize the given document."},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": "gs://cloud-samples-data/generative-ai/pdf/2403.05530.pdf",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
max_tokens=300,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response.choices[0])
|
||||||
|
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="proxy" lable="PROXY">
|
||||||
|
|
||||||
|
1. Add model to config
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
- model_name: gemini-1.5-flash
|
||||||
|
litellm_params:
|
||||||
|
model: vertex_ai/gemini-1.5-flash
|
||||||
|
vertex_credentials: "/path/to/service_account.json"
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Start Proxy
|
||||||
|
|
||||||
|
```
|
||||||
|
litellm --config /path/to/config.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Test it!
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl http://0.0.0.0:4000/v1/chat/completions \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-H "Authorization: Bearer <YOUR-LITELLM-KEY>" \
|
||||||
|
-d '{
|
||||||
|
"model": "gemini-1.5-flash",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "You are a very professional document summarization specialist. Please summarize the given document"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": "gs://cloud-samples-data/generative-ai/pdf/2403.05530.pdf",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"max_tokens": 300
|
||||||
|
}'
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
## Chat Models
|
## Chat Models
|
||||||
| Model Name | Function Call |
|
| Model Name | Function Call |
|
||||||
|------------------|--------------------------------------|
|
|------------------|--------------------------------------|
|
||||||
|
|
|
@ -277,6 +277,54 @@ curl --location 'http://0.0.0.0:4000/v1/model/info' \
|
||||||
--data ''
|
--data ''
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Wildcard Model Name (Add ALL MODELS from env)
|
||||||
|
|
||||||
|
Dynamically call any model from any given provider without the need to predefine it in the config YAML file. As long as the relevant keys are in the environment (see [providers list](../providers/)), LiteLLM will make the call correctly.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
1. Setup config.yaml
|
||||||
|
```
|
||||||
|
model_list:
|
||||||
|
- model_name: "*" # all requests where model not in your config go to this deployment
|
||||||
|
litellm_params:
|
||||||
|
model: "openai/*" # passes our validation check that a real provider is given
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Start LiteLLM proxy
|
||||||
|
|
||||||
|
```
|
||||||
|
litellm --config /path/to/config.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Try claude 3-5 sonnet from anthropic
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST 'http://0.0.0.0:4000/chat/completions' \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-H 'Authorization: Bearer sk-1234' \
|
||||||
|
-D '{
|
||||||
|
"model": "claude-3-5-sonnet-20240620",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Hey, how'\''s it going?"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "I'\''m doing well. Would like to hear the rest of the story?"
|
||||||
|
},
|
||||||
|
{"role": "user", "content": "Na"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "No problem, is there anything else i can help you with today?"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "I think you'\''re getting cut off sometimes"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
'
|
||||||
|
```
|
||||||
|
|
||||||
## Load Balancing
|
## Load Balancing
|
||||||
|
|
||||||
:::info
|
:::info
|
||||||
|
|
|
@ -117,6 +117,8 @@ That's IT. Now Verify your spend was tracked
|
||||||
<Tabs>
|
<Tabs>
|
||||||
<TabItem value="curl" label="Response Headers">
|
<TabItem value="curl" label="Response Headers">
|
||||||
|
|
||||||
|
Expect to see `x-litellm-response-cost` in the response headers with calculated cost
|
||||||
|
|
||||||
<Image img={require('../../img/response_cost_img.png')} />
|
<Image img={require('../../img/response_cost_img.png')} />
|
||||||
|
|
||||||
</TabItem>
|
</TabItem>
|
||||||
|
@ -145,16 +147,16 @@ Navigate to the Usage Tab on the LiteLLM UI (found on https://your-proxy-endpoin
|
||||||
|
|
||||||
<Image img={require('../../img/admin_ui_spend.png')} />
|
<Image img={require('../../img/admin_ui_spend.png')} />
|
||||||
|
|
||||||
## API Endpoints to get Spend
|
|
||||||
#### Getting Spend Reports - To Charge Other Teams, Customers
|
|
||||||
|
|
||||||
Use the `/global/spend/report` endpoint to get daily spend report per
|
|
||||||
- team
|
|
||||||
- customer [this is `user` passed to `/chat/completions` request](#how-to-track-spend-with-litellm)
|
|
||||||
|
|
||||||
</TabItem>
|
</TabItem>
|
||||||
</Tabs>
|
</Tabs>
|
||||||
|
|
||||||
|
## ✨ (Enterprise) API Endpoints to get Spend
|
||||||
|
#### Getting Spend Reports - To Charge Other Teams, Customers
|
||||||
|
|
||||||
|
Use the `/global/spend/report` endpoint to get daily spend report per
|
||||||
|
- Team
|
||||||
|
- Customer [this is `user` passed to `/chat/completions` request](#how-to-track-spend-with-litellm)
|
||||||
|
- [LiteLLM API key](virtual_keys.md)
|
||||||
|
|
||||||
<Tabs>
|
<Tabs>
|
||||||
|
|
||||||
|
@ -337,6 +339,61 @@ curl -X GET 'http://localhost:4000/global/spend/report?start_date=2024-04-01&end
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
<TabItem value="per key" label="Spend Per API Key">
|
||||||
|
|
||||||
|
|
||||||
|
👉 Key Change: Specify `group_by=api_key`
|
||||||
|
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl -X GET 'http://localhost:4000/global/spend/report?start_date=2024-04-01&end_date=2024-06-30&group_by=api_key' \
|
||||||
|
-H 'Authorization: Bearer sk-1234'
|
||||||
|
```
|
||||||
|
|
||||||
|
##### Example Response
|
||||||
|
|
||||||
|
|
||||||
|
```shell
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"api_key": "ad64768847d05d978d62f623d872bff0f9616cc14b9c1e651c84d14fe3b9f539",
|
||||||
|
"total_cost": 0.0002157,
|
||||||
|
"total_input_tokens": 45.0,
|
||||||
|
"total_output_tokens": 1375.0,
|
||||||
|
"model_details": [
|
||||||
|
{
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"total_cost": 0.0001095,
|
||||||
|
"total_input_tokens": 9,
|
||||||
|
"total_output_tokens": 70
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model": "llama3-8b-8192",
|
||||||
|
"total_cost": 0.0001062,
|
||||||
|
"total_input_tokens": 36,
|
||||||
|
"total_output_tokens": 1305
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"api_key": "88dc28d0f030c55ed4ab77ed8faf098196cb1c05df778539800c9f1243fe6b4b",
|
||||||
|
"total_cost": 0.00012924,
|
||||||
|
"total_input_tokens": 36.0,
|
||||||
|
"total_output_tokens": 1593.0,
|
||||||
|
"model_details": [
|
||||||
|
{
|
||||||
|
"model": "llama3-8b-8192",
|
||||||
|
"total_cost": 0.00012924,
|
||||||
|
"total_input_tokens": 36,
|
||||||
|
"total_output_tokens": 1593
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
</TabItem>
|
</TabItem>
|
||||||
|
|
||||||
</Tabs>
|
</Tabs>
|
||||||
|
|
|
@ -89,3 +89,30 @@ Expected Output:
|
||||||
```bash
|
```bash
|
||||||
# no info statements
|
# no info statements
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Common Errors
|
||||||
|
|
||||||
|
1. "No available deployments..."
|
||||||
|
|
||||||
|
```
|
||||||
|
No deployments available for selected model, Try again in 60 seconds. Passed model=claude-3-5-sonnet. pre-call-checks=False, allowed_model_region=n/a.
|
||||||
|
```
|
||||||
|
|
||||||
|
This can be caused due to all your models hitting rate limit errors, causing the cooldown to kick in.
|
||||||
|
|
||||||
|
How to control this?
|
||||||
|
- Adjust the cooldown time
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
router_settings:
|
||||||
|
cooldown_time: 0 # 👈 KEY CHANGE
|
||||||
|
```
|
||||||
|
|
||||||
|
- Disable Cooldowns [NOT RECOMMENDED]
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
router_settings:
|
||||||
|
disable_cooldowns: True
|
||||||
|
```
|
||||||
|
|
||||||
|
This is not recommended, as it will lead to requests being routed to deployments over their tpm/rpm limit.
|
|
@ -6,21 +6,34 @@ import TabItem from '@theme/TabItem';
|
||||||
|
|
||||||
:::tip
|
:::tip
|
||||||
|
|
||||||
Get in touch with us [here](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat)
|
To get a license, get in touch with us [here](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat)
|
||||||
|
|
||||||
:::
|
:::
|
||||||
|
|
||||||
Features:
|
Features:
|
||||||
- ✅ [SSO for Admin UI](./ui.md#✨-enterprise-features)
|
|
||||||
- ✅ [Audit Logs](#audit-logs)
|
- **Security**
|
||||||
- ✅ [Tracking Spend for Custom Tags](#tracking-spend-for-custom-tags)
|
- ✅ [SSO for Admin UI](./ui.md#✨-enterprise-features)
|
||||||
- ✅ [Control available public, private routes](#control-available-public-private-routes)
|
- ✅ [Audit Logs with retention policy](#audit-logs)
|
||||||
- ✅ [Content Moderation with LLM Guard, LlamaGuard, Secret Detection, Google Text Moderations](#content-moderation)
|
- ✅ [JWT-Auth](../docs/proxy/token_auth.md)
|
||||||
- ✅ [Prompt Injection Detection (with LakeraAI API)](#prompt-injection-detection---lakeraai)
|
- ✅ [Control available public, private routes](#control-available-public-private-routes)
|
||||||
- ✅ [Custom Branding + Routes on Swagger Docs](#swagger-docs---custom-routes--branding)
|
- ✅ [[BETA] AWS Key Manager v2 - Key Decryption](#beta-aws-key-manager---key-decryption)
|
||||||
- ✅ [Enforce Required Params for LLM Requests (ex. Reject requests missing ["metadata"]["generation_name"])](#enforce-required-params-for-llm-requests)
|
- ✅ [Use LiteLLM keys/authentication on Pass Through Endpoints](pass_through#✨-enterprise---use-litellm-keysauthentication-on-pass-through-endpoints)
|
||||||
- ✅ Reject calls from Blocked User list
|
- ✅ [Enforce Required Params for LLM Requests (ex. Reject requests missing ["metadata"]["generation_name"])](#enforce-required-params-for-llm-requests)
|
||||||
- ✅ Reject calls (incoming / outgoing) with Banned Keywords (e.g. competitors)
|
- **Spend Tracking**
|
||||||
|
- ✅ [Tracking Spend for Custom Tags](#tracking-spend-for-custom-tags)
|
||||||
|
- ✅ [API Endpoints to get Spend Reports per Team, API Key, Customer](cost_tracking.md#✨-enterprise-api-endpoints-to-get-spend)
|
||||||
|
- **Advanced Metrics**
|
||||||
|
- ✅ [`x-ratelimit-remaining-requests`, `x-ratelimit-remaining-tokens` for LLM APIs on Prometheus](prometheus#✨-enterprise-llm-remaining-requests-and-remaining-tokens)
|
||||||
|
- **Guardrails, PII Masking, Content Moderation**
|
||||||
|
- ✅ [Content Moderation with LLM Guard, LlamaGuard, Secret Detection, Google Text Moderations](#content-moderation)
|
||||||
|
- ✅ [Prompt Injection Detection (with LakeraAI API)](#prompt-injection-detection---lakeraai)
|
||||||
|
- ✅ Reject calls from Blocked User list
|
||||||
|
- ✅ Reject calls (incoming / outgoing) with Banned Keywords (e.g. competitors)
|
||||||
|
- **Custom Branding**
|
||||||
|
- ✅ [Custom Branding + Routes on Swagger Docs](#swagger-docs---custom-routes--branding)
|
||||||
|
- ✅ [Public Model Hub](../docs/proxy/enterprise.md#public-model-hub)
|
||||||
|
- ✅ [Custom Email Branding](../docs/proxy/email.md#customizing-email-branding)
|
||||||
|
|
||||||
## Audit Logs
|
## Audit Logs
|
||||||
|
|
||||||
|
@ -1020,3 +1033,34 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||||
Share a public page of available models for users
|
Share a public page of available models for users
|
||||||
|
|
||||||
<Image img={require('../../img/model_hub.png')} style={{ width: '900px', height: 'auto' }}/>
|
<Image img={require('../../img/model_hub.png')} style={{ width: '900px', height: 'auto' }}/>
|
||||||
|
|
||||||
|
|
||||||
|
## [BETA] AWS Key Manager - Key Decryption
|
||||||
|
|
||||||
|
This is a beta feature, and subject to changes.
|
||||||
|
|
||||||
|
|
||||||
|
**Step 1.** Add `USE_AWS_KMS` to env
|
||||||
|
|
||||||
|
```env
|
||||||
|
USE_AWS_KMS="True"
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2.** Add `aws_kms/` to encrypted keys in env
|
||||||
|
|
||||||
|
```env
|
||||||
|
DATABASE_URL="aws_kms/AQICAH.."
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 3.** Start proxy
|
||||||
|
|
||||||
|
```
|
||||||
|
$ litellm
|
||||||
|
```
|
||||||
|
|
||||||
|
How it works?
|
||||||
|
- Key Decryption runs before server starts up. [**Code**](https://github.com/BerriAI/litellm/blob/8571cb45e80cc561dc34bc6aa89611eb96b9fe3e/litellm/proxy/proxy_cli.py#L445)
|
||||||
|
- It adds the decrypted value to the `os.environ` for the python process.
|
||||||
|
|
||||||
|
**Note:** Setting an environment variable within a Python script using os.environ will not make that variable accessible via SSH sessions or any other new processes that are started independently of the Python script. Environment variables set this way only affect the current process and its child processes.
|
||||||
|
|
||||||
|
|
|
@ -1188,6 +1188,7 @@ litellm_settings:
|
||||||
s3_region_name: us-west-2 # AWS Region Name for S3
|
s3_region_name: us-west-2 # AWS Region Name for S3
|
||||||
s3_aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID # us os.environ/<variable name> to pass environment variables. This is AWS Access Key ID for S3
|
s3_aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID # us os.environ/<variable name> to pass environment variables. This is AWS Access Key ID for S3
|
||||||
s3_aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY # AWS Secret Access Key for S3
|
s3_aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY # AWS Secret Access Key for S3
|
||||||
|
s3_path: my-test-path # [OPTIONAL] set path in bucket you want to write logs to
|
||||||
s3_endpoint_url: https://s3.amazonaws.com # [OPTIONAL] S3 endpoint URL, if you want to use Backblaze/cloudflare s3 buckets
|
s3_endpoint_url: https://s3.amazonaws.com # [OPTIONAL] S3 endpoint URL, if you want to use Backblaze/cloudflare s3 buckets
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
220
docs/my-website/docs/proxy/pass_through.md
Normal file
220
docs/my-website/docs/proxy/pass_through.md
Normal file
|
@ -0,0 +1,220 @@
|
||||||
|
import Image from '@theme/IdealImage';
|
||||||
|
|
||||||
|
# ➡️ Create Pass Through Endpoints
|
||||||
|
|
||||||
|
Add pass through routes to LiteLLM Proxy
|
||||||
|
|
||||||
|
**Example:** Add a route `/v1/rerank` that forwards requests to `https://api.cohere.com/v1/rerank` through LiteLLM Proxy
|
||||||
|
|
||||||
|
|
||||||
|
💡 This allows making the following Request to LiteLLM Proxy
|
||||||
|
```shell
|
||||||
|
curl --request POST \
|
||||||
|
--url http://localhost:4000/v1/rerank \
|
||||||
|
--header 'accept: application/json' \
|
||||||
|
--header 'content-type: application/json' \
|
||||||
|
--data '{
|
||||||
|
"model": "rerank-english-v3.0",
|
||||||
|
"query": "What is the capital of the United States?",
|
||||||
|
"top_n": 3,
|
||||||
|
"documents": ["Carson City is the capital city of the American state of Nevada."]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
## Tutorial - Pass through Cohere Re-Rank Endpoint
|
||||||
|
|
||||||
|
**Step 1** Define pass through routes on [litellm config.yaml](configs.md)
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
general_settings:
|
||||||
|
master_key: sk-1234
|
||||||
|
pass_through_endpoints:
|
||||||
|
- path: "/v1/rerank" # route you want to add to LiteLLM Proxy Server
|
||||||
|
target: "https://api.cohere.com/v1/rerank" # URL this route should forward requests to
|
||||||
|
headers: # headers to forward to this URL
|
||||||
|
Authorization: "bearer os.environ/COHERE_API_KEY" # (Optional) Auth Header to forward to your Endpoint
|
||||||
|
content-type: application/json # (Optional) Extra Headers to pass to this endpoint
|
||||||
|
accept: application/json
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2** Start Proxy Server in detailed_debug mode
|
||||||
|
|
||||||
|
```shell
|
||||||
|
litellm --config config.yaml --detailed_debug
|
||||||
|
```
|
||||||
|
**Step 3** Make Request to pass through endpoint
|
||||||
|
|
||||||
|
Here `http://localhost:4000` is your litellm proxy endpoint
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl --request POST \
|
||||||
|
--url http://localhost:4000/v1/rerank \
|
||||||
|
--header 'accept: application/json' \
|
||||||
|
--header 'content-type: application/json' \
|
||||||
|
--data '{
|
||||||
|
"model": "rerank-english-v3.0",
|
||||||
|
"query": "What is the capital of the United States?",
|
||||||
|
"top_n": 3,
|
||||||
|
"documents": ["Carson City is the capital city of the American state of Nevada.",
|
||||||
|
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.",
|
||||||
|
"Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.",
|
||||||
|
"Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages.",
|
||||||
|
"Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states."]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
🎉 **Expected Response**
|
||||||
|
|
||||||
|
This request got forwarded from LiteLLM Proxy -> Defined Target URL (with headers)
|
||||||
|
|
||||||
|
```shell
|
||||||
|
{
|
||||||
|
"id": "37103a5b-8cfb-48d3-87c7-da288bedd429",
|
||||||
|
"results": [
|
||||||
|
{
|
||||||
|
"index": 2,
|
||||||
|
"relevance_score": 0.999071
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"index": 4,
|
||||||
|
"relevance_score": 0.7867867
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"relevance_score": 0.32713068
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"meta": {
|
||||||
|
"api_version": {
|
||||||
|
"version": "1"
|
||||||
|
},
|
||||||
|
"billed_units": {
|
||||||
|
"search_units": 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Tutorial - Pass Through Langfuse Requests
|
||||||
|
|
||||||
|
|
||||||
|
**Step 1** Define pass through routes on [litellm config.yaml](configs.md)
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
general_settings:
|
||||||
|
master_key: sk-1234
|
||||||
|
pass_through_endpoints:
|
||||||
|
- path: "/api/public/ingestion" # route you want to add to LiteLLM Proxy Server
|
||||||
|
target: "https://us.cloud.langfuse.com/api/public/ingestion" # URL this route should forward
|
||||||
|
headers:
|
||||||
|
LANGFUSE_PUBLIC_KEY: "os.environ/LANGFUSE_DEV_PUBLIC_KEY" # your langfuse account public key
|
||||||
|
LANGFUSE_SECRET_KEY: "os.environ/LANGFUSE_DEV_SK_KEY" # your langfuse account secret key
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2** Start Proxy Server in detailed_debug mode
|
||||||
|
|
||||||
|
```shell
|
||||||
|
litellm --config config.yaml --detailed_debug
|
||||||
|
```
|
||||||
|
**Step 3** Make Request to pass through endpoint
|
||||||
|
|
||||||
|
Run this code to make a sample trace
|
||||||
|
```python
|
||||||
|
from langfuse import Langfuse
|
||||||
|
|
||||||
|
langfuse = Langfuse(
|
||||||
|
host="http://localhost:4000", # your litellm proxy endpoint
|
||||||
|
public_key="anything", # no key required since this is a pass through
|
||||||
|
secret_key="anything", # no key required since this is a pass through
|
||||||
|
)
|
||||||
|
|
||||||
|
print("sending langfuse trace request")
|
||||||
|
trace = langfuse.trace(name="test-trace-litellm-proxy-passthrough")
|
||||||
|
print("flushing langfuse request")
|
||||||
|
langfuse.flush()
|
||||||
|
|
||||||
|
print("flushed langfuse request")
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
🎉 **Expected Response**
|
||||||
|
|
||||||
|
On success
|
||||||
|
Expect to see the following Trace Generated on your Langfuse Dashboard
|
||||||
|
|
||||||
|
<Image img={require('../../img/proxy_langfuse.png')} />
|
||||||
|
|
||||||
|
You will see the following endpoint called on your litellm proxy server logs
|
||||||
|
|
||||||
|
```shell
|
||||||
|
POST /api/public/ingestion HTTP/1.1" 207 Multi-Status
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## ✨ [Enterprise] - Use LiteLLM keys/authentication on Pass Through Endpoints
|
||||||
|
|
||||||
|
Use this if you want the pass through endpoint to honour LiteLLM keys/authentication
|
||||||
|
|
||||||
|
Usage - set `auth: true` on the config
|
||||||
|
```yaml
|
||||||
|
general_settings:
|
||||||
|
master_key: sk-1234
|
||||||
|
pass_through_endpoints:
|
||||||
|
- path: "/v1/rerank"
|
||||||
|
target: "https://api.cohere.com/v1/rerank"
|
||||||
|
auth: true # 👈 Key change to use LiteLLM Auth / Keys
|
||||||
|
headers:
|
||||||
|
Authorization: "bearer os.environ/COHERE_API_KEY"
|
||||||
|
content-type: application/json
|
||||||
|
accept: application/json
|
||||||
|
```
|
||||||
|
|
||||||
|
Test Request with LiteLLM Key
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl --request POST \
|
||||||
|
--url http://localhost:4000/v1/rerank \
|
||||||
|
--header 'accept: application/json' \
|
||||||
|
--header 'Authorization: Bearer sk-1234'\
|
||||||
|
--header 'content-type: application/json' \
|
||||||
|
--data '{
|
||||||
|
"model": "rerank-english-v3.0",
|
||||||
|
"query": "What is the capital of the United States?",
|
||||||
|
"top_n": 3,
|
||||||
|
"documents": ["Carson City is the capital city of the American state of Nevada.",
|
||||||
|
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.",
|
||||||
|
"Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.",
|
||||||
|
"Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages.",
|
||||||
|
"Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states."]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
## `pass_through_endpoints` Spec on config.yaml
|
||||||
|
|
||||||
|
All possible values for `pass_through_endpoints` and what they mean
|
||||||
|
|
||||||
|
**Example config**
|
||||||
|
```yaml
|
||||||
|
general_settings:
|
||||||
|
pass_through_endpoints:
|
||||||
|
- path: "/v1/rerank" # route you want to add to LiteLLM Proxy Server
|
||||||
|
target: "https://api.cohere.com/v1/rerank" # URL this route should forward requests to
|
||||||
|
headers: # headers to forward to this URL
|
||||||
|
Authorization: "bearer os.environ/COHERE_API_KEY" # (Optional) Auth Header to forward to your Endpoint
|
||||||
|
content-type: application/json # (Optional) Extra Headers to pass to this endpoint
|
||||||
|
accept: application/json
|
||||||
|
```
|
||||||
|
|
||||||
|
**Spec**
|
||||||
|
|
||||||
|
* `pass_through_endpoints` *list*: A collection of endpoint configurations for request forwarding.
|
||||||
|
* `path` *string*: The route to be added to the LiteLLM Proxy Server.
|
||||||
|
* `target` *string*: The URL to which requests for this path should be forwarded.
|
||||||
|
* `headers` *object*: Key-value pairs of headers to be forwarded with the request. You can set any key value pair here and it will be forwarded to your target endpoint
|
||||||
|
* `Authorization` *string*: The authentication header for the target API.
|
||||||
|
* `content-type` *string*: The format specification for the request body.
|
||||||
|
* `accept` *string*: The expected response format from the server.
|
||||||
|
* `LANGFUSE_PUBLIC_KEY` *string*: Your Langfuse account public key - only set this when forwarding to Langfuse.
|
||||||
|
* `LANGFUSE_SECRET_KEY` *string*: Your Langfuse account secret key - only set this when forwarding to Langfuse.
|
||||||
|
* `<your-custom-header>` *string*: Pass any custom header key/value pair
|
|
@ -1,3 +1,6 @@
|
||||||
|
import Tabs from '@theme/Tabs';
|
||||||
|
import TabItem from '@theme/TabItem';
|
||||||
|
|
||||||
# 📈 Prometheus metrics [BETA]
|
# 📈 Prometheus metrics [BETA]
|
||||||
|
|
||||||
LiteLLM Exposes a `/metrics` endpoint for Prometheus to Poll
|
LiteLLM Exposes a `/metrics` endpoint for Prometheus to Poll
|
||||||
|
@ -61,6 +64,56 @@ http://localhost:4000/metrics
|
||||||
| `litellm_remaining_api_key_budget_metric` | Remaining Budget for API Key (A key Created on LiteLLM)|
|
| `litellm_remaining_api_key_budget_metric` | Remaining Budget for API Key (A key Created on LiteLLM)|
|
||||||
|
|
||||||
|
|
||||||
|
### ✨ (Enterprise) LLM Remaining Requests and Remaining Tokens
|
||||||
|
Set this on your config.yaml to allow you to track how close you are to hitting your TPM / RPM limits on each model group
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
litellm_settings:
|
||||||
|
success_callback: ["prometheus"]
|
||||||
|
failure_callback: ["prometheus"]
|
||||||
|
return_response_headers: true # ensures the LLM API calls track the response headers
|
||||||
|
```
|
||||||
|
|
||||||
|
| Metric Name | Description |
|
||||||
|
|----------------------|--------------------------------------|
|
||||||
|
| `litellm_remaining_requests_metric` | Track `x-ratelimit-remaining-requests` returned from LLM API Deployment |
|
||||||
|
| `litellm_remaining_tokens` | Track `x-ratelimit-remaining-tokens` return from LLM API Deployment |
|
||||||
|
|
||||||
|
Example Metric
|
||||||
|
<Tabs>
|
||||||
|
|
||||||
|
<TabItem value="Remaining Requests" label="Remaining Requests">
|
||||||
|
|
||||||
|
```shell
|
||||||
|
litellm_remaining_requests
|
||||||
|
{
|
||||||
|
api_base="https://api.openai.com/v1",
|
||||||
|
api_provider="openai",
|
||||||
|
litellm_model_name="gpt-3.5-turbo",
|
||||||
|
model_group="gpt-3.5-turbo"
|
||||||
|
}
|
||||||
|
8998.0
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
<TabItem value="Requests" label="Remaining Tokens">
|
||||||
|
|
||||||
|
```shell
|
||||||
|
litellm_remaining_tokens
|
||||||
|
{
|
||||||
|
api_base="https://api.openai.com/v1",
|
||||||
|
api_provider="openai",
|
||||||
|
litellm_model_name="gpt-3.5-turbo",
|
||||||
|
model_group="gpt-3.5-turbo"
|
||||||
|
}
|
||||||
|
999981.0
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
## Monitor System Health
|
## Monitor System Health
|
||||||
|
|
||||||
To monitor the health of litellm adjacent services (redis / postgres), do:
|
To monitor the health of litellm adjacent services (redis / postgres), do:
|
||||||
|
|
|
@ -815,6 +815,35 @@ model_list:
|
||||||
</TabItem>
|
</TabItem>
|
||||||
</Tabs>
|
</Tabs>
|
||||||
|
|
||||||
|
**Expected Response**
|
||||||
|
|
||||||
|
```
|
||||||
|
No deployments available for selected model, Try again in 60 seconds. Passed model=claude-3-5-sonnet. pre-call-checks=False, allowed_model_region=n/a.
|
||||||
|
```
|
||||||
|
|
||||||
|
#### **Disable cooldowns**
|
||||||
|
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="sdk" label="SDK">
|
||||||
|
|
||||||
|
```python
|
||||||
|
from litellm import Router
|
||||||
|
|
||||||
|
|
||||||
|
router = Router(..., disable_cooldowns=True)
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="proxy" label="PROXY">
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
router_settings:
|
||||||
|
disable_cooldowns: True
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
### Retries
|
### Retries
|
||||||
|
|
||||||
For both async + sync functions, we support retrying failed requests.
|
For both async + sync functions, we support retrying failed requests.
|
||||||
|
|
|
@ -8,7 +8,13 @@ LiteLLM supports reading secrets from Azure Key Vault and Infisical
|
||||||
- [Infisical Secret Manager](#infisical-secret-manager)
|
- [Infisical Secret Manager](#infisical-secret-manager)
|
||||||
- [.env Files](#env-files)
|
- [.env Files](#env-files)
|
||||||
|
|
||||||
## AWS Key Management Service
|
## AWS Key Management V1
|
||||||
|
|
||||||
|
:::tip
|
||||||
|
|
||||||
|
[BETA] AWS Key Management v2 is on the enterprise tier. Go [here for docs](./proxy/enterprise.md#beta-aws-key-manager---key-decryption)
|
||||||
|
|
||||||
|
:::
|
||||||
|
|
||||||
Use AWS KMS to storing a hashed copy of your Proxy Master Key in the environment.
|
Use AWS KMS to storing a hashed copy of your Proxy Master Key in the environment.
|
||||||
|
|
||||||
|
|
BIN
docs/my-website/img/proxy_langfuse.png
Normal file
BIN
docs/my-website/img/proxy_langfuse.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 212 KiB |
Binary file not shown.
Before Width: | Height: | Size: 142 KiB After Width: | Height: | Size: 206 KiB |
|
@ -48,6 +48,7 @@ const sidebars = {
|
||||||
"proxy/billing",
|
"proxy/billing",
|
||||||
"proxy/user_keys",
|
"proxy/user_keys",
|
||||||
"proxy/virtual_keys",
|
"proxy/virtual_keys",
|
||||||
|
"proxy/token_auth",
|
||||||
"proxy/alerting",
|
"proxy/alerting",
|
||||||
{
|
{
|
||||||
type: "category",
|
type: "category",
|
||||||
|
@ -56,11 +57,11 @@ const sidebars = {
|
||||||
},
|
},
|
||||||
"proxy/ui",
|
"proxy/ui",
|
||||||
"proxy/prometheus",
|
"proxy/prometheus",
|
||||||
|
"proxy/pass_through",
|
||||||
"proxy/email",
|
"proxy/email",
|
||||||
"proxy/multiple_admins",
|
"proxy/multiple_admins",
|
||||||
"proxy/team_based_routing",
|
"proxy/team_based_routing",
|
||||||
"proxy/customer_routing",
|
"proxy/customer_routing",
|
||||||
"proxy/token_auth",
|
|
||||||
{
|
{
|
||||||
type: "category",
|
type: "category",
|
||||||
label: "Extra Load Balancing",
|
label: "Extra Load Balancing",
|
||||||
|
|
|
@ -114,7 +114,11 @@ class _ENTERPRISE_lakeraAI_Moderation(CustomLogger):
|
||||||
|
|
||||||
if flagged == True:
|
if flagged == True:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400, detail={"error": "Violated content safety policy"}
|
status_code=400,
|
||||||
|
detail={
|
||||||
|
"error": "Violated content safety policy",
|
||||||
|
"lakera_ai_response": _json_response,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -1,48 +1,13 @@
|
||||||
#!/bin/sh
|
#!/bin/bash
|
||||||
|
echo $(pwd)
|
||||||
|
|
||||||
# Check if DATABASE_URL is not set
|
# Run the Python migration script
|
||||||
if [ -z "$DATABASE_URL" ]; then
|
python3 litellm/proxy/prisma_migration.py
|
||||||
# Check if all required variables are provided
|
|
||||||
if [ -n "$DATABASE_HOST" ] && [ -n "$DATABASE_USERNAME" ] && [ -n "$DATABASE_PASSWORD" ] && [ -n "$DATABASE_NAME" ]; then
|
|
||||||
# Construct DATABASE_URL from the provided variables
|
|
||||||
DATABASE_URL="postgresql://${DATABASE_USERNAME}:${DATABASE_PASSWORD}@${DATABASE_HOST}/${DATABASE_NAME}"
|
|
||||||
export DATABASE_URL
|
|
||||||
else
|
|
||||||
echo "Error: Required database environment variables are not set. Provide a postgres url for DATABASE_URL."
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Set DIRECT_URL to the value of DATABASE_URL if it is not set, required for migrations
|
# Check if the Python script executed successfully
|
||||||
if [ -z "$DIRECT_URL" ]; then
|
if [ $? -eq 0 ]; then
|
||||||
export DIRECT_URL=$DATABASE_URL
|
echo "Migration script ran successfully!"
|
||||||
fi
|
else
|
||||||
|
echo "Migration script failed!"
|
||||||
# Apply migrations
|
|
||||||
retry_count=0
|
|
||||||
max_retries=3
|
|
||||||
exit_code=1
|
|
||||||
|
|
||||||
until [ $retry_count -ge $max_retries ] || [ $exit_code -eq 0 ]
|
|
||||||
do
|
|
||||||
retry_count=$((retry_count+1))
|
|
||||||
echo "Attempt $retry_count..."
|
|
||||||
|
|
||||||
# Run the Prisma db push command
|
|
||||||
prisma db push --accept-data-loss
|
|
||||||
|
|
||||||
exit_code=$?
|
|
||||||
|
|
||||||
if [ $exit_code -ne 0 ] && [ $retry_count -lt $max_retries ]; then
|
|
||||||
echo "Retrying in 10 seconds..."
|
|
||||||
sleep 10
|
|
||||||
fi
|
|
||||||
done
|
|
||||||
|
|
||||||
if [ $exit_code -ne 0 ]; then
|
|
||||||
echo "Unable to push database changes after $max_retries retries."
|
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
echo "Database push successful!"
|
|
||||||
|
|
||||||
|
|
|
@ -125,6 +125,9 @@ llm_guard_mode: Literal["all", "key-specific", "request-specific"] = "all"
|
||||||
##################
|
##################
|
||||||
### PREVIEW FEATURES ###
|
### PREVIEW FEATURES ###
|
||||||
enable_preview_features: bool = False
|
enable_preview_features: bool = False
|
||||||
|
return_response_headers: bool = (
|
||||||
|
False # get response headers from LLM Api providers - example x-remaining-requests,
|
||||||
|
)
|
||||||
##################
|
##################
|
||||||
logging: bool = True
|
logging: bool = True
|
||||||
caching: bool = (
|
caching: bool = (
|
||||||
|
@ -749,6 +752,7 @@ from .utils import (
|
||||||
create_pretrained_tokenizer,
|
create_pretrained_tokenizer,
|
||||||
create_tokenizer,
|
create_tokenizer,
|
||||||
supports_function_calling,
|
supports_function_calling,
|
||||||
|
supports_response_schema,
|
||||||
supports_parallel_function_calling,
|
supports_parallel_function_calling,
|
||||||
supports_vision,
|
supports_vision,
|
||||||
supports_system_messages,
|
supports_system_messages,
|
||||||
|
@ -799,7 +803,11 @@ from .llms.sagemaker import SagemakerConfig
|
||||||
from .llms.ollama import OllamaConfig
|
from .llms.ollama import OllamaConfig
|
||||||
from .llms.ollama_chat import OllamaChatConfig
|
from .llms.ollama_chat import OllamaChatConfig
|
||||||
from .llms.maritalk import MaritTalkConfig
|
from .llms.maritalk import MaritTalkConfig
|
||||||
from .llms.bedrock_httpx import AmazonCohereChatConfig, AmazonConverseConfig
|
from .llms.bedrock_httpx import (
|
||||||
|
AmazonCohereChatConfig,
|
||||||
|
AmazonConverseConfig,
|
||||||
|
BEDROCK_CONVERSE_MODELS,
|
||||||
|
)
|
||||||
from .llms.bedrock import (
|
from .llms.bedrock import (
|
||||||
AmazonTitanConfig,
|
AmazonTitanConfig,
|
||||||
AmazonAI21Config,
|
AmazonAI21Config,
|
||||||
|
@ -848,6 +856,7 @@ from .exceptions import (
|
||||||
APIResponseValidationError,
|
APIResponseValidationError,
|
||||||
UnprocessableEntityError,
|
UnprocessableEntityError,
|
||||||
InternalServerError,
|
InternalServerError,
|
||||||
|
JSONSchemaValidationError,
|
||||||
LITELLM_EXCEPTION_TYPES,
|
LITELLM_EXCEPTION_TYPES,
|
||||||
)
|
)
|
||||||
from .budget_manager import BudgetManager
|
from .budget_manager import BudgetManager
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
# What is this?
|
# What is this?
|
||||||
## File for 'response_cost' calculation in Logging
|
## File for 'response_cost' calculation in Logging
|
||||||
import time
|
import time
|
||||||
|
import traceback
|
||||||
from typing import List, Literal, Optional, Tuple, Union
|
from typing import List, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
@ -668,3 +669,10 @@ def response_cost_calculator(
|
||||||
f"Model={model} for LLM Provider={custom_llm_provider} not found in completion cost map."
|
f"Model={model} for LLM Provider={custom_llm_provider} not found in completion cost map."
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
verbose_logger.error(
|
||||||
|
"litellm.cost_calculator.py::response_cost_calculator - Exception occurred - {}/n{}".format(
|
||||||
|
str(e), traceback.format_exc()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
|
@ -551,7 +551,7 @@ class APIError(openai.APIError): # type: ignore
|
||||||
message,
|
message,
|
||||||
llm_provider,
|
llm_provider,
|
||||||
model,
|
model,
|
||||||
request: httpx.Request,
|
request: Optional[httpx.Request] = None,
|
||||||
litellm_debug_info: Optional[str] = None,
|
litellm_debug_info: Optional[str] = None,
|
||||||
max_retries: Optional[int] = None,
|
max_retries: Optional[int] = None,
|
||||||
num_retries: Optional[int] = None,
|
num_retries: Optional[int] = None,
|
||||||
|
@ -563,6 +563,8 @@ class APIError(openai.APIError): # type: ignore
|
||||||
self.litellm_debug_info = litellm_debug_info
|
self.litellm_debug_info = litellm_debug_info
|
||||||
self.max_retries = max_retries
|
self.max_retries = max_retries
|
||||||
self.num_retries = num_retries
|
self.num_retries = num_retries
|
||||||
|
if request is None:
|
||||||
|
request = httpx.Request(method="POST", url="https://api.openai.com/v1")
|
||||||
super().__init__(self.message, request=request, body=None) # type: ignore
|
super().__init__(self.message, request=request, body=None) # type: ignore
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
|
@ -664,6 +666,22 @@ class OpenAIError(openai.OpenAIError): # type: ignore
|
||||||
self.llm_provider = "openai"
|
self.llm_provider = "openai"
|
||||||
|
|
||||||
|
|
||||||
|
class JSONSchemaValidationError(APIError):
|
||||||
|
def __init__(
|
||||||
|
self, model: str, llm_provider: str, raw_response: str, schema: str
|
||||||
|
) -> None:
|
||||||
|
self.raw_response = raw_response
|
||||||
|
self.schema = schema
|
||||||
|
self.model = model
|
||||||
|
message = "litellm.JSONSchemaValidationError: model={}, returned an invalid response={}, for schema={}.\nAccess raw response with `e.raw_response`".format(
|
||||||
|
model, raw_response, schema
|
||||||
|
)
|
||||||
|
self.message = message
|
||||||
|
super().__init__(
|
||||||
|
model=model, message=message, llm_provider=llm_provider, status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
LITELLM_EXCEPTION_TYPES = [
|
LITELLM_EXCEPTION_TYPES = [
|
||||||
AuthenticationError,
|
AuthenticationError,
|
||||||
NotFoundError,
|
NotFoundError,
|
||||||
|
@ -682,6 +700,7 @@ LITELLM_EXCEPTION_TYPES = [
|
||||||
APIResponseValidationError,
|
APIResponseValidationError,
|
||||||
OpenAIError,
|
OpenAIError,
|
||||||
InternalServerError,
|
InternalServerError,
|
||||||
|
JSONSchemaValidationError,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -311,11 +311,6 @@ class LangFuseLogger:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
tags = []
|
tags = []
|
||||||
try:
|
|
||||||
metadata = copy.deepcopy(
|
|
||||||
metadata
|
|
||||||
) # Avoid modifying the original metadata
|
|
||||||
except:
|
|
||||||
new_metadata = {}
|
new_metadata = {}
|
||||||
for key, value in metadata.items():
|
for key, value in metadata.items():
|
||||||
if (
|
if (
|
||||||
|
|
|
@ -2,14 +2,20 @@
|
||||||
#### What this does ####
|
#### What this does ####
|
||||||
# On success, log events to Prometheus
|
# On success, log events to Prometheus
|
||||||
|
|
||||||
import dotenv, os
|
import datetime
|
||||||
import requests # type: ignore
|
import os
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
import datetime, subprocess, sys
|
import uuid
|
||||||
import litellm, uuid
|
|
||||||
from litellm._logging import print_verbose, verbose_logger
|
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import dotenv
|
||||||
|
import requests # type: ignore
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm._logging import print_verbose, verbose_logger
|
||||||
|
|
||||||
|
|
||||||
class PrometheusLogger:
|
class PrometheusLogger:
|
||||||
# Class variables or attributes
|
# Class variables or attributes
|
||||||
|
@ -20,6 +26,8 @@ class PrometheusLogger:
|
||||||
try:
|
try:
|
||||||
from prometheus_client import Counter, Gauge
|
from prometheus_client import Counter, Gauge
|
||||||
|
|
||||||
|
from litellm.proxy.proxy_server import premium_user
|
||||||
|
|
||||||
self.litellm_llm_api_failed_requests_metric = Counter(
|
self.litellm_llm_api_failed_requests_metric = Counter(
|
||||||
name="litellm_llm_api_failed_requests_metric",
|
name="litellm_llm_api_failed_requests_metric",
|
||||||
documentation="Total number of failed LLM API calls via litellm",
|
documentation="Total number of failed LLM API calls via litellm",
|
||||||
|
@ -88,6 +96,31 @@ class PrometheusLogger:
|
||||||
labelnames=["hashed_api_key", "api_key_alias"],
|
labelnames=["hashed_api_key", "api_key_alias"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Litellm-Enterprise Metrics
|
||||||
|
if premium_user is True:
|
||||||
|
# Remaining Rate Limit for model
|
||||||
|
self.litellm_remaining_requests_metric = Gauge(
|
||||||
|
"litellm_remaining_requests",
|
||||||
|
"remaining requests for model, returned from LLM API Provider",
|
||||||
|
labelnames=[
|
||||||
|
"model_group",
|
||||||
|
"api_provider",
|
||||||
|
"api_base",
|
||||||
|
"litellm_model_name",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.litellm_remaining_tokens_metric = Gauge(
|
||||||
|
"litellm_remaining_tokens",
|
||||||
|
"remaining tokens for model, returned from LLM API Provider",
|
||||||
|
labelnames=[
|
||||||
|
"model_group",
|
||||||
|
"api_provider",
|
||||||
|
"api_base",
|
||||||
|
"litellm_model_name",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print_verbose(f"Got exception on init prometheus client {str(e)}")
|
print_verbose(f"Got exception on init prometheus client {str(e)}")
|
||||||
raise e
|
raise e
|
||||||
|
@ -104,6 +137,8 @@ class PrometheusLogger:
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
# Define prometheus client
|
# Define prometheus client
|
||||||
|
from litellm.proxy.proxy_server import premium_user
|
||||||
|
|
||||||
verbose_logger.debug(
|
verbose_logger.debug(
|
||||||
f"prometheus Logging - Enters logging function for model {kwargs}"
|
f"prometheus Logging - Enters logging function for model {kwargs}"
|
||||||
)
|
)
|
||||||
|
@ -199,6 +234,10 @@ class PrometheusLogger:
|
||||||
user_api_key, user_api_key_alias
|
user_api_key, user_api_key_alias
|
||||||
).set(_remaining_api_key_budget)
|
).set(_remaining_api_key_budget)
|
||||||
|
|
||||||
|
# set x-ratelimit headers
|
||||||
|
if premium_user is True:
|
||||||
|
self.set_remaining_tokens_requests_metric(kwargs)
|
||||||
|
|
||||||
### FAILURE INCREMENT ###
|
### FAILURE INCREMENT ###
|
||||||
if "exception" in kwargs:
|
if "exception" in kwargs:
|
||||||
self.litellm_llm_api_failed_requests_metric.labels(
|
self.litellm_llm_api_failed_requests_metric.labels(
|
||||||
|
@ -216,6 +255,58 @@ class PrometheusLogger:
|
||||||
verbose_logger.debug(traceback.format_exc())
|
verbose_logger.debug(traceback.format_exc())
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def set_remaining_tokens_requests_metric(self, request_kwargs: dict):
|
||||||
|
try:
|
||||||
|
verbose_logger.debug("setting remaining tokens requests metric")
|
||||||
|
_response_headers = request_kwargs.get("response_headers")
|
||||||
|
_litellm_params = request_kwargs.get("litellm_params", {}) or {}
|
||||||
|
_metadata = _litellm_params.get("metadata", {})
|
||||||
|
litellm_model_name = request_kwargs.get("model", None)
|
||||||
|
model_group = _metadata.get("model_group", None)
|
||||||
|
api_base = _metadata.get("api_base", None)
|
||||||
|
llm_provider = _litellm_params.get("custom_llm_provider", None)
|
||||||
|
|
||||||
|
remaining_requests = None
|
||||||
|
remaining_tokens = None
|
||||||
|
# OpenAI / OpenAI Compatible headers
|
||||||
|
if (
|
||||||
|
_response_headers
|
||||||
|
and "x-ratelimit-remaining-requests" in _response_headers
|
||||||
|
):
|
||||||
|
remaining_requests = _response_headers["x-ratelimit-remaining-requests"]
|
||||||
|
if (
|
||||||
|
_response_headers
|
||||||
|
and "x-ratelimit-remaining-tokens" in _response_headers
|
||||||
|
):
|
||||||
|
remaining_tokens = _response_headers["x-ratelimit-remaining-tokens"]
|
||||||
|
verbose_logger.debug(
|
||||||
|
f"remaining requests: {remaining_requests}, remaining tokens: {remaining_tokens}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if remaining_requests:
|
||||||
|
"""
|
||||||
|
"model_group",
|
||||||
|
"api_provider",
|
||||||
|
"api_base",
|
||||||
|
"litellm_model_name"
|
||||||
|
"""
|
||||||
|
self.litellm_remaining_requests_metric.labels(
|
||||||
|
model_group, llm_provider, api_base, litellm_model_name
|
||||||
|
).set(remaining_requests)
|
||||||
|
|
||||||
|
if remaining_tokens:
|
||||||
|
self.litellm_remaining_tokens_metric.labels(
|
||||||
|
model_group, llm_provider, api_base, litellm_model_name
|
||||||
|
).set(remaining_tokens)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
verbose_logger.error(
|
||||||
|
"Prometheus Error: set_remaining_tokens_requests_metric. Exception occured - {}".format(
|
||||||
|
str(e)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
def safe_get_remaining_budget(
|
def safe_get_remaining_budget(
|
||||||
max_budget: Optional[float], spend: Optional[float]
|
max_budget: Optional[float], spend: Optional[float]
|
||||||
|
|
|
@ -1,10 +1,14 @@
|
||||||
#### What this does ####
|
#### What this does ####
|
||||||
# On success + failure, log events to Supabase
|
# On success + failure, log events to Supabase
|
||||||
|
|
||||||
|
import datetime
|
||||||
import os
|
import os
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
import datetime, subprocess, sys
|
import uuid
|
||||||
import litellm, uuid
|
|
||||||
|
import litellm
|
||||||
from litellm._logging import print_verbose, verbose_logger
|
from litellm._logging import print_verbose, verbose_logger
|
||||||
|
|
||||||
|
|
||||||
|
@ -54,6 +58,7 @@ class S3Logger:
|
||||||
"s3_aws_session_token"
|
"s3_aws_session_token"
|
||||||
)
|
)
|
||||||
s3_config = litellm.s3_callback_params.get("s3_config")
|
s3_config = litellm.s3_callback_params.get("s3_config")
|
||||||
|
s3_path = litellm.s3_callback_params.get("s3_path")
|
||||||
# done reading litellm.s3_callback_params
|
# done reading litellm.s3_callback_params
|
||||||
|
|
||||||
self.bucket_name = s3_bucket_name
|
self.bucket_name = s3_bucket_name
|
||||||
|
|
|
@ -26,7 +26,7 @@ def map_finish_reason(
|
||||||
finish_reason == "FINISH_REASON_UNSPECIFIED" or finish_reason == "STOP"
|
finish_reason == "FINISH_REASON_UNSPECIFIED" or finish_reason == "STOP"
|
||||||
): # vertex ai - got from running `print(dir(response_obj.candidates[0].finish_reason))`: ['FINISH_REASON_UNSPECIFIED', 'MAX_TOKENS', 'OTHER', 'RECITATION', 'SAFETY', 'STOP',]
|
): # vertex ai - got from running `print(dir(response_obj.candidates[0].finish_reason))`: ['FINISH_REASON_UNSPECIFIED', 'MAX_TOKENS', 'OTHER', 'RECITATION', 'SAFETY', 'STOP',]
|
||||||
return "stop"
|
return "stop"
|
||||||
elif finish_reason == "SAFETY": # vertex ai
|
elif finish_reason == "SAFETY" or finish_reason == "RECITATION": # vertex ai
|
||||||
return "content_filter"
|
return "content_filter"
|
||||||
elif finish_reason == "STOP": # vertex ai
|
elif finish_reason == "STOP": # vertex ai
|
||||||
return "stop"
|
return "stop"
|
||||||
|
|
40
litellm/litellm_core_utils/exception_mapping_utils.py
Normal file
40
litellm/litellm_core_utils/exception_mapping_utils.py
Normal file
|
@ -0,0 +1,40 @@
|
||||||
|
import json
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
def get_error_message(error_obj) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
OpenAI Returns Error message that is nested, this extract the message
|
||||||
|
|
||||||
|
Example:
|
||||||
|
{
|
||||||
|
'request': "<Request('POST', 'https://api.openai.com/v1/chat/completions')>",
|
||||||
|
'message': "Error code: 400 - {\'error\': {\'message\': \"Invalid 'temperature': decimal above maximum value. Expected a value <= 2, but got 200 instead.\", 'type': 'invalid_request_error', 'param': 'temperature', 'code': 'decimal_above_max_value'}}",
|
||||||
|
'body': {
|
||||||
|
'message': "Invalid 'temperature': decimal above maximum value. Expected a value <= 2, but got 200 instead.",
|
||||||
|
'type': 'invalid_request_error',
|
||||||
|
'param': 'temperature',
|
||||||
|
'code': 'decimal_above_max_value'
|
||||||
|
},
|
||||||
|
'code': 'decimal_above_max_value',
|
||||||
|
'param': 'temperature',
|
||||||
|
'type': 'invalid_request_error',
|
||||||
|
'response': "<Response [400 Bad Request]>",
|
||||||
|
'status_code': 400,
|
||||||
|
'request_id': 'req_f287898caa6364cd42bc01355f74dd2a'
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# First, try to access the message directly from the 'body' key
|
||||||
|
if error_obj is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if hasattr(error_obj, "body"):
|
||||||
|
_error_obj_body = getattr(error_obj, "body")
|
||||||
|
if isinstance(_error_obj_body, dict):
|
||||||
|
return _error_obj_body.get("message")
|
||||||
|
|
||||||
|
# If all else fails, return None
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
return None
|
23
litellm/litellm_core_utils/json_validation_rule.py
Normal file
23
litellm/litellm_core_utils/json_validation_rule.py
Normal file
|
@ -0,0 +1,23 @@
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
def validate_schema(schema: dict, response: str):
|
||||||
|
"""
|
||||||
|
Validate if the returned json response follows the schema.
|
||||||
|
|
||||||
|
Params:
|
||||||
|
- schema - dict: JSON schema
|
||||||
|
- response - str: Received json response as string.
|
||||||
|
"""
|
||||||
|
from jsonschema import ValidationError, validate
|
||||||
|
|
||||||
|
from litellm import JSONSchemaValidationError
|
||||||
|
|
||||||
|
response_dict = json.loads(response)
|
||||||
|
|
||||||
|
try:
|
||||||
|
validate(response_dict, schema=schema)
|
||||||
|
except ValidationError:
|
||||||
|
raise JSONSchemaValidationError(
|
||||||
|
model="", llm_provider="", raw_response=response, schema=json.dumps(schema)
|
||||||
|
)
|
|
@ -1,23 +1,28 @@
|
||||||
import os, types
|
import copy
|
||||||
import json
|
import json
|
||||||
from enum import Enum
|
import os
|
||||||
import requests, copy # type: ignore
|
|
||||||
import time
|
import time
|
||||||
|
import types
|
||||||
|
from enum import Enum
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Callable, Optional, List, Union
|
from typing import Callable, List, Optional, Union
|
||||||
import litellm.litellm_core_utils
|
|
||||||
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper
|
import httpx # type: ignore
|
||||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
import requests # type: ignore
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
import litellm.litellm_core_utils
|
||||||
|
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||||
from litellm.llms.custom_httpx.http_handler import (
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
AsyncHTTPHandler,
|
AsyncHTTPHandler,
|
||||||
_get_async_httpx_client,
|
_get_async_httpx_client,
|
||||||
_get_httpx_client,
|
_get_httpx_client,
|
||||||
)
|
)
|
||||||
from .base import BaseLLM
|
|
||||||
import httpx # type: ignore
|
|
||||||
from litellm.types.llms.anthropic import AnthropicMessagesToolChoice
|
from litellm.types.llms.anthropic import AnthropicMessagesToolChoice
|
||||||
|
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
|
||||||
|
|
||||||
|
from .base import BaseLLM
|
||||||
|
from .prompt_templates.factory import custom_prompt, prompt_factory
|
||||||
|
|
||||||
|
|
||||||
class AnthropicConstants(Enum):
|
class AnthropicConstants(Enum):
|
||||||
|
@ -179,10 +184,19 @@ async def make_call(
|
||||||
if client is None:
|
if client is None:
|
||||||
client = _get_async_httpx_client() # Create a new client if none provided
|
client = _get_async_httpx_client() # Create a new client if none provided
|
||||||
|
|
||||||
|
try:
|
||||||
response = await client.post(api_base, headers=headers, data=data, stream=True)
|
response = await client.post(api_base, headers=headers, data=data, stream=True)
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
raise AnthropicError(
|
||||||
|
status_code=e.response.status_code, message=await e.response.aread()
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise AnthropicError(status_code=500, message=str(e))
|
||||||
|
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
raise AnthropicError(status_code=response.status_code, message=response.text)
|
raise AnthropicError(
|
||||||
|
status_code=response.status_code, message=await response.aread()
|
||||||
|
)
|
||||||
|
|
||||||
completion_stream = response.aiter_lines()
|
completion_stream = response.aiter_lines()
|
||||||
|
|
||||||
|
|
|
@ -23,6 +23,7 @@ from typing_extensions import overload
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import OpenAIConfig
|
from litellm import OpenAIConfig
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache
|
||||||
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||||
from litellm.utils import (
|
from litellm.utils import (
|
||||||
Choices,
|
Choices,
|
||||||
CustomStreamWrapper,
|
CustomStreamWrapper,
|
||||||
|
@ -458,6 +459,36 @@ class AzureChatCompletion(BaseLLM):
|
||||||
|
|
||||||
return azure_client
|
return azure_client
|
||||||
|
|
||||||
|
async def make_azure_openai_chat_completion_request(
|
||||||
|
self,
|
||||||
|
azure_client: AsyncAzureOpenAI,
|
||||||
|
data: dict,
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Helper to:
|
||||||
|
- call chat.completions.create.with_raw_response when litellm.return_response_headers is True
|
||||||
|
- call chat.completions.create by default
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if litellm.return_response_headers is True:
|
||||||
|
raw_response = (
|
||||||
|
await azure_client.chat.completions.with_raw_response.create(
|
||||||
|
**data, timeout=timeout
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
headers = dict(raw_response.headers)
|
||||||
|
response = raw_response.parse()
|
||||||
|
return headers, response
|
||||||
|
else:
|
||||||
|
response = await azure_client.chat.completions.create(
|
||||||
|
**data, timeout=timeout
|
||||||
|
)
|
||||||
|
return None, response
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
def completion(
|
def completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
@ -470,7 +501,7 @@ class AzureChatCompletion(BaseLLM):
|
||||||
azure_ad_token: str,
|
azure_ad_token: str,
|
||||||
print_verbose: Callable,
|
print_verbose: Callable,
|
||||||
timeout: Union[float, httpx.Timeout],
|
timeout: Union[float, httpx.Timeout],
|
||||||
logging_obj,
|
logging_obj: LiteLLMLoggingObj,
|
||||||
optional_params,
|
optional_params,
|
||||||
litellm_params,
|
litellm_params,
|
||||||
logger_fn,
|
logger_fn,
|
||||||
|
@ -649,9 +680,9 @@ class AzureChatCompletion(BaseLLM):
|
||||||
data: dict,
|
data: dict,
|
||||||
timeout: Any,
|
timeout: Any,
|
||||||
model_response: ModelResponse,
|
model_response: ModelResponse,
|
||||||
|
logging_obj: LiteLLMLoggingObj,
|
||||||
azure_ad_token: Optional[str] = None,
|
azure_ad_token: Optional[str] = None,
|
||||||
client=None, # this is the AsyncAzureOpenAI
|
client=None, # this is the AsyncAzureOpenAI
|
||||||
logging_obj=None,
|
|
||||||
):
|
):
|
||||||
response = None
|
response = None
|
||||||
try:
|
try:
|
||||||
|
@ -701,9 +732,13 @@ class AzureChatCompletion(BaseLLM):
|
||||||
"complete_input_dict": data,
|
"complete_input_dict": data,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
response = await azure_client.chat.completions.create(
|
|
||||||
**data, timeout=timeout
|
headers, response = await self.make_azure_openai_chat_completion_request(
|
||||||
|
azure_client=azure_client,
|
||||||
|
data=data,
|
||||||
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
|
logging_obj.model_call_details["response_headers"] = headers
|
||||||
|
|
||||||
stringified_response = response.model_dump()
|
stringified_response = response.model_dump()
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
|
@ -717,11 +752,32 @@ class AzureChatCompletion(BaseLLM):
|
||||||
model_response_object=model_response,
|
model_response_object=model_response,
|
||||||
)
|
)
|
||||||
except AzureOpenAIError as e:
|
except AzureOpenAIError as e:
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=data["messages"],
|
||||||
|
api_key=api_key,
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
original_response=str(e),
|
||||||
|
)
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
raise e
|
raise e
|
||||||
except asyncio.CancelledError as e:
|
except asyncio.CancelledError as e:
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=data["messages"],
|
||||||
|
api_key=api_key,
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
original_response=str(e),
|
||||||
|
)
|
||||||
raise AzureOpenAIError(status_code=500, message=str(e))
|
raise AzureOpenAIError(status_code=500, message=str(e))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=data["messages"],
|
||||||
|
api_key=api_key,
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
original_response=str(e),
|
||||||
|
)
|
||||||
if hasattr(e, "status_code"):
|
if hasattr(e, "status_code"):
|
||||||
raise e
|
raise e
|
||||||
else:
|
else:
|
||||||
|
@ -791,7 +847,7 @@ class AzureChatCompletion(BaseLLM):
|
||||||
|
|
||||||
async def async_streaming(
|
async def async_streaming(
|
||||||
self,
|
self,
|
||||||
logging_obj,
|
logging_obj: LiteLLMLoggingObj,
|
||||||
api_base: str,
|
api_base: str,
|
||||||
api_key: str,
|
api_key: str,
|
||||||
api_version: str,
|
api_version: str,
|
||||||
|
@ -840,9 +896,14 @@ class AzureChatCompletion(BaseLLM):
|
||||||
"complete_input_dict": data,
|
"complete_input_dict": data,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
response = await azure_client.chat.completions.create(
|
|
||||||
**data, timeout=timeout
|
headers, response = await self.make_azure_openai_chat_completion_request(
|
||||||
|
azure_client=azure_client,
|
||||||
|
data=data,
|
||||||
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
|
logging_obj.model_call_details["response_headers"] = headers
|
||||||
|
|
||||||
# return response
|
# return response
|
||||||
streamwrapper = CustomStreamWrapper(
|
streamwrapper = CustomStreamWrapper(
|
||||||
completion_stream=response,
|
completion_stream=response,
|
||||||
|
|
|
@ -60,6 +60,17 @@ from .prompt_templates.factory import (
|
||||||
prompt_factory,
|
prompt_factory,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
BEDROCK_CONVERSE_MODELS = [
|
||||||
|
"anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||||
|
"anthropic.claude-3-opus-20240229-v1:0",
|
||||||
|
"anthropic.claude-3-sonnet-20240229-v1:0",
|
||||||
|
"anthropic.claude-3-haiku-20240307-v1:0",
|
||||||
|
"anthropic.claude-v2",
|
||||||
|
"anthropic.claude-v2:1",
|
||||||
|
"anthropic.claude-v1",
|
||||||
|
"anthropic.claude-instant-v1",
|
||||||
|
]
|
||||||
|
|
||||||
iam_cache = DualCache()
|
iam_cache = DualCache()
|
||||||
|
|
||||||
|
|
||||||
|
@ -305,6 +316,7 @@ class BedrockLLM(BaseLLM):
|
||||||
self,
|
self,
|
||||||
aws_access_key_id: Optional[str] = None,
|
aws_access_key_id: Optional[str] = None,
|
||||||
aws_secret_access_key: Optional[str] = None,
|
aws_secret_access_key: Optional[str] = None,
|
||||||
|
aws_session_token: Optional[str] = None,
|
||||||
aws_region_name: Optional[str] = None,
|
aws_region_name: Optional[str] = None,
|
||||||
aws_session_name: Optional[str] = None,
|
aws_session_name: Optional[str] = None,
|
||||||
aws_profile_name: Optional[str] = None,
|
aws_profile_name: Optional[str] = None,
|
||||||
|
@ -320,6 +332,7 @@ class BedrockLLM(BaseLLM):
|
||||||
params_to_check: List[Optional[str]] = [
|
params_to_check: List[Optional[str]] = [
|
||||||
aws_access_key_id,
|
aws_access_key_id,
|
||||||
aws_secret_access_key,
|
aws_secret_access_key,
|
||||||
|
aws_session_token,
|
||||||
aws_region_name,
|
aws_region_name,
|
||||||
aws_session_name,
|
aws_session_name,
|
||||||
aws_profile_name,
|
aws_profile_name,
|
||||||
|
@ -337,6 +350,7 @@ class BedrockLLM(BaseLLM):
|
||||||
(
|
(
|
||||||
aws_access_key_id,
|
aws_access_key_id,
|
||||||
aws_secret_access_key,
|
aws_secret_access_key,
|
||||||
|
aws_session_token,
|
||||||
aws_region_name,
|
aws_region_name,
|
||||||
aws_session_name,
|
aws_session_name,
|
||||||
aws_profile_name,
|
aws_profile_name,
|
||||||
|
@ -430,6 +444,19 @@ class BedrockLLM(BaseLLM):
|
||||||
client = boto3.Session(profile_name=aws_profile_name)
|
client = boto3.Session(profile_name=aws_profile_name)
|
||||||
|
|
||||||
return client.get_credentials()
|
return client.get_credentials()
|
||||||
|
elif (
|
||||||
|
aws_access_key_id is not None
|
||||||
|
and aws_secret_access_key is not None
|
||||||
|
and aws_session_token is not None
|
||||||
|
): ### CHECK FOR AWS SESSION TOKEN ###
|
||||||
|
from botocore.credentials import Credentials
|
||||||
|
|
||||||
|
credentials = Credentials(
|
||||||
|
access_key=aws_access_key_id,
|
||||||
|
secret_key=aws_secret_access_key,
|
||||||
|
token=aws_session_token,
|
||||||
|
)
|
||||||
|
return credentials
|
||||||
else:
|
else:
|
||||||
session = boto3.Session(
|
session = boto3.Session(
|
||||||
aws_access_key_id=aws_access_key_id,
|
aws_access_key_id=aws_access_key_id,
|
||||||
|
@ -734,9 +761,10 @@ class BedrockLLM(BaseLLM):
|
||||||
provider = model.split(".")[0]
|
provider = model.split(".")[0]
|
||||||
|
|
||||||
## CREDENTIALS ##
|
## CREDENTIALS ##
|
||||||
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
|
# pop aws_secret_access_key, aws_access_key_id, aws_session_token, aws_region_name from kwargs, since completion calls fail with them
|
||||||
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
||||||
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
|
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
|
||||||
|
aws_session_token = optional_params.pop("aws_session_token", None)
|
||||||
aws_region_name = optional_params.pop("aws_region_name", None)
|
aws_region_name = optional_params.pop("aws_region_name", None)
|
||||||
aws_role_name = optional_params.pop("aws_role_name", None)
|
aws_role_name = optional_params.pop("aws_role_name", None)
|
||||||
aws_session_name = optional_params.pop("aws_session_name", None)
|
aws_session_name = optional_params.pop("aws_session_name", None)
|
||||||
|
@ -768,6 +796,7 @@ class BedrockLLM(BaseLLM):
|
||||||
credentials: Credentials = self.get_credentials(
|
credentials: Credentials = self.get_credentials(
|
||||||
aws_access_key_id=aws_access_key_id,
|
aws_access_key_id=aws_access_key_id,
|
||||||
aws_secret_access_key=aws_secret_access_key,
|
aws_secret_access_key=aws_secret_access_key,
|
||||||
|
aws_session_token=aws_session_token,
|
||||||
aws_region_name=aws_region_name,
|
aws_region_name=aws_region_name,
|
||||||
aws_session_name=aws_session_name,
|
aws_session_name=aws_session_name,
|
||||||
aws_profile_name=aws_profile_name,
|
aws_profile_name=aws_profile_name,
|
||||||
|
@ -1422,6 +1451,7 @@ class BedrockConverseLLM(BaseLLM):
|
||||||
self,
|
self,
|
||||||
aws_access_key_id: Optional[str] = None,
|
aws_access_key_id: Optional[str] = None,
|
||||||
aws_secret_access_key: Optional[str] = None,
|
aws_secret_access_key: Optional[str] = None,
|
||||||
|
aws_session_token: Optional[str] = None,
|
||||||
aws_region_name: Optional[str] = None,
|
aws_region_name: Optional[str] = None,
|
||||||
aws_session_name: Optional[str] = None,
|
aws_session_name: Optional[str] = None,
|
||||||
aws_profile_name: Optional[str] = None,
|
aws_profile_name: Optional[str] = None,
|
||||||
|
@ -1437,6 +1467,7 @@ class BedrockConverseLLM(BaseLLM):
|
||||||
params_to_check: List[Optional[str]] = [
|
params_to_check: List[Optional[str]] = [
|
||||||
aws_access_key_id,
|
aws_access_key_id,
|
||||||
aws_secret_access_key,
|
aws_secret_access_key,
|
||||||
|
aws_session_token,
|
||||||
aws_region_name,
|
aws_region_name,
|
||||||
aws_session_name,
|
aws_session_name,
|
||||||
aws_profile_name,
|
aws_profile_name,
|
||||||
|
@ -1454,6 +1485,7 @@ class BedrockConverseLLM(BaseLLM):
|
||||||
(
|
(
|
||||||
aws_access_key_id,
|
aws_access_key_id,
|
||||||
aws_secret_access_key,
|
aws_secret_access_key,
|
||||||
|
aws_session_token,
|
||||||
aws_region_name,
|
aws_region_name,
|
||||||
aws_session_name,
|
aws_session_name,
|
||||||
aws_profile_name,
|
aws_profile_name,
|
||||||
|
@ -1547,6 +1579,19 @@ class BedrockConverseLLM(BaseLLM):
|
||||||
client = boto3.Session(profile_name=aws_profile_name)
|
client = boto3.Session(profile_name=aws_profile_name)
|
||||||
|
|
||||||
return client.get_credentials()
|
return client.get_credentials()
|
||||||
|
elif (
|
||||||
|
aws_access_key_id is not None
|
||||||
|
and aws_secret_access_key is not None
|
||||||
|
and aws_session_token is not None
|
||||||
|
): ### CHECK FOR AWS SESSION TOKEN ###
|
||||||
|
from botocore.credentials import Credentials
|
||||||
|
|
||||||
|
credentials = Credentials(
|
||||||
|
access_key=aws_access_key_id,
|
||||||
|
secret_key=aws_secret_access_key,
|
||||||
|
token=aws_session_token,
|
||||||
|
)
|
||||||
|
return credentials
|
||||||
else:
|
else:
|
||||||
session = boto3.Session(
|
session = boto3.Session(
|
||||||
aws_access_key_id=aws_access_key_id,
|
aws_access_key_id=aws_access_key_id,
|
||||||
|
@ -1682,6 +1727,7 @@ class BedrockConverseLLM(BaseLLM):
|
||||||
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
|
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
|
||||||
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
||||||
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
|
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
|
||||||
|
aws_session_token = optional_params.pop("aws_session_token", None)
|
||||||
aws_region_name = optional_params.pop("aws_region_name", None)
|
aws_region_name = optional_params.pop("aws_region_name", None)
|
||||||
aws_role_name = optional_params.pop("aws_role_name", None)
|
aws_role_name = optional_params.pop("aws_role_name", None)
|
||||||
aws_session_name = optional_params.pop("aws_session_name", None)
|
aws_session_name = optional_params.pop("aws_session_name", None)
|
||||||
|
@ -1713,6 +1759,7 @@ class BedrockConverseLLM(BaseLLM):
|
||||||
credentials: Credentials = self.get_credentials(
|
credentials: Credentials = self.get_credentials(
|
||||||
aws_access_key_id=aws_access_key_id,
|
aws_access_key_id=aws_access_key_id,
|
||||||
aws_secret_access_key=aws_secret_access_key,
|
aws_secret_access_key=aws_secret_access_key,
|
||||||
|
aws_session_token=aws_session_token,
|
||||||
aws_region_name=aws_region_name,
|
aws_region_name=aws_region_name,
|
||||||
aws_session_name=aws_session_name,
|
aws_session_name=aws_session_name,
|
||||||
aws_profile_name=aws_profile_name,
|
aws_profile_name=aws_profile_name,
|
||||||
|
|
|
@ -1,4 +1,8 @@
|
||||||
import time, json, httpx, asyncio
|
import asyncio
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
|
||||||
class AsyncCustomHTTPTransport(httpx.AsyncHTTPTransport):
|
class AsyncCustomHTTPTransport(httpx.AsyncHTTPTransport):
|
||||||
|
@ -7,15 +11,18 @@ class AsyncCustomHTTPTransport(httpx.AsyncHTTPTransport):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
|
async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
|
||||||
if "images/generations" in request.url.path and request.url.params[
|
_api_version = request.url.params.get("api-version", "")
|
||||||
"api-version"
|
if (
|
||||||
] in [ # dall-e-3 starts from `2023-12-01-preview` so we should be able to avoid conflict
|
"images/generations" in request.url.path
|
||||||
|
and _api_version
|
||||||
|
in [ # dall-e-3 starts from `2023-12-01-preview` so we should be able to avoid conflict
|
||||||
"2023-06-01-preview",
|
"2023-06-01-preview",
|
||||||
"2023-07-01-preview",
|
"2023-07-01-preview",
|
||||||
"2023-08-01-preview",
|
"2023-08-01-preview",
|
||||||
"2023-09-01-preview",
|
"2023-09-01-preview",
|
||||||
"2023-10-01-preview",
|
"2023-10-01-preview",
|
||||||
]:
|
]
|
||||||
|
):
|
||||||
request.url = request.url.copy_with(
|
request.url = request.url.copy_with(
|
||||||
path="/openai/images/generations:submit"
|
path="/openai/images/generations:submit"
|
||||||
)
|
)
|
||||||
|
@ -77,15 +84,18 @@ class CustomHTTPTransport(httpx.HTTPTransport):
|
||||||
self,
|
self,
|
||||||
request: httpx.Request,
|
request: httpx.Request,
|
||||||
) -> httpx.Response:
|
) -> httpx.Response:
|
||||||
if "images/generations" in request.url.path and request.url.params[
|
_api_version = request.url.params.get("api-version", "")
|
||||||
"api-version"
|
if (
|
||||||
] in [ # dall-e-3 starts from `2023-12-01-preview` so we should be able to avoid conflict
|
"images/generations" in request.url.path
|
||||||
|
and _api_version
|
||||||
|
in [ # dall-e-3 starts from `2023-12-01-preview` so we should be able to avoid conflict
|
||||||
"2023-06-01-preview",
|
"2023-06-01-preview",
|
||||||
"2023-07-01-preview",
|
"2023-07-01-preview",
|
||||||
"2023-08-01-preview",
|
"2023-08-01-preview",
|
||||||
"2023-09-01-preview",
|
"2023-09-01-preview",
|
||||||
"2023-10-01-preview",
|
"2023-10-01-preview",
|
||||||
]:
|
]
|
||||||
|
):
|
||||||
request.url = request.url.copy_with(
|
request.url = request.url.copy_with(
|
||||||
path="/openai/images/generations:submit"
|
path="/openai/images/generations:submit"
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,6 +1,11 @@
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import traceback
|
||||||
|
from typing import Any, Mapping, Optional, Union
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
import httpx, asyncio, traceback, os
|
|
||||||
from typing import Optional, Union, Mapping, Any
|
|
||||||
|
|
||||||
# https://www.python-httpx.org/advanced/timeouts
|
# https://www.python-httpx.org/advanced/timeouts
|
||||||
_DEFAULT_TIMEOUT = httpx.Timeout(timeout=5.0, connect=5.0)
|
_DEFAULT_TIMEOUT = httpx.Timeout(timeout=5.0, connect=5.0)
|
||||||
|
@ -93,7 +98,7 @@ class AsyncHTTPHandler:
|
||||||
response = await self.client.send(req, stream=stream)
|
response = await self.client.send(req, stream=stream)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return response
|
return response
|
||||||
except httpx.RemoteProtocolError:
|
except (httpx.RemoteProtocolError, httpx.ConnectError):
|
||||||
# Retry the request with a new session if there is a connection error
|
# Retry the request with a new session if there is a connection error
|
||||||
new_client = self.create_client(timeout=self.timeout, concurrent_limit=1)
|
new_client = self.create_client(timeout=self.timeout, concurrent_limit=1)
|
||||||
try:
|
try:
|
||||||
|
@ -109,6 +114,11 @@ class AsyncHTTPHandler:
|
||||||
finally:
|
finally:
|
||||||
await new_client.aclose()
|
await new_client.aclose()
|
||||||
except httpx.HTTPStatusError as e:
|
except httpx.HTTPStatusError as e:
|
||||||
|
setattr(e, "status_code", e.response.status_code)
|
||||||
|
if stream is True:
|
||||||
|
setattr(e, "message", await e.response.aread())
|
||||||
|
else:
|
||||||
|
setattr(e, "message", e.response.text)
|
||||||
raise e
|
raise e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
@ -208,6 +218,7 @@ class HTTPHandler:
|
||||||
headers: Optional[dict] = None,
|
headers: Optional[dict] = None,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
):
|
):
|
||||||
|
|
||||||
req = self.client.build_request(
|
req = self.client.build_request(
|
||||||
"POST", url, data=data, json=json, params=params, headers=headers # type: ignore
|
"POST", url, data=data, json=json, params=params, headers=headers # type: ignore
|
||||||
)
|
)
|
||||||
|
|
|
@ -21,6 +21,7 @@ from pydantic import BaseModel
|
||||||
from typing_extensions import overload, override
|
from typing_extensions import overload, override
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||||
from litellm.types.utils import ProviderField
|
from litellm.types.utils import ProviderField
|
||||||
from litellm.utils import (
|
from litellm.utils import (
|
||||||
Choices,
|
Choices,
|
||||||
|
@ -652,6 +653,36 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
else:
|
else:
|
||||||
return client
|
return client
|
||||||
|
|
||||||
|
async def make_openai_chat_completion_request(
|
||||||
|
self,
|
||||||
|
openai_aclient: AsyncOpenAI,
|
||||||
|
data: dict,
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Helper to:
|
||||||
|
- call chat.completions.create.with_raw_response when litellm.return_response_headers is True
|
||||||
|
- call chat.completions.create by default
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if litellm.return_response_headers is True:
|
||||||
|
raw_response = (
|
||||||
|
await openai_aclient.chat.completions.with_raw_response.create(
|
||||||
|
**data, timeout=timeout
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
headers = dict(raw_response.headers)
|
||||||
|
response = raw_response.parse()
|
||||||
|
return headers, response
|
||||||
|
else:
|
||||||
|
response = await openai_aclient.chat.completions.create(
|
||||||
|
**data, timeout=timeout
|
||||||
|
)
|
||||||
|
return None, response
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
def completion(
|
def completion(
|
||||||
self,
|
self,
|
||||||
model_response: ModelResponse,
|
model_response: ModelResponse,
|
||||||
|
@ -678,17 +709,17 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
if headers:
|
if headers:
|
||||||
optional_params["extra_headers"] = headers
|
optional_params["extra_headers"] = headers
|
||||||
if model is None or messages is None:
|
if model is None or messages is None:
|
||||||
raise OpenAIError(status_code=422, message=f"Missing model or messages")
|
raise OpenAIError(status_code=422, message="Missing model or messages")
|
||||||
|
|
||||||
if not isinstance(timeout, float) and not isinstance(
|
if not isinstance(timeout, float) and not isinstance(
|
||||||
timeout, httpx.Timeout
|
timeout, httpx.Timeout
|
||||||
):
|
):
|
||||||
raise OpenAIError(
|
raise OpenAIError(
|
||||||
status_code=422,
|
status_code=422,
|
||||||
message=f"Timeout needs to be a float or httpx.Timeout",
|
message="Timeout needs to be a float or httpx.Timeout",
|
||||||
)
|
)
|
||||||
|
|
||||||
if custom_llm_provider != "openai":
|
if custom_llm_provider is not None and custom_llm_provider != "openai":
|
||||||
model_response.model = f"{custom_llm_provider}/{model}"
|
model_response.model = f"{custom_llm_provider}/{model}"
|
||||||
# process all OpenAI compatible provider logic here
|
# process all OpenAI compatible provider logic here
|
||||||
if custom_llm_provider == "mistral":
|
if custom_llm_provider == "mistral":
|
||||||
|
@ -836,13 +867,13 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
self,
|
self,
|
||||||
data: dict,
|
data: dict,
|
||||||
model_response: ModelResponse,
|
model_response: ModelResponse,
|
||||||
|
logging_obj: LiteLLMLoggingObj,
|
||||||
timeout: Union[float, httpx.Timeout],
|
timeout: Union[float, httpx.Timeout],
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
organization: Optional[str] = None,
|
organization: Optional[str] = None,
|
||||||
client=None,
|
client=None,
|
||||||
max_retries=None,
|
max_retries=None,
|
||||||
logging_obj=None,
|
|
||||||
headers=None,
|
headers=None,
|
||||||
):
|
):
|
||||||
response = None
|
response = None
|
||||||
|
@ -869,8 +900,8 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
response = await openai_aclient.chat.completions.create(
|
headers, response = await self.make_openai_chat_completion_request(
|
||||||
**data, timeout=timeout
|
openai_aclient=openai_aclient, data=data, timeout=timeout
|
||||||
)
|
)
|
||||||
stringified_response = response.model_dump()
|
stringified_response = response.model_dump()
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
|
@ -879,9 +910,11 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
original_response=stringified_response,
|
original_response=stringified_response,
|
||||||
additional_args={"complete_input_dict": data},
|
additional_args={"complete_input_dict": data},
|
||||||
)
|
)
|
||||||
|
logging_obj.model_call_details["response_headers"] = headers
|
||||||
return convert_to_model_response_object(
|
return convert_to_model_response_object(
|
||||||
response_object=stringified_response,
|
response_object=stringified_response,
|
||||||
model_response_object=model_response,
|
model_response_object=model_response,
|
||||||
|
hidden_params={"headers": headers},
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
@ -931,10 +964,10 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
|
|
||||||
async def async_streaming(
|
async def async_streaming(
|
||||||
self,
|
self,
|
||||||
logging_obj,
|
|
||||||
timeout: Union[float, httpx.Timeout],
|
timeout: Union[float, httpx.Timeout],
|
||||||
data: dict,
|
data: dict,
|
||||||
model: str,
|
model: str,
|
||||||
|
logging_obj: LiteLLMLoggingObj,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
organization: Optional[str] = None,
|
organization: Optional[str] = None,
|
||||||
|
@ -965,9 +998,10 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
response = await openai_aclient.chat.completions.create(
|
headers, response = await self.make_openai_chat_completion_request(
|
||||||
**data, timeout=timeout
|
openai_aclient=openai_aclient, data=data, timeout=timeout
|
||||||
)
|
)
|
||||||
|
logging_obj.model_call_details["response_headers"] = headers
|
||||||
streamwrapper = CustomStreamWrapper(
|
streamwrapper = CustomStreamWrapper(
|
||||||
completion_stream=response,
|
completion_stream=response,
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -992,17 +1026,43 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
else:
|
else:
|
||||||
raise OpenAIError(status_code=500, message=f"{str(e)}")
|
raise OpenAIError(status_code=500, message=f"{str(e)}")
|
||||||
|
|
||||||
|
# Embedding
|
||||||
|
async def make_openai_embedding_request(
|
||||||
|
self,
|
||||||
|
openai_aclient: AsyncOpenAI,
|
||||||
|
data: dict,
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Helper to:
|
||||||
|
- call embeddings.create.with_raw_response when litellm.return_response_headers is True
|
||||||
|
- call embeddings.create by default
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if litellm.return_response_headers is True:
|
||||||
|
raw_response = await openai_aclient.embeddings.with_raw_response.create(
|
||||||
|
**data, timeout=timeout
|
||||||
|
) # type: ignore
|
||||||
|
headers = dict(raw_response.headers)
|
||||||
|
response = raw_response.parse()
|
||||||
|
return headers, response
|
||||||
|
else:
|
||||||
|
response = await openai_aclient.embeddings.create(**data, timeout=timeout) # type: ignore
|
||||||
|
return None, response
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
async def aembedding(
|
async def aembedding(
|
||||||
self,
|
self,
|
||||||
input: list,
|
input: list,
|
||||||
data: dict,
|
data: dict,
|
||||||
model_response: litellm.utils.EmbeddingResponse,
|
model_response: litellm.utils.EmbeddingResponse,
|
||||||
timeout: float,
|
timeout: float,
|
||||||
|
logging_obj: LiteLLMLoggingObj,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
client: Optional[AsyncOpenAI] = None,
|
client: Optional[AsyncOpenAI] = None,
|
||||||
max_retries=None,
|
max_retries=None,
|
||||||
logging_obj=None,
|
|
||||||
):
|
):
|
||||||
response = None
|
response = None
|
||||||
try:
|
try:
|
||||||
|
@ -1014,7 +1074,10 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
client=client,
|
client=client,
|
||||||
)
|
)
|
||||||
response = await openai_aclient.embeddings.create(**data, timeout=timeout) # type: ignore
|
headers, response = await self.make_openai_embedding_request(
|
||||||
|
openai_aclient=openai_aclient, data=data, timeout=timeout
|
||||||
|
)
|
||||||
|
logging_obj.model_call_details["response_headers"] = headers
|
||||||
stringified_response = response.model_dump()
|
stringified_response = response.model_dump()
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
|
@ -1229,6 +1292,34 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
else:
|
else:
|
||||||
raise OpenAIError(status_code=500, message=str(e))
|
raise OpenAIError(status_code=500, message=str(e))
|
||||||
|
|
||||||
|
# Audio Transcriptions
|
||||||
|
async def make_openai_audio_transcriptions_request(
|
||||||
|
self,
|
||||||
|
openai_aclient: AsyncOpenAI,
|
||||||
|
data: dict,
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Helper to:
|
||||||
|
- call openai_aclient.audio.transcriptions.with_raw_response when litellm.return_response_headers is True
|
||||||
|
- call openai_aclient.audio.transcriptions.create by default
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if litellm.return_response_headers is True:
|
||||||
|
raw_response = (
|
||||||
|
await openai_aclient.audio.transcriptions.with_raw_response.create(
|
||||||
|
**data, timeout=timeout
|
||||||
|
)
|
||||||
|
) # type: ignore
|
||||||
|
headers = dict(raw_response.headers)
|
||||||
|
response = raw_response.parse()
|
||||||
|
return headers, response
|
||||||
|
else:
|
||||||
|
response = await openai_aclient.audio.transcriptions.create(**data, timeout=timeout) # type: ignore
|
||||||
|
return None, response
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
def audio_transcriptions(
|
def audio_transcriptions(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
@ -1286,11 +1377,11 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
data: dict,
|
data: dict,
|
||||||
model_response: TranscriptionResponse,
|
model_response: TranscriptionResponse,
|
||||||
timeout: float,
|
timeout: float,
|
||||||
|
logging_obj: LiteLLMLoggingObj,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
client=None,
|
client=None,
|
||||||
max_retries=None,
|
max_retries=None,
|
||||||
logging_obj=None,
|
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
openai_aclient = self._get_openai_client(
|
openai_aclient = self._get_openai_client(
|
||||||
|
@ -1302,9 +1393,12 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
client=client,
|
client=client,
|
||||||
)
|
)
|
||||||
|
|
||||||
response = await openai_aclient.audio.transcriptions.create(
|
headers, response = await self.make_openai_audio_transcriptions_request(
|
||||||
**data, timeout=timeout
|
openai_aclient=openai_aclient,
|
||||||
) # type: ignore
|
data=data,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
logging_obj.model_call_details["response_headers"] = headers
|
||||||
stringified_response = response.model_dump()
|
stringified_response = response.model_dump()
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
|
@ -1497,9 +1591,9 @@ class OpenAITextCompletion(BaseLLM):
|
||||||
model: str,
|
model: str,
|
||||||
messages: list,
|
messages: list,
|
||||||
timeout: float,
|
timeout: float,
|
||||||
|
logging_obj: LiteLLMLoggingObj,
|
||||||
print_verbose: Optional[Callable] = None,
|
print_verbose: Optional[Callable] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
logging_obj=None,
|
|
||||||
acompletion: bool = False,
|
acompletion: bool = False,
|
||||||
optional_params=None,
|
optional_params=None,
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
|
|
|
@ -2033,6 +2033,50 @@ def function_call_prompt(messages: list, functions: list):
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|
||||||
|
def response_schema_prompt(model: str, response_schema: dict) -> str:
|
||||||
|
"""
|
||||||
|
Decides if a user-defined custom prompt or default needs to be used
|
||||||
|
|
||||||
|
Returns the prompt str that's passed to the model as a user message
|
||||||
|
"""
|
||||||
|
custom_prompt_details: Optional[dict] = None
|
||||||
|
response_schema_as_message = [
|
||||||
|
{"role": "user", "content": "{}".format(response_schema)}
|
||||||
|
]
|
||||||
|
if f"{model}/response_schema_prompt" in litellm.custom_prompt_dict:
|
||||||
|
|
||||||
|
custom_prompt_details = litellm.custom_prompt_dict[
|
||||||
|
f"{model}/response_schema_prompt"
|
||||||
|
] # allow user to define custom response schema prompt by model
|
||||||
|
elif "response_schema_prompt" in litellm.custom_prompt_dict:
|
||||||
|
custom_prompt_details = litellm.custom_prompt_dict["response_schema_prompt"]
|
||||||
|
|
||||||
|
if custom_prompt_details is not None:
|
||||||
|
return custom_prompt(
|
||||||
|
role_dict=custom_prompt_details["roles"],
|
||||||
|
initial_prompt_value=custom_prompt_details["initial_prompt_value"],
|
||||||
|
final_prompt_value=custom_prompt_details["final_prompt_value"],
|
||||||
|
messages=response_schema_as_message,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return default_response_schema_prompt(response_schema=response_schema)
|
||||||
|
|
||||||
|
|
||||||
|
def default_response_schema_prompt(response_schema: dict) -> str:
|
||||||
|
"""
|
||||||
|
Used if provider/model doesn't support 'response_schema' param.
|
||||||
|
|
||||||
|
This is the default prompt. Allow user to override this with a custom_prompt.
|
||||||
|
"""
|
||||||
|
prompt_str = """Use this JSON schema:
|
||||||
|
```json
|
||||||
|
{}
|
||||||
|
```""".format(
|
||||||
|
response_schema
|
||||||
|
)
|
||||||
|
return prompt_str
|
||||||
|
|
||||||
|
|
||||||
# Custom prompt template
|
# Custom prompt template
|
||||||
def custom_prompt(
|
def custom_prompt(
|
||||||
role_dict: dict,
|
role_dict: dict,
|
||||||
|
|
|
@ -12,6 +12,7 @@ import requests # type: ignore
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
from litellm._logging import verbose_logger
|
||||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||||
from litellm.llms.prompt_templates.factory import (
|
from litellm.llms.prompt_templates.factory import (
|
||||||
convert_to_anthropic_image_obj,
|
convert_to_anthropic_image_obj,
|
||||||
|
@ -328,11 +329,14 @@ def _gemini_convert_messages_with_history(messages: list) -> List[ContentType]:
|
||||||
contents: List[ContentType] = []
|
contents: List[ContentType] = []
|
||||||
|
|
||||||
msg_i = 0
|
msg_i = 0
|
||||||
|
try:
|
||||||
while msg_i < len(messages):
|
while msg_i < len(messages):
|
||||||
user_content: List[PartType] = []
|
user_content: List[PartType] = []
|
||||||
init_msg_i = msg_i
|
init_msg_i = msg_i
|
||||||
## MERGE CONSECUTIVE USER CONTENT ##
|
## MERGE CONSECUTIVE USER CONTENT ##
|
||||||
while msg_i < len(messages) and messages[msg_i]["role"] in user_message_types:
|
while (
|
||||||
|
msg_i < len(messages) and messages[msg_i]["role"] in user_message_types
|
||||||
|
):
|
||||||
if isinstance(messages[msg_i]["content"], list):
|
if isinstance(messages[msg_i]["content"], list):
|
||||||
_parts: List[PartType] = []
|
_parts: List[PartType] = []
|
||||||
for element in messages[msg_i]["content"]:
|
for element in messages[msg_i]["content"]:
|
||||||
|
@ -375,7 +379,9 @@ def _gemini_convert_messages_with_history(messages: list) -> List[ContentType]:
|
||||||
"tool_calls", []
|
"tool_calls", []
|
||||||
): # support assistant tool invoke conversion
|
): # support assistant tool invoke conversion
|
||||||
assistant_content.extend(
|
assistant_content.extend(
|
||||||
convert_to_gemini_tool_call_invoke(messages[msg_i]["tool_calls"])
|
convert_to_gemini_tool_call_invoke(
|
||||||
|
messages[msg_i]["tool_calls"]
|
||||||
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assistant_text = (
|
assistant_text = (
|
||||||
|
@ -400,8 +406,9 @@ def _gemini_convert_messages_with_history(messages: list) -> List[ContentType]:
|
||||||
messages[msg_i]
|
messages[msg_i]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return contents
|
return contents
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
def _get_client_cache_key(model: str, vertex_project: str, vertex_location: str):
|
def _get_client_cache_key(model: str, vertex_project: str, vertex_location: str):
|
||||||
|
|
|
@ -1,24 +1,32 @@
|
||||||
# What is this?
|
# What is this?
|
||||||
## Handler file for calling claude-3 on vertex ai
|
## Handler file for calling claude-3 on vertex ai
|
||||||
import os, types
|
import copy
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import types
|
||||||
|
import uuid
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import requests, copy # type: ignore
|
from typing import Any, Callable, List, Optional, Tuple
|
||||||
import time, uuid
|
|
||||||
from typing import Callable, Optional, List
|
import httpx # type: ignore
|
||||||
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper
|
import requests # type: ignore
|
||||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
|
||||||
import litellm
|
import litellm
|
||||||
|
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
|
from litellm.types.utils import ResponseFormatChunk
|
||||||
|
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
|
||||||
|
|
||||||
from .prompt_templates.factory import (
|
from .prompt_templates.factory import (
|
||||||
contains_tag,
|
|
||||||
prompt_factory,
|
|
||||||
custom_prompt,
|
|
||||||
construct_tool_use_system_prompt,
|
construct_tool_use_system_prompt,
|
||||||
|
contains_tag,
|
||||||
|
custom_prompt,
|
||||||
extract_between_tags,
|
extract_between_tags,
|
||||||
parse_xml_params,
|
parse_xml_params,
|
||||||
|
prompt_factory,
|
||||||
|
response_schema_prompt,
|
||||||
)
|
)
|
||||||
import httpx # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
class VertexAIError(Exception):
|
class VertexAIError(Exception):
|
||||||
|
@ -104,6 +112,7 @@ class VertexAIAnthropicConfig:
|
||||||
"stop",
|
"stop",
|
||||||
"temperature",
|
"temperature",
|
||||||
"top_p",
|
"top_p",
|
||||||
|
"response_format",
|
||||||
]
|
]
|
||||||
|
|
||||||
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
||||||
|
@ -120,6 +129,8 @@ class VertexAIAnthropicConfig:
|
||||||
optional_params["temperature"] = value
|
optional_params["temperature"] = value
|
||||||
if param == "top_p":
|
if param == "top_p":
|
||||||
optional_params["top_p"] = value
|
optional_params["top_p"] = value
|
||||||
|
if param == "response_format" and "response_schema" in value:
|
||||||
|
optional_params["response_format"] = ResponseFormatChunk(**value) # type: ignore
|
||||||
return optional_params
|
return optional_params
|
||||||
|
|
||||||
|
|
||||||
|
@ -129,7 +140,6 @@ class VertexAIAnthropicConfig:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
# makes headers for API call
|
|
||||||
def refresh_auth(
|
def refresh_auth(
|
||||||
credentials,
|
credentials,
|
||||||
) -> str: # used when user passes in credentials as json string
|
) -> str: # used when user passes in credentials as json string
|
||||||
|
@ -144,6 +154,40 @@ def refresh_auth(
|
||||||
return credentials.token
|
return credentials.token
|
||||||
|
|
||||||
|
|
||||||
|
def get_vertex_client(
|
||||||
|
client: Any,
|
||||||
|
vertex_project: Optional[str],
|
||||||
|
vertex_location: Optional[str],
|
||||||
|
vertex_credentials: Optional[str],
|
||||||
|
) -> Tuple[Any, Optional[str]]:
|
||||||
|
args = locals()
|
||||||
|
from litellm.llms.vertex_httpx import VertexLLM
|
||||||
|
|
||||||
|
try:
|
||||||
|
from anthropic import AnthropicVertex
|
||||||
|
except Exception:
|
||||||
|
raise VertexAIError(
|
||||||
|
status_code=400,
|
||||||
|
message="""vertexai import failed please run `pip install -U google-cloud-aiplatform "anthropic[vertex]"`""",
|
||||||
|
)
|
||||||
|
|
||||||
|
access_token: Optional[str] = None
|
||||||
|
|
||||||
|
if client is None:
|
||||||
|
_credentials, cred_project_id = VertexLLM().load_auth(
|
||||||
|
credentials=vertex_credentials, project_id=vertex_project
|
||||||
|
)
|
||||||
|
vertex_ai_client = AnthropicVertex(
|
||||||
|
project_id=vertex_project or cred_project_id,
|
||||||
|
region=vertex_location or "us-central1",
|
||||||
|
access_token=_credentials.token,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
vertex_ai_client = client
|
||||||
|
|
||||||
|
return vertex_ai_client, access_token
|
||||||
|
|
||||||
|
|
||||||
def completion(
|
def completion(
|
||||||
model: str,
|
model: str,
|
||||||
messages: list,
|
messages: list,
|
||||||
|
@ -151,10 +195,10 @@ def completion(
|
||||||
print_verbose: Callable,
|
print_verbose: Callable,
|
||||||
encoding,
|
encoding,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
|
optional_params: dict,
|
||||||
vertex_project=None,
|
vertex_project=None,
|
||||||
vertex_location=None,
|
vertex_location=None,
|
||||||
vertex_credentials=None,
|
vertex_credentials=None,
|
||||||
optional_params=None,
|
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
acompletion: bool = False,
|
acompletion: bool = False,
|
||||||
|
@ -178,6 +222,13 @@ def completion(
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
|
|
||||||
|
vertex_ai_client, access_token = get_vertex_client(
|
||||||
|
client=client,
|
||||||
|
vertex_project=vertex_project,
|
||||||
|
vertex_location=vertex_location,
|
||||||
|
vertex_credentials=vertex_credentials,
|
||||||
|
)
|
||||||
|
|
||||||
## Load Config
|
## Load Config
|
||||||
config = litellm.VertexAIAnthropicConfig.get_config()
|
config = litellm.VertexAIAnthropicConfig.get_config()
|
||||||
for k, v in config.items():
|
for k, v in config.items():
|
||||||
|
@ -186,6 +237,7 @@ def completion(
|
||||||
|
|
||||||
## Format Prompt
|
## Format Prompt
|
||||||
_is_function_call = False
|
_is_function_call = False
|
||||||
|
_is_json_schema = False
|
||||||
messages = copy.deepcopy(messages)
|
messages = copy.deepcopy(messages)
|
||||||
optional_params = copy.deepcopy(optional_params)
|
optional_params = copy.deepcopy(optional_params)
|
||||||
# Separate system prompt from rest of message
|
# Separate system prompt from rest of message
|
||||||
|
@ -200,6 +252,29 @@ def completion(
|
||||||
messages.pop(idx)
|
messages.pop(idx)
|
||||||
if len(system_prompt) > 0:
|
if len(system_prompt) > 0:
|
||||||
optional_params["system"] = system_prompt
|
optional_params["system"] = system_prompt
|
||||||
|
# Checks for 'response_schema' support - if passed in
|
||||||
|
if "response_format" in optional_params:
|
||||||
|
response_format_chunk = ResponseFormatChunk(
|
||||||
|
**optional_params["response_format"] # type: ignore
|
||||||
|
)
|
||||||
|
supports_response_schema = litellm.supports_response_schema(
|
||||||
|
model=model, custom_llm_provider="vertex_ai"
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
supports_response_schema is False
|
||||||
|
and response_format_chunk["type"] == "json_object"
|
||||||
|
and "response_schema" in response_format_chunk
|
||||||
|
):
|
||||||
|
_is_json_schema = True
|
||||||
|
user_response_schema_message = response_schema_prompt(
|
||||||
|
model=model,
|
||||||
|
response_schema=response_format_chunk["response_schema"],
|
||||||
|
)
|
||||||
|
messages.append(
|
||||||
|
{"role": "user", "content": user_response_schema_message}
|
||||||
|
)
|
||||||
|
messages.append({"role": "assistant", "content": "{"})
|
||||||
|
optional_params.pop("response_format")
|
||||||
# Format rest of message according to anthropic guidelines
|
# Format rest of message according to anthropic guidelines
|
||||||
try:
|
try:
|
||||||
messages = prompt_factory(
|
messages = prompt_factory(
|
||||||
|
@ -233,32 +308,6 @@ def completion(
|
||||||
print_verbose(
|
print_verbose(
|
||||||
f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}; vertex_credentials={vertex_credentials}"
|
f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}; vertex_credentials={vertex_credentials}"
|
||||||
)
|
)
|
||||||
access_token = None
|
|
||||||
if client is None:
|
|
||||||
if vertex_credentials is not None and isinstance(vertex_credentials, str):
|
|
||||||
import google.oauth2.service_account
|
|
||||||
|
|
||||||
try:
|
|
||||||
json_obj = json.loads(vertex_credentials)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
json_obj = json.load(open(vertex_credentials))
|
|
||||||
|
|
||||||
creds = (
|
|
||||||
google.oauth2.service_account.Credentials.from_service_account_info(
|
|
||||||
json_obj,
|
|
||||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
### CHECK IF ACCESS
|
|
||||||
access_token = refresh_auth(credentials=creds)
|
|
||||||
|
|
||||||
vertex_ai_client = AnthropicVertex(
|
|
||||||
project_id=vertex_project,
|
|
||||||
region=vertex_location,
|
|
||||||
access_token=access_token,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
vertex_ai_client = client
|
|
||||||
|
|
||||||
if acompletion == True:
|
if acompletion == True:
|
||||||
"""
|
"""
|
||||||
|
@ -315,7 +364,16 @@ def completion(
|
||||||
)
|
)
|
||||||
|
|
||||||
message = vertex_ai_client.messages.create(**data) # type: ignore
|
message = vertex_ai_client.messages.create(**data) # type: ignore
|
||||||
text_content = message.content[0].text
|
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=messages,
|
||||||
|
api_key="",
|
||||||
|
original_response=message,
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
)
|
||||||
|
|
||||||
|
text_content: str = message.content[0].text
|
||||||
## TOOL CALLING - OUTPUT PARSE
|
## TOOL CALLING - OUTPUT PARSE
|
||||||
if text_content is not None and contains_tag("invoke", text_content):
|
if text_content is not None and contains_tag("invoke", text_content):
|
||||||
function_name = extract_between_tags("tool_name", text_content)[0]
|
function_name = extract_between_tags("tool_name", text_content)[0]
|
||||||
|
@ -338,6 +396,12 @@ def completion(
|
||||||
content=None,
|
content=None,
|
||||||
)
|
)
|
||||||
model_response.choices[0].message = _message # type: ignore
|
model_response.choices[0].message = _message # type: ignore
|
||||||
|
else:
|
||||||
|
if (
|
||||||
|
_is_json_schema
|
||||||
|
): # follows https://github.com/anthropics/anthropic-cookbook/blob/main/misc/how_to_enable_json_mode.ipynb
|
||||||
|
json_response = "{" + text_content[: text_content.rfind("}") + 1]
|
||||||
|
model_response.choices[0].message.content = json_response # type: ignore
|
||||||
else:
|
else:
|
||||||
model_response.choices[0].message.content = text_content # type: ignore
|
model_response.choices[0].message.content = text_content # type: ignore
|
||||||
model_response.choices[0].finish_reason = map_finish_reason(message.stop_reason)
|
model_response.choices[0].finish_reason = map_finish_reason(message.stop_reason)
|
||||||
|
|
|
@ -12,7 +12,6 @@ from functools import partial
|
||||||
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
|
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
import httpx # type: ignore
|
import httpx # type: ignore
|
||||||
import ijson
|
|
||||||
import requests # type: ignore
|
import requests # type: ignore
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
@ -21,7 +20,10 @@ import litellm.litellm_core_utils.litellm_logging
|
||||||
from litellm import verbose_logger
|
from litellm import verbose_logger
|
||||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
from litellm.llms.prompt_templates.factory import convert_url_to_base64
|
from litellm.llms.prompt_templates.factory import (
|
||||||
|
convert_url_to_base64,
|
||||||
|
response_schema_prompt,
|
||||||
|
)
|
||||||
from litellm.llms.vertex_ai import _gemini_convert_messages_with_history
|
from litellm.llms.vertex_ai import _gemini_convert_messages_with_history
|
||||||
from litellm.types.llms.openai import (
|
from litellm.types.llms.openai import (
|
||||||
ChatCompletionResponseMessage,
|
ChatCompletionResponseMessage,
|
||||||
|
@ -183,10 +185,17 @@ class GoogleAIStudioGeminiConfig: # key diff from VertexAI - 'frequency_penalty
|
||||||
if param == "tools" and isinstance(value, list):
|
if param == "tools" and isinstance(value, list):
|
||||||
gtool_func_declarations = []
|
gtool_func_declarations = []
|
||||||
for tool in value:
|
for tool in value:
|
||||||
|
_parameters = tool.get("function", {}).get("parameters", {})
|
||||||
|
_properties = _parameters.get("properties", {})
|
||||||
|
if isinstance(_properties, dict):
|
||||||
|
for _, _property in _properties.items():
|
||||||
|
if "enum" in _property and "format" not in _property:
|
||||||
|
_property["format"] = "enum"
|
||||||
|
|
||||||
gtool_func_declaration = FunctionDeclaration(
|
gtool_func_declaration = FunctionDeclaration(
|
||||||
name=tool["function"]["name"],
|
name=tool["function"]["name"],
|
||||||
description=tool["function"].get("description", ""),
|
description=tool["function"].get("description", ""),
|
||||||
parameters=tool["function"].get("parameters", {}),
|
parameters=_parameters,
|
||||||
)
|
)
|
||||||
gtool_func_declarations.append(gtool_func_declaration)
|
gtool_func_declarations.append(gtool_func_declaration)
|
||||||
optional_params["tools"] = [
|
optional_params["tools"] = [
|
||||||
|
@ -349,6 +358,7 @@ class VertexGeminiConfig:
|
||||||
model: str,
|
model: str,
|
||||||
non_default_params: dict,
|
non_default_params: dict,
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
drop_params: bool,
|
||||||
):
|
):
|
||||||
for param, value in non_default_params.items():
|
for param, value in non_default_params.items():
|
||||||
if param == "temperature":
|
if param == "temperature":
|
||||||
|
@ -368,8 +378,13 @@ class VertexGeminiConfig:
|
||||||
optional_params["stop_sequences"] = value
|
optional_params["stop_sequences"] = value
|
||||||
if param == "max_tokens":
|
if param == "max_tokens":
|
||||||
optional_params["max_output_tokens"] = value
|
optional_params["max_output_tokens"] = value
|
||||||
if param == "response_format" and value["type"] == "json_object": # type: ignore
|
if param == "response_format" and isinstance(value, dict): # type: ignore
|
||||||
|
if value["type"] == "json_object":
|
||||||
optional_params["response_mime_type"] = "application/json"
|
optional_params["response_mime_type"] = "application/json"
|
||||||
|
elif value["type"] == "text":
|
||||||
|
optional_params["response_mime_type"] = "text/plain"
|
||||||
|
if "response_schema" in value:
|
||||||
|
optional_params["response_schema"] = value["response_schema"]
|
||||||
if param == "frequency_penalty":
|
if param == "frequency_penalty":
|
||||||
optional_params["frequency_penalty"] = value
|
optional_params["frequency_penalty"] = value
|
||||||
if param == "presence_penalty":
|
if param == "presence_penalty":
|
||||||
|
@ -460,7 +475,7 @@ async def make_call(
|
||||||
raise VertexAIError(status_code=response.status_code, message=response.text)
|
raise VertexAIError(status_code=response.status_code, message=response.text)
|
||||||
|
|
||||||
completion_stream = ModelResponseIterator(
|
completion_stream = ModelResponseIterator(
|
||||||
streaming_response=response.aiter_bytes(), sync_stream=False
|
streaming_response=response.aiter_lines(), sync_stream=False
|
||||||
)
|
)
|
||||||
# LOGGING
|
# LOGGING
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
|
@ -491,7 +506,7 @@ def make_sync_call(
|
||||||
raise VertexAIError(status_code=response.status_code, message=response.read())
|
raise VertexAIError(status_code=response.status_code, message=response.read())
|
||||||
|
|
||||||
completion_stream = ModelResponseIterator(
|
completion_stream = ModelResponseIterator(
|
||||||
streaming_response=response.iter_bytes(chunk_size=2056), sync_stream=True
|
streaming_response=response.iter_lines(), sync_stream=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# LOGGING
|
# LOGGING
|
||||||
|
@ -813,11 +828,12 @@ class VertexLLM(BaseLLM):
|
||||||
endpoint = "generateContent"
|
endpoint = "generateContent"
|
||||||
if stream is True:
|
if stream is True:
|
||||||
endpoint = "streamGenerateContent"
|
endpoint = "streamGenerateContent"
|
||||||
|
url = "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}&alt=sse".format(
|
||||||
url = (
|
|
||||||
"https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}".format(
|
|
||||||
_gemini_model_name, endpoint, gemini_api_key
|
_gemini_model_name, endpoint, gemini_api_key
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
url = "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}".format(
|
||||||
|
_gemini_model_name, endpoint, gemini_api_key
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
auth_header, vertex_project = self._ensure_access_token(
|
auth_header, vertex_project = self._ensure_access_token(
|
||||||
|
@ -829,6 +845,8 @@ class VertexLLM(BaseLLM):
|
||||||
endpoint = "generateContent"
|
endpoint = "generateContent"
|
||||||
if stream is True:
|
if stream is True:
|
||||||
endpoint = "streamGenerateContent"
|
endpoint = "streamGenerateContent"
|
||||||
|
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}?alt=sse"
|
||||||
|
else:
|
||||||
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}"
|
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}"
|
||||||
|
|
||||||
if (
|
if (
|
||||||
|
@ -842,6 +860,9 @@ class VertexLLM(BaseLLM):
|
||||||
else:
|
else:
|
||||||
url = "{}:{}".format(api_base, endpoint)
|
url = "{}:{}".format(api_base, endpoint)
|
||||||
|
|
||||||
|
if stream is True:
|
||||||
|
url = url + "?alt=sse"
|
||||||
|
|
||||||
return auth_header, url
|
return auth_header, url
|
||||||
|
|
||||||
async def async_streaming(
|
async def async_streaming(
|
||||||
|
@ -994,6 +1015,22 @@ class VertexLLM(BaseLLM):
|
||||||
if len(system_prompt_indices) > 0:
|
if len(system_prompt_indices) > 0:
|
||||||
for idx in reversed(system_prompt_indices):
|
for idx in reversed(system_prompt_indices):
|
||||||
messages.pop(idx)
|
messages.pop(idx)
|
||||||
|
|
||||||
|
# Checks for 'response_schema' support - if passed in
|
||||||
|
if "response_schema" in optional_params:
|
||||||
|
supports_response_schema = litellm.supports_response_schema(
|
||||||
|
model=model, custom_llm_provider="vertex_ai"
|
||||||
|
)
|
||||||
|
if supports_response_schema is False:
|
||||||
|
user_response_schema_message = response_schema_prompt(
|
||||||
|
model=model, response_schema=optional_params.get("response_schema") # type: ignore
|
||||||
|
)
|
||||||
|
messages.append(
|
||||||
|
{"role": "user", "content": user_response_schema_message}
|
||||||
|
)
|
||||||
|
optional_params.pop("response_schema")
|
||||||
|
|
||||||
|
try:
|
||||||
content = _gemini_convert_messages_with_history(messages=messages)
|
content = _gemini_convert_messages_with_history(messages=messages)
|
||||||
tools: Optional[Tools] = optional_params.pop("tools", None)
|
tools: Optional[Tools] = optional_params.pop("tools", None)
|
||||||
tool_choice: Optional[ToolConfig] = optional_params.pop("tool_choice", None)
|
tool_choice: Optional[ToolConfig] = optional_params.pop("tool_choice", None)
|
||||||
|
@ -1017,12 +1054,14 @@ class VertexLLM(BaseLLM):
|
||||||
data["generationConfig"] = generation_config
|
data["generationConfig"] = generation_config
|
||||||
|
|
||||||
headers = {
|
headers = {
|
||||||
"Content-Type": "application/json; charset=utf-8",
|
"Content-Type": "application/json",
|
||||||
}
|
}
|
||||||
if auth_header is not None:
|
if auth_header is not None:
|
||||||
headers["Authorization"] = f"Bearer {auth_header}"
|
headers["Authorization"] = f"Bearer {auth_header}"
|
||||||
if extra_headers is not None:
|
if extra_headers is not None:
|
||||||
headers.update(extra_headers)
|
headers.update(extra_headers)
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
|
@ -1270,11 +1309,6 @@ class VertexLLM(BaseLLM):
|
||||||
class ModelResponseIterator:
|
class ModelResponseIterator:
|
||||||
def __init__(self, streaming_response, sync_stream: bool):
|
def __init__(self, streaming_response, sync_stream: bool):
|
||||||
self.streaming_response = streaming_response
|
self.streaming_response = streaming_response
|
||||||
if sync_stream:
|
|
||||||
self.response_iterator = iter(self.streaming_response)
|
|
||||||
|
|
||||||
self.events = ijson.sendable_list()
|
|
||||||
self.coro = ijson.items_coro(self.events, "item")
|
|
||||||
|
|
||||||
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
|
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
|
||||||
try:
|
try:
|
||||||
|
@ -1304,9 +1338,9 @@ class ModelResponseIterator:
|
||||||
if "usageMetadata" in processed_chunk:
|
if "usageMetadata" in processed_chunk:
|
||||||
usage = ChatCompletionUsageBlock(
|
usage = ChatCompletionUsageBlock(
|
||||||
prompt_tokens=processed_chunk["usageMetadata"]["promptTokenCount"],
|
prompt_tokens=processed_chunk["usageMetadata"]["promptTokenCount"],
|
||||||
completion_tokens=processed_chunk["usageMetadata"][
|
completion_tokens=processed_chunk["usageMetadata"].get(
|
||||||
"candidatesTokenCount"
|
"candidatesTokenCount", 0
|
||||||
],
|
),
|
||||||
total_tokens=processed_chunk["usageMetadata"]["totalTokenCount"],
|
total_tokens=processed_chunk["usageMetadata"]["totalTokenCount"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1324,16 +1358,24 @@ class ModelResponseIterator:
|
||||||
|
|
||||||
# Sync iterator
|
# Sync iterator
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
|
self.response_iterator = self.streaming_response
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __next__(self):
|
def __next__(self):
|
||||||
try:
|
try:
|
||||||
chunk = self.response_iterator.__next__()
|
chunk = self.response_iterator.__next__()
|
||||||
self.coro.send(chunk)
|
except StopIteration:
|
||||||
if self.events:
|
raise StopIteration
|
||||||
event = self.events.pop(0)
|
except ValueError as e:
|
||||||
json_chunk = event
|
raise RuntimeError(f"Error receiving chunk from stream: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
chunk = chunk.replace("data:", "")
|
||||||
|
chunk = chunk.strip()
|
||||||
|
if len(chunk) > 0:
|
||||||
|
json_chunk = json.loads(chunk)
|
||||||
return self.chunk_parser(chunk=json_chunk)
|
return self.chunk_parser(chunk=json_chunk)
|
||||||
|
else:
|
||||||
return GenericStreamingChunk(
|
return GenericStreamingChunk(
|
||||||
text="",
|
text="",
|
||||||
is_finished=False,
|
is_finished=False,
|
||||||
|
@ -1343,12 +1385,9 @@ class ModelResponseIterator:
|
||||||
tool_use=None,
|
tool_use=None,
|
||||||
)
|
)
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
if self.events: # flush the events
|
|
||||||
event = self.events.pop(0) # Remove the first event
|
|
||||||
return self.chunk_parser(chunk=event)
|
|
||||||
raise StopIteration
|
raise StopIteration
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise RuntimeError(f"Error parsing chunk: {e}")
|
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
|
||||||
|
|
||||||
# Async iterator
|
# Async iterator
|
||||||
def __aiter__(self):
|
def __aiter__(self):
|
||||||
|
@ -1358,11 +1397,18 @@ class ModelResponseIterator:
|
||||||
async def __anext__(self):
|
async def __anext__(self):
|
||||||
try:
|
try:
|
||||||
chunk = await self.async_response_iterator.__anext__()
|
chunk = await self.async_response_iterator.__anext__()
|
||||||
self.coro.send(chunk)
|
except StopAsyncIteration:
|
||||||
if self.events:
|
raise StopAsyncIteration
|
||||||
event = self.events.pop(0)
|
except ValueError as e:
|
||||||
json_chunk = event
|
raise RuntimeError(f"Error receiving chunk from stream: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
chunk = chunk.replace("data:", "")
|
||||||
|
chunk = chunk.strip()
|
||||||
|
if len(chunk) > 0:
|
||||||
|
json_chunk = json.loads(chunk)
|
||||||
return self.chunk_parser(chunk=json_chunk)
|
return self.chunk_parser(chunk=json_chunk)
|
||||||
|
else:
|
||||||
return GenericStreamingChunk(
|
return GenericStreamingChunk(
|
||||||
text="",
|
text="",
|
||||||
is_finished=False,
|
is_finished=False,
|
||||||
|
@ -1372,9 +1418,6 @@ class ModelResponseIterator:
|
||||||
tool_use=None,
|
tool_use=None,
|
||||||
)
|
)
|
||||||
except StopAsyncIteration:
|
except StopAsyncIteration:
|
||||||
if self.events: # flush the events
|
|
||||||
event = self.events.pop(0) # Remove the first event
|
|
||||||
return self.chunk_parser(chunk=event)
|
|
||||||
raise StopAsyncIteration
|
raise StopAsyncIteration
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise RuntimeError(f"Error parsing chunk: {e}")
|
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
|
||||||
|
|
|
@ -476,6 +476,15 @@ def mock_completion(
|
||||||
model=model, # type: ignore
|
model=model, # type: ignore
|
||||||
request=httpx.Request(method="POST", url="https://api.openai.com/v1/"),
|
request=httpx.Request(method="POST", url="https://api.openai.com/v1/"),
|
||||||
)
|
)
|
||||||
|
elif (
|
||||||
|
isinstance(mock_response, str) and mock_response == "litellm.RateLimitError"
|
||||||
|
):
|
||||||
|
raise litellm.RateLimitError(
|
||||||
|
message="this is a mock rate limit error",
|
||||||
|
status_code=getattr(mock_response, "status_code", 429), # type: ignore
|
||||||
|
llm_provider=getattr(mock_response, "llm_provider", custom_llm_provider or "openai"), # type: ignore
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
time_delay = kwargs.get("mock_delay", None)
|
time_delay = kwargs.get("mock_delay", None)
|
||||||
if time_delay is not None:
|
if time_delay is not None:
|
||||||
time.sleep(time_delay)
|
time.sleep(time_delay)
|
||||||
|
@ -676,6 +685,8 @@ def completion(
|
||||||
client = kwargs.get("client", None)
|
client = kwargs.get("client", None)
|
||||||
### Admin Controls ###
|
### Admin Controls ###
|
||||||
no_log = kwargs.get("no-log", False)
|
no_log = kwargs.get("no-log", False)
|
||||||
|
### COPY MESSAGES ### - related issue https://github.com/BerriAI/litellm/discussions/4489
|
||||||
|
messages = deepcopy(messages)
|
||||||
######## end of unpacking kwargs ###########
|
######## end of unpacking kwargs ###########
|
||||||
openai_params = [
|
openai_params = [
|
||||||
"functions",
|
"functions",
|
||||||
|
@ -1828,6 +1839,7 @@ def completion(
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
acompletion=acompletion,
|
acompletion=acompletion,
|
||||||
timeout=timeout, # type: ignore
|
timeout=timeout, # type: ignore
|
||||||
|
custom_llm_provider="openrouter",
|
||||||
)
|
)
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging.post_call(
|
logging.post_call(
|
||||||
|
@ -2199,46 +2211,29 @@ def completion(
|
||||||
# boto3 reads keys from .env
|
# boto3 reads keys from .env
|
||||||
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
|
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
|
||||||
|
|
||||||
|
if "aws_bedrock_client" in optional_params:
|
||||||
|
verbose_logger.warning(
|
||||||
|
"'aws_bedrock_client' is a deprecated param. Please move to another auth method - https://docs.litellm.ai/docs/providers/bedrock#boto3---authentication."
|
||||||
|
)
|
||||||
|
# Extract credentials for legacy boto3 client and pass thru to httpx
|
||||||
|
aws_bedrock_client = optional_params.pop("aws_bedrock_client")
|
||||||
|
creds = aws_bedrock_client._get_credentials().get_frozen_credentials()
|
||||||
|
|
||||||
|
if creds.access_key:
|
||||||
|
optional_params["aws_access_key_id"] = creds.access_key
|
||||||
|
if creds.secret_key:
|
||||||
|
optional_params["aws_secret_access_key"] = creds.secret_key
|
||||||
|
if creds.token:
|
||||||
|
optional_params["aws_session_token"] = creds.token
|
||||||
if (
|
if (
|
||||||
"aws_bedrock_client" in optional_params
|
"aws_region_name" not in optional_params
|
||||||
): # use old bedrock flow for aws_bedrock_client users.
|
or optional_params["aws_region_name"] is None
|
||||||
response = bedrock.completion(
|
):
|
||||||
model=model,
|
optional_params["aws_region_name"] = (
|
||||||
messages=messages,
|
aws_bedrock_client.meta.region_name
|
||||||
custom_prompt_dict=litellm.custom_prompt_dict,
|
|
||||||
model_response=model_response,
|
|
||||||
print_verbose=print_verbose,
|
|
||||||
optional_params=optional_params,
|
|
||||||
litellm_params=litellm_params,
|
|
||||||
logger_fn=logger_fn,
|
|
||||||
encoding=encoding,
|
|
||||||
logging_obj=logging,
|
|
||||||
extra_headers=extra_headers,
|
|
||||||
timeout=timeout,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if model in litellm.BEDROCK_CONVERSE_MODELS:
|
||||||
"stream" in optional_params
|
|
||||||
and optional_params["stream"] == True
|
|
||||||
and not isinstance(response, CustomStreamWrapper)
|
|
||||||
):
|
|
||||||
# don't try to access stream object,
|
|
||||||
if "ai21" in model:
|
|
||||||
response = CustomStreamWrapper(
|
|
||||||
response,
|
|
||||||
model,
|
|
||||||
custom_llm_provider="bedrock",
|
|
||||||
logging_obj=logging,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
response = CustomStreamWrapper(
|
|
||||||
iter(response),
|
|
||||||
model,
|
|
||||||
custom_llm_provider="bedrock",
|
|
||||||
logging_obj=logging,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
if model.startswith("anthropic"):
|
|
||||||
response = bedrock_converse_chat_completion.completion(
|
response = bedrock_converse_chat_completion.completion(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
@ -2272,6 +2267,7 @@ def completion(
|
||||||
acompletion=acompletion,
|
acompletion=acompletion,
|
||||||
client=client,
|
client=client,
|
||||||
)
|
)
|
||||||
|
|
||||||
if optional_params.get("stream", False):
|
if optional_params.get("stream", False):
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging.post_call(
|
logging.post_call(
|
||||||
|
|
|
@ -1486,6 +1486,7 @@
|
||||||
"supports_system_messages": true,
|
"supports_system_messages": true,
|
||||||
"supports_function_calling": true,
|
"supports_function_calling": true,
|
||||||
"supports_tool_choice": true,
|
"supports_tool_choice": true,
|
||||||
|
"supports_response_schema": true,
|
||||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
|
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
|
||||||
},
|
},
|
||||||
"gemini-1.5-pro-001": {
|
"gemini-1.5-pro-001": {
|
||||||
|
@ -1511,6 +1512,7 @@
|
||||||
"supports_system_messages": true,
|
"supports_system_messages": true,
|
||||||
"supports_function_calling": true,
|
"supports_function_calling": true,
|
||||||
"supports_tool_choice": true,
|
"supports_tool_choice": true,
|
||||||
|
"supports_response_schema": true,
|
||||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
|
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
|
||||||
},
|
},
|
||||||
"gemini-1.5-pro-preview-0514": {
|
"gemini-1.5-pro-preview-0514": {
|
||||||
|
@ -1536,6 +1538,7 @@
|
||||||
"supports_system_messages": true,
|
"supports_system_messages": true,
|
||||||
"supports_function_calling": true,
|
"supports_function_calling": true,
|
||||||
"supports_tool_choice": true,
|
"supports_tool_choice": true,
|
||||||
|
"supports_response_schema": true,
|
||||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
|
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
|
||||||
},
|
},
|
||||||
"gemini-1.5-pro-preview-0215": {
|
"gemini-1.5-pro-preview-0215": {
|
||||||
|
@ -1561,6 +1564,7 @@
|
||||||
"supports_system_messages": true,
|
"supports_system_messages": true,
|
||||||
"supports_function_calling": true,
|
"supports_function_calling": true,
|
||||||
"supports_tool_choice": true,
|
"supports_tool_choice": true,
|
||||||
|
"supports_response_schema": true,
|
||||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
|
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
|
||||||
},
|
},
|
||||||
"gemini-1.5-pro-preview-0409": {
|
"gemini-1.5-pro-preview-0409": {
|
||||||
|
@ -1585,6 +1589,7 @@
|
||||||
"mode": "chat",
|
"mode": "chat",
|
||||||
"supports_function_calling": true,
|
"supports_function_calling": true,
|
||||||
"supports_tool_choice": true,
|
"supports_tool_choice": true,
|
||||||
|
"supports_response_schema": true,
|
||||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
|
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
|
||||||
},
|
},
|
||||||
"gemini-1.5-flash": {
|
"gemini-1.5-flash": {
|
||||||
|
@ -2007,6 +2012,7 @@
|
||||||
"supports_function_calling": true,
|
"supports_function_calling": true,
|
||||||
"supports_vision": true,
|
"supports_vision": true,
|
||||||
"supports_tool_choice": true,
|
"supports_tool_choice": true,
|
||||||
|
"supports_response_schema": true,
|
||||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
|
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
|
||||||
},
|
},
|
||||||
"gemini/gemini-1.5-pro-latest": {
|
"gemini/gemini-1.5-pro-latest": {
|
||||||
|
@ -2023,6 +2029,7 @@
|
||||||
"supports_function_calling": true,
|
"supports_function_calling": true,
|
||||||
"supports_vision": true,
|
"supports_vision": true,
|
||||||
"supports_tool_choice": true,
|
"supports_tool_choice": true,
|
||||||
|
"supports_response_schema": true,
|
||||||
"source": "https://ai.google.dev/models/gemini"
|
"source": "https://ai.google.dev/models/gemini"
|
||||||
},
|
},
|
||||||
"gemini/gemini-pro-vision": {
|
"gemini/gemini-pro-vision": {
|
||||||
|
|
|
@ -1,54 +1,5 @@
|
||||||
# model_list:
|
|
||||||
# - model_name: my-fake-model
|
|
||||||
# litellm_params:
|
|
||||||
# model: bedrock/anthropic.claude-3-sonnet-20240229-v1:0
|
|
||||||
# api_key: my-fake-key
|
|
||||||
# aws_bedrock_runtime_endpoint: http://127.0.0.1:8000
|
|
||||||
# mock_response: "Hello world 1"
|
|
||||||
# model_info:
|
|
||||||
# max_input_tokens: 0 # trigger context window fallback
|
|
||||||
# - model_name: my-fake-model
|
|
||||||
# litellm_params:
|
|
||||||
# model: bedrock/anthropic.claude-3-sonnet-20240229-v1:0
|
|
||||||
# api_key: my-fake-key
|
|
||||||
# aws_bedrock_runtime_endpoint: http://127.0.0.1:8000
|
|
||||||
# mock_response: "Hello world 2"
|
|
||||||
# model_info:
|
|
||||||
# max_input_tokens: 0
|
|
||||||
|
|
||||||
# router_settings:
|
|
||||||
# enable_pre_call_checks: True
|
|
||||||
|
|
||||||
|
|
||||||
# litellm_settings:
|
|
||||||
# failure_callback: ["langfuse"]
|
|
||||||
|
|
||||||
model_list:
|
model_list:
|
||||||
- model_name: summarize
|
- model_name: claude-3-5-sonnet # all requests where model not in your config go to this deployment
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: openai/gpt-4o
|
model: "openai/*"
|
||||||
rpm: 10000
|
mock_response: "litellm.RateLimitError"
|
||||||
tpm: 12000000
|
|
||||||
api_key: os.environ/OPENAI_API_KEY
|
|
||||||
mock_response: Hello world 1
|
|
||||||
|
|
||||||
- model_name: summarize-l
|
|
||||||
litellm_params:
|
|
||||||
model: claude-3-5-sonnet-20240620
|
|
||||||
rpm: 4000
|
|
||||||
tpm: 400000
|
|
||||||
api_key: os.environ/ANTHROPIC_API_KEY
|
|
||||||
mock_response: Hello world 2
|
|
||||||
|
|
||||||
litellm_settings:
|
|
||||||
num_retries: 3
|
|
||||||
request_timeout: 120
|
|
||||||
allowed_fails: 3
|
|
||||||
# fallbacks: [{"summarize": ["summarize-l", "summarize-xl"]}, {"summarize-l": ["summarize-xl"]}]
|
|
||||||
# context_window_fallbacks: [{"summarize": ["summarize-l", "summarize-xl"]}, {"summarize-l": ["summarize-xl"]}]
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
router_settings:
|
|
||||||
routing_strategy: simple-shuffle
|
|
||||||
enable_pre_call_checks: true.
|
|
|
@ -1,10 +1,11 @@
|
||||||
model_list:
|
model_list:
|
||||||
|
- model_name: claude-3-5-sonnet
|
||||||
|
litellm_params:
|
||||||
|
model: anthropic/claude-3-5-sonnet
|
||||||
- model_name: gemini-1.5-flash-gemini
|
- model_name: gemini-1.5-flash-gemini
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: gemini/gemini-1.5-flash
|
model: vertex_ai_beta/gemini-1.5-flash
|
||||||
- model_name: gemini-1.5-flash-gemini
|
api_base: https://gateway.ai.cloudflare.com/v1/fa4cdcab1f32b95ca3b53fd36043d691/test/google-vertex-ai/v1/projects/adroit-crow-413218/locations/us-central1/publishers/google/models/gemini-1.5-flash
|
||||||
litellm_params:
|
|
||||||
model: gemini/gemini-1.5-flash
|
|
||||||
- litellm_params:
|
- litellm_params:
|
||||||
api_base: http://0.0.0.0:8080
|
api_base: http://0.0.0.0:8080
|
||||||
api_key: ''
|
api_key: ''
|
||||||
|
|
|
@ -1622,7 +1622,7 @@ class ProxyException(Exception):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class CommonProxyErrors(enum.Enum):
|
class CommonProxyErrors(str, enum.Enum):
|
||||||
db_not_connected_error = "DB not connected"
|
db_not_connected_error = "DB not connected"
|
||||||
no_llm_router = "No models configured on proxy"
|
no_llm_router = "No models configured on proxy"
|
||||||
not_allowed_access = "Admin-only endpoint. Not allowed to access this."
|
not_allowed_access = "Admin-only endpoint. Not allowed to access this."
|
||||||
|
|
173
litellm/proxy/pass_through_endpoints/pass_through_endpoints.py
Normal file
173
litellm/proxy/pass_through_endpoints/pass_through_endpoints.py
Normal file
|
@ -0,0 +1,173 @@
|
||||||
|
import ast
|
||||||
|
import traceback
|
||||||
|
from base64 import b64encode
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from fastapi import (
|
||||||
|
APIRouter,
|
||||||
|
Depends,
|
||||||
|
FastAPI,
|
||||||
|
HTTPException,
|
||||||
|
Request,
|
||||||
|
Response,
|
||||||
|
status,
|
||||||
|
)
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm._logging import verbose_proxy_logger
|
||||||
|
from litellm.proxy._types import ProxyException
|
||||||
|
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||||
|
|
||||||
|
async_client = httpx.AsyncClient()
|
||||||
|
|
||||||
|
|
||||||
|
async def set_env_variables_in_header(custom_headers: dict):
|
||||||
|
"""
|
||||||
|
checks if nay headers on config.yaml are defined as os.environ/COHERE_API_KEY etc
|
||||||
|
|
||||||
|
only runs for headers defined on config.yaml
|
||||||
|
|
||||||
|
example header can be
|
||||||
|
|
||||||
|
{"Authorization": "bearer os.environ/COHERE_API_KEY"}
|
||||||
|
"""
|
||||||
|
headers = {}
|
||||||
|
for key, value in custom_headers.items():
|
||||||
|
# langfuse Api requires base64 encoded headers - it's simpleer to just ask litellm users to set their langfuse public and secret keys
|
||||||
|
# we can then get the b64 encoded keys here
|
||||||
|
if key == "LANGFUSE_PUBLIC_KEY" or key == "LANGFUSE_SECRET_KEY":
|
||||||
|
# langfuse requires b64 encoded headers - we construct that here
|
||||||
|
_langfuse_public_key = custom_headers["LANGFUSE_PUBLIC_KEY"]
|
||||||
|
_langfuse_secret_key = custom_headers["LANGFUSE_SECRET_KEY"]
|
||||||
|
if isinstance(
|
||||||
|
_langfuse_public_key, str
|
||||||
|
) and _langfuse_public_key.startswith("os.environ/"):
|
||||||
|
_langfuse_public_key = litellm.get_secret(_langfuse_public_key)
|
||||||
|
if isinstance(
|
||||||
|
_langfuse_secret_key, str
|
||||||
|
) and _langfuse_secret_key.startswith("os.environ/"):
|
||||||
|
_langfuse_secret_key = litellm.get_secret(_langfuse_secret_key)
|
||||||
|
headers["Authorization"] = "Basic " + b64encode(
|
||||||
|
f"{_langfuse_public_key}:{_langfuse_secret_key}".encode("utf-8")
|
||||||
|
).decode("ascii")
|
||||||
|
else:
|
||||||
|
# for all other headers
|
||||||
|
headers[key] = value
|
||||||
|
if isinstance(value, str) and "os.environ/" in value:
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
"pass through endpoint - looking up 'os.environ/' variable"
|
||||||
|
)
|
||||||
|
# get string section that is os.environ/
|
||||||
|
start_index = value.find("os.environ/")
|
||||||
|
_variable_name = value[start_index:]
|
||||||
|
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
"pass through endpoint - getting secret for variable name: %s",
|
||||||
|
_variable_name,
|
||||||
|
)
|
||||||
|
_secret_value = litellm.get_secret(_variable_name)
|
||||||
|
new_value = value.replace(_variable_name, _secret_value)
|
||||||
|
headers[key] = new_value
|
||||||
|
return headers
|
||||||
|
|
||||||
|
|
||||||
|
async def pass_through_request(request: Request, target: str, custom_headers: dict):
|
||||||
|
try:
|
||||||
|
|
||||||
|
url = httpx.URL(target)
|
||||||
|
headers = custom_headers
|
||||||
|
|
||||||
|
request_body = await request.body()
|
||||||
|
_parsed_body = ast.literal_eval(request_body.decode("utf-8"))
|
||||||
|
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
"Pass through endpoint sending request to \nURL {}\nheaders: {}\nbody: {}\n".format(
|
||||||
|
url, headers, _parsed_body
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await async_client.request(
|
||||||
|
method=request.method,
|
||||||
|
url=url,
|
||||||
|
headers=headers,
|
||||||
|
params=request.query_params,
|
||||||
|
json=_parsed_body,
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise HTTPException(status_code=response.status_code, detail=response.text)
|
||||||
|
|
||||||
|
content = await response.aread()
|
||||||
|
return Response(
|
||||||
|
content=content,
|
||||||
|
status_code=response.status_code,
|
||||||
|
headers=dict(response.headers),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
verbose_proxy_logger.error(
|
||||||
|
"litellm.proxy.proxy_server.pass through endpoint(): Exception occured - {}".format(
|
||||||
|
str(e)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
verbose_proxy_logger.debug(traceback.format_exc())
|
||||||
|
if isinstance(e, HTTPException):
|
||||||
|
raise ProxyException(
|
||||||
|
message=getattr(e, "message", str(e.detail)),
|
||||||
|
type=getattr(e, "type", "None"),
|
||||||
|
param=getattr(e, "param", "None"),
|
||||||
|
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
error_msg = f"{str(e)}"
|
||||||
|
raise ProxyException(
|
||||||
|
message=getattr(e, "message", error_msg),
|
||||||
|
type=getattr(e, "type", "None"),
|
||||||
|
param=getattr(e, "param", "None"),
|
||||||
|
code=getattr(e, "status_code", 500),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_pass_through_route(endpoint, target, custom_headers=None):
|
||||||
|
async def endpoint_func(request: Request):
|
||||||
|
return await pass_through_request(request, target, custom_headers)
|
||||||
|
|
||||||
|
return endpoint_func
|
||||||
|
|
||||||
|
|
||||||
|
async def initialize_pass_through_endpoints(pass_through_endpoints: list):
|
||||||
|
|
||||||
|
verbose_proxy_logger.debug("initializing pass through endpoints")
|
||||||
|
from litellm.proxy._types import CommonProxyErrors, LiteLLMRoutes
|
||||||
|
from litellm.proxy.proxy_server import app, premium_user
|
||||||
|
|
||||||
|
for endpoint in pass_through_endpoints:
|
||||||
|
_target = endpoint.get("target", None)
|
||||||
|
_path = endpoint.get("path", None)
|
||||||
|
_custom_headers = endpoint.get("headers", None)
|
||||||
|
_custom_headers = await set_env_variables_in_header(
|
||||||
|
custom_headers=_custom_headers
|
||||||
|
)
|
||||||
|
_auth = endpoint.get("auth", None)
|
||||||
|
_dependencies = None
|
||||||
|
if _auth is not None and str(_auth).lower() == "true":
|
||||||
|
if premium_user is not True:
|
||||||
|
raise ValueError(
|
||||||
|
f"Error Setting Authentication on Pass Through Endpoint: {CommonProxyErrors.not_premium_user}"
|
||||||
|
)
|
||||||
|
_dependencies = [Depends(user_api_key_auth)]
|
||||||
|
LiteLLMRoutes.openai_routes.value.append(_path)
|
||||||
|
|
||||||
|
if _target is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
verbose_proxy_logger.debug("adding pass through endpoint: %s", _path)
|
||||||
|
|
||||||
|
app.add_api_route(
|
||||||
|
path=_path,
|
||||||
|
endpoint=create_pass_through_route(_path, _target, _custom_headers),
|
||||||
|
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
|
||||||
|
dependencies=_dependencies,
|
||||||
|
)
|
||||||
|
|
||||||
|
verbose_proxy_logger.debug("Added new pass through endpoint: %s", _path)
|
68
litellm/proxy/prisma_migration.py
Normal file
68
litellm/proxy/prisma_migration.py
Normal file
|
@ -0,0 +1,68 @@
|
||||||
|
# What is this?
|
||||||
|
## Script to apply initial prisma migration on Docker setup
|
||||||
|
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("./")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
from litellm.proxy.secret_managers.aws_secret_manager import decrypt_env_var
|
||||||
|
|
||||||
|
if os.getenv("USE_AWS_KMS", None) is not None and os.getenv("USE_AWS_KMS") == "True":
|
||||||
|
## V2 IMPLEMENTATION OF AWS KMS - USER WANTS TO DECRYPT MULTIPLE KEYS IN THEIR ENV
|
||||||
|
new_env_var = decrypt_env_var()
|
||||||
|
|
||||||
|
for k, v in new_env_var.items():
|
||||||
|
os.environ[k] = v
|
||||||
|
|
||||||
|
# Check if DATABASE_URL is not set
|
||||||
|
database_url = os.getenv("DATABASE_URL")
|
||||||
|
if not database_url:
|
||||||
|
# Check if all required variables are provided
|
||||||
|
database_host = os.getenv("DATABASE_HOST")
|
||||||
|
database_username = os.getenv("DATABASE_USERNAME")
|
||||||
|
database_password = os.getenv("DATABASE_PASSWORD")
|
||||||
|
database_name = os.getenv("DATABASE_NAME")
|
||||||
|
|
||||||
|
if database_host and database_username and database_password and database_name:
|
||||||
|
# Construct DATABASE_URL from the provided variables
|
||||||
|
database_url = f"postgresql://{database_username}:{database_password}@{database_host}/{database_name}"
|
||||||
|
os.environ["DATABASE_URL"] = database_url
|
||||||
|
else:
|
||||||
|
print( # noqa
|
||||||
|
"Error: Required database environment variables are not set. Provide a postgres url for DATABASE_URL." # noqa
|
||||||
|
)
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
# Set DIRECT_URL to the value of DATABASE_URL if it is not set, required for migrations
|
||||||
|
direct_url = os.getenv("DIRECT_URL")
|
||||||
|
if not direct_url:
|
||||||
|
os.environ["DIRECT_URL"] = database_url
|
||||||
|
|
||||||
|
# Apply migrations
|
||||||
|
retry_count = 0
|
||||||
|
max_retries = 3
|
||||||
|
exit_code = 1
|
||||||
|
|
||||||
|
while retry_count < max_retries and exit_code != 0:
|
||||||
|
retry_count += 1
|
||||||
|
print(f"Attempt {retry_count}...") # noqa
|
||||||
|
|
||||||
|
# Run the Prisma db push command
|
||||||
|
result = subprocess.run(
|
||||||
|
["prisma", "db", "push", "--accept-data-loss"], capture_output=True
|
||||||
|
)
|
||||||
|
exit_code = result.returncode
|
||||||
|
|
||||||
|
if exit_code != 0 and retry_count < max_retries:
|
||||||
|
print("Retrying in 10 seconds...") # noqa
|
||||||
|
time.sleep(10)
|
||||||
|
|
||||||
|
if exit_code != 0:
|
||||||
|
print(f"Unable to push database changes after {max_retries} retries.") # noqa
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
print("Database push successful!") # noqa
|
|
@ -442,6 +442,20 @@ def run_server(
|
||||||
|
|
||||||
db_connection_pool_limit = 100
|
db_connection_pool_limit = 100
|
||||||
db_connection_timeout = 60
|
db_connection_timeout = 60
|
||||||
|
### DECRYPT ENV VAR ###
|
||||||
|
|
||||||
|
from litellm.proxy.secret_managers.aws_secret_manager import decrypt_env_var
|
||||||
|
|
||||||
|
if (
|
||||||
|
os.getenv("USE_AWS_KMS", None) is not None
|
||||||
|
and os.getenv("USE_AWS_KMS") == "True"
|
||||||
|
):
|
||||||
|
## V2 IMPLEMENTATION OF AWS KMS - USER WANTS TO DECRYPT MULTIPLE KEYS IN THEIR ENV
|
||||||
|
new_env_var = decrypt_env_var()
|
||||||
|
|
||||||
|
for k, v in new_env_var.items():
|
||||||
|
os.environ[k] = v
|
||||||
|
|
||||||
if config is not None:
|
if config is not None:
|
||||||
"""
|
"""
|
||||||
Allow user to pass in db url via config
|
Allow user to pass in db url via config
|
||||||
|
@ -459,6 +473,7 @@ def run_server(
|
||||||
|
|
||||||
proxy_config = ProxyConfig()
|
proxy_config = ProxyConfig()
|
||||||
_config = asyncio.run(proxy_config.get_config(config_file_path=config))
|
_config = asyncio.run(proxy_config.get_config(config_file_path=config))
|
||||||
|
|
||||||
### LITELLM SETTINGS ###
|
### LITELLM SETTINGS ###
|
||||||
litellm_settings = _config.get("litellm_settings", None)
|
litellm_settings = _config.get("litellm_settings", None)
|
||||||
if (
|
if (
|
||||||
|
|
|
@ -20,11 +20,23 @@ model_list:
|
||||||
|
|
||||||
general_settings:
|
general_settings:
|
||||||
master_key: sk-1234
|
master_key: sk-1234
|
||||||
alerting: ["slack", "email"]
|
pass_through_endpoints:
|
||||||
public_routes: ["LiteLLMRoutes.public_routes", "/spend/calculate"]
|
- path: "/v1/rerank"
|
||||||
|
target: "https://api.cohere.com/v1/rerank"
|
||||||
|
auth: true # 👈 Key change to use LiteLLM Auth / Keys
|
||||||
|
headers:
|
||||||
|
Authorization: "bearer os.environ/COHERE_API_KEY"
|
||||||
|
content-type: application/json
|
||||||
|
accept: application/json
|
||||||
|
- path: "/api/public/ingestion"
|
||||||
|
target: "https://us.cloud.langfuse.com/api/public/ingestion"
|
||||||
|
auth: true
|
||||||
|
headers:
|
||||||
|
LANGFUSE_PUBLIC_KEY: "os.environ/LANGFUSE_DEV_PUBLIC_KEY"
|
||||||
|
LANGFUSE_SECRET_KEY: "os.environ/LANGFUSE_DEV_SK_KEY"
|
||||||
|
|
||||||
litellm_settings:
|
litellm_settings:
|
||||||
|
return_response_headers: true
|
||||||
success_callback: ["prometheus"]
|
success_callback: ["prometheus"]
|
||||||
callbacks: ["otel", "hide_secrets"]
|
callbacks: ["otel", "hide_secrets"]
|
||||||
failure_callback: ["prometheus"]
|
failure_callback: ["prometheus"]
|
||||||
|
@ -34,6 +46,5 @@ litellm_settings:
|
||||||
- user
|
- user
|
||||||
- metadata
|
- metadata
|
||||||
- metadata.generation_name
|
- metadata.generation_name
|
||||||
cache: True
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -161,6 +161,9 @@ from litellm.proxy.management_endpoints.key_management_endpoints import (
|
||||||
router as key_management_router,
|
router as key_management_router,
|
||||||
)
|
)
|
||||||
from litellm.proxy.management_endpoints.team_endpoints import router as team_router
|
from litellm.proxy.management_endpoints.team_endpoints import router as team_router
|
||||||
|
from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
|
||||||
|
initialize_pass_through_endpoints,
|
||||||
|
)
|
||||||
from litellm.proxy.secret_managers.aws_secret_manager import (
|
from litellm.proxy.secret_managers.aws_secret_manager import (
|
||||||
load_aws_kms,
|
load_aws_kms,
|
||||||
load_aws_secret_manager,
|
load_aws_secret_manager,
|
||||||
|
@ -590,7 +593,7 @@ async def _PROXY_failure_handler(
|
||||||
_model_id = _metadata.get("model_info", {}).get("id", "")
|
_model_id = _metadata.get("model_info", {}).get("id", "")
|
||||||
_model_group = _metadata.get("model_group", "")
|
_model_group = _metadata.get("model_group", "")
|
||||||
api_base = litellm.get_api_base(model=_model, optional_params=_litellm_params)
|
api_base = litellm.get_api_base(model=_model, optional_params=_litellm_params)
|
||||||
_exception_string = str(_exception)[:500]
|
_exception_string = str(_exception)
|
||||||
|
|
||||||
error_log = LiteLLM_ErrorLogs(
|
error_log = LiteLLM_ErrorLogs(
|
||||||
request_id=str(uuid.uuid4()),
|
request_id=str(uuid.uuid4()),
|
||||||
|
@ -1856,6 +1859,11 @@ class ProxyConfig:
|
||||||
user_custom_key_generate = get_instance_fn(
|
user_custom_key_generate = get_instance_fn(
|
||||||
value=custom_key_generate, config_file_path=config_file_path
|
value=custom_key_generate, config_file_path=config_file_path
|
||||||
)
|
)
|
||||||
|
## pass through endpoints
|
||||||
|
if general_settings.get("pass_through_endpoints", None) is not None:
|
||||||
|
await initialize_pass_through_endpoints(
|
||||||
|
pass_through_endpoints=general_settings["pass_through_endpoints"]
|
||||||
|
)
|
||||||
## dynamodb
|
## dynamodb
|
||||||
database_type = general_settings.get("database_type", None)
|
database_type = general_settings.get("database_type", None)
|
||||||
if database_type is not None and (
|
if database_type is not None and (
|
||||||
|
@ -7503,7 +7511,9 @@ async def login(request: Request):
|
||||||
# Non SSO -> If user is using UI_USERNAME and UI_PASSWORD they are Proxy admin
|
# Non SSO -> If user is using UI_USERNAME and UI_PASSWORD they are Proxy admin
|
||||||
user_role = LitellmUserRoles.PROXY_ADMIN
|
user_role = LitellmUserRoles.PROXY_ADMIN
|
||||||
user_id = username
|
user_id = username
|
||||||
key_user_id = user_id
|
|
||||||
|
# we want the key created to have PROXY_ADMIN_PERMISSIONS
|
||||||
|
key_user_id = litellm_proxy_admin_name
|
||||||
if (
|
if (
|
||||||
os.getenv("PROXY_ADMIN_ID", None) is not None
|
os.getenv("PROXY_ADMIN_ID", None) is not None
|
||||||
and os.environ["PROXY_ADMIN_ID"] == user_id
|
and os.environ["PROXY_ADMIN_ID"] == user_id
|
||||||
|
@ -7523,7 +7533,17 @@ async def login(request: Request):
|
||||||
if os.getenv("DATABASE_URL") is not None:
|
if os.getenv("DATABASE_URL") is not None:
|
||||||
response = await generate_key_helper_fn(
|
response = await generate_key_helper_fn(
|
||||||
request_type="key",
|
request_type="key",
|
||||||
**{"user_role": LitellmUserRoles.PROXY_ADMIN, "duration": "2hr", "key_max_budget": 5, "models": [], "aliases": {}, "config": {}, "spend": 0, "user_id": key_user_id, "team_id": "litellm-dashboard"}, # type: ignore
|
**{
|
||||||
|
"user_role": LitellmUserRoles.PROXY_ADMIN,
|
||||||
|
"duration": "2hr",
|
||||||
|
"key_max_budget": 5,
|
||||||
|
"models": [],
|
||||||
|
"aliases": {},
|
||||||
|
"config": {},
|
||||||
|
"spend": 0,
|
||||||
|
"user_id": key_user_id,
|
||||||
|
"team_id": "litellm-dashboard",
|
||||||
|
}, # type: ignore
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ProxyException(
|
raise ProxyException(
|
||||||
|
|
|
@ -8,9 +8,13 @@ Requires:
|
||||||
* `pip install boto3>=1.28.57`
|
* `pip install boto3>=1.28.57`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import litellm
|
import ast
|
||||||
|
import base64
|
||||||
import os
|
import os
|
||||||
from typing import Optional
|
import re
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
import litellm
|
||||||
from litellm.proxy._types import KeyManagementSystem
|
from litellm.proxy._types import KeyManagementSystem
|
||||||
|
|
||||||
|
|
||||||
|
@ -57,3 +61,99 @@ def load_aws_kms(use_aws_kms: Optional[bool]):
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
class AWSKeyManagementService_V2:
|
||||||
|
"""
|
||||||
|
V2 Clean Class for decrypting keys from AWS KeyManagementService
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.validate_environment()
|
||||||
|
self.kms_client = self.load_aws_kms(use_aws_kms=True)
|
||||||
|
|
||||||
|
def validate_environment(
|
||||||
|
self,
|
||||||
|
):
|
||||||
|
if "AWS_REGION_NAME" not in os.environ:
|
||||||
|
raise ValueError("Missing required environment variable - AWS_REGION_NAME")
|
||||||
|
|
||||||
|
## CHECK IF LICENSE IN ENV ## - premium feature
|
||||||
|
if os.getenv("LITELLM_LICENSE", None) is None:
|
||||||
|
raise ValueError(
|
||||||
|
"AWSKeyManagementService V2 is an Enterprise Feature. Please add a valid LITELLM_LICENSE to your envionment."
|
||||||
|
)
|
||||||
|
|
||||||
|
def load_aws_kms(self, use_aws_kms: Optional[bool]):
|
||||||
|
if use_aws_kms is None or use_aws_kms is False:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
import boto3
|
||||||
|
|
||||||
|
validate_environment()
|
||||||
|
|
||||||
|
# Create a Secrets Manager client
|
||||||
|
kms_client = boto3.client("kms", region_name=os.getenv("AWS_REGION_NAME"))
|
||||||
|
|
||||||
|
return kms_client
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def decrypt_value(self, secret_name: str) -> Any:
|
||||||
|
if self.kms_client is None:
|
||||||
|
raise ValueError("kms_client is None")
|
||||||
|
encrypted_value = os.getenv(secret_name, None)
|
||||||
|
if encrypted_value is None:
|
||||||
|
raise Exception(
|
||||||
|
"AWS KMS - Encrypted Value of Key={} is None".format(secret_name)
|
||||||
|
)
|
||||||
|
if isinstance(encrypted_value, str) and encrypted_value.startswith("aws_kms/"):
|
||||||
|
encrypted_value = encrypted_value.replace("aws_kms/", "")
|
||||||
|
|
||||||
|
# Decode the base64 encoded ciphertext
|
||||||
|
ciphertext_blob = base64.b64decode(encrypted_value)
|
||||||
|
|
||||||
|
# Set up the parameters for the decrypt call
|
||||||
|
params = {"CiphertextBlob": ciphertext_blob}
|
||||||
|
# Perform the decryption
|
||||||
|
response = self.kms_client.decrypt(**params)
|
||||||
|
|
||||||
|
# Extract and decode the plaintext
|
||||||
|
plaintext = response["Plaintext"]
|
||||||
|
secret = plaintext.decode("utf-8")
|
||||||
|
if isinstance(secret, str):
|
||||||
|
secret = secret.strip()
|
||||||
|
try:
|
||||||
|
secret_value_as_bool = ast.literal_eval(secret)
|
||||||
|
if isinstance(secret_value_as_bool, bool):
|
||||||
|
return secret_value_as_bool
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return secret
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
- look for all values in the env with `aws_kms/<hashed_key>`
|
||||||
|
- decrypt keys
|
||||||
|
- rewrite env var with decrypted key (). Note: this environment variable will only be available to the current process and any child processes spawned from it. Once the Python script ends, the environment variable will not persist.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def decrypt_env_var() -> Dict[str, Any]:
|
||||||
|
# setup client class
|
||||||
|
aws_kms = AWSKeyManagementService_V2()
|
||||||
|
# iterate through env - for `aws_kms/`
|
||||||
|
new_values = {}
|
||||||
|
for k, v in os.environ.items():
|
||||||
|
if (
|
||||||
|
k is not None
|
||||||
|
and isinstance(k, str)
|
||||||
|
and k.lower().startswith("litellm_secret_aws_kms")
|
||||||
|
) or (v is not None and isinstance(v, str) and v.startswith("aws_kms/")):
|
||||||
|
decrypted_value = aws_kms.decrypt_value(secret_name=k)
|
||||||
|
# reset env var
|
||||||
|
k = re.sub("litellm_secret_aws_kms_", "", k, flags=re.IGNORECASE)
|
||||||
|
new_values[k] = decrypted_value
|
||||||
|
|
||||||
|
return new_values
|
||||||
|
|
|
@ -817,9 +817,9 @@ async def get_global_spend_report(
|
||||||
default=None,
|
default=None,
|
||||||
description="Time till which to view spend",
|
description="Time till which to view spend",
|
||||||
),
|
),
|
||||||
group_by: Optional[Literal["team", "customer"]] = fastapi.Query(
|
group_by: Optional[Literal["team", "customer", "api_key"]] = fastapi.Query(
|
||||||
default="team",
|
default="team",
|
||||||
description="Group spend by internal team or customer",
|
description="Group spend by internal team or customer or api_key",
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -860,7 +860,7 @@ async def get_global_spend_report(
|
||||||
start_date_obj = datetime.strptime(start_date, "%Y-%m-%d")
|
start_date_obj = datetime.strptime(start_date, "%Y-%m-%d")
|
||||||
end_date_obj = datetime.strptime(end_date, "%Y-%m-%d")
|
end_date_obj = datetime.strptime(end_date, "%Y-%m-%d")
|
||||||
|
|
||||||
from litellm.proxy.proxy_server import prisma_client
|
from litellm.proxy.proxy_server import premium_user, prisma_client
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if prisma_client is None:
|
if prisma_client is None:
|
||||||
|
@ -868,6 +868,12 @@ async def get_global_spend_report(
|
||||||
f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys"
|
f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if premium_user is not True:
|
||||||
|
verbose_proxy_logger.debug("accessing /spend/report but not a premium user")
|
||||||
|
raise ValueError(
|
||||||
|
"/spend/report endpoint " + CommonProxyErrors.not_premium_user.value
|
||||||
|
)
|
||||||
|
|
||||||
if group_by == "team":
|
if group_by == "team":
|
||||||
# first get data from spend logs -> SpendByModelApiKey
|
# first get data from spend logs -> SpendByModelApiKey
|
||||||
# then read data from "SpendByModelApiKey" to format the response obj
|
# then read data from "SpendByModelApiKey" to format the response obj
|
||||||
|
@ -992,6 +998,48 @@ async def get_global_spend_report(
|
||||||
return []
|
return []
|
||||||
|
|
||||||
return db_response
|
return db_response
|
||||||
|
elif group_by == "api_key":
|
||||||
|
sql_query = """
|
||||||
|
WITH SpendByModelApiKey AS (
|
||||||
|
SELECT
|
||||||
|
sl.api_key,
|
||||||
|
sl.model,
|
||||||
|
SUM(sl.spend) AS model_cost,
|
||||||
|
SUM(sl.prompt_tokens) AS model_input_tokens,
|
||||||
|
SUM(sl.completion_tokens) AS model_output_tokens
|
||||||
|
FROM
|
||||||
|
"LiteLLM_SpendLogs" sl
|
||||||
|
WHERE
|
||||||
|
sl."startTime" BETWEEN $1::date AND $2::date
|
||||||
|
GROUP BY
|
||||||
|
sl.api_key,
|
||||||
|
sl.model
|
||||||
|
)
|
||||||
|
SELECT
|
||||||
|
api_key,
|
||||||
|
SUM(model_cost) AS total_cost,
|
||||||
|
SUM(model_input_tokens) AS total_input_tokens,
|
||||||
|
SUM(model_output_tokens) AS total_output_tokens,
|
||||||
|
jsonb_agg(jsonb_build_object(
|
||||||
|
'model', model,
|
||||||
|
'total_cost', model_cost,
|
||||||
|
'total_input_tokens', model_input_tokens,
|
||||||
|
'total_output_tokens', model_output_tokens
|
||||||
|
)) AS model_details
|
||||||
|
FROM
|
||||||
|
SpendByModelApiKey
|
||||||
|
GROUP BY
|
||||||
|
api_key
|
||||||
|
ORDER BY
|
||||||
|
total_cost DESC;
|
||||||
|
"""
|
||||||
|
db_response = await prisma_client.db.query_raw(
|
||||||
|
sql_query, start_date_obj, end_date_obj
|
||||||
|
)
|
||||||
|
if db_response is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
return db_response
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
14
litellm/proxy/tests/test_pass_through_langfuse.py
Normal file
14
litellm/proxy/tests/test_pass_through_langfuse.py
Normal file
|
@ -0,0 +1,14 @@
|
||||||
|
from langfuse import Langfuse
|
||||||
|
|
||||||
|
langfuse = Langfuse(
|
||||||
|
host="http://localhost:4000",
|
||||||
|
public_key="anything",
|
||||||
|
secret_key="anything",
|
||||||
|
)
|
||||||
|
|
||||||
|
print("sending langfuse trace request")
|
||||||
|
trace = langfuse.trace(name="test-trace-litellm-proxy-passthrough")
|
||||||
|
print("flushing langfuse request")
|
||||||
|
langfuse.flush()
|
||||||
|
|
||||||
|
print("flushed langfuse request")
|
|
@ -105,7 +105,9 @@ class Router:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_list: Optional[List[Union[DeploymentTypedDict, Dict]]] = None,
|
model_list: Optional[
|
||||||
|
Union[List[DeploymentTypedDict], List[Dict[str, Any]]]
|
||||||
|
] = None,
|
||||||
## ASSISTANTS API ##
|
## ASSISTANTS API ##
|
||||||
assistants_config: Optional[AssistantsTypedDict] = None,
|
assistants_config: Optional[AssistantsTypedDict] = None,
|
||||||
## CACHING ##
|
## CACHING ##
|
||||||
|
@ -154,6 +156,7 @@ class Router:
|
||||||
cooldown_time: Optional[
|
cooldown_time: Optional[
|
||||||
float
|
float
|
||||||
] = None, # (seconds) time to cooldown a deployment after failure
|
] = None, # (seconds) time to cooldown a deployment after failure
|
||||||
|
disable_cooldowns: Optional[bool] = None,
|
||||||
routing_strategy: Literal[
|
routing_strategy: Literal[
|
||||||
"simple-shuffle",
|
"simple-shuffle",
|
||||||
"least-busy",
|
"least-busy",
|
||||||
|
@ -305,6 +308,7 @@ class Router:
|
||||||
|
|
||||||
self.allowed_fails = allowed_fails or litellm.allowed_fails
|
self.allowed_fails = allowed_fails or litellm.allowed_fails
|
||||||
self.cooldown_time = cooldown_time or 60
|
self.cooldown_time = cooldown_time or 60
|
||||||
|
self.disable_cooldowns = disable_cooldowns
|
||||||
self.failed_calls = (
|
self.failed_calls = (
|
||||||
InMemoryCache()
|
InMemoryCache()
|
||||||
) # cache to track failed call per deployment, if num failed calls within 1 minute > allowed fails, then add it to cooldown
|
) # cache to track failed call per deployment, if num failed calls within 1 minute > allowed fails, then add it to cooldown
|
||||||
|
@ -2988,6 +2992,8 @@ class Router:
|
||||||
|
|
||||||
the exception is not one that should be immediately retried (e.g. 401)
|
the exception is not one that should be immediately retried (e.g. 401)
|
||||||
"""
|
"""
|
||||||
|
if self.disable_cooldowns is True:
|
||||||
|
return
|
||||||
|
|
||||||
if deployment is None:
|
if deployment is None:
|
||||||
return
|
return
|
||||||
|
@ -3028,24 +3034,50 @@ class Router:
|
||||||
exception_status = 500
|
exception_status = 500
|
||||||
_should_retry = litellm._should_retry(status_code=exception_status)
|
_should_retry = litellm._should_retry(status_code=exception_status)
|
||||||
|
|
||||||
if updated_fails > allowed_fails or _should_retry == False:
|
if updated_fails > allowed_fails or _should_retry is False:
|
||||||
# get the current cooldown list for that minute
|
# get the current cooldown list for that minute
|
||||||
cooldown_key = f"{current_minute}:cooldown_models" # group cooldown models by minute to reduce number of redis calls
|
cooldown_key = f"{current_minute}:cooldown_models" # group cooldown models by minute to reduce number of redis calls
|
||||||
cached_value = self.cache.get_cache(key=cooldown_key)
|
cached_value = self.cache.get_cache(
|
||||||
|
key=cooldown_key
|
||||||
|
) # [(deployment_id, {last_error_str, last_error_status_code})]
|
||||||
|
|
||||||
|
cached_value_deployment_ids = []
|
||||||
|
if (
|
||||||
|
cached_value is not None
|
||||||
|
and isinstance(cached_value, list)
|
||||||
|
and len(cached_value) > 0
|
||||||
|
and isinstance(cached_value[0], tuple)
|
||||||
|
):
|
||||||
|
cached_value_deployment_ids = [cv[0] for cv in cached_value]
|
||||||
verbose_router_logger.debug(f"adding {deployment} to cooldown models")
|
verbose_router_logger.debug(f"adding {deployment} to cooldown models")
|
||||||
# update value
|
# update value
|
||||||
try:
|
if cached_value is not None and len(cached_value_deployment_ids) > 0:
|
||||||
if deployment in cached_value:
|
if deployment in cached_value_deployment_ids:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
cached_value = cached_value + [deployment]
|
cached_value = cached_value + [
|
||||||
|
(
|
||||||
|
deployment,
|
||||||
|
{
|
||||||
|
"Exception Received": str(original_exception),
|
||||||
|
"Status Code": str(exception_status),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
]
|
||||||
# save updated value
|
# save updated value
|
||||||
self.cache.set_cache(
|
self.cache.set_cache(
|
||||||
value=cached_value, key=cooldown_key, ttl=cooldown_time
|
value=cached_value, key=cooldown_key, ttl=cooldown_time
|
||||||
)
|
)
|
||||||
except:
|
else:
|
||||||
cached_value = [deployment]
|
cached_value = [
|
||||||
|
(
|
||||||
|
deployment,
|
||||||
|
{
|
||||||
|
"Exception Received": str(original_exception),
|
||||||
|
"Status Code": str(exception_status),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
]
|
||||||
# save updated value
|
# save updated value
|
||||||
self.cache.set_cache(
|
self.cache.set_cache(
|
||||||
value=cached_value, key=cooldown_key, ttl=cooldown_time
|
value=cached_value, key=cooldown_key, ttl=cooldown_time
|
||||||
|
@ -3061,7 +3093,33 @@ class Router:
|
||||||
key=deployment, value=updated_fails, ttl=cooldown_time
|
key=deployment, value=updated_fails, ttl=cooldown_time
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _async_get_cooldown_deployments(self):
|
async def _async_get_cooldown_deployments(self) -> List[str]:
|
||||||
|
"""
|
||||||
|
Async implementation of '_get_cooldown_deployments'
|
||||||
|
"""
|
||||||
|
dt = get_utc_datetime()
|
||||||
|
current_minute = dt.strftime("%H-%M")
|
||||||
|
# get the current cooldown list for that minute
|
||||||
|
cooldown_key = f"{current_minute}:cooldown_models"
|
||||||
|
|
||||||
|
# ----------------------
|
||||||
|
# Return cooldown models
|
||||||
|
# ----------------------
|
||||||
|
cooldown_models = await self.cache.async_get_cache(key=cooldown_key) or []
|
||||||
|
|
||||||
|
cached_value_deployment_ids = []
|
||||||
|
if (
|
||||||
|
cooldown_models is not None
|
||||||
|
and isinstance(cooldown_models, list)
|
||||||
|
and len(cooldown_models) > 0
|
||||||
|
and isinstance(cooldown_models[0], tuple)
|
||||||
|
):
|
||||||
|
cached_value_deployment_ids = [cv[0] for cv in cooldown_models]
|
||||||
|
|
||||||
|
verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}")
|
||||||
|
return cached_value_deployment_ids
|
||||||
|
|
||||||
|
async def _async_get_cooldown_deployments_with_debug_info(self) -> List[tuple]:
|
||||||
"""
|
"""
|
||||||
Async implementation of '_get_cooldown_deployments'
|
Async implementation of '_get_cooldown_deployments'
|
||||||
"""
|
"""
|
||||||
|
@ -3078,7 +3136,7 @@ class Router:
|
||||||
verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}")
|
verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}")
|
||||||
return cooldown_models
|
return cooldown_models
|
||||||
|
|
||||||
def _get_cooldown_deployments(self):
|
def _get_cooldown_deployments(self) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Get the list of models being cooled down for this minute
|
Get the list of models being cooled down for this minute
|
||||||
"""
|
"""
|
||||||
|
@ -3092,8 +3150,17 @@ class Router:
|
||||||
# ----------------------
|
# ----------------------
|
||||||
cooldown_models = self.cache.get_cache(key=cooldown_key) or []
|
cooldown_models = self.cache.get_cache(key=cooldown_key) or []
|
||||||
|
|
||||||
|
cached_value_deployment_ids = []
|
||||||
|
if (
|
||||||
|
cooldown_models is not None
|
||||||
|
and isinstance(cooldown_models, list)
|
||||||
|
and len(cooldown_models) > 0
|
||||||
|
and isinstance(cooldown_models[0], tuple)
|
||||||
|
):
|
||||||
|
cached_value_deployment_ids = [cv[0] for cv in cooldown_models]
|
||||||
|
|
||||||
verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}")
|
verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}")
|
||||||
return cooldown_models
|
return cached_value_deployment_ids
|
||||||
|
|
||||||
def _get_healthy_deployments(self, model: str):
|
def _get_healthy_deployments(self, model: str):
|
||||||
_all_deployments: list = []
|
_all_deployments: list = []
|
||||||
|
@ -3970,16 +4037,36 @@ class Router:
|
||||||
|
|
||||||
Augment litellm info with additional params set in `model_info`.
|
Augment litellm info with additional params set in `model_info`.
|
||||||
|
|
||||||
|
For azure models, ignore the `model:`. Only set max tokens, cost values if base_model is set.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
- ModelInfo - If found -> typed dict with max tokens, input cost, etc.
|
- ModelInfo - If found -> typed dict with max tokens, input cost, etc.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
- ValueError -> If model is not mapped yet
|
||||||
"""
|
"""
|
||||||
## SET MODEL NAME
|
## GET BASE MODEL
|
||||||
base_model = deployment.get("model_info", {}).get("base_model", None)
|
base_model = deployment.get("model_info", {}).get("base_model", None)
|
||||||
if base_model is None:
|
if base_model is None:
|
||||||
base_model = deployment.get("litellm_params", {}).get("base_model", None)
|
base_model = deployment.get("litellm_params", {}).get("base_model", None)
|
||||||
model = base_model or deployment.get("litellm_params", {}).get("model", None)
|
|
||||||
|
|
||||||
## GET LITELLM MODEL INFO
|
model = base_model
|
||||||
|
|
||||||
|
## GET PROVIDER
|
||||||
|
_model, custom_llm_provider, _, _ = litellm.get_llm_provider(
|
||||||
|
model=deployment.get("litellm_params", {}).get("model", ""),
|
||||||
|
litellm_params=LiteLLM_Params(**deployment.get("litellm_params", {})),
|
||||||
|
)
|
||||||
|
|
||||||
|
## SET MODEL TO 'model=' - if base_model is None + not azure
|
||||||
|
if custom_llm_provider == "azure" and base_model is None:
|
||||||
|
verbose_router_logger.error(
|
||||||
|
"Could not identify azure model. Set azure 'base_model' for accurate max tokens, cost tracking, etc.- https://docs.litellm.ai/docs/proxy/cost_tracking#spend-tracking-for-azure-openai-models"
|
||||||
|
)
|
||||||
|
elif custom_llm_provider != "azure":
|
||||||
|
model = _model
|
||||||
|
|
||||||
|
## GET LITELLM MODEL INFO - raises exception, if model is not mapped
|
||||||
model_info = litellm.get_model_info(model=model)
|
model_info = litellm.get_model_info(model=model)
|
||||||
|
|
||||||
## CHECK USER SET MODEL INFO
|
## CHECK USER SET MODEL INFO
|
||||||
|
@ -4365,7 +4452,7 @@ class Router:
|
||||||
"""
|
"""
|
||||||
Filter out model in model group, if:
|
Filter out model in model group, if:
|
||||||
|
|
||||||
- model context window < message length
|
- model context window < message length. For azure openai models, requires 'base_model' is set. - https://docs.litellm.ai/docs/proxy/cost_tracking#spend-tracking-for-azure-openai-models
|
||||||
- filter models above rpm limits
|
- filter models above rpm limits
|
||||||
- if region given, filter out models not in that region / unknown region
|
- if region given, filter out models not in that region / unknown region
|
||||||
- [TODO] function call and model doesn't support function calling
|
- [TODO] function call and model doesn't support function calling
|
||||||
|
@ -4382,6 +4469,11 @@ class Router:
|
||||||
try:
|
try:
|
||||||
input_tokens = litellm.token_counter(messages=messages)
|
input_tokens = litellm.token_counter(messages=messages)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
verbose_router_logger.error(
|
||||||
|
"litellm.router.py::_pre_call_checks: failed to count tokens. Returning initial list of deployments. Got - {}".format(
|
||||||
|
str(e)
|
||||||
|
)
|
||||||
|
)
|
||||||
return _returned_deployments
|
return _returned_deployments
|
||||||
|
|
||||||
_context_window_error = False
|
_context_window_error = False
|
||||||
|
@ -4425,7 +4517,7 @@ class Router:
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_router_logger.debug("An error occurs - {}".format(str(e)))
|
verbose_router_logger.error("An error occurs - {}".format(str(e)))
|
||||||
|
|
||||||
_litellm_params = deployment.get("litellm_params", {})
|
_litellm_params = deployment.get("litellm_params", {})
|
||||||
model_id = deployment.get("model_info", {}).get("id", "")
|
model_id = deployment.get("model_info", {}).get("id", "")
|
||||||
|
@ -4686,7 +4778,7 @@ class Router:
|
||||||
if _allowed_model_region is None:
|
if _allowed_model_region is None:
|
||||||
_allowed_model_region = "n/a"
|
_allowed_model_region = "n/a"
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"{RouterErrors.no_deployments_available.value}, Try again in {self.cooldown_time} seconds. Passed model={model}. pre-call-checks={self.enable_pre_call_checks}, allowed_model_region={_allowed_model_region}"
|
f"{RouterErrors.no_deployments_available.value}, Try again in {self.cooldown_time} seconds. Passed model={model}. pre-call-checks={self.enable_pre_call_checks}, allowed_model_region={_allowed_model_region}, cooldown_list={await self._async_get_cooldown_deployments_with_debug_info()}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
|
|
|
@ -880,6 +880,208 @@ Using this JSON schema:
|
||||||
mock_call.assert_called_once()
|
mock_call.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def vertex_httpx_mock_post_valid_response(*args, **kwargs):
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.headers = {"Content-Type": "application/json"}
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"candidates": [
|
||||||
|
{
|
||||||
|
"content": {
|
||||||
|
"role": "model",
|
||||||
|
"parts": [
|
||||||
|
{
|
||||||
|
"text": '[{"recipe_name": "Chocolate Chip Cookies"}, {"recipe_name": "Oatmeal Raisin Cookies"}, {"recipe_name": "Peanut Butter Cookies"}, {"recipe_name": "Sugar Cookies"}, {"recipe_name": "Snickerdoodles"}]\n'
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"finishReason": "STOP",
|
||||||
|
"safetyRatings": [
|
||||||
|
{
|
||||||
|
"category": "HARM_CATEGORY_HATE_SPEECH",
|
||||||
|
"probability": "NEGLIGIBLE",
|
||||||
|
"probabilityScore": 0.09790669,
|
||||||
|
"severity": "HARM_SEVERITY_NEGLIGIBLE",
|
||||||
|
"severityScore": 0.11736965,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||||
|
"probability": "NEGLIGIBLE",
|
||||||
|
"probabilityScore": 0.1261379,
|
||||||
|
"severity": "HARM_SEVERITY_NEGLIGIBLE",
|
||||||
|
"severityScore": 0.08601588,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"category": "HARM_CATEGORY_HARASSMENT",
|
||||||
|
"probability": "NEGLIGIBLE",
|
||||||
|
"probabilityScore": 0.083441176,
|
||||||
|
"severity": "HARM_SEVERITY_NEGLIGIBLE",
|
||||||
|
"severityScore": 0.0355444,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||||
|
"probability": "NEGLIGIBLE",
|
||||||
|
"probabilityScore": 0.071981624,
|
||||||
|
"severity": "HARM_SEVERITY_NEGLIGIBLE",
|
||||||
|
"severityScore": 0.08108212,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usageMetadata": {
|
||||||
|
"promptTokenCount": 60,
|
||||||
|
"candidatesTokenCount": 55,
|
||||||
|
"totalTokenCount": 115,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return mock_response
|
||||||
|
|
||||||
|
|
||||||
|
def vertex_httpx_mock_post_invalid_schema_response(*args, **kwargs):
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.headers = {"Content-Type": "application/json"}
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"candidates": [
|
||||||
|
{
|
||||||
|
"content": {
|
||||||
|
"role": "model",
|
||||||
|
"parts": [
|
||||||
|
{"text": '[{"recipe_world": "Chocolate Chip Cookies"}]\n'}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"finishReason": "STOP",
|
||||||
|
"safetyRatings": [
|
||||||
|
{
|
||||||
|
"category": "HARM_CATEGORY_HATE_SPEECH",
|
||||||
|
"probability": "NEGLIGIBLE",
|
||||||
|
"probabilityScore": 0.09790669,
|
||||||
|
"severity": "HARM_SEVERITY_NEGLIGIBLE",
|
||||||
|
"severityScore": 0.11736965,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||||
|
"probability": "NEGLIGIBLE",
|
||||||
|
"probabilityScore": 0.1261379,
|
||||||
|
"severity": "HARM_SEVERITY_NEGLIGIBLE",
|
||||||
|
"severityScore": 0.08601588,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"category": "HARM_CATEGORY_HARASSMENT",
|
||||||
|
"probability": "NEGLIGIBLE",
|
||||||
|
"probabilityScore": 0.083441176,
|
||||||
|
"severity": "HARM_SEVERITY_NEGLIGIBLE",
|
||||||
|
"severityScore": 0.0355444,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||||
|
"probability": "NEGLIGIBLE",
|
||||||
|
"probabilityScore": 0.071981624,
|
||||||
|
"severity": "HARM_SEVERITY_NEGLIGIBLE",
|
||||||
|
"severityScore": 0.08108212,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usageMetadata": {
|
||||||
|
"promptTokenCount": 60,
|
||||||
|
"candidatesTokenCount": 55,
|
||||||
|
"totalTokenCount": 115,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return mock_response
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model, vertex_location, supports_response_schema",
|
||||||
|
[
|
||||||
|
("vertex_ai_beta/gemini-1.5-pro-001", "us-central1", True),
|
||||||
|
("vertex_ai_beta/gemini-1.5-flash", "us-central1", False),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"invalid_response",
|
||||||
|
[True, False],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"enforce_validation",
|
||||||
|
[True, False],
|
||||||
|
)
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_gemini_pro_json_schema_args_sent_httpx(
|
||||||
|
model,
|
||||||
|
supports_response_schema,
|
||||||
|
vertex_location,
|
||||||
|
invalid_response,
|
||||||
|
enforce_validation,
|
||||||
|
):
|
||||||
|
load_vertex_ai_credentials()
|
||||||
|
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
|
||||||
|
litellm.model_cost = litellm.get_model_cost_map(url="")
|
||||||
|
|
||||||
|
litellm.set_verbose = True
|
||||||
|
messages = [{"role": "user", "content": "List 5 cookie recipes"}]
|
||||||
|
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||||
|
|
||||||
|
response_schema = {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"recipe_name": {
|
||||||
|
"type": "string",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["recipe_name"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
client = HTTPHandler()
|
||||||
|
|
||||||
|
httpx_response = MagicMock()
|
||||||
|
if invalid_response is True:
|
||||||
|
httpx_response.side_effect = vertex_httpx_mock_post_invalid_schema_response
|
||||||
|
else:
|
||||||
|
httpx_response.side_effect = vertex_httpx_mock_post_valid_response
|
||||||
|
with patch.object(client, "post", new=httpx_response) as mock_call:
|
||||||
|
try:
|
||||||
|
_ = completion(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
response_format={
|
||||||
|
"type": "json_object",
|
||||||
|
"response_schema": response_schema,
|
||||||
|
"enforce_validation": enforce_validation,
|
||||||
|
},
|
||||||
|
vertex_location=vertex_location,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
if invalid_response is True and enforce_validation is True:
|
||||||
|
pytest.fail("Expected this to fail")
|
||||||
|
except litellm.JSONSchemaValidationError as e:
|
||||||
|
if invalid_response is False and "claude-3" not in model:
|
||||||
|
pytest.fail("Expected this to pass. Got={}".format(e))
|
||||||
|
|
||||||
|
mock_call.assert_called_once()
|
||||||
|
print(mock_call.call_args.kwargs)
|
||||||
|
print(mock_call.call_args.kwargs["json"]["generationConfig"])
|
||||||
|
|
||||||
|
if supports_response_schema:
|
||||||
|
assert (
|
||||||
|
"response_schema"
|
||||||
|
in mock_call.call_args.kwargs["json"]["generationConfig"]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert (
|
||||||
|
"response_schema"
|
||||||
|
not in mock_call.call_args.kwargs["json"]["generationConfig"]
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
"Use this JSON schema:"
|
||||||
|
in mock_call.call_args.kwargs["json"]["contents"][0]["parts"][1]["text"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("provider", ["vertex_ai_beta"]) # "vertex_ai",
|
@pytest.mark.parametrize("provider", ["vertex_ai_beta"]) # "vertex_ai",
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_gemini_pro_httpx_custom_api_base(provider):
|
async def test_gemini_pro_httpx_custom_api_base(provider):
|
||||||
|
|
|
@ -25,6 +25,7 @@ from litellm import (
|
||||||
completion_cost,
|
completion_cost,
|
||||||
embedding,
|
embedding,
|
||||||
)
|
)
|
||||||
|
from litellm.llms.bedrock_httpx import BedrockLLM
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
|
|
||||||
# litellm.num_retries = 3
|
# litellm.num_retries = 3
|
||||||
|
@ -217,6 +218,234 @@ def test_completion_bedrock_claude_sts_client_auth():
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def bedrock_session_token_creds():
|
||||||
|
print("\ncalling oidc auto to get aws_session_token credentials")
|
||||||
|
import os
|
||||||
|
|
||||||
|
aws_region_name = os.environ["AWS_REGION_NAME"]
|
||||||
|
aws_session_token = os.environ.get("AWS_SESSION_TOKEN")
|
||||||
|
|
||||||
|
bllm = BedrockLLM()
|
||||||
|
if aws_session_token is not None:
|
||||||
|
# For local testing
|
||||||
|
creds = bllm.get_credentials(
|
||||||
|
aws_region_name=aws_region_name,
|
||||||
|
aws_access_key_id=os.environ["AWS_ACCESS_KEY_ID"],
|
||||||
|
aws_secret_access_key=os.environ["AWS_SECRET_ACCESS_KEY"],
|
||||||
|
aws_session_token=aws_session_token,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# For circle-ci testing
|
||||||
|
# aws_role_name = os.environ["AWS_TEMP_ROLE_NAME"]
|
||||||
|
# TODO: This is using ai.moda's IAM role, we should use LiteLLM's IAM role eventually
|
||||||
|
aws_role_name = (
|
||||||
|
"arn:aws:iam::335785316107:role/litellm-github-unit-tests-circleci"
|
||||||
|
)
|
||||||
|
aws_web_identity_token = "oidc/circleci_v2/"
|
||||||
|
|
||||||
|
creds = bllm.get_credentials(
|
||||||
|
aws_region_name=aws_region_name,
|
||||||
|
aws_web_identity_token=aws_web_identity_token,
|
||||||
|
aws_role_name=aws_role_name,
|
||||||
|
aws_session_name="my-test-session",
|
||||||
|
)
|
||||||
|
return creds
|
||||||
|
|
||||||
|
|
||||||
|
def process_stream_response(res, messages):
|
||||||
|
import types
|
||||||
|
|
||||||
|
if isinstance(res, litellm.utils.CustomStreamWrapper):
|
||||||
|
chunks = []
|
||||||
|
for part in res:
|
||||||
|
chunks.append(part)
|
||||||
|
text = part.choices[0].delta.content or ""
|
||||||
|
print(text, end="")
|
||||||
|
res = litellm.stream_chunk_builder(chunks, messages=messages)
|
||||||
|
else:
|
||||||
|
raise ValueError("Response object is not a streaming response")
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
os.environ.get("CIRCLE_OIDC_TOKEN_V2") is None,
|
||||||
|
reason="Cannot run without being in CircleCI Runner",
|
||||||
|
)
|
||||||
|
def test_completion_bedrock_claude_aws_session_token(bedrock_session_token_creds):
|
||||||
|
print("\ncalling bedrock claude with aws_session_token auth")
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
aws_region_name = os.environ["AWS_REGION_NAME"]
|
||||||
|
aws_access_key_id = bedrock_session_token_creds.access_key
|
||||||
|
aws_secret_access_key = bedrock_session_token_creds.secret_key
|
||||||
|
aws_session_token = bedrock_session_token_creds.token
|
||||||
|
|
||||||
|
try:
|
||||||
|
litellm.set_verbose = True
|
||||||
|
|
||||||
|
response_1 = completion(
|
||||||
|
model="bedrock/anthropic.claude-3-haiku-20240307-v1:0",
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=10,
|
||||||
|
temperature=0.1,
|
||||||
|
aws_region_name=aws_region_name,
|
||||||
|
aws_access_key_id=aws_access_key_id,
|
||||||
|
aws_secret_access_key=aws_secret_access_key,
|
||||||
|
aws_session_token=aws_session_token,
|
||||||
|
)
|
||||||
|
print(response_1)
|
||||||
|
assert len(response_1.choices) > 0
|
||||||
|
assert len(response_1.choices[0].message.content) > 0
|
||||||
|
|
||||||
|
# This second call is to verify that the cache isn't breaking anything
|
||||||
|
response_2 = completion(
|
||||||
|
model="bedrock/anthropic.claude-3-haiku-20240307-v1:0",
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=5,
|
||||||
|
temperature=0.2,
|
||||||
|
aws_region_name=aws_region_name,
|
||||||
|
aws_access_key_id=aws_access_key_id,
|
||||||
|
aws_secret_access_key=aws_secret_access_key,
|
||||||
|
aws_session_token=aws_session_token,
|
||||||
|
)
|
||||||
|
print(response_2)
|
||||||
|
assert len(response_2.choices) > 0
|
||||||
|
assert len(response_2.choices[0].message.content) > 0
|
||||||
|
|
||||||
|
# This third call is to verify that the cache isn't used for a different region
|
||||||
|
response_3 = completion(
|
||||||
|
model="bedrock/anthropic.claude-3-haiku-20240307-v1:0",
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=6,
|
||||||
|
temperature=0.3,
|
||||||
|
aws_region_name="us-east-1",
|
||||||
|
aws_access_key_id=aws_access_key_id,
|
||||||
|
aws_secret_access_key=aws_secret_access_key,
|
||||||
|
aws_session_token=aws_session_token,
|
||||||
|
)
|
||||||
|
print(response_3)
|
||||||
|
assert len(response_3.choices) > 0
|
||||||
|
assert len(response_3.choices[0].message.content) > 0
|
||||||
|
|
||||||
|
# This fourth call is to verify streaming api works
|
||||||
|
response_4 = completion(
|
||||||
|
model="bedrock/anthropic.claude-3-haiku-20240307-v1:0",
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=6,
|
||||||
|
temperature=0.3,
|
||||||
|
aws_region_name="us-east-1",
|
||||||
|
aws_access_key_id=aws_access_key_id,
|
||||||
|
aws_secret_access_key=aws_secret_access_key,
|
||||||
|
aws_session_token=aws_session_token,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
response_4 = process_stream_response(response_4, messages)
|
||||||
|
print(response_4)
|
||||||
|
assert len(response_4.choices) > 0
|
||||||
|
assert len(response_4.choices[0].message.content) > 0
|
||||||
|
|
||||||
|
except RateLimitError:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
os.environ.get("CIRCLE_OIDC_TOKEN_V2") is None,
|
||||||
|
reason="Cannot run without being in CircleCI Runner",
|
||||||
|
)
|
||||||
|
def test_completion_bedrock_claude_aws_bedrock_client(bedrock_session_token_creds):
|
||||||
|
print("\ncalling bedrock claude with aws_session_token auth")
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import boto3
|
||||||
|
from botocore.client import Config
|
||||||
|
|
||||||
|
aws_region_name = os.environ["AWS_REGION_NAME"]
|
||||||
|
aws_access_key_id = bedrock_session_token_creds.access_key
|
||||||
|
aws_secret_access_key = bedrock_session_token_creds.secret_key
|
||||||
|
aws_session_token = bedrock_session_token_creds.token
|
||||||
|
|
||||||
|
aws_bedrock_client_west = boto3.client(
|
||||||
|
service_name="bedrock-runtime",
|
||||||
|
region_name=aws_region_name,
|
||||||
|
aws_access_key_id=aws_access_key_id,
|
||||||
|
aws_secret_access_key=aws_secret_access_key,
|
||||||
|
aws_session_token=aws_session_token,
|
||||||
|
config=Config(read_timeout=600),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
litellm.set_verbose = True
|
||||||
|
|
||||||
|
response_1 = completion(
|
||||||
|
model="bedrock/anthropic.claude-3-haiku-20240307-v1:0",
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=10,
|
||||||
|
temperature=0.1,
|
||||||
|
aws_bedrock_client=aws_bedrock_client_west,
|
||||||
|
)
|
||||||
|
print(response_1)
|
||||||
|
assert len(response_1.choices) > 0
|
||||||
|
assert len(response_1.choices[0].message.content) > 0
|
||||||
|
|
||||||
|
# This second call is to verify that the cache isn't breaking anything
|
||||||
|
response_2 = completion(
|
||||||
|
model="bedrock/anthropic.claude-3-haiku-20240307-v1:0",
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=5,
|
||||||
|
temperature=0.2,
|
||||||
|
aws_bedrock_client=aws_bedrock_client_west,
|
||||||
|
)
|
||||||
|
print(response_2)
|
||||||
|
assert len(response_2.choices) > 0
|
||||||
|
assert len(response_2.choices[0].message.content) > 0
|
||||||
|
|
||||||
|
# This third call is to verify that the cache isn't used for a different region
|
||||||
|
aws_bedrock_client_east = boto3.client(
|
||||||
|
service_name="bedrock-runtime",
|
||||||
|
region_name="us-east-1",
|
||||||
|
aws_access_key_id=aws_access_key_id,
|
||||||
|
aws_secret_access_key=aws_secret_access_key,
|
||||||
|
aws_session_token=aws_session_token,
|
||||||
|
config=Config(read_timeout=600),
|
||||||
|
)
|
||||||
|
|
||||||
|
response_3 = completion(
|
||||||
|
model="bedrock/anthropic.claude-3-haiku-20240307-v1:0",
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=6,
|
||||||
|
temperature=0.3,
|
||||||
|
aws_bedrock_client=aws_bedrock_client_east,
|
||||||
|
)
|
||||||
|
print(response_3)
|
||||||
|
assert len(response_3.choices) > 0
|
||||||
|
assert len(response_3.choices[0].message.content) > 0
|
||||||
|
|
||||||
|
# This fourth call is to verify streaming api works
|
||||||
|
response_4 = completion(
|
||||||
|
model="bedrock/anthropic.claude-3-haiku-20240307-v1:0",
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=6,
|
||||||
|
temperature=0.3,
|
||||||
|
aws_bedrock_client=aws_bedrock_client_east,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
response_4 = process_stream_response(response_4, messages)
|
||||||
|
print(response_4)
|
||||||
|
assert len(response_4.choices) > 0
|
||||||
|
assert len(response_4.choices[0].message.content) > 0
|
||||||
|
|
||||||
|
except RateLimitError:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
# test_completion_bedrock_claude_sts_client_auth()
|
# test_completion_bedrock_claude_sts_client_auth()
|
||||||
|
|
||||||
|
|
||||||
|
@ -489,61 +718,6 @@ def test_completion_claude_3_base64():
|
||||||
pytest.fail(f"An exception occurred - {str(e)}")
|
pytest.fail(f"An exception occurred - {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
def test_provisioned_throughput():
|
|
||||||
try:
|
|
||||||
litellm.set_verbose = True
|
|
||||||
import io
|
|
||||||
import json
|
|
||||||
|
|
||||||
import botocore
|
|
||||||
import botocore.session
|
|
||||||
from botocore.stub import Stubber
|
|
||||||
|
|
||||||
bedrock_client = botocore.session.get_session().create_client(
|
|
||||||
"bedrock-runtime", region_name="us-east-1"
|
|
||||||
)
|
|
||||||
|
|
||||||
expected_params = {
|
|
||||||
"accept": "application/json",
|
|
||||||
"body": '{"prompt": "\\n\\nHuman: Hello, how are you?\\n\\nAssistant: ", '
|
|
||||||
'"max_tokens_to_sample": 256}',
|
|
||||||
"contentType": "application/json",
|
|
||||||
"modelId": "provisioned-model-arn",
|
|
||||||
}
|
|
||||||
response_from_bedrock = {
|
|
||||||
"body": io.StringIO(
|
|
||||||
json.dumps(
|
|
||||||
{
|
|
||||||
"completion": " Here is a short poem about the sky:",
|
|
||||||
"stop_reason": "max_tokens",
|
|
||||||
"stop": None,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
),
|
|
||||||
"contentType": "contentType",
|
|
||||||
"ResponseMetadata": {"HTTPStatusCode": 200},
|
|
||||||
}
|
|
||||||
|
|
||||||
with Stubber(bedrock_client) as stubber:
|
|
||||||
stubber.add_response(
|
|
||||||
"invoke_model",
|
|
||||||
service_response=response_from_bedrock,
|
|
||||||
expected_params=expected_params,
|
|
||||||
)
|
|
||||||
response = litellm.completion(
|
|
||||||
model="bedrock/anthropic.claude-instant-v1",
|
|
||||||
model_id="provisioned-model-arn",
|
|
||||||
messages=[{"content": "Hello, how are you?", "role": "user"}],
|
|
||||||
aws_bedrock_client=bedrock_client,
|
|
||||||
)
|
|
||||||
print("response stubbed", response)
|
|
||||||
except Exception as e:
|
|
||||||
pytest.fail(f"Error occurred: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
# test_provisioned_throughput()
|
|
||||||
|
|
||||||
|
|
||||||
def test_completion_bedrock_mistral_completion_auth():
|
def test_completion_bedrock_mistral_completion_auth():
|
||||||
print("calling bedrock mistral completion params auth")
|
print("calling bedrock mistral completion params auth")
|
||||||
import os
|
import os
|
||||||
|
@ -682,3 +856,56 @@ async def test_bedrock_custom_prompt_template():
|
||||||
prompt = json.loads(mock_client_post.call_args.kwargs["data"])["prompt"]
|
prompt = json.loads(mock_client_post.call_args.kwargs["data"])["prompt"]
|
||||||
assert prompt == "<|im_start|>user\nWhat's AWS?<|im_end|>"
|
assert prompt == "<|im_start|>user\nWhat's AWS?<|im_end|>"
|
||||||
mock_client_post.assert_called_once()
|
mock_client_post.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_completion_bedrock_external_client_region():
|
||||||
|
print("\ncalling bedrock claude external client auth")
|
||||||
|
import os
|
||||||
|
|
||||||
|
aws_access_key_id = os.environ["AWS_ACCESS_KEY_ID"]
|
||||||
|
aws_secret_access_key = os.environ["AWS_SECRET_ACCESS_KEY"]
|
||||||
|
aws_region_name = "us-east-1"
|
||||||
|
|
||||||
|
os.environ.pop("AWS_ACCESS_KEY_ID", None)
|
||||||
|
os.environ.pop("AWS_SECRET_ACCESS_KEY", None)
|
||||||
|
|
||||||
|
client = HTTPHandler()
|
||||||
|
|
||||||
|
try:
|
||||||
|
import boto3
|
||||||
|
|
||||||
|
litellm.set_verbose = True
|
||||||
|
|
||||||
|
bedrock = boto3.client(
|
||||||
|
service_name="bedrock-runtime",
|
||||||
|
region_name=aws_region_name,
|
||||||
|
aws_access_key_id=aws_access_key_id,
|
||||||
|
aws_secret_access_key=aws_secret_access_key,
|
||||||
|
endpoint_url=f"https://bedrock-runtime.{aws_region_name}.amazonaws.com",
|
||||||
|
)
|
||||||
|
with patch.object(client, "post", new=Mock()) as mock_client_post:
|
||||||
|
try:
|
||||||
|
response = completion(
|
||||||
|
model="bedrock/anthropic.claude-instant-v1",
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=10,
|
||||||
|
temperature=0.1,
|
||||||
|
aws_bedrock_client=bedrock,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
# Add any assertions here to check the response
|
||||||
|
print(response)
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
|
print(f"mock_client_post.call_args: {mock_client_post.call_args}")
|
||||||
|
assert "us-east-1" in mock_client_post.call_args.kwargs["url"]
|
||||||
|
|
||||||
|
mock_client_post.assert_called_once()
|
||||||
|
|
||||||
|
os.environ["AWS_ACCESS_KEY_ID"] = aws_access_key_id
|
||||||
|
os.environ["AWS_SECRET_ACCESS_KEY"] = aws_secret_access_key
|
||||||
|
except RateLimitError:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
|
@ -23,7 +23,7 @@ from litellm import RateLimitError, Timeout, completion, completion_cost, embedd
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
from litellm.llms.prompt_templates.factory import anthropic_messages_pt
|
from litellm.llms.prompt_templates.factory import anthropic_messages_pt
|
||||||
|
|
||||||
# litellm.num_retries = 3
|
# litellm.num_retries=3
|
||||||
litellm.cache = None
|
litellm.cache = None
|
||||||
litellm.success_callback = []
|
litellm.success_callback = []
|
||||||
user_message = "Write a short poem about the sky"
|
user_message = "Write a short poem about the sky"
|
||||||
|
|
|
@ -249,6 +249,25 @@ def test_completion_azure_exception():
|
||||||
# test_completion_azure_exception()
|
# test_completion_azure_exception()
|
||||||
|
|
||||||
|
|
||||||
|
def test_azure_embedding_exceptions():
|
||||||
|
try:
|
||||||
|
|
||||||
|
response = litellm.embedding(
|
||||||
|
model="azure/azure-embedding-model",
|
||||||
|
input="hello",
|
||||||
|
messages="hello",
|
||||||
|
)
|
||||||
|
pytest.fail(f"Bad request this should have failed but got {response}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(vars(e))
|
||||||
|
# CRUCIAL Test - Ensures our exceptions are readable and not overly complicated. some users have complained exceptions will randomly have another exception raised in our exception mapping
|
||||||
|
assert (
|
||||||
|
e.message
|
||||||
|
== "litellm.APIError: AzureException APIError - Embeddings.create() got an unexpected keyword argument 'messages'"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def asynctest_completion_azure_exception():
|
async def asynctest_completion_azure_exception():
|
||||||
try:
|
try:
|
||||||
import openai
|
import openai
|
||||||
|
|
|
@ -61,7 +61,6 @@ async def test_token_single_public_key():
|
||||||
import jwt
|
import jwt
|
||||||
|
|
||||||
jwt_handler = JWTHandler()
|
jwt_handler = JWTHandler()
|
||||||
|
|
||||||
backend_keys = {
|
backend_keys = {
|
||||||
"keys": [
|
"keys": [
|
||||||
{
|
{
|
||||||
|
|
|
@ -1,10 +1,16 @@
|
||||||
# What is this?
|
# What is this?
|
||||||
## This tests the Lakera AI integration
|
## This tests the Lakera AI integration
|
||||||
|
|
||||||
import sys, os, asyncio, time, random
|
import asyncio
|
||||||
from datetime import datetime
|
import os
|
||||||
|
import random
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
import os
|
import os
|
||||||
|
@ -12,17 +18,19 @@ import os
|
||||||
sys.path.insert(
|
sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system path
|
||||||
|
import logging
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
from litellm import Router, mock_completion
|
||||||
|
from litellm._logging import verbose_proxy_logger
|
||||||
|
from litellm.caching import DualCache
|
||||||
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
from litellm.proxy.enterprise.enterprise_hooks.lakera_ai import (
|
from litellm.proxy.enterprise.enterprise_hooks.lakera_ai import (
|
||||||
_ENTERPRISE_lakeraAI_Moderation,
|
_ENTERPRISE_lakeraAI_Moderation,
|
||||||
)
|
)
|
||||||
from litellm import Router, mock_completion
|
|
||||||
from litellm.proxy.utils import ProxyLogging, hash_token
|
from litellm.proxy.utils import ProxyLogging, hash_token
|
||||||
from litellm.proxy._types import UserAPIKeyAuth
|
|
||||||
from litellm.caching import DualCache
|
|
||||||
from litellm._logging import verbose_proxy_logger
|
|
||||||
import logging
|
|
||||||
|
|
||||||
verbose_proxy_logger.setLevel(logging.DEBUG)
|
verbose_proxy_logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
@ -55,10 +63,12 @@ async def test_lakera_prompt_injection_detection():
|
||||||
call_type="completion",
|
call_type="completion",
|
||||||
)
|
)
|
||||||
pytest.fail(f"Should have failed")
|
pytest.fail(f"Should have failed")
|
||||||
except Exception as e:
|
except HTTPException as http_exception:
|
||||||
print("Got exception: ", e)
|
print("http exception details=", http_exception.detail)
|
||||||
assert "Violated content safety policy" in str(e)
|
|
||||||
pass
|
# Assert that the laker ai response is in the exception raise
|
||||||
|
assert "lakera_ai_response" in http_exception.detail
|
||||||
|
assert "Violated content safety policy" in str(http_exception)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
85
litellm/tests/test_pass_through_endpoints.py
Normal file
85
litellm/tests/test_pass_through_endpoints.py
Normal file
|
@ -0,0 +1,85 @@
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds-the parent directory to the system path
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from litellm.proxy.proxy_server import app, initialize_pass_through_endpoints
|
||||||
|
|
||||||
|
|
||||||
|
# Mock the async_client used in the pass_through_request function
|
||||||
|
async def mock_request(*args, **kwargs):
|
||||||
|
return httpx.Response(200, json={"message": "Mocked response"})
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client():
|
||||||
|
return TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pass_through_endpoint(client, monkeypatch):
|
||||||
|
# Mock the httpx.AsyncClient.request method
|
||||||
|
monkeypatch.setattr("httpx.AsyncClient.request", mock_request)
|
||||||
|
|
||||||
|
# Define a pass-through endpoint
|
||||||
|
pass_through_endpoints = [
|
||||||
|
{
|
||||||
|
"path": "/test-endpoint",
|
||||||
|
"target": "https://api.example.com/v1/chat/completions",
|
||||||
|
"headers": {"Authorization": "Bearer test-token"},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
# Initialize the pass-through endpoint
|
||||||
|
await initialize_pass_through_endpoints(pass_through_endpoints)
|
||||||
|
|
||||||
|
# Make a request to the pass-through endpoint
|
||||||
|
response = client.post("/test-endpoint", json={"prompt": "Hello, world!"})
|
||||||
|
|
||||||
|
# Assert the response
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"message": "Mocked response"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pass_through_endpoint_rerank(client):
|
||||||
|
_cohere_api_key = os.environ.get("COHERE_API_KEY")
|
||||||
|
|
||||||
|
# Define a pass-through endpoint
|
||||||
|
pass_through_endpoints = [
|
||||||
|
{
|
||||||
|
"path": "/v1/rerank",
|
||||||
|
"target": "https://api.cohere.com/v1/rerank",
|
||||||
|
"headers": {"Authorization": f"bearer {_cohere_api_key}"},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
# Initialize the pass-through endpoint
|
||||||
|
await initialize_pass_through_endpoints(pass_through_endpoints)
|
||||||
|
|
||||||
|
_json_data = {
|
||||||
|
"model": "rerank-english-v3.0",
|
||||||
|
"query": "What is the capital of the United States?",
|
||||||
|
"top_n": 3,
|
||||||
|
"documents": [
|
||||||
|
"Carson City is the capital city of the American state of Nevada."
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Make a request to the pass-through endpoint
|
||||||
|
response = client.post("/v1/rerank", json=_json_data)
|
||||||
|
|
||||||
|
print("JSON response: ", _json_data)
|
||||||
|
|
||||||
|
# Assert the response
|
||||||
|
assert response.status_code == 200
|
|
@ -1,25 +1,31 @@
|
||||||
# test that the proxy actually does exception mapping to the OpenAI format
|
# test that the proxy actually does exception mapping to the OpenAI format
|
||||||
|
|
||||||
import sys, os
|
|
||||||
from unittest import mock
|
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
import os, io, asyncio
|
import asyncio
|
||||||
|
import io
|
||||||
|
import os
|
||||||
|
|
||||||
sys.path.insert(
|
sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system path
|
||||||
|
import openai
|
||||||
import pytest
|
import pytest
|
||||||
import litellm, openai
|
|
||||||
from fastapi.testclient import TestClient
|
|
||||||
from fastapi import Response
|
from fastapi import Response
|
||||||
from litellm.proxy.proxy_server import (
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm.proxy.proxy_server import ( # Replace with the actual module where your FastAPI router is defined
|
||||||
|
initialize,
|
||||||
router,
|
router,
|
||||||
save_worker_config,
|
save_worker_config,
|
||||||
initialize,
|
)
|
||||||
) # Replace with the actual module where your FastAPI router is defined
|
|
||||||
|
|
||||||
invalid_authentication_error_response = Response(
|
invalid_authentication_error_response = Response(
|
||||||
status_code=401,
|
status_code=401,
|
||||||
|
@ -66,6 +72,12 @@ def test_chat_completion_exception(client):
|
||||||
json_response = response.json()
|
json_response = response.json()
|
||||||
print("keys in json response", json_response.keys())
|
print("keys in json response", json_response.keys())
|
||||||
assert json_response.keys() == {"error"}
|
assert json_response.keys() == {"error"}
|
||||||
|
print("ERROR=", json_response["error"])
|
||||||
|
assert isinstance(json_response["error"]["message"], str)
|
||||||
|
assert (
|
||||||
|
json_response["error"]["message"]
|
||||||
|
== "litellm.AuthenticationError: AuthenticationError: OpenAIException - Incorrect API key provided: bad-key. You can find your API key at https://platform.openai.com/account/api-keys."
|
||||||
|
)
|
||||||
|
|
||||||
# make an openai client to call _make_status_error_from_response
|
# make an openai client to call _make_status_error_from_response
|
||||||
openai_client = openai.OpenAI(api_key="anything")
|
openai_client = openai.OpenAI(api_key="anything")
|
||||||
|
|
|
@ -16,6 +16,7 @@ sys.path.insert(
|
||||||
import os
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
@ -811,6 +812,7 @@ def test_router_context_window_check_pre_call_check():
|
||||||
"base_model": "azure/gpt-35-turbo",
|
"base_model": "azure/gpt-35-turbo",
|
||||||
"mock_response": "Hello world 1!",
|
"mock_response": "Hello world 1!",
|
||||||
},
|
},
|
||||||
|
"model_info": {"base_model": "azure/gpt-35-turbo"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"model_name": "gpt-3.5-turbo", # openai model name
|
"model_name": "gpt-3.5-turbo", # openai model name
|
||||||
|
@ -1884,3 +1886,106 @@ async def test_router_model_usage(mock_response):
|
||||||
else:
|
else:
|
||||||
print(f"allowed_fails: {allowed_fails}")
|
print(f"allowed_fails: {allowed_fails}")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model, base_model, llm_provider",
|
||||||
|
[
|
||||||
|
("azure/gpt-4", None, "azure"),
|
||||||
|
("azure/gpt-4", "azure/gpt-4-0125-preview", "azure"),
|
||||||
|
("gpt-4", None, "openai"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_router_get_model_info(model, base_model, llm_provider):
|
||||||
|
"""
|
||||||
|
Test if router get model info works based on provider
|
||||||
|
|
||||||
|
For azure -> only if base model set
|
||||||
|
For openai -> use model=
|
||||||
|
"""
|
||||||
|
router = Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "gpt-4",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": model,
|
||||||
|
"api_key": "my-fake-key",
|
||||||
|
"api_base": "my-fake-base",
|
||||||
|
},
|
||||||
|
"model_info": {"base_model": base_model, "id": "1"},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
deployment = router.get_deployment(model_id="1")
|
||||||
|
|
||||||
|
assert deployment is not None
|
||||||
|
|
||||||
|
if llm_provider == "openai" or (base_model is not None and llm_provider == "azure"):
|
||||||
|
router.get_router_model_info(deployment=deployment.to_json())
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
router.get_router_model_info(deployment=deployment.to_json())
|
||||||
|
pytest.fail("Expected this to raise model not mapped error")
|
||||||
|
except Exception as e:
|
||||||
|
if "This model isn't mapped yet" in str(e):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model, base_model, llm_provider",
|
||||||
|
[
|
||||||
|
("azure/gpt-4", None, "azure"),
|
||||||
|
("azure/gpt-4", "azure/gpt-4-0125-preview", "azure"),
|
||||||
|
("gpt-4", None, "openai"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_router_context_window_pre_call_check(model, base_model, llm_provider):
|
||||||
|
"""
|
||||||
|
- For an azure model
|
||||||
|
- if no base model set
|
||||||
|
- don't enforce context window limits
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
model_list = [
|
||||||
|
{
|
||||||
|
"model_name": "gpt-4",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": model,
|
||||||
|
"api_key": "my-fake-key",
|
||||||
|
"api_base": "my-fake-base",
|
||||||
|
},
|
||||||
|
"model_info": {"base_model": base_model, "id": "1"},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
router = Router(
|
||||||
|
model_list=model_list,
|
||||||
|
set_verbose=True,
|
||||||
|
enable_pre_call_checks=True,
|
||||||
|
num_retries=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
litellm.token_counter = MagicMock()
|
||||||
|
|
||||||
|
def token_counter_side_effect(*args, **kwargs):
|
||||||
|
# Process args and kwargs if needed
|
||||||
|
return 1000000
|
||||||
|
|
||||||
|
litellm.token_counter.side_effect = token_counter_side_effect
|
||||||
|
try:
|
||||||
|
updated_list = router._pre_call_checks(
|
||||||
|
model="gpt-4",
|
||||||
|
healthy_deployments=model_list,
|
||||||
|
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||||
|
)
|
||||||
|
if llm_provider == "azure" and base_model is None:
|
||||||
|
assert len(updated_list) == 1
|
||||||
|
else:
|
||||||
|
pytest.fail("Expected to raise an error. Got={}".format(updated_list))
|
||||||
|
except Exception as e:
|
||||||
|
if (
|
||||||
|
llm_provider == "azure" and base_model is not None
|
||||||
|
) or llm_provider == "openai":
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Got unexpected exception on router! - {str(e)}")
|
||||||
|
|
|
@ -1,16 +1,23 @@
|
||||||
import sys, os, time
|
import asyncio
|
||||||
import traceback, asyncio
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
sys.path.insert(
|
sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system path
|
||||||
|
|
||||||
import litellm, asyncio, logging
|
import asyncio
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import litellm
|
||||||
from litellm import Router
|
from litellm import Router
|
||||||
|
|
||||||
# this tests debug logs from litellm router and litellm proxy server
|
# this tests debug logs from litellm router and litellm proxy server
|
||||||
from litellm._logging import verbose_router_logger, verbose_logger, verbose_proxy_logger
|
from litellm._logging import verbose_logger, verbose_proxy_logger, verbose_router_logger
|
||||||
|
|
||||||
|
|
||||||
# this tests debug logs from litellm router and litellm proxy server
|
# this tests debug logs from litellm router and litellm proxy server
|
||||||
|
@ -81,7 +88,7 @@ def test_async_fallbacks(caplog):
|
||||||
# Define the expected log messages
|
# Define the expected log messages
|
||||||
# - error request, falling back notice, success notice
|
# - error request, falling back notice, success notice
|
||||||
expected_logs = [
|
expected_logs = [
|
||||||
"litellm.acompletion(model=gpt-3.5-turbo)\x1b[31m Exception litellm.AuthenticationError: AuthenticationError: OpenAIException - Error code: 401 - {'error': {'message': 'Incorrect API key provided: bad-key. You can find your API key at https://platform.openai.com/account/api-keys.', 'type': 'invalid_request_error', 'param': None, 'code': 'invalid_api_key'}}\x1b[0m",
|
"litellm.acompletion(model=gpt-3.5-turbo)\x1b[31m Exception litellm.AuthenticationError: AuthenticationError: OpenAIException - Incorrect API key provided: bad-key. You can find your API key at https://platform.openai.com/account/api-keys.\x1b[0m",
|
||||||
"Falling back to model_group = azure/gpt-3.5-turbo",
|
"Falling back to model_group = azure/gpt-3.5-turbo",
|
||||||
"litellm.acompletion(model=azure/chatgpt-v-2)\x1b[32m 200 OK\x1b[0m",
|
"litellm.acompletion(model=azure/chatgpt-v-2)\x1b[32m 200 OK\x1b[0m",
|
||||||
"Successful fallback b/w models.",
|
"Successful fallback b/w models.",
|
||||||
|
|
|
@ -742,7 +742,10 @@ def test_completion_palm_stream():
|
||||||
# test_completion_palm_stream()
|
# test_completion_palm_stream()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("sync_mode", [False]) # True,
|
@pytest.mark.parametrize(
|
||||||
|
"sync_mode",
|
||||||
|
[True, False],
|
||||||
|
) # ,
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_completion_gemini_stream(sync_mode):
|
async def test_completion_gemini_stream(sync_mode):
|
||||||
try:
|
try:
|
||||||
|
@ -807,49 +810,6 @@ async def test_completion_gemini_stream(sync_mode):
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_acompletion_gemini_stream():
|
|
||||||
try:
|
|
||||||
litellm.set_verbose = True
|
|
||||||
print("Streaming gemini response")
|
|
||||||
messages = [
|
|
||||||
# {"role": "system", "content": "You are a helpful assistant."},
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": "What do you know?",
|
|
||||||
},
|
|
||||||
]
|
|
||||||
print("testing gemini streaming")
|
|
||||||
response = await acompletion(
|
|
||||||
model="gemini/gemini-pro", messages=messages, max_tokens=50, stream=True
|
|
||||||
)
|
|
||||||
print(f"type of response at the top: {response}")
|
|
||||||
complete_response = ""
|
|
||||||
idx = 0
|
|
||||||
# Add any assertions here to check, the response
|
|
||||||
async for chunk in response:
|
|
||||||
print(f"chunk in acompletion gemini: {chunk}")
|
|
||||||
print(chunk.choices[0].delta)
|
|
||||||
chunk, finished = streaming_format_tests(idx, chunk)
|
|
||||||
if finished:
|
|
||||||
break
|
|
||||||
print(f"chunk: {chunk}")
|
|
||||||
complete_response += chunk
|
|
||||||
idx += 1
|
|
||||||
print(f"completion_response: {complete_response}")
|
|
||||||
if complete_response.strip() == "":
|
|
||||||
raise Exception("Empty response received")
|
|
||||||
except litellm.APIError as e:
|
|
||||||
pass
|
|
||||||
except litellm.RateLimitError as e:
|
|
||||||
pass
|
|
||||||
except Exception as e:
|
|
||||||
if "429 Resource has been exhausted" in str(e):
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
pytest.fail(f"Error occurred: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
# asyncio.run(test_acompletion_gemini_stream())
|
# asyncio.run(test_acompletion_gemini_stream())
|
||||||
|
|
||||||
|
|
||||||
|
@ -1071,7 +1031,7 @@ def test_completion_claude_stream_bad_key():
|
||||||
# test_completion_replicate_stream()
|
# test_completion_replicate_stream()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("provider", ["vertex_ai"]) # "vertex_ai_beta"
|
@pytest.mark.parametrize("provider", ["vertex_ai_beta"]) # ""
|
||||||
def test_vertex_ai_stream(provider):
|
def test_vertex_ai_stream(provider):
|
||||||
from litellm.tests.test_amazing_vertex_completion import load_vertex_ai_credentials
|
from litellm.tests.test_amazing_vertex_completion import load_vertex_ai_credentials
|
||||||
|
|
||||||
|
@ -1080,14 +1040,27 @@ def test_vertex_ai_stream(provider):
|
||||||
litellm.vertex_project = "adroit-crow-413218"
|
litellm.vertex_project = "adroit-crow-413218"
|
||||||
import random
|
import random
|
||||||
|
|
||||||
test_models = ["gemini-1.0-pro"]
|
test_models = ["gemini-1.5-pro"]
|
||||||
for model in test_models:
|
for model in test_models:
|
||||||
try:
|
try:
|
||||||
print("making request", model)
|
print("making request", model)
|
||||||
response = completion(
|
response = completion(
|
||||||
model="{}/{}".format(provider, model),
|
model="{}/{}".format(provider, model),
|
||||||
messages=[
|
messages=[
|
||||||
{"role": "user", "content": "write 10 line code code for saying hi"}
|
{"role": "user", "content": "Hey, how's it going?"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "I'm doing well. Would like to hear the rest of the story?",
|
||||||
|
},
|
||||||
|
{"role": "user", "content": "Na"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "No problem, is there anything else i can help you with today?",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "I think you're getting cut off sometimes",
|
||||||
|
},
|
||||||
],
|
],
|
||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
|
@ -1104,6 +1077,8 @@ def test_vertex_ai_stream(provider):
|
||||||
raise Exception("Empty response received")
|
raise Exception("Empty response received")
|
||||||
print(f"completion_response: {complete_response}")
|
print(f"completion_response: {complete_response}")
|
||||||
assert is_finished == True
|
assert is_finished == True
|
||||||
|
|
||||||
|
assert False
|
||||||
except litellm.RateLimitError as e:
|
except litellm.RateLimitError as e:
|
||||||
pass
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -1251,6 +1226,7 @@ async def test_completion_replicate_llama3_streaming(sync_mode):
|
||||||
messages=messages,
|
messages=messages,
|
||||||
max_tokens=10, # type: ignore
|
max_tokens=10, # type: ignore
|
||||||
stream=True,
|
stream=True,
|
||||||
|
num_retries=3,
|
||||||
)
|
)
|
||||||
complete_response = ""
|
complete_response = ""
|
||||||
# Add any assertions here to check the response
|
# Add any assertions here to check the response
|
||||||
|
@ -1272,6 +1248,7 @@ async def test_completion_replicate_llama3_streaming(sync_mode):
|
||||||
messages=messages,
|
messages=messages,
|
||||||
max_tokens=100, # type: ignore
|
max_tokens=100, # type: ignore
|
||||||
stream=True,
|
stream=True,
|
||||||
|
num_retries=3,
|
||||||
)
|
)
|
||||||
complete_response = ""
|
complete_response = ""
|
||||||
# Add any assertions here to check the response
|
# Add any assertions here to check the response
|
||||||
|
@ -1290,6 +1267,8 @@ async def test_completion_replicate_llama3_streaming(sync_mode):
|
||||||
raise Exception("finish reason not set")
|
raise Exception("finish reason not set")
|
||||||
if complete_response.strip() == "":
|
if complete_response.strip() == "":
|
||||||
raise Exception("Empty response received")
|
raise Exception("Empty response received")
|
||||||
|
except litellm.UnprocessableEntityError as e:
|
||||||
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
|
@ -609,3 +609,83 @@ def test_logging_trace_id(langfuse_trace_id, langfuse_existing_trace_id):
|
||||||
litellm_logging_obj._get_trace_id(service_name="langfuse")
|
litellm_logging_obj._get_trace_id(service_name="langfuse")
|
||||||
== litellm_call_id
|
== litellm_call_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_model_response_object():
|
||||||
|
"""
|
||||||
|
Unit test to ensure model response object correctly handles openrouter errors.
|
||||||
|
"""
|
||||||
|
args = {
|
||||||
|
"response_object": {
|
||||||
|
"id": None,
|
||||||
|
"choices": None,
|
||||||
|
"created": None,
|
||||||
|
"model": None,
|
||||||
|
"object": None,
|
||||||
|
"service_tier": None,
|
||||||
|
"system_fingerprint": None,
|
||||||
|
"usage": None,
|
||||||
|
"error": {
|
||||||
|
"message": '{"type":"error","error":{"type":"invalid_request_error","message":"Output blocked by content filtering policy"}}',
|
||||||
|
"code": 400,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"model_response_object": litellm.ModelResponse(
|
||||||
|
id="chatcmpl-b88ce43a-7bfc-437c-b8cc-e90d59372cfb",
|
||||||
|
choices=[
|
||||||
|
litellm.Choices(
|
||||||
|
finish_reason="stop",
|
||||||
|
index=0,
|
||||||
|
message=litellm.Message(content="default", role="assistant"),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
created=1719376241,
|
||||||
|
model="openrouter/anthropic/claude-3.5-sonnet",
|
||||||
|
object="chat.completion",
|
||||||
|
system_fingerprint=None,
|
||||||
|
usage=litellm.Usage(),
|
||||||
|
),
|
||||||
|
"response_type": "completion",
|
||||||
|
"stream": False,
|
||||||
|
"start_time": None,
|
||||||
|
"end_time": None,
|
||||||
|
"hidden_params": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
litellm.convert_to_model_response_object(**args)
|
||||||
|
pytest.fail("Expected this to fail")
|
||||||
|
except Exception as e:
|
||||||
|
assert hasattr(e, "status_code")
|
||||||
|
assert e.status_code == 400
|
||||||
|
assert hasattr(e, "message")
|
||||||
|
assert (
|
||||||
|
e.message
|
||||||
|
== '{"type":"error","error":{"type":"invalid_request_error","message":"Output blocked by content filtering policy"}}'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model, expected_bool",
|
||||||
|
[
|
||||||
|
("vertex_ai/gemini-1.5-pro", True),
|
||||||
|
("gemini/gemini-1.5-pro", True),
|
||||||
|
("predibase/llama3-8b-instruct", True),
|
||||||
|
("gpt-4o", False),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_supports_response_schema(model, expected_bool):
|
||||||
|
"""
|
||||||
|
Unit tests for 'supports_response_schema' helper function.
|
||||||
|
|
||||||
|
Should be true for gemini-1.5-pro on google ai studio / vertex ai AND predibase models
|
||||||
|
Should be false otherwise
|
||||||
|
"""
|
||||||
|
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
|
||||||
|
litellm.model_cost = litellm.get_model_cost_map(url="")
|
||||||
|
|
||||||
|
from litellm.utils import supports_response_schema
|
||||||
|
|
||||||
|
response = supports_response_schema(model=model, custom_llm_provider=None)
|
||||||
|
|
||||||
|
assert expected_bool == response
|
||||||
|
|
|
@ -71,6 +71,7 @@ class ModelInfo(TypedDict, total=False):
|
||||||
]
|
]
|
||||||
supported_openai_params: Required[Optional[List[str]]]
|
supported_openai_params: Required[Optional[List[str]]]
|
||||||
supports_system_messages: Optional[bool]
|
supports_system_messages: Optional[bool]
|
||||||
|
supports_response_schema: Optional[bool]
|
||||||
|
|
||||||
|
|
||||||
class GenericStreamingChunk(TypedDict):
|
class GenericStreamingChunk(TypedDict):
|
||||||
|
@ -994,3 +995,8 @@ class GenericImageParsingChunk(TypedDict):
|
||||||
type: str
|
type: str
|
||||||
media_type: str
|
media_type: str
|
||||||
data: str
|
data: str
|
||||||
|
|
||||||
|
|
||||||
|
class ResponseFormatChunk(TypedDict, total=False):
|
||||||
|
type: Required[Literal["json_object", "text"]]
|
||||||
|
response_schema: dict
|
||||||
|
|
305
litellm/utils.py
305
litellm/utils.py
|
@ -48,8 +48,10 @@ from tokenizers import Tokenizer
|
||||||
import litellm
|
import litellm
|
||||||
import litellm._service_logger # for storing API inputs, outputs, and metadata
|
import litellm._service_logger # for storing API inputs, outputs, and metadata
|
||||||
import litellm.litellm_core_utils
|
import litellm.litellm_core_utils
|
||||||
|
import litellm.litellm_core_utils.json_validation_rule
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache
|
||||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||||
|
from litellm.litellm_core_utils.exception_mapping_utils import get_error_message
|
||||||
from litellm.litellm_core_utils.llm_request_utils import _ensure_extra_body_is_safe
|
from litellm.litellm_core_utils.llm_request_utils import _ensure_extra_body_is_safe
|
||||||
from litellm.litellm_core_utils.redact_messages import (
|
from litellm.litellm_core_utils.redact_messages import (
|
||||||
redact_message_input_output_from_logging,
|
redact_message_input_output_from_logging,
|
||||||
|
@ -579,7 +581,7 @@ def client(original_function):
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def post_call_processing(original_response, model):
|
def post_call_processing(original_response, model, optional_params: Optional[dict]):
|
||||||
try:
|
try:
|
||||||
if original_response is None:
|
if original_response is None:
|
||||||
pass
|
pass
|
||||||
|
@ -594,11 +596,47 @@ def client(original_function):
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
if isinstance(original_response, ModelResponse):
|
if isinstance(original_response, ModelResponse):
|
||||||
model_response = original_response.choices[
|
model_response: Optional[str] = original_response.choices[
|
||||||
0
|
0
|
||||||
].message.content
|
].message.content # type: ignore
|
||||||
|
if model_response is not None:
|
||||||
### POST-CALL RULES ###
|
### POST-CALL RULES ###
|
||||||
rules_obj.post_call_rules(input=model_response, model=model)
|
rules_obj.post_call_rules(
|
||||||
|
input=model_response, model=model
|
||||||
|
)
|
||||||
|
### JSON SCHEMA VALIDATION ###
|
||||||
|
if (
|
||||||
|
optional_params is not None
|
||||||
|
and "response_format" in optional_params
|
||||||
|
and isinstance(
|
||||||
|
optional_params["response_format"], dict
|
||||||
|
)
|
||||||
|
and "type" in optional_params["response_format"]
|
||||||
|
and optional_params["response_format"]["type"]
|
||||||
|
== "json_object"
|
||||||
|
and "response_schema"
|
||||||
|
in optional_params["response_format"]
|
||||||
|
and isinstance(
|
||||||
|
optional_params["response_format"][
|
||||||
|
"response_schema"
|
||||||
|
],
|
||||||
|
dict,
|
||||||
|
)
|
||||||
|
and "enforce_validation"
|
||||||
|
in optional_params["response_format"]
|
||||||
|
and optional_params["response_format"][
|
||||||
|
"enforce_validation"
|
||||||
|
]
|
||||||
|
is True
|
||||||
|
):
|
||||||
|
# schema given, json response expected, and validation enforced
|
||||||
|
litellm.litellm_core_utils.json_validation_rule.validate_schema(
|
||||||
|
schema=optional_params["response_format"][
|
||||||
|
"response_schema"
|
||||||
|
],
|
||||||
|
response=model_response,
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
@ -867,7 +905,11 @@ def client(original_function):
|
||||||
return result
|
return result
|
||||||
|
|
||||||
### POST-CALL RULES ###
|
### POST-CALL RULES ###
|
||||||
post_call_processing(original_response=result, model=model or None)
|
post_call_processing(
|
||||||
|
original_response=result,
|
||||||
|
model=model or None,
|
||||||
|
optional_params=kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
# [OPTIONAL] ADD TO CACHE
|
# [OPTIONAL] ADD TO CACHE
|
||||||
if (
|
if (
|
||||||
|
@ -1316,7 +1358,9 @@ def client(original_function):
|
||||||
).total_seconds() * 1000 # return response latency in ms like openai
|
).total_seconds() * 1000 # return response latency in ms like openai
|
||||||
|
|
||||||
### POST-CALL RULES ###
|
### POST-CALL RULES ###
|
||||||
post_call_processing(original_response=result, model=model)
|
post_call_processing(
|
||||||
|
original_response=result, model=model, optional_params=kwargs
|
||||||
|
)
|
||||||
|
|
||||||
# [OPTIONAL] ADD TO CACHE
|
# [OPTIONAL] ADD TO CACHE
|
||||||
if (
|
if (
|
||||||
|
@ -1847,9 +1891,10 @@ def supports_system_messages(model: str, custom_llm_provider: Optional[str]) ->
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
model (str): The model name to be checked.
|
model (str): The model name to be checked.
|
||||||
|
custom_llm_provider (str): The provider to be checked.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: True if the model supports function calling, False otherwise.
|
bool: True if the model supports system messages, False otherwise.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
Exception: If the given model is not found in model_prices_and_context_window.json.
|
Exception: If the given model is not found in model_prices_and_context_window.json.
|
||||||
|
@ -1867,6 +1912,43 @@ def supports_system_messages(model: str, custom_llm_provider: Optional[str]) ->
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def supports_response_schema(model: str, custom_llm_provider: Optional[str]) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the given model + provider supports 'response_schema' as a param.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
model (str): The model name to be checked.
|
||||||
|
custom_llm_provider (str): The provider to be checked.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the model supports response_schema, False otherwise.
|
||||||
|
|
||||||
|
Does not raise error. Defaults to 'False'. Outputs logging.error.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
## GET LLM PROVIDER ##
|
||||||
|
model, custom_llm_provider, _, _ = get_llm_provider(
|
||||||
|
model=model, custom_llm_provider=custom_llm_provider
|
||||||
|
)
|
||||||
|
|
||||||
|
if custom_llm_provider == "predibase": # predibase supports this globally
|
||||||
|
return True
|
||||||
|
|
||||||
|
## GET MODEL INFO
|
||||||
|
model_info = litellm.get_model_info(
|
||||||
|
model=model, custom_llm_provider=custom_llm_provider
|
||||||
|
)
|
||||||
|
|
||||||
|
if model_info.get("supports_response_schema", False) is True:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
except Exception:
|
||||||
|
verbose_logger.error(
|
||||||
|
f"Model not in model_prices_and_context_window.json. You passed model={model}, custom_llm_provider={custom_llm_provider}."
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def supports_function_calling(model: str) -> bool:
|
def supports_function_calling(model: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if the given model supports function calling and return a boolean value.
|
Check if the given model supports function calling and return a boolean value.
|
||||||
|
@ -2324,7 +2406,9 @@ def get_optional_params(
|
||||||
elif k == "hf_model_name" and custom_llm_provider != "sagemaker":
|
elif k == "hf_model_name" and custom_llm_provider != "sagemaker":
|
||||||
continue
|
continue
|
||||||
elif (
|
elif (
|
||||||
k.startswith("vertex_") and custom_llm_provider != "vertex_ai" and custom_llm_provider != "vertex_ai_beta"
|
k.startswith("vertex_")
|
||||||
|
and custom_llm_provider != "vertex_ai"
|
||||||
|
and custom_llm_provider != "vertex_ai_beta"
|
||||||
): # allow dynamically setting vertex ai init logic
|
): # allow dynamically setting vertex ai init logic
|
||||||
continue
|
continue
|
||||||
passed_params[k] = v
|
passed_params[k] = v
|
||||||
|
@ -2756,6 +2840,11 @@ def get_optional_params(
|
||||||
non_default_params=non_default_params,
|
non_default_params=non_default_params,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
model=model,
|
model=model,
|
||||||
|
drop_params=(
|
||||||
|
drop_params
|
||||||
|
if drop_params is not None and isinstance(drop_params, bool)
|
||||||
|
else False
|
||||||
|
),
|
||||||
)
|
)
|
||||||
elif (
|
elif (
|
||||||
custom_llm_provider == "vertex_ai" and model in litellm.vertex_anthropic_models
|
custom_llm_provider == "vertex_ai" and model in litellm.vertex_anthropic_models
|
||||||
|
@ -2824,12 +2913,7 @@ def get_optional_params(
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
elif model in litellm.BEDROCK_CONVERSE_MODELS:
|
||||||
optional_params = litellm.AmazonAnthropicConfig().map_openai_params(
|
|
||||||
non_default_params=non_default_params,
|
|
||||||
optional_params=optional_params,
|
|
||||||
)
|
|
||||||
else: # bedrock httpx route
|
|
||||||
optional_params = litellm.AmazonConverseConfig().map_openai_params(
|
optional_params = litellm.AmazonConverseConfig().map_openai_params(
|
||||||
model=model,
|
model=model,
|
||||||
non_default_params=non_default_params,
|
non_default_params=non_default_params,
|
||||||
|
@ -2840,6 +2924,11 @@ def get_optional_params(
|
||||||
else False
|
else False
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
optional_params = litellm.AmazonAnthropicConfig().map_openai_params(
|
||||||
|
non_default_params=non_default_params,
|
||||||
|
optional_params=optional_params,
|
||||||
|
)
|
||||||
elif "amazon" in model: # amazon titan llms
|
elif "amazon" in model: # amazon titan llms
|
||||||
_check_valid_arg(supported_params=supported_params)
|
_check_valid_arg(supported_params=supported_params)
|
||||||
# see https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-large
|
# see https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-large
|
||||||
|
@ -3755,23 +3844,18 @@ def get_supported_openai_params(
|
||||||
return litellm.AzureOpenAIConfig().get_supported_openai_params()
|
return litellm.AzureOpenAIConfig().get_supported_openai_params()
|
||||||
elif custom_llm_provider == "openrouter":
|
elif custom_llm_provider == "openrouter":
|
||||||
return [
|
return [
|
||||||
"functions",
|
|
||||||
"function_call",
|
|
||||||
"temperature",
|
"temperature",
|
||||||
"top_p",
|
"top_p",
|
||||||
"n",
|
|
||||||
"stream",
|
|
||||||
"stop",
|
|
||||||
"max_tokens",
|
|
||||||
"presence_penalty",
|
|
||||||
"frequency_penalty",
|
"frequency_penalty",
|
||||||
"logit_bias",
|
"presence_penalty",
|
||||||
"user",
|
"repetition_penalty",
|
||||||
"response_format",
|
|
||||||
"seed",
|
"seed",
|
||||||
"tools",
|
"max_tokens",
|
||||||
"tool_choice",
|
"logit_bias",
|
||||||
"max_retries",
|
"logprobs",
|
||||||
|
"top_logprobs",
|
||||||
|
"response_format",
|
||||||
|
"stop",
|
||||||
]
|
]
|
||||||
elif custom_llm_provider == "mistral" or custom_llm_provider == "codestral":
|
elif custom_llm_provider == "mistral" or custom_llm_provider == "codestral":
|
||||||
# mistal and codestral api have the exact same params
|
# mistal and codestral api have the exact same params
|
||||||
|
@ -3789,6 +3873,10 @@ def get_supported_openai_params(
|
||||||
"top_p",
|
"top_p",
|
||||||
"stop",
|
"stop",
|
||||||
"seed",
|
"seed",
|
||||||
|
"tools",
|
||||||
|
"tool_choice",
|
||||||
|
"functions",
|
||||||
|
"function_call",
|
||||||
]
|
]
|
||||||
elif custom_llm_provider == "huggingface":
|
elif custom_llm_provider == "huggingface":
|
||||||
return litellm.HuggingfaceConfig().get_supported_openai_params()
|
return litellm.HuggingfaceConfig().get_supported_openai_params()
|
||||||
|
@ -4434,8 +4522,7 @@ def get_max_tokens(model: str) -> Optional[int]:
|
||||||
|
|
||||||
def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> ModelInfo:
|
def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> ModelInfo:
|
||||||
"""
|
"""
|
||||||
Get a dict for the maximum tokens (context window),
|
Get a dict for the maximum tokens (context window), input_cost_per_token, output_cost_per_token for a given model.
|
||||||
input_cost_per_token, output_cost_per_token for a given model.
|
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
- model (str): The name of the model.
|
- model (str): The name of the model.
|
||||||
|
@ -4520,6 +4607,7 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
|
||||||
mode="chat",
|
mode="chat",
|
||||||
supported_openai_params=supported_openai_params,
|
supported_openai_params=supported_openai_params,
|
||||||
supports_system_messages=None,
|
supports_system_messages=None,
|
||||||
|
supports_response_schema=None,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
"""
|
"""
|
||||||
|
@ -4541,36 +4629,6 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
raise Exception
|
raise Exception
|
||||||
return ModelInfo(
|
|
||||||
max_tokens=_model_info.get("max_tokens", None),
|
|
||||||
max_input_tokens=_model_info.get("max_input_tokens", None),
|
|
||||||
max_output_tokens=_model_info.get("max_output_tokens", None),
|
|
||||||
input_cost_per_token=_model_info.get("input_cost_per_token", 0),
|
|
||||||
input_cost_per_character=_model_info.get(
|
|
||||||
"input_cost_per_character", None
|
|
||||||
),
|
|
||||||
input_cost_per_token_above_128k_tokens=_model_info.get(
|
|
||||||
"input_cost_per_token_above_128k_tokens", None
|
|
||||||
),
|
|
||||||
output_cost_per_token=_model_info.get("output_cost_per_token", 0),
|
|
||||||
output_cost_per_character=_model_info.get(
|
|
||||||
"output_cost_per_character", None
|
|
||||||
),
|
|
||||||
output_cost_per_token_above_128k_tokens=_model_info.get(
|
|
||||||
"output_cost_per_token_above_128k_tokens", None
|
|
||||||
),
|
|
||||||
output_cost_per_character_above_128k_tokens=_model_info.get(
|
|
||||||
"output_cost_per_character_above_128k_tokens", None
|
|
||||||
),
|
|
||||||
litellm_provider=_model_info.get(
|
|
||||||
"litellm_provider", custom_llm_provider
|
|
||||||
),
|
|
||||||
mode=_model_info.get("mode"),
|
|
||||||
supported_openai_params=supported_openai_params,
|
|
||||||
supports_system_messages=_model_info.get(
|
|
||||||
"supports_system_messages", None
|
|
||||||
),
|
|
||||||
)
|
|
||||||
elif model in litellm.model_cost:
|
elif model in litellm.model_cost:
|
||||||
_model_info = litellm.model_cost[model]
|
_model_info = litellm.model_cost[model]
|
||||||
_model_info["supported_openai_params"] = supported_openai_params
|
_model_info["supported_openai_params"] = supported_openai_params
|
||||||
|
@ -4584,36 +4642,6 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
raise Exception
|
raise Exception
|
||||||
return ModelInfo(
|
|
||||||
max_tokens=_model_info.get("max_tokens", None),
|
|
||||||
max_input_tokens=_model_info.get("max_input_tokens", None),
|
|
||||||
max_output_tokens=_model_info.get("max_output_tokens", None),
|
|
||||||
input_cost_per_token=_model_info.get("input_cost_per_token", 0),
|
|
||||||
input_cost_per_character=_model_info.get(
|
|
||||||
"input_cost_per_character", None
|
|
||||||
),
|
|
||||||
input_cost_per_token_above_128k_tokens=_model_info.get(
|
|
||||||
"input_cost_per_token_above_128k_tokens", None
|
|
||||||
),
|
|
||||||
output_cost_per_token=_model_info.get("output_cost_per_token", 0),
|
|
||||||
output_cost_per_character=_model_info.get(
|
|
||||||
"output_cost_per_character", None
|
|
||||||
),
|
|
||||||
output_cost_per_token_above_128k_tokens=_model_info.get(
|
|
||||||
"output_cost_per_token_above_128k_tokens", None
|
|
||||||
),
|
|
||||||
output_cost_per_character_above_128k_tokens=_model_info.get(
|
|
||||||
"output_cost_per_character_above_128k_tokens", None
|
|
||||||
),
|
|
||||||
litellm_provider=_model_info.get(
|
|
||||||
"litellm_provider", custom_llm_provider
|
|
||||||
),
|
|
||||||
mode=_model_info.get("mode"),
|
|
||||||
supported_openai_params=supported_openai_params,
|
|
||||||
supports_system_messages=_model_info.get(
|
|
||||||
"supports_system_messages", None
|
|
||||||
),
|
|
||||||
)
|
|
||||||
elif split_model in litellm.model_cost:
|
elif split_model in litellm.model_cost:
|
||||||
_model_info = litellm.model_cost[split_model]
|
_model_info = litellm.model_cost[split_model]
|
||||||
_model_info["supported_openai_params"] = supported_openai_params
|
_model_info["supported_openai_params"] = supported_openai_params
|
||||||
|
@ -4627,6 +4655,15 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
raise Exception
|
raise Exception
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"This model isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json"
|
||||||
|
)
|
||||||
|
|
||||||
|
## PROVIDER-SPECIFIC INFORMATION
|
||||||
|
if custom_llm_provider == "predibase":
|
||||||
|
_model_info["supports_response_schema"] = True
|
||||||
|
|
||||||
return ModelInfo(
|
return ModelInfo(
|
||||||
max_tokens=_model_info.get("max_tokens", None),
|
max_tokens=_model_info.get("max_tokens", None),
|
||||||
max_input_tokens=_model_info.get("max_input_tokens", None),
|
max_input_tokens=_model_info.get("max_input_tokens", None),
|
||||||
|
@ -4656,10 +4693,9 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
|
||||||
supports_system_messages=_model_info.get(
|
supports_system_messages=_model_info.get(
|
||||||
"supports_system_messages", None
|
"supports_system_messages", None
|
||||||
),
|
),
|
||||||
)
|
supports_response_schema=_model_info.get(
|
||||||
else:
|
"supports_response_schema", None
|
||||||
raise ValueError(
|
),
|
||||||
"This model isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json"
|
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
|
@ -5278,6 +5314,27 @@ def convert_to_model_response_object(
|
||||||
hidden_params: Optional[dict] = None,
|
hidden_params: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
received_args = locals()
|
received_args = locals()
|
||||||
|
### CHECK IF ERROR IN RESPONSE ### - openrouter returns these in the dictionary
|
||||||
|
if (
|
||||||
|
response_object is not None
|
||||||
|
and "error" in response_object
|
||||||
|
and response_object["error"] is not None
|
||||||
|
):
|
||||||
|
error_args = {"status_code": 422, "message": "Error in response object"}
|
||||||
|
if isinstance(response_object["error"], dict):
|
||||||
|
if "code" in response_object["error"]:
|
||||||
|
error_args["status_code"] = response_object["error"]["code"]
|
||||||
|
if "message" in response_object["error"]:
|
||||||
|
if isinstance(response_object["error"]["message"], dict):
|
||||||
|
message_str = json.dumps(response_object["error"]["message"])
|
||||||
|
else:
|
||||||
|
message_str = str(response_object["error"]["message"])
|
||||||
|
error_args["message"] = message_str
|
||||||
|
raised_exception = Exception()
|
||||||
|
setattr(raised_exception, "status_code", error_args["status_code"])
|
||||||
|
setattr(raised_exception, "message", error_args["message"])
|
||||||
|
raise raised_exception
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if response_type == "completion" and (
|
if response_type == "completion" and (
|
||||||
model_response_object is None
|
model_response_object is None
|
||||||
|
@ -5733,6 +5790,9 @@ def exception_type(
|
||||||
print() # noqa
|
print() # noqa
|
||||||
try:
|
try:
|
||||||
if model:
|
if model:
|
||||||
|
if hasattr(original_exception, "message"):
|
||||||
|
error_str = str(original_exception.message)
|
||||||
|
else:
|
||||||
error_str = str(original_exception)
|
error_str = str(original_exception)
|
||||||
if isinstance(original_exception, BaseException):
|
if isinstance(original_exception, BaseException):
|
||||||
exception_type = type(original_exception).__name__
|
exception_type = type(original_exception).__name__
|
||||||
|
@ -5755,6 +5815,18 @@ def exception_type(
|
||||||
_model_group = _metadata.get("model_group")
|
_model_group = _metadata.get("model_group")
|
||||||
_deployment = _metadata.get("deployment")
|
_deployment = _metadata.get("deployment")
|
||||||
extra_information = f"\nModel: {model}"
|
extra_information = f"\nModel: {model}"
|
||||||
|
|
||||||
|
exception_provider = "Unknown"
|
||||||
|
if (
|
||||||
|
isinstance(custom_llm_provider, str)
|
||||||
|
and len(custom_llm_provider) > 0
|
||||||
|
):
|
||||||
|
exception_provider = (
|
||||||
|
custom_llm_provider[0].upper()
|
||||||
|
+ custom_llm_provider[1:]
|
||||||
|
+ "Exception"
|
||||||
|
)
|
||||||
|
|
||||||
if _api_base:
|
if _api_base:
|
||||||
extra_information += f"\nAPI Base: `{_api_base}`"
|
extra_information += f"\nAPI Base: `{_api_base}`"
|
||||||
if (
|
if (
|
||||||
|
@ -5805,10 +5877,13 @@ def exception_type(
|
||||||
or custom_llm_provider in litellm.openai_compatible_providers
|
or custom_llm_provider in litellm.openai_compatible_providers
|
||||||
):
|
):
|
||||||
# custom_llm_provider is openai, make it OpenAI
|
# custom_llm_provider is openai, make it OpenAI
|
||||||
|
message = get_error_message(error_obj=original_exception)
|
||||||
|
if message is None:
|
||||||
if hasattr(original_exception, "message"):
|
if hasattr(original_exception, "message"):
|
||||||
message = original_exception.message
|
message = original_exception.message
|
||||||
else:
|
else:
|
||||||
message = str(original_exception)
|
message = str(original_exception)
|
||||||
|
|
||||||
if message is not None and isinstance(message, str):
|
if message is not None and isinstance(message, str):
|
||||||
message = message.replace("OPENAI", custom_llm_provider.upper())
|
message = message.replace("OPENAI", custom_llm_provider.upper())
|
||||||
message = message.replace("openai", custom_llm_provider)
|
message = message.replace("openai", custom_llm_provider)
|
||||||
|
@ -6141,7 +6216,6 @@ def exception_type(
|
||||||
)
|
)
|
||||||
elif (
|
elif (
|
||||||
original_exception.status_code == 400
|
original_exception.status_code == 400
|
||||||
or original_exception.status_code == 422
|
|
||||||
or original_exception.status_code == 413
|
or original_exception.status_code == 413
|
||||||
):
|
):
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
|
@ -6151,6 +6225,14 @@ def exception_type(
|
||||||
llm_provider="replicate",
|
llm_provider="replicate",
|
||||||
response=original_exception.response,
|
response=original_exception.response,
|
||||||
)
|
)
|
||||||
|
elif original_exception.status_code == 422:
|
||||||
|
exception_mapping_worked = True
|
||||||
|
raise UnprocessableEntityError(
|
||||||
|
message=f"ReplicateException - {original_exception.message}",
|
||||||
|
model=model,
|
||||||
|
llm_provider="replicate",
|
||||||
|
response=original_exception.response,
|
||||||
|
)
|
||||||
elif original_exception.status_code == 408:
|
elif original_exception.status_code == 408:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
raise Timeout(
|
raise Timeout(
|
||||||
|
@ -7254,10 +7336,17 @@ def exception_type(
|
||||||
request=original_exception.request,
|
request=original_exception.request,
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "azure":
|
elif custom_llm_provider == "azure":
|
||||||
|
message = get_error_message(error_obj=original_exception)
|
||||||
|
if message is None:
|
||||||
|
if hasattr(original_exception, "message"):
|
||||||
|
message = original_exception.message
|
||||||
|
else:
|
||||||
|
message = str(original_exception)
|
||||||
|
|
||||||
if "Internal server error" in error_str:
|
if "Internal server error" in error_str:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
raise litellm.InternalServerError(
|
raise litellm.InternalServerError(
|
||||||
message=f"AzureException Internal server error - {original_exception.message}",
|
message=f"AzureException Internal server error - {message}",
|
||||||
llm_provider="azure",
|
llm_provider="azure",
|
||||||
model=model,
|
model=model,
|
||||||
litellm_debug_info=extra_information,
|
litellm_debug_info=extra_information,
|
||||||
|
@ -7270,7 +7359,7 @@ def exception_type(
|
||||||
elif "This model's maximum context length is" in error_str:
|
elif "This model's maximum context length is" in error_str:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
raise ContextWindowExceededError(
|
raise ContextWindowExceededError(
|
||||||
message=f"AzureException ContextWindowExceededError - {original_exception.message}",
|
message=f"AzureException ContextWindowExceededError - {message}",
|
||||||
llm_provider="azure",
|
llm_provider="azure",
|
||||||
model=model,
|
model=model,
|
||||||
litellm_debug_info=extra_information,
|
litellm_debug_info=extra_information,
|
||||||
|
@ -7279,7 +7368,7 @@ def exception_type(
|
||||||
elif "DeploymentNotFound" in error_str:
|
elif "DeploymentNotFound" in error_str:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
raise NotFoundError(
|
raise NotFoundError(
|
||||||
message=f"AzureException NotFoundError - {original_exception.message}",
|
message=f"AzureException NotFoundError - {message}",
|
||||||
llm_provider="azure",
|
llm_provider="azure",
|
||||||
model=model,
|
model=model,
|
||||||
litellm_debug_info=extra_information,
|
litellm_debug_info=extra_information,
|
||||||
|
@ -7299,7 +7388,7 @@ def exception_type(
|
||||||
):
|
):
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
raise ContentPolicyViolationError(
|
raise ContentPolicyViolationError(
|
||||||
message=f"litellm.ContentPolicyViolationError: AzureException - {original_exception.message}",
|
message=f"litellm.ContentPolicyViolationError: AzureException - {message}",
|
||||||
llm_provider="azure",
|
llm_provider="azure",
|
||||||
model=model,
|
model=model,
|
||||||
litellm_debug_info=extra_information,
|
litellm_debug_info=extra_information,
|
||||||
|
@ -7308,7 +7397,7 @@ def exception_type(
|
||||||
elif "invalid_request_error" in error_str:
|
elif "invalid_request_error" in error_str:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
raise BadRequestError(
|
raise BadRequestError(
|
||||||
message=f"AzureException BadRequestError - {original_exception.message}",
|
message=f"AzureException BadRequestError - {message}",
|
||||||
llm_provider="azure",
|
llm_provider="azure",
|
||||||
model=model,
|
model=model,
|
||||||
litellm_debug_info=extra_information,
|
litellm_debug_info=extra_information,
|
||||||
|
@ -7320,7 +7409,7 @@ def exception_type(
|
||||||
):
|
):
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
raise AuthenticationError(
|
raise AuthenticationError(
|
||||||
message=f"{exception_provider} AuthenticationError - {original_exception.message}",
|
message=f"{exception_provider} AuthenticationError - {message}",
|
||||||
llm_provider=custom_llm_provider,
|
llm_provider=custom_llm_provider,
|
||||||
model=model,
|
model=model,
|
||||||
litellm_debug_info=extra_information,
|
litellm_debug_info=extra_information,
|
||||||
|
@ -7331,7 +7420,7 @@ def exception_type(
|
||||||
if original_exception.status_code == 400:
|
if original_exception.status_code == 400:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
raise BadRequestError(
|
raise BadRequestError(
|
||||||
message=f"AzureException - {original_exception.message}",
|
message=f"AzureException - {message}",
|
||||||
llm_provider="azure",
|
llm_provider="azure",
|
||||||
model=model,
|
model=model,
|
||||||
litellm_debug_info=extra_information,
|
litellm_debug_info=extra_information,
|
||||||
|
@ -7340,7 +7429,7 @@ def exception_type(
|
||||||
elif original_exception.status_code == 401:
|
elif original_exception.status_code == 401:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
raise AuthenticationError(
|
raise AuthenticationError(
|
||||||
message=f"AzureException AuthenticationError - {original_exception.message}",
|
message=f"AzureException AuthenticationError - {message}",
|
||||||
llm_provider="azure",
|
llm_provider="azure",
|
||||||
model=model,
|
model=model,
|
||||||
litellm_debug_info=extra_information,
|
litellm_debug_info=extra_information,
|
||||||
|
@ -7349,7 +7438,7 @@ def exception_type(
|
||||||
elif original_exception.status_code == 408:
|
elif original_exception.status_code == 408:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
raise Timeout(
|
raise Timeout(
|
||||||
message=f"AzureException Timeout - {original_exception.message}",
|
message=f"AzureException Timeout - {message}",
|
||||||
model=model,
|
model=model,
|
||||||
litellm_debug_info=extra_information,
|
litellm_debug_info=extra_information,
|
||||||
llm_provider="azure",
|
llm_provider="azure",
|
||||||
|
@ -7357,7 +7446,7 @@ def exception_type(
|
||||||
elif original_exception.status_code == 422:
|
elif original_exception.status_code == 422:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
raise BadRequestError(
|
raise BadRequestError(
|
||||||
message=f"AzureException BadRequestError - {original_exception.message}",
|
message=f"AzureException BadRequestError - {message}",
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider="azure",
|
llm_provider="azure",
|
||||||
litellm_debug_info=extra_information,
|
litellm_debug_info=extra_information,
|
||||||
|
@ -7366,7 +7455,7 @@ def exception_type(
|
||||||
elif original_exception.status_code == 429:
|
elif original_exception.status_code == 429:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
raise RateLimitError(
|
raise RateLimitError(
|
||||||
message=f"AzureException RateLimitError - {original_exception.message}",
|
message=f"AzureException RateLimitError - {message}",
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider="azure",
|
llm_provider="azure",
|
||||||
litellm_debug_info=extra_information,
|
litellm_debug_info=extra_information,
|
||||||
|
@ -7375,7 +7464,7 @@ def exception_type(
|
||||||
elif original_exception.status_code == 503:
|
elif original_exception.status_code == 503:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
raise ServiceUnavailableError(
|
raise ServiceUnavailableError(
|
||||||
message=f"AzureException ServiceUnavailableError - {original_exception.message}",
|
message=f"AzureException ServiceUnavailableError - {message}",
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider="azure",
|
llm_provider="azure",
|
||||||
litellm_debug_info=extra_information,
|
litellm_debug_info=extra_information,
|
||||||
|
@ -7384,7 +7473,7 @@ def exception_type(
|
||||||
elif original_exception.status_code == 504: # gateway timeout error
|
elif original_exception.status_code == 504: # gateway timeout error
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
raise Timeout(
|
raise Timeout(
|
||||||
message=f"AzureException Timeout - {original_exception.message}",
|
message=f"AzureException Timeout - {message}",
|
||||||
model=model,
|
model=model,
|
||||||
litellm_debug_info=extra_information,
|
litellm_debug_info=extra_information,
|
||||||
llm_provider="azure",
|
llm_provider="azure",
|
||||||
|
@ -7393,7 +7482,7 @@ def exception_type(
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
raise APIError(
|
raise APIError(
|
||||||
status_code=original_exception.status_code,
|
status_code=original_exception.status_code,
|
||||||
message=f"AzureException APIError - {original_exception.message}",
|
message=f"AzureException APIError - {message}",
|
||||||
llm_provider="azure",
|
llm_provider="azure",
|
||||||
litellm_debug_info=extra_information,
|
litellm_debug_info=extra_information,
|
||||||
model=model,
|
model=model,
|
||||||
|
|
|
@ -1486,6 +1486,7 @@
|
||||||
"supports_system_messages": true,
|
"supports_system_messages": true,
|
||||||
"supports_function_calling": true,
|
"supports_function_calling": true,
|
||||||
"supports_tool_choice": true,
|
"supports_tool_choice": true,
|
||||||
|
"supports_response_schema": true,
|
||||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
|
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
|
||||||
},
|
},
|
||||||
"gemini-1.5-pro-001": {
|
"gemini-1.5-pro-001": {
|
||||||
|
@ -1511,6 +1512,7 @@
|
||||||
"supports_system_messages": true,
|
"supports_system_messages": true,
|
||||||
"supports_function_calling": true,
|
"supports_function_calling": true,
|
||||||
"supports_tool_choice": true,
|
"supports_tool_choice": true,
|
||||||
|
"supports_response_schema": true,
|
||||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
|
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
|
||||||
},
|
},
|
||||||
"gemini-1.5-pro-preview-0514": {
|
"gemini-1.5-pro-preview-0514": {
|
||||||
|
@ -1536,6 +1538,7 @@
|
||||||
"supports_system_messages": true,
|
"supports_system_messages": true,
|
||||||
"supports_function_calling": true,
|
"supports_function_calling": true,
|
||||||
"supports_tool_choice": true,
|
"supports_tool_choice": true,
|
||||||
|
"supports_response_schema": true,
|
||||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
|
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
|
||||||
},
|
},
|
||||||
"gemini-1.5-pro-preview-0215": {
|
"gemini-1.5-pro-preview-0215": {
|
||||||
|
@ -1561,6 +1564,7 @@
|
||||||
"supports_system_messages": true,
|
"supports_system_messages": true,
|
||||||
"supports_function_calling": true,
|
"supports_function_calling": true,
|
||||||
"supports_tool_choice": true,
|
"supports_tool_choice": true,
|
||||||
|
"supports_response_schema": true,
|
||||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
|
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
|
||||||
},
|
},
|
||||||
"gemini-1.5-pro-preview-0409": {
|
"gemini-1.5-pro-preview-0409": {
|
||||||
|
@ -1585,6 +1589,7 @@
|
||||||
"mode": "chat",
|
"mode": "chat",
|
||||||
"supports_function_calling": true,
|
"supports_function_calling": true,
|
||||||
"supports_tool_choice": true,
|
"supports_tool_choice": true,
|
||||||
|
"supports_response_schema": true,
|
||||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
|
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
|
||||||
},
|
},
|
||||||
"gemini-1.5-flash": {
|
"gemini-1.5-flash": {
|
||||||
|
@ -2007,6 +2012,7 @@
|
||||||
"supports_function_calling": true,
|
"supports_function_calling": true,
|
||||||
"supports_vision": true,
|
"supports_vision": true,
|
||||||
"supports_tool_choice": true,
|
"supports_tool_choice": true,
|
||||||
|
"supports_response_schema": true,
|
||||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
|
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
|
||||||
},
|
},
|
||||||
"gemini/gemini-1.5-pro-latest": {
|
"gemini/gemini-1.5-pro-latest": {
|
||||||
|
@ -2023,6 +2029,7 @@
|
||||||
"supports_function_calling": true,
|
"supports_function_calling": true,
|
||||||
"supports_vision": true,
|
"supports_vision": true,
|
||||||
"supports_tool_choice": true,
|
"supports_tool_choice": true,
|
||||||
|
"supports_response_schema": true,
|
||||||
"source": "https://ai.google.dev/models/gemini"
|
"source": "https://ai.google.dev/models/gemini"
|
||||||
},
|
},
|
||||||
"gemini/gemini-pro-vision": {
|
"gemini/gemini-pro-vision": {
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "litellm"
|
name = "litellm"
|
||||||
version = "1.40.31"
|
version = "1.41.3"
|
||||||
description = "Library to easily interface with LLM API providers"
|
description = "Library to easily interface with LLM API providers"
|
||||||
authors = ["BerriAI"]
|
authors = ["BerriAI"]
|
||||||
license = "MIT"
|
license = "MIT"
|
||||||
|
@ -27,7 +27,7 @@ jinja2 = "^3.1.2"
|
||||||
aiohttp = "*"
|
aiohttp = "*"
|
||||||
requests = "^2.31.0"
|
requests = "^2.31.0"
|
||||||
pydantic = "^2.0.0"
|
pydantic = "^2.0.0"
|
||||||
ijson = "*"
|
jsonschema = "^4.22.0"
|
||||||
|
|
||||||
uvicorn = {version = "^0.22.0", optional = true}
|
uvicorn = {version = "^0.22.0", optional = true}
|
||||||
gunicorn = {version = "^22.0.0", optional = true}
|
gunicorn = {version = "^22.0.0", optional = true}
|
||||||
|
@ -90,7 +90,7 @@ requires = ["poetry-core", "wheel"]
|
||||||
build-backend = "poetry.core.masonry.api"
|
build-backend = "poetry.core.masonry.api"
|
||||||
|
|
||||||
[tool.commitizen]
|
[tool.commitizen]
|
||||||
version = "1.40.31"
|
version = "1.41.3"
|
||||||
version_files = [
|
version_files = [
|
||||||
"pyproject.toml:^version"
|
"pyproject.toml:^version"
|
||||||
]
|
]
|
||||||
|
|
|
@ -46,5 +46,5 @@ aiohttp==3.9.0 # for network calls
|
||||||
aioboto3==12.3.0 # for async sagemaker calls
|
aioboto3==12.3.0 # for async sagemaker calls
|
||||||
tenacity==8.2.3 # for retrying requests, when litellm.num_retries set
|
tenacity==8.2.3 # for retrying requests, when litellm.num_retries set
|
||||||
pydantic==2.7.1 # proxy + openai req.
|
pydantic==2.7.1 # proxy + openai req.
|
||||||
ijson==3.2.3 # for google ai studio streaming
|
jsonschema==4.22.0 # validating json schema
|
||||||
####
|
####
|
59
tests/test_entrypoint.py
Normal file
59
tests/test_entrypoint.py
Normal file
|
@ -0,0 +1,59 @@
|
||||||
|
# What is this?
|
||||||
|
## Unit tests for 'entrypoint.sh'
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
import litellm
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="local test")
|
||||||
|
def test_decrypt_and_reset_env():
|
||||||
|
os.environ["DATABASE_URL"] = (
|
||||||
|
"aws_kms/AQICAHgwddjZ9xjVaZ9CNCG8smFU6FiQvfdrjL12DIqi9vUAQwHwF6U7caMgHQa6tK+TzaoMAAAAzjCBywYJKoZIhvcNAQcGoIG9MIG6AgEAMIG0BgkqhkiG9w0BBwEwHgYJYIZIAWUDBAEuMBEEDCmu+DVeKTm5tFZu6AIBEICBhnOFQYviL8JsciGk0bZsn9pfzeYWtNkVXEsl01AdgHBqT9UOZOI4ZC+T3wO/fXA7wdNF4o8ASPDbVZ34ZFdBs8xt4LKp9niufL30WYBkuuzz89ztly0jvE9pZ8L6BMw0ATTaMgIweVtVSDCeCzEb5PUPyxt4QayrlYHBGrNH5Aq/axFTe0La"
|
||||||
|
)
|
||||||
|
from litellm.proxy.secret_managers.aws_secret_manager import (
|
||||||
|
decrypt_and_reset_env_var,
|
||||||
|
)
|
||||||
|
|
||||||
|
decrypt_and_reset_env_var()
|
||||||
|
|
||||||
|
assert os.environ["DATABASE_URL"] is not None
|
||||||
|
assert isinstance(os.environ["DATABASE_URL"], str)
|
||||||
|
assert not os.environ["DATABASE_URL"].startswith("aws_kms/")
|
||||||
|
|
||||||
|
print("DATABASE_URL={}".format(os.environ["DATABASE_URL"]))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="local test")
|
||||||
|
def test_entrypoint_decrypt_and_reset():
|
||||||
|
os.environ["DATABASE_URL"] = (
|
||||||
|
"aws_kms/AQICAHgwddjZ9xjVaZ9CNCG8smFU6FiQvfdrjL12DIqi9vUAQwHwF6U7caMgHQa6tK+TzaoMAAAAzjCBywYJKoZIhvcNAQcGoIG9MIG6AgEAMIG0BgkqhkiG9w0BBwEwHgYJYIZIAWUDBAEuMBEEDCmu+DVeKTm5tFZu6AIBEICBhnOFQYviL8JsciGk0bZsn9pfzeYWtNkVXEsl01AdgHBqT9UOZOI4ZC+T3wO/fXA7wdNF4o8ASPDbVZ34ZFdBs8xt4LKp9niufL30WYBkuuzz89ztly0jvE9pZ8L6BMw0ATTaMgIweVtVSDCeCzEb5PUPyxt4QayrlYHBGrNH5Aq/axFTe0La"
|
||||||
|
)
|
||||||
|
command = "./entrypoint.sh"
|
||||||
|
directory = ".." # Relative to the current directory
|
||||||
|
|
||||||
|
# Run the command using subprocess
|
||||||
|
result = subprocess.run(
|
||||||
|
command, shell=True, cwd=directory, capture_output=True, text=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Print the output for debugging purposes
|
||||||
|
print("STDOUT:", result.stdout)
|
||||||
|
print("STDERR:", result.stderr)
|
||||||
|
|
||||||
|
# Assert the script ran successfully
|
||||||
|
assert result.returncode == 0, "The shell script did not execute successfully"
|
||||||
|
assert (
|
||||||
|
"DECRYPTS VALUE" in result.stdout
|
||||||
|
), "Expected output not found in script output"
|
||||||
|
assert (
|
||||||
|
"Database push successful!" in result.stdout
|
||||||
|
), "Expected output not found in script output"
|
||||||
|
|
||||||
|
assert False
|
Loading…
Add table
Add a link
Reference in a new issue