diff --git a/litellm/utils.py b/litellm/utils.py index 9b6b9f54cd..ddfc39ff70 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2780,7 +2780,9 @@ def token_counter( return num_tokens -def cost_per_token(model="", prompt_tokens=0, completion_tokens=0): +def cost_per_token( + model="", prompt_tokens=0, completion_tokens=0, custom_llm_provider=None +): """ Calculates the cost per token for a given model, prompt tokens, and completion tokens. @@ -2796,6 +2798,7 @@ def cost_per_token(model="", prompt_tokens=0, completion_tokens=0): prompt_tokens_cost_usd_dollar = 0 completion_tokens_cost_usd_dollar = 0 model_cost_ref = litellm.model_cost + model_with_provider = custom_llm_provider + "/" + model # see this https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models print_verbose(f"Looking up model={model} in model_cost_map") @@ -2807,6 +2810,16 @@ def cost_per_token(model="", prompt_tokens=0, completion_tokens=0): model_cost_ref[model]["output_cost_per_token"] * completion_tokens ) return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar + elif model_with_provider in model_cost_ref: + print_verbose(f"Looking up model={model_with_provider} in model_cost_map") + prompt_tokens_cost_usd_dollar = ( + model_cost_ref[model_with_provider]["input_cost_per_token"] * prompt_tokens + ) + completion_tokens_cost_usd_dollar = ( + model_cost_ref[model_with_provider]["output_cost_per_token"] + * completion_tokens + ) + return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar elif "ft:gpt-3.5-turbo" in model: print_verbose(f"Cost Tracking: {model} is an OpenAI FinteTuned LLM") # fuzzy match ft:gpt-3.5-turbo:abcd-id-cool-litellm @@ -2890,6 +2903,7 @@ def completion_cost( # Handle Inputs to completion_cost prompt_tokens = 0 completion_tokens = 0 + custom_llm_provider = None if completion_response is not None: # get input/output tokens from completion_response prompt_tokens = completion_response.get("usage", {}).get("prompt_tokens", 0) @@ -2899,6 +2913,12 @@ def completion_cost( model = ( model or completion_response["model"] ) # check if user passed an override for model, if it's none check completion_response['model'] + if completion_response is not None and hasattr( + completion_response, "_hidden_params" + ): + custom_llm_provider = completion_response._hidden_params.get( + "custom_llm_provider", "" + ) else: if len(messages) > 0: prompt_tokens = token_counter(model=model, messages=messages) @@ -2926,6 +2946,7 @@ def completion_cost( model=model, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, + custom_llm_provider=custom_llm_provider, ) return prompt_tokens_cost_usd_dollar + completion_tokens_cost_usd_dollar except Exception as e: