mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
Litellm dev 12 28 2024 p1 (#7463)
* refactor(utils.py): migrate amazon titan config to base config * refactor(utils.py): refactor bedrock meta invoke model translation to use base config * refactor(utils.py): move bedrock ai21 to base config * refactor(utils.py): move bedrock cohere to base config * refactor(utils.py): move bedrock mistral to use base config * refactor(utils.py): move all provider optional param translations to using a config * docs(clientside_auth.md): clarify how to pass vertex region to litellm proxy * fix(utils.py): handle scenario where custom llm provider is none / empty * fix: fix get config * test(test_otel_load_tests.py): widen perf margin * fix(utils.py): fix get provider config check to handle custom llm's * fix(utils.py): fix check
This commit is contained in:
parent
ec7fcc982d
commit
31ace870a2
11 changed files with 753 additions and 446 deletions
284
docs/my-website/docs/proxy/clientside_auth.md
Normal file
284
docs/my-website/docs/proxy/clientside_auth.md
Normal file
|
@ -0,0 +1,284 @@
|
||||||
|
# Clientside LLM Credentials
|
||||||
|
|
||||||
|
|
||||||
|
### Pass User LLM API Keys, Fallbacks
|
||||||
|
Allow your end-users to pass their model list, api base, OpenAI API key (any LiteLLM supported provider) to make requests
|
||||||
|
|
||||||
|
**Note** This is not related to [virtual keys](./virtual_keys.md). This is for when you want to pass in your users actual LLM API keys.
|
||||||
|
|
||||||
|
:::info
|
||||||
|
|
||||||
|
**You can pass a litellm.RouterConfig as `user_config`, See all supported params here https://github.com/BerriAI/litellm/blob/main/litellm/types/router.py **
|
||||||
|
|
||||||
|
:::
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
|
||||||
|
<TabItem value="openai-py" label="OpenAI Python">
|
||||||
|
|
||||||
|
#### Step 1: Define user model list & config
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
|
||||||
|
user_config = {
|
||||||
|
'model_list': [
|
||||||
|
{
|
||||||
|
'model_name': 'user-azure-instance',
|
||||||
|
'litellm_params': {
|
||||||
|
'model': 'azure/chatgpt-v-2',
|
||||||
|
'api_key': os.getenv('AZURE_API_KEY'),
|
||||||
|
'api_version': os.getenv('AZURE_API_VERSION'),
|
||||||
|
'api_base': os.getenv('AZURE_API_BASE'),
|
||||||
|
'timeout': 10,
|
||||||
|
},
|
||||||
|
'tpm': 240000,
|
||||||
|
'rpm': 1800,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'model_name': 'user-openai-instance',
|
||||||
|
'litellm_params': {
|
||||||
|
'model': 'gpt-3.5-turbo',
|
||||||
|
'api_key': os.getenv('OPENAI_API_KEY'),
|
||||||
|
'timeout': 10,
|
||||||
|
},
|
||||||
|
'tpm': 240000,
|
||||||
|
'rpm': 1800,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
'num_retries': 2,
|
||||||
|
'allowed_fails': 3,
|
||||||
|
'fallbacks': [
|
||||||
|
{
|
||||||
|
'user-azure-instance': ['user-openai-instance']
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Step 2: Send user_config in `extra_body`
|
||||||
|
```python
|
||||||
|
import openai
|
||||||
|
client = openai.OpenAI(
|
||||||
|
api_key="sk-1234",
|
||||||
|
base_url="http://0.0.0.0:4000"
|
||||||
|
)
|
||||||
|
|
||||||
|
# send request to `user-azure-instance`
|
||||||
|
response = client.chat.completions.create(model="user-azure-instance", messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "this is a test request, write a short poem"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
extra_body={
|
||||||
|
"user_config": user_config
|
||||||
|
}
|
||||||
|
) # 👈 User config
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
<TabItem value="openai-js" label="OpenAI JS">
|
||||||
|
|
||||||
|
#### Step 1: Define user model list & config
|
||||||
|
```javascript
|
||||||
|
const os = require('os');
|
||||||
|
|
||||||
|
const userConfig = {
|
||||||
|
model_list: [
|
||||||
|
{
|
||||||
|
model_name: 'user-azure-instance',
|
||||||
|
litellm_params: {
|
||||||
|
model: 'azure/chatgpt-v-2',
|
||||||
|
api_key: process.env.AZURE_API_KEY,
|
||||||
|
api_version: process.env.AZURE_API_VERSION,
|
||||||
|
api_base: process.env.AZURE_API_BASE,
|
||||||
|
timeout: 10,
|
||||||
|
},
|
||||||
|
tpm: 240000,
|
||||||
|
rpm: 1800,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
model_name: 'user-openai-instance',
|
||||||
|
litellm_params: {
|
||||||
|
model: 'gpt-3.5-turbo',
|
||||||
|
api_key: process.env.OPENAI_API_KEY,
|
||||||
|
timeout: 10,
|
||||||
|
},
|
||||||
|
tpm: 240000,
|
||||||
|
rpm: 1800,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
num_retries: 2,
|
||||||
|
allowed_fails: 3,
|
||||||
|
fallbacks: [
|
||||||
|
{
|
||||||
|
'user-azure-instance': ['user-openai-instance']
|
||||||
|
}
|
||||||
|
]
|
||||||
|
};
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Step 2: Send `user_config` as a param to `openai.chat.completions.create`
|
||||||
|
|
||||||
|
```javascript
|
||||||
|
const { OpenAI } = require('openai');
|
||||||
|
|
||||||
|
const openai = new OpenAI({
|
||||||
|
apiKey: "sk-1234",
|
||||||
|
baseURL: "http://0.0.0.0:4000"
|
||||||
|
});
|
||||||
|
|
||||||
|
async function main() {
|
||||||
|
const chatCompletion = await openai.chat.completions.create({
|
||||||
|
messages: [{ role: 'user', content: 'Say this is a test' }],
|
||||||
|
model: 'gpt-3.5-turbo',
|
||||||
|
user_config: userConfig // # 👈 User config
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
main();
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
### Pass User LLM API Keys / API Base
|
||||||
|
Allows your users to pass in their OpenAI API key/API base (any LiteLLM supported provider) to make requests
|
||||||
|
|
||||||
|
Here's how to do it:
|
||||||
|
|
||||||
|
#### 1. Enable configurable clientside auth credentials for a provider
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: "fireworks_ai/*"
|
||||||
|
litellm_params:
|
||||||
|
model: "fireworks_ai/*"
|
||||||
|
configurable_clientside_auth_params: ["api_base"]
|
||||||
|
# OR
|
||||||
|
configurable_clientside_auth_params: [{"api_base": "^https://litellm.*direct\.fireworks\.ai/v1$"}] # 👈 regex
|
||||||
|
```
|
||||||
|
|
||||||
|
Specify any/all auth params you want the user to be able to configure:
|
||||||
|
|
||||||
|
- api_base (✅ regex supported)
|
||||||
|
- api_key
|
||||||
|
- base_url
|
||||||
|
|
||||||
|
(check [provider docs](../providers/) for provider-specific auth params - e.g. `vertex_project`)
|
||||||
|
|
||||||
|
|
||||||
|
#### 2. Test it!
|
||||||
|
|
||||||
|
```python
|
||||||
|
import openai
|
||||||
|
client = openai.OpenAI(
|
||||||
|
api_key="sk-1234",
|
||||||
|
base_url="http://0.0.0.0:4000"
|
||||||
|
)
|
||||||
|
|
||||||
|
# request sent to model set on litellm proxy, `litellm --model`
|
||||||
|
response = client.chat.completions.create(model="gpt-3.5-turbo", messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "this is a test request, write a short poem"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
extra_body={"api_key": "my-bad-key", "api_base": "https://litellm-dev.direct.fireworks.ai/v1"}) # 👈 clientside credentials
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
```
|
||||||
|
|
||||||
|
More examples:
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="openai-py" label="Azure Credentials">
|
||||||
|
|
||||||
|
Pass in the litellm_params (E.g. api_key, api_base, etc.) via the `extra_body` parameter in the OpenAI client.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import openai
|
||||||
|
client = openai.OpenAI(
|
||||||
|
api_key="sk-1234",
|
||||||
|
base_url="http://0.0.0.0:4000"
|
||||||
|
)
|
||||||
|
|
||||||
|
# request sent to model set on litellm proxy, `litellm --model`
|
||||||
|
response = client.chat.completions.create(model="gpt-3.5-turbo", messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "this is a test request, write a short poem"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
extra_body={
|
||||||
|
"api_key": "my-azure-key",
|
||||||
|
"api_base": "my-azure-base",
|
||||||
|
"api_version": "my-azure-version"
|
||||||
|
}) # 👈 User Key
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="openai-js" label="OpenAI JS">
|
||||||
|
|
||||||
|
For JS, the OpenAI client accepts passing params in the `create(..)` body as normal.
|
||||||
|
|
||||||
|
```javascript
|
||||||
|
const { OpenAI } = require('openai');
|
||||||
|
|
||||||
|
const openai = new OpenAI({
|
||||||
|
apiKey: "sk-1234",
|
||||||
|
baseURL: "http://0.0.0.0:4000"
|
||||||
|
});
|
||||||
|
|
||||||
|
async function main() {
|
||||||
|
const chatCompletion = await openai.chat.completions.create({
|
||||||
|
messages: [{ role: 'user', content: 'Say this is a test' }],
|
||||||
|
model: 'gpt-3.5-turbo',
|
||||||
|
api_key: "my-bad-key" // 👈 User Key
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
main();
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
### Pass provider-specific params (e.g. Region, Project ID, etc.)
|
||||||
|
|
||||||
|
Specify the region, project id, etc. to use for making requests to Vertex AI on the clientside.
|
||||||
|
|
||||||
|
Any value passed in the Proxy's request body, will be checked by LiteLLM against the mapped openai / litellm auth params.
|
||||||
|
|
||||||
|
Unmapped params, will be assumed to be provider-specific params, and will be passed through to the provider in the LLM API's request body.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
import openai
|
||||||
|
client = openai.OpenAI(
|
||||||
|
api_key="anything",
|
||||||
|
base_url="http://0.0.0.0:4000"
|
||||||
|
)
|
||||||
|
|
||||||
|
# request sent to model set on litellm proxy, `litellm --model`
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "this is a test request, write a short poem"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
extra_body={ # pass any additional litellm_params here
|
||||||
|
vertex_ai_location: "us-east1"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
```
|
|
@ -996,254 +996,3 @@ Get a list of responses when `model` is passed as a list
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
### Pass User LLM API Keys, Fallbacks
|
|
||||||
Allow your end-users to pass their model list, api base, OpenAI API key (any LiteLLM supported provider) to make requests
|
|
||||||
|
|
||||||
**Note** This is not related to [virtual keys](./virtual_keys.md). This is for when you want to pass in your users actual LLM API keys.
|
|
||||||
|
|
||||||
:::info
|
|
||||||
|
|
||||||
**You can pass a litellm.RouterConfig as `user_config`, See all supported params here https://github.com/BerriAI/litellm/blob/main/litellm/types/router.py **
|
|
||||||
|
|
||||||
:::
|
|
||||||
|
|
||||||
<Tabs>
|
|
||||||
|
|
||||||
<TabItem value="openai-py" label="OpenAI Python">
|
|
||||||
|
|
||||||
#### Step 1: Define user model list & config
|
|
||||||
```python
|
|
||||||
import os
|
|
||||||
|
|
||||||
user_config = {
|
|
||||||
'model_list': [
|
|
||||||
{
|
|
||||||
'model_name': 'user-azure-instance',
|
|
||||||
'litellm_params': {
|
|
||||||
'model': 'azure/chatgpt-v-2',
|
|
||||||
'api_key': os.getenv('AZURE_API_KEY'),
|
|
||||||
'api_version': os.getenv('AZURE_API_VERSION'),
|
|
||||||
'api_base': os.getenv('AZURE_API_BASE'),
|
|
||||||
'timeout': 10,
|
|
||||||
},
|
|
||||||
'tpm': 240000,
|
|
||||||
'rpm': 1800,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
'model_name': 'user-openai-instance',
|
|
||||||
'litellm_params': {
|
|
||||||
'model': 'gpt-3.5-turbo',
|
|
||||||
'api_key': os.getenv('OPENAI_API_KEY'),
|
|
||||||
'timeout': 10,
|
|
||||||
},
|
|
||||||
'tpm': 240000,
|
|
||||||
'rpm': 1800,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
'num_retries': 2,
|
|
||||||
'allowed_fails': 3,
|
|
||||||
'fallbacks': [
|
|
||||||
{
|
|
||||||
'user-azure-instance': ['user-openai-instance']
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Step 2: Send user_config in `extra_body`
|
|
||||||
```python
|
|
||||||
import openai
|
|
||||||
client = openai.OpenAI(
|
|
||||||
api_key="sk-1234",
|
|
||||||
base_url="http://0.0.0.0:4000"
|
|
||||||
)
|
|
||||||
|
|
||||||
# send request to `user-azure-instance`
|
|
||||||
response = client.chat.completions.create(model="user-azure-instance", messages = [
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": "this is a test request, write a short poem"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
extra_body={
|
|
||||||
"user_config": user_config
|
|
||||||
}
|
|
||||||
) # 👈 User config
|
|
||||||
|
|
||||||
print(response)
|
|
||||||
```
|
|
||||||
|
|
||||||
</TabItem>
|
|
||||||
|
|
||||||
<TabItem value="openai-js" label="OpenAI JS">
|
|
||||||
|
|
||||||
#### Step 1: Define user model list & config
|
|
||||||
```javascript
|
|
||||||
const os = require('os');
|
|
||||||
|
|
||||||
const userConfig = {
|
|
||||||
model_list: [
|
|
||||||
{
|
|
||||||
model_name: 'user-azure-instance',
|
|
||||||
litellm_params: {
|
|
||||||
model: 'azure/chatgpt-v-2',
|
|
||||||
api_key: process.env.AZURE_API_KEY,
|
|
||||||
api_version: process.env.AZURE_API_VERSION,
|
|
||||||
api_base: process.env.AZURE_API_BASE,
|
|
||||||
timeout: 10,
|
|
||||||
},
|
|
||||||
tpm: 240000,
|
|
||||||
rpm: 1800,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
model_name: 'user-openai-instance',
|
|
||||||
litellm_params: {
|
|
||||||
model: 'gpt-3.5-turbo',
|
|
||||||
api_key: process.env.OPENAI_API_KEY,
|
|
||||||
timeout: 10,
|
|
||||||
},
|
|
||||||
tpm: 240000,
|
|
||||||
rpm: 1800,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
num_retries: 2,
|
|
||||||
allowed_fails: 3,
|
|
||||||
fallbacks: [
|
|
||||||
{
|
|
||||||
'user-azure-instance': ['user-openai-instance']
|
|
||||||
}
|
|
||||||
]
|
|
||||||
};
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Step 2: Send `user_config` as a param to `openai.chat.completions.create`
|
|
||||||
|
|
||||||
```javascript
|
|
||||||
const { OpenAI } = require('openai');
|
|
||||||
|
|
||||||
const openai = new OpenAI({
|
|
||||||
apiKey: "sk-1234",
|
|
||||||
baseURL: "http://0.0.0.0:4000"
|
|
||||||
});
|
|
||||||
|
|
||||||
async function main() {
|
|
||||||
const chatCompletion = await openai.chat.completions.create({
|
|
||||||
messages: [{ role: 'user', content: 'Say this is a test' }],
|
|
||||||
model: 'gpt-3.5-turbo',
|
|
||||||
user_config: userConfig // # 👈 User config
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
main();
|
|
||||||
```
|
|
||||||
|
|
||||||
</TabItem>
|
|
||||||
|
|
||||||
</Tabs>
|
|
||||||
|
|
||||||
### Pass User LLM API Keys / API Base
|
|
||||||
Allows your users to pass in their OpenAI API key/API base (any LiteLLM supported provider) to make requests
|
|
||||||
|
|
||||||
Here's how to do it:
|
|
||||||
|
|
||||||
#### 1. Enable configurable clientside auth credentials for a provider
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
model_list:
|
|
||||||
- model_name: "fireworks_ai/*"
|
|
||||||
litellm_params:
|
|
||||||
model: "fireworks_ai/*"
|
|
||||||
configurable_clientside_auth_params: ["api_base"]
|
|
||||||
# OR
|
|
||||||
configurable_clientside_auth_params: [{"api_base": "^https://litellm.*direct\.fireworks\.ai/v1$"}] # 👈 regex
|
|
||||||
```
|
|
||||||
|
|
||||||
Specify any/all auth params you want the user to be able to configure:
|
|
||||||
|
|
||||||
- api_base (✅ regex supported)
|
|
||||||
- api_key
|
|
||||||
- base_url
|
|
||||||
|
|
||||||
(check [provider docs](../providers/) for provider-specific auth params - e.g. `vertex_project`)
|
|
||||||
|
|
||||||
|
|
||||||
#### 2. Test it!
|
|
||||||
|
|
||||||
```python
|
|
||||||
import openai
|
|
||||||
client = openai.OpenAI(
|
|
||||||
api_key="sk-1234",
|
|
||||||
base_url="http://0.0.0.0:4000"
|
|
||||||
)
|
|
||||||
|
|
||||||
# request sent to model set on litellm proxy, `litellm --model`
|
|
||||||
response = client.chat.completions.create(model="gpt-3.5-turbo", messages = [
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": "this is a test request, write a short poem"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
extra_body={"api_key": "my-bad-key", "api_base": "https://litellm-dev.direct.fireworks.ai/v1"}) # 👈 clientside credentials
|
|
||||||
|
|
||||||
print(response)
|
|
||||||
```
|
|
||||||
|
|
||||||
More examples:
|
|
||||||
<Tabs>
|
|
||||||
<TabItem value="openai-py" label="Azure Credentials">
|
|
||||||
|
|
||||||
Pass in the litellm_params (E.g. api_key, api_base, etc.) via the `extra_body` parameter in the OpenAI client.
|
|
||||||
|
|
||||||
```python
|
|
||||||
import openai
|
|
||||||
client = openai.OpenAI(
|
|
||||||
api_key="sk-1234",
|
|
||||||
base_url="http://0.0.0.0:4000"
|
|
||||||
)
|
|
||||||
|
|
||||||
# request sent to model set on litellm proxy, `litellm --model`
|
|
||||||
response = client.chat.completions.create(model="gpt-3.5-turbo", messages = [
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": "this is a test request, write a short poem"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
extra_body={
|
|
||||||
"api_key": "my-azure-key",
|
|
||||||
"api_base": "my-azure-base",
|
|
||||||
"api_version": "my-azure-version"
|
|
||||||
}) # 👈 User Key
|
|
||||||
|
|
||||||
print(response)
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
</TabItem>
|
|
||||||
<TabItem value="openai-js" label="OpenAI JS">
|
|
||||||
|
|
||||||
For JS, the OpenAI client accepts passing params in the `create(..)` body as normal.
|
|
||||||
|
|
||||||
```javascript
|
|
||||||
const { OpenAI } = require('openai');
|
|
||||||
|
|
||||||
const openai = new OpenAI({
|
|
||||||
apiKey: "sk-1234",
|
|
||||||
baseURL: "http://0.0.0.0:4000"
|
|
||||||
});
|
|
||||||
|
|
||||||
async function main() {
|
|
||||||
const chatCompletion = await openai.chat.completions.create({
|
|
||||||
messages: [{ role: 'user', content: 'Say this is a test' }],
|
|
||||||
model: 'gpt-3.5-turbo',
|
|
||||||
api_key: "my-bad-key" // 👈 User Key
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
main();
|
|
||||||
```
|
|
||||||
</TabItem>
|
|
||||||
</Tabs>
|
|
|
@ -65,6 +65,7 @@ const sidebars = {
|
||||||
label: "Making LLM Requests",
|
label: "Making LLM Requests",
|
||||||
items: [
|
items: [
|
||||||
"proxy/user_keys",
|
"proxy/user_keys",
|
||||||
|
"proxy/clientside_auth",
|
||||||
"proxy/response_headers",
|
"proxy/response_headers",
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
|
|
|
@ -3,27 +3,26 @@ Common utilities used across bedrock chat/embedding/image generation
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
import types
|
import types
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List, Optional, Union
|
from typing import Any, List, Optional, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
from litellm.llms.base_llm.chat.transformation import (
|
||||||
|
BaseConfig,
|
||||||
|
BaseLLMException,
|
||||||
|
LiteLLMLoggingObj,
|
||||||
|
)
|
||||||
from litellm.secret_managers.main import get_secret
|
from litellm.secret_managers.main import get_secret
|
||||||
|
from litellm.types.llms.openai import AllMessageValues
|
||||||
|
from litellm.types.utils import ModelResponse
|
||||||
|
|
||||||
|
|
||||||
class BedrockError(Exception):
|
class BedrockError(BaseLLMException):
|
||||||
def __init__(self, status_code, message):
|
pass
|
||||||
self.status_code = status_code
|
|
||||||
self.message = message
|
|
||||||
self.request = httpx.Request(
|
|
||||||
method="POST", url="https://us-west-2.console.aws.amazon.com/bedrock"
|
|
||||||
)
|
|
||||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
|
||||||
super().__init__(
|
|
||||||
self.message
|
|
||||||
) # Call the base class constructor with the parameters it needs
|
|
||||||
|
|
||||||
|
|
||||||
class AmazonBedrockGlobalConfig:
|
class AmazonBedrockGlobalConfig:
|
||||||
|
@ -65,7 +64,64 @@ class AmazonBedrockGlobalConfig:
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class AmazonTitanConfig:
|
class AmazonInvokeMixin:
|
||||||
|
"""
|
||||||
|
Base class for bedrock models going through invoke_handler.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_error_class(
|
||||||
|
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||||
|
) -> BaseLLMException:
|
||||||
|
return BedrockError(
|
||||||
|
message=error_message,
|
||||||
|
status_code=status_code,
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
def transform_request(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
|
headers: dict,
|
||||||
|
) -> dict:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"transform_request not implemented for config. Done in invoke_handler.py"
|
||||||
|
)
|
||||||
|
|
||||||
|
def transform_response(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
raw_response: httpx.Response,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
logging_obj: LiteLLMLoggingObj,
|
||||||
|
request_data: dict,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
|
encoding: Any,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
json_mode: Optional[bool] = None,
|
||||||
|
) -> ModelResponse:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"transform_response not implemented for config. Done in invoke_handler.py"
|
||||||
|
)
|
||||||
|
|
||||||
|
def validate_environment(
|
||||||
|
self,
|
||||||
|
headers: dict,
|
||||||
|
model: str,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: dict,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
) -> dict:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"validate_environment not implemented for config. Done in invoke_handler.py"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AmazonTitanConfig(AmazonInvokeMixin, BaseConfig):
|
||||||
"""
|
"""
|
||||||
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-text-express-v1
|
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-text-express-v1
|
||||||
|
|
||||||
|
@ -100,6 +156,7 @@ class AmazonTitanConfig:
|
||||||
k: v
|
k: v
|
||||||
for k, v in cls.__dict__.items()
|
for k, v in cls.__dict__.items()
|
||||||
if not k.startswith("__")
|
if not k.startswith("__")
|
||||||
|
and not k.startswith("_abc")
|
||||||
and not isinstance(
|
and not isinstance(
|
||||||
v,
|
v,
|
||||||
(
|
(
|
||||||
|
@ -112,6 +169,62 @@ class AmazonTitanConfig:
|
||||||
and v is not None
|
and v is not None
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def _map_and_modify_arg(
|
||||||
|
self,
|
||||||
|
supported_params: dict,
|
||||||
|
provider: str,
|
||||||
|
model: str,
|
||||||
|
stop: Union[List[str], str],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
filter params to fit the required provider format, drop those that don't fit if user sets `litellm.drop_params = True`.
|
||||||
|
"""
|
||||||
|
filtered_stop = None
|
||||||
|
if "stop" in supported_params and litellm.drop_params:
|
||||||
|
if provider == "bedrock" and "amazon" in model:
|
||||||
|
filtered_stop = []
|
||||||
|
if isinstance(stop, list):
|
||||||
|
for s in stop:
|
||||||
|
if re.match(r"^(\|+|User:)$", s):
|
||||||
|
filtered_stop.append(s)
|
||||||
|
if filtered_stop is not None:
|
||||||
|
supported_params["stop"] = filtered_stop
|
||||||
|
|
||||||
|
return supported_params
|
||||||
|
|
||||||
|
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||||
|
return [
|
||||||
|
"max_tokens",
|
||||||
|
"max_completion_tokens",
|
||||||
|
"stop",
|
||||||
|
"temperature",
|
||||||
|
"top_p",
|
||||||
|
"stream",
|
||||||
|
]
|
||||||
|
|
||||||
|
def map_openai_params(
|
||||||
|
self,
|
||||||
|
non_default_params: dict,
|
||||||
|
optional_params: dict,
|
||||||
|
model: str,
|
||||||
|
drop_params: bool,
|
||||||
|
) -> dict:
|
||||||
|
for k, v in non_default_params.items():
|
||||||
|
if k == "max_tokens" or k == "max_completion_tokens":
|
||||||
|
optional_params["maxTokenCount"] = v
|
||||||
|
if k == "temperature":
|
||||||
|
optional_params["temperature"] = v
|
||||||
|
if k == "stop":
|
||||||
|
filtered_stop = self._map_and_modify_arg(
|
||||||
|
{"stop": v}, provider="bedrock", model=model, stop=v
|
||||||
|
)
|
||||||
|
optional_params["stopSequences"] = filtered_stop["stop"]
|
||||||
|
if k == "top_p":
|
||||||
|
optional_params["topP"] = v
|
||||||
|
if k == "stream":
|
||||||
|
optional_params["stream"] = v
|
||||||
|
return optional_params
|
||||||
|
|
||||||
|
|
||||||
class AmazonAnthropicClaude3Config:
|
class AmazonAnthropicClaude3Config:
|
||||||
"""
|
"""
|
||||||
|
@ -276,7 +389,7 @@ class AmazonAnthropicConfig:
|
||||||
return optional_params
|
return optional_params
|
||||||
|
|
||||||
|
|
||||||
class AmazonCohereConfig:
|
class AmazonCohereConfig(AmazonInvokeMixin, BaseConfig):
|
||||||
"""
|
"""
|
||||||
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=command
|
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=command
|
||||||
|
|
||||||
|
@ -308,6 +421,7 @@ class AmazonCohereConfig:
|
||||||
k: v
|
k: v
|
||||||
for k, v in cls.__dict__.items()
|
for k, v in cls.__dict__.items()
|
||||||
if not k.startswith("__")
|
if not k.startswith("__")
|
||||||
|
and not k.startswith("_abc")
|
||||||
and not isinstance(
|
and not isinstance(
|
||||||
v,
|
v,
|
||||||
(
|
(
|
||||||
|
@ -320,8 +434,31 @@ class AmazonCohereConfig:
|
||||||
and v is not None
|
and v is not None
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||||
|
return [
|
||||||
|
"max_tokens",
|
||||||
|
"temperature",
|
||||||
|
"stream",
|
||||||
|
]
|
||||||
|
|
||||||
class AmazonAI21Config:
|
def map_openai_params(
|
||||||
|
self,
|
||||||
|
non_default_params: dict,
|
||||||
|
optional_params: dict,
|
||||||
|
model: str,
|
||||||
|
drop_params: bool,
|
||||||
|
) -> dict:
|
||||||
|
for k, v in non_default_params.items():
|
||||||
|
if k == "stream":
|
||||||
|
optional_params["stream"] = v
|
||||||
|
if k == "temperature":
|
||||||
|
optional_params["temperature"] = v
|
||||||
|
if k == "max_tokens":
|
||||||
|
optional_params["max_tokens"] = v
|
||||||
|
return optional_params
|
||||||
|
|
||||||
|
|
||||||
|
class AmazonAI21Config(AmazonInvokeMixin, BaseConfig):
|
||||||
"""
|
"""
|
||||||
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=j2-ultra
|
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=j2-ultra
|
||||||
|
|
||||||
|
@ -371,6 +508,7 @@ class AmazonAI21Config:
|
||||||
k: v
|
k: v
|
||||||
for k, v in cls.__dict__.items()
|
for k, v in cls.__dict__.items()
|
||||||
if not k.startswith("__")
|
if not k.startswith("__")
|
||||||
|
and not k.startswith("_abc")
|
||||||
and not isinstance(
|
and not isinstance(
|
||||||
v,
|
v,
|
||||||
(
|
(
|
||||||
|
@ -383,13 +521,39 @@ class AmazonAI21Config:
|
||||||
and v is not None
|
and v is not None
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def get_supported_openai_params(self, model: str) -> List:
|
||||||
|
return [
|
||||||
|
"max_tokens",
|
||||||
|
"temperature",
|
||||||
|
"top_p",
|
||||||
|
"stream",
|
||||||
|
]
|
||||||
|
|
||||||
|
def map_openai_params(
|
||||||
|
self,
|
||||||
|
non_default_params: dict,
|
||||||
|
optional_params: dict,
|
||||||
|
model: str,
|
||||||
|
drop_params: bool,
|
||||||
|
) -> dict:
|
||||||
|
for k, v in non_default_params.items():
|
||||||
|
if k == "max_tokens":
|
||||||
|
optional_params["maxTokens"] = v
|
||||||
|
if k == "temperature":
|
||||||
|
optional_params["temperature"] = v
|
||||||
|
if k == "top_p":
|
||||||
|
optional_params["topP"] = v
|
||||||
|
if k == "stream":
|
||||||
|
optional_params["stream"] = v
|
||||||
|
return optional_params
|
||||||
|
|
||||||
|
|
||||||
class AnthropicConstants(Enum):
|
class AnthropicConstants(Enum):
|
||||||
HUMAN_PROMPT = "\n\nHuman: "
|
HUMAN_PROMPT = "\n\nHuman: "
|
||||||
AI_PROMPT = "\n\nAssistant: "
|
AI_PROMPT = "\n\nAssistant: "
|
||||||
|
|
||||||
|
|
||||||
class AmazonLlamaConfig:
|
class AmazonLlamaConfig(AmazonInvokeMixin, BaseConfig):
|
||||||
"""
|
"""
|
||||||
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=meta.llama2-13b-chat-v1
|
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=meta.llama2-13b-chat-v1
|
||||||
|
|
||||||
|
@ -421,6 +585,7 @@ class AmazonLlamaConfig:
|
||||||
k: v
|
k: v
|
||||||
for k, v in cls.__dict__.items()
|
for k, v in cls.__dict__.items()
|
||||||
if not k.startswith("__")
|
if not k.startswith("__")
|
||||||
|
and not k.startswith("_abc")
|
||||||
and not isinstance(
|
and not isinstance(
|
||||||
v,
|
v,
|
||||||
(
|
(
|
||||||
|
@ -433,8 +598,34 @@ class AmazonLlamaConfig:
|
||||||
and v is not None
|
and v is not None
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def get_supported_openai_params(self, model: str) -> List:
|
||||||
|
return [
|
||||||
|
"max_tokens",
|
||||||
|
"temperature",
|
||||||
|
"top_p",
|
||||||
|
"stream",
|
||||||
|
]
|
||||||
|
|
||||||
class AmazonMistralConfig:
|
def map_openai_params(
|
||||||
|
self,
|
||||||
|
non_default_params: dict,
|
||||||
|
optional_params: dict,
|
||||||
|
model: str,
|
||||||
|
drop_params: bool,
|
||||||
|
) -> dict:
|
||||||
|
for k, v in non_default_params.items():
|
||||||
|
if k == "max_tokens":
|
||||||
|
optional_params["max_gen_len"] = v
|
||||||
|
if k == "temperature":
|
||||||
|
optional_params["temperature"] = v
|
||||||
|
if k == "top_p":
|
||||||
|
optional_params["top_p"] = v
|
||||||
|
if k == "stream":
|
||||||
|
optional_params["stream"] = v
|
||||||
|
return optional_params
|
||||||
|
|
||||||
|
|
||||||
|
class AmazonMistralConfig(AmazonInvokeMixin, BaseConfig):
|
||||||
"""
|
"""
|
||||||
Reference: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-mistral.html
|
Reference: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-mistral.html
|
||||||
Supported Params for the Amazon / Mistral models:
|
Supported Params for the Amazon / Mistral models:
|
||||||
|
@ -471,6 +662,7 @@ class AmazonMistralConfig:
|
||||||
k: v
|
k: v
|
||||||
for k, v in cls.__dict__.items()
|
for k, v in cls.__dict__.items()
|
||||||
if not k.startswith("__")
|
if not k.startswith("__")
|
||||||
|
and not k.startswith("_abc")
|
||||||
and not isinstance(
|
and not isinstance(
|
||||||
v,
|
v,
|
||||||
(
|
(
|
||||||
|
@ -483,6 +675,29 @@ class AmazonMistralConfig:
|
||||||
and v is not None
|
and v is not None
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||||
|
return ["max_tokens", "temperature", "top_p", "stop", "stream"]
|
||||||
|
|
||||||
|
def map_openai_params(
|
||||||
|
self,
|
||||||
|
non_default_params: dict,
|
||||||
|
optional_params: dict,
|
||||||
|
model: str,
|
||||||
|
drop_params: bool,
|
||||||
|
) -> dict:
|
||||||
|
for k, v in non_default_params.items():
|
||||||
|
if k == "max_tokens":
|
||||||
|
optional_params["max_tokens"] = v
|
||||||
|
if k == "temperature":
|
||||||
|
optional_params["temperature"] = v
|
||||||
|
if k == "top_p":
|
||||||
|
optional_params["top_p"] = v
|
||||||
|
if k == "stop":
|
||||||
|
optional_params["stop"] = v
|
||||||
|
if k == "stream":
|
||||||
|
optional_params["stream"] = v
|
||||||
|
return optional_params
|
||||||
|
|
||||||
|
|
||||||
def add_custom_header(headers):
|
def add_custom_header(headers):
|
||||||
"""Closure to capture the headers and add them."""
|
"""Closure to capture the headers and add them."""
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
import litellm
|
||||||
from litellm.litellm_core_utils.prompt_templates.factory import (
|
from litellm.litellm_core_utils.prompt_templates.factory import (
|
||||||
convert_generic_image_chunk_to_openai_image_obj,
|
convert_generic_image_chunk_to_openai_image_obj,
|
||||||
convert_to_anthropic_image_obj,
|
convert_to_anthropic_image_obj,
|
||||||
|
@ -96,6 +97,8 @@ class GoogleAIStudioGeminiConfig(
|
||||||
del non_default_params["frequency_penalty"]
|
del non_default_params["frequency_penalty"]
|
||||||
if "presence_penalty" in non_default_params:
|
if "presence_penalty" in non_default_params:
|
||||||
del non_default_params["presence_penalty"]
|
del non_default_params["presence_penalty"]
|
||||||
|
if litellm.vertex_ai_safety_settings is not None:
|
||||||
|
optional_params["safety_settings"] = litellm.vertex_ai_safety_settings
|
||||||
return super().map_openai_params(
|
return super().map_openai_params(
|
||||||
model=model,
|
model=model,
|
||||||
non_default_params=non_default_params,
|
non_default_params=non_default_params,
|
||||||
|
|
|
@ -380,6 +380,8 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
|
||||||
if param == "seed":
|
if param == "seed":
|
||||||
optional_params["seed"] = value
|
optional_params["seed"] = value
|
||||||
|
|
||||||
|
if litellm.vertex_ai_safety_settings is not None:
|
||||||
|
optional_params["safety_settings"] = litellm.vertex_ai_safety_settings
|
||||||
return optional_params
|
return optional_params
|
||||||
|
|
||||||
def get_mapped_special_auth_params(self) -> dict:
|
def get_mapped_special_auth_params(self) -> dict:
|
||||||
|
|
224
litellm/utils.py
224
litellm/utils.py
|
@ -2773,23 +2773,13 @@ def get_optional_params( # noqa: PLR0915
|
||||||
message=f"{custom_llm_provider} does not support parameters: {unsupported_params}, for model={model}. To drop these, set `litellm.drop_params=True` or for proxy:\n\n`litellm_settings:\n drop_params: true`\n",
|
message=f"{custom_llm_provider} does not support parameters: {unsupported_params}, for model={model}. To drop these, set `litellm.drop_params=True` or for proxy:\n\n`litellm_settings:\n drop_params: true`\n",
|
||||||
)
|
)
|
||||||
|
|
||||||
def _map_and_modify_arg(supported_params: dict, provider: str, model: str):
|
provider_config: Optional[BaseConfig] = None
|
||||||
"""
|
if custom_llm_provider is not None and custom_llm_provider in [
|
||||||
filter params to fit the required provider format, drop those that don't fit if user sets `litellm.drop_params = True`.
|
provider.value for provider in LlmProviders
|
||||||
"""
|
]:
|
||||||
filtered_stop = None
|
provider_config = ProviderConfigManager.get_provider_chat_config(
|
||||||
if "stop" in supported_params and litellm.drop_params:
|
model=model, provider=LlmProviders(custom_llm_provider)
|
||||||
if provider == "bedrock" and "amazon" in model:
|
)
|
||||||
filtered_stop = []
|
|
||||||
if isinstance(stop, list):
|
|
||||||
for s in stop:
|
|
||||||
if re.match(r"^(\|+|User:)$", s):
|
|
||||||
filtered_stop.append(s)
|
|
||||||
if filtered_stop is not None:
|
|
||||||
supported_params["stop"] = filtered_stop
|
|
||||||
|
|
||||||
return supported_params
|
|
||||||
|
|
||||||
## raise exception if provider doesn't support passed in param
|
## raise exception if provider doesn't support passed in param
|
||||||
if custom_llm_provider == "anthropic":
|
if custom_llm_provider == "anthropic":
|
||||||
## check if unsupported param passed in
|
## check if unsupported param passed in
|
||||||
|
@ -2885,21 +2875,16 @@ def get_optional_params( # noqa: PLR0915
|
||||||
model=model, custom_llm_provider=custom_llm_provider
|
model=model, custom_llm_provider=custom_llm_provider
|
||||||
)
|
)
|
||||||
_check_valid_arg(supported_params=supported_params)
|
_check_valid_arg(supported_params=supported_params)
|
||||||
# handle cohere params
|
optional_params = litellm.MaritalkConfig().map_openai_params(
|
||||||
if stream:
|
non_default_params=non_default_params,
|
||||||
optional_params["stream"] = stream
|
optional_params=optional_params,
|
||||||
if temperature is not None:
|
model=model,
|
||||||
optional_params["temperature"] = temperature
|
drop_params=(
|
||||||
if max_tokens is not None:
|
drop_params
|
||||||
optional_params["max_tokens"] = max_tokens
|
if drop_params is not None and isinstance(drop_params, bool)
|
||||||
if logit_bias is not None:
|
else False
|
||||||
optional_params["logit_bias"] = logit_bias
|
),
|
||||||
if top_p is not None:
|
)
|
||||||
optional_params["p"] = top_p
|
|
||||||
if presence_penalty is not None:
|
|
||||||
optional_params["repetition_penalty"] = presence_penalty
|
|
||||||
if stop is not None:
|
|
||||||
optional_params["stopping_tokens"] = stop
|
|
||||||
elif custom_llm_provider == "replicate":
|
elif custom_llm_provider == "replicate":
|
||||||
## check if unsupported param passed in
|
## check if unsupported param passed in
|
||||||
supported_params = get_supported_openai_params(
|
supported_params = get_supported_openai_params(
|
||||||
|
@ -2990,8 +2975,6 @@ def get_optional_params( # noqa: PLR0915
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
if litellm.vertex_ai_safety_settings is not None:
|
|
||||||
optional_params["safety_settings"] = litellm.vertex_ai_safety_settings
|
|
||||||
elif custom_llm_provider == "gemini":
|
elif custom_llm_provider == "gemini":
|
||||||
supported_params = get_supported_openai_params(
|
supported_params = get_supported_openai_params(
|
||||||
model=model, custom_llm_provider=custom_llm_provider
|
model=model, custom_llm_provider=custom_llm_provider
|
||||||
|
@ -3024,8 +3007,6 @@ def get_optional_params( # noqa: PLR0915
|
||||||
else False
|
else False
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
if litellm.vertex_ai_safety_settings is not None:
|
|
||||||
optional_params["safety_settings"] = litellm.vertex_ai_safety_settings
|
|
||||||
elif litellm.VertexAIAnthropicConfig.is_supported_model(
|
elif litellm.VertexAIAnthropicConfig.is_supported_model(
|
||||||
model=model, custom_llm_provider=custom_llm_provider
|
model=model, custom_llm_provider=custom_llm_provider
|
||||||
):
|
):
|
||||||
|
@ -3135,18 +3116,7 @@ def get_optional_params( # noqa: PLR0915
|
||||||
),
|
),
|
||||||
messages=messages,
|
messages=messages,
|
||||||
)
|
)
|
||||||
elif "ai21" in model:
|
|
||||||
_check_valid_arg(supported_params=supported_params)
|
|
||||||
# params "maxTokens":200,"temperature":0,"topP":250,"stop_sequences":[],
|
|
||||||
# https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=j2-ultra
|
|
||||||
if max_tokens is not None:
|
|
||||||
optional_params["maxTokens"] = max_tokens
|
|
||||||
if temperature is not None:
|
|
||||||
optional_params["temperature"] = temperature
|
|
||||||
if top_p is not None:
|
|
||||||
optional_params["topP"] = top_p
|
|
||||||
if stream:
|
|
||||||
optional_params["stream"] = stream
|
|
||||||
elif "anthropic" in model:
|
elif "anthropic" in model:
|
||||||
_check_valid_arg(supported_params=supported_params)
|
_check_valid_arg(supported_params=supported_params)
|
||||||
if "aws_bedrock_client" in passed_params: # deprecated boto3.invoke route.
|
if "aws_bedrock_client" in passed_params: # deprecated boto3.invoke route.
|
||||||
|
@ -3162,84 +3132,18 @@ def get_optional_params( # noqa: PLR0915
|
||||||
non_default_params=non_default_params,
|
non_default_params=non_default_params,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
)
|
)
|
||||||
elif "amazon" in model: # amazon titan llms
|
elif provider_config is not None:
|
||||||
_check_valid_arg(supported_params=supported_params)
|
_check_valid_arg(supported_params=supported_params)
|
||||||
# see https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-large
|
optional_params = provider_config.map_openai_params(
|
||||||
if max_tokens is not None:
|
non_default_params=non_default_params,
|
||||||
optional_params["maxTokenCount"] = max_tokens
|
optional_params=optional_params,
|
||||||
if temperature is not None:
|
model=model,
|
||||||
optional_params["temperature"] = temperature
|
drop_params=(
|
||||||
if stop is not None:
|
drop_params
|
||||||
filtered_stop = _map_and_modify_arg(
|
if drop_params is not None and isinstance(drop_params, bool)
|
||||||
{"stop": stop}, provider="bedrock", model=model
|
else False
|
||||||
)
|
),
|
||||||
optional_params["stopSequences"] = filtered_stop["stop"]
|
)
|
||||||
if top_p is not None:
|
|
||||||
optional_params["topP"] = top_p
|
|
||||||
if stream:
|
|
||||||
optional_params["stream"] = stream
|
|
||||||
elif "meta" in model: # amazon / meta llms
|
|
||||||
_check_valid_arg(supported_params=supported_params)
|
|
||||||
# see https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-large
|
|
||||||
if max_tokens is not None:
|
|
||||||
optional_params["max_gen_len"] = max_tokens
|
|
||||||
if temperature is not None:
|
|
||||||
optional_params["temperature"] = temperature
|
|
||||||
if top_p is not None:
|
|
||||||
optional_params["top_p"] = top_p
|
|
||||||
if stream:
|
|
||||||
optional_params["stream"] = stream
|
|
||||||
elif "cohere" in model: # cohere models on bedrock
|
|
||||||
_check_valid_arg(supported_params=supported_params)
|
|
||||||
# handle cohere params
|
|
||||||
if stream:
|
|
||||||
optional_params["stream"] = stream
|
|
||||||
if temperature is not None:
|
|
||||||
optional_params["temperature"] = temperature
|
|
||||||
if max_tokens is not None:
|
|
||||||
optional_params["max_tokens"] = max_tokens
|
|
||||||
elif "mistral" in model:
|
|
||||||
_check_valid_arg(supported_params=supported_params)
|
|
||||||
# mistral params on bedrock
|
|
||||||
# \"max_tokens\":400,\"temperature\":0.7,\"top_p\":0.7,\"stop\":[\"\\\\n\\\\nHuman:\"]}"
|
|
||||||
if max_tokens is not None:
|
|
||||||
optional_params["max_tokens"] = max_tokens
|
|
||||||
if temperature is not None:
|
|
||||||
optional_params["temperature"] = temperature
|
|
||||||
if top_p is not None:
|
|
||||||
optional_params["top_p"] = top_p
|
|
||||||
if stop is not None:
|
|
||||||
optional_params["stop"] = stop
|
|
||||||
if stream is not None:
|
|
||||||
optional_params["stream"] = stream
|
|
||||||
elif custom_llm_provider == "aleph_alpha":
|
|
||||||
supported_params = [
|
|
||||||
"max_tokens",
|
|
||||||
"stream",
|
|
||||||
"top_p",
|
|
||||||
"temperature",
|
|
||||||
"presence_penalty",
|
|
||||||
"frequency_penalty",
|
|
||||||
"n",
|
|
||||||
"stop",
|
|
||||||
]
|
|
||||||
_check_valid_arg(supported_params=supported_params)
|
|
||||||
if max_tokens is not None:
|
|
||||||
optional_params["maximum_tokens"] = max_tokens
|
|
||||||
if stream:
|
|
||||||
optional_params["stream"] = stream
|
|
||||||
if temperature is not None:
|
|
||||||
optional_params["temperature"] = temperature
|
|
||||||
if top_p is not None:
|
|
||||||
optional_params["top_p"] = top_p
|
|
||||||
if presence_penalty is not None:
|
|
||||||
optional_params["presence_penalty"] = presence_penalty
|
|
||||||
if frequency_penalty is not None:
|
|
||||||
optional_params["frequency_penalty"] = frequency_penalty
|
|
||||||
if n is not None:
|
|
||||||
optional_params["n"] = n
|
|
||||||
if stop is not None:
|
|
||||||
optional_params["stop_sequences"] = stop
|
|
||||||
elif custom_llm_provider == "cloudflare":
|
elif custom_llm_provider == "cloudflare":
|
||||||
# https://developers.cloudflare.com/workers-ai/models/text-generation/#input
|
# https://developers.cloudflare.com/workers-ai/models/text-generation/#input
|
||||||
supported_params = get_supported_openai_params(
|
supported_params = get_supported_openai_params(
|
||||||
|
@ -3336,57 +3240,21 @@ def get_optional_params( # noqa: PLR0915
|
||||||
else False
|
else False
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "perplexity":
|
elif custom_llm_provider == "perplexity" and provider_config is not None:
|
||||||
supported_params = get_supported_openai_params(
|
supported_params = get_supported_openai_params(
|
||||||
model=model, custom_llm_provider=custom_llm_provider
|
model=model, custom_llm_provider=custom_llm_provider
|
||||||
)
|
)
|
||||||
_check_valid_arg(supported_params=supported_params)
|
_check_valid_arg(supported_params=supported_params)
|
||||||
if temperature is not None:
|
optional_params = provider_config.map_openai_params(
|
||||||
if (
|
non_default_params=non_default_params,
|
||||||
temperature == 0 and model == "mistral-7b-instruct"
|
optional_params=optional_params,
|
||||||
): # this model does no support temperature == 0
|
model=model,
|
||||||
temperature = 0.0001 # close to 0
|
drop_params=(
|
||||||
optional_params["temperature"] = temperature
|
drop_params
|
||||||
if top_p:
|
if drop_params is not None and isinstance(drop_params, bool)
|
||||||
optional_params["top_p"] = top_p
|
else False
|
||||||
if stream:
|
),
|
||||||
optional_params["stream"] = stream
|
|
||||||
if max_tokens:
|
|
||||||
optional_params["max_tokens"] = max_tokens
|
|
||||||
if presence_penalty:
|
|
||||||
optional_params["presence_penalty"] = presence_penalty
|
|
||||||
if frequency_penalty:
|
|
||||||
optional_params["frequency_penalty"] = frequency_penalty
|
|
||||||
elif custom_llm_provider == "anyscale":
|
|
||||||
supported_params = get_supported_openai_params(
|
|
||||||
model=model, custom_llm_provider=custom_llm_provider
|
|
||||||
)
|
)
|
||||||
if model in [
|
|
||||||
"mistralai/Mistral-7B-Instruct-v0.1",
|
|
||||||
"mistralai/Mixtral-8x7B-Instruct-v0.1",
|
|
||||||
]:
|
|
||||||
supported_params += [ # type: ignore
|
|
||||||
"functions",
|
|
||||||
"function_call",
|
|
||||||
"tools",
|
|
||||||
"tool_choice",
|
|
||||||
"response_format",
|
|
||||||
]
|
|
||||||
_check_valid_arg(supported_params=supported_params)
|
|
||||||
optional_params = non_default_params
|
|
||||||
if temperature is not None:
|
|
||||||
if temperature == 0 and model in [
|
|
||||||
"mistralai/Mistral-7B-Instruct-v0.1",
|
|
||||||
"mistralai/Mixtral-8x7B-Instruct-v0.1",
|
|
||||||
]: # this model does no support temperature == 0
|
|
||||||
temperature = 0.0001 # close to 0
|
|
||||||
optional_params["temperature"] = temperature
|
|
||||||
if top_p:
|
|
||||||
optional_params["top_p"] = top_p
|
|
||||||
if stream:
|
|
||||||
optional_params["stream"] = stream
|
|
||||||
if max_tokens:
|
|
||||||
optional_params["max_tokens"] = max_tokens
|
|
||||||
elif custom_llm_provider == "mistral" or custom_llm_provider == "codestral":
|
elif custom_llm_provider == "mistral" or custom_llm_provider == "codestral":
|
||||||
supported_params = get_supported_openai_params(
|
supported_params = get_supported_openai_params(
|
||||||
model=model, custom_llm_provider=custom_llm_provider
|
model=model, custom_llm_provider=custom_llm_provider
|
||||||
|
@ -6302,6 +6170,20 @@ class ProviderConfigManager:
|
||||||
return litellm.TritonConfig()
|
return litellm.TritonConfig()
|
||||||
elif litellm.LlmProviders.PETALS == provider:
|
elif litellm.LlmProviders.PETALS == provider:
|
||||||
return litellm.PetalsConfig()
|
return litellm.PetalsConfig()
|
||||||
|
elif litellm.LlmProviders.BEDROCK == provider:
|
||||||
|
base_model = litellm.AmazonConverseConfig()._get_base_model(model)
|
||||||
|
if base_model in litellm.bedrock_converse_models:
|
||||||
|
pass
|
||||||
|
elif "amazon" in model: # amazon titan llms
|
||||||
|
return litellm.AmazonTitanConfig()
|
||||||
|
elif "meta" in model: # amazon / meta llms
|
||||||
|
return litellm.AmazonLlamaConfig()
|
||||||
|
elif "ai21" in model: # ai21 llms
|
||||||
|
return litellm.AmazonAI21Config()
|
||||||
|
elif "cohere" in model: # cohere models on bedrock
|
||||||
|
return litellm.AmazonCohereConfig()
|
||||||
|
elif "mistral" in model: # mistral models on bedrock
|
||||||
|
return litellm.AmazonMistralConfig()
|
||||||
return litellm.OpenAIGPTConfig()
|
return litellm.OpenAIGPTConfig()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
151
tests/documentation_tests/test_optional_params.py
Normal file
151
tests/documentation_tests/test_optional_params.py
Normal file
|
@ -0,0 +1,151 @@
|
||||||
|
import ast
|
||||||
|
from typing import List, Set, Dict, Optional
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigChecker(ast.NodeVisitor):
|
||||||
|
def __init__(self):
|
||||||
|
self.errors: List[str] = []
|
||||||
|
self.current_provider_block: Optional[str] = None
|
||||||
|
self.param_assignments: Dict[str, Set[str]] = {}
|
||||||
|
self.map_openai_calls: Set[str] = set()
|
||||||
|
self.class_inheritance: Dict[str, List[str]] = {}
|
||||||
|
|
||||||
|
def get_full_name(self, node):
|
||||||
|
"""Recursively extract the full name from a node."""
|
||||||
|
if isinstance(node, ast.Name):
|
||||||
|
return node.id
|
||||||
|
elif isinstance(node, ast.Attribute):
|
||||||
|
base = self.get_full_name(node.value)
|
||||||
|
if base:
|
||||||
|
return f"{base}.{node.attr}"
|
||||||
|
return None
|
||||||
|
|
||||||
|
def visit_ClassDef(self, node: ast.ClassDef):
|
||||||
|
# Record class inheritance
|
||||||
|
bases = [base.id for base in node.bases if isinstance(base, ast.Name)]
|
||||||
|
print(f"Found class {node.name} with bases {bases}")
|
||||||
|
self.class_inheritance[node.name] = bases
|
||||||
|
self.generic_visit(node)
|
||||||
|
|
||||||
|
def visit_Call(self, node: ast.Call):
|
||||||
|
# Check for map_openai_params calls
|
||||||
|
if (
|
||||||
|
isinstance(node.func, ast.Attribute)
|
||||||
|
and node.func.attr == "map_openai_params"
|
||||||
|
):
|
||||||
|
if isinstance(node.func.value, ast.Name):
|
||||||
|
config_name = node.func.value.id
|
||||||
|
self.map_openai_calls.add(config_name)
|
||||||
|
self.generic_visit(node)
|
||||||
|
|
||||||
|
def visit_If(self, node: ast.If):
|
||||||
|
# Detect custom_llm_provider blocks
|
||||||
|
provider = self._extract_provider_from_if(node)
|
||||||
|
if provider:
|
||||||
|
old_provider = self.current_provider_block
|
||||||
|
self.current_provider_block = provider
|
||||||
|
self.generic_visit(node)
|
||||||
|
self.current_provider_block = old_provider
|
||||||
|
else:
|
||||||
|
self.generic_visit(node)
|
||||||
|
|
||||||
|
def visit_Assign(self, node: ast.Assign):
|
||||||
|
# Track assignments to optional_params
|
||||||
|
if self.current_provider_block and len(node.targets) == 1:
|
||||||
|
target = node.targets[0]
|
||||||
|
if isinstance(target, ast.Subscript) and isinstance(target.value, ast.Name):
|
||||||
|
if target.value.id == "optional_params":
|
||||||
|
if isinstance(target.slice, ast.Constant):
|
||||||
|
key = target.slice.value
|
||||||
|
if self.current_provider_block not in self.param_assignments:
|
||||||
|
self.param_assignments[self.current_provider_block] = set()
|
||||||
|
self.param_assignments[self.current_provider_block].add(key)
|
||||||
|
self.generic_visit(node)
|
||||||
|
|
||||||
|
def _extract_provider_from_if(self, node: ast.If) -> Optional[str]:
|
||||||
|
"""Extract the provider name from an if condition checking custom_llm_provider"""
|
||||||
|
if isinstance(node.test, ast.Compare):
|
||||||
|
if len(node.test.ops) == 1 and isinstance(node.test.ops[0], ast.Eq):
|
||||||
|
if (
|
||||||
|
isinstance(node.test.left, ast.Name)
|
||||||
|
and node.test.left.id == "custom_llm_provider"
|
||||||
|
):
|
||||||
|
if isinstance(node.test.comparators[0], ast.Constant):
|
||||||
|
return node.test.comparators[0].value
|
||||||
|
return None
|
||||||
|
|
||||||
|
def check_patterns(self) -> List[str]:
|
||||||
|
# Check if all configs using map_openai_params inherit from BaseConfig
|
||||||
|
for config_name in self.map_openai_calls:
|
||||||
|
print(f"Checking config: {config_name}")
|
||||||
|
if (
|
||||||
|
config_name not in self.class_inheritance
|
||||||
|
or "BaseConfig" not in self.class_inheritance[config_name]
|
||||||
|
):
|
||||||
|
# Retrieve the associated class name, if any
|
||||||
|
class_name = next(
|
||||||
|
(
|
||||||
|
cls
|
||||||
|
for cls, bases in self.class_inheritance.items()
|
||||||
|
if config_name in bases
|
||||||
|
),
|
||||||
|
"Unknown Class",
|
||||||
|
)
|
||||||
|
self.errors.append(
|
||||||
|
f"Error: {config_name} calls map_openai_params but doesn't inherit from BaseConfig. "
|
||||||
|
f"It is used in the class: {class_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for parameter assignments in provider blocks
|
||||||
|
for provider, params in self.param_assignments.items():
|
||||||
|
# You can customize which parameters should raise warnings for each provider
|
||||||
|
for param in params:
|
||||||
|
if param not in self._get_allowed_params(provider):
|
||||||
|
self.errors.append(
|
||||||
|
f"Warning: Parameter '{param}' is directly assigned in {provider} block. "
|
||||||
|
f"Consider using a config class instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.errors
|
||||||
|
|
||||||
|
def _get_allowed_params(self, provider: str) -> Set[str]:
|
||||||
|
"""Define allowed direct parameter assignments for each provider"""
|
||||||
|
# You can customize this based on your requirements
|
||||||
|
common_allowed = {"stream", "api_key", "api_base"}
|
||||||
|
provider_specific = {
|
||||||
|
"anthropic": {"api_version"},
|
||||||
|
"openai": {"organization"},
|
||||||
|
# Add more providers and their allowed params here
|
||||||
|
}
|
||||||
|
return common_allowed.union(provider_specific.get(provider, set()))
|
||||||
|
|
||||||
|
|
||||||
|
def check_file(file_path: str) -> List[str]:
|
||||||
|
with open(file_path, "r") as file:
|
||||||
|
tree = ast.parse(file.read())
|
||||||
|
|
||||||
|
checker = ConfigChecker()
|
||||||
|
for node in tree.body:
|
||||||
|
if isinstance(node, ast.FunctionDef) and node.name == "get_optional_params":
|
||||||
|
checker.visit(node)
|
||||||
|
break # No need to visit other functions
|
||||||
|
return checker.check_patterns()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
file_path = "../../litellm/utils.py"
|
||||||
|
errors = check_file(file_path)
|
||||||
|
|
||||||
|
if errors:
|
||||||
|
print("\nFound the following issues:")
|
||||||
|
for error in errors:
|
||||||
|
print(f"- {error}")
|
||||||
|
sys.exit(1)
|
||||||
|
else:
|
||||||
|
print("No issues found!")
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -121,6 +121,26 @@ def test_bedrock_optional_params_completions(model):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model",
|
||||||
|
[
|
||||||
|
"bedrock/amazon.titan-large",
|
||||||
|
"bedrock/meta.llama3-2-11b-instruct-v1:0",
|
||||||
|
"bedrock/ai21.j2-ultra-v1",
|
||||||
|
"bedrock/cohere.command-nightly",
|
||||||
|
"bedrock/mistral.mistral-7b",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_bedrock_optional_params_simple(model):
|
||||||
|
litellm.drop_params = True
|
||||||
|
get_optional_params(
|
||||||
|
model=model,
|
||||||
|
max_tokens=10,
|
||||||
|
temperature=0.1,
|
||||||
|
custom_llm_provider="bedrock",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model, expected_dimensions, dimensions_kwarg",
|
"model, expected_dimensions, dimensions_kwarg",
|
||||||
[
|
[
|
||||||
|
|
|
@ -42,7 +42,7 @@ def test_otel_logging_async():
|
||||||
print(f"Average performance difference: {avg_percent_diff:.2f}%")
|
print(f"Average performance difference: {avg_percent_diff:.2f}%")
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
avg_percent_diff < 15
|
avg_percent_diff < 20
|
||||||
), f"Average performance difference of {avg_percent_diff:.2f}% exceeds 15% threshold"
|
), f"Average performance difference of {avg_percent_diff:.2f}% exceeds 15% threshold"
|
||||||
|
|
||||||
except litellm.Timeout as e:
|
except litellm.Timeout as e:
|
||||||
|
|
|
@ -1385,13 +1385,13 @@ async def test_completion_replicate_llama3_streaming(sync_mode):
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model, region",
|
"model, region",
|
||||||
[
|
[
|
||||||
["bedrock/ai21.jamba-instruct-v1:0", "us-east-1"],
|
# ["bedrock/ai21.jamba-instruct-v1:0", "us-east-1"],
|
||||||
["bedrock/cohere.command-r-plus-v1:0", None],
|
# ["bedrock/cohere.command-r-plus-v1:0", None],
|
||||||
["anthropic.claude-3-sonnet-20240229-v1:0", None],
|
# ["anthropic.claude-3-sonnet-20240229-v1:0", None],
|
||||||
["anthropic.claude-instant-v1", None],
|
# ["anthropic.claude-instant-v1", None],
|
||||||
["mistral.mistral-7b-instruct-v0:2", None],
|
# ["mistral.mistral-7b-instruct-v0:2", None],
|
||||||
["bedrock/amazon.titan-tg1-large", None],
|
["bedrock/amazon.titan-tg1-large", None],
|
||||||
["meta.llama3-8b-instruct-v1:0", None],
|
# ["meta.llama3-8b-instruct-v1:0", None],
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue