add api manager

This commit is contained in:
Krrish Dholakia 2023-09-09 15:55:36 -07:00
parent 15c40625c6
commit a39756bfda
10 changed files with 110 additions and 19 deletions

View file

@ -35,6 +35,11 @@ caching = False # deprecated son
caching_with_models = False # if you want the caching key to be model + prompt # deprecated soon caching_with_models = False # if you want the caching key to be model + prompt # deprecated soon
cache: Optional[Cache] = None # cache object cache: Optional[Cache] = None # cache object
model_alias_map: Dict[str, str] = {} model_alias_map: Dict[str, str] = {}
####### APIManager ###################
from .apimanager import APIManager
apiManager = APIManager()
def get_model_cost_map(): def get_model_cost_map():
url = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json" url = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"

25
litellm/apimanager.py Normal file
View file

@ -0,0 +1,25 @@
import litellm
from litellm.utils import ModelResponse
class APIManager:
def __init__(self):
self.user_dict = {}
def create_budget(self, total_budget: float, user: str):
self.user_dict[user] = {"total_budget": total_budget}
return self.user_dict[user]
def projected_cost(self, model: str, messages: list, user: str):
text = "".join(message["content"] for message in messages)
prompt_tokens = litellm.token_counter(model=model, text=text)
prompt_cost, _ = litellm.cost_per_token(model=model, prompt_tokens=prompt_tokens, completion_tokens=0)
current_cost = self.user_dict[user].get("current_cost", 0)
projected_cost = prompt_cost + current_cost
return projected_cost
def get_total_budget(self, user: str):
return self.user_dict[user]["total_budget"]
def update_cost(self, completion_obj: ModelResponse, user: str):
cost = litellm.completion_cost(completion_response=completion_obj)
self.user_dict[user]["current_cost"] = cost + self.user_dict[user].get("current_cost", 0)
return self.user_dict[user]["current_cost"]

View file

@ -168,7 +168,7 @@ def completion(
model_alias_map=litellm.model_alias_map, model_alias_map=litellm.model_alias_map,
completion_call_id=id completion_call_id=id
) )
logging.update_environment_variables(model=model, optional_params=optional_params, litellm_params=litellm_params) logging.update_environment_variables(model=model, user=user, optional_params=optional_params, litellm_params=litellm_params)
if custom_llm_provider == "azure": if custom_llm_provider == "azure":
# azure configs # azure configs
openai.api_type = "azure" openai.api_type = "azure"

View file

@ -0,0 +1,60 @@
#### What this tests ####
# This tests calling batch_completions by running 100 messages together
import sys, os
import traceback
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import litellm
from litellm import apiManager, completion
litellm.success_callback = ["api_manager"]
## Scenario 1: User budget enough to make call
def test_user_budget_enough():
user = "1234"
# create a budget for a user
apiManager.create_budget(total_budget=10, user=user)
# check if a given call can be made
data = {
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "Hey, how's it going?"}]
}
model = data["model"]
messages = data["messages"]
if apiManager.projected_cost(**data, user=user) <= apiManager.get_total_budget(user):
response = completion(**data)
else:
response = "Sorry - no budget!"
print(f"response: {response}")
## Scenario 2: User budget not enough to make call
def test_user_budget_not_enough():
user = "12345"
# create a budget for a user
apiManager.create_budget(total_budget=0, user=user)
# check if a given call can be made
data = {
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "Hey, how's it going?"}]
}
model = data["model"]
messages = data["messages"]
projectedCost = apiManager.projected_cost(**data, user=user)
print(f"projectedCost: {projectedCost}")
totalBudget = apiManager.get_total_budget(user)
print(f"totalBudget: {totalBudget}")
if projectedCost <= totalBudget:
response = completion(**data)
else:
response = "Sorry - no budget!"
print(f"response: {response}")
test_user_budget_not_enough()

View file

@ -117,9 +117,6 @@ def invalid_auth(model): # set the model key to an invalid key, depending on th
os.environ["TOGETHERAI_API_KEY"] = temporary_key os.environ["TOGETHERAI_API_KEY"] = temporary_key
return return
invalid_auth(test_model)
# Test 3: Rate Limit Errors # Test 3: Rate Limit Errors
# def test_model_call(model): # def test_model_call(model):
# try: # try:

View file

@ -144,9 +144,10 @@ class Logging:
self.litellm_call_id = litellm_call_id self.litellm_call_id = litellm_call_id
self.function_id = function_id self.function_id = function_id
def update_environment_variables(self, model, optional_params, litellm_params): def update_environment_variables(self, model, user, optional_params, litellm_params):
self.optional_params = optional_params self.optional_params = optional_params
self.model = model self.model = model
self.user = user
self.litellm_params = litellm_params self.litellm_params = litellm_params
self.logger_fn = litellm_params["logger_fn"] self.logger_fn = litellm_params["logger_fn"]
print_verbose(f"self.optional_params: {self.optional_params}") print_verbose(f"self.optional_params: {self.optional_params}")
@ -298,19 +299,22 @@ class Logging:
for callback in litellm.success_callback: for callback in litellm.success_callback:
try: try:
if callback == "lite_debugger": if callback == "lite_debugger":
print_verbose("reaches lite_debugger for logging!") print_verbose("reaches lite_debugger for logging!")
print_verbose(f"liteDebuggerClient: {liteDebuggerClient}") print_verbose(f"liteDebuggerClient: {liteDebuggerClient}")
print_verbose(f"liteDebuggerClient details function {self.call_type} and stream set to {self.stream}") print_verbose(f"liteDebuggerClient details function {self.call_type} and stream set to {self.stream}")
liteDebuggerClient.log_event( liteDebuggerClient.log_event(
end_user=litellm._thread_context.user, end_user=litellm._thread_context.user,
response_obj=result, response_obj=result,
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
litellm_call_id=self.litellm_call_id, litellm_call_id=self.litellm_call_id,
print_verbose=print_verbose, print_verbose=print_verbose,
call_type = self.call_type, call_type = self.call_type,
stream = self.stream, stream = self.stream,
) )
if callback == "api_manager":
print_verbose("reaches api manager for updating model cost")
litellm.apiManager.update_cost(completion_obj=result, user=self.user)
if callback == "cache": if callback == "cache":
# print("entering logger first time") # print("entering logger first time")
# print(self.litellm_params["stream_response"]) # print(self.litellm_params["stream_response"])

View file

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "litellm" name = "litellm"
version = "0.1.576" version = "0.1.577"
description = "Library to easily interface with LLM API providers" description = "Library to easily interface with LLM API providers"
authors = ["BerriAI"] authors = ["BerriAI"]
license = "MIT License" license = "MIT License"