fix(utils.py): fix custom pricing when litellm model != response obj model name

This commit is contained in:
Krrish Dholakia 2024-05-13 15:24:56 -07:00
parent 1be6ea0c0d
commit b4a8665d11
4 changed files with 87 additions and 5 deletions

View file

@ -727,7 +727,6 @@ def completion(
### REGISTER CUSTOM MODEL PRICING -- IF GIVEN ### ### REGISTER CUSTOM MODEL PRICING -- IF GIVEN ###
if input_cost_per_token is not None and output_cost_per_token is not None: if input_cost_per_token is not None and output_cost_per_token is not None:
print_verbose(f"Registering model={model} in model cost map")
litellm.register_model( litellm.register_model(
{ {
f"{custom_llm_provider}/{model}": { f"{custom_llm_provider}/{model}": {
@ -849,6 +848,10 @@ def completion(
proxy_server_request=proxy_server_request, proxy_server_request=proxy_server_request,
preset_cache_key=preset_cache_key, preset_cache_key=preset_cache_key,
no_log=no_log, no_log=no_log,
input_cost_per_second=input_cost_per_second,
input_cost_per_token=input_cost_per_token,
output_cost_per_second=output_cost_per_second,
output_cost_per_token=output_cost_per_token,
) )
logging.update_environment_variables( logging.update_environment_variables(
model=model, model=model,

View file

@ -18,6 +18,8 @@ model_list:
model: azure/chatgpt-v-2 model: azure/chatgpt-v-2
api_key: os.environ/AZURE_API_KEY api_key: os.environ/AZURE_API_KEY
api_base: os.environ/AZURE_API_BASE api_base: os.environ/AZURE_API_BASE
input_cost_per_token: 0.0
output_cost_per_token: 0.0
router_settings: router_settings:
redis_host: redis redis_host: redis

View file

@ -5,6 +5,7 @@ sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
import time import time
from typing import Optional
import litellm import litellm
from litellm import ( from litellm import (
get_max_tokens, get_max_tokens,
@ -12,7 +13,56 @@ from litellm import (
open_ai_chat_completion_models, open_ai_chat_completion_models,
TranscriptionResponse, TranscriptionResponse,
) )
import pytest from litellm.utils import CustomLogger
import pytest, asyncio
class CustomLoggingHandler(CustomLogger):
response_cost: Optional[float] = None
def __init__(self):
super().__init__()
def log_success_event(self, kwargs, response_obj, start_time, end_time):
self.response_cost = kwargs["response_cost"]
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
print(f"kwargs - {kwargs}")
print(f"kwargs response cost - {kwargs.get('response_cost')}")
self.response_cost = kwargs["response_cost"]
print(f"response_cost: {self.response_cost} ")
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_custom_pricing(sync_mode):
new_handler = CustomLoggingHandler()
litellm.callbacks = [new_handler]
if sync_mode:
response = litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hey!"}],
mock_response="What do you want?",
input_cost_per_token=0.0,
output_cost_per_token=0.0,
)
time.sleep(5)
else:
response = await litellm.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hey!"}],
mock_response="What do you want?",
input_cost_per_token=0.0,
output_cost_per_token=0.0,
)
await asyncio.sleep(5)
print(f"new_handler.response_cost: {new_handler.response_cost}")
assert new_handler.response_cost is not None
assert new_handler.response_cost == 0
def test_get_gpt3_tokens(): def test_get_gpt3_tokens():

View file

@ -1083,6 +1083,8 @@ class CallTypes(Enum):
class Logging: class Logging:
global supabaseClient, liteDebuggerClient, promptLayerLogger, weightsBiasesLogger, langsmithLogger, capture_exception, add_breadcrumb, lunaryLogger global supabaseClient, liteDebuggerClient, promptLayerLogger, weightsBiasesLogger, langsmithLogger, capture_exception, add_breadcrumb, lunaryLogger
custom_pricing: bool = False
def __init__( def __init__(
self, self,
model, model,
@ -1165,6 +1167,15 @@ class Logging:
**additional_params, **additional_params,
} }
## check if custom pricing set ##
if (
litellm_params.get("input_cost_per_token") is not None
or litellm_params.get("input_cost_per_second") is not None
or litellm_params.get("output_cost_per_token") is not None
or litellm_params.get("output_cost_per_second") is not None
):
self.custom_pricing = True
def _pre_call(self, input, api_key, model=None, additional_args={}): def _pre_call(self, input, api_key, model=None, additional_args={}):
""" """
Common helper function across the sync + async pre-call function Common helper function across the sync + async pre-call function
@ -1442,10 +1453,18 @@ class Logging:
) )
) )
else: else:
base_model: Optional[str] = None
# check if base_model set on azure # check if base_model set on azure
base_model = _get_base_model_from_metadata( base_model = _get_base_model_from_metadata(
model_call_details=self.model_call_details model_call_details=self.model_call_details
) )
# litellm model name
litellm_model = self.model_call_details["model"]
if (
litellm_model in litellm.model_cost
and self.custom_pricing == True
):
base_model = litellm_model
# base_model defaults to None if not set on model_info # base_model defaults to None if not set on model_info
self.model_call_details["response_cost"] = ( self.model_call_details["response_cost"] = (
litellm.completion_cost( litellm.completion_cost(
@ -4365,7 +4384,7 @@ def completion_cost(
size=None, size=None,
quality=None, quality=None,
n=None, # number of images n=None, # number of images
): ) -> float:
""" """
Calculate the cost of a given completion call fot GPT-3.5-turbo, llama2, any litellm supported llm. Calculate the cost of a given completion call fot GPT-3.5-turbo, llama2, any litellm supported llm.
@ -4386,10 +4405,10 @@ def completion_cost(
- If completion_response is not provided, the function calculates token counts based on the model and input text. - If completion_response is not provided, the function calculates token counts based on the model and input text.
- The cost is calculated based on the model, prompt tokens, and completion tokens. - The cost is calculated based on the model, prompt tokens, and completion tokens.
- For certain models containing "togethercomputer" in the name, prices are based on the model size. - For certain models containing "togethercomputer" in the name, prices are based on the model size.
- For Replicate models, the cost is calculated based on the total time used for the request. - For un-mapped Replicate models, the cost is calculated based on the total time used for the request.
Exceptions: Exceptions:
- If an error occurs during execution, the function returns 0.0 without blocking the user's execution path. - If an error occurs during execution, the error is raised
""" """
try: try:
if ( if (
@ -4701,6 +4720,10 @@ def get_litellm_params(
acompletion=None, acompletion=None,
preset_cache_key=None, preset_cache_key=None,
no_log=None, no_log=None,
input_cost_per_second=None,
input_cost_per_token=None,
output_cost_per_token=None,
output_cost_per_second=None,
): ):
litellm_params = { litellm_params = {
"acompletion": acompletion, "acompletion": acompletion,
@ -4719,6 +4742,10 @@ def get_litellm_params(
"preset_cache_key": preset_cache_key, "preset_cache_key": preset_cache_key,
"no-log": no_log, "no-log": no_log,
"stream_response": {}, # litellm_call_id: ModelResponse Dict "stream_response": {}, # litellm_call_id: ModelResponse Dict
"input_cost_per_token": input_cost_per_token,
"input_cost_per_second": input_cost_per_second,
"output_cost_per_token": output_cost_per_token,
"output_cost_per_second": output_cost_per_second,
} }
return litellm_params return litellm_params