mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
commitb12a9892b7
Author: Krrish Dholakia <krrishdholakia@gmail.com> Date: Wed Apr 2 08:09:56 2025 -0700 fix(utils.py): don't modify openai_token_counter commit294de31803
Author: Krrish Dholakia <krrishdholakia@gmail.com> Date: Mon Mar 24 21:22:40 2025 -0700 fix: fix linting error commitcb6e9fbe40
Author: Krrish Dholakia <krrishdholakia@gmail.com> Date: Mon Mar 24 19:52:45 2025 -0700 refactor: complete migration commitbfc159172d
Author: Krrish Dholakia <krrishdholakia@gmail.com> Date: Mon Mar 24 19:09:59 2025 -0700 refactor: refactor more constants commit43ffb6a558
Author: Krrish Dholakia <krrishdholakia@gmail.com> Date: Mon Mar 24 18:45:24 2025 -0700 fix: test commit04dbe4310c
Author: Krrish Dholakia <krrishdholakia@gmail.com> Date: Mon Mar 24 18:28:58 2025 -0700 refactor: refactor: move more constants into constants.py commit3c26284aff
Author: Krrish Dholakia <krrishdholakia@gmail.com> Date: Mon Mar 24 18:14:46 2025 -0700 refactor: migrate hardcoded constants out of __init__.py commitc11e0de69d
Author: Krrish Dholakia <krrishdholakia@gmail.com> Date: Mon Mar 24 18:11:21 2025 -0700 build: migrate all constants into constants.py commit7882bdc787
Author: Krrish Dholakia <krrishdholakia@gmail.com> Date: Mon Mar 24 18:07:37 2025 -0700 build: initial test banning hardcoded numbers in repo
84 lines
2.7 KiB
Python
84 lines
2.7 KiB
Python
"""
|
|
For calculating cost of fireworks ai serverless inference models.
|
|
"""
|
|
|
|
from typing import Tuple
|
|
|
|
from litellm.constants import (
|
|
FIREWORKS_AI_16_B,
|
|
FIREWORKS_AI_56_B_MOE,
|
|
FIREWORKS_AI_80_B,
|
|
FIREWORKS_AI_176_B_MOE,
|
|
)
|
|
from litellm.types.utils import Usage
|
|
from litellm.utils import get_model_info
|
|
|
|
|
|
# Extract the number of billion parameters from the model name
|
|
# only used for together_computer LLMs
|
|
def get_base_model_for_pricing(model_name: str) -> str:
|
|
"""
|
|
Helper function for calculating together ai pricing.
|
|
|
|
Returns:
|
|
- str: model pricing category if mapped else received model name
|
|
"""
|
|
import re
|
|
|
|
model_name = model_name.lower()
|
|
|
|
# Check for MoE models in the form <number>x<number>b
|
|
moe_match = re.search(r"(\d+)x(\d+)b", model_name)
|
|
if moe_match:
|
|
total_billion = int(moe_match.group(1)) * int(moe_match.group(2))
|
|
if total_billion <= FIREWORKS_AI_56_B_MOE:
|
|
return "fireworks-ai-moe-up-to-56b"
|
|
elif total_billion <= FIREWORKS_AI_176_B_MOE:
|
|
return "fireworks-ai-56b-to-176b"
|
|
|
|
# Check for standard models in the form <number>b
|
|
re_params_match = re.search(r"(\d+)b", model_name)
|
|
if re_params_match is not None:
|
|
params_match = str(re_params_match.group(1))
|
|
params_billion = float(params_match)
|
|
|
|
# Determine the category based on the number of parameters
|
|
if params_billion <= FIREWORKS_AI_16_B:
|
|
return "fireworks-ai-up-to-16b"
|
|
elif params_billion <= FIREWORKS_AI_80_B:
|
|
return "fireworks-ai-16b-80b"
|
|
|
|
# If no matches, return the original model_name
|
|
return "fireworks-ai-default"
|
|
|
|
|
|
def cost_per_token(model: str, usage: Usage) -> Tuple[float, float]:
|
|
"""
|
|
Calculates the cost per token for a given model, prompt tokens, and completion tokens.
|
|
|
|
Input:
|
|
- model: str, the model name without provider prefix
|
|
- usage: LiteLLM Usage block, containing anthropic caching information
|
|
|
|
Returns:
|
|
Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd
|
|
"""
|
|
## check if model mapped, else use default pricing
|
|
try:
|
|
model_info = get_model_info(model=model, custom_llm_provider="fireworks_ai")
|
|
except Exception:
|
|
base_model = get_base_model_for_pricing(model_name=model)
|
|
|
|
## GET MODEL INFO
|
|
model_info = get_model_info(
|
|
model=base_model, custom_llm_provider="fireworks_ai"
|
|
)
|
|
|
|
## CALCULATE INPUT COST
|
|
|
|
prompt_cost: float = usage["prompt_tokens"] * model_info["input_cost_per_token"]
|
|
|
|
## CALCULATE OUTPUT COST
|
|
completion_cost = usage["completion_tokens"] * model_info["output_cost_per_token"]
|
|
|
|
return prompt_cost, completion_cost
|