Litellm dev 12 28 2024 p2 (#7458)

* docs(sidebar.js): docs for support model access groups for wildcard routes

* feat(key_management_endpoints.py): add check if user is premium_user when adding model access group for wildcard route

* refactor(docs/): make control model access a root-level doc in proxy sidebar

easier to discover how to control model access on litellm

* docs: more cleanup

* feat(fireworks_ai/): add document inlining support

Enables user to call non-vision models with images/pdfs/etc.

* test(test_fireworks_ai_translation.py): add unit testing for fireworks ai transform inline helper util

* docs(docs/): add document inlining details to fireworks ai docs

* feat(fireworks_ai/): allow user to dynamically disable auto add transform inline

allows client-side disabling of this feature for proxy users

* feat(fireworks_ai/): return 'supports_vision' and 'supports_pdf_input' true on all fireworks ai models

now true as fireworks ai supports document inlining

* test: fix tests

* fix(router.py): add unit testing for _is_model_access_group_for_wildcard_route
This commit is contained in:
Krish Dholakia 2024-12-28 19:38:06 -08:00 committed by GitHub
parent 3eb962c594
commit cfb6890b9f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 832 additions and 305 deletions

View file

@ -190,6 +190,116 @@ print(response)
</TabItem> </TabItem>
</Tabs> </Tabs>
## Document Inlining
LiteLLM supports document inlining for Fireworks AI models. This is useful for models that are not vision models, but still need to parse documents/images/etc.
LiteLLM will add `#transform=inline` to the url of the image_url, if the model is not a vision model.[**See Code**](https://github.com/BerriAI/litellm/blob/1ae9d45798bdaf8450f2dfdec703369f3d2212b7/litellm/llms/fireworks_ai/chat/transformation.py#L114)
<Tabs>
<TabItem value="sdk" label="SDK">
```python
from litellm import completion
import os
os.environ["FIREWORKS_AI_API_KEY"] = "YOUR_API_KEY"
os.environ["FIREWORKS_AI_API_BASE"] = "https://audio-prod.us-virginia-1.direct.fireworks.ai/v1"
completion = litellm.completion(
model="fireworks_ai/accounts/fireworks/models/llama-v3p3-70b-instruct",
messages=[
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": "https://storage.googleapis.com/fireworks-public/test/sample_resume.pdf"
},
},
{
"type": "text",
"text": "What are the candidate's BA and MBA GPAs?",
},
],
}
],
)
print(completion)
```
</TabItem>
<TabItem value="proxy" label="PROXY">
1. Setup config.yaml
```yaml
model_list:
- model_name: llama-v3p3-70b-instruct
litellm_params:
model: fireworks_ai/accounts/fireworks/models/llama-v3p3-70b-instruct
api_key: os.environ/FIREWORKS_AI_API_KEY
# api_base: os.environ/FIREWORKS_AI_API_BASE [OPTIONAL], defaults to "https://api.fireworks.ai/inference/v1"
```
2. Start Proxy
```
litellm --config config.yaml
```
3. Test it
```bash
curl -L -X POST 'http://0.0.0.0:4000/chat/completions' \
-H 'Content-Type: application/json' \
-H 'Authorization: Bearer YOUR_API_KEY' \
-d '{"model": "llama-v3p3-70b-instruct",
"messages": [
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": "https://storage.googleapis.com/fireworks-public/test/sample_resume.pdf"
},
},
{
"type": "text",
"text": "What are the candidate's BA and MBA GPAs?",
},
],
}
]}'
```
</TabItem>
</Tabs>
### Disable Auto-add
If you want to disable the auto-add of `#transform=inline` to the url of the image_url, you can set the `auto_add_transform_inline` to `False` in the `FireworksAIConfig` class.
<Tabs>
<TabItem value="sdk" label="SDK">
```python
litellm.disable_add_transform_inline_image_block = True
```
</TabItem>
<TabItem value="proxy" label="PROXY">
```yaml
litellm_settings:
disable_add_transform_inline_image_block: true
```
</TabItem>
</Tabs>
## Supported Models - ALL Fireworks AI Models Supported! ## Supported Models - ALL Fireworks AI Models Supported!
:::info :::info

View file

@ -138,6 +138,7 @@ general_settings:
| disable_end_user_cost_tracking | boolean | If true, turns off end user cost tracking on prometheus metrics + litellm spend logs table on proxy. | | disable_end_user_cost_tracking | boolean | If true, turns off end user cost tracking on prometheus metrics + litellm spend logs table on proxy. |
| disable_end_user_cost_tracking_prometheus_only | boolean | If true, turns off end user cost tracking on prometheus metrics only. | | disable_end_user_cost_tracking_prometheus_only | boolean | If true, turns off end user cost tracking on prometheus metrics only. |
| key_generation_settings | object | Restricts who can generate keys. [Further docs](./virtual_keys.md#restricting-key-generation) | | key_generation_settings | object | Restricts who can generate keys. [Further docs](./virtual_keys.md#restricting-key-generation) |
| disable_add_transform_inline_image_block | boolean | For Fireworks AI models - if true, turns off the auto-add of `#transform=inline` to the url of the image_url, if the model is not a vision model. |
### general_settings - Reference ### general_settings - Reference

View file

@ -0,0 +1,346 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# Control Model Access
## **Restrict models by Virtual Key**
Set allowed models for a key using the `models` param
```shell
curl 'http://0.0.0.0:4000/key/generate' \
--header 'Authorization: Bearer <your-master-key>' \
--header 'Content-Type: application/json' \
--data-raw '{"models": ["gpt-3.5-turbo", "gpt-4"]}'
```
:::info
This key can only make requests to `models` that are `gpt-3.5-turbo` or `gpt-4`
:::
Verify this is set correctly by
<Tabs>
<TabItem label="Allowed Access" value = "allowed">
```shell
curl -i http://localhost:4000/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer sk-1234" \
-d '{
"model": "gpt-4",
"messages": [
{"role": "user", "content": "Hello"}
]
}'
```
</TabItem>
<TabItem label="Disallowed Access" value = "not-allowed">
:::info
Expect this to fail since gpt-4o is not in the `models` for the key generated
:::
```shell
curl -i http://localhost:4000/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer sk-1234" \
-d '{
"model": "gpt-4o",
"messages": [
{"role": "user", "content": "Hello"}
]
}'
```
</TabItem>
</Tabs>
### [API Reference](https://litellm-api.up.railway.app/#/key%20management/generate_key_fn_key_generate_post)
## **Restrict models by `team_id`**
`litellm-dev` can only access `azure-gpt-3.5`
**1. Create a team via `/team/new`**
```shell
curl --location 'http://localhost:4000/team/new' \
--header 'Authorization: Bearer <your-master-key>' \
--header 'Content-Type: application/json' \
--data-raw '{
"team_alias": "litellm-dev",
"models": ["azure-gpt-3.5"]
}'
# returns {...,"team_id": "my-unique-id"}
```
**2. Create a key for team**
```shell
curl --location 'http://localhost:4000/key/generate' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data-raw '{"team_id": "my-unique-id"}'
```
**3. Test it**
```shell
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Content-Type: application/json' \
--header 'Authorization: Bearer sk-qo992IjKOC2CHKZGRoJIGA' \
--data '{
"model": "BEDROCK_GROUP",
"messages": [
{
"role": "user",
"content": "hi"
}
]
}'
```
```shell
{"error":{"message":"Invalid model for team litellm-dev: BEDROCK_GROUP. Valid models for team are: ['azure-gpt-3.5']\n\n\nTraceback (most recent call last):\n File \"/Users/ishaanjaffer/Github/litellm/litellm/proxy/proxy_server.py\", line 2298, in chat_completion\n _is_valid_team_configs(\n File \"/Users/ishaanjaffer/Github/litellm/litellm/proxy/utils.py\", line 1296, in _is_valid_team_configs\n raise Exception(\nException: Invalid model for team litellm-dev: BEDROCK_GROUP. Valid models for team are: ['azure-gpt-3.5']\n\n","type":"None","param":"None","code":500}}%
```
### [API Reference](https://litellm-api.up.railway.app/#/team%20management/new_team_team_new_post)
## **Model Access Groups**
Use model access groups to give users access to select models, and add new ones to it over time (e.g. mistral, llama-2, etc.)
**Step 1. Assign model, access group in config.yaml**
```yaml
model_list:
- model_name: gpt-4
litellm_params:
model: openai/fake
api_key: fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app/
model_info:
access_groups: ["beta-models"] # 👈 Model Access Group
- model_name: fireworks-llama-v3-70b-instruct
litellm_params:
model: fireworks_ai/accounts/fireworks/models/llama-v3-70b-instruct
api_key: "os.environ/FIREWORKS"
model_info:
access_groups: ["beta-models"] # 👈 Model Access Group
```
<Tabs>
<TabItem value="key" label="Key Access Groups">
**Create key with access group**
```bash
curl --location 'http://localhost:4000/key/generate' \
-H 'Authorization: Bearer <your-master-key>' \
-H 'Content-Type: application/json' \
-d '{"models": ["beta-models"], # 👈 Model Access Group
"max_budget": 0,}'
```
Test Key
<Tabs>
<TabItem label="Allowed Access" value = "allowed">
```shell
curl -i http://localhost:4000/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer sk-<key-from-previous-step>" \
-d '{
"model": "gpt-4",
"messages": [
{"role": "user", "content": "Hello"}
]
}'
```
</TabItem>
<TabItem label="Disallowed Access" value = "not-allowed">
:::info
Expect this to fail since gpt-4o is not in the `beta-models` access group
:::
```shell
curl -i http://localhost:4000/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer sk-<key-from-previous-step>" \
-d '{
"model": "gpt-4o",
"messages": [
{"role": "user", "content": "Hello"}
]
}'
```
</TabItem>
</Tabs>
</TabItem>
<TabItem value="team" label="Team Access Groups">
Create Team
```shell
curl --location 'http://localhost:4000/team/new' \
-H 'Authorization: Bearer sk-<key-from-previous-step>' \
-H 'Content-Type: application/json' \
-d '{"models": ["beta-models"]}'
```
Create Key for Team
```shell
curl --location 'http://0.0.0.0:4000/key/generate' \
--header 'Authorization: Bearer sk-<key-from-previous-step>' \
--header 'Content-Type: application/json' \
--data '{"team_id": "0ac97648-c194-4c90-8cd6-40af7b0d2d2a"}
```
Test Key
<Tabs>
<TabItem label="Allowed Access" value = "allowed">
```shell
curl -i http://localhost:4000/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer sk-<key-from-previous-step>" \
-d '{
"model": "gpt-4",
"messages": [
{"role": "user", "content": "Hello"}
]
}'
```
</TabItem>
<TabItem label="Disallowed Access" value = "not-allowed">
:::info
Expect this to fail since gpt-4o is not in the `beta-models` access group
:::
```shell
curl -i http://localhost:4000/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer sk-<key-from-previous-step>" \
-d '{
"model": "gpt-4o",
"messages": [
{"role": "user", "content": "Hello"}
]
}'
```
</TabItem>
</Tabs>
</TabItem>
</Tabs>
### ✨ Control Access on Wildcard Models
Control access to all models with a specific prefix (e.g. `openai/*`).
Use this to also give users access to all models, except for a few that you don't want them to use (e.g. `openai/o1-*`).
:::info
Setting model access groups on wildcard models is an Enterprise feature.
See pricing [here](https://litellm.ai/#pricing)
Get a trial key [here](https://litellm.ai/#trial)
:::
1. Setup config.yaml
```yaml
model_list:
- model_name: openai/*
litellm_params:
model: openai/*
api_key: os.environ/OPENAI_API_KEY
model_info:
access_groups: ["default-models"]
- model_name: openai/o1-*
litellm_params:
model: openai/o1-*
api_key: os.environ/OPENAI_API_KEY
model_info:
access_groups: ["restricted-models"]
```
2. Generate a key with access to `default-models`
```bash
curl -L -X POST 'http://0.0.0.0:4000/key/generate' \
-H 'Authorization: Bearer sk-1234' \
-H 'Content-Type: application/json' \
-d '{
"models": ["default-models"],
}'
```
3. Test the key
<Tabs>
<TabItem label="Successful Request" value = "success">
```bash
curl -i http://localhost:4000/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer sk-<key-from-previous-step>" \
-d '{
"model": "openai/gpt-4",
"messages": [
{"role": "user", "content": "Hello"}
]
}'
```
</TabItem>
<TabItem value="bad-request" label="Rejected Request">
```bash
curl -i http://localhost:4000/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer sk-<key-from-previous-step>" \
-d '{
"model": "openai/o1-mini",
"messages": [
{"role": "user", "content": "Hello"}
]
}'
```
</TabItem>
</Tabs>

View file

@ -224,272 +224,13 @@ Expected Response
</TabItem> </TabItem>
</Tabs> </Tabs>
## **Model Access**
### **Restrict models by Virtual Key** ## Model Aliases
Set allowed models for a key using the `models` param
```shell
curl 'http://0.0.0.0:4000/key/generate' \
--header 'Authorization: Bearer <your-master-key>' \
--header 'Content-Type: application/json' \
--data-raw '{"models": ["gpt-3.5-turbo", "gpt-4"]}'
```
:::info
This key can only make requests to `models` that are `gpt-3.5-turbo` or `gpt-4`
:::
Verify this is set correctly by
<Tabs>
<TabItem label="Allowed Access" value = "allowed">
```shell
curl -i http://localhost:4000/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer sk-1234" \
-d '{
"model": "gpt-4",
"messages": [
{"role": "user", "content": "Hello"}
]
}'
```
</TabItem>
<TabItem label="Disallowed Access" value = "not-allowed">
:::info
Expect this to fail since gpt-4o is not in the `models` for the key generated
:::
```shell
curl -i http://localhost:4000/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer sk-1234" \
-d '{
"model": "gpt-4o",
"messages": [
{"role": "user", "content": "Hello"}
]
}'
```
</TabItem>
</Tabs>
### **Restrict models by `team_id`**
`litellm-dev` can only access `azure-gpt-3.5`
**1. Create a team via `/team/new`**
```shell
curl --location 'http://localhost:4000/team/new' \
--header 'Authorization: Bearer <your-master-key>' \
--header 'Content-Type: application/json' \
--data-raw '{
"team_alias": "litellm-dev",
"models": ["azure-gpt-3.5"]
}'
# returns {...,"team_id": "my-unique-id"}
```
**2. Create a key for team**
```shell
curl --location 'http://localhost:4000/key/generate' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data-raw '{"team_id": "my-unique-id"}'
```
**3. Test it**
```shell
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Content-Type: application/json' \
--header 'Authorization: Bearer sk-qo992IjKOC2CHKZGRoJIGA' \
--data '{
"model": "BEDROCK_GROUP",
"messages": [
{
"role": "user",
"content": "hi"
}
]
}'
```
```shell
{"error":{"message":"Invalid model for team litellm-dev: BEDROCK_GROUP. Valid models for team are: ['azure-gpt-3.5']\n\n\nTraceback (most recent call last):\n File \"/Users/ishaanjaffer/Github/litellm/litellm/proxy/proxy_server.py\", line 2298, in chat_completion\n _is_valid_team_configs(\n File \"/Users/ishaanjaffer/Github/litellm/litellm/proxy/utils.py\", line 1296, in _is_valid_team_configs\n raise Exception(\nException: Invalid model for team litellm-dev: BEDROCK_GROUP. Valid models for team are: ['azure-gpt-3.5']\n\n","type":"None","param":"None","code":500}}%
```
### **Grant Access to new model (Access Groups)**
Use model access groups to give users access to select models, and add new ones to it over time (e.g. mistral, llama-2, etc.)
**Step 1. Assign model, access group in config.yaml**
```yaml
model_list:
- model_name: gpt-4
litellm_params:
model: openai/fake
api_key: fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app/
model_info:
access_groups: ["beta-models"] # 👈 Model Access Group
- model_name: fireworks-llama-v3-70b-instruct
litellm_params:
model: fireworks_ai/accounts/fireworks/models/llama-v3-70b-instruct
api_key: "os.environ/FIREWORKS"
model_info:
access_groups: ["beta-models"] # 👈 Model Access Group
```
<Tabs>
<TabItem value="key" label="Key Access Groups">
**Create key with access group**
```bash
curl --location 'http://localhost:4000/key/generate' \
-H 'Authorization: Bearer <your-master-key>' \
-H 'Content-Type: application/json' \
-d '{"models": ["beta-models"], # 👈 Model Access Group
"max_budget": 0,}'
```
Test Key
<Tabs>
<TabItem label="Allowed Access" value = "allowed">
```shell
curl -i http://localhost:4000/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer sk-<key-from-previous-step>" \
-d '{
"model": "gpt-4",
"messages": [
{"role": "user", "content": "Hello"}
]
}'
```
</TabItem>
<TabItem label="Disallowed Access" value = "not-allowed">
:::info
Expect this to fail since gpt-4o is not in the `beta-models` access group
:::
```shell
curl -i http://localhost:4000/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer sk-<key-from-previous-step>" \
-d '{
"model": "gpt-4o",
"messages": [
{"role": "user", "content": "Hello"}
]
}'
```
</TabItem>
</Tabs>
</TabItem>
<TabItem value="team" label="Team Access Groups">
Create Team
```shell
curl --location 'http://localhost:4000/team/new' \
-H 'Authorization: Bearer sk-<key-from-previous-step>' \
-H 'Content-Type: application/json' \
-d '{"models": ["beta-models"]}'
```
Create Key for Team
```shell
curl --location 'http://0.0.0.0:4000/key/generate' \
--header 'Authorization: Bearer sk-<key-from-previous-step>' \
--header 'Content-Type: application/json' \
--data '{"team_id": "0ac97648-c194-4c90-8cd6-40af7b0d2d2a"}
```
Test Key
<Tabs>
<TabItem label="Allowed Access" value = "allowed">
```shell
curl -i http://localhost:4000/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer sk-<key-from-previous-step>" \
-d '{
"model": "gpt-4",
"messages": [
{"role": "user", "content": "Hello"}
]
}'
```
</TabItem>
<TabItem label="Disallowed Access" value = "not-allowed">
:::info
Expect this to fail since gpt-4o is not in the `beta-models` access group
:::
```shell
curl -i http://localhost:4000/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer sk-<key-from-previous-step>" \
-d '{
"model": "gpt-4o",
"messages": [
{"role": "user", "content": "Hello"}
]
}'
```
</TabItem>
</Tabs>
</TabItem>
</Tabs>
### Model Aliases
If a user is expected to use a given model (i.e. gpt3-5), and you want to: If a user is expected to use a given model (i.e. gpt3-5), and you want to:
- try to upgrade the request (i.e. GPT4) - try to upgrade the request (i.e. GPT4)
- or downgrade it (i.e. Mistral) - or downgrade it (i.e. Mistral)
- OR rotate the API KEY (i.e. open AI)
- OR access the same model through different end points (i.e. openAI vs openrouter vs Azure)
Here's how you can do that: Here's how you can do that:
@ -509,13 +250,13 @@ model_list:
litellm_params: litellm_params:
model: huggingface/HuggingFaceH4/zephyr-7b-beta model: huggingface/HuggingFaceH4/zephyr-7b-beta
api_base: http://0.0.0.0:8003 api_base: http://0.0.0.0:8003
- model_name: my-paid-tier - model_name: my-paid-tier
litellm_params: litellm_params:
model: gpt-4 model: gpt-4
api_key: my-api-key api_key: my-api-key
``` ```
**Step 2: Generate a user key - enabling them access to specific models, custom model aliases, etc.** **Step 2: Generate a key**
```bash ```bash
curl -X POST "https://0.0.0.0:4000/key/generate" \ curl -X POST "https://0.0.0.0:4000/key/generate" \
@ -523,13 +264,29 @@ curl -X POST "https://0.0.0.0:4000/key/generate" \
-H "Content-Type: application/json" \ -H "Content-Type: application/json" \
-d '{ -d '{
"models": ["my-free-tier"], "models": ["my-free-tier"],
"aliases": {"gpt-3.5-turbo": "my-free-tier"}, "aliases": {"gpt-3.5-turbo": "my-free-tier"}, # 👈 KEY CHANGE
"duration": "30min" "duration": "30min"
}' }'
``` ```
- **How to upgrade / downgrade request?** Change the alias mapping - **How to upgrade / downgrade request?** Change the alias mapping
- **How are routing between diff keys/api bases done?** litellm handles this by shuffling between different models in the model list with the same model_name. [**See Code**](https://github.com/BerriAI/litellm/blob/main/litellm/router.py)
**Step 3: Test the key**
```bash
curl -X POST "https://0.0.0.0:4000/key/generate" \
-H "Authorization: Bearer <user-key>" \
-H "Content-Type: application/json" \
-d '{
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "user",
"content": "this is a test request, write a short poem"
}
]
}'
```
## Advanced ## Advanced

View file

@ -138,3 +138,6 @@ curl http://localhost:4000/v1/chat/completions \
</TabItem> </TabItem>
</Tabs> </Tabs>
## [[PROXY-Only] Control Wildcard Model Access](./proxy/model_access#-control-access-on-wildcard-models)

View file

@ -81,6 +81,14 @@ const sidebars = {
"proxy/multiple_admins", "proxy/multiple_admins",
], ],
}, },
{
type: "category",
label: "Model Access",
items: [
"proxy/model_access",
"proxy/team_model_add"
]
},
{ {
type: "category", type: "category",
label: "Admin UI", label: "Admin UI",
@ -91,13 +99,6 @@ const sidebars = {
"proxy/custom_sso" "proxy/custom_sso"
], ],
}, },
{
type: "category",
label: "Team Management",
items: [
"proxy/team_model_add"
],
},
{ {
type: "category", type: "category",
label: "Spend Tracking", label: "Spend Tracking",

View file

@ -151,6 +151,7 @@ use_client: bool = False
ssl_verify: Union[str, bool] = True ssl_verify: Union[str, bool] = True
ssl_certificate: Optional[str] = None ssl_certificate: Optional[str] = None
disable_streaming_logging: bool = False disable_streaming_logging: bool = False
disable_add_transform_inline_image_block: bool = False
in_memory_llm_clients_cache: InMemoryCache = InMemoryCache() in_memory_llm_clients_cache: InMemoryCache = InMemoryCache()
safe_memory_mode: bool = False safe_memory_mode: bool = False
enable_azure_ad_token_refresh: Optional[bool] = False enable_azure_ad_token_refresh: Optional[bool] = False

View file

@ -0,0 +1,9 @@
from abc import ABC, abstractmethod
from litellm.types.utils import ModelInfoBase
class BaseLLMModelInfo(ABC):
@abstractmethod
def get_model_info(self, model: str) -> ModelInfoBase:
pass

View file

@ -1,12 +1,15 @@
from typing import List, Literal, Optional, Tuple, Union from typing import List, Literal, Optional, Tuple, Union, cast
import litellm
from litellm.llms.base_llm.base_utils import BaseLLMModelInfo
from litellm.secret_managers.main import get_secret_str from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import AllMessageValues from litellm.types.llms.openai import AllMessageValues, ChatCompletionImageObject
from litellm.types.utils import ModelInfoBase, ProviderSpecificModelInfo
from ...openai.chat.gpt_transformation import OpenAIGPTConfig from ...openai.chat.gpt_transformation import OpenAIGPTConfig
class FireworksAIConfig(OpenAIGPTConfig): class FireworksAIConfig(BaseLLMModelInfo, OpenAIGPTConfig):
""" """
Reference: https://docs.fireworks.ai/api-reference/post-chatcompletions Reference: https://docs.fireworks.ai/api-reference/post-chatcompletions
@ -110,6 +113,80 @@ class FireworksAIConfig(OpenAIGPTConfig):
optional_params[param] = value optional_params[param] = value
return optional_params return optional_params
def _add_transform_inline_image_block(
self,
content: ChatCompletionImageObject,
model: str,
disable_add_transform_inline_image_block: Optional[bool],
) -> ChatCompletionImageObject:
"""
Add transform_inline to the image_url (allows non-vision models to parse documents/images/etc.)
- ignore if model is a vision model
- ignore if user has disabled this feature
"""
if (
"vision" in model or disable_add_transform_inline_image_block
): # allow user to toggle this feature.
return content
if isinstance(content["image_url"], str):
content["image_url"] = f"{content['image_url']}#transform=inline"
elif isinstance(content["image_url"], dict):
content["image_url"][
"url"
] = f"{content['image_url']['url']}#transform=inline"
return content
def _transform_messages_helper(
self, messages: List[AllMessageValues], model: str, litellm_params: dict
) -> List[AllMessageValues]:
"""
Add 'transform=inline' to the url of the image_url
"""
disable_add_transform_inline_image_block = cast(
Optional[bool],
litellm_params.get(
"disable_add_transform_inline_image_block",
litellm.disable_add_transform_inline_image_block,
),
)
for message in messages:
if message["role"] == "user":
_message_content = message.get("content")
if _message_content is not None and isinstance(_message_content, list):
for content in _message_content:
if content["type"] == "image_url":
content = self._add_transform_inline_image_block(
content=content,
model=model,
disable_add_transform_inline_image_block=disable_add_transform_inline_image_block,
)
return messages
def get_model_info(
self, model: str, existing_model_info: Optional[ModelInfoBase] = None
) -> ModelInfoBase:
provider_specific_model_info = ProviderSpecificModelInfo(
supports_function_calling=True,
supports_prompt_caching=True, # https://docs.fireworks.ai/guides/prompt-caching
supports_pdf_input=True, # via document inlining
supports_vision=True, # via document inlining
)
if existing_model_info is not None:
return ModelInfoBase(
**{**existing_model_info, **provider_specific_model_info}
)
return ModelInfoBase(
key=model,
litellm_provider="fireworks_ai",
mode="chat",
input_cost_per_token=0.0,
output_cost_per_token=0.0,
max_tokens=None,
max_input_tokens=None,
max_output_tokens=None,
**provider_specific_model_info,
)
def transform_request( def transform_request(
self, self,
model: str, model: str,
@ -120,6 +197,9 @@ class FireworksAIConfig(OpenAIGPTConfig):
) -> dict: ) -> dict:
if not model.startswith("accounts/"): if not model.startswith("accounts/"):
model = f"accounts/fireworks/models/{model}" model = f"accounts/fireworks/models/{model}"
messages = self._transform_messages_helper(
messages=messages, model=model, litellm_params=litellm_params
)
return super().transform_request( return super().transform_request(
model=model, model=model,
messages=messages, messages=messages,

View file

@ -899,6 +899,10 @@ def completion( # type: ignore # noqa: PLR0915
hf_model_name = kwargs.get("hf_model_name", None) hf_model_name = kwargs.get("hf_model_name", None)
supports_system_message = kwargs.get("supports_system_message", None) supports_system_message = kwargs.get("supports_system_message", None)
base_model = kwargs.get("base_model", None) base_model = kwargs.get("base_model", None)
### DISABLE FLAGS ###
disable_add_transform_inline_image_block = kwargs.get(
"disable_add_transform_inline_image_block", None
)
### TEXT COMPLETION CALLS ### ### TEXT COMPLETION CALLS ###
text_completion = kwargs.get("text_completion", False) text_completion = kwargs.get("text_completion", False)
atext_completion = kwargs.get("atext_completion", False) atext_completion = kwargs.get("atext_completion", False)
@ -956,14 +960,11 @@ def completion( # type: ignore # noqa: PLR0915
"top_logprobs", "top_logprobs",
"extra_headers", "extra_headers",
] ]
default_params = openai_params + all_litellm_params default_params = openai_params + all_litellm_params
litellm_params = {} # used to prevent unbound var errors litellm_params = {} # used to prevent unbound var errors
non_default_params = { non_default_params = {
k: v for k, v in kwargs.items() if k not in default_params k: v for k, v in kwargs.items() if k not in default_params
} # model-specific params - pass them straight to the model/provider } # model-specific params - pass them straight to the model/provider
## PROMPT MANAGEMENT HOOKS ## ## PROMPT MANAGEMENT HOOKS ##
if isinstance(litellm_logging_obj, LiteLLMLoggingObj) and prompt_id is not None: if isinstance(litellm_logging_obj, LiteLLMLoggingObj) and prompt_id is not None:
@ -1156,6 +1157,7 @@ def completion( # type: ignore # noqa: PLR0915
hf_model_name=hf_model_name, hf_model_name=hf_model_name,
custom_prompt_dict=custom_prompt_dict, custom_prompt_dict=custom_prompt_dict,
litellm_metadata=kwargs.get("litellm_metadata"), litellm_metadata=kwargs.get("litellm_metadata"),
disable_add_transform_inline_image_block=disable_add_transform_inline_image_block,
) )
logging.update_environment_variables( logging.update_environment_variables(
model=model, model=model,

View file

@ -1,17 +1,13 @@
model_list: model_list:
- model_name: model-test - model_name: openai/*
litellm_params: litellm_params:
model: openai/gpt-3.5-turbo model: openai/*
api_key: os.environ/OPENAI_API_KEY api_key: os.environ/OPENAI_API_KEY
mock_response: "Hello, world!" model_info:
rpm: 1 access_groups: ["default-models"]
- model_name: model-test - model_name: openai/o1-*
litellm_params: litellm_params:
model: openai/o1-mini model: openai/o1-*
api_key: os.environ/OPENAI_API_KEY api_key: os.environ/OPENAI_API_KEY
mock_response: "Hello, world, it's o1!" model_info:
rpm: 10 access_groups: ["restricted-models"]
router_settings:
routing_strategy: usage-based-routing-v2
disable_cooldowns: True

View file

@ -16,7 +16,7 @@ import secrets
import traceback import traceback
import uuid import uuid
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from typing import List, Optional, Tuple, cast from typing import List, Literal, Optional, Tuple, cast
import fastapi import fastapi
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Request, status from fastapi import APIRouter, Depends, Header, HTTPException, Query, Request, status
@ -38,6 +38,7 @@ from litellm.proxy.utils import (
duration_in_seconds, duration_in_seconds,
handle_exception_on_proxy, handle_exception_on_proxy,
) )
from litellm.router import Router
from litellm.secret_managers.main import get_secret from litellm.secret_managers.main import get_secret
from litellm.types.utils import ( from litellm.types.utils import (
BudgetConfig, BudgetConfig,
@ -330,6 +331,8 @@ async def generate_key_fn( # noqa: PLR0915
try: try:
from litellm.proxy.proxy_server import ( from litellm.proxy.proxy_server import (
litellm_proxy_admin_name, litellm_proxy_admin_name,
llm_router,
premium_user,
prisma_client, prisma_client,
user_api_key_cache, user_api_key_cache,
user_custom_key_generate, user_custom_key_generate,
@ -386,6 +389,12 @@ async def generate_key_fn( # noqa: PLR0915
detail=str(e), detail=str(e),
) )
_check_model_access_group(
models=data.models,
llm_router=llm_router,
premium_user=premium_user,
)
# check if user set default key/generate params on config.yaml # check if user set default key/generate params on config.yaml
if litellm.default_key_generate_params is not None: if litellm.default_key_generate_params is not None:
for elem in data: for elem in data:
@ -992,6 +1001,34 @@ async def info_key_fn(
raise handle_exception_on_proxy(e) raise handle_exception_on_proxy(e)
def _check_model_access_group(
models: Optional[List[str]], llm_router: Optional[Router], premium_user: bool
) -> Literal[True]:
"""
if is_model_access_group is True + is_wildcard_route is True, check if user is a premium user
Return True if user is a premium user, False otherwise
"""
if models is None or llm_router is None:
return True
for model in models:
if llm_router._is_model_access_group_for_wildcard_route(
model_access_group=model
):
if not premium_user:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail={
"error": "Setting a model access group on a wildcard model is only available for LiteLLM Enterprise users.{}".format(
CommonProxyErrors.not_premium_user.value
)
},
)
return True
async def generate_key_helper_fn( # noqa: PLR0915 async def generate_key_helper_fn( # noqa: PLR0915
request_type: Literal[ request_type: Literal[
"user", "key" "user", "key"

View file

@ -4713,10 +4713,14 @@ class Router:
return None return None
def get_model_access_groups( def get_model_access_groups(
self, model_name: Optional[str] = None self, model_name: Optional[str] = None, model_access_group: Optional[str] = None
) -> Dict[str, List[str]]: ) -> Dict[str, List[str]]:
""" """
If model_name is provided, only return access groups for that model. If model_name is provided, only return access groups for that model.
Parameters:
- model_name: Optional[str] - the received model name from the user (can be a wildcard route). If set, will only return access groups for that model.
- model_access_group: Optional[str] - the received model access group from the user. If set, will only return models for that access group.
""" """
from collections import defaultdict from collections import defaultdict
@ -4726,11 +4730,39 @@ class Router:
if model_list: if model_list:
for m in model_list: for m in model_list:
for group in m.get("model_info", {}).get("access_groups", []): for group in m.get("model_info", {}).get("access_groups", []):
model_name = m["model_name"] if model_access_group is not None:
access_groups[group].append(model_name) if group == model_access_group:
model_name = m["model_name"]
access_groups[group].append(model_name)
else:
model_name = m["model_name"]
access_groups[group].append(model_name)
return access_groups return access_groups
def _is_model_access_group_for_wildcard_route(
self, model_access_group: str
) -> bool:
"""
Return True if model access group is a wildcard route
"""
# GET ACCESS GROUPS
access_groups = self.get_model_access_groups(
model_access_group=model_access_group
)
if len(access_groups) == 0:
return False
models = access_groups.get(model_access_group, [])
for model in models:
# CHECK IF MODEL ACCESS GROUP IS A WILDCARD ROUTE
if self.pattern_router.route(request=model) is not None:
return True
return False
def get_settings(self): def get_settings(self):
""" """
Get router settings method, returns a dictionary of the settings and their values. Get router settings method, returns a dictionary of the settings and their values.

View file

@ -128,7 +128,7 @@ class PatternMatchRouter:
if no pattern is found, return None if no pattern is found, return None
Args: Args:
request: Optional[str] request: str - the received model name from the user (can be a wildcard route). If none, No deployments will be returned.
filtered_model_names: Optional[List[str]] - if provided, only return deployments that match the filtered_model_names filtered_model_names: Optional[List[str]] - if provided, only return deployments that match the filtered_model_names
Returns: Returns:
Optional[List[Deployment]]: llm deployments Optional[List[Deployment]]: llm deployments

View file

@ -75,7 +75,20 @@ class ProviderField(TypedDict):
field_value: str field_value: str
class ModelInfoBase(TypedDict, total=False): class ProviderSpecificModelInfo(TypedDict, total=False):
supports_system_messages: Optional[bool]
supports_response_schema: Optional[bool]
supports_vision: Optional[bool]
supports_function_calling: Optional[bool]
supports_assistant_prefill: Optional[bool]
supports_prompt_caching: Optional[bool]
supports_audio_input: Optional[bool]
supports_embedding_image_input: Optional[bool]
supports_audio_output: Optional[bool]
supports_pdf_input: Optional[bool]
class ModelInfoBase(ProviderSpecificModelInfo, total=False):
key: Required[str] # the key in litellm.model_cost which is returned key: Required[str] # the key in litellm.model_cost which is returned
max_tokens: Required[Optional[int]] max_tokens: Required[Optional[int]]
@ -116,16 +129,6 @@ class ModelInfoBase(TypedDict, total=False):
"completion", "embedding", "image_generation", "chat", "audio_transcription" "completion", "embedding", "image_generation", "chat", "audio_transcription"
] ]
] ]
supports_system_messages: Optional[bool]
supports_response_schema: Optional[bool]
supports_vision: Optional[bool]
supports_function_calling: Optional[bool]
supports_assistant_prefill: Optional[bool]
supports_prompt_caching: Optional[bool]
supports_audio_input: Optional[bool]
supports_embedding_image_input: Optional[bool]
supports_audio_output: Optional[bool]
supports_pdf_input: Optional[bool]
tpm: Optional[int] tpm: Optional[int]
rpm: Optional[int] rpm: Optional[int]
@ -1613,6 +1616,7 @@ all_litellm_params = [
"caching", "caching",
"mock_response", "mock_response",
"mock_timeout", "mock_timeout",
"disable_add_transform_inline_image_block",
"api_key", "api_key",
"api_version", "api_version",
"prompt_id", "prompt_id",

View file

@ -174,6 +174,7 @@ from openai import OpenAIError as OriginalError
from litellm.llms.base_llm.audio_transcription.transformation import ( from litellm.llms.base_llm.audio_transcription.transformation import (
BaseAudioTranscriptionConfig, BaseAudioTranscriptionConfig,
) )
from litellm.llms.base_llm.base_utils import BaseLLMModelInfo
from litellm.llms.base_llm.chat.transformation import BaseConfig from litellm.llms.base_llm.chat.transformation import BaseConfig
from litellm.llms.base_llm.completion.transformation import BaseTextCompletionConfig from litellm.llms.base_llm.completion.transformation import BaseTextCompletionConfig
from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig
@ -1989,6 +1990,7 @@ def get_litellm_params(
hf_model_name: Optional[str] = None, hf_model_name: Optional[str] = None,
custom_prompt_dict: Optional[dict] = None, custom_prompt_dict: Optional[dict] = None,
litellm_metadata: Optional[dict] = None, litellm_metadata: Optional[dict] = None,
disable_add_transform_inline_image_block: Optional[bool] = None,
): ):
litellm_params = { litellm_params = {
"acompletion": acompletion, "acompletion": acompletion,
@ -2021,6 +2023,7 @@ def get_litellm_params(
"hf_model_name": hf_model_name, "hf_model_name": hf_model_name,
"custom_prompt_dict": custom_prompt_dict, "custom_prompt_dict": custom_prompt_dict,
"litellm_metadata": litellm_metadata, "litellm_metadata": litellm_metadata,
"disable_add_transform_inline_image_block": disable_add_transform_inline_image_block,
} }
return litellm_params return litellm_params
@ -4373,6 +4376,17 @@ def _get_model_info_helper( # noqa: PLR0915
model_info=_model_info, custom_llm_provider=custom_llm_provider model_info=_model_info, custom_llm_provider=custom_llm_provider
): ):
_model_info = None _model_info = None
if _model_info is None and ProviderConfigManager.get_provider_model_info(
model=model, provider=LlmProviders(custom_llm_provider)
):
provider_config = ProviderConfigManager.get_provider_model_info(
model=model, provider=LlmProviders(custom_llm_provider)
)
if provider_config is not None:
_model_info = cast(
dict, provider_config.get_model_info(model=model)
)
key = "provider_specific_model_info"
if _model_info is None or key is None: if _model_info is None or key is None:
raise ValueError( raise ValueError(
"This model isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json" "This model isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json"
@ -6338,6 +6352,15 @@ class ProviderConfigManager:
return litellm.TogetherAITextCompletionConfig() return litellm.TogetherAITextCompletionConfig()
return litellm.OpenAITextCompletionConfig() return litellm.OpenAITextCompletionConfig()
@staticmethod
def get_provider_model_info(
model: str,
provider: LlmProviders,
) -> Optional[BaseLLMModelInfo]:
if LlmProviders.FIREWORKS_AI == provider:
return litellm.FireworksAIConfig()
return None
def get_end_user_id_for_cost_tracking( def get_end_user_id_for_cost_tracking(
litellm_params: dict, litellm_params: dict,

View file

@ -103,3 +103,96 @@ class TestFireworksAIAudioTranscription(BaseLLMAudioTranscriptionTest):
def get_custom_llm_provider(self) -> litellm.LlmProviders: def get_custom_llm_provider(self) -> litellm.LlmProviders:
return litellm.LlmProviders.FIREWORKS_AI return litellm.LlmProviders.FIREWORKS_AI
@pytest.mark.parametrize(
"disable_add_transform_inline_image_block",
[True, False],
)
def test_document_inlining_example(disable_add_transform_inline_image_block):
litellm.set_verbose = True
if disable_add_transform_inline_image_block is True:
with pytest.raises(Exception):
completion = litellm.completion(
model="fireworks_ai/accounts/fireworks/models/llama-v3p3-70b-instruct",
messages=[
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": "https://storage.googleapis.com/fireworks-public/test/sample_resume.pdf"
},
},
{
"type": "text",
"text": "What are the candidate's BA and MBA GPAs?",
},
],
}
],
disable_add_transform_inline_image_block=disable_add_transform_inline_image_block,
)
else:
completion = litellm.completion(
model="fireworks_ai/accounts/fireworks/models/llama-v3p3-70b-instruct",
messages=[
{
"role": "user",
"content": "this is a test request, write a short poem",
},
],
disable_add_transform_inline_image_block=disable_add_transform_inline_image_block,
)
print(completion)
@pytest.mark.parametrize(
"content, model, expected_url",
[
(
{"image_url": "http://example.com/image.png"},
"gpt-4",
"http://example.com/image.png#transform=inline",
),
(
{"image_url": {"url": "http://example.com/image.png"}},
"gpt-4",
{"url": "http://example.com/image.png#transform=inline"},
),
(
{"image_url": "http://example.com/image.png"},
"vision-gpt",
"http://example.com/image.png",
),
],
)
def test_transform_inline(content, model, expected_url):
result = litellm.FireworksAIConfig()._add_transform_inline_image_block(
content=content, model=model, disable_add_transform_inline_image_block=False
)
if isinstance(expected_url, str):
assert result["image_url"] == expected_url
else:
assert result["image_url"]["url"] == expected_url["url"]
@pytest.mark.parametrize(
"model, is_disabled, expected_url",
[
("gpt-4", True, "http://example.com/image.png"),
("vision-gpt", False, "http://example.com/image.png"),
("gpt-4", False, "http://example.com/image.png#transform=inline"),
],
)
def test_global_disable_flag(model, is_disabled, expected_url):
content = {"image_url": "http://example.com/image.png"}
result = litellm.FireworksAIConfig()._add_transform_inline_image_block(
content=content,
model=model,
disable_add_transform_inline_image_block=is_disabled,
)
assert result["image_url"] == expected_url
litellm.disable_add_transform_inline_image_block = False # Reset for other tests

View file

@ -364,3 +364,23 @@ async def test_get_remaining_model_group_usage():
assert remaining_usage is not None assert remaining_usage is not None
assert "x-ratelimit-remaining-requests" in remaining_usage assert "x-ratelimit-remaining-requests" in remaining_usage
assert "x-ratelimit-remaining-tokens" in remaining_usage assert "x-ratelimit-remaining-tokens" in remaining_usage
@pytest.mark.parametrize(
"potential_access_group, expected_result",
[("gemini-models", True), ("gemini-models-2", False), ("gemini/*", False)],
)
def test_router_get_model_access_groups(potential_access_group, expected_result):
router = Router(
model_list=[
{
"model_name": "gemini/*",
"litellm_params": {"model": "gemini/*"},
"model_info": {"id": 1, "access_groups": ["gemini-models"]},
},
]
)
access_groups = router._is_model_access_group_for_wildcard_route(
model_access_group=potential_access_group
)
assert access_groups == expected_result

View file

@ -1240,3 +1240,15 @@ def test_token_counter_with_image_url_with_detail_high():
) )
print("tokens", _tokens) print("tokens", _tokens)
assert _tokens == DEFAULT_IMAGE_TOKEN_COUNT + 7 assert _tokens == DEFAULT_IMAGE_TOKEN_COUNT + 7
def test_fireworks_ai_document_inlining():
"""
With document inlining, all fireworks ai models are now:
- supports_pdf
- supports_vision
"""
from litellm.utils import supports_pdf_input, supports_vision
assert supports_pdf_input("fireworks_ai/llama-3.1-8b-instruct") is True
assert supports_vision("fireworks_ai/llama-3.1-8b-instruct") is True