fix(azure.py): support cost tracking for azure/dall-e-3

This commit is contained in:
Krrish Dholakia 2024-03-12 10:55:54 -07:00
parent 01921aeb5a
commit 7dd94c802e
3 changed files with 44 additions and 1 deletions

View file

@ -715,6 +715,16 @@ class AzureChatCompletion(BaseLLM):
model = model model = model
else: else:
model = None model = None
## BASE MODEL CHECK
if (
model_response is not None
and optional_params.get("base_model", None) is not None
):
model_response._hidden_params["model"] = optional_params.pop(
"base_model"
)
data = {"model": model, "prompt": prompt, **optional_params} data = {"model": model, "prompt": prompt, **optional_params}
max_retries = data.pop("max_retries", 2) max_retries = data.pop("max_retries", 2)
if not isinstance(max_retries, int): if not isinstance(max_retries, int):

View file

@ -297,3 +297,34 @@ def test_whisper_azure():
5, 5,
) )
assert cost == expected_cost assert cost == expected_cost
def test_dalle_3_azure_cost_tracking():
litellm.set_verbose = True
# model = "azure/dall-e-3-test"
# response = litellm.image_generation(
# model=model,
# prompt="A cute baby sea otter",
# api_version="2023-12-01-preview",
# api_base=os.getenv("AZURE_SWEDEN_API_BASE"),
# api_key=os.getenv("AZURE_SWEDEN_API_KEY"),
# base_model="dall-e-3",
# )
# print(f"response: {response}")
response = litellm.ImageResponse(
created=1710265780,
data=[
{
"b64_json": None,
"revised_prompt": "A close-up image of an adorable baby sea otter. Its fur is thick and fluffy to provide buoyancy and insulation against the cold water. Its eyes are round, curious and full of life. It's lying on its back, floating effortlessly on the calm sea surface under the warm sun. Surrounding the otter are patches of colorful kelp drifting along the gentle waves, giving the scene a touch of vibrancy. The sea otter has its small paws folded on its chest, and it seems to be taking a break from its play.",
"url": "https://dalleprodsec.blob.core.windows.net/private/images/3e5d00f3-700e-4b75-869d-2de73c3c975d/generated_00.png?se=2024-03-13T17%3A49%3A51Z&sig=R9RJD5oOSe0Vp9Eg7ze%2FZ8QR7ldRyGH6XhMxiau16Jc%3D&ske=2024-03-19T11%3A08%3A03Z&skoid=e52d5ed7-0657-4f62-bc12-7e5dbb260a96&sks=b&skt=2024-03-12T11%3A08%3A03Z&sktid=33e01921-4d64-4f8c-a055-5bdaffd5e33d&skv=2020-10-02&sp=r&spr=https&sr=b&sv=2020-10-02",
}
],
)
response.usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
response._hidden_params = {"model": "dall-e-3", "model_id": None}
print(f"response hidden params: {response._hidden_params}")
cost = litellm.completion_cost(
completion_response=response, call_type="image_generation"
)
assert cost > 0

View file

@ -3841,7 +3841,9 @@ def completion_cost(
* n * n
) )
else: else:
raise Exception(f"Model={model} not found in completion cost model map") raise Exception(
f"Model={image_gen_model_name} not found in completion cost model map"
)
# Calculate cost based on prompt_tokens, completion_tokens # Calculate cost based on prompt_tokens, completion_tokens
if ( if (
"togethercomputer" in model "togethercomputer" in model