LiteLLM Minor Fixes & Improvements (11/26/2024) (#6913)

* docs(config_settings.md): document all router_settings

* ci(config.yml): add router_settings doc test to ci/cd

* test: debug test on ci/cd

* test: debug ci/cd test

* test: fix test

* fix(team_endpoints.py): skip invalid team object. don't fail `/team/list` call

Causes downstream errors if ui just fails to load team list

* test(base_llm_unit_tests.py): add 'response_format={"type": "text"}' test to base_llm_unit_tests

adds complete coverage for all 'response_format' values to ci/cd

* feat(router.py): support wildcard routes in `get_router_model_info()`

Addresses https://github.com/BerriAI/litellm/issues/6914

* build(model_prices_and_context_window.json): add tpm/rpm limits for all gemini models

Allows for ratelimit tracking for gemini models even with wildcard routing enabled

Addresses https://github.com/BerriAI/litellm/issues/6914

* feat(router.py): add tpm/rpm tracking on success/failure to global_router

Addresses https://github.com/BerriAI/litellm/issues/6914

* feat(router.py): support wildcard routes on router.get_model_group_usage()

* fix(router.py): fix linting error

* fix(router.py): implement get_remaining_tokens_and_requests

Addresses https://github.com/BerriAI/litellm/issues/6914

* fix(router.py): fix linting errors

* test: fix test

* test: fix tests

* docs(config_settings.md): add missing dd env vars to docs

* fix(router.py): check if hidden params is dict
This commit is contained in:
Krish Dholakia 2024-11-28 00:01:38 +05:30 committed by GitHub
parent 5d13302e6b
commit 2d2931a215
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
22 changed files with 878 additions and 131 deletions

View file

@ -811,7 +811,8 @@ jobs:
- run: python ./tests/code_coverage_tests/router_code_coverage.py
- run: python ./tests/code_coverage_tests/test_router_strategy_async.py
- run: python ./tests/code_coverage_tests/litellm_logging_code_coverage.py
# - run: python ./tests/documentation_tests/test_env_keys.py
- run: python ./tests/documentation_tests/test_env_keys.py
- run: python ./tests/documentation_tests/test_router_settings.py
- run: python ./tests/documentation_tests/test_api_docs.py
- run: python ./tests/code_coverage_tests/ensure_async_clients_test.py
- run: helm lint ./deploy/charts/litellm-helm

View file

@ -279,7 +279,31 @@ router_settings:
| retry_policy | object | Specifies the number of retries for different types of exceptions. [More information here](reliability) |
| allowed_fails | integer | The number of failures allowed before cooling down a model. [More information here](reliability) |
| allowed_fails_policy | object | Specifies the number of allowed failures for different error types before cooling down a deployment. [More information here](reliability) |
| default_max_parallel_requests | Optional[int] | The default maximum number of parallel requests for a deployment. |
| default_priority | (Optional[int]) | The default priority for a request. Only for '.scheduler_acompletion()'. Default is None. |
| polling_interval | (Optional[float]) | frequency of polling queue. Only for '.scheduler_acompletion()'. Default is 3ms. |
| max_fallbacks | Optional[int] | The maximum number of fallbacks to try before exiting the call. Defaults to 5. |
| default_litellm_params | Optional[dict] | The default litellm parameters to add to all requests (e.g. `temperature`, `max_tokens`). |
| timeout | Optional[float] | The default timeout for a request. |
| debug_level | Literal["DEBUG", "INFO"] | The debug level for the logging library in the router. Defaults to "INFO". |
| client_ttl | int | Time-to-live for cached clients in seconds. Defaults to 3600. |
| cache_kwargs | dict | Additional keyword arguments for the cache initialization. |
| routing_strategy_args | dict | Additional keyword arguments for the routing strategy - e.g. lowest latency routing default ttl |
| model_group_alias | dict | Model group alias mapping. E.g. `{"claude-3-haiku": "claude-3-haiku-20240229"}` |
| num_retries | int | Number of retries for a request. Defaults to 3. |
| default_fallbacks | Optional[List[str]] | Fallbacks to try if no model group-specific fallbacks are defined. |
| caching_groups | Optional[List[tuple]] | List of model groups for caching across model groups. Defaults to None. - e.g. caching_groups=[("openai-gpt-3.5-turbo", "azure-gpt-3.5-turbo")]|
| alerting_config | AlertingConfig | [SDK-only arg] Slack alerting configuration. Defaults to None. [Further Docs](../routing.md#alerting-) |
| assistants_config | AssistantsConfig | Set on proxy via `assistant_settings`. [Further docs](../assistants.md) |
| set_verbose | boolean | [DEPRECATED PARAM - see debug docs](./debugging.md) If true, sets the logging level to verbose. |
| retry_after | int | Time to wait before retrying a request in seconds. Defaults to 0. If `x-retry-after` is received from LLM API, this value is overridden. |
| provider_budget_config | ProviderBudgetConfig | Provider budget configuration. Use this to set llm_provider budget limits. example $100/day to OpenAI, $100/day to Azure, etc. Defaults to None. [Further Docs](./provider_budget_routing.md) |
| enable_pre_call_checks | boolean | If true, checks if a call is within the model's context window before making the call. [More information here](reliability) |
| model_group_retry_policy | Dict[str, RetryPolicy] | [SDK-only arg] Set retry policy for model groups. |
| context_window_fallbacks | List[Dict[str, List[str]]] | Fallback models for context window violations. |
| redis_url | str | URL for Redis server. **Known performance issue with Redis URL.** |
| cache_responses | boolean | Flag to enable caching LLM Responses, if cache set under `router_settings`. If true, caches responses. Defaults to False. |
| router_general_settings | RouterGeneralSettings | [SDK-Only] Router general settings - contains optimizations like 'async_only_mode'. [Docs](../routing.md#router-general-settings) |
### environment variables - Reference
@ -335,6 +359,8 @@ router_settings:
| DD_SITE | Site URL for Datadog (e.g., datadoghq.com)
| DD_SOURCE | Source identifier for Datadog logs
| DD_ENV | Environment identifier for Datadog logs. Only supported for `datadog_llm_observability` callback
| DD_SERVICE | Service identifier for Datadog logs. Defaults to "litellm-server"
| DD_VERSION | Version identifier for Datadog logs. Defaults to "unknown"
| DEBUG_OTEL | Enable debug mode for OpenTelemetry
| DIRECT_URL | Direct URL for service endpoint
| DISABLE_ADMIN_UI | Toggle to disable the admin UI

View file

@ -357,77 +357,6 @@ curl --location 'http://0.0.0.0:4000/v1/model/info' \
--data ''
```
### Provider specific wildcard routing
**Proxy all models from a provider**
Use this if you want to **proxy all models from a specific provider without defining them on the config.yaml**
**Step 1** - define provider specific routing on config.yaml
```yaml
model_list:
# provider specific wildcard routing
- model_name: "anthropic/*"
litellm_params:
model: "anthropic/*"
api_key: os.environ/ANTHROPIC_API_KEY
- model_name: "groq/*"
litellm_params:
model: "groq/*"
api_key: os.environ/GROQ_API_KEY
- model_name: "fo::*:static::*" # all requests matching this pattern will be routed to this deployment, example: model="fo::hi::static::hi" will be routed to deployment: "openai/fo::*:static::*"
litellm_params:
model: "openai/fo::*:static::*"
api_key: os.environ/OPENAI_API_KEY
```
Step 2 - Run litellm proxy
```shell
$ litellm --config /path/to/config.yaml
```
Step 3 Test it
Test with `anthropic/` - all models with `anthropic/` prefix will get routed to `anthropic/*`
```shell
curl http://localhost:4000/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer sk-1234" \
-d '{
"model": "anthropic/claude-3-sonnet-20240229",
"messages": [
{"role": "user", "content": "Hello, Claude!"}
]
}'
```
Test with `groq/` - all models with `groq/` prefix will get routed to `groq/*`
```shell
curl http://localhost:4000/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer sk-1234" \
-d '{
"model": "groq/llama3-8b-8192",
"messages": [
{"role": "user", "content": "Hello, Claude!"}
]
}'
```
Test with `fo::*::static::*` - all requests matching this pattern will be routed to `openai/fo::*:static::*`
```shell
curl http://localhost:4000/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer sk-1234" \
-d '{
"model": "fo::hi::static::hi",
"messages": [
{"role": "user", "content": "Hello, Claude!"}
]
}'
```
### Load Balancing
:::info

View file

@ -1891,3 +1891,22 @@ router = Router(
debug_level="DEBUG" # defaults to INFO
)
```
## Router General Settings
### Usage
```python
router = Router(model_list=..., router_general_settings=RouterGeneralSettings(async_only_mode=True))
```
### Spec
```python
class RouterGeneralSettings(BaseModel):
async_only_mode: bool = Field(
default=False
) # this will only initialize async clients. Good for memory utils
pass_through_all_models: bool = Field(
default=False
) # if passed a model not llm_router model list, pass through the request to litellm.acompletion/embedding
```

View file

@ -0,0 +1,140 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# Provider specific Wildcard routing
**Proxy all models from a provider**
Use this if you want to **proxy all models from a specific provider without defining them on the config.yaml**
## Step 1. Define provider specific routing
<Tabs>
<TabItem value="sdk" label="SDK">
```python
from litellm import Router
router = Router(
model_list=[
{
"model_name": "anthropic/*",
"litellm_params": {
"model": "anthropic/*",
"api_key": os.environ["ANTHROPIC_API_KEY"]
}
},
{
"model_name": "groq/*",
"litellm_params": {
"model": "groq/*",
"api_key": os.environ["GROQ_API_KEY"]
}
},
{
"model_name": "fo::*:static::*", # all requests matching this pattern will be routed to this deployment, example: model="fo::hi::static::hi" will be routed to deployment: "openai/fo::*:static::*"
"litellm_params": {
"model": "openai/fo::*:static::*",
"api_key": os.environ["OPENAI_API_KEY"]
}
}
]
)
```
</TabItem>
<TabItem value="proxy" label="PROXY">
**Step 1** - define provider specific routing on config.yaml
```yaml
model_list:
# provider specific wildcard routing
- model_name: "anthropic/*"
litellm_params:
model: "anthropic/*"
api_key: os.environ/ANTHROPIC_API_KEY
- model_name: "groq/*"
litellm_params:
model: "groq/*"
api_key: os.environ/GROQ_API_KEY
- model_name: "fo::*:static::*" # all requests matching this pattern will be routed to this deployment, example: model="fo::hi::static::hi" will be routed to deployment: "openai/fo::*:static::*"
litellm_params:
model: "openai/fo::*:static::*"
api_key: os.environ/OPENAI_API_KEY
```
</TabItem>
</Tabs>
## [PROXY-Only] Step 2 - Run litellm proxy
```shell
$ litellm --config /path/to/config.yaml
```
## Step 3 - Test it
<Tabs>
<TabItem value="sdk" label="SDK">
```python
from litellm import Router
router = Router(model_list=...)
# Test with `anthropic/` - all models with `anthropic/` prefix will get routed to `anthropic/*`
resp = completion(model="anthropic/claude-3-sonnet-20240229", messages=[{"role": "user", "content": "Hello, Claude!"}])
print(resp)
# Test with `groq/` - all models with `groq/` prefix will get routed to `groq/*`
resp = completion(model="groq/llama3-8b-8192", messages=[{"role": "user", "content": "Hello, Groq!"}])
print(resp)
# Test with `fo::*::static::*` - all requests matching this pattern will be routed to `openai/fo::*:static::*`
resp = completion(model="fo::hi::static::hi", messages=[{"role": "user", "content": "Hello, Claude!"}])
print(resp)
```
</TabItem>
<TabItem value="proxy" label="PROXY">
Test with `anthropic/` - all models with `anthropic/` prefix will get routed to `anthropic/*`
```bash
curl http://localhost:4000/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer sk-1234" \
-d '{
"model": "anthropic/claude-3-sonnet-20240229",
"messages": [
{"role": "user", "content": "Hello, Claude!"}
]
}'
```
Test with `groq/` - all models with `groq/` prefix will get routed to `groq/*`
```shell
curl http://localhost:4000/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer sk-1234" \
-d '{
"model": "groq/llama3-8b-8192",
"messages": [
{"role": "user", "content": "Hello, Claude!"}
]
}'
```
Test with `fo::*::static::*` - all requests matching this pattern will be routed to `openai/fo::*:static::*`
```shell
curl http://localhost:4000/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer sk-1234" \
-d '{
"model": "fo::hi::static::hi",
"messages": [
{"role": "user", "content": "Hello, Claude!"}
]
}'
```
</TabItem>
</Tabs>

View file

@ -277,7 +277,7 @@ const sidebars = {
description: "Learn how to load balance, route, and set fallbacks for your LLM requests",
slug: "/routing-load-balancing",
},
items: ["routing", "scheduler", "proxy/load_balancing", "proxy/reliability", "proxy/tag_routing", "proxy/provider_budget_routing", "proxy/team_based_routing", "proxy/customer_routing"],
items: ["routing", "scheduler", "proxy/load_balancing", "proxy/reliability", "proxy/tag_routing", "proxy/provider_budget_routing", "proxy/team_based_routing", "proxy/customer_routing", "wildcard_routing"],
},
{
type: "category",

View file

@ -3383,6 +3383,8 @@
"supports_vision": true,
"supports_response_schema": true,
"supports_prompt_caching": true,
"tpm": 4000000,
"rpm": 2000,
"source": "https://ai.google.dev/pricing"
},
"gemini/gemini-1.5-flash-001": {
@ -3406,6 +3408,8 @@
"supports_vision": true,
"supports_response_schema": true,
"supports_prompt_caching": true,
"tpm": 4000000,
"rpm": 2000,
"source": "https://ai.google.dev/pricing"
},
"gemini/gemini-1.5-flash": {
@ -3428,6 +3432,8 @@
"supports_function_calling": true,
"supports_vision": true,
"supports_response_schema": true,
"tpm": 4000000,
"rpm": 2000,
"source": "https://ai.google.dev/pricing"
},
"gemini/gemini-1.5-flash-latest": {
@ -3450,6 +3456,32 @@
"supports_function_calling": true,
"supports_vision": true,
"supports_response_schema": true,
"tpm": 4000000,
"rpm": 2000,
"source": "https://ai.google.dev/pricing"
},
"gemini/gemini-1.5-flash-8b": {
"max_tokens": 8192,
"max_input_tokens": 1048576,
"max_output_tokens": 8192,
"max_images_per_prompt": 3000,
"max_videos_per_prompt": 10,
"max_video_length": 1,
"max_audio_length_hours": 8.4,
"max_audio_per_prompt": 1,
"max_pdf_size_mb": 30,
"input_cost_per_token": 0,
"input_cost_per_token_above_128k_tokens": 0,
"output_cost_per_token": 0,
"output_cost_per_token_above_128k_tokens": 0,
"litellm_provider": "gemini",
"mode": "chat",
"supports_system_messages": true,
"supports_function_calling": true,
"supports_vision": true,
"supports_response_schema": true,
"tpm": 4000000,
"rpm": 4000,
"source": "https://ai.google.dev/pricing"
},
"gemini/gemini-1.5-flash-8b-exp-0924": {
@ -3472,6 +3504,8 @@
"supports_function_calling": true,
"supports_vision": true,
"supports_response_schema": true,
"tpm": 4000000,
"rpm": 4000,
"source": "https://ai.google.dev/pricing"
},
"gemini/gemini-exp-1114": {
@ -3494,7 +3528,12 @@
"supports_function_calling": true,
"supports_vision": true,
"supports_response_schema": true,
"source": "https://ai.google.dev/pricing"
"tpm": 4000000,
"rpm": 1000,
"source": "https://ai.google.dev/pricing",
"metadata": {
"notes": "Rate limits not documented for gemini-exp-1114. Assuming same as gemini-1.5-pro."
}
},
"gemini/gemini-1.5-flash-exp-0827": {
"max_tokens": 8192,
@ -3516,6 +3555,8 @@
"supports_function_calling": true,
"supports_vision": true,
"supports_response_schema": true,
"tpm": 4000000,
"rpm": 2000,
"source": "https://ai.google.dev/pricing"
},
"gemini/gemini-1.5-flash-8b-exp-0827": {
@ -3537,6 +3578,9 @@
"supports_system_messages": true,
"supports_function_calling": true,
"supports_vision": true,
"supports_response_schema": true,
"tpm": 4000000,
"rpm": 4000,
"source": "https://ai.google.dev/pricing"
},
"gemini/gemini-pro": {
@ -3550,7 +3594,10 @@
"litellm_provider": "gemini",
"mode": "chat",
"supports_function_calling": true,
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
"rpd": 30000,
"tpm": 120000,
"rpm": 360,
"source": "https://ai.google.dev/gemini-api/docs/models/gemini"
},
"gemini/gemini-1.5-pro": {
"max_tokens": 8192,
@ -3567,6 +3614,8 @@
"supports_vision": true,
"supports_tool_choice": true,
"supports_response_schema": true,
"tpm": 4000000,
"rpm": 1000,
"source": "https://ai.google.dev/pricing"
},
"gemini/gemini-1.5-pro-002": {
@ -3585,6 +3634,8 @@
"supports_tool_choice": true,
"supports_response_schema": true,
"supports_prompt_caching": true,
"tpm": 4000000,
"rpm": 1000,
"source": "https://ai.google.dev/pricing"
},
"gemini/gemini-1.5-pro-001": {
@ -3603,6 +3654,8 @@
"supports_tool_choice": true,
"supports_response_schema": true,
"supports_prompt_caching": true,
"tpm": 4000000,
"rpm": 1000,
"source": "https://ai.google.dev/pricing"
},
"gemini/gemini-1.5-pro-exp-0801": {
@ -3620,6 +3673,8 @@
"supports_vision": true,
"supports_tool_choice": true,
"supports_response_schema": true,
"tpm": 4000000,
"rpm": 1000,
"source": "https://ai.google.dev/pricing"
},
"gemini/gemini-1.5-pro-exp-0827": {
@ -3637,6 +3692,8 @@
"supports_vision": true,
"supports_tool_choice": true,
"supports_response_schema": true,
"tpm": 4000000,
"rpm": 1000,
"source": "https://ai.google.dev/pricing"
},
"gemini/gemini-1.5-pro-latest": {
@ -3654,6 +3711,8 @@
"supports_vision": true,
"supports_tool_choice": true,
"supports_response_schema": true,
"tpm": 4000000,
"rpm": 1000,
"source": "https://ai.google.dev/pricing"
},
"gemini/gemini-pro-vision": {
@ -3668,6 +3727,9 @@
"mode": "chat",
"supports_function_calling": true,
"supports_vision": true,
"rpd": 30000,
"tpm": 120000,
"rpm": 360,
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
},
"gemini/gemini-gemma-2-27b-it": {

View file

@ -1367,6 +1367,7 @@ async def list_team(
""".format(
team.team_id, team.model_dump(), str(e)
)
raise HTTPException(status_code=400, detail={"error": team_exception})
verbose_proxy_logger.exception(team_exception)
continue
return returned_responses

View file

@ -41,6 +41,7 @@ from typing import (
import httpx
import openai
from openai import AsyncOpenAI
from pydantic import BaseModel
from typing_extensions import overload
import litellm
@ -122,6 +123,7 @@ from litellm.types.router import (
ModelInfo,
ProviderBudgetConfigType,
RetryPolicy,
RouterCacheEnum,
RouterErrors,
RouterGeneralSettings,
RouterModelGroupAliasItem,
@ -239,7 +241,6 @@ class Router:
] = "simple-shuffle",
routing_strategy_args: dict = {}, # just for latency-based
provider_budget_config: Optional[ProviderBudgetConfigType] = None,
semaphore: Optional[asyncio.Semaphore] = None,
alerting_config: Optional[AlertingConfig] = None,
router_general_settings: Optional[
RouterGeneralSettings
@ -315,8 +316,6 @@ class Router:
from litellm._service_logger import ServiceLogging
if semaphore:
self.semaphore = semaphore
self.set_verbose = set_verbose
self.debug_level = debug_level
self.enable_pre_call_checks = enable_pre_call_checks
@ -506,6 +505,14 @@ class Router:
litellm.success_callback.append(self.sync_deployment_callback_on_success)
else:
litellm.success_callback = [self.sync_deployment_callback_on_success]
if isinstance(litellm._async_failure_callback, list):
litellm._async_failure_callback.append(
self.async_deployment_callback_on_failure
)
else:
litellm._async_failure_callback = [
self.async_deployment_callback_on_failure
]
## COOLDOWNS ##
if isinstance(litellm.failure_callback, list):
litellm.failure_callback.append(self.deployment_callback_on_failure)
@ -3291,13 +3298,14 @@ class Router:
):
"""
Track remaining tpm/rpm quota for model in model_list
Currently, only updates TPM usage.
"""
try:
if kwargs["litellm_params"].get("metadata") is None:
pass
else:
deployment_name = kwargs["litellm_params"]["metadata"].get(
"deployment", None
) # stable name - works for wildcard routes as well
model_group = kwargs["litellm_params"]["metadata"].get(
"model_group", None
)
@ -3308,6 +3316,8 @@ class Router:
elif isinstance(id, int):
id = str(id)
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
_usage_obj = completion_response.get("usage")
total_tokens = _usage_obj.get("total_tokens", 0) if _usage_obj else 0
@ -3319,13 +3329,14 @@ class Router:
"%H-%M"
) # use the same timezone regardless of system clock
tpm_key = f"global_router:{id}:tpm:{current_minute}"
tpm_key = RouterCacheEnum.TPM.value.format(
id=id, current_minute=current_minute, model=deployment_name
)
# ------------
# Update usage
# ------------
# update cache
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
## TPM
await self.cache.async_increment_cache(
key=tpm_key,
@ -3334,6 +3345,17 @@ class Router:
ttl=RoutingArgs.ttl.value,
)
## RPM
rpm_key = RouterCacheEnum.RPM.value.format(
id=id, current_minute=current_minute, model=deployment_name
)
await self.cache.async_increment_cache(
key=rpm_key,
value=1,
parent_otel_span=parent_otel_span,
ttl=RoutingArgs.ttl.value,
)
increment_deployment_successes_for_current_minute(
litellm_router_instance=self,
deployment_id=id,
@ -3446,6 +3468,40 @@ class Router:
except Exception as e:
raise e
async def async_deployment_callback_on_failure(
self, kwargs, completion_response: Optional[Any], start_time, end_time
):
"""
Update RPM usage for a deployment
"""
deployment_name = kwargs["litellm_params"]["metadata"].get(
"deployment", None
) # handles wildcard routes - by giving the original name sent to `litellm.completion`
model_group = kwargs["litellm_params"]["metadata"].get("model_group", None)
model_info = kwargs["litellm_params"].get("model_info", {}) or {}
id = model_info.get("id", None)
if model_group is None or id is None:
return
elif isinstance(id, int):
id = str(id)
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
dt = get_utc_datetime()
current_minute = dt.strftime(
"%H-%M"
) # use the same timezone regardless of system clock
## RPM
rpm_key = RouterCacheEnum.RPM.value.format(
id=id, current_minute=current_minute, model=deployment_name
)
await self.cache.async_increment_cache(
key=rpm_key,
value=1,
parent_otel_span=parent_otel_span,
ttl=RoutingArgs.ttl.value,
)
def log_retry(self, kwargs: dict, e: Exception) -> dict:
"""
When a retry or fallback happens, log the details of the just failed model call - similar to Sentry breadcrumbing
@ -4123,7 +4179,24 @@ class Router:
raise Exception("Model Name invalid - {}".format(type(model)))
return None
def get_router_model_info(self, deployment: dict) -> ModelMapInfo:
@overload
def get_router_model_info(
self, deployment: dict, received_model_name: str, id: None = None
) -> ModelMapInfo:
pass
@overload
def get_router_model_info(
self, deployment: None, received_model_name: str, id: str
) -> ModelMapInfo:
pass
def get_router_model_info(
self,
deployment: Optional[dict],
received_model_name: str,
id: Optional[str] = None,
) -> ModelMapInfo:
"""
For a given model id, return the model info (max tokens, input cost, output cost, etc.).
@ -4137,6 +4210,14 @@ class Router:
Raises:
- ValueError -> If model is not mapped yet
"""
if id is not None:
_deployment = self.get_deployment(model_id=id)
if _deployment is not None:
deployment = _deployment.model_dump(exclude_none=True)
if deployment is None:
raise ValueError("Deployment not found")
## GET BASE MODEL
base_model = deployment.get("model_info", {}).get("base_model", None)
if base_model is None:
@ -4158,10 +4239,27 @@ class Router:
elif custom_llm_provider != "azure":
model = _model
## GET LITELLM MODEL INFO - raises exception, if model is not mapped
model_info = litellm.get_model_info(
model="{}/{}".format(custom_llm_provider, model)
potential_models = self.pattern_router.route(received_model_name)
if "*" in model and potential_models is not None: # if wildcard route
for potential_model in potential_models:
try:
if potential_model.get("model_info", {}).get(
"id"
) == deployment.get("model_info", {}).get("id"):
model = potential_model.get("litellm_params", {}).get(
"model"
)
break
except Exception:
pass
## GET LITELLM MODEL INFO - raises exception, if model is not mapped
if not model.startswith(custom_llm_provider):
model_info_name = "{}/{}".format(custom_llm_provider, model)
else:
model_info_name = model
model_info = litellm.get_model_info(model=model_info_name)
## CHECK USER SET MODEL INFO
user_model_info = deployment.get("model_info", {})
@ -4211,8 +4309,10 @@ class Router:
total_tpm: Optional[int] = None
total_rpm: Optional[int] = None
configurable_clientside_auth_params: CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS = None
for model in self.model_list:
model_list = self.get_model_list(model_name=model_group)
if model_list is None:
return None
for model in model_list:
is_match = False
if (
"model_name" in model and model["model_name"] == model_group
@ -4227,7 +4327,7 @@ class Router:
if not is_match:
continue
# model in model group found #
litellm_params = LiteLLM_Params(**model["litellm_params"])
litellm_params = LiteLLM_Params(**model["litellm_params"]) # type: ignore
# get configurable clientside auth params
configurable_clientside_auth_params = (
litellm_params.configurable_clientside_auth_params
@ -4235,38 +4335,30 @@ class Router:
# get model tpm
_deployment_tpm: Optional[int] = None
if _deployment_tpm is None:
_deployment_tpm = model.get("tpm", None)
_deployment_tpm = model.get("tpm", None) # type: ignore
if _deployment_tpm is None:
_deployment_tpm = model.get("litellm_params", {}).get("tpm", None)
_deployment_tpm = model.get("litellm_params", {}).get("tpm", None) # type: ignore
if _deployment_tpm is None:
_deployment_tpm = model.get("model_info", {}).get("tpm", None)
_deployment_tpm = model.get("model_info", {}).get("tpm", None) # type: ignore
if _deployment_tpm is not None:
if total_tpm is None:
total_tpm = 0
total_tpm += _deployment_tpm # type: ignore
# get model rpm
_deployment_rpm: Optional[int] = None
if _deployment_rpm is None:
_deployment_rpm = model.get("rpm", None)
_deployment_rpm = model.get("rpm", None) # type: ignore
if _deployment_rpm is None:
_deployment_rpm = model.get("litellm_params", {}).get("rpm", None)
_deployment_rpm = model.get("litellm_params", {}).get("rpm", None) # type: ignore
if _deployment_rpm is None:
_deployment_rpm = model.get("model_info", {}).get("rpm", None)
_deployment_rpm = model.get("model_info", {}).get("rpm", None) # type: ignore
if _deployment_rpm is not None:
if total_rpm is None:
total_rpm = 0
total_rpm += _deployment_rpm # type: ignore
# get model info
try:
model_info = litellm.get_model_info(model=litellm_params.model)
except Exception:
model_info = None
# get llm provider
model, llm_provider = "", ""
litellm_model, llm_provider = "", ""
try:
model, llm_provider, _, _ = litellm.get_llm_provider(
litellm_model, llm_provider, _, _ = litellm.get_llm_provider(
model=litellm_params.model,
custom_llm_provider=litellm_params.custom_llm_provider,
)
@ -4277,7 +4369,7 @@ class Router:
if model_info is None:
supported_openai_params = litellm.get_supported_openai_params(
model=model, custom_llm_provider=llm_provider
model=litellm_model, custom_llm_provider=llm_provider
)
if supported_openai_params is None:
supported_openai_params = []
@ -4367,7 +4459,20 @@ class Router:
model_group_info.supported_openai_params = model_info[
"supported_openai_params"
]
if model_info.get("tpm", None) is not None and _deployment_tpm is None:
_deployment_tpm = model_info.get("tpm")
if model_info.get("rpm", None) is not None and _deployment_rpm is None:
_deployment_rpm = model_info.get("rpm")
if _deployment_tpm is not None:
if total_tpm is None:
total_tpm = 0
total_tpm += _deployment_tpm # type: ignore
if _deployment_rpm is not None:
if total_rpm is None:
total_rpm = 0
total_rpm += _deployment_rpm # type: ignore
if model_group_info is not None:
## UPDATE WITH TOTAL TPM/RPM FOR MODEL GROUP
if total_tpm is not None:
@ -4419,7 +4524,10 @@ class Router:
self, model_group: str
) -> Tuple[Optional[int], Optional[int]]:
"""
Returns remaining tpm/rpm quota for model group
Returns current tpm/rpm usage for model group
Parameters:
- model_group: str - the received model name from the user (can be a wildcard route).
Returns:
- usage: Tuple[tpm, rpm]
@ -4430,20 +4538,37 @@ class Router:
) # use the same timezone regardless of system clock
tpm_keys: List[str] = []
rpm_keys: List[str] = []
for model in self.model_list:
if "model_name" in model and model["model_name"] == model_group:
model_list = self.get_model_list(model_name=model_group)
if model_list is None: # no matching deployments
return None, None
for model in model_list:
id: Optional[str] = model.get("model_info", {}).get("id") # type: ignore
litellm_model: Optional[str] = model["litellm_params"].get(
"model"
) # USE THE MODEL SENT TO litellm.completion() - consistent with how global_router cache is written.
if id is None or litellm_model is None:
continue
tpm_keys.append(
f"global_router:{model['model_info']['id']}:tpm:{current_minute}"
RouterCacheEnum.TPM.value.format(
id=id,
model=litellm_model,
current_minute=current_minute,
)
)
rpm_keys.append(
f"global_router:{model['model_info']['id']}:rpm:{current_minute}"
RouterCacheEnum.RPM.value.format(
id=id,
model=litellm_model,
current_minute=current_minute,
)
)
combined_tpm_rpm_keys = tpm_keys + rpm_keys
combined_tpm_rpm_values = await self.cache.async_batch_get_cache(
keys=combined_tpm_rpm_keys
)
if combined_tpm_rpm_values is None:
return None, None
@ -4468,6 +4593,32 @@ class Router:
rpm_usage += t
return tpm_usage, rpm_usage
async def get_remaining_model_group_usage(self, model_group: str) -> Dict[str, int]:
current_tpm, current_rpm = await self.get_model_group_usage(model_group)
model_group_info = self.get_model_group_info(model_group)
if model_group_info is not None and model_group_info.tpm is not None:
tpm_limit = model_group_info.tpm
else:
tpm_limit = None
if model_group_info is not None and model_group_info.rpm is not None:
rpm_limit = model_group_info.rpm
else:
rpm_limit = None
returned_dict = {}
if tpm_limit is not None and current_tpm is not None:
returned_dict["x-ratelimit-remaining-tokens"] = tpm_limit - current_tpm
returned_dict["x-ratelimit-limit-tokens"] = tpm_limit
if rpm_limit is not None and current_rpm is not None:
returned_dict["x-ratelimit-remaining-requests"] = rpm_limit - current_rpm
returned_dict["x-ratelimit-limit-requests"] = rpm_limit
return returned_dict
async def set_response_headers(
self, response: Any, model_group: Optional[str] = None
) -> Any:
@ -4478,6 +4629,30 @@ class Router:
# - if healthy_deployments > 1, return model group rate limit headers
# - else return the model's rate limit headers
"""
if (
isinstance(response, BaseModel)
and hasattr(response, "_hidden_params")
and isinstance(response._hidden_params, dict) # type: ignore
):
response._hidden_params.setdefault("additional_headers", {}) # type: ignore
response._hidden_params["additional_headers"][ # type: ignore
"x-litellm-model-group"
] = model_group
additional_headers = response._hidden_params["additional_headers"] # type: ignore
if (
"x-ratelimit-remaining-tokens" not in additional_headers
and "x-ratelimit-remaining-requests" not in additional_headers
and model_group is not None
):
remaining_usage = await self.get_remaining_model_group_usage(
model_group
)
for header, value in remaining_usage.items():
if value is not None:
additional_headers[header] = value
return response
def get_model_ids(self, model_name: Optional[str] = None) -> List[str]:
@ -4560,6 +4735,13 @@ class Router:
)
)
if len(returned_models) == 0: # check if wildcard route
potential_wildcard_models = self.pattern_router.route(model_name)
if potential_wildcard_models is not None:
returned_models.extend(
[DeploymentTypedDict(**m) for m in potential_wildcard_models] # type: ignore
)
if model_name is None:
returned_models += self.model_list
@ -4810,10 +4992,12 @@ class Router:
base_model = deployment.get("litellm_params", {}).get(
"base_model", None
)
model_info = self.get_router_model_info(
deployment=deployment, received_model_name=model
)
model = base_model or deployment.get("litellm_params", {}).get(
"model", None
)
model_info = self.get_router_model_info(deployment=deployment)
if (
isinstance(model_info, dict)

View file

View file

@ -9,7 +9,7 @@ from typing import Any, Dict, List, Literal, Optional, Tuple, Union
import httpx
from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import TypedDict
from typing_extensions import Required, TypedDict
from ..exceptions import RateLimitError
from .completion import CompletionRequest
@ -352,9 +352,10 @@ class LiteLLMParamsTypedDict(TypedDict, total=False):
tags: Optional[List[str]]
class DeploymentTypedDict(TypedDict):
model_name: str
litellm_params: LiteLLMParamsTypedDict
class DeploymentTypedDict(TypedDict, total=False):
model_name: Required[str]
litellm_params: Required[LiteLLMParamsTypedDict]
model_info: Optional[dict]
SPECIAL_MODEL_INFO_PARAMS = [
@ -640,3 +641,8 @@ class ProviderBudgetInfo(BaseModel):
ProviderBudgetConfigType = Dict[str, ProviderBudgetInfo]
class RouterCacheEnum(enum.Enum):
TPM = "global_router:{id}:{model}:tpm:{current_minute}"
RPM = "global_router:{id}:{model}:rpm:{current_minute}"

View file

@ -106,6 +106,8 @@ class ModelInfo(TypedDict, total=False):
supports_prompt_caching: Optional[bool]
supports_audio_input: Optional[bool]
supports_audio_output: Optional[bool]
tpm: Optional[int]
rpm: Optional[int]
class GenericStreamingChunk(TypedDict, total=False):

View file

@ -4656,6 +4656,8 @@ def get_model_info( # noqa: PLR0915
),
supports_audio_input=_model_info.get("supports_audio_input", False),
supports_audio_output=_model_info.get("supports_audio_output", False),
tpm=_model_info.get("tpm", None),
rpm=_model_info.get("rpm", None),
)
except Exception as e:
if "OllamaError" in str(e):

View file

@ -3383,6 +3383,8 @@
"supports_vision": true,
"supports_response_schema": true,
"supports_prompt_caching": true,
"tpm": 4000000,
"rpm": 2000,
"source": "https://ai.google.dev/pricing"
},
"gemini/gemini-1.5-flash-001": {
@ -3406,6 +3408,8 @@
"supports_vision": true,
"supports_response_schema": true,
"supports_prompt_caching": true,
"tpm": 4000000,
"rpm": 2000,
"source": "https://ai.google.dev/pricing"
},
"gemini/gemini-1.5-flash": {
@ -3428,6 +3432,8 @@
"supports_function_calling": true,
"supports_vision": true,
"supports_response_schema": true,
"tpm": 4000000,
"rpm": 2000,
"source": "https://ai.google.dev/pricing"
},
"gemini/gemini-1.5-flash-latest": {
@ -3450,6 +3456,32 @@
"supports_function_calling": true,
"supports_vision": true,
"supports_response_schema": true,
"tpm": 4000000,
"rpm": 2000,
"source": "https://ai.google.dev/pricing"
},
"gemini/gemini-1.5-flash-8b": {
"max_tokens": 8192,
"max_input_tokens": 1048576,
"max_output_tokens": 8192,
"max_images_per_prompt": 3000,
"max_videos_per_prompt": 10,
"max_video_length": 1,
"max_audio_length_hours": 8.4,
"max_audio_per_prompt": 1,
"max_pdf_size_mb": 30,
"input_cost_per_token": 0,
"input_cost_per_token_above_128k_tokens": 0,
"output_cost_per_token": 0,
"output_cost_per_token_above_128k_tokens": 0,
"litellm_provider": "gemini",
"mode": "chat",
"supports_system_messages": true,
"supports_function_calling": true,
"supports_vision": true,
"supports_response_schema": true,
"tpm": 4000000,
"rpm": 4000,
"source": "https://ai.google.dev/pricing"
},
"gemini/gemini-1.5-flash-8b-exp-0924": {
@ -3472,6 +3504,8 @@
"supports_function_calling": true,
"supports_vision": true,
"supports_response_schema": true,
"tpm": 4000000,
"rpm": 4000,
"source": "https://ai.google.dev/pricing"
},
"gemini/gemini-exp-1114": {
@ -3494,7 +3528,12 @@
"supports_function_calling": true,
"supports_vision": true,
"supports_response_schema": true,
"source": "https://ai.google.dev/pricing"
"tpm": 4000000,
"rpm": 1000,
"source": "https://ai.google.dev/pricing",
"metadata": {
"notes": "Rate limits not documented for gemini-exp-1114. Assuming same as gemini-1.5-pro."
}
},
"gemini/gemini-1.5-flash-exp-0827": {
"max_tokens": 8192,
@ -3516,6 +3555,8 @@
"supports_function_calling": true,
"supports_vision": true,
"supports_response_schema": true,
"tpm": 4000000,
"rpm": 2000,
"source": "https://ai.google.dev/pricing"
},
"gemini/gemini-1.5-flash-8b-exp-0827": {
@ -3537,6 +3578,9 @@
"supports_system_messages": true,
"supports_function_calling": true,
"supports_vision": true,
"supports_response_schema": true,
"tpm": 4000000,
"rpm": 4000,
"source": "https://ai.google.dev/pricing"
},
"gemini/gemini-pro": {
@ -3550,7 +3594,10 @@
"litellm_provider": "gemini",
"mode": "chat",
"supports_function_calling": true,
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
"rpd": 30000,
"tpm": 120000,
"rpm": 360,
"source": "https://ai.google.dev/gemini-api/docs/models/gemini"
},
"gemini/gemini-1.5-pro": {
"max_tokens": 8192,
@ -3567,6 +3614,8 @@
"supports_vision": true,
"supports_tool_choice": true,
"supports_response_schema": true,
"tpm": 4000000,
"rpm": 1000,
"source": "https://ai.google.dev/pricing"
},
"gemini/gemini-1.5-pro-002": {
@ -3585,6 +3634,8 @@
"supports_tool_choice": true,
"supports_response_schema": true,
"supports_prompt_caching": true,
"tpm": 4000000,
"rpm": 1000,
"source": "https://ai.google.dev/pricing"
},
"gemini/gemini-1.5-pro-001": {
@ -3603,6 +3654,8 @@
"supports_tool_choice": true,
"supports_response_schema": true,
"supports_prompt_caching": true,
"tpm": 4000000,
"rpm": 1000,
"source": "https://ai.google.dev/pricing"
},
"gemini/gemini-1.5-pro-exp-0801": {
@ -3620,6 +3673,8 @@
"supports_vision": true,
"supports_tool_choice": true,
"supports_response_schema": true,
"tpm": 4000000,
"rpm": 1000,
"source": "https://ai.google.dev/pricing"
},
"gemini/gemini-1.5-pro-exp-0827": {
@ -3637,6 +3692,8 @@
"supports_vision": true,
"supports_tool_choice": true,
"supports_response_schema": true,
"tpm": 4000000,
"rpm": 1000,
"source": "https://ai.google.dev/pricing"
},
"gemini/gemini-1.5-pro-latest": {
@ -3654,6 +3711,8 @@
"supports_vision": true,
"supports_tool_choice": true,
"supports_response_schema": true,
"tpm": 4000000,
"rpm": 1000,
"source": "https://ai.google.dev/pricing"
},
"gemini/gemini-pro-vision": {
@ -3668,6 +3727,9 @@
"mode": "chat",
"supports_function_calling": true,
"supports_vision": true,
"rpd": 30000,
"tpm": 120000,
"rpm": 360,
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
},
"gemini/gemini-gemma-2-27b-it": {

View file

@ -46,17 +46,22 @@ print(env_keys)
repo_base = "./"
print(os.listdir(repo_base))
docs_path = (
"../../docs/my-website/docs/proxy/config_settings.md" # Path to the documentation
"./docs/my-website/docs/proxy/config_settings.md" # Path to the documentation
)
documented_keys = set()
try:
with open(docs_path, "r", encoding="utf-8") as docs_file:
content = docs_file.read()
print(f"content: {content}")
# Find the section titled "general_settings - Reference"
general_settings_section = re.search(
r"### environment variables - Reference(.*?)###", content, re.DOTALL
r"### environment variables - Reference(.*?)(?=\n###|\Z)",
content,
re.DOTALL | re.MULTILINE,
)
print(f"general_settings_section: {general_settings_section}")
if general_settings_section:
# Extract the table rows, which contain the documented keys
table_content = general_settings_section.group(1)
@ -70,6 +75,7 @@ except Exception as e:
)
print(f"documented_keys: {documented_keys}")
# Compare and find undocumented keys
undocumented_keys = env_keys - documented_keys

View file

@ -0,0 +1,87 @@
import os
import re
import inspect
from typing import Type
import sys
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import litellm
def get_init_params(cls: Type) -> list[str]:
"""
Retrieve all parameters supported by the `__init__` method of a given class.
Args:
cls: The class to inspect.
Returns:
A list of parameter names.
"""
if not hasattr(cls, "__init__"):
raise ValueError(
f"The provided class {cls.__name__} does not have an __init__ method."
)
init_method = cls.__init__
argspec = inspect.getfullargspec(init_method)
# The first argument is usually 'self', so we exclude it
return argspec.args[1:] # Exclude 'self'
router_init_params = set(get_init_params(litellm.router.Router))
print(router_init_params)
router_init_params.remove("model_list")
# Parse the documentation to extract documented keys
repo_base = "./"
print(os.listdir(repo_base))
docs_path = (
"./docs/my-website/docs/proxy/config_settings.md" # Path to the documentation
)
# docs_path = (
# "../../docs/my-website/docs/proxy/config_settings.md" # Path to the documentation
# )
documented_keys = set()
try:
with open(docs_path, "r", encoding="utf-8") as docs_file:
content = docs_file.read()
# Find the section titled "general_settings - Reference"
general_settings_section = re.search(
r"### router_settings - Reference(.*?)###", content, re.DOTALL
)
if general_settings_section:
# Extract the table rows, which contain the documented keys
table_content = general_settings_section.group(1)
doc_key_pattern = re.compile(
r"\|\s*([^\|]+?)\s*\|"
) # Capture the key from each row of the table
documented_keys.update(doc_key_pattern.findall(table_content))
except Exception as e:
raise Exception(
f"Error reading documentation: {e}, \n repo base - {os.listdir(repo_base)}"
)
# Compare and find undocumented keys
undocumented_keys = router_init_params - documented_keys
# Print results
print("Keys expected in 'router settings' (found in code):")
for key in sorted(router_init_params):
print(key)
if undocumented_keys:
raise Exception(
f"\nKeys not documented in 'router settings - Reference': {undocumented_keys}"
)
else:
print(
"\nAll keys are documented in 'router settings - Reference'. - {}".format(
router_init_params
)
)

View file

@ -62,7 +62,14 @@ class BaseLLMChatTest(ABC):
response = litellm.completion(**base_completion_call_args, messages=messages)
assert response is not None
def test_json_response_format(self):
@pytest.mark.parametrize(
"response_format",
[
{"type": "json_object"},
{"type": "text"},
],
)
def test_json_response_format(self, response_format):
"""
Test that the JSON response format is supported by the LLM API
"""
@ -83,7 +90,7 @@ class BaseLLMChatTest(ABC):
response = litellm.completion(
**base_completion_call_args,
messages=messages,
response_format={"type": "json_object"},
response_format=response_format,
)
print(response)

View file

@ -102,3 +102,17 @@ def test_get_model_info_ollama_chat():
print(mock_client.call_args.kwargs)
assert mock_client.call_args.kwargs["json"]["name"] == "mistral"
def test_get_model_info_gemini():
"""
Tests if ALL gemini models have 'tpm' and 'rpm' in the model info
"""
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
litellm.model_cost = litellm.get_model_cost_map(url="")
model_map = litellm.model_cost
for model, info in model_map.items():
if model.startswith("gemini/") and not "gemma" in model:
assert info.get("tpm") is not None, f"{model} does not have tpm"
assert info.get("rpm") is not None, f"{model} does not have rpm"

View file

@ -2115,10 +2115,14 @@ def test_router_get_model_info(model, base_model, llm_provider):
assert deployment is not None
if llm_provider == "openai" or (base_model is not None and llm_provider == "azure"):
router.get_router_model_info(deployment=deployment.to_json())
router.get_router_model_info(
deployment=deployment.to_json(), received_model_name=model
)
else:
try:
router.get_router_model_info(deployment=deployment.to_json())
router.get_router_model_info(
deployment=deployment.to_json(), received_model_name=model
)
pytest.fail("Expected this to raise model not mapped error")
except Exception as e:
if "This model isn't mapped yet" in str(e):

View file

@ -174,3 +174,185 @@ async def test_update_kwargs_before_fallbacks(call_type):
print(mock_client.call_args.kwargs)
assert mock_client.call_args.kwargs["litellm_trace_id"] is not None
def test_router_get_model_info_wildcard_routes():
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
litellm.model_cost = litellm.get_model_cost_map(url="")
router = Router(
model_list=[
{
"model_name": "gemini/*",
"litellm_params": {"model": "gemini/*"},
"model_info": {"id": 1},
},
]
)
model_info = router.get_router_model_info(
deployment=None, received_model_name="gemini/gemini-1.5-flash", id="1"
)
print(model_info)
assert model_info is not None
assert model_info["tpm"] is not None
assert model_info["rpm"] is not None
@pytest.mark.asyncio
async def test_router_get_model_group_usage_wildcard_routes():
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
litellm.model_cost = litellm.get_model_cost_map(url="")
router = Router(
model_list=[
{
"model_name": "gemini/*",
"litellm_params": {"model": "gemini/*"},
"model_info": {"id": 1},
},
]
)
resp = await router.acompletion(
model="gemini/gemini-1.5-flash",
messages=[{"role": "user", "content": "Hello, how are you?"}],
mock_response="Hello, I'm good.",
)
print(resp)
await asyncio.sleep(1)
tpm, rpm = await router.get_model_group_usage(model_group="gemini/gemini-1.5-flash")
assert tpm is not None, "tpm is None"
assert rpm is not None, "rpm is None"
@pytest.mark.asyncio
async def test_call_router_callbacks_on_success():
router = Router(
model_list=[
{
"model_name": "gemini/*",
"litellm_params": {"model": "gemini/*"},
"model_info": {"id": 1},
},
]
)
with patch.object(
router.cache, "async_increment_cache", new=AsyncMock()
) as mock_callback:
await router.acompletion(
model="gemini/gemini-1.5-flash",
messages=[{"role": "user", "content": "Hello, how are you?"}],
mock_response="Hello, I'm good.",
)
await asyncio.sleep(1)
assert mock_callback.call_count == 2
assert (
mock_callback.call_args_list[0]
.kwargs["key"]
.startswith("global_router:1:gemini/gemini-1.5-flash:tpm")
)
assert (
mock_callback.call_args_list[1]
.kwargs["key"]
.startswith("global_router:1:gemini/gemini-1.5-flash:rpm")
)
@pytest.mark.asyncio
async def test_call_router_callbacks_on_failure():
router = Router(
model_list=[
{
"model_name": "gemini/*",
"litellm_params": {"model": "gemini/*"},
"model_info": {"id": 1},
},
]
)
with patch.object(
router.cache, "async_increment_cache", new=AsyncMock()
) as mock_callback:
with pytest.raises(litellm.RateLimitError):
await router.acompletion(
model="gemini/gemini-1.5-flash",
messages=[{"role": "user", "content": "Hello, how are you?"}],
mock_response="litellm.RateLimitError",
num_retries=0,
)
await asyncio.sleep(1)
print(mock_callback.call_args_list)
assert mock_callback.call_count == 1
assert (
mock_callback.call_args_list[0]
.kwargs["key"]
.startswith("global_router:1:gemini/gemini-1.5-flash:rpm")
)
@pytest.mark.asyncio
async def test_router_model_group_headers():
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
litellm.model_cost = litellm.get_model_cost_map(url="")
from litellm.types.utils import OPENAI_RESPONSE_HEADERS
router = Router(
model_list=[
{
"model_name": "gemini/*",
"litellm_params": {"model": "gemini/*"},
"model_info": {"id": 1},
}
]
)
for _ in range(2):
resp = await router.acompletion(
model="gemini/gemini-1.5-flash",
messages=[{"role": "user", "content": "Hello, how are you?"}],
mock_response="Hello, I'm good.",
)
await asyncio.sleep(1)
assert (
resp._hidden_params["additional_headers"]["x-litellm-model-group"]
== "gemini/gemini-1.5-flash"
)
assert "x-ratelimit-remaining-requests" in resp._hidden_params["additional_headers"]
assert "x-ratelimit-remaining-tokens" in resp._hidden_params["additional_headers"]
@pytest.mark.asyncio
async def test_get_remaining_model_group_usage():
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
litellm.model_cost = litellm.get_model_cost_map(url="")
from litellm.types.utils import OPENAI_RESPONSE_HEADERS
router = Router(
model_list=[
{
"model_name": "gemini/*",
"litellm_params": {"model": "gemini/*"},
"model_info": {"id": 1},
}
]
)
for _ in range(2):
await router.acompletion(
model="gemini/gemini-1.5-flash",
messages=[{"role": "user", "content": "Hello, how are you?"}],
mock_response="Hello, I'm good.",
)
await asyncio.sleep(1)
remaining_usage = await router.get_remaining_model_group_usage(
model_group="gemini/gemini-1.5-flash"
)
assert remaining_usage is not None
assert "x-ratelimit-remaining-requests" in remaining_usage
assert "x-ratelimit-remaining-tokens" in remaining_usage

View file

@ -506,7 +506,7 @@ async def test_router_caching_ttl():
) as mock_client:
await router.acompletion(model=model, messages=messages)
mock_client.assert_called_once()
# mock_client.assert_called_once()
print(f"mock_client.call_args.kwargs: {mock_client.call_args.kwargs}")
print(f"mock_client.call_args.args: {mock_client.call_args.args}")

View file

@ -396,7 +396,8 @@ async def test_deployment_callback_on_success(model_list, sync_mode):
assert tpm_key is not None
def test_deployment_callback_on_failure(model_list):
@pytest.mark.asyncio
async def test_deployment_callback_on_failure(model_list):
"""Test if the '_deployment_callback_on_failure' function is working correctly"""
import time
@ -418,6 +419,18 @@ def test_deployment_callback_on_failure(model_list):
assert isinstance(result, bool)
assert result is False
model_response = router.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hello, how are you?"}],
mock_response="I'm fine, thank you!",
)
result = await router.async_deployment_callback_on_failure(
kwargs=kwargs,
completion_response=model_response,
start_time=time.time(),
end_time=time.time(),
)
def test_log_retry(model_list):
"""Test if the '_log_retry' function is working correctly"""