mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
fix(key_management_endpoints.py): override metadata field value on up… (#7008)
* fix(key_management_endpoints.py): override metadata field value on update allow user to override tags * feat(__init__.py): expose new disable_end_user_cost_tracking_prometheus_only metric allow disabling end user cost tracking on prometheus - fixes cardinality issue * fix(litellm_pre_call_utils.py): add key/team level enforced params Fixes https://github.com/BerriAI/litellm/issues/6652 * fix(key_management_endpoints.py): allow user to pass in `enforced_params` as a top level param on /key/generate and /key/update * docs(enterprise.md): add docs on enforcing required params for llm requests * Add support of Galadriel API (#7005) * fix(router.py): robust retry after handling set retry after time to 0 if >0 healthy deployments. handle base case = 1 deployment * test(test_router.py): fix test * feat(bedrock/): add support for 'nova' models also adds explicit 'converse/' route for simpler routing * fix: fix 'supports_pdf_input' return if model supports pdf input on get_model_info * feat(converse_transformation.py): support bedrock pdf input * docs(document_understanding.md): add document understanding to docs * fix(litellm_pre_call_utils.py): fix linting error * fix(init.py): fix passing of bedrock converse models * feat(bedrock/converse): support 'response_format={"type": "json_object"}' * fix(converse_handler.py): fix linting error * fix(base_llm_unit_tests.py): fix test * fix: fix test * test: fix test * test: fix test * test: remove duplicate test --------- Co-authored-by: h4n0 <4738254+h4n0@users.noreply.github.com>
This commit is contained in:
parent
d558b643be
commit
6bb934c0ac
37 changed files with 1297 additions and 503 deletions
|
@ -271,6 +271,7 @@ curl 'http://0.0.0.0:4000/key/generate' \
|
|||
| [voyage ai](https://docs.litellm.ai/docs/providers/voyage) | | | | | ✅ | |
|
||||
| [xinference [Xorbits Inference]](https://docs.litellm.ai/docs/providers/xinference) | | | | | ✅ | |
|
||||
| [FriendliAI](https://docs.litellm.ai/docs/providers/friendliai) | ✅ | ✅ | ✅ | ✅ | | |
|
||||
| [Galadriel](https://docs.litellm.ai/docs/providers/galadriel) | ✅ | ✅ | ✅ | ✅ | | |
|
||||
|
||||
[**Read the Docs**](https://docs.litellm.ai/docs/)
|
||||
|
||||
|
|
202
docs/my-website/docs/completion/document_understanding.md
Normal file
202
docs/my-website/docs/completion/document_understanding.md
Normal file
|
@ -0,0 +1,202 @@
|
|||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# Using PDF Input
|
||||
|
||||
How to send / receieve pdf's (other document types) to a `/chat/completions` endpoint
|
||||
|
||||
Works for:
|
||||
- Vertex AI models (Gemini + Anthropic)
|
||||
- Bedrock Models
|
||||
- Anthropic API Models
|
||||
|
||||
## Quick Start
|
||||
|
||||
### url
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="sdk" label="SDK">
|
||||
|
||||
```python
|
||||
from litellm.utils import supports_pdf_input, completion
|
||||
|
||||
# set aws credentials
|
||||
os.environ["AWS_ACCESS_KEY_ID"] = ""
|
||||
os.environ["AWS_SECRET_ACCESS_KEY"] = ""
|
||||
os.environ["AWS_REGION_NAME"] = ""
|
||||
|
||||
|
||||
# pdf url
|
||||
image_url = "https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf"
|
||||
|
||||
# model
|
||||
model = "bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0"
|
||||
|
||||
image_content = [
|
||||
{"type": "text", "text": "What's this file about?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": image_url, # OR {"url": image_url}
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
if not supports_pdf_input(model, None):
|
||||
print("Model does not support image input")
|
||||
|
||||
response = completion(
|
||||
model=model,
|
||||
messages=[{"role": "user", "content": image_content}],
|
||||
)
|
||||
assert response is not None
|
||||
```
|
||||
</TabItem>
|
||||
<TabItem value="proxy" label="PROXY">
|
||||
|
||||
1. Setup config.yaml
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: bedrock-model
|
||||
litellm_params:
|
||||
model: bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0
|
||||
aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID
|
||||
aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY
|
||||
aws_region_name: os.environ/AWS_REGION_NAME
|
||||
```
|
||||
|
||||
2. Start the proxy
|
||||
|
||||
```bash
|
||||
litellm --config /path/to/config.yaml
|
||||
```
|
||||
|
||||
3. Test it!
|
||||
|
||||
```bash
|
||||
curl -X POST 'http://0.0.0.0:4000/chat/completions' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-H 'Authorization: Bearer sk-1234' \
|
||||
-d '{
|
||||
"model": "bedrock-model",
|
||||
"messages": [
|
||||
{"role": "user", "content": {"type": "text", "text": "What's this file about?"}},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": "https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf",
|
||||
}
|
||||
]
|
||||
}'
|
||||
```
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
### base64
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="sdk" label="SDK">
|
||||
|
||||
```python
|
||||
from litellm.utils import supports_pdf_input, completion
|
||||
|
||||
# set aws credentials
|
||||
os.environ["AWS_ACCESS_KEY_ID"] = ""
|
||||
os.environ["AWS_SECRET_ACCESS_KEY"] = ""
|
||||
os.environ["AWS_REGION_NAME"] = ""
|
||||
|
||||
|
||||
# pdf url
|
||||
image_url = "https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf"
|
||||
response = requests.get(url)
|
||||
file_data = response.content
|
||||
|
||||
encoded_file = base64.b64encode(file_data).decode("utf-8")
|
||||
base64_url = f"data:application/pdf;base64,{encoded_file}"
|
||||
|
||||
# model
|
||||
model = "bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0"
|
||||
|
||||
image_content = [
|
||||
{"type": "text", "text": "What's this file about?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": base64_url, # OR {"url": base64_url}
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
if not supports_pdf_input(model, None):
|
||||
print("Model does not support image input")
|
||||
|
||||
response = completion(
|
||||
model=model,
|
||||
messages=[{"role": "user", "content": image_content}],
|
||||
)
|
||||
assert response is not None
|
||||
```
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
## Checking if a model supports pdf input
|
||||
|
||||
<Tabs>
|
||||
<TabItem label="SDK" value="sdk">
|
||||
|
||||
Use `litellm.supports_pdf_input(model="bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0")` -> returns `True` if model can accept pdf input
|
||||
|
||||
```python
|
||||
assert litellm.supports_pdf_input(model="bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0") == True
|
||||
```
|
||||
</TabItem>
|
||||
|
||||
<TabItem label="PROXY" value="proxy">
|
||||
|
||||
1. Define bedrock models on config.yaml
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: bedrock-model # model group name
|
||||
litellm_params:
|
||||
model: bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0
|
||||
aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID
|
||||
aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY
|
||||
aws_region_name: os.environ/AWS_REGION_NAME
|
||||
model_info: # OPTIONAL - set manually
|
||||
supports_pdf_input: True
|
||||
```
|
||||
|
||||
2. Run proxy server
|
||||
|
||||
```bash
|
||||
litellm --config config.yaml
|
||||
```
|
||||
|
||||
3. Call `/model_group/info` to check if a model supports `pdf` input
|
||||
|
||||
```shell
|
||||
curl -X 'GET' \
|
||||
'http://localhost:4000/model_group/info' \
|
||||
-H 'accept: application/json' \
|
||||
-H 'x-api-key: sk-1234'
|
||||
```
|
||||
|
||||
Expected Response
|
||||
|
||||
```json
|
||||
{
|
||||
"data": [
|
||||
{
|
||||
"model_group": "bedrock-model",
|
||||
"providers": ["bedrock"],
|
||||
"max_input_tokens": 128000,
|
||||
"max_output_tokens": 16384,
|
||||
"mode": "chat",
|
||||
...,
|
||||
"supports_pdf_input": true, # 👈 supports_pdf_input is true
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
|
@ -706,6 +706,37 @@ print(response)
|
|||
</Tabs>
|
||||
|
||||
|
||||
## Set 'converse' / 'invoke' route
|
||||
|
||||
LiteLLM defaults to the `invoke` route. LiteLLM uses the `converse` route for Bedrock models that support it.
|
||||
|
||||
To explicitly set the route, do `bedrock/converse/<model>` or `bedrock/invoke/<model>`.
|
||||
|
||||
|
||||
E.g.
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="sdk" label="SDK">
|
||||
|
||||
```python
|
||||
from litellm import completion
|
||||
|
||||
completion(model="bedrock/converse/us.amazon.nova-pro-v1:0")
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
<TabItem value="proxy" label="PROXY">
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: bedrock-model
|
||||
litellm_params:
|
||||
model: bedrock/converse/us.amazon.nova-pro-v1:0
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
## Alternate user/assistant messages
|
||||
|
||||
Use `user_continue_message` to add a default user message, for cases (e.g. Autogen) where the client might not follow alternating user/assistant messages starting and ending with a user message.
|
||||
|
@ -745,6 +776,174 @@ curl -X POST 'http://0.0.0.0:4000/chat/completions' \
|
|||
}'
|
||||
```
|
||||
|
||||
## Usage - PDF / Document Understanding
|
||||
|
||||
LiteLLM supports Document Understanding for Bedrock models - [AWS Bedrock Docs](https://docs.aws.amazon.com/nova/latest/userguide/modalities-document.html).
|
||||
|
||||
### url
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="sdk" label="SDK">
|
||||
|
||||
```python
|
||||
from litellm.utils import supports_pdf_input, completion
|
||||
|
||||
# set aws credentials
|
||||
os.environ["AWS_ACCESS_KEY_ID"] = ""
|
||||
os.environ["AWS_SECRET_ACCESS_KEY"] = ""
|
||||
os.environ["AWS_REGION_NAME"] = ""
|
||||
|
||||
|
||||
# pdf url
|
||||
image_url = "https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf"
|
||||
|
||||
# model
|
||||
model = "bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0"
|
||||
|
||||
image_content = [
|
||||
{"type": "text", "text": "What's this file about?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": image_url, # OR {"url": image_url}
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
if not supports_pdf_input(model, None):
|
||||
print("Model does not support image input")
|
||||
|
||||
response = completion(
|
||||
model=model,
|
||||
messages=[{"role": "user", "content": image_content}],
|
||||
)
|
||||
assert response is not None
|
||||
```
|
||||
</TabItem>
|
||||
<TabItem value="proxy" label="PROXY">
|
||||
|
||||
1. Setup config.yaml
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: bedrock-model
|
||||
litellm_params:
|
||||
model: bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0
|
||||
aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID
|
||||
aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY
|
||||
aws_region_name: os.environ/AWS_REGION_NAME
|
||||
```
|
||||
|
||||
2. Start the proxy
|
||||
|
||||
```bash
|
||||
litellm --config /path/to/config.yaml
|
||||
```
|
||||
|
||||
3. Test it!
|
||||
|
||||
```bash
|
||||
curl -X POST 'http://0.0.0.0:4000/chat/completions' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-H 'Authorization: Bearer sk-1234' \
|
||||
-d '{
|
||||
"model": "bedrock-model",
|
||||
"messages": [
|
||||
{"role": "user", "content": {"type": "text", "text": "What's this file about?"}},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": "https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf",
|
||||
}
|
||||
]
|
||||
}'
|
||||
```
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
### base64
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="sdk" label="SDK">
|
||||
|
||||
```python
|
||||
from litellm.utils import supports_pdf_input, completion
|
||||
|
||||
# set aws credentials
|
||||
os.environ["AWS_ACCESS_KEY_ID"] = ""
|
||||
os.environ["AWS_SECRET_ACCESS_KEY"] = ""
|
||||
os.environ["AWS_REGION_NAME"] = ""
|
||||
|
||||
|
||||
# pdf url
|
||||
image_url = "https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf"
|
||||
response = requests.get(url)
|
||||
file_data = response.content
|
||||
|
||||
encoded_file = base64.b64encode(file_data).decode("utf-8")
|
||||
base64_url = f"data:application/pdf;base64,{encoded_file}"
|
||||
|
||||
# model
|
||||
model = "bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0"
|
||||
|
||||
image_content = [
|
||||
{"type": "text", "text": "What's this file about?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": base64_url, # OR {"url": base64_url}
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
if not supports_pdf_input(model, None):
|
||||
print("Model does not support image input")
|
||||
|
||||
response = completion(
|
||||
model=model,
|
||||
messages=[{"role": "user", "content": image_content}],
|
||||
)
|
||||
assert response is not None
|
||||
```
|
||||
</TabItem>
|
||||
<TabItem value="proxy" label="PROXY">
|
||||
|
||||
1. Setup config.yaml
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: bedrock-model
|
||||
litellm_params:
|
||||
model: bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0
|
||||
aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID
|
||||
aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY
|
||||
aws_region_name: os.environ/AWS_REGION_NAME
|
||||
```
|
||||
|
||||
2. Start the proxy
|
||||
|
||||
```bash
|
||||
litellm --config /path/to/config.yaml
|
||||
```
|
||||
|
||||
3. Test it!
|
||||
|
||||
```bash
|
||||
curl -X POST 'http://0.0.0.0:4000/chat/completions' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-H 'Authorization: Bearer sk-1234' \
|
||||
-d '{
|
||||
"model": "bedrock-model",
|
||||
"messages": [
|
||||
{"role": "user", "content": {"type": "text", "text": "What's this file about?"}},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": "data:application/pdf;base64,{b64_encoded_file}",
|
||||
}
|
||||
]
|
||||
}'
|
||||
```
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
|
||||
## Boto3 - Authentication
|
||||
|
||||
### Passing credentials as parameters - Completion()
|
||||
|
|
63
docs/my-website/docs/providers/galadriel.md
Normal file
63
docs/my-website/docs/providers/galadriel.md
Normal file
|
@ -0,0 +1,63 @@
|
|||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# Galadriel
|
||||
https://docs.galadriel.com/api-reference/chat-completion-API
|
||||
|
||||
LiteLLM supports all models on Galadriel.
|
||||
|
||||
## API Key
|
||||
```python
|
||||
import os
|
||||
os.environ['GALADRIEL_API_KEY'] = "your-api-key"
|
||||
```
|
||||
|
||||
## Sample Usage
|
||||
```python
|
||||
from litellm import completion
|
||||
import os
|
||||
|
||||
os.environ['GALADRIEL_API_KEY'] = ""
|
||||
response = completion(
|
||||
model="galadriel/llama3.1",
|
||||
messages=[
|
||||
{"role": "user", "content": "hello from litellm"}
|
||||
],
|
||||
)
|
||||
print(response)
|
||||
```
|
||||
|
||||
## Sample Usage - Streaming
|
||||
```python
|
||||
from litellm import completion
|
||||
import os
|
||||
|
||||
os.environ['GALADRIEL_API_KEY'] = ""
|
||||
response = completion(
|
||||
model="galadriel/llama3.1",
|
||||
messages=[
|
||||
{"role": "user", "content": "hello from litellm"}
|
||||
],
|
||||
stream=True
|
||||
)
|
||||
|
||||
for chunk in response:
|
||||
print(chunk)
|
||||
```
|
||||
|
||||
|
||||
## Supported Models
|
||||
### Serverless Endpoints
|
||||
We support ALL Galadriel AI models, just set `galadriel/` as a prefix when sending completion requests
|
||||
|
||||
We support both the complete model name and the simplified name match.
|
||||
|
||||
You can specify the model name either with the full name or with a simplified version e.g. `llama3.1:70b`
|
||||
|
||||
| Model Name | Simplified Name | Function Call |
|
||||
| -------------------------------------------------------- | -------------------------------- | ------------------------------------------------------- |
|
||||
| neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8 | llama3.1 or llama3.1:8b | `completion(model="galadriel/llama3.1", messages)` |
|
||||
| neuralmagic/Meta-Llama-3.1-70B-Instruct-quantized.w4a16 | llama3.1:70b | `completion(model="galadriel/llama3.1:70b", messages)` |
|
||||
| neuralmagic/Meta-Llama-3.1-405B-Instruct-quantized.w4a16 | llama3.1:405b | `completion(model="galadriel/llama3.1:405b", messages)` |
|
||||
| neuralmagic/Mistral-Nemo-Instruct-2407-quantized.w4a16 | mistral-nemo or mistral-nemo:12b | `completion(model="galadriel/mistral-nemo", messages)` |
|
||||
|
|
@ -134,28 +134,9 @@ general_settings:
|
|||
| content_policy_fallbacks | array of objects | Fallbacks to use when a ContentPolicyViolationError is encountered. [Further docs](./reliability#content-policy-fallbacks) |
|
||||
| context_window_fallbacks | array of objects | Fallbacks to use when a ContextWindowExceededError is encountered. [Further docs](./reliability#context-window-fallbacks) |
|
||||
| cache | boolean | If true, enables caching. [Further docs](./caching) |
|
||||
| cache_params | object | Parameters for the cache. [Further docs](./caching) |
|
||||
| cache_params.type | string | The type of cache to initialize. Can be one of ["local", "redis", "redis-semantic", "s3", "disk", "qdrant-semantic"]. Defaults to "redis". [Furher docs](./caching) |
|
||||
| cache_params.host | string | The host address for the Redis cache. Required if type is "redis". |
|
||||
| cache_params.port | integer | The port number for the Redis cache. Required if type is "redis". |
|
||||
| cache_params.password | string | The password for the Redis cache. Required if type is "redis". |
|
||||
| cache_params.namespace | string | The namespace for the Redis cache. |
|
||||
| cache_params.redis_startup_nodes | array of objects | Redis Cluster Settings. [Further docs](./caching) |
|
||||
| cache_params.service_name | string | Redis Sentinel Settings. [Further docs](./caching) |
|
||||
| cache_params.sentinel_nodes | array of arrays | Redis Sentinel Settings. [Further docs](./caching) |
|
||||
| cache_params.ttl | integer | The time (in seconds) to store entries in cache. |
|
||||
| cache_params.qdrant_semantic_cache_embedding_model | string | The embedding model to use for qdrant semantic cache. |
|
||||
| cache_params.qdrant_collection_name | string | The name of the collection to use for qdrant semantic cache. |
|
||||
| cache_params.qdrant_quantization_config | string | The quantization configuration for the qdrant semantic cache. |
|
||||
| cache_params.similarity_threshold | float | The similarity threshold for the semantic cache. |
|
||||
| cache_params.s3_bucket_name | string | The name of the S3 bucket to use for the semantic cache. |
|
||||
| cache_params.s3_region_name | string | The region name for the S3 bucket. |
|
||||
| cache_params.s3_aws_access_key_id | string | The AWS access key ID for the S3 bucket. |
|
||||
| cache_params.s3_aws_secret_access_key | string | The AWS secret access key for the S3 bucket. |
|
||||
| cache_params.s3_endpoint_url | string | Optional - The endpoint URL for the S3 bucket. |
|
||||
| cache_params.supported_call_types | array of strings | The types of calls to cache. [Further docs](./caching) |
|
||||
| cache_params.mode | string | The mode of the cache. [Further docs](./caching) |
|
||||
| cache_params | object | Parameters for the cache. [Further docs](./caching#supported-cache_params-on-proxy-configyaml) |
|
||||
| disable_end_user_cost_tracking | boolean | If true, turns off end user cost tracking on prometheus metrics + litellm spend logs table on proxy. |
|
||||
| disable_end_user_cost_tracking_prometheus_only | boolean | If true, turns off end user cost tracking on prometheus metrics only. |
|
||||
| key_generation_settings | object | Restricts who can generate keys. [Further docs](./virtual_keys.md#restricting-key-generation) |
|
||||
|
||||
### general_settings - Reference
|
||||
|
|
|
@ -507,6 +507,11 @@ curl -X GET "http://0.0.0.0:4000/spend/logs?request_id=<your-call-id" \ # e.g.:
|
|||
## Enforce Required Params for LLM Requests
|
||||
Use this when you want to enforce all requests to include certain params. Example you need all requests to include the `user` and `["metadata]["generation_name"]` params.
|
||||
|
||||
|
||||
<Tabs>
|
||||
|
||||
<TabItem value="config" label="Set on Config">
|
||||
|
||||
**Step 1** Define all Params you want to enforce on config.yaml
|
||||
|
||||
This means `["user"]` and `["metadata]["generation_name"]` are required in all LLM Requests to LiteLLM
|
||||
|
@ -518,8 +523,21 @@ general_settings:
|
|||
- user
|
||||
- metadata.generation_name
|
||||
```
|
||||
</TabItem>
|
||||
|
||||
Start LiteLLM Proxy
|
||||
<TabItem value="key" label="Set on Key">
|
||||
|
||||
```bash
|
||||
curl -L -X POST 'http://0.0.0.0:4000/key/generate' \
|
||||
-H 'Authorization: Bearer sk-1234' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{
|
||||
"enforced_params": ["user", "metadata.generation_name"]
|
||||
}'
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
**Step 2 Verify if this works**
|
||||
|
||||
|
|
|
@ -828,8 +828,18 @@ litellm_settings:
|
|||
#### Spec
|
||||
|
||||
```python
|
||||
key_generation_settings: Optional[StandardKeyGenerationConfig] = None
|
||||
```
|
||||
|
||||
#### Types
|
||||
|
||||
```python
|
||||
class StandardKeyGenerationConfig(TypedDict, total=False):
|
||||
team_key_generation: TeamUIKeyGenerationConfig
|
||||
personal_key_generation: PersonalUIKeyGenerationConfig
|
||||
|
||||
class TeamUIKeyGenerationConfig(TypedDict):
|
||||
allowed_team_member_roles: List[str]
|
||||
allowed_team_member_roles: List[str] # either 'user' or 'admin'
|
||||
required_params: List[str] # require params on `/key/generate` to be set if a team key (team_id in request) is being generated
|
||||
|
||||
|
||||
|
@ -838,11 +848,6 @@ class PersonalUIKeyGenerationConfig(TypedDict):
|
|||
required_params: List[str] # require params on `/key/generate` to be set if a personal key (no team_id in request) is being generated
|
||||
|
||||
|
||||
class StandardKeyGenerationConfig(TypedDict, total=False):
|
||||
team_key_generation: TeamUIKeyGenerationConfig
|
||||
personal_key_generation: PersonalUIKeyGenerationConfig
|
||||
|
||||
|
||||
class LitellmUserRoles(str, enum.Enum):
|
||||
"""
|
||||
Admin Roles:
|
||||
|
|
|
@ -177,6 +177,7 @@ const sidebars = {
|
|||
"providers/ollama",
|
||||
"providers/perplexity",
|
||||
"providers/friendliai",
|
||||
"providers/galadriel",
|
||||
"providers/groq",
|
||||
"providers/github",
|
||||
"providers/deepseek",
|
||||
|
@ -210,6 +211,7 @@ const sidebars = {
|
|||
"completion/provider_specific_params",
|
||||
"guides/finetuned_models",
|
||||
"completion/audio",
|
||||
"completion/document_understanding",
|
||||
"completion/vision",
|
||||
"completion/json_mode",
|
||||
"completion/prompt_caching",
|
||||
|
|
|
@ -288,6 +288,7 @@ max_internal_user_budget: Optional[float] = None
|
|||
internal_user_budget_duration: Optional[str] = None
|
||||
max_end_user_budget: Optional[float] = None
|
||||
disable_end_user_cost_tracking: Optional[bool] = None
|
||||
disable_end_user_cost_tracking_prometheus_only: Optional[bool] = None
|
||||
#### REQUEST PRIORITIZATION ####
|
||||
priority_reservation: Optional[Dict[str, float]] = None
|
||||
#### RELIABILITY ####
|
||||
|
@ -385,6 +386,31 @@ organization = None
|
|||
project = None
|
||||
config_path = None
|
||||
vertex_ai_safety_settings: Optional[dict] = None
|
||||
BEDROCK_CONVERSE_MODELS = [
|
||||
"anthropic.claude-3-5-haiku-20241022-v1:0",
|
||||
"anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
"anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||
"anthropic.claude-3-opus-20240229-v1:0",
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
"anthropic.claude-3-haiku-20240307-v1:0",
|
||||
"anthropic.claude-v2",
|
||||
"anthropic.claude-v2:1",
|
||||
"anthropic.claude-v1",
|
||||
"anthropic.claude-instant-v1",
|
||||
"ai21.jamba-instruct-v1:0",
|
||||
"meta.llama3-70b-instruct-v1:0",
|
||||
"meta.llama3-8b-instruct-v1:0",
|
||||
"meta.llama3-1-8b-instruct-v1:0",
|
||||
"meta.llama3-1-70b-instruct-v1:0",
|
||||
"meta.llama3-1-405b-instruct-v1:0",
|
||||
"meta.llama3-70b-instruct-v1:0",
|
||||
"mistral.mistral-large-2407-v1:0",
|
||||
"meta.llama3-2-1b-instruct-v1:0",
|
||||
"meta.llama3-2-3b-instruct-v1:0",
|
||||
"meta.llama3-2-11b-instruct-v1:0",
|
||||
"meta.llama3-2-90b-instruct-v1:0",
|
||||
"meta.llama3-2-405b-instruct-v1:0",
|
||||
]
|
||||
####### COMPLETION MODELS ###################
|
||||
open_ai_chat_completion_models: List = []
|
||||
open_ai_text_completion_models: List = []
|
||||
|
@ -412,6 +438,7 @@ ai21_chat_models: List = []
|
|||
nlp_cloud_models: List = []
|
||||
aleph_alpha_models: List = []
|
||||
bedrock_models: List = []
|
||||
bedrock_converse_models: List = BEDROCK_CONVERSE_MODELS
|
||||
fireworks_ai_models: List = []
|
||||
fireworks_ai_embedding_models: List = []
|
||||
deepinfra_models: List = []
|
||||
|
@ -431,6 +458,7 @@ groq_models: List = []
|
|||
azure_models: List = []
|
||||
anyscale_models: List = []
|
||||
cerebras_models: List = []
|
||||
galadriel_models: List = []
|
||||
|
||||
|
||||
def add_known_models():
|
||||
|
@ -491,6 +519,8 @@ def add_known_models():
|
|||
aleph_alpha_models.append(key)
|
||||
elif value.get("litellm_provider") == "bedrock":
|
||||
bedrock_models.append(key)
|
||||
elif value.get("litellm_provider") == "bedrock_converse":
|
||||
bedrock_converse_models.append(key)
|
||||
elif value.get("litellm_provider") == "deepinfra":
|
||||
deepinfra_models.append(key)
|
||||
elif value.get("litellm_provider") == "perplexity":
|
||||
|
@ -535,6 +565,8 @@ def add_known_models():
|
|||
anyscale_models.append(key)
|
||||
elif value.get("litellm_provider") == "cerebras":
|
||||
cerebras_models.append(key)
|
||||
elif value.get("litellm_provider") == "galadriel":
|
||||
galadriel_models.append(key)
|
||||
|
||||
|
||||
add_known_models()
|
||||
|
@ -554,6 +586,7 @@ openai_compatible_endpoints: List = [
|
|||
"inference.friendli.ai/v1",
|
||||
"api.sambanova.ai/v1",
|
||||
"api.x.ai/v1",
|
||||
"api.galadriel.ai/v1",
|
||||
]
|
||||
|
||||
# this is maintained for Exception Mapping
|
||||
|
@ -581,6 +614,7 @@ openai_compatible_providers: List = [
|
|||
"litellm_proxy",
|
||||
"hosted_vllm",
|
||||
"lm_studio",
|
||||
"galadriel",
|
||||
]
|
||||
openai_text_completion_compatible_providers: List = (
|
||||
[ # providers that support `/v1/completions`
|
||||
|
@ -794,6 +828,7 @@ model_list = (
|
|||
+ azure_models
|
||||
+ anyscale_models
|
||||
+ cerebras_models
|
||||
+ galadriel_models
|
||||
)
|
||||
|
||||
|
||||
|
@ -860,6 +895,7 @@ class LlmProviders(str, Enum):
|
|||
LITELLM_PROXY = "litellm_proxy"
|
||||
HOSTED_VLLM = "hosted_vllm"
|
||||
LM_STUDIO = "lm_studio"
|
||||
GALADRIEL = "galadriel"
|
||||
|
||||
|
||||
provider_list: List[Union[LlmProviders, str]] = list(LlmProviders)
|
||||
|
@ -882,7 +918,7 @@ models_by_provider: dict = {
|
|||
+ vertex_vision_models
|
||||
+ vertex_language_models,
|
||||
"ai21": ai21_models,
|
||||
"bedrock": bedrock_models,
|
||||
"bedrock": bedrock_models + bedrock_converse_models,
|
||||
"petals": petals_models,
|
||||
"ollama": ollama_models,
|
||||
"deepinfra": deepinfra_models,
|
||||
|
@ -908,6 +944,7 @@ models_by_provider: dict = {
|
|||
"azure": azure_models,
|
||||
"anyscale": anyscale_models,
|
||||
"cerebras": cerebras_models,
|
||||
"galadriel": galadriel_models,
|
||||
}
|
||||
|
||||
# mapping for those models which have larger equivalents
|
||||
|
@ -1067,9 +1104,6 @@ from .llms.bedrock.chat.invoke_handler import (
|
|||
AmazonConverseConfig,
|
||||
bedrock_tool_name_mappings,
|
||||
)
|
||||
from .llms.bedrock.chat.converse_handler import (
|
||||
BEDROCK_CONVERSE_MODELS,
|
||||
)
|
||||
from .llms.bedrock.common_utils import (
|
||||
AmazonTitanConfig,
|
||||
AmazonAI21Config,
|
||||
|
|
|
@ -365,7 +365,9 @@ class PrometheusLogger(CustomLogger):
|
|||
model = kwargs.get("model", "")
|
||||
litellm_params = kwargs.get("litellm_params", {}) or {}
|
||||
_metadata = litellm_params.get("metadata", {})
|
||||
end_user_id = get_end_user_id_for_cost_tracking(litellm_params)
|
||||
end_user_id = get_end_user_id_for_cost_tracking(
|
||||
litellm_params, service_type="prometheus"
|
||||
)
|
||||
user_id = standard_logging_payload["metadata"]["user_api_key_user_id"]
|
||||
user_api_key = standard_logging_payload["metadata"]["user_api_key_hash"]
|
||||
user_api_key_alias = standard_logging_payload["metadata"]["user_api_key_alias"]
|
||||
|
@ -668,7 +670,9 @@ class PrometheusLogger(CustomLogger):
|
|||
"standard_logging_object", {}
|
||||
)
|
||||
litellm_params = kwargs.get("litellm_params", {}) or {}
|
||||
end_user_id = get_end_user_id_for_cost_tracking(litellm_params)
|
||||
end_user_id = get_end_user_id_for_cost_tracking(
|
||||
litellm_params, service_type="prometheus"
|
||||
)
|
||||
user_id = standard_logging_payload["metadata"]["user_api_key_user_id"]
|
||||
user_api_key = standard_logging_payload["metadata"]["user_api_key_hash"]
|
||||
user_api_key_alias = standard_logging_payload["metadata"]["user_api_key_alias"]
|
||||
|
|
|
@ -177,6 +177,9 @@ def get_llm_provider( # noqa: PLR0915
|
|||
dynamic_api_key = get_secret_str(
|
||||
"FRIENDLIAI_API_KEY"
|
||||
) or get_secret("FRIENDLI_TOKEN")
|
||||
elif endpoint == "api.galadriel.com/v1":
|
||||
custom_llm_provider = "galadriel"
|
||||
dynamic_api_key = get_secret_str("GALADRIEL_API_KEY")
|
||||
|
||||
if api_base is not None and not isinstance(api_base, str):
|
||||
raise Exception(
|
||||
|
@ -526,6 +529,11 @@ def _get_openai_compatible_provider_info( # noqa: PLR0915
|
|||
or get_secret_str("FRIENDLIAI_API_KEY")
|
||||
or get_secret_str("FRIENDLI_TOKEN")
|
||||
)
|
||||
elif custom_llm_provider == "galadriel":
|
||||
api_base = (
|
||||
api_base or get_secret("GALADRIEL_API_BASE") or "https://api.galadriel.com/v1"
|
||||
) # type: ignore
|
||||
dynamic_api_key = api_key or get_secret_str("GALADRIEL_API_KEY")
|
||||
if api_base is not None and not isinstance(api_base, str):
|
||||
raise Exception("api base needs to be a string. api_base={}".format(api_base))
|
||||
if dynamic_api_key is not None and not isinstance(dynamic_api_key, str):
|
||||
|
|
|
@ -18,32 +18,6 @@ from ...base_aws_llm import BaseAWSLLM
|
|||
from ..common_utils import BedrockError
|
||||
from .invoke_handler import AWSEventStreamDecoder, MockResponseIterator, make_call
|
||||
|
||||
BEDROCK_CONVERSE_MODELS = [
|
||||
"anthropic.claude-3-5-haiku-20241022-v1:0",
|
||||
"anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
"anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||
"anthropic.claude-3-opus-20240229-v1:0",
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
"anthropic.claude-3-haiku-20240307-v1:0",
|
||||
"anthropic.claude-v2",
|
||||
"anthropic.claude-v2:1",
|
||||
"anthropic.claude-v1",
|
||||
"anthropic.claude-instant-v1",
|
||||
"ai21.jamba-instruct-v1:0",
|
||||
"meta.llama3-70b-instruct-v1:0",
|
||||
"meta.llama3-8b-instruct-v1:0",
|
||||
"meta.llama3-1-8b-instruct-v1:0",
|
||||
"meta.llama3-1-70b-instruct-v1:0",
|
||||
"meta.llama3-1-405b-instruct-v1:0",
|
||||
"meta.llama3-70b-instruct-v1:0",
|
||||
"mistral.mistral-large-2407-v1:0",
|
||||
"meta.llama3-2-1b-instruct-v1:0",
|
||||
"meta.llama3-2-3b-instruct-v1:0",
|
||||
"meta.llama3-2-11b-instruct-v1:0",
|
||||
"meta.llama3-2-90b-instruct-v1:0",
|
||||
"meta.llama3-2-405b-instruct-v1:0",
|
||||
]
|
||||
|
||||
|
||||
def make_sync_call(
|
||||
client: Optional[HTTPHandler],
|
||||
|
@ -53,6 +27,8 @@ def make_sync_call(
|
|||
model: str,
|
||||
messages: list,
|
||||
logging_obj,
|
||||
json_mode: Optional[bool] = False,
|
||||
fake_stream: bool = False,
|
||||
):
|
||||
if client is None:
|
||||
client = _get_httpx_client() # Create a new client if none provided
|
||||
|
@ -61,13 +37,13 @@ def make_sync_call(
|
|||
api_base,
|
||||
headers=headers,
|
||||
data=data,
|
||||
stream=True if "ai21" not in api_base else False,
|
||||
stream=not fake_stream,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise BedrockError(status_code=response.status_code, message=response.read())
|
||||
|
||||
if "ai21" in api_base:
|
||||
if fake_stream:
|
||||
model_response: (
|
||||
ModelResponse
|
||||
) = litellm.AmazonConverseConfig()._transform_response(
|
||||
|
@ -83,7 +59,9 @@ def make_sync_call(
|
|||
print_verbose=litellm.print_verbose,
|
||||
encoding=litellm.encoding,
|
||||
) # type: ignore
|
||||
completion_stream: Any = MockResponseIterator(model_response=model_response)
|
||||
completion_stream: Any = MockResponseIterator(
|
||||
model_response=model_response, json_mode=json_mode
|
||||
)
|
||||
else:
|
||||
decoder = AWSEventStreamDecoder(model=model)
|
||||
completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024))
|
||||
|
@ -130,6 +108,8 @@ class BedrockConverseLLM(BaseAWSLLM):
|
|||
logger_fn=None,
|
||||
headers={},
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
fake_stream: bool = False,
|
||||
json_mode: Optional[bool] = False,
|
||||
) -> CustomStreamWrapper:
|
||||
|
||||
completion_stream = await make_call(
|
||||
|
@ -140,6 +120,8 @@ class BedrockConverseLLM(BaseAWSLLM):
|
|||
model=model,
|
||||
messages=messages,
|
||||
logging_obj=logging_obj,
|
||||
fake_stream=fake_stream,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
streaming_response = CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
|
@ -231,12 +213,15 @@ class BedrockConverseLLM(BaseAWSLLM):
|
|||
## SETUP ##
|
||||
stream = optional_params.pop("stream", None)
|
||||
modelId = optional_params.pop("model_id", None)
|
||||
fake_stream = optional_params.pop("fake_stream", False)
|
||||
json_mode = optional_params.get("json_mode", False)
|
||||
if modelId is not None:
|
||||
modelId = self.encode_model_id(model_id=modelId)
|
||||
else:
|
||||
modelId = model
|
||||
|
||||
provider = model.split(".")[0]
|
||||
if stream is True and "ai21" in modelId:
|
||||
fake_stream = True
|
||||
|
||||
## CREDENTIALS ##
|
||||
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
|
||||
|
@ -290,7 +275,7 @@ class BedrockConverseLLM(BaseAWSLLM):
|
|||
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
|
||||
aws_region_name=aws_region_name,
|
||||
)
|
||||
if (stream is not None and stream is True) and provider != "ai21":
|
||||
if (stream is not None and stream is True) and not fake_stream:
|
||||
endpoint_url = f"{endpoint_url}/model/{modelId}/converse-stream"
|
||||
proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/converse-stream"
|
||||
else:
|
||||
|
@ -355,6 +340,8 @@ class BedrockConverseLLM(BaseAWSLLM):
|
|||
headers=prepped.headers,
|
||||
timeout=timeout,
|
||||
client=client,
|
||||
json_mode=json_mode,
|
||||
fake_stream=fake_stream,
|
||||
) # type: ignore
|
||||
### ASYNC COMPLETION
|
||||
return self.async_completion(
|
||||
|
@ -398,6 +385,8 @@ class BedrockConverseLLM(BaseAWSLLM):
|
|||
model=model,
|
||||
messages=messages,
|
||||
logging_obj=logging_obj,
|
||||
json_mode=json_mode,
|
||||
fake_stream=fake_stream,
|
||||
)
|
||||
streaming_response = CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
|
|
|
@ -134,6 +134,43 @@ class AmazonConverseConfig:
|
|||
def get_supported_image_types(self) -> List[str]:
|
||||
return ["png", "jpeg", "gif", "webp"]
|
||||
|
||||
def get_supported_document_types(self) -> List[str]:
|
||||
return ["pdf", "csv", "doc", "docx", "xls", "xlsx", "html", "txt", "md"]
|
||||
|
||||
def _create_json_tool_call_for_response_format(
|
||||
self,
|
||||
json_schema: Optional[dict] = None,
|
||||
schema_name: str = "json_tool_call",
|
||||
) -> ChatCompletionToolParam:
|
||||
"""
|
||||
Handles creating a tool call for getting responses in JSON format.
|
||||
|
||||
Args:
|
||||
json_schema (Optional[dict]): The JSON schema the response should be in
|
||||
|
||||
Returns:
|
||||
AnthropicMessagesTool: The tool call to send to Anthropic API to get responses in JSON format
|
||||
"""
|
||||
|
||||
if json_schema is None:
|
||||
# Anthropic raises a 400 BadRequest error if properties is passed as None
|
||||
# see usage with additionalProperties (Example 5) https://github.com/anthropics/anthropic-cookbook/blob/main/tool_use/extracting_structured_json.ipynb
|
||||
_input_schema = {
|
||||
"type": "object",
|
||||
"additionalProperties": True,
|
||||
"properties": {},
|
||||
}
|
||||
else:
|
||||
_input_schema = json_schema
|
||||
|
||||
_tool = ChatCompletionToolParam(
|
||||
type="function",
|
||||
function=ChatCompletionToolParamFunctionChunk(
|
||||
name=schema_name, parameters=_input_schema
|
||||
),
|
||||
)
|
||||
return _tool
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
model: str,
|
||||
|
@ -160,31 +197,20 @@ class AmazonConverseConfig:
|
|||
- You should set tool_choice (see Forcing tool use) to instruct the model to explicitly use that tool
|
||||
- Remember that the model will pass the input to the tool, so the name of the tool and description should be from the model’s perspective.
|
||||
"""
|
||||
if json_schema is not None:
|
||||
_tool_choice = self.map_tool_choice_values(
|
||||
model=model, tool_choice="required", drop_params=drop_params # type: ignore
|
||||
_tool_choice = {"name": schema_name, "type": "tool"}
|
||||
_tool = self._create_json_tool_call_for_response_format(
|
||||
json_schema=json_schema,
|
||||
schema_name=schema_name if schema_name != "" else "json_tool_call",
|
||||
)
|
||||
optional_params["tools"] = [_tool]
|
||||
optional_params["tool_choice"] = ToolChoiceValuesBlock(
|
||||
tool=SpecificToolChoiceBlock(
|
||||
name=schema_name if schema_name != "" else "json_tool_call"
|
||||
)
|
||||
|
||||
_tool = ChatCompletionToolParam(
|
||||
type="function",
|
||||
function=ChatCompletionToolParamFunctionChunk(
|
||||
name=schema_name, parameters=json_schema
|
||||
),
|
||||
)
|
||||
|
||||
optional_params["tools"] = [_tool]
|
||||
optional_params["tool_choice"] = _tool_choice
|
||||
optional_params["json_mode"] = True
|
||||
else:
|
||||
if litellm.drop_params is True or drop_params is True:
|
||||
pass
|
||||
else:
|
||||
raise litellm.utils.UnsupportedParamsError(
|
||||
message="Bedrock doesn't support response_format={}. To drop it from the call, set `litellm.drop_params = True.".format(
|
||||
value
|
||||
),
|
||||
status_code=400,
|
||||
)
|
||||
)
|
||||
optional_params["json_mode"] = True
|
||||
if non_default_params.get("stream", False) is True:
|
||||
optional_params["fake_stream"] = True
|
||||
if param == "max_tokens" or param == "max_completion_tokens":
|
||||
optional_params["maxTokens"] = value
|
||||
if param == "stream":
|
||||
|
@ -330,7 +356,6 @@ class AmazonConverseConfig:
|
|||
print_verbose,
|
||||
encoding,
|
||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||
|
||||
## LOGGING
|
||||
if logging_obj is not None:
|
||||
logging_obj.post_call(
|
||||
|
@ -468,6 +493,9 @@ class AmazonConverseConfig:
|
|||
AND "meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1"
|
||||
"""
|
||||
|
||||
if model.startswith("converse/"):
|
||||
model = model.split("/")[1]
|
||||
|
||||
potential_region = model.split(".", 1)[0]
|
||||
if potential_region in self._supported_cross_region_inference_region():
|
||||
return model.split(".", 1)[1]
|
||||
|
|
|
@ -182,6 +182,8 @@ async def make_call(
|
|||
model: str,
|
||||
messages: list,
|
||||
logging_obj,
|
||||
fake_stream: bool = False,
|
||||
json_mode: Optional[bool] = False,
|
||||
):
|
||||
try:
|
||||
if client is None:
|
||||
|
@ -193,13 +195,13 @@ async def make_call(
|
|||
api_base,
|
||||
headers=headers,
|
||||
data=data,
|
||||
stream=True if "ai21" not in api_base else False,
|
||||
stream=not fake_stream,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise BedrockError(status_code=response.status_code, message=response.text)
|
||||
|
||||
if "ai21" in api_base:
|
||||
if fake_stream:
|
||||
model_response: (
|
||||
ModelResponse
|
||||
) = litellm.AmazonConverseConfig()._transform_response(
|
||||
|
@ -215,7 +217,9 @@ async def make_call(
|
|||
print_verbose=litellm.print_verbose,
|
||||
encoding=litellm.encoding,
|
||||
) # type: ignore
|
||||
completion_stream: Any = MockResponseIterator(model_response=model_response)
|
||||
completion_stream: Any = MockResponseIterator(
|
||||
model_response=model_response, json_mode=json_mode
|
||||
)
|
||||
else:
|
||||
decoder = AWSEventStreamDecoder(model=model)
|
||||
completion_stream = decoder.aiter_bytes(
|
||||
|
@ -1028,6 +1032,7 @@ class BedrockLLM(BaseAWSLLM):
|
|||
model=model,
|
||||
messages=messages,
|
||||
logging_obj=logging_obj,
|
||||
fake_stream=True if "ai21" in api_base else False,
|
||||
),
|
||||
model=model,
|
||||
custom_llm_provider="bedrock",
|
||||
|
@ -1271,21 +1276,58 @@ class AWSEventStreamDecoder:
|
|||
|
||||
|
||||
class MockResponseIterator: # for returning ai21 streaming responses
|
||||
def __init__(self, model_response):
|
||||
def __init__(self, model_response, json_mode: Optional[bool] = False):
|
||||
self.model_response = model_response
|
||||
self.json_mode = json_mode
|
||||
self.is_done = False
|
||||
|
||||
# Sync iterator
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def _chunk_parser(self, chunk_data: ModelResponse) -> GChunk:
|
||||
def _handle_json_mode_chunk(
|
||||
self, text: str, tool_calls: Optional[List[ChatCompletionToolCallChunk]]
|
||||
) -> Tuple[str, Optional[ChatCompletionToolCallChunk]]:
|
||||
"""
|
||||
If JSON mode is enabled, convert the tool call to a message.
|
||||
|
||||
Bedrock returns the JSON schema as part of the tool call
|
||||
OpenAI returns the JSON schema as part of the content, this handles placing it in the content
|
||||
|
||||
Args:
|
||||
text: str
|
||||
tool_use: Optional[ChatCompletionToolCallChunk]
|
||||
Returns:
|
||||
Tuple[str, Optional[ChatCompletionToolCallChunk]]
|
||||
|
||||
text: The text to use in the content
|
||||
tool_use: The ChatCompletionToolCallChunk to use in the chunk response
|
||||
"""
|
||||
tool_use: Optional[ChatCompletionToolCallChunk] = None
|
||||
if self.json_mode is True and tool_calls is not None:
|
||||
message = litellm.AnthropicConfig()._convert_tool_response_to_message(
|
||||
tool_calls=tool_calls
|
||||
)
|
||||
if message is not None:
|
||||
text = message.content or ""
|
||||
tool_use = None
|
||||
elif tool_calls is not None and len(tool_calls) > 0:
|
||||
tool_use = tool_calls[0]
|
||||
return text, tool_use
|
||||
|
||||
def _chunk_parser(self, chunk_data: ModelResponse) -> GChunk:
|
||||
try:
|
||||
chunk_usage: litellm.Usage = getattr(chunk_data, "usage")
|
||||
text = chunk_data.choices[0].message.content or "" # type: ignore
|
||||
tool_use = None
|
||||
if self.json_mode is True:
|
||||
text, tool_use = self._handle_json_mode_chunk(
|
||||
text=text,
|
||||
tool_calls=chunk_data.choices[0].message.tool_calls, # type: ignore
|
||||
)
|
||||
processed_chunk = GChunk(
|
||||
text=chunk_data.choices[0].message.content or "", # type: ignore
|
||||
tool_use=None,
|
||||
text=text,
|
||||
tool_use=tool_use,
|
||||
is_finished=True,
|
||||
finish_reason=map_finish_reason(
|
||||
finish_reason=chunk_data.choices[0].finish_reason or ""
|
||||
|
@ -1298,8 +1340,8 @@ class MockResponseIterator: # for returning ai21 streaming responses
|
|||
index=0,
|
||||
)
|
||||
return processed_chunk
|
||||
except Exception:
|
||||
raise ValueError(f"Failed to decode chunk: {chunk_data}")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to decode chunk: {chunk_data}. Error: {e}")
|
||||
|
||||
def __next__(self):
|
||||
if self.is_done:
|
||||
|
|
|
@ -2162,8 +2162,9 @@ def stringify_json_tool_call_content(messages: List) -> List:
|
|||
###### AMAZON BEDROCK #######
|
||||
|
||||
from litellm.types.llms.bedrock import ContentBlock as BedrockContentBlock
|
||||
from litellm.types.llms.bedrock import DocumentBlock as BedrockDocumentBlock
|
||||
from litellm.types.llms.bedrock import ImageBlock as BedrockImageBlock
|
||||
from litellm.types.llms.bedrock import ImageSourceBlock as BedrockImageSourceBlock
|
||||
from litellm.types.llms.bedrock import SourceBlock as BedrockSourceBlock
|
||||
from litellm.types.llms.bedrock import ToolBlock as BedrockToolBlock
|
||||
from litellm.types.llms.bedrock import (
|
||||
ToolChoiceValuesBlock as BedrockToolChoiceValuesBlock,
|
||||
|
@ -2210,7 +2211,9 @@ def get_image_details(image_url) -> Tuple[str, str]:
|
|||
raise e
|
||||
|
||||
|
||||
def _process_bedrock_converse_image_block(image_url: str) -> BedrockImageBlock:
|
||||
def _process_bedrock_converse_image_block(
|
||||
image_url: str,
|
||||
) -> BedrockContentBlock:
|
||||
if "base64" in image_url:
|
||||
# Case 1: Images with base64 encoding
|
||||
import base64
|
||||
|
@ -2228,12 +2231,17 @@ def _process_bedrock_converse_image_block(image_url: str) -> BedrockImageBlock:
|
|||
else:
|
||||
mime_type = "image/jpeg"
|
||||
image_format = "jpeg"
|
||||
_blob = BedrockImageSourceBlock(bytes=img_without_base_64)
|
||||
_blob = BedrockSourceBlock(bytes=img_without_base_64)
|
||||
supported_image_formats = (
|
||||
litellm.AmazonConverseConfig().get_supported_image_types()
|
||||
)
|
||||
supported_document_types = (
|
||||
litellm.AmazonConverseConfig().get_supported_document_types()
|
||||
)
|
||||
if image_format in supported_image_formats:
|
||||
return BedrockImageBlock(source=_blob, format=image_format) # type: ignore
|
||||
return BedrockContentBlock(image=BedrockImageBlock(source=_blob, format=image_format)) # type: ignore
|
||||
elif image_format in supported_document_types:
|
||||
return BedrockContentBlock(document=BedrockDocumentBlock(source=_blob, format=image_format, name="DocumentPDFmessages_{}".format(str(uuid.uuid4())))) # type: ignore
|
||||
else:
|
||||
# Handle the case when the image format is not supported
|
||||
raise ValueError(
|
||||
|
@ -2244,12 +2252,17 @@ def _process_bedrock_converse_image_block(image_url: str) -> BedrockImageBlock:
|
|||
elif "https:/" in image_url:
|
||||
# Case 2: Images with direct links
|
||||
image_bytes, image_format = get_image_details(image_url)
|
||||
_blob = BedrockImageSourceBlock(bytes=image_bytes)
|
||||
_blob = BedrockSourceBlock(bytes=image_bytes)
|
||||
supported_image_formats = (
|
||||
litellm.AmazonConverseConfig().get_supported_image_types()
|
||||
)
|
||||
supported_document_types = (
|
||||
litellm.AmazonConverseConfig().get_supported_document_types()
|
||||
)
|
||||
if image_format in supported_image_formats:
|
||||
return BedrockImageBlock(source=_blob, format=image_format) # type: ignore
|
||||
return BedrockContentBlock(image=BedrockImageBlock(source=_blob, format=image_format)) # type: ignore
|
||||
elif image_format in supported_document_types:
|
||||
return BedrockContentBlock(document=BedrockDocumentBlock(source=_blob, format=image_format, name="DocumentPDFmessages_{}".format(str(uuid.uuid4())))) # type: ignore
|
||||
else:
|
||||
# Handle the case when the image format is not supported
|
||||
raise ValueError(
|
||||
|
@ -2464,11 +2477,14 @@ def _bedrock_converse_messages_pt( # noqa: PLR0915
|
|||
_part = BedrockContentBlock(text=element["text"])
|
||||
_parts.append(_part)
|
||||
elif element["type"] == "image_url":
|
||||
image_url = element["image_url"]["url"]
|
||||
if isinstance(element["image_url"], dict):
|
||||
image_url = element["image_url"]["url"]
|
||||
else:
|
||||
image_url = element["image_url"]
|
||||
_part = _process_bedrock_converse_image_block( # type: ignore
|
||||
image_url=image_url
|
||||
)
|
||||
_parts.append(BedrockContentBlock(image=_part)) # type: ignore
|
||||
_parts.append(_part) # type: ignore
|
||||
user_content.extend(_parts)
|
||||
else:
|
||||
_part = BedrockContentBlock(text=messages[msg_i]["content"])
|
||||
|
@ -2539,13 +2555,14 @@ def _bedrock_converse_messages_pt( # noqa: PLR0915
|
|||
assistants_part = BedrockContentBlock(text=element["text"])
|
||||
assistants_parts.append(assistants_part)
|
||||
elif element["type"] == "image_url":
|
||||
image_url = element["image_url"]["url"]
|
||||
if isinstance(element["image_url"], dict):
|
||||
image_url = element["image_url"]["url"]
|
||||
else:
|
||||
image_url = element["image_url"]
|
||||
assistants_part = _process_bedrock_converse_image_block( # type: ignore
|
||||
image_url=image_url
|
||||
)
|
||||
assistants_parts.append(
|
||||
BedrockContentBlock(image=assistants_part) # type: ignore
|
||||
)
|
||||
assistants_parts.append(assistants_part)
|
||||
assistant_content.extend(assistants_parts)
|
||||
elif messages[msg_i].get("content", None) is not None and isinstance(
|
||||
messages[msg_i]["content"], str
|
||||
|
|
|
@ -2603,7 +2603,10 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
|
||||
base_model = litellm.AmazonConverseConfig()._get_base_model(model)
|
||||
|
||||
if base_model in litellm.BEDROCK_CONVERSE_MODELS:
|
||||
if base_model in litellm.bedrock_converse_models or model.startswith(
|
||||
"converse/"
|
||||
):
|
||||
model = model.replace("converse/", "")
|
||||
response = bedrock_converse_chat_completion.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
|
@ -2622,6 +2625,7 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
api_base=api_base,
|
||||
)
|
||||
else:
|
||||
model = model.replace("invoke/", "")
|
||||
response = bedrock_chat_completion.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
|
|
|
@ -4795,6 +4795,42 @@
|
|||
"mode": "chat",
|
||||
"supports_function_calling": true
|
||||
},
|
||||
"amazon.nova-micro-v1:0": {
|
||||
"max_tokens": 4096,
|
||||
"max_input_tokens": 300000,
|
||||
"max_output_tokens": 4096,
|
||||
"input_cost_per_token": 0.000000035,
|
||||
"output_cost_per_token": 0.00000014,
|
||||
"litellm_provider": "bedrock_converse",
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_vision": true,
|
||||
"supports_pdf_input": true
|
||||
},
|
||||
"amazon.nova-lite-v1:0": {
|
||||
"max_tokens": 4096,
|
||||
"max_input_tokens": 128000,
|
||||
"max_output_tokens": 4096,
|
||||
"input_cost_per_token": 0.00000006,
|
||||
"output_cost_per_token": 0.00000024,
|
||||
"litellm_provider": "bedrock_converse",
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_vision": true,
|
||||
"supports_pdf_input": true
|
||||
},
|
||||
"amazon.nova-pro-v1:0": {
|
||||
"max_tokens": 4096,
|
||||
"max_input_tokens": 300000,
|
||||
"max_output_tokens": 4096,
|
||||
"input_cost_per_token": 0.0000008,
|
||||
"output_cost_per_token": 0.0000032,
|
||||
"litellm_provider": "bedrock_converse",
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_vision": true,
|
||||
"supports_pdf_input": true
|
||||
},
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0": {
|
||||
"max_tokens": 4096,
|
||||
"max_input_tokens": 200000,
|
||||
|
|
|
@ -3,6 +3,12 @@ model_list:
|
|||
- model_name: gpt-4
|
||||
litellm_params:
|
||||
model: gpt-4
|
||||
rpm: 1
|
||||
- model_name: gpt-4
|
||||
litellm_params:
|
||||
model: azure/chatgpt-v-2
|
||||
api_key: os.environ/AZURE_API_KEY
|
||||
api_base: os.environ/AZURE_API_BASE
|
||||
- model_name: rerank-model
|
||||
litellm_params:
|
||||
model: jina_ai/jina-reranker-v2-base-multilingual
|
||||
|
|
|
@ -667,6 +667,7 @@ class _GenerateKeyRequest(GenerateRequestBase):
|
|||
|
||||
class GenerateKeyRequest(_GenerateKeyRequest):
|
||||
tags: Optional[List[str]] = None
|
||||
enforced_params: Optional[List[str]] = None
|
||||
|
||||
|
||||
class GenerateKeyResponse(_GenerateKeyRequest):
|
||||
|
@ -2190,4 +2191,5 @@ LiteLLM_ManagementEndpoint_MetadataFields = [
|
|||
"model_tpm_limit",
|
||||
"guardrails",
|
||||
"tags",
|
||||
"enforced_params",
|
||||
]
|
||||
|
|
|
@ -156,40 +156,6 @@ def common_checks( # noqa: PLR0915
|
|||
raise Exception(
|
||||
f"'user' param not passed in. 'enforce_user_param'={general_settings['enforce_user_param']}"
|
||||
)
|
||||
if general_settings.get("enforced_params", None) is not None:
|
||||
# Enterprise ONLY Feature
|
||||
# we already validate if user is premium_user when reading the config
|
||||
# Add an extra premium_usercheck here too, just incase
|
||||
from litellm.proxy.proxy_server import CommonProxyErrors, premium_user
|
||||
|
||||
if premium_user is not True:
|
||||
raise ValueError(
|
||||
"Trying to use `enforced_params`"
|
||||
+ CommonProxyErrors.not_premium_user.value
|
||||
)
|
||||
|
||||
if RouteChecks.is_llm_api_route(route=route):
|
||||
# loop through each enforced param
|
||||
# example enforced_params ['user', 'metadata', 'metadata.generation_name']
|
||||
for enforced_param in general_settings["enforced_params"]:
|
||||
_enforced_params = enforced_param.split(".")
|
||||
if len(_enforced_params) == 1:
|
||||
if _enforced_params[0] not in request_body:
|
||||
raise ValueError(
|
||||
f"BadRequest please pass param={_enforced_params[0]} in request body. This is a required param"
|
||||
)
|
||||
elif len(_enforced_params) == 2:
|
||||
# this is a scenario where user requires request['metadata']['generation_name'] to exist
|
||||
if _enforced_params[0] not in request_body:
|
||||
raise ValueError(
|
||||
f"BadRequest please pass param={_enforced_params[0]} in request body. This is a required param"
|
||||
)
|
||||
if _enforced_params[1] not in request_body[_enforced_params[0]]:
|
||||
raise ValueError(
|
||||
f"BadRequest please pass param=[{_enforced_params[0]}][{_enforced_params[1]}] in request body. This is a required param"
|
||||
)
|
||||
|
||||
pass
|
||||
# 7. [OPTIONAL] If 'litellm.max_budget' is set (>0), is proxy under budget
|
||||
if (
|
||||
litellm.max_budget > 0
|
||||
|
|
|
@ -587,6 +587,16 @@ async def add_litellm_data_to_request( # noqa: PLR0915
|
|||
f"[PROXY]returned data from litellm_pre_call_utils: {data}"
|
||||
)
|
||||
|
||||
## ENFORCED PARAMS CHECK
|
||||
# loop through each enforced param
|
||||
# example enforced_params ['user', 'metadata', 'metadata.generation_name']
|
||||
_enforced_params_check(
|
||||
request_body=data,
|
||||
general_settings=general_settings,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
premium_user=premium_user,
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
await service_logger_obj.async_service_success_hook(
|
||||
service=ServiceTypes.PROXY_PRE_CALL,
|
||||
|
@ -599,6 +609,64 @@ async def add_litellm_data_to_request( # noqa: PLR0915
|
|||
return data
|
||||
|
||||
|
||||
def _get_enforced_params(
|
||||
general_settings: Optional[dict], user_api_key_dict: UserAPIKeyAuth
|
||||
) -> Optional[list]:
|
||||
enforced_params: Optional[list] = None
|
||||
if general_settings is not None:
|
||||
enforced_params = general_settings.get("enforced_params")
|
||||
if "service_account_settings" in general_settings:
|
||||
service_account_settings = general_settings["service_account_settings"]
|
||||
if "enforced_params" in service_account_settings:
|
||||
if enforced_params is None:
|
||||
enforced_params = []
|
||||
enforced_params.extend(service_account_settings["enforced_params"])
|
||||
if user_api_key_dict.metadata.get("enforced_params", None) is not None:
|
||||
if enforced_params is None:
|
||||
enforced_params = []
|
||||
enforced_params.extend(user_api_key_dict.metadata["enforced_params"])
|
||||
return enforced_params
|
||||
|
||||
|
||||
def _enforced_params_check(
|
||||
request_body: dict,
|
||||
general_settings: Optional[dict],
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
premium_user: bool,
|
||||
) -> bool:
|
||||
"""
|
||||
If enforced params are set, check if the request body contains the enforced params.
|
||||
"""
|
||||
enforced_params: Optional[list] = _get_enforced_params(
|
||||
general_settings=general_settings, user_api_key_dict=user_api_key_dict
|
||||
)
|
||||
if enforced_params is None:
|
||||
return True
|
||||
if enforced_params is not None and premium_user is not True:
|
||||
raise ValueError(
|
||||
f"Enforced Params is an Enterprise feature. Enforced Params: {enforced_params}. {CommonProxyErrors.not_premium_user.value}"
|
||||
)
|
||||
|
||||
for enforced_param in enforced_params:
|
||||
_enforced_params = enforced_param.split(".")
|
||||
if len(_enforced_params) == 1:
|
||||
if _enforced_params[0] not in request_body:
|
||||
raise ValueError(
|
||||
f"BadRequest please pass param={_enforced_params[0]} in request body. This is a required param"
|
||||
)
|
||||
elif len(_enforced_params) == 2:
|
||||
# this is a scenario where user requires request['metadata']['generation_name'] to exist
|
||||
if _enforced_params[0] not in request_body:
|
||||
raise ValueError(
|
||||
f"BadRequest please pass param={_enforced_params[0]} in request body. This is a required param"
|
||||
)
|
||||
if _enforced_params[1] not in request_body[_enforced_params[0]]:
|
||||
raise ValueError(
|
||||
f"BadRequest please pass param=[{_enforced_params[0]}][{_enforced_params[1]}] in request body. This is a required param"
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
def move_guardrails_to_metadata(
|
||||
data: dict,
|
||||
_metadata_variable_name: str,
|
||||
|
|
|
@ -255,6 +255,7 @@ async def generate_key_fn( # noqa: PLR0915
|
|||
- tpm_limit: Optional[int] - Specify tpm limit for a given key (Tokens per minute)
|
||||
- soft_budget: Optional[float] - Specify soft budget for a given key. Will trigger a slack alert when this soft budget is reached.
|
||||
- tags: Optional[List[str]] - Tags for [tracking spend](https://litellm.vercel.app/docs/proxy/enterprise#tracking-spend-for-custom-tags) and/or doing [tag-based routing](https://litellm.vercel.app/docs/proxy/tag_routing).
|
||||
- enforced_params: Optional[List[str]] - List of enforced params for the key (Enterprise only). [Docs](https://docs.litellm.ai/docs/proxy/enterprise#enforce-required-params-for-llm-requests)
|
||||
|
||||
Examples:
|
||||
|
||||
|
@ -459,7 +460,6 @@ def prepare_metadata_fields(
|
|||
"""
|
||||
Check LiteLLM_ManagementEndpoint_MetadataFields (proxy/_types.py) for fields that are allowed to be updated
|
||||
"""
|
||||
|
||||
if "metadata" not in non_default_values: # allow user to set metadata to none
|
||||
non_default_values["metadata"] = existing_metadata.copy()
|
||||
|
||||
|
@ -469,18 +469,8 @@ def prepare_metadata_fields(
|
|||
|
||||
try:
|
||||
for k, v in data_json.items():
|
||||
if k == "model_tpm_limit" or k == "model_rpm_limit":
|
||||
if k not in casted_metadata or casted_metadata[k] is None:
|
||||
casted_metadata[k] = {}
|
||||
casted_metadata[k].update(v)
|
||||
|
||||
if k == "tags" or k == "guardrails":
|
||||
if k not in casted_metadata or casted_metadata[k] is None:
|
||||
casted_metadata[k] = []
|
||||
seen = set(casted_metadata[k])
|
||||
casted_metadata[k].extend(
|
||||
x for x in v if x not in seen and not seen.add(x) # type: ignore
|
||||
) # prevent duplicates from being added + maintain initial order
|
||||
if k in LiteLLM_ManagementEndpoint_MetadataFields:
|
||||
casted_metadata[k] = v
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
|
@ -498,10 +488,9 @@ def prepare_key_update_data(
|
|||
):
|
||||
data_json: dict = data.model_dump(exclude_unset=True)
|
||||
data_json.pop("key", None)
|
||||
_metadata_fields = ["model_rpm_limit", "model_tpm_limit", "guardrails", "tags"]
|
||||
non_default_values = {}
|
||||
for k, v in data_json.items():
|
||||
if k in _metadata_fields:
|
||||
if k in LiteLLM_ManagementEndpoint_MetadataFields:
|
||||
continue
|
||||
non_default_values[k] = v
|
||||
|
||||
|
@ -556,6 +545,7 @@ async def update_key_fn(
|
|||
- team_id: Optional[str] - Team ID associated with key
|
||||
- models: Optional[list] - Model_name's a user is allowed to call
|
||||
- tags: Optional[List[str]] - Tags for organizing keys (Enterprise only)
|
||||
- enforced_params: Optional[List[str]] - List of enforced params for the key (Enterprise only). [Docs](https://docs.litellm.ai/docs/proxy/enterprise#enforce-required-params-for-llm-requests)
|
||||
- spend: Optional[float] - Amount spent by key
|
||||
- max_budget: Optional[float] - Max budget for key
|
||||
- model_max_budget: Optional[dict] - Model-specific budgets {"gpt-4": 0.5, "claude-v1": 1.0}
|
||||
|
|
|
@ -2940,6 +2940,7 @@ class Router:
|
|||
remaining_retries=num_retries,
|
||||
num_retries=num_retries,
|
||||
healthy_deployments=_healthy_deployments,
|
||||
all_deployments=_all_deployments,
|
||||
)
|
||||
|
||||
await asyncio.sleep(retry_after)
|
||||
|
@ -2972,6 +2973,7 @@ class Router:
|
|||
remaining_retries=remaining_retries,
|
||||
num_retries=num_retries,
|
||||
healthy_deployments=_healthy_deployments,
|
||||
all_deployments=_all_deployments,
|
||||
)
|
||||
await asyncio.sleep(_timeout)
|
||||
|
||||
|
@ -3149,6 +3151,7 @@ class Router:
|
|||
remaining_retries: int,
|
||||
num_retries: int,
|
||||
healthy_deployments: Optional[List] = None,
|
||||
all_deployments: Optional[List] = None,
|
||||
) -> Union[int, float]:
|
||||
"""
|
||||
Calculate back-off, then retry
|
||||
|
@ -3157,10 +3160,14 @@ class Router:
|
|||
1. there are healthy deployments in the same model group
|
||||
2. there are fallbacks for the completion call
|
||||
"""
|
||||
if (
|
||||
|
||||
## base case - single deployment
|
||||
if all_deployments is not None and len(all_deployments) == 1:
|
||||
pass
|
||||
elif (
|
||||
healthy_deployments is not None
|
||||
and isinstance(healthy_deployments, list)
|
||||
and len(healthy_deployments) > 1
|
||||
and len(healthy_deployments) > 0
|
||||
):
|
||||
return 0
|
||||
|
||||
|
@ -3242,6 +3249,7 @@ class Router:
|
|||
remaining_retries=num_retries,
|
||||
num_retries=num_retries,
|
||||
healthy_deployments=_healthy_deployments,
|
||||
all_deployments=_all_deployments,
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
|
@ -3276,6 +3284,7 @@ class Router:
|
|||
remaining_retries=remaining_retries,
|
||||
num_retries=num_retries,
|
||||
healthy_deployments=_healthy_deployments,
|
||||
all_deployments=_all_deployments,
|
||||
)
|
||||
time.sleep(_timeout)
|
||||
|
||||
|
|
|
@ -18,17 +18,24 @@ class SystemContentBlock(TypedDict):
|
|||
text: str
|
||||
|
||||
|
||||
class ImageSourceBlock(TypedDict):
|
||||
class SourceBlock(TypedDict):
|
||||
bytes: Optional[str] # base 64 encoded string
|
||||
|
||||
|
||||
class ImageBlock(TypedDict):
|
||||
format: Literal["png", "jpeg", "gif", "webp"]
|
||||
source: ImageSourceBlock
|
||||
source: SourceBlock
|
||||
|
||||
|
||||
class DocumentBlock(TypedDict):
|
||||
format: Literal["pdf", "csv", "doc", "docx", "xls", "xlsx", "html", "txt", "md"]
|
||||
source: SourceBlock
|
||||
name: str
|
||||
|
||||
|
||||
class ToolResultContentBlock(TypedDict, total=False):
|
||||
image: ImageBlock
|
||||
document: DocumentBlock
|
||||
json: dict
|
||||
text: str
|
||||
|
||||
|
@ -48,6 +55,7 @@ class ToolUseBlock(TypedDict):
|
|||
class ContentBlock(TypedDict, total=False):
|
||||
text: str
|
||||
image: ImageBlock
|
||||
document: DocumentBlock
|
||||
toolResult: ToolResultBlock
|
||||
toolUse: ToolUseBlock
|
||||
|
||||
|
|
|
@ -106,6 +106,7 @@ class ModelInfo(TypedDict, total=False):
|
|||
supports_prompt_caching: Optional[bool]
|
||||
supports_audio_input: Optional[bool]
|
||||
supports_audio_output: Optional[bool]
|
||||
supports_pdf_input: Optional[bool]
|
||||
tpm: Optional[int]
|
||||
rpm: Optional[int]
|
||||
|
||||
|
|
155
litellm/utils.py
155
litellm/utils.py
|
@ -3142,7 +3142,7 @@ def get_optional_params( # noqa: PLR0915
|
|||
model=model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
base_model = litellm.AmazonConverseConfig()._get_base_model(model)
|
||||
if base_model in litellm.BEDROCK_CONVERSE_MODELS:
|
||||
if base_model in litellm.bedrock_converse_models:
|
||||
_check_valid_arg(supported_params=supported_params)
|
||||
optional_params = litellm.AmazonConverseConfig().map_openai_params(
|
||||
model=model,
|
||||
|
@ -4308,6 +4308,10 @@ def _strip_stable_vertex_version(model_name) -> str:
|
|||
return re.sub(r"-\d+$", "", model_name)
|
||||
|
||||
|
||||
def _strip_bedrock_region(model_name) -> str:
|
||||
return litellm.AmazonConverseConfig()._get_base_model(model_name)
|
||||
|
||||
|
||||
def _strip_openai_finetune_model_name(model_name: str) -> str:
|
||||
"""
|
||||
Strips the organization, custom suffix, and ID from an OpenAI fine-tuned model name.
|
||||
|
@ -4324,16 +4328,50 @@ def _strip_openai_finetune_model_name(model_name: str) -> str:
|
|||
return re.sub(r"(:[^:]+){3}$", "", model_name)
|
||||
|
||||
|
||||
def _strip_model_name(model: str) -> str:
|
||||
strip_version = _strip_stable_vertex_version(model_name=model)
|
||||
strip_finetune = _strip_openai_finetune_model_name(model_name=strip_version)
|
||||
return strip_finetune
|
||||
def _strip_model_name(model: str, custom_llm_provider: Optional[str]) -> str:
|
||||
if custom_llm_provider and custom_llm_provider == "bedrock":
|
||||
strip_bedrock_region = _strip_bedrock_region(model_name=model)
|
||||
return strip_bedrock_region
|
||||
elif custom_llm_provider and (
|
||||
custom_llm_provider == "vertex_ai" or custom_llm_provider == "gemini"
|
||||
):
|
||||
strip_version = _strip_stable_vertex_version(model_name=model)
|
||||
return strip_version
|
||||
else:
|
||||
strip_finetune = _strip_openai_finetune_model_name(model_name=model)
|
||||
return strip_finetune
|
||||
|
||||
|
||||
def _get_model_info_from_model_cost(key: str) -> dict:
|
||||
return litellm.model_cost[key]
|
||||
|
||||
|
||||
def _check_provider_match(model_info: dict, custom_llm_provider: Optional[str]) -> bool:
|
||||
"""
|
||||
Check if the model info provider matches the custom provider.
|
||||
"""
|
||||
if custom_llm_provider and (
|
||||
"litellm_provider" in model_info
|
||||
and model_info["litellm_provider"] != custom_llm_provider
|
||||
):
|
||||
if custom_llm_provider == "vertex_ai" and model_info[
|
||||
"litellm_provider"
|
||||
].startswith("vertex_ai"):
|
||||
return True
|
||||
elif custom_llm_provider == "fireworks_ai" and model_info[
|
||||
"litellm_provider"
|
||||
].startswith("fireworks_ai"):
|
||||
return True
|
||||
elif custom_llm_provider == "bedrock" and model_info[
|
||||
"litellm_provider"
|
||||
].startswith("bedrock"):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def get_model_info( # noqa: PLR0915
|
||||
model: str, custom_llm_provider: Optional[str] = None
|
||||
) -> ModelInfo:
|
||||
|
@ -4388,6 +4426,7 @@ def get_model_info( # noqa: PLR0915
|
|||
supports_prompt_caching: Optional[bool]
|
||||
supports_audio_input: Optional[bool]
|
||||
supports_audio_output: Optional[bool]
|
||||
supports_pdf_input: Optional[bool]
|
||||
Raises:
|
||||
Exception: If the model is not mapped yet.
|
||||
|
||||
|
@ -4445,15 +4484,21 @@ def get_model_info( # noqa: PLR0915
|
|||
except Exception:
|
||||
split_model = model
|
||||
combined_model_name = model
|
||||
stripped_model_name = _strip_model_name(model=model)
|
||||
stripped_model_name = _strip_model_name(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
combined_stripped_model_name = stripped_model_name
|
||||
else:
|
||||
split_model = model
|
||||
combined_model_name = "{}/{}".format(custom_llm_provider, model)
|
||||
stripped_model_name = _strip_model_name(model=model)
|
||||
combined_stripped_model_name = "{}/{}".format(
|
||||
custom_llm_provider, _strip_model_name(model=model)
|
||||
stripped_model_name = _strip_model_name(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
combined_stripped_model_name = "{}/{}".format(
|
||||
custom_llm_provider,
|
||||
_strip_model_name(model=model, custom_llm_provider=custom_llm_provider),
|
||||
)
|
||||
|
||||
#########################
|
||||
|
||||
supported_openai_params = litellm.get_supported_openai_params(
|
||||
|
@ -4476,6 +4521,7 @@ def get_model_info( # noqa: PLR0915
|
|||
supports_function_calling=None,
|
||||
supports_assistant_prefill=None,
|
||||
supports_prompt_caching=None,
|
||||
supports_pdf_input=None,
|
||||
)
|
||||
elif custom_llm_provider == "ollama" or custom_llm_provider == "ollama_chat":
|
||||
return litellm.OllamaConfig().get_model_info(model)
|
||||
|
@ -4488,40 +4534,25 @@ def get_model_info( # noqa: PLR0915
|
|||
4. 'stripped_model_name' in litellm.model_cost. Checks if 'ft:gpt-3.5-turbo' in model map, if 'ft:gpt-3.5-turbo:my-org:custom_suffix:id' given.
|
||||
5. 'split_model' in litellm.model_cost. Checks "llama3-8b-8192" in litellm.model_cost if model="groq/llama3-8b-8192"
|
||||
"""
|
||||
|
||||
_model_info: Optional[Dict[str, Any]] = None
|
||||
key: Optional[str] = None
|
||||
if combined_model_name in litellm.model_cost:
|
||||
key = combined_model_name
|
||||
_model_info = _get_model_info_from_model_cost(key=key)
|
||||
_model_info["supported_openai_params"] = supported_openai_params
|
||||
if (
|
||||
"litellm_provider" in _model_info
|
||||
and _model_info["litellm_provider"] != custom_llm_provider
|
||||
if not _check_provider_match(
|
||||
model_info=_model_info, custom_llm_provider=custom_llm_provider
|
||||
):
|
||||
if custom_llm_provider == "vertex_ai" and _model_info[
|
||||
"litellm_provider"
|
||||
].startswith("vertex_ai"):
|
||||
pass
|
||||
else:
|
||||
_model_info = None
|
||||
_model_info = None
|
||||
if _model_info is None and model in litellm.model_cost:
|
||||
key = model
|
||||
_model_info = _get_model_info_from_model_cost(key=key)
|
||||
_model_info["supported_openai_params"] = supported_openai_params
|
||||
if (
|
||||
"litellm_provider" in _model_info
|
||||
and _model_info["litellm_provider"] != custom_llm_provider
|
||||
if not _check_provider_match(
|
||||
model_info=_model_info, custom_llm_provider=custom_llm_provider
|
||||
):
|
||||
if custom_llm_provider == "vertex_ai" and _model_info[
|
||||
"litellm_provider"
|
||||
].startswith("vertex_ai"):
|
||||
pass
|
||||
elif custom_llm_provider == "fireworks_ai" and _model_info[
|
||||
"litellm_provider"
|
||||
].startswith("fireworks_ai"):
|
||||
pass
|
||||
else:
|
||||
_model_info = None
|
||||
_model_info = None
|
||||
if (
|
||||
_model_info is None
|
||||
and combined_stripped_model_name in litellm.model_cost
|
||||
|
@ -4529,57 +4560,26 @@ def get_model_info( # noqa: PLR0915
|
|||
key = combined_stripped_model_name
|
||||
_model_info = _get_model_info_from_model_cost(key=key)
|
||||
_model_info["supported_openai_params"] = supported_openai_params
|
||||
if (
|
||||
"litellm_provider" in _model_info
|
||||
and _model_info["litellm_provider"] != custom_llm_provider
|
||||
if not _check_provider_match(
|
||||
model_info=_model_info, custom_llm_provider=custom_llm_provider
|
||||
):
|
||||
if custom_llm_provider == "vertex_ai" and _model_info[
|
||||
"litellm_provider"
|
||||
].startswith("vertex_ai"):
|
||||
pass
|
||||
elif custom_llm_provider == "fireworks_ai" and _model_info[
|
||||
"litellm_provider"
|
||||
].startswith("fireworks_ai"):
|
||||
pass
|
||||
else:
|
||||
_model_info = None
|
||||
_model_info = None
|
||||
if _model_info is None and stripped_model_name in litellm.model_cost:
|
||||
key = stripped_model_name
|
||||
_model_info = _get_model_info_from_model_cost(key=key)
|
||||
_model_info["supported_openai_params"] = supported_openai_params
|
||||
if (
|
||||
"litellm_provider" in _model_info
|
||||
and _model_info["litellm_provider"] != custom_llm_provider
|
||||
if not _check_provider_match(
|
||||
model_info=_model_info, custom_llm_provider=custom_llm_provider
|
||||
):
|
||||
if custom_llm_provider == "vertex_ai" and _model_info[
|
||||
"litellm_provider"
|
||||
].startswith("vertex_ai"):
|
||||
pass
|
||||
elif custom_llm_provider == "fireworks_ai" and _model_info[
|
||||
"litellm_provider"
|
||||
].startswith("fireworks_ai"):
|
||||
pass
|
||||
else:
|
||||
_model_info = None
|
||||
|
||||
_model_info = None
|
||||
if _model_info is None and split_model in litellm.model_cost:
|
||||
key = split_model
|
||||
_model_info = _get_model_info_from_model_cost(key=key)
|
||||
_model_info["supported_openai_params"] = supported_openai_params
|
||||
if (
|
||||
"litellm_provider" in _model_info
|
||||
and _model_info["litellm_provider"] != custom_llm_provider
|
||||
if not _check_provider_match(
|
||||
model_info=_model_info, custom_llm_provider=custom_llm_provider
|
||||
):
|
||||
if custom_llm_provider == "vertex_ai" and _model_info[
|
||||
"litellm_provider"
|
||||
].startswith("vertex_ai"):
|
||||
pass
|
||||
elif custom_llm_provider == "fireworks_ai" and _model_info[
|
||||
"litellm_provider"
|
||||
].startswith("fireworks_ai"):
|
||||
pass
|
||||
else:
|
||||
_model_info = None
|
||||
_model_info = None
|
||||
if _model_info is None or key is None:
|
||||
raise ValueError(
|
||||
"This model isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json"
|
||||
|
@ -4675,6 +4675,7 @@ def get_model_info( # noqa: PLR0915
|
|||
),
|
||||
supports_audio_input=_model_info.get("supports_audio_input", False),
|
||||
supports_audio_output=_model_info.get("supports_audio_output", False),
|
||||
supports_pdf_input=_model_info.get("supports_pdf_input", False),
|
||||
tpm=_model_info.get("tpm", None),
|
||||
rpm=_model_info.get("rpm", None),
|
||||
)
|
||||
|
@ -6195,11 +6196,21 @@ class ProviderConfigManager:
|
|||
return OpenAIGPTConfig()
|
||||
|
||||
|
||||
def get_end_user_id_for_cost_tracking(litellm_params: dict) -> Optional[str]:
|
||||
def get_end_user_id_for_cost_tracking(
|
||||
litellm_params: dict,
|
||||
service_type: Literal["litellm_logging", "prometheus"] = "litellm_logging",
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Used for enforcing `disable_end_user_cost_tracking` param.
|
||||
|
||||
service_type: "litellm_logging" or "prometheus" - used to allow prometheus only disable cost tracking.
|
||||
"""
|
||||
proxy_server_request = litellm_params.get("proxy_server_request") or {}
|
||||
if litellm.disable_end_user_cost_tracking:
|
||||
return None
|
||||
if (
|
||||
service_type == "prometheus"
|
||||
and litellm.disable_end_user_cost_tracking_prometheus_only
|
||||
):
|
||||
return None
|
||||
return proxy_server_request.get("body", {}).get("user", None)
|
||||
|
|
|
@ -4795,6 +4795,42 @@
|
|||
"mode": "chat",
|
||||
"supports_function_calling": true
|
||||
},
|
||||
"amazon.nova-micro-v1:0": {
|
||||
"max_tokens": 4096,
|
||||
"max_input_tokens": 300000,
|
||||
"max_output_tokens": 4096,
|
||||
"input_cost_per_token": 0.000000035,
|
||||
"output_cost_per_token": 0.00000014,
|
||||
"litellm_provider": "bedrock_converse",
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_vision": true,
|
||||
"supports_pdf_input": true
|
||||
},
|
||||
"amazon.nova-lite-v1:0": {
|
||||
"max_tokens": 4096,
|
||||
"max_input_tokens": 128000,
|
||||
"max_output_tokens": 4096,
|
||||
"input_cost_per_token": 0.00000006,
|
||||
"output_cost_per_token": 0.00000024,
|
||||
"litellm_provider": "bedrock_converse",
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_vision": true,
|
||||
"supports_pdf_input": true
|
||||
},
|
||||
"amazon.nova-pro-v1:0": {
|
||||
"max_tokens": 4096,
|
||||
"max_input_tokens": 300000,
|
||||
"max_output_tokens": 4096,
|
||||
"input_cost_per_token": 0.0000008,
|
||||
"output_cost_per_token": 0.0000032,
|
||||
"litellm_provider": "bedrock_converse",
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_vision": true,
|
||||
"supports_pdf_input": true
|
||||
},
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0": {
|
||||
"max_tokens": 4096,
|
||||
"max_input_tokens": 200000,
|
||||
|
|
|
@ -54,6 +54,36 @@ class BaseLLMChatTest(ABC):
|
|||
# for OpenAI the content contains the JSON schema, so we need to assert that the content is not None
|
||||
assert response.choices[0].message.content is not None
|
||||
|
||||
@pytest.mark.parametrize("image_url", ["str", "dict"])
|
||||
def test_pdf_handling(self, pdf_messages, image_url):
|
||||
from litellm.utils import supports_pdf_input
|
||||
|
||||
if image_url == "str":
|
||||
image_url = pdf_messages
|
||||
elif image_url == "dict":
|
||||
image_url = {"url": pdf_messages}
|
||||
|
||||
image_content = [
|
||||
{"type": "text", "text": "What's this file about?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": image_url,
|
||||
},
|
||||
]
|
||||
|
||||
image_messages = [{"role": "user", "content": image_content}]
|
||||
|
||||
base_completion_call_args = self.get_base_completion_call_args()
|
||||
|
||||
if not supports_pdf_input(base_completion_call_args["model"], None):
|
||||
pytest.skip("Model does not support image input")
|
||||
|
||||
response = litellm.completion(
|
||||
**base_completion_call_args,
|
||||
messages=image_messages,
|
||||
)
|
||||
assert response is not None
|
||||
|
||||
def test_message_with_name(self):
|
||||
litellm.set_verbose = True
|
||||
base_completion_call_args = self.get_base_completion_call_args()
|
||||
|
@ -187,7 +217,7 @@ class BaseLLMChatTest(ABC):
|
|||
for chunk in response:
|
||||
content += chunk.choices[0].delta.content or ""
|
||||
|
||||
print("content=", content)
|
||||
print(f"content={content}<END>")
|
||||
|
||||
# OpenAI guarantees that the JSON schema is returned in the content
|
||||
# relevant issue: https://github.com/BerriAI/litellm/issues/6741
|
||||
|
@ -250,7 +280,7 @@ class BaseLLMChatTest(ABC):
|
|||
import requests
|
||||
|
||||
# URL of the file
|
||||
url = "https://storage.googleapis.com/cloud-samples-data/generative-ai/pdf/2403.05530.pdf"
|
||||
url = "https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf"
|
||||
|
||||
response = requests.get(url)
|
||||
file_data = response.content
|
||||
|
@ -258,14 +288,4 @@ class BaseLLMChatTest(ABC):
|
|||
encoded_file = base64.b64encode(file_data).decode("utf-8")
|
||||
url = f"data:application/pdf;base64,{encoded_file}"
|
||||
|
||||
image_content = [
|
||||
{"type": "text", "text": "What's this file about?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": url},
|
||||
},
|
||||
]
|
||||
|
||||
image_messages = [{"role": "user", "content": image_content}]
|
||||
|
||||
return image_messages
|
||||
return url
|
||||
|
|
|
@ -668,35 +668,6 @@ class TestAnthropicCompletion(BaseLLMChatTest):
|
|||
def get_base_completion_call_args(self) -> dict:
|
||||
return {"model": "claude-3-haiku-20240307"}
|
||||
|
||||
def test_pdf_handling(self, pdf_messages):
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
from litellm.types.llms.anthropic import AnthropicMessagesDocumentParam
|
||||
import json
|
||||
|
||||
client = HTTPHandler()
|
||||
|
||||
with patch.object(client, "post", new=MagicMock()) as mock_client:
|
||||
response = completion(
|
||||
model="claude-3-5-sonnet-20241022",
|
||||
messages=pdf_messages,
|
||||
client=client,
|
||||
)
|
||||
|
||||
mock_client.assert_called_once()
|
||||
|
||||
json_data = json.loads(mock_client.call_args.kwargs["data"])
|
||||
headers = mock_client.call_args.kwargs["headers"]
|
||||
|
||||
assert headers["anthropic-beta"] == "pdfs-2024-09-25"
|
||||
|
||||
json_data["messages"][0]["role"] == "user"
|
||||
_document_validation = AnthropicMessagesDocumentParam(
|
||||
**json_data["messages"][0]["content"][1]
|
||||
)
|
||||
assert _document_validation["type"] == "document"
|
||||
assert _document_validation["source"]["media_type"] == "application/pdf"
|
||||
assert _document_validation["source"]["type"] == "base64"
|
||||
|
||||
def test_tool_call_no_arguments(self, tool_call_no_arguments):
|
||||
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
|
||||
from litellm.llms.prompt_templates.factory import (
|
||||
|
|
|
@ -30,6 +30,7 @@ from litellm import (
|
|||
from litellm.llms.bedrock.chat import BedrockLLM
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.llms.prompt_templates.factory import _bedrock_tools_pt
|
||||
from base_llm_unit_tests import BaseLLMChatTest
|
||||
|
||||
# litellm.num_retries = 3
|
||||
litellm.cache = None
|
||||
|
@ -1943,3 +1944,42 @@ def test_bedrock_context_window_error():
|
|||
messages=[{"role": "user", "content": "Hello, world!"}],
|
||||
mock_response=Exception("prompt is too long"),
|
||||
)
|
||||
|
||||
|
||||
def test_bedrock_converse_route():
|
||||
litellm.set_verbose = True
|
||||
litellm.completion(
|
||||
model="bedrock/converse/us.amazon.nova-pro-v1:0",
|
||||
messages=[{"role": "user", "content": "Hello, world!"}],
|
||||
)
|
||||
|
||||
|
||||
def test_bedrock_mapped_converse_models():
|
||||
litellm.set_verbose = True
|
||||
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
|
||||
litellm.model_cost = litellm.get_model_cost_map(url="")
|
||||
litellm.add_known_models()
|
||||
litellm.completion(
|
||||
model="bedrock/us.amazon.nova-pro-v1:0",
|
||||
messages=[{"role": "user", "content": "Hello, world!"}],
|
||||
)
|
||||
|
||||
|
||||
def test_bedrock_base_model_helper():
|
||||
model = "us.amazon.nova-pro-v1:0"
|
||||
litellm.AmazonConverseConfig()._get_base_model(model)
|
||||
assert model == "us.amazon.nova-pro-v1:0"
|
||||
|
||||
|
||||
class TestBedrockConverseChat(BaseLLMChatTest):
|
||||
def get_base_completion_call_args(self) -> dict:
|
||||
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
|
||||
litellm.model_cost = litellm.get_model_cost_map(url="")
|
||||
litellm.add_known_models()
|
||||
return {
|
||||
"model": "bedrock/us.anthropic.claude-3-haiku-20240307-v1:0",
|
||||
}
|
||||
|
||||
def test_tool_call_no_arguments(self, tool_call_no_arguments):
|
||||
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
|
||||
pass
|
||||
|
|
|
@ -695,29 +695,6 @@ async def test_anthropic_no_content_error():
|
|||
pytest.fail(f"An unexpected error occurred - {str(e)}")
|
||||
|
||||
|
||||
def test_gemini_completion_call_error():
|
||||
try:
|
||||
print("test completion + streaming")
|
||||
litellm.num_retries = 3
|
||||
litellm.set_verbose = True
|
||||
messages = [{"role": "user", "content": "what is the capital of congo?"}]
|
||||
response = completion(
|
||||
model="gemini/gemini-1.5-pro-latest",
|
||||
messages=messages,
|
||||
stream=True,
|
||||
max_tokens=10,
|
||||
)
|
||||
print(f"response: {response}")
|
||||
for chunk in response:
|
||||
print(chunk)
|
||||
except litellm.RateLimitError:
|
||||
pass
|
||||
except litellm.InternalServerError:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"error occurred: {str(e)}")
|
||||
|
||||
|
||||
def test_completion_cohere_command_r_plus_function_call():
|
||||
litellm.set_verbose = True
|
||||
tools = [
|
||||
|
|
|
@ -2342,12 +2342,6 @@ def test_router_dynamic_cooldown_correct_retry_after_time():
|
|||
pass
|
||||
|
||||
new_retry_after_mock_client.assert_called()
|
||||
print(
|
||||
f"new_retry_after_mock_client.call_args.kwargs: {new_retry_after_mock_client.call_args.kwargs}"
|
||||
)
|
||||
print(
|
||||
f"new_retry_after_mock_client.call_args: {new_retry_after_mock_client.call_args[0][0]}"
|
||||
)
|
||||
|
||||
response_headers: httpx.Headers = new_retry_after_mock_client.call_args[0][0]
|
||||
assert int(response_headers["retry-after"]) == cooldown_time
|
||||
|
|
|
@ -539,45 +539,71 @@ def test_raise_context_window_exceeded_error_no_retry():
|
|||
"""
|
||||
|
||||
|
||||
def test_timeout_for_rate_limit_error_with_healthy_deployments():
|
||||
@pytest.mark.parametrize("num_deployments, expected_timeout", [(1, 60), (2, 0.0)])
|
||||
def test_timeout_for_rate_limit_error_with_healthy_deployments(
|
||||
num_deployments, expected_timeout
|
||||
):
|
||||
"""
|
||||
Test 1. Timeout is 0.0 when RateLimit Error and healthy deployments are > 0
|
||||
"""
|
||||
healthy_deployments = [
|
||||
"deployment1",
|
||||
"deployment2",
|
||||
] # multiple healthy deployments mocked up
|
||||
fallbacks = None
|
||||
|
||||
router = litellm.Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "azure/chatgpt-v-2",
|
||||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": os.getenv("AZURE_API_BASE"),
|
||||
},
|
||||
}
|
||||
]
|
||||
cooldown_time = 60
|
||||
rate_limit_error = litellm.RateLimitError(
|
||||
message="{RouterErrors.no_deployments_available.value}. 12345 Passed model={model_group}. Deployments={deployment_dict}",
|
||||
llm_provider="",
|
||||
model="gpt-3.5-turbo",
|
||||
response=httpx.Response(
|
||||
status_code=429,
|
||||
content="",
|
||||
headers={"retry-after": str(cooldown_time)}, # type: ignore
|
||||
request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||
),
|
||||
)
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "azure/chatgpt-v-2",
|
||||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": os.getenv("AZURE_API_BASE"),
|
||||
},
|
||||
}
|
||||
]
|
||||
if num_deployments == 2:
|
||||
model_list.append(
|
||||
{
|
||||
"model_name": "gpt-4",
|
||||
"litellm_params": {"model": "gpt-3.5-turbo"},
|
||||
}
|
||||
)
|
||||
|
||||
router = litellm.Router(model_list=model_list)
|
||||
|
||||
_timeout = router._time_to_sleep_before_retry(
|
||||
e=rate_limit_error,
|
||||
remaining_retries=4,
|
||||
num_retries=4,
|
||||
healthy_deployments=healthy_deployments,
|
||||
remaining_retries=2,
|
||||
num_retries=2,
|
||||
healthy_deployments=[
|
||||
{
|
||||
"model_name": "gpt-4",
|
||||
"litellm_params": {
|
||||
"api_key": "my-key",
|
||||
"api_base": "https://openai-gpt-4-test-v-1.openai.azure.com",
|
||||
"model": "azure/chatgpt-v-2",
|
||||
},
|
||||
"model_info": {
|
||||
"id": "0e30bc8a63fa91ae4415d4234e231b3f9e6dd900cac57d118ce13a720d95e9d6",
|
||||
"db_model": False,
|
||||
},
|
||||
}
|
||||
],
|
||||
all_deployments=model_list,
|
||||
)
|
||||
|
||||
print(
|
||||
"timeout=",
|
||||
_timeout,
|
||||
"error is rate_limit_error and there are healthy deployments=",
|
||||
healthy_deployments,
|
||||
)
|
||||
|
||||
assert _timeout == 0.0
|
||||
if expected_timeout == 0.0:
|
||||
assert _timeout == expected_timeout
|
||||
else:
|
||||
assert _timeout > 0.0
|
||||
|
||||
|
||||
def test_timeout_for_rate_limit_error_with_no_healthy_deployments():
|
||||
|
@ -585,26 +611,26 @@ def test_timeout_for_rate_limit_error_with_no_healthy_deployments():
|
|||
Test 2. Timeout is > 0.0 when RateLimit Error and healthy deployments == 0
|
||||
"""
|
||||
healthy_deployments = []
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "azure/chatgpt-v-2",
|
||||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": os.getenv("AZURE_API_BASE"),
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
router = litellm.Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "azure/chatgpt-v-2",
|
||||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": os.getenv("AZURE_API_BASE"),
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
router = litellm.Router(model_list=model_list)
|
||||
|
||||
_timeout = router._time_to_sleep_before_retry(
|
||||
e=rate_limit_error,
|
||||
remaining_retries=4,
|
||||
num_retries=4,
|
||||
healthy_deployments=healthy_deployments,
|
||||
all_deployments=model_list,
|
||||
)
|
||||
|
||||
print(
|
||||
|
|
|
@ -1005,7 +1005,10 @@ def test_models_by_provider():
|
|||
continue
|
||||
elif k == "sample_spec":
|
||||
continue
|
||||
elif v["litellm_provider"] == "sagemaker":
|
||||
elif (
|
||||
v["litellm_provider"] == "sagemaker"
|
||||
or v["litellm_provider"] == "bedrock_converse"
|
||||
):
|
||||
continue
|
||||
else:
|
||||
providers.add(v["litellm_provider"])
|
||||
|
@ -1032,3 +1035,27 @@ def test_get_end_user_id_for_cost_tracking(
|
|||
get_end_user_id_for_cost_tracking(litellm_params=litellm_params)
|
||||
== expected_end_user_id
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"litellm_params, disable_end_user_cost_tracking_prometheus_only, expected_end_user_id",
|
||||
[
|
||||
({}, False, None),
|
||||
({"proxy_server_request": {"body": {"user": "123"}}}, False, "123"),
|
||||
({"proxy_server_request": {"body": {"user": "123"}}}, True, None),
|
||||
],
|
||||
)
|
||||
def test_get_end_user_id_for_cost_tracking_prometheus_only(
|
||||
litellm_params, disable_end_user_cost_tracking_prometheus_only, expected_end_user_id
|
||||
):
|
||||
from litellm.utils import get_end_user_id_for_cost_tracking
|
||||
|
||||
litellm.disable_end_user_cost_tracking_prometheus_only = (
|
||||
disable_end_user_cost_tracking_prometheus_only
|
||||
)
|
||||
assert (
|
||||
get_end_user_id_for_cost_tracking(
|
||||
litellm_params=litellm_params, service_type="prometheus"
|
||||
)
|
||||
== expected_end_user_id
|
||||
)
|
||||
|
|
|
@ -695,45 +695,43 @@ def test_personal_key_generation_check():
|
|||
)
|
||||
|
||||
|
||||
def test_prepare_metadata_fields():
|
||||
@pytest.mark.parametrize(
|
||||
"update_request_data, non_default_values, existing_metadata, expected_result",
|
||||
[
|
||||
(
|
||||
{"metadata": {"test": "new"}},
|
||||
{"metadata": {"test": "new"}},
|
||||
{"test": "test"},
|
||||
{"metadata": {"test": "new"}},
|
||||
),
|
||||
(
|
||||
{"tags": ["new_tag"]},
|
||||
{},
|
||||
{"tags": ["old_tag"]},
|
||||
{"metadata": {"tags": ["new_tag"]}},
|
||||
),
|
||||
(
|
||||
{"enforced_params": ["metadata.tags"]},
|
||||
{},
|
||||
{"tags": ["old_tag"]},
|
||||
{"metadata": {"tags": ["old_tag"], "enforced_params": ["metadata.tags"]}},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_prepare_metadata_fields(
|
||||
update_request_data, non_default_values, existing_metadata, expected_result
|
||||
):
|
||||
from litellm.proxy.management_endpoints.key_management_endpoints import (
|
||||
prepare_metadata_fields,
|
||||
)
|
||||
|
||||
new_metadata = {"test": "new"}
|
||||
old_metadata = {"test": "test"}
|
||||
|
||||
args = {
|
||||
"data": UpdateKeyRequest(
|
||||
key_alias=None,
|
||||
duration=None,
|
||||
models=[],
|
||||
spend=None,
|
||||
max_budget=None,
|
||||
user_id=None,
|
||||
team_id=None,
|
||||
max_parallel_requests=None,
|
||||
metadata=new_metadata,
|
||||
tpm_limit=None,
|
||||
rpm_limit=None,
|
||||
budget_duration=None,
|
||||
allowed_cache_controls=[],
|
||||
soft_budget=None,
|
||||
config={},
|
||||
permissions={},
|
||||
model_max_budget={},
|
||||
send_invite_email=None,
|
||||
model_rpm_limit=None,
|
||||
model_tpm_limit=None,
|
||||
guardrails=None,
|
||||
blocked=None,
|
||||
aliases={},
|
||||
key="sk-1qGQUJJTcljeaPfzgWRrXQ",
|
||||
tags=None,
|
||||
key="sk-1qGQUJJTcljeaPfzgWRrXQ", **update_request_data
|
||||
),
|
||||
"non_default_values": {"metadata": new_metadata},
|
||||
"existing_metadata": {"tags": None, **old_metadata},
|
||||
"non_default_values": non_default_values,
|
||||
"existing_metadata": existing_metadata,
|
||||
}
|
||||
|
||||
non_default_values = prepare_metadata_fields(**args)
|
||||
assert non_default_values == {"metadata": new_metadata}
|
||||
updated_non_default_values = prepare_metadata_fields(**args)
|
||||
assert updated_non_default_values == expected_result
|
||||
|
|
|
@ -2664,66 +2664,6 @@ async def test_create_update_team(prisma_client):
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_enforced_params(prisma_client):
|
||||
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||
from litellm.proxy.proxy_server import general_settings
|
||||
|
||||
general_settings["enforced_params"] = [
|
||||
"user",
|
||||
"metadata",
|
||||
"metadata.generation_name",
|
||||
]
|
||||
|
||||
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||
request = NewUserRequest()
|
||||
key = await new_user(
|
||||
data=request,
|
||||
user_api_key_dict=UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
api_key="sk-1234",
|
||||
user_id="1234",
|
||||
),
|
||||
)
|
||||
print(key)
|
||||
|
||||
generated_key = key.key
|
||||
bearer_token = "Bearer " + generated_key
|
||||
|
||||
request = Request(scope={"type": "http"})
|
||||
request._url = URL(url="/chat/completions")
|
||||
|
||||
# Case 1: Missing user
|
||||
async def return_body():
|
||||
return b'{"model": "gemini-pro-vision"}'
|
||||
|
||||
request.body = return_body
|
||||
try:
|
||||
await user_api_key_auth(request=request, api_key=bearer_token)
|
||||
pytest.fail(f"This should have failed!. IT's an invalid request")
|
||||
except Exception as e:
|
||||
assert (
|
||||
"BadRequest please pass param=user in request body. This is a required param"
|
||||
in e.message
|
||||
)
|
||||
|
||||
# Case 2: Missing metadata["generation_name"]
|
||||
async def return_body_2():
|
||||
return b'{"model": "gemini-pro-vision", "user": "1234", "metadata": {}}'
|
||||
|
||||
request.body = return_body_2
|
||||
try:
|
||||
await user_api_key_auth(request=request, api_key=bearer_token)
|
||||
pytest.fail(f"This should have failed!. IT's an invalid request")
|
||||
except Exception as e:
|
||||
assert (
|
||||
"Authentication Error, BadRequest please pass param=[metadata][generation_name] in request body"
|
||||
in e.message
|
||||
)
|
||||
general_settings.pop("enforced_params")
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_update_user_role(prisma_client):
|
||||
"""
|
||||
|
@ -3363,64 +3303,6 @@ async def test_auth_vertex_ai_route(prisma_client):
|
|||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_service_accounts(prisma_client):
|
||||
"""
|
||||
Do not delete
|
||||
this is the Admin UI flow
|
||||
"""
|
||||
# Make a call to a key with model = `all-proxy-models` this is an Alias from LiteLLM Admin UI
|
||||
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||
setattr(
|
||||
litellm.proxy.proxy_server,
|
||||
"general_settings",
|
||||
{"service_account_settings": {"enforced_params": ["user"]}},
|
||||
)
|
||||
|
||||
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||
|
||||
request = GenerateKeyRequest(
|
||||
metadata={"service_account_id": f"prod-service-{uuid.uuid4()}"},
|
||||
)
|
||||
response = await generate_key_fn(
|
||||
data=request,
|
||||
)
|
||||
|
||||
print("key generated=", response)
|
||||
generated_key = response.key
|
||||
bearer_token = "Bearer " + generated_key
|
||||
# make a bad /chat/completions call expect it to fail
|
||||
|
||||
request = Request(scope={"type": "http"})
|
||||
request._url = URL(url="/chat/completions")
|
||||
|
||||
async def return_body():
|
||||
return b'{"model": "gemini-pro-vision"}'
|
||||
|
||||
request.body = return_body
|
||||
|
||||
# use generated key to auth in
|
||||
print("Bearer token being sent to user_api_key_auth() - {}".format(bearer_token))
|
||||
try:
|
||||
result = await user_api_key_auth(request=request, api_key=bearer_token)
|
||||
pytest.fail("Expected this call to fail. Bad request using service account")
|
||||
except Exception as e:
|
||||
print("error str=", str(e.message))
|
||||
assert "This is a required param for service account" in str(e.message)
|
||||
|
||||
# make a good /chat/completions call it should pass
|
||||
async def good_return_body():
|
||||
return b'{"model": "gemini-pro-vision", "user": "foo"}'
|
||||
|
||||
request.body = good_return_body
|
||||
|
||||
result = await user_api_key_auth(request=request, api_key=bearer_token)
|
||||
print("response from user_api_key_auth", result)
|
||||
|
||||
setattr(litellm.proxy.proxy_server, "general_settings", {})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_api_key_auth_db_unavailable():
|
||||
"""
|
||||
|
|
|
@ -679,3 +679,132 @@ async def test_add_litellm_data_to_request_duplicate_tags(
|
|||
assert sorted(result["metadata"]["tags"]) == sorted(
|
||||
expected_tags
|
||||
), f"Expected {expected_tags}, got {result['metadata']['tags']}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"general_settings, user_api_key_dict, expected_enforced_params",
|
||||
[
|
||||
(
|
||||
{"enforced_params": ["param1", "param2"]},
|
||||
UserAPIKeyAuth(
|
||||
api_key="test_api_key", user_id="test_user_id", org_id="test_org_id"
|
||||
),
|
||||
["param1", "param2"],
|
||||
),
|
||||
(
|
||||
{"service_account_settings": {"enforced_params": ["param1", "param2"]}},
|
||||
UserAPIKeyAuth(
|
||||
api_key="test_api_key", user_id="test_user_id", org_id="test_org_id"
|
||||
),
|
||||
["param1", "param2"],
|
||||
),
|
||||
(
|
||||
{"service_account_settings": {"enforced_params": ["param1", "param2"]}},
|
||||
UserAPIKeyAuth(
|
||||
api_key="test_api_key",
|
||||
metadata={"enforced_params": ["param3", "param4"]},
|
||||
),
|
||||
["param1", "param2", "param3", "param4"],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_get_enforced_params(
|
||||
general_settings, user_api_key_dict, expected_enforced_params
|
||||
):
|
||||
from litellm.proxy.litellm_pre_call_utils import _get_enforced_params
|
||||
|
||||
enforced_params = _get_enforced_params(general_settings, user_api_key_dict)
|
||||
assert enforced_params == expected_enforced_params
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"general_settings, user_api_key_dict, request_body, expected_error",
|
||||
[
|
||||
(
|
||||
{"enforced_params": ["param1", "param2"]},
|
||||
UserAPIKeyAuth(
|
||||
api_key="test_api_key", user_id="test_user_id", org_id="test_org_id"
|
||||
),
|
||||
{},
|
||||
True,
|
||||
),
|
||||
(
|
||||
{"service_account_settings": {"enforced_params": ["user"]}},
|
||||
UserAPIKeyAuth(
|
||||
api_key="test_api_key", user_id="test_user_id", org_id="test_org_id"
|
||||
),
|
||||
{},
|
||||
True,
|
||||
),
|
||||
(
|
||||
{},
|
||||
UserAPIKeyAuth(
|
||||
api_key="test_api_key",
|
||||
metadata={"enforced_params": ["user"]},
|
||||
),
|
||||
{},
|
||||
True,
|
||||
),
|
||||
(
|
||||
{},
|
||||
UserAPIKeyAuth(
|
||||
api_key="test_api_key",
|
||||
metadata={"enforced_params": ["user"]},
|
||||
),
|
||||
{"user": "test_user"},
|
||||
False,
|
||||
),
|
||||
(
|
||||
{"enforced_params": ["user"]},
|
||||
UserAPIKeyAuth(
|
||||
api_key="test_api_key",
|
||||
),
|
||||
{"user": "test_user"},
|
||||
False,
|
||||
),
|
||||
(
|
||||
{"service_account_settings": {"enforced_params": ["user"]}},
|
||||
UserAPIKeyAuth(
|
||||
api_key="test_api_key",
|
||||
),
|
||||
{"user": "test_user"},
|
||||
False,
|
||||
),
|
||||
(
|
||||
{"enforced_params": ["metadata.generation_name"]},
|
||||
UserAPIKeyAuth(
|
||||
api_key="test_api_key",
|
||||
),
|
||||
{"metadata": {}},
|
||||
True,
|
||||
),
|
||||
(
|
||||
{"enforced_params": ["metadata.generation_name"]},
|
||||
UserAPIKeyAuth(
|
||||
api_key="test_api_key",
|
||||
),
|
||||
{"metadata": {"generation_name": "test_generation_name"}},
|
||||
False,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_enforced_params_check(
|
||||
general_settings, user_api_key_dict, request_body, expected_error
|
||||
):
|
||||
from litellm.proxy.litellm_pre_call_utils import _enforced_params_check
|
||||
|
||||
if expected_error:
|
||||
with pytest.raises(ValueError):
|
||||
_enforced_params_check(
|
||||
request_body=request_body,
|
||||
general_settings=general_settings,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
premium_user=True,
|
||||
)
|
||||
else:
|
||||
_enforced_params_check(
|
||||
request_body=request_body,
|
||||
general_settings=general_settings,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
premium_user=True,
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue