Merge pull request #4245 from BerriAI/litellm_gemini_pricing_token_counter

VertexAI/Gemini: Calculate cost based on context window
This commit is contained in:
Krish Dholakia 2024-06-17 16:35:23 -07:00 committed by GitHub
commit 02d9d96141
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 247 additions and 51 deletions

View file

@ -1,20 +1,24 @@
# What is this?
## File for 'response_cost' calculation in Logging
from typing import Optional, Union, Literal, List, Tuple
from typing import List, Literal, Optional, Tuple, Union
import litellm
import litellm._logging
from litellm import verbose_logger
from litellm.litellm_core_utils.llm_cost_calc.google import (
cost_per_token as google_cost_per_token,
)
from litellm.utils import (
ModelResponse,
CallTypes,
CostPerToken,
EmbeddingResponse,
ImageResponse,
TranscriptionResponse,
ModelResponse,
TextCompletionResponse,
CallTypes,
TranscriptionResponse,
print_verbose,
CostPerToken,
token_counter,
)
import litellm
from litellm import verbose_logger
def _cost_per_token_custom_pricing_helper(
@ -42,10 +46,10 @@ def _cost_per_token_custom_pricing_helper(
def cost_per_token(
model: str = "",
prompt_tokens=0,
completion_tokens=0,
prompt_tokens: float = 0,
completion_tokens: float = 0,
response_time_ms=None,
custom_llm_provider=None,
custom_llm_provider: Optional[str] = None,
region_name=None,
### CUSTOM PRICING ###
custom_cost_per_token: Optional[CostPerToken] = None,
@ -66,6 +70,7 @@ def cost_per_token(
Returns:
tuple: A tuple containing the cost in USD dollars for prompt tokens and completion tokens, respectively.
"""
args = locals()
if model is None:
raise Exception("Invalid arg. Model cannot be none.")
## CUSTOM PRICING ##
@ -94,7 +99,8 @@ def cost_per_token(
model_with_provider_and_region in model_cost_ref
): # use region based pricing, if it's available
model_with_provider = model_with_provider_and_region
else:
_, custom_llm_provider, _, _ = litellm.get_llm_provider(model=model)
model_without_prefix = model
model_parts = model.split("/")
if len(model_parts) > 1:
@ -120,7 +126,14 @@ def cost_per_token(
# see this https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models
print_verbose(f"Looking up model={model} in model_cost_map")
if model in model_cost_ref:
if custom_llm_provider == "vertex_ai" or custom_llm_provider == "gemini":
return google_cost_per_token(
model=model_without_prefix,
custom_llm_provider=custom_llm_provider,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
elif model in model_cost_ref:
print_verbose(f"Success: model={model} in model_cost_map")
print_verbose(
f"prompt_tokens={prompt_tokens}; completion_tokens={completion_tokens}"

View file

@ -0,0 +1,82 @@
# What is this?
## Cost calculation for Google AI Studio / Vertex AI models
from typing import Literal, Tuple
import litellm
"""
Gemini pricing covers:
- token
- image
- audio
- video
"""
models_without_dynamic_pricing = ["gemini-1.0-pro", "gemini-pro"]
def _is_above_128k(tokens: float) -> bool:
if tokens > 128000:
return True
return False
def cost_per_token(
model: str,
custom_llm_provider: str,
prompt_tokens: float,
completion_tokens: float,
) -> Tuple[float, float]:
"""
Calculates the cost per token for a given model, prompt tokens, and completion tokens.
Input:
- model: str, the model name without provider prefix
- custom_llm_provider: str, either "vertex_ai-*" or "gemini"
- prompt_tokens: float, the number of input tokens
- completion_tokens: float, the number of output tokens
Returns:
Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd
Raises:
Exception if model requires >128k pricing, but model cost not mapped
"""
## GET MODEL INFO
model_info = litellm.get_model_info(
model=model, custom_llm_provider=custom_llm_provider
)
## CALCULATE INPUT COST
if (
_is_above_128k(tokens=prompt_tokens)
and model not in models_without_dynamic_pricing
):
assert (
model_info["input_cost_per_token_above_128k_tokens"] is not None
), "model info for model={} does not have pricing for > 128k tokens\nmodel_info={}".format(
model, model_info
)
prompt_cost = (
prompt_tokens * model_info["input_cost_per_token_above_128k_tokens"]
)
else:
prompt_cost = prompt_tokens * model_info["input_cost_per_token"]
## CALCULATE OUTPUT COST
if (
_is_above_128k(tokens=completion_tokens)
and model not in models_without_dynamic_pricing
):
assert (
model_info["output_cost_per_token_above_128k_tokens"] is not None
), "model info for model={} does not have pricing for > 128k tokens\nmodel_info={}".format(
model, model_info
)
completion_cost = (
completion_tokens * model_info["output_cost_per_token_above_128k_tokens"]
)
else:
completion_cost = completion_tokens * model_info["output_cost_per_token"]
return prompt_cost, completion_cost

View file

@ -1,20 +1,28 @@
import sys, os
import os
import sys
import traceback
import litellm.cost_calculator
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import asyncio
import time
from typing import Optional
import pytest
import litellm
from litellm import (
TranscriptionResponse,
completion_cost,
cost_per_token,
get_max_tokens,
model_cost,
open_ai_chat_completion_models,
TranscriptionResponse,
)
from litellm.litellm_core_utils.litellm_logging import CustomLogger
import pytest, asyncio
class CustomLoggingHandler(CustomLogger):
@ -66,7 +74,7 @@ async def test_custom_pricing(sync_mode):
def test_custom_pricing_as_completion_cost_param():
from litellm import ModelResponse, Choices, Message
from litellm import Choices, Message, ModelResponse
from litellm.utils import Usage
resp = ModelResponse(
@ -134,7 +142,7 @@ def test_cost_ft_gpt_35():
try:
# this tests if litellm.completion_cost can calculate cost for ft:gpt-3.5-turbo:my-org:custom_suffix:id
# it needs to lookup ft:gpt-3.5-turbo in the litellm model_cost map to get the correct cost
from litellm import ModelResponse, Choices, Message
from litellm import Choices, Message, ModelResponse
from litellm.utils import Usage
resp = ModelResponse(
@ -179,7 +187,7 @@ def test_cost_azure_gpt_35():
try:
# this tests if litellm.completion_cost can calculate cost for azure/chatgpt-deployment-2 which maps to azure/gpt-3.5-turbo
# for this test we check if passing `model` to completion_cost overrides the completion cost
from litellm import ModelResponse, Choices, Message
from litellm import Choices, Message, ModelResponse
from litellm.utils import Usage
resp = ModelResponse(
@ -266,7 +274,7 @@ def test_cost_bedrock_pricing():
"""
- get pricing specific to region for a model
"""
from litellm import ModelResponse, Choices, Message
from litellm import Choices, Message, ModelResponse
from litellm.utils import Usage
litellm.set_verbose = True
@ -475,13 +483,13 @@ def test_replicate_llama3_cost_tracking():
@pytest.mark.parametrize("is_streaming", [True, False]) #
def test_groq_response_cost_tracking(is_streaming):
from litellm.utils import (
ModelResponse,
Choices,
Message,
Usage,
CallTypes,
StreamingChoices,
Choices,
Delta,
Message,
ModelResponse,
StreamingChoices,
Usage,
)
response = ModelResponse(
@ -565,3 +573,58 @@ def test_together_ai_qwen_completion_cost():
)
assert response == "together-ai-41.1b-80b"
@pytest.mark.parametrize("above_128k", [False, True])
@pytest.mark.parametrize("provider", ["vertex_ai", "gemini"])
def test_gemini_completion_cost(above_128k, provider):
"""
Check if cost correctly calculated for gemini models based on context window
"""
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
litellm.model_cost = litellm.get_model_cost_map(url="")
if provider == "gemini":
model_name = "gemini-1.5-flash-latest"
else:
model_name = "gemini-1.5-flash-preview-0514"
if above_128k:
prompt_tokens = 128001.0
output_tokens = 228001.0
else:
prompt_tokens = 128.0
output_tokens = 228.0
## GET MODEL FROM LITELLM.MODEL_INFO
model_info = litellm.get_model_info(model=model_name, custom_llm_provider=provider)
## EXPECTED COST
if above_128k:
assert (
model_info["input_cost_per_token_above_128k_tokens"] is not None
), "model info for model={} does not have pricing for > 128k tokens\nmodel_info={}".format(
model_name, model_info
)
assert (
model_info["output_cost_per_token_above_128k_tokens"] is not None
), "model info for model={} does not have pricing for > 128k tokens\nmodel_info={}".format(
model_name, model_info
)
input_cost = (
prompt_tokens * model_info["input_cost_per_token_above_128k_tokens"]
)
output_cost = (
output_tokens * model_info["output_cost_per_token_above_128k_tokens"]
)
else:
input_cost = prompt_tokens * model_info["input_cost_per_token"]
output_cost = output_tokens * model_info["output_cost_per_token"]
## CALCULATED COST
calculated_input_cost, calculated_output_cost = cost_per_token(
model=model_name,
prompt_tokens=prompt_tokens,
completion_tokens=output_tokens,
custom_llm_provider=provider,
)
assert calculated_input_cost == input_cost
assert calculated_output_cost == output_cost

View file

@ -1,14 +1,15 @@
from typing import List, Optional, Union, Dict, Tuple, Literal
from typing_extensions import TypedDict
from enum import Enum
from typing_extensions import override, Required, Dict
from .llms.openai import ChatCompletionUsageBlock, ChatCompletionToolCallChunk
from ..litellm_core_utils.core_helpers import map_finish_reason
from openai._models import BaseModel as OpenAIObject
from pydantic import ConfigDict
import uuid
import json
import time
import uuid
from enum import Enum
from typing import Dict, List, Literal, Optional, Tuple, Union
from openai._models import BaseModel as OpenAIObject
from pydantic import ConfigDict
from typing_extensions import Dict, Required, TypedDict, override
from ..litellm_core_utils.core_helpers import map_finish_reason
from .llms.openai import ChatCompletionToolCallChunk, ChatCompletionUsageBlock
def _generate_id(): # private helper function
@ -34,21 +35,31 @@ class ProviderField(TypedDict):
field_value: str
class ModelInfo(TypedDict):
class ModelInfo(TypedDict, total=False):
"""
Model info for a given model, this is information found in litellm.model_prices_and_context_window.json
"""
max_tokens: Optional[int]
max_input_tokens: Optional[int]
max_output_tokens: Optional[int]
input_cost_per_token: float
output_cost_per_token: float
litellm_provider: str
mode: Literal[
max_tokens: Required[Optional[int]]
max_input_tokens: Required[Optional[int]]
max_output_tokens: Required[Optional[int]]
input_cost_per_token: Required[float]
input_cost_per_token_above_128k_tokens: Optional[float]
input_cost_per_image: Optional[float]
input_cost_per_audio_per_second: Optional[float]
input_cost_per_video_per_second: Optional[float]
output_cost_per_token: Required[float]
output_cost_per_token_above_128k_tokens: Optional[float]
output_cost_per_image: Optional[float]
output_cost_per_video_per_second: Optional[float]
output_cost_per_audio_per_second: Optional[float]
litellm_provider: Required[str]
mode: Required[
Literal[
"completion", "embedding", "image_generation", "chat", "audio_transcription"
]
supported_openai_params: Optional[List[str]]
]
supported_openai_params: Required[Optional[List[str]]]
class GenericStreamingChunk(TypedDict):

View file

@ -4286,8 +4286,10 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
split_model, custom_llm_provider, _, _ = get_llm_provider(model=model)
except:
pass
combined_model_name = model
else:
split_model = model
combined_model_name = "{}/{}".format(custom_llm_provider, model)
#########################
supported_openai_params = litellm.get_supported_openai_params(
@ -4305,33 +4307,58 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
}
else:
"""
Check if:
1. 'model' in litellm.model_cost. Checks "groq/llama3-8b-8192" in litellm.model_cost
2. 'split_model' in litellm.model_cost. Checks "llama3-8b-8192" in litellm.model_cost
Check if: (in order of specificity)
1. 'custom_llm_provider/model' in litellm.model_cost. Checks "groq/llama3-8b-8192" if model="llama3-8b-8192" and custom_llm_provider="groq"
2. 'model' in litellm.model_cost. Checks "groq/llama3-8b-8192" in litellm.model_cost if model="groq/llama3-8b-8192" and custom_llm_provider=None
3. 'split_model' in litellm.model_cost. Checks "llama3-8b-8192" in litellm.model_cost if model="groq/llama3-8b-8192"
"""
if model in litellm.model_cost:
if combined_model_name in litellm.model_cost:
_model_info = litellm.model_cost[combined_model_name]
_model_info["supported_openai_params"] = supported_openai_params
if (
"litellm_provider" in _model_info
and _model_info["litellm_provider"] != custom_llm_provider
):
if custom_llm_provider == "vertex_ai" and _model_info[
"litellm_provider"
].startswith("vertex_ai"):
pass
else:
raise Exception
return _model_info
elif model in litellm.model_cost:
_model_info = litellm.model_cost[model]
_model_info["supported_openai_params"] = supported_openai_params
if (
"litellm_provider" in _model_info
and _model_info["litellm_provider"] != custom_llm_provider
):
if custom_llm_provider == "vertex_ai" and _model_info[
"litellm_provider"
].startswith("vertex_ai"):
pass
else:
raise Exception
return _model_info
if split_model in litellm.model_cost:
elif split_model in litellm.model_cost:
_model_info = litellm.model_cost[split_model]
_model_info["supported_openai_params"] = supported_openai_params
if (
"litellm_provider" in _model_info
and _model_info["litellm_provider"] != custom_llm_provider
):
if custom_llm_provider == "vertex_ai" and _model_info[
"litellm_provider"
].startswith("vertex_ai"):
pass
else:
raise Exception
return _model_info
else:
raise ValueError(
"This model isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json"
)
except:
except Exception:
raise Exception(
"This model isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json"
)