diff --git a/litellm/__init__.py b/litellm/__init__.py index 8b8a50e462..03114b78e0 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -36,7 +36,8 @@ caching = False # deprecated son caching_with_models = False # if you want the caching key to be model + prompt # deprecated soon cache: Optional[Cache] = None # cache object model_alias_map: Dict[str, str] = {} - +max_budget = None # set the max budget across all providers +_current_cost = 0 # private variable, used if max budget is set ############################################# def get_model_cost_map(): diff --git a/litellm/__pycache__/__init__.cpython-311.pyc b/litellm/__pycache__/__init__.cpython-311.pyc index dabe860d76..9d7f4f02b4 100644 Binary files a/litellm/__pycache__/__init__.cpython-311.pyc and b/litellm/__pycache__/__init__.cpython-311.pyc differ diff --git a/litellm/__pycache__/utils.cpython-311.pyc b/litellm/__pycache__/utils.cpython-311.pyc index 76b82a67db..d323a9eaef 100644 Binary files a/litellm/__pycache__/utils.cpython-311.pyc and b/litellm/__pycache__/utils.cpython-311.pyc differ diff --git a/litellm/exceptions.py b/litellm/exceptions.py index b8ca92d6dc..4e0102b1de 100644 --- a/litellm/exceptions.py +++ b/litellm/exceptions.py @@ -107,4 +107,11 @@ class OpenAIError(OpenAIError): # type: ignore headers=original_exception.headers, code=original_exception.code, ) - self.llm_provider = "openai" \ No newline at end of file + self.llm_provider = "openai" + +class BudgetExceededError(Exception): + def __init__(self, current_cost, max_budget): + self.current_cost = current_cost + self.max_budget = max_budget + message = f"Budget has been exceeded! Current cost: {current_cost}, Max budget: {max_budget}" + super().__init__(message) \ No newline at end of file diff --git a/litellm/tests/test_budget_manager.py b/litellm/tests/test_budget_manager.py index bc5738aa03..e47682d9b4 100644 --- a/litellm/tests/test_budget_manager.py +++ b/litellm/tests/test_budget_manager.py @@ -10,7 +10,7 @@ sys.path.insert( ) # Adds the parent directory to the system path import litellm litellm.set_verbose = True -from litellm import BudgetManager, completion +from litellm import completion, BudgetManager budget_manager = BudgetManager(project_name="test_project", client_type="hosted") diff --git a/litellm/tests/test_litellm_max_budget.py b/litellm/tests/test_litellm_max_budget.py new file mode 100644 index 0000000000..15ca48efe8 --- /dev/null +++ b/litellm/tests/test_litellm_max_budget.py @@ -0,0 +1,21 @@ +#### What this tests #### +# This tests calling litellm.max_budget by making back-to-back gpt-4 calls +# commenting out this test for circle ci, as it causes other tests to fail, since litellm.max_budget would impact other litellm imports +# import sys, os, json +# import traceback +# import pytest + +# sys.path.insert( +# 0, os.path.abspath("../..") +# ) # Adds the parent directory to the system path +# import litellm +# litellm.set_verbose = True +# from litellm import completion + +# litellm.max_budget = 0.001 # sets a max budget of $0.001 + +# messages = [{"role": "user", "content": "Hey, how's it going"}] +# completion(model="gpt-4", messages=messages) +# completion(model="gpt-4", messages=messages) +# print(litellm._current_cost) + diff --git a/litellm/utils.py b/litellm/utils.py index e0c29896ea..69f53a8776 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -31,7 +31,8 @@ from .exceptions import ( ContextWindowExceededError, Timeout, APIConnectionError, - APIError + APIError, + BudgetExceededError ) from typing import List, Dict, Union, Optional from .caching import Cache @@ -542,6 +543,12 @@ def client(original_function): try: logging_obj = function_setup(start_time, *args, **kwargs) kwargs["litellm_logging_obj"] = logging_obj + + # [OPTIONAL] CHECK BUDGET + if litellm.max_budget: + if litellm._current_cost > litellm.max_budget: + raise BudgetExceededError(current_cost=litellm._current_cost, max_budget=litellm.max_budget) + # [OPTIONAL] CHECK CACHE # remove this after deprecating litellm.caching if (litellm.caching or litellm.caching_with_models) and litellm.cache is None: @@ -567,6 +574,10 @@ def client(original_function): if litellm.caching or litellm.caching_with_models or litellm.cache != None: # user init a cache object litellm.cache.add_cache(result, *args, **kwargs) + # [OPTIONAL] UPDATE BUDGET + if litellm.max_budget: + litellm._current_cost += litellm.completion_cost(completion_response=result) + # [OPTIONAL] Return LiteLLM call_id if litellm.use_client == True: result['litellm_call_id'] = litellm_call_id diff --git a/pyproject.toml b/pyproject.toml index e73a67509b..e19396fd8c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "litellm" -version = "0.1.628" +version = "0.1.629" description = "Library to easily interface with LLM API providers" authors = ["BerriAI"] license = "MIT License"