diff --git a/litellm/main.py b/litellm/main.py index 0dbd5a166..9dd618631 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -727,7 +727,6 @@ def completion( ### REGISTER CUSTOM MODEL PRICING -- IF GIVEN ### if input_cost_per_token is not None and output_cost_per_token is not None: - print_verbose(f"Registering model={model} in model cost map") litellm.register_model( { f"{custom_llm_provider}/{model}": { @@ -849,6 +848,10 @@ def completion( proxy_server_request=proxy_server_request, preset_cache_key=preset_cache_key, no_log=no_log, + input_cost_per_second=input_cost_per_second, + input_cost_per_token=input_cost_per_token, + output_cost_per_second=output_cost_per_second, + output_cost_per_token=output_cost_per_token, ) logging.update_environment_variables( model=model, diff --git a/litellm/proxy/_super_secret_config.yaml b/litellm/proxy/_super_secret_config.yaml index b83883beb..5f5ee89c3 100644 --- a/litellm/proxy/_super_secret_config.yaml +++ b/litellm/proxy/_super_secret_config.yaml @@ -18,6 +18,8 @@ model_list: model: azure/chatgpt-v-2 api_key: os.environ/AZURE_API_KEY api_base: os.environ/AZURE_API_BASE + input_cost_per_token: 0.0 + output_cost_per_token: 0.0 router_settings: redis_host: redis diff --git a/litellm/tests/test_completion_cost.py b/litellm/tests/test_completion_cost.py index 35e0496fb..8b817eb3c 100644 --- a/litellm/tests/test_completion_cost.py +++ b/litellm/tests/test_completion_cost.py @@ -5,6 +5,7 @@ sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path import time +from typing import Optional import litellm from litellm import ( get_max_tokens, @@ -12,7 +13,56 @@ from litellm import ( open_ai_chat_completion_models, TranscriptionResponse, ) -import pytest +from litellm.utils import CustomLogger +import pytest, asyncio + + +class CustomLoggingHandler(CustomLogger): + response_cost: Optional[float] = None + + def __init__(self): + super().__init__() + + def log_success_event(self, kwargs, response_obj, start_time, end_time): + self.response_cost = kwargs["response_cost"] + + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + print(f"kwargs - {kwargs}") + print(f"kwargs response cost - {kwargs.get('response_cost')}") + self.response_cost = kwargs["response_cost"] + + print(f"response_cost: {self.response_cost} ") + + +@pytest.mark.parametrize("sync_mode", [True, False]) +@pytest.mark.asyncio +async def test_custom_pricing(sync_mode): + new_handler = CustomLoggingHandler() + litellm.callbacks = [new_handler] + if sync_mode: + response = litellm.completion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Hey!"}], + mock_response="What do you want?", + input_cost_per_token=0.0, + output_cost_per_token=0.0, + ) + time.sleep(5) + else: + response = await litellm.acompletion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Hey!"}], + mock_response="What do you want?", + input_cost_per_token=0.0, + output_cost_per_token=0.0, + ) + + await asyncio.sleep(5) + + print(f"new_handler.response_cost: {new_handler.response_cost}") + assert new_handler.response_cost is not None + + assert new_handler.response_cost == 0 def test_get_gpt3_tokens(): diff --git a/litellm/utils.py b/litellm/utils.py index 113a1a550..e2bc8bec1 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -1083,6 +1083,8 @@ class CallTypes(Enum): class Logging: global supabaseClient, liteDebuggerClient, promptLayerLogger, weightsBiasesLogger, langsmithLogger, capture_exception, add_breadcrumb, lunaryLogger + custom_pricing: bool = False + def __init__( self, model, @@ -1165,6 +1167,15 @@ class Logging: **additional_params, } + ## check if custom pricing set ## + if ( + litellm_params.get("input_cost_per_token") is not None + or litellm_params.get("input_cost_per_second") is not None + or litellm_params.get("output_cost_per_token") is not None + or litellm_params.get("output_cost_per_second") is not None + ): + self.custom_pricing = True + def _pre_call(self, input, api_key, model=None, additional_args={}): """ Common helper function across the sync + async pre-call function @@ -1442,10 +1453,18 @@ class Logging: ) ) else: + base_model: Optional[str] = None # check if base_model set on azure base_model = _get_base_model_from_metadata( model_call_details=self.model_call_details ) + # litellm model name + litellm_model = self.model_call_details["model"] + if ( + litellm_model in litellm.model_cost + and self.custom_pricing == True + ): + base_model = litellm_model # base_model defaults to None if not set on model_info self.model_call_details["response_cost"] = ( litellm.completion_cost( @@ -4365,7 +4384,7 @@ def completion_cost( size=None, quality=None, n=None, # number of images -): +) -> float: """ Calculate the cost of a given completion call fot GPT-3.5-turbo, llama2, any litellm supported llm. @@ -4386,10 +4405,10 @@ def completion_cost( - If completion_response is not provided, the function calculates token counts based on the model and input text. - The cost is calculated based on the model, prompt tokens, and completion tokens. - For certain models containing "togethercomputer" in the name, prices are based on the model size. - - For Replicate models, the cost is calculated based on the total time used for the request. + - For un-mapped Replicate models, the cost is calculated based on the total time used for the request. Exceptions: - - If an error occurs during execution, the function returns 0.0 without blocking the user's execution path. + - If an error occurs during execution, the error is raised """ try: if ( @@ -4701,6 +4720,10 @@ def get_litellm_params( acompletion=None, preset_cache_key=None, no_log=None, + input_cost_per_second=None, + input_cost_per_token=None, + output_cost_per_token=None, + output_cost_per_second=None, ): litellm_params = { "acompletion": acompletion, @@ -4719,6 +4742,10 @@ def get_litellm_params( "preset_cache_key": preset_cache_key, "no-log": no_log, "stream_response": {}, # litellm_call_id: ModelResponse Dict + "input_cost_per_token": input_cost_per_token, + "input_cost_per_second": input_cost_per_second, + "output_cost_per_token": output_cost_per_token, + "output_cost_per_second": output_cost_per_second, } return litellm_params