mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
Support temporary budget increases on keys (#7754)
* fix(gpt_transformation.py): fix response_format translation check for 4o models Fixes https://github.com/BerriAI/litellm/issues/7616 * feat(key_management_endpoints.py): support 'temp_budget_increase' and 'temp_budget_expiry' fields Allow proxy admin to grant temporary budget increases to keys * fix(proxy/_types.py): enforce temp_budget_increase and temp_budget_expiry are always passed together * feat(user_api_key_auth.py): initial working temp budget increase logic ensures key budget exceeded error checks for temp budget in key metadata * feat(proxy_server.py): return the key max budget and key spend in the response headers Allows clientside user to know their remaining limits * test: add unit testing for new proxy utils Ensures new key budget is correctly handled * docs(temporary_budget_increase.md): add doc on temporary budget increase * fix(utils.py): remove 3.5 from response_format check for now not all azure 3.5 models support response_format * fix(user_api_key_auth.py): return valid user api key auth object on all paths
This commit is contained in:
parent
000d3152a8
commit
d7a13ad561
11 changed files with 259 additions and 52 deletions
74
docs/my-website/docs/proxy/temporary_budget_increase.md
Normal file
74
docs/my-website/docs/proxy/temporary_budget_increase.md
Normal file
|
@ -0,0 +1,74 @@
|
|||
# ✨ Temporary Budget Increase
|
||||
|
||||
Set temporary budget increase for a LiteLLM Virtual Key. Use this if you get asked to increase the budget for a key temporarily.
|
||||
|
||||
|
||||
| Heirarchy | Supported |
|
||||
|-----------|-----------|
|
||||
| LiteLLM Virtual Key | ✅ |
|
||||
| User | ❌ |
|
||||
| Team | ❌ |
|
||||
| Organization | ❌ |
|
||||
|
||||
:::note
|
||||
|
||||
✨ Temporary Budget Increase is a LiteLLM Enterprise feature.
|
||||
|
||||
[Enterprise Pricing](https://www.litellm.ai/#pricing)
|
||||
|
||||
[Get free 7-day trial key](https://www.litellm.ai/#trial)
|
||||
|
||||
:::
|
||||
|
||||
|
||||
1. Create a LiteLLM Virtual Key with budget
|
||||
|
||||
```bash
|
||||
curl -L -X POST 'http://localhost:4000/key/generate' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-H 'Authorization: Bearer LITELLM_MASTER_KEY' \
|
||||
-d '{
|
||||
"max_budget": 0.0000001
|
||||
}'
|
||||
```
|
||||
|
||||
Expected response:
|
||||
|
||||
```json
|
||||
{
|
||||
"key": "sk-your-new-key"
|
||||
}
|
||||
```
|
||||
|
||||
2. Update key with temporary budget increase
|
||||
|
||||
```bash
|
||||
curl -L -X POST 'http://localhost:4000/key/update' \
|
||||
-H 'Authorization: Bearer LITELLM_MASTER_KEY' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{
|
||||
"key": "sk-your-new-key",
|
||||
"temp_budget_increase": 100,
|
||||
"temp_budget_expiry": "2025-01-15"
|
||||
}'
|
||||
```
|
||||
|
||||
3. Test it!
|
||||
|
||||
```bash
|
||||
curl -L -X POST 'http://localhost:4000/chat/completions' \
|
||||
-H 'Authorization: Bearer sk-your-new-key' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{
|
||||
"model": "gpt-4o",
|
||||
"messages": [{"role": "user", "content": "Hello, world!"}]
|
||||
}'
|
||||
```
|
||||
|
||||
Expected Response Header:
|
||||
|
||||
```
|
||||
x-litellm-key-max-budget: 100.0000001
|
||||
```
|
||||
|
||||
|
|
@ -107,7 +107,7 @@ const sidebars = {
|
|||
{
|
||||
type: "category",
|
||||
label: "Budgets + Rate Limits",
|
||||
items: ["proxy/users", "proxy/rate_limit_tiers", "proxy/team_budgets", "proxy/customers"],
|
||||
items: ["proxy/users", "proxy/temporary_budget_increase", "proxy/rate_limit_tiers", "proxy/team_budgets", "proxy/customers"],
|
||||
},
|
||||
{
|
||||
type: "link",
|
||||
|
|
|
@ -8,6 +8,7 @@ from litellm.litellm_core_utils.prompt_templates.factory import (
|
|||
)
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.types.utils import ModelResponse
|
||||
from litellm.utils import supports_response_schema
|
||||
|
||||
from ....exceptions import UnsupportedParamsError
|
||||
from ....types.llms.openai import (
|
||||
|
@ -105,6 +106,19 @@ class AzureOpenAIConfig(BaseConfig):
|
|||
"parallel_tool_calls",
|
||||
]
|
||||
|
||||
def _is_response_format_supported_model(self, model: str) -> bool:
|
||||
"""
|
||||
- all 4o models are supported
|
||||
- check if 'supports_response_format' is True from get_model_info
|
||||
- [TODO] support smart retries for 3.5 models (some supported, some not)
|
||||
"""
|
||||
if "4o" in model:
|
||||
return True
|
||||
elif supports_response_schema(model):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
|
@ -176,10 +190,14 @@ class AzureOpenAIConfig(BaseConfig):
|
|||
- You should set tool_choice (see Forcing tool use) to instruct the model to explicitly use that tool
|
||||
- Remember that the model will pass the input to the tool, so the name of the tool and description should be from the model’s perspective.
|
||||
"""
|
||||
_is_response_format_supported_model = (
|
||||
self._is_response_format_supported_model(model)
|
||||
)
|
||||
if json_schema is not None and (
|
||||
(api_version_year <= "2024" and api_version_month < "08")
|
||||
or "gpt-4o" not in model
|
||||
): # azure api version "2024-08-01-preview" onwards supports 'json_schema' only for gpt-4o
|
||||
or not _is_response_format_supported_model
|
||||
): # azure api version "2024-08-01-preview" onwards supports 'json_schema' only for gpt-4o/3.5 models
|
||||
|
||||
_tool_choice = ChatCompletionToolChoiceObjectParam(
|
||||
type="function",
|
||||
function=ChatCompletionToolChoiceFunctionParam(
|
||||
|
|
|
@ -1,43 +1,7 @@
|
|||
model_list:
|
||||
# At least one model must exist for the proxy to start.
|
||||
- model_name: gpt-4o
|
||||
- model_name: "gpt-4o"
|
||||
litellm_params:
|
||||
model: gpt-4o
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
# timeout: 0.1 # timeout in (seconds)
|
||||
# stream_timeout: 0.01 # timeout for stream requests (seconds)
|
||||
- model_name: anthropic.claude-3-5-sonnet-20241022-v2:0
|
||||
litellm_params:
|
||||
model: bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0
|
||||
- model_name: nova-lite
|
||||
litellm_params:
|
||||
model: bedrock/us.amazon.nova-lite-v1:0
|
||||
- model_name: llama3-2-11b-instruct-v1:0
|
||||
litellm_params:
|
||||
model: bedrock/us.meta.llama3-2-11b-instruct-v1:0
|
||||
- model_name: gpt-4o-bad
|
||||
litellm_params:
|
||||
model: gpt-4o
|
||||
api_key: bad
|
||||
- model_name: "bedrock/*"
|
||||
litellm_params:
|
||||
model: "bedrock/*"
|
||||
- model_name: "openai/*"
|
||||
litellm_params:
|
||||
model: "openai/*"
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
general_settings:
|
||||
store_model_in_db: true
|
||||
disable_prisma_schema_update: true
|
||||
# master_key: os.environ/LITELLM_MASTER_KEY
|
||||
litellm_settings:
|
||||
fallbacks: [{"gpt-4o-bad": ["gpt-4o"]}] #, {"gpt-4o": ["nova-lite"]}]
|
||||
request_timeout: 600 # raise Timeout error if call takes longer than 600 seconds. Default value is 6000seconds if not set
|
||||
# set_verbose: false # Switch off Debug Logging, ensure your logs do not have any debugging on
|
||||
# json_logs: true # Get debug logs in json format
|
||||
ssl_verify: true
|
||||
callbacks: ["prometheus"]
|
||||
service_callback: ["prometheus_system"]
|
||||
turn_off_message_logging: true # turn off messages in otel
|
||||
#callbacks: ["langfuse"]
|
||||
redact_user_api_key_info: true
|
||||
model: "azure/gpt-4o"
|
||||
api_key: os.environ/AZURE_API_KEY
|
||||
api_base: os.environ/AZURE_API_BASE
|
||||
|
|
@ -688,6 +688,17 @@ class UpdateKeyRequest(KeyRequestBase):
|
|||
duration: Optional[str] = None
|
||||
spend: Optional[float] = None
|
||||
metadata: Optional[dict] = None
|
||||
temp_budget_increase: Optional[float] = None
|
||||
temp_budget_expiry: Optional[datetime] = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_temp_budget(self) -> "UpdateKeyRequest":
|
||||
if self.temp_budget_increase is not None or self.temp_budget_expiry is not None:
|
||||
if self.temp_budget_increase is None or self.temp_budget_expiry is None:
|
||||
raise ValueError(
|
||||
"temp_budget_increase and temp_budget_expiry must be set together"
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
class RegenerateKeyRequest(GenerateKeyRequest):
|
||||
|
@ -2229,6 +2240,8 @@ LiteLLM_ManagementEndpoint_MetadataFields = [
|
|||
"guardrails",
|
||||
"tags",
|
||||
"enforced_params",
|
||||
"temp_budget_increase",
|
||||
"temp_budget_expiry",
|
||||
]
|
||||
|
||||
|
||||
|
|
|
@ -811,7 +811,10 @@ async def user_api_key_auth( # noqa: PLR0915
|
|||
valid_token.allowed_model_region = end_user_params.get(
|
||||
"allowed_model_region"
|
||||
)
|
||||
|
||||
# update key budget with temp budget increase
|
||||
valid_token = _update_key_budget_with_temp_budget_increase(
|
||||
valid_token
|
||||
) # updating it here, allows all downstream reporting / checks to use the updated budget
|
||||
except Exception:
|
||||
verbose_logger.info(
|
||||
"litellm.proxy.auth.user_api_key_auth.py::user_api_key_auth() - Unable to find token={} in cache or `LiteLLM_VerificationTokenTable`. Defaulting 'valid_token' to None'".format(
|
||||
|
@ -1016,6 +1019,7 @@ async def user_api_key_auth( # noqa: PLR0915
|
|||
current_cost=valid_token.spend,
|
||||
max_budget=valid_token.max_budget,
|
||||
)
|
||||
|
||||
if valid_token.soft_budget and valid_token.spend >= valid_token.soft_budget:
|
||||
verbose_proxy_logger.debug(
|
||||
"Crossed Soft Budget for token %s, spend %s, soft_budget %s",
|
||||
|
@ -1383,3 +1387,25 @@ def get_api_key_from_custom_header(
|
|||
f"No LiteLLM Virtual Key pass. Please set header={custom_litellm_key_header_name}: Bearer <api_key>"
|
||||
)
|
||||
return api_key
|
||||
|
||||
|
||||
def _get_temp_budget_increase(valid_token: UserAPIKeyAuth):
|
||||
valid_token_metadata = valid_token.metadata
|
||||
if (
|
||||
"temp_budget_increase" in valid_token_metadata
|
||||
and "temp_budget_expiry" in valid_token_metadata
|
||||
):
|
||||
expiry = datetime.fromisoformat(valid_token_metadata["temp_budget_expiry"])
|
||||
if expiry > datetime.now():
|
||||
return valid_token_metadata["temp_budget_increase"]
|
||||
return None
|
||||
|
||||
|
||||
def _update_key_budget_with_temp_budget_increase(
|
||||
valid_token: UserAPIKeyAuth,
|
||||
) -> UserAPIKeyAuth:
|
||||
if valid_token.max_budget is None:
|
||||
return valid_token
|
||||
temp_budget_increase = _get_temp_budget_increase(valid_token) or 0.0
|
||||
valid_token.max_budget = valid_token.max_budget + temp_budget_increase
|
||||
return valid_token
|
||||
|
|
|
@ -558,6 +558,9 @@ def prepare_metadata_fields(
|
|||
try:
|
||||
for k, v in data_json.items():
|
||||
if k in LiteLLM_ManagementEndpoint_MetadataFields:
|
||||
if isinstance(v, datetime):
|
||||
casted_metadata[k] = v.isoformat()
|
||||
else:
|
||||
casted_metadata[k] = v
|
||||
|
||||
except Exception as e:
|
||||
|
@ -658,6 +661,8 @@ async def update_key_fn(
|
|||
- blocked: Optional[bool] - Whether the key is blocked
|
||||
- aliases: Optional[dict] - Model aliases for the key - [Docs](https://litellm.vercel.app/docs/proxy/virtual_keys#model-aliases)
|
||||
- config: Optional[dict] - [DEPRECATED PARAM] Key-specific config.
|
||||
- temp_budget_increase: Optional[float] - Temporary budget increase for the key (Enterprise only).
|
||||
- temp_budget_expiry: Optional[str] - Expiry time for the temporary budget increase (Enterprise only).
|
||||
|
||||
Example:
|
||||
```bash
|
||||
|
@ -707,9 +712,8 @@ async def update_key_fn(
|
|||
existing_key_token=existing_key_row.token,
|
||||
)
|
||||
|
||||
response = await prisma_client.update_data(
|
||||
token=key, data={**non_default_values, "token": key}
|
||||
)
|
||||
_data = {**non_default_values, "token": key}
|
||||
response = await prisma_client.update_data(token=key, data=_data)
|
||||
|
||||
# Delete - key from cache, since it's been updated!
|
||||
# key updated - a new model could have been added to this key. it should not block requests after this is done
|
||||
|
|
|
@ -747,6 +747,8 @@ def get_custom_headers(
|
|||
"x-litellm-response-cost": str(response_cost),
|
||||
"x-litellm-key-tpm-limit": str(user_api_key_dict.tpm_limit),
|
||||
"x-litellm-key-rpm-limit": str(user_api_key_dict.rpm_limit),
|
||||
"x-litellm-key-max-budget": str(user_api_key_dict.max_budget),
|
||||
"x-litellm-key-spend": str(user_api_key_dict.spend),
|
||||
"x-litellm-fastest_response_batch_completion": (
|
||||
str(fastest_response_batch_completion)
|
||||
if fastest_response_batch_completion is not None
|
||||
|
|
|
@ -1732,9 +1732,15 @@ def supports_response_schema(
|
|||
Does not raise error. Defaults to 'False'. Outputs logging.error.
|
||||
"""
|
||||
## GET LLM PROVIDER ##
|
||||
try:
|
||||
model, custom_llm_provider, _, _ = get_llm_provider(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.debug(
|
||||
f"Model not found or error in checking response schema support. You passed model={model}, custom_llm_provider={custom_llm_provider}. Error: {str(e)}"
|
||||
)
|
||||
return False
|
||||
|
||||
# providers that globally support response schema
|
||||
PROVIDERS_GLOBALLY_SUPPORT_RESPONSE_SCHEMA = [
|
||||
|
|
|
@ -248,3 +248,38 @@ def test_get_azure_ad_token_from_username_password(
|
|||
|
||||
# Verify the result is the mock token provider
|
||||
assert result == mock_token_provider
|
||||
|
||||
|
||||
def test_azure_openai_gpt_4o_naming(monkeypatch):
|
||||
from openai import AzureOpenAI
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
monkeypatch.setenv("AZURE_API_VERSION", "2024-10-21")
|
||||
|
||||
client = AzureOpenAI(
|
||||
api_key="test-api-key",
|
||||
base_url="https://my-endpoint-sweden-berri992.openai.azure.com",
|
||||
api_version="2023-12-01-preview",
|
||||
)
|
||||
|
||||
class ResponseFormat(BaseModel):
|
||||
|
||||
number: str = Field(description="total number of days in a week")
|
||||
days: list[str] = Field(description="name of days in a week")
|
||||
|
||||
with patch.object(client.chat.completions.with_raw_response, "create") as mock_post:
|
||||
try:
|
||||
completion(
|
||||
model="azure/gpt4o",
|
||||
messages=[{"role": "user", "content": "Hello world"}],
|
||||
response_format=ResponseFormat,
|
||||
client=client,
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
mock_post.assert_called_once()
|
||||
|
||||
print(mock_post.call_args.kwargs)
|
||||
|
||||
assert "tool_calls" not in mock_post.call_args.kwargs
|
||||
|
|
|
@ -1382,3 +1382,68 @@ def test_custom_openid_response():
|
|||
jwt_handler=jwt_handler,
|
||||
)
|
||||
assert resp.team_ids == ["/test-group"]
|
||||
|
||||
|
||||
def test_update_key_request_validation():
|
||||
"""
|
||||
Ensures that the UpdateKeyRequest model validates the temp_budget_increase and temp_budget_expiry fields together
|
||||
"""
|
||||
from litellm.proxy._types import UpdateKeyRequest
|
||||
|
||||
with pytest.raises(Exception):
|
||||
UpdateKeyRequest(
|
||||
key="test_key",
|
||||
temp_budget_increase=100,
|
||||
)
|
||||
|
||||
with pytest.raises(Exception):
|
||||
UpdateKeyRequest(
|
||||
key="test_key",
|
||||
temp_budget_expiry="2024-01-20T00:00:00Z",
|
||||
)
|
||||
|
||||
UpdateKeyRequest(
|
||||
key="test_key",
|
||||
temp_budget_increase=100,
|
||||
temp_budget_expiry="2024-01-20T00:00:00Z",
|
||||
)
|
||||
|
||||
|
||||
def test_get_temp_budget_increase():
|
||||
from litellm.proxy.auth.user_api_key_auth import _get_temp_budget_increase
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
expiry = datetime.now() + timedelta(days=1)
|
||||
expiry_in_isoformat = expiry.isoformat()
|
||||
|
||||
valid_token = UserAPIKeyAuth(
|
||||
max_budget=100,
|
||||
spend=0,
|
||||
metadata={
|
||||
"temp_budget_increase": 100,
|
||||
"temp_budget_expiry": expiry_in_isoformat,
|
||||
},
|
||||
)
|
||||
assert _get_temp_budget_increase(valid_token) == 100
|
||||
|
||||
|
||||
def test_update_key_budget_with_temp_budget_increase():
|
||||
from litellm.proxy.auth.user_api_key_auth import (
|
||||
_update_key_budget_with_temp_budget_increase,
|
||||
)
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
expiry = datetime.now() + timedelta(days=1)
|
||||
expiry_in_isoformat = expiry.isoformat()
|
||||
|
||||
valid_token = UserAPIKeyAuth(
|
||||
max_budget=100,
|
||||
spend=0,
|
||||
metadata={
|
||||
"temp_budget_increase": 100,
|
||||
"temp_budget_expiry": expiry_in_isoformat,
|
||||
},
|
||||
)
|
||||
assert _update_key_budget_with_temp_budget_increase(valid_token).max_budget == 200
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue