mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
feat(utils.py): support custom cost tracking per second
https://github.com/BerriAI/litellm/issues/1374
This commit is contained in:
parent
44f756efb5
commit
276a685a59
4 changed files with 74 additions and 31 deletions
|
@ -12,15 +12,6 @@ formatter = logging.Formatter("\033[92m%(name)s - %(levelname)s\033[0m: %(messag
|
||||||
|
|
||||||
handler.setFormatter(formatter)
|
handler.setFormatter(formatter)
|
||||||
|
|
||||||
|
|
||||||
def print_verbose(print_statement):
|
|
||||||
try:
|
|
||||||
if set_verbose:
|
|
||||||
print(print_statement) # noqa
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
verbose_proxy_logger = logging.getLogger("LiteLLM Proxy")
|
verbose_proxy_logger = logging.getLogger("LiteLLM Proxy")
|
||||||
verbose_router_logger = logging.getLogger("LiteLLM Router")
|
verbose_router_logger = logging.getLogger("LiteLLM Router")
|
||||||
verbose_logger = logging.getLogger("LiteLLM")
|
verbose_logger = logging.getLogger("LiteLLM")
|
||||||
|
@ -29,3 +20,18 @@ verbose_logger = logging.getLogger("LiteLLM")
|
||||||
verbose_router_logger.addHandler(handler)
|
verbose_router_logger.addHandler(handler)
|
||||||
verbose_proxy_logger.addHandler(handler)
|
verbose_proxy_logger.addHandler(handler)
|
||||||
verbose_logger.addHandler(handler)
|
verbose_logger.addHandler(handler)
|
||||||
|
|
||||||
|
|
||||||
|
def print_verbose(print_statement):
|
||||||
|
try:
|
||||||
|
if set_verbose:
|
||||||
|
print(print_statement) # noqa
|
||||||
|
verbose_logger.setLevel(level=logging.DEBUG) # set package log to debug
|
||||||
|
verbose_router_logger.setLevel(
|
||||||
|
level=logging.DEBUG
|
||||||
|
) # set router logs to debug
|
||||||
|
verbose_proxy_logger.setLevel(
|
||||||
|
level=logging.DEBUG
|
||||||
|
) # set proxy logs to debug
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
|
@ -457,6 +457,8 @@ def completion(
|
||||||
### CUSTOM MODEL COST ###
|
### CUSTOM MODEL COST ###
|
||||||
input_cost_per_token = kwargs.get("input_cost_per_token", None)
|
input_cost_per_token = kwargs.get("input_cost_per_token", None)
|
||||||
output_cost_per_token = kwargs.get("output_cost_per_token", None)
|
output_cost_per_token = kwargs.get("output_cost_per_token", None)
|
||||||
|
input_cost_per_second = kwargs.get("input_cost_per_second", None)
|
||||||
|
output_cost_per_second = kwargs.get("output_cost_per_second", None)
|
||||||
### CUSTOM PROMPT TEMPLATE ###
|
### CUSTOM PROMPT TEMPLATE ###
|
||||||
initial_prompt_value = kwargs.get("initial_prompt_value", None)
|
initial_prompt_value = kwargs.get("initial_prompt_value", None)
|
||||||
roles = kwargs.get("roles", None)
|
roles = kwargs.get("roles", None)
|
||||||
|
@ -596,6 +598,19 @@ def completion(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
if (
|
||||||
|
input_cost_per_second is not None
|
||||||
|
): # time based pricing just needs cost in place
|
||||||
|
output_cost_per_second = output_cost_per_second or 0.0
|
||||||
|
litellm.register_model(
|
||||||
|
{
|
||||||
|
model: {
|
||||||
|
"input_cost_per_second": input_cost_per_second,
|
||||||
|
"output_cost_per_second": output_cost_per_second,
|
||||||
|
"litellm_provider": custom_llm_provider,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
### BUILD CUSTOM PROMPT TEMPLATE -- IF GIVEN ###
|
### BUILD CUSTOM PROMPT TEMPLATE -- IF GIVEN ###
|
||||||
custom_prompt_dict = {} # type: ignore
|
custom_prompt_dict = {} # type: ignore
|
||||||
if (
|
if (
|
||||||
|
|
|
@ -1372,16 +1372,21 @@ def test_customprompt_together_ai():
|
||||||
|
|
||||||
def test_completion_sagemaker():
|
def test_completion_sagemaker():
|
||||||
try:
|
try:
|
||||||
print("testing sagemaker")
|
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
|
print("testing sagemaker")
|
||||||
response = completion(
|
response = completion(
|
||||||
model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4",
|
model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4",
|
||||||
messages=messages,
|
messages=messages,
|
||||||
temperature=0.2,
|
temperature=0.2,
|
||||||
max_tokens=80,
|
max_tokens=80,
|
||||||
|
input_cost_per_second=0.000420,
|
||||||
)
|
)
|
||||||
# Add any assertions here to check the response
|
# Add any assertions here to check the response
|
||||||
print(response)
|
print(response)
|
||||||
|
cost = completion_cost(completion_response=response)
|
||||||
|
assert (
|
||||||
|
cost > 0.0 and cost < 1.0
|
||||||
|
) # should never be > $1 for a single completion call
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
|
@ -829,7 +829,7 @@ class Logging:
|
||||||
[f"-H '{k}: {v}'" for k, v in masked_headers.items()]
|
[f"-H '{k}: {v}'" for k, v in masked_headers.items()]
|
||||||
)
|
)
|
||||||
|
|
||||||
print_verbose(f"PRE-API-CALL ADDITIONAL ARGS: {additional_args}")
|
verbose_logger.debug(f"PRE-API-CALL ADDITIONAL ARGS: {additional_args}")
|
||||||
|
|
||||||
curl_command = "\n\nPOST Request Sent from LiteLLM:\n"
|
curl_command = "\n\nPOST Request Sent from LiteLLM:\n"
|
||||||
curl_command += "curl -X POST \\\n"
|
curl_command += "curl -X POST \\\n"
|
||||||
|
@ -995,13 +995,10 @@ class Logging:
|
||||||
self.model_call_details["log_event_type"] = "post_api_call"
|
self.model_call_details["log_event_type"] = "post_api_call"
|
||||||
|
|
||||||
# User Logging -> if you pass in a custom logging function
|
# User Logging -> if you pass in a custom logging function
|
||||||
print_verbose(
|
verbose_logger.info(
|
||||||
f"RAW RESPONSE:\n{self.model_call_details.get('original_response', self.model_call_details)}\n\n"
|
f"RAW RESPONSE:\n{self.model_call_details.get('original_response', self.model_call_details)}\n\n"
|
||||||
)
|
)
|
||||||
print_verbose(
|
verbose_logger.debug(
|
||||||
f"Logging Details Post-API Call: logger_fn - {self.logger_fn} | callable(logger_fn) - {callable(self.logger_fn)}"
|
|
||||||
)
|
|
||||||
print_verbose(
|
|
||||||
f"Logging Details Post-API Call: LiteLLM Params: {self.model_call_details}"
|
f"Logging Details Post-API Call: LiteLLM Params: {self.model_call_details}"
|
||||||
)
|
)
|
||||||
if self.logger_fn and callable(self.logger_fn):
|
if self.logger_fn and callable(self.logger_fn):
|
||||||
|
@ -2135,7 +2132,7 @@ def client(original_function):
|
||||||
litellm.cache.add_cache(result, *args, **kwargs)
|
litellm.cache.add_cache(result, *args, **kwargs)
|
||||||
|
|
||||||
# LOG SUCCESS - handle streaming success logging in the _next_ object, remove `handle_success` once it's deprecated
|
# LOG SUCCESS - handle streaming success logging in the _next_ object, remove `handle_success` once it's deprecated
|
||||||
print_verbose(f"Wrapper: Completed Call, calling success_handler")
|
verbose_logger.info(f"Wrapper: Completed Call, calling success_handler")
|
||||||
threading.Thread(
|
threading.Thread(
|
||||||
target=logging_obj.success_handler, args=(result, start_time, end_time)
|
target=logging_obj.success_handler, args=(result, start_time, end_time)
|
||||||
).start()
|
).start()
|
||||||
|
@ -2807,7 +2804,11 @@ def token_counter(
|
||||||
|
|
||||||
|
|
||||||
def cost_per_token(
|
def cost_per_token(
|
||||||
model="", prompt_tokens=0, completion_tokens=0, custom_llm_provider=None
|
model="",
|
||||||
|
prompt_tokens=0,
|
||||||
|
completion_tokens=0,
|
||||||
|
response_time_ms=None,
|
||||||
|
custom_llm_provider=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Calculates the cost per token for a given model, prompt tokens, and completion tokens.
|
Calculates the cost per token for a given model, prompt tokens, and completion tokens.
|
||||||
|
@ -2829,15 +2830,29 @@ def cost_per_token(
|
||||||
else:
|
else:
|
||||||
model_with_provider = model
|
model_with_provider = model
|
||||||
# see this https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models
|
# see this https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models
|
||||||
print_verbose(f"Looking up model={model} in model_cost_map")
|
verbose_logger.debug(f"Looking up model={model} in model_cost_map")
|
||||||
|
|
||||||
if model in model_cost_ref:
|
if model in model_cost_ref:
|
||||||
prompt_tokens_cost_usd_dollar = (
|
if (
|
||||||
model_cost_ref[model]["input_cost_per_token"] * prompt_tokens
|
model_cost_ref[model].get("input_cost_per_token", None) is not None
|
||||||
)
|
and model_cost_ref[model].get("output_cost_per_token", None) is not None
|
||||||
completion_tokens_cost_usd_dollar = (
|
):
|
||||||
model_cost_ref[model]["output_cost_per_token"] * completion_tokens
|
## COST PER TOKEN ##
|
||||||
)
|
prompt_tokens_cost_usd_dollar = (
|
||||||
|
model_cost_ref[model]["input_cost_per_token"] * prompt_tokens
|
||||||
|
)
|
||||||
|
completion_tokens_cost_usd_dollar = (
|
||||||
|
model_cost_ref[model]["output_cost_per_token"] * completion_tokens
|
||||||
|
)
|
||||||
|
elif (
|
||||||
|
model_cost_ref[model].get("input_cost_per_second", None) is not None
|
||||||
|
and response_time_ms is not None
|
||||||
|
):
|
||||||
|
## COST PER SECOND ##
|
||||||
|
prompt_tokens_cost_usd_dollar = (
|
||||||
|
model_cost_ref[model]["input_cost_per_second"] * response_time_ms / 1000
|
||||||
|
)
|
||||||
|
completion_tokens_cost_usd_dollar = 0.0
|
||||||
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
|
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
|
||||||
elif model_with_provider in model_cost_ref:
|
elif model_with_provider in model_cost_ref:
|
||||||
print_verbose(f"Looking up model={model_with_provider} in model_cost_map")
|
print_verbose(f"Looking up model={model_with_provider} in model_cost_map")
|
||||||
|
@ -2939,6 +2954,7 @@ def completion_cost(
|
||||||
completion_tokens = completion_response.get("usage", {}).get(
|
completion_tokens = completion_response.get("usage", {}).get(
|
||||||
"completion_tokens", 0
|
"completion_tokens", 0
|
||||||
)
|
)
|
||||||
|
total_time = completion_response.get("_response_ms", 0)
|
||||||
model = (
|
model = (
|
||||||
model or completion_response["model"]
|
model or completion_response["model"]
|
||||||
) # check if user passed an override for model, if it's none check completion_response['model']
|
) # check if user passed an override for model, if it's none check completion_response['model']
|
||||||
|
@ -2976,6 +2992,7 @@ def completion_cost(
|
||||||
prompt_tokens=prompt_tokens,
|
prompt_tokens=prompt_tokens,
|
||||||
completion_tokens=completion_tokens,
|
completion_tokens=completion_tokens,
|
||||||
custom_llm_provider=custom_llm_provider,
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
response_time_ms=total_time,
|
||||||
)
|
)
|
||||||
return prompt_tokens_cost_usd_dollar + completion_tokens_cost_usd_dollar
|
return prompt_tokens_cost_usd_dollar + completion_tokens_cost_usd_dollar
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -3006,9 +3023,7 @@ def register_model(model_cost: Union[str, dict]):
|
||||||
|
|
||||||
for key, value in loaded_model_cost.items():
|
for key, value in loaded_model_cost.items():
|
||||||
## override / add new keys to the existing model cost dictionary
|
## override / add new keys to the existing model cost dictionary
|
||||||
if key in litellm.model_cost:
|
litellm.model_cost.setdefault(key, {}).update(value)
|
||||||
for k, v in loaded_model_cost[key].items():
|
|
||||||
litellm.model_cost[key][k] = v
|
|
||||||
# add new model names to provider lists
|
# add new model names to provider lists
|
||||||
if value.get("litellm_provider") == "openai":
|
if value.get("litellm_provider") == "openai":
|
||||||
if key not in litellm.open_ai_chat_completion_models:
|
if key not in litellm.open_ai_chat_completion_models:
|
||||||
|
@ -3301,11 +3316,13 @@ def get_optional_params(
|
||||||
)
|
)
|
||||||
|
|
||||||
def _check_valid_arg(supported_params):
|
def _check_valid_arg(supported_params):
|
||||||
print_verbose(
|
verbose_logger.debug(
|
||||||
f"\nLiteLLM completion() model= {model}; provider = {custom_llm_provider}"
|
f"\nLiteLLM completion() model= {model}; provider = {custom_llm_provider}"
|
||||||
)
|
)
|
||||||
print_verbose(f"\nLiteLLM: Params passed to completion() {passed_params}")
|
verbose_logger.debug(
|
||||||
print_verbose(
|
f"\nLiteLLM: Params passed to completion() {passed_params}"
|
||||||
|
)
|
||||||
|
verbose_logger.debug(
|
||||||
f"\nLiteLLM: Non-Default params passed to completion() {non_default_params}"
|
f"\nLiteLLM: Non-Default params passed to completion() {non_default_params}"
|
||||||
)
|
)
|
||||||
unsupported_params = {}
|
unsupported_params = {}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue