diff --git a/litellm/main.py b/litellm/main.py index 7bc12ffef8..8ee0e7d7b8 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -536,6 +536,8 @@ def completion( "tpm", "input_cost_per_token", "output_cost_per_token", + "input_cost_per_second", + "output_cost_per_second", "hf_model_name", "model_info", "proxy_server_request", @@ -2262,6 +2264,11 @@ def embedding( encoding_format = kwargs.get("encoding_format", None) proxy_server_request = kwargs.get("proxy_server_request", None) aembedding = kwargs.get("aembedding", None) + ### CUSTOM MODEL COST ### + input_cost_per_token = kwargs.get("input_cost_per_token", None) + output_cost_per_token = kwargs.get("output_cost_per_token", None) + input_cost_per_second = kwargs.get("input_cost_per_second", None) + output_cost_per_second = kwargs.get("output_cost_per_second", None) openai_params = [ "user", "request_timeout", @@ -2310,6 +2317,8 @@ def embedding( "tpm", "input_cost_per_token", "output_cost_per_token", + "input_cost_per_second", + "output_cost_per_second", "hf_model_name", "proxy_server_request", "model_info", @@ -2335,6 +2344,28 @@ def embedding( custom_llm_provider=custom_llm_provider, **non_default_params, ) + ### REGISTER CUSTOM MODEL PRICING -- IF GIVEN ### + if input_cost_per_token is not None and output_cost_per_token is not None: + litellm.register_model( + { + model: { + "input_cost_per_token": input_cost_per_token, + "output_cost_per_token": output_cost_per_token, + "litellm_provider": custom_llm_provider, + } + } + ) + if input_cost_per_second is not None: # time based pricing just needs cost in place + output_cost_per_second = output_cost_per_second or 0.0 + litellm.register_model( + { + model: { + "input_cost_per_second": input_cost_per_second, + "output_cost_per_second": output_cost_per_second, + "litellm_provider": custom_llm_provider, + } + } + ) try: response = None logging = litellm_logging_obj diff --git a/litellm/utils.py b/litellm/utils.py index f4c7e93a88..192ae099df 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2371,7 +2371,9 @@ def client(original_function): result._hidden_params["model_id"] = kwargs.get("model_info", {}).get( "id", None ) - if isinstance(result, ModelResponse): + if isinstance(result, ModelResponse) or isinstance( + result, EmbeddingResponse + ): result._response_ms = ( end_time - start_time ).total_seconds() * 1000 # return response latency in ms like openai