mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
Merge pull request #4786 from BerriAI/litellm_use_model_tier_keys
[Feat-Enterprise] Use free/paid tiers for Virtual Keys
This commit is contained in:
commit
4b96cd46b2
9 changed files with 359 additions and 23 deletions
102
docs/my-website/docs/proxy/free_paid_tier.md
Normal file
102
docs/my-website/docs/proxy/free_paid_tier.md
Normal file
|
@ -0,0 +1,102 @@
|
|||
# 💸 Free, Paid Tier Routing
|
||||
|
||||
Route Virtual Keys on `free tier` to cheaper models
|
||||
|
||||
### 1. Define free, paid tier models on config.yaml
|
||||
|
||||
:::info
|
||||
Requests with `model=gpt-4` will be routed to either `openai/fake` or `openai/gpt-4o` depending on which tier the virtual key is on
|
||||
:::
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: gpt-4
|
||||
litellm_params:
|
||||
model: openai/fake
|
||||
api_key: fake-key
|
||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||
model_info:
|
||||
tier: free # 👈 Key Change - set `tier to paid or free`
|
||||
- model_name: gpt-4
|
||||
litellm_params:
|
||||
model: openai/gpt-4o
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
model_info:
|
||||
tier: paid # 👈 Key Change - set `tier to paid or free`
|
||||
|
||||
general_settings:
|
||||
master_key: sk-1234
|
||||
```
|
||||
|
||||
### 2. Create Virtual Keys with pricing `tier=free`
|
||||
|
||||
```shell
|
||||
curl --location 'http://0.0.0.0:4000/key/generate' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data '{
|
||||
"metadata": {"tier": "free"}
|
||||
}'
|
||||
```
|
||||
|
||||
### 3. Make Request with Key on `Free Tier`
|
||||
|
||||
```shell
|
||||
curl -i http://localhost:4000/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
curl -i http://localhost:4000/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer sk-inxzoSurQsjog9gPrVOCcA" \
|
||||
-d '{
|
||||
"model": "gpt-4",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello, Claude gm!"}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
**Expected Response**
|
||||
|
||||
If this worked as expected then `x-litellm-model-api-base` should be `https://exampleopenaiendpoint-production.up.railway.app/` in the response headers
|
||||
|
||||
```shell
|
||||
x-litellm-model-api-base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||
|
||||
{"id":"chatcmpl-657b750f581240c1908679ed94b31bfe","choices":[{"finish_reason":"stop","index":0,"message":{"content":"\n\nHello there, how may I assist you today?","role":"assistant","tool_calls":null,"function_call":null}}],"created":1677652288,"model":"gpt-3.5-turbo-0125","object":"chat.completion","system_fingerprint":"fp_44709d6fcb","usage":{"completion_tokens":12,"prompt_tokens":9,"total_tokens":21}}%
|
||||
```
|
||||
|
||||
|
||||
### 4. Create Virtual Keys with pricing `tier=paid`
|
||||
|
||||
```shell
|
||||
curl --location 'http://0.0.0.0:4000/key/generate' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data '{
|
||||
"metadata": {"tier": "paid"}
|
||||
}'
|
||||
```
|
||||
|
||||
### 5. Make Request with Key on `Paid Tier`
|
||||
|
||||
```shell
|
||||
curl -i http://localhost:4000/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer sk-mnJoeSc6jFjzZr256q-iqA" \
|
||||
-d '{
|
||||
"model": "gpt-4",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello, Claude gm!"}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
**Expected Response**
|
||||
|
||||
If this worked as expected then `x-litellm-model-api-base` should be `https://api.openai.com` in the response headers
|
||||
|
||||
```shell
|
||||
x-litellm-model-api-base: https://api.openai.com
|
||||
|
||||
{"id":"chatcmpl-9mW75EbJCgwmLcO0M5DmwxpiBgWdc","choices":[{"finish_reason":"stop","index":0,"message":{"content":"Good morning! How can I assist you today?","role":"assistant","tool_calls":null,"function_call":null}}],"created":1721350215,"model":"gpt-4o-2024-05-13","object":"chat.completion","system_fingerprint":"fp_c4e5b6fa31","usage":{"completion_tokens":10,"prompt_tokens":12,"total_tokens":22}}
|
||||
```
|
|
@ -43,11 +43,12 @@ const sidebars = {
|
|||
"proxy/reliability",
|
||||
"proxy/cost_tracking",
|
||||
"proxy/self_serve",
|
||||
"proxy/virtual_keys",
|
||||
"proxy/free_paid_tier",
|
||||
"proxy/users",
|
||||
"proxy/team_budgets",
|
||||
"proxy/customers",
|
||||
"proxy/billing",
|
||||
"proxy/virtual_keys",
|
||||
"proxy/guardrails",
|
||||
"proxy/token_auth",
|
||||
"proxy/alerting",
|
||||
|
|
|
@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Any, Dict, Optional
|
|||
from fastapi import Request
|
||||
|
||||
from litellm._logging import verbose_logger, verbose_proxy_logger
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy._types import CommonProxyErrors, UserAPIKeyAuth
|
||||
from litellm.types.utils import SupportedCacheControls
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -95,15 +95,6 @@ async def add_litellm_data_to_request(
|
|||
cache_dict = parse_cache_control(cache_control_header)
|
||||
data["ttl"] = cache_dict.get("s-maxage")
|
||||
|
||||
### KEY-LEVEL CACHNG
|
||||
key_metadata = user_api_key_dict.metadata
|
||||
if "cache" in key_metadata:
|
||||
data["cache"] = {}
|
||||
if isinstance(key_metadata["cache"], dict):
|
||||
for k, v in key_metadata["cache"].items():
|
||||
if k in SupportedCacheControls:
|
||||
data["cache"][k] = v
|
||||
|
||||
verbose_proxy_logger.debug("receiving data: %s", data)
|
||||
|
||||
_metadata_variable_name = _get_metadata_variable_name(request)
|
||||
|
@ -133,6 +124,24 @@ async def add_litellm_data_to_request(
|
|||
user_api_key_dict, "team_alias", None
|
||||
)
|
||||
|
||||
### KEY-LEVEL Contorls
|
||||
key_metadata = user_api_key_dict.metadata
|
||||
if "cache" in key_metadata:
|
||||
data["cache"] = {}
|
||||
if isinstance(key_metadata["cache"], dict):
|
||||
for k, v in key_metadata["cache"].items():
|
||||
if k in SupportedCacheControls:
|
||||
data["cache"][k] = v
|
||||
if "tier" in key_metadata:
|
||||
if premium_user is not True:
|
||||
verbose_logger.warning(
|
||||
"Trying to use free/paid tier feature. This will not be applied %s",
|
||||
CommonProxyErrors.not_premium_user.value,
|
||||
)
|
||||
|
||||
# add request tier to metadata
|
||||
data[_metadata_variable_name]["tier"] = key_metadata["tier"]
|
||||
|
||||
# Team spend, budget - used by prometheus.py
|
||||
data[_metadata_variable_name][
|
||||
"user_api_key_team_max_budget"
|
||||
|
|
|
@ -1,23 +1,19 @@
|
|||
model_list:
|
||||
- model_name: fake-openai-endpoint
|
||||
- model_name: gpt-4
|
||||
litellm_params:
|
||||
model: openai/fake
|
||||
api_key: fake-key
|
||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||
- model_name: gemini-flash
|
||||
litellm_params:
|
||||
model: gemini/gemini-1.5-flash
|
||||
- model_name: whisper
|
||||
litellm_params:
|
||||
model: whisper-1
|
||||
api_key: sk-*******
|
||||
max_file_size_mb: 1000
|
||||
model_info:
|
||||
mode: audio_transcription
|
||||
tier: free # 👈 Key Change - set `tier`
|
||||
- model_name: gpt-4
|
||||
litellm_params:
|
||||
model: openai/gpt-4o
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
model_info:
|
||||
tier: paid # 👈 Key Change - set `tier`
|
||||
|
||||
general_settings:
|
||||
master_key: sk-1234
|
||||
|
||||
litellm_settings:
|
||||
success_callback: ["langsmith"]
|
||||
|
||||
|
|
|
@ -47,6 +47,7 @@ from litellm.assistants.main import AssistantDeleted
|
|||
from litellm.caching import DualCache, InMemoryCache, RedisCache
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.llms.azure import get_azure_ad_token_from_oidc
|
||||
from litellm.router_strategy.free_paid_tiers import get_deployments_for_tier
|
||||
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
|
||||
from litellm.router_strategy.lowest_cost import LowestCostLoggingHandler
|
||||
from litellm.router_strategy.lowest_latency import LowestLatencyLoggingHandler
|
||||
|
@ -4481,6 +4482,12 @@ class Router:
|
|||
request_kwargs=request_kwargs,
|
||||
)
|
||||
|
||||
# check free / paid tier for each deployment
|
||||
healthy_deployments = await get_deployments_for_tier(
|
||||
request_kwargs=request_kwargs,
|
||||
healthy_deployments=healthy_deployments,
|
||||
)
|
||||
|
||||
if len(healthy_deployments) == 0:
|
||||
if _allowed_model_region is None:
|
||||
_allowed_model_region = "n/a"
|
||||
|
|
69
litellm/router_strategy/free_paid_tiers.py
Normal file
69
litellm/router_strategy/free_paid_tiers.py
Normal file
|
@ -0,0 +1,69 @@
|
|||
"""
|
||||
Use this to route requests between free and paid tiers
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Literal, Optional, TypedDict, Union, cast
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.types.router import DeploymentTypedDict
|
||||
|
||||
|
||||
class ModelInfo(TypedDict):
|
||||
tier: Literal["free", "paid"]
|
||||
|
||||
|
||||
class Deployment(TypedDict):
|
||||
model_info: ModelInfo
|
||||
|
||||
|
||||
async def get_deployments_for_tier(
|
||||
request_kwargs: Optional[Dict[Any, Any]] = None,
|
||||
healthy_deployments: Optional[Union[List[Any], Dict[Any, Any]]] = None,
|
||||
):
|
||||
"""
|
||||
if request_kwargs contains {"metadata": {"tier": "free"}} or {"metadata": {"tier": "paid"}}, then routes the request to free/paid tier models
|
||||
"""
|
||||
if request_kwargs is None:
|
||||
verbose_logger.debug(
|
||||
"get_deployments_for_tier: request_kwargs is None returning healthy_deployments: %s",
|
||||
healthy_deployments,
|
||||
)
|
||||
return healthy_deployments
|
||||
|
||||
verbose_logger.debug("request metadata: %s", request_kwargs.get("metadata"))
|
||||
if "metadata" in request_kwargs:
|
||||
metadata = request_kwargs["metadata"]
|
||||
if "tier" in metadata:
|
||||
selected_tier: Literal["free", "paid"] = metadata["tier"]
|
||||
if healthy_deployments is None:
|
||||
return None
|
||||
|
||||
if selected_tier == "free":
|
||||
# get all deployments where model_info has tier = free
|
||||
free_deployments: List[Any] = []
|
||||
verbose_logger.debug(
|
||||
"Getting deployments in free tier, all_deployments: %s",
|
||||
healthy_deployments,
|
||||
)
|
||||
for deployment in healthy_deployments:
|
||||
typed_deployment = cast(Deployment, deployment)
|
||||
if typed_deployment["model_info"]["tier"] == "free":
|
||||
free_deployments.append(deployment)
|
||||
verbose_logger.debug("free_deployments: %s", free_deployments)
|
||||
return free_deployments
|
||||
|
||||
elif selected_tier == "paid":
|
||||
# get all deployments where model_info has tier = paid
|
||||
paid_deployments: List[Any] = []
|
||||
for deployment in healthy_deployments:
|
||||
typed_deployment = cast(Deployment, deployment)
|
||||
if typed_deployment["model_info"]["tier"] == "paid":
|
||||
paid_deployments.append(deployment)
|
||||
verbose_logger.debug("paid_deployments: %s", paid_deployments)
|
||||
return paid_deployments
|
||||
|
||||
verbose_logger.debug(
|
||||
"no tier found in metadata, returning healthy_deployments: %s",
|
||||
healthy_deployments,
|
||||
)
|
||||
return healthy_deployments
|
60
litellm/tests/test_litellm_pre_call_utils.py
Normal file
60
litellm/tests/test_litellm_pre_call_utils.py
Normal file
|
@ -0,0 +1,60 @@
|
|||
"""
|
||||
Tests litellm pre_call_utils
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import Request
|
||||
from fastapi.routing import APIRoute
|
||||
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
|
||||
from litellm.proxy.proxy_server import ProxyConfig, chat_completion
|
||||
|
||||
load_dotenv()
|
||||
import io
|
||||
import os
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
# this file is to test litellm/proxy
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tier", ["free", "paid"])
|
||||
@pytest.mark.asyncio()
|
||||
async def test_adding_key_tier_to_request_metadata(tier):
|
||||
"""
|
||||
Tests if we can add tier: free/paid from key metadata to the request metadata
|
||||
"""
|
||||
data = {}
|
||||
|
||||
api_route = APIRoute(path="/chat/completions", endpoint=chat_completion)
|
||||
request = Request(
|
||||
{
|
||||
"type": "http",
|
||||
"method": "POST",
|
||||
"route": api_route,
|
||||
"path": api_route.path,
|
||||
"headers": [],
|
||||
}
|
||||
)
|
||||
new_data = await add_litellm_data_to_request(
|
||||
data=data,
|
||||
request=request,
|
||||
user_api_key_dict=UserAPIKeyAuth(metadata={"tier": tier}),
|
||||
proxy_config=ProxyConfig(),
|
||||
)
|
||||
|
||||
print("new_data", new_data)
|
||||
|
||||
assert new_data["metadata"]["tier"] == tier
|
90
litellm/tests/test_router_tiers.py
Normal file
90
litellm/tests/test_router_tiers.py
Normal file
|
@ -0,0 +1,90 @@
|
|||
#### What this tests ####
|
||||
# This tests litellm router
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import logging
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import httpx
|
||||
from dotenv import load_dotenv
|
||||
|
||||
import litellm
|
||||
from litellm import Router
|
||||
from litellm._logging import verbose_logger
|
||||
|
||||
verbose_logger.setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_router_free_paid_tier():
|
||||
"""
|
||||
Pass list of orgs in 1 model definition,
|
||||
expect a unique deployment for each to be created
|
||||
"""
|
||||
router = litellm.Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "gpt-4",
|
||||
"litellm_params": {
|
||||
"model": "gpt-4o",
|
||||
"api_base": "https://exampleopenaiendpoint-production.up.railway.app/",
|
||||
},
|
||||
"model_info": {"tier": "paid", "id": "very-expensive-model"},
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-4",
|
||||
"litellm_params": {
|
||||
"model": "gpt-4o-mini",
|
||||
"api_base": "https://exampleopenaiendpoint-production.up.railway.app/",
|
||||
},
|
||||
"model_info": {"tier": "free", "id": "very-cheap-model"},
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
for _ in range(5):
|
||||
# this should pick model with id == very-cheap-model
|
||||
response = await router.acompletion(
|
||||
model="gpt-4",
|
||||
messages=[{"role": "user", "content": "Tell me a joke."}],
|
||||
metadata={"tier": "free"},
|
||||
)
|
||||
|
||||
print("Response: ", response)
|
||||
|
||||
response_extra_info = response._hidden_params
|
||||
print("response_extra_info: ", response_extra_info)
|
||||
|
||||
assert response_extra_info["model_id"] == "very-cheap-model"
|
||||
|
||||
for _ in range(5):
|
||||
# this should pick model with id == very-cheap-model
|
||||
response = await router.acompletion(
|
||||
model="gpt-4",
|
||||
messages=[{"role": "user", "content": "Tell me a joke."}],
|
||||
metadata={"tier": "paid"},
|
||||
)
|
||||
|
||||
print("Response: ", response)
|
||||
|
||||
response_extra_info = response._hidden_params
|
||||
print("response_extra_info: ", response_extra_info)
|
||||
|
||||
assert response_extra_info["model_id"] == "very-expensive-model"
|
|
@ -91,6 +91,7 @@ class ModelInfo(BaseModel):
|
|||
base_model: Optional[str] = (
|
||||
None # specify if the base model is azure/gpt-3.5-turbo etc for accurate cost tracking
|
||||
)
|
||||
tier: Optional[Literal["free", "paid"]] = None
|
||||
|
||||
def __init__(self, id: Optional[Union[str, int]] = None, **params):
|
||||
if id is None:
|
||||
|
@ -328,6 +329,7 @@ class LiteLLMParamsTypedDict(TypedDict, total=False):
|
|||
class DeploymentTypedDict(TypedDict):
|
||||
model_name: str
|
||||
litellm_params: LiteLLMParamsTypedDict
|
||||
model_info: ModelInfo
|
||||
|
||||
|
||||
SPECIAL_MODEL_INFO_PARAMS = [
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue