exposing a litellm.max budget

This commit is contained in:
Krrish Dholakia 2023-09-14 14:19:51 -07:00
parent 970822c79a
commit 73a084c19c
8 changed files with 45 additions and 5 deletions

View file

@ -36,7 +36,8 @@ caching = False # deprecated son
caching_with_models = False # if you want the caching key to be model + prompt # deprecated soon caching_with_models = False # if you want the caching key to be model + prompt # deprecated soon
cache: Optional[Cache] = None # cache object cache: Optional[Cache] = None # cache object
model_alias_map: Dict[str, str] = {} model_alias_map: Dict[str, str] = {}
max_budget = None # set the max budget across all providers
_current_cost = 0 # private variable, used if max budget is set
############################################# #############################################
def get_model_cost_map(): def get_model_cost_map():

View file

@ -108,3 +108,10 @@ class OpenAIError(OpenAIError): # type: ignore
code=original_exception.code, code=original_exception.code,
) )
self.llm_provider = "openai" self.llm_provider = "openai"
class BudgetExceededError(Exception):
def __init__(self, current_cost, max_budget):
self.current_cost = current_cost
self.max_budget = max_budget
message = f"Budget has been exceeded! Current cost: {current_cost}, Max budget: {max_budget}"
super().__init__(message)

View file

@ -10,7 +10,7 @@ sys.path.insert(
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
import litellm import litellm
litellm.set_verbose = True litellm.set_verbose = True
from litellm import BudgetManager, completion from litellm import completion, BudgetManager
budget_manager = BudgetManager(project_name="test_project", client_type="hosted") budget_manager = BudgetManager(project_name="test_project", client_type="hosted")

View file

@ -0,0 +1,21 @@
#### What this tests ####
# This tests calling litellm.max_budget by making back-to-back gpt-4 calls
# commenting out this test for circle ci, as it causes other tests to fail, since litellm.max_budget would impact other litellm imports
# import sys, os, json
# import traceback
# import pytest
# sys.path.insert(
# 0, os.path.abspath("../..")
# ) # Adds the parent directory to the system path
# import litellm
# litellm.set_verbose = True
# from litellm import completion
# litellm.max_budget = 0.001 # sets a max budget of $0.001
# messages = [{"role": "user", "content": "Hey, how's it going"}]
# completion(model="gpt-4", messages=messages)
# completion(model="gpt-4", messages=messages)
# print(litellm._current_cost)

View file

@ -31,7 +31,8 @@ from .exceptions import (
ContextWindowExceededError, ContextWindowExceededError,
Timeout, Timeout,
APIConnectionError, APIConnectionError,
APIError APIError,
BudgetExceededError
) )
from typing import List, Dict, Union, Optional from typing import List, Dict, Union, Optional
from .caching import Cache from .caching import Cache
@ -542,6 +543,12 @@ def client(original_function):
try: try:
logging_obj = function_setup(start_time, *args, **kwargs) logging_obj = function_setup(start_time, *args, **kwargs)
kwargs["litellm_logging_obj"] = logging_obj kwargs["litellm_logging_obj"] = logging_obj
# [OPTIONAL] CHECK BUDGET
if litellm.max_budget:
if litellm._current_cost > litellm.max_budget:
raise BudgetExceededError(current_cost=litellm._current_cost, max_budget=litellm.max_budget)
# [OPTIONAL] CHECK CACHE # [OPTIONAL] CHECK CACHE
# remove this after deprecating litellm.caching # remove this after deprecating litellm.caching
if (litellm.caching or litellm.caching_with_models) and litellm.cache is None: if (litellm.caching or litellm.caching_with_models) and litellm.cache is None:
@ -567,6 +574,10 @@ def client(original_function):
if litellm.caching or litellm.caching_with_models or litellm.cache != None: # user init a cache object if litellm.caching or litellm.caching_with_models or litellm.cache != None: # user init a cache object
litellm.cache.add_cache(result, *args, **kwargs) litellm.cache.add_cache(result, *args, **kwargs)
# [OPTIONAL] UPDATE BUDGET
if litellm.max_budget:
litellm._current_cost += litellm.completion_cost(completion_response=result)
# [OPTIONAL] Return LiteLLM call_id # [OPTIONAL] Return LiteLLM call_id
if litellm.use_client == True: if litellm.use_client == True:
result['litellm_call_id'] = litellm_call_id result['litellm_call_id'] = litellm_call_id

View file

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "litellm" name = "litellm"
version = "0.1.628" version = "0.1.629"
description = "Library to easily interface with LLM API providers" description = "Library to easily interface with LLM API providers"
authors = ["BerriAI"] authors = ["BerriAI"]
license = "MIT License" license = "MIT License"