feat(llm_cost_calc/google.py): do character based cost calculation for vertex ai

Calculate cost for vertex ai responses using characters in query/response

 Closes https://github.com/BerriAI/litellm/issues/4165
This commit is contained in:
Krrish Dholakia 2024-06-19 17:18:42 -07:00
parent cab057da4a
commit 16da21e839
5 changed files with 287 additions and 17 deletions

View file

@ -6,6 +6,9 @@ from typing import List, Literal, Optional, Tuple, Union
import litellm
import litellm._logging
from litellm import verbose_logger
from litellm.litellm_core_utils.llm_cost_calc.google import (
cost_per_character as google_cost_per_character,
)
from litellm.litellm_core_utils.llm_cost_calc.google import (
cost_per_token as google_cost_per_token,
)
@ -23,8 +26,8 @@ from litellm.utils import (
def _cost_per_token_custom_pricing_helper(
prompt_tokens=0,
completion_tokens=0,
prompt_tokens: float = 0,
completion_tokens: float = 0,
response_time_ms=None,
### CUSTOM PRICING ###
custom_cost_per_token: Optional[CostPerToken] = None,
@ -52,6 +55,9 @@ def cost_per_token(
response_time_ms=None,
custom_llm_provider: Optional[str] = None,
region_name=None,
### CHARACTER PRICING ###
prompt_characters: float = 0,
completion_characters: float = 0,
### CUSTOM PRICING ###
custom_cost_per_token: Optional[CostPerToken] = None,
custom_cost_per_second: Optional[float] = None,
@ -64,6 +70,8 @@ def cost_per_token(
prompt_tokens (int): The number of tokens in the prompt.
completion_tokens (int): The number of tokens in the completion.
response_time (float): The amount of time, in milliseconds, it took the call to complete.
prompt_characters (float): The number of characters in the prompt. Used for vertex ai cost calculation.
completion_characters (float): The number of characters in the completion response. Used for vertex ai cost calculation.
custom_llm_provider (str): The llm provider to whom the call was made (see init.py for full list)
custom_cost_per_token: Optional[CostPerToken]: the cost per input + output token for the llm api call.
custom_cost_per_second: Optional[float]: the cost per second for the llm api call.
@ -127,7 +135,16 @@ def cost_per_token(
# see this https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models
print_verbose(f"Looking up model={model} in model_cost_map")
if custom_llm_provider == "vertex_ai" or custom_llm_provider == "gemini":
if custom_llm_provider == "vertex_ai":
return google_cost_per_character(
model=model_without_prefix,
custom_llm_provider=custom_llm_provider,
prompt_characters=prompt_characters,
completion_characters=completion_characters,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
elif custom_llm_provider == "gemini":
return google_cost_per_token(
model=model_without_prefix,
custom_llm_provider=custom_llm_provider,
@ -378,7 +395,9 @@ def completion_cost(
model = "dall-e-2" # for dall-e-2, azure expects an empty model name
# Handle Inputs to completion_cost
prompt_tokens = 0
prompt_characters = 0
completion_tokens = 0
completion_characters = 0
custom_llm_provider = None
if completion_response is not None:
# get input/output tokens from completion_response
@ -495,6 +514,30 @@ def completion_cost(
f"Model is None and does not exist in passed completion_response. Passed completion_response={completion_response}, model={model}"
)
if (
custom_llm_provider is not None
and custom_llm_provider == "vertex_ai"
and completion_response is not None
and isinstance(completion_response, ModelResponse)
):
# Calculate the prompt characters + response characters
if len("messages") > 0:
prompt_string = litellm.utils.get_formatted_prompt(
data={"messages": messages}, call_type="completion"
)
else:
prompt_string = ""
prompt_characters = litellm.utils._count_characters(text=prompt_string)
completion_string = litellm.utils.get_response_string(
response_obj=completion_response
)
completion_characters = litellm.utils._count_characters(
text=completion_string
)
(
prompt_tokens_cost_usd_dollar,
completion_tokens_cost_usd_dollar,
@ -507,6 +550,8 @@ def completion_cost(
region_name=region_name,
custom_cost_per_second=custom_cost_per_second,
custom_cost_per_token=custom_cost_per_token,
prompt_characters=prompt_characters,
completion_characters=completion_characters,
)
_final_cost = prompt_tokens_cost_usd_dollar + completion_tokens_cost_usd_dollar
print_verbose(