Realtime API: Support 'base_model' cost tracking + show response in spend logs (if enabled) (#9897)

* refactor(litellm_logging.py): refactor realtime cost tracking to use common code as rest

Ensures basic features like base model just work

* feat(realtime/): support 'base_model' cost tracking on realtime api

Fixes issue where base model was not working on realtime

* fix: fix ruff linting error

* test: fix test
This commit is contained in:
Krish Dholakia 2025-04-10 21:24:45 -07:00 committed by GitHub
parent 78879c68a9
commit 9f27e8363f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 102 additions and 46 deletions

View file

@ -66,6 +66,7 @@ from litellm.types.llms.openai import (
from litellm.types.rerank import RerankBilledUnits, RerankResponse
from litellm.types.utils import (
CallTypesLiteral,
LiteLLMRealtimeStreamLoggingObject,
LlmProviders,
LlmProvidersSet,
ModelInfo,
@ -617,6 +618,7 @@ def completion_cost( # noqa: PLR0915
completion_response=completion_response
)
rerank_billed_units: Optional[RerankBilledUnits] = None
selected_model = _select_model_name_for_cost_calc(
model=model,
completion_response=completion_response,
@ -792,6 +794,25 @@ def completion_cost( # noqa: PLR0915
billed_units.get("search_units") or 1
) # cohere charges per request by default.
completion_tokens = search_units
elif call_type == CallTypes.arealtime.value and isinstance(
completion_response, LiteLLMRealtimeStreamLoggingObject
):
if (
cost_per_token_usage_object is None
or custom_llm_provider is None
):
raise ValueError(
"usage object and custom_llm_provider must be provided for realtime stream cost calculation. Got cost_per_token_usage_object={}, custom_llm_provider={}".format(
cost_per_token_usage_object,
custom_llm_provider,
)
)
return handle_realtime_stream_cost_calculation(
results=completion_response.results,
combined_usage_object=cost_per_token_usage_object,
custom_llm_provider=custom_llm_provider,
litellm_model_name=model,
)
# Calculate cost based on prompt_tokens, completion_tokens
if (
"togethercomputer" in model
@ -921,6 +942,7 @@ def response_cost_calculator(
HttpxBinaryResponseContent,
RerankResponse,
ResponsesAPIResponse,
LiteLLMRealtimeStreamLoggingObject,
],
model: str,
custom_llm_provider: Optional[str],
@ -1274,6 +1296,15 @@ class RealtimeAPITokenUsageProcessor:
)
return combined_usage_object
@staticmethod
def create_logging_realtime_object(
usage: Usage, results: OpenAIRealtimeStreamList
) -> LiteLLMRealtimeStreamLoggingObject:
return LiteLLMRealtimeStreamLoggingObject(
usage=usage,
results=results,
)
def handle_realtime_stream_cost_calculation(
results: OpenAIRealtimeStreamList,

View file

@ -35,7 +35,6 @@ from litellm.constants import (
from litellm.cost_calculator import (
RealtimeAPITokenUsageProcessor,
_select_model_name_for_cost_calc,
handle_realtime_stream_cost_calculation,
)
from litellm.integrations.arize.arize import ArizeLogger
from litellm.integrations.custom_guardrail import CustomGuardrail
@ -68,6 +67,7 @@ from litellm.types.utils import (
ImageResponse,
LiteLLMBatch,
LiteLLMLoggingBaseClass,
LiteLLMRealtimeStreamLoggingObject,
ModelResponse,
ModelResponseStream,
RawRequestTypedDict,
@ -902,6 +902,7 @@ class Logging(LiteLLMLoggingBaseClass):
FineTuningJob,
ResponsesAPIResponse,
ResponseCompletedEvent,
LiteLLMRealtimeStreamLoggingObject,
],
cache_hit: Optional[bool] = None,
litellm_model_name: Optional[str] = None,
@ -1055,39 +1056,49 @@ class Logging(LiteLLMLoggingBaseClass):
## if model in model cost map - log the response cost
## else set cost to None
logging_result = result
if self.call_type == CallTypes.arealtime.value and isinstance(result, list):
combined_usage_object = RealtimeAPITokenUsageProcessor.collect_and_combine_usage_from_realtime_stream_results(
results=result
)
self.model_call_details[
"response_cost"
] = handle_realtime_stream_cost_calculation(
results=result,
combined_usage_object=combined_usage_object,
custom_llm_provider=self.custom_llm_provider,
litellm_model_name=self.model,
logging_result = (
RealtimeAPITokenUsageProcessor.create_logging_realtime_object(
usage=combined_usage_object,
results=result,
)
)
self.model_call_details["combined_usage_object"] = combined_usage_object
# self.model_call_details[
# "response_cost"
# ] = handle_realtime_stream_cost_calculation(
# results=result,
# combined_usage_object=combined_usage_object,
# custom_llm_provider=self.custom_llm_provider,
# litellm_model_name=self.model,
# )
# self.model_call_details["combined_usage_object"] = combined_usage_object
if (
standard_logging_object is None
and result is not None
and self.stream is not True
):
if (
isinstance(result, ModelResponse)
or isinstance(result, ModelResponseStream)
or isinstance(result, EmbeddingResponse)
or isinstance(result, ImageResponse)
or isinstance(result, TranscriptionResponse)
or isinstance(result, TextCompletionResponse)
or isinstance(result, HttpxBinaryResponseContent) # tts
or isinstance(result, RerankResponse)
or isinstance(result, FineTuningJob)
or isinstance(result, LiteLLMBatch)
or isinstance(result, ResponsesAPIResponse)
isinstance(logging_result, ModelResponse)
or isinstance(logging_result, ModelResponseStream)
or isinstance(logging_result, EmbeddingResponse)
or isinstance(logging_result, ImageResponse)
or isinstance(logging_result, TranscriptionResponse)
or isinstance(logging_result, TextCompletionResponse)
or isinstance(logging_result, HttpxBinaryResponseContent) # tts
or isinstance(logging_result, RerankResponse)
or isinstance(logging_result, FineTuningJob)
or isinstance(logging_result, LiteLLMBatch)
or isinstance(logging_result, ResponsesAPIResponse)
or isinstance(logging_result, LiteLLMRealtimeStreamLoggingObject)
):
## HIDDEN PARAMS ##
hidden_params = getattr(result, "_hidden_params", {})
hidden_params = getattr(logging_result, "_hidden_params", {})
if hidden_params:
# add to metadata for logging
if self.model_call_details.get("litellm_params") is not None:
@ -1105,7 +1116,7 @@ class Logging(LiteLLMLoggingBaseClass):
self.model_call_details["litellm_params"]["metadata"][ # type: ignore
"hidden_params"
] = getattr(
result, "_hidden_params", {}
logging_result, "_hidden_params", {}
)
## RESPONSE COST - Only calculate if not in hidden_params ##
if "response_cost" in hidden_params:
@ -1115,14 +1126,14 @@ class Logging(LiteLLMLoggingBaseClass):
else:
self.model_call_details[
"response_cost"
] = self._response_cost_calculator(result=result)
] = self._response_cost_calculator(result=logging_result)
## STANDARDIZED LOGGING PAYLOAD
self.model_call_details[
"standard_logging_object"
] = get_standard_logging_object_payload(
kwargs=self.model_call_details,
init_response_obj=result,
init_response_obj=logging_result,
start_time=start_time,
end_time=end_time,
logging_obj=self,

View file

@ -8,10 +8,6 @@ model_list:
litellm_params:
model: gpt-4o-mini
api_key: os.environ/OPENAI_API_KEY
- model_name: "openai/*"
litellm_params:
model: openai/*
api_key: os.environ/OPENAI_API_KEY
- model_name: "bedrock-nova"
litellm_params:
model: us.amazon.nova-pro-v1:0
@ -29,14 +25,13 @@ model_list:
model: databricks/databricks-claude-3-7-sonnet
api_key: os.environ/DATABRICKS_API_KEY
api_base: os.environ/DATABRICKS_API_BASE
- model_name: "llmaas-meta/llama-3.1-8b-instruct"
- model_name: "gpt-4o-realtime-preview"
litellm_params:
model: nvidia_nim/meta/llama-3.3-70b-instruct
api_key: "invalid"
api_base: "http://0.0.0.0:8090"
model: azure/gpt-4o-realtime-preview-2
api_key: os.environ/AZURE_API_KEY_REALTIME
api_base: https://krris-m2f9a9i7-eastus2.openai.azure.com/
model_info:
input_cost_per_token: "100"
output_cost_per_token: "100"
base_model: azure/gpt-4o-realtime-preview-2024-10-01
litellm_settings:
num_retries: 0

View file

@ -7,6 +7,7 @@ from litellm import get_llm_provider
from litellm.secret_managers.main import get_secret_str
from litellm.types.router import GenericLiteLLMParams
from ..litellm_core_utils.get_litellm_params import get_litellm_params
from ..litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
from ..llms.azure.realtime.handler import AzureOpenAIRealtime
from ..llms.openai.realtime.handler import OpenAIRealtime
@ -34,13 +35,11 @@ async def _arealtime(
For PROXY use only.
"""
litellm_logging_obj: LiteLLMLogging = kwargs.get("litellm_logging_obj") # type: ignore
litellm_call_id: Optional[str] = kwargs.get("litellm_call_id", None)
proxy_server_request = kwargs.get("proxy_server_request", None)
model_info = kwargs.get("model_info", None)
metadata = kwargs.get("metadata", {})
user = kwargs.get("user", None)
litellm_params = GenericLiteLLMParams(**kwargs)
litellm_params_dict = get_litellm_params(**kwargs)
model, _custom_llm_provider, dynamic_api_key, dynamic_api_base = get_llm_provider(
model=model,
api_base=api_base,
@ -51,14 +50,7 @@ async def _arealtime(
model=model,
user=user,
optional_params={},
litellm_params={
"litellm_call_id": litellm_call_id,
"proxy_server_request": proxy_server_request,
"model_info": model_info,
"metadata": metadata,
"preset_cache_key": None,
"stream_response": {},
},
litellm_params=litellm_params_dict,
custom_llm_provider=_custom_llm_provider,
)

View file

@ -2138,6 +2138,7 @@ class Router:
request_kwargs=kwargs,
)
self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs)
data = deployment["litellm_params"].copy()
for k, v in self.default_litellm_params.items():
if (

View file

@ -34,6 +34,7 @@ from .llms.openai import (
ChatCompletionUsageBlock,
FileSearchTool,
OpenAIChatCompletionChunk,
OpenAIRealtimeStreamList,
WebSearchOptions,
)
from .rerank import RerankResponse
@ -2152,6 +2153,31 @@ class LiteLLMBatch(Batch):
return self.dict()
class LiteLLMRealtimeStreamLoggingObject(LiteLLMPydanticObjectBase):
results: OpenAIRealtimeStreamList
usage: Usage
_hidden_params: dict = {}
def __contains__(self, key):
# Define custom behavior for the 'in' operator
return hasattr(self, key)
def get(self, key, default=None):
# Custom .get() method to access attributes with a default value if the attribute doesn't exist
return getattr(self, key, default)
def __getitem__(self, key):
# Allow dictionary-style access to attributes
return getattr(self, key)
def json(self, **kwargs): # type: ignore
try:
return self.model_dump() # noqa
except Exception:
# if using pydantic v1
return self.dict()
class RawRequestTypedDict(TypedDict, total=False):
raw_request_api_base: Optional[str]
raw_request_body: Optional[dict]