Merge pull request #4786 from BerriAI/litellm_use_model_tier_keys

[Feat-Enterprise] Use free/paid tiers for Virtual Keys
This commit is contained in:
Ishaan Jaff 2024-07-18 18:07:09 -07:00 committed by GitHub
commit 4b96cd46b2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 359 additions and 23 deletions

View file

@ -0,0 +1,102 @@
# 💸 Free, Paid Tier Routing
Route Virtual Keys on `free tier` to cheaper models
### 1. Define free, paid tier models on config.yaml
:::info
Requests with `model=gpt-4` will be routed to either `openai/fake` or `openai/gpt-4o` depending on which tier the virtual key is on
:::
```yaml
model_list:
- model_name: gpt-4
litellm_params:
model: openai/fake
api_key: fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app/
model_info:
tier: free # 👈 Key Change - set `tier to paid or free`
- model_name: gpt-4
litellm_params:
model: openai/gpt-4o
api_key: os.environ/OPENAI_API_KEY
model_info:
tier: paid # 👈 Key Change - set `tier to paid or free`
general_settings:
master_key: sk-1234
```
### 2. Create Virtual Keys with pricing `tier=free`
```shell
curl --location 'http://0.0.0.0:4000/key/generate' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"metadata": {"tier": "free"}
}'
```
### 3. Make Request with Key on `Free Tier`
```shell
curl -i http://localhost:4000/v1/chat/completions \
-H "Content-Type: application/json" \
curl -i http://localhost:4000/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer sk-inxzoSurQsjog9gPrVOCcA" \
-d '{
"model": "gpt-4",
"messages": [
{"role": "user", "content": "Hello, Claude gm!"}
]
}'
```
**Expected Response**
If this worked as expected then `x-litellm-model-api-base` should be `https://exampleopenaiendpoint-production.up.railway.app/` in the response headers
```shell
x-litellm-model-api-base: https://exampleopenaiendpoint-production.up.railway.app/
{"id":"chatcmpl-657b750f581240c1908679ed94b31bfe","choices":[{"finish_reason":"stop","index":0,"message":{"content":"\n\nHello there, how may I assist you today?","role":"assistant","tool_calls":null,"function_call":null}}],"created":1677652288,"model":"gpt-3.5-turbo-0125","object":"chat.completion","system_fingerprint":"fp_44709d6fcb","usage":{"completion_tokens":12,"prompt_tokens":9,"total_tokens":21}}%
```
### 4. Create Virtual Keys with pricing `tier=paid`
```shell
curl --location 'http://0.0.0.0:4000/key/generate' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"metadata": {"tier": "paid"}
}'
```
### 5. Make Request with Key on `Paid Tier`
```shell
curl -i http://localhost:4000/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer sk-mnJoeSc6jFjzZr256q-iqA" \
-d '{
"model": "gpt-4",
"messages": [
{"role": "user", "content": "Hello, Claude gm!"}
]
}'
```
**Expected Response**
If this worked as expected then `x-litellm-model-api-base` should be `https://api.openai.com` in the response headers
```shell
x-litellm-model-api-base: https://api.openai.com
{"id":"chatcmpl-9mW75EbJCgwmLcO0M5DmwxpiBgWdc","choices":[{"finish_reason":"stop","index":0,"message":{"content":"Good morning! How can I assist you today?","role":"assistant","tool_calls":null,"function_call":null}}],"created":1721350215,"model":"gpt-4o-2024-05-13","object":"chat.completion","system_fingerprint":"fp_c4e5b6fa31","usage":{"completion_tokens":10,"prompt_tokens":12,"total_tokens":22}}
```

View file

@ -43,11 +43,12 @@ const sidebars = {
"proxy/reliability", "proxy/reliability",
"proxy/cost_tracking", "proxy/cost_tracking",
"proxy/self_serve", "proxy/self_serve",
"proxy/virtual_keys",
"proxy/free_paid_tier",
"proxy/users", "proxy/users",
"proxy/team_budgets", "proxy/team_budgets",
"proxy/customers", "proxy/customers",
"proxy/billing", "proxy/billing",
"proxy/virtual_keys",
"proxy/guardrails", "proxy/guardrails",
"proxy/token_auth", "proxy/token_auth",
"proxy/alerting", "proxy/alerting",

View file

@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Any, Dict, Optional
from fastapi import Request from fastapi import Request
from litellm._logging import verbose_logger, verbose_proxy_logger from litellm._logging import verbose_logger, verbose_proxy_logger
from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy._types import CommonProxyErrors, UserAPIKeyAuth
from litellm.types.utils import SupportedCacheControls from litellm.types.utils import SupportedCacheControls
if TYPE_CHECKING: if TYPE_CHECKING:
@ -95,15 +95,6 @@ async def add_litellm_data_to_request(
cache_dict = parse_cache_control(cache_control_header) cache_dict = parse_cache_control(cache_control_header)
data["ttl"] = cache_dict.get("s-maxage") data["ttl"] = cache_dict.get("s-maxage")
### KEY-LEVEL CACHNG
key_metadata = user_api_key_dict.metadata
if "cache" in key_metadata:
data["cache"] = {}
if isinstance(key_metadata["cache"], dict):
for k, v in key_metadata["cache"].items():
if k in SupportedCacheControls:
data["cache"][k] = v
verbose_proxy_logger.debug("receiving data: %s", data) verbose_proxy_logger.debug("receiving data: %s", data)
_metadata_variable_name = _get_metadata_variable_name(request) _metadata_variable_name = _get_metadata_variable_name(request)
@ -133,6 +124,24 @@ async def add_litellm_data_to_request(
user_api_key_dict, "team_alias", None user_api_key_dict, "team_alias", None
) )
### KEY-LEVEL Contorls
key_metadata = user_api_key_dict.metadata
if "cache" in key_metadata:
data["cache"] = {}
if isinstance(key_metadata["cache"], dict):
for k, v in key_metadata["cache"].items():
if k in SupportedCacheControls:
data["cache"][k] = v
if "tier" in key_metadata:
if premium_user is not True:
verbose_logger.warning(
"Trying to use free/paid tier feature. This will not be applied %s",
CommonProxyErrors.not_premium_user.value,
)
# add request tier to metadata
data[_metadata_variable_name]["tier"] = key_metadata["tier"]
# Team spend, budget - used by prometheus.py # Team spend, budget - used by prometheus.py
data[_metadata_variable_name][ data[_metadata_variable_name][
"user_api_key_team_max_budget" "user_api_key_team_max_budget"

View file

@ -1,23 +1,19 @@
model_list: model_list:
- model_name: fake-openai-endpoint - model_name: gpt-4
litellm_params: litellm_params:
model: openai/fake model: openai/fake
api_key: fake-key api_key: fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app/ api_base: https://exampleopenaiendpoint-production.up.railway.app/
- model_name: gemini-flash
litellm_params:
model: gemini/gemini-1.5-flash
- model_name: whisper
litellm_params:
model: whisper-1
api_key: sk-*******
max_file_size_mb: 1000
model_info: model_info:
mode: audio_transcription tier: free # 👈 Key Change - set `tier`
- model_name: gpt-4
litellm_params:
model: openai/gpt-4o
api_key: os.environ/OPENAI_API_KEY
model_info:
tier: paid # 👈 Key Change - set `tier`
general_settings: general_settings:
master_key: sk-1234 master_key: sk-1234
litellm_settings:
success_callback: ["langsmith"]

View file

@ -47,6 +47,7 @@ from litellm.assistants.main import AssistantDeleted
from litellm.caching import DualCache, InMemoryCache, RedisCache from litellm.caching import DualCache, InMemoryCache, RedisCache
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from litellm.llms.azure import get_azure_ad_token_from_oidc from litellm.llms.azure import get_azure_ad_token_from_oidc
from litellm.router_strategy.free_paid_tiers import get_deployments_for_tier
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
from litellm.router_strategy.lowest_cost import LowestCostLoggingHandler from litellm.router_strategy.lowest_cost import LowestCostLoggingHandler
from litellm.router_strategy.lowest_latency import LowestLatencyLoggingHandler from litellm.router_strategy.lowest_latency import LowestLatencyLoggingHandler
@ -4481,6 +4482,12 @@ class Router:
request_kwargs=request_kwargs, request_kwargs=request_kwargs,
) )
# check free / paid tier for each deployment
healthy_deployments = await get_deployments_for_tier(
request_kwargs=request_kwargs,
healthy_deployments=healthy_deployments,
)
if len(healthy_deployments) == 0: if len(healthy_deployments) == 0:
if _allowed_model_region is None: if _allowed_model_region is None:
_allowed_model_region = "n/a" _allowed_model_region = "n/a"

View file

@ -0,0 +1,69 @@
"""
Use this to route requests between free and paid tiers
"""
from typing import Any, Dict, List, Literal, Optional, TypedDict, Union, cast
from litellm._logging import verbose_logger
from litellm.types.router import DeploymentTypedDict
class ModelInfo(TypedDict):
tier: Literal["free", "paid"]
class Deployment(TypedDict):
model_info: ModelInfo
async def get_deployments_for_tier(
request_kwargs: Optional[Dict[Any, Any]] = None,
healthy_deployments: Optional[Union[List[Any], Dict[Any, Any]]] = None,
):
"""
if request_kwargs contains {"metadata": {"tier": "free"}} or {"metadata": {"tier": "paid"}}, then routes the request to free/paid tier models
"""
if request_kwargs is None:
verbose_logger.debug(
"get_deployments_for_tier: request_kwargs is None returning healthy_deployments: %s",
healthy_deployments,
)
return healthy_deployments
verbose_logger.debug("request metadata: %s", request_kwargs.get("metadata"))
if "metadata" in request_kwargs:
metadata = request_kwargs["metadata"]
if "tier" in metadata:
selected_tier: Literal["free", "paid"] = metadata["tier"]
if healthy_deployments is None:
return None
if selected_tier == "free":
# get all deployments where model_info has tier = free
free_deployments: List[Any] = []
verbose_logger.debug(
"Getting deployments in free tier, all_deployments: %s",
healthy_deployments,
)
for deployment in healthy_deployments:
typed_deployment = cast(Deployment, deployment)
if typed_deployment["model_info"]["tier"] == "free":
free_deployments.append(deployment)
verbose_logger.debug("free_deployments: %s", free_deployments)
return free_deployments
elif selected_tier == "paid":
# get all deployments where model_info has tier = paid
paid_deployments: List[Any] = []
for deployment in healthy_deployments:
typed_deployment = cast(Deployment, deployment)
if typed_deployment["model_info"]["tier"] == "paid":
paid_deployments.append(deployment)
verbose_logger.debug("paid_deployments: %s", paid_deployments)
return paid_deployments
verbose_logger.debug(
"no tier found in metadata, returning healthy_deployments: %s",
healthy_deployments,
)
return healthy_deployments

View file

@ -0,0 +1,60 @@
"""
Tests litellm pre_call_utils
"""
import os
import sys
import traceback
import uuid
from datetime import datetime
from dotenv import load_dotenv
from fastapi import Request
from fastapi.routing import APIRoute
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
from litellm.proxy.proxy_server import ProxyConfig, chat_completion
load_dotenv()
import io
import os
import time
import pytest
# this file is to test litellm/proxy
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
@pytest.mark.parametrize("tier", ["free", "paid"])
@pytest.mark.asyncio()
async def test_adding_key_tier_to_request_metadata(tier):
"""
Tests if we can add tier: free/paid from key metadata to the request metadata
"""
data = {}
api_route = APIRoute(path="/chat/completions", endpoint=chat_completion)
request = Request(
{
"type": "http",
"method": "POST",
"route": api_route,
"path": api_route.path,
"headers": [],
}
)
new_data = await add_litellm_data_to_request(
data=data,
request=request,
user_api_key_dict=UserAPIKeyAuth(metadata={"tier": tier}),
proxy_config=ProxyConfig(),
)
print("new_data", new_data)
assert new_data["metadata"]["tier"] == tier

View file

@ -0,0 +1,90 @@
#### What this tests ####
# This tests litellm router
import asyncio
import os
import sys
import time
import traceback
import openai
import pytest
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import logging
import os
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
from dotenv import load_dotenv
import litellm
from litellm import Router
from litellm._logging import verbose_logger
verbose_logger.setLevel(logging.DEBUG)
load_dotenv()
@pytest.mark.asyncio()
async def test_router_free_paid_tier():
"""
Pass list of orgs in 1 model definition,
expect a unique deployment for each to be created
"""
router = litellm.Router(
model_list=[
{
"model_name": "gpt-4",
"litellm_params": {
"model": "gpt-4o",
"api_base": "https://exampleopenaiendpoint-production.up.railway.app/",
},
"model_info": {"tier": "paid", "id": "very-expensive-model"},
},
{
"model_name": "gpt-4",
"litellm_params": {
"model": "gpt-4o-mini",
"api_base": "https://exampleopenaiendpoint-production.up.railway.app/",
},
"model_info": {"tier": "free", "id": "very-cheap-model"},
},
]
)
for _ in range(5):
# this should pick model with id == very-cheap-model
response = await router.acompletion(
model="gpt-4",
messages=[{"role": "user", "content": "Tell me a joke."}],
metadata={"tier": "free"},
)
print("Response: ", response)
response_extra_info = response._hidden_params
print("response_extra_info: ", response_extra_info)
assert response_extra_info["model_id"] == "very-cheap-model"
for _ in range(5):
# this should pick model with id == very-cheap-model
response = await router.acompletion(
model="gpt-4",
messages=[{"role": "user", "content": "Tell me a joke."}],
metadata={"tier": "paid"},
)
print("Response: ", response)
response_extra_info = response._hidden_params
print("response_extra_info: ", response_extra_info)
assert response_extra_info["model_id"] == "very-expensive-model"

View file

@ -91,6 +91,7 @@ class ModelInfo(BaseModel):
base_model: Optional[str] = ( base_model: Optional[str] = (
None # specify if the base model is azure/gpt-3.5-turbo etc for accurate cost tracking None # specify if the base model is azure/gpt-3.5-turbo etc for accurate cost tracking
) )
tier: Optional[Literal["free", "paid"]] = None
def __init__(self, id: Optional[Union[str, int]] = None, **params): def __init__(self, id: Optional[Union[str, int]] = None, **params):
if id is None: if id is None:
@ -328,6 +329,7 @@ class LiteLLMParamsTypedDict(TypedDict, total=False):
class DeploymentTypedDict(TypedDict): class DeploymentTypedDict(TypedDict):
model_name: str model_name: str
litellm_params: LiteLLMParamsTypedDict litellm_params: LiteLLMParamsTypedDict
model_info: ModelInfo
SPECIAL_MODEL_INFO_PARAMS = [ SPECIAL_MODEL_INFO_PARAMS = [