fix(key_management_endpoints.py): override metadata field value on up… (#7008)

* fix(key_management_endpoints.py): override metadata field value on update

allow user to override tags

* feat(__init__.py): expose new disable_end_user_cost_tracking_prometheus_only metric

allow disabling end user cost tracking on prometheus - fixes cardinality issue

* fix(litellm_pre_call_utils.py): add key/team level enforced params

Fixes https://github.com/BerriAI/litellm/issues/6652

* fix(key_management_endpoints.py): allow user to pass in `enforced_params` as a top level param on /key/generate and /key/update

* docs(enterprise.md): add docs on enforcing required params for llm requests

* Add support of Galadriel API (#7005)

* fix(router.py): robust retry after handling

set retry after time to 0 if >0 healthy deployments. handle base case = 1 deployment

* test(test_router.py): fix test

* feat(bedrock/): add support for 'nova' models

also adds explicit 'converse/' route for simpler routing

* fix: fix 'supports_pdf_input'

return if model supports pdf input on get_model_info

* feat(converse_transformation.py): support bedrock pdf input

* docs(document_understanding.md): add document understanding to docs

* fix(litellm_pre_call_utils.py): fix linting error

* fix(init.py): fix passing of bedrock converse models

* feat(bedrock/converse): support 'response_format={"type": "json_object"}'

* fix(converse_handler.py): fix linting error

* fix(base_llm_unit_tests.py): fix test

* fix: fix test

* test: fix test

* test: fix test

* test: remove duplicate test

---------

Co-authored-by: h4n0 <4738254+h4n0@users.noreply.github.com>
This commit is contained in:
Krish Dholakia 2024-12-03 23:03:50 -08:00 committed by GitHub
parent d558b643be
commit 6bb934c0ac
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
37 changed files with 1297 additions and 503 deletions

View file

@ -271,6 +271,7 @@ curl 'http://0.0.0.0:4000/key/generate' \
| [voyage ai](https://docs.litellm.ai/docs/providers/voyage) | | | | | ✅ | |
| [xinference [Xorbits Inference]](https://docs.litellm.ai/docs/providers/xinference) | | | | | ✅ | |
| [FriendliAI](https://docs.litellm.ai/docs/providers/friendliai) | ✅ | ✅ | ✅ | ✅ | | |
| [Galadriel](https://docs.litellm.ai/docs/providers/galadriel) | ✅ | ✅ | ✅ | ✅ | | |
[**Read the Docs**](https://docs.litellm.ai/docs/)

View file

@ -0,0 +1,202 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# Using PDF Input
How to send / receieve pdf's (other document types) to a `/chat/completions` endpoint
Works for:
- Vertex AI models (Gemini + Anthropic)
- Bedrock Models
- Anthropic API Models
## Quick Start
### url
<Tabs>
<TabItem value="sdk" label="SDK">
```python
from litellm.utils import supports_pdf_input, completion
# set aws credentials
os.environ["AWS_ACCESS_KEY_ID"] = ""
os.environ["AWS_SECRET_ACCESS_KEY"] = ""
os.environ["AWS_REGION_NAME"] = ""
# pdf url
image_url = "https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf"
# model
model = "bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0"
image_content = [
{"type": "text", "text": "What's this file about?"},
{
"type": "image_url",
"image_url": image_url, # OR {"url": image_url}
},
]
if not supports_pdf_input(model, None):
print("Model does not support image input")
response = completion(
model=model,
messages=[{"role": "user", "content": image_content}],
)
assert response is not None
```
</TabItem>
<TabItem value="proxy" label="PROXY">
1. Setup config.yaml
```yaml
model_list:
- model_name: bedrock-model
litellm_params:
model: bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0
aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID
aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY
aws_region_name: os.environ/AWS_REGION_NAME
```
2. Start the proxy
```bash
litellm --config /path/to/config.yaml
```
3. Test it!
```bash
curl -X POST 'http://0.0.0.0:4000/chat/completions' \
-H 'Content-Type: application/json' \
-H 'Authorization: Bearer sk-1234' \
-d '{
"model": "bedrock-model",
"messages": [
{"role": "user", "content": {"type": "text", "text": "What's this file about?"}},
{
"type": "image_url",
"image_url": "https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf",
}
]
}'
```
</TabItem>
</Tabs>
### base64
<Tabs>
<TabItem value="sdk" label="SDK">
```python
from litellm.utils import supports_pdf_input, completion
# set aws credentials
os.environ["AWS_ACCESS_KEY_ID"] = ""
os.environ["AWS_SECRET_ACCESS_KEY"] = ""
os.environ["AWS_REGION_NAME"] = ""
# pdf url
image_url = "https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf"
response = requests.get(url)
file_data = response.content
encoded_file = base64.b64encode(file_data).decode("utf-8")
base64_url = f"data:application/pdf;base64,{encoded_file}"
# model
model = "bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0"
image_content = [
{"type": "text", "text": "What's this file about?"},
{
"type": "image_url",
"image_url": base64_url, # OR {"url": base64_url}
},
]
if not supports_pdf_input(model, None):
print("Model does not support image input")
response = completion(
model=model,
messages=[{"role": "user", "content": image_content}],
)
assert response is not None
```
</TabItem>
</Tabs>
## Checking if a model supports pdf input
<Tabs>
<TabItem label="SDK" value="sdk">
Use `litellm.supports_pdf_input(model="bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0")` -> returns `True` if model can accept pdf input
```python
assert litellm.supports_pdf_input(model="bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0") == True
```
</TabItem>
<TabItem label="PROXY" value="proxy">
1. Define bedrock models on config.yaml
```yaml
model_list:
- model_name: bedrock-model # model group name
litellm_params:
model: bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0
aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID
aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY
aws_region_name: os.environ/AWS_REGION_NAME
model_info: # OPTIONAL - set manually
supports_pdf_input: True
```
2. Run proxy server
```bash
litellm --config config.yaml
```
3. Call `/model_group/info` to check if a model supports `pdf` input
```shell
curl -X 'GET' \
'http://localhost:4000/model_group/info' \
-H 'accept: application/json' \
-H 'x-api-key: sk-1234'
```
Expected Response
```json
{
"data": [
{
"model_group": "bedrock-model",
"providers": ["bedrock"],
"max_input_tokens": 128000,
"max_output_tokens": 16384,
"mode": "chat",
...,
"supports_pdf_input": true, # 👈 supports_pdf_input is true
}
]
}
```
</TabItem>
</Tabs>

View file

@ -706,6 +706,37 @@ print(response)
</Tabs>
## Set 'converse' / 'invoke' route
LiteLLM defaults to the `invoke` route. LiteLLM uses the `converse` route for Bedrock models that support it.
To explicitly set the route, do `bedrock/converse/<model>` or `bedrock/invoke/<model>`.
E.g.
<Tabs>
<TabItem value="sdk" label="SDK">
```python
from litellm import completion
completion(model="bedrock/converse/us.amazon.nova-pro-v1:0")
```
</TabItem>
<TabItem value="proxy" label="PROXY">
```yaml
model_list:
- model_name: bedrock-model
litellm_params:
model: bedrock/converse/us.amazon.nova-pro-v1:0
```
</TabItem>
</Tabs>
## Alternate user/assistant messages
Use `user_continue_message` to add a default user message, for cases (e.g. Autogen) where the client might not follow alternating user/assistant messages starting and ending with a user message.
@ -745,6 +776,174 @@ curl -X POST 'http://0.0.0.0:4000/chat/completions' \
}'
```
## Usage - PDF / Document Understanding
LiteLLM supports Document Understanding for Bedrock models - [AWS Bedrock Docs](https://docs.aws.amazon.com/nova/latest/userguide/modalities-document.html).
### url
<Tabs>
<TabItem value="sdk" label="SDK">
```python
from litellm.utils import supports_pdf_input, completion
# set aws credentials
os.environ["AWS_ACCESS_KEY_ID"] = ""
os.environ["AWS_SECRET_ACCESS_KEY"] = ""
os.environ["AWS_REGION_NAME"] = ""
# pdf url
image_url = "https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf"
# model
model = "bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0"
image_content = [
{"type": "text", "text": "What's this file about?"},
{
"type": "image_url",
"image_url": image_url, # OR {"url": image_url}
},
]
if not supports_pdf_input(model, None):
print("Model does not support image input")
response = completion(
model=model,
messages=[{"role": "user", "content": image_content}],
)
assert response is not None
```
</TabItem>
<TabItem value="proxy" label="PROXY">
1. Setup config.yaml
```yaml
model_list:
- model_name: bedrock-model
litellm_params:
model: bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0
aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID
aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY
aws_region_name: os.environ/AWS_REGION_NAME
```
2. Start the proxy
```bash
litellm --config /path/to/config.yaml
```
3. Test it!
```bash
curl -X POST 'http://0.0.0.0:4000/chat/completions' \
-H 'Content-Type: application/json' \
-H 'Authorization: Bearer sk-1234' \
-d '{
"model": "bedrock-model",
"messages": [
{"role": "user", "content": {"type": "text", "text": "What's this file about?"}},
{
"type": "image_url",
"image_url": "https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf",
}
]
}'
```
</TabItem>
</Tabs>
### base64
<Tabs>
<TabItem value="sdk" label="SDK">
```python
from litellm.utils import supports_pdf_input, completion
# set aws credentials
os.environ["AWS_ACCESS_KEY_ID"] = ""
os.environ["AWS_SECRET_ACCESS_KEY"] = ""
os.environ["AWS_REGION_NAME"] = ""
# pdf url
image_url = "https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf"
response = requests.get(url)
file_data = response.content
encoded_file = base64.b64encode(file_data).decode("utf-8")
base64_url = f"data:application/pdf;base64,{encoded_file}"
# model
model = "bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0"
image_content = [
{"type": "text", "text": "What's this file about?"},
{
"type": "image_url",
"image_url": base64_url, # OR {"url": base64_url}
},
]
if not supports_pdf_input(model, None):
print("Model does not support image input")
response = completion(
model=model,
messages=[{"role": "user", "content": image_content}],
)
assert response is not None
```
</TabItem>
<TabItem value="proxy" label="PROXY">
1. Setup config.yaml
```yaml
model_list:
- model_name: bedrock-model
litellm_params:
model: bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0
aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID
aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY
aws_region_name: os.environ/AWS_REGION_NAME
```
2. Start the proxy
```bash
litellm --config /path/to/config.yaml
```
3. Test it!
```bash
curl -X POST 'http://0.0.0.0:4000/chat/completions' \
-H 'Content-Type: application/json' \
-H 'Authorization: Bearer sk-1234' \
-d '{
"model": "bedrock-model",
"messages": [
{"role": "user", "content": {"type": "text", "text": "What's this file about?"}},
{
"type": "image_url",
"image_url": "data:application/pdf;base64,{b64_encoded_file}",
}
]
}'
```
</TabItem>
</Tabs>
## Boto3 - Authentication
### Passing credentials as parameters - Completion()

View file

@ -0,0 +1,63 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# Galadriel
https://docs.galadriel.com/api-reference/chat-completion-API
LiteLLM supports all models on Galadriel.
## API Key
```python
import os
os.environ['GALADRIEL_API_KEY'] = "your-api-key"
```
## Sample Usage
```python
from litellm import completion
import os
os.environ['GALADRIEL_API_KEY'] = ""
response = completion(
model="galadriel/llama3.1",
messages=[
{"role": "user", "content": "hello from litellm"}
],
)
print(response)
```
## Sample Usage - Streaming
```python
from litellm import completion
import os
os.environ['GALADRIEL_API_KEY'] = ""
response = completion(
model="galadriel/llama3.1",
messages=[
{"role": "user", "content": "hello from litellm"}
],
stream=True
)
for chunk in response:
print(chunk)
```
## Supported Models
### Serverless Endpoints
We support ALL Galadriel AI models, just set `galadriel/` as a prefix when sending completion requests
We support both the complete model name and the simplified name match.
You can specify the model name either with the full name or with a simplified version e.g. `llama3.1:70b`
| Model Name | Simplified Name | Function Call |
| -------------------------------------------------------- | -------------------------------- | ------------------------------------------------------- |
| neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8 | llama3.1 or llama3.1:8b | `completion(model="galadriel/llama3.1", messages)` |
| neuralmagic/Meta-Llama-3.1-70B-Instruct-quantized.w4a16 | llama3.1:70b | `completion(model="galadriel/llama3.1:70b", messages)` |
| neuralmagic/Meta-Llama-3.1-405B-Instruct-quantized.w4a16 | llama3.1:405b | `completion(model="galadriel/llama3.1:405b", messages)` |
| neuralmagic/Mistral-Nemo-Instruct-2407-quantized.w4a16 | mistral-nemo or mistral-nemo:12b | `completion(model="galadriel/mistral-nemo", messages)` |

View file

@ -134,28 +134,9 @@ general_settings:
| content_policy_fallbacks | array of objects | Fallbacks to use when a ContentPolicyViolationError is encountered. [Further docs](./reliability#content-policy-fallbacks) |
| context_window_fallbacks | array of objects | Fallbacks to use when a ContextWindowExceededError is encountered. [Further docs](./reliability#context-window-fallbacks) |
| cache | boolean | If true, enables caching. [Further docs](./caching) |
| cache_params | object | Parameters for the cache. [Further docs](./caching) |
| cache_params.type | string | The type of cache to initialize. Can be one of ["local", "redis", "redis-semantic", "s3", "disk", "qdrant-semantic"]. Defaults to "redis". [Furher docs](./caching) |
| cache_params.host | string | The host address for the Redis cache. Required if type is "redis". |
| cache_params.port | integer | The port number for the Redis cache. Required if type is "redis". |
| cache_params.password | string | The password for the Redis cache. Required if type is "redis". |
| cache_params.namespace | string | The namespace for the Redis cache. |
| cache_params.redis_startup_nodes | array of objects | Redis Cluster Settings. [Further docs](./caching) |
| cache_params.service_name | string | Redis Sentinel Settings. [Further docs](./caching) |
| cache_params.sentinel_nodes | array of arrays | Redis Sentinel Settings. [Further docs](./caching) |
| cache_params.ttl | integer | The time (in seconds) to store entries in cache. |
| cache_params.qdrant_semantic_cache_embedding_model | string | The embedding model to use for qdrant semantic cache. |
| cache_params.qdrant_collection_name | string | The name of the collection to use for qdrant semantic cache. |
| cache_params.qdrant_quantization_config | string | The quantization configuration for the qdrant semantic cache. |
| cache_params.similarity_threshold | float | The similarity threshold for the semantic cache. |
| cache_params.s3_bucket_name | string | The name of the S3 bucket to use for the semantic cache. |
| cache_params.s3_region_name | string | The region name for the S3 bucket. |
| cache_params.s3_aws_access_key_id | string | The AWS access key ID for the S3 bucket. |
| cache_params.s3_aws_secret_access_key | string | The AWS secret access key for the S3 bucket. |
| cache_params.s3_endpoint_url | string | Optional - The endpoint URL for the S3 bucket. |
| cache_params.supported_call_types | array of strings | The types of calls to cache. [Further docs](./caching) |
| cache_params.mode | string | The mode of the cache. [Further docs](./caching) |
| cache_params | object | Parameters for the cache. [Further docs](./caching#supported-cache_params-on-proxy-configyaml) |
| disable_end_user_cost_tracking | boolean | If true, turns off end user cost tracking on prometheus metrics + litellm spend logs table on proxy. |
| disable_end_user_cost_tracking_prometheus_only | boolean | If true, turns off end user cost tracking on prometheus metrics only. |
| key_generation_settings | object | Restricts who can generate keys. [Further docs](./virtual_keys.md#restricting-key-generation) |
### general_settings - Reference

View file

@ -507,6 +507,11 @@ curl -X GET "http://0.0.0.0:4000/spend/logs?request_id=<your-call-id" \ # e.g.:
## Enforce Required Params for LLM Requests
Use this when you want to enforce all requests to include certain params. Example you need all requests to include the `user` and `["metadata]["generation_name"]` params.
<Tabs>
<TabItem value="config" label="Set on Config">
**Step 1** Define all Params you want to enforce on config.yaml
This means `["user"]` and `["metadata]["generation_name"]` are required in all LLM Requests to LiteLLM
@ -518,8 +523,21 @@ general_settings:
- user
- metadata.generation_name
```
</TabItem>
Start LiteLLM Proxy
<TabItem value="key" label="Set on Key">
```bash
curl -L -X POST 'http://0.0.0.0:4000/key/generate' \
-H 'Authorization: Bearer sk-1234' \
-H 'Content-Type: application/json' \
-d '{
"enforced_params": ["user", "metadata.generation_name"]
}'
```
</TabItem>
</Tabs>
**Step 2 Verify if this works**

View file

@ -828,8 +828,18 @@ litellm_settings:
#### Spec
```python
key_generation_settings: Optional[StandardKeyGenerationConfig] = None
```
#### Types
```python
class StandardKeyGenerationConfig(TypedDict, total=False):
team_key_generation: TeamUIKeyGenerationConfig
personal_key_generation: PersonalUIKeyGenerationConfig
class TeamUIKeyGenerationConfig(TypedDict):
allowed_team_member_roles: List[str]
allowed_team_member_roles: List[str] # either 'user' or 'admin'
required_params: List[str] # require params on `/key/generate` to be set if a team key (team_id in request) is being generated
@ -838,11 +848,6 @@ class PersonalUIKeyGenerationConfig(TypedDict):
required_params: List[str] # require params on `/key/generate` to be set if a personal key (no team_id in request) is being generated
class StandardKeyGenerationConfig(TypedDict, total=False):
team_key_generation: TeamUIKeyGenerationConfig
personal_key_generation: PersonalUIKeyGenerationConfig
class LitellmUserRoles(str, enum.Enum):
"""
Admin Roles:

View file

@ -177,6 +177,7 @@ const sidebars = {
"providers/ollama",
"providers/perplexity",
"providers/friendliai",
"providers/galadriel",
"providers/groq",
"providers/github",
"providers/deepseek",
@ -210,6 +211,7 @@ const sidebars = {
"completion/provider_specific_params",
"guides/finetuned_models",
"completion/audio",
"completion/document_understanding",
"completion/vision",
"completion/json_mode",
"completion/prompt_caching",

View file

@ -288,6 +288,7 @@ max_internal_user_budget: Optional[float] = None
internal_user_budget_duration: Optional[str] = None
max_end_user_budget: Optional[float] = None
disable_end_user_cost_tracking: Optional[bool] = None
disable_end_user_cost_tracking_prometheus_only: Optional[bool] = None
#### REQUEST PRIORITIZATION ####
priority_reservation: Optional[Dict[str, float]] = None
#### RELIABILITY ####
@ -385,6 +386,31 @@ organization = None
project = None
config_path = None
vertex_ai_safety_settings: Optional[dict] = None
BEDROCK_CONVERSE_MODELS = [
"anthropic.claude-3-5-haiku-20241022-v1:0",
"anthropic.claude-3-5-sonnet-20241022-v2:0",
"anthropic.claude-3-5-sonnet-20240620-v1:0",
"anthropic.claude-3-opus-20240229-v1:0",
"anthropic.claude-3-sonnet-20240229-v1:0",
"anthropic.claude-3-haiku-20240307-v1:0",
"anthropic.claude-v2",
"anthropic.claude-v2:1",
"anthropic.claude-v1",
"anthropic.claude-instant-v1",
"ai21.jamba-instruct-v1:0",
"meta.llama3-70b-instruct-v1:0",
"meta.llama3-8b-instruct-v1:0",
"meta.llama3-1-8b-instruct-v1:0",
"meta.llama3-1-70b-instruct-v1:0",
"meta.llama3-1-405b-instruct-v1:0",
"meta.llama3-70b-instruct-v1:0",
"mistral.mistral-large-2407-v1:0",
"meta.llama3-2-1b-instruct-v1:0",
"meta.llama3-2-3b-instruct-v1:0",
"meta.llama3-2-11b-instruct-v1:0",
"meta.llama3-2-90b-instruct-v1:0",
"meta.llama3-2-405b-instruct-v1:0",
]
####### COMPLETION MODELS ###################
open_ai_chat_completion_models: List = []
open_ai_text_completion_models: List = []
@ -412,6 +438,7 @@ ai21_chat_models: List = []
nlp_cloud_models: List = []
aleph_alpha_models: List = []
bedrock_models: List = []
bedrock_converse_models: List = BEDROCK_CONVERSE_MODELS
fireworks_ai_models: List = []
fireworks_ai_embedding_models: List = []
deepinfra_models: List = []
@ -431,6 +458,7 @@ groq_models: List = []
azure_models: List = []
anyscale_models: List = []
cerebras_models: List = []
galadriel_models: List = []
def add_known_models():
@ -491,6 +519,8 @@ def add_known_models():
aleph_alpha_models.append(key)
elif value.get("litellm_provider") == "bedrock":
bedrock_models.append(key)
elif value.get("litellm_provider") == "bedrock_converse":
bedrock_converse_models.append(key)
elif value.get("litellm_provider") == "deepinfra":
deepinfra_models.append(key)
elif value.get("litellm_provider") == "perplexity":
@ -535,6 +565,8 @@ def add_known_models():
anyscale_models.append(key)
elif value.get("litellm_provider") == "cerebras":
cerebras_models.append(key)
elif value.get("litellm_provider") == "galadriel":
galadriel_models.append(key)
add_known_models()
@ -554,6 +586,7 @@ openai_compatible_endpoints: List = [
"inference.friendli.ai/v1",
"api.sambanova.ai/v1",
"api.x.ai/v1",
"api.galadriel.ai/v1",
]
# this is maintained for Exception Mapping
@ -581,6 +614,7 @@ openai_compatible_providers: List = [
"litellm_proxy",
"hosted_vllm",
"lm_studio",
"galadriel",
]
openai_text_completion_compatible_providers: List = (
[ # providers that support `/v1/completions`
@ -794,6 +828,7 @@ model_list = (
+ azure_models
+ anyscale_models
+ cerebras_models
+ galadriel_models
)
@ -860,6 +895,7 @@ class LlmProviders(str, Enum):
LITELLM_PROXY = "litellm_proxy"
HOSTED_VLLM = "hosted_vllm"
LM_STUDIO = "lm_studio"
GALADRIEL = "galadriel"
provider_list: List[Union[LlmProviders, str]] = list(LlmProviders)
@ -882,7 +918,7 @@ models_by_provider: dict = {
+ vertex_vision_models
+ vertex_language_models,
"ai21": ai21_models,
"bedrock": bedrock_models,
"bedrock": bedrock_models + bedrock_converse_models,
"petals": petals_models,
"ollama": ollama_models,
"deepinfra": deepinfra_models,
@ -908,6 +944,7 @@ models_by_provider: dict = {
"azure": azure_models,
"anyscale": anyscale_models,
"cerebras": cerebras_models,
"galadriel": galadriel_models,
}
# mapping for those models which have larger equivalents
@ -1067,9 +1104,6 @@ from .llms.bedrock.chat.invoke_handler import (
AmazonConverseConfig,
bedrock_tool_name_mappings,
)
from .llms.bedrock.chat.converse_handler import (
BEDROCK_CONVERSE_MODELS,
)
from .llms.bedrock.common_utils import (
AmazonTitanConfig,
AmazonAI21Config,

View file

@ -365,7 +365,9 @@ class PrometheusLogger(CustomLogger):
model = kwargs.get("model", "")
litellm_params = kwargs.get("litellm_params", {}) or {}
_metadata = litellm_params.get("metadata", {})
end_user_id = get_end_user_id_for_cost_tracking(litellm_params)
end_user_id = get_end_user_id_for_cost_tracking(
litellm_params, service_type="prometheus"
)
user_id = standard_logging_payload["metadata"]["user_api_key_user_id"]
user_api_key = standard_logging_payload["metadata"]["user_api_key_hash"]
user_api_key_alias = standard_logging_payload["metadata"]["user_api_key_alias"]
@ -668,7 +670,9 @@ class PrometheusLogger(CustomLogger):
"standard_logging_object", {}
)
litellm_params = kwargs.get("litellm_params", {}) or {}
end_user_id = get_end_user_id_for_cost_tracking(litellm_params)
end_user_id = get_end_user_id_for_cost_tracking(
litellm_params, service_type="prometheus"
)
user_id = standard_logging_payload["metadata"]["user_api_key_user_id"]
user_api_key = standard_logging_payload["metadata"]["user_api_key_hash"]
user_api_key_alias = standard_logging_payload["metadata"]["user_api_key_alias"]

View file

@ -177,6 +177,9 @@ def get_llm_provider( # noqa: PLR0915
dynamic_api_key = get_secret_str(
"FRIENDLIAI_API_KEY"
) or get_secret("FRIENDLI_TOKEN")
elif endpoint == "api.galadriel.com/v1":
custom_llm_provider = "galadriel"
dynamic_api_key = get_secret_str("GALADRIEL_API_KEY")
if api_base is not None and not isinstance(api_base, str):
raise Exception(
@ -526,6 +529,11 @@ def _get_openai_compatible_provider_info( # noqa: PLR0915
or get_secret_str("FRIENDLIAI_API_KEY")
or get_secret_str("FRIENDLI_TOKEN")
)
elif custom_llm_provider == "galadriel":
api_base = (
api_base or get_secret("GALADRIEL_API_BASE") or "https://api.galadriel.com/v1"
) # type: ignore
dynamic_api_key = api_key or get_secret_str("GALADRIEL_API_KEY")
if api_base is not None and not isinstance(api_base, str):
raise Exception("api base needs to be a string. api_base={}".format(api_base))
if dynamic_api_key is not None and not isinstance(dynamic_api_key, str):

View file

@ -18,32 +18,6 @@ from ...base_aws_llm import BaseAWSLLM
from ..common_utils import BedrockError
from .invoke_handler import AWSEventStreamDecoder, MockResponseIterator, make_call
BEDROCK_CONVERSE_MODELS = [
"anthropic.claude-3-5-haiku-20241022-v1:0",
"anthropic.claude-3-5-sonnet-20241022-v2:0",
"anthropic.claude-3-5-sonnet-20240620-v1:0",
"anthropic.claude-3-opus-20240229-v1:0",
"anthropic.claude-3-sonnet-20240229-v1:0",
"anthropic.claude-3-haiku-20240307-v1:0",
"anthropic.claude-v2",
"anthropic.claude-v2:1",
"anthropic.claude-v1",
"anthropic.claude-instant-v1",
"ai21.jamba-instruct-v1:0",
"meta.llama3-70b-instruct-v1:0",
"meta.llama3-8b-instruct-v1:0",
"meta.llama3-1-8b-instruct-v1:0",
"meta.llama3-1-70b-instruct-v1:0",
"meta.llama3-1-405b-instruct-v1:0",
"meta.llama3-70b-instruct-v1:0",
"mistral.mistral-large-2407-v1:0",
"meta.llama3-2-1b-instruct-v1:0",
"meta.llama3-2-3b-instruct-v1:0",
"meta.llama3-2-11b-instruct-v1:0",
"meta.llama3-2-90b-instruct-v1:0",
"meta.llama3-2-405b-instruct-v1:0",
]
def make_sync_call(
client: Optional[HTTPHandler],
@ -53,6 +27,8 @@ def make_sync_call(
model: str,
messages: list,
logging_obj,
json_mode: Optional[bool] = False,
fake_stream: bool = False,
):
if client is None:
client = _get_httpx_client() # Create a new client if none provided
@ -61,13 +37,13 @@ def make_sync_call(
api_base,
headers=headers,
data=data,
stream=True if "ai21" not in api_base else False,
stream=not fake_stream,
)
if response.status_code != 200:
raise BedrockError(status_code=response.status_code, message=response.read())
if "ai21" in api_base:
if fake_stream:
model_response: (
ModelResponse
) = litellm.AmazonConverseConfig()._transform_response(
@ -83,7 +59,9 @@ def make_sync_call(
print_verbose=litellm.print_verbose,
encoding=litellm.encoding,
) # type: ignore
completion_stream: Any = MockResponseIterator(model_response=model_response)
completion_stream: Any = MockResponseIterator(
model_response=model_response, json_mode=json_mode
)
else:
decoder = AWSEventStreamDecoder(model=model)
completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024))
@ -130,6 +108,8 @@ class BedrockConverseLLM(BaseAWSLLM):
logger_fn=None,
headers={},
client: Optional[AsyncHTTPHandler] = None,
fake_stream: bool = False,
json_mode: Optional[bool] = False,
) -> CustomStreamWrapper:
completion_stream = await make_call(
@ -140,6 +120,8 @@ class BedrockConverseLLM(BaseAWSLLM):
model=model,
messages=messages,
logging_obj=logging_obj,
fake_stream=fake_stream,
json_mode=json_mode,
)
streaming_response = CustomStreamWrapper(
completion_stream=completion_stream,
@ -231,12 +213,15 @@ class BedrockConverseLLM(BaseAWSLLM):
## SETUP ##
stream = optional_params.pop("stream", None)
modelId = optional_params.pop("model_id", None)
fake_stream = optional_params.pop("fake_stream", False)
json_mode = optional_params.get("json_mode", False)
if modelId is not None:
modelId = self.encode_model_id(model_id=modelId)
else:
modelId = model
provider = model.split(".")[0]
if stream is True and "ai21" in modelId:
fake_stream = True
## CREDENTIALS ##
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
@ -290,7 +275,7 @@ class BedrockConverseLLM(BaseAWSLLM):
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
aws_region_name=aws_region_name,
)
if (stream is not None and stream is True) and provider != "ai21":
if (stream is not None and stream is True) and not fake_stream:
endpoint_url = f"{endpoint_url}/model/{modelId}/converse-stream"
proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/converse-stream"
else:
@ -355,6 +340,8 @@ class BedrockConverseLLM(BaseAWSLLM):
headers=prepped.headers,
timeout=timeout,
client=client,
json_mode=json_mode,
fake_stream=fake_stream,
) # type: ignore
### ASYNC COMPLETION
return self.async_completion(
@ -398,6 +385,8 @@ class BedrockConverseLLM(BaseAWSLLM):
model=model,
messages=messages,
logging_obj=logging_obj,
json_mode=json_mode,
fake_stream=fake_stream,
)
streaming_response = CustomStreamWrapper(
completion_stream=completion_stream,

View file

@ -134,6 +134,43 @@ class AmazonConverseConfig:
def get_supported_image_types(self) -> List[str]:
return ["png", "jpeg", "gif", "webp"]
def get_supported_document_types(self) -> List[str]:
return ["pdf", "csv", "doc", "docx", "xls", "xlsx", "html", "txt", "md"]
def _create_json_tool_call_for_response_format(
self,
json_schema: Optional[dict] = None,
schema_name: str = "json_tool_call",
) -> ChatCompletionToolParam:
"""
Handles creating a tool call for getting responses in JSON format.
Args:
json_schema (Optional[dict]): The JSON schema the response should be in
Returns:
AnthropicMessagesTool: The tool call to send to Anthropic API to get responses in JSON format
"""
if json_schema is None:
# Anthropic raises a 400 BadRequest error if properties is passed as None
# see usage with additionalProperties (Example 5) https://github.com/anthropics/anthropic-cookbook/blob/main/tool_use/extracting_structured_json.ipynb
_input_schema = {
"type": "object",
"additionalProperties": True,
"properties": {},
}
else:
_input_schema = json_schema
_tool = ChatCompletionToolParam(
type="function",
function=ChatCompletionToolParamFunctionChunk(
name=schema_name, parameters=_input_schema
),
)
return _tool
def map_openai_params(
self,
model: str,
@ -160,31 +197,20 @@ class AmazonConverseConfig:
- You should set tool_choice (see Forcing tool use) to instruct the model to explicitly use that tool
- Remember that the model will pass the input to the tool, so the name of the tool and description should be from the models perspective.
"""
if json_schema is not None:
_tool_choice = self.map_tool_choice_values(
model=model, tool_choice="required", drop_params=drop_params # type: ignore
_tool_choice = {"name": schema_name, "type": "tool"}
_tool = self._create_json_tool_call_for_response_format(
json_schema=json_schema,
schema_name=schema_name if schema_name != "" else "json_tool_call",
)
optional_params["tools"] = [_tool]
optional_params["tool_choice"] = ToolChoiceValuesBlock(
tool=SpecificToolChoiceBlock(
name=schema_name if schema_name != "" else "json_tool_call"
)
_tool = ChatCompletionToolParam(
type="function",
function=ChatCompletionToolParamFunctionChunk(
name=schema_name, parameters=json_schema
),
)
optional_params["tools"] = [_tool]
optional_params["tool_choice"] = _tool_choice
optional_params["json_mode"] = True
else:
if litellm.drop_params is True or drop_params is True:
pass
else:
raise litellm.utils.UnsupportedParamsError(
message="Bedrock doesn't support response_format={}. To drop it from the call, set `litellm.drop_params = True.".format(
value
),
status_code=400,
)
)
optional_params["json_mode"] = True
if non_default_params.get("stream", False) is True:
optional_params["fake_stream"] = True
if param == "max_tokens" or param == "max_completion_tokens":
optional_params["maxTokens"] = value
if param == "stream":
@ -330,7 +356,6 @@ class AmazonConverseConfig:
print_verbose,
encoding,
) -> Union[ModelResponse, CustomStreamWrapper]:
## LOGGING
if logging_obj is not None:
logging_obj.post_call(
@ -468,6 +493,9 @@ class AmazonConverseConfig:
AND "meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1"
"""
if model.startswith("converse/"):
model = model.split("/")[1]
potential_region = model.split(".", 1)[0]
if potential_region in self._supported_cross_region_inference_region():
return model.split(".", 1)[1]

View file

@ -182,6 +182,8 @@ async def make_call(
model: str,
messages: list,
logging_obj,
fake_stream: bool = False,
json_mode: Optional[bool] = False,
):
try:
if client is None:
@ -193,13 +195,13 @@ async def make_call(
api_base,
headers=headers,
data=data,
stream=True if "ai21" not in api_base else False,
stream=not fake_stream,
)
if response.status_code != 200:
raise BedrockError(status_code=response.status_code, message=response.text)
if "ai21" in api_base:
if fake_stream:
model_response: (
ModelResponse
) = litellm.AmazonConverseConfig()._transform_response(
@ -215,7 +217,9 @@ async def make_call(
print_verbose=litellm.print_verbose,
encoding=litellm.encoding,
) # type: ignore
completion_stream: Any = MockResponseIterator(model_response=model_response)
completion_stream: Any = MockResponseIterator(
model_response=model_response, json_mode=json_mode
)
else:
decoder = AWSEventStreamDecoder(model=model)
completion_stream = decoder.aiter_bytes(
@ -1028,6 +1032,7 @@ class BedrockLLM(BaseAWSLLM):
model=model,
messages=messages,
logging_obj=logging_obj,
fake_stream=True if "ai21" in api_base else False,
),
model=model,
custom_llm_provider="bedrock",
@ -1271,21 +1276,58 @@ class AWSEventStreamDecoder:
class MockResponseIterator: # for returning ai21 streaming responses
def __init__(self, model_response):
def __init__(self, model_response, json_mode: Optional[bool] = False):
self.model_response = model_response
self.json_mode = json_mode
self.is_done = False
# Sync iterator
def __iter__(self):
return self
def _chunk_parser(self, chunk_data: ModelResponse) -> GChunk:
def _handle_json_mode_chunk(
self, text: str, tool_calls: Optional[List[ChatCompletionToolCallChunk]]
) -> Tuple[str, Optional[ChatCompletionToolCallChunk]]:
"""
If JSON mode is enabled, convert the tool call to a message.
Bedrock returns the JSON schema as part of the tool call
OpenAI returns the JSON schema as part of the content, this handles placing it in the content
Args:
text: str
tool_use: Optional[ChatCompletionToolCallChunk]
Returns:
Tuple[str, Optional[ChatCompletionToolCallChunk]]
text: The text to use in the content
tool_use: The ChatCompletionToolCallChunk to use in the chunk response
"""
tool_use: Optional[ChatCompletionToolCallChunk] = None
if self.json_mode is True and tool_calls is not None:
message = litellm.AnthropicConfig()._convert_tool_response_to_message(
tool_calls=tool_calls
)
if message is not None:
text = message.content or ""
tool_use = None
elif tool_calls is not None and len(tool_calls) > 0:
tool_use = tool_calls[0]
return text, tool_use
def _chunk_parser(self, chunk_data: ModelResponse) -> GChunk:
try:
chunk_usage: litellm.Usage = getattr(chunk_data, "usage")
text = chunk_data.choices[0].message.content or "" # type: ignore
tool_use = None
if self.json_mode is True:
text, tool_use = self._handle_json_mode_chunk(
text=text,
tool_calls=chunk_data.choices[0].message.tool_calls, # type: ignore
)
processed_chunk = GChunk(
text=chunk_data.choices[0].message.content or "", # type: ignore
tool_use=None,
text=text,
tool_use=tool_use,
is_finished=True,
finish_reason=map_finish_reason(
finish_reason=chunk_data.choices[0].finish_reason or ""
@ -1298,8 +1340,8 @@ class MockResponseIterator: # for returning ai21 streaming responses
index=0,
)
return processed_chunk
except Exception:
raise ValueError(f"Failed to decode chunk: {chunk_data}")
except Exception as e:
raise ValueError(f"Failed to decode chunk: {chunk_data}. Error: {e}")
def __next__(self):
if self.is_done:

View file

@ -2162,8 +2162,9 @@ def stringify_json_tool_call_content(messages: List) -> List:
###### AMAZON BEDROCK #######
from litellm.types.llms.bedrock import ContentBlock as BedrockContentBlock
from litellm.types.llms.bedrock import DocumentBlock as BedrockDocumentBlock
from litellm.types.llms.bedrock import ImageBlock as BedrockImageBlock
from litellm.types.llms.bedrock import ImageSourceBlock as BedrockImageSourceBlock
from litellm.types.llms.bedrock import SourceBlock as BedrockSourceBlock
from litellm.types.llms.bedrock import ToolBlock as BedrockToolBlock
from litellm.types.llms.bedrock import (
ToolChoiceValuesBlock as BedrockToolChoiceValuesBlock,
@ -2210,7 +2211,9 @@ def get_image_details(image_url) -> Tuple[str, str]:
raise e
def _process_bedrock_converse_image_block(image_url: str) -> BedrockImageBlock:
def _process_bedrock_converse_image_block(
image_url: str,
) -> BedrockContentBlock:
if "base64" in image_url:
# Case 1: Images with base64 encoding
import base64
@ -2228,12 +2231,17 @@ def _process_bedrock_converse_image_block(image_url: str) -> BedrockImageBlock:
else:
mime_type = "image/jpeg"
image_format = "jpeg"
_blob = BedrockImageSourceBlock(bytes=img_without_base_64)
_blob = BedrockSourceBlock(bytes=img_without_base_64)
supported_image_formats = (
litellm.AmazonConverseConfig().get_supported_image_types()
)
supported_document_types = (
litellm.AmazonConverseConfig().get_supported_document_types()
)
if image_format in supported_image_formats:
return BedrockImageBlock(source=_blob, format=image_format) # type: ignore
return BedrockContentBlock(image=BedrockImageBlock(source=_blob, format=image_format)) # type: ignore
elif image_format in supported_document_types:
return BedrockContentBlock(document=BedrockDocumentBlock(source=_blob, format=image_format, name="DocumentPDFmessages_{}".format(str(uuid.uuid4())))) # type: ignore
else:
# Handle the case when the image format is not supported
raise ValueError(
@ -2244,12 +2252,17 @@ def _process_bedrock_converse_image_block(image_url: str) -> BedrockImageBlock:
elif "https:/" in image_url:
# Case 2: Images with direct links
image_bytes, image_format = get_image_details(image_url)
_blob = BedrockImageSourceBlock(bytes=image_bytes)
_blob = BedrockSourceBlock(bytes=image_bytes)
supported_image_formats = (
litellm.AmazonConverseConfig().get_supported_image_types()
)
supported_document_types = (
litellm.AmazonConverseConfig().get_supported_document_types()
)
if image_format in supported_image_formats:
return BedrockImageBlock(source=_blob, format=image_format) # type: ignore
return BedrockContentBlock(image=BedrockImageBlock(source=_blob, format=image_format)) # type: ignore
elif image_format in supported_document_types:
return BedrockContentBlock(document=BedrockDocumentBlock(source=_blob, format=image_format, name="DocumentPDFmessages_{}".format(str(uuid.uuid4())))) # type: ignore
else:
# Handle the case when the image format is not supported
raise ValueError(
@ -2464,11 +2477,14 @@ def _bedrock_converse_messages_pt( # noqa: PLR0915
_part = BedrockContentBlock(text=element["text"])
_parts.append(_part)
elif element["type"] == "image_url":
image_url = element["image_url"]["url"]
if isinstance(element["image_url"], dict):
image_url = element["image_url"]["url"]
else:
image_url = element["image_url"]
_part = _process_bedrock_converse_image_block( # type: ignore
image_url=image_url
)
_parts.append(BedrockContentBlock(image=_part)) # type: ignore
_parts.append(_part) # type: ignore
user_content.extend(_parts)
else:
_part = BedrockContentBlock(text=messages[msg_i]["content"])
@ -2539,13 +2555,14 @@ def _bedrock_converse_messages_pt( # noqa: PLR0915
assistants_part = BedrockContentBlock(text=element["text"])
assistants_parts.append(assistants_part)
elif element["type"] == "image_url":
image_url = element["image_url"]["url"]
if isinstance(element["image_url"], dict):
image_url = element["image_url"]["url"]
else:
image_url = element["image_url"]
assistants_part = _process_bedrock_converse_image_block( # type: ignore
image_url=image_url
)
assistants_parts.append(
BedrockContentBlock(image=assistants_part) # type: ignore
)
assistants_parts.append(assistants_part)
assistant_content.extend(assistants_parts)
elif messages[msg_i].get("content", None) is not None and isinstance(
messages[msg_i]["content"], str

View file

@ -2603,7 +2603,10 @@ def completion( # type: ignore # noqa: PLR0915
base_model = litellm.AmazonConverseConfig()._get_base_model(model)
if base_model in litellm.BEDROCK_CONVERSE_MODELS:
if base_model in litellm.bedrock_converse_models or model.startswith(
"converse/"
):
model = model.replace("converse/", "")
response = bedrock_converse_chat_completion.completion(
model=model,
messages=messages,
@ -2622,6 +2625,7 @@ def completion( # type: ignore # noqa: PLR0915
api_base=api_base,
)
else:
model = model.replace("invoke/", "")
response = bedrock_chat_completion.completion(
model=model,
messages=messages,

View file

@ -4795,6 +4795,42 @@
"mode": "chat",
"supports_function_calling": true
},
"amazon.nova-micro-v1:0": {
"max_tokens": 4096,
"max_input_tokens": 300000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000000035,
"output_cost_per_token": 0.00000014,
"litellm_provider": "bedrock_converse",
"mode": "chat",
"supports_function_calling": true,
"supports_vision": true,
"supports_pdf_input": true
},
"amazon.nova-lite-v1:0": {
"max_tokens": 4096,
"max_input_tokens": 128000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.00000006,
"output_cost_per_token": 0.00000024,
"litellm_provider": "bedrock_converse",
"mode": "chat",
"supports_function_calling": true,
"supports_vision": true,
"supports_pdf_input": true
},
"amazon.nova-pro-v1:0": {
"max_tokens": 4096,
"max_input_tokens": 300000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.0000008,
"output_cost_per_token": 0.0000032,
"litellm_provider": "bedrock_converse",
"mode": "chat",
"supports_function_calling": true,
"supports_vision": true,
"supports_pdf_input": true
},
"anthropic.claude-3-sonnet-20240229-v1:0": {
"max_tokens": 4096,
"max_input_tokens": 200000,

View file

@ -3,6 +3,12 @@ model_list:
- model_name: gpt-4
litellm_params:
model: gpt-4
rpm: 1
- model_name: gpt-4
litellm_params:
model: azure/chatgpt-v-2
api_key: os.environ/AZURE_API_KEY
api_base: os.environ/AZURE_API_BASE
- model_name: rerank-model
litellm_params:
model: jina_ai/jina-reranker-v2-base-multilingual

View file

@ -667,6 +667,7 @@ class _GenerateKeyRequest(GenerateRequestBase):
class GenerateKeyRequest(_GenerateKeyRequest):
tags: Optional[List[str]] = None
enforced_params: Optional[List[str]] = None
class GenerateKeyResponse(_GenerateKeyRequest):
@ -2190,4 +2191,5 @@ LiteLLM_ManagementEndpoint_MetadataFields = [
"model_tpm_limit",
"guardrails",
"tags",
"enforced_params",
]

View file

@ -156,40 +156,6 @@ def common_checks( # noqa: PLR0915
raise Exception(
f"'user' param not passed in. 'enforce_user_param'={general_settings['enforce_user_param']}"
)
if general_settings.get("enforced_params", None) is not None:
# Enterprise ONLY Feature
# we already validate if user is premium_user when reading the config
# Add an extra premium_usercheck here too, just incase
from litellm.proxy.proxy_server import CommonProxyErrors, premium_user
if premium_user is not True:
raise ValueError(
"Trying to use `enforced_params`"
+ CommonProxyErrors.not_premium_user.value
)
if RouteChecks.is_llm_api_route(route=route):
# loop through each enforced param
# example enforced_params ['user', 'metadata', 'metadata.generation_name']
for enforced_param in general_settings["enforced_params"]:
_enforced_params = enforced_param.split(".")
if len(_enforced_params) == 1:
if _enforced_params[0] not in request_body:
raise ValueError(
f"BadRequest please pass param={_enforced_params[0]} in request body. This is a required param"
)
elif len(_enforced_params) == 2:
# this is a scenario where user requires request['metadata']['generation_name'] to exist
if _enforced_params[0] not in request_body:
raise ValueError(
f"BadRequest please pass param={_enforced_params[0]} in request body. This is a required param"
)
if _enforced_params[1] not in request_body[_enforced_params[0]]:
raise ValueError(
f"BadRequest please pass param=[{_enforced_params[0]}][{_enforced_params[1]}] in request body. This is a required param"
)
pass
# 7. [OPTIONAL] If 'litellm.max_budget' is set (>0), is proxy under budget
if (
litellm.max_budget > 0

View file

@ -587,6 +587,16 @@ async def add_litellm_data_to_request( # noqa: PLR0915
f"[PROXY]returned data from litellm_pre_call_utils: {data}"
)
## ENFORCED PARAMS CHECK
# loop through each enforced param
# example enforced_params ['user', 'metadata', 'metadata.generation_name']
_enforced_params_check(
request_body=data,
general_settings=general_settings,
user_api_key_dict=user_api_key_dict,
premium_user=premium_user,
)
end_time = time.time()
await service_logger_obj.async_service_success_hook(
service=ServiceTypes.PROXY_PRE_CALL,
@ -599,6 +609,64 @@ async def add_litellm_data_to_request( # noqa: PLR0915
return data
def _get_enforced_params(
general_settings: Optional[dict], user_api_key_dict: UserAPIKeyAuth
) -> Optional[list]:
enforced_params: Optional[list] = None
if general_settings is not None:
enforced_params = general_settings.get("enforced_params")
if "service_account_settings" in general_settings:
service_account_settings = general_settings["service_account_settings"]
if "enforced_params" in service_account_settings:
if enforced_params is None:
enforced_params = []
enforced_params.extend(service_account_settings["enforced_params"])
if user_api_key_dict.metadata.get("enforced_params", None) is not None:
if enforced_params is None:
enforced_params = []
enforced_params.extend(user_api_key_dict.metadata["enforced_params"])
return enforced_params
def _enforced_params_check(
request_body: dict,
general_settings: Optional[dict],
user_api_key_dict: UserAPIKeyAuth,
premium_user: bool,
) -> bool:
"""
If enforced params are set, check if the request body contains the enforced params.
"""
enforced_params: Optional[list] = _get_enforced_params(
general_settings=general_settings, user_api_key_dict=user_api_key_dict
)
if enforced_params is None:
return True
if enforced_params is not None and premium_user is not True:
raise ValueError(
f"Enforced Params is an Enterprise feature. Enforced Params: {enforced_params}. {CommonProxyErrors.not_premium_user.value}"
)
for enforced_param in enforced_params:
_enforced_params = enforced_param.split(".")
if len(_enforced_params) == 1:
if _enforced_params[0] not in request_body:
raise ValueError(
f"BadRequest please pass param={_enforced_params[0]} in request body. This is a required param"
)
elif len(_enforced_params) == 2:
# this is a scenario where user requires request['metadata']['generation_name'] to exist
if _enforced_params[0] not in request_body:
raise ValueError(
f"BadRequest please pass param={_enforced_params[0]} in request body. This is a required param"
)
if _enforced_params[1] not in request_body[_enforced_params[0]]:
raise ValueError(
f"BadRequest please pass param=[{_enforced_params[0]}][{_enforced_params[1]}] in request body. This is a required param"
)
return True
def move_guardrails_to_metadata(
data: dict,
_metadata_variable_name: str,

View file

@ -255,6 +255,7 @@ async def generate_key_fn( # noqa: PLR0915
- tpm_limit: Optional[int] - Specify tpm limit for a given key (Tokens per minute)
- soft_budget: Optional[float] - Specify soft budget for a given key. Will trigger a slack alert when this soft budget is reached.
- tags: Optional[List[str]] - Tags for [tracking spend](https://litellm.vercel.app/docs/proxy/enterprise#tracking-spend-for-custom-tags) and/or doing [tag-based routing](https://litellm.vercel.app/docs/proxy/tag_routing).
- enforced_params: Optional[List[str]] - List of enforced params for the key (Enterprise only). [Docs](https://docs.litellm.ai/docs/proxy/enterprise#enforce-required-params-for-llm-requests)
Examples:
@ -459,7 +460,6 @@ def prepare_metadata_fields(
"""
Check LiteLLM_ManagementEndpoint_MetadataFields (proxy/_types.py) for fields that are allowed to be updated
"""
if "metadata" not in non_default_values: # allow user to set metadata to none
non_default_values["metadata"] = existing_metadata.copy()
@ -469,18 +469,8 @@ def prepare_metadata_fields(
try:
for k, v in data_json.items():
if k == "model_tpm_limit" or k == "model_rpm_limit":
if k not in casted_metadata or casted_metadata[k] is None:
casted_metadata[k] = {}
casted_metadata[k].update(v)
if k == "tags" or k == "guardrails":
if k not in casted_metadata or casted_metadata[k] is None:
casted_metadata[k] = []
seen = set(casted_metadata[k])
casted_metadata[k].extend(
x for x in v if x not in seen and not seen.add(x) # type: ignore
) # prevent duplicates from being added + maintain initial order
if k in LiteLLM_ManagementEndpoint_MetadataFields:
casted_metadata[k] = v
except Exception as e:
verbose_proxy_logger.exception(
@ -498,10 +488,9 @@ def prepare_key_update_data(
):
data_json: dict = data.model_dump(exclude_unset=True)
data_json.pop("key", None)
_metadata_fields = ["model_rpm_limit", "model_tpm_limit", "guardrails", "tags"]
non_default_values = {}
for k, v in data_json.items():
if k in _metadata_fields:
if k in LiteLLM_ManagementEndpoint_MetadataFields:
continue
non_default_values[k] = v
@ -556,6 +545,7 @@ async def update_key_fn(
- team_id: Optional[str] - Team ID associated with key
- models: Optional[list] - Model_name's a user is allowed to call
- tags: Optional[List[str]] - Tags for organizing keys (Enterprise only)
- enforced_params: Optional[List[str]] - List of enforced params for the key (Enterprise only). [Docs](https://docs.litellm.ai/docs/proxy/enterprise#enforce-required-params-for-llm-requests)
- spend: Optional[float] - Amount spent by key
- max_budget: Optional[float] - Max budget for key
- model_max_budget: Optional[dict] - Model-specific budgets {"gpt-4": 0.5, "claude-v1": 1.0}

View file

@ -2940,6 +2940,7 @@ class Router:
remaining_retries=num_retries,
num_retries=num_retries,
healthy_deployments=_healthy_deployments,
all_deployments=_all_deployments,
)
await asyncio.sleep(retry_after)
@ -2972,6 +2973,7 @@ class Router:
remaining_retries=remaining_retries,
num_retries=num_retries,
healthy_deployments=_healthy_deployments,
all_deployments=_all_deployments,
)
await asyncio.sleep(_timeout)
@ -3149,6 +3151,7 @@ class Router:
remaining_retries: int,
num_retries: int,
healthy_deployments: Optional[List] = None,
all_deployments: Optional[List] = None,
) -> Union[int, float]:
"""
Calculate back-off, then retry
@ -3157,10 +3160,14 @@ class Router:
1. there are healthy deployments in the same model group
2. there are fallbacks for the completion call
"""
if (
## base case - single deployment
if all_deployments is not None and len(all_deployments) == 1:
pass
elif (
healthy_deployments is not None
and isinstance(healthy_deployments, list)
and len(healthy_deployments) > 1
and len(healthy_deployments) > 0
):
return 0
@ -3242,6 +3249,7 @@ class Router:
remaining_retries=num_retries,
num_retries=num_retries,
healthy_deployments=_healthy_deployments,
all_deployments=_all_deployments,
)
## LOGGING
@ -3276,6 +3284,7 @@ class Router:
remaining_retries=remaining_retries,
num_retries=num_retries,
healthy_deployments=_healthy_deployments,
all_deployments=_all_deployments,
)
time.sleep(_timeout)

View file

@ -18,17 +18,24 @@ class SystemContentBlock(TypedDict):
text: str
class ImageSourceBlock(TypedDict):
class SourceBlock(TypedDict):
bytes: Optional[str] # base 64 encoded string
class ImageBlock(TypedDict):
format: Literal["png", "jpeg", "gif", "webp"]
source: ImageSourceBlock
source: SourceBlock
class DocumentBlock(TypedDict):
format: Literal["pdf", "csv", "doc", "docx", "xls", "xlsx", "html", "txt", "md"]
source: SourceBlock
name: str
class ToolResultContentBlock(TypedDict, total=False):
image: ImageBlock
document: DocumentBlock
json: dict
text: str
@ -48,6 +55,7 @@ class ToolUseBlock(TypedDict):
class ContentBlock(TypedDict, total=False):
text: str
image: ImageBlock
document: DocumentBlock
toolResult: ToolResultBlock
toolUse: ToolUseBlock

View file

@ -106,6 +106,7 @@ class ModelInfo(TypedDict, total=False):
supports_prompt_caching: Optional[bool]
supports_audio_input: Optional[bool]
supports_audio_output: Optional[bool]
supports_pdf_input: Optional[bool]
tpm: Optional[int]
rpm: Optional[int]

View file

@ -3142,7 +3142,7 @@ def get_optional_params( # noqa: PLR0915
model=model, custom_llm_provider=custom_llm_provider
)
base_model = litellm.AmazonConverseConfig()._get_base_model(model)
if base_model in litellm.BEDROCK_CONVERSE_MODELS:
if base_model in litellm.bedrock_converse_models:
_check_valid_arg(supported_params=supported_params)
optional_params = litellm.AmazonConverseConfig().map_openai_params(
model=model,
@ -4308,6 +4308,10 @@ def _strip_stable_vertex_version(model_name) -> str:
return re.sub(r"-\d+$", "", model_name)
def _strip_bedrock_region(model_name) -> str:
return litellm.AmazonConverseConfig()._get_base_model(model_name)
def _strip_openai_finetune_model_name(model_name: str) -> str:
"""
Strips the organization, custom suffix, and ID from an OpenAI fine-tuned model name.
@ -4324,16 +4328,50 @@ def _strip_openai_finetune_model_name(model_name: str) -> str:
return re.sub(r"(:[^:]+){3}$", "", model_name)
def _strip_model_name(model: str) -> str:
strip_version = _strip_stable_vertex_version(model_name=model)
strip_finetune = _strip_openai_finetune_model_name(model_name=strip_version)
return strip_finetune
def _strip_model_name(model: str, custom_llm_provider: Optional[str]) -> str:
if custom_llm_provider and custom_llm_provider == "bedrock":
strip_bedrock_region = _strip_bedrock_region(model_name=model)
return strip_bedrock_region
elif custom_llm_provider and (
custom_llm_provider == "vertex_ai" or custom_llm_provider == "gemini"
):
strip_version = _strip_stable_vertex_version(model_name=model)
return strip_version
else:
strip_finetune = _strip_openai_finetune_model_name(model_name=model)
return strip_finetune
def _get_model_info_from_model_cost(key: str) -> dict:
return litellm.model_cost[key]
def _check_provider_match(model_info: dict, custom_llm_provider: Optional[str]) -> bool:
"""
Check if the model info provider matches the custom provider.
"""
if custom_llm_provider and (
"litellm_provider" in model_info
and model_info["litellm_provider"] != custom_llm_provider
):
if custom_llm_provider == "vertex_ai" and model_info[
"litellm_provider"
].startswith("vertex_ai"):
return True
elif custom_llm_provider == "fireworks_ai" and model_info[
"litellm_provider"
].startswith("fireworks_ai"):
return True
elif custom_llm_provider == "bedrock" and model_info[
"litellm_provider"
].startswith("bedrock"):
return True
else:
return False
return True
def get_model_info( # noqa: PLR0915
model: str, custom_llm_provider: Optional[str] = None
) -> ModelInfo:
@ -4388,6 +4426,7 @@ def get_model_info( # noqa: PLR0915
supports_prompt_caching: Optional[bool]
supports_audio_input: Optional[bool]
supports_audio_output: Optional[bool]
supports_pdf_input: Optional[bool]
Raises:
Exception: If the model is not mapped yet.
@ -4445,15 +4484,21 @@ def get_model_info( # noqa: PLR0915
except Exception:
split_model = model
combined_model_name = model
stripped_model_name = _strip_model_name(model=model)
stripped_model_name = _strip_model_name(
model=model, custom_llm_provider=custom_llm_provider
)
combined_stripped_model_name = stripped_model_name
else:
split_model = model
combined_model_name = "{}/{}".format(custom_llm_provider, model)
stripped_model_name = _strip_model_name(model=model)
combined_stripped_model_name = "{}/{}".format(
custom_llm_provider, _strip_model_name(model=model)
stripped_model_name = _strip_model_name(
model=model, custom_llm_provider=custom_llm_provider
)
combined_stripped_model_name = "{}/{}".format(
custom_llm_provider,
_strip_model_name(model=model, custom_llm_provider=custom_llm_provider),
)
#########################
supported_openai_params = litellm.get_supported_openai_params(
@ -4476,6 +4521,7 @@ def get_model_info( # noqa: PLR0915
supports_function_calling=None,
supports_assistant_prefill=None,
supports_prompt_caching=None,
supports_pdf_input=None,
)
elif custom_llm_provider == "ollama" or custom_llm_provider == "ollama_chat":
return litellm.OllamaConfig().get_model_info(model)
@ -4488,40 +4534,25 @@ def get_model_info( # noqa: PLR0915
4. 'stripped_model_name' in litellm.model_cost. Checks if 'ft:gpt-3.5-turbo' in model map, if 'ft:gpt-3.5-turbo:my-org:custom_suffix:id' given.
5. 'split_model' in litellm.model_cost. Checks "llama3-8b-8192" in litellm.model_cost if model="groq/llama3-8b-8192"
"""
_model_info: Optional[Dict[str, Any]] = None
key: Optional[str] = None
if combined_model_name in litellm.model_cost:
key = combined_model_name
_model_info = _get_model_info_from_model_cost(key=key)
_model_info["supported_openai_params"] = supported_openai_params
if (
"litellm_provider" in _model_info
and _model_info["litellm_provider"] != custom_llm_provider
if not _check_provider_match(
model_info=_model_info, custom_llm_provider=custom_llm_provider
):
if custom_llm_provider == "vertex_ai" and _model_info[
"litellm_provider"
].startswith("vertex_ai"):
pass
else:
_model_info = None
_model_info = None
if _model_info is None and model in litellm.model_cost:
key = model
_model_info = _get_model_info_from_model_cost(key=key)
_model_info["supported_openai_params"] = supported_openai_params
if (
"litellm_provider" in _model_info
and _model_info["litellm_provider"] != custom_llm_provider
if not _check_provider_match(
model_info=_model_info, custom_llm_provider=custom_llm_provider
):
if custom_llm_provider == "vertex_ai" and _model_info[
"litellm_provider"
].startswith("vertex_ai"):
pass
elif custom_llm_provider == "fireworks_ai" and _model_info[
"litellm_provider"
].startswith("fireworks_ai"):
pass
else:
_model_info = None
_model_info = None
if (
_model_info is None
and combined_stripped_model_name in litellm.model_cost
@ -4529,57 +4560,26 @@ def get_model_info( # noqa: PLR0915
key = combined_stripped_model_name
_model_info = _get_model_info_from_model_cost(key=key)
_model_info["supported_openai_params"] = supported_openai_params
if (
"litellm_provider" in _model_info
and _model_info["litellm_provider"] != custom_llm_provider
if not _check_provider_match(
model_info=_model_info, custom_llm_provider=custom_llm_provider
):
if custom_llm_provider == "vertex_ai" and _model_info[
"litellm_provider"
].startswith("vertex_ai"):
pass
elif custom_llm_provider == "fireworks_ai" and _model_info[
"litellm_provider"
].startswith("fireworks_ai"):
pass
else:
_model_info = None
_model_info = None
if _model_info is None and stripped_model_name in litellm.model_cost:
key = stripped_model_name
_model_info = _get_model_info_from_model_cost(key=key)
_model_info["supported_openai_params"] = supported_openai_params
if (
"litellm_provider" in _model_info
and _model_info["litellm_provider"] != custom_llm_provider
if not _check_provider_match(
model_info=_model_info, custom_llm_provider=custom_llm_provider
):
if custom_llm_provider == "vertex_ai" and _model_info[
"litellm_provider"
].startswith("vertex_ai"):
pass
elif custom_llm_provider == "fireworks_ai" and _model_info[
"litellm_provider"
].startswith("fireworks_ai"):
pass
else:
_model_info = None
_model_info = None
if _model_info is None and split_model in litellm.model_cost:
key = split_model
_model_info = _get_model_info_from_model_cost(key=key)
_model_info["supported_openai_params"] = supported_openai_params
if (
"litellm_provider" in _model_info
and _model_info["litellm_provider"] != custom_llm_provider
if not _check_provider_match(
model_info=_model_info, custom_llm_provider=custom_llm_provider
):
if custom_llm_provider == "vertex_ai" and _model_info[
"litellm_provider"
].startswith("vertex_ai"):
pass
elif custom_llm_provider == "fireworks_ai" and _model_info[
"litellm_provider"
].startswith("fireworks_ai"):
pass
else:
_model_info = None
_model_info = None
if _model_info is None or key is None:
raise ValueError(
"This model isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json"
@ -4675,6 +4675,7 @@ def get_model_info( # noqa: PLR0915
),
supports_audio_input=_model_info.get("supports_audio_input", False),
supports_audio_output=_model_info.get("supports_audio_output", False),
supports_pdf_input=_model_info.get("supports_pdf_input", False),
tpm=_model_info.get("tpm", None),
rpm=_model_info.get("rpm", None),
)
@ -6195,11 +6196,21 @@ class ProviderConfigManager:
return OpenAIGPTConfig()
def get_end_user_id_for_cost_tracking(litellm_params: dict) -> Optional[str]:
def get_end_user_id_for_cost_tracking(
litellm_params: dict,
service_type: Literal["litellm_logging", "prometheus"] = "litellm_logging",
) -> Optional[str]:
"""
Used for enforcing `disable_end_user_cost_tracking` param.
service_type: "litellm_logging" or "prometheus" - used to allow prometheus only disable cost tracking.
"""
proxy_server_request = litellm_params.get("proxy_server_request") or {}
if litellm.disable_end_user_cost_tracking:
return None
if (
service_type == "prometheus"
and litellm.disable_end_user_cost_tracking_prometheus_only
):
return None
return proxy_server_request.get("body", {}).get("user", None)

View file

@ -4795,6 +4795,42 @@
"mode": "chat",
"supports_function_calling": true
},
"amazon.nova-micro-v1:0": {
"max_tokens": 4096,
"max_input_tokens": 300000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000000035,
"output_cost_per_token": 0.00000014,
"litellm_provider": "bedrock_converse",
"mode": "chat",
"supports_function_calling": true,
"supports_vision": true,
"supports_pdf_input": true
},
"amazon.nova-lite-v1:0": {
"max_tokens": 4096,
"max_input_tokens": 128000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.00000006,
"output_cost_per_token": 0.00000024,
"litellm_provider": "bedrock_converse",
"mode": "chat",
"supports_function_calling": true,
"supports_vision": true,
"supports_pdf_input": true
},
"amazon.nova-pro-v1:0": {
"max_tokens": 4096,
"max_input_tokens": 300000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.0000008,
"output_cost_per_token": 0.0000032,
"litellm_provider": "bedrock_converse",
"mode": "chat",
"supports_function_calling": true,
"supports_vision": true,
"supports_pdf_input": true
},
"anthropic.claude-3-sonnet-20240229-v1:0": {
"max_tokens": 4096,
"max_input_tokens": 200000,

View file

@ -54,6 +54,36 @@ class BaseLLMChatTest(ABC):
# for OpenAI the content contains the JSON schema, so we need to assert that the content is not None
assert response.choices[0].message.content is not None
@pytest.mark.parametrize("image_url", ["str", "dict"])
def test_pdf_handling(self, pdf_messages, image_url):
from litellm.utils import supports_pdf_input
if image_url == "str":
image_url = pdf_messages
elif image_url == "dict":
image_url = {"url": pdf_messages}
image_content = [
{"type": "text", "text": "What's this file about?"},
{
"type": "image_url",
"image_url": image_url,
},
]
image_messages = [{"role": "user", "content": image_content}]
base_completion_call_args = self.get_base_completion_call_args()
if not supports_pdf_input(base_completion_call_args["model"], None):
pytest.skip("Model does not support image input")
response = litellm.completion(
**base_completion_call_args,
messages=image_messages,
)
assert response is not None
def test_message_with_name(self):
litellm.set_verbose = True
base_completion_call_args = self.get_base_completion_call_args()
@ -187,7 +217,7 @@ class BaseLLMChatTest(ABC):
for chunk in response:
content += chunk.choices[0].delta.content or ""
print("content=", content)
print(f"content={content}<END>")
# OpenAI guarantees that the JSON schema is returned in the content
# relevant issue: https://github.com/BerriAI/litellm/issues/6741
@ -250,7 +280,7 @@ class BaseLLMChatTest(ABC):
import requests
# URL of the file
url = "https://storage.googleapis.com/cloud-samples-data/generative-ai/pdf/2403.05530.pdf"
url = "https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf"
response = requests.get(url)
file_data = response.content
@ -258,14 +288,4 @@ class BaseLLMChatTest(ABC):
encoded_file = base64.b64encode(file_data).decode("utf-8")
url = f"data:application/pdf;base64,{encoded_file}"
image_content = [
{"type": "text", "text": "What's this file about?"},
{
"type": "image_url",
"image_url": {"url": url},
},
]
image_messages = [{"role": "user", "content": image_content}]
return image_messages
return url

View file

@ -668,35 +668,6 @@ class TestAnthropicCompletion(BaseLLMChatTest):
def get_base_completion_call_args(self) -> dict:
return {"model": "claude-3-haiku-20240307"}
def test_pdf_handling(self, pdf_messages):
from litellm.llms.custom_httpx.http_handler import HTTPHandler
from litellm.types.llms.anthropic import AnthropicMessagesDocumentParam
import json
client = HTTPHandler()
with patch.object(client, "post", new=MagicMock()) as mock_client:
response = completion(
model="claude-3-5-sonnet-20241022",
messages=pdf_messages,
client=client,
)
mock_client.assert_called_once()
json_data = json.loads(mock_client.call_args.kwargs["data"])
headers = mock_client.call_args.kwargs["headers"]
assert headers["anthropic-beta"] == "pdfs-2024-09-25"
json_data["messages"][0]["role"] == "user"
_document_validation = AnthropicMessagesDocumentParam(
**json_data["messages"][0]["content"][1]
)
assert _document_validation["type"] == "document"
assert _document_validation["source"]["media_type"] == "application/pdf"
assert _document_validation["source"]["type"] == "base64"
def test_tool_call_no_arguments(self, tool_call_no_arguments):
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
from litellm.llms.prompt_templates.factory import (

View file

@ -30,6 +30,7 @@ from litellm import (
from litellm.llms.bedrock.chat import BedrockLLM
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.llms.prompt_templates.factory import _bedrock_tools_pt
from base_llm_unit_tests import BaseLLMChatTest
# litellm.num_retries = 3
litellm.cache = None
@ -1943,3 +1944,42 @@ def test_bedrock_context_window_error():
messages=[{"role": "user", "content": "Hello, world!"}],
mock_response=Exception("prompt is too long"),
)
def test_bedrock_converse_route():
litellm.set_verbose = True
litellm.completion(
model="bedrock/converse/us.amazon.nova-pro-v1:0",
messages=[{"role": "user", "content": "Hello, world!"}],
)
def test_bedrock_mapped_converse_models():
litellm.set_verbose = True
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
litellm.model_cost = litellm.get_model_cost_map(url="")
litellm.add_known_models()
litellm.completion(
model="bedrock/us.amazon.nova-pro-v1:0",
messages=[{"role": "user", "content": "Hello, world!"}],
)
def test_bedrock_base_model_helper():
model = "us.amazon.nova-pro-v1:0"
litellm.AmazonConverseConfig()._get_base_model(model)
assert model == "us.amazon.nova-pro-v1:0"
class TestBedrockConverseChat(BaseLLMChatTest):
def get_base_completion_call_args(self) -> dict:
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
litellm.model_cost = litellm.get_model_cost_map(url="")
litellm.add_known_models()
return {
"model": "bedrock/us.anthropic.claude-3-haiku-20240307-v1:0",
}
def test_tool_call_no_arguments(self, tool_call_no_arguments):
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
pass

View file

@ -695,29 +695,6 @@ async def test_anthropic_no_content_error():
pytest.fail(f"An unexpected error occurred - {str(e)}")
def test_gemini_completion_call_error():
try:
print("test completion + streaming")
litellm.num_retries = 3
litellm.set_verbose = True
messages = [{"role": "user", "content": "what is the capital of congo?"}]
response = completion(
model="gemini/gemini-1.5-pro-latest",
messages=messages,
stream=True,
max_tokens=10,
)
print(f"response: {response}")
for chunk in response:
print(chunk)
except litellm.RateLimitError:
pass
except litellm.InternalServerError:
pass
except Exception as e:
pytest.fail(f"error occurred: {str(e)}")
def test_completion_cohere_command_r_plus_function_call():
litellm.set_verbose = True
tools = [

View file

@ -2342,12 +2342,6 @@ def test_router_dynamic_cooldown_correct_retry_after_time():
pass
new_retry_after_mock_client.assert_called()
print(
f"new_retry_after_mock_client.call_args.kwargs: {new_retry_after_mock_client.call_args.kwargs}"
)
print(
f"new_retry_after_mock_client.call_args: {new_retry_after_mock_client.call_args[0][0]}"
)
response_headers: httpx.Headers = new_retry_after_mock_client.call_args[0][0]
assert int(response_headers["retry-after"]) == cooldown_time

View file

@ -539,45 +539,71 @@ def test_raise_context_window_exceeded_error_no_retry():
"""
def test_timeout_for_rate_limit_error_with_healthy_deployments():
@pytest.mark.parametrize("num_deployments, expected_timeout", [(1, 60), (2, 0.0)])
def test_timeout_for_rate_limit_error_with_healthy_deployments(
num_deployments, expected_timeout
):
"""
Test 1. Timeout is 0.0 when RateLimit Error and healthy deployments are > 0
"""
healthy_deployments = [
"deployment1",
"deployment2",
] # multiple healthy deployments mocked up
fallbacks = None
router = litellm.Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"),
},
}
]
cooldown_time = 60
rate_limit_error = litellm.RateLimitError(
message="{RouterErrors.no_deployments_available.value}. 12345 Passed model={model_group}. Deployments={deployment_dict}",
llm_provider="",
model="gpt-3.5-turbo",
response=httpx.Response(
status_code=429,
content="",
headers={"retry-after": str(cooldown_time)}, # type: ignore
request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore
),
)
model_list = [
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"),
},
}
]
if num_deployments == 2:
model_list.append(
{
"model_name": "gpt-4",
"litellm_params": {"model": "gpt-3.5-turbo"},
}
)
router = litellm.Router(model_list=model_list)
_timeout = router._time_to_sleep_before_retry(
e=rate_limit_error,
remaining_retries=4,
num_retries=4,
healthy_deployments=healthy_deployments,
remaining_retries=2,
num_retries=2,
healthy_deployments=[
{
"model_name": "gpt-4",
"litellm_params": {
"api_key": "my-key",
"api_base": "https://openai-gpt-4-test-v-1.openai.azure.com",
"model": "azure/chatgpt-v-2",
},
"model_info": {
"id": "0e30bc8a63fa91ae4415d4234e231b3f9e6dd900cac57d118ce13a720d95e9d6",
"db_model": False,
},
}
],
all_deployments=model_list,
)
print(
"timeout=",
_timeout,
"error is rate_limit_error and there are healthy deployments=",
healthy_deployments,
)
assert _timeout == 0.0
if expected_timeout == 0.0:
assert _timeout == expected_timeout
else:
assert _timeout > 0.0
def test_timeout_for_rate_limit_error_with_no_healthy_deployments():
@ -585,26 +611,26 @@ def test_timeout_for_rate_limit_error_with_no_healthy_deployments():
Test 2. Timeout is > 0.0 when RateLimit Error and healthy deployments == 0
"""
healthy_deployments = []
model_list = [
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"),
},
}
]
router = litellm.Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"),
},
}
]
)
router = litellm.Router(model_list=model_list)
_timeout = router._time_to_sleep_before_retry(
e=rate_limit_error,
remaining_retries=4,
num_retries=4,
healthy_deployments=healthy_deployments,
all_deployments=model_list,
)
print(

View file

@ -1005,7 +1005,10 @@ def test_models_by_provider():
continue
elif k == "sample_spec":
continue
elif v["litellm_provider"] == "sagemaker":
elif (
v["litellm_provider"] == "sagemaker"
or v["litellm_provider"] == "bedrock_converse"
):
continue
else:
providers.add(v["litellm_provider"])
@ -1032,3 +1035,27 @@ def test_get_end_user_id_for_cost_tracking(
get_end_user_id_for_cost_tracking(litellm_params=litellm_params)
== expected_end_user_id
)
@pytest.mark.parametrize(
"litellm_params, disable_end_user_cost_tracking_prometheus_only, expected_end_user_id",
[
({}, False, None),
({"proxy_server_request": {"body": {"user": "123"}}}, False, "123"),
({"proxy_server_request": {"body": {"user": "123"}}}, True, None),
],
)
def test_get_end_user_id_for_cost_tracking_prometheus_only(
litellm_params, disable_end_user_cost_tracking_prometheus_only, expected_end_user_id
):
from litellm.utils import get_end_user_id_for_cost_tracking
litellm.disable_end_user_cost_tracking_prometheus_only = (
disable_end_user_cost_tracking_prometheus_only
)
assert (
get_end_user_id_for_cost_tracking(
litellm_params=litellm_params, service_type="prometheus"
)
== expected_end_user_id
)

View file

@ -695,45 +695,43 @@ def test_personal_key_generation_check():
)
def test_prepare_metadata_fields():
@pytest.mark.parametrize(
"update_request_data, non_default_values, existing_metadata, expected_result",
[
(
{"metadata": {"test": "new"}},
{"metadata": {"test": "new"}},
{"test": "test"},
{"metadata": {"test": "new"}},
),
(
{"tags": ["new_tag"]},
{},
{"tags": ["old_tag"]},
{"metadata": {"tags": ["new_tag"]}},
),
(
{"enforced_params": ["metadata.tags"]},
{},
{"tags": ["old_tag"]},
{"metadata": {"tags": ["old_tag"], "enforced_params": ["metadata.tags"]}},
),
],
)
def test_prepare_metadata_fields(
update_request_data, non_default_values, existing_metadata, expected_result
):
from litellm.proxy.management_endpoints.key_management_endpoints import (
prepare_metadata_fields,
)
new_metadata = {"test": "new"}
old_metadata = {"test": "test"}
args = {
"data": UpdateKeyRequest(
key_alias=None,
duration=None,
models=[],
spend=None,
max_budget=None,
user_id=None,
team_id=None,
max_parallel_requests=None,
metadata=new_metadata,
tpm_limit=None,
rpm_limit=None,
budget_duration=None,
allowed_cache_controls=[],
soft_budget=None,
config={},
permissions={},
model_max_budget={},
send_invite_email=None,
model_rpm_limit=None,
model_tpm_limit=None,
guardrails=None,
blocked=None,
aliases={},
key="sk-1qGQUJJTcljeaPfzgWRrXQ",
tags=None,
key="sk-1qGQUJJTcljeaPfzgWRrXQ", **update_request_data
),
"non_default_values": {"metadata": new_metadata},
"existing_metadata": {"tags": None, **old_metadata},
"non_default_values": non_default_values,
"existing_metadata": existing_metadata,
}
non_default_values = prepare_metadata_fields(**args)
assert non_default_values == {"metadata": new_metadata}
updated_non_default_values = prepare_metadata_fields(**args)
assert updated_non_default_values == expected_result

View file

@ -2664,66 +2664,6 @@ async def test_create_update_team(prisma_client):
)
@pytest.mark.asyncio()
async def test_enforced_params(prisma_client):
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
from litellm.proxy.proxy_server import general_settings
general_settings["enforced_params"] = [
"user",
"metadata",
"metadata.generation_name",
]
await litellm.proxy.proxy_server.prisma_client.connect()
request = NewUserRequest()
key = await new_user(
data=request,
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="1234",
),
)
print(key)
generated_key = key.key
bearer_token = "Bearer " + generated_key
request = Request(scope={"type": "http"})
request._url = URL(url="/chat/completions")
# Case 1: Missing user
async def return_body():
return b'{"model": "gemini-pro-vision"}'
request.body = return_body
try:
await user_api_key_auth(request=request, api_key=bearer_token)
pytest.fail(f"This should have failed!. IT's an invalid request")
except Exception as e:
assert (
"BadRequest please pass param=user in request body. This is a required param"
in e.message
)
# Case 2: Missing metadata["generation_name"]
async def return_body_2():
return b'{"model": "gemini-pro-vision", "user": "1234", "metadata": {}}'
request.body = return_body_2
try:
await user_api_key_auth(request=request, api_key=bearer_token)
pytest.fail(f"This should have failed!. IT's an invalid request")
except Exception as e:
assert (
"Authentication Error, BadRequest please pass param=[metadata][generation_name] in request body"
in e.message
)
general_settings.pop("enforced_params")
@pytest.mark.asyncio()
async def test_update_user_role(prisma_client):
"""
@ -3363,64 +3303,6 @@ async def test_auth_vertex_ai_route(prisma_client):
pass
@pytest.mark.asyncio
async def test_service_accounts(prisma_client):
"""
Do not delete
this is the Admin UI flow
"""
# Make a call to a key with model = `all-proxy-models` this is an Alias from LiteLLM Admin UI
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
setattr(
litellm.proxy.proxy_server,
"general_settings",
{"service_account_settings": {"enforced_params": ["user"]}},
)
await litellm.proxy.proxy_server.prisma_client.connect()
request = GenerateKeyRequest(
metadata={"service_account_id": f"prod-service-{uuid.uuid4()}"},
)
response = await generate_key_fn(
data=request,
)
print("key generated=", response)
generated_key = response.key
bearer_token = "Bearer " + generated_key
# make a bad /chat/completions call expect it to fail
request = Request(scope={"type": "http"})
request._url = URL(url="/chat/completions")
async def return_body():
return b'{"model": "gemini-pro-vision"}'
request.body = return_body
# use generated key to auth in
print("Bearer token being sent to user_api_key_auth() - {}".format(bearer_token))
try:
result = await user_api_key_auth(request=request, api_key=bearer_token)
pytest.fail("Expected this call to fail. Bad request using service account")
except Exception as e:
print("error str=", str(e.message))
assert "This is a required param for service account" in str(e.message)
# make a good /chat/completions call it should pass
async def good_return_body():
return b'{"model": "gemini-pro-vision", "user": "foo"}'
request.body = good_return_body
result = await user_api_key_auth(request=request, api_key=bearer_token)
print("response from user_api_key_auth", result)
setattr(litellm.proxy.proxy_server, "general_settings", {})
@pytest.mark.asyncio
async def test_user_api_key_auth_db_unavailable():
"""

View file

@ -679,3 +679,132 @@ async def test_add_litellm_data_to_request_duplicate_tags(
assert sorted(result["metadata"]["tags"]) == sorted(
expected_tags
), f"Expected {expected_tags}, got {result['metadata']['tags']}"
@pytest.mark.parametrize(
"general_settings, user_api_key_dict, expected_enforced_params",
[
(
{"enforced_params": ["param1", "param2"]},
UserAPIKeyAuth(
api_key="test_api_key", user_id="test_user_id", org_id="test_org_id"
),
["param1", "param2"],
),
(
{"service_account_settings": {"enforced_params": ["param1", "param2"]}},
UserAPIKeyAuth(
api_key="test_api_key", user_id="test_user_id", org_id="test_org_id"
),
["param1", "param2"],
),
(
{"service_account_settings": {"enforced_params": ["param1", "param2"]}},
UserAPIKeyAuth(
api_key="test_api_key",
metadata={"enforced_params": ["param3", "param4"]},
),
["param1", "param2", "param3", "param4"],
),
],
)
def test_get_enforced_params(
general_settings, user_api_key_dict, expected_enforced_params
):
from litellm.proxy.litellm_pre_call_utils import _get_enforced_params
enforced_params = _get_enforced_params(general_settings, user_api_key_dict)
assert enforced_params == expected_enforced_params
@pytest.mark.parametrize(
"general_settings, user_api_key_dict, request_body, expected_error",
[
(
{"enforced_params": ["param1", "param2"]},
UserAPIKeyAuth(
api_key="test_api_key", user_id="test_user_id", org_id="test_org_id"
),
{},
True,
),
(
{"service_account_settings": {"enforced_params": ["user"]}},
UserAPIKeyAuth(
api_key="test_api_key", user_id="test_user_id", org_id="test_org_id"
),
{},
True,
),
(
{},
UserAPIKeyAuth(
api_key="test_api_key",
metadata={"enforced_params": ["user"]},
),
{},
True,
),
(
{},
UserAPIKeyAuth(
api_key="test_api_key",
metadata={"enforced_params": ["user"]},
),
{"user": "test_user"},
False,
),
(
{"enforced_params": ["user"]},
UserAPIKeyAuth(
api_key="test_api_key",
),
{"user": "test_user"},
False,
),
(
{"service_account_settings": {"enforced_params": ["user"]}},
UserAPIKeyAuth(
api_key="test_api_key",
),
{"user": "test_user"},
False,
),
(
{"enforced_params": ["metadata.generation_name"]},
UserAPIKeyAuth(
api_key="test_api_key",
),
{"metadata": {}},
True,
),
(
{"enforced_params": ["metadata.generation_name"]},
UserAPIKeyAuth(
api_key="test_api_key",
),
{"metadata": {"generation_name": "test_generation_name"}},
False,
),
],
)
def test_enforced_params_check(
general_settings, user_api_key_dict, request_body, expected_error
):
from litellm.proxy.litellm_pre_call_utils import _enforced_params_check
if expected_error:
with pytest.raises(ValueError):
_enforced_params_check(
request_body=request_body,
general_settings=general_settings,
user_api_key_dict=user_api_key_dict,
premium_user=True,
)
else:
_enforced_params_check(
request_body=request_body,
general_settings=general_settings,
user_api_key_dict=user_api_key_dict,
premium_user=True,
)