forked from phoenix/litellm-mirror
Compare commits
7 commits
main
...
litellm_al
Author | SHA1 | Date | |
---|---|---|---|
|
ea54fba9f3 | ||
|
ca2f0680df | ||
|
97b846cb08 | ||
|
459ba98607 | ||
|
5c6d9200c4 | ||
|
0869a1c13f | ||
|
778dbf1c2f |
75 changed files with 903 additions and 3098 deletions
|
@ -811,8 +811,7 @@ jobs:
|
|||
- run: python ./tests/code_coverage_tests/router_code_coverage.py
|
||||
- run: python ./tests/code_coverage_tests/test_router_strategy_async.py
|
||||
- run: python ./tests/code_coverage_tests/litellm_logging_code_coverage.py
|
||||
- run: python ./tests/documentation_tests/test_env_keys.py
|
||||
- run: python ./tests/documentation_tests/test_router_settings.py
|
||||
# - run: python ./tests/documentation_tests/test_env_keys.py
|
||||
- run: python ./tests/documentation_tests/test_api_docs.py
|
||||
- run: python ./tests/code_coverage_tests/ensure_async_clients_test.py
|
||||
- run: helm lint ./deploy/charts/litellm-helm
|
||||
|
@ -1408,7 +1407,7 @@ jobs:
|
|||
command: |
|
||||
docker run -d \
|
||||
-p 4000:4000 \
|
||||
-e DATABASE_URL=$PROXY_DATABASE_URL_2 \
|
||||
-e DATABASE_URL=$PROXY_DATABASE_URL \
|
||||
-e LITELLM_MASTER_KEY="sk-1234" \
|
||||
-e OPENAI_API_KEY=$OPENAI_API_KEY \
|
||||
-e UI_USERNAME="admin" \
|
||||
|
|
|
@ -1,135 +0,0 @@
|
|||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# Moderation
|
||||
|
||||
|
||||
### Usage
|
||||
<Tabs>
|
||||
<TabItem value="python" label="LiteLLM Python SDK">
|
||||
|
||||
```python
|
||||
from litellm import moderation
|
||||
|
||||
response = moderation(
|
||||
input="hello from litellm",
|
||||
model="text-moderation-stable"
|
||||
)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
<TabItem value="proxy" label="LiteLLM Proxy Server">
|
||||
|
||||
For `/moderations` endpoint, there is **no need to specify `model` in the request or on the litellm config.yaml**
|
||||
|
||||
Start litellm proxy server
|
||||
|
||||
```
|
||||
litellm
|
||||
```
|
||||
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="python" label="OpenAI Python SDK">
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
# set base_url to your proxy server
|
||||
# set api_key to send to proxy server
|
||||
client = OpenAI(api_key="<proxy-api-key>", base_url="http://0.0.0.0:4000")
|
||||
|
||||
response = client.moderations.create(
|
||||
input="hello from litellm",
|
||||
model="text-moderation-stable" # optional, defaults to `omni-moderation-latest`
|
||||
)
|
||||
|
||||
print(response)
|
||||
```
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="curl" label="Curl Request">
|
||||
|
||||
```shell
|
||||
curl --location 'http://0.0.0.0:4000/moderations' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--data '{"input": "Sample text goes here", "model": "text-moderation-stable"}'
|
||||
```
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
## Input Params
|
||||
LiteLLM accepts and translates the [OpenAI Moderation params](https://platform.openai.com/docs/api-reference/moderations) across all supported providers.
|
||||
|
||||
### Required Fields
|
||||
|
||||
- `input`: *string or array* - Input (or inputs) to classify. Can be a single string, an array of strings, or an array of multi-modal input objects similar to other models.
|
||||
- If string: A string of text to classify for moderation
|
||||
- If array of strings: An array of strings to classify for moderation
|
||||
- If array of objects: An array of multi-modal inputs to the moderation model, where each object can be:
|
||||
- An object describing an image to classify with:
|
||||
- `type`: *string, required* - Always `image_url`
|
||||
- `image_url`: *object, required* - Contains either an image URL or a data URL for a base64 encoded image
|
||||
- An object describing text to classify with:
|
||||
- `type`: *string, required* - Always `text`
|
||||
- `text`: *string, required* - A string of text to classify
|
||||
|
||||
### Optional Fields
|
||||
|
||||
- `model`: *string (optional)* - The moderation model to use. Defaults to `omni-moderation-latest`.
|
||||
|
||||
## Output Format
|
||||
Here's the exact json output and type you can expect from all moderation calls:
|
||||
|
||||
[**LiteLLM follows OpenAI's output format**](https://platform.openai.com/docs/api-reference/moderations/object)
|
||||
|
||||
|
||||
```python
|
||||
{
|
||||
"id": "modr-AB8CjOTu2jiq12hp1AQPfeqFWaORR",
|
||||
"model": "text-moderation-007",
|
||||
"results": [
|
||||
{
|
||||
"flagged": true,
|
||||
"categories": {
|
||||
"sexual": false,
|
||||
"hate": false,
|
||||
"harassment": true,
|
||||
"self-harm": false,
|
||||
"sexual/minors": false,
|
||||
"hate/threatening": false,
|
||||
"violence/graphic": false,
|
||||
"self-harm/intent": false,
|
||||
"self-harm/instructions": false,
|
||||
"harassment/threatening": true,
|
||||
"violence": true
|
||||
},
|
||||
"category_scores": {
|
||||
"sexual": 0.000011726012417057063,
|
||||
"hate": 0.22706663608551025,
|
||||
"harassment": 0.5215635299682617,
|
||||
"self-harm": 2.227119921371923e-6,
|
||||
"sexual/minors": 7.107352217872176e-8,
|
||||
"hate/threatening": 0.023547329008579254,
|
||||
"violence/graphic": 0.00003391829886822961,
|
||||
"self-harm/intent": 1.646940972932498e-6,
|
||||
"self-harm/instructions": 1.1198755256458526e-9,
|
||||
"harassment/threatening": 0.5694745779037476,
|
||||
"violence": 0.9971134662628174
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
|
||||
## **Supported Providers**
|
||||
|
||||
| Provider |
|
||||
|-------------|
|
||||
| OpenAI |
|
|
@ -4,63 +4,24 @@ import TabItem from '@theme/TabItem';
|
|||
|
||||
# Argilla
|
||||
|
||||
Argilla is a collaborative annotation tool for AI engineers and domain experts who need to build high-quality datasets for their projects.
|
||||
Argilla is a tool for annotating datasets.
|
||||
|
||||
|
||||
## Getting Started
|
||||
|
||||
To log the data to Argilla, first you need to deploy the Argilla server. If you have not deployed the Argilla server, please follow the instructions [here](https://docs.argilla.io/latest/getting_started/quickstart/).
|
||||
|
||||
Next, you will need to configure and create the Argilla dataset.
|
||||
|
||||
```python
|
||||
import argilla as rg
|
||||
|
||||
client = rg.Argilla(api_url="<api_url>", api_key="<api_key>")
|
||||
|
||||
settings = rg.Settings(
|
||||
guidelines="These are some guidelines.",
|
||||
fields=[
|
||||
rg.ChatField(
|
||||
name="user_input",
|
||||
),
|
||||
rg.TextField(
|
||||
name="llm_output",
|
||||
),
|
||||
],
|
||||
questions=[
|
||||
rg.RatingQuestion(
|
||||
name="rating",
|
||||
values=[1, 2, 3, 4, 5, 6, 7],
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
dataset = rg.Dataset(
|
||||
name="my_first_dataset",
|
||||
settings=settings,
|
||||
)
|
||||
|
||||
dataset.create()
|
||||
```
|
||||
|
||||
For further configuration, please refer to the [Argilla documentation](https://docs.argilla.io/latest/how_to_guides/dataset/).
|
||||
|
||||
|
||||
## Usage
|
||||
## Usage
|
||||
|
||||
<Tabs>
|
||||
<Tab value="sdk" label="SDK">
|
||||
|
||||
```python
|
||||
import os
|
||||
import litellm
|
||||
from litellm import completion
|
||||
import litellm
|
||||
import os
|
||||
|
||||
# add env vars
|
||||
os.environ["ARGILLA_API_KEY"]="argilla.apikey"
|
||||
os.environ["ARGILLA_BASE_URL"]="http://localhost:6900"
|
||||
os.environ["ARGILLA_DATASET_NAME"]="my_first_dataset"
|
||||
os.environ["ARGILLA_DATASET_NAME"]="my_second_dataset"
|
||||
os.environ["OPENAI_API_KEY"]="sk-proj-..."
|
||||
|
||||
litellm.callbacks = ["argilla"]
|
||||
|
|
|
@ -279,31 +279,7 @@ router_settings:
|
|||
| retry_policy | object | Specifies the number of retries for different types of exceptions. [More information here](reliability) |
|
||||
| allowed_fails | integer | The number of failures allowed before cooling down a model. [More information here](reliability) |
|
||||
| allowed_fails_policy | object | Specifies the number of allowed failures for different error types before cooling down a deployment. [More information here](reliability) |
|
||||
| default_max_parallel_requests | Optional[int] | The default maximum number of parallel requests for a deployment. |
|
||||
| default_priority | (Optional[int]) | The default priority for a request. Only for '.scheduler_acompletion()'. Default is None. |
|
||||
| polling_interval | (Optional[float]) | frequency of polling queue. Only for '.scheduler_acompletion()'. Default is 3ms. |
|
||||
| max_fallbacks | Optional[int] | The maximum number of fallbacks to try before exiting the call. Defaults to 5. |
|
||||
| default_litellm_params | Optional[dict] | The default litellm parameters to add to all requests (e.g. `temperature`, `max_tokens`). |
|
||||
| timeout | Optional[float] | The default timeout for a request. |
|
||||
| debug_level | Literal["DEBUG", "INFO"] | The debug level for the logging library in the router. Defaults to "INFO". |
|
||||
| client_ttl | int | Time-to-live for cached clients in seconds. Defaults to 3600. |
|
||||
| cache_kwargs | dict | Additional keyword arguments for the cache initialization. |
|
||||
| routing_strategy_args | dict | Additional keyword arguments for the routing strategy - e.g. lowest latency routing default ttl |
|
||||
| model_group_alias | dict | Model group alias mapping. E.g. `{"claude-3-haiku": "claude-3-haiku-20240229"}` |
|
||||
| num_retries | int | Number of retries for a request. Defaults to 3. |
|
||||
| default_fallbacks | Optional[List[str]] | Fallbacks to try if no model group-specific fallbacks are defined. |
|
||||
| caching_groups | Optional[List[tuple]] | List of model groups for caching across model groups. Defaults to None. - e.g. caching_groups=[("openai-gpt-3.5-turbo", "azure-gpt-3.5-turbo")]|
|
||||
| alerting_config | AlertingConfig | [SDK-only arg] Slack alerting configuration. Defaults to None. [Further Docs](../routing.md#alerting-) |
|
||||
| assistants_config | AssistantsConfig | Set on proxy via `assistant_settings`. [Further docs](../assistants.md) |
|
||||
| set_verbose | boolean | [DEPRECATED PARAM - see debug docs](./debugging.md) If true, sets the logging level to verbose. |
|
||||
| retry_after | int | Time to wait before retrying a request in seconds. Defaults to 0. If `x-retry-after` is received from LLM API, this value is overridden. |
|
||||
| provider_budget_config | ProviderBudgetConfig | Provider budget configuration. Use this to set llm_provider budget limits. example $100/day to OpenAI, $100/day to Azure, etc. Defaults to None. [Further Docs](./provider_budget_routing.md) |
|
||||
| enable_pre_call_checks | boolean | If true, checks if a call is within the model's context window before making the call. [More information here](reliability) |
|
||||
| model_group_retry_policy | Dict[str, RetryPolicy] | [SDK-only arg] Set retry policy for model groups. |
|
||||
| context_window_fallbacks | List[Dict[str, List[str]]] | Fallback models for context window violations. |
|
||||
| redis_url | str | URL for Redis server. **Known performance issue with Redis URL.** |
|
||||
| cache_responses | boolean | Flag to enable caching LLM Responses, if cache set under `router_settings`. If true, caches responses. Defaults to False. |
|
||||
| router_general_settings | RouterGeneralSettings | [SDK-Only] Router general settings - contains optimizations like 'async_only_mode'. [Docs](../routing.md#router-general-settings) |
|
||||
|
||||
|
||||
### environment variables - Reference
|
||||
|
||||
|
@ -359,8 +335,6 @@ router_settings:
|
|||
| DD_SITE | Site URL for Datadog (e.g., datadoghq.com)
|
||||
| DD_SOURCE | Source identifier for Datadog logs
|
||||
| DD_ENV | Environment identifier for Datadog logs. Only supported for `datadog_llm_observability` callback
|
||||
| DD_SERVICE | Service identifier for Datadog logs. Defaults to "litellm-server"
|
||||
| DD_VERSION | Version identifier for Datadog logs. Defaults to "unknown"
|
||||
| DEBUG_OTEL | Enable debug mode for OpenTelemetry
|
||||
| DIRECT_URL | Direct URL for service endpoint
|
||||
| DISABLE_ADMIN_UI | Toggle to disable the admin UI
|
||||
|
|
|
@ -357,6 +357,77 @@ curl --location 'http://0.0.0.0:4000/v1/model/info' \
|
|||
--data ''
|
||||
```
|
||||
|
||||
|
||||
### Provider specific wildcard routing
|
||||
**Proxy all models from a provider**
|
||||
|
||||
Use this if you want to **proxy all models from a specific provider without defining them on the config.yaml**
|
||||
|
||||
**Step 1** - define provider specific routing on config.yaml
|
||||
```yaml
|
||||
model_list:
|
||||
# provider specific wildcard routing
|
||||
- model_name: "anthropic/*"
|
||||
litellm_params:
|
||||
model: "anthropic/*"
|
||||
api_key: os.environ/ANTHROPIC_API_KEY
|
||||
- model_name: "groq/*"
|
||||
litellm_params:
|
||||
model: "groq/*"
|
||||
api_key: os.environ/GROQ_API_KEY
|
||||
- model_name: "fo::*:static::*" # all requests matching this pattern will be routed to this deployment, example: model="fo::hi::static::hi" will be routed to deployment: "openai/fo::*:static::*"
|
||||
litellm_params:
|
||||
model: "openai/fo::*:static::*"
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
```
|
||||
|
||||
Step 2 - Run litellm proxy
|
||||
|
||||
```shell
|
||||
$ litellm --config /path/to/config.yaml
|
||||
```
|
||||
|
||||
Step 3 Test it
|
||||
|
||||
Test with `anthropic/` - all models with `anthropic/` prefix will get routed to `anthropic/*`
|
||||
```shell
|
||||
curl http://localhost:4000/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer sk-1234" \
|
||||
-d '{
|
||||
"model": "anthropic/claude-3-sonnet-20240229",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello, Claude!"}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
Test with `groq/` - all models with `groq/` prefix will get routed to `groq/*`
|
||||
```shell
|
||||
curl http://localhost:4000/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer sk-1234" \
|
||||
-d '{
|
||||
"model": "groq/llama3-8b-8192",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello, Claude!"}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
Test with `fo::*::static::*` - all requests matching this pattern will be routed to `openai/fo::*:static::*`
|
||||
```shell
|
||||
curl http://localhost:4000/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer sk-1234" \
|
||||
-d '{
|
||||
"model": "fo::hi::static::hi",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello, Claude!"}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
### Load Balancing
|
||||
|
||||
:::info
|
||||
|
|
|
@ -192,13 +192,3 @@ Here is a screenshot of the metrics you can monitor with the LiteLLM Grafana Das
|
|||
|----------------------|--------------------------------------|
|
||||
| `litellm_llm_api_failed_requests_metric` | **deprecated** use `litellm_proxy_failed_requests_metric` |
|
||||
| `litellm_requests_metric` | **deprecated** use `litellm_proxy_total_requests_metric` |
|
||||
|
||||
|
||||
## FAQ
|
||||
|
||||
### What are `_created` vs. `_total` metrics?
|
||||
|
||||
- `_created` metrics are metrics that are created when the proxy starts
|
||||
- `_total` metrics are metrics that are incremented for each request
|
||||
|
||||
You should consume the `_total` metrics for your counting purposes
|
|
@ -1891,22 +1891,3 @@ router = Router(
|
|||
debug_level="DEBUG" # defaults to INFO
|
||||
)
|
||||
```
|
||||
|
||||
## Router General Settings
|
||||
|
||||
### Usage
|
||||
|
||||
```python
|
||||
router = Router(model_list=..., router_general_settings=RouterGeneralSettings(async_only_mode=True))
|
||||
```
|
||||
|
||||
### Spec
|
||||
```python
|
||||
class RouterGeneralSettings(BaseModel):
|
||||
async_only_mode: bool = Field(
|
||||
default=False
|
||||
) # this will only initialize async clients. Good for memory utils
|
||||
pass_through_all_models: bool = Field(
|
||||
default=False
|
||||
) # if passed a model not llm_router model list, pass through the request to litellm.acompletion/embedding
|
||||
```
|
|
@ -1,174 +0,0 @@
|
|||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# Text Completion
|
||||
|
||||
### Usage
|
||||
<Tabs>
|
||||
<TabItem value="python" label="LiteLLM Python SDK">
|
||||
|
||||
```python
|
||||
from litellm import text_completion
|
||||
|
||||
response = text_completion(
|
||||
model="gpt-3.5-turbo-instruct",
|
||||
prompt="Say this is a test",
|
||||
max_tokens=7
|
||||
)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
<TabItem value="proxy" label="LiteLLM Proxy Server">
|
||||
|
||||
1. Define models on config.yaml
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: gpt-3.5-turbo-instruct
|
||||
litellm_params:
|
||||
model: text-completion-openai/gpt-3.5-turbo-instruct # The `text-completion-openai/` prefix will call openai.completions.create
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
- model_name: text-davinci-003
|
||||
litellm_params:
|
||||
model: text-completion-openai/text-davinci-003
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
```
|
||||
|
||||
2. Start litellm proxy server
|
||||
|
||||
```
|
||||
litellm --config config.yaml
|
||||
```
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="python" label="OpenAI Python SDK">
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
# set base_url to your proxy server
|
||||
# set api_key to send to proxy server
|
||||
client = OpenAI(api_key="<proxy-api-key>", base_url="http://0.0.0.0:4000")
|
||||
|
||||
response = client.completions.create(
|
||||
model="gpt-3.5-turbo-instruct",
|
||||
prompt="Say this is a test",
|
||||
max_tokens=7
|
||||
)
|
||||
|
||||
print(response)
|
||||
```
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="curl" label="Curl Request">
|
||||
|
||||
```shell
|
||||
curl --location 'http://0.0.0.0:4000/completions' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--data '{
|
||||
"model": "gpt-3.5-turbo-instruct",
|
||||
"prompt": "Say this is a test",
|
||||
"max_tokens": 7
|
||||
}'
|
||||
```
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
## Input Params
|
||||
|
||||
LiteLLM accepts and translates the [OpenAI Text Completion params](https://platform.openai.com/docs/api-reference/completions) across all supported providers.
|
||||
|
||||
### Required Fields
|
||||
|
||||
- `model`: *string* - ID of the model to use
|
||||
- `prompt`: *string or array* - The prompt(s) to generate completions for
|
||||
|
||||
### Optional Fields
|
||||
|
||||
- `best_of`: *integer* - Generates best_of completions server-side and returns the "best" one
|
||||
- `echo`: *boolean* - Echo back the prompt in addition to the completion.
|
||||
- `frequency_penalty`: *number* - Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency.
|
||||
- `logit_bias`: *map* - Modify the likelihood of specified tokens appearing in the completion
|
||||
- `logprobs`: *integer* - Include the log probabilities on the logprobs most likely tokens. Max value of 5
|
||||
- `max_tokens`: *integer* - The maximum number of tokens to generate.
|
||||
- `n`: *integer* - How many completions to generate for each prompt.
|
||||
- `presence_penalty`: *number* - Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far.
|
||||
- `seed`: *integer* - If specified, system will attempt to make deterministic samples
|
||||
- `stop`: *string or array* - Up to 4 sequences where the API will stop generating tokens
|
||||
- `stream`: *boolean* - Whether to stream back partial progress. Defaults to false
|
||||
- `suffix`: *string* - The suffix that comes after a completion of inserted text
|
||||
- `temperature`: *number* - What sampling temperature to use, between 0 and 2.
|
||||
- `top_p`: *number* - An alternative to sampling with temperature, called nucleus sampling.
|
||||
- `user`: *string* - A unique identifier representing your end-user
|
||||
|
||||
## Output Format
|
||||
Here's the exact JSON output format you can expect from completion calls:
|
||||
|
||||
|
||||
[**Follows OpenAI's output format**](https://platform.openai.com/docs/api-reference/completions/object)
|
||||
|
||||
<Tabs>
|
||||
|
||||
<TabItem value="non-streaming" label="Non-Streaming Response">
|
||||
|
||||
```python
|
||||
{
|
||||
"id": "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7",
|
||||
"object": "text_completion",
|
||||
"created": 1589478378,
|
||||
"model": "gpt-3.5-turbo-instruct",
|
||||
"system_fingerprint": "fp_44709d6fcb",
|
||||
"choices": [
|
||||
{
|
||||
"text": "\n\nThis is indeed a test",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"finish_reason": "length"
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 5,
|
||||
"completion_tokens": 7,
|
||||
"total_tokens": 12
|
||||
}
|
||||
}
|
||||
|
||||
```
|
||||
</TabItem>
|
||||
<TabItem value="streaming" label="Streaming Response">
|
||||
|
||||
```python
|
||||
{
|
||||
"id": "cmpl-7iA7iJjj8V2zOkCGvWF2hAkDWBQZe",
|
||||
"object": "text_completion",
|
||||
"created": 1690759702,
|
||||
"choices": [
|
||||
{
|
||||
"text": "This",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"finish_reason": null
|
||||
}
|
||||
],
|
||||
"model": "gpt-3.5-turbo-instruct"
|
||||
"system_fingerprint": "fp_44709d6fcb",
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
|
||||
## **Supported Providers**
|
||||
|
||||
| Provider | Link to Usage |
|
||||
|-------------|--------------------|
|
||||
| OpenAI | [Usage](../docs/providers/text_completion_openai) |
|
||||
| Azure OpenAI| [Usage](../docs/providers/azure) |
|
||||
|
||||
|
|
@ -1,140 +0,0 @@
|
|||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# Provider specific Wildcard routing
|
||||
|
||||
**Proxy all models from a provider**
|
||||
|
||||
Use this if you want to **proxy all models from a specific provider without defining them on the config.yaml**
|
||||
|
||||
## Step 1. Define provider specific routing
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="sdk" label="SDK">
|
||||
|
||||
```python
|
||||
from litellm import Router
|
||||
|
||||
router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "anthropic/*",
|
||||
"litellm_params": {
|
||||
"model": "anthropic/*",
|
||||
"api_key": os.environ["ANTHROPIC_API_KEY"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"model_name": "groq/*",
|
||||
"litellm_params": {
|
||||
"model": "groq/*",
|
||||
"api_key": os.environ["GROQ_API_KEY"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"model_name": "fo::*:static::*", # all requests matching this pattern will be routed to this deployment, example: model="fo::hi::static::hi" will be routed to deployment: "openai/fo::*:static::*"
|
||||
"litellm_params": {
|
||||
"model": "openai/fo::*:static::*",
|
||||
"api_key": os.environ["OPENAI_API_KEY"]
|
||||
}
|
||||
}
|
||||
]
|
||||
)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
<TabItem value="proxy" label="PROXY">
|
||||
|
||||
**Step 1** - define provider specific routing on config.yaml
|
||||
```yaml
|
||||
model_list:
|
||||
# provider specific wildcard routing
|
||||
- model_name: "anthropic/*"
|
||||
litellm_params:
|
||||
model: "anthropic/*"
|
||||
api_key: os.environ/ANTHROPIC_API_KEY
|
||||
- model_name: "groq/*"
|
||||
litellm_params:
|
||||
model: "groq/*"
|
||||
api_key: os.environ/GROQ_API_KEY
|
||||
- model_name: "fo::*:static::*" # all requests matching this pattern will be routed to this deployment, example: model="fo::hi::static::hi" will be routed to deployment: "openai/fo::*:static::*"
|
||||
litellm_params:
|
||||
model: "openai/fo::*:static::*"
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
```
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
## [PROXY-Only] Step 2 - Run litellm proxy
|
||||
|
||||
```shell
|
||||
$ litellm --config /path/to/config.yaml
|
||||
```
|
||||
|
||||
## Step 3 - Test it
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="sdk" label="SDK">
|
||||
|
||||
```python
|
||||
from litellm import Router
|
||||
|
||||
router = Router(model_list=...)
|
||||
|
||||
# Test with `anthropic/` - all models with `anthropic/` prefix will get routed to `anthropic/*`
|
||||
resp = completion(model="anthropic/claude-3-sonnet-20240229", messages=[{"role": "user", "content": "Hello, Claude!"}])
|
||||
print(resp)
|
||||
|
||||
# Test with `groq/` - all models with `groq/` prefix will get routed to `groq/*`
|
||||
resp = completion(model="groq/llama3-8b-8192", messages=[{"role": "user", "content": "Hello, Groq!"}])
|
||||
print(resp)
|
||||
|
||||
# Test with `fo::*::static::*` - all requests matching this pattern will be routed to `openai/fo::*:static::*`
|
||||
resp = completion(model="fo::hi::static::hi", messages=[{"role": "user", "content": "Hello, Claude!"}])
|
||||
print(resp)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
<TabItem value="proxy" label="PROXY">
|
||||
|
||||
Test with `anthropic/` - all models with `anthropic/` prefix will get routed to `anthropic/*`
|
||||
```bash
|
||||
curl http://localhost:4000/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer sk-1234" \
|
||||
-d '{
|
||||
"model": "anthropic/claude-3-sonnet-20240229",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello, Claude!"}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
Test with `groq/` - all models with `groq/` prefix will get routed to `groq/*`
|
||||
```shell
|
||||
curl http://localhost:4000/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer sk-1234" \
|
||||
-d '{
|
||||
"model": "groq/llama3-8b-8192",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello, Claude!"}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
Test with `fo::*::static::*` - all requests matching this pattern will be routed to `openai/fo::*:static::*`
|
||||
```shell
|
||||
curl http://localhost:4000/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer sk-1234" \
|
||||
-d '{
|
||||
"model": "fo::hi::static::hi",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello, Claude!"}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
|
@ -246,7 +246,6 @@ const sidebars = {
|
|||
"completion/usage",
|
||||
],
|
||||
},
|
||||
"text_completion",
|
||||
"embedding/supported_embedding",
|
||||
"image_generation",
|
||||
{
|
||||
|
@ -262,7 +261,6 @@ const sidebars = {
|
|||
"batches",
|
||||
"realtime",
|
||||
"fine_tuning",
|
||||
"moderation",
|
||||
{
|
||||
type: "link",
|
||||
label: "Use LiteLLM Proxy with Vertex, Bedrock SDK",
|
||||
|
@ -279,7 +277,7 @@ const sidebars = {
|
|||
description: "Learn how to load balance, route, and set fallbacks for your LLM requests",
|
||||
slug: "/routing-load-balancing",
|
||||
},
|
||||
items: ["routing", "scheduler", "proxy/load_balancing", "proxy/reliability", "proxy/tag_routing", "proxy/provider_budget_routing", "proxy/team_based_routing", "proxy/customer_routing", "wildcard_routing"],
|
||||
items: ["routing", "scheduler", "proxy/load_balancing", "proxy/reliability", "proxy/tag_routing", "proxy/provider_budget_routing", "proxy/team_based_routing", "proxy/customer_routing"],
|
||||
},
|
||||
{
|
||||
type: "category",
|
||||
|
|
|
@ -2,9 +2,7 @@
|
|||
from typing import Optional, List
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.proxy.proxy_server import PrismaClient, HTTPException
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
import collections
|
||||
import httpx
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
|
@ -116,6 +114,7 @@ async def ui_get_spend_by_tags(
|
|||
|
||||
|
||||
def _forecast_daily_cost(data: list):
|
||||
import requests # type: ignore
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
if len(data) == 0:
|
||||
|
@ -137,17 +136,17 @@ def _forecast_daily_cost(data: list):
|
|||
|
||||
print("last entry date", last_entry_date)
|
||||
|
||||
# Assuming today_date is a datetime object
|
||||
today_date = datetime.now()
|
||||
|
||||
# Calculate the last day of the month
|
||||
last_day_of_todays_month = datetime(
|
||||
today_date.year, today_date.month % 12 + 1, 1
|
||||
) - timedelta(days=1)
|
||||
|
||||
print("last day of todays month", last_day_of_todays_month)
|
||||
# Calculate the remaining days in the month
|
||||
remaining_days = (last_day_of_todays_month - last_entry_date).days
|
||||
|
||||
print("remaining days", remaining_days)
|
||||
|
||||
current_spend_this_month = 0
|
||||
series = {}
|
||||
for entry in data:
|
||||
|
@ -177,19 +176,13 @@ def _forecast_daily_cost(data: list):
|
|||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
client = HTTPHandler()
|
||||
|
||||
try:
|
||||
response = client.post(
|
||||
url="https://trend-api-production.up.railway.app/forecast",
|
||||
json=payload,
|
||||
headers=headers,
|
||||
)
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": f"Error getting forecast: {e.response.text}"},
|
||||
)
|
||||
response = requests.post(
|
||||
url="https://trend-api-production.up.railway.app/forecast",
|
||||
json=payload,
|
||||
headers=headers,
|
||||
)
|
||||
# check the status code
|
||||
response.raise_for_status()
|
||||
|
||||
json_response = response.json()
|
||||
forecast_data = json_response["forecast"]
|
||||
|
@ -213,3 +206,13 @@ def _forecast_daily_cost(data: list):
|
|||
f"Predicted Spend for { today_month } 2024, ${total_predicted_spend}"
|
||||
)
|
||||
return {"response": response_data, "predicted_spend": predicted_spend}
|
||||
|
||||
# print(f"Date: {entry['date']}, Spend: {entry['spend']}, Response: {response.text}")
|
||||
|
||||
|
||||
# _forecast_daily_cost(
|
||||
# [
|
||||
# {"date": "2022-01-01", "spend": 100},
|
||||
|
||||
# ]
|
||||
# )
|
||||
|
|
|
@ -458,7 +458,7 @@ class AmazonConverseConfig:
|
|||
"""
|
||||
Abbreviations of regions AWS Bedrock supports for cross region inference
|
||||
"""
|
||||
return ["us", "eu", "apac"]
|
||||
return ["us", "eu"]
|
||||
|
||||
def _get_base_model(self, model: str) -> str:
|
||||
"""
|
||||
|
|
|
@ -28,62 +28,6 @@ headers = {
|
|||
_DEFAULT_TIMEOUT = httpx.Timeout(timeout=5.0, connect=5.0)
|
||||
_DEFAULT_TTL_FOR_HTTPX_CLIENTS = 3600 # 1 hour, re-use the same httpx client for 1 hour
|
||||
|
||||
import re
|
||||
|
||||
|
||||
def mask_sensitive_info(error_message):
|
||||
# Find the start of the key parameter
|
||||
if isinstance(error_message, str):
|
||||
key_index = error_message.find("key=")
|
||||
else:
|
||||
return error_message
|
||||
|
||||
# If key is found
|
||||
if key_index != -1:
|
||||
# Find the end of the key parameter (next & or end of string)
|
||||
next_param = error_message.find("&", key_index)
|
||||
|
||||
if next_param == -1:
|
||||
# If no more parameters, mask until the end of the string
|
||||
masked_message = error_message[: key_index + 4] + "[REDACTED_API_KEY]"
|
||||
else:
|
||||
# Replace the key with redacted value, keeping other parameters
|
||||
masked_message = (
|
||||
error_message[: key_index + 4]
|
||||
+ "[REDACTED_API_KEY]"
|
||||
+ error_message[next_param:]
|
||||
)
|
||||
|
||||
return masked_message
|
||||
|
||||
return error_message
|
||||
|
||||
|
||||
class MaskedHTTPStatusError(httpx.HTTPStatusError):
|
||||
def __init__(
|
||||
self, original_error, message: Optional[str] = None, text: Optional[str] = None
|
||||
):
|
||||
# Create a new error with the masked URL
|
||||
masked_url = mask_sensitive_info(str(original_error.request.url))
|
||||
# Create a new error that looks like the original, but with a masked URL
|
||||
|
||||
super().__init__(
|
||||
message=original_error.message,
|
||||
request=httpx.Request(
|
||||
method=original_error.request.method,
|
||||
url=masked_url,
|
||||
headers=original_error.request.headers,
|
||||
content=original_error.request.content,
|
||||
),
|
||||
response=httpx.Response(
|
||||
status_code=original_error.response.status_code,
|
||||
content=original_error.response.content,
|
||||
headers=original_error.response.headers,
|
||||
),
|
||||
)
|
||||
self.message = message
|
||||
self.text = text
|
||||
|
||||
|
||||
class AsyncHTTPHandler:
|
||||
def __init__(
|
||||
|
@ -211,16 +155,13 @@ class AsyncHTTPHandler:
|
|||
headers=headers,
|
||||
)
|
||||
except httpx.HTTPStatusError as e:
|
||||
|
||||
setattr(e, "status_code", e.response.status_code)
|
||||
if stream is True:
|
||||
setattr(e, "message", await e.response.aread())
|
||||
setattr(e, "text", await e.response.aread())
|
||||
else:
|
||||
setattr(e, "message", mask_sensitive_info(e.response.text))
|
||||
setattr(e, "text", mask_sensitive_info(e.response.text))
|
||||
|
||||
setattr(e, "status_code", e.response.status_code)
|
||||
|
||||
setattr(e, "message", e.response.text)
|
||||
setattr(e, "text", e.response.text)
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
@ -458,17 +399,11 @@ class HTTPHandler:
|
|||
llm_provider="litellm-httpx-handler",
|
||||
)
|
||||
except httpx.HTTPStatusError as e:
|
||||
|
||||
if stream is True:
|
||||
setattr(e, "message", mask_sensitive_info(e.response.read()))
|
||||
setattr(e, "text", mask_sensitive_info(e.response.read()))
|
||||
else:
|
||||
error_text = mask_sensitive_info(e.response.text)
|
||||
setattr(e, "message", error_text)
|
||||
setattr(e, "text", error_text)
|
||||
|
||||
setattr(e, "status_code", e.response.status_code)
|
||||
|
||||
if stream is True:
|
||||
setattr(e, "message", e.response.read())
|
||||
else:
|
||||
setattr(e, "message", e.response.text)
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
|
|
@ -1159,44 +1159,15 @@ def convert_to_anthropic_tool_result(
|
|||
]
|
||||
}
|
||||
"""
|
||||
anthropic_content: Union[
|
||||
str,
|
||||
List[Union[AnthropicMessagesToolResultContent, AnthropicMessagesImageParam]],
|
||||
] = ""
|
||||
content_str: str = ""
|
||||
if isinstance(message["content"], str):
|
||||
anthropic_content = message["content"]
|
||||
content_str = message["content"]
|
||||
elif isinstance(message["content"], List):
|
||||
content_list = message["content"]
|
||||
anthropic_content_list: List[
|
||||
Union[AnthropicMessagesToolResultContent, AnthropicMessagesImageParam]
|
||||
] = []
|
||||
for content in content_list:
|
||||
if content["type"] == "text":
|
||||
anthropic_content_list.append(
|
||||
AnthropicMessagesToolResultContent(
|
||||
type="text",
|
||||
text=content["text"],
|
||||
)
|
||||
)
|
||||
elif content["type"] == "image_url":
|
||||
if isinstance(content["image_url"], str):
|
||||
image_chunk = convert_to_anthropic_image_obj(content["image_url"])
|
||||
else:
|
||||
image_chunk = convert_to_anthropic_image_obj(
|
||||
content["image_url"]["url"]
|
||||
)
|
||||
anthropic_content_list.append(
|
||||
AnthropicMessagesImageParam(
|
||||
type="image",
|
||||
source=AnthropicContentParamSource(
|
||||
type="base64",
|
||||
media_type=image_chunk["media_type"],
|
||||
data=image_chunk["data"],
|
||||
),
|
||||
)
|
||||
)
|
||||
content_str += content["text"]
|
||||
|
||||
anthropic_content = anthropic_content_list
|
||||
anthropic_tool_result: Optional[AnthropicMessagesToolResultParam] = None
|
||||
## PROMPT CACHING CHECK ##
|
||||
cache_control = message.get("cache_control", None)
|
||||
|
@ -1207,14 +1178,14 @@ def convert_to_anthropic_tool_result(
|
|||
# We can't determine from openai message format whether it's a successful or
|
||||
# error call result so default to the successful result template
|
||||
anthropic_tool_result = AnthropicMessagesToolResultParam(
|
||||
type="tool_result", tool_use_id=tool_call_id, content=anthropic_content
|
||||
type="tool_result", tool_use_id=tool_call_id, content=content_str
|
||||
)
|
||||
|
||||
if message["role"] == "function":
|
||||
function_message: ChatCompletionFunctionMessage = message
|
||||
tool_call_id = function_message.get("tool_call_id") or str(uuid.uuid4())
|
||||
anthropic_tool_result = AnthropicMessagesToolResultParam(
|
||||
type="tool_result", tool_use_id=tool_call_id, content=anthropic_content
|
||||
type="tool_result", tool_use_id=tool_call_id, content=content_str
|
||||
)
|
||||
|
||||
if anthropic_tool_result is None:
|
||||
|
|
|
@ -107,10 +107,6 @@ def _get_image_mime_type_from_url(url: str) -> Optional[str]:
|
|||
return "image/png"
|
||||
elif url.endswith(".webp"):
|
||||
return "image/webp"
|
||||
elif url.endswith(".mp4"):
|
||||
return "video/mp4"
|
||||
elif url.endswith(".pdf"):
|
||||
return "application/pdf"
|
||||
return None
|
||||
|
||||
|
||||
|
|
|
@ -3383,8 +3383,6 @@
|
|||
"supports_vision": true,
|
||||
"supports_response_schema": true,
|
||||
"supports_prompt_caching": true,
|
||||
"tpm": 4000000,
|
||||
"rpm": 2000,
|
||||
"source": "https://ai.google.dev/pricing"
|
||||
},
|
||||
"gemini/gemini-1.5-flash-001": {
|
||||
|
@ -3408,8 +3406,6 @@
|
|||
"supports_vision": true,
|
||||
"supports_response_schema": true,
|
||||
"supports_prompt_caching": true,
|
||||
"tpm": 4000000,
|
||||
"rpm": 2000,
|
||||
"source": "https://ai.google.dev/pricing"
|
||||
},
|
||||
"gemini/gemini-1.5-flash": {
|
||||
|
@ -3432,8 +3428,6 @@
|
|||
"supports_function_calling": true,
|
||||
"supports_vision": true,
|
||||
"supports_response_schema": true,
|
||||
"tpm": 4000000,
|
||||
"rpm": 2000,
|
||||
"source": "https://ai.google.dev/pricing"
|
||||
},
|
||||
"gemini/gemini-1.5-flash-latest": {
|
||||
|
@ -3456,32 +3450,6 @@
|
|||
"supports_function_calling": true,
|
||||
"supports_vision": true,
|
||||
"supports_response_schema": true,
|
||||
"tpm": 4000000,
|
||||
"rpm": 2000,
|
||||
"source": "https://ai.google.dev/pricing"
|
||||
},
|
||||
"gemini/gemini-1.5-flash-8b": {
|
||||
"max_tokens": 8192,
|
||||
"max_input_tokens": 1048576,
|
||||
"max_output_tokens": 8192,
|
||||
"max_images_per_prompt": 3000,
|
||||
"max_videos_per_prompt": 10,
|
||||
"max_video_length": 1,
|
||||
"max_audio_length_hours": 8.4,
|
||||
"max_audio_per_prompt": 1,
|
||||
"max_pdf_size_mb": 30,
|
||||
"input_cost_per_token": 0,
|
||||
"input_cost_per_token_above_128k_tokens": 0,
|
||||
"output_cost_per_token": 0,
|
||||
"output_cost_per_token_above_128k_tokens": 0,
|
||||
"litellm_provider": "gemini",
|
||||
"mode": "chat",
|
||||
"supports_system_messages": true,
|
||||
"supports_function_calling": true,
|
||||
"supports_vision": true,
|
||||
"supports_response_schema": true,
|
||||
"tpm": 4000000,
|
||||
"rpm": 4000,
|
||||
"source": "https://ai.google.dev/pricing"
|
||||
},
|
||||
"gemini/gemini-1.5-flash-8b-exp-0924": {
|
||||
|
@ -3504,8 +3472,6 @@
|
|||
"supports_function_calling": true,
|
||||
"supports_vision": true,
|
||||
"supports_response_schema": true,
|
||||
"tpm": 4000000,
|
||||
"rpm": 4000,
|
||||
"source": "https://ai.google.dev/pricing"
|
||||
},
|
||||
"gemini/gemini-exp-1114": {
|
||||
|
@ -3528,12 +3494,7 @@
|
|||
"supports_function_calling": true,
|
||||
"supports_vision": true,
|
||||
"supports_response_schema": true,
|
||||
"tpm": 4000000,
|
||||
"rpm": 1000,
|
||||
"source": "https://ai.google.dev/pricing",
|
||||
"metadata": {
|
||||
"notes": "Rate limits not documented for gemini-exp-1114. Assuming same as gemini-1.5-pro."
|
||||
}
|
||||
"source": "https://ai.google.dev/pricing"
|
||||
},
|
||||
"gemini/gemini-1.5-flash-exp-0827": {
|
||||
"max_tokens": 8192,
|
||||
|
@ -3555,8 +3516,6 @@
|
|||
"supports_function_calling": true,
|
||||
"supports_vision": true,
|
||||
"supports_response_schema": true,
|
||||
"tpm": 4000000,
|
||||
"rpm": 2000,
|
||||
"source": "https://ai.google.dev/pricing"
|
||||
},
|
||||
"gemini/gemini-1.5-flash-8b-exp-0827": {
|
||||
|
@ -3578,9 +3537,6 @@
|
|||
"supports_system_messages": true,
|
||||
"supports_function_calling": true,
|
||||
"supports_vision": true,
|
||||
"supports_response_schema": true,
|
||||
"tpm": 4000000,
|
||||
"rpm": 4000,
|
||||
"source": "https://ai.google.dev/pricing"
|
||||
},
|
||||
"gemini/gemini-pro": {
|
||||
|
@ -3594,10 +3550,7 @@
|
|||
"litellm_provider": "gemini",
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"rpd": 30000,
|
||||
"tpm": 120000,
|
||||
"rpm": 360,
|
||||
"source": "https://ai.google.dev/gemini-api/docs/models/gemini"
|
||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
|
||||
},
|
||||
"gemini/gemini-1.5-pro": {
|
||||
"max_tokens": 8192,
|
||||
|
@ -3614,8 +3567,6 @@
|
|||
"supports_vision": true,
|
||||
"supports_tool_choice": true,
|
||||
"supports_response_schema": true,
|
||||
"tpm": 4000000,
|
||||
"rpm": 1000,
|
||||
"source": "https://ai.google.dev/pricing"
|
||||
},
|
||||
"gemini/gemini-1.5-pro-002": {
|
||||
|
@ -3634,8 +3585,6 @@
|
|||
"supports_tool_choice": true,
|
||||
"supports_response_schema": true,
|
||||
"supports_prompt_caching": true,
|
||||
"tpm": 4000000,
|
||||
"rpm": 1000,
|
||||
"source": "https://ai.google.dev/pricing"
|
||||
},
|
||||
"gemini/gemini-1.5-pro-001": {
|
||||
|
@ -3654,8 +3603,6 @@
|
|||
"supports_tool_choice": true,
|
||||
"supports_response_schema": true,
|
||||
"supports_prompt_caching": true,
|
||||
"tpm": 4000000,
|
||||
"rpm": 1000,
|
||||
"source": "https://ai.google.dev/pricing"
|
||||
},
|
||||
"gemini/gemini-1.5-pro-exp-0801": {
|
||||
|
@ -3673,8 +3620,6 @@
|
|||
"supports_vision": true,
|
||||
"supports_tool_choice": true,
|
||||
"supports_response_schema": true,
|
||||
"tpm": 4000000,
|
||||
"rpm": 1000,
|
||||
"source": "https://ai.google.dev/pricing"
|
||||
},
|
||||
"gemini/gemini-1.5-pro-exp-0827": {
|
||||
|
@ -3692,8 +3637,6 @@
|
|||
"supports_vision": true,
|
||||
"supports_tool_choice": true,
|
||||
"supports_response_schema": true,
|
||||
"tpm": 4000000,
|
||||
"rpm": 1000,
|
||||
"source": "https://ai.google.dev/pricing"
|
||||
},
|
||||
"gemini/gemini-1.5-pro-latest": {
|
||||
|
@ -3711,8 +3654,6 @@
|
|||
"supports_vision": true,
|
||||
"supports_tool_choice": true,
|
||||
"supports_response_schema": true,
|
||||
"tpm": 4000000,
|
||||
"rpm": 1000,
|
||||
"source": "https://ai.google.dev/pricing"
|
||||
},
|
||||
"gemini/gemini-pro-vision": {
|
||||
|
@ -3727,9 +3668,6 @@
|
|||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_vision": true,
|
||||
"rpd": 30000,
|
||||
"tpm": 120000,
|
||||
"rpm": 360,
|
||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
|
||||
},
|
||||
"gemini/gemini-gemma-2-27b-it": {
|
||||
|
|
|
@ -11,27 +11,7 @@ model_list:
|
|||
model: vertex_ai/claude-3-5-sonnet-v2
|
||||
vertex_ai_project: "adroit-crow-413218"
|
||||
vertex_ai_location: "us-east5"
|
||||
- model_name: openai-gpt-4o-realtime-audio
|
||||
litellm_params:
|
||||
model: openai/gpt-4o-realtime-preview-2024-10-01
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
- model_name: openai/*
|
||||
litellm_params:
|
||||
model: openai/*
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
- model_name: openai/*
|
||||
litellm_params:
|
||||
model: openai/*
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
model_info:
|
||||
access_groups: ["public-openai-models"]
|
||||
- model_name: openai/gpt-4o
|
||||
litellm_params:
|
||||
model: openai/gpt-4o
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
model_info:
|
||||
access_groups: ["private-openai-models"]
|
||||
|
||||
|
||||
router_settings:
|
||||
routing_strategy: usage-based-routing-v2
|
||||
#redis_url: "os.environ/REDIS_URL"
|
||||
|
|
|
@ -2183,11 +2183,3 @@ PassThroughEndpointLoggingResultValues = Union[
|
|||
class PassThroughEndpointLoggingTypedDict(TypedDict):
|
||||
result: Optional[PassThroughEndpointLoggingResultValues]
|
||||
kwargs: dict
|
||||
|
||||
|
||||
LiteLLM_ManagementEndpoint_MetadataFields = [
|
||||
"model_rpm_limit",
|
||||
"model_tpm_limit",
|
||||
"guardrails",
|
||||
"tags",
|
||||
]
|
||||
|
|
|
@ -60,7 +60,6 @@ def common_checks( # noqa: PLR0915
|
|||
global_proxy_spend: Optional[float],
|
||||
general_settings: dict,
|
||||
route: str,
|
||||
llm_router: Optional[litellm.Router],
|
||||
) -> bool:
|
||||
"""
|
||||
Common checks across jwt + key-based auth.
|
||||
|
@ -98,12 +97,7 @@ def common_checks( # noqa: PLR0915
|
|||
# this means the team has access to all models on the proxy
|
||||
pass
|
||||
# check if the team model is an access_group
|
||||
elif (
|
||||
model_in_access_group(
|
||||
model=_model, team_models=team_object.models, llm_router=llm_router
|
||||
)
|
||||
is True
|
||||
):
|
||||
elif model_in_access_group(_model, team_object.models) is True:
|
||||
pass
|
||||
elif _model and "*" in _model:
|
||||
pass
|
||||
|
@ -379,33 +373,36 @@ async def get_end_user_object(
|
|||
return None
|
||||
|
||||
|
||||
def model_in_access_group(
|
||||
model: str, team_models: Optional[List[str]], llm_router: Optional[litellm.Router]
|
||||
) -> bool:
|
||||
def model_in_access_group(model: str, team_models: Optional[List[str]]) -> bool:
|
||||
from collections import defaultdict
|
||||
|
||||
from litellm.proxy.proxy_server import llm_router
|
||||
|
||||
if team_models is None:
|
||||
return True
|
||||
if model in team_models:
|
||||
return True
|
||||
|
||||
access_groups: dict[str, list[str]] = defaultdict(list)
|
||||
access_groups = defaultdict(list)
|
||||
if llm_router:
|
||||
access_groups = llm_router.get_model_access_groups(model_name=model)
|
||||
access_groups = llm_router.get_model_access_groups()
|
||||
|
||||
models_in_current_access_groups = []
|
||||
if len(access_groups) > 0: # check if token contains any model access groups
|
||||
for idx, m in enumerate(
|
||||
team_models
|
||||
): # loop token models, if any of them are an access group add the access group
|
||||
if m in access_groups:
|
||||
return True
|
||||
# if it is an access group we need to remove it from valid_token.models
|
||||
models_in_group = access_groups[m]
|
||||
models_in_current_access_groups.extend(models_in_group)
|
||||
|
||||
# Filter out models that are access_groups
|
||||
filtered_models = [m for m in team_models if m not in access_groups]
|
||||
filtered_models += models_in_current_access_groups
|
||||
|
||||
if model in filtered_models:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
|
@ -526,6 +523,10 @@ async def _cache_management_object(
|
|||
proxy_logging_obj: Optional[ProxyLogging],
|
||||
):
|
||||
await user_api_key_cache.async_set_cache(key=key, value=value)
|
||||
if proxy_logging_obj is not None:
|
||||
await proxy_logging_obj.internal_usage_cache.dual_cache.async_set_cache(
|
||||
key=key, value=value
|
||||
)
|
||||
|
||||
|
||||
async def _cache_team_object(
|
||||
|
@ -877,10 +878,7 @@ async def get_org_object(
|
|||
|
||||
|
||||
async def can_key_call_model(
|
||||
model: str,
|
||||
llm_model_list: Optional[list],
|
||||
valid_token: UserAPIKeyAuth,
|
||||
llm_router: Optional[litellm.Router],
|
||||
model: str, llm_model_list: Optional[list], valid_token: UserAPIKeyAuth
|
||||
) -> Literal[True]:
|
||||
"""
|
||||
Checks if token can call a given model
|
||||
|
@ -900,29 +898,35 @@ async def can_key_call_model(
|
|||
)
|
||||
from collections import defaultdict
|
||||
|
||||
from litellm.proxy.proxy_server import llm_router
|
||||
|
||||
access_groups = defaultdict(list)
|
||||
if llm_router:
|
||||
access_groups = llm_router.get_model_access_groups(model_name=model)
|
||||
access_groups = llm_router.get_model_access_groups()
|
||||
|
||||
if (
|
||||
len(access_groups) > 0 and llm_router is not None
|
||||
): # check if token contains any model access groups
|
||||
models_in_current_access_groups = []
|
||||
if len(access_groups) > 0: # check if token contains any model access groups
|
||||
for idx, m in enumerate(
|
||||
valid_token.models
|
||||
): # loop token models, if any of them are an access group add the access group
|
||||
if m in access_groups:
|
||||
return True
|
||||
# if it is an access group we need to remove it from valid_token.models
|
||||
models_in_group = access_groups[m]
|
||||
models_in_current_access_groups.extend(models_in_group)
|
||||
|
||||
# Filter out models that are access_groups
|
||||
filtered_models = [m for m in valid_token.models if m not in access_groups]
|
||||
|
||||
filtered_models += models_in_current_access_groups
|
||||
verbose_proxy_logger.debug(f"model: {model}; allowed_models: {filtered_models}")
|
||||
|
||||
all_model_access: bool = False
|
||||
|
||||
if (
|
||||
len(filtered_models) == 0 and len(valid_token.models) == 0
|
||||
) or "*" in filtered_models:
|
||||
len(filtered_models) == 0
|
||||
or "*" in filtered_models
|
||||
or "openai/*" in filtered_models
|
||||
):
|
||||
all_model_access = True
|
||||
|
||||
if model is not None and model not in filtered_models and all_model_access is False:
|
||||
|
|
|
@ -28,8 +28,6 @@ from fastapi import (
|
|||
Request,
|
||||
Response,
|
||||
UploadFile,
|
||||
WebSocket,
|
||||
WebSocketDisconnect,
|
||||
status,
|
||||
)
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
@ -197,52 +195,6 @@ def _is_allowed_route(
|
|||
)
|
||||
|
||||
|
||||
async def user_api_key_auth_websocket(websocket: WebSocket):
|
||||
# Accept the WebSocket connection
|
||||
|
||||
request = Request(scope={"type": "http"})
|
||||
request._url = websocket.url
|
||||
|
||||
query_params = websocket.query_params
|
||||
|
||||
model = query_params.get("model")
|
||||
|
||||
async def return_body():
|
||||
return_string = f'{{"model": "{model}"}}'
|
||||
# return string as bytes
|
||||
return return_string.encode()
|
||||
|
||||
request.body = return_body # type: ignore
|
||||
|
||||
# Extract the Authorization header
|
||||
authorization = websocket.headers.get("authorization")
|
||||
|
||||
# If no Authorization header, try the api-key header
|
||||
if not authorization:
|
||||
api_key = websocket.headers.get("api-key")
|
||||
if not api_key:
|
||||
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
|
||||
raise HTTPException(status_code=403, detail="No API key provided")
|
||||
else:
|
||||
# Extract the API key from the Bearer token
|
||||
if not authorization.startswith("Bearer "):
|
||||
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
|
||||
raise HTTPException(
|
||||
status_code=403, detail="Invalid Authorization header format"
|
||||
)
|
||||
|
||||
api_key = authorization[len("Bearer ") :].strip()
|
||||
|
||||
# Call user_api_key_auth with the extracted API key
|
||||
# Note: You'll need to modify this to work with WebSocket context if needed
|
||||
try:
|
||||
return await user_api_key_auth(request=request, api_key=f"Bearer {api_key}")
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(e)
|
||||
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
|
||||
|
||||
async def user_api_key_auth( # noqa: PLR0915
|
||||
request: Request,
|
||||
api_key: str = fastapi.Security(api_key_header),
|
||||
|
@ -259,7 +211,6 @@ async def user_api_key_auth( # noqa: PLR0915
|
|||
jwt_handler,
|
||||
litellm_proxy_admin_name,
|
||||
llm_model_list,
|
||||
llm_router,
|
||||
master_key,
|
||||
open_telemetry_logger,
|
||||
prisma_client,
|
||||
|
@ -543,7 +494,6 @@ async def user_api_key_auth( # noqa: PLR0915
|
|||
general_settings=general_settings,
|
||||
global_proxy_spend=global_proxy_spend,
|
||||
route=route,
|
||||
llm_router=llm_router,
|
||||
)
|
||||
|
||||
# return UserAPIKeyAuth object
|
||||
|
@ -907,7 +857,6 @@ async def user_api_key_auth( # noqa: PLR0915
|
|||
model=model,
|
||||
llm_model_list=llm_model_list,
|
||||
valid_token=valid_token,
|
||||
llm_router=llm_router,
|
||||
)
|
||||
|
||||
if fallback_models is not None:
|
||||
|
@ -916,7 +865,6 @@ async def user_api_key_auth( # noqa: PLR0915
|
|||
model=m,
|
||||
llm_model_list=llm_model_list,
|
||||
valid_token=valid_token,
|
||||
llm_router=llm_router,
|
||||
)
|
||||
|
||||
# Check 2. If user_id for this token is in budget - done in common_checks()
|
||||
|
@ -1177,7 +1125,6 @@ async def user_api_key_auth( # noqa: PLR0915
|
|||
general_settings=general_settings,
|
||||
global_proxy_spend=global_proxy_spend,
|
||||
route=route,
|
||||
llm_router=llm_router,
|
||||
)
|
||||
# Token passed all checks
|
||||
if valid_token is None:
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import ast
|
||||
import json
|
||||
from typing import Dict, List, Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import Request, UploadFile, status
|
||||
|
||||
|
@ -8,43 +8,31 @@ from litellm._logging import verbose_proxy_logger
|
|||
from litellm.types.router import Deployment
|
||||
|
||||
|
||||
async def _read_request_body(request: Optional[Request]) -> Dict:
|
||||
async def _read_request_body(request: Optional[Request]) -> dict:
|
||||
"""
|
||||
Safely read the request body and parse it as JSON.
|
||||
Asynchronous function to read the request body and parse it as JSON or literal data.
|
||||
|
||||
Parameters:
|
||||
- request: The request object to read the body from
|
||||
|
||||
Returns:
|
||||
- dict: Parsed request data as a dictionary or an empty dictionary if parsing fails
|
||||
- dict: Parsed request data as a dictionary
|
||||
"""
|
||||
try:
|
||||
request_data: dict = {}
|
||||
if request is None:
|
||||
return {}
|
||||
|
||||
# Read the request body
|
||||
return request_data
|
||||
body = await request.body()
|
||||
|
||||
# Return empty dict if body is empty or None
|
||||
if not body:
|
||||
return {}
|
||||
|
||||
# Decode the body to a string
|
||||
if body == b"" or body is None:
|
||||
return request_data
|
||||
body_str = body.decode()
|
||||
|
||||
# Attempt JSON parsing (safe for untrusted input)
|
||||
return json.loads(body_str)
|
||||
|
||||
except json.JSONDecodeError:
|
||||
# Log detailed information for debugging
|
||||
verbose_proxy_logger.exception("Invalid JSON payload received.")
|
||||
return {}
|
||||
|
||||
except Exception as e:
|
||||
# Catch unexpected errors to avoid crashes
|
||||
verbose_proxy_logger.exception(
|
||||
"Unexpected error reading request body - {}".format(e)
|
||||
)
|
||||
try:
|
||||
request_data = ast.literal_eval(body_str)
|
||||
except Exception:
|
||||
request_data = json.loads(body_str)
|
||||
return request_data
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
|
||||
|
|
|
@ -214,10 +214,10 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
|
|||
prepared_request.url,
|
||||
prepared_request.headers,
|
||||
)
|
||||
|
||||
_json_data = json.dumps(request_data) # type: ignore
|
||||
response = await self.async_handler.post(
|
||||
url=prepared_request.url,
|
||||
data=prepared_request.body, # type: ignore
|
||||
json=request_data, # type: ignore
|
||||
headers=prepared_request.headers, # type: ignore
|
||||
)
|
||||
verbose_proxy_logger.debug("Bedrock AI response: %s", response.text)
|
||||
|
|
|
@ -288,12 +288,12 @@ class LiteLLMProxyRequestSetup:
|
|||
|
||||
## KEY-LEVEL SPEND LOGS / TAGS
|
||||
if "tags" in key_metadata and key_metadata["tags"] is not None:
|
||||
data[_metadata_variable_name]["tags"] = (
|
||||
LiteLLMProxyRequestSetup._merge_tags(
|
||||
request_tags=data[_metadata_variable_name].get("tags"),
|
||||
tags_to_add=key_metadata["tags"],
|
||||
)
|
||||
)
|
||||
if "tags" in data[_metadata_variable_name] and isinstance(
|
||||
data[_metadata_variable_name]["tags"], list
|
||||
):
|
||||
data[_metadata_variable_name]["tags"].extend(key_metadata["tags"])
|
||||
else:
|
||||
data[_metadata_variable_name]["tags"] = key_metadata["tags"]
|
||||
if "spend_logs_metadata" in key_metadata and isinstance(
|
||||
key_metadata["spend_logs_metadata"], dict
|
||||
):
|
||||
|
@ -319,30 +319,6 @@ class LiteLLMProxyRequestSetup:
|
|||
data["disable_fallbacks"] = key_metadata["disable_fallbacks"]
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
def _merge_tags(request_tags: Optional[list], tags_to_add: Optional[list]) -> list:
|
||||
"""
|
||||
Helper function to merge two lists of tags, ensuring no duplicates.
|
||||
|
||||
Args:
|
||||
request_tags (Optional[list]): List of tags from the original request
|
||||
tags_to_add (Optional[list]): List of tags to add
|
||||
|
||||
Returns:
|
||||
list: Combined list of unique tags
|
||||
"""
|
||||
final_tags = []
|
||||
|
||||
if request_tags and isinstance(request_tags, list):
|
||||
final_tags.extend(request_tags)
|
||||
|
||||
if tags_to_add and isinstance(tags_to_add, list):
|
||||
for tag in tags_to_add:
|
||||
if tag not in final_tags:
|
||||
final_tags.append(tag)
|
||||
|
||||
return final_tags
|
||||
|
||||
|
||||
async def add_litellm_data_to_request( # noqa: PLR0915
|
||||
data: dict,
|
||||
|
@ -466,10 +442,12 @@ async def add_litellm_data_to_request( # noqa: PLR0915
|
|||
## TEAM-LEVEL SPEND LOGS/TAGS
|
||||
team_metadata = user_api_key_dict.team_metadata or {}
|
||||
if "tags" in team_metadata and team_metadata["tags"] is not None:
|
||||
data[_metadata_variable_name]["tags"] = LiteLLMProxyRequestSetup._merge_tags(
|
||||
request_tags=data[_metadata_variable_name].get("tags"),
|
||||
tags_to_add=team_metadata["tags"],
|
||||
)
|
||||
if "tags" in data[_metadata_variable_name] and isinstance(
|
||||
data[_metadata_variable_name]["tags"], list
|
||||
):
|
||||
data[_metadata_variable_name]["tags"].extend(team_metadata["tags"])
|
||||
else:
|
||||
data[_metadata_variable_name]["tags"] = team_metadata["tags"]
|
||||
if "spend_logs_metadata" in team_metadata and isinstance(
|
||||
team_metadata["spend_logs_metadata"], dict
|
||||
):
|
||||
|
|
|
@ -32,7 +32,6 @@ from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
|||
from litellm.proxy.management_endpoints.key_management_endpoints import (
|
||||
duration_in_seconds,
|
||||
generate_key_helper_fn,
|
||||
prepare_metadata_fields,
|
||||
)
|
||||
from litellm.proxy.management_helpers.utils import (
|
||||
add_new_member,
|
||||
|
@ -43,7 +42,7 @@ from litellm.proxy.utils import handle_exception_on_proxy
|
|||
router = APIRouter()
|
||||
|
||||
|
||||
def _update_internal_new_user_params(data_json: dict, data: NewUserRequest) -> dict:
|
||||
def _update_internal_user_params(data_json: dict, data: NewUserRequest) -> dict:
|
||||
if "user_id" in data_json and data_json["user_id"] is None:
|
||||
data_json["user_id"] = str(uuid.uuid4())
|
||||
auto_create_key = data_json.pop("auto_create_key", True)
|
||||
|
@ -146,7 +145,7 @@ async def new_user(
|
|||
from litellm.proxy.proxy_server import general_settings, proxy_logging_obj
|
||||
|
||||
data_json = data.json() # type: ignore
|
||||
data_json = _update_internal_new_user_params(data_json, data)
|
||||
data_json = _update_internal_user_params(data_json, data)
|
||||
response = await generate_key_helper_fn(request_type="user", **data_json)
|
||||
|
||||
# Admin UI Logic
|
||||
|
@ -439,52 +438,6 @@ async def user_info( # noqa: PLR0915
|
|||
raise handle_exception_on_proxy(e)
|
||||
|
||||
|
||||
def _update_internal_user_params(data_json: dict, data: UpdateUserRequest) -> dict:
|
||||
non_default_values = {}
|
||||
for k, v in data_json.items():
|
||||
if (
|
||||
v is not None
|
||||
and v
|
||||
not in (
|
||||
[],
|
||||
{},
|
||||
0,
|
||||
)
|
||||
and k not in LiteLLM_ManagementEndpoint_MetadataFields
|
||||
): # models default to [], spend defaults to 0, we should not reset these values
|
||||
non_default_values[k] = v
|
||||
|
||||
is_internal_user = False
|
||||
if data.user_role == LitellmUserRoles.INTERNAL_USER:
|
||||
is_internal_user = True
|
||||
|
||||
if "budget_duration" in non_default_values:
|
||||
duration_s = duration_in_seconds(duration=non_default_values["budget_duration"])
|
||||
user_reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s)
|
||||
non_default_values["budget_reset_at"] = user_reset_at
|
||||
|
||||
if "max_budget" not in non_default_values:
|
||||
if (
|
||||
is_internal_user and litellm.max_internal_user_budget is not None
|
||||
): # applies internal user limits, if user role updated
|
||||
non_default_values["max_budget"] = litellm.max_internal_user_budget
|
||||
|
||||
if (
|
||||
"budget_duration" not in non_default_values
|
||||
): # applies internal user limits, if user role updated
|
||||
if is_internal_user and litellm.internal_user_budget_duration is not None:
|
||||
non_default_values["budget_duration"] = (
|
||||
litellm.internal_user_budget_duration
|
||||
)
|
||||
duration_s = duration_in_seconds(
|
||||
duration=non_default_values["budget_duration"]
|
||||
)
|
||||
user_reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s)
|
||||
non_default_values["budget_reset_at"] = user_reset_at
|
||||
|
||||
return non_default_values
|
||||
|
||||
|
||||
@router.post(
|
||||
"/user/update",
|
||||
tags=["Internal User management"],
|
||||
|
@ -506,8 +459,7 @@ async def user_update(
|
|||
"user_id": "test-litellm-user-4",
|
||||
"user_role": "proxy_admin_viewer"
|
||||
}'
|
||||
```
|
||||
|
||||
|
||||
Parameters:
|
||||
- user_id: Optional[str] - Specify a user id. If not set, a unique id will be generated.
|
||||
- user_email: Optional[str] - Specify a user email.
|
||||
|
@ -539,7 +491,7 @@ async def user_update(
|
|||
- duration: Optional[str] - [NOT IMPLEMENTED].
|
||||
- key_alias: Optional[str] - [NOT IMPLEMENTED].
|
||||
|
||||
|
||||
```
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
|
@ -550,21 +502,46 @@ async def user_update(
|
|||
raise Exception("Not connected to DB!")
|
||||
|
||||
# get non default values for key
|
||||
non_default_values = _update_internal_user_params(
|
||||
data_json=data_json, data=data
|
||||
)
|
||||
non_default_values = {}
|
||||
for k, v in data_json.items():
|
||||
if v is not None and v not in (
|
||||
[],
|
||||
{},
|
||||
0,
|
||||
): # models default to [], spend defaults to 0, we should not reset these values
|
||||
non_default_values[k] = v
|
||||
|
||||
existing_user_row = await prisma_client.get_data(
|
||||
user_id=data.user_id, table_name="user", query_type="find_unique"
|
||||
)
|
||||
is_internal_user = False
|
||||
if data.user_role == LitellmUserRoles.INTERNAL_USER:
|
||||
is_internal_user = True
|
||||
|
||||
existing_metadata = existing_user_row.metadata if existing_user_row else {}
|
||||
if "budget_duration" in non_default_values:
|
||||
duration_s = duration_in_seconds(
|
||||
duration=non_default_values["budget_duration"]
|
||||
)
|
||||
user_reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s)
|
||||
non_default_values["budget_reset_at"] = user_reset_at
|
||||
|
||||
non_default_values = prepare_metadata_fields(
|
||||
data=data,
|
||||
non_default_values=non_default_values,
|
||||
existing_metadata=existing_metadata or {},
|
||||
)
|
||||
if "max_budget" not in non_default_values:
|
||||
if (
|
||||
is_internal_user and litellm.max_internal_user_budget is not None
|
||||
): # applies internal user limits, if user role updated
|
||||
non_default_values["max_budget"] = litellm.max_internal_user_budget
|
||||
|
||||
if (
|
||||
"budget_duration" not in non_default_values
|
||||
): # applies internal user limits, if user role updated
|
||||
if is_internal_user and litellm.internal_user_budget_duration is not None:
|
||||
non_default_values["budget_duration"] = (
|
||||
litellm.internal_user_budget_duration
|
||||
)
|
||||
duration_s = duration_in_seconds(
|
||||
duration=non_default_values["budget_duration"]
|
||||
)
|
||||
user_reset_at = datetime.now(timezone.utc) + timedelta(
|
||||
seconds=duration_s
|
||||
)
|
||||
non_default_values["budget_reset_at"] = user_reset_at
|
||||
|
||||
## ADD USER, IF NEW ##
|
||||
verbose_proxy_logger.debug("/user/update: Received data = %s", data)
|
||||
|
|
|
@ -17,7 +17,7 @@ import secrets
|
|||
import traceback
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import List, Optional, Tuple, cast
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import fastapi
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Request, status
|
||||
|
@ -394,8 +394,7 @@ async def generate_key_fn( # noqa: PLR0915
|
|||
}
|
||||
)
|
||||
_budget_id = getattr(_budget, "budget_id", None)
|
||||
data_json = data.model_dump(exclude_unset=True, exclude_none=True) # type: ignore
|
||||
|
||||
data_json = data.json() # type: ignore
|
||||
# if we get max_budget passed to /key/generate, then use it as key_max_budget. Since generate_key_helper_fn is used to make new users
|
||||
if "max_budget" in data_json:
|
||||
data_json["key_max_budget"] = data_json.pop("max_budget", None)
|
||||
|
@ -421,11 +420,6 @@ async def generate_key_fn( # noqa: PLR0915
|
|||
|
||||
data_json.pop("tags")
|
||||
|
||||
await _enforce_unique_key_alias(
|
||||
key_alias=data_json.get("key_alias", None),
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
|
||||
response = await generate_key_helper_fn(
|
||||
request_type="key", **data_json, table_name="key"
|
||||
)
|
||||
|
@ -453,52 +447,12 @@ async def generate_key_fn( # noqa: PLR0915
|
|||
raise handle_exception_on_proxy(e)
|
||||
|
||||
|
||||
def prepare_metadata_fields(
|
||||
data: BaseModel, non_default_values: dict, existing_metadata: dict
|
||||
) -> dict:
|
||||
"""
|
||||
Check LiteLLM_ManagementEndpoint_MetadataFields (proxy/_types.py) for fields that are allowed to be updated
|
||||
"""
|
||||
|
||||
if "metadata" not in non_default_values: # allow user to set metadata to none
|
||||
non_default_values["metadata"] = existing_metadata.copy()
|
||||
|
||||
casted_metadata = cast(dict, non_default_values["metadata"])
|
||||
|
||||
data_json = data.model_dump(exclude_unset=True, exclude_none=True)
|
||||
|
||||
try:
|
||||
for k, v in data_json.items():
|
||||
if k == "model_tpm_limit" or k == "model_rpm_limit":
|
||||
if k not in casted_metadata or casted_metadata[k] is None:
|
||||
casted_metadata[k] = {}
|
||||
casted_metadata[k].update(v)
|
||||
|
||||
if k == "tags" or k == "guardrails":
|
||||
if k not in casted_metadata or casted_metadata[k] is None:
|
||||
casted_metadata[k] = []
|
||||
seen = set(casted_metadata[k])
|
||||
casted_metadata[k].extend(
|
||||
x for x in v if x not in seen and not seen.add(x) # type: ignore
|
||||
) # prevent duplicates from being added + maintain initial order
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
"litellm.proxy.proxy_server.prepare_metadata_fields(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
|
||||
non_default_values["metadata"] = casted_metadata
|
||||
return non_default_values
|
||||
|
||||
|
||||
def prepare_key_update_data(
|
||||
data: Union[UpdateKeyRequest, RegenerateKeyRequest], existing_key_row
|
||||
):
|
||||
data_json: dict = data.model_dump(exclude_unset=True)
|
||||
data_json.pop("key", None)
|
||||
_metadata_fields = ["model_rpm_limit", "model_tpm_limit", "guardrails", "tags"]
|
||||
_metadata_fields = ["model_rpm_limit", "model_tpm_limit", "guardrails"]
|
||||
non_default_values = {}
|
||||
for k, v in data_json.items():
|
||||
if k in _metadata_fields:
|
||||
|
@ -522,13 +476,24 @@ def prepare_key_update_data(
|
|||
duration_s = duration_in_seconds(duration=budget_duration)
|
||||
key_reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s)
|
||||
non_default_values["budget_reset_at"] = key_reset_at
|
||||
non_default_values["budget_duration"] = budget_duration
|
||||
|
||||
_metadata = existing_key_row.metadata or {}
|
||||
|
||||
non_default_values = prepare_metadata_fields(
|
||||
data=data, non_default_values=non_default_values, existing_metadata=_metadata
|
||||
)
|
||||
if data.model_tpm_limit:
|
||||
if "model_tpm_limit" not in _metadata:
|
||||
_metadata["model_tpm_limit"] = {}
|
||||
_metadata["model_tpm_limit"].update(data.model_tpm_limit)
|
||||
non_default_values["metadata"] = _metadata
|
||||
|
||||
if data.model_rpm_limit:
|
||||
if "model_rpm_limit" not in _metadata:
|
||||
_metadata["model_rpm_limit"] = {}
|
||||
_metadata["model_rpm_limit"].update(data.model_rpm_limit)
|
||||
non_default_values["metadata"] = _metadata
|
||||
|
||||
if data.guardrails:
|
||||
_metadata["guardrails"] = data.guardrails
|
||||
non_default_values["metadata"] = _metadata
|
||||
|
||||
return non_default_values
|
||||
|
||||
|
@ -620,12 +585,6 @@ async def update_key_fn(
|
|||
data=data, existing_key_row=existing_key_row
|
||||
)
|
||||
|
||||
await _enforce_unique_key_alias(
|
||||
key_alias=non_default_values.get("key_alias", None),
|
||||
prisma_client=prisma_client,
|
||||
existing_key_token=existing_key_row.token,
|
||||
)
|
||||
|
||||
response = await prisma_client.update_data(
|
||||
token=key, data={**non_default_values, "token": key}
|
||||
)
|
||||
|
@ -953,11 +912,11 @@ async def generate_key_helper_fn( # noqa: PLR0915
|
|||
request_type: Literal[
|
||||
"user", "key"
|
||||
], # identifies if this request is from /user/new or /key/generate
|
||||
duration: Optional[str] = None,
|
||||
models: list = [],
|
||||
aliases: dict = {},
|
||||
config: dict = {},
|
||||
spend: float = 0.0,
|
||||
duration: Optional[str],
|
||||
models: list,
|
||||
aliases: dict,
|
||||
config: dict,
|
||||
spend: float,
|
||||
key_max_budget: Optional[float] = None, # key_max_budget is used to Budget Per key
|
||||
key_budget_duration: Optional[str] = None,
|
||||
budget_id: Optional[float] = None, # budget id <-> LiteLLM_BudgetTable
|
||||
|
@ -986,8 +945,8 @@ async def generate_key_helper_fn( # noqa: PLR0915
|
|||
allowed_cache_controls: Optional[list] = [],
|
||||
permissions: Optional[dict] = {},
|
||||
model_max_budget: Optional[dict] = {},
|
||||
model_rpm_limit: Optional[dict] = None,
|
||||
model_tpm_limit: Optional[dict] = None,
|
||||
model_rpm_limit: Optional[dict] = {},
|
||||
model_tpm_limit: Optional[dict] = {},
|
||||
guardrails: Optional[list] = None,
|
||||
teams: Optional[list] = None,
|
||||
organization_id: Optional[str] = None,
|
||||
|
@ -1924,38 +1883,3 @@ async def test_key_logging(
|
|||
status="healthy",
|
||||
details=f"No logger exceptions triggered, system is healthy. Manually check if logs were sent to {logging_callbacks} ",
|
||||
)
|
||||
|
||||
|
||||
async def _enforce_unique_key_alias(
|
||||
key_alias: Optional[str],
|
||||
prisma_client: Any,
|
||||
existing_key_token: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Helper to enforce unique key aliases across all keys.
|
||||
|
||||
Args:
|
||||
key_alias (Optional[str]): The key alias to check
|
||||
prisma_client (Any): Prisma client instance
|
||||
existing_key_token (Optional[str]): ID of existing key being updated, to exclude from uniqueness check
|
||||
(The Admin UI passes key_alias, in all Edit key requests. So we need to be sure that if we find a key with the same alias, it's not the same key we're updating)
|
||||
|
||||
Raises:
|
||||
ProxyException: If key alias already exists on a different key
|
||||
"""
|
||||
if key_alias is not None and prisma_client is not None:
|
||||
where_clause: dict[str, Any] = {"key_alias": key_alias}
|
||||
if existing_key_token:
|
||||
# Exclude the current key from the uniqueness check
|
||||
where_clause["NOT"] = {"token": existing_key_token}
|
||||
|
||||
existing_key = await prisma_client.db.litellm_verificationtoken.find_first(
|
||||
where=where_clause
|
||||
)
|
||||
if existing_key is not None:
|
||||
raise ProxyException(
|
||||
message=f"Key with alias '{key_alias}' already exists. Unique key aliases across all keys are required.",
|
||||
type=ProxyErrorTypes.bad_request_error,
|
||||
param="key_alias",
|
||||
code=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
|
|
@ -1367,7 +1367,6 @@ async def list_team(
|
|||
""".format(
|
||||
team.team_id, team.model_dump(), str(e)
|
||||
)
|
||||
verbose_proxy_logger.exception(team_exception)
|
||||
continue
|
||||
raise HTTPException(status_code=400, detail={"error": team_exception})
|
||||
|
||||
return returned_responses
|
||||
|
|
|
@ -134,10 +134,7 @@ from litellm.proxy.auth.model_checks import (
|
|||
get_key_models,
|
||||
get_team_models,
|
||||
)
|
||||
from litellm.proxy.auth.user_api_key_auth import (
|
||||
user_api_key_auth,
|
||||
user_api_key_auth_websocket,
|
||||
)
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
|
||||
## Import All Misc routes here ##
|
||||
from litellm.proxy.caching_routes import router as caching_router
|
||||
|
@ -4328,11 +4325,7 @@ from litellm import _arealtime
|
|||
|
||||
|
||||
@app.websocket("/v1/realtime")
|
||||
async def websocket_endpoint(
|
||||
websocket: WebSocket,
|
||||
model: str,
|
||||
user_api_key_dict=Depends(user_api_key_auth_websocket),
|
||||
):
|
||||
async def websocket_endpoint(websocket: WebSocket, model: str):
|
||||
import websockets
|
||||
|
||||
await websocket.accept()
|
||||
|
|
|
@ -86,6 +86,7 @@ async def route_request(
|
|||
else:
|
||||
models = [model.strip() for model in data.pop("model").split(",")]
|
||||
return llm_router.abatch_completion(models=models, **data)
|
||||
|
||||
elif llm_router is not None:
|
||||
if (
|
||||
data["model"] in router_model_names
|
||||
|
@ -112,9 +113,6 @@ async def route_request(
|
|||
or len(llm_router.pattern_router.patterns) > 0
|
||||
):
|
||||
return getattr(llm_router, f"{route_type}")(**data)
|
||||
elif route_type == "amoderation":
|
||||
# moderation endpoint does not require `model` parameter
|
||||
return getattr(llm_router, f"{route_type}")(**data)
|
||||
|
||||
elif user_model is not None:
|
||||
return getattr(litellm, f"{route_type}")(**data)
|
||||
|
|
|
@ -891,7 +891,7 @@ class ProxyLogging:
|
|||
original_exception: Exception,
|
||||
request: Request,
|
||||
parent_otel_span: Optional[Any],
|
||||
api_key: Optional[str],
|
||||
api_key: str,
|
||||
):
|
||||
"""
|
||||
Handler for Logging Authentication Errors on LiteLLM Proxy
|
||||
|
@ -905,13 +905,9 @@ class ProxyLogging:
|
|||
|
||||
user_api_key_dict = UserAPIKeyAuth(
|
||||
parent_otel_span=parent_otel_span,
|
||||
token=_hash_token_if_needed(token=api_key or ""),
|
||||
token=_hash_token_if_needed(token=api_key),
|
||||
)
|
||||
try:
|
||||
request_data = await request.json()
|
||||
except json.JSONDecodeError:
|
||||
# For GET requests or requests without a JSON body
|
||||
request_data = {}
|
||||
request_data = await request.json()
|
||||
await self._run_post_call_failure_hook_custom_loggers(
|
||||
original_exception=original_exception,
|
||||
request_data=request_data,
|
||||
|
|
|
@ -41,7 +41,6 @@ from typing import (
|
|||
import httpx
|
||||
import openai
|
||||
from openai import AsyncOpenAI
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import overload
|
||||
|
||||
import litellm
|
||||
|
@ -123,7 +122,6 @@ from litellm.types.router import (
|
|||
ModelInfo,
|
||||
ProviderBudgetConfigType,
|
||||
RetryPolicy,
|
||||
RouterCacheEnum,
|
||||
RouterErrors,
|
||||
RouterGeneralSettings,
|
||||
RouterModelGroupAliasItem,
|
||||
|
@ -241,6 +239,7 @@ class Router:
|
|||
] = "simple-shuffle",
|
||||
routing_strategy_args: dict = {}, # just for latency-based
|
||||
provider_budget_config: Optional[ProviderBudgetConfigType] = None,
|
||||
semaphore: Optional[asyncio.Semaphore] = None,
|
||||
alerting_config: Optional[AlertingConfig] = None,
|
||||
router_general_settings: Optional[
|
||||
RouterGeneralSettings
|
||||
|
@ -316,6 +315,8 @@ class Router:
|
|||
|
||||
from litellm._service_logger import ServiceLogging
|
||||
|
||||
if semaphore:
|
||||
self.semaphore = semaphore
|
||||
self.set_verbose = set_verbose
|
||||
self.debug_level = debug_level
|
||||
self.enable_pre_call_checks = enable_pre_call_checks
|
||||
|
@ -505,14 +506,6 @@ class Router:
|
|||
litellm.success_callback.append(self.sync_deployment_callback_on_success)
|
||||
else:
|
||||
litellm.success_callback = [self.sync_deployment_callback_on_success]
|
||||
if isinstance(litellm._async_failure_callback, list):
|
||||
litellm._async_failure_callback.append(
|
||||
self.async_deployment_callback_on_failure
|
||||
)
|
||||
else:
|
||||
litellm._async_failure_callback = [
|
||||
self.async_deployment_callback_on_failure
|
||||
]
|
||||
## COOLDOWNS ##
|
||||
if isinstance(litellm.failure_callback, list):
|
||||
litellm.failure_callback.append(self.deployment_callback_on_failure)
|
||||
|
@ -2563,7 +2556,10 @@ class Router:
|
|||
original_function: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
if kwargs.get("model") and self.get_model_list(model_name=kwargs["model"]):
|
||||
if (
|
||||
"model" in kwargs
|
||||
and self.get_model_list(model_name=kwargs["model"]) is not None
|
||||
):
|
||||
deployment = await self.async_get_available_deployment(
|
||||
model=kwargs["model"]
|
||||
)
|
||||
|
@ -3295,14 +3291,13 @@ class Router:
|
|||
):
|
||||
"""
|
||||
Track remaining tpm/rpm quota for model in model_list
|
||||
|
||||
Currently, only updates TPM usage.
|
||||
"""
|
||||
try:
|
||||
if kwargs["litellm_params"].get("metadata") is None:
|
||||
pass
|
||||
else:
|
||||
deployment_name = kwargs["litellm_params"]["metadata"].get(
|
||||
"deployment", None
|
||||
) # stable name - works for wildcard routes as well
|
||||
model_group = kwargs["litellm_params"]["metadata"].get(
|
||||
"model_group", None
|
||||
)
|
||||
|
@ -3313,8 +3308,6 @@ class Router:
|
|||
elif isinstance(id, int):
|
||||
id = str(id)
|
||||
|
||||
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
|
||||
|
||||
_usage_obj = completion_response.get("usage")
|
||||
total_tokens = _usage_obj.get("total_tokens", 0) if _usage_obj else 0
|
||||
|
||||
|
@ -3326,14 +3319,13 @@ class Router:
|
|||
"%H-%M"
|
||||
) # use the same timezone regardless of system clock
|
||||
|
||||
tpm_key = RouterCacheEnum.TPM.value.format(
|
||||
id=id, current_minute=current_minute, model=deployment_name
|
||||
)
|
||||
tpm_key = f"global_router:{id}:tpm:{current_minute}"
|
||||
# ------------
|
||||
# Update usage
|
||||
# ------------
|
||||
# update cache
|
||||
|
||||
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
|
||||
## TPM
|
||||
await self.cache.async_increment_cache(
|
||||
key=tpm_key,
|
||||
|
@ -3342,17 +3334,6 @@ class Router:
|
|||
ttl=RoutingArgs.ttl.value,
|
||||
)
|
||||
|
||||
## RPM
|
||||
rpm_key = RouterCacheEnum.RPM.value.format(
|
||||
id=id, current_minute=current_minute, model=deployment_name
|
||||
)
|
||||
await self.cache.async_increment_cache(
|
||||
key=rpm_key,
|
||||
value=1,
|
||||
parent_otel_span=parent_otel_span,
|
||||
ttl=RoutingArgs.ttl.value,
|
||||
)
|
||||
|
||||
increment_deployment_successes_for_current_minute(
|
||||
litellm_router_instance=self,
|
||||
deployment_id=id,
|
||||
|
@ -3465,40 +3446,6 @@ class Router:
|
|||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def async_deployment_callback_on_failure(
|
||||
self, kwargs, completion_response: Optional[Any], start_time, end_time
|
||||
):
|
||||
"""
|
||||
Update RPM usage for a deployment
|
||||
"""
|
||||
deployment_name = kwargs["litellm_params"]["metadata"].get(
|
||||
"deployment", None
|
||||
) # handles wildcard routes - by giving the original name sent to `litellm.completion`
|
||||
model_group = kwargs["litellm_params"]["metadata"].get("model_group", None)
|
||||
model_info = kwargs["litellm_params"].get("model_info", {}) or {}
|
||||
id = model_info.get("id", None)
|
||||
if model_group is None or id is None:
|
||||
return
|
||||
elif isinstance(id, int):
|
||||
id = str(id)
|
||||
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
|
||||
|
||||
dt = get_utc_datetime()
|
||||
current_minute = dt.strftime(
|
||||
"%H-%M"
|
||||
) # use the same timezone regardless of system clock
|
||||
|
||||
## RPM
|
||||
rpm_key = RouterCacheEnum.RPM.value.format(
|
||||
id=id, current_minute=current_minute, model=deployment_name
|
||||
)
|
||||
await self.cache.async_increment_cache(
|
||||
key=rpm_key,
|
||||
value=1,
|
||||
parent_otel_span=parent_otel_span,
|
||||
ttl=RoutingArgs.ttl.value,
|
||||
)
|
||||
|
||||
def log_retry(self, kwargs: dict, e: Exception) -> dict:
|
||||
"""
|
||||
When a retry or fallback happens, log the details of the just failed model call - similar to Sentry breadcrumbing
|
||||
|
@ -4176,24 +4123,7 @@ class Router:
|
|||
raise Exception("Model Name invalid - {}".format(type(model)))
|
||||
return None
|
||||
|
||||
@overload
|
||||
def get_router_model_info(
|
||||
self, deployment: dict, received_model_name: str, id: None = None
|
||||
) -> ModelMapInfo:
|
||||
pass
|
||||
|
||||
@overload
|
||||
def get_router_model_info(
|
||||
self, deployment: None, received_model_name: str, id: str
|
||||
) -> ModelMapInfo:
|
||||
pass
|
||||
|
||||
def get_router_model_info(
|
||||
self,
|
||||
deployment: Optional[dict],
|
||||
received_model_name: str,
|
||||
id: Optional[str] = None,
|
||||
) -> ModelMapInfo:
|
||||
def get_router_model_info(self, deployment: dict) -> ModelMapInfo:
|
||||
"""
|
||||
For a given model id, return the model info (max tokens, input cost, output cost, etc.).
|
||||
|
||||
|
@ -4207,14 +4137,6 @@ class Router:
|
|||
Raises:
|
||||
- ValueError -> If model is not mapped yet
|
||||
"""
|
||||
if id is not None:
|
||||
_deployment = self.get_deployment(model_id=id)
|
||||
if _deployment is not None:
|
||||
deployment = _deployment.model_dump(exclude_none=True)
|
||||
|
||||
if deployment is None:
|
||||
raise ValueError("Deployment not found")
|
||||
|
||||
## GET BASE MODEL
|
||||
base_model = deployment.get("model_info", {}).get("base_model", None)
|
||||
if base_model is None:
|
||||
|
@ -4236,27 +4158,10 @@ class Router:
|
|||
elif custom_llm_provider != "azure":
|
||||
model = _model
|
||||
|
||||
potential_models = self.pattern_router.route(received_model_name)
|
||||
if "*" in model and potential_models is not None: # if wildcard route
|
||||
for potential_model in potential_models:
|
||||
try:
|
||||
if potential_model.get("model_info", {}).get(
|
||||
"id"
|
||||
) == deployment.get("model_info", {}).get("id"):
|
||||
model = potential_model.get("litellm_params", {}).get(
|
||||
"model"
|
||||
)
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
## GET LITELLM MODEL INFO - raises exception, if model is not mapped
|
||||
if not model.startswith(custom_llm_provider):
|
||||
model_info_name = "{}/{}".format(custom_llm_provider, model)
|
||||
else:
|
||||
model_info_name = model
|
||||
|
||||
model_info = litellm.get_model_info(model=model_info_name)
|
||||
model_info = litellm.get_model_info(
|
||||
model="{}/{}".format(custom_llm_provider, model)
|
||||
)
|
||||
|
||||
## CHECK USER SET MODEL INFO
|
||||
user_model_info = deployment.get("model_info", {})
|
||||
|
@ -4306,10 +4211,8 @@ class Router:
|
|||
total_tpm: Optional[int] = None
|
||||
total_rpm: Optional[int] = None
|
||||
configurable_clientside_auth_params: CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS = None
|
||||
model_list = self.get_model_list(model_name=model_group)
|
||||
if model_list is None:
|
||||
return None
|
||||
for model in model_list:
|
||||
|
||||
for model in self.model_list:
|
||||
is_match = False
|
||||
if (
|
||||
"model_name" in model and model["model_name"] == model_group
|
||||
|
@ -4324,7 +4227,7 @@ class Router:
|
|||
if not is_match:
|
||||
continue
|
||||
# model in model group found #
|
||||
litellm_params = LiteLLM_Params(**model["litellm_params"]) # type: ignore
|
||||
litellm_params = LiteLLM_Params(**model["litellm_params"])
|
||||
# get configurable clientside auth params
|
||||
configurable_clientside_auth_params = (
|
||||
litellm_params.configurable_clientside_auth_params
|
||||
|
@ -4332,30 +4235,38 @@ class Router:
|
|||
# get model tpm
|
||||
_deployment_tpm: Optional[int] = None
|
||||
if _deployment_tpm is None:
|
||||
_deployment_tpm = model.get("tpm", None) # type: ignore
|
||||
_deployment_tpm = model.get("tpm", None)
|
||||
if _deployment_tpm is None:
|
||||
_deployment_tpm = model.get("litellm_params", {}).get("tpm", None) # type: ignore
|
||||
_deployment_tpm = model.get("litellm_params", {}).get("tpm", None)
|
||||
if _deployment_tpm is None:
|
||||
_deployment_tpm = model.get("model_info", {}).get("tpm", None) # type: ignore
|
||||
_deployment_tpm = model.get("model_info", {}).get("tpm", None)
|
||||
|
||||
if _deployment_tpm is not None:
|
||||
if total_tpm is None:
|
||||
total_tpm = 0
|
||||
total_tpm += _deployment_tpm # type: ignore
|
||||
# get model rpm
|
||||
_deployment_rpm: Optional[int] = None
|
||||
if _deployment_rpm is None:
|
||||
_deployment_rpm = model.get("rpm", None) # type: ignore
|
||||
_deployment_rpm = model.get("rpm", None)
|
||||
if _deployment_rpm is None:
|
||||
_deployment_rpm = model.get("litellm_params", {}).get("rpm", None) # type: ignore
|
||||
_deployment_rpm = model.get("litellm_params", {}).get("rpm", None)
|
||||
if _deployment_rpm is None:
|
||||
_deployment_rpm = model.get("model_info", {}).get("rpm", None) # type: ignore
|
||||
_deployment_rpm = model.get("model_info", {}).get("rpm", None)
|
||||
|
||||
if _deployment_rpm is not None:
|
||||
if total_rpm is None:
|
||||
total_rpm = 0
|
||||
total_rpm += _deployment_rpm # type: ignore
|
||||
# get model info
|
||||
try:
|
||||
model_info = litellm.get_model_info(model=litellm_params.model)
|
||||
except Exception:
|
||||
model_info = None
|
||||
# get llm provider
|
||||
litellm_model, llm_provider = "", ""
|
||||
model, llm_provider = "", ""
|
||||
try:
|
||||
litellm_model, llm_provider, _, _ = litellm.get_llm_provider(
|
||||
model, llm_provider, _, _ = litellm.get_llm_provider(
|
||||
model=litellm_params.model,
|
||||
custom_llm_provider=litellm_params.custom_llm_provider,
|
||||
)
|
||||
|
@ -4366,7 +4277,7 @@ class Router:
|
|||
|
||||
if model_info is None:
|
||||
supported_openai_params = litellm.get_supported_openai_params(
|
||||
model=litellm_model, custom_llm_provider=llm_provider
|
||||
model=model, custom_llm_provider=llm_provider
|
||||
)
|
||||
if supported_openai_params is None:
|
||||
supported_openai_params = []
|
||||
|
@ -4456,20 +4367,7 @@ class Router:
|
|||
model_group_info.supported_openai_params = model_info[
|
||||
"supported_openai_params"
|
||||
]
|
||||
if model_info.get("tpm", None) is not None and _deployment_tpm is None:
|
||||
_deployment_tpm = model_info.get("tpm")
|
||||
if model_info.get("rpm", None) is not None and _deployment_rpm is None:
|
||||
_deployment_rpm = model_info.get("rpm")
|
||||
|
||||
if _deployment_tpm is not None:
|
||||
if total_tpm is None:
|
||||
total_tpm = 0
|
||||
total_tpm += _deployment_tpm # type: ignore
|
||||
|
||||
if _deployment_rpm is not None:
|
||||
if total_rpm is None:
|
||||
total_rpm = 0
|
||||
total_rpm += _deployment_rpm # type: ignore
|
||||
if model_group_info is not None:
|
||||
## UPDATE WITH TOTAL TPM/RPM FOR MODEL GROUP
|
||||
if total_tpm is not None:
|
||||
|
@ -4521,10 +4419,7 @@ class Router:
|
|||
self, model_group: str
|
||||
) -> Tuple[Optional[int], Optional[int]]:
|
||||
"""
|
||||
Returns current tpm/rpm usage for model group
|
||||
|
||||
Parameters:
|
||||
- model_group: str - the received model name from the user (can be a wildcard route).
|
||||
Returns remaining tpm/rpm quota for model group
|
||||
|
||||
Returns:
|
||||
- usage: Tuple[tpm, rpm]
|
||||
|
@ -4535,37 +4430,20 @@ class Router:
|
|||
) # use the same timezone regardless of system clock
|
||||
tpm_keys: List[str] = []
|
||||
rpm_keys: List[str] = []
|
||||
|
||||
model_list = self.get_model_list(model_name=model_group)
|
||||
if model_list is None: # no matching deployments
|
||||
return None, None
|
||||
|
||||
for model in model_list:
|
||||
id: Optional[str] = model.get("model_info", {}).get("id") # type: ignore
|
||||
litellm_model: Optional[str] = model["litellm_params"].get(
|
||||
"model"
|
||||
) # USE THE MODEL SENT TO litellm.completion() - consistent with how global_router cache is written.
|
||||
if id is None or litellm_model is None:
|
||||
continue
|
||||
tpm_keys.append(
|
||||
RouterCacheEnum.TPM.value.format(
|
||||
id=id,
|
||||
model=litellm_model,
|
||||
current_minute=current_minute,
|
||||
for model in self.model_list:
|
||||
if "model_name" in model and model["model_name"] == model_group:
|
||||
tpm_keys.append(
|
||||
f"global_router:{model['model_info']['id']}:tpm:{current_minute}"
|
||||
)
|
||||
)
|
||||
rpm_keys.append(
|
||||
RouterCacheEnum.RPM.value.format(
|
||||
id=id,
|
||||
model=litellm_model,
|
||||
current_minute=current_minute,
|
||||
rpm_keys.append(
|
||||
f"global_router:{model['model_info']['id']}:rpm:{current_minute}"
|
||||
)
|
||||
)
|
||||
combined_tpm_rpm_keys = tpm_keys + rpm_keys
|
||||
|
||||
combined_tpm_rpm_values = await self.cache.async_batch_get_cache(
|
||||
keys=combined_tpm_rpm_keys
|
||||
)
|
||||
|
||||
if combined_tpm_rpm_values is None:
|
||||
return None, None
|
||||
|
||||
|
@ -4590,32 +4468,6 @@ class Router:
|
|||
rpm_usage += t
|
||||
return tpm_usage, rpm_usage
|
||||
|
||||
async def get_remaining_model_group_usage(self, model_group: str) -> Dict[str, int]:
|
||||
|
||||
current_tpm, current_rpm = await self.get_model_group_usage(model_group)
|
||||
|
||||
model_group_info = self.get_model_group_info(model_group)
|
||||
|
||||
if model_group_info is not None and model_group_info.tpm is not None:
|
||||
tpm_limit = model_group_info.tpm
|
||||
else:
|
||||
tpm_limit = None
|
||||
|
||||
if model_group_info is not None and model_group_info.rpm is not None:
|
||||
rpm_limit = model_group_info.rpm
|
||||
else:
|
||||
rpm_limit = None
|
||||
|
||||
returned_dict = {}
|
||||
if tpm_limit is not None and current_tpm is not None:
|
||||
returned_dict["x-ratelimit-remaining-tokens"] = tpm_limit - current_tpm
|
||||
returned_dict["x-ratelimit-limit-tokens"] = tpm_limit
|
||||
if rpm_limit is not None and current_rpm is not None:
|
||||
returned_dict["x-ratelimit-remaining-requests"] = rpm_limit - current_rpm
|
||||
returned_dict["x-ratelimit-limit-requests"] = rpm_limit
|
||||
|
||||
return returned_dict
|
||||
|
||||
async def set_response_headers(
|
||||
self, response: Any, model_group: Optional[str] = None
|
||||
) -> Any:
|
||||
|
@ -4626,30 +4478,6 @@ class Router:
|
|||
# - if healthy_deployments > 1, return model group rate limit headers
|
||||
# - else return the model's rate limit headers
|
||||
"""
|
||||
if (
|
||||
isinstance(response, BaseModel)
|
||||
and hasattr(response, "_hidden_params")
|
||||
and isinstance(response._hidden_params, dict) # type: ignore
|
||||
):
|
||||
response._hidden_params.setdefault("additional_headers", {}) # type: ignore
|
||||
response._hidden_params["additional_headers"][ # type: ignore
|
||||
"x-litellm-model-group"
|
||||
] = model_group
|
||||
|
||||
additional_headers = response._hidden_params["additional_headers"] # type: ignore
|
||||
|
||||
if (
|
||||
"x-ratelimit-remaining-tokens" not in additional_headers
|
||||
and "x-ratelimit-remaining-requests" not in additional_headers
|
||||
and model_group is not None
|
||||
):
|
||||
remaining_usage = await self.get_remaining_model_group_usage(
|
||||
model_group
|
||||
)
|
||||
|
||||
for header, value in remaining_usage.items():
|
||||
if value is not None:
|
||||
additional_headers[header] = value
|
||||
return response
|
||||
|
||||
def get_model_ids(self, model_name: Optional[str] = None) -> List[str]:
|
||||
|
@ -4712,9 +4540,6 @@ class Router:
|
|||
if hasattr(self, "model_list"):
|
||||
returned_models: List[DeploymentTypedDict] = []
|
||||
|
||||
if model_name is not None:
|
||||
returned_models.extend(self._get_all_deployments(model_name=model_name))
|
||||
|
||||
if hasattr(self, "model_group_alias"):
|
||||
for model_alias, model_value in self.model_group_alias.items():
|
||||
|
||||
|
@ -4735,32 +4560,21 @@ class Router:
|
|||
)
|
||||
)
|
||||
|
||||
if len(returned_models) == 0: # check if wildcard route
|
||||
potential_wildcard_models = self.pattern_router.route(model_name)
|
||||
if potential_wildcard_models is not None:
|
||||
returned_models.extend(
|
||||
[DeploymentTypedDict(**m) for m in potential_wildcard_models] # type: ignore
|
||||
)
|
||||
|
||||
if model_name is None:
|
||||
returned_models += self.model_list
|
||||
|
||||
return returned_models
|
||||
|
||||
returned_models.extend(self._get_all_deployments(model_name=model_name))
|
||||
return returned_models
|
||||
return None
|
||||
|
||||
def get_model_access_groups(self, model_name: Optional[str] = None):
|
||||
"""
|
||||
If model_name is provided, only return access groups for that model.
|
||||
"""
|
||||
def get_model_access_groups(self):
|
||||
from collections import defaultdict
|
||||
|
||||
access_groups = defaultdict(list)
|
||||
|
||||
model_list = self.get_model_list(model_name=model_name)
|
||||
if model_list:
|
||||
for m in model_list:
|
||||
if self.model_list:
|
||||
for m in self.model_list:
|
||||
for group in m.get("model_info", {}).get("access_groups", []):
|
||||
model_name = m["model_name"]
|
||||
access_groups[group].append(model_name)
|
||||
|
@ -4996,12 +4810,10 @@ class Router:
|
|||
base_model = deployment.get("litellm_params", {}).get(
|
||||
"base_model", None
|
||||
)
|
||||
model_info = self.get_router_model_info(
|
||||
deployment=deployment, received_model_name=model
|
||||
)
|
||||
model = base_model or deployment.get("litellm_params", {}).get(
|
||||
"model", None
|
||||
)
|
||||
model_info = self.get_router_model_info(deployment=deployment)
|
||||
|
||||
if (
|
||||
isinstance(model_info, dict)
|
||||
|
|
|
@ -79,9 +79,7 @@ class PatternMatchRouter:
|
|||
|
||||
return new_deployments
|
||||
|
||||
def route(
|
||||
self, request: Optional[str], filtered_model_names: Optional[List[str]] = None
|
||||
) -> Optional[List[Dict]]:
|
||||
def route(self, request: Optional[str]) -> Optional[List[Dict]]:
|
||||
"""
|
||||
Route a requested model to the corresponding llm deployments based on the regex pattern
|
||||
|
||||
|
@ -91,26 +89,14 @@ class PatternMatchRouter:
|
|||
|
||||
Args:
|
||||
request: Optional[str]
|
||||
filtered_model_names: Optional[List[str]] - if provided, only return deployments that match the filtered_model_names
|
||||
|
||||
Returns:
|
||||
Optional[List[Deployment]]: llm deployments
|
||||
"""
|
||||
try:
|
||||
if request is None:
|
||||
return None
|
||||
|
||||
regex_filtered_model_names = (
|
||||
[self._pattern_to_regex(m) for m in filtered_model_names]
|
||||
if filtered_model_names is not None
|
||||
else []
|
||||
)
|
||||
|
||||
for pattern, llm_deployments in self.patterns.items():
|
||||
if (
|
||||
filtered_model_names is not None
|
||||
and pattern not in regex_filtered_model_names
|
||||
):
|
||||
continue
|
||||
pattern_match = re.match(pattern, request)
|
||||
if pattern_match:
|
||||
return self._return_pattern_matched_deployments(
|
||||
|
|
29
litellm/tests/test_mlflow.py
Normal file
29
litellm/tests/test_mlflow.py
Normal file
|
@ -0,0 +1,29 @@
|
|||
import pytest
|
||||
|
||||
import litellm
|
||||
|
||||
|
||||
def test_mlflow_logging():
|
||||
litellm.success_callback = ["mlflow"]
|
||||
litellm.failure_callback = ["mlflow"]
|
||||
|
||||
litellm.completion(
|
||||
model="gpt-4o-mini",
|
||||
messages=[{"role": "user", "content": "what llm are u"}],
|
||||
max_tokens=10,
|
||||
temperature=0.2,
|
||||
user="test-user",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_async_mlflow_logging():
|
||||
litellm.success_callback = ["mlflow"]
|
||||
litellm.failure_callback = ["mlflow"]
|
||||
|
||||
await litellm.acompletion(
|
||||
model="gpt-4o-mini",
|
||||
messages=[{"role": "user", "content": "hi test from local arize"}],
|
||||
mock_response="hello",
|
||||
temperature=0.1,
|
||||
user="OTEL_USER",
|
||||
)
|
|
@ -9,7 +9,7 @@ from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
|||
|
||||
import httpx
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from typing_extensions import Required, TypedDict
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from ..exceptions import RateLimitError
|
||||
from .completion import CompletionRequest
|
||||
|
@ -352,10 +352,9 @@ class LiteLLMParamsTypedDict(TypedDict, total=False):
|
|||
tags: Optional[List[str]]
|
||||
|
||||
|
||||
class DeploymentTypedDict(TypedDict, total=False):
|
||||
model_name: Required[str]
|
||||
litellm_params: Required[LiteLLMParamsTypedDict]
|
||||
model_info: dict
|
||||
class DeploymentTypedDict(TypedDict):
|
||||
model_name: str
|
||||
litellm_params: LiteLLMParamsTypedDict
|
||||
|
||||
|
||||
SPECIAL_MODEL_INFO_PARAMS = [
|
||||
|
@ -641,8 +640,3 @@ class ProviderBudgetInfo(BaseModel):
|
|||
|
||||
|
||||
ProviderBudgetConfigType = Dict[str, ProviderBudgetInfo]
|
||||
|
||||
|
||||
class RouterCacheEnum(enum.Enum):
|
||||
TPM = "global_router:{id}:{model}:tpm:{current_minute}"
|
||||
RPM = "global_router:{id}:{model}:rpm:{current_minute}"
|
||||
|
|
|
@ -106,8 +106,6 @@ class ModelInfo(TypedDict, total=False):
|
|||
supports_prompt_caching: Optional[bool]
|
||||
supports_audio_input: Optional[bool]
|
||||
supports_audio_output: Optional[bool]
|
||||
tpm: Optional[int]
|
||||
rpm: Optional[int]
|
||||
|
||||
|
||||
class GenericStreamingChunk(TypedDict, total=False):
|
||||
|
|
|
@ -4656,8 +4656,6 @@ def get_model_info( # noqa: PLR0915
|
|||
),
|
||||
supports_audio_input=_model_info.get("supports_audio_input", False),
|
||||
supports_audio_output=_model_info.get("supports_audio_output", False),
|
||||
tpm=_model_info.get("tpm", None),
|
||||
rpm=_model_info.get("rpm", None),
|
||||
)
|
||||
except Exception as e:
|
||||
if "OllamaError" in str(e):
|
||||
|
|
|
@ -3383,8 +3383,6 @@
|
|||
"supports_vision": true,
|
||||
"supports_response_schema": true,
|
||||
"supports_prompt_caching": true,
|
||||
"tpm": 4000000,
|
||||
"rpm": 2000,
|
||||
"source": "https://ai.google.dev/pricing"
|
||||
},
|
||||
"gemini/gemini-1.5-flash-001": {
|
||||
|
@ -3408,8 +3406,6 @@
|
|||
"supports_vision": true,
|
||||
"supports_response_schema": true,
|
||||
"supports_prompt_caching": true,
|
||||
"tpm": 4000000,
|
||||
"rpm": 2000,
|
||||
"source": "https://ai.google.dev/pricing"
|
||||
},
|
||||
"gemini/gemini-1.5-flash": {
|
||||
|
@ -3432,8 +3428,6 @@
|
|||
"supports_function_calling": true,
|
||||
"supports_vision": true,
|
||||
"supports_response_schema": true,
|
||||
"tpm": 4000000,
|
||||
"rpm": 2000,
|
||||
"source": "https://ai.google.dev/pricing"
|
||||
},
|
||||
"gemini/gemini-1.5-flash-latest": {
|
||||
|
@ -3456,32 +3450,6 @@
|
|||
"supports_function_calling": true,
|
||||
"supports_vision": true,
|
||||
"supports_response_schema": true,
|
||||
"tpm": 4000000,
|
||||
"rpm": 2000,
|
||||
"source": "https://ai.google.dev/pricing"
|
||||
},
|
||||
"gemini/gemini-1.5-flash-8b": {
|
||||
"max_tokens": 8192,
|
||||
"max_input_tokens": 1048576,
|
||||
"max_output_tokens": 8192,
|
||||
"max_images_per_prompt": 3000,
|
||||
"max_videos_per_prompt": 10,
|
||||
"max_video_length": 1,
|
||||
"max_audio_length_hours": 8.4,
|
||||
"max_audio_per_prompt": 1,
|
||||
"max_pdf_size_mb": 30,
|
||||
"input_cost_per_token": 0,
|
||||
"input_cost_per_token_above_128k_tokens": 0,
|
||||
"output_cost_per_token": 0,
|
||||
"output_cost_per_token_above_128k_tokens": 0,
|
||||
"litellm_provider": "gemini",
|
||||
"mode": "chat",
|
||||
"supports_system_messages": true,
|
||||
"supports_function_calling": true,
|
||||
"supports_vision": true,
|
||||
"supports_response_schema": true,
|
||||
"tpm": 4000000,
|
||||
"rpm": 4000,
|
||||
"source": "https://ai.google.dev/pricing"
|
||||
},
|
||||
"gemini/gemini-1.5-flash-8b-exp-0924": {
|
||||
|
@ -3504,8 +3472,6 @@
|
|||
"supports_function_calling": true,
|
||||
"supports_vision": true,
|
||||
"supports_response_schema": true,
|
||||
"tpm": 4000000,
|
||||
"rpm": 4000,
|
||||
"source": "https://ai.google.dev/pricing"
|
||||
},
|
||||
"gemini/gemini-exp-1114": {
|
||||
|
@ -3528,12 +3494,7 @@
|
|||
"supports_function_calling": true,
|
||||
"supports_vision": true,
|
||||
"supports_response_schema": true,
|
||||
"tpm": 4000000,
|
||||
"rpm": 1000,
|
||||
"source": "https://ai.google.dev/pricing",
|
||||
"metadata": {
|
||||
"notes": "Rate limits not documented for gemini-exp-1114. Assuming same as gemini-1.5-pro."
|
||||
}
|
||||
"source": "https://ai.google.dev/pricing"
|
||||
},
|
||||
"gemini/gemini-1.5-flash-exp-0827": {
|
||||
"max_tokens": 8192,
|
||||
|
@ -3555,8 +3516,6 @@
|
|||
"supports_function_calling": true,
|
||||
"supports_vision": true,
|
||||
"supports_response_schema": true,
|
||||
"tpm": 4000000,
|
||||
"rpm": 2000,
|
||||
"source": "https://ai.google.dev/pricing"
|
||||
},
|
||||
"gemini/gemini-1.5-flash-8b-exp-0827": {
|
||||
|
@ -3578,9 +3537,6 @@
|
|||
"supports_system_messages": true,
|
||||
"supports_function_calling": true,
|
||||
"supports_vision": true,
|
||||
"supports_response_schema": true,
|
||||
"tpm": 4000000,
|
||||
"rpm": 4000,
|
||||
"source": "https://ai.google.dev/pricing"
|
||||
},
|
||||
"gemini/gemini-pro": {
|
||||
|
@ -3594,10 +3550,7 @@
|
|||
"litellm_provider": "gemini",
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"rpd": 30000,
|
||||
"tpm": 120000,
|
||||
"rpm": 360,
|
||||
"source": "https://ai.google.dev/gemini-api/docs/models/gemini"
|
||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
|
||||
},
|
||||
"gemini/gemini-1.5-pro": {
|
||||
"max_tokens": 8192,
|
||||
|
@ -3614,8 +3567,6 @@
|
|||
"supports_vision": true,
|
||||
"supports_tool_choice": true,
|
||||
"supports_response_schema": true,
|
||||
"tpm": 4000000,
|
||||
"rpm": 1000,
|
||||
"source": "https://ai.google.dev/pricing"
|
||||
},
|
||||
"gemini/gemini-1.5-pro-002": {
|
||||
|
@ -3634,8 +3585,6 @@
|
|||
"supports_tool_choice": true,
|
||||
"supports_response_schema": true,
|
||||
"supports_prompt_caching": true,
|
||||
"tpm": 4000000,
|
||||
"rpm": 1000,
|
||||
"source": "https://ai.google.dev/pricing"
|
||||
},
|
||||
"gemini/gemini-1.5-pro-001": {
|
||||
|
@ -3654,8 +3603,6 @@
|
|||
"supports_tool_choice": true,
|
||||
"supports_response_schema": true,
|
||||
"supports_prompt_caching": true,
|
||||
"tpm": 4000000,
|
||||
"rpm": 1000,
|
||||
"source": "https://ai.google.dev/pricing"
|
||||
},
|
||||
"gemini/gemini-1.5-pro-exp-0801": {
|
||||
|
@ -3673,8 +3620,6 @@
|
|||
"supports_vision": true,
|
||||
"supports_tool_choice": true,
|
||||
"supports_response_schema": true,
|
||||
"tpm": 4000000,
|
||||
"rpm": 1000,
|
||||
"source": "https://ai.google.dev/pricing"
|
||||
},
|
||||
"gemini/gemini-1.5-pro-exp-0827": {
|
||||
|
@ -3692,8 +3637,6 @@
|
|||
"supports_vision": true,
|
||||
"supports_tool_choice": true,
|
||||
"supports_response_schema": true,
|
||||
"tpm": 4000000,
|
||||
"rpm": 1000,
|
||||
"source": "https://ai.google.dev/pricing"
|
||||
},
|
||||
"gemini/gemini-1.5-pro-latest": {
|
||||
|
@ -3711,8 +3654,6 @@
|
|||
"supports_vision": true,
|
||||
"supports_tool_choice": true,
|
||||
"supports_response_schema": true,
|
||||
"tpm": 4000000,
|
||||
"rpm": 1000,
|
||||
"source": "https://ai.google.dev/pricing"
|
||||
},
|
||||
"gemini/gemini-pro-vision": {
|
||||
|
@ -3727,9 +3668,6 @@
|
|||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_vision": true,
|
||||
"rpd": 30000,
|
||||
"tpm": 120000,
|
||||
"rpm": 360,
|
||||
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
|
||||
},
|
||||
"gemini/gemini-gemma-2-27b-it": {
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[tool.poetry]
|
||||
name = "litellm"
|
||||
version = "1.53.2"
|
||||
version = "1.53.1"
|
||||
description = "Library to easily interface with LLM API providers"
|
||||
authors = ["BerriAI"]
|
||||
license = "MIT"
|
||||
|
@ -91,7 +91,7 @@ requires = ["poetry-core", "wheel"]
|
|||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
[tool.commitizen]
|
||||
version = "1.53.2"
|
||||
version = "1.53.1"
|
||||
version_files = [
|
||||
"pyproject.toml:^version"
|
||||
]
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# LITELLM PROXY DEPENDENCIES #
|
||||
anyio==4.4.0 # openai + http req.
|
||||
openai==1.55.3 # openai req.
|
||||
openai==1.54.0 # openai req.
|
||||
fastapi==0.111.0 # server dep
|
||||
backoff==2.2.1 # server dep
|
||||
pyyaml==6.0.0 # server dep
|
||||
|
|
|
@ -46,22 +46,17 @@ print(env_keys)
|
|||
repo_base = "./"
|
||||
print(os.listdir(repo_base))
|
||||
docs_path = (
|
||||
"./docs/my-website/docs/proxy/config_settings.md" # Path to the documentation
|
||||
"../../docs/my-website/docs/proxy/config_settings.md" # Path to the documentation
|
||||
)
|
||||
documented_keys = set()
|
||||
try:
|
||||
with open(docs_path, "r", encoding="utf-8") as docs_file:
|
||||
content = docs_file.read()
|
||||
|
||||
print(f"content: {content}")
|
||||
|
||||
# Find the section titled "general_settings - Reference"
|
||||
general_settings_section = re.search(
|
||||
r"### environment variables - Reference(.*?)(?=\n###|\Z)",
|
||||
content,
|
||||
re.DOTALL | re.MULTILINE,
|
||||
r"### environment variables - Reference(.*?)###", content, re.DOTALL
|
||||
)
|
||||
print(f"general_settings_section: {general_settings_section}")
|
||||
if general_settings_section:
|
||||
# Extract the table rows, which contain the documented keys
|
||||
table_content = general_settings_section.group(1)
|
||||
|
@ -75,7 +70,6 @@ except Exception as e:
|
|||
)
|
||||
|
||||
|
||||
print(f"documented_keys: {documented_keys}")
|
||||
# Compare and find undocumented keys
|
||||
undocumented_keys = env_keys - documented_keys
|
||||
|
||||
|
|
|
@ -1,87 +0,0 @@
|
|||
import os
|
||||
import re
|
||||
import inspect
|
||||
from typing import Type
|
||||
import sys
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import litellm
|
||||
|
||||
|
||||
def get_init_params(cls: Type) -> list[str]:
|
||||
"""
|
||||
Retrieve all parameters supported by the `__init__` method of a given class.
|
||||
|
||||
Args:
|
||||
cls: The class to inspect.
|
||||
|
||||
Returns:
|
||||
A list of parameter names.
|
||||
"""
|
||||
if not hasattr(cls, "__init__"):
|
||||
raise ValueError(
|
||||
f"The provided class {cls.__name__} does not have an __init__ method."
|
||||
)
|
||||
|
||||
init_method = cls.__init__
|
||||
argspec = inspect.getfullargspec(init_method)
|
||||
|
||||
# The first argument is usually 'self', so we exclude it
|
||||
return argspec.args[1:] # Exclude 'self'
|
||||
|
||||
|
||||
router_init_params = set(get_init_params(litellm.router.Router))
|
||||
print(router_init_params)
|
||||
router_init_params.remove("model_list")
|
||||
|
||||
# Parse the documentation to extract documented keys
|
||||
repo_base = "./"
|
||||
print(os.listdir(repo_base))
|
||||
docs_path = (
|
||||
"./docs/my-website/docs/proxy/config_settings.md" # Path to the documentation
|
||||
)
|
||||
# docs_path = (
|
||||
# "../../docs/my-website/docs/proxy/config_settings.md" # Path to the documentation
|
||||
# )
|
||||
documented_keys = set()
|
||||
try:
|
||||
with open(docs_path, "r", encoding="utf-8") as docs_file:
|
||||
content = docs_file.read()
|
||||
|
||||
# Find the section titled "general_settings - Reference"
|
||||
general_settings_section = re.search(
|
||||
r"### router_settings - Reference(.*?)###", content, re.DOTALL
|
||||
)
|
||||
if general_settings_section:
|
||||
# Extract the table rows, which contain the documented keys
|
||||
table_content = general_settings_section.group(1)
|
||||
doc_key_pattern = re.compile(
|
||||
r"\|\s*([^\|]+?)\s*\|"
|
||||
) # Capture the key from each row of the table
|
||||
documented_keys.update(doc_key_pattern.findall(table_content))
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"Error reading documentation: {e}, \n repo base - {os.listdir(repo_base)}"
|
||||
)
|
||||
|
||||
|
||||
# Compare and find undocumented keys
|
||||
undocumented_keys = router_init_params - documented_keys
|
||||
|
||||
# Print results
|
||||
print("Keys expected in 'router settings' (found in code):")
|
||||
for key in sorted(router_init_params):
|
||||
print(key)
|
||||
|
||||
if undocumented_keys:
|
||||
raise Exception(
|
||||
f"\nKeys not documented in 'router settings - Reference': {undocumented_keys}"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
"\nAll keys are documented in 'router settings - Reference'. - {}".format(
|
||||
router_init_params
|
||||
)
|
||||
)
|
|
@ -1,3 +1 @@
|
|||
Unit tests for individual LLM providers.
|
||||
|
||||
Name of the test file is the name of the LLM provider - e.g. `test_openai.py` is for OpenAI.
|
||||
More tests under `litellm/litellm/tests/*`.
|
|
@ -62,14 +62,7 @@ class BaseLLMChatTest(ABC):
|
|||
response = litellm.completion(**base_completion_call_args, messages=messages)
|
||||
assert response is not None
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"response_format",
|
||||
[
|
||||
{"type": "json_object"},
|
||||
{"type": "text"},
|
||||
],
|
||||
)
|
||||
def test_json_response_format(self, response_format):
|
||||
def test_json_response_format(self):
|
||||
"""
|
||||
Test that the JSON response format is supported by the LLM API
|
||||
"""
|
||||
|
@ -90,7 +83,7 @@ class BaseLLMChatTest(ABC):
|
|||
response = litellm.completion(
|
||||
**base_completion_call_args,
|
||||
messages=messages,
|
||||
response_format=response_format,
|
||||
response_format={"type": "json_object"},
|
||||
)
|
||||
|
||||
print(response)
|
||||
|
|
File diff suppressed because one or more lines are too long
|
@ -45,59 +45,81 @@ def test_map_azure_model_group(model_group_header, expected_model):
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_azure_ai_with_image_url():
|
||||
@pytest.mark.respx
|
||||
async def test_azure_ai_with_image_url(respx_mock: MockRouter):
|
||||
"""
|
||||
Important test:
|
||||
|
||||
Test that Azure AI studio can handle image_url passed when content is a list containing both text and image_url
|
||||
"""
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
litellm.set_verbose = True
|
||||
|
||||
client = AsyncOpenAI(
|
||||
api_key="fake-api-key",
|
||||
base_url="https://Phi-3-5-vision-instruct-dcvov.eastus2.models.ai.azure.com",
|
||||
)
|
||||
# Mock response based on the actual API response
|
||||
mock_response = {
|
||||
"id": "cmpl-53860ea1efa24d2883555bfec13d2254",
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
"logprobs": None,
|
||||
"message": {
|
||||
"content": "The image displays a graphic with the text 'LiteLLM' in black",
|
||||
"role": "assistant",
|
||||
"refusal": None,
|
||||
"audio": None,
|
||||
"function_call": None,
|
||||
"tool_calls": None,
|
||||
},
|
||||
}
|
||||
],
|
||||
"created": 1731801937,
|
||||
"model": "phi35-vision-instruct",
|
||||
"object": "chat.completion",
|
||||
"usage": {
|
||||
"completion_tokens": 69,
|
||||
"prompt_tokens": 617,
|
||||
"total_tokens": 686,
|
||||
"completion_tokens_details": None,
|
||||
"prompt_tokens_details": None,
|
||||
},
|
||||
}
|
||||
|
||||
with patch.object(
|
||||
client.chat.completions.with_raw_response, "create"
|
||||
) as mock_client:
|
||||
try:
|
||||
await litellm.acompletion(
|
||||
model="azure_ai/Phi-3-5-vision-instruct-dcvov",
|
||||
api_base="https://Phi-3-5-vision-instruct-dcvov.eastus2.models.ai.azure.com",
|
||||
messages=[
|
||||
# Mock the API request
|
||||
mock_request = respx_mock.post(
|
||||
"https://Phi-3-5-vision-instruct-dcvov.eastus2.models.ai.azure.com"
|
||||
).mock(return_value=httpx.Response(200, json=mock_response))
|
||||
|
||||
response = await litellm.acompletion(
|
||||
model="azure_ai/Phi-3-5-vision-instruct-dcvov",
|
||||
api_base="https://Phi-3-5-vision-instruct-dcvov.eastus2.models.ai.azure.com",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What is in this image?",
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://litellm-listing.s3.amazonaws.com/litellm_logo.png"
|
||||
},
|
||||
},
|
||||
],
|
||||
"type": "text",
|
||||
"text": "What is in this image?",
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://litellm-listing.s3.amazonaws.com/litellm_logo.png"
|
||||
},
|
||||
},
|
||||
],
|
||||
api_key="fake-api-key",
|
||||
client=client,
|
||||
)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
print(f"Error: {e}")
|
||||
},
|
||||
],
|
||||
api_key="fake-api-key",
|
||||
)
|
||||
|
||||
# Verify the request was made
|
||||
mock_client.assert_called_once()
|
||||
# Verify the request was made
|
||||
assert mock_request.called
|
||||
|
||||
# Check the request body
|
||||
request_body = mock_client.call_args.kwargs
|
||||
assert request_body["model"] == "Phi-3-5-vision-instruct-dcvov"
|
||||
assert request_body["messages"] == [
|
||||
# Check the request body
|
||||
request_body = json.loads(mock_request.calls[0].request.content)
|
||||
assert request_body == {
|
||||
"model": "Phi-3-5-vision-instruct-dcvov",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
|
@ -110,4 +132,7 @@ async def test_azure_ai_with_image_url():
|
|||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
],
|
||||
}
|
||||
|
||||
print(f"response: {response}")
|
||||
|
|
|
@ -1243,19 +1243,6 @@ def test_bedrock_cross_region_inference(model):
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model, expected_base_model",
|
||||
[
|
||||
(
|
||||
"apac.anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||
"anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_bedrock_get_base_model(model, expected_base_model):
|
||||
assert litellm.AmazonConverseConfig()._get_base_model(model) == expected_base_model
|
||||
|
||||
|
||||
from litellm.llms.prompt_templates.factory import _bedrock_converse_messages_pt
|
||||
|
||||
|
||||
|
|
|
@ -13,7 +13,6 @@ load_dotenv()
|
|||
import httpx
|
||||
import pytest
|
||||
from respx import MockRouter
|
||||
from unittest.mock import patch, MagicMock, AsyncMock
|
||||
|
||||
import litellm
|
||||
from litellm import Choices, Message, ModelResponse
|
||||
|
@ -42,58 +41,56 @@ def return_mocked_response(model: str):
|
|||
"bedrock/mistral.mistral-large-2407-v1:0",
|
||||
],
|
||||
)
|
||||
@pytest.mark.respx
|
||||
@pytest.mark.asyncio()
|
||||
async def test_bedrock_max_completion_tokens(model: str):
|
||||
async def test_bedrock_max_completion_tokens(model: str, respx_mock: MockRouter):
|
||||
"""
|
||||
Tests that:
|
||||
- max_completion_tokens is passed as max_tokens to bedrock models
|
||||
"""
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||
|
||||
litellm.set_verbose = True
|
||||
|
||||
client = AsyncHTTPHandler()
|
||||
|
||||
mock_response = return_mocked_response(model)
|
||||
_model = model.split("/")[1]
|
||||
print("\n\nmock_response: ", mock_response)
|
||||
url = f"https://bedrock-runtime.us-west-2.amazonaws.com/model/{_model}/converse"
|
||||
mock_request = respx_mock.post(url).mock(
|
||||
return_value=httpx.Response(200, json=mock_response)
|
||||
)
|
||||
|
||||
with patch.object(client, "post") as mock_client:
|
||||
try:
|
||||
response = await litellm.acompletion(
|
||||
model=model,
|
||||
max_completion_tokens=10,
|
||||
messages=[{"role": "user", "content": "Hello!"}],
|
||||
client=client,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
response = await litellm.acompletion(
|
||||
model=model,
|
||||
max_completion_tokens=10,
|
||||
messages=[{"role": "user", "content": "Hello!"}],
|
||||
)
|
||||
|
||||
mock_client.assert_called_once()
|
||||
request_body = json.loads(mock_client.call_args.kwargs["data"])
|
||||
assert mock_request.called
|
||||
request_body = json.loads(mock_request.calls[0].request.content)
|
||||
|
||||
print("request_body: ", request_body)
|
||||
print("request_body: ", request_body)
|
||||
|
||||
assert request_body == {
|
||||
"messages": [{"role": "user", "content": [{"text": "Hello!"}]}],
|
||||
"additionalModelRequestFields": {},
|
||||
"system": [],
|
||||
"inferenceConfig": {"maxTokens": 10},
|
||||
}
|
||||
assert request_body == {
|
||||
"messages": [{"role": "user", "content": [{"text": "Hello!"}]}],
|
||||
"additionalModelRequestFields": {},
|
||||
"system": [],
|
||||
"inferenceConfig": {"maxTokens": 10},
|
||||
}
|
||||
print(f"response: {response}")
|
||||
assert isinstance(response, ModelResponse)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
["anthropic/claude-3-sonnet-20240229", "anthropic/claude-3-opus-20240229"],
|
||||
["anthropic/claude-3-sonnet-20240229", "anthropic/claude-3-opus-20240229,"],
|
||||
)
|
||||
@pytest.mark.respx
|
||||
@pytest.mark.asyncio()
|
||||
async def test_anthropic_api_max_completion_tokens(model: str):
|
||||
async def test_anthropic_api_max_completion_tokens(model: str, respx_mock: MockRouter):
|
||||
"""
|
||||
Tests that:
|
||||
- max_completion_tokens is passed as max_tokens to anthropic models
|
||||
"""
|
||||
litellm.set_verbose = True
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
|
||||
mock_response = {
|
||||
"content": [{"text": "Hi! My name is Claude.", "type": "text"}],
|
||||
|
@ -106,32 +103,30 @@ async def test_anthropic_api_max_completion_tokens(model: str):
|
|||
"usage": {"input_tokens": 2095, "output_tokens": 503},
|
||||
}
|
||||
|
||||
client = HTTPHandler()
|
||||
|
||||
print("\n\nmock_response: ", mock_response)
|
||||
url = f"https://api.anthropic.com/v1/messages"
|
||||
mock_request = respx_mock.post(url).mock(
|
||||
return_value=httpx.Response(200, json=mock_response)
|
||||
)
|
||||
|
||||
with patch.object(client, "post") as mock_client:
|
||||
try:
|
||||
response = await litellm.acompletion(
|
||||
model=model,
|
||||
max_completion_tokens=10,
|
||||
messages=[{"role": "user", "content": "Hello!"}],
|
||||
client=client,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
mock_client.assert_called_once()
|
||||
request_body = mock_client.call_args.kwargs["json"]
|
||||
response = await litellm.acompletion(
|
||||
model=model,
|
||||
max_completion_tokens=10,
|
||||
messages=[{"role": "user", "content": "Hello!"}],
|
||||
)
|
||||
|
||||
print("request_body: ", request_body)
|
||||
assert mock_request.called
|
||||
request_body = json.loads(mock_request.calls[0].request.content)
|
||||
|
||||
assert request_body == {
|
||||
"messages": [
|
||||
{"role": "user", "content": [{"type": "text", "text": "Hello!"}]}
|
||||
],
|
||||
"max_tokens": 10,
|
||||
"model": model.split("/")[-1],
|
||||
}
|
||||
print("request_body: ", request_body)
|
||||
|
||||
assert request_body == {
|
||||
"messages": [{"role": "user", "content": [{"type": "text", "text": "Hello!"}]}],
|
||||
"max_tokens": 10,
|
||||
"model": model.split("/")[-1],
|
||||
}
|
||||
print(f"response: {response}")
|
||||
assert isinstance(response, ModelResponse)
|
||||
|
||||
|
||||
def test_all_model_configs():
|
||||
|
|
|
@ -12,78 +12,95 @@ sys.path.insert(
|
|||
import httpx
|
||||
import pytest
|
||||
from respx import MockRouter
|
||||
from unittest.mock import patch, MagicMock, AsyncMock
|
||||
|
||||
import litellm
|
||||
from litellm import Choices, Message, ModelResponse, EmbeddingResponse, Usage
|
||||
from litellm import completion
|
||||
|
||||
|
||||
def test_completion_nvidia_nim():
|
||||
from openai import OpenAI
|
||||
|
||||
@pytest.mark.respx
|
||||
def test_completion_nvidia_nim(respx_mock: MockRouter):
|
||||
litellm.set_verbose = True
|
||||
mock_response = ModelResponse(
|
||||
id="cmpl-mock",
|
||||
choices=[Choices(message=Message(content="Mocked response", role="assistant"))],
|
||||
created=int(datetime.now().timestamp()),
|
||||
model="databricks/dbrx-instruct",
|
||||
)
|
||||
model_name = "nvidia_nim/databricks/dbrx-instruct"
|
||||
client = OpenAI(
|
||||
api_key="fake-api-key",
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
client.chat.completions.with_raw_response, "create"
|
||||
) as mock_client:
|
||||
try:
|
||||
completion(
|
||||
model=model_name,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the weather like in Boston today in Fahrenheit?",
|
||||
}
|
||||
],
|
||||
presence_penalty=0.5,
|
||||
frequency_penalty=0.1,
|
||||
client=client,
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
mock_request = respx_mock.post(
|
||||
"https://integrate.api.nvidia.com/v1/chat/completions"
|
||||
).mock(return_value=httpx.Response(200, json=mock_response.dict()))
|
||||
try:
|
||||
response = completion(
|
||||
model=model_name,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the weather like in Boston today in Fahrenheit?",
|
||||
}
|
||||
],
|
||||
presence_penalty=0.5,
|
||||
frequency_penalty=0.1,
|
||||
)
|
||||
# Add any assertions here to check the response
|
||||
print(response)
|
||||
assert response.choices[0].message.content is not None
|
||||
assert len(response.choices[0].message.content) > 0
|
||||
|
||||
mock_client.assert_called_once()
|
||||
request_body = mock_client.call_args.kwargs
|
||||
assert mock_request.called
|
||||
request_body = json.loads(mock_request.calls[0].request.content)
|
||||
|
||||
print("request_body: ", request_body)
|
||||
|
||||
assert request_body["messages"] == [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the weather like in Boston today in Fahrenheit?",
|
||||
},
|
||||
]
|
||||
assert request_body["model"] == "databricks/dbrx-instruct"
|
||||
assert request_body["frequency_penalty"] == 0.1
|
||||
assert request_body["presence_penalty"] == 0.5
|
||||
assert request_body == {
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the weather like in Boston today in Fahrenheit?",
|
||||
}
|
||||
],
|
||||
"model": "databricks/dbrx-instruct",
|
||||
"frequency_penalty": 0.1,
|
||||
"presence_penalty": 0.5,
|
||||
}
|
||||
except litellm.exceptions.Timeout as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
def test_embedding_nvidia_nim():
|
||||
def test_embedding_nvidia_nim(respx_mock: MockRouter):
|
||||
litellm.set_verbose = True
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(
|
||||
api_key="fake-api-key",
|
||||
mock_response = EmbeddingResponse(
|
||||
model="nvidia_nim/databricks/dbrx-instruct",
|
||||
data=[
|
||||
{
|
||||
"embedding": [0.1, 0.2, 0.3],
|
||||
"index": 0,
|
||||
}
|
||||
],
|
||||
usage=Usage(
|
||||
prompt_tokens=10,
|
||||
completion_tokens=0,
|
||||
total_tokens=10,
|
||||
),
|
||||
)
|
||||
with patch.object(client.embeddings.with_raw_response, "create") as mock_client:
|
||||
try:
|
||||
litellm.embedding(
|
||||
model="nvidia_nim/nvidia/nv-embedqa-e5-v5",
|
||||
input="What is the meaning of life?",
|
||||
input_type="passage",
|
||||
client=client,
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
mock_client.assert_called_once()
|
||||
request_body = mock_client.call_args.kwargs
|
||||
print("request_body: ", request_body)
|
||||
assert request_body["input"] == "What is the meaning of life?"
|
||||
assert request_body["model"] == "nvidia/nv-embedqa-e5-v5"
|
||||
assert request_body["extra_body"]["input_type"] == "passage"
|
||||
mock_request = respx_mock.post(
|
||||
"https://integrate.api.nvidia.com/v1/embeddings"
|
||||
).mock(return_value=httpx.Response(200, json=mock_response.dict()))
|
||||
response = litellm.embedding(
|
||||
model="nvidia_nim/nvidia/nv-embedqa-e5-v5",
|
||||
input="What is the meaning of life?",
|
||||
input_type="passage",
|
||||
)
|
||||
assert mock_request.called
|
||||
request_body = json.loads(mock_request.calls[0].request.content)
|
||||
print("request_body: ", request_body)
|
||||
assert request_body == {
|
||||
"input": "What is the meaning of life?",
|
||||
"model": "nvidia/nv-embedqa-e5-v5",
|
||||
"input_type": "passage",
|
||||
"encoding_format": "base64",
|
||||
}
|
||||
|
|
|
@ -2,7 +2,7 @@ import json
|
|||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
|
@ -18,75 +18,87 @@ from litellm import Choices, Message, ModelResponse
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_o1_handle_system_role():
|
||||
@pytest.mark.respx
|
||||
async def test_o1_handle_system_role(respx_mock: MockRouter):
|
||||
"""
|
||||
Tests that:
|
||||
- max_tokens is translated to 'max_completion_tokens'
|
||||
- role 'system' is translated to 'user'
|
||||
"""
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
litellm.set_verbose = True
|
||||
|
||||
client = AsyncOpenAI(api_key="fake-api-key")
|
||||
mock_response = ModelResponse(
|
||||
id="cmpl-mock",
|
||||
choices=[Choices(message=Message(content="Mocked response", role="assistant"))],
|
||||
created=int(datetime.now().timestamp()),
|
||||
model="o1-preview",
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
client.chat.completions.with_raw_response, "create"
|
||||
) as mock_client:
|
||||
try:
|
||||
await litellm.acompletion(
|
||||
model="o1-preview",
|
||||
max_tokens=10,
|
||||
messages=[{"role": "system", "content": "Hello!"}],
|
||||
client=client,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
mock_request = respx_mock.post("https://api.openai.com/v1/chat/completions").mock(
|
||||
return_value=httpx.Response(200, json=mock_response.dict())
|
||||
)
|
||||
|
||||
mock_client.assert_called_once()
|
||||
request_body = mock_client.call_args.kwargs
|
||||
response = await litellm.acompletion(
|
||||
model="o1-preview",
|
||||
max_tokens=10,
|
||||
messages=[{"role": "system", "content": "Hello!"}],
|
||||
)
|
||||
|
||||
print("request_body: ", request_body)
|
||||
assert mock_request.called
|
||||
request_body = json.loads(mock_request.calls[0].request.content)
|
||||
|
||||
assert request_body["model"] == "o1-preview"
|
||||
assert request_body["max_completion_tokens"] == 10
|
||||
assert request_body["messages"] == [{"role": "user", "content": "Hello!"}]
|
||||
print("request_body: ", request_body)
|
||||
|
||||
assert request_body == {
|
||||
"model": "o1-preview",
|
||||
"max_completion_tokens": 10,
|
||||
"messages": [{"role": "user", "content": "Hello!"}],
|
||||
}
|
||||
|
||||
print(f"response: {response}")
|
||||
assert isinstance(response, ModelResponse)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.respx
|
||||
@pytest.mark.parametrize("model", ["gpt-4", "gpt-4-0314", "gpt-4-32k", "o1-preview"])
|
||||
async def test_o1_max_completion_tokens(model: str):
|
||||
async def test_o1_max_completion_tokens(respx_mock: MockRouter, model: str):
|
||||
"""
|
||||
Tests that:
|
||||
- max_completion_tokens is passed directly to OpenAI chat completion models
|
||||
"""
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
litellm.set_verbose = True
|
||||
|
||||
client = AsyncOpenAI(api_key="fake-api-key")
|
||||
mock_response = ModelResponse(
|
||||
id="cmpl-mock",
|
||||
choices=[Choices(message=Message(content="Mocked response", role="assistant"))],
|
||||
created=int(datetime.now().timestamp()),
|
||||
model=model,
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
client.chat.completions.with_raw_response, "create"
|
||||
) as mock_client:
|
||||
try:
|
||||
await litellm.acompletion(
|
||||
model=model,
|
||||
max_completion_tokens=10,
|
||||
messages=[{"role": "user", "content": "Hello!"}],
|
||||
client=client,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
mock_request = respx_mock.post("https://api.openai.com/v1/chat/completions").mock(
|
||||
return_value=httpx.Response(200, json=mock_response.dict())
|
||||
)
|
||||
|
||||
mock_client.assert_called_once()
|
||||
request_body = mock_client.call_args.kwargs
|
||||
response = await litellm.acompletion(
|
||||
model=model,
|
||||
max_completion_tokens=10,
|
||||
messages=[{"role": "user", "content": "Hello!"}],
|
||||
)
|
||||
|
||||
print("request_body: ", request_body)
|
||||
assert mock_request.called
|
||||
request_body = json.loads(mock_request.calls[0].request.content)
|
||||
|
||||
assert request_body["model"] == model
|
||||
assert request_body["max_completion_tokens"] == 10
|
||||
assert request_body["messages"] == [{"role": "user", "content": "Hello!"}]
|
||||
print("request_body: ", request_body)
|
||||
|
||||
assert request_body == {
|
||||
"model": model,
|
||||
"max_completion_tokens": 10,
|
||||
"messages": [{"role": "user", "content": "Hello!"}],
|
||||
}
|
||||
|
||||
print(f"response: {response}")
|
||||
assert isinstance(response, ModelResponse)
|
||||
|
||||
|
||||
def test_litellm_responses():
|
||||
|
|
|
@ -2,7 +2,7 @@ import json
|
|||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
|
@ -63,7 +63,8 @@ def test_openai_prediction_param():
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_prediction_param_mock():
|
||||
@pytest.mark.respx
|
||||
async def test_openai_prediction_param_mock(respx_mock: MockRouter):
|
||||
"""
|
||||
Tests that prediction parameter is correctly passed to the API
|
||||
"""
|
||||
|
@ -91,36 +92,60 @@ async def test_openai_prediction_param_mock():
|
|||
public string Username { get; set; }
|
||||
}
|
||||
"""
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
client = AsyncOpenAI(api_key="fake-api-key")
|
||||
|
||||
with patch.object(
|
||||
client.chat.completions.with_raw_response, "create"
|
||||
) as mock_client:
|
||||
try:
|
||||
await litellm.acompletion(
|
||||
model="gpt-4o-mini",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Replace the Username property with an Email property. Respond only with code, and with no markdown formatting.",
|
||||
},
|
||||
{"role": "user", "content": code},
|
||||
],
|
||||
prediction={"type": "content", "content": code},
|
||||
client=client,
|
||||
mock_response = ModelResponse(
|
||||
id="chatcmpl-AQ5RmV8GvVSRxEcDxnuXlQnsibiY9",
|
||||
choices=[
|
||||
Choices(
|
||||
message=Message(
|
||||
content=code.replace("Username", "Email").replace(
|
||||
"username", "email"
|
||||
),
|
||||
role="assistant",
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
],
|
||||
created=int(datetime.now().timestamp()),
|
||||
model="gpt-4o-mini-2024-07-18",
|
||||
usage={
|
||||
"completion_tokens": 207,
|
||||
"prompt_tokens": 175,
|
||||
"total_tokens": 382,
|
||||
"completion_tokens_details": {
|
||||
"accepted_prediction_tokens": 0,
|
||||
"reasoning_tokens": 0,
|
||||
"rejected_prediction_tokens": 80,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
mock_client.assert_called_once()
|
||||
request_body = mock_client.call_args.kwargs
|
||||
mock_request = respx_mock.post("https://api.openai.com/v1/chat/completions").mock(
|
||||
return_value=httpx.Response(200, json=mock_response.dict())
|
||||
)
|
||||
|
||||
# Verify the request contains the prediction parameter
|
||||
assert "prediction" in request_body
|
||||
# verify prediction is correctly sent to the API
|
||||
assert request_body["prediction"] == {"type": "content", "content": code}
|
||||
completion = await litellm.acompletion(
|
||||
model="gpt-4o-mini",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Replace the Username property with an Email property. Respond only with code, and with no markdown formatting.",
|
||||
},
|
||||
{"role": "user", "content": code},
|
||||
],
|
||||
prediction={"type": "content", "content": code},
|
||||
)
|
||||
|
||||
assert mock_request.called
|
||||
request_body = json.loads(mock_request.calls[0].request.content)
|
||||
|
||||
# Verify the request contains the prediction parameter
|
||||
assert "prediction" in request_body
|
||||
# verify prediction is correctly sent to the API
|
||||
assert request_body["prediction"] == {"type": "content", "content": code}
|
||||
|
||||
# Verify the completion tokens details
|
||||
assert completion.usage.completion_tokens_details.accepted_prediction_tokens == 0
|
||||
assert completion.usage.completion_tokens_details.rejected_prediction_tokens == 80
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -198,73 +223,3 @@ async def test_openai_prediction_param_with_caching():
|
|||
)
|
||||
|
||||
assert completion_response_3.id != completion_response_1.id
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_vision_with_custom_model():
|
||||
"""
|
||||
Tests that an OpenAI compatible endpoint when sent an image will receive the image in the request
|
||||
|
||||
"""
|
||||
import base64
|
||||
import requests
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
client = AsyncOpenAI(api_key="fake-api-key")
|
||||
|
||||
litellm.set_verbose = True
|
||||
api_base = "https://my-custom.api.openai.com"
|
||||
|
||||
# Fetch and encode a test image
|
||||
url = "https://dummyimage.com/100/100/fff&text=Test+image"
|
||||
response = requests.get(url)
|
||||
file_data = response.content
|
||||
encoded_file = base64.b64encode(file_data).decode("utf-8")
|
||||
base64_image = f"data:image/png;base64,{encoded_file}"
|
||||
|
||||
with patch.object(
|
||||
client.chat.completions.with_raw_response, "create"
|
||||
) as mock_client:
|
||||
try:
|
||||
response = await litellm.acompletion(
|
||||
model="openai/my-custom-model",
|
||||
max_tokens=10,
|
||||
api_base=api_base, # use the mock api
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": base64_image},
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
client=client,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
|
||||
mock_client.assert_called_once()
|
||||
request_body = mock_client.call_args.kwargs
|
||||
|
||||
print("request_body: ", request_body)
|
||||
|
||||
assert request_body["messages"] == [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAGQAAABkBAMAAACCzIhnAAAAG1BMVEURAAD///+ln5/h39/Dv79qX18uHx+If39MPz9oMSdmAAAACXBIWXMAAA7EAAAOxAGVKw4bAAABB0lEQVRYhe2SzWrEIBCAh2A0jxEs4j6GLDS9hqWmV5Flt0cJS+lRwv742DXpEjY1kOZW6HwHFZnPmVEBEARBEARB/jd0KYA/bcUYbPrRLh6amXHJ/K+ypMoyUaGthILzw0l+xI0jsO7ZcmCcm4ILd+QuVYgpHOmDmz6jBeJImdcUCmeBqQpuqRIbVmQsLCrAalrGpfoEqEogqbLTWuXCPCo+Ki1XGqgQ+jVVuhB8bOaHkvmYuzm/b0KYLWwoK58oFqi6XfxQ4Uz7d6WeKpna6ytUs5e8betMcqAv5YPC5EZB2Lm9FIn0/VP6R58+/GEY1X1egVoZ/3bt/EqF6malgSAIgiDIH+QL41409QMY0LMAAAAASUVORK5CYII="
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
assert request_body["model"] == "my-custom-model"
|
||||
assert request_body["max_tokens"] == 10
|
94
tests/llm_translation/test_supports_vision.py
Normal file
94
tests/llm_translation/test_supports_vision.py
Normal file
|
@ -0,0 +1,94 @@
|
|||
import json
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from respx import MockRouter
|
||||
|
||||
import litellm
|
||||
from litellm import Choices, Message, ModelResponse
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
@pytest.mark.respx
|
||||
async def test_vision_with_custom_model(respx_mock: MockRouter):
|
||||
"""
|
||||
Tests that an OpenAI compatible endpoint when sent an image will receive the image in the request
|
||||
|
||||
"""
|
||||
import base64
|
||||
import requests
|
||||
|
||||
litellm.set_verbose = True
|
||||
api_base = "https://my-custom.api.openai.com"
|
||||
|
||||
# Fetch and encode a test image
|
||||
url = "https://dummyimage.com/100/100/fff&text=Test+image"
|
||||
response = requests.get(url)
|
||||
file_data = response.content
|
||||
encoded_file = base64.b64encode(file_data).decode("utf-8")
|
||||
base64_image = f"data:image/png;base64,{encoded_file}"
|
||||
|
||||
mock_response = ModelResponse(
|
||||
id="cmpl-mock",
|
||||
choices=[Choices(message=Message(content="Mocked response", role="assistant"))],
|
||||
created=int(datetime.now().timestamp()),
|
||||
model="my-custom-model",
|
||||
)
|
||||
|
||||
mock_request = respx_mock.post(f"{api_base}/chat/completions").mock(
|
||||
return_value=httpx.Response(200, json=mock_response.dict())
|
||||
)
|
||||
|
||||
response = await litellm.acompletion(
|
||||
model="openai/my-custom-model",
|
||||
max_tokens=10,
|
||||
api_base=api_base, # use the mock api
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": base64_image},
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
assert mock_request.called
|
||||
request_body = json.loads(mock_request.calls[0].request.content)
|
||||
|
||||
print("request_body: ", request_body)
|
||||
|
||||
assert request_body == {
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAGQAAABkBAMAAACCzIhnAAAAG1BMVEURAAD///+ln5/h39/Dv79qX18uHx+If39MPz9oMSdmAAAACXBIWXMAAA7EAAAOxAGVKw4bAAABB0lEQVRYhe2SzWrEIBCAh2A0jxEs4j6GLDS9hqWmV5Flt0cJS+lRwv742DXpEjY1kOZW6HwHFZnPmVEBEARBEARB/jd0KYA/bcUYbPrRLh6amXHJ/K+ypMoyUaGthILzw0l+xI0jsO7ZcmCcm4ILd+QuVYgpHOmDmz6jBeJImdcUCmeBqQpuqRIbVmQsLCrAalrGpfoEqEogqbLTWuXCPCo+Ki1XGqgQ+jVVuhB8bOaHkvmYuzm/b0KYLWwoK58oFqi6XfxQ4Uz7d6WeKpna6ytUs5e8betMcqAv5YPC5EZB2Lm9FIn0/VP6R58+/GEY1X1egVoZ/3bt/EqF6malgSAIgiDIH+QL41409QMY0LMAAAAASUVORK5CYII="
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
"model": "my-custom-model",
|
||||
"max_tokens": 10,
|
||||
}
|
||||
|
||||
print(f"response: {response}")
|
||||
assert isinstance(response, ModelResponse)
|
|
@ -6,7 +6,6 @@ from unittest.mock import AsyncMock
|
|||
import pytest
|
||||
import httpx
|
||||
from respx import MockRouter
|
||||
from unittest.mock import patch, MagicMock, AsyncMock
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
|
@ -69,16 +68,13 @@ def test_convert_dict_to_text_completion_response():
|
|||
assert response.choices[0].logprobs.top_logprobs == [None, {",": -2.1568563}]
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="need to migrate huggingface to support httpx client being passed in"
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.respx
|
||||
async def test_huggingface_text_completion_logprobs():
|
||||
async def test_huggingface_text_completion_logprobs(respx_mock: MockRouter):
|
||||
"""Test text completion with Hugging Face, focusing on logprobs structure"""
|
||||
litellm.set_verbose = True
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler
|
||||
|
||||
# Mock the raw response from Hugging Face
|
||||
mock_response = [
|
||||
{
|
||||
"generated_text": ",\n\nI have a question...", # truncated for brevity
|
||||
|
@ -95,48 +91,46 @@ async def test_huggingface_text_completion_logprobs():
|
|||
}
|
||||
]
|
||||
|
||||
return_val = AsyncMock()
|
||||
# Mock the API request
|
||||
mock_request = respx_mock.post(
|
||||
"https://api-inference.huggingface.co/models/mistralai/Mistral-7B-v0.1"
|
||||
).mock(return_value=httpx.Response(200, json=mock_response))
|
||||
|
||||
return_val.json.return_value = mock_response
|
||||
response = await litellm.atext_completion(
|
||||
model="huggingface/mistralai/Mistral-7B-v0.1",
|
||||
prompt="good morning",
|
||||
)
|
||||
|
||||
client = AsyncHTTPHandler()
|
||||
with patch.object(client, "post", return_value=return_val) as mock_post:
|
||||
response = await litellm.atext_completion(
|
||||
model="huggingface/mistralai/Mistral-7B-v0.1",
|
||||
prompt="good morning",
|
||||
client=client,
|
||||
)
|
||||
# Verify the request
|
||||
assert mock_request.called
|
||||
request_body = json.loads(mock_request.calls[0].request.content)
|
||||
assert request_body == {
|
||||
"inputs": "good morning",
|
||||
"parameters": {"details": True, "return_full_text": False},
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
# Verify the request
|
||||
mock_post.assert_called_once()
|
||||
request_body = json.loads(mock_post.call_args.kwargs["data"])
|
||||
assert request_body == {
|
||||
"inputs": "good morning",
|
||||
"parameters": {"details": True, "return_full_text": False},
|
||||
"stream": False,
|
||||
}
|
||||
print("response=", response)
|
||||
|
||||
print("response=", response)
|
||||
# Verify response structure
|
||||
assert isinstance(response, TextCompletionResponse)
|
||||
assert response.object == "text_completion"
|
||||
assert response.model == "mistralai/Mistral-7B-v0.1"
|
||||
|
||||
# Verify response structure
|
||||
assert isinstance(response, TextCompletionResponse)
|
||||
assert response.object == "text_completion"
|
||||
assert response.model == "mistralai/Mistral-7B-v0.1"
|
||||
# Verify logprobs structure
|
||||
choice = response.choices[0]
|
||||
assert choice.finish_reason == "length"
|
||||
assert choice.index == 0
|
||||
assert isinstance(choice.logprobs.tokens, list)
|
||||
assert isinstance(choice.logprobs.token_logprobs, list)
|
||||
assert isinstance(choice.logprobs.text_offset, list)
|
||||
assert isinstance(choice.logprobs.top_logprobs, list)
|
||||
assert choice.logprobs.tokens == [",", "\n"]
|
||||
assert choice.logprobs.token_logprobs == [-1.7626953, -1.7314453]
|
||||
assert choice.logprobs.text_offset == [0, 1]
|
||||
assert choice.logprobs.top_logprobs == [{}, {}]
|
||||
|
||||
# Verify logprobs structure
|
||||
choice = response.choices[0]
|
||||
assert choice.finish_reason == "length"
|
||||
assert choice.index == 0
|
||||
assert isinstance(choice.logprobs.tokens, list)
|
||||
assert isinstance(choice.logprobs.token_logprobs, list)
|
||||
assert isinstance(choice.logprobs.text_offset, list)
|
||||
assert isinstance(choice.logprobs.top_logprobs, list)
|
||||
assert choice.logprobs.tokens == [",", "\n"]
|
||||
assert choice.logprobs.token_logprobs == [-1.7626953, -1.7314453]
|
||||
assert choice.logprobs.text_offset == [0, 1]
|
||||
assert choice.logprobs.top_logprobs == [{}, {}]
|
||||
|
||||
# Verify usage
|
||||
assert response.usage["completion_tokens"] > 0
|
||||
assert response.usage["prompt_tokens"] > 0
|
||||
assert response.usage["total_tokens"] > 0
|
||||
# Verify usage
|
||||
assert response.usage["completion_tokens"] > 0
|
||||
assert response.usage["prompt_tokens"] > 0
|
||||
assert response.usage["total_tokens"] > 0
|
||||
|
|
|
@ -1146,21 +1146,6 @@ def test_process_gemini_image():
|
|||
mime_type="image/png", file_uri="https://example.com/image.png"
|
||||
)
|
||||
|
||||
# Test HTTPS VIDEO URL
|
||||
https_result = _process_gemini_image("https://cloud-samples-data/video/animals.mp4")
|
||||
print("https_result PNG", https_result)
|
||||
assert https_result["file_data"] == FileDataType(
|
||||
mime_type="video/mp4", file_uri="https://cloud-samples-data/video/animals.mp4"
|
||||
)
|
||||
|
||||
# Test HTTPS PDF URL
|
||||
https_result = _process_gemini_image("https://cloud-samples-data/pdf/animals.pdf")
|
||||
print("https_result PDF", https_result)
|
||||
assert https_result["file_data"] == FileDataType(
|
||||
mime_type="application/pdf",
|
||||
file_uri="https://cloud-samples-data/pdf/animals.pdf",
|
||||
)
|
||||
|
||||
# Test base64 image
|
||||
base64_image = "data:image/jpeg;base64,/9j/4AAQSkZJRg..."
|
||||
base64_result = _process_gemini_image(base64_image)
|
||||
|
|
|
@ -95,107 +95,3 @@ async def test_handle_failed_db_connection():
|
|||
print("_handle_failed_db_connection_for_get_key_object got exception", exc_info)
|
||||
|
||||
assert str(exc_info.value) == "Failed to connect to DB"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model, expect_to_work",
|
||||
[("openai/gpt-4o-mini", True), ("openai/gpt-4o", False)],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_can_key_call_model(model, expect_to_work):
|
||||
"""
|
||||
If wildcard model + specific model is used, choose the specific model settings
|
||||
"""
|
||||
from litellm.proxy.auth.auth_checks import can_key_call_model
|
||||
from fastapi import HTTPException
|
||||
|
||||
llm_model_list = [
|
||||
{
|
||||
"model_name": "openai/*",
|
||||
"litellm_params": {
|
||||
"model": "openai/*",
|
||||
"api_key": "test-api-key",
|
||||
},
|
||||
"model_info": {
|
||||
"id": "e6e7006f83029df40ebc02ddd068890253f4cd3092bcb203d3d8e6f6f606f30f",
|
||||
"db_model": False,
|
||||
"access_groups": ["public-openai-models"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "openai/gpt-4o",
|
||||
"litellm_params": {
|
||||
"model": "openai/gpt-4o",
|
||||
"api_key": "test-api-key",
|
||||
},
|
||||
"model_info": {
|
||||
"id": "0cfcd87f2cb12a783a466888d05c6c89df66db23e01cecd75ec0b83aed73c9ad",
|
||||
"db_model": False,
|
||||
"access_groups": ["private-openai-models"],
|
||||
},
|
||||
},
|
||||
]
|
||||
router = litellm.Router(model_list=llm_model_list)
|
||||
args = {
|
||||
"model": model,
|
||||
"llm_model_list": llm_model_list,
|
||||
"valid_token": UserAPIKeyAuth(
|
||||
models=["public-openai-models"],
|
||||
),
|
||||
"llm_router": router,
|
||||
}
|
||||
if expect_to_work:
|
||||
await can_key_call_model(**args)
|
||||
else:
|
||||
with pytest.raises(Exception) as e:
|
||||
await can_key_call_model(**args)
|
||||
|
||||
print(e)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model, expect_to_work",
|
||||
[("openai/gpt-4o", False), ("openai/gpt-4o-mini", True)],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_can_team_call_model(model, expect_to_work):
|
||||
from litellm.proxy.auth.auth_checks import model_in_access_group
|
||||
from fastapi import HTTPException
|
||||
|
||||
llm_model_list = [
|
||||
{
|
||||
"model_name": "openai/*",
|
||||
"litellm_params": {
|
||||
"model": "openai/*",
|
||||
"api_key": "test-api-key",
|
||||
},
|
||||
"model_info": {
|
||||
"id": "e6e7006f83029df40ebc02ddd068890253f4cd3092bcb203d3d8e6f6f606f30f",
|
||||
"db_model": False,
|
||||
"access_groups": ["public-openai-models"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "openai/gpt-4o",
|
||||
"litellm_params": {
|
||||
"model": "openai/gpt-4o",
|
||||
"api_key": "test-api-key",
|
||||
},
|
||||
"model_info": {
|
||||
"id": "0cfcd87f2cb12a783a466888d05c6c89df66db23e01cecd75ec0b83aed73c9ad",
|
||||
"db_model": False,
|
||||
"access_groups": ["private-openai-models"],
|
||||
},
|
||||
},
|
||||
]
|
||||
router = litellm.Router(model_list=llm_model_list)
|
||||
|
||||
args = {
|
||||
"model": model,
|
||||
"team_models": ["public-openai-models"],
|
||||
"llm_router": router,
|
||||
}
|
||||
if expect_to_work:
|
||||
assert model_in_access_group(**args)
|
||||
else:
|
||||
assert not model_in_access_group(**args)
|
||||
|
|
|
@ -33,7 +33,7 @@ from litellm.router import Router
|
|||
|
||||
@pytest.mark.asyncio()
|
||||
@pytest.mark.respx()
|
||||
async def test_aaaaazure_tenant_id_auth(respx_mock: MockRouter):
|
||||
async def test_azure_tenant_id_auth(respx_mock: MockRouter):
|
||||
"""
|
||||
|
||||
Tests when we set tenant_id, client_id, client_secret they don't get sent with the request
|
||||
|
|
|
@ -1,128 +1,128 @@
|
|||
# #### What this tests ####
|
||||
# # This adds perf testing to the router, to ensure it's never > 50ms slower than the azure-openai sdk.
|
||||
# import sys, os, time, inspect, asyncio, traceback
|
||||
# from datetime import datetime
|
||||
# import pytest
|
||||
#### What this tests ####
|
||||
# This adds perf testing to the router, to ensure it's never > 50ms slower than the azure-openai sdk.
|
||||
import sys, os, time, inspect, asyncio, traceback
|
||||
from datetime import datetime
|
||||
import pytest
|
||||
|
||||
# sys.path.insert(0, os.path.abspath("../.."))
|
||||
# import openai, litellm, uuid
|
||||
# from openai import AsyncAzureOpenAI
|
||||
sys.path.insert(0, os.path.abspath("../.."))
|
||||
import openai, litellm, uuid
|
||||
from openai import AsyncAzureOpenAI
|
||||
|
||||
# client = AsyncAzureOpenAI(
|
||||
# api_key=os.getenv("AZURE_API_KEY"),
|
||||
# azure_endpoint=os.getenv("AZURE_API_BASE"), # type: ignore
|
||||
# api_version=os.getenv("AZURE_API_VERSION"),
|
||||
# )
|
||||
client = AsyncAzureOpenAI(
|
||||
api_key=os.getenv("AZURE_API_KEY"),
|
||||
azure_endpoint=os.getenv("AZURE_API_BASE"), # type: ignore
|
||||
api_version=os.getenv("AZURE_API_VERSION"),
|
||||
)
|
||||
|
||||
# model_list = [
|
||||
# {
|
||||
# "model_name": "azure-test",
|
||||
# "litellm_params": {
|
||||
# "model": "azure/chatgpt-v-2",
|
||||
# "api_key": os.getenv("AZURE_API_KEY"),
|
||||
# "api_base": os.getenv("AZURE_API_BASE"),
|
||||
# "api_version": os.getenv("AZURE_API_VERSION"),
|
||||
# },
|
||||
# }
|
||||
# ]
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "azure-test",
|
||||
"litellm_params": {
|
||||
"model": "azure/chatgpt-v-2",
|
||||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_base": os.getenv("AZURE_API_BASE"),
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
# router = litellm.Router(model_list=model_list) # type: ignore
|
||||
router = litellm.Router(model_list=model_list) # type: ignore
|
||||
|
||||
|
||||
# async def _openai_completion():
|
||||
# try:
|
||||
# start_time = time.time()
|
||||
# response = await client.chat.completions.create(
|
||||
# model="chatgpt-v-2",
|
||||
# messages=[{"role": "user", "content": f"This is a test: {uuid.uuid4()}"}],
|
||||
# stream=True,
|
||||
# )
|
||||
# time_to_first_token = None
|
||||
# first_token_ts = None
|
||||
# init_chunk = None
|
||||
# async for chunk in response:
|
||||
# if (
|
||||
# time_to_first_token is None
|
||||
# and len(chunk.choices) > 0
|
||||
# and chunk.choices[0].delta.content is not None
|
||||
# ):
|
||||
# first_token_ts = time.time()
|
||||
# time_to_first_token = first_token_ts - start_time
|
||||
# init_chunk = chunk
|
||||
# end_time = time.time()
|
||||
# print(
|
||||
# "OpenAI Call: ",
|
||||
# init_chunk,
|
||||
# start_time,
|
||||
# first_token_ts,
|
||||
# time_to_first_token,
|
||||
# end_time,
|
||||
# )
|
||||
# return time_to_first_token
|
||||
# except Exception as e:
|
||||
# print(e)
|
||||
# return None
|
||||
async def _openai_completion():
|
||||
try:
|
||||
start_time = time.time()
|
||||
response = await client.chat.completions.create(
|
||||
model="chatgpt-v-2",
|
||||
messages=[{"role": "user", "content": f"This is a test: {uuid.uuid4()}"}],
|
||||
stream=True,
|
||||
)
|
||||
time_to_first_token = None
|
||||
first_token_ts = None
|
||||
init_chunk = None
|
||||
async for chunk in response:
|
||||
if (
|
||||
time_to_first_token is None
|
||||
and len(chunk.choices) > 0
|
||||
and chunk.choices[0].delta.content is not None
|
||||
):
|
||||
first_token_ts = time.time()
|
||||
time_to_first_token = first_token_ts - start_time
|
||||
init_chunk = chunk
|
||||
end_time = time.time()
|
||||
print(
|
||||
"OpenAI Call: ",
|
||||
init_chunk,
|
||||
start_time,
|
||||
first_token_ts,
|
||||
time_to_first_token,
|
||||
end_time,
|
||||
)
|
||||
return time_to_first_token
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return None
|
||||
|
||||
|
||||
# async def _router_completion():
|
||||
# try:
|
||||
# start_time = time.time()
|
||||
# response = await router.acompletion(
|
||||
# model="azure-test",
|
||||
# messages=[{"role": "user", "content": f"This is a test: {uuid.uuid4()}"}],
|
||||
# stream=True,
|
||||
# )
|
||||
# time_to_first_token = None
|
||||
# first_token_ts = None
|
||||
# init_chunk = None
|
||||
# async for chunk in response:
|
||||
# if (
|
||||
# time_to_first_token is None
|
||||
# and len(chunk.choices) > 0
|
||||
# and chunk.choices[0].delta.content is not None
|
||||
# ):
|
||||
# first_token_ts = time.time()
|
||||
# time_to_first_token = first_token_ts - start_time
|
||||
# init_chunk = chunk
|
||||
# end_time = time.time()
|
||||
# print(
|
||||
# "Router Call: ",
|
||||
# init_chunk,
|
||||
# start_time,
|
||||
# first_token_ts,
|
||||
# time_to_first_token,
|
||||
# end_time - first_token_ts,
|
||||
# )
|
||||
# return time_to_first_token
|
||||
# except Exception as e:
|
||||
# print(e)
|
||||
# return None
|
||||
async def _router_completion():
|
||||
try:
|
||||
start_time = time.time()
|
||||
response = await router.acompletion(
|
||||
model="azure-test",
|
||||
messages=[{"role": "user", "content": f"This is a test: {uuid.uuid4()}"}],
|
||||
stream=True,
|
||||
)
|
||||
time_to_first_token = None
|
||||
first_token_ts = None
|
||||
init_chunk = None
|
||||
async for chunk in response:
|
||||
if (
|
||||
time_to_first_token is None
|
||||
and len(chunk.choices) > 0
|
||||
and chunk.choices[0].delta.content is not None
|
||||
):
|
||||
first_token_ts = time.time()
|
||||
time_to_first_token = first_token_ts - start_time
|
||||
init_chunk = chunk
|
||||
end_time = time.time()
|
||||
print(
|
||||
"Router Call: ",
|
||||
init_chunk,
|
||||
start_time,
|
||||
first_token_ts,
|
||||
time_to_first_token,
|
||||
end_time - first_token_ts,
|
||||
)
|
||||
return time_to_first_token
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return None
|
||||
|
||||
|
||||
# async def test_azure_completion_streaming():
|
||||
# """
|
||||
# Test azure streaming call - measure on time to first (non-null) token.
|
||||
# """
|
||||
# n = 3 # Number of concurrent tasks
|
||||
# ## OPENAI AVG. TIME
|
||||
# tasks = [_openai_completion() for _ in range(n)]
|
||||
# chat_completions = await asyncio.gather(*tasks)
|
||||
# successful_completions = [c for c in chat_completions if c is not None]
|
||||
# total_time = 0
|
||||
# for item in successful_completions:
|
||||
# total_time += item
|
||||
# avg_openai_time = total_time / 3
|
||||
# ## ROUTER AVG. TIME
|
||||
# tasks = [_router_completion() for _ in range(n)]
|
||||
# chat_completions = await asyncio.gather(*tasks)
|
||||
# successful_completions = [c for c in chat_completions if c is not None]
|
||||
# total_time = 0
|
||||
# for item in successful_completions:
|
||||
# total_time += item
|
||||
# avg_router_time = total_time / 3
|
||||
# ## COMPARE
|
||||
# print(f"avg_router_time: {avg_router_time}; avg_openai_time: {avg_openai_time}")
|
||||
# assert avg_router_time < avg_openai_time + 0.5
|
||||
async def test_azure_completion_streaming():
|
||||
"""
|
||||
Test azure streaming call - measure on time to first (non-null) token.
|
||||
"""
|
||||
n = 3 # Number of concurrent tasks
|
||||
## OPENAI AVG. TIME
|
||||
tasks = [_openai_completion() for _ in range(n)]
|
||||
chat_completions = await asyncio.gather(*tasks)
|
||||
successful_completions = [c for c in chat_completions if c is not None]
|
||||
total_time = 0
|
||||
for item in successful_completions:
|
||||
total_time += item
|
||||
avg_openai_time = total_time / 3
|
||||
## ROUTER AVG. TIME
|
||||
tasks = [_router_completion() for _ in range(n)]
|
||||
chat_completions = await asyncio.gather(*tasks)
|
||||
successful_completions = [c for c in chat_completions if c is not None]
|
||||
total_time = 0
|
||||
for item in successful_completions:
|
||||
total_time += item
|
||||
avg_router_time = total_time / 3
|
||||
## COMPARE
|
||||
print(f"avg_router_time: {avg_router_time}; avg_openai_time: {avg_openai_time}")
|
||||
assert avg_router_time < avg_openai_time + 0.5
|
||||
|
||||
|
||||
# # asyncio.run(test_azure_completion_streaming())
|
||||
# asyncio.run(test_azure_completion_streaming())
|
||||
|
|
|
@ -1146,9 +1146,7 @@ async def test_exception_with_headers_httpx(
|
|||
|
||||
except litellm.RateLimitError as e:
|
||||
exception_raised = True
|
||||
assert (
|
||||
e.litellm_response_headers is not None
|
||||
), "litellm_response_headers is None"
|
||||
assert e.litellm_response_headers is not None
|
||||
print("e.litellm_response_headers", e.litellm_response_headers)
|
||||
assert int(e.litellm_response_headers["retry-after"]) == cooldown_time
|
||||
|
||||
|
|
|
@ -102,17 +102,3 @@ def test_get_model_info_ollama_chat():
|
|||
print(mock_client.call_args.kwargs)
|
||||
|
||||
assert mock_client.call_args.kwargs["json"]["name"] == "mistral"
|
||||
|
||||
|
||||
def test_get_model_info_gemini():
|
||||
"""
|
||||
Tests if ALL gemini models have 'tpm' and 'rpm' in the model info
|
||||
"""
|
||||
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
|
||||
litellm.model_cost = litellm.get_model_cost_map(url="")
|
||||
|
||||
model_map = litellm.model_cost
|
||||
for model, info in model_map.items():
|
||||
if model.startswith("gemini/") and not "gemma" in model:
|
||||
assert info.get("tpm") is not None, f"{model} does not have tpm"
|
||||
assert info.get("rpm") is not None, f"{model} does not have rpm"
|
||||
|
|
|
@ -1,79 +0,0 @@
|
|||
import pytest
|
||||
from fastapi import Request
|
||||
from fastapi.testclient import TestClient
|
||||
from starlette.datastructures import Headers
|
||||
from starlette.requests import HTTPConnection
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_request_body_valid_json():
|
||||
"""Test the function with a valid JSON payload."""
|
||||
|
||||
class MockRequest:
|
||||
async def body(self):
|
||||
return b'{"key": "value"}'
|
||||
|
||||
request = MockRequest()
|
||||
result = await _read_request_body(request)
|
||||
assert result == {"key": "value"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_request_body_empty_body():
|
||||
"""Test the function with an empty body."""
|
||||
|
||||
class MockRequest:
|
||||
async def body(self):
|
||||
return b""
|
||||
|
||||
request = MockRequest()
|
||||
result = await _read_request_body(request)
|
||||
assert result == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_request_body_invalid_json():
|
||||
"""Test the function with an invalid JSON payload."""
|
||||
|
||||
class MockRequest:
|
||||
async def body(self):
|
||||
return b'{"key": value}' # Missing quotes around `value`
|
||||
|
||||
request = MockRequest()
|
||||
result = await _read_request_body(request)
|
||||
assert result == {} # Should return an empty dict on failure
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_request_body_large_payload():
|
||||
"""Test the function with a very large payload."""
|
||||
large_payload = '{"key":' + '"a"' * 10**6 + "}" # Large payload
|
||||
|
||||
class MockRequest:
|
||||
async def body(self):
|
||||
return large_payload.encode()
|
||||
|
||||
request = MockRequest()
|
||||
result = await _read_request_body(request)
|
||||
assert result == {} # Large payloads could trigger errors, so validate behavior
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_request_body_unexpected_error():
|
||||
"""Test the function when an unexpected error occurs."""
|
||||
|
||||
class MockRequest:
|
||||
async def body(self):
|
||||
raise ValueError("Unexpected error")
|
||||
|
||||
request = MockRequest()
|
||||
result = await _read_request_body(request)
|
||||
assert result == {} # Ensure fallback behavior
|
|
@ -2115,14 +2115,10 @@ def test_router_get_model_info(model, base_model, llm_provider):
|
|||
assert deployment is not None
|
||||
|
||||
if llm_provider == "openai" or (base_model is not None and llm_provider == "azure"):
|
||||
router.get_router_model_info(
|
||||
deployment=deployment.to_json(), received_model_name=model
|
||||
)
|
||||
router.get_router_model_info(deployment=deployment.to_json())
|
||||
else:
|
||||
try:
|
||||
router.get_router_model_info(
|
||||
deployment=deployment.to_json(), received_model_name=model
|
||||
)
|
||||
router.get_router_model_info(deployment=deployment.to_json())
|
||||
pytest.fail("Expected this to raise model not mapped error")
|
||||
except Exception as e:
|
||||
if "This model isn't mapped yet" in str(e):
|
||||
|
|
|
@ -536,7 +536,7 @@ def test_init_clients_azure_command_r_plus():
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aaaaatext_completion_with_organization():
|
||||
async def test_text_completion_with_organization():
|
||||
try:
|
||||
print("Testing Text OpenAI with organization")
|
||||
model_list = [
|
||||
|
|
|
@ -174,185 +174,3 @@ async def test_update_kwargs_before_fallbacks(call_type):
|
|||
|
||||
print(mock_client.call_args.kwargs)
|
||||
assert mock_client.call_args.kwargs["litellm_trace_id"] is not None
|
||||
|
||||
|
||||
def test_router_get_model_info_wildcard_routes():
|
||||
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
|
||||
litellm.model_cost = litellm.get_model_cost_map(url="")
|
||||
router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "gemini/*",
|
||||
"litellm_params": {"model": "gemini/*"},
|
||||
"model_info": {"id": 1},
|
||||
},
|
||||
]
|
||||
)
|
||||
model_info = router.get_router_model_info(
|
||||
deployment=None, received_model_name="gemini/gemini-1.5-flash", id="1"
|
||||
)
|
||||
print(model_info)
|
||||
assert model_info is not None
|
||||
assert model_info["tpm"] is not None
|
||||
assert model_info["rpm"] is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_router_get_model_group_usage_wildcard_routes():
|
||||
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
|
||||
litellm.model_cost = litellm.get_model_cost_map(url="")
|
||||
router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "gemini/*",
|
||||
"litellm_params": {"model": "gemini/*"},
|
||||
"model_info": {"id": 1},
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
resp = await router.acompletion(
|
||||
model="gemini/gemini-1.5-flash",
|
||||
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||
mock_response="Hello, I'm good.",
|
||||
)
|
||||
print(resp)
|
||||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
tpm, rpm = await router.get_model_group_usage(model_group="gemini/gemini-1.5-flash")
|
||||
|
||||
assert tpm is not None, "tpm is None"
|
||||
assert rpm is not None, "rpm is None"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_router_callbacks_on_success():
|
||||
router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "gemini/*",
|
||||
"litellm_params": {"model": "gemini/*"},
|
||||
"model_info": {"id": 1},
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
router.cache, "async_increment_cache", new=AsyncMock()
|
||||
) as mock_callback:
|
||||
await router.acompletion(
|
||||
model="gemini/gemini-1.5-flash",
|
||||
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||
mock_response="Hello, I'm good.",
|
||||
)
|
||||
await asyncio.sleep(1)
|
||||
assert mock_callback.call_count == 2
|
||||
|
||||
assert (
|
||||
mock_callback.call_args_list[0]
|
||||
.kwargs["key"]
|
||||
.startswith("global_router:1:gemini/gemini-1.5-flash:tpm")
|
||||
)
|
||||
assert (
|
||||
mock_callback.call_args_list[1]
|
||||
.kwargs["key"]
|
||||
.startswith("global_router:1:gemini/gemini-1.5-flash:rpm")
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_router_callbacks_on_failure():
|
||||
router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "gemini/*",
|
||||
"litellm_params": {"model": "gemini/*"},
|
||||
"model_info": {"id": 1},
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
router.cache, "async_increment_cache", new=AsyncMock()
|
||||
) as mock_callback:
|
||||
with pytest.raises(litellm.RateLimitError):
|
||||
await router.acompletion(
|
||||
model="gemini/gemini-1.5-flash",
|
||||
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||
mock_response="litellm.RateLimitError",
|
||||
num_retries=0,
|
||||
)
|
||||
await asyncio.sleep(1)
|
||||
print(mock_callback.call_args_list)
|
||||
assert mock_callback.call_count == 1
|
||||
|
||||
assert (
|
||||
mock_callback.call_args_list[0]
|
||||
.kwargs["key"]
|
||||
.startswith("global_router:1:gemini/gemini-1.5-flash:rpm")
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_router_model_group_headers():
|
||||
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
|
||||
litellm.model_cost = litellm.get_model_cost_map(url="")
|
||||
from litellm.types.utils import OPENAI_RESPONSE_HEADERS
|
||||
|
||||
router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "gemini/*",
|
||||
"litellm_params": {"model": "gemini/*"},
|
||||
"model_info": {"id": 1},
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
for _ in range(2):
|
||||
resp = await router.acompletion(
|
||||
model="gemini/gemini-1.5-flash",
|
||||
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||
mock_response="Hello, I'm good.",
|
||||
)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
assert (
|
||||
resp._hidden_params["additional_headers"]["x-litellm-model-group"]
|
||||
== "gemini/gemini-1.5-flash"
|
||||
)
|
||||
|
||||
assert "x-ratelimit-remaining-requests" in resp._hidden_params["additional_headers"]
|
||||
assert "x-ratelimit-remaining-tokens" in resp._hidden_params["additional_headers"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_remaining_model_group_usage():
|
||||
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
|
||||
litellm.model_cost = litellm.get_model_cost_map(url="")
|
||||
from litellm.types.utils import OPENAI_RESPONSE_HEADERS
|
||||
|
||||
router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "gemini/*",
|
||||
"litellm_params": {"model": "gemini/*"},
|
||||
"model_info": {"id": 1},
|
||||
}
|
||||
]
|
||||
)
|
||||
for _ in range(2):
|
||||
await router.acompletion(
|
||||
model="gemini/gemini-1.5-flash",
|
||||
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||
mock_response="Hello, I'm good.",
|
||||
)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
remaining_usage = await router.get_remaining_model_group_usage(
|
||||
model_group="gemini/gemini-1.5-flash"
|
||||
)
|
||||
assert remaining_usage is not None
|
||||
assert "x-ratelimit-remaining-requests" in remaining_usage
|
||||
assert "x-ratelimit-remaining-tokens" in remaining_usage
|
||||
|
|
|
@ -506,7 +506,7 @@ async def test_router_caching_ttl():
|
|||
) as mock_client:
|
||||
await router.acompletion(model=model, messages=messages)
|
||||
|
||||
# mock_client.assert_called_once()
|
||||
mock_client.assert_called_once()
|
||||
print(f"mock_client.call_args.kwargs: {mock_client.call_args.kwargs}")
|
||||
print(f"mock_client.call_args.args: {mock_client.call_args.args}")
|
||||
|
||||
|
|
|
@ -415,18 +415,3 @@ def test_allowed_route_inside_route(
|
|||
)
|
||||
== expected_result
|
||||
)
|
||||
|
||||
|
||||
def test_read_request_body():
|
||||
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
|
||||
from fastapi import Request
|
||||
|
||||
payload = "()" * 1000000
|
||||
request = Request(scope={"type": "http"})
|
||||
|
||||
async def return_body():
|
||||
return payload
|
||||
|
||||
request.body = return_body
|
||||
result = _read_request_body(request)
|
||||
assert result is not None
|
||||
|
|
|
@ -212,7 +212,7 @@ async def test_bedrock_guardrail_triggered():
|
|||
session,
|
||||
"sk-1234",
|
||||
model="fake-openai-endpoint",
|
||||
messages=[{"role": "user", "content": "Hello do you like coffee?"}],
|
||||
messages=[{"role": "user", "content": f"Hello do you like coffee?"}],
|
||||
guardrails=["bedrock-pre-guard"],
|
||||
)
|
||||
pytest.fail("Should have thrown an exception")
|
||||
|
|
|
@ -1,71 +0,0 @@
|
|||
import pytest
|
||||
import asyncio
|
||||
import aiohttp, openai
|
||||
from openai import OpenAI, AsyncOpenAI
|
||||
from typing import Optional, List, Union
|
||||
import uuid
|
||||
|
||||
|
||||
async def make_moderations_curl_request(
|
||||
session,
|
||||
key,
|
||||
request_data: dict,
|
||||
):
|
||||
url = "http://0.0.0.0:4000/moderations"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
async with session.post(url, headers=headers, json=request_data) as response:
|
||||
status = response.status
|
||||
response_text = await response.text()
|
||||
|
||||
if status != 200:
|
||||
raise Exception(response_text)
|
||||
|
||||
return await response.json()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_moderations_on_proxy_no_model():
|
||||
"""
|
||||
Test moderations endpoint on proxy when no `model` is specified in the request
|
||||
"""
|
||||
async with aiohttp.ClientSession() as session:
|
||||
test_text = "I want to harm someone" # Test text that should trigger moderation
|
||||
request_data = {
|
||||
"input": test_text,
|
||||
}
|
||||
try:
|
||||
response = await make_moderations_curl_request(
|
||||
session,
|
||||
"sk-1234",
|
||||
request_data,
|
||||
)
|
||||
print("response=", response)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
pytest.fail("Moderations request failed")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_moderations_on_proxy_with_model():
|
||||
"""
|
||||
Test moderations endpoint on proxy when `model` is specified in the request
|
||||
"""
|
||||
async with aiohttp.ClientSession() as session:
|
||||
test_text = "I want to harm someone" # Test text that should trigger moderation
|
||||
request_data = {
|
||||
"input": test_text,
|
||||
"model": "text-moderation-stable",
|
||||
}
|
||||
try:
|
||||
response = await make_moderations_curl_request(
|
||||
session,
|
||||
"sk-1234",
|
||||
request_data,
|
||||
)
|
||||
print("response=", response)
|
||||
except Exception as e:
|
||||
pytest.fail("Moderations request failed")
|
|
@ -693,47 +693,3 @@ def test_personal_key_generation_check():
|
|||
),
|
||||
data=GenerateKeyRequest(),
|
||||
)
|
||||
|
||||
|
||||
def test_prepare_metadata_fields():
|
||||
from litellm.proxy.management_endpoints.key_management_endpoints import (
|
||||
prepare_metadata_fields,
|
||||
)
|
||||
|
||||
new_metadata = {"test": "new"}
|
||||
old_metadata = {"test": "test"}
|
||||
|
||||
args = {
|
||||
"data": UpdateKeyRequest(
|
||||
key_alias=None,
|
||||
duration=None,
|
||||
models=[],
|
||||
spend=None,
|
||||
max_budget=None,
|
||||
user_id=None,
|
||||
team_id=None,
|
||||
max_parallel_requests=None,
|
||||
metadata=new_metadata,
|
||||
tpm_limit=None,
|
||||
rpm_limit=None,
|
||||
budget_duration=None,
|
||||
allowed_cache_controls=[],
|
||||
soft_budget=None,
|
||||
config={},
|
||||
permissions={},
|
||||
model_max_budget={},
|
||||
send_invite_email=None,
|
||||
model_rpm_limit=None,
|
||||
model_tpm_limit=None,
|
||||
guardrails=None,
|
||||
blocked=None,
|
||||
aliases={},
|
||||
key="sk-1qGQUJJTcljeaPfzgWRrXQ",
|
||||
tags=None,
|
||||
),
|
||||
"non_default_values": {"metadata": new_metadata},
|
||||
"existing_metadata": {"tags": None, **old_metadata},
|
||||
}
|
||||
|
||||
non_default_values = prepare_metadata_fields(**args)
|
||||
assert non_default_values == {"metadata": new_metadata}
|
||||
|
|
|
@ -23,7 +23,7 @@ import os
|
|||
import sys
|
||||
import traceback
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from datetime import datetime
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import Request
|
||||
|
@ -1305,8 +1305,6 @@ def test_generate_and_update_key(prisma_client):
|
|||
data=UpdateKeyRequest(
|
||||
key=generated_key,
|
||||
models=["ada", "babbage", "curie", "davinci"],
|
||||
budget_duration="1mo",
|
||||
max_budget=100,
|
||||
),
|
||||
)
|
||||
|
||||
|
@ -1335,18 +1333,6 @@ def test_generate_and_update_key(prisma_client):
|
|||
}
|
||||
assert result["info"]["models"] == ["ada", "babbage", "curie", "davinci"]
|
||||
assert result["info"]["team_id"] == _team_2
|
||||
assert result["info"]["budget_duration"] == "1mo"
|
||||
assert result["info"]["max_budget"] == 100
|
||||
|
||||
# budget_reset_at should be 30 days from now
|
||||
assert result["info"]["budget_reset_at"] is not None
|
||||
budget_reset_at = result["info"]["budget_reset_at"].replace(
|
||||
tzinfo=timezone.utc
|
||||
)
|
||||
current_time = datetime.now(timezone.utc)
|
||||
|
||||
# assert budget_reset_at is 30 days from now
|
||||
assert 31 >= (budget_reset_at - current_time).days >= 29
|
||||
|
||||
# cleanup - delete key
|
||||
delete_key_request = KeyRequest(keys=[generated_key])
|
||||
|
@ -2627,15 +2613,6 @@ async def test_create_update_team(prisma_client):
|
|||
_updated_info["budget_reset_at"], datetime.datetime
|
||||
)
|
||||
|
||||
# budget_reset_at should be 2 days from now
|
||||
budget_reset_at = _updated_info["budget_reset_at"].replace(tzinfo=timezone.utc)
|
||||
current_time = datetime.datetime.now(timezone.utc)
|
||||
|
||||
# assert budget_reset_at is 2 days from now
|
||||
assert (
|
||||
abs((budget_reset_at - current_time).total_seconds() - 2 * 24 * 60 * 60) <= 10
|
||||
)
|
||||
|
||||
# now hit team_info
|
||||
try:
|
||||
response = await team_info(
|
||||
|
@ -2779,56 +2756,6 @@ async def test_update_user_role(prisma_client):
|
|||
print("result from user auth with new key", result)
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_update_user_unit_test(prisma_client):
|
||||
"""
|
||||
Unit test for /user/update
|
||||
|
||||
Ensure that params are updated for UpdateUserRequest
|
||||
"""
|
||||
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||
key = await new_user(
|
||||
data=NewUserRequest(
|
||||
user_email="test@test.com",
|
||||
)
|
||||
)
|
||||
|
||||
print(key)
|
||||
|
||||
user_info = await user_update(
|
||||
data=UpdateUserRequest(
|
||||
user_id=key.user_id,
|
||||
team_id="1234",
|
||||
max_budget=100,
|
||||
budget_duration="10d",
|
||||
tpm_limit=100,
|
||||
rpm_limit=100,
|
||||
metadata={"very-new-metadata": "something"},
|
||||
)
|
||||
)
|
||||
|
||||
print("user_info", user_info)
|
||||
assert user_info is not None
|
||||
_user_info = user_info["data"].model_dump()
|
||||
|
||||
assert _user_info["user_id"] == key.user_id
|
||||
assert _user_info["team_id"] == "1234"
|
||||
assert _user_info["max_budget"] == 100
|
||||
assert _user_info["budget_duration"] == "10d"
|
||||
assert _user_info["tpm_limit"] == 100
|
||||
assert _user_info["rpm_limit"] == 100
|
||||
assert _user_info["metadata"] == {"very-new-metadata": "something"}
|
||||
|
||||
# budget reset at should be 10 days from now
|
||||
budget_reset_at = _user_info["budget_reset_at"].replace(tzinfo=timezone.utc)
|
||||
current_time = datetime.now(timezone.utc)
|
||||
assert (
|
||||
abs((budget_reset_at - current_time).total_seconds() - 10 * 24 * 60 * 60) <= 10
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_custom_api_key_header_name(prisma_client):
|
||||
""" """
|
||||
|
@ -2917,6 +2844,7 @@ async def test_generate_key_with_model_tpm_limit(prisma_client):
|
|||
"team": "litellm-team3",
|
||||
"model_tpm_limit": {"gpt-4": 100},
|
||||
"model_rpm_limit": {"gpt-4": 2},
|
||||
"tags": None,
|
||||
}
|
||||
|
||||
# Update model tpm_limit and rpm_limit
|
||||
|
@ -2940,6 +2868,7 @@ async def test_generate_key_with_model_tpm_limit(prisma_client):
|
|||
"team": "litellm-team3",
|
||||
"model_tpm_limit": {"gpt-4": 200},
|
||||
"model_rpm_limit": {"gpt-4": 3},
|
||||
"tags": None,
|
||||
}
|
||||
|
||||
|
||||
|
@ -2979,6 +2908,7 @@ async def test_generate_key_with_guardrails(prisma_client):
|
|||
assert result["info"]["metadata"] == {
|
||||
"team": "litellm-team3",
|
||||
"guardrails": ["aporia-pre-call"],
|
||||
"tags": None,
|
||||
}
|
||||
|
||||
# Update model tpm_limit and rpm_limit
|
||||
|
@ -3000,6 +2930,7 @@ async def test_generate_key_with_guardrails(prisma_client):
|
|||
assert result["info"]["metadata"] == {
|
||||
"team": "litellm-team3",
|
||||
"guardrails": ["aporia-pre-call", "aporia-post-call"],
|
||||
"tags": None,
|
||||
}
|
||||
|
||||
|
||||
|
@ -3619,152 +3550,3 @@ async def test_key_generate_with_secret_manager_call(prisma_client):
|
|||
|
||||
|
||||
################################################################################
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_key_alias_uniqueness(prisma_client):
|
||||
"""
|
||||
Test that:
|
||||
1. We cannot create two keys with the same alias
|
||||
2. We cannot update a key to use an alias that's already taken
|
||||
3. We can update a key while keeping its existing alias
|
||||
"""
|
||||
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||
|
||||
try:
|
||||
# Create first key with an alias
|
||||
unique_alias = f"test-alias-{uuid.uuid4()}"
|
||||
key1 = await generate_key_fn(
|
||||
data=GenerateKeyRequest(key_alias=unique_alias),
|
||||
user_api_key_dict=UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
api_key="sk-1234",
|
||||
user_id="1234",
|
||||
),
|
||||
)
|
||||
|
||||
# Try to create second key with same alias - should fail
|
||||
try:
|
||||
key2 = await generate_key_fn(
|
||||
data=GenerateKeyRequest(key_alias=unique_alias),
|
||||
user_api_key_dict=UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
api_key="sk-1234",
|
||||
user_id="1234",
|
||||
),
|
||||
)
|
||||
pytest.fail("Should not be able to create a second key with the same alias")
|
||||
except Exception as e:
|
||||
print("vars(e)=", vars(e))
|
||||
assert "Unique key aliases across all keys are required" in str(e.message)
|
||||
|
||||
# Create another key with different alias
|
||||
another_alias = f"test-alias-{uuid.uuid4()}"
|
||||
key3 = await generate_key_fn(
|
||||
data=GenerateKeyRequest(key_alias=another_alias),
|
||||
user_api_key_dict=UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
api_key="sk-1234",
|
||||
user_id="1234",
|
||||
),
|
||||
)
|
||||
|
||||
# Try to update key3 to use key1's alias - should fail
|
||||
try:
|
||||
await update_key_fn(
|
||||
data=UpdateKeyRequest(key=key3.key, key_alias=unique_alias),
|
||||
request=Request(scope={"type": "http"}),
|
||||
)
|
||||
pytest.fail("Should not be able to update a key to use an existing alias")
|
||||
except Exception as e:
|
||||
assert "Unique key aliases across all keys are required" in str(e.message)
|
||||
|
||||
# Update key1 with its own existing alias - should succeed
|
||||
updated_key = await update_key_fn(
|
||||
data=UpdateKeyRequest(key=key1.key, key_alias=unique_alias),
|
||||
request=Request(scope={"type": "http"}),
|
||||
)
|
||||
assert updated_key is not None
|
||||
|
||||
except Exception as e:
|
||||
print("got exceptions, e=", e)
|
||||
print("vars(e)=", vars(e))
|
||||
pytest.fail(f"An unexpected error occurred: {str(e)}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enforce_unique_key_alias(prisma_client):
|
||||
"""
|
||||
Unit test the _enforce_unique_key_alias function:
|
||||
1. Test it allows unique aliases
|
||||
2. Test it blocks duplicate aliases for new keys
|
||||
3. Test it allows updating a key with its own existing alias
|
||||
4. Test it blocks updating a key with another key's alias
|
||||
"""
|
||||
from litellm.proxy.management_endpoints.key_management_endpoints import (
|
||||
_enforce_unique_key_alias,
|
||||
)
|
||||
|
||||
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||
|
||||
try:
|
||||
# Test 1: Allow unique alias
|
||||
unique_alias = f"test-alias-{uuid.uuid4()}"
|
||||
await _enforce_unique_key_alias(
|
||||
key_alias=unique_alias,
|
||||
prisma_client=prisma_client,
|
||||
) # Should pass
|
||||
|
||||
# Create a key with this alias in the database
|
||||
key1 = await generate_key_fn(
|
||||
data=GenerateKeyRequest(key_alias=unique_alias),
|
||||
user_api_key_dict=UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
api_key="sk-1234",
|
||||
user_id="1234",
|
||||
),
|
||||
)
|
||||
|
||||
# Test 2: Block duplicate alias for new key
|
||||
try:
|
||||
await _enforce_unique_key_alias(
|
||||
key_alias=unique_alias,
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
pytest.fail("Should not allow duplicate alias")
|
||||
except Exception as e:
|
||||
assert "Unique key aliases across all keys are required" in str(e.message)
|
||||
|
||||
# Test 3: Allow updating key with its own alias
|
||||
await _enforce_unique_key_alias(
|
||||
key_alias=unique_alias,
|
||||
existing_key_token=hash_token(key1.key),
|
||||
prisma_client=prisma_client,
|
||||
) # Should pass
|
||||
|
||||
# Test 4: Block updating with another key's alias
|
||||
another_key = await generate_key_fn(
|
||||
data=GenerateKeyRequest(key_alias=f"test-alias-{uuid.uuid4()}"),
|
||||
user_api_key_dict=UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
api_key="sk-1234",
|
||||
user_id="1234",
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
await _enforce_unique_key_alias(
|
||||
key_alias=unique_alias,
|
||||
existing_key_token=another_key.key,
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
pytest.fail("Should not allow using another key's alias")
|
||||
except Exception as e:
|
||||
assert "Unique key aliases across all keys are required" in str(e.message)
|
||||
|
||||
except Exception as e:
|
||||
print("Unexpected error:", e)
|
||||
pytest.fail(f"An unexpected error occurred: {str(e)}")
|
||||
|
|
|
@ -2195,126 +2195,3 @@ async def test_async_log_proxy_authentication_errors():
|
|||
assert (
|
||||
mock_logger.user_api_key_dict_logged.token is not None
|
||||
) # token should be hashed
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_log_proxy_authentication_errors_get_request():
|
||||
"""
|
||||
Test if async_log_proxy_authentication_errors correctly handles GET requests
|
||||
that don't have a JSON body
|
||||
"""
|
||||
import json
|
||||
from fastapi import Request
|
||||
from litellm.proxy.utils import ProxyLogging
|
||||
from litellm.caching import DualCache
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
|
||||
class MockCustomLogger(CustomLogger):
|
||||
def __init__(self):
|
||||
self.called = False
|
||||
self.exception_logged = None
|
||||
self.request_data_logged = None
|
||||
self.user_api_key_dict_logged = None
|
||||
|
||||
async def async_post_call_failure_hook(
|
||||
self,
|
||||
request_data: dict,
|
||||
original_exception: Exception,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
):
|
||||
self.called = True
|
||||
self.exception_logged = original_exception
|
||||
self.request_data_logged = request_data
|
||||
self.user_api_key_dict_logged = user_api_key_dict
|
||||
|
||||
# Create a mock GET request
|
||||
request = Request(scope={"type": "http", "method": "GET"})
|
||||
|
||||
# Mock the json() method to raise JSONDecodeError
|
||||
async def mock_json():
|
||||
raise json.JSONDecodeError("Expecting value", "", 0)
|
||||
|
||||
request.json = mock_json
|
||||
|
||||
# Create a test exception
|
||||
test_exception = Exception("Invalid API Key")
|
||||
|
||||
# Initialize ProxyLogging
|
||||
mock_logger = MockCustomLogger()
|
||||
litellm.callbacks = [mock_logger]
|
||||
proxy_logging_obj = ProxyLogging(user_api_key_cache=DualCache())
|
||||
|
||||
# Call the method
|
||||
await proxy_logging_obj.async_log_proxy_authentication_errors(
|
||||
original_exception=test_exception,
|
||||
request=request,
|
||||
parent_otel_span=None,
|
||||
api_key="test-key",
|
||||
)
|
||||
|
||||
# Verify the mock logger was called with correct parameters
|
||||
assert mock_logger.called == True
|
||||
assert mock_logger.exception_logged == test_exception
|
||||
assert mock_logger.user_api_key_dict_logged is not None
|
||||
assert mock_logger.user_api_key_dict_logged.token is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_log_proxy_authentication_errors_no_api_key():
|
||||
"""
|
||||
Test if async_log_proxy_authentication_errors correctly handles requests
|
||||
with no API key provided
|
||||
"""
|
||||
from fastapi import Request
|
||||
from litellm.proxy.utils import ProxyLogging
|
||||
from litellm.caching import DualCache
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
|
||||
class MockCustomLogger(CustomLogger):
|
||||
def __init__(self):
|
||||
self.called = False
|
||||
self.exception_logged = None
|
||||
self.request_data_logged = None
|
||||
self.user_api_key_dict_logged = None
|
||||
|
||||
async def async_post_call_failure_hook(
|
||||
self,
|
||||
request_data: dict,
|
||||
original_exception: Exception,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
):
|
||||
self.called = True
|
||||
self.exception_logged = original_exception
|
||||
self.request_data_logged = request_data
|
||||
self.user_api_key_dict_logged = user_api_key_dict
|
||||
|
||||
# Create test data
|
||||
test_data = {"model": "gpt-4", "messages": [{"role": "user", "content": "Hello"}]}
|
||||
|
||||
# Create a mock request
|
||||
request = Request(scope={"type": "http", "method": "POST"})
|
||||
request._json = AsyncMock(return_value=test_data)
|
||||
|
||||
# Create a test exception
|
||||
test_exception = Exception("No API Key Provided")
|
||||
|
||||
# Initialize ProxyLogging
|
||||
mock_logger = MockCustomLogger()
|
||||
litellm.callbacks = [mock_logger]
|
||||
proxy_logging_obj = ProxyLogging(user_api_key_cache=DualCache())
|
||||
|
||||
# Call the method with api_key=None
|
||||
await proxy_logging_obj.async_log_proxy_authentication_errors(
|
||||
original_exception=test_exception,
|
||||
request=request,
|
||||
parent_otel_span=None,
|
||||
api_key=None,
|
||||
)
|
||||
|
||||
# Verify the mock logger was called with correct parameters
|
||||
assert mock_logger.called == True
|
||||
assert mock_logger.exception_logged == test_exception
|
||||
assert mock_logger.user_api_key_dict_logged is not None
|
||||
assert (
|
||||
mock_logger.user_api_key_dict_logged.token == ""
|
||||
) # Empty token for no API key
|
||||
|
|
|
@ -444,7 +444,7 @@ def test_foward_litellm_user_info_to_backend_llm_call():
|
|||
|
||||
def test_update_internal_user_params():
|
||||
from litellm.proxy.management_endpoints.internal_user_endpoints import (
|
||||
_update_internal_new_user_params,
|
||||
_update_internal_user_params,
|
||||
)
|
||||
from litellm.proxy._types import NewUserRequest
|
||||
|
||||
|
@ -456,7 +456,7 @@ def test_update_internal_user_params():
|
|||
|
||||
data = NewUserRequest(user_role="internal_user", user_email="krrish3@berri.ai")
|
||||
data_json = data.model_dump()
|
||||
updated_data_json = _update_internal_new_user_params(data_json, data)
|
||||
updated_data_json = _update_internal_user_params(data_json, data)
|
||||
assert updated_data_json["models"] == litellm.default_internal_user_params["models"]
|
||||
assert (
|
||||
updated_data_json["max_budget"]
|
||||
|
@ -530,7 +530,7 @@ def test_prepare_key_update_data():
|
|||
|
||||
data = UpdateKeyRequest(key="test_key", metadata=None)
|
||||
updated_data = prepare_key_update_data(data, existing_key_row)
|
||||
assert updated_data["metadata"] is None
|
||||
assert updated_data["metadata"] == None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -574,108 +574,3 @@ def test_get_docs_url(env_vars, expected_url):
|
|||
|
||||
result = _get_docs_url()
|
||||
assert result == expected_url
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"request_tags, tags_to_add, expected_tags",
|
||||
[
|
||||
(None, None, []), # both None
|
||||
(["tag1", "tag2"], None, ["tag1", "tag2"]), # tags_to_add is None
|
||||
(None, ["tag3", "tag4"], ["tag3", "tag4"]), # request_tags is None
|
||||
(
|
||||
["tag1", "tag2"],
|
||||
["tag3", "tag4"],
|
||||
["tag1", "tag2", "tag3", "tag4"],
|
||||
), # both have unique tags
|
||||
(
|
||||
["tag1", "tag2"],
|
||||
["tag2", "tag3"],
|
||||
["tag1", "tag2", "tag3"],
|
||||
), # overlapping tags
|
||||
([], [], []), # both empty lists
|
||||
("not_a_list", ["tag1"], ["tag1"]), # request_tags invalid type
|
||||
(["tag1"], "not_a_list", ["tag1"]), # tags_to_add invalid type
|
||||
(
|
||||
["tag1"],
|
||||
["tag1", "tag2"],
|
||||
["tag1", "tag2"],
|
||||
), # duplicate tags in inputs
|
||||
],
|
||||
)
|
||||
def test_merge_tags(request_tags, tags_to_add, expected_tags):
|
||||
from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup
|
||||
|
||||
result = LiteLLMProxyRequestSetup._merge_tags(
|
||||
request_tags=request_tags, tags_to_add=tags_to_add
|
||||
)
|
||||
|
||||
assert isinstance(result, list)
|
||||
assert sorted(result) == sorted(expected_tags)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"key_tags, request_tags, expected_tags",
|
||||
[
|
||||
# exact duplicates
|
||||
(["tag1", "tag2", "tag3"], ["tag1", "tag2", "tag3"], ["tag1", "tag2", "tag3"]),
|
||||
# partial duplicates
|
||||
(
|
||||
["tag1", "tag2", "tag3"],
|
||||
["tag2", "tag3", "tag4"],
|
||||
["tag1", "tag2", "tag3", "tag4"],
|
||||
),
|
||||
# duplicates within key tags
|
||||
(["tag1", "tag2"], ["tag3", "tag4"], ["tag1", "tag2", "tag3", "tag4"]),
|
||||
# duplicates within request tags
|
||||
(["tag1", "tag2"], ["tag2", "tag3", "tag4"], ["tag1", "tag2", "tag3", "tag4"]),
|
||||
# case sensitive duplicates
|
||||
(["Tag1", "TAG2"], ["tag1", "tag2"], ["Tag1", "TAG2", "tag1", "tag2"]),
|
||||
],
|
||||
)
|
||||
async def test_add_litellm_data_to_request_duplicate_tags(
|
||||
key_tags, request_tags, expected_tags
|
||||
):
|
||||
"""
|
||||
Test to verify duplicate tags between request and key metadata are handled correctly
|
||||
|
||||
|
||||
Aggregation logic when checking spend can be impacted if duplicate tags are not handled correctly.
|
||||
|
||||
User feedback:
|
||||
"If I register my key with tag1 and
|
||||
also pass the same tag1 when using the key
|
||||
then I see tag1 twice in the
|
||||
LiteLLM_SpendLogs table request_tags column. This can mess up aggregation logic"
|
||||
"""
|
||||
mock_request = Mock(spec=Request)
|
||||
mock_request.url.path = "/chat/completions"
|
||||
mock_request.query_params = {}
|
||||
mock_request.headers = {}
|
||||
|
||||
# Setup key with tags in metadata
|
||||
user_api_key_dict = UserAPIKeyAuth(
|
||||
api_key="test_api_key",
|
||||
user_id="test_user_id",
|
||||
org_id="test_org_id",
|
||||
metadata={"tags": key_tags},
|
||||
)
|
||||
|
||||
# Setup request data with tags
|
||||
data = {"metadata": {"tags": request_tags}}
|
||||
|
||||
# Process request
|
||||
proxy_config = Mock()
|
||||
result = await add_litellm_data_to_request(
|
||||
data=data,
|
||||
request=mock_request,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
proxy_config=proxy_config,
|
||||
)
|
||||
|
||||
# Verify results
|
||||
assert "metadata" in result
|
||||
assert "tags" in result["metadata"]
|
||||
assert sorted(result["metadata"]["tags"]) == sorted(
|
||||
expected_tags
|
||||
), f"Expected {expected_tags}, got {result['metadata']['tags']}"
|
||||
|
|
|
@ -215,7 +215,7 @@ async def test_rerank_endpoint(model_list):
|
|||
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
@pytest.mark.asyncio
|
||||
async def test_aaaaatext_completion_endpoint(model_list, sync_mode):
|
||||
async def test_text_completion_endpoint(model_list, sync_mode):
|
||||
router = Router(model_list=model_list)
|
||||
|
||||
if sync_mode:
|
||||
|
|
|
@ -396,8 +396,7 @@ async def test_deployment_callback_on_success(model_list, sync_mode):
|
|||
assert tpm_key is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deployment_callback_on_failure(model_list):
|
||||
def test_deployment_callback_on_failure(model_list):
|
||||
"""Test if the '_deployment_callback_on_failure' function is working correctly"""
|
||||
import time
|
||||
|
||||
|
@ -419,18 +418,6 @@ async def test_deployment_callback_on_failure(model_list):
|
|||
assert isinstance(result, bool)
|
||||
assert result is False
|
||||
|
||||
model_response = router.completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||
mock_response="I'm fine, thank you!",
|
||||
)
|
||||
result = await router.async_deployment_callback_on_failure(
|
||||
kwargs=kwargs,
|
||||
completion_response=model_response,
|
||||
start_time=time.time(),
|
||||
end_time=time.time(),
|
||||
)
|
||||
|
||||
|
||||
def test_log_retry(model_list):
|
||||
"""Test if the '_log_retry' function is working correctly"""
|
||||
|
@ -1040,11 +1027,8 @@ def test_pattern_match_deployment_set_model_name(
|
|||
async def test_pass_through_moderation_endpoint_factory(model_list):
|
||||
router = Router(model_list=model_list)
|
||||
response = await router._pass_through_moderation_endpoint_factory(
|
||||
original_function=litellm.amoderation,
|
||||
input="this is valid good text",
|
||||
model=None,
|
||||
original_function=litellm.amoderation, input="this is valid good text"
|
||||
)
|
||||
assert response is not None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
|
|
@ -300,7 +300,6 @@ async def test_key_update(metadata):
|
|||
get_key=key,
|
||||
metadata=metadata,
|
||||
)
|
||||
print(f"updated_key['metadata']: {updated_key['metadata']}")
|
||||
assert updated_key["metadata"] == metadata
|
||||
await update_proxy_budget(session=session) # resets proxy spend
|
||||
await chat_completion(session=session, key=key)
|
||||
|
|
|
@ -114,7 +114,7 @@ async def test_spend_logs():
|
|||
|
||||
|
||||
async def get_predict_spend_logs(session):
|
||||
url = "http://0.0.0.0:4000/global/predict/spend/logs"
|
||||
url = f"http://0.0.0.0:4000/global/predict/spend/logs"
|
||||
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
|
||||
data = {
|
||||
"data": [
|
||||
|
@ -155,7 +155,6 @@ async def get_spend_report(session, start_date, end_date):
|
|||
return await response.json()
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="datetime in ci/cd gets set weirdly")
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_predicted_spend_logs():
|
||||
"""
|
||||
|
|
|
@ -24,7 +24,6 @@ import {
|
|||
Icon,
|
||||
BarChart,
|
||||
TextInput,
|
||||
Textarea,
|
||||
} from "@tremor/react";
|
||||
import { Select as Select3, SelectItem, MultiSelect, MultiSelectItem } from "@tremor/react";
|
||||
import {
|
||||
|
@ -41,7 +40,6 @@ import {
|
|||
} from "antd";
|
||||
|
||||
import { CopyToClipboard } from "react-copy-to-clipboard";
|
||||
import TextArea from "antd/es/input/TextArea";
|
||||
|
||||
const { Option } = Select;
|
||||
const isLocal = process.env.NODE_ENV === "development";
|
||||
|
@ -440,16 +438,6 @@ const ViewKeyTable: React.FC<ViewKeyTableProps> = ({
|
|||
>
|
||||
<InputNumber step={1} precision={1} width={200} />
|
||||
</Form.Item>
|
||||
<Form.Item
|
||||
label="Metadata"
|
||||
name="metadata"
|
||||
initialValue={token.metadata}
|
||||
>
|
||||
<TextArea
|
||||
value={String(token.metadata)}
|
||||
rows={10}
|
||||
/>
|
||||
</Form.Item>
|
||||
</>
|
||||
<div style={{ textAlign: "right", marginTop: "10px" }}>
|
||||
<Button2 htmlType="submit">Edit Key</Button2>
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue