mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
Merge pull request #4295 from BerriAI/litellm_gemini_pricing_2
Vertex AI - character based cost calculation
This commit is contained in:
commit
71716bec48
5 changed files with 287 additions and 17 deletions
|
@ -3811,6 +3811,12 @@ def get_supported_openai_params(
|
|||
return None
|
||||
|
||||
|
||||
def _count_characters(text: str) -> int:
|
||||
# Remove white spaces and count characters
|
||||
filtered_text = "".join(char for char in text if not char.isspace())
|
||||
return len(filtered_text)
|
||||
|
||||
|
||||
def get_formatted_prompt(
|
||||
data: dict,
|
||||
call_type: Literal[
|
||||
|
@ -3829,9 +3835,20 @@ def get_formatted_prompt(
|
|||
"""
|
||||
prompt = ""
|
||||
if call_type == "completion":
|
||||
for m in data["messages"]:
|
||||
if "content" in m and isinstance(m["content"], str):
|
||||
prompt += m["content"]
|
||||
for message in data["messages"]:
|
||||
if message.get("content", None) is not None:
|
||||
content = message.get("content")
|
||||
if isinstance(content, str):
|
||||
prompt += message["content"]
|
||||
elif isinstance(content, List):
|
||||
for c in content:
|
||||
if c["type"] == "text":
|
||||
prompt += c["text"]
|
||||
if "tool_calls" in message:
|
||||
for tool_call in message["tool_calls"]:
|
||||
if "function" in tool_call:
|
||||
function_arguments = tool_call["function"]["arguments"]
|
||||
prompt += function_arguments
|
||||
elif call_type == "text_completion":
|
||||
prompt = data["prompt"]
|
||||
elif call_type == "embedding" or call_type == "moderation":
|
||||
|
@ -3848,6 +3865,17 @@ def get_formatted_prompt(
|
|||
return prompt
|
||||
|
||||
|
||||
def get_response_string(response_obj: ModelResponse) -> str:
|
||||
_choices: List[Choices] = response_obj.choices # type: ignore
|
||||
|
||||
response_str = ""
|
||||
for choice in _choices:
|
||||
if choice.message.content is not None:
|
||||
response_str += choice.message.content
|
||||
|
||||
return response_str
|
||||
|
||||
|
||||
def _is_non_openai_azure_model(model: str) -> bool:
|
||||
try:
|
||||
model_name = model.split("/", 1)[1]
|
||||
|
@ -4394,13 +4422,22 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
|
|||
max_input_tokens=_model_info.get("max_input_tokens", None),
|
||||
max_output_tokens=_model_info.get("max_output_tokens", None),
|
||||
input_cost_per_token=_model_info.get("input_cost_per_token", 0),
|
||||
input_cost_per_character=_model_info.get(
|
||||
"input_cost_per_character", None
|
||||
),
|
||||
input_cost_per_token_above_128k_tokens=_model_info.get(
|
||||
"input_cost_per_token_above_128k_tokens", None
|
||||
),
|
||||
output_cost_per_token=_model_info.get("output_cost_per_token", 0),
|
||||
output_cost_per_character=_model_info.get(
|
||||
"output_cost_per_character", None
|
||||
),
|
||||
output_cost_per_token_above_128k_tokens=_model_info.get(
|
||||
"output_cost_per_token_above_128k_tokens", None
|
||||
),
|
||||
output_cost_per_character_above_128k_tokens=_model_info.get(
|
||||
"output_cost_per_character_above_128k_tokens", None
|
||||
),
|
||||
litellm_provider=_model_info.get(
|
||||
"litellm_provider", custom_llm_provider
|
||||
),
|
||||
|
@ -4428,13 +4465,22 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
|
|||
max_input_tokens=_model_info.get("max_input_tokens", None),
|
||||
max_output_tokens=_model_info.get("max_output_tokens", None),
|
||||
input_cost_per_token=_model_info.get("input_cost_per_token", 0),
|
||||
input_cost_per_character=_model_info.get(
|
||||
"input_cost_per_character", None
|
||||
),
|
||||
input_cost_per_token_above_128k_tokens=_model_info.get(
|
||||
"input_cost_per_token_above_128k_tokens", None
|
||||
),
|
||||
output_cost_per_token=_model_info.get("output_cost_per_token", 0),
|
||||
output_cost_per_character=_model_info.get(
|
||||
"output_cost_per_character", None
|
||||
),
|
||||
output_cost_per_token_above_128k_tokens=_model_info.get(
|
||||
"output_cost_per_token_above_128k_tokens", None
|
||||
),
|
||||
output_cost_per_character_above_128k_tokens=_model_info.get(
|
||||
"output_cost_per_character_above_128k_tokens", None
|
||||
),
|
||||
litellm_provider=_model_info.get(
|
||||
"litellm_provider", custom_llm_provider
|
||||
),
|
||||
|
@ -4462,13 +4508,22 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
|
|||
max_input_tokens=_model_info.get("max_input_tokens", None),
|
||||
max_output_tokens=_model_info.get("max_output_tokens", None),
|
||||
input_cost_per_token=_model_info.get("input_cost_per_token", 0),
|
||||
input_cost_per_character=_model_info.get(
|
||||
"input_cost_per_character", None
|
||||
),
|
||||
input_cost_per_token_above_128k_tokens=_model_info.get(
|
||||
"input_cost_per_token_above_128k_tokens", None
|
||||
),
|
||||
output_cost_per_token=_model_info.get("output_cost_per_token", 0),
|
||||
output_cost_per_character=_model_info.get(
|
||||
"output_cost_per_character", None
|
||||
),
|
||||
output_cost_per_token_above_128k_tokens=_model_info.get(
|
||||
"output_cost_per_token_above_128k_tokens", None
|
||||
),
|
||||
output_cost_per_character_above_128k_tokens=_model_info.get(
|
||||
"output_cost_per_character_above_128k_tokens", None
|
||||
),
|
||||
litellm_provider=_model_info.get(
|
||||
"litellm_provider", custom_llm_provider
|
||||
),
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue