Realtime API Cost tracking (#9795)

* fix(proxy_server.py): log realtime calls to spendlogs

Fixes https://github.com/BerriAI/litellm/issues/8410

* feat(realtime/): OpenAI Realtime API cost tracking

Closes https://github.com/BerriAI/litellm/issues/8410

* test: add unit testing for coverage

* test: add more unit testing

* fix: handle edge cases
This commit is contained in:
Krish Dholakia 2025-04-07 16:43:12 -07:00 committed by GitHub
parent 9a60cd9deb
commit 4a128cfd64
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 401 additions and 39 deletions

View file

@ -16,7 +16,10 @@ from litellm.constants import (
from litellm.litellm_core_utils.llm_cost_calc.tool_call_cost_tracking import (
StandardBuiltInToolCostTracking,
)
from litellm.litellm_core_utils.llm_cost_calc.utils import _generic_cost_per_character
from litellm.litellm_core_utils.llm_cost_calc.utils import (
_generic_cost_per_character,
generic_cost_per_token,
)
from litellm.llms.anthropic.cost_calculation import (
cost_per_token as anthropic_cost_per_token,
)
@ -54,6 +57,9 @@ from litellm.llms.vertex_ai.image_generation.cost_calculator import (
from litellm.responses.utils import ResponseAPILoggingUtils
from litellm.types.llms.openai import (
HttpxBinaryResponseContent,
OpenAIRealtimeStreamList,
OpenAIRealtimeStreamResponseBaseObject,
OpenAIRealtimeStreamSessionEvents,
ResponseAPIUsage,
ResponsesAPIResponse,
)
@ -1141,3 +1147,50 @@ def batch_cost_calculator(
) # batch cost is usually half of the regular token cost
return total_prompt_cost, total_completion_cost
def handle_realtime_stream_cost_calculation(
results: OpenAIRealtimeStreamList, custom_llm_provider: str, litellm_model_name: str
) -> float:
"""
Handles the cost calculation for realtime stream responses.
Pick the 'response.done' events. Calculate total cost across all 'response.done' events.
Args:
results: A list of OpenAIRealtimeStreamBaseObject objects
"""
response_done_events: List[OpenAIRealtimeStreamResponseBaseObject] = cast(
List[OpenAIRealtimeStreamResponseBaseObject],
[result for result in results if result["type"] == "response.done"],
)
received_model = None
potential_model_names = []
for result in results:
if result["type"] == "session.created":
received_model = cast(OpenAIRealtimeStreamSessionEvents, result)["session"][
"model"
]
potential_model_names.append(received_model)
potential_model_names.append(litellm_model_name)
input_cost_per_token = 0.0
output_cost_per_token = 0.0
for result in response_done_events:
usage_object = (
ResponseAPILoggingUtils._transform_response_api_usage_to_chat_usage(
result["response"].get("usage", {})
)
)
for model_name in potential_model_names:
_input_cost_per_token, _output_cost_per_token = generic_cost_per_token(
model=model_name,
usage=usage_object,
custom_llm_provider=custom_llm_provider,
)
input_cost_per_token += _input_cost_per_token
output_cost_per_token += _output_cost_per_token
total_cost = input_cost_per_token + output_cost_per_token
return total_cost

View file

@ -32,7 +32,10 @@ from litellm.constants import (
DEFAULT_MOCK_RESPONSE_COMPLETION_TOKEN_COUNT,
DEFAULT_MOCK_RESPONSE_PROMPT_TOKEN_COUNT,
)
from litellm.cost_calculator import _select_model_name_for_cost_calc
from litellm.cost_calculator import (
_select_model_name_for_cost_calc,
handle_realtime_stream_cost_calculation,
)
from litellm.integrations.arize.arize import ArizeLogger
from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.integrations.custom_logger import CustomLogger
@ -1049,6 +1052,13 @@ class Logging(LiteLLMLoggingBaseClass):
result = self._handle_anthropic_messages_response_logging(result=result)
## if model in model cost map - log the response cost
## else set cost to None
if self.call_type == CallTypes.arealtime.value and isinstance(result, list):
self.model_call_details[
"response_cost"
] = handle_realtime_stream_cost_calculation(
result, self.custom_llm_provider, self.model
)
if (
standard_logging_object is None
and result is not None

View file

@ -30,6 +30,11 @@ import json
from typing import Any, Dict, List, Optional, Union
import litellm
from litellm._logging import verbose_logger
from litellm.types.llms.openai import (
OpenAIRealtimeStreamResponseBaseObject,
OpenAIRealtimeStreamSessionEvents,
)
from .litellm_logging import Logging as LiteLLMLogging
@ -53,7 +58,12 @@ class RealTimeStreaming:
self.websocket = websocket
self.backend_ws = backend_ws
self.logging_obj = logging_obj
self.messages: List = []
self.messages: List[
Union[
OpenAIRealtimeStreamResponseBaseObject,
OpenAIRealtimeStreamSessionEvents,
]
] = []
self.input_message: Dict = {}
_logged_real_time_event_types = litellm.logged_real_time_event_types
@ -62,10 +72,14 @@ class RealTimeStreaming:
_logged_real_time_event_types = DefaultLoggedRealTimeEventTypes
self.logged_real_time_event_types = _logged_real_time_event_types
def _should_store_message(self, message: Union[str, bytes]) -> bool:
if isinstance(message, bytes):
message = message.decode("utf-8")
message_obj = json.loads(message)
def _should_store_message(
self,
message_obj: Union[
dict,
OpenAIRealtimeStreamSessionEvents,
OpenAIRealtimeStreamResponseBaseObject,
],
) -> bool:
_msg_type = message_obj["type"]
if self.logged_real_time_event_types == "*":
return True
@ -75,8 +89,22 @@ class RealTimeStreaming:
def store_message(self, message: Union[str, bytes]):
"""Store message in list"""
if self._should_store_message(message):
self.messages.append(message)
if isinstance(message, bytes):
message = message.decode("utf-8")
message_obj = json.loads(message)
try:
if (
message_obj.get("type") == "session.created"
or message_obj.get("type") == "session.updated"
):
message_obj = OpenAIRealtimeStreamSessionEvents(**message_obj) # type: ignore
else:
message_obj = OpenAIRealtimeStreamResponseBaseObject(**message_obj) # type: ignore
except Exception as e:
verbose_logger.debug(f"Error parsing message for logging: {e}")
raise e
if self._should_store_message(message_obj):
self.messages.append(message_obj)
def store_input(self, message: dict):
"""Store input message"""

View file

@ -40,9 +40,4 @@ litellm_settings:
files_settings:
- custom_llm_provider: gemini
api_key: os.environ/GEMINI_API_KEY
general_settings:
disable_spend_logs: True
disable_error_logs: True
api_key: os.environ/GEMINI_API_KEY

View file

@ -2,7 +2,7 @@ import asyncio
import json
import uuid
from datetime import datetime
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Tuple, Union
import httpx
from fastapi import HTTPException, Request, status
@ -101,33 +101,22 @@ class ProxyBaseLLMRequestProcessing:
verbose_proxy_logger.error(f"Error setting custom headers: {e}")
return {}
async def base_process_llm_request(
async def common_processing_pre_call_logic(
self,
request: Request,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth,
route_type: Literal["acompletion", "aresponses"],
proxy_logging_obj: ProxyLogging,
general_settings: dict,
user_api_key_dict: UserAPIKeyAuth,
proxy_logging_obj: ProxyLogging,
proxy_config: ProxyConfig,
select_data_generator: Callable,
llm_router: Optional[Router] = None,
model: Optional[str] = None,
route_type: Literal["acompletion", "aresponses", "_arealtime"],
version: Optional[str] = None,
user_model: Optional[str] = None,
user_temperature: Optional[float] = None,
user_request_timeout: Optional[float] = None,
user_max_tokens: Optional[int] = None,
user_api_base: Optional[str] = None,
version: Optional[str] = None,
) -> Any:
"""
Common request processing logic for both chat completions and responses API endpoints
"""
verbose_proxy_logger.debug(
"Request received by LiteLLM:\n{}".format(json.dumps(self.data, indent=4)),
)
model: Optional[str] = None,
) -> Tuple[dict, LiteLLMLoggingObj]:
self.data = await add_litellm_data_to_request(
data=self.data,
request=request,
@ -182,13 +171,57 @@ class ProxyBaseLLMRequestProcessing:
self.data["litellm_logging_obj"] = logging_obj
return self.data, logging_obj
async def base_process_llm_request(
self,
request: Request,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth,
route_type: Literal["acompletion", "aresponses", "_arealtime"],
proxy_logging_obj: ProxyLogging,
general_settings: dict,
proxy_config: ProxyConfig,
select_data_generator: Callable,
llm_router: Optional[Router] = None,
model: Optional[str] = None,
user_model: Optional[str] = None,
user_temperature: Optional[float] = None,
user_request_timeout: Optional[float] = None,
user_max_tokens: Optional[int] = None,
user_api_base: Optional[str] = None,
version: Optional[str] = None,
) -> Any:
"""
Common request processing logic for both chat completions and responses API endpoints
"""
verbose_proxy_logger.debug(
"Request received by LiteLLM:\n{}".format(json.dumps(self.data, indent=4)),
)
self.data, logging_obj = await self.common_processing_pre_call_logic(
request=request,
general_settings=general_settings,
proxy_logging_obj=proxy_logging_obj,
user_api_key_dict=user_api_key_dict,
version=version,
proxy_config=proxy_config,
user_model=user_model,
user_temperature=user_temperature,
user_request_timeout=user_request_timeout,
user_max_tokens=user_max_tokens,
user_api_base=user_api_base,
model=model,
route_type=route_type,
)
tasks = []
tasks.append(
proxy_logging_obj.during_call_hook(
data=self.data,
user_api_key_dict=user_api_key_dict,
call_type=ProxyBaseLLMRequestProcessing._get_pre_call_type(
route_type=route_type
route_type=route_type # type: ignore
),
)
)

View file

@ -194,13 +194,15 @@ class _ProxyDBLogger(CustomLogger):
error_msg = f"Error in tracking cost callback - {str(e)}\n Traceback:{traceback.format_exc()}"
model = kwargs.get("model", "")
metadata = kwargs.get("litellm_params", {}).get("metadata", {})
error_msg += f"\n Args to _PROXY_track_cost_callback\n model: {model}\n metadata: {metadata}\n"
call_type = kwargs.get("call_type", "")
error_msg += f"\n Args to _PROXY_track_cost_callback\n model: {model}\n metadata: {metadata}\n call_type: {call_type}\n"
asyncio.create_task(
proxy_logging_obj.failed_tracking_alert(
error_message=error_msg,
failing_model=model,
)
)
verbose_proxy_logger.exception(
"Error in tracking cost callback - %s", str(e)
)

View file

@ -191,6 +191,7 @@ def clean_headers(
if litellm_key_header_name is not None:
special_headers.append(litellm_key_header_name.lower())
clean_headers = {}
for header, value in headers.items():
if header.lower() not in special_headers:
clean_headers[header] = value

View file

@ -4261,8 +4261,47 @@ async def websocket_endpoint(
"websocket": websocket,
}
headers = dict(websocket.headers.items()) # Convert headers to dict first
request = Request(
scope={
"type": "http",
"headers": [(k.lower().encode(), v.encode()) for k, v in headers.items()],
"method": "POST",
"path": "/v1/realtime",
}
)
request._url = websocket.url
async def return_body():
return_string = f'{{"model": "{model}"}}'
# return string as bytes
return return_string.encode()
request.body = return_body # type: ignore
### ROUTE THE REQUEST ###
base_llm_response_processor = ProxyBaseLLMRequestProcessing(data=data)
try:
(
data,
litellm_logging_obj,
) = await base_llm_response_processor.common_processing_pre_call_logic(
request=request,
general_settings=general_settings,
user_api_key_dict=user_api_key_dict,
version=version,
proxy_logging_obj=proxy_logging_obj,
proxy_config=proxy_config,
user_model=user_model,
user_temperature=user_temperature,
user_request_timeout=user_request_timeout,
user_max_tokens=user_max_tokens,
user_api_base=user_api_base,
model=model,
route_type="_arealtime",
)
llm_call = await route_request(
data=data,
route_type="_arealtime",

View file

@ -892,7 +892,17 @@ class BaseLiteLLMOpenAIResponseObject(BaseModel):
class OutputTokensDetails(BaseLiteLLMOpenAIResponseObject):
reasoning_tokens: int
reasoning_tokens: Optional[int] = None
text_tokens: Optional[int] = None
model_config = {"extra": "allow"}
class InputTokensDetails(BaseLiteLLMOpenAIResponseObject):
audio_tokens: Optional[int] = None
cached_tokens: Optional[int] = None
text_tokens: Optional[int] = None
model_config = {"extra": "allow"}
@ -901,10 +911,13 @@ class ResponseAPIUsage(BaseLiteLLMOpenAIResponseObject):
input_tokens: int
"""The number of input tokens."""
input_tokens_details: Optional[InputTokensDetails] = None
"""A detailed breakdown of the input tokens."""
output_tokens: int
"""The number of output tokens."""
output_tokens_details: Optional[OutputTokensDetails]
output_tokens_details: Optional[OutputTokensDetails] = None
"""A detailed breakdown of the output tokens."""
total_tokens: int
@ -1173,3 +1186,20 @@ ResponsesAPIStreamingResponse = Annotated[
REASONING_EFFORT = Literal["low", "medium", "high"]
class OpenAIRealtimeStreamSessionEvents(TypedDict):
event_id: str
session: dict
type: Union[Literal["session.created"], Literal["session.updated"]]
class OpenAIRealtimeStreamResponseBaseObject(TypedDict):
event_id: str
response: dict
type: str
OpenAIRealtimeStreamList = List[
Union[OpenAIRealtimeStreamResponseBaseObject, OpenAIRealtimeStreamSessionEvents]
]

View file

@ -0,0 +1,65 @@
import json
import os
import sys
from unittest.mock import MagicMock, patch
import pytest
sys.path.insert(
0, os.path.abspath("../../..")
) # Adds the parent directory to the system path
from litellm.litellm_core_utils.realtime_streaming import RealTimeStreaming
from litellm.types.llms.openai import (
OpenAIRealtimeStreamResponseBaseObject,
OpenAIRealtimeStreamSessionEvents,
)
def test_realtime_streaming_store_message():
# Setup
websocket = MagicMock()
backend_ws = MagicMock()
logging_obj = MagicMock()
streaming = RealTimeStreaming(websocket, backend_ws, logging_obj)
# Test 1: Session created event (string input)
session_created_msg = json.dumps(
{"type": "session.created", "session": {"id": "test-session"}}
)
streaming.store_message(session_created_msg)
assert len(streaming.messages) == 1
assert "session" in streaming.messages[0]
assert streaming.messages[0]["type"] == "session.created"
# Test 2: Response object (bytes input)
response_msg = json.dumps(
{
"type": "response.create",
"event_id": "test-event",
"response": {"text": "test response"},
}
).encode("utf-8")
streaming.store_message(response_msg)
assert len(streaming.messages) == 2
assert "response" in streaming.messages[1]
assert streaming.messages[1]["type"] == "response.create"
# Test 3: Invalid message format
invalid_msg = "invalid json"
with pytest.raises(Exception):
streaming.store_message(invalid_msg)
# Test 4: Message type not in logged events
streaming.logged_real_time_event_types = [
"session.created"
] # Only log session.created
other_msg = json.dumps(
{
"type": "response.done",
"event_id": "test-event",
"response": {"text": "test response"},
}
)
streaming.store_message(other_msg)
assert len(streaming.messages) == 2 # Should not store the new message

View file

@ -13,7 +13,11 @@ from unittest.mock import MagicMock, patch
from pydantic import BaseModel
import litellm
from litellm.cost_calculator import response_cost_calculator
from litellm.cost_calculator import (
handle_realtime_stream_cost_calculation,
response_cost_calculator,
)
from litellm.types.llms.openai import OpenAIRealtimeStreamList
from litellm.types.utils import ModelResponse, PromptTokensDetailsWrapper, Usage
@ -71,3 +75,66 @@ def test_cost_calculator_with_usage():
)
assert result == expected_cost, f"Got {result}, Expected {expected_cost}"
def test_handle_realtime_stream_cost_calculation():
# Setup test data
results: OpenAIRealtimeStreamList = [
{"type": "session.created", "session": {"model": "gpt-3.5-turbo"}},
{
"type": "response.done",
"response": {
"usage": {"input_tokens": 100, "output_tokens": 50, "total_tokens": 150}
},
},
{
"type": "response.done",
"response": {
"usage": {
"input_tokens": 200,
"output_tokens": 100,
"total_tokens": 300,
}
},
},
]
# Test with explicit model name
cost = handle_realtime_stream_cost_calculation(
results=results,
custom_llm_provider="openai",
litellm_model_name="gpt-3.5-turbo",
)
# Calculate expected cost
# gpt-3.5-turbo costs: $0.0015/1K tokens input, $0.002/1K tokens output
expected_cost = (300 * 0.0015 / 1000) + ( # input tokens (100 + 200)
150 * 0.002 / 1000
) # output tokens (50 + 100)
assert (
abs(cost - expected_cost) <= 0.00075
) # Allow small floating point differences
# Test with different model name in session
results[0]["session"]["model"] = "gpt-4"
cost = handle_realtime_stream_cost_calculation(
results=results,
custom_llm_provider="openai",
litellm_model_name="gpt-3.5-turbo",
)
# Calculate expected cost using gpt-4 rates
# gpt-4 costs: $0.03/1K tokens input, $0.06/1K tokens output
expected_cost = (300 * 0.03 / 1000) + ( # input tokens
150 * 0.06 / 1000
) # output tokens
assert abs(cost - expected_cost) < 0.00076
# Test with no response.done events
results = [{"type": "session.created", "session": {"model": "gpt-3.5-turbo"}}]
cost = handle_realtime_stream_cost_calculation(
results=results,
custom_llm_provider="openai",
litellm_model_name="gpt-3.5-turbo",
)
assert cost == 0.0 # No usage, no cost

View file

@ -0,0 +1,39 @@
============================= test session starts ==============================
platform darwin -- Python 3.11.4, pytest-7.4.1, pluggy-1.2.0 -- /Library/Frameworks/Python.framework/Versions/3.11/bin/python3
cachedir: .pytest_cache
rootdir: /Users/krrishdholakia/Documents/litellm/tests/logging_callback_tests
plugins: snapshot-0.9.0, cov-5.0.0, timeout-2.2.0, postgresql-7.0.1, respx-0.21.1, asyncio-0.21.1, langsmith-0.3.4, anyio-4.8.0, mock-3.11.1, Faker-25.9.2
asyncio: mode=Mode.STRICT
collecting ... collected 4 items
test_built_in_tools_cost_tracking.py::test_openai_responses_api_web_search_cost_tracking[tools_config0-search_context_size_low-True] PASSED [ 25%]
test_built_in_tools_cost_tracking.py::test_openai_responses_api_web_search_cost_tracking[tools_config1-search_context_size_low-False] PASSED [ 50%]
test_built_in_tools_cost_tracking.py::test_openai_responses_api_web_search_cost_tracking[tools_config2-search_context_size_medium-True] PASSED [ 75%]
test_built_in_tools_cost_tracking.py::test_openai_responses_api_web_search_cost_tracking[tools_config3-search_context_size_medium-False] PASSED [100%]
=============================== warnings summary ===============================
../../../../../../Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/pydantic/_internal/_config.py:295
/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/pydantic/_internal/_config.py:295: PydanticDeprecatedSince20: Support for class-based `config` is deprecated, use ConfigDict instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.10/migration/
warnings.warn(DEPRECATION_MESSAGE, DeprecationWarning)
../../litellm/litellm_core_utils/get_model_cost_map.py:24
test_built_in_tools_cost_tracking.py::test_openai_responses_api_web_search_cost_tracking[tools_config0-search_context_size_low-True]
test_built_in_tools_cost_tracking.py::test_openai_responses_api_web_search_cost_tracking[tools_config1-search_context_size_low-False]
test_built_in_tools_cost_tracking.py::test_openai_responses_api_web_search_cost_tracking[tools_config2-search_context_size_medium-True]
test_built_in_tools_cost_tracking.py::test_openai_responses_api_web_search_cost_tracking[tools_config3-search_context_size_medium-False]
/Users/krrishdholakia/Documents/litellm/litellm/litellm_core_utils/get_model_cost_map.py:24: DeprecationWarning: open_text is deprecated. Use files() instead. Refer to https://importlib-resources.readthedocs.io/en/latest/using.html#migrating-from-legacy for migration advice.
with importlib.resources.open_text(
../../litellm/utils.py:183
/Users/krrishdholakia/Documents/litellm/litellm/utils.py:183: DeprecationWarning: open_text is deprecated. Use files() instead. Refer to https://importlib-resources.readthedocs.io/en/latest/using.html#migrating-from-legacy for migration advice.
with resources.open_text(
test_built_in_tools_cost_tracking.py::test_openai_responses_api_web_search_cost_tracking[tools_config0-search_context_size_low-True]
test_built_in_tools_cost_tracking.py::test_openai_responses_api_web_search_cost_tracking[tools_config1-search_context_size_low-False]
test_built_in_tools_cost_tracking.py::test_openai_responses_api_web_search_cost_tracking[tools_config2-search_context_size_medium-True]
test_built_in_tools_cost_tracking.py::test_openai_responses_api_web_search_cost_tracking[tools_config3-search_context_size_medium-False]
/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_content.py:204: DeprecationWarning: Use 'content=<...>' to upload raw bytes/text content.
warnings.warn(message, DeprecationWarning)
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
======================= 4 passed, 11 warnings in 18.95s ========================