feat(llm_cost_calc/google.py): do character based cost calculation for vertex ai

Calculate cost for vertex ai responses using characters in query/response

 Closes https://github.com/BerriAI/litellm/issues/4165
This commit is contained in:
Krrish Dholakia 2024-06-19 17:18:42 -07:00
parent 8bb304997a
commit edfe550165
5 changed files with 287 additions and 17 deletions

View file

@ -3810,6 +3810,12 @@ def get_supported_openai_params(
return None
def _count_characters(text: str) -> int:
# Remove white spaces and count characters
filtered_text = "".join(char for char in text if not char.isspace())
return len(filtered_text)
def get_formatted_prompt(
data: dict,
call_type: Literal[
@ -3828,9 +3834,20 @@ def get_formatted_prompt(
"""
prompt = ""
if call_type == "completion":
for m in data["messages"]:
if "content" in m and isinstance(m["content"], str):
prompt += m["content"]
for message in data["messages"]:
if message.get("content", None) is not None:
content = message.get("content")
if isinstance(content, str):
prompt += message["content"]
elif isinstance(content, List):
for c in content:
if c["type"] == "text":
prompt += c["text"]
if "tool_calls" in message:
for tool_call in message["tool_calls"]:
if "function" in tool_call:
function_arguments = tool_call["function"]["arguments"]
prompt += function_arguments
elif call_type == "text_completion":
prompt = data["prompt"]
elif call_type == "embedding" or call_type == "moderation":
@ -3847,6 +3864,17 @@ def get_formatted_prompt(
return prompt
def get_response_string(response_obj: ModelResponse) -> str:
_choices: List[Choices] = response_obj.choices # type: ignore
response_str = ""
for choice in _choices:
if choice.message.content is not None:
response_str += choice.message.content
return response_str
def _is_non_openai_azure_model(model: str) -> bool:
try:
model_name = model.split("/", 1)[1]
@ -4392,13 +4420,22 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
max_input_tokens=_model_info.get("max_input_tokens", None),
max_output_tokens=_model_info.get("max_output_tokens", None),
input_cost_per_token=_model_info.get("input_cost_per_token", 0),
input_cost_per_character=_model_info.get(
"input_cost_per_character", None
),
input_cost_per_token_above_128k_tokens=_model_info.get(
"input_cost_per_token_above_128k_tokens", None
),
output_cost_per_token=_model_info.get("output_cost_per_token", 0),
output_cost_per_character=_model_info.get(
"output_cost_per_character", None
),
output_cost_per_token_above_128k_tokens=_model_info.get(
"output_cost_per_token_above_128k_tokens", None
),
output_cost_per_character_above_128k_tokens=_model_info.get(
"output_cost_per_character_above_128k_tokens", None
),
litellm_provider=_model_info.get(
"litellm_provider", custom_llm_provider
),
@ -4426,13 +4463,22 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
max_input_tokens=_model_info.get("max_input_tokens", None),
max_output_tokens=_model_info.get("max_output_tokens", None),
input_cost_per_token=_model_info.get("input_cost_per_token", 0),
input_cost_per_character=_model_info.get(
"input_cost_per_character", None
),
input_cost_per_token_above_128k_tokens=_model_info.get(
"input_cost_per_token_above_128k_tokens", None
),
output_cost_per_token=_model_info.get("output_cost_per_token", 0),
output_cost_per_character=_model_info.get(
"output_cost_per_character", None
),
output_cost_per_token_above_128k_tokens=_model_info.get(
"output_cost_per_token_above_128k_tokens", None
),
output_cost_per_character_above_128k_tokens=_model_info.get(
"output_cost_per_character_above_128k_tokens", None
),
litellm_provider=_model_info.get(
"litellm_provider", custom_llm_provider
),
@ -4460,13 +4506,22 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
max_input_tokens=_model_info.get("max_input_tokens", None),
max_output_tokens=_model_info.get("max_output_tokens", None),
input_cost_per_token=_model_info.get("input_cost_per_token", 0),
input_cost_per_character=_model_info.get(
"input_cost_per_character", None
),
input_cost_per_token_above_128k_tokens=_model_info.get(
"input_cost_per_token_above_128k_tokens", None
),
output_cost_per_token=_model_info.get("output_cost_per_token", 0),
output_cost_per_character=_model_info.get(
"output_cost_per_character", None
),
output_cost_per_token_above_128k_tokens=_model_info.get(
"output_cost_per_token_above_128k_tokens", None
),
output_cost_per_character_above_128k_tokens=_model_info.get(
"output_cost_per_character_above_128k_tokens", None
),
litellm_provider=_model_info.get(
"litellm_provider", custom_llm_provider
),