(feat) use custom_llm_provider in completion_cost

This commit is contained in:
ishaan-jaff 2024-01-13 12:29:51 -08:00
parent 70426cad76
commit 53fd62b0cd

View file

@ -2780,7 +2780,9 @@ def token_counter(
return num_tokens
def cost_per_token(model="", prompt_tokens=0, completion_tokens=0):
def cost_per_token(
model="", prompt_tokens=0, completion_tokens=0, custom_llm_provider=None
):
"""
Calculates the cost per token for a given model, prompt tokens, and completion tokens.
@ -2796,6 +2798,7 @@ def cost_per_token(model="", prompt_tokens=0, completion_tokens=0):
prompt_tokens_cost_usd_dollar = 0
completion_tokens_cost_usd_dollar = 0
model_cost_ref = litellm.model_cost
model_with_provider = custom_llm_provider + "/" + model
# see this https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models
print_verbose(f"Looking up model={model} in model_cost_map")
@ -2807,6 +2810,16 @@ def cost_per_token(model="", prompt_tokens=0, completion_tokens=0):
model_cost_ref[model]["output_cost_per_token"] * completion_tokens
)
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
elif model_with_provider in model_cost_ref:
print_verbose(f"Looking up model={model_with_provider} in model_cost_map")
prompt_tokens_cost_usd_dollar = (
model_cost_ref[model_with_provider]["input_cost_per_token"] * prompt_tokens
)
completion_tokens_cost_usd_dollar = (
model_cost_ref[model_with_provider]["output_cost_per_token"]
* completion_tokens
)
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
elif "ft:gpt-3.5-turbo" in model:
print_verbose(f"Cost Tracking: {model} is an OpenAI FinteTuned LLM")
# fuzzy match ft:gpt-3.5-turbo:abcd-id-cool-litellm
@ -2890,6 +2903,7 @@ def completion_cost(
# Handle Inputs to completion_cost
prompt_tokens = 0
completion_tokens = 0
custom_llm_provider = None
if completion_response is not None:
# get input/output tokens from completion_response
prompt_tokens = completion_response.get("usage", {}).get("prompt_tokens", 0)
@ -2899,6 +2913,12 @@ def completion_cost(
model = (
model or completion_response["model"]
) # check if user passed an override for model, if it's none check completion_response['model']
if completion_response is not None and hasattr(
completion_response, "_hidden_params"
):
custom_llm_provider = completion_response._hidden_params.get(
"custom_llm_provider", ""
)
else:
if len(messages) > 0:
prompt_tokens = token_counter(model=model, messages=messages)
@ -2926,6 +2946,7 @@ def completion_cost(
model=model,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
custom_llm_provider=custom_llm_provider,
)
return prompt_tokens_cost_usd_dollar + completion_tokens_cost_usd_dollar
except Exception as e: