diff --git a/litellm/proxy/spend_tracking/spend_management_endpoints.py b/litellm/proxy/spend_tracking/spend_management_endpoints.py index abbdc3419..1fbd95b3c 100644 --- a/litellm/proxy/spend_tracking/spend_management_endpoints.py +++ b/litellm/proxy/spend_tracking/spend_management_endpoints.py @@ -1265,9 +1265,22 @@ async def calculate_spend(request: SpendCalculateRequest): _model_in_llm_router = None cost_per_token: Optional[CostPerToken] = None if llm_router is not None: - for model in llm_router.model_list: - if model.get("model_name") == request.model: - _model_in_llm_router = model + if ( + llm_router.model_group_alias is not None + and request.model in llm_router.model_group_alias + ): + # lookup alias in llm_router + _model_group_name = llm_router.model_group_alias[request.model] + for model in llm_router.model_list: + if model.get("model_name") == _model_group_name: + _model_in_llm_router = model + + else: + # no model_group aliases set -> try finding model in llm_router + # find model in llm_router + for model in llm_router.model_list: + if model.get("model_name") == request.model: + _model_in_llm_router = model """ 3 cases for /spend/calculate diff --git a/litellm/tests/test_spend_calculate_endpoint.py b/litellm/tests/test_spend_calculate_endpoint.py index f8aff337e..8bdd4a54d 100644 --- a/litellm/tests/test_spend_calculate_endpoint.py +++ b/litellm/tests/test_spend_calculate_endpoint.py @@ -101,3 +101,41 @@ async def test_spend_calc_using_response(): print("calculated cost", cost_obj) cost = cost_obj["cost"] assert cost > 0.0 + + +@pytest.mark.asyncio +async def test_spend_calc_model_alias_on_router_messages(): + from litellm.proxy.proxy_server import llm_router as init_llm_router + + temp_llm_router = Router( + model_list=[ + { + "model_name": "gpt-4o", + "litellm_params": { + "model": "gpt-4o", + }, + } + ], + model_group_alias={ + "gpt4o": "gpt-4o", + }, + ) + + setattr(litellm.proxy.proxy_server, "llm_router", temp_llm_router) + + cost_obj = await calculate_spend( + request=SpendCalculateRequest( + model="gpt4o", + messages=[ + {"role": "user", "content": "What is the capital of France?"}, + ], + ) + ) + + print("calculated cost", cost_obj) + _cost = cost_obj["cost"] + + assert _cost > 0.0 + + # set router to init value + setattr(litellm.proxy.proxy_server, "llm_router", init_llm_router)