diff --git a/README.md b/README.md
index 5d3efe3559..41bb756199 100644
--- a/README.md
+++ b/README.md
@@ -271,6 +271,7 @@ curl 'http://0.0.0.0:4000/key/generate' \
| [voyage ai](https://docs.litellm.ai/docs/providers/voyage) | | | | | ✅ | |
| [xinference [Xorbits Inference]](https://docs.litellm.ai/docs/providers/xinference) | | | | | ✅ | |
| [FriendliAI](https://docs.litellm.ai/docs/providers/friendliai) | ✅ | ✅ | ✅ | ✅ | | |
+| [Galadriel](https://docs.litellm.ai/docs/providers/galadriel) | ✅ | ✅ | ✅ | ✅ | | |
[**Read the Docs**](https://docs.litellm.ai/docs/)
diff --git a/docs/my-website/docs/completion/document_understanding.md b/docs/my-website/docs/completion/document_understanding.md
new file mode 100644
index 0000000000..6719169aef
--- /dev/null
+++ b/docs/my-website/docs/completion/document_understanding.md
@@ -0,0 +1,202 @@
+import Tabs from '@theme/Tabs';
+import TabItem from '@theme/TabItem';
+
+# Using PDF Input
+
+How to send / receieve pdf's (other document types) to a `/chat/completions` endpoint
+
+Works for:
+- Vertex AI models (Gemini + Anthropic)
+- Bedrock Models
+- Anthropic API Models
+
+## Quick Start
+
+### url
+
+
+
+
+```python
+from litellm.utils import supports_pdf_input, completion
+
+# set aws credentials
+os.environ["AWS_ACCESS_KEY_ID"] = ""
+os.environ["AWS_SECRET_ACCESS_KEY"] = ""
+os.environ["AWS_REGION_NAME"] = ""
+
+
+# pdf url
+image_url = "https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf"
+
+# model
+model = "bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0"
+
+image_content = [
+ {"type": "text", "text": "What's this file about?"},
+ {
+ "type": "image_url",
+ "image_url": image_url, # OR {"url": image_url}
+ },
+]
+
+
+if not supports_pdf_input(model, None):
+ print("Model does not support image input")
+
+response = completion(
+ model=model,
+ messages=[{"role": "user", "content": image_content}],
+)
+assert response is not None
+```
+
+
+
+1. Setup config.yaml
+
+```yaml
+model_list:
+ - model_name: bedrock-model
+ litellm_params:
+ model: bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0
+ aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID
+ aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY
+ aws_region_name: os.environ/AWS_REGION_NAME
+```
+
+2. Start the proxy
+
+```bash
+litellm --config /path/to/config.yaml
+```
+
+3. Test it!
+
+```bash
+curl -X POST 'http://0.0.0.0:4000/chat/completions' \
+-H 'Content-Type: application/json' \
+-H 'Authorization: Bearer sk-1234' \
+-d '{
+ "model": "bedrock-model",
+ "messages": [
+ {"role": "user", "content": {"type": "text", "text": "What's this file about?"}},
+ {
+ "type": "image_url",
+ "image_url": "https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf",
+ }
+ ]
+}'
+```
+
+
+
+### base64
+
+
+
+
+```python
+from litellm.utils import supports_pdf_input, completion
+
+# set aws credentials
+os.environ["AWS_ACCESS_KEY_ID"] = ""
+os.environ["AWS_SECRET_ACCESS_KEY"] = ""
+os.environ["AWS_REGION_NAME"] = ""
+
+
+# pdf url
+image_url = "https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf"
+response = requests.get(url)
+file_data = response.content
+
+encoded_file = base64.b64encode(file_data).decode("utf-8")
+base64_url = f"data:application/pdf;base64,{encoded_file}"
+
+# model
+model = "bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0"
+
+image_content = [
+ {"type": "text", "text": "What's this file about?"},
+ {
+ "type": "image_url",
+ "image_url": base64_url, # OR {"url": base64_url}
+ },
+]
+
+
+if not supports_pdf_input(model, None):
+ print("Model does not support image input")
+
+response = completion(
+ model=model,
+ messages=[{"role": "user", "content": image_content}],
+)
+assert response is not None
+```
+
+
+
+## Checking if a model supports pdf input
+
+
+
+
+Use `litellm.supports_pdf_input(model="bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0")` -> returns `True` if model can accept pdf input
+
+```python
+assert litellm.supports_pdf_input(model="bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0") == True
+```
+
+
+
+
+1. Define bedrock models on config.yaml
+
+```yaml
+model_list:
+ - model_name: bedrock-model # model group name
+ litellm_params:
+ model: bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0
+ aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID
+ aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY
+ aws_region_name: os.environ/AWS_REGION_NAME
+ model_info: # OPTIONAL - set manually
+ supports_pdf_input: True
+```
+
+2. Run proxy server
+
+```bash
+litellm --config config.yaml
+```
+
+3. Call `/model_group/info` to check if a model supports `pdf` input
+
+```shell
+curl -X 'GET' \
+ 'http://localhost:4000/model_group/info' \
+ -H 'accept: application/json' \
+ -H 'x-api-key: sk-1234'
+```
+
+Expected Response
+
+```json
+{
+ "data": [
+ {
+ "model_group": "bedrock-model",
+ "providers": ["bedrock"],
+ "max_input_tokens": 128000,
+ "max_output_tokens": 16384,
+ "mode": "chat",
+ ...,
+ "supports_pdf_input": true, # 👈 supports_pdf_input is true
+ }
+ ]
+}
+```
+
+
+
diff --git a/docs/my-website/docs/providers/bedrock.md b/docs/my-website/docs/providers/bedrock.md
index 579353d652..741f8fcf29 100644
--- a/docs/my-website/docs/providers/bedrock.md
+++ b/docs/my-website/docs/providers/bedrock.md
@@ -706,6 +706,37 @@ print(response)
+## Set 'converse' / 'invoke' route
+
+LiteLLM defaults to the `invoke` route. LiteLLM uses the `converse` route for Bedrock models that support it.
+
+To explicitly set the route, do `bedrock/converse/` or `bedrock/invoke/`.
+
+
+E.g.
+
+
+
+
+```python
+from litellm import completion
+
+completion(model="bedrock/converse/us.amazon.nova-pro-v1:0")
+```
+
+
+
+
+```yaml
+model_list:
+ - model_name: bedrock-model
+ litellm_params:
+ model: bedrock/converse/us.amazon.nova-pro-v1:0
+```
+
+
+
+
## Alternate user/assistant messages
Use `user_continue_message` to add a default user message, for cases (e.g. Autogen) where the client might not follow alternating user/assistant messages starting and ending with a user message.
@@ -745,6 +776,174 @@ curl -X POST 'http://0.0.0.0:4000/chat/completions' \
}'
```
+## Usage - PDF / Document Understanding
+
+LiteLLM supports Document Understanding for Bedrock models - [AWS Bedrock Docs](https://docs.aws.amazon.com/nova/latest/userguide/modalities-document.html).
+
+### url
+
+
+
+
+```python
+from litellm.utils import supports_pdf_input, completion
+
+# set aws credentials
+os.environ["AWS_ACCESS_KEY_ID"] = ""
+os.environ["AWS_SECRET_ACCESS_KEY"] = ""
+os.environ["AWS_REGION_NAME"] = ""
+
+
+# pdf url
+image_url = "https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf"
+
+# model
+model = "bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0"
+
+image_content = [
+ {"type": "text", "text": "What's this file about?"},
+ {
+ "type": "image_url",
+ "image_url": image_url, # OR {"url": image_url}
+ },
+]
+
+
+if not supports_pdf_input(model, None):
+ print("Model does not support image input")
+
+response = completion(
+ model=model,
+ messages=[{"role": "user", "content": image_content}],
+)
+assert response is not None
+```
+
+
+
+1. Setup config.yaml
+
+```yaml
+model_list:
+ - model_name: bedrock-model
+ litellm_params:
+ model: bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0
+ aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID
+ aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY
+ aws_region_name: os.environ/AWS_REGION_NAME
+```
+
+2. Start the proxy
+
+```bash
+litellm --config /path/to/config.yaml
+```
+
+3. Test it!
+
+```bash
+curl -X POST 'http://0.0.0.0:4000/chat/completions' \
+-H 'Content-Type: application/json' \
+-H 'Authorization: Bearer sk-1234' \
+-d '{
+ "model": "bedrock-model",
+ "messages": [
+ {"role": "user", "content": {"type": "text", "text": "What's this file about?"}},
+ {
+ "type": "image_url",
+ "image_url": "https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf",
+ }
+ ]
+}'
+```
+
+
+
+### base64
+
+
+
+
+```python
+from litellm.utils import supports_pdf_input, completion
+
+# set aws credentials
+os.environ["AWS_ACCESS_KEY_ID"] = ""
+os.environ["AWS_SECRET_ACCESS_KEY"] = ""
+os.environ["AWS_REGION_NAME"] = ""
+
+
+# pdf url
+image_url = "https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf"
+response = requests.get(url)
+file_data = response.content
+
+encoded_file = base64.b64encode(file_data).decode("utf-8")
+base64_url = f"data:application/pdf;base64,{encoded_file}"
+
+# model
+model = "bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0"
+
+image_content = [
+ {"type": "text", "text": "What's this file about?"},
+ {
+ "type": "image_url",
+ "image_url": base64_url, # OR {"url": base64_url}
+ },
+]
+
+
+if not supports_pdf_input(model, None):
+ print("Model does not support image input")
+
+response = completion(
+ model=model,
+ messages=[{"role": "user", "content": image_content}],
+)
+assert response is not None
+```
+
+
+
+1. Setup config.yaml
+
+```yaml
+model_list:
+ - model_name: bedrock-model
+ litellm_params:
+ model: bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0
+ aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID
+ aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY
+ aws_region_name: os.environ/AWS_REGION_NAME
+```
+
+2. Start the proxy
+
+```bash
+litellm --config /path/to/config.yaml
+```
+
+3. Test it!
+
+```bash
+curl -X POST 'http://0.0.0.0:4000/chat/completions' \
+-H 'Content-Type: application/json' \
+-H 'Authorization: Bearer sk-1234' \
+-d '{
+ "model": "bedrock-model",
+ "messages": [
+ {"role": "user", "content": {"type": "text", "text": "What's this file about?"}},
+ {
+ "type": "image_url",
+ "image_url": "data:application/pdf;base64,{b64_encoded_file}",
+ }
+ ]
+}'
+```
+
+
+
+
## Boto3 - Authentication
### Passing credentials as parameters - Completion()
diff --git a/docs/my-website/docs/providers/galadriel.md b/docs/my-website/docs/providers/galadriel.md
new file mode 100644
index 0000000000..73f1ec8e76
--- /dev/null
+++ b/docs/my-website/docs/providers/galadriel.md
@@ -0,0 +1,63 @@
+import Tabs from '@theme/Tabs';
+import TabItem from '@theme/TabItem';
+
+# Galadriel
+https://docs.galadriel.com/api-reference/chat-completion-API
+
+LiteLLM supports all models on Galadriel.
+
+## API Key
+```python
+import os
+os.environ['GALADRIEL_API_KEY'] = "your-api-key"
+```
+
+## Sample Usage
+```python
+from litellm import completion
+import os
+
+os.environ['GALADRIEL_API_KEY'] = ""
+response = completion(
+ model="galadriel/llama3.1",
+ messages=[
+ {"role": "user", "content": "hello from litellm"}
+ ],
+)
+print(response)
+```
+
+## Sample Usage - Streaming
+```python
+from litellm import completion
+import os
+
+os.environ['GALADRIEL_API_KEY'] = ""
+response = completion(
+ model="galadriel/llama3.1",
+ messages=[
+ {"role": "user", "content": "hello from litellm"}
+ ],
+ stream=True
+)
+
+for chunk in response:
+ print(chunk)
+```
+
+
+## Supported Models
+### Serverless Endpoints
+We support ALL Galadriel AI models, just set `galadriel/` as a prefix when sending completion requests
+
+We support both the complete model name and the simplified name match.
+
+You can specify the model name either with the full name or with a simplified version e.g. `llama3.1:70b`
+
+| Model Name | Simplified Name | Function Call |
+| -------------------------------------------------------- | -------------------------------- | ------------------------------------------------------- |
+| neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8 | llama3.1 or llama3.1:8b | `completion(model="galadriel/llama3.1", messages)` |
+| neuralmagic/Meta-Llama-3.1-70B-Instruct-quantized.w4a16 | llama3.1:70b | `completion(model="galadriel/llama3.1:70b", messages)` |
+| neuralmagic/Meta-Llama-3.1-405B-Instruct-quantized.w4a16 | llama3.1:405b | `completion(model="galadriel/llama3.1:405b", messages)` |
+| neuralmagic/Mistral-Nemo-Instruct-2407-quantized.w4a16 | mistral-nemo or mistral-nemo:12b | `completion(model="galadriel/mistral-nemo", messages)` |
+
diff --git a/docs/my-website/docs/proxy/config_settings.md b/docs/my-website/docs/proxy/config_settings.md
index c762a0716c..875c691e37 100644
--- a/docs/my-website/docs/proxy/config_settings.md
+++ b/docs/my-website/docs/proxy/config_settings.md
@@ -134,28 +134,9 @@ general_settings:
| content_policy_fallbacks | array of objects | Fallbacks to use when a ContentPolicyViolationError is encountered. [Further docs](./reliability#content-policy-fallbacks) |
| context_window_fallbacks | array of objects | Fallbacks to use when a ContextWindowExceededError is encountered. [Further docs](./reliability#context-window-fallbacks) |
| cache | boolean | If true, enables caching. [Further docs](./caching) |
-| cache_params | object | Parameters for the cache. [Further docs](./caching) |
-| cache_params.type | string | The type of cache to initialize. Can be one of ["local", "redis", "redis-semantic", "s3", "disk", "qdrant-semantic"]. Defaults to "redis". [Furher docs](./caching) |
-| cache_params.host | string | The host address for the Redis cache. Required if type is "redis". |
-| cache_params.port | integer | The port number for the Redis cache. Required if type is "redis". |
-| cache_params.password | string | The password for the Redis cache. Required if type is "redis". |
-| cache_params.namespace | string | The namespace for the Redis cache. |
-| cache_params.redis_startup_nodes | array of objects | Redis Cluster Settings. [Further docs](./caching) |
-| cache_params.service_name | string | Redis Sentinel Settings. [Further docs](./caching) |
-| cache_params.sentinel_nodes | array of arrays | Redis Sentinel Settings. [Further docs](./caching) |
-| cache_params.ttl | integer | The time (in seconds) to store entries in cache. |
-| cache_params.qdrant_semantic_cache_embedding_model | string | The embedding model to use for qdrant semantic cache. |
-| cache_params.qdrant_collection_name | string | The name of the collection to use for qdrant semantic cache. |
-| cache_params.qdrant_quantization_config | string | The quantization configuration for the qdrant semantic cache. |
-| cache_params.similarity_threshold | float | The similarity threshold for the semantic cache. |
-| cache_params.s3_bucket_name | string | The name of the S3 bucket to use for the semantic cache. |
-| cache_params.s3_region_name | string | The region name for the S3 bucket. |
-| cache_params.s3_aws_access_key_id | string | The AWS access key ID for the S3 bucket. |
-| cache_params.s3_aws_secret_access_key | string | The AWS secret access key for the S3 bucket. |
-| cache_params.s3_endpoint_url | string | Optional - The endpoint URL for the S3 bucket. |
-| cache_params.supported_call_types | array of strings | The types of calls to cache. [Further docs](./caching) |
-| cache_params.mode | string | The mode of the cache. [Further docs](./caching) |
+| cache_params | object | Parameters for the cache. [Further docs](./caching#supported-cache_params-on-proxy-configyaml) |
| disable_end_user_cost_tracking | boolean | If true, turns off end user cost tracking on prometheus metrics + litellm spend logs table on proxy. |
+| disable_end_user_cost_tracking_prometheus_only | boolean | If true, turns off end user cost tracking on prometheus metrics only. |
| key_generation_settings | object | Restricts who can generate keys. [Further docs](./virtual_keys.md#restricting-key-generation) |
### general_settings - Reference
diff --git a/docs/my-website/docs/proxy/enterprise.md b/docs/my-website/docs/proxy/enterprise.md
index a41d02bc25..f4d649fa2d 100644
--- a/docs/my-website/docs/proxy/enterprise.md
+++ b/docs/my-website/docs/proxy/enterprise.md
@@ -507,6 +507,11 @@ curl -X GET "http://0.0.0.0:4000/spend/logs?request_id=
+
+
+
**Step 1** Define all Params you want to enforce on config.yaml
This means `["user"]` and `["metadata]["generation_name"]` are required in all LLM Requests to LiteLLM
@@ -518,8 +523,21 @@ general_settings:
- user
- metadata.generation_name
```
+
-Start LiteLLM Proxy
+
+
+```bash
+curl -L -X POST 'http://0.0.0.0:4000/key/generate' \
+-H 'Authorization: Bearer sk-1234' \
+-H 'Content-Type: application/json' \
+-d '{
+ "enforced_params": ["user", "metadata.generation_name"]
+}'
+```
+
+
+
**Step 2 Verify if this works**
diff --git a/docs/my-website/docs/proxy/virtual_keys.md b/docs/my-website/docs/proxy/virtual_keys.md
index 5bbb6b2a00..2107698f32 100644
--- a/docs/my-website/docs/proxy/virtual_keys.md
+++ b/docs/my-website/docs/proxy/virtual_keys.md
@@ -828,8 +828,18 @@ litellm_settings:
#### Spec
```python
+key_generation_settings: Optional[StandardKeyGenerationConfig] = None
+```
+
+#### Types
+
+```python
+class StandardKeyGenerationConfig(TypedDict, total=False):
+ team_key_generation: TeamUIKeyGenerationConfig
+ personal_key_generation: PersonalUIKeyGenerationConfig
+
class TeamUIKeyGenerationConfig(TypedDict):
- allowed_team_member_roles: List[str]
+ allowed_team_member_roles: List[str] # either 'user' or 'admin'
required_params: List[str] # require params on `/key/generate` to be set if a team key (team_id in request) is being generated
@@ -838,11 +848,6 @@ class PersonalUIKeyGenerationConfig(TypedDict):
required_params: List[str] # require params on `/key/generate` to be set if a personal key (no team_id in request) is being generated
-class StandardKeyGenerationConfig(TypedDict, total=False):
- team_key_generation: TeamUIKeyGenerationConfig
- personal_key_generation: PersonalUIKeyGenerationConfig
-
-
class LitellmUserRoles(str, enum.Enum):
"""
Admin Roles:
diff --git a/docs/my-website/sidebars.js b/docs/my-website/sidebars.js
index e6a028d831..7e1b4bb257 100644
--- a/docs/my-website/sidebars.js
+++ b/docs/my-website/sidebars.js
@@ -177,6 +177,7 @@ const sidebars = {
"providers/ollama",
"providers/perplexity",
"providers/friendliai",
+ "providers/galadriel",
"providers/groq",
"providers/github",
"providers/deepseek",
@@ -210,6 +211,7 @@ const sidebars = {
"completion/provider_specific_params",
"guides/finetuned_models",
"completion/audio",
+ "completion/document_understanding",
"completion/vision",
"completion/json_mode",
"completion/prompt_caching",
diff --git a/litellm/__init__.py b/litellm/__init__.py
index d80caaf8f6..244b7e6e85 100644
--- a/litellm/__init__.py
+++ b/litellm/__init__.py
@@ -288,6 +288,7 @@ max_internal_user_budget: Optional[float] = None
internal_user_budget_duration: Optional[str] = None
max_end_user_budget: Optional[float] = None
disable_end_user_cost_tracking: Optional[bool] = None
+disable_end_user_cost_tracking_prometheus_only: Optional[bool] = None
#### REQUEST PRIORITIZATION ####
priority_reservation: Optional[Dict[str, float]] = None
#### RELIABILITY ####
@@ -385,6 +386,31 @@ organization = None
project = None
config_path = None
vertex_ai_safety_settings: Optional[dict] = None
+BEDROCK_CONVERSE_MODELS = [
+ "anthropic.claude-3-5-haiku-20241022-v1:0",
+ "anthropic.claude-3-5-sonnet-20241022-v2:0",
+ "anthropic.claude-3-5-sonnet-20240620-v1:0",
+ "anthropic.claude-3-opus-20240229-v1:0",
+ "anthropic.claude-3-sonnet-20240229-v1:0",
+ "anthropic.claude-3-haiku-20240307-v1:0",
+ "anthropic.claude-v2",
+ "anthropic.claude-v2:1",
+ "anthropic.claude-v1",
+ "anthropic.claude-instant-v1",
+ "ai21.jamba-instruct-v1:0",
+ "meta.llama3-70b-instruct-v1:0",
+ "meta.llama3-8b-instruct-v1:0",
+ "meta.llama3-1-8b-instruct-v1:0",
+ "meta.llama3-1-70b-instruct-v1:0",
+ "meta.llama3-1-405b-instruct-v1:0",
+ "meta.llama3-70b-instruct-v1:0",
+ "mistral.mistral-large-2407-v1:0",
+ "meta.llama3-2-1b-instruct-v1:0",
+ "meta.llama3-2-3b-instruct-v1:0",
+ "meta.llama3-2-11b-instruct-v1:0",
+ "meta.llama3-2-90b-instruct-v1:0",
+ "meta.llama3-2-405b-instruct-v1:0",
+]
####### COMPLETION MODELS ###################
open_ai_chat_completion_models: List = []
open_ai_text_completion_models: List = []
@@ -412,6 +438,7 @@ ai21_chat_models: List = []
nlp_cloud_models: List = []
aleph_alpha_models: List = []
bedrock_models: List = []
+bedrock_converse_models: List = BEDROCK_CONVERSE_MODELS
fireworks_ai_models: List = []
fireworks_ai_embedding_models: List = []
deepinfra_models: List = []
@@ -431,6 +458,7 @@ groq_models: List = []
azure_models: List = []
anyscale_models: List = []
cerebras_models: List = []
+galadriel_models: List = []
def add_known_models():
@@ -491,6 +519,8 @@ def add_known_models():
aleph_alpha_models.append(key)
elif value.get("litellm_provider") == "bedrock":
bedrock_models.append(key)
+ elif value.get("litellm_provider") == "bedrock_converse":
+ bedrock_converse_models.append(key)
elif value.get("litellm_provider") == "deepinfra":
deepinfra_models.append(key)
elif value.get("litellm_provider") == "perplexity":
@@ -535,6 +565,8 @@ def add_known_models():
anyscale_models.append(key)
elif value.get("litellm_provider") == "cerebras":
cerebras_models.append(key)
+ elif value.get("litellm_provider") == "galadriel":
+ galadriel_models.append(key)
add_known_models()
@@ -554,6 +586,7 @@ openai_compatible_endpoints: List = [
"inference.friendli.ai/v1",
"api.sambanova.ai/v1",
"api.x.ai/v1",
+ "api.galadriel.ai/v1",
]
# this is maintained for Exception Mapping
@@ -581,6 +614,7 @@ openai_compatible_providers: List = [
"litellm_proxy",
"hosted_vllm",
"lm_studio",
+ "galadriel",
]
openai_text_completion_compatible_providers: List = (
[ # providers that support `/v1/completions`
@@ -794,6 +828,7 @@ model_list = (
+ azure_models
+ anyscale_models
+ cerebras_models
+ + galadriel_models
)
@@ -860,6 +895,7 @@ class LlmProviders(str, Enum):
LITELLM_PROXY = "litellm_proxy"
HOSTED_VLLM = "hosted_vllm"
LM_STUDIO = "lm_studio"
+ GALADRIEL = "galadriel"
provider_list: List[Union[LlmProviders, str]] = list(LlmProviders)
@@ -882,7 +918,7 @@ models_by_provider: dict = {
+ vertex_vision_models
+ vertex_language_models,
"ai21": ai21_models,
- "bedrock": bedrock_models,
+ "bedrock": bedrock_models + bedrock_converse_models,
"petals": petals_models,
"ollama": ollama_models,
"deepinfra": deepinfra_models,
@@ -908,6 +944,7 @@ models_by_provider: dict = {
"azure": azure_models,
"anyscale": anyscale_models,
"cerebras": cerebras_models,
+ "galadriel": galadriel_models,
}
# mapping for those models which have larger equivalents
@@ -1067,9 +1104,6 @@ from .llms.bedrock.chat.invoke_handler import (
AmazonConverseConfig,
bedrock_tool_name_mappings,
)
-from .llms.bedrock.chat.converse_handler import (
- BEDROCK_CONVERSE_MODELS,
-)
from .llms.bedrock.common_utils import (
AmazonTitanConfig,
AmazonAI21Config,
diff --git a/litellm/integrations/prometheus.py b/litellm/integrations/prometheus.py
index 1460a1d7f0..b87f245240 100644
--- a/litellm/integrations/prometheus.py
+++ b/litellm/integrations/prometheus.py
@@ -365,7 +365,9 @@ class PrometheusLogger(CustomLogger):
model = kwargs.get("model", "")
litellm_params = kwargs.get("litellm_params", {}) or {}
_metadata = litellm_params.get("metadata", {})
- end_user_id = get_end_user_id_for_cost_tracking(litellm_params)
+ end_user_id = get_end_user_id_for_cost_tracking(
+ litellm_params, service_type="prometheus"
+ )
user_id = standard_logging_payload["metadata"]["user_api_key_user_id"]
user_api_key = standard_logging_payload["metadata"]["user_api_key_hash"]
user_api_key_alias = standard_logging_payload["metadata"]["user_api_key_alias"]
@@ -668,7 +670,9 @@ class PrometheusLogger(CustomLogger):
"standard_logging_object", {}
)
litellm_params = kwargs.get("litellm_params", {}) or {}
- end_user_id = get_end_user_id_for_cost_tracking(litellm_params)
+ end_user_id = get_end_user_id_for_cost_tracking(
+ litellm_params, service_type="prometheus"
+ )
user_id = standard_logging_payload["metadata"]["user_api_key_user_id"]
user_api_key = standard_logging_payload["metadata"]["user_api_key_hash"]
user_api_key_alias = standard_logging_payload["metadata"]["user_api_key_alias"]
diff --git a/litellm/litellm_core_utils/get_llm_provider_logic.py b/litellm/litellm_core_utils/get_llm_provider_logic.py
index 71eaaead0c..fea931491f 100644
--- a/litellm/litellm_core_utils/get_llm_provider_logic.py
+++ b/litellm/litellm_core_utils/get_llm_provider_logic.py
@@ -177,6 +177,9 @@ def get_llm_provider( # noqa: PLR0915
dynamic_api_key = get_secret_str(
"FRIENDLIAI_API_KEY"
) or get_secret("FRIENDLI_TOKEN")
+ elif endpoint == "api.galadriel.com/v1":
+ custom_llm_provider = "galadriel"
+ dynamic_api_key = get_secret_str("GALADRIEL_API_KEY")
if api_base is not None and not isinstance(api_base, str):
raise Exception(
@@ -526,6 +529,11 @@ def _get_openai_compatible_provider_info( # noqa: PLR0915
or get_secret_str("FRIENDLIAI_API_KEY")
or get_secret_str("FRIENDLI_TOKEN")
)
+ elif custom_llm_provider == "galadriel":
+ api_base = (
+ api_base or get_secret("GALADRIEL_API_BASE") or "https://api.galadriel.com/v1"
+ ) # type: ignore
+ dynamic_api_key = api_key or get_secret_str("GALADRIEL_API_KEY")
if api_base is not None and not isinstance(api_base, str):
raise Exception("api base needs to be a string. api_base={}".format(api_base))
if dynamic_api_key is not None and not isinstance(dynamic_api_key, str):
diff --git a/litellm/llms/bedrock/chat/converse_handler.py b/litellm/llms/bedrock/chat/converse_handler.py
index e47ba4f421..743e596f8e 100644
--- a/litellm/llms/bedrock/chat/converse_handler.py
+++ b/litellm/llms/bedrock/chat/converse_handler.py
@@ -18,32 +18,6 @@ from ...base_aws_llm import BaseAWSLLM
from ..common_utils import BedrockError
from .invoke_handler import AWSEventStreamDecoder, MockResponseIterator, make_call
-BEDROCK_CONVERSE_MODELS = [
- "anthropic.claude-3-5-haiku-20241022-v1:0",
- "anthropic.claude-3-5-sonnet-20241022-v2:0",
- "anthropic.claude-3-5-sonnet-20240620-v1:0",
- "anthropic.claude-3-opus-20240229-v1:0",
- "anthropic.claude-3-sonnet-20240229-v1:0",
- "anthropic.claude-3-haiku-20240307-v1:0",
- "anthropic.claude-v2",
- "anthropic.claude-v2:1",
- "anthropic.claude-v1",
- "anthropic.claude-instant-v1",
- "ai21.jamba-instruct-v1:0",
- "meta.llama3-70b-instruct-v1:0",
- "meta.llama3-8b-instruct-v1:0",
- "meta.llama3-1-8b-instruct-v1:0",
- "meta.llama3-1-70b-instruct-v1:0",
- "meta.llama3-1-405b-instruct-v1:0",
- "meta.llama3-70b-instruct-v1:0",
- "mistral.mistral-large-2407-v1:0",
- "meta.llama3-2-1b-instruct-v1:0",
- "meta.llama3-2-3b-instruct-v1:0",
- "meta.llama3-2-11b-instruct-v1:0",
- "meta.llama3-2-90b-instruct-v1:0",
- "meta.llama3-2-405b-instruct-v1:0",
-]
-
def make_sync_call(
client: Optional[HTTPHandler],
@@ -53,6 +27,8 @@ def make_sync_call(
model: str,
messages: list,
logging_obj,
+ json_mode: Optional[bool] = False,
+ fake_stream: bool = False,
):
if client is None:
client = _get_httpx_client() # Create a new client if none provided
@@ -61,13 +37,13 @@ def make_sync_call(
api_base,
headers=headers,
data=data,
- stream=True if "ai21" not in api_base else False,
+ stream=not fake_stream,
)
if response.status_code != 200:
raise BedrockError(status_code=response.status_code, message=response.read())
- if "ai21" in api_base:
+ if fake_stream:
model_response: (
ModelResponse
) = litellm.AmazonConverseConfig()._transform_response(
@@ -83,7 +59,9 @@ def make_sync_call(
print_verbose=litellm.print_verbose,
encoding=litellm.encoding,
) # type: ignore
- completion_stream: Any = MockResponseIterator(model_response=model_response)
+ completion_stream: Any = MockResponseIterator(
+ model_response=model_response, json_mode=json_mode
+ )
else:
decoder = AWSEventStreamDecoder(model=model)
completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024))
@@ -130,6 +108,8 @@ class BedrockConverseLLM(BaseAWSLLM):
logger_fn=None,
headers={},
client: Optional[AsyncHTTPHandler] = None,
+ fake_stream: bool = False,
+ json_mode: Optional[bool] = False,
) -> CustomStreamWrapper:
completion_stream = await make_call(
@@ -140,6 +120,8 @@ class BedrockConverseLLM(BaseAWSLLM):
model=model,
messages=messages,
logging_obj=logging_obj,
+ fake_stream=fake_stream,
+ json_mode=json_mode,
)
streaming_response = CustomStreamWrapper(
completion_stream=completion_stream,
@@ -231,12 +213,15 @@ class BedrockConverseLLM(BaseAWSLLM):
## SETUP ##
stream = optional_params.pop("stream", None)
modelId = optional_params.pop("model_id", None)
+ fake_stream = optional_params.pop("fake_stream", False)
+ json_mode = optional_params.get("json_mode", False)
if modelId is not None:
modelId = self.encode_model_id(model_id=modelId)
else:
modelId = model
- provider = model.split(".")[0]
+ if stream is True and "ai21" in modelId:
+ fake_stream = True
## CREDENTIALS ##
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
@@ -290,7 +275,7 @@ class BedrockConverseLLM(BaseAWSLLM):
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
aws_region_name=aws_region_name,
)
- if (stream is not None and stream is True) and provider != "ai21":
+ if (stream is not None and stream is True) and not fake_stream:
endpoint_url = f"{endpoint_url}/model/{modelId}/converse-stream"
proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/converse-stream"
else:
@@ -355,6 +340,8 @@ class BedrockConverseLLM(BaseAWSLLM):
headers=prepped.headers,
timeout=timeout,
client=client,
+ json_mode=json_mode,
+ fake_stream=fake_stream,
) # type: ignore
### ASYNC COMPLETION
return self.async_completion(
@@ -398,6 +385,8 @@ class BedrockConverseLLM(BaseAWSLLM):
model=model,
messages=messages,
logging_obj=logging_obj,
+ json_mode=json_mode,
+ fake_stream=fake_stream,
)
streaming_response = CustomStreamWrapper(
completion_stream=completion_stream,
diff --git a/litellm/llms/bedrock/chat/converse_transformation.py b/litellm/llms/bedrock/chat/converse_transformation.py
index 23ee97a47e..e8d7e80b74 100644
--- a/litellm/llms/bedrock/chat/converse_transformation.py
+++ b/litellm/llms/bedrock/chat/converse_transformation.py
@@ -134,6 +134,43 @@ class AmazonConverseConfig:
def get_supported_image_types(self) -> List[str]:
return ["png", "jpeg", "gif", "webp"]
+ def get_supported_document_types(self) -> List[str]:
+ return ["pdf", "csv", "doc", "docx", "xls", "xlsx", "html", "txt", "md"]
+
+ def _create_json_tool_call_for_response_format(
+ self,
+ json_schema: Optional[dict] = None,
+ schema_name: str = "json_tool_call",
+ ) -> ChatCompletionToolParam:
+ """
+ Handles creating a tool call for getting responses in JSON format.
+
+ Args:
+ json_schema (Optional[dict]): The JSON schema the response should be in
+
+ Returns:
+ AnthropicMessagesTool: The tool call to send to Anthropic API to get responses in JSON format
+ """
+
+ if json_schema is None:
+ # Anthropic raises a 400 BadRequest error if properties is passed as None
+ # see usage with additionalProperties (Example 5) https://github.com/anthropics/anthropic-cookbook/blob/main/tool_use/extracting_structured_json.ipynb
+ _input_schema = {
+ "type": "object",
+ "additionalProperties": True,
+ "properties": {},
+ }
+ else:
+ _input_schema = json_schema
+
+ _tool = ChatCompletionToolParam(
+ type="function",
+ function=ChatCompletionToolParamFunctionChunk(
+ name=schema_name, parameters=_input_schema
+ ),
+ )
+ return _tool
+
def map_openai_params(
self,
model: str,
@@ -160,31 +197,20 @@ class AmazonConverseConfig:
- You should set tool_choice (see Forcing tool use) to instruct the model to explicitly use that tool
- Remember that the model will pass the input to the tool, so the name of the tool and description should be from the model’s perspective.
"""
- if json_schema is not None:
- _tool_choice = self.map_tool_choice_values(
- model=model, tool_choice="required", drop_params=drop_params # type: ignore
+ _tool_choice = {"name": schema_name, "type": "tool"}
+ _tool = self._create_json_tool_call_for_response_format(
+ json_schema=json_schema,
+ schema_name=schema_name if schema_name != "" else "json_tool_call",
+ )
+ optional_params["tools"] = [_tool]
+ optional_params["tool_choice"] = ToolChoiceValuesBlock(
+ tool=SpecificToolChoiceBlock(
+ name=schema_name if schema_name != "" else "json_tool_call"
)
-
- _tool = ChatCompletionToolParam(
- type="function",
- function=ChatCompletionToolParamFunctionChunk(
- name=schema_name, parameters=json_schema
- ),
- )
-
- optional_params["tools"] = [_tool]
- optional_params["tool_choice"] = _tool_choice
- optional_params["json_mode"] = True
- else:
- if litellm.drop_params is True or drop_params is True:
- pass
- else:
- raise litellm.utils.UnsupportedParamsError(
- message="Bedrock doesn't support response_format={}. To drop it from the call, set `litellm.drop_params = True.".format(
- value
- ),
- status_code=400,
- )
+ )
+ optional_params["json_mode"] = True
+ if non_default_params.get("stream", False) is True:
+ optional_params["fake_stream"] = True
if param == "max_tokens" or param == "max_completion_tokens":
optional_params["maxTokens"] = value
if param == "stream":
@@ -330,7 +356,6 @@ class AmazonConverseConfig:
print_verbose,
encoding,
) -> Union[ModelResponse, CustomStreamWrapper]:
-
## LOGGING
if logging_obj is not None:
logging_obj.post_call(
@@ -468,6 +493,9 @@ class AmazonConverseConfig:
AND "meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1"
"""
+ if model.startswith("converse/"):
+ model = model.split("/")[1]
+
potential_region = model.split(".", 1)[0]
if potential_region in self._supported_cross_region_inference_region():
return model.split(".", 1)[1]
diff --git a/litellm/llms/bedrock/chat/invoke_handler.py b/litellm/llms/bedrock/chat/invoke_handler.py
index 7805f74dc3..3a72cf8057 100644
--- a/litellm/llms/bedrock/chat/invoke_handler.py
+++ b/litellm/llms/bedrock/chat/invoke_handler.py
@@ -182,6 +182,8 @@ async def make_call(
model: str,
messages: list,
logging_obj,
+ fake_stream: bool = False,
+ json_mode: Optional[bool] = False,
):
try:
if client is None:
@@ -193,13 +195,13 @@ async def make_call(
api_base,
headers=headers,
data=data,
- stream=True if "ai21" not in api_base else False,
+ stream=not fake_stream,
)
if response.status_code != 200:
raise BedrockError(status_code=response.status_code, message=response.text)
- if "ai21" in api_base:
+ if fake_stream:
model_response: (
ModelResponse
) = litellm.AmazonConverseConfig()._transform_response(
@@ -215,7 +217,9 @@ async def make_call(
print_verbose=litellm.print_verbose,
encoding=litellm.encoding,
) # type: ignore
- completion_stream: Any = MockResponseIterator(model_response=model_response)
+ completion_stream: Any = MockResponseIterator(
+ model_response=model_response, json_mode=json_mode
+ )
else:
decoder = AWSEventStreamDecoder(model=model)
completion_stream = decoder.aiter_bytes(
@@ -1028,6 +1032,7 @@ class BedrockLLM(BaseAWSLLM):
model=model,
messages=messages,
logging_obj=logging_obj,
+ fake_stream=True if "ai21" in api_base else False,
),
model=model,
custom_llm_provider="bedrock",
@@ -1271,21 +1276,58 @@ class AWSEventStreamDecoder:
class MockResponseIterator: # for returning ai21 streaming responses
- def __init__(self, model_response):
+ def __init__(self, model_response, json_mode: Optional[bool] = False):
self.model_response = model_response
+ self.json_mode = json_mode
self.is_done = False
# Sync iterator
def __iter__(self):
return self
- def _chunk_parser(self, chunk_data: ModelResponse) -> GChunk:
+ def _handle_json_mode_chunk(
+ self, text: str, tool_calls: Optional[List[ChatCompletionToolCallChunk]]
+ ) -> Tuple[str, Optional[ChatCompletionToolCallChunk]]:
+ """
+ If JSON mode is enabled, convert the tool call to a message.
+ Bedrock returns the JSON schema as part of the tool call
+ OpenAI returns the JSON schema as part of the content, this handles placing it in the content
+
+ Args:
+ text: str
+ tool_use: Optional[ChatCompletionToolCallChunk]
+ Returns:
+ Tuple[str, Optional[ChatCompletionToolCallChunk]]
+
+ text: The text to use in the content
+ tool_use: The ChatCompletionToolCallChunk to use in the chunk response
+ """
+ tool_use: Optional[ChatCompletionToolCallChunk] = None
+ if self.json_mode is True and tool_calls is not None:
+ message = litellm.AnthropicConfig()._convert_tool_response_to_message(
+ tool_calls=tool_calls
+ )
+ if message is not None:
+ text = message.content or ""
+ tool_use = None
+ elif tool_calls is not None and len(tool_calls) > 0:
+ tool_use = tool_calls[0]
+ return text, tool_use
+
+ def _chunk_parser(self, chunk_data: ModelResponse) -> GChunk:
try:
chunk_usage: litellm.Usage = getattr(chunk_data, "usage")
+ text = chunk_data.choices[0].message.content or "" # type: ignore
+ tool_use = None
+ if self.json_mode is True:
+ text, tool_use = self._handle_json_mode_chunk(
+ text=text,
+ tool_calls=chunk_data.choices[0].message.tool_calls, # type: ignore
+ )
processed_chunk = GChunk(
- text=chunk_data.choices[0].message.content or "", # type: ignore
- tool_use=None,
+ text=text,
+ tool_use=tool_use,
is_finished=True,
finish_reason=map_finish_reason(
finish_reason=chunk_data.choices[0].finish_reason or ""
@@ -1298,8 +1340,8 @@ class MockResponseIterator: # for returning ai21 streaming responses
index=0,
)
return processed_chunk
- except Exception:
- raise ValueError(f"Failed to decode chunk: {chunk_data}")
+ except Exception as e:
+ raise ValueError(f"Failed to decode chunk: {chunk_data}. Error: {e}")
def __next__(self):
if self.is_done:
diff --git a/litellm/llms/prompt_templates/factory.py b/litellm/llms/prompt_templates/factory.py
index bfd35ca475..2f55bb7bac 100644
--- a/litellm/llms/prompt_templates/factory.py
+++ b/litellm/llms/prompt_templates/factory.py
@@ -2162,8 +2162,9 @@ def stringify_json_tool_call_content(messages: List) -> List:
###### AMAZON BEDROCK #######
from litellm.types.llms.bedrock import ContentBlock as BedrockContentBlock
+from litellm.types.llms.bedrock import DocumentBlock as BedrockDocumentBlock
from litellm.types.llms.bedrock import ImageBlock as BedrockImageBlock
-from litellm.types.llms.bedrock import ImageSourceBlock as BedrockImageSourceBlock
+from litellm.types.llms.bedrock import SourceBlock as BedrockSourceBlock
from litellm.types.llms.bedrock import ToolBlock as BedrockToolBlock
from litellm.types.llms.bedrock import (
ToolChoiceValuesBlock as BedrockToolChoiceValuesBlock,
@@ -2210,7 +2211,9 @@ def get_image_details(image_url) -> Tuple[str, str]:
raise e
-def _process_bedrock_converse_image_block(image_url: str) -> BedrockImageBlock:
+def _process_bedrock_converse_image_block(
+ image_url: str,
+) -> BedrockContentBlock:
if "base64" in image_url:
# Case 1: Images with base64 encoding
import base64
@@ -2228,12 +2231,17 @@ def _process_bedrock_converse_image_block(image_url: str) -> BedrockImageBlock:
else:
mime_type = "image/jpeg"
image_format = "jpeg"
- _blob = BedrockImageSourceBlock(bytes=img_without_base_64)
+ _blob = BedrockSourceBlock(bytes=img_without_base_64)
supported_image_formats = (
litellm.AmazonConverseConfig().get_supported_image_types()
)
+ supported_document_types = (
+ litellm.AmazonConverseConfig().get_supported_document_types()
+ )
if image_format in supported_image_formats:
- return BedrockImageBlock(source=_blob, format=image_format) # type: ignore
+ return BedrockContentBlock(image=BedrockImageBlock(source=_blob, format=image_format)) # type: ignore
+ elif image_format in supported_document_types:
+ return BedrockContentBlock(document=BedrockDocumentBlock(source=_blob, format=image_format, name="DocumentPDFmessages_{}".format(str(uuid.uuid4())))) # type: ignore
else:
# Handle the case when the image format is not supported
raise ValueError(
@@ -2244,12 +2252,17 @@ def _process_bedrock_converse_image_block(image_url: str) -> BedrockImageBlock:
elif "https:/" in image_url:
# Case 2: Images with direct links
image_bytes, image_format = get_image_details(image_url)
- _blob = BedrockImageSourceBlock(bytes=image_bytes)
+ _blob = BedrockSourceBlock(bytes=image_bytes)
supported_image_formats = (
litellm.AmazonConverseConfig().get_supported_image_types()
)
+ supported_document_types = (
+ litellm.AmazonConverseConfig().get_supported_document_types()
+ )
if image_format in supported_image_formats:
- return BedrockImageBlock(source=_blob, format=image_format) # type: ignore
+ return BedrockContentBlock(image=BedrockImageBlock(source=_blob, format=image_format)) # type: ignore
+ elif image_format in supported_document_types:
+ return BedrockContentBlock(document=BedrockDocumentBlock(source=_blob, format=image_format, name="DocumentPDFmessages_{}".format(str(uuid.uuid4())))) # type: ignore
else:
# Handle the case when the image format is not supported
raise ValueError(
@@ -2464,11 +2477,14 @@ def _bedrock_converse_messages_pt( # noqa: PLR0915
_part = BedrockContentBlock(text=element["text"])
_parts.append(_part)
elif element["type"] == "image_url":
- image_url = element["image_url"]["url"]
+ if isinstance(element["image_url"], dict):
+ image_url = element["image_url"]["url"]
+ else:
+ image_url = element["image_url"]
_part = _process_bedrock_converse_image_block( # type: ignore
image_url=image_url
)
- _parts.append(BedrockContentBlock(image=_part)) # type: ignore
+ _parts.append(_part) # type: ignore
user_content.extend(_parts)
else:
_part = BedrockContentBlock(text=messages[msg_i]["content"])
@@ -2539,13 +2555,14 @@ def _bedrock_converse_messages_pt( # noqa: PLR0915
assistants_part = BedrockContentBlock(text=element["text"])
assistants_parts.append(assistants_part)
elif element["type"] == "image_url":
- image_url = element["image_url"]["url"]
+ if isinstance(element["image_url"], dict):
+ image_url = element["image_url"]["url"]
+ else:
+ image_url = element["image_url"]
assistants_part = _process_bedrock_converse_image_block( # type: ignore
image_url=image_url
)
- assistants_parts.append(
- BedrockContentBlock(image=assistants_part) # type: ignore
- )
+ assistants_parts.append(assistants_part)
assistant_content.extend(assistants_parts)
elif messages[msg_i].get("content", None) is not None and isinstance(
messages[msg_i]["content"], str
diff --git a/litellm/main.py b/litellm/main.py
index a32e8b6c05..d17c73c91e 100644
--- a/litellm/main.py
+++ b/litellm/main.py
@@ -2603,7 +2603,10 @@ def completion( # type: ignore # noqa: PLR0915
base_model = litellm.AmazonConverseConfig()._get_base_model(model)
- if base_model in litellm.BEDROCK_CONVERSE_MODELS:
+ if base_model in litellm.bedrock_converse_models or model.startswith(
+ "converse/"
+ ):
+ model = model.replace("converse/", "")
response = bedrock_converse_chat_completion.completion(
model=model,
messages=messages,
@@ -2622,6 +2625,7 @@ def completion( # type: ignore # noqa: PLR0915
api_base=api_base,
)
else:
+ model = model.replace("invoke/", "")
response = bedrock_chat_completion.completion(
model=model,
messages=messages,
diff --git a/litellm/model_prices_and_context_window_backup.json b/litellm/model_prices_and_context_window_backup.json
index ac22871bcc..4d89091c4d 100644
--- a/litellm/model_prices_and_context_window_backup.json
+++ b/litellm/model_prices_and_context_window_backup.json
@@ -4795,6 +4795,42 @@
"mode": "chat",
"supports_function_calling": true
},
+ "amazon.nova-micro-v1:0": {
+ "max_tokens": 4096,
+ "max_input_tokens": 300000,
+ "max_output_tokens": 4096,
+ "input_cost_per_token": 0.000000035,
+ "output_cost_per_token": 0.00000014,
+ "litellm_provider": "bedrock_converse",
+ "mode": "chat",
+ "supports_function_calling": true,
+ "supports_vision": true,
+ "supports_pdf_input": true
+ },
+ "amazon.nova-lite-v1:0": {
+ "max_tokens": 4096,
+ "max_input_tokens": 128000,
+ "max_output_tokens": 4096,
+ "input_cost_per_token": 0.00000006,
+ "output_cost_per_token": 0.00000024,
+ "litellm_provider": "bedrock_converse",
+ "mode": "chat",
+ "supports_function_calling": true,
+ "supports_vision": true,
+ "supports_pdf_input": true
+ },
+ "amazon.nova-pro-v1:0": {
+ "max_tokens": 4096,
+ "max_input_tokens": 300000,
+ "max_output_tokens": 4096,
+ "input_cost_per_token": 0.0000008,
+ "output_cost_per_token": 0.0000032,
+ "litellm_provider": "bedrock_converse",
+ "mode": "chat",
+ "supports_function_calling": true,
+ "supports_vision": true,
+ "supports_pdf_input": true
+ },
"anthropic.claude-3-sonnet-20240229-v1:0": {
"max_tokens": 4096,
"max_input_tokens": 200000,
diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml
index b9fae4d254..f2569f2f28 100644
--- a/litellm/proxy/_new_secret_config.yaml
+++ b/litellm/proxy/_new_secret_config.yaml
@@ -3,6 +3,12 @@ model_list:
- model_name: gpt-4
litellm_params:
model: gpt-4
+ rpm: 1
+ - model_name: gpt-4
+ litellm_params:
+ model: azure/chatgpt-v-2
+ api_key: os.environ/AZURE_API_KEY
+ api_base: os.environ/AZURE_API_BASE
- model_name: rerank-model
litellm_params:
model: jina_ai/jina-reranker-v2-base-multilingual
diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py
index d2b417c9d4..d170cd92b6 100644
--- a/litellm/proxy/_types.py
+++ b/litellm/proxy/_types.py
@@ -667,6 +667,7 @@ class _GenerateKeyRequest(GenerateRequestBase):
class GenerateKeyRequest(_GenerateKeyRequest):
tags: Optional[List[str]] = None
+ enforced_params: Optional[List[str]] = None
class GenerateKeyResponse(_GenerateKeyRequest):
@@ -2190,4 +2191,5 @@ LiteLLM_ManagementEndpoint_MetadataFields = [
"model_tpm_limit",
"guardrails",
"tags",
+ "enforced_params",
]
diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py
index 21a25c8c1a..179a6fcf54 100644
--- a/litellm/proxy/auth/auth_checks.py
+++ b/litellm/proxy/auth/auth_checks.py
@@ -156,40 +156,6 @@ def common_checks( # noqa: PLR0915
raise Exception(
f"'user' param not passed in. 'enforce_user_param'={general_settings['enforce_user_param']}"
)
- if general_settings.get("enforced_params", None) is not None:
- # Enterprise ONLY Feature
- # we already validate if user is premium_user when reading the config
- # Add an extra premium_usercheck here too, just incase
- from litellm.proxy.proxy_server import CommonProxyErrors, premium_user
-
- if premium_user is not True:
- raise ValueError(
- "Trying to use `enforced_params`"
- + CommonProxyErrors.not_premium_user.value
- )
-
- if RouteChecks.is_llm_api_route(route=route):
- # loop through each enforced param
- # example enforced_params ['user', 'metadata', 'metadata.generation_name']
- for enforced_param in general_settings["enforced_params"]:
- _enforced_params = enforced_param.split(".")
- if len(_enforced_params) == 1:
- if _enforced_params[0] not in request_body:
- raise ValueError(
- f"BadRequest please pass param={_enforced_params[0]} in request body. This is a required param"
- )
- elif len(_enforced_params) == 2:
- # this is a scenario where user requires request['metadata']['generation_name'] to exist
- if _enforced_params[0] not in request_body:
- raise ValueError(
- f"BadRequest please pass param={_enforced_params[0]} in request body. This is a required param"
- )
- if _enforced_params[1] not in request_body[_enforced_params[0]]:
- raise ValueError(
- f"BadRequest please pass param=[{_enforced_params[0]}][{_enforced_params[1]}] in request body. This is a required param"
- )
-
- pass
# 7. [OPTIONAL] If 'litellm.max_budget' is set (>0), is proxy under budget
if (
litellm.max_budget > 0
diff --git a/litellm/proxy/litellm_pre_call_utils.py b/litellm/proxy/litellm_pre_call_utils.py
index 9e47c20d6f..f690f517c7 100644
--- a/litellm/proxy/litellm_pre_call_utils.py
+++ b/litellm/proxy/litellm_pre_call_utils.py
@@ -587,6 +587,16 @@ async def add_litellm_data_to_request( # noqa: PLR0915
f"[PROXY]returned data from litellm_pre_call_utils: {data}"
)
+ ## ENFORCED PARAMS CHECK
+ # loop through each enforced param
+ # example enforced_params ['user', 'metadata', 'metadata.generation_name']
+ _enforced_params_check(
+ request_body=data,
+ general_settings=general_settings,
+ user_api_key_dict=user_api_key_dict,
+ premium_user=premium_user,
+ )
+
end_time = time.time()
await service_logger_obj.async_service_success_hook(
service=ServiceTypes.PROXY_PRE_CALL,
@@ -599,6 +609,64 @@ async def add_litellm_data_to_request( # noqa: PLR0915
return data
+def _get_enforced_params(
+ general_settings: Optional[dict], user_api_key_dict: UserAPIKeyAuth
+) -> Optional[list]:
+ enforced_params: Optional[list] = None
+ if general_settings is not None:
+ enforced_params = general_settings.get("enforced_params")
+ if "service_account_settings" in general_settings:
+ service_account_settings = general_settings["service_account_settings"]
+ if "enforced_params" in service_account_settings:
+ if enforced_params is None:
+ enforced_params = []
+ enforced_params.extend(service_account_settings["enforced_params"])
+ if user_api_key_dict.metadata.get("enforced_params", None) is not None:
+ if enforced_params is None:
+ enforced_params = []
+ enforced_params.extend(user_api_key_dict.metadata["enforced_params"])
+ return enforced_params
+
+
+def _enforced_params_check(
+ request_body: dict,
+ general_settings: Optional[dict],
+ user_api_key_dict: UserAPIKeyAuth,
+ premium_user: bool,
+) -> bool:
+ """
+ If enforced params are set, check if the request body contains the enforced params.
+ """
+ enforced_params: Optional[list] = _get_enforced_params(
+ general_settings=general_settings, user_api_key_dict=user_api_key_dict
+ )
+ if enforced_params is None:
+ return True
+ if enforced_params is not None and premium_user is not True:
+ raise ValueError(
+ f"Enforced Params is an Enterprise feature. Enforced Params: {enforced_params}. {CommonProxyErrors.not_premium_user.value}"
+ )
+
+ for enforced_param in enforced_params:
+ _enforced_params = enforced_param.split(".")
+ if len(_enforced_params) == 1:
+ if _enforced_params[0] not in request_body:
+ raise ValueError(
+ f"BadRequest please pass param={_enforced_params[0]} in request body. This is a required param"
+ )
+ elif len(_enforced_params) == 2:
+ # this is a scenario where user requires request['metadata']['generation_name'] to exist
+ if _enforced_params[0] not in request_body:
+ raise ValueError(
+ f"BadRequest please pass param={_enforced_params[0]} in request body. This is a required param"
+ )
+ if _enforced_params[1] not in request_body[_enforced_params[0]]:
+ raise ValueError(
+ f"BadRequest please pass param=[{_enforced_params[0]}][{_enforced_params[1]}] in request body. This is a required param"
+ )
+ return True
+
+
def move_guardrails_to_metadata(
data: dict,
_metadata_variable_name: str,
diff --git a/litellm/proxy/management_endpoints/key_management_endpoints.py b/litellm/proxy/management_endpoints/key_management_endpoints.py
index 287de56964..ee1b9bd8b3 100644
--- a/litellm/proxy/management_endpoints/key_management_endpoints.py
+++ b/litellm/proxy/management_endpoints/key_management_endpoints.py
@@ -255,6 +255,7 @@ async def generate_key_fn( # noqa: PLR0915
- tpm_limit: Optional[int] - Specify tpm limit for a given key (Tokens per minute)
- soft_budget: Optional[float] - Specify soft budget for a given key. Will trigger a slack alert when this soft budget is reached.
- tags: Optional[List[str]] - Tags for [tracking spend](https://litellm.vercel.app/docs/proxy/enterprise#tracking-spend-for-custom-tags) and/or doing [tag-based routing](https://litellm.vercel.app/docs/proxy/tag_routing).
+ - enforced_params: Optional[List[str]] - List of enforced params for the key (Enterprise only). [Docs](https://docs.litellm.ai/docs/proxy/enterprise#enforce-required-params-for-llm-requests)
Examples:
@@ -459,7 +460,6 @@ def prepare_metadata_fields(
"""
Check LiteLLM_ManagementEndpoint_MetadataFields (proxy/_types.py) for fields that are allowed to be updated
"""
-
if "metadata" not in non_default_values: # allow user to set metadata to none
non_default_values["metadata"] = existing_metadata.copy()
@@ -469,18 +469,8 @@ def prepare_metadata_fields(
try:
for k, v in data_json.items():
- if k == "model_tpm_limit" or k == "model_rpm_limit":
- if k not in casted_metadata or casted_metadata[k] is None:
- casted_metadata[k] = {}
- casted_metadata[k].update(v)
-
- if k == "tags" or k == "guardrails":
- if k not in casted_metadata or casted_metadata[k] is None:
- casted_metadata[k] = []
- seen = set(casted_metadata[k])
- casted_metadata[k].extend(
- x for x in v if x not in seen and not seen.add(x) # type: ignore
- ) # prevent duplicates from being added + maintain initial order
+ if k in LiteLLM_ManagementEndpoint_MetadataFields:
+ casted_metadata[k] = v
except Exception as e:
verbose_proxy_logger.exception(
@@ -498,10 +488,9 @@ def prepare_key_update_data(
):
data_json: dict = data.model_dump(exclude_unset=True)
data_json.pop("key", None)
- _metadata_fields = ["model_rpm_limit", "model_tpm_limit", "guardrails", "tags"]
non_default_values = {}
for k, v in data_json.items():
- if k in _metadata_fields:
+ if k in LiteLLM_ManagementEndpoint_MetadataFields:
continue
non_default_values[k] = v
@@ -556,6 +545,7 @@ async def update_key_fn(
- team_id: Optional[str] - Team ID associated with key
- models: Optional[list] - Model_name's a user is allowed to call
- tags: Optional[List[str]] - Tags for organizing keys (Enterprise only)
+ - enforced_params: Optional[List[str]] - List of enforced params for the key (Enterprise only). [Docs](https://docs.litellm.ai/docs/proxy/enterprise#enforce-required-params-for-llm-requests)
- spend: Optional[float] - Amount spent by key
- max_budget: Optional[float] - Max budget for key
- model_max_budget: Optional[dict] - Model-specific budgets {"gpt-4": 0.5, "claude-v1": 1.0}
diff --git a/litellm/router.py b/litellm/router.py
index bbee4d7bb2..800c7a9b95 100644
--- a/litellm/router.py
+++ b/litellm/router.py
@@ -2940,6 +2940,7 @@ class Router:
remaining_retries=num_retries,
num_retries=num_retries,
healthy_deployments=_healthy_deployments,
+ all_deployments=_all_deployments,
)
await asyncio.sleep(retry_after)
@@ -2972,6 +2973,7 @@ class Router:
remaining_retries=remaining_retries,
num_retries=num_retries,
healthy_deployments=_healthy_deployments,
+ all_deployments=_all_deployments,
)
await asyncio.sleep(_timeout)
@@ -3149,6 +3151,7 @@ class Router:
remaining_retries: int,
num_retries: int,
healthy_deployments: Optional[List] = None,
+ all_deployments: Optional[List] = None,
) -> Union[int, float]:
"""
Calculate back-off, then retry
@@ -3157,10 +3160,14 @@ class Router:
1. there are healthy deployments in the same model group
2. there are fallbacks for the completion call
"""
- if (
+
+ ## base case - single deployment
+ if all_deployments is not None and len(all_deployments) == 1:
+ pass
+ elif (
healthy_deployments is not None
and isinstance(healthy_deployments, list)
- and len(healthy_deployments) > 1
+ and len(healthy_deployments) > 0
):
return 0
@@ -3242,6 +3249,7 @@ class Router:
remaining_retries=num_retries,
num_retries=num_retries,
healthy_deployments=_healthy_deployments,
+ all_deployments=_all_deployments,
)
## LOGGING
@@ -3276,6 +3284,7 @@ class Router:
remaining_retries=remaining_retries,
num_retries=num_retries,
healthy_deployments=_healthy_deployments,
+ all_deployments=_all_deployments,
)
time.sleep(_timeout)
diff --git a/litellm/types/llms/bedrock.py b/litellm/types/llms/bedrock.py
index c80b16f6e9..88f329adeb 100644
--- a/litellm/types/llms/bedrock.py
+++ b/litellm/types/llms/bedrock.py
@@ -18,17 +18,24 @@ class SystemContentBlock(TypedDict):
text: str
-class ImageSourceBlock(TypedDict):
+class SourceBlock(TypedDict):
bytes: Optional[str] # base 64 encoded string
class ImageBlock(TypedDict):
format: Literal["png", "jpeg", "gif", "webp"]
- source: ImageSourceBlock
+ source: SourceBlock
+
+
+class DocumentBlock(TypedDict):
+ format: Literal["pdf", "csv", "doc", "docx", "xls", "xlsx", "html", "txt", "md"]
+ source: SourceBlock
+ name: str
class ToolResultContentBlock(TypedDict, total=False):
image: ImageBlock
+ document: DocumentBlock
json: dict
text: str
@@ -48,6 +55,7 @@ class ToolUseBlock(TypedDict):
class ContentBlock(TypedDict, total=False):
text: str
image: ImageBlock
+ document: DocumentBlock
toolResult: ToolResultBlock
toolUse: ToolUseBlock
diff --git a/litellm/types/utils.py b/litellm/types/utils.py
index 93b4a39d3b..3467e1a107 100644
--- a/litellm/types/utils.py
+++ b/litellm/types/utils.py
@@ -106,6 +106,7 @@ class ModelInfo(TypedDict, total=False):
supports_prompt_caching: Optional[bool]
supports_audio_input: Optional[bool]
supports_audio_output: Optional[bool]
+ supports_pdf_input: Optional[bool]
tpm: Optional[int]
rpm: Optional[int]
diff --git a/litellm/utils.py b/litellm/utils.py
index 946d81982b..86c0a60294 100644
--- a/litellm/utils.py
+++ b/litellm/utils.py
@@ -3142,7 +3142,7 @@ def get_optional_params( # noqa: PLR0915
model=model, custom_llm_provider=custom_llm_provider
)
base_model = litellm.AmazonConverseConfig()._get_base_model(model)
- if base_model in litellm.BEDROCK_CONVERSE_MODELS:
+ if base_model in litellm.bedrock_converse_models:
_check_valid_arg(supported_params=supported_params)
optional_params = litellm.AmazonConverseConfig().map_openai_params(
model=model,
@@ -4308,6 +4308,10 @@ def _strip_stable_vertex_version(model_name) -> str:
return re.sub(r"-\d+$", "", model_name)
+def _strip_bedrock_region(model_name) -> str:
+ return litellm.AmazonConverseConfig()._get_base_model(model_name)
+
+
def _strip_openai_finetune_model_name(model_name: str) -> str:
"""
Strips the organization, custom suffix, and ID from an OpenAI fine-tuned model name.
@@ -4324,16 +4328,50 @@ def _strip_openai_finetune_model_name(model_name: str) -> str:
return re.sub(r"(:[^:]+){3}$", "", model_name)
-def _strip_model_name(model: str) -> str:
- strip_version = _strip_stable_vertex_version(model_name=model)
- strip_finetune = _strip_openai_finetune_model_name(model_name=strip_version)
- return strip_finetune
+def _strip_model_name(model: str, custom_llm_provider: Optional[str]) -> str:
+ if custom_llm_provider and custom_llm_provider == "bedrock":
+ strip_bedrock_region = _strip_bedrock_region(model_name=model)
+ return strip_bedrock_region
+ elif custom_llm_provider and (
+ custom_llm_provider == "vertex_ai" or custom_llm_provider == "gemini"
+ ):
+ strip_version = _strip_stable_vertex_version(model_name=model)
+ return strip_version
+ else:
+ strip_finetune = _strip_openai_finetune_model_name(model_name=model)
+ return strip_finetune
def _get_model_info_from_model_cost(key: str) -> dict:
return litellm.model_cost[key]
+def _check_provider_match(model_info: dict, custom_llm_provider: Optional[str]) -> bool:
+ """
+ Check if the model info provider matches the custom provider.
+ """
+ if custom_llm_provider and (
+ "litellm_provider" in model_info
+ and model_info["litellm_provider"] != custom_llm_provider
+ ):
+ if custom_llm_provider == "vertex_ai" and model_info[
+ "litellm_provider"
+ ].startswith("vertex_ai"):
+ return True
+ elif custom_llm_provider == "fireworks_ai" and model_info[
+ "litellm_provider"
+ ].startswith("fireworks_ai"):
+ return True
+ elif custom_llm_provider == "bedrock" and model_info[
+ "litellm_provider"
+ ].startswith("bedrock"):
+ return True
+ else:
+ return False
+
+ return True
+
+
def get_model_info( # noqa: PLR0915
model: str, custom_llm_provider: Optional[str] = None
) -> ModelInfo:
@@ -4388,6 +4426,7 @@ def get_model_info( # noqa: PLR0915
supports_prompt_caching: Optional[bool]
supports_audio_input: Optional[bool]
supports_audio_output: Optional[bool]
+ supports_pdf_input: Optional[bool]
Raises:
Exception: If the model is not mapped yet.
@@ -4445,15 +4484,21 @@ def get_model_info( # noqa: PLR0915
except Exception:
split_model = model
combined_model_name = model
- stripped_model_name = _strip_model_name(model=model)
+ stripped_model_name = _strip_model_name(
+ model=model, custom_llm_provider=custom_llm_provider
+ )
combined_stripped_model_name = stripped_model_name
else:
split_model = model
combined_model_name = "{}/{}".format(custom_llm_provider, model)
- stripped_model_name = _strip_model_name(model=model)
- combined_stripped_model_name = "{}/{}".format(
- custom_llm_provider, _strip_model_name(model=model)
+ stripped_model_name = _strip_model_name(
+ model=model, custom_llm_provider=custom_llm_provider
)
+ combined_stripped_model_name = "{}/{}".format(
+ custom_llm_provider,
+ _strip_model_name(model=model, custom_llm_provider=custom_llm_provider),
+ )
+
#########################
supported_openai_params = litellm.get_supported_openai_params(
@@ -4476,6 +4521,7 @@ def get_model_info( # noqa: PLR0915
supports_function_calling=None,
supports_assistant_prefill=None,
supports_prompt_caching=None,
+ supports_pdf_input=None,
)
elif custom_llm_provider == "ollama" or custom_llm_provider == "ollama_chat":
return litellm.OllamaConfig().get_model_info(model)
@@ -4488,40 +4534,25 @@ def get_model_info( # noqa: PLR0915
4. 'stripped_model_name' in litellm.model_cost. Checks if 'ft:gpt-3.5-turbo' in model map, if 'ft:gpt-3.5-turbo:my-org:custom_suffix:id' given.
5. 'split_model' in litellm.model_cost. Checks "llama3-8b-8192" in litellm.model_cost if model="groq/llama3-8b-8192"
"""
+
_model_info: Optional[Dict[str, Any]] = None
key: Optional[str] = None
if combined_model_name in litellm.model_cost:
key = combined_model_name
_model_info = _get_model_info_from_model_cost(key=key)
_model_info["supported_openai_params"] = supported_openai_params
- if (
- "litellm_provider" in _model_info
- and _model_info["litellm_provider"] != custom_llm_provider
+ if not _check_provider_match(
+ model_info=_model_info, custom_llm_provider=custom_llm_provider
):
- if custom_llm_provider == "vertex_ai" and _model_info[
- "litellm_provider"
- ].startswith("vertex_ai"):
- pass
- else:
- _model_info = None
+ _model_info = None
if _model_info is None and model in litellm.model_cost:
key = model
_model_info = _get_model_info_from_model_cost(key=key)
_model_info["supported_openai_params"] = supported_openai_params
- if (
- "litellm_provider" in _model_info
- and _model_info["litellm_provider"] != custom_llm_provider
+ if not _check_provider_match(
+ model_info=_model_info, custom_llm_provider=custom_llm_provider
):
- if custom_llm_provider == "vertex_ai" and _model_info[
- "litellm_provider"
- ].startswith("vertex_ai"):
- pass
- elif custom_llm_provider == "fireworks_ai" and _model_info[
- "litellm_provider"
- ].startswith("fireworks_ai"):
- pass
- else:
- _model_info = None
+ _model_info = None
if (
_model_info is None
and combined_stripped_model_name in litellm.model_cost
@@ -4529,57 +4560,26 @@ def get_model_info( # noqa: PLR0915
key = combined_stripped_model_name
_model_info = _get_model_info_from_model_cost(key=key)
_model_info["supported_openai_params"] = supported_openai_params
- if (
- "litellm_provider" in _model_info
- and _model_info["litellm_provider"] != custom_llm_provider
+ if not _check_provider_match(
+ model_info=_model_info, custom_llm_provider=custom_llm_provider
):
- if custom_llm_provider == "vertex_ai" and _model_info[
- "litellm_provider"
- ].startswith("vertex_ai"):
- pass
- elif custom_llm_provider == "fireworks_ai" and _model_info[
- "litellm_provider"
- ].startswith("fireworks_ai"):
- pass
- else:
- _model_info = None
+ _model_info = None
if _model_info is None and stripped_model_name in litellm.model_cost:
key = stripped_model_name
_model_info = _get_model_info_from_model_cost(key=key)
_model_info["supported_openai_params"] = supported_openai_params
- if (
- "litellm_provider" in _model_info
- and _model_info["litellm_provider"] != custom_llm_provider
+ if not _check_provider_match(
+ model_info=_model_info, custom_llm_provider=custom_llm_provider
):
- if custom_llm_provider == "vertex_ai" and _model_info[
- "litellm_provider"
- ].startswith("vertex_ai"):
- pass
- elif custom_llm_provider == "fireworks_ai" and _model_info[
- "litellm_provider"
- ].startswith("fireworks_ai"):
- pass
- else:
- _model_info = None
-
+ _model_info = None
if _model_info is None and split_model in litellm.model_cost:
key = split_model
_model_info = _get_model_info_from_model_cost(key=key)
_model_info["supported_openai_params"] = supported_openai_params
- if (
- "litellm_provider" in _model_info
- and _model_info["litellm_provider"] != custom_llm_provider
+ if not _check_provider_match(
+ model_info=_model_info, custom_llm_provider=custom_llm_provider
):
- if custom_llm_provider == "vertex_ai" and _model_info[
- "litellm_provider"
- ].startswith("vertex_ai"):
- pass
- elif custom_llm_provider == "fireworks_ai" and _model_info[
- "litellm_provider"
- ].startswith("fireworks_ai"):
- pass
- else:
- _model_info = None
+ _model_info = None
if _model_info is None or key is None:
raise ValueError(
"This model isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json"
@@ -4675,6 +4675,7 @@ def get_model_info( # noqa: PLR0915
),
supports_audio_input=_model_info.get("supports_audio_input", False),
supports_audio_output=_model_info.get("supports_audio_output", False),
+ supports_pdf_input=_model_info.get("supports_pdf_input", False),
tpm=_model_info.get("tpm", None),
rpm=_model_info.get("rpm", None),
)
@@ -6195,11 +6196,21 @@ class ProviderConfigManager:
return OpenAIGPTConfig()
-def get_end_user_id_for_cost_tracking(litellm_params: dict) -> Optional[str]:
+def get_end_user_id_for_cost_tracking(
+ litellm_params: dict,
+ service_type: Literal["litellm_logging", "prometheus"] = "litellm_logging",
+) -> Optional[str]:
"""
Used for enforcing `disable_end_user_cost_tracking` param.
+
+ service_type: "litellm_logging" or "prometheus" - used to allow prometheus only disable cost tracking.
"""
proxy_server_request = litellm_params.get("proxy_server_request") or {}
if litellm.disable_end_user_cost_tracking:
return None
+ if (
+ service_type == "prometheus"
+ and litellm.disable_end_user_cost_tracking_prometheus_only
+ ):
+ return None
return proxy_server_request.get("body", {}).get("user", None)
diff --git a/model_prices_and_context_window.json b/model_prices_and_context_window.json
index ac22871bcc..4d89091c4d 100644
--- a/model_prices_and_context_window.json
+++ b/model_prices_and_context_window.json
@@ -4795,6 +4795,42 @@
"mode": "chat",
"supports_function_calling": true
},
+ "amazon.nova-micro-v1:0": {
+ "max_tokens": 4096,
+ "max_input_tokens": 300000,
+ "max_output_tokens": 4096,
+ "input_cost_per_token": 0.000000035,
+ "output_cost_per_token": 0.00000014,
+ "litellm_provider": "bedrock_converse",
+ "mode": "chat",
+ "supports_function_calling": true,
+ "supports_vision": true,
+ "supports_pdf_input": true
+ },
+ "amazon.nova-lite-v1:0": {
+ "max_tokens": 4096,
+ "max_input_tokens": 128000,
+ "max_output_tokens": 4096,
+ "input_cost_per_token": 0.00000006,
+ "output_cost_per_token": 0.00000024,
+ "litellm_provider": "bedrock_converse",
+ "mode": "chat",
+ "supports_function_calling": true,
+ "supports_vision": true,
+ "supports_pdf_input": true
+ },
+ "amazon.nova-pro-v1:0": {
+ "max_tokens": 4096,
+ "max_input_tokens": 300000,
+ "max_output_tokens": 4096,
+ "input_cost_per_token": 0.0000008,
+ "output_cost_per_token": 0.0000032,
+ "litellm_provider": "bedrock_converse",
+ "mode": "chat",
+ "supports_function_calling": true,
+ "supports_vision": true,
+ "supports_pdf_input": true
+ },
"anthropic.claude-3-sonnet-20240229-v1:0": {
"max_tokens": 4096,
"max_input_tokens": 200000,
diff --git a/tests/llm_translation/base_llm_unit_tests.py b/tests/llm_translation/base_llm_unit_tests.py
index a4551378eb..5004d45994 100644
--- a/tests/llm_translation/base_llm_unit_tests.py
+++ b/tests/llm_translation/base_llm_unit_tests.py
@@ -54,6 +54,36 @@ class BaseLLMChatTest(ABC):
# for OpenAI the content contains the JSON schema, so we need to assert that the content is not None
assert response.choices[0].message.content is not None
+ @pytest.mark.parametrize("image_url", ["str", "dict"])
+ def test_pdf_handling(self, pdf_messages, image_url):
+ from litellm.utils import supports_pdf_input
+
+ if image_url == "str":
+ image_url = pdf_messages
+ elif image_url == "dict":
+ image_url = {"url": pdf_messages}
+
+ image_content = [
+ {"type": "text", "text": "What's this file about?"},
+ {
+ "type": "image_url",
+ "image_url": image_url,
+ },
+ ]
+
+ image_messages = [{"role": "user", "content": image_content}]
+
+ base_completion_call_args = self.get_base_completion_call_args()
+
+ if not supports_pdf_input(base_completion_call_args["model"], None):
+ pytest.skip("Model does not support image input")
+
+ response = litellm.completion(
+ **base_completion_call_args,
+ messages=image_messages,
+ )
+ assert response is not None
+
def test_message_with_name(self):
litellm.set_verbose = True
base_completion_call_args = self.get_base_completion_call_args()
@@ -187,7 +217,7 @@ class BaseLLMChatTest(ABC):
for chunk in response:
content += chunk.choices[0].delta.content or ""
- print("content=", content)
+ print(f"content={content}")
# OpenAI guarantees that the JSON schema is returned in the content
# relevant issue: https://github.com/BerriAI/litellm/issues/6741
@@ -250,7 +280,7 @@ class BaseLLMChatTest(ABC):
import requests
# URL of the file
- url = "https://storage.googleapis.com/cloud-samples-data/generative-ai/pdf/2403.05530.pdf"
+ url = "https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf"
response = requests.get(url)
file_data = response.content
@@ -258,14 +288,4 @@ class BaseLLMChatTest(ABC):
encoded_file = base64.b64encode(file_data).decode("utf-8")
url = f"data:application/pdf;base64,{encoded_file}"
- image_content = [
- {"type": "text", "text": "What's this file about?"},
- {
- "type": "image_url",
- "image_url": {"url": url},
- },
- ]
-
- image_messages = [{"role": "user", "content": image_content}]
-
- return image_messages
+ return url
diff --git a/tests/llm_translation/test_anthropic_completion.py b/tests/llm_translation/test_anthropic_completion.py
index a61bdfd731..b5a10953d4 100644
--- a/tests/llm_translation/test_anthropic_completion.py
+++ b/tests/llm_translation/test_anthropic_completion.py
@@ -668,35 +668,6 @@ class TestAnthropicCompletion(BaseLLMChatTest):
def get_base_completion_call_args(self) -> dict:
return {"model": "claude-3-haiku-20240307"}
- def test_pdf_handling(self, pdf_messages):
- from litellm.llms.custom_httpx.http_handler import HTTPHandler
- from litellm.types.llms.anthropic import AnthropicMessagesDocumentParam
- import json
-
- client = HTTPHandler()
-
- with patch.object(client, "post", new=MagicMock()) as mock_client:
- response = completion(
- model="claude-3-5-sonnet-20241022",
- messages=pdf_messages,
- client=client,
- )
-
- mock_client.assert_called_once()
-
- json_data = json.loads(mock_client.call_args.kwargs["data"])
- headers = mock_client.call_args.kwargs["headers"]
-
- assert headers["anthropic-beta"] == "pdfs-2024-09-25"
-
- json_data["messages"][0]["role"] == "user"
- _document_validation = AnthropicMessagesDocumentParam(
- **json_data["messages"][0]["content"][1]
- )
- assert _document_validation["type"] == "document"
- assert _document_validation["source"]["media_type"] == "application/pdf"
- assert _document_validation["source"]["type"] == "base64"
-
def test_tool_call_no_arguments(self, tool_call_no_arguments):
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
from litellm.llms.prompt_templates.factory import (
diff --git a/tests/llm_translation/test_bedrock_completion.py b/tests/llm_translation/test_bedrock_completion.py
index b2054dc232..a89b1c943d 100644
--- a/tests/llm_translation/test_bedrock_completion.py
+++ b/tests/llm_translation/test_bedrock_completion.py
@@ -30,6 +30,7 @@ from litellm import (
from litellm.llms.bedrock.chat import BedrockLLM
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.llms.prompt_templates.factory import _bedrock_tools_pt
+from base_llm_unit_tests import BaseLLMChatTest
# litellm.num_retries = 3
litellm.cache = None
@@ -1943,3 +1944,42 @@ def test_bedrock_context_window_error():
messages=[{"role": "user", "content": "Hello, world!"}],
mock_response=Exception("prompt is too long"),
)
+
+
+def test_bedrock_converse_route():
+ litellm.set_verbose = True
+ litellm.completion(
+ model="bedrock/converse/us.amazon.nova-pro-v1:0",
+ messages=[{"role": "user", "content": "Hello, world!"}],
+ )
+
+
+def test_bedrock_mapped_converse_models():
+ litellm.set_verbose = True
+ os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
+ litellm.model_cost = litellm.get_model_cost_map(url="")
+ litellm.add_known_models()
+ litellm.completion(
+ model="bedrock/us.amazon.nova-pro-v1:0",
+ messages=[{"role": "user", "content": "Hello, world!"}],
+ )
+
+
+def test_bedrock_base_model_helper():
+ model = "us.amazon.nova-pro-v1:0"
+ litellm.AmazonConverseConfig()._get_base_model(model)
+ assert model == "us.amazon.nova-pro-v1:0"
+
+
+class TestBedrockConverseChat(BaseLLMChatTest):
+ def get_base_completion_call_args(self) -> dict:
+ os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
+ litellm.model_cost = litellm.get_model_cost_map(url="")
+ litellm.add_known_models()
+ return {
+ "model": "bedrock/us.anthropic.claude-3-haiku-20240307-v1:0",
+ }
+
+ def test_tool_call_no_arguments(self, tool_call_no_arguments):
+ """Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
+ pass
diff --git a/tests/local_testing/test_completion.py b/tests/local_testing/test_completion.py
index f69778e484..6f42d9a396 100644
--- a/tests/local_testing/test_completion.py
+++ b/tests/local_testing/test_completion.py
@@ -695,29 +695,6 @@ async def test_anthropic_no_content_error():
pytest.fail(f"An unexpected error occurred - {str(e)}")
-def test_gemini_completion_call_error():
- try:
- print("test completion + streaming")
- litellm.num_retries = 3
- litellm.set_verbose = True
- messages = [{"role": "user", "content": "what is the capital of congo?"}]
- response = completion(
- model="gemini/gemini-1.5-pro-latest",
- messages=messages,
- stream=True,
- max_tokens=10,
- )
- print(f"response: {response}")
- for chunk in response:
- print(chunk)
- except litellm.RateLimitError:
- pass
- except litellm.InternalServerError:
- pass
- except Exception as e:
- pytest.fail(f"error occurred: {str(e)}")
-
-
def test_completion_cohere_command_r_plus_function_call():
litellm.set_verbose = True
tools = [
diff --git a/tests/local_testing/test_router.py b/tests/local_testing/test_router.py
index 7b53d42db0..d3db083f68 100644
--- a/tests/local_testing/test_router.py
+++ b/tests/local_testing/test_router.py
@@ -2342,12 +2342,6 @@ def test_router_dynamic_cooldown_correct_retry_after_time():
pass
new_retry_after_mock_client.assert_called()
- print(
- f"new_retry_after_mock_client.call_args.kwargs: {new_retry_after_mock_client.call_args.kwargs}"
- )
- print(
- f"new_retry_after_mock_client.call_args: {new_retry_after_mock_client.call_args[0][0]}"
- )
response_headers: httpx.Headers = new_retry_after_mock_client.call_args[0][0]
assert int(response_headers["retry-after"]) == cooldown_time
diff --git a/tests/local_testing/test_router_retries.py b/tests/local_testing/test_router_retries.py
index 6922f55ca5..24b46b6549 100644
--- a/tests/local_testing/test_router_retries.py
+++ b/tests/local_testing/test_router_retries.py
@@ -539,45 +539,71 @@ def test_raise_context_window_exceeded_error_no_retry():
"""
-def test_timeout_for_rate_limit_error_with_healthy_deployments():
+@pytest.mark.parametrize("num_deployments, expected_timeout", [(1, 60), (2, 0.0)])
+def test_timeout_for_rate_limit_error_with_healthy_deployments(
+ num_deployments, expected_timeout
+):
"""
Test 1. Timeout is 0.0 when RateLimit Error and healthy deployments are > 0
"""
- healthy_deployments = [
- "deployment1",
- "deployment2",
- ] # multiple healthy deployments mocked up
- fallbacks = None
-
- router = litellm.Router(
- model_list=[
- {
- "model_name": "gpt-3.5-turbo",
- "litellm_params": {
- "model": "azure/chatgpt-v-2",
- "api_key": os.getenv("AZURE_API_KEY"),
- "api_version": os.getenv("AZURE_API_VERSION"),
- "api_base": os.getenv("AZURE_API_BASE"),
- },
- }
- ]
+ cooldown_time = 60
+ rate_limit_error = litellm.RateLimitError(
+ message="{RouterErrors.no_deployments_available.value}. 12345 Passed model={model_group}. Deployments={deployment_dict}",
+ llm_provider="",
+ model="gpt-3.5-turbo",
+ response=httpx.Response(
+ status_code=429,
+ content="",
+ headers={"retry-after": str(cooldown_time)}, # type: ignore
+ request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore
+ ),
)
+ model_list = [
+ {
+ "model_name": "gpt-3.5-turbo",
+ "litellm_params": {
+ "model": "azure/chatgpt-v-2",
+ "api_key": os.getenv("AZURE_API_KEY"),
+ "api_version": os.getenv("AZURE_API_VERSION"),
+ "api_base": os.getenv("AZURE_API_BASE"),
+ },
+ }
+ ]
+ if num_deployments == 2:
+ model_list.append(
+ {
+ "model_name": "gpt-4",
+ "litellm_params": {"model": "gpt-3.5-turbo"},
+ }
+ )
+
+ router = litellm.Router(model_list=model_list)
_timeout = router._time_to_sleep_before_retry(
e=rate_limit_error,
- remaining_retries=4,
- num_retries=4,
- healthy_deployments=healthy_deployments,
+ remaining_retries=2,
+ num_retries=2,
+ healthy_deployments=[
+ {
+ "model_name": "gpt-4",
+ "litellm_params": {
+ "api_key": "my-key",
+ "api_base": "https://openai-gpt-4-test-v-1.openai.azure.com",
+ "model": "azure/chatgpt-v-2",
+ },
+ "model_info": {
+ "id": "0e30bc8a63fa91ae4415d4234e231b3f9e6dd900cac57d118ce13a720d95e9d6",
+ "db_model": False,
+ },
+ }
+ ],
+ all_deployments=model_list,
)
- print(
- "timeout=",
- _timeout,
- "error is rate_limit_error and there are healthy deployments=",
- healthy_deployments,
- )
-
- assert _timeout == 0.0
+ if expected_timeout == 0.0:
+ assert _timeout == expected_timeout
+ else:
+ assert _timeout > 0.0
def test_timeout_for_rate_limit_error_with_no_healthy_deployments():
@@ -585,26 +611,26 @@ def test_timeout_for_rate_limit_error_with_no_healthy_deployments():
Test 2. Timeout is > 0.0 when RateLimit Error and healthy deployments == 0
"""
healthy_deployments = []
+ model_list = [
+ {
+ "model_name": "gpt-3.5-turbo",
+ "litellm_params": {
+ "model": "azure/chatgpt-v-2",
+ "api_key": os.getenv("AZURE_API_KEY"),
+ "api_version": os.getenv("AZURE_API_VERSION"),
+ "api_base": os.getenv("AZURE_API_BASE"),
+ },
+ }
+ ]
- router = litellm.Router(
- model_list=[
- {
- "model_name": "gpt-3.5-turbo",
- "litellm_params": {
- "model": "azure/chatgpt-v-2",
- "api_key": os.getenv("AZURE_API_KEY"),
- "api_version": os.getenv("AZURE_API_VERSION"),
- "api_base": os.getenv("AZURE_API_BASE"),
- },
- }
- ]
- )
+ router = litellm.Router(model_list=model_list)
_timeout = router._time_to_sleep_before_retry(
e=rate_limit_error,
remaining_retries=4,
num_retries=4,
healthy_deployments=healthy_deployments,
+ all_deployments=model_list,
)
print(
diff --git a/tests/local_testing/test_utils.py b/tests/local_testing/test_utils.py
index 7c349a6580..54c2e8c6c3 100644
--- a/tests/local_testing/test_utils.py
+++ b/tests/local_testing/test_utils.py
@@ -1005,7 +1005,10 @@ def test_models_by_provider():
continue
elif k == "sample_spec":
continue
- elif v["litellm_provider"] == "sagemaker":
+ elif (
+ v["litellm_provider"] == "sagemaker"
+ or v["litellm_provider"] == "bedrock_converse"
+ ):
continue
else:
providers.add(v["litellm_provider"])
@@ -1032,3 +1035,27 @@ def test_get_end_user_id_for_cost_tracking(
get_end_user_id_for_cost_tracking(litellm_params=litellm_params)
== expected_end_user_id
)
+
+
+@pytest.mark.parametrize(
+ "litellm_params, disable_end_user_cost_tracking_prometheus_only, expected_end_user_id",
+ [
+ ({}, False, None),
+ ({"proxy_server_request": {"body": {"user": "123"}}}, False, "123"),
+ ({"proxy_server_request": {"body": {"user": "123"}}}, True, None),
+ ],
+)
+def test_get_end_user_id_for_cost_tracking_prometheus_only(
+ litellm_params, disable_end_user_cost_tracking_prometheus_only, expected_end_user_id
+):
+ from litellm.utils import get_end_user_id_for_cost_tracking
+
+ litellm.disable_end_user_cost_tracking_prometheus_only = (
+ disable_end_user_cost_tracking_prometheus_only
+ )
+ assert (
+ get_end_user_id_for_cost_tracking(
+ litellm_params=litellm_params, service_type="prometheus"
+ )
+ == expected_end_user_id
+ )
diff --git a/tests/proxy_admin_ui_tests/test_key_management.py b/tests/proxy_admin_ui_tests/test_key_management.py
index 7a2764e3f2..e1fe3e4d30 100644
--- a/tests/proxy_admin_ui_tests/test_key_management.py
+++ b/tests/proxy_admin_ui_tests/test_key_management.py
@@ -695,45 +695,43 @@ def test_personal_key_generation_check():
)
-def test_prepare_metadata_fields():
+@pytest.mark.parametrize(
+ "update_request_data, non_default_values, existing_metadata, expected_result",
+ [
+ (
+ {"metadata": {"test": "new"}},
+ {"metadata": {"test": "new"}},
+ {"test": "test"},
+ {"metadata": {"test": "new"}},
+ ),
+ (
+ {"tags": ["new_tag"]},
+ {},
+ {"tags": ["old_tag"]},
+ {"metadata": {"tags": ["new_tag"]}},
+ ),
+ (
+ {"enforced_params": ["metadata.tags"]},
+ {},
+ {"tags": ["old_tag"]},
+ {"metadata": {"tags": ["old_tag"], "enforced_params": ["metadata.tags"]}},
+ ),
+ ],
+)
+def test_prepare_metadata_fields(
+ update_request_data, non_default_values, existing_metadata, expected_result
+):
from litellm.proxy.management_endpoints.key_management_endpoints import (
prepare_metadata_fields,
)
- new_metadata = {"test": "new"}
- old_metadata = {"test": "test"}
-
args = {
"data": UpdateKeyRequest(
- key_alias=None,
- duration=None,
- models=[],
- spend=None,
- max_budget=None,
- user_id=None,
- team_id=None,
- max_parallel_requests=None,
- metadata=new_metadata,
- tpm_limit=None,
- rpm_limit=None,
- budget_duration=None,
- allowed_cache_controls=[],
- soft_budget=None,
- config={},
- permissions={},
- model_max_budget={},
- send_invite_email=None,
- model_rpm_limit=None,
- model_tpm_limit=None,
- guardrails=None,
- blocked=None,
- aliases={},
- key="sk-1qGQUJJTcljeaPfzgWRrXQ",
- tags=None,
+ key="sk-1qGQUJJTcljeaPfzgWRrXQ", **update_request_data
),
- "non_default_values": {"metadata": new_metadata},
- "existing_metadata": {"tags": None, **old_metadata},
+ "non_default_values": non_default_values,
+ "existing_metadata": existing_metadata,
}
- non_default_values = prepare_metadata_fields(**args)
- assert non_default_values == {"metadata": new_metadata}
+ updated_non_default_values = prepare_metadata_fields(**args)
+ assert updated_non_default_values == expected_result
diff --git a/tests/proxy_unit_tests/test_key_generate_prisma.py b/tests/proxy_unit_tests/test_key_generate_prisma.py
index e1720654b1..2c8ba5b2ab 100644
--- a/tests/proxy_unit_tests/test_key_generate_prisma.py
+++ b/tests/proxy_unit_tests/test_key_generate_prisma.py
@@ -2664,66 +2664,6 @@ async def test_create_update_team(prisma_client):
)
-@pytest.mark.asyncio()
-async def test_enforced_params(prisma_client):
- setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
- setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
- from litellm.proxy.proxy_server import general_settings
-
- general_settings["enforced_params"] = [
- "user",
- "metadata",
- "metadata.generation_name",
- ]
-
- await litellm.proxy.proxy_server.prisma_client.connect()
- request = NewUserRequest()
- key = await new_user(
- data=request,
- user_api_key_dict=UserAPIKeyAuth(
- user_role=LitellmUserRoles.PROXY_ADMIN,
- api_key="sk-1234",
- user_id="1234",
- ),
- )
- print(key)
-
- generated_key = key.key
- bearer_token = "Bearer " + generated_key
-
- request = Request(scope={"type": "http"})
- request._url = URL(url="/chat/completions")
-
- # Case 1: Missing user
- async def return_body():
- return b'{"model": "gemini-pro-vision"}'
-
- request.body = return_body
- try:
- await user_api_key_auth(request=request, api_key=bearer_token)
- pytest.fail(f"This should have failed!. IT's an invalid request")
- except Exception as e:
- assert (
- "BadRequest please pass param=user in request body. This is a required param"
- in e.message
- )
-
- # Case 2: Missing metadata["generation_name"]
- async def return_body_2():
- return b'{"model": "gemini-pro-vision", "user": "1234", "metadata": {}}'
-
- request.body = return_body_2
- try:
- await user_api_key_auth(request=request, api_key=bearer_token)
- pytest.fail(f"This should have failed!. IT's an invalid request")
- except Exception as e:
- assert (
- "Authentication Error, BadRequest please pass param=[metadata][generation_name] in request body"
- in e.message
- )
- general_settings.pop("enforced_params")
-
-
@pytest.mark.asyncio()
async def test_update_user_role(prisma_client):
"""
@@ -3363,64 +3303,6 @@ async def test_auth_vertex_ai_route(prisma_client):
pass
-@pytest.mark.asyncio
-async def test_service_accounts(prisma_client):
- """
- Do not delete
- this is the Admin UI flow
- """
- # Make a call to a key with model = `all-proxy-models` this is an Alias from LiteLLM Admin UI
- setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
- setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
- setattr(
- litellm.proxy.proxy_server,
- "general_settings",
- {"service_account_settings": {"enforced_params": ["user"]}},
- )
-
- await litellm.proxy.proxy_server.prisma_client.connect()
-
- request = GenerateKeyRequest(
- metadata={"service_account_id": f"prod-service-{uuid.uuid4()}"},
- )
- response = await generate_key_fn(
- data=request,
- )
-
- print("key generated=", response)
- generated_key = response.key
- bearer_token = "Bearer " + generated_key
- # make a bad /chat/completions call expect it to fail
-
- request = Request(scope={"type": "http"})
- request._url = URL(url="/chat/completions")
-
- async def return_body():
- return b'{"model": "gemini-pro-vision"}'
-
- request.body = return_body
-
- # use generated key to auth in
- print("Bearer token being sent to user_api_key_auth() - {}".format(bearer_token))
- try:
- result = await user_api_key_auth(request=request, api_key=bearer_token)
- pytest.fail("Expected this call to fail. Bad request using service account")
- except Exception as e:
- print("error str=", str(e.message))
- assert "This is a required param for service account" in str(e.message)
-
- # make a good /chat/completions call it should pass
- async def good_return_body():
- return b'{"model": "gemini-pro-vision", "user": "foo"}'
-
- request.body = good_return_body
-
- result = await user_api_key_auth(request=request, api_key=bearer_token)
- print("response from user_api_key_auth", result)
-
- setattr(litellm.proxy.proxy_server, "general_settings", {})
-
-
@pytest.mark.asyncio
async def test_user_api_key_auth_db_unavailable():
"""
diff --git a/tests/proxy_unit_tests/test_proxy_utils.py b/tests/proxy_unit_tests/test_proxy_utils.py
index 6de47b6eed..7910cc7ea9 100644
--- a/tests/proxy_unit_tests/test_proxy_utils.py
+++ b/tests/proxy_unit_tests/test_proxy_utils.py
@@ -679,3 +679,132 @@ async def test_add_litellm_data_to_request_duplicate_tags(
assert sorted(result["metadata"]["tags"]) == sorted(
expected_tags
), f"Expected {expected_tags}, got {result['metadata']['tags']}"
+
+
+@pytest.mark.parametrize(
+ "general_settings, user_api_key_dict, expected_enforced_params",
+ [
+ (
+ {"enforced_params": ["param1", "param2"]},
+ UserAPIKeyAuth(
+ api_key="test_api_key", user_id="test_user_id", org_id="test_org_id"
+ ),
+ ["param1", "param2"],
+ ),
+ (
+ {"service_account_settings": {"enforced_params": ["param1", "param2"]}},
+ UserAPIKeyAuth(
+ api_key="test_api_key", user_id="test_user_id", org_id="test_org_id"
+ ),
+ ["param1", "param2"],
+ ),
+ (
+ {"service_account_settings": {"enforced_params": ["param1", "param2"]}},
+ UserAPIKeyAuth(
+ api_key="test_api_key",
+ metadata={"enforced_params": ["param3", "param4"]},
+ ),
+ ["param1", "param2", "param3", "param4"],
+ ),
+ ],
+)
+def test_get_enforced_params(
+ general_settings, user_api_key_dict, expected_enforced_params
+):
+ from litellm.proxy.litellm_pre_call_utils import _get_enforced_params
+
+ enforced_params = _get_enforced_params(general_settings, user_api_key_dict)
+ assert enforced_params == expected_enforced_params
+
+
+@pytest.mark.parametrize(
+ "general_settings, user_api_key_dict, request_body, expected_error",
+ [
+ (
+ {"enforced_params": ["param1", "param2"]},
+ UserAPIKeyAuth(
+ api_key="test_api_key", user_id="test_user_id", org_id="test_org_id"
+ ),
+ {},
+ True,
+ ),
+ (
+ {"service_account_settings": {"enforced_params": ["user"]}},
+ UserAPIKeyAuth(
+ api_key="test_api_key", user_id="test_user_id", org_id="test_org_id"
+ ),
+ {},
+ True,
+ ),
+ (
+ {},
+ UserAPIKeyAuth(
+ api_key="test_api_key",
+ metadata={"enforced_params": ["user"]},
+ ),
+ {},
+ True,
+ ),
+ (
+ {},
+ UserAPIKeyAuth(
+ api_key="test_api_key",
+ metadata={"enforced_params": ["user"]},
+ ),
+ {"user": "test_user"},
+ False,
+ ),
+ (
+ {"enforced_params": ["user"]},
+ UserAPIKeyAuth(
+ api_key="test_api_key",
+ ),
+ {"user": "test_user"},
+ False,
+ ),
+ (
+ {"service_account_settings": {"enforced_params": ["user"]}},
+ UserAPIKeyAuth(
+ api_key="test_api_key",
+ ),
+ {"user": "test_user"},
+ False,
+ ),
+ (
+ {"enforced_params": ["metadata.generation_name"]},
+ UserAPIKeyAuth(
+ api_key="test_api_key",
+ ),
+ {"metadata": {}},
+ True,
+ ),
+ (
+ {"enforced_params": ["metadata.generation_name"]},
+ UserAPIKeyAuth(
+ api_key="test_api_key",
+ ),
+ {"metadata": {"generation_name": "test_generation_name"}},
+ False,
+ ),
+ ],
+)
+def test_enforced_params_check(
+ general_settings, user_api_key_dict, request_body, expected_error
+):
+ from litellm.proxy.litellm_pre_call_utils import _enforced_params_check
+
+ if expected_error:
+ with pytest.raises(ValueError):
+ _enforced_params_check(
+ request_body=request_body,
+ general_settings=general_settings,
+ user_api_key_dict=user_api_key_dict,
+ premium_user=True,
+ )
+ else:
+ _enforced_params_check(
+ request_body=request_body,
+ general_settings=general_settings,
+ user_api_key_dict=user_api_key_dict,
+ premium_user=True,
+ )