Realtime API: Support 'base_model' cost tracking + show response in spend logs (if enabled) (#9897)

* refactor(litellm_logging.py): refactor realtime cost tracking to use common code as rest

Ensures basic features like base model just work

* feat(realtime/): support 'base_model' cost tracking on realtime api

Fixes issue where base model was not working on realtime

* fix: fix ruff linting error

* test: fix test
This commit is contained in:
Krish Dholakia 2025-04-10 21:24:45 -07:00 committed by GitHub
parent 78879c68a9
commit 9f27e8363f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 102 additions and 46 deletions

View file

@ -66,6 +66,7 @@ from litellm.types.llms.openai import (
from litellm.types.rerank import RerankBilledUnits, RerankResponse from litellm.types.rerank import RerankBilledUnits, RerankResponse
from litellm.types.utils import ( from litellm.types.utils import (
CallTypesLiteral, CallTypesLiteral,
LiteLLMRealtimeStreamLoggingObject,
LlmProviders, LlmProviders,
LlmProvidersSet, LlmProvidersSet,
ModelInfo, ModelInfo,
@ -617,6 +618,7 @@ def completion_cost( # noqa: PLR0915
completion_response=completion_response completion_response=completion_response
) )
rerank_billed_units: Optional[RerankBilledUnits] = None rerank_billed_units: Optional[RerankBilledUnits] = None
selected_model = _select_model_name_for_cost_calc( selected_model = _select_model_name_for_cost_calc(
model=model, model=model,
completion_response=completion_response, completion_response=completion_response,
@ -792,6 +794,25 @@ def completion_cost( # noqa: PLR0915
billed_units.get("search_units") or 1 billed_units.get("search_units") or 1
) # cohere charges per request by default. ) # cohere charges per request by default.
completion_tokens = search_units completion_tokens = search_units
elif call_type == CallTypes.arealtime.value and isinstance(
completion_response, LiteLLMRealtimeStreamLoggingObject
):
if (
cost_per_token_usage_object is None
or custom_llm_provider is None
):
raise ValueError(
"usage object and custom_llm_provider must be provided for realtime stream cost calculation. Got cost_per_token_usage_object={}, custom_llm_provider={}".format(
cost_per_token_usage_object,
custom_llm_provider,
)
)
return handle_realtime_stream_cost_calculation(
results=completion_response.results,
combined_usage_object=cost_per_token_usage_object,
custom_llm_provider=custom_llm_provider,
litellm_model_name=model,
)
# Calculate cost based on prompt_tokens, completion_tokens # Calculate cost based on prompt_tokens, completion_tokens
if ( if (
"togethercomputer" in model "togethercomputer" in model
@ -921,6 +942,7 @@ def response_cost_calculator(
HttpxBinaryResponseContent, HttpxBinaryResponseContent,
RerankResponse, RerankResponse,
ResponsesAPIResponse, ResponsesAPIResponse,
LiteLLMRealtimeStreamLoggingObject,
], ],
model: str, model: str,
custom_llm_provider: Optional[str], custom_llm_provider: Optional[str],
@ -1274,6 +1296,15 @@ class RealtimeAPITokenUsageProcessor:
) )
return combined_usage_object return combined_usage_object
@staticmethod
def create_logging_realtime_object(
usage: Usage, results: OpenAIRealtimeStreamList
) -> LiteLLMRealtimeStreamLoggingObject:
return LiteLLMRealtimeStreamLoggingObject(
usage=usage,
results=results,
)
def handle_realtime_stream_cost_calculation( def handle_realtime_stream_cost_calculation(
results: OpenAIRealtimeStreamList, results: OpenAIRealtimeStreamList,

View file

@ -35,7 +35,6 @@ from litellm.constants import (
from litellm.cost_calculator import ( from litellm.cost_calculator import (
RealtimeAPITokenUsageProcessor, RealtimeAPITokenUsageProcessor,
_select_model_name_for_cost_calc, _select_model_name_for_cost_calc,
handle_realtime_stream_cost_calculation,
) )
from litellm.integrations.arize.arize import ArizeLogger from litellm.integrations.arize.arize import ArizeLogger
from litellm.integrations.custom_guardrail import CustomGuardrail from litellm.integrations.custom_guardrail import CustomGuardrail
@ -68,6 +67,7 @@ from litellm.types.utils import (
ImageResponse, ImageResponse,
LiteLLMBatch, LiteLLMBatch,
LiteLLMLoggingBaseClass, LiteLLMLoggingBaseClass,
LiteLLMRealtimeStreamLoggingObject,
ModelResponse, ModelResponse,
ModelResponseStream, ModelResponseStream,
RawRequestTypedDict, RawRequestTypedDict,
@ -902,6 +902,7 @@ class Logging(LiteLLMLoggingBaseClass):
FineTuningJob, FineTuningJob,
ResponsesAPIResponse, ResponsesAPIResponse,
ResponseCompletedEvent, ResponseCompletedEvent,
LiteLLMRealtimeStreamLoggingObject,
], ],
cache_hit: Optional[bool] = None, cache_hit: Optional[bool] = None,
litellm_model_name: Optional[str] = None, litellm_model_name: Optional[str] = None,
@ -1055,39 +1056,49 @@ class Logging(LiteLLMLoggingBaseClass):
## if model in model cost map - log the response cost ## if model in model cost map - log the response cost
## else set cost to None ## else set cost to None
logging_result = result
if self.call_type == CallTypes.arealtime.value and isinstance(result, list): if self.call_type == CallTypes.arealtime.value and isinstance(result, list):
combined_usage_object = RealtimeAPITokenUsageProcessor.collect_and_combine_usage_from_realtime_stream_results( combined_usage_object = RealtimeAPITokenUsageProcessor.collect_and_combine_usage_from_realtime_stream_results(
results=result results=result
) )
self.model_call_details[ logging_result = (
"response_cost" RealtimeAPITokenUsageProcessor.create_logging_realtime_object(
] = handle_realtime_stream_cost_calculation( usage=combined_usage_object,
results=result, results=result,
combined_usage_object=combined_usage_object,
custom_llm_provider=self.custom_llm_provider,
litellm_model_name=self.model,
) )
self.model_call_details["combined_usage_object"] = combined_usage_object )
# self.model_call_details[
# "response_cost"
# ] = handle_realtime_stream_cost_calculation(
# results=result,
# combined_usage_object=combined_usage_object,
# custom_llm_provider=self.custom_llm_provider,
# litellm_model_name=self.model,
# )
# self.model_call_details["combined_usage_object"] = combined_usage_object
if ( if (
standard_logging_object is None standard_logging_object is None
and result is not None and result is not None
and self.stream is not True and self.stream is not True
): ):
if ( if (
isinstance(result, ModelResponse) isinstance(logging_result, ModelResponse)
or isinstance(result, ModelResponseStream) or isinstance(logging_result, ModelResponseStream)
or isinstance(result, EmbeddingResponse) or isinstance(logging_result, EmbeddingResponse)
or isinstance(result, ImageResponse) or isinstance(logging_result, ImageResponse)
or isinstance(result, TranscriptionResponse) or isinstance(logging_result, TranscriptionResponse)
or isinstance(result, TextCompletionResponse) or isinstance(logging_result, TextCompletionResponse)
or isinstance(result, HttpxBinaryResponseContent) # tts or isinstance(logging_result, HttpxBinaryResponseContent) # tts
or isinstance(result, RerankResponse) or isinstance(logging_result, RerankResponse)
or isinstance(result, FineTuningJob) or isinstance(logging_result, FineTuningJob)
or isinstance(result, LiteLLMBatch) or isinstance(logging_result, LiteLLMBatch)
or isinstance(result, ResponsesAPIResponse) or isinstance(logging_result, ResponsesAPIResponse)
or isinstance(logging_result, LiteLLMRealtimeStreamLoggingObject)
): ):
## HIDDEN PARAMS ## ## HIDDEN PARAMS ##
hidden_params = getattr(result, "_hidden_params", {}) hidden_params = getattr(logging_result, "_hidden_params", {})
if hidden_params: if hidden_params:
# add to metadata for logging # add to metadata for logging
if self.model_call_details.get("litellm_params") is not None: if self.model_call_details.get("litellm_params") is not None:
@ -1105,7 +1116,7 @@ class Logging(LiteLLMLoggingBaseClass):
self.model_call_details["litellm_params"]["metadata"][ # type: ignore self.model_call_details["litellm_params"]["metadata"][ # type: ignore
"hidden_params" "hidden_params"
] = getattr( ] = getattr(
result, "_hidden_params", {} logging_result, "_hidden_params", {}
) )
## RESPONSE COST - Only calculate if not in hidden_params ## ## RESPONSE COST - Only calculate if not in hidden_params ##
if "response_cost" in hidden_params: if "response_cost" in hidden_params:
@ -1115,14 +1126,14 @@ class Logging(LiteLLMLoggingBaseClass):
else: else:
self.model_call_details[ self.model_call_details[
"response_cost" "response_cost"
] = self._response_cost_calculator(result=result) ] = self._response_cost_calculator(result=logging_result)
## STANDARDIZED LOGGING PAYLOAD ## STANDARDIZED LOGGING PAYLOAD
self.model_call_details[ self.model_call_details[
"standard_logging_object" "standard_logging_object"
] = get_standard_logging_object_payload( ] = get_standard_logging_object_payload(
kwargs=self.model_call_details, kwargs=self.model_call_details,
init_response_obj=result, init_response_obj=logging_result,
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
logging_obj=self, logging_obj=self,

View file

@ -8,10 +8,6 @@ model_list:
litellm_params: litellm_params:
model: gpt-4o-mini model: gpt-4o-mini
api_key: os.environ/OPENAI_API_KEY api_key: os.environ/OPENAI_API_KEY
- model_name: "openai/*"
litellm_params:
model: openai/*
api_key: os.environ/OPENAI_API_KEY
- model_name: "bedrock-nova" - model_name: "bedrock-nova"
litellm_params: litellm_params:
model: us.amazon.nova-pro-v1:0 model: us.amazon.nova-pro-v1:0
@ -29,14 +25,13 @@ model_list:
model: databricks/databricks-claude-3-7-sonnet model: databricks/databricks-claude-3-7-sonnet
api_key: os.environ/DATABRICKS_API_KEY api_key: os.environ/DATABRICKS_API_KEY
api_base: os.environ/DATABRICKS_API_BASE api_base: os.environ/DATABRICKS_API_BASE
- model_name: "llmaas-meta/llama-3.1-8b-instruct" - model_name: "gpt-4o-realtime-preview"
litellm_params: litellm_params:
model: nvidia_nim/meta/llama-3.3-70b-instruct model: azure/gpt-4o-realtime-preview-2
api_key: "invalid" api_key: os.environ/AZURE_API_KEY_REALTIME
api_base: "http://0.0.0.0:8090" api_base: https://krris-m2f9a9i7-eastus2.openai.azure.com/
model_info: model_info:
input_cost_per_token: "100" base_model: azure/gpt-4o-realtime-preview-2024-10-01
output_cost_per_token: "100"
litellm_settings: litellm_settings:
num_retries: 0 num_retries: 0

View file

@ -7,6 +7,7 @@ from litellm import get_llm_provider
from litellm.secret_managers.main import get_secret_str from litellm.secret_managers.main import get_secret_str
from litellm.types.router import GenericLiteLLMParams from litellm.types.router import GenericLiteLLMParams
from ..litellm_core_utils.get_litellm_params import get_litellm_params
from ..litellm_core_utils.litellm_logging import Logging as LiteLLMLogging from ..litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
from ..llms.azure.realtime.handler import AzureOpenAIRealtime from ..llms.azure.realtime.handler import AzureOpenAIRealtime
from ..llms.openai.realtime.handler import OpenAIRealtime from ..llms.openai.realtime.handler import OpenAIRealtime
@ -34,13 +35,11 @@ async def _arealtime(
For PROXY use only. For PROXY use only.
""" """
litellm_logging_obj: LiteLLMLogging = kwargs.get("litellm_logging_obj") # type: ignore litellm_logging_obj: LiteLLMLogging = kwargs.get("litellm_logging_obj") # type: ignore
litellm_call_id: Optional[str] = kwargs.get("litellm_call_id", None)
proxy_server_request = kwargs.get("proxy_server_request", None)
model_info = kwargs.get("model_info", None)
metadata = kwargs.get("metadata", {})
user = kwargs.get("user", None) user = kwargs.get("user", None)
litellm_params = GenericLiteLLMParams(**kwargs) litellm_params = GenericLiteLLMParams(**kwargs)
litellm_params_dict = get_litellm_params(**kwargs)
model, _custom_llm_provider, dynamic_api_key, dynamic_api_base = get_llm_provider( model, _custom_llm_provider, dynamic_api_key, dynamic_api_base = get_llm_provider(
model=model, model=model,
api_base=api_base, api_base=api_base,
@ -51,14 +50,7 @@ async def _arealtime(
model=model, model=model,
user=user, user=user,
optional_params={}, optional_params={},
litellm_params={ litellm_params=litellm_params_dict,
"litellm_call_id": litellm_call_id,
"proxy_server_request": proxy_server_request,
"model_info": model_info,
"metadata": metadata,
"preset_cache_key": None,
"stream_response": {},
},
custom_llm_provider=_custom_llm_provider, custom_llm_provider=_custom_llm_provider,
) )

View file

@ -2138,6 +2138,7 @@ class Router:
request_kwargs=kwargs, request_kwargs=kwargs,
) )
self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs)
data = deployment["litellm_params"].copy() data = deployment["litellm_params"].copy()
for k, v in self.default_litellm_params.items(): for k, v in self.default_litellm_params.items():
if ( if (

View file

@ -34,6 +34,7 @@ from .llms.openai import (
ChatCompletionUsageBlock, ChatCompletionUsageBlock,
FileSearchTool, FileSearchTool,
OpenAIChatCompletionChunk, OpenAIChatCompletionChunk,
OpenAIRealtimeStreamList,
WebSearchOptions, WebSearchOptions,
) )
from .rerank import RerankResponse from .rerank import RerankResponse
@ -2152,6 +2153,31 @@ class LiteLLMBatch(Batch):
return self.dict() return self.dict()
class LiteLLMRealtimeStreamLoggingObject(LiteLLMPydanticObjectBase):
results: OpenAIRealtimeStreamList
usage: Usage
_hidden_params: dict = {}
def __contains__(self, key):
# Define custom behavior for the 'in' operator
return hasattr(self, key)
def get(self, key, default=None):
# Custom .get() method to access attributes with a default value if the attribute doesn't exist
return getattr(self, key, default)
def __getitem__(self, key):
# Allow dictionary-style access to attributes
return getattr(self, key)
def json(self, **kwargs): # type: ignore
try:
return self.model_dump() # noqa
except Exception:
# if using pydantic v1
return self.dict()
class RawRequestTypedDict(TypedDict, total=False): class RawRequestTypedDict(TypedDict, total=False):
raw_request_api_base: Optional[str] raw_request_api_base: Optional[str]
raw_request_body: Optional[dict] raw_request_body: Optional[dict]