forked from phoenix/litellm-mirror
Compare commits
9 commits
litellm_mo
...
main
Author | SHA1 | Date | |
---|---|---|---|
|
d2b123eef7 | ||
|
859b47f08b | ||
|
bd59f18809 | ||
|
05f810922c | ||
|
0ac2d8b256 | ||
|
9393434d01 | ||
|
d6181b2c9f | ||
|
4ebb7c8a7f | ||
|
eba700a491 |
61 changed files with 2209 additions and 830 deletions
|
@ -1408,7 +1408,7 @@ jobs:
|
|||
command: |
|
||||
docker run -d \
|
||||
-p 4000:4000 \
|
||||
-e DATABASE_URL=$PROXY_DATABASE_URL \
|
||||
-e DATABASE_URL=$PROXY_DATABASE_URL_2 \
|
||||
-e LITELLM_MASTER_KEY="sk-1234" \
|
||||
-e OPENAI_API_KEY=$OPENAI_API_KEY \
|
||||
-e UI_USERNAME="admin" \
|
||||
|
|
135
docs/my-website/docs/moderation.md
Normal file
135
docs/my-website/docs/moderation.md
Normal file
|
@ -0,0 +1,135 @@
|
|||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# Moderation
|
||||
|
||||
|
||||
### Usage
|
||||
<Tabs>
|
||||
<TabItem value="python" label="LiteLLM Python SDK">
|
||||
|
||||
```python
|
||||
from litellm import moderation
|
||||
|
||||
response = moderation(
|
||||
input="hello from litellm",
|
||||
model="text-moderation-stable"
|
||||
)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
<TabItem value="proxy" label="LiteLLM Proxy Server">
|
||||
|
||||
For `/moderations` endpoint, there is **no need to specify `model` in the request or on the litellm config.yaml**
|
||||
|
||||
Start litellm proxy server
|
||||
|
||||
```
|
||||
litellm
|
||||
```
|
||||
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="python" label="OpenAI Python SDK">
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
# set base_url to your proxy server
|
||||
# set api_key to send to proxy server
|
||||
client = OpenAI(api_key="<proxy-api-key>", base_url="http://0.0.0.0:4000")
|
||||
|
||||
response = client.moderations.create(
|
||||
input="hello from litellm",
|
||||
model="text-moderation-stable" # optional, defaults to `omni-moderation-latest`
|
||||
)
|
||||
|
||||
print(response)
|
||||
```
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="curl" label="Curl Request">
|
||||
|
||||
```shell
|
||||
curl --location 'http://0.0.0.0:4000/moderations' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--data '{"input": "Sample text goes here", "model": "text-moderation-stable"}'
|
||||
```
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
## Input Params
|
||||
LiteLLM accepts and translates the [OpenAI Moderation params](https://platform.openai.com/docs/api-reference/moderations) across all supported providers.
|
||||
|
||||
### Required Fields
|
||||
|
||||
- `input`: *string or array* - Input (or inputs) to classify. Can be a single string, an array of strings, or an array of multi-modal input objects similar to other models.
|
||||
- If string: A string of text to classify for moderation
|
||||
- If array of strings: An array of strings to classify for moderation
|
||||
- If array of objects: An array of multi-modal inputs to the moderation model, where each object can be:
|
||||
- An object describing an image to classify with:
|
||||
- `type`: *string, required* - Always `image_url`
|
||||
- `image_url`: *object, required* - Contains either an image URL or a data URL for a base64 encoded image
|
||||
- An object describing text to classify with:
|
||||
- `type`: *string, required* - Always `text`
|
||||
- `text`: *string, required* - A string of text to classify
|
||||
|
||||
### Optional Fields
|
||||
|
||||
- `model`: *string (optional)* - The moderation model to use. Defaults to `omni-moderation-latest`.
|
||||
|
||||
## Output Format
|
||||
Here's the exact json output and type you can expect from all moderation calls:
|
||||
|
||||
[**LiteLLM follows OpenAI's output format**](https://platform.openai.com/docs/api-reference/moderations/object)
|
||||
|
||||
|
||||
```python
|
||||
{
|
||||
"id": "modr-AB8CjOTu2jiq12hp1AQPfeqFWaORR",
|
||||
"model": "text-moderation-007",
|
||||
"results": [
|
||||
{
|
||||
"flagged": true,
|
||||
"categories": {
|
||||
"sexual": false,
|
||||
"hate": false,
|
||||
"harassment": true,
|
||||
"self-harm": false,
|
||||
"sexual/minors": false,
|
||||
"hate/threatening": false,
|
||||
"violence/graphic": false,
|
||||
"self-harm/intent": false,
|
||||
"self-harm/instructions": false,
|
||||
"harassment/threatening": true,
|
||||
"violence": true
|
||||
},
|
||||
"category_scores": {
|
||||
"sexual": 0.000011726012417057063,
|
||||
"hate": 0.22706663608551025,
|
||||
"harassment": 0.5215635299682617,
|
||||
"self-harm": 2.227119921371923e-6,
|
||||
"sexual/minors": 7.107352217872176e-8,
|
||||
"hate/threatening": 0.023547329008579254,
|
||||
"violence/graphic": 0.00003391829886822961,
|
||||
"self-harm/intent": 1.646940972932498e-6,
|
||||
"self-harm/instructions": 1.1198755256458526e-9,
|
||||
"harassment/threatening": 0.5694745779037476,
|
||||
"violence": 0.9971134662628174
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
|
||||
## **Supported Providers**
|
||||
|
||||
| Provider |
|
||||
|-------------|
|
||||
| OpenAI |
|
59
docs/my-website/docs/proxy/config_management.md
Normal file
59
docs/my-website/docs/proxy/config_management.md
Normal file
|
@ -0,0 +1,59 @@
|
|||
# File Management
|
||||
|
||||
## `include` external YAML files in a config.yaml
|
||||
|
||||
You can use `include` to include external YAML files in a config.yaml.
|
||||
|
||||
**Quick Start Usage:**
|
||||
|
||||
To include a config file, use `include` with either a single file or a list of files.
|
||||
|
||||
Contents of `parent_config.yaml`:
|
||||
```yaml
|
||||
include:
|
||||
- model_config.yaml # 👈 Key change, will include the contents of model_config.yaml
|
||||
|
||||
litellm_settings:
|
||||
callbacks: ["prometheus"]
|
||||
```
|
||||
|
||||
|
||||
Contents of `model_config.yaml`:
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: gpt-4o
|
||||
litellm_params:
|
||||
model: openai/gpt-4o
|
||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||
- model_name: fake-anthropic-endpoint
|
||||
litellm_params:
|
||||
model: anthropic/fake
|
||||
api_base: https://exampleanthropicendpoint-production.up.railway.app/
|
||||
|
||||
```
|
||||
|
||||
Start proxy server
|
||||
|
||||
This will start the proxy server with config `parent_config.yaml`. Since the `include` directive is used, the server will also include the contents of `model_config.yaml`.
|
||||
```
|
||||
litellm --config parent_config.yaml --detailed_debug
|
||||
```
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
## Examples using `include`
|
||||
|
||||
Include a single file:
|
||||
```yaml
|
||||
include:
|
||||
- model_config.yaml
|
||||
```
|
||||
|
||||
Include multiple files:
|
||||
```yaml
|
||||
include:
|
||||
- model_config.yaml
|
||||
- another_config.yaml
|
||||
```
|
|
@ -2,7 +2,7 @@ import Image from '@theme/IdealImage';
|
|||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# Proxy Config.yaml
|
||||
# Overview
|
||||
Set model list, `api_base`, `api_key`, `temperature` & proxy server settings (`master-key`) on the config.yaml.
|
||||
|
||||
| Param Name | Description |
|
||||
|
|
|
@ -50,18 +50,22 @@ You can see the full DB Schema [here](https://github.com/BerriAI/litellm/blob/ma
|
|||
| LiteLLM_ErrorLogs | Captures failed requests and errors. Stores exception details and request information. Helps with debugging and monitoring. | **Medium - on errors only** |
|
||||
| LiteLLM_AuditLog | Tracks changes to system configuration. Records who made changes and what was modified. Maintains history of updates to teams, users, and models. | **Off by default**, **High - when enabled** |
|
||||
|
||||
## How to Disable `LiteLLM_SpendLogs`
|
||||
## Disable `LiteLLM_SpendLogs` & `LiteLLM_ErrorLogs`
|
||||
|
||||
You can disable spend_logs by setting `disable_spend_logs` to `True` on the `general_settings` section of your proxy_config.yaml file.
|
||||
You can disable spend_logs and error_logs by setting `disable_spend_logs` and `disable_error_logs` to `True` on the `general_settings` section of your proxy_config.yaml file.
|
||||
|
||||
```yaml
|
||||
general_settings:
|
||||
disable_spend_logs: True
|
||||
disable_spend_logs: True # Disable writing spend logs to DB
|
||||
disable_error_logs: True # Disable writing error logs to DB
|
||||
```
|
||||
|
||||
### What is the impact of disabling these logs?
|
||||
|
||||
### What is the impact of disabling `LiteLLM_SpendLogs`?
|
||||
|
||||
When disabling spend logs (`disable_spend_logs: True`):
|
||||
- You **will not** be able to view Usage on the LiteLLM UI
|
||||
- You **will** continue seeing cost metrics on s3, Prometheus, Langfuse (any other Logging integration you are using)
|
||||
|
||||
When disabling error logs (`disable_error_logs: True`):
|
||||
- You **will not** be able to view Errors on the LiteLLM UI
|
||||
- You **will** continue seeing error logs in your application logs and any other logging integrations you are using
|
||||
|
|
|
@ -23,6 +23,7 @@ general_settings:
|
|||
|
||||
# OPTIONAL Best Practices
|
||||
disable_spend_logs: True # turn off writing each transaction to the db. We recommend doing this is you don't need to see Usage on the LiteLLM UI and are tracking metrics via Prometheus
|
||||
disable_error_logs: True # turn off writing LLM Exceptions to DB
|
||||
allow_requests_on_db_unavailable: True # Only USE when running LiteLLM on your VPC. Allow requests to still be processed even if the DB is unavailable. We recommend doing this if you're running LiteLLM on VPC that cannot be accessed from the public internet.
|
||||
|
||||
litellm_settings:
|
||||
|
@ -102,17 +103,22 @@ general_settings:
|
|||
allow_requests_on_db_unavailable: True
|
||||
```
|
||||
|
||||
## 6. Disable spend_logs if you're not using the LiteLLM UI
|
||||
## 6. Disable spend_logs & error_logs if not using the LiteLLM UI
|
||||
|
||||
By default LiteLLM will write every request to the `LiteLLM_SpendLogs` table. This is used for viewing Usage on the LiteLLM UI.
|
||||
By default, LiteLLM writes several types of logs to the database:
|
||||
- Every LLM API request to the `LiteLLM_SpendLogs` table
|
||||
- LLM Exceptions to the `LiteLLM_LogsErrors` table
|
||||
|
||||
If you're not viewing Usage on the LiteLLM UI (most users use Prometheus when this is disabled), you can disable spend_logs by setting `disable_spend_logs` to `True`.
|
||||
If you're not viewing these logs on the LiteLLM UI (most users use Prometheus for monitoring), you can disable them by setting the following flags to `True`:
|
||||
|
||||
```yaml
|
||||
general_settings:
|
||||
disable_spend_logs: True
|
||||
disable_spend_logs: True # Disable writing spend logs to DB
|
||||
disable_error_logs: True # Disable writing error logs to DB
|
||||
```
|
||||
|
||||
[More information about what the Database is used for here](db_info)
|
||||
|
||||
## 7. Use Helm PreSync Hook for Database Migrations [BETA]
|
||||
|
||||
To ensure only one service manages database migrations, use our [Helm PreSync hook for Database Migrations](https://github.com/BerriAI/litellm/blob/main/deploy/charts/litellm-helm/templates/migrations-job.yaml). This ensures migrations are handled during `helm upgrade` or `helm install`, while LiteLLM pods explicitly disable migrations.
|
||||
|
|
|
@ -192,3 +192,13 @@ Here is a screenshot of the metrics you can monitor with the LiteLLM Grafana Das
|
|||
|----------------------|--------------------------------------|
|
||||
| `litellm_llm_api_failed_requests_metric` | **deprecated** use `litellm_proxy_failed_requests_metric` |
|
||||
| `litellm_requests_metric` | **deprecated** use `litellm_proxy_total_requests_metric` |
|
||||
|
||||
|
||||
## FAQ
|
||||
|
||||
### What are `_created` vs. `_total` metrics?
|
||||
|
||||
- `_created` metrics are metrics that are created when the proxy starts
|
||||
- `_total` metrics are metrics that are incremented for each request
|
||||
|
||||
You should consume the `_total` metrics for your counting purposes
|
174
docs/my-website/docs/text_completion.md
Normal file
174
docs/my-website/docs/text_completion.md
Normal file
|
@ -0,0 +1,174 @@
|
|||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# Text Completion
|
||||
|
||||
### Usage
|
||||
<Tabs>
|
||||
<TabItem value="python" label="LiteLLM Python SDK">
|
||||
|
||||
```python
|
||||
from litellm import text_completion
|
||||
|
||||
response = text_completion(
|
||||
model="gpt-3.5-turbo-instruct",
|
||||
prompt="Say this is a test",
|
||||
max_tokens=7
|
||||
)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
<TabItem value="proxy" label="LiteLLM Proxy Server">
|
||||
|
||||
1. Define models on config.yaml
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: gpt-3.5-turbo-instruct
|
||||
litellm_params:
|
||||
model: text-completion-openai/gpt-3.5-turbo-instruct # The `text-completion-openai/` prefix will call openai.completions.create
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
- model_name: text-davinci-003
|
||||
litellm_params:
|
||||
model: text-completion-openai/text-davinci-003
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
```
|
||||
|
||||
2. Start litellm proxy server
|
||||
|
||||
```
|
||||
litellm --config config.yaml
|
||||
```
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="python" label="OpenAI Python SDK">
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
# set base_url to your proxy server
|
||||
# set api_key to send to proxy server
|
||||
client = OpenAI(api_key="<proxy-api-key>", base_url="http://0.0.0.0:4000")
|
||||
|
||||
response = client.completions.create(
|
||||
model="gpt-3.5-turbo-instruct",
|
||||
prompt="Say this is a test",
|
||||
max_tokens=7
|
||||
)
|
||||
|
||||
print(response)
|
||||
```
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="curl" label="Curl Request">
|
||||
|
||||
```shell
|
||||
curl --location 'http://0.0.0.0:4000/completions' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--data '{
|
||||
"model": "gpt-3.5-turbo-instruct",
|
||||
"prompt": "Say this is a test",
|
||||
"max_tokens": 7
|
||||
}'
|
||||
```
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
## Input Params
|
||||
|
||||
LiteLLM accepts and translates the [OpenAI Text Completion params](https://platform.openai.com/docs/api-reference/completions) across all supported providers.
|
||||
|
||||
### Required Fields
|
||||
|
||||
- `model`: *string* - ID of the model to use
|
||||
- `prompt`: *string or array* - The prompt(s) to generate completions for
|
||||
|
||||
### Optional Fields
|
||||
|
||||
- `best_of`: *integer* - Generates best_of completions server-side and returns the "best" one
|
||||
- `echo`: *boolean* - Echo back the prompt in addition to the completion.
|
||||
- `frequency_penalty`: *number* - Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency.
|
||||
- `logit_bias`: *map* - Modify the likelihood of specified tokens appearing in the completion
|
||||
- `logprobs`: *integer* - Include the log probabilities on the logprobs most likely tokens. Max value of 5
|
||||
- `max_tokens`: *integer* - The maximum number of tokens to generate.
|
||||
- `n`: *integer* - How many completions to generate for each prompt.
|
||||
- `presence_penalty`: *number* - Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far.
|
||||
- `seed`: *integer* - If specified, system will attempt to make deterministic samples
|
||||
- `stop`: *string or array* - Up to 4 sequences where the API will stop generating tokens
|
||||
- `stream`: *boolean* - Whether to stream back partial progress. Defaults to false
|
||||
- `suffix`: *string* - The suffix that comes after a completion of inserted text
|
||||
- `temperature`: *number* - What sampling temperature to use, between 0 and 2.
|
||||
- `top_p`: *number* - An alternative to sampling with temperature, called nucleus sampling.
|
||||
- `user`: *string* - A unique identifier representing your end-user
|
||||
|
||||
## Output Format
|
||||
Here's the exact JSON output format you can expect from completion calls:
|
||||
|
||||
|
||||
[**Follows OpenAI's output format**](https://platform.openai.com/docs/api-reference/completions/object)
|
||||
|
||||
<Tabs>
|
||||
|
||||
<TabItem value="non-streaming" label="Non-Streaming Response">
|
||||
|
||||
```python
|
||||
{
|
||||
"id": "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7",
|
||||
"object": "text_completion",
|
||||
"created": 1589478378,
|
||||
"model": "gpt-3.5-turbo-instruct",
|
||||
"system_fingerprint": "fp_44709d6fcb",
|
||||
"choices": [
|
||||
{
|
||||
"text": "\n\nThis is indeed a test",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"finish_reason": "length"
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 5,
|
||||
"completion_tokens": 7,
|
||||
"total_tokens": 12
|
||||
}
|
||||
}
|
||||
|
||||
```
|
||||
</TabItem>
|
||||
<TabItem value="streaming" label="Streaming Response">
|
||||
|
||||
```python
|
||||
{
|
||||
"id": "cmpl-7iA7iJjj8V2zOkCGvWF2hAkDWBQZe",
|
||||
"object": "text_completion",
|
||||
"created": 1690759702,
|
||||
"choices": [
|
||||
{
|
||||
"text": "This",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"finish_reason": null
|
||||
}
|
||||
],
|
||||
"model": "gpt-3.5-turbo-instruct"
|
||||
"system_fingerprint": "fp_44709d6fcb",
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
|
||||
## **Supported Providers**
|
||||
|
||||
| Provider | Link to Usage |
|
||||
|-------------|--------------------|
|
||||
| OpenAI | [Usage](../docs/providers/text_completion_openai) |
|
||||
| Azure OpenAI| [Usage](../docs/providers/azure) |
|
||||
|
||||
|
|
@ -32,7 +32,7 @@ const sidebars = {
|
|||
{
|
||||
"type": "category",
|
||||
"label": "Config.yaml",
|
||||
"items": ["proxy/configs", "proxy/config_settings"]
|
||||
"items": ["proxy/configs", "proxy/config_management", "proxy/config_settings"]
|
||||
},
|
||||
{
|
||||
type: "category",
|
||||
|
@ -246,6 +246,7 @@ const sidebars = {
|
|||
"completion/usage",
|
||||
],
|
||||
},
|
||||
"text_completion",
|
||||
"embedding/supported_embedding",
|
||||
"image_generation",
|
||||
{
|
||||
|
@ -261,6 +262,7 @@ const sidebars = {
|
|||
"batches",
|
||||
"realtime",
|
||||
"fine_tuning",
|
||||
"moderation",
|
||||
{
|
||||
type: "link",
|
||||
label: "Use LiteLLM Proxy with Vertex, Bedrock SDK",
|
||||
|
|
|
@ -2,7 +2,9 @@
|
|||
from typing import Optional, List
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.proxy.proxy_server import PrismaClient, HTTPException
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
import collections
|
||||
import httpx
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
|
@ -114,7 +116,6 @@ async def ui_get_spend_by_tags(
|
|||
|
||||
|
||||
def _forecast_daily_cost(data: list):
|
||||
import requests # type: ignore
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
if len(data) == 0:
|
||||
|
@ -136,17 +137,17 @@ def _forecast_daily_cost(data: list):
|
|||
|
||||
print("last entry date", last_entry_date)
|
||||
|
||||
# Assuming today_date is a datetime object
|
||||
today_date = datetime.now()
|
||||
|
||||
# Calculate the last day of the month
|
||||
last_day_of_todays_month = datetime(
|
||||
today_date.year, today_date.month % 12 + 1, 1
|
||||
) - timedelta(days=1)
|
||||
|
||||
print("last day of todays month", last_day_of_todays_month)
|
||||
# Calculate the remaining days in the month
|
||||
remaining_days = (last_day_of_todays_month - last_entry_date).days
|
||||
|
||||
print("remaining days", remaining_days)
|
||||
|
||||
current_spend_this_month = 0
|
||||
series = {}
|
||||
for entry in data:
|
||||
|
@ -176,13 +177,19 @@ def _forecast_daily_cost(data: list):
|
|||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
url="https://trend-api-production.up.railway.app/forecast",
|
||||
json=payload,
|
||||
headers=headers,
|
||||
)
|
||||
# check the status code
|
||||
response.raise_for_status()
|
||||
client = HTTPHandler()
|
||||
|
||||
try:
|
||||
response = client.post(
|
||||
url="https://trend-api-production.up.railway.app/forecast",
|
||||
json=payload,
|
||||
headers=headers,
|
||||
)
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": f"Error getting forecast: {e.response.text}"},
|
||||
)
|
||||
|
||||
json_response = response.json()
|
||||
forecast_data = json_response["forecast"]
|
||||
|
@ -206,13 +213,3 @@ def _forecast_daily_cost(data: list):
|
|||
f"Predicted Spend for { today_month } 2024, ${total_predicted_spend}"
|
||||
)
|
||||
return {"response": response_data, "predicted_spend": predicted_spend}
|
||||
|
||||
# print(f"Date: {entry['date']}, Spend: {entry['spend']}, Response: {response.text}")
|
||||
|
||||
|
||||
# _forecast_daily_cost(
|
||||
# [
|
||||
# {"date": "2022-01-01", "spend": 100},
|
||||
|
||||
# ]
|
||||
# )
|
||||
|
|
|
@ -28,6 +28,62 @@ headers = {
|
|||
_DEFAULT_TIMEOUT = httpx.Timeout(timeout=5.0, connect=5.0)
|
||||
_DEFAULT_TTL_FOR_HTTPX_CLIENTS = 3600 # 1 hour, re-use the same httpx client for 1 hour
|
||||
|
||||
import re
|
||||
|
||||
|
||||
def mask_sensitive_info(error_message):
|
||||
# Find the start of the key parameter
|
||||
if isinstance(error_message, str):
|
||||
key_index = error_message.find("key=")
|
||||
else:
|
||||
return error_message
|
||||
|
||||
# If key is found
|
||||
if key_index != -1:
|
||||
# Find the end of the key parameter (next & or end of string)
|
||||
next_param = error_message.find("&", key_index)
|
||||
|
||||
if next_param == -1:
|
||||
# If no more parameters, mask until the end of the string
|
||||
masked_message = error_message[: key_index + 4] + "[REDACTED_API_KEY]"
|
||||
else:
|
||||
# Replace the key with redacted value, keeping other parameters
|
||||
masked_message = (
|
||||
error_message[: key_index + 4]
|
||||
+ "[REDACTED_API_KEY]"
|
||||
+ error_message[next_param:]
|
||||
)
|
||||
|
||||
return masked_message
|
||||
|
||||
return error_message
|
||||
|
||||
|
||||
class MaskedHTTPStatusError(httpx.HTTPStatusError):
|
||||
def __init__(
|
||||
self, original_error, message: Optional[str] = None, text: Optional[str] = None
|
||||
):
|
||||
# Create a new error with the masked URL
|
||||
masked_url = mask_sensitive_info(str(original_error.request.url))
|
||||
# Create a new error that looks like the original, but with a masked URL
|
||||
|
||||
super().__init__(
|
||||
message=original_error.message,
|
||||
request=httpx.Request(
|
||||
method=original_error.request.method,
|
||||
url=masked_url,
|
||||
headers=original_error.request.headers,
|
||||
content=original_error.request.content,
|
||||
),
|
||||
response=httpx.Response(
|
||||
status_code=original_error.response.status_code,
|
||||
content=original_error.response.content,
|
||||
headers=original_error.response.headers,
|
||||
),
|
||||
)
|
||||
self.message = message
|
||||
self.text = text
|
||||
|
||||
|
||||
class AsyncHTTPHandler:
|
||||
def __init__(
|
||||
|
@ -155,13 +211,16 @@ class AsyncHTTPHandler:
|
|||
headers=headers,
|
||||
)
|
||||
except httpx.HTTPStatusError as e:
|
||||
setattr(e, "status_code", e.response.status_code)
|
||||
|
||||
if stream is True:
|
||||
setattr(e, "message", await e.response.aread())
|
||||
setattr(e, "text", await e.response.aread())
|
||||
else:
|
||||
setattr(e, "message", e.response.text)
|
||||
setattr(e, "text", e.response.text)
|
||||
setattr(e, "message", mask_sensitive_info(e.response.text))
|
||||
setattr(e, "text", mask_sensitive_info(e.response.text))
|
||||
|
||||
setattr(e, "status_code", e.response.status_code)
|
||||
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
@ -399,11 +458,17 @@ class HTTPHandler:
|
|||
llm_provider="litellm-httpx-handler",
|
||||
)
|
||||
except httpx.HTTPStatusError as e:
|
||||
setattr(e, "status_code", e.response.status_code)
|
||||
|
||||
if stream is True:
|
||||
setattr(e, "message", e.response.read())
|
||||
setattr(e, "message", mask_sensitive_info(e.response.read()))
|
||||
setattr(e, "text", mask_sensitive_info(e.response.read()))
|
||||
else:
|
||||
setattr(e, "message", e.response.text)
|
||||
error_text = mask_sensitive_info(e.response.text)
|
||||
setattr(e, "message", error_text)
|
||||
setattr(e, "text", error_text)
|
||||
|
||||
setattr(e, "status_code", e.response.status_code)
|
||||
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
|
|
@ -1159,15 +1159,44 @@ def convert_to_anthropic_tool_result(
|
|||
]
|
||||
}
|
||||
"""
|
||||
content_str: str = ""
|
||||
anthropic_content: Union[
|
||||
str,
|
||||
List[Union[AnthropicMessagesToolResultContent, AnthropicMessagesImageParam]],
|
||||
] = ""
|
||||
if isinstance(message["content"], str):
|
||||
content_str = message["content"]
|
||||
anthropic_content = message["content"]
|
||||
elif isinstance(message["content"], List):
|
||||
content_list = message["content"]
|
||||
anthropic_content_list: List[
|
||||
Union[AnthropicMessagesToolResultContent, AnthropicMessagesImageParam]
|
||||
] = []
|
||||
for content in content_list:
|
||||
if content["type"] == "text":
|
||||
content_str += content["text"]
|
||||
anthropic_content_list.append(
|
||||
AnthropicMessagesToolResultContent(
|
||||
type="text",
|
||||
text=content["text"],
|
||||
)
|
||||
)
|
||||
elif content["type"] == "image_url":
|
||||
if isinstance(content["image_url"], str):
|
||||
image_chunk = convert_to_anthropic_image_obj(content["image_url"])
|
||||
else:
|
||||
image_chunk = convert_to_anthropic_image_obj(
|
||||
content["image_url"]["url"]
|
||||
)
|
||||
anthropic_content_list.append(
|
||||
AnthropicMessagesImageParam(
|
||||
type="image",
|
||||
source=AnthropicContentParamSource(
|
||||
type="base64",
|
||||
media_type=image_chunk["media_type"],
|
||||
data=image_chunk["data"],
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
anthropic_content = anthropic_content_list
|
||||
anthropic_tool_result: Optional[AnthropicMessagesToolResultParam] = None
|
||||
## PROMPT CACHING CHECK ##
|
||||
cache_control = message.get("cache_control", None)
|
||||
|
@ -1178,14 +1207,14 @@ def convert_to_anthropic_tool_result(
|
|||
# We can't determine from openai message format whether it's a successful or
|
||||
# error call result so default to the successful result template
|
||||
anthropic_tool_result = AnthropicMessagesToolResultParam(
|
||||
type="tool_result", tool_use_id=tool_call_id, content=content_str
|
||||
type="tool_result", tool_use_id=tool_call_id, content=anthropic_content
|
||||
)
|
||||
|
||||
if message["role"] == "function":
|
||||
function_message: ChatCompletionFunctionMessage = message
|
||||
tool_call_id = function_message.get("tool_call_id") or str(uuid.uuid4())
|
||||
anthropic_tool_result = AnthropicMessagesToolResultParam(
|
||||
type="tool_result", tool_use_id=tool_call_id, content=content_str
|
||||
type="tool_result", tool_use_id=tool_call_id, content=anthropic_content
|
||||
)
|
||||
|
||||
if anthropic_tool_result is None:
|
||||
|
|
|
@ -107,6 +107,10 @@ def _get_image_mime_type_from_url(url: str) -> Optional[str]:
|
|||
return "image/png"
|
||||
elif url.endswith(".webp"):
|
||||
return "image/webp"
|
||||
elif url.endswith(".mp4"):
|
||||
return "video/mp4"
|
||||
elif url.endswith(".pdf"):
|
||||
return "application/pdf"
|
||||
return None
|
||||
|
||||
|
||||
|
|
|
@ -15,6 +15,22 @@ model_list:
|
|||
litellm_params:
|
||||
model: openai/gpt-4o-realtime-preview-2024-10-01
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
- model_name: openai/*
|
||||
litellm_params:
|
||||
model: openai/*
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
- model_name: openai/*
|
||||
litellm_params:
|
||||
model: openai/*
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
model_info:
|
||||
access_groups: ["public-openai-models"]
|
||||
- model_name: openai/gpt-4o
|
||||
litellm_params:
|
||||
model: openai/gpt-4o
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
model_info:
|
||||
access_groups: ["private-openai-models"]
|
||||
|
||||
router_settings:
|
||||
routing_strategy: usage-based-routing-v2
|
||||
|
|
|
@ -2183,3 +2183,11 @@ PassThroughEndpointLoggingResultValues = Union[
|
|||
class PassThroughEndpointLoggingTypedDict(TypedDict):
|
||||
result: Optional[PassThroughEndpointLoggingResultValues]
|
||||
kwargs: dict
|
||||
|
||||
|
||||
LiteLLM_ManagementEndpoint_MetadataFields = [
|
||||
"model_rpm_limit",
|
||||
"model_tpm_limit",
|
||||
"guardrails",
|
||||
"tags",
|
||||
]
|
||||
|
|
|
@ -60,6 +60,7 @@ def common_checks( # noqa: PLR0915
|
|||
global_proxy_spend: Optional[float],
|
||||
general_settings: dict,
|
||||
route: str,
|
||||
llm_router: Optional[litellm.Router],
|
||||
) -> bool:
|
||||
"""
|
||||
Common checks across jwt + key-based auth.
|
||||
|
@ -97,7 +98,12 @@ def common_checks( # noqa: PLR0915
|
|||
# this means the team has access to all models on the proxy
|
||||
pass
|
||||
# check if the team model is an access_group
|
||||
elif model_in_access_group(_model, team_object.models) is True:
|
||||
elif (
|
||||
model_in_access_group(
|
||||
model=_model, team_models=team_object.models, llm_router=llm_router
|
||||
)
|
||||
is True
|
||||
):
|
||||
pass
|
||||
elif _model and "*" in _model:
|
||||
pass
|
||||
|
@ -373,36 +379,33 @@ async def get_end_user_object(
|
|||
return None
|
||||
|
||||
|
||||
def model_in_access_group(model: str, team_models: Optional[List[str]]) -> bool:
|
||||
def model_in_access_group(
|
||||
model: str, team_models: Optional[List[str]], llm_router: Optional[litellm.Router]
|
||||
) -> bool:
|
||||
from collections import defaultdict
|
||||
|
||||
from litellm.proxy.proxy_server import llm_router
|
||||
|
||||
if team_models is None:
|
||||
return True
|
||||
if model in team_models:
|
||||
return True
|
||||
|
||||
access_groups = defaultdict(list)
|
||||
access_groups: dict[str, list[str]] = defaultdict(list)
|
||||
if llm_router:
|
||||
access_groups = llm_router.get_model_access_groups()
|
||||
access_groups = llm_router.get_model_access_groups(model_name=model)
|
||||
|
||||
models_in_current_access_groups = []
|
||||
if len(access_groups) > 0: # check if token contains any model access groups
|
||||
for idx, m in enumerate(
|
||||
team_models
|
||||
): # loop token models, if any of them are an access group add the access group
|
||||
if m in access_groups:
|
||||
# if it is an access group we need to remove it from valid_token.models
|
||||
models_in_group = access_groups[m]
|
||||
models_in_current_access_groups.extend(models_in_group)
|
||||
return True
|
||||
|
||||
# Filter out models that are access_groups
|
||||
filtered_models = [m for m in team_models if m not in access_groups]
|
||||
filtered_models += models_in_current_access_groups
|
||||
|
||||
if model in filtered_models:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
|
@ -523,10 +526,6 @@ async def _cache_management_object(
|
|||
proxy_logging_obj: Optional[ProxyLogging],
|
||||
):
|
||||
await user_api_key_cache.async_set_cache(key=key, value=value)
|
||||
if proxy_logging_obj is not None:
|
||||
await proxy_logging_obj.internal_usage_cache.dual_cache.async_set_cache(
|
||||
key=key, value=value
|
||||
)
|
||||
|
||||
|
||||
async def _cache_team_object(
|
||||
|
@ -878,7 +877,10 @@ async def get_org_object(
|
|||
|
||||
|
||||
async def can_key_call_model(
|
||||
model: str, llm_model_list: Optional[list], valid_token: UserAPIKeyAuth
|
||||
model: str,
|
||||
llm_model_list: Optional[list],
|
||||
valid_token: UserAPIKeyAuth,
|
||||
llm_router: Optional[litellm.Router],
|
||||
) -> Literal[True]:
|
||||
"""
|
||||
Checks if token can call a given model
|
||||
|
@ -898,35 +900,29 @@ async def can_key_call_model(
|
|||
)
|
||||
from collections import defaultdict
|
||||
|
||||
from litellm.proxy.proxy_server import llm_router
|
||||
|
||||
access_groups = defaultdict(list)
|
||||
if llm_router:
|
||||
access_groups = llm_router.get_model_access_groups()
|
||||
access_groups = llm_router.get_model_access_groups(model_name=model)
|
||||
|
||||
models_in_current_access_groups = []
|
||||
if len(access_groups) > 0: # check if token contains any model access groups
|
||||
if (
|
||||
len(access_groups) > 0 and llm_router is not None
|
||||
): # check if token contains any model access groups
|
||||
for idx, m in enumerate(
|
||||
valid_token.models
|
||||
): # loop token models, if any of them are an access group add the access group
|
||||
if m in access_groups:
|
||||
# if it is an access group we need to remove it from valid_token.models
|
||||
models_in_group = access_groups[m]
|
||||
models_in_current_access_groups.extend(models_in_group)
|
||||
return True
|
||||
|
||||
# Filter out models that are access_groups
|
||||
filtered_models = [m for m in valid_token.models if m not in access_groups]
|
||||
|
||||
filtered_models += models_in_current_access_groups
|
||||
verbose_proxy_logger.debug(f"model: {model}; allowed_models: {filtered_models}")
|
||||
|
||||
all_model_access: bool = False
|
||||
|
||||
if (
|
||||
len(filtered_models) == 0
|
||||
or "*" in filtered_models
|
||||
or "openai/*" in filtered_models
|
||||
):
|
||||
len(filtered_models) == 0 and len(valid_token.models) == 0
|
||||
) or "*" in filtered_models:
|
||||
all_model_access = True
|
||||
|
||||
if model is not None and model not in filtered_models and all_model_access is False:
|
||||
|
|
|
@ -259,6 +259,7 @@ async def user_api_key_auth( # noqa: PLR0915
|
|||
jwt_handler,
|
||||
litellm_proxy_admin_name,
|
||||
llm_model_list,
|
||||
llm_router,
|
||||
master_key,
|
||||
open_telemetry_logger,
|
||||
prisma_client,
|
||||
|
@ -542,6 +543,7 @@ async def user_api_key_auth( # noqa: PLR0915
|
|||
general_settings=general_settings,
|
||||
global_proxy_spend=global_proxy_spend,
|
||||
route=route,
|
||||
llm_router=llm_router,
|
||||
)
|
||||
|
||||
# return UserAPIKeyAuth object
|
||||
|
@ -905,6 +907,7 @@ async def user_api_key_auth( # noqa: PLR0915
|
|||
model=model,
|
||||
llm_model_list=llm_model_list,
|
||||
valid_token=valid_token,
|
||||
llm_router=llm_router,
|
||||
)
|
||||
|
||||
if fallback_models is not None:
|
||||
|
@ -913,6 +916,7 @@ async def user_api_key_auth( # noqa: PLR0915
|
|||
model=m,
|
||||
llm_model_list=llm_model_list,
|
||||
valid_token=valid_token,
|
||||
llm_router=llm_router,
|
||||
)
|
||||
|
||||
# Check 2. If user_id for this token is in budget - done in common_checks()
|
||||
|
@ -1173,6 +1177,7 @@ async def user_api_key_auth( # noqa: PLR0915
|
|||
general_settings=general_settings,
|
||||
global_proxy_spend=global_proxy_spend,
|
||||
route=route,
|
||||
llm_router=llm_router,
|
||||
)
|
||||
# Token passed all checks
|
||||
if valid_token is None:
|
||||
|
|
|
@ -214,10 +214,10 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
|
|||
prepared_request.url,
|
||||
prepared_request.headers,
|
||||
)
|
||||
_json_data = json.dumps(request_data) # type: ignore
|
||||
|
||||
response = await self.async_handler.post(
|
||||
url=prepared_request.url,
|
||||
json=request_data, # type: ignore
|
||||
data=prepared_request.body, # type: ignore
|
||||
headers=prepared_request.headers, # type: ignore
|
||||
)
|
||||
verbose_proxy_logger.debug("Bedrock AI response: %s", response.text)
|
||||
|
|
87
litellm/proxy/hooks/proxy_failure_handler.py
Normal file
87
litellm/proxy/hooks/proxy_failure_handler.py
Normal file
|
@ -0,0 +1,87 @@
|
|||
"""
|
||||
Runs when LLM Exceptions occur on LiteLLM Proxy
|
||||
"""
|
||||
|
||||
import copy
|
||||
import json
|
||||
import uuid
|
||||
|
||||
import litellm
|
||||
from litellm.proxy._types import LiteLLM_ErrorLogs
|
||||
|
||||
|
||||
async def _PROXY_failure_handler(
|
||||
kwargs, # kwargs to completion
|
||||
completion_response: litellm.ModelResponse, # response from completion
|
||||
start_time=None,
|
||||
end_time=None, # start/end time for completion
|
||||
):
|
||||
"""
|
||||
Async Failure Handler - runs when LLM Exceptions occur on LiteLLM Proxy.
|
||||
This function logs the errors to the Prisma DB
|
||||
|
||||
Can be disabled by setting the following on proxy_config.yaml:
|
||||
```yaml
|
||||
general_settings:
|
||||
disable_error_logs: True
|
||||
```
|
||||
|
||||
"""
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy.proxy_server import general_settings, prisma_client
|
||||
|
||||
if general_settings.get("disable_error_logs") is True:
|
||||
return
|
||||
|
||||
if prisma_client is not None:
|
||||
verbose_proxy_logger.debug(
|
||||
"inside _PROXY_failure_handler kwargs=", extra=kwargs
|
||||
)
|
||||
|
||||
_exception = kwargs.get("exception")
|
||||
_exception_type = _exception.__class__.__name__
|
||||
_model = kwargs.get("model", None)
|
||||
|
||||
_optional_params = kwargs.get("optional_params", {})
|
||||
_optional_params = copy.deepcopy(_optional_params)
|
||||
|
||||
for k, v in _optional_params.items():
|
||||
v = str(v)
|
||||
v = v[:100]
|
||||
|
||||
_status_code = "500"
|
||||
try:
|
||||
_status_code = str(_exception.status_code)
|
||||
except Exception:
|
||||
# Don't let this fail logging the exception to the dB
|
||||
pass
|
||||
|
||||
_litellm_params = kwargs.get("litellm_params", {}) or {}
|
||||
_metadata = _litellm_params.get("metadata", {}) or {}
|
||||
_model_id = _metadata.get("model_info", {}).get("id", "")
|
||||
_model_group = _metadata.get("model_group", "")
|
||||
api_base = litellm.get_api_base(model=_model, optional_params=_litellm_params)
|
||||
_exception_string = str(_exception)
|
||||
|
||||
error_log = LiteLLM_ErrorLogs(
|
||||
request_id=str(uuid.uuid4()),
|
||||
model_group=_model_group,
|
||||
model_id=_model_id,
|
||||
litellm_model_name=kwargs.get("model"),
|
||||
request_kwargs=_optional_params,
|
||||
api_base=api_base,
|
||||
exception_type=_exception_type,
|
||||
status_code=_status_code,
|
||||
exception_string=_exception_string,
|
||||
startTime=kwargs.get("start_time"),
|
||||
endTime=kwargs.get("end_time"),
|
||||
)
|
||||
|
||||
error_log_dict = error_log.model_dump()
|
||||
error_log_dict["request_kwargs"] = json.dumps(error_log_dict["request_kwargs"])
|
||||
|
||||
await prisma_client.db.litellm_errorlogs.create(
|
||||
data=error_log_dict # type: ignore
|
||||
)
|
||||
|
||||
pass
|
|
@ -288,12 +288,12 @@ class LiteLLMProxyRequestSetup:
|
|||
|
||||
## KEY-LEVEL SPEND LOGS / TAGS
|
||||
if "tags" in key_metadata and key_metadata["tags"] is not None:
|
||||
if "tags" in data[_metadata_variable_name] and isinstance(
|
||||
data[_metadata_variable_name]["tags"], list
|
||||
):
|
||||
data[_metadata_variable_name]["tags"].extend(key_metadata["tags"])
|
||||
else:
|
||||
data[_metadata_variable_name]["tags"] = key_metadata["tags"]
|
||||
data[_metadata_variable_name]["tags"] = (
|
||||
LiteLLMProxyRequestSetup._merge_tags(
|
||||
request_tags=data[_metadata_variable_name].get("tags"),
|
||||
tags_to_add=key_metadata["tags"],
|
||||
)
|
||||
)
|
||||
if "spend_logs_metadata" in key_metadata and isinstance(
|
||||
key_metadata["spend_logs_metadata"], dict
|
||||
):
|
||||
|
@ -319,6 +319,30 @@ class LiteLLMProxyRequestSetup:
|
|||
data["disable_fallbacks"] = key_metadata["disable_fallbacks"]
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
def _merge_tags(request_tags: Optional[list], tags_to_add: Optional[list]) -> list:
|
||||
"""
|
||||
Helper function to merge two lists of tags, ensuring no duplicates.
|
||||
|
||||
Args:
|
||||
request_tags (Optional[list]): List of tags from the original request
|
||||
tags_to_add (Optional[list]): List of tags to add
|
||||
|
||||
Returns:
|
||||
list: Combined list of unique tags
|
||||
"""
|
||||
final_tags = []
|
||||
|
||||
if request_tags and isinstance(request_tags, list):
|
||||
final_tags.extend(request_tags)
|
||||
|
||||
if tags_to_add and isinstance(tags_to_add, list):
|
||||
for tag in tags_to_add:
|
||||
if tag not in final_tags:
|
||||
final_tags.append(tag)
|
||||
|
||||
return final_tags
|
||||
|
||||
|
||||
async def add_litellm_data_to_request( # noqa: PLR0915
|
||||
data: dict,
|
||||
|
@ -442,12 +466,10 @@ async def add_litellm_data_to_request( # noqa: PLR0915
|
|||
## TEAM-LEVEL SPEND LOGS/TAGS
|
||||
team_metadata = user_api_key_dict.team_metadata or {}
|
||||
if "tags" in team_metadata and team_metadata["tags"] is not None:
|
||||
if "tags" in data[_metadata_variable_name] and isinstance(
|
||||
data[_metadata_variable_name]["tags"], list
|
||||
):
|
||||
data[_metadata_variable_name]["tags"].extend(team_metadata["tags"])
|
||||
else:
|
||||
data[_metadata_variable_name]["tags"] = team_metadata["tags"]
|
||||
data[_metadata_variable_name]["tags"] = LiteLLMProxyRequestSetup._merge_tags(
|
||||
request_tags=data[_metadata_variable_name].get("tags"),
|
||||
tags_to_add=team_metadata["tags"],
|
||||
)
|
||||
if "spend_logs_metadata" in team_metadata and isinstance(
|
||||
team_metadata["spend_logs_metadata"], dict
|
||||
):
|
||||
|
|
|
@ -32,6 +32,7 @@ from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
|||
from litellm.proxy.management_endpoints.key_management_endpoints import (
|
||||
duration_in_seconds,
|
||||
generate_key_helper_fn,
|
||||
prepare_metadata_fields,
|
||||
)
|
||||
from litellm.proxy.management_helpers.utils import (
|
||||
add_new_member,
|
||||
|
@ -42,7 +43,7 @@ from litellm.proxy.utils import handle_exception_on_proxy
|
|||
router = APIRouter()
|
||||
|
||||
|
||||
def _update_internal_user_params(data_json: dict, data: NewUserRequest) -> dict:
|
||||
def _update_internal_new_user_params(data_json: dict, data: NewUserRequest) -> dict:
|
||||
if "user_id" in data_json and data_json["user_id"] is None:
|
||||
data_json["user_id"] = str(uuid.uuid4())
|
||||
auto_create_key = data_json.pop("auto_create_key", True)
|
||||
|
@ -145,7 +146,7 @@ async def new_user(
|
|||
from litellm.proxy.proxy_server import general_settings, proxy_logging_obj
|
||||
|
||||
data_json = data.json() # type: ignore
|
||||
data_json = _update_internal_user_params(data_json, data)
|
||||
data_json = _update_internal_new_user_params(data_json, data)
|
||||
response = await generate_key_helper_fn(request_type="user", **data_json)
|
||||
|
||||
# Admin UI Logic
|
||||
|
@ -438,6 +439,52 @@ async def user_info( # noqa: PLR0915
|
|||
raise handle_exception_on_proxy(e)
|
||||
|
||||
|
||||
def _update_internal_user_params(data_json: dict, data: UpdateUserRequest) -> dict:
|
||||
non_default_values = {}
|
||||
for k, v in data_json.items():
|
||||
if (
|
||||
v is not None
|
||||
and v
|
||||
not in (
|
||||
[],
|
||||
{},
|
||||
0,
|
||||
)
|
||||
and k not in LiteLLM_ManagementEndpoint_MetadataFields
|
||||
): # models default to [], spend defaults to 0, we should not reset these values
|
||||
non_default_values[k] = v
|
||||
|
||||
is_internal_user = False
|
||||
if data.user_role == LitellmUserRoles.INTERNAL_USER:
|
||||
is_internal_user = True
|
||||
|
||||
if "budget_duration" in non_default_values:
|
||||
duration_s = duration_in_seconds(duration=non_default_values["budget_duration"])
|
||||
user_reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s)
|
||||
non_default_values["budget_reset_at"] = user_reset_at
|
||||
|
||||
if "max_budget" not in non_default_values:
|
||||
if (
|
||||
is_internal_user and litellm.max_internal_user_budget is not None
|
||||
): # applies internal user limits, if user role updated
|
||||
non_default_values["max_budget"] = litellm.max_internal_user_budget
|
||||
|
||||
if (
|
||||
"budget_duration" not in non_default_values
|
||||
): # applies internal user limits, if user role updated
|
||||
if is_internal_user and litellm.internal_user_budget_duration is not None:
|
||||
non_default_values["budget_duration"] = (
|
||||
litellm.internal_user_budget_duration
|
||||
)
|
||||
duration_s = duration_in_seconds(
|
||||
duration=non_default_values["budget_duration"]
|
||||
)
|
||||
user_reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s)
|
||||
non_default_values["budget_reset_at"] = user_reset_at
|
||||
|
||||
return non_default_values
|
||||
|
||||
|
||||
@router.post(
|
||||
"/user/update",
|
||||
tags=["Internal User management"],
|
||||
|
@ -459,6 +506,7 @@ async def user_update(
|
|||
"user_id": "test-litellm-user-4",
|
||||
"user_role": "proxy_admin_viewer"
|
||||
}'
|
||||
```
|
||||
|
||||
Parameters:
|
||||
- user_id: Optional[str] - Specify a user id. If not set, a unique id will be generated.
|
||||
|
@ -491,7 +539,7 @@ async def user_update(
|
|||
- duration: Optional[str] - [NOT IMPLEMENTED].
|
||||
- key_alias: Optional[str] - [NOT IMPLEMENTED].
|
||||
|
||||
```
|
||||
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
|
@ -502,46 +550,21 @@ async def user_update(
|
|||
raise Exception("Not connected to DB!")
|
||||
|
||||
# get non default values for key
|
||||
non_default_values = {}
|
||||
for k, v in data_json.items():
|
||||
if v is not None and v not in (
|
||||
[],
|
||||
{},
|
||||
0,
|
||||
): # models default to [], spend defaults to 0, we should not reset these values
|
||||
non_default_values[k] = v
|
||||
non_default_values = _update_internal_user_params(
|
||||
data_json=data_json, data=data
|
||||
)
|
||||
|
||||
is_internal_user = False
|
||||
if data.user_role == LitellmUserRoles.INTERNAL_USER:
|
||||
is_internal_user = True
|
||||
existing_user_row = await prisma_client.get_data(
|
||||
user_id=data.user_id, table_name="user", query_type="find_unique"
|
||||
)
|
||||
|
||||
if "budget_duration" in non_default_values:
|
||||
duration_s = duration_in_seconds(
|
||||
duration=non_default_values["budget_duration"]
|
||||
)
|
||||
user_reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s)
|
||||
non_default_values["budget_reset_at"] = user_reset_at
|
||||
existing_metadata = existing_user_row.metadata if existing_user_row else {}
|
||||
|
||||
if "max_budget" not in non_default_values:
|
||||
if (
|
||||
is_internal_user and litellm.max_internal_user_budget is not None
|
||||
): # applies internal user limits, if user role updated
|
||||
non_default_values["max_budget"] = litellm.max_internal_user_budget
|
||||
|
||||
if (
|
||||
"budget_duration" not in non_default_values
|
||||
): # applies internal user limits, if user role updated
|
||||
if is_internal_user and litellm.internal_user_budget_duration is not None:
|
||||
non_default_values["budget_duration"] = (
|
||||
litellm.internal_user_budget_duration
|
||||
)
|
||||
duration_s = duration_in_seconds(
|
||||
duration=non_default_values["budget_duration"]
|
||||
)
|
||||
user_reset_at = datetime.now(timezone.utc) + timedelta(
|
||||
seconds=duration_s
|
||||
)
|
||||
non_default_values["budget_reset_at"] = user_reset_at
|
||||
non_default_values = prepare_metadata_fields(
|
||||
data=data,
|
||||
non_default_values=non_default_values,
|
||||
existing_metadata=existing_metadata or {},
|
||||
)
|
||||
|
||||
## ADD USER, IF NEW ##
|
||||
verbose_proxy_logger.debug("/user/update: Received data = %s", data)
|
||||
|
|
|
@ -17,7 +17,7 @@ import secrets
|
|||
import traceback
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple, cast
|
||||
|
||||
import fastapi
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Request, status
|
||||
|
@ -394,7 +394,8 @@ async def generate_key_fn( # noqa: PLR0915
|
|||
}
|
||||
)
|
||||
_budget_id = getattr(_budget, "budget_id", None)
|
||||
data_json = data.json() # type: ignore
|
||||
data_json = data.model_dump(exclude_unset=True, exclude_none=True) # type: ignore
|
||||
|
||||
# if we get max_budget passed to /key/generate, then use it as key_max_budget. Since generate_key_helper_fn is used to make new users
|
||||
if "max_budget" in data_json:
|
||||
data_json["key_max_budget"] = data_json.pop("max_budget", None)
|
||||
|
@ -420,6 +421,11 @@ async def generate_key_fn( # noqa: PLR0915
|
|||
|
||||
data_json.pop("tags")
|
||||
|
||||
await _enforce_unique_key_alias(
|
||||
key_alias=data_json.get("key_alias", None),
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
|
||||
response = await generate_key_helper_fn(
|
||||
request_type="key", **data_json, table_name="key"
|
||||
)
|
||||
|
@ -447,12 +453,52 @@ async def generate_key_fn( # noqa: PLR0915
|
|||
raise handle_exception_on_proxy(e)
|
||||
|
||||
|
||||
def prepare_metadata_fields(
|
||||
data: BaseModel, non_default_values: dict, existing_metadata: dict
|
||||
) -> dict:
|
||||
"""
|
||||
Check LiteLLM_ManagementEndpoint_MetadataFields (proxy/_types.py) for fields that are allowed to be updated
|
||||
"""
|
||||
|
||||
if "metadata" not in non_default_values: # allow user to set metadata to none
|
||||
non_default_values["metadata"] = existing_metadata.copy()
|
||||
|
||||
casted_metadata = cast(dict, non_default_values["metadata"])
|
||||
|
||||
data_json = data.model_dump(exclude_unset=True, exclude_none=True)
|
||||
|
||||
try:
|
||||
for k, v in data_json.items():
|
||||
if k == "model_tpm_limit" or k == "model_rpm_limit":
|
||||
if k not in casted_metadata or casted_metadata[k] is None:
|
||||
casted_metadata[k] = {}
|
||||
casted_metadata[k].update(v)
|
||||
|
||||
if k == "tags" or k == "guardrails":
|
||||
if k not in casted_metadata or casted_metadata[k] is None:
|
||||
casted_metadata[k] = []
|
||||
seen = set(casted_metadata[k])
|
||||
casted_metadata[k].extend(
|
||||
x for x in v if x not in seen and not seen.add(x) # type: ignore
|
||||
) # prevent duplicates from being added + maintain initial order
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
"litellm.proxy.proxy_server.prepare_metadata_fields(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
|
||||
non_default_values["metadata"] = casted_metadata
|
||||
return non_default_values
|
||||
|
||||
|
||||
def prepare_key_update_data(
|
||||
data: Union[UpdateKeyRequest, RegenerateKeyRequest], existing_key_row
|
||||
):
|
||||
data_json: dict = data.model_dump(exclude_unset=True)
|
||||
data_json.pop("key", None)
|
||||
_metadata_fields = ["model_rpm_limit", "model_tpm_limit", "guardrails"]
|
||||
_metadata_fields = ["model_rpm_limit", "model_tpm_limit", "guardrails", "tags"]
|
||||
non_default_values = {}
|
||||
for k, v in data_json.items():
|
||||
if k in _metadata_fields:
|
||||
|
@ -480,21 +526,9 @@ def prepare_key_update_data(
|
|||
|
||||
_metadata = existing_key_row.metadata or {}
|
||||
|
||||
if data.model_tpm_limit:
|
||||
if "model_tpm_limit" not in _metadata:
|
||||
_metadata["model_tpm_limit"] = {}
|
||||
_metadata["model_tpm_limit"].update(data.model_tpm_limit)
|
||||
non_default_values["metadata"] = _metadata
|
||||
|
||||
if data.model_rpm_limit:
|
||||
if "model_rpm_limit" not in _metadata:
|
||||
_metadata["model_rpm_limit"] = {}
|
||||
_metadata["model_rpm_limit"].update(data.model_rpm_limit)
|
||||
non_default_values["metadata"] = _metadata
|
||||
|
||||
if data.guardrails:
|
||||
_metadata["guardrails"] = data.guardrails
|
||||
non_default_values["metadata"] = _metadata
|
||||
non_default_values = prepare_metadata_fields(
|
||||
data=data, non_default_values=non_default_values, existing_metadata=_metadata
|
||||
)
|
||||
|
||||
return non_default_values
|
||||
|
||||
|
@ -586,6 +620,12 @@ async def update_key_fn(
|
|||
data=data, existing_key_row=existing_key_row
|
||||
)
|
||||
|
||||
await _enforce_unique_key_alias(
|
||||
key_alias=non_default_values.get("key_alias", None),
|
||||
prisma_client=prisma_client,
|
||||
existing_key_token=existing_key_row.token,
|
||||
)
|
||||
|
||||
response = await prisma_client.update_data(
|
||||
token=key, data={**non_default_values, "token": key}
|
||||
)
|
||||
|
@ -913,11 +953,11 @@ async def generate_key_helper_fn( # noqa: PLR0915
|
|||
request_type: Literal[
|
||||
"user", "key"
|
||||
], # identifies if this request is from /user/new or /key/generate
|
||||
duration: Optional[str],
|
||||
models: list,
|
||||
aliases: dict,
|
||||
config: dict,
|
||||
spend: float,
|
||||
duration: Optional[str] = None,
|
||||
models: list = [],
|
||||
aliases: dict = {},
|
||||
config: dict = {},
|
||||
spend: float = 0.0,
|
||||
key_max_budget: Optional[float] = None, # key_max_budget is used to Budget Per key
|
||||
key_budget_duration: Optional[str] = None,
|
||||
budget_id: Optional[float] = None, # budget id <-> LiteLLM_BudgetTable
|
||||
|
@ -946,8 +986,8 @@ async def generate_key_helper_fn( # noqa: PLR0915
|
|||
allowed_cache_controls: Optional[list] = [],
|
||||
permissions: Optional[dict] = {},
|
||||
model_max_budget: Optional[dict] = {},
|
||||
model_rpm_limit: Optional[dict] = {},
|
||||
model_tpm_limit: Optional[dict] = {},
|
||||
model_rpm_limit: Optional[dict] = None,
|
||||
model_tpm_limit: Optional[dict] = None,
|
||||
guardrails: Optional[list] = None,
|
||||
teams: Optional[list] = None,
|
||||
organization_id: Optional[str] = None,
|
||||
|
@ -1884,3 +1924,38 @@ async def test_key_logging(
|
|||
status="healthy",
|
||||
details=f"No logger exceptions triggered, system is healthy. Manually check if logs were sent to {logging_callbacks} ",
|
||||
)
|
||||
|
||||
|
||||
async def _enforce_unique_key_alias(
|
||||
key_alias: Optional[str],
|
||||
prisma_client: Any,
|
||||
existing_key_token: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Helper to enforce unique key aliases across all keys.
|
||||
|
||||
Args:
|
||||
key_alias (Optional[str]): The key alias to check
|
||||
prisma_client (Any): Prisma client instance
|
||||
existing_key_token (Optional[str]): ID of existing key being updated, to exclude from uniqueness check
|
||||
(The Admin UI passes key_alias, in all Edit key requests. So we need to be sure that if we find a key with the same alias, it's not the same key we're updating)
|
||||
|
||||
Raises:
|
||||
ProxyException: If key alias already exists on a different key
|
||||
"""
|
||||
if key_alias is not None and prisma_client is not None:
|
||||
where_clause: dict[str, Any] = {"key_alias": key_alias}
|
||||
if existing_key_token:
|
||||
# Exclude the current key from the uniqueness check
|
||||
where_clause["NOT"] = {"token": existing_key_token}
|
||||
|
||||
existing_key = await prisma_client.db.litellm_verificationtoken.find_first(
|
||||
where=where_clause
|
||||
)
|
||||
if existing_key is not None:
|
||||
raise ProxyException(
|
||||
message=f"Key with alias '{key_alias}' already exists. Unique key aliases across all keys are required.",
|
||||
type=ProxyErrorTypes.bad_request_error,
|
||||
param="key_alias",
|
||||
code=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
|
10
litellm/proxy/model_config.yaml
Normal file
10
litellm/proxy/model_config.yaml
Normal file
|
@ -0,0 +1,10 @@
|
|||
model_list:
|
||||
- model_name: gpt-4o
|
||||
litellm_params:
|
||||
model: openai/gpt-4o
|
||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||
- model_name: fake-anthropic-endpoint
|
||||
litellm_params:
|
||||
model: anthropic/fake
|
||||
api_base: https://exampleanthropicendpoint-production.up.railway.app/
|
||||
|
|
@ -1,24 +1,5 @@
|
|||
model_list:
|
||||
- model_name: gpt-4o
|
||||
litellm_params:
|
||||
model: openai/gpt-4o
|
||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||
- model_name: fake-anthropic-endpoint
|
||||
litellm_params:
|
||||
model: anthropic/fake
|
||||
api_base: https://exampleanthropicendpoint-production.up.railway.app/
|
||||
|
||||
router_settings:
|
||||
provider_budget_config:
|
||||
openai:
|
||||
budget_limit: 0.3 # float of $ value budget for time period
|
||||
time_period: 1d # can be 1d, 2d, 30d
|
||||
anthropic:
|
||||
budget_limit: 5
|
||||
time_period: 1d
|
||||
redis_host: os.environ/REDIS_HOST
|
||||
redis_port: os.environ/REDIS_PORT
|
||||
redis_password: os.environ/REDIS_PASSWORD
|
||||
include:
|
||||
- model_config.yaml
|
||||
|
||||
litellm_settings:
|
||||
callbacks: ["datadog"]
|
||||
|
|
|
@ -176,6 +176,7 @@ from litellm.proxy.health_endpoints._health_endpoints import router as health_ro
|
|||
from litellm.proxy.hooks.prompt_injection_detection import (
|
||||
_OPTIONAL_PromptInjectionDetection,
|
||||
)
|
||||
from litellm.proxy.hooks.proxy_failure_handler import _PROXY_failure_handler
|
||||
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
|
||||
from litellm.proxy.management_endpoints.customer_endpoints import (
|
||||
router as customer_router,
|
||||
|
@ -529,14 +530,6 @@ db_writer_client: Optional[HTTPHandler] = None
|
|||
### logger ###
|
||||
|
||||
|
||||
def _get_pydantic_json_dict(pydantic_obj: BaseModel) -> dict:
|
||||
try:
|
||||
return pydantic_obj.model_dump() # type: ignore
|
||||
except Exception:
|
||||
# if using pydantic v1
|
||||
return pydantic_obj.dict()
|
||||
|
||||
|
||||
def get_custom_headers(
|
||||
*,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
|
@ -690,68 +683,6 @@ def cost_tracking():
|
|||
litellm._async_success_callback.append(_PROXY_track_cost_callback) # type: ignore
|
||||
|
||||
|
||||
async def _PROXY_failure_handler(
|
||||
kwargs, # kwargs to completion
|
||||
completion_response: litellm.ModelResponse, # response from completion
|
||||
start_time=None,
|
||||
end_time=None, # start/end time for completion
|
||||
):
|
||||
global prisma_client
|
||||
if prisma_client is not None:
|
||||
verbose_proxy_logger.debug(
|
||||
"inside _PROXY_failure_handler kwargs=", extra=kwargs
|
||||
)
|
||||
|
||||
_exception = kwargs.get("exception")
|
||||
_exception_type = _exception.__class__.__name__
|
||||
_model = kwargs.get("model", None)
|
||||
|
||||
_optional_params = kwargs.get("optional_params", {})
|
||||
_optional_params = copy.deepcopy(_optional_params)
|
||||
|
||||
for k, v in _optional_params.items():
|
||||
v = str(v)
|
||||
v = v[:100]
|
||||
|
||||
_status_code = "500"
|
||||
try:
|
||||
_status_code = str(_exception.status_code)
|
||||
except Exception:
|
||||
# Don't let this fail logging the exception to the dB
|
||||
pass
|
||||
|
||||
_litellm_params = kwargs.get("litellm_params", {}) or {}
|
||||
_metadata = _litellm_params.get("metadata", {}) or {}
|
||||
_model_id = _metadata.get("model_info", {}).get("id", "")
|
||||
_model_group = _metadata.get("model_group", "")
|
||||
api_base = litellm.get_api_base(model=_model, optional_params=_litellm_params)
|
||||
_exception_string = str(_exception)
|
||||
|
||||
error_log = LiteLLM_ErrorLogs(
|
||||
request_id=str(uuid.uuid4()),
|
||||
model_group=_model_group,
|
||||
model_id=_model_id,
|
||||
litellm_model_name=kwargs.get("model"),
|
||||
request_kwargs=_optional_params,
|
||||
api_base=api_base,
|
||||
exception_type=_exception_type,
|
||||
status_code=_status_code,
|
||||
exception_string=_exception_string,
|
||||
startTime=kwargs.get("start_time"),
|
||||
endTime=kwargs.get("end_time"),
|
||||
)
|
||||
|
||||
# helper function to convert to dict on pydantic v2 & v1
|
||||
error_log_dict = _get_pydantic_json_dict(error_log)
|
||||
error_log_dict["request_kwargs"] = json.dumps(error_log_dict["request_kwargs"])
|
||||
|
||||
await prisma_client.db.litellm_errorlogs.create(
|
||||
data=error_log_dict # type: ignore
|
||||
)
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@log_db_metrics
|
||||
async def _PROXY_track_cost_callback(
|
||||
kwargs, # kwargs to completion
|
||||
|
@ -1380,6 +1311,16 @@ class ProxyConfig:
|
|||
_, file_extension = os.path.splitext(config_file_path)
|
||||
return file_extension.lower() == ".yaml" or file_extension.lower() == ".yml"
|
||||
|
||||
def _load_yaml_file(self, file_path: str) -> dict:
|
||||
"""
|
||||
Load and parse a YAML file
|
||||
"""
|
||||
try:
|
||||
with open(file_path, "r") as file:
|
||||
return yaml.safe_load(file) or {}
|
||||
except Exception as e:
|
||||
raise Exception(f"Error loading yaml file {file_path}: {str(e)}")
|
||||
|
||||
async def _get_config_from_file(
|
||||
self, config_file_path: Optional[str] = None
|
||||
) -> dict:
|
||||
|
@ -1410,6 +1351,51 @@ class ProxyConfig:
|
|||
"litellm_settings": {},
|
||||
}
|
||||
|
||||
# Process includes
|
||||
config = self._process_includes(
|
||||
config=config, base_dir=os.path.dirname(os.path.abspath(file_path or ""))
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(f"loaded config={json.dumps(config, indent=4)}")
|
||||
return config
|
||||
|
||||
def _process_includes(self, config: dict, base_dir: str) -> dict:
|
||||
"""
|
||||
Process includes by appending their contents to the main config
|
||||
|
||||
Handles nested config.yamls with `include` section
|
||||
|
||||
Example config: This will get the contents from files in `include` and append it
|
||||
```yaml
|
||||
include:
|
||||
- model_config.yaml
|
||||
|
||||
litellm_settings:
|
||||
callbacks: ["prometheus"]
|
||||
```
|
||||
"""
|
||||
if "include" not in config:
|
||||
return config
|
||||
|
||||
if not isinstance(config["include"], list):
|
||||
raise ValueError("'include' must be a list of file paths")
|
||||
|
||||
# Load and append all included files
|
||||
for include_file in config["include"]:
|
||||
file_path = os.path.join(base_dir, include_file)
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"Included file not found: {file_path}")
|
||||
|
||||
included_config = self._load_yaml_file(file_path)
|
||||
# Simply update/extend the main config with included config
|
||||
for key, value in included_config.items():
|
||||
if isinstance(value, list) and key in config:
|
||||
config[key].extend(value)
|
||||
else:
|
||||
config[key] = value
|
||||
|
||||
# Remove the include directive
|
||||
del config["include"]
|
||||
return config
|
||||
|
||||
async def save_config(self, new_config: dict):
|
||||
|
|
|
@ -86,7 +86,6 @@ async def route_request(
|
|||
else:
|
||||
models = [model.strip() for model in data.pop("model").split(",")]
|
||||
return llm_router.abatch_completion(models=models, **data)
|
||||
|
||||
elif llm_router is not None:
|
||||
if (
|
||||
data["model"] in router_model_names
|
||||
|
@ -113,6 +112,9 @@ async def route_request(
|
|||
or len(llm_router.pattern_router.patterns) > 0
|
||||
):
|
||||
return getattr(llm_router, f"{route_type}")(**data)
|
||||
elif route_type == "amoderation":
|
||||
# moderation endpoint does not require `model` parameter
|
||||
return getattr(llm_router, f"{route_type}")(**data)
|
||||
|
||||
elif user_model is not None:
|
||||
return getattr(litellm, f"{route_type}")(**data)
|
||||
|
|
|
@ -2563,10 +2563,7 @@ class Router:
|
|||
original_function: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
if (
|
||||
"model" in kwargs
|
||||
and self.get_model_list(model_name=kwargs["model"]) is not None
|
||||
):
|
||||
if kwargs.get("model") and self.get_model_list(model_name=kwargs["model"]):
|
||||
deployment = await self.async_get_available_deployment(
|
||||
model=kwargs["model"]
|
||||
)
|
||||
|
@ -4715,6 +4712,9 @@ class Router:
|
|||
if hasattr(self, "model_list"):
|
||||
returned_models: List[DeploymentTypedDict] = []
|
||||
|
||||
if model_name is not None:
|
||||
returned_models.extend(self._get_all_deployments(model_name=model_name))
|
||||
|
||||
if hasattr(self, "model_group_alias"):
|
||||
for model_alias, model_value in self.model_group_alias.items():
|
||||
|
||||
|
@ -4746,17 +4746,21 @@ class Router:
|
|||
returned_models += self.model_list
|
||||
|
||||
return returned_models
|
||||
returned_models.extend(self._get_all_deployments(model_name=model_name))
|
||||
|
||||
return returned_models
|
||||
return None
|
||||
|
||||
def get_model_access_groups(self):
|
||||
def get_model_access_groups(self, model_name: Optional[str] = None):
|
||||
"""
|
||||
If model_name is provided, only return access groups for that model.
|
||||
"""
|
||||
from collections import defaultdict
|
||||
|
||||
access_groups = defaultdict(list)
|
||||
|
||||
if self.model_list:
|
||||
for m in self.model_list:
|
||||
model_list = self.get_model_list(model_name=model_name)
|
||||
if model_list:
|
||||
for m in model_list:
|
||||
for group in m.get("model_info", {}).get("access_groups", []):
|
||||
model_name = m["model_name"]
|
||||
access_groups[group].append(model_name)
|
||||
|
|
|
@ -79,7 +79,9 @@ class PatternMatchRouter:
|
|||
|
||||
return new_deployments
|
||||
|
||||
def route(self, request: Optional[str]) -> Optional[List[Dict]]:
|
||||
def route(
|
||||
self, request: Optional[str], filtered_model_names: Optional[List[str]] = None
|
||||
) -> Optional[List[Dict]]:
|
||||
"""
|
||||
Route a requested model to the corresponding llm deployments based on the regex pattern
|
||||
|
||||
|
@ -89,14 +91,26 @@ class PatternMatchRouter:
|
|||
|
||||
Args:
|
||||
request: Optional[str]
|
||||
|
||||
filtered_model_names: Optional[List[str]] - if provided, only return deployments that match the filtered_model_names
|
||||
Returns:
|
||||
Optional[List[Deployment]]: llm deployments
|
||||
"""
|
||||
try:
|
||||
if request is None:
|
||||
return None
|
||||
|
||||
regex_filtered_model_names = (
|
||||
[self._pattern_to_regex(m) for m in filtered_model_names]
|
||||
if filtered_model_names is not None
|
||||
else []
|
||||
)
|
||||
|
||||
for pattern, llm_deployments in self.patterns.items():
|
||||
if (
|
||||
filtered_model_names is not None
|
||||
and pattern not in regex_filtered_model_names
|
||||
):
|
||||
continue
|
||||
pattern_match = re.match(pattern, request)
|
||||
if pattern_match:
|
||||
return self._return_pattern_matched_deployments(
|
||||
|
|
|
@ -355,7 +355,7 @@ class LiteLLMParamsTypedDict(TypedDict, total=False):
|
|||
class DeploymentTypedDict(TypedDict, total=False):
|
||||
model_name: Required[str]
|
||||
litellm_params: Required[LiteLLMParamsTypedDict]
|
||||
model_info: Optional[dict]
|
||||
model_info: dict
|
||||
|
||||
|
||||
SPECIAL_MODEL_INFO_PARAMS = [
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[tool.poetry]
|
||||
name = "litellm"
|
||||
version = "1.53.1"
|
||||
version = "1.53.2"
|
||||
description = "Library to easily interface with LLM API providers"
|
||||
authors = ["BerriAI"]
|
||||
license = "MIT"
|
||||
|
@ -91,7 +91,7 @@ requires = ["poetry-core", "wheel"]
|
|||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
[tool.commitizen]
|
||||
version = "1.53.1"
|
||||
version = "1.53.2"
|
||||
version_files = [
|
||||
"pyproject.toml:^version"
|
||||
]
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# LITELLM PROXY DEPENDENCIES #
|
||||
anyio==4.4.0 # openai + http req.
|
||||
openai==1.54.0 # openai req.
|
||||
openai==1.55.3 # openai req.
|
||||
fastapi==0.111.0 # server dep
|
||||
backoff==2.2.1 # server dep
|
||||
pyyaml==6.0.0 # server dep
|
||||
|
|
|
@ -1 +1,3 @@
|
|||
More tests under `litellm/litellm/tests/*`.
|
||||
Unit tests for individual LLM providers.
|
||||
|
||||
Name of the test file is the name of the LLM provider - e.g. `test_openai.py` is for OpenAI.
|
File diff suppressed because one or more lines are too long
|
@ -45,81 +45,59 @@ def test_map_azure_model_group(model_group_header, expected_model):
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.respx
|
||||
async def test_azure_ai_with_image_url(respx_mock: MockRouter):
|
||||
async def test_azure_ai_with_image_url():
|
||||
"""
|
||||
Important test:
|
||||
|
||||
Test that Azure AI studio can handle image_url passed when content is a list containing both text and image_url
|
||||
"""
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
litellm.set_verbose = True
|
||||
|
||||
# Mock response based on the actual API response
|
||||
mock_response = {
|
||||
"id": "cmpl-53860ea1efa24d2883555bfec13d2254",
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
"logprobs": None,
|
||||
"message": {
|
||||
"content": "The image displays a graphic with the text 'LiteLLM' in black",
|
||||
"role": "assistant",
|
||||
"refusal": None,
|
||||
"audio": None,
|
||||
"function_call": None,
|
||||
"tool_calls": None,
|
||||
},
|
||||
}
|
||||
],
|
||||
"created": 1731801937,
|
||||
"model": "phi35-vision-instruct",
|
||||
"object": "chat.completion",
|
||||
"usage": {
|
||||
"completion_tokens": 69,
|
||||
"prompt_tokens": 617,
|
||||
"total_tokens": 686,
|
||||
"completion_tokens_details": None,
|
||||
"prompt_tokens_details": None,
|
||||
},
|
||||
}
|
||||
|
||||
# Mock the API request
|
||||
mock_request = respx_mock.post(
|
||||
"https://Phi-3-5-vision-instruct-dcvov.eastus2.models.ai.azure.com"
|
||||
).mock(return_value=httpx.Response(200, json=mock_response))
|
||||
|
||||
response = await litellm.acompletion(
|
||||
model="azure_ai/Phi-3-5-vision-instruct-dcvov",
|
||||
api_base="https://Phi-3-5-vision-instruct-dcvov.eastus2.models.ai.azure.com",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What is in this image?",
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://litellm-listing.s3.amazonaws.com/litellm_logo.png"
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
client = AsyncOpenAI(
|
||||
api_key="fake-api-key",
|
||||
base_url="https://Phi-3-5-vision-instruct-dcvov.eastus2.models.ai.azure.com",
|
||||
)
|
||||
|
||||
# Verify the request was made
|
||||
assert mock_request.called
|
||||
with patch.object(
|
||||
client.chat.completions.with_raw_response, "create"
|
||||
) as mock_client:
|
||||
try:
|
||||
await litellm.acompletion(
|
||||
model="azure_ai/Phi-3-5-vision-instruct-dcvov",
|
||||
api_base="https://Phi-3-5-vision-instruct-dcvov.eastus2.models.ai.azure.com",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What is in this image?",
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://litellm-listing.s3.amazonaws.com/litellm_logo.png"
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
api_key="fake-api-key",
|
||||
client=client,
|
||||
)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
print(f"Error: {e}")
|
||||
|
||||
# Check the request body
|
||||
request_body = json.loads(mock_request.calls[0].request.content)
|
||||
assert request_body == {
|
||||
"model": "Phi-3-5-vision-instruct-dcvov",
|
||||
"messages": [
|
||||
# Verify the request was made
|
||||
mock_client.assert_called_once()
|
||||
|
||||
# Check the request body
|
||||
request_body = mock_client.call_args.kwargs
|
||||
assert request_body["model"] == "Phi-3-5-vision-instruct-dcvov"
|
||||
assert request_body["messages"] == [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
|
@ -132,7 +110,4 @@ async def test_azure_ai_with_image_url(respx_mock: MockRouter):
|
|||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
print(f"response: {response}")
|
||||
]
|
||||
|
|
|
@ -13,6 +13,7 @@ load_dotenv()
|
|||
import httpx
|
||||
import pytest
|
||||
from respx import MockRouter
|
||||
from unittest.mock import patch, MagicMock, AsyncMock
|
||||
|
||||
import litellm
|
||||
from litellm import Choices, Message, ModelResponse
|
||||
|
@ -41,56 +42,58 @@ def return_mocked_response(model: str):
|
|||
"bedrock/mistral.mistral-large-2407-v1:0",
|
||||
],
|
||||
)
|
||||
@pytest.mark.respx
|
||||
@pytest.mark.asyncio()
|
||||
async def test_bedrock_max_completion_tokens(model: str, respx_mock: MockRouter):
|
||||
async def test_bedrock_max_completion_tokens(model: str):
|
||||
"""
|
||||
Tests that:
|
||||
- max_completion_tokens is passed as max_tokens to bedrock models
|
||||
"""
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||
|
||||
litellm.set_verbose = True
|
||||
|
||||
client = AsyncHTTPHandler()
|
||||
|
||||
mock_response = return_mocked_response(model)
|
||||
_model = model.split("/")[1]
|
||||
print("\n\nmock_response: ", mock_response)
|
||||
url = f"https://bedrock-runtime.us-west-2.amazonaws.com/model/{_model}/converse"
|
||||
mock_request = respx_mock.post(url).mock(
|
||||
return_value=httpx.Response(200, json=mock_response)
|
||||
)
|
||||
|
||||
response = await litellm.acompletion(
|
||||
model=model,
|
||||
max_completion_tokens=10,
|
||||
messages=[{"role": "user", "content": "Hello!"}],
|
||||
)
|
||||
with patch.object(client, "post") as mock_client:
|
||||
try:
|
||||
response = await litellm.acompletion(
|
||||
model=model,
|
||||
max_completion_tokens=10,
|
||||
messages=[{"role": "user", "content": "Hello!"}],
|
||||
client=client,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
|
||||
assert mock_request.called
|
||||
request_body = json.loads(mock_request.calls[0].request.content)
|
||||
mock_client.assert_called_once()
|
||||
request_body = json.loads(mock_client.call_args.kwargs["data"])
|
||||
|
||||
print("request_body: ", request_body)
|
||||
print("request_body: ", request_body)
|
||||
|
||||
assert request_body == {
|
||||
"messages": [{"role": "user", "content": [{"text": "Hello!"}]}],
|
||||
"additionalModelRequestFields": {},
|
||||
"system": [],
|
||||
"inferenceConfig": {"maxTokens": 10},
|
||||
}
|
||||
print(f"response: {response}")
|
||||
assert isinstance(response, ModelResponse)
|
||||
assert request_body == {
|
||||
"messages": [{"role": "user", "content": [{"text": "Hello!"}]}],
|
||||
"additionalModelRequestFields": {},
|
||||
"system": [],
|
||||
"inferenceConfig": {"maxTokens": 10},
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
["anthropic/claude-3-sonnet-20240229", "anthropic/claude-3-opus-20240229,"],
|
||||
["anthropic/claude-3-sonnet-20240229", "anthropic/claude-3-opus-20240229"],
|
||||
)
|
||||
@pytest.mark.respx
|
||||
@pytest.mark.asyncio()
|
||||
async def test_anthropic_api_max_completion_tokens(model: str, respx_mock: MockRouter):
|
||||
async def test_anthropic_api_max_completion_tokens(model: str):
|
||||
"""
|
||||
Tests that:
|
||||
- max_completion_tokens is passed as max_tokens to anthropic models
|
||||
"""
|
||||
litellm.set_verbose = True
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
|
||||
mock_response = {
|
||||
"content": [{"text": "Hi! My name is Claude.", "type": "text"}],
|
||||
|
@ -103,30 +106,32 @@ async def test_anthropic_api_max_completion_tokens(model: str, respx_mock: MockR
|
|||
"usage": {"input_tokens": 2095, "output_tokens": 503},
|
||||
}
|
||||
|
||||
client = HTTPHandler()
|
||||
|
||||
print("\n\nmock_response: ", mock_response)
|
||||
url = f"https://api.anthropic.com/v1/messages"
|
||||
mock_request = respx_mock.post(url).mock(
|
||||
return_value=httpx.Response(200, json=mock_response)
|
||||
)
|
||||
|
||||
response = await litellm.acompletion(
|
||||
model=model,
|
||||
max_completion_tokens=10,
|
||||
messages=[{"role": "user", "content": "Hello!"}],
|
||||
)
|
||||
with patch.object(client, "post") as mock_client:
|
||||
try:
|
||||
response = await litellm.acompletion(
|
||||
model=model,
|
||||
max_completion_tokens=10,
|
||||
messages=[{"role": "user", "content": "Hello!"}],
|
||||
client=client,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
mock_client.assert_called_once()
|
||||
request_body = mock_client.call_args.kwargs["json"]
|
||||
|
||||
assert mock_request.called
|
||||
request_body = json.loads(mock_request.calls[0].request.content)
|
||||
print("request_body: ", request_body)
|
||||
|
||||
print("request_body: ", request_body)
|
||||
|
||||
assert request_body == {
|
||||
"messages": [{"role": "user", "content": [{"type": "text", "text": "Hello!"}]}],
|
||||
"max_tokens": 10,
|
||||
"model": model.split("/")[-1],
|
||||
}
|
||||
print(f"response: {response}")
|
||||
assert isinstance(response, ModelResponse)
|
||||
assert request_body == {
|
||||
"messages": [
|
||||
{"role": "user", "content": [{"type": "text", "text": "Hello!"}]}
|
||||
],
|
||||
"max_tokens": 10,
|
||||
"model": model.split("/")[-1],
|
||||
}
|
||||
|
||||
|
||||
def test_all_model_configs():
|
||||
|
|
|
@ -12,95 +12,78 @@ sys.path.insert(
|
|||
import httpx
|
||||
import pytest
|
||||
from respx import MockRouter
|
||||
from unittest.mock import patch, MagicMock, AsyncMock
|
||||
|
||||
import litellm
|
||||
from litellm import Choices, Message, ModelResponse, EmbeddingResponse, Usage
|
||||
from litellm import completion
|
||||
|
||||
|
||||
@pytest.mark.respx
|
||||
def test_completion_nvidia_nim(respx_mock: MockRouter):
|
||||
def test_completion_nvidia_nim():
|
||||
from openai import OpenAI
|
||||
|
||||
litellm.set_verbose = True
|
||||
mock_response = ModelResponse(
|
||||
id="cmpl-mock",
|
||||
choices=[Choices(message=Message(content="Mocked response", role="assistant"))],
|
||||
created=int(datetime.now().timestamp()),
|
||||
model="databricks/dbrx-instruct",
|
||||
)
|
||||
model_name = "nvidia_nim/databricks/dbrx-instruct"
|
||||
client = OpenAI(
|
||||
api_key="fake-api-key",
|
||||
)
|
||||
|
||||
mock_request = respx_mock.post(
|
||||
"https://integrate.api.nvidia.com/v1/chat/completions"
|
||||
).mock(return_value=httpx.Response(200, json=mock_response.dict()))
|
||||
try:
|
||||
response = completion(
|
||||
model=model_name,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the weather like in Boston today in Fahrenheit?",
|
||||
}
|
||||
],
|
||||
presence_penalty=0.5,
|
||||
frequency_penalty=0.1,
|
||||
)
|
||||
with patch.object(
|
||||
client.chat.completions.with_raw_response, "create"
|
||||
) as mock_client:
|
||||
try:
|
||||
completion(
|
||||
model=model_name,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the weather like in Boston today in Fahrenheit?",
|
||||
}
|
||||
],
|
||||
presence_penalty=0.5,
|
||||
frequency_penalty=0.1,
|
||||
client=client,
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
# Add any assertions here to check the response
|
||||
print(response)
|
||||
assert response.choices[0].message.content is not None
|
||||
assert len(response.choices[0].message.content) > 0
|
||||
|
||||
assert mock_request.called
|
||||
request_body = json.loads(mock_request.calls[0].request.content)
|
||||
mock_client.assert_called_once()
|
||||
request_body = mock_client.call_args.kwargs
|
||||
|
||||
print("request_body: ", request_body)
|
||||
|
||||
assert request_body == {
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the weather like in Boston today in Fahrenheit?",
|
||||
}
|
||||
],
|
||||
"model": "databricks/dbrx-instruct",
|
||||
"frequency_penalty": 0.1,
|
||||
"presence_penalty": 0.5,
|
||||
}
|
||||
except litellm.exceptions.Timeout as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
def test_embedding_nvidia_nim(respx_mock: MockRouter):
|
||||
litellm.set_verbose = True
|
||||
mock_response = EmbeddingResponse(
|
||||
model="nvidia_nim/databricks/dbrx-instruct",
|
||||
data=[
|
||||
assert request_body["messages"] == [
|
||||
{
|
||||
"embedding": [0.1, 0.2, 0.3],
|
||||
"index": 0,
|
||||
}
|
||||
],
|
||||
usage=Usage(
|
||||
prompt_tokens=10,
|
||||
completion_tokens=0,
|
||||
total_tokens=10,
|
||||
),
|
||||
"role": "user",
|
||||
"content": "What's the weather like in Boston today in Fahrenheit?",
|
||||
},
|
||||
]
|
||||
assert request_body["model"] == "databricks/dbrx-instruct"
|
||||
assert request_body["frequency_penalty"] == 0.1
|
||||
assert request_body["presence_penalty"] == 0.5
|
||||
|
||||
|
||||
def test_embedding_nvidia_nim():
|
||||
litellm.set_verbose = True
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(
|
||||
api_key="fake-api-key",
|
||||
)
|
||||
mock_request = respx_mock.post(
|
||||
"https://integrate.api.nvidia.com/v1/embeddings"
|
||||
).mock(return_value=httpx.Response(200, json=mock_response.dict()))
|
||||
response = litellm.embedding(
|
||||
model="nvidia_nim/nvidia/nv-embedqa-e5-v5",
|
||||
input="What is the meaning of life?",
|
||||
input_type="passage",
|
||||
)
|
||||
assert mock_request.called
|
||||
request_body = json.loads(mock_request.calls[0].request.content)
|
||||
print("request_body: ", request_body)
|
||||
assert request_body == {
|
||||
"input": "What is the meaning of life?",
|
||||
"model": "nvidia/nv-embedqa-e5-v5",
|
||||
"input_type": "passage",
|
||||
"encoding_format": "base64",
|
||||
}
|
||||
with patch.object(client.embeddings.with_raw_response, "create") as mock_client:
|
||||
try:
|
||||
litellm.embedding(
|
||||
model="nvidia_nim/nvidia/nv-embedqa-e5-v5",
|
||||
input="What is the meaning of life?",
|
||||
input_type="passage",
|
||||
client=client,
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
mock_client.assert_called_once()
|
||||
request_body = mock_client.call_args.kwargs
|
||||
print("request_body: ", request_body)
|
||||
assert request_body["input"] == "What is the meaning of life?"
|
||||
assert request_body["model"] == "nvidia/nv-embedqa-e5-v5"
|
||||
assert request_body["extra_body"]["input_type"] == "passage"
|
||||
|
|
|
@ -2,7 +2,7 @@ import json
|
|||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
|
@ -63,8 +63,7 @@ def test_openai_prediction_param():
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.respx
|
||||
async def test_openai_prediction_param_mock(respx_mock: MockRouter):
|
||||
async def test_openai_prediction_param_mock():
|
||||
"""
|
||||
Tests that prediction parameter is correctly passed to the API
|
||||
"""
|
||||
|
@ -92,60 +91,36 @@ async def test_openai_prediction_param_mock(respx_mock: MockRouter):
|
|||
public string Username { get; set; }
|
||||
}
|
||||
"""
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
mock_response = ModelResponse(
|
||||
id="chatcmpl-AQ5RmV8GvVSRxEcDxnuXlQnsibiY9",
|
||||
choices=[
|
||||
Choices(
|
||||
message=Message(
|
||||
content=code.replace("Username", "Email").replace(
|
||||
"username", "email"
|
||||
),
|
||||
role="assistant",
|
||||
)
|
||||
client = AsyncOpenAI(api_key="fake-api-key")
|
||||
|
||||
with patch.object(
|
||||
client.chat.completions.with_raw_response, "create"
|
||||
) as mock_client:
|
||||
try:
|
||||
await litellm.acompletion(
|
||||
model="gpt-4o-mini",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Replace the Username property with an Email property. Respond only with code, and with no markdown formatting.",
|
||||
},
|
||||
{"role": "user", "content": code},
|
||||
],
|
||||
prediction={"type": "content", "content": code},
|
||||
client=client,
|
||||
)
|
||||
],
|
||||
created=int(datetime.now().timestamp()),
|
||||
model="gpt-4o-mini-2024-07-18",
|
||||
usage={
|
||||
"completion_tokens": 207,
|
||||
"prompt_tokens": 175,
|
||||
"total_tokens": 382,
|
||||
"completion_tokens_details": {
|
||||
"accepted_prediction_tokens": 0,
|
||||
"reasoning_tokens": 0,
|
||||
"rejected_prediction_tokens": 80,
|
||||
},
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
|
||||
mock_request = respx_mock.post("https://api.openai.com/v1/chat/completions").mock(
|
||||
return_value=httpx.Response(200, json=mock_response.dict())
|
||||
)
|
||||
mock_client.assert_called_once()
|
||||
request_body = mock_client.call_args.kwargs
|
||||
|
||||
completion = await litellm.acompletion(
|
||||
model="gpt-4o-mini",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Replace the Username property with an Email property. Respond only with code, and with no markdown formatting.",
|
||||
},
|
||||
{"role": "user", "content": code},
|
||||
],
|
||||
prediction={"type": "content", "content": code},
|
||||
)
|
||||
|
||||
assert mock_request.called
|
||||
request_body = json.loads(mock_request.calls[0].request.content)
|
||||
|
||||
# Verify the request contains the prediction parameter
|
||||
assert "prediction" in request_body
|
||||
# verify prediction is correctly sent to the API
|
||||
assert request_body["prediction"] == {"type": "content", "content": code}
|
||||
|
||||
# Verify the completion tokens details
|
||||
assert completion.usage.completion_tokens_details.accepted_prediction_tokens == 0
|
||||
assert completion.usage.completion_tokens_details.rejected_prediction_tokens == 80
|
||||
# Verify the request contains the prediction parameter
|
||||
assert "prediction" in request_body
|
||||
# verify prediction is correctly sent to the API
|
||||
assert request_body["prediction"] == {"type": "content", "content": code}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -223,3 +198,73 @@ async def test_openai_prediction_param_with_caching():
|
|||
)
|
||||
|
||||
assert completion_response_3.id != completion_response_1.id
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_vision_with_custom_model():
|
||||
"""
|
||||
Tests that an OpenAI compatible endpoint when sent an image will receive the image in the request
|
||||
|
||||
"""
|
||||
import base64
|
||||
import requests
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
client = AsyncOpenAI(api_key="fake-api-key")
|
||||
|
||||
litellm.set_verbose = True
|
||||
api_base = "https://my-custom.api.openai.com"
|
||||
|
||||
# Fetch and encode a test image
|
||||
url = "https://dummyimage.com/100/100/fff&text=Test+image"
|
||||
response = requests.get(url)
|
||||
file_data = response.content
|
||||
encoded_file = base64.b64encode(file_data).decode("utf-8")
|
||||
base64_image = f"data:image/png;base64,{encoded_file}"
|
||||
|
||||
with patch.object(
|
||||
client.chat.completions.with_raw_response, "create"
|
||||
) as mock_client:
|
||||
try:
|
||||
response = await litellm.acompletion(
|
||||
model="openai/my-custom-model",
|
||||
max_tokens=10,
|
||||
api_base=api_base, # use the mock api
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": base64_image},
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
client=client,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
|
||||
mock_client.assert_called_once()
|
||||
request_body = mock_client.call_args.kwargs
|
||||
|
||||
print("request_body: ", request_body)
|
||||
|
||||
assert request_body["messages"] == [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAGQAAABkBAMAAACCzIhnAAAAG1BMVEURAAD///+ln5/h39/Dv79qX18uHx+If39MPz9oMSdmAAAACXBIWXMAAA7EAAAOxAGVKw4bAAABB0lEQVRYhe2SzWrEIBCAh2A0jxEs4j6GLDS9hqWmV5Flt0cJS+lRwv742DXpEjY1kOZW6HwHFZnPmVEBEARBEARB/jd0KYA/bcUYbPrRLh6amXHJ/K+ypMoyUaGthILzw0l+xI0jsO7ZcmCcm4ILd+QuVYgpHOmDmz6jBeJImdcUCmeBqQpuqRIbVmQsLCrAalrGpfoEqEogqbLTWuXCPCo+Ki1XGqgQ+jVVuhB8bOaHkvmYuzm/b0KYLWwoK58oFqi6XfxQ4Uz7d6WeKpna6ytUs5e8betMcqAv5YPC5EZB2Lm9FIn0/VP6R58+/GEY1X1egVoZ/3bt/EqF6malgSAIgiDIH+QL41409QMY0LMAAAAASUVORK5CYII="
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
assert request_body["model"] == "my-custom-model"
|
||||
assert request_body["max_tokens"] == 10
|
|
@ -2,7 +2,7 @@ import json
|
|||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
|
@ -18,87 +18,75 @@ from litellm import Choices, Message, ModelResponse
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.respx
|
||||
async def test_o1_handle_system_role(respx_mock: MockRouter):
|
||||
async def test_o1_handle_system_role():
|
||||
"""
|
||||
Tests that:
|
||||
- max_tokens is translated to 'max_completion_tokens'
|
||||
- role 'system' is translated to 'user'
|
||||
"""
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
litellm.set_verbose = True
|
||||
|
||||
mock_response = ModelResponse(
|
||||
id="cmpl-mock",
|
||||
choices=[Choices(message=Message(content="Mocked response", role="assistant"))],
|
||||
created=int(datetime.now().timestamp()),
|
||||
model="o1-preview",
|
||||
)
|
||||
client = AsyncOpenAI(api_key="fake-api-key")
|
||||
|
||||
mock_request = respx_mock.post("https://api.openai.com/v1/chat/completions").mock(
|
||||
return_value=httpx.Response(200, json=mock_response.dict())
|
||||
)
|
||||
with patch.object(
|
||||
client.chat.completions.with_raw_response, "create"
|
||||
) as mock_client:
|
||||
try:
|
||||
await litellm.acompletion(
|
||||
model="o1-preview",
|
||||
max_tokens=10,
|
||||
messages=[{"role": "system", "content": "Hello!"}],
|
||||
client=client,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
|
||||
response = await litellm.acompletion(
|
||||
model="o1-preview",
|
||||
max_tokens=10,
|
||||
messages=[{"role": "system", "content": "Hello!"}],
|
||||
)
|
||||
mock_client.assert_called_once()
|
||||
request_body = mock_client.call_args.kwargs
|
||||
|
||||
assert mock_request.called
|
||||
request_body = json.loads(mock_request.calls[0].request.content)
|
||||
print("request_body: ", request_body)
|
||||
|
||||
print("request_body: ", request_body)
|
||||
|
||||
assert request_body == {
|
||||
"model": "o1-preview",
|
||||
"max_completion_tokens": 10,
|
||||
"messages": [{"role": "user", "content": "Hello!"}],
|
||||
}
|
||||
|
||||
print(f"response: {response}")
|
||||
assert isinstance(response, ModelResponse)
|
||||
assert request_body["model"] == "o1-preview"
|
||||
assert request_body["max_completion_tokens"] == 10
|
||||
assert request_body["messages"] == [{"role": "user", "content": "Hello!"}]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.respx
|
||||
@pytest.mark.parametrize("model", ["gpt-4", "gpt-4-0314", "gpt-4-32k", "o1-preview"])
|
||||
async def test_o1_max_completion_tokens(respx_mock: MockRouter, model: str):
|
||||
async def test_o1_max_completion_tokens(model: str):
|
||||
"""
|
||||
Tests that:
|
||||
- max_completion_tokens is passed directly to OpenAI chat completion models
|
||||
"""
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
litellm.set_verbose = True
|
||||
|
||||
mock_response = ModelResponse(
|
||||
id="cmpl-mock",
|
||||
choices=[Choices(message=Message(content="Mocked response", role="assistant"))],
|
||||
created=int(datetime.now().timestamp()),
|
||||
model=model,
|
||||
)
|
||||
client = AsyncOpenAI(api_key="fake-api-key")
|
||||
|
||||
mock_request = respx_mock.post("https://api.openai.com/v1/chat/completions").mock(
|
||||
return_value=httpx.Response(200, json=mock_response.dict())
|
||||
)
|
||||
with patch.object(
|
||||
client.chat.completions.with_raw_response, "create"
|
||||
) as mock_client:
|
||||
try:
|
||||
await litellm.acompletion(
|
||||
model=model,
|
||||
max_completion_tokens=10,
|
||||
messages=[{"role": "user", "content": "Hello!"}],
|
||||
client=client,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
|
||||
response = await litellm.acompletion(
|
||||
model=model,
|
||||
max_completion_tokens=10,
|
||||
messages=[{"role": "user", "content": "Hello!"}],
|
||||
)
|
||||
mock_client.assert_called_once()
|
||||
request_body = mock_client.call_args.kwargs
|
||||
|
||||
assert mock_request.called
|
||||
request_body = json.loads(mock_request.calls[0].request.content)
|
||||
print("request_body: ", request_body)
|
||||
|
||||
print("request_body: ", request_body)
|
||||
|
||||
assert request_body == {
|
||||
"model": model,
|
||||
"max_completion_tokens": 10,
|
||||
"messages": [{"role": "user", "content": "Hello!"}],
|
||||
}
|
||||
|
||||
print(f"response: {response}")
|
||||
assert isinstance(response, ModelResponse)
|
||||
assert request_body["model"] == model
|
||||
assert request_body["max_completion_tokens"] == 10
|
||||
assert request_body["messages"] == [{"role": "user", "content": "Hello!"}]
|
||||
|
||||
|
||||
def test_litellm_responses():
|
||||
|
|
|
@ -1,94 +0,0 @@
|
|||
import json
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from respx import MockRouter
|
||||
|
||||
import litellm
|
||||
from litellm import Choices, Message, ModelResponse
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
@pytest.mark.respx
|
||||
async def test_vision_with_custom_model(respx_mock: MockRouter):
|
||||
"""
|
||||
Tests that an OpenAI compatible endpoint when sent an image will receive the image in the request
|
||||
|
||||
"""
|
||||
import base64
|
||||
import requests
|
||||
|
||||
litellm.set_verbose = True
|
||||
api_base = "https://my-custom.api.openai.com"
|
||||
|
||||
# Fetch and encode a test image
|
||||
url = "https://dummyimage.com/100/100/fff&text=Test+image"
|
||||
response = requests.get(url)
|
||||
file_data = response.content
|
||||
encoded_file = base64.b64encode(file_data).decode("utf-8")
|
||||
base64_image = f"data:image/png;base64,{encoded_file}"
|
||||
|
||||
mock_response = ModelResponse(
|
||||
id="cmpl-mock",
|
||||
choices=[Choices(message=Message(content="Mocked response", role="assistant"))],
|
||||
created=int(datetime.now().timestamp()),
|
||||
model="my-custom-model",
|
||||
)
|
||||
|
||||
mock_request = respx_mock.post(f"{api_base}/chat/completions").mock(
|
||||
return_value=httpx.Response(200, json=mock_response.dict())
|
||||
)
|
||||
|
||||
response = await litellm.acompletion(
|
||||
model="openai/my-custom-model",
|
||||
max_tokens=10,
|
||||
api_base=api_base, # use the mock api
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": base64_image},
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
assert mock_request.called
|
||||
request_body = json.loads(mock_request.calls[0].request.content)
|
||||
|
||||
print("request_body: ", request_body)
|
||||
|
||||
assert request_body == {
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAGQAAABkBAMAAACCzIhnAAAAG1BMVEURAAD///+ln5/h39/Dv79qX18uHx+If39MPz9oMSdmAAAACXBIWXMAAA7EAAAOxAGVKw4bAAABB0lEQVRYhe2SzWrEIBCAh2A0jxEs4j6GLDS9hqWmV5Flt0cJS+lRwv742DXpEjY1kOZW6HwHFZnPmVEBEARBEARB/jd0KYA/bcUYbPrRLh6amXHJ/K+ypMoyUaGthILzw0l+xI0jsO7ZcmCcm4ILd+QuVYgpHOmDmz6jBeJImdcUCmeBqQpuqRIbVmQsLCrAalrGpfoEqEogqbLTWuXCPCo+Ki1XGqgQ+jVVuhB8bOaHkvmYuzm/b0KYLWwoK58oFqi6XfxQ4Uz7d6WeKpna6ytUs5e8betMcqAv5YPC5EZB2Lm9FIn0/VP6R58+/GEY1X1egVoZ/3bt/EqF6malgSAIgiDIH+QL41409QMY0LMAAAAASUVORK5CYII="
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
"model": "my-custom-model",
|
||||
"max_tokens": 10,
|
||||
}
|
||||
|
||||
print(f"response: {response}")
|
||||
assert isinstance(response, ModelResponse)
|
|
@ -6,6 +6,7 @@ from unittest.mock import AsyncMock
|
|||
import pytest
|
||||
import httpx
|
||||
from respx import MockRouter
|
||||
from unittest.mock import patch, MagicMock, AsyncMock
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
|
@ -68,13 +69,16 @@ def test_convert_dict_to_text_completion_response():
|
|||
assert response.choices[0].logprobs.top_logprobs == [None, {",": -2.1568563}]
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="need to migrate huggingface to support httpx client being passed in"
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.respx
|
||||
async def test_huggingface_text_completion_logprobs(respx_mock: MockRouter):
|
||||
async def test_huggingface_text_completion_logprobs():
|
||||
"""Test text completion with Hugging Face, focusing on logprobs structure"""
|
||||
litellm.set_verbose = True
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler
|
||||
|
||||
# Mock the raw response from Hugging Face
|
||||
mock_response = [
|
||||
{
|
||||
"generated_text": ",\n\nI have a question...", # truncated for brevity
|
||||
|
@ -91,46 +95,48 @@ async def test_huggingface_text_completion_logprobs(respx_mock: MockRouter):
|
|||
}
|
||||
]
|
||||
|
||||
# Mock the API request
|
||||
mock_request = respx_mock.post(
|
||||
"https://api-inference.huggingface.co/models/mistralai/Mistral-7B-v0.1"
|
||||
).mock(return_value=httpx.Response(200, json=mock_response))
|
||||
return_val = AsyncMock()
|
||||
|
||||
response = await litellm.atext_completion(
|
||||
model="huggingface/mistralai/Mistral-7B-v0.1",
|
||||
prompt="good morning",
|
||||
)
|
||||
return_val.json.return_value = mock_response
|
||||
|
||||
# Verify the request
|
||||
assert mock_request.called
|
||||
request_body = json.loads(mock_request.calls[0].request.content)
|
||||
assert request_body == {
|
||||
"inputs": "good morning",
|
||||
"parameters": {"details": True, "return_full_text": False},
|
||||
"stream": False,
|
||||
}
|
||||
client = AsyncHTTPHandler()
|
||||
with patch.object(client, "post", return_value=return_val) as mock_post:
|
||||
response = await litellm.atext_completion(
|
||||
model="huggingface/mistralai/Mistral-7B-v0.1",
|
||||
prompt="good morning",
|
||||
client=client,
|
||||
)
|
||||
|
||||
print("response=", response)
|
||||
# Verify the request
|
||||
mock_post.assert_called_once()
|
||||
request_body = json.loads(mock_post.call_args.kwargs["data"])
|
||||
assert request_body == {
|
||||
"inputs": "good morning",
|
||||
"parameters": {"details": True, "return_full_text": False},
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
# Verify response structure
|
||||
assert isinstance(response, TextCompletionResponse)
|
||||
assert response.object == "text_completion"
|
||||
assert response.model == "mistralai/Mistral-7B-v0.1"
|
||||
print("response=", response)
|
||||
|
||||
# Verify logprobs structure
|
||||
choice = response.choices[0]
|
||||
assert choice.finish_reason == "length"
|
||||
assert choice.index == 0
|
||||
assert isinstance(choice.logprobs.tokens, list)
|
||||
assert isinstance(choice.logprobs.token_logprobs, list)
|
||||
assert isinstance(choice.logprobs.text_offset, list)
|
||||
assert isinstance(choice.logprobs.top_logprobs, list)
|
||||
assert choice.logprobs.tokens == [",", "\n"]
|
||||
assert choice.logprobs.token_logprobs == [-1.7626953, -1.7314453]
|
||||
assert choice.logprobs.text_offset == [0, 1]
|
||||
assert choice.logprobs.top_logprobs == [{}, {}]
|
||||
# Verify response structure
|
||||
assert isinstance(response, TextCompletionResponse)
|
||||
assert response.object == "text_completion"
|
||||
assert response.model == "mistralai/Mistral-7B-v0.1"
|
||||
|
||||
# Verify usage
|
||||
assert response.usage["completion_tokens"] > 0
|
||||
assert response.usage["prompt_tokens"] > 0
|
||||
assert response.usage["total_tokens"] > 0
|
||||
# Verify logprobs structure
|
||||
choice = response.choices[0]
|
||||
assert choice.finish_reason == "length"
|
||||
assert choice.index == 0
|
||||
assert isinstance(choice.logprobs.tokens, list)
|
||||
assert isinstance(choice.logprobs.token_logprobs, list)
|
||||
assert isinstance(choice.logprobs.text_offset, list)
|
||||
assert isinstance(choice.logprobs.top_logprobs, list)
|
||||
assert choice.logprobs.tokens == [",", "\n"]
|
||||
assert choice.logprobs.token_logprobs == [-1.7626953, -1.7314453]
|
||||
assert choice.logprobs.text_offset == [0, 1]
|
||||
assert choice.logprobs.top_logprobs == [{}, {}]
|
||||
|
||||
# Verify usage
|
||||
assert response.usage["completion_tokens"] > 0
|
||||
assert response.usage["prompt_tokens"] > 0
|
||||
assert response.usage["total_tokens"] > 0
|
||||
|
|
|
@ -1146,6 +1146,21 @@ def test_process_gemini_image():
|
|||
mime_type="image/png", file_uri="https://example.com/image.png"
|
||||
)
|
||||
|
||||
# Test HTTPS VIDEO URL
|
||||
https_result = _process_gemini_image("https://cloud-samples-data/video/animals.mp4")
|
||||
print("https_result PNG", https_result)
|
||||
assert https_result["file_data"] == FileDataType(
|
||||
mime_type="video/mp4", file_uri="https://cloud-samples-data/video/animals.mp4"
|
||||
)
|
||||
|
||||
# Test HTTPS PDF URL
|
||||
https_result = _process_gemini_image("https://cloud-samples-data/pdf/animals.pdf")
|
||||
print("https_result PDF", https_result)
|
||||
assert https_result["file_data"] == FileDataType(
|
||||
mime_type="application/pdf",
|
||||
file_uri="https://cloud-samples-data/pdf/animals.pdf",
|
||||
)
|
||||
|
||||
# Test base64 image
|
||||
base64_image = "data:image/jpeg;base64,/9j/4AAQSkZJRg..."
|
||||
base64_result = _process_gemini_image(base64_image)
|
||||
|
|
|
@ -95,3 +95,107 @@ async def test_handle_failed_db_connection():
|
|||
print("_handle_failed_db_connection_for_get_key_object got exception", exc_info)
|
||||
|
||||
assert str(exc_info.value) == "Failed to connect to DB"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model, expect_to_work",
|
||||
[("openai/gpt-4o-mini", True), ("openai/gpt-4o", False)],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_can_key_call_model(model, expect_to_work):
|
||||
"""
|
||||
If wildcard model + specific model is used, choose the specific model settings
|
||||
"""
|
||||
from litellm.proxy.auth.auth_checks import can_key_call_model
|
||||
from fastapi import HTTPException
|
||||
|
||||
llm_model_list = [
|
||||
{
|
||||
"model_name": "openai/*",
|
||||
"litellm_params": {
|
||||
"model": "openai/*",
|
||||
"api_key": "test-api-key",
|
||||
},
|
||||
"model_info": {
|
||||
"id": "e6e7006f83029df40ebc02ddd068890253f4cd3092bcb203d3d8e6f6f606f30f",
|
||||
"db_model": False,
|
||||
"access_groups": ["public-openai-models"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "openai/gpt-4o",
|
||||
"litellm_params": {
|
||||
"model": "openai/gpt-4o",
|
||||
"api_key": "test-api-key",
|
||||
},
|
||||
"model_info": {
|
||||
"id": "0cfcd87f2cb12a783a466888d05c6c89df66db23e01cecd75ec0b83aed73c9ad",
|
||||
"db_model": False,
|
||||
"access_groups": ["private-openai-models"],
|
||||
},
|
||||
},
|
||||
]
|
||||
router = litellm.Router(model_list=llm_model_list)
|
||||
args = {
|
||||
"model": model,
|
||||
"llm_model_list": llm_model_list,
|
||||
"valid_token": UserAPIKeyAuth(
|
||||
models=["public-openai-models"],
|
||||
),
|
||||
"llm_router": router,
|
||||
}
|
||||
if expect_to_work:
|
||||
await can_key_call_model(**args)
|
||||
else:
|
||||
with pytest.raises(Exception) as e:
|
||||
await can_key_call_model(**args)
|
||||
|
||||
print(e)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model, expect_to_work",
|
||||
[("openai/gpt-4o", False), ("openai/gpt-4o-mini", True)],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_can_team_call_model(model, expect_to_work):
|
||||
from litellm.proxy.auth.auth_checks import model_in_access_group
|
||||
from fastapi import HTTPException
|
||||
|
||||
llm_model_list = [
|
||||
{
|
||||
"model_name": "openai/*",
|
||||
"litellm_params": {
|
||||
"model": "openai/*",
|
||||
"api_key": "test-api-key",
|
||||
},
|
||||
"model_info": {
|
||||
"id": "e6e7006f83029df40ebc02ddd068890253f4cd3092bcb203d3d8e6f6f606f30f",
|
||||
"db_model": False,
|
||||
"access_groups": ["public-openai-models"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "openai/gpt-4o",
|
||||
"litellm_params": {
|
||||
"model": "openai/gpt-4o",
|
||||
"api_key": "test-api-key",
|
||||
},
|
||||
"model_info": {
|
||||
"id": "0cfcd87f2cb12a783a466888d05c6c89df66db23e01cecd75ec0b83aed73c9ad",
|
||||
"db_model": False,
|
||||
"access_groups": ["private-openai-models"],
|
||||
},
|
||||
},
|
||||
]
|
||||
router = litellm.Router(model_list=llm_model_list)
|
||||
|
||||
args = {
|
||||
"model": model,
|
||||
"team_models": ["public-openai-models"],
|
||||
"llm_router": router,
|
||||
}
|
||||
if expect_to_work:
|
||||
assert model_in_access_group(**args)
|
||||
else:
|
||||
assert not model_in_access_group(**args)
|
||||
|
|
|
@ -33,7 +33,7 @@ from litellm.router import Router
|
|||
|
||||
@pytest.mark.asyncio()
|
||||
@pytest.mark.respx()
|
||||
async def test_azure_tenant_id_auth(respx_mock: MockRouter):
|
||||
async def test_aaaaazure_tenant_id_auth(respx_mock: MockRouter):
|
||||
"""
|
||||
|
||||
Tests when we set tenant_id, client_id, client_secret they don't get sent with the request
|
||||
|
|
|
@ -1,128 +1,128 @@
|
|||
#### What this tests ####
|
||||
# This adds perf testing to the router, to ensure it's never > 50ms slower than the azure-openai sdk.
|
||||
import sys, os, time, inspect, asyncio, traceback
|
||||
from datetime import datetime
|
||||
import pytest
|
||||
# #### What this tests ####
|
||||
# # This adds perf testing to the router, to ensure it's never > 50ms slower than the azure-openai sdk.
|
||||
# import sys, os, time, inspect, asyncio, traceback
|
||||
# from datetime import datetime
|
||||
# import pytest
|
||||
|
||||
sys.path.insert(0, os.path.abspath("../.."))
|
||||
import openai, litellm, uuid
|
||||
from openai import AsyncAzureOpenAI
|
||||
# sys.path.insert(0, os.path.abspath("../.."))
|
||||
# import openai, litellm, uuid
|
||||
# from openai import AsyncAzureOpenAI
|
||||
|
||||
client = AsyncAzureOpenAI(
|
||||
api_key=os.getenv("AZURE_API_KEY"),
|
||||
azure_endpoint=os.getenv("AZURE_API_BASE"), # type: ignore
|
||||
api_version=os.getenv("AZURE_API_VERSION"),
|
||||
)
|
||||
# client = AsyncAzureOpenAI(
|
||||
# api_key=os.getenv("AZURE_API_KEY"),
|
||||
# azure_endpoint=os.getenv("AZURE_API_BASE"), # type: ignore
|
||||
# api_version=os.getenv("AZURE_API_VERSION"),
|
||||
# )
|
||||
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "azure-test",
|
||||
"litellm_params": {
|
||||
"model": "azure/chatgpt-v-2",
|
||||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_base": os.getenv("AZURE_API_BASE"),
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
},
|
||||
}
|
||||
]
|
||||
# model_list = [
|
||||
# {
|
||||
# "model_name": "azure-test",
|
||||
# "litellm_params": {
|
||||
# "model": "azure/chatgpt-v-2",
|
||||
# "api_key": os.getenv("AZURE_API_KEY"),
|
||||
# "api_base": os.getenv("AZURE_API_BASE"),
|
||||
# "api_version": os.getenv("AZURE_API_VERSION"),
|
||||
# },
|
||||
# }
|
||||
# ]
|
||||
|
||||
router = litellm.Router(model_list=model_list) # type: ignore
|
||||
# router = litellm.Router(model_list=model_list) # type: ignore
|
||||
|
||||
|
||||
async def _openai_completion():
|
||||
try:
|
||||
start_time = time.time()
|
||||
response = await client.chat.completions.create(
|
||||
model="chatgpt-v-2",
|
||||
messages=[{"role": "user", "content": f"This is a test: {uuid.uuid4()}"}],
|
||||
stream=True,
|
||||
)
|
||||
time_to_first_token = None
|
||||
first_token_ts = None
|
||||
init_chunk = None
|
||||
async for chunk in response:
|
||||
if (
|
||||
time_to_first_token is None
|
||||
and len(chunk.choices) > 0
|
||||
and chunk.choices[0].delta.content is not None
|
||||
):
|
||||
first_token_ts = time.time()
|
||||
time_to_first_token = first_token_ts - start_time
|
||||
init_chunk = chunk
|
||||
end_time = time.time()
|
||||
print(
|
||||
"OpenAI Call: ",
|
||||
init_chunk,
|
||||
start_time,
|
||||
first_token_ts,
|
||||
time_to_first_token,
|
||||
end_time,
|
||||
)
|
||||
return time_to_first_token
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return None
|
||||
# async def _openai_completion():
|
||||
# try:
|
||||
# start_time = time.time()
|
||||
# response = await client.chat.completions.create(
|
||||
# model="chatgpt-v-2",
|
||||
# messages=[{"role": "user", "content": f"This is a test: {uuid.uuid4()}"}],
|
||||
# stream=True,
|
||||
# )
|
||||
# time_to_first_token = None
|
||||
# first_token_ts = None
|
||||
# init_chunk = None
|
||||
# async for chunk in response:
|
||||
# if (
|
||||
# time_to_first_token is None
|
||||
# and len(chunk.choices) > 0
|
||||
# and chunk.choices[0].delta.content is not None
|
||||
# ):
|
||||
# first_token_ts = time.time()
|
||||
# time_to_first_token = first_token_ts - start_time
|
||||
# init_chunk = chunk
|
||||
# end_time = time.time()
|
||||
# print(
|
||||
# "OpenAI Call: ",
|
||||
# init_chunk,
|
||||
# start_time,
|
||||
# first_token_ts,
|
||||
# time_to_first_token,
|
||||
# end_time,
|
||||
# )
|
||||
# return time_to_first_token
|
||||
# except Exception as e:
|
||||
# print(e)
|
||||
# return None
|
||||
|
||||
|
||||
async def _router_completion():
|
||||
try:
|
||||
start_time = time.time()
|
||||
response = await router.acompletion(
|
||||
model="azure-test",
|
||||
messages=[{"role": "user", "content": f"This is a test: {uuid.uuid4()}"}],
|
||||
stream=True,
|
||||
)
|
||||
time_to_first_token = None
|
||||
first_token_ts = None
|
||||
init_chunk = None
|
||||
async for chunk in response:
|
||||
if (
|
||||
time_to_first_token is None
|
||||
and len(chunk.choices) > 0
|
||||
and chunk.choices[0].delta.content is not None
|
||||
):
|
||||
first_token_ts = time.time()
|
||||
time_to_first_token = first_token_ts - start_time
|
||||
init_chunk = chunk
|
||||
end_time = time.time()
|
||||
print(
|
||||
"Router Call: ",
|
||||
init_chunk,
|
||||
start_time,
|
||||
first_token_ts,
|
||||
time_to_first_token,
|
||||
end_time - first_token_ts,
|
||||
)
|
||||
return time_to_first_token
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return None
|
||||
# async def _router_completion():
|
||||
# try:
|
||||
# start_time = time.time()
|
||||
# response = await router.acompletion(
|
||||
# model="azure-test",
|
||||
# messages=[{"role": "user", "content": f"This is a test: {uuid.uuid4()}"}],
|
||||
# stream=True,
|
||||
# )
|
||||
# time_to_first_token = None
|
||||
# first_token_ts = None
|
||||
# init_chunk = None
|
||||
# async for chunk in response:
|
||||
# if (
|
||||
# time_to_first_token is None
|
||||
# and len(chunk.choices) > 0
|
||||
# and chunk.choices[0].delta.content is not None
|
||||
# ):
|
||||
# first_token_ts = time.time()
|
||||
# time_to_first_token = first_token_ts - start_time
|
||||
# init_chunk = chunk
|
||||
# end_time = time.time()
|
||||
# print(
|
||||
# "Router Call: ",
|
||||
# init_chunk,
|
||||
# start_time,
|
||||
# first_token_ts,
|
||||
# time_to_first_token,
|
||||
# end_time - first_token_ts,
|
||||
# )
|
||||
# return time_to_first_token
|
||||
# except Exception as e:
|
||||
# print(e)
|
||||
# return None
|
||||
|
||||
|
||||
async def test_azure_completion_streaming():
|
||||
"""
|
||||
Test azure streaming call - measure on time to first (non-null) token.
|
||||
"""
|
||||
n = 3 # Number of concurrent tasks
|
||||
## OPENAI AVG. TIME
|
||||
tasks = [_openai_completion() for _ in range(n)]
|
||||
chat_completions = await asyncio.gather(*tasks)
|
||||
successful_completions = [c for c in chat_completions if c is not None]
|
||||
total_time = 0
|
||||
for item in successful_completions:
|
||||
total_time += item
|
||||
avg_openai_time = total_time / 3
|
||||
## ROUTER AVG. TIME
|
||||
tasks = [_router_completion() for _ in range(n)]
|
||||
chat_completions = await asyncio.gather(*tasks)
|
||||
successful_completions = [c for c in chat_completions if c is not None]
|
||||
total_time = 0
|
||||
for item in successful_completions:
|
||||
total_time += item
|
||||
avg_router_time = total_time / 3
|
||||
## COMPARE
|
||||
print(f"avg_router_time: {avg_router_time}; avg_openai_time: {avg_openai_time}")
|
||||
assert avg_router_time < avg_openai_time + 0.5
|
||||
# async def test_azure_completion_streaming():
|
||||
# """
|
||||
# Test azure streaming call - measure on time to first (non-null) token.
|
||||
# """
|
||||
# n = 3 # Number of concurrent tasks
|
||||
# ## OPENAI AVG. TIME
|
||||
# tasks = [_openai_completion() for _ in range(n)]
|
||||
# chat_completions = await asyncio.gather(*tasks)
|
||||
# successful_completions = [c for c in chat_completions if c is not None]
|
||||
# total_time = 0
|
||||
# for item in successful_completions:
|
||||
# total_time += item
|
||||
# avg_openai_time = total_time / 3
|
||||
# ## ROUTER AVG. TIME
|
||||
# tasks = [_router_completion() for _ in range(n)]
|
||||
# chat_completions = await asyncio.gather(*tasks)
|
||||
# successful_completions = [c for c in chat_completions if c is not None]
|
||||
# total_time = 0
|
||||
# for item in successful_completions:
|
||||
# total_time += item
|
||||
# avg_router_time = total_time / 3
|
||||
# ## COMPARE
|
||||
# print(f"avg_router_time: {avg_router_time}; avg_openai_time: {avg_openai_time}")
|
||||
# assert avg_router_time < avg_openai_time + 0.5
|
||||
|
||||
|
||||
# asyncio.run(test_azure_completion_streaming())
|
||||
# # asyncio.run(test_azure_completion_streaming())
|
||||
|
|
|
@ -1146,7 +1146,9 @@ async def test_exception_with_headers_httpx(
|
|||
|
||||
except litellm.RateLimitError as e:
|
||||
exception_raised = True
|
||||
assert e.litellm_response_headers is not None
|
||||
assert (
|
||||
e.litellm_response_headers is not None
|
||||
), "litellm_response_headers is None"
|
||||
print("e.litellm_response_headers", e.litellm_response_headers)
|
||||
assert int(e.litellm_response_headers["retry-after"]) == cooldown_time
|
||||
|
||||
|
|
|
@ -212,7 +212,7 @@ async def test_bedrock_guardrail_triggered():
|
|||
session,
|
||||
"sk-1234",
|
||||
model="fake-openai-endpoint",
|
||||
messages=[{"role": "user", "content": f"Hello do you like coffee?"}],
|
||||
messages=[{"role": "user", "content": "Hello do you like coffee?"}],
|
||||
guardrails=["bedrock-pre-guard"],
|
||||
)
|
||||
pytest.fail("Should have thrown an exception")
|
||||
|
|
71
tests/otel_tests/test_moderations.py
Normal file
71
tests/otel_tests/test_moderations.py
Normal file
|
@ -0,0 +1,71 @@
|
|||
import pytest
|
||||
import asyncio
|
||||
import aiohttp, openai
|
||||
from openai import OpenAI, AsyncOpenAI
|
||||
from typing import Optional, List, Union
|
||||
import uuid
|
||||
|
||||
|
||||
async def make_moderations_curl_request(
|
||||
session,
|
||||
key,
|
||||
request_data: dict,
|
||||
):
|
||||
url = "http://0.0.0.0:4000/moderations"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
async with session.post(url, headers=headers, json=request_data) as response:
|
||||
status = response.status
|
||||
response_text = await response.text()
|
||||
|
||||
if status != 200:
|
||||
raise Exception(response_text)
|
||||
|
||||
return await response.json()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_moderations_on_proxy_no_model():
|
||||
"""
|
||||
Test moderations endpoint on proxy when no `model` is specified in the request
|
||||
"""
|
||||
async with aiohttp.ClientSession() as session:
|
||||
test_text = "I want to harm someone" # Test text that should trigger moderation
|
||||
request_data = {
|
||||
"input": test_text,
|
||||
}
|
||||
try:
|
||||
response = await make_moderations_curl_request(
|
||||
session,
|
||||
"sk-1234",
|
||||
request_data,
|
||||
)
|
||||
print("response=", response)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
pytest.fail("Moderations request failed")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_moderations_on_proxy_with_model():
|
||||
"""
|
||||
Test moderations endpoint on proxy when `model` is specified in the request
|
||||
"""
|
||||
async with aiohttp.ClientSession() as session:
|
||||
test_text = "I want to harm someone" # Test text that should trigger moderation
|
||||
request_data = {
|
||||
"input": test_text,
|
||||
"model": "text-moderation-stable",
|
||||
}
|
||||
try:
|
||||
response = await make_moderations_curl_request(
|
||||
session,
|
||||
"sk-1234",
|
||||
request_data,
|
||||
)
|
||||
print("response=", response)
|
||||
except Exception as e:
|
||||
pytest.fail("Moderations request failed")
|
|
@ -693,3 +693,47 @@ def test_personal_key_generation_check():
|
|||
),
|
||||
data=GenerateKeyRequest(),
|
||||
)
|
||||
|
||||
|
||||
def test_prepare_metadata_fields():
|
||||
from litellm.proxy.management_endpoints.key_management_endpoints import (
|
||||
prepare_metadata_fields,
|
||||
)
|
||||
|
||||
new_metadata = {"test": "new"}
|
||||
old_metadata = {"test": "test"}
|
||||
|
||||
args = {
|
||||
"data": UpdateKeyRequest(
|
||||
key_alias=None,
|
||||
duration=None,
|
||||
models=[],
|
||||
spend=None,
|
||||
max_budget=None,
|
||||
user_id=None,
|
||||
team_id=None,
|
||||
max_parallel_requests=None,
|
||||
metadata=new_metadata,
|
||||
tpm_limit=None,
|
||||
rpm_limit=None,
|
||||
budget_duration=None,
|
||||
allowed_cache_controls=[],
|
||||
soft_budget=None,
|
||||
config={},
|
||||
permissions={},
|
||||
model_max_budget={},
|
||||
send_invite_email=None,
|
||||
model_rpm_limit=None,
|
||||
model_tpm_limit=None,
|
||||
guardrails=None,
|
||||
blocked=None,
|
||||
aliases={},
|
||||
key="sk-1qGQUJJTcljeaPfzgWRrXQ",
|
||||
tags=None,
|
||||
),
|
||||
"non_default_values": {"metadata": new_metadata},
|
||||
"existing_metadata": {"tags": None, **old_metadata},
|
||||
}
|
||||
|
||||
non_default_values = prepare_metadata_fields(**args)
|
||||
assert non_default_values == {"metadata": new_metadata}
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
include:
|
||||
- included_models.yaml
|
||||
|
||||
litellm_settings:
|
||||
callbacks: ["prometheus"]
|
|
@ -0,0 +1,5 @@
|
|||
include:
|
||||
- non-existent-file.yaml
|
||||
|
||||
litellm_settings:
|
||||
callbacks: ["prometheus"]
|
|
@ -0,0 +1,6 @@
|
|||
include:
|
||||
- models_file_1.yaml
|
||||
- models_file_2.yaml
|
||||
|
||||
litellm_settings:
|
||||
callbacks: ["prometheus"]
|
|
@ -0,0 +1,4 @@
|
|||
model_list:
|
||||
- model_name: included-model
|
||||
litellm_params:
|
||||
model: gpt-4
|
|
@ -0,0 +1,4 @@
|
|||
model_list:
|
||||
- model_name: included-model-1
|
||||
litellm_params:
|
||||
model: gpt-4
|
|
@ -0,0 +1,4 @@
|
|||
model_list:
|
||||
- model_name: included-model-2
|
||||
litellm_params:
|
||||
model: gpt-3.5-turbo
|
|
@ -1345,17 +1345,8 @@ def test_generate_and_update_key(prisma_client):
|
|||
)
|
||||
current_time = datetime.now(timezone.utc)
|
||||
|
||||
print(
|
||||
"days between now and budget_reset_at",
|
||||
(budget_reset_at - current_time).days,
|
||||
)
|
||||
# assert budget_reset_at is 30 days from now
|
||||
assert (
|
||||
abs(
|
||||
(budget_reset_at - current_time).total_seconds() - 30 * 24 * 60 * 60
|
||||
)
|
||||
<= 10
|
||||
)
|
||||
assert 31 >= (budget_reset_at - current_time).days >= 29
|
||||
|
||||
# cleanup - delete key
|
||||
delete_key_request = KeyRequest(keys=[generated_key])
|
||||
|
@ -2926,7 +2917,6 @@ async def test_generate_key_with_model_tpm_limit(prisma_client):
|
|||
"team": "litellm-team3",
|
||||
"model_tpm_limit": {"gpt-4": 100},
|
||||
"model_rpm_limit": {"gpt-4": 2},
|
||||
"tags": None,
|
||||
}
|
||||
|
||||
# Update model tpm_limit and rpm_limit
|
||||
|
@ -2950,7 +2940,6 @@ async def test_generate_key_with_model_tpm_limit(prisma_client):
|
|||
"team": "litellm-team3",
|
||||
"model_tpm_limit": {"gpt-4": 200},
|
||||
"model_rpm_limit": {"gpt-4": 3},
|
||||
"tags": None,
|
||||
}
|
||||
|
||||
|
||||
|
@ -2990,7 +2979,6 @@ async def test_generate_key_with_guardrails(prisma_client):
|
|||
assert result["info"]["metadata"] == {
|
||||
"team": "litellm-team3",
|
||||
"guardrails": ["aporia-pre-call"],
|
||||
"tags": None,
|
||||
}
|
||||
|
||||
# Update model tpm_limit and rpm_limit
|
||||
|
@ -3012,7 +3000,6 @@ async def test_generate_key_with_guardrails(prisma_client):
|
|||
assert result["info"]["metadata"] == {
|
||||
"team": "litellm-team3",
|
||||
"guardrails": ["aporia-pre-call", "aporia-post-call"],
|
||||
"tags": None,
|
||||
}
|
||||
|
||||
|
||||
|
@ -3632,3 +3619,152 @@ async def test_key_generate_with_secret_manager_call(prisma_client):
|
|||
|
||||
|
||||
################################################################################
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_key_alias_uniqueness(prisma_client):
|
||||
"""
|
||||
Test that:
|
||||
1. We cannot create two keys with the same alias
|
||||
2. We cannot update a key to use an alias that's already taken
|
||||
3. We can update a key while keeping its existing alias
|
||||
"""
|
||||
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||
|
||||
try:
|
||||
# Create first key with an alias
|
||||
unique_alias = f"test-alias-{uuid.uuid4()}"
|
||||
key1 = await generate_key_fn(
|
||||
data=GenerateKeyRequest(key_alias=unique_alias),
|
||||
user_api_key_dict=UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
api_key="sk-1234",
|
||||
user_id="1234",
|
||||
),
|
||||
)
|
||||
|
||||
# Try to create second key with same alias - should fail
|
||||
try:
|
||||
key2 = await generate_key_fn(
|
||||
data=GenerateKeyRequest(key_alias=unique_alias),
|
||||
user_api_key_dict=UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
api_key="sk-1234",
|
||||
user_id="1234",
|
||||
),
|
||||
)
|
||||
pytest.fail("Should not be able to create a second key with the same alias")
|
||||
except Exception as e:
|
||||
print("vars(e)=", vars(e))
|
||||
assert "Unique key aliases across all keys are required" in str(e.message)
|
||||
|
||||
# Create another key with different alias
|
||||
another_alias = f"test-alias-{uuid.uuid4()}"
|
||||
key3 = await generate_key_fn(
|
||||
data=GenerateKeyRequest(key_alias=another_alias),
|
||||
user_api_key_dict=UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
api_key="sk-1234",
|
||||
user_id="1234",
|
||||
),
|
||||
)
|
||||
|
||||
# Try to update key3 to use key1's alias - should fail
|
||||
try:
|
||||
await update_key_fn(
|
||||
data=UpdateKeyRequest(key=key3.key, key_alias=unique_alias),
|
||||
request=Request(scope={"type": "http"}),
|
||||
)
|
||||
pytest.fail("Should not be able to update a key to use an existing alias")
|
||||
except Exception as e:
|
||||
assert "Unique key aliases across all keys are required" in str(e.message)
|
||||
|
||||
# Update key1 with its own existing alias - should succeed
|
||||
updated_key = await update_key_fn(
|
||||
data=UpdateKeyRequest(key=key1.key, key_alias=unique_alias),
|
||||
request=Request(scope={"type": "http"}),
|
||||
)
|
||||
assert updated_key is not None
|
||||
|
||||
except Exception as e:
|
||||
print("got exceptions, e=", e)
|
||||
print("vars(e)=", vars(e))
|
||||
pytest.fail(f"An unexpected error occurred: {str(e)}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enforce_unique_key_alias(prisma_client):
|
||||
"""
|
||||
Unit test the _enforce_unique_key_alias function:
|
||||
1. Test it allows unique aliases
|
||||
2. Test it blocks duplicate aliases for new keys
|
||||
3. Test it allows updating a key with its own existing alias
|
||||
4. Test it blocks updating a key with another key's alias
|
||||
"""
|
||||
from litellm.proxy.management_endpoints.key_management_endpoints import (
|
||||
_enforce_unique_key_alias,
|
||||
)
|
||||
|
||||
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||
|
||||
try:
|
||||
# Test 1: Allow unique alias
|
||||
unique_alias = f"test-alias-{uuid.uuid4()}"
|
||||
await _enforce_unique_key_alias(
|
||||
key_alias=unique_alias,
|
||||
prisma_client=prisma_client,
|
||||
) # Should pass
|
||||
|
||||
# Create a key with this alias in the database
|
||||
key1 = await generate_key_fn(
|
||||
data=GenerateKeyRequest(key_alias=unique_alias),
|
||||
user_api_key_dict=UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
api_key="sk-1234",
|
||||
user_id="1234",
|
||||
),
|
||||
)
|
||||
|
||||
# Test 2: Block duplicate alias for new key
|
||||
try:
|
||||
await _enforce_unique_key_alias(
|
||||
key_alias=unique_alias,
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
pytest.fail("Should not allow duplicate alias")
|
||||
except Exception as e:
|
||||
assert "Unique key aliases across all keys are required" in str(e.message)
|
||||
|
||||
# Test 3: Allow updating key with its own alias
|
||||
await _enforce_unique_key_alias(
|
||||
key_alias=unique_alias,
|
||||
existing_key_token=hash_token(key1.key),
|
||||
prisma_client=prisma_client,
|
||||
) # Should pass
|
||||
|
||||
# Test 4: Block updating with another key's alias
|
||||
another_key = await generate_key_fn(
|
||||
data=GenerateKeyRequest(key_alias=f"test-alias-{uuid.uuid4()}"),
|
||||
user_api_key_dict=UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
api_key="sk-1234",
|
||||
user_id="1234",
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
await _enforce_unique_key_alias(
|
||||
key_alias=unique_alias,
|
||||
existing_key_token=another_key.key,
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
pytest.fail("Should not allow using another key's alias")
|
||||
except Exception as e:
|
||||
assert "Unique key aliases across all keys are required" in str(e.message)
|
||||
|
||||
except Exception as e:
|
||||
print("Unexpected error:", e)
|
||||
pytest.fail(f"An unexpected error occurred: {str(e)}")
|
||||
|
|
|
@ -23,6 +23,8 @@ import logging
|
|||
|
||||
from litellm.proxy.proxy_server import ProxyConfig
|
||||
|
||||
INVALID_FILES = ["config_with_missing_include.yaml"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_reading_configs_from_files():
|
||||
|
@ -38,6 +40,9 @@ async def test_basic_reading_configs_from_files():
|
|||
print(files)
|
||||
|
||||
for file in files:
|
||||
if file in INVALID_FILES: # these are intentionally invalid files
|
||||
continue
|
||||
print("reading file=", file)
|
||||
config_path = os.path.join(example_config_yaml_path, file)
|
||||
config = await proxy_config_instance.get_config(config_file_path=config_path)
|
||||
print(config)
|
||||
|
@ -115,3 +120,67 @@ async def test_read_config_file_with_os_environ_vars():
|
|||
os.environ[key] = _old_env_vars[key]
|
||||
else:
|
||||
del os.environ[key]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_include_directive():
|
||||
"""
|
||||
Test that the include directive correctly loads and merges configs
|
||||
"""
|
||||
proxy_config_instance = ProxyConfig()
|
||||
current_path = os.path.dirname(os.path.abspath(__file__))
|
||||
config_path = os.path.join(
|
||||
current_path, "example_config_yaml", "config_with_include.yaml"
|
||||
)
|
||||
|
||||
config = await proxy_config_instance.get_config(config_file_path=config_path)
|
||||
|
||||
# Verify the included model list was merged
|
||||
assert len(config["model_list"]) > 0
|
||||
assert any(
|
||||
model["model_name"] == "included-model" for model in config["model_list"]
|
||||
)
|
||||
|
||||
# Verify original config settings remain
|
||||
assert config["litellm_settings"]["callbacks"] == ["prometheus"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_include_file():
|
||||
"""
|
||||
Test that a missing included file raises FileNotFoundError
|
||||
"""
|
||||
proxy_config_instance = ProxyConfig()
|
||||
current_path = os.path.dirname(os.path.abspath(__file__))
|
||||
config_path = os.path.join(
|
||||
current_path, "example_config_yaml", "config_with_missing_include.yaml"
|
||||
)
|
||||
|
||||
with pytest.raises(FileNotFoundError):
|
||||
await proxy_config_instance.get_config(config_file_path=config_path)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_includes():
|
||||
"""
|
||||
Test that multiple files in the include list are all processed correctly
|
||||
"""
|
||||
proxy_config_instance = ProxyConfig()
|
||||
current_path = os.path.dirname(os.path.abspath(__file__))
|
||||
config_path = os.path.join(
|
||||
current_path, "example_config_yaml", "config_with_multiple_includes.yaml"
|
||||
)
|
||||
|
||||
config = await proxy_config_instance.get_config(config_file_path=config_path)
|
||||
|
||||
# Verify models from both included files are present
|
||||
assert len(config["model_list"]) == 2
|
||||
assert any(
|
||||
model["model_name"] == "included-model-1" for model in config["model_list"]
|
||||
)
|
||||
assert any(
|
||||
model["model_name"] == "included-model-2" for model in config["model_list"]
|
||||
)
|
||||
|
||||
# Verify original config settings remain
|
||||
assert config["litellm_settings"]["callbacks"] == ["prometheus"]
|
||||
|
|
|
@ -444,7 +444,7 @@ def test_foward_litellm_user_info_to_backend_llm_call():
|
|||
|
||||
def test_update_internal_user_params():
|
||||
from litellm.proxy.management_endpoints.internal_user_endpoints import (
|
||||
_update_internal_user_params,
|
||||
_update_internal_new_user_params,
|
||||
)
|
||||
from litellm.proxy._types import NewUserRequest
|
||||
|
||||
|
@ -456,7 +456,7 @@ def test_update_internal_user_params():
|
|||
|
||||
data = NewUserRequest(user_role="internal_user", user_email="krrish3@berri.ai")
|
||||
data_json = data.model_dump()
|
||||
updated_data_json = _update_internal_user_params(data_json, data)
|
||||
updated_data_json = _update_internal_new_user_params(data_json, data)
|
||||
assert updated_data_json["models"] == litellm.default_internal_user_params["models"]
|
||||
assert (
|
||||
updated_data_json["max_budget"]
|
||||
|
@ -530,7 +530,7 @@ def test_prepare_key_update_data():
|
|||
|
||||
data = UpdateKeyRequest(key="test_key", metadata=None)
|
||||
updated_data = prepare_key_update_data(data, existing_key_row)
|
||||
assert updated_data["metadata"] == None
|
||||
assert updated_data["metadata"] is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -574,3 +574,108 @@ def test_get_docs_url(env_vars, expected_url):
|
|||
|
||||
result = _get_docs_url()
|
||||
assert result == expected_url
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"request_tags, tags_to_add, expected_tags",
|
||||
[
|
||||
(None, None, []), # both None
|
||||
(["tag1", "tag2"], None, ["tag1", "tag2"]), # tags_to_add is None
|
||||
(None, ["tag3", "tag4"], ["tag3", "tag4"]), # request_tags is None
|
||||
(
|
||||
["tag1", "tag2"],
|
||||
["tag3", "tag4"],
|
||||
["tag1", "tag2", "tag3", "tag4"],
|
||||
), # both have unique tags
|
||||
(
|
||||
["tag1", "tag2"],
|
||||
["tag2", "tag3"],
|
||||
["tag1", "tag2", "tag3"],
|
||||
), # overlapping tags
|
||||
([], [], []), # both empty lists
|
||||
("not_a_list", ["tag1"], ["tag1"]), # request_tags invalid type
|
||||
(["tag1"], "not_a_list", ["tag1"]), # tags_to_add invalid type
|
||||
(
|
||||
["tag1"],
|
||||
["tag1", "tag2"],
|
||||
["tag1", "tag2"],
|
||||
), # duplicate tags in inputs
|
||||
],
|
||||
)
|
||||
def test_merge_tags(request_tags, tags_to_add, expected_tags):
|
||||
from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup
|
||||
|
||||
result = LiteLLMProxyRequestSetup._merge_tags(
|
||||
request_tags=request_tags, tags_to_add=tags_to_add
|
||||
)
|
||||
|
||||
assert isinstance(result, list)
|
||||
assert sorted(result) == sorted(expected_tags)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"key_tags, request_tags, expected_tags",
|
||||
[
|
||||
# exact duplicates
|
||||
(["tag1", "tag2", "tag3"], ["tag1", "tag2", "tag3"], ["tag1", "tag2", "tag3"]),
|
||||
# partial duplicates
|
||||
(
|
||||
["tag1", "tag2", "tag3"],
|
||||
["tag2", "tag3", "tag4"],
|
||||
["tag1", "tag2", "tag3", "tag4"],
|
||||
),
|
||||
# duplicates within key tags
|
||||
(["tag1", "tag2"], ["tag3", "tag4"], ["tag1", "tag2", "tag3", "tag4"]),
|
||||
# duplicates within request tags
|
||||
(["tag1", "tag2"], ["tag2", "tag3", "tag4"], ["tag1", "tag2", "tag3", "tag4"]),
|
||||
# case sensitive duplicates
|
||||
(["Tag1", "TAG2"], ["tag1", "tag2"], ["Tag1", "TAG2", "tag1", "tag2"]),
|
||||
],
|
||||
)
|
||||
async def test_add_litellm_data_to_request_duplicate_tags(
|
||||
key_tags, request_tags, expected_tags
|
||||
):
|
||||
"""
|
||||
Test to verify duplicate tags between request and key metadata are handled correctly
|
||||
|
||||
|
||||
Aggregation logic when checking spend can be impacted if duplicate tags are not handled correctly.
|
||||
|
||||
User feedback:
|
||||
"If I register my key with tag1 and
|
||||
also pass the same tag1 when using the key
|
||||
then I see tag1 twice in the
|
||||
LiteLLM_SpendLogs table request_tags column. This can mess up aggregation logic"
|
||||
"""
|
||||
mock_request = Mock(spec=Request)
|
||||
mock_request.url.path = "/chat/completions"
|
||||
mock_request.query_params = {}
|
||||
mock_request.headers = {}
|
||||
|
||||
# Setup key with tags in metadata
|
||||
user_api_key_dict = UserAPIKeyAuth(
|
||||
api_key="test_api_key",
|
||||
user_id="test_user_id",
|
||||
org_id="test_org_id",
|
||||
metadata={"tags": key_tags},
|
||||
)
|
||||
|
||||
# Setup request data with tags
|
||||
data = {"metadata": {"tags": request_tags}}
|
||||
|
||||
# Process request
|
||||
proxy_config = Mock()
|
||||
result = await add_litellm_data_to_request(
|
||||
data=data,
|
||||
request=mock_request,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
proxy_config=proxy_config,
|
||||
)
|
||||
|
||||
# Verify results
|
||||
assert "metadata" in result
|
||||
assert "tags" in result["metadata"]
|
||||
assert sorted(result["metadata"]["tags"]) == sorted(
|
||||
expected_tags
|
||||
), f"Expected {expected_tags}, got {result['metadata']['tags']}"
|
||||
|
|
111
tests/proxy_unit_tests/test_unit_test_proxy_hooks.py
Normal file
111
tests/proxy_unit_tests/test_unit_test_proxy_hooks.py
Normal file
|
@ -0,0 +1,111 @@
|
|||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import Mock, patch, AsyncMock
|
||||
import pytest
|
||||
from fastapi import Request
|
||||
from litellm.proxy.utils import _get_redoc_url, _get_docs_url
|
||||
|
||||
sys.path.insert(0, os.path.abspath("../.."))
|
||||
import litellm
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disable_error_logs():
|
||||
"""
|
||||
Test that the error logs are not written to the database when disable_error_logs is True
|
||||
"""
|
||||
# Mock the necessary components
|
||||
mock_prisma_client = AsyncMock()
|
||||
mock_general_settings = {"disable_error_logs": True}
|
||||
|
||||
with patch(
|
||||
"litellm.proxy.proxy_server.general_settings", mock_general_settings
|
||||
), patch("litellm.proxy.proxy_server.prisma_client", mock_prisma_client):
|
||||
|
||||
# Create a test exception
|
||||
test_exception = Exception("Test error")
|
||||
test_kwargs = {
|
||||
"model": "gpt-4",
|
||||
"exception": test_exception,
|
||||
"optional_params": {},
|
||||
"litellm_params": {"metadata": {}},
|
||||
}
|
||||
|
||||
# Call the failure handler
|
||||
from litellm.proxy.proxy_server import _PROXY_failure_handler
|
||||
|
||||
await _PROXY_failure_handler(
|
||||
kwargs=test_kwargs,
|
||||
completion_response=None,
|
||||
start_time="2024-01-01",
|
||||
end_time="2024-01-01",
|
||||
)
|
||||
|
||||
# Verify prisma client was not called to create error logs
|
||||
if hasattr(mock_prisma_client, "db"):
|
||||
assert not mock_prisma_client.db.litellm_errorlogs.create.called
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disable_spend_logs():
|
||||
"""
|
||||
Test that the spend logs are not written to the database when disable_spend_logs is True
|
||||
"""
|
||||
# Mock the necessary components
|
||||
mock_prisma_client = Mock()
|
||||
mock_prisma_client.spend_log_transactions = []
|
||||
|
||||
with patch("litellm.proxy.proxy_server.disable_spend_logs", True), patch(
|
||||
"litellm.proxy.proxy_server.prisma_client", mock_prisma_client
|
||||
):
|
||||
from litellm.proxy.proxy_server import update_database
|
||||
|
||||
# Call update_database with disable_spend_logs=True
|
||||
await update_database(
|
||||
token="fake-token",
|
||||
response_cost=0.1,
|
||||
user_id="user123",
|
||||
completion_response=None,
|
||||
start_time="2024-01-01",
|
||||
end_time="2024-01-01",
|
||||
)
|
||||
# Verify no spend logs were added
|
||||
assert len(mock_prisma_client.spend_log_transactions) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enable_error_logs():
|
||||
"""
|
||||
Test that the error logs are written to the database when disable_error_logs is False
|
||||
"""
|
||||
# Mock the necessary components
|
||||
mock_prisma_client = AsyncMock()
|
||||
mock_general_settings = {"disable_error_logs": False}
|
||||
|
||||
with patch(
|
||||
"litellm.proxy.proxy_server.general_settings", mock_general_settings
|
||||
), patch("litellm.proxy.proxy_server.prisma_client", mock_prisma_client):
|
||||
|
||||
# Create a test exception
|
||||
test_exception = Exception("Test error")
|
||||
test_kwargs = {
|
||||
"model": "gpt-4",
|
||||
"exception": test_exception,
|
||||
"optional_params": {},
|
||||
"litellm_params": {"metadata": {}},
|
||||
}
|
||||
|
||||
# Call the failure handler
|
||||
from litellm.proxy.proxy_server import _PROXY_failure_handler
|
||||
|
||||
await _PROXY_failure_handler(
|
||||
kwargs=test_kwargs,
|
||||
completion_response=None,
|
||||
start_time="2024-01-01",
|
||||
end_time="2024-01-01",
|
||||
)
|
||||
|
||||
# Verify prisma client was called to create error logs
|
||||
if hasattr(mock_prisma_client, "db"):
|
||||
assert mock_prisma_client.db.litellm_errorlogs.create.called
|
|
@ -1040,8 +1040,11 @@ def test_pattern_match_deployment_set_model_name(
|
|||
async def test_pass_through_moderation_endpoint_factory(model_list):
|
||||
router = Router(model_list=model_list)
|
||||
response = await router._pass_through_moderation_endpoint_factory(
|
||||
original_function=litellm.amoderation, input="this is valid good text"
|
||||
original_function=litellm.amoderation,
|
||||
input="this is valid good text",
|
||||
model=None,
|
||||
)
|
||||
assert response is not None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
|
|
@ -300,6 +300,7 @@ async def test_key_update(metadata):
|
|||
get_key=key,
|
||||
metadata=metadata,
|
||||
)
|
||||
print(f"updated_key['metadata']: {updated_key['metadata']}")
|
||||
assert updated_key["metadata"] == metadata
|
||||
await update_proxy_budget(session=session) # resets proxy spend
|
||||
await chat_completion(session=session, key=key)
|
||||
|
|
|
@ -114,7 +114,7 @@ async def test_spend_logs():
|
|||
|
||||
|
||||
async def get_predict_spend_logs(session):
|
||||
url = f"http://0.0.0.0:4000/global/predict/spend/logs"
|
||||
url = "http://0.0.0.0:4000/global/predict/spend/logs"
|
||||
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
|
||||
data = {
|
||||
"data": [
|
||||
|
@ -155,6 +155,7 @@ async def get_spend_report(session, start_date, end_date):
|
|||
return await response.json()
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="datetime in ci/cd gets set weirdly")
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_predicted_spend_logs():
|
||||
"""
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue