mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
Realtime API: Support 'base_model' cost tracking + show response in spend logs (if enabled) (#9897)
* refactor(litellm_logging.py): refactor realtime cost tracking to use common code as rest Ensures basic features like base model just work * feat(realtime/): support 'base_model' cost tracking on realtime api Fixes issue where base model was not working on realtime * fix: fix ruff linting error * test: fix test
This commit is contained in:
parent
78879c68a9
commit
9f27e8363f
6 changed files with 102 additions and 46 deletions
|
@ -66,6 +66,7 @@ from litellm.types.llms.openai import (
|
|||
from litellm.types.rerank import RerankBilledUnits, RerankResponse
|
||||
from litellm.types.utils import (
|
||||
CallTypesLiteral,
|
||||
LiteLLMRealtimeStreamLoggingObject,
|
||||
LlmProviders,
|
||||
LlmProvidersSet,
|
||||
ModelInfo,
|
||||
|
@ -617,6 +618,7 @@ def completion_cost( # noqa: PLR0915
|
|||
completion_response=completion_response
|
||||
)
|
||||
rerank_billed_units: Optional[RerankBilledUnits] = None
|
||||
|
||||
selected_model = _select_model_name_for_cost_calc(
|
||||
model=model,
|
||||
completion_response=completion_response,
|
||||
|
@ -792,6 +794,25 @@ def completion_cost( # noqa: PLR0915
|
|||
billed_units.get("search_units") or 1
|
||||
) # cohere charges per request by default.
|
||||
completion_tokens = search_units
|
||||
elif call_type == CallTypes.arealtime.value and isinstance(
|
||||
completion_response, LiteLLMRealtimeStreamLoggingObject
|
||||
):
|
||||
if (
|
||||
cost_per_token_usage_object is None
|
||||
or custom_llm_provider is None
|
||||
):
|
||||
raise ValueError(
|
||||
"usage object and custom_llm_provider must be provided for realtime stream cost calculation. Got cost_per_token_usage_object={}, custom_llm_provider={}".format(
|
||||
cost_per_token_usage_object,
|
||||
custom_llm_provider,
|
||||
)
|
||||
)
|
||||
return handle_realtime_stream_cost_calculation(
|
||||
results=completion_response.results,
|
||||
combined_usage_object=cost_per_token_usage_object,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
litellm_model_name=model,
|
||||
)
|
||||
# Calculate cost based on prompt_tokens, completion_tokens
|
||||
if (
|
||||
"togethercomputer" in model
|
||||
|
@ -921,6 +942,7 @@ def response_cost_calculator(
|
|||
HttpxBinaryResponseContent,
|
||||
RerankResponse,
|
||||
ResponsesAPIResponse,
|
||||
LiteLLMRealtimeStreamLoggingObject,
|
||||
],
|
||||
model: str,
|
||||
custom_llm_provider: Optional[str],
|
||||
|
@ -1274,6 +1296,15 @@ class RealtimeAPITokenUsageProcessor:
|
|||
)
|
||||
return combined_usage_object
|
||||
|
||||
@staticmethod
|
||||
def create_logging_realtime_object(
|
||||
usage: Usage, results: OpenAIRealtimeStreamList
|
||||
) -> LiteLLMRealtimeStreamLoggingObject:
|
||||
return LiteLLMRealtimeStreamLoggingObject(
|
||||
usage=usage,
|
||||
results=results,
|
||||
)
|
||||
|
||||
|
||||
def handle_realtime_stream_cost_calculation(
|
||||
results: OpenAIRealtimeStreamList,
|
||||
|
|
|
@ -35,7 +35,6 @@ from litellm.constants import (
|
|||
from litellm.cost_calculator import (
|
||||
RealtimeAPITokenUsageProcessor,
|
||||
_select_model_name_for_cost_calc,
|
||||
handle_realtime_stream_cost_calculation,
|
||||
)
|
||||
from litellm.integrations.arize.arize import ArizeLogger
|
||||
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||
|
@ -68,6 +67,7 @@ from litellm.types.utils import (
|
|||
ImageResponse,
|
||||
LiteLLMBatch,
|
||||
LiteLLMLoggingBaseClass,
|
||||
LiteLLMRealtimeStreamLoggingObject,
|
||||
ModelResponse,
|
||||
ModelResponseStream,
|
||||
RawRequestTypedDict,
|
||||
|
@ -902,6 +902,7 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
FineTuningJob,
|
||||
ResponsesAPIResponse,
|
||||
ResponseCompletedEvent,
|
||||
LiteLLMRealtimeStreamLoggingObject,
|
||||
],
|
||||
cache_hit: Optional[bool] = None,
|
||||
litellm_model_name: Optional[str] = None,
|
||||
|
@ -1055,39 +1056,49 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
## if model in model cost map - log the response cost
|
||||
## else set cost to None
|
||||
|
||||
logging_result = result
|
||||
|
||||
if self.call_type == CallTypes.arealtime.value and isinstance(result, list):
|
||||
combined_usage_object = RealtimeAPITokenUsageProcessor.collect_and_combine_usage_from_realtime_stream_results(
|
||||
results=result
|
||||
)
|
||||
self.model_call_details[
|
||||
"response_cost"
|
||||
] = handle_realtime_stream_cost_calculation(
|
||||
results=result,
|
||||
combined_usage_object=combined_usage_object,
|
||||
custom_llm_provider=self.custom_llm_provider,
|
||||
litellm_model_name=self.model,
|
||||
logging_result = (
|
||||
RealtimeAPITokenUsageProcessor.create_logging_realtime_object(
|
||||
usage=combined_usage_object,
|
||||
results=result,
|
||||
)
|
||||
)
|
||||
self.model_call_details["combined_usage_object"] = combined_usage_object
|
||||
|
||||
# self.model_call_details[
|
||||
# "response_cost"
|
||||
# ] = handle_realtime_stream_cost_calculation(
|
||||
# results=result,
|
||||
# combined_usage_object=combined_usage_object,
|
||||
# custom_llm_provider=self.custom_llm_provider,
|
||||
# litellm_model_name=self.model,
|
||||
# )
|
||||
# self.model_call_details["combined_usage_object"] = combined_usage_object
|
||||
if (
|
||||
standard_logging_object is None
|
||||
and result is not None
|
||||
and self.stream is not True
|
||||
):
|
||||
if (
|
||||
isinstance(result, ModelResponse)
|
||||
or isinstance(result, ModelResponseStream)
|
||||
or isinstance(result, EmbeddingResponse)
|
||||
or isinstance(result, ImageResponse)
|
||||
or isinstance(result, TranscriptionResponse)
|
||||
or isinstance(result, TextCompletionResponse)
|
||||
or isinstance(result, HttpxBinaryResponseContent) # tts
|
||||
or isinstance(result, RerankResponse)
|
||||
or isinstance(result, FineTuningJob)
|
||||
or isinstance(result, LiteLLMBatch)
|
||||
or isinstance(result, ResponsesAPIResponse)
|
||||
isinstance(logging_result, ModelResponse)
|
||||
or isinstance(logging_result, ModelResponseStream)
|
||||
or isinstance(logging_result, EmbeddingResponse)
|
||||
or isinstance(logging_result, ImageResponse)
|
||||
or isinstance(logging_result, TranscriptionResponse)
|
||||
or isinstance(logging_result, TextCompletionResponse)
|
||||
or isinstance(logging_result, HttpxBinaryResponseContent) # tts
|
||||
or isinstance(logging_result, RerankResponse)
|
||||
or isinstance(logging_result, FineTuningJob)
|
||||
or isinstance(logging_result, LiteLLMBatch)
|
||||
or isinstance(logging_result, ResponsesAPIResponse)
|
||||
or isinstance(logging_result, LiteLLMRealtimeStreamLoggingObject)
|
||||
):
|
||||
## HIDDEN PARAMS ##
|
||||
hidden_params = getattr(result, "_hidden_params", {})
|
||||
hidden_params = getattr(logging_result, "_hidden_params", {})
|
||||
if hidden_params:
|
||||
# add to metadata for logging
|
||||
if self.model_call_details.get("litellm_params") is not None:
|
||||
|
@ -1105,7 +1116,7 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
self.model_call_details["litellm_params"]["metadata"][ # type: ignore
|
||||
"hidden_params"
|
||||
] = getattr(
|
||||
result, "_hidden_params", {}
|
||||
logging_result, "_hidden_params", {}
|
||||
)
|
||||
## RESPONSE COST - Only calculate if not in hidden_params ##
|
||||
if "response_cost" in hidden_params:
|
||||
|
@ -1115,14 +1126,14 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
else:
|
||||
self.model_call_details[
|
||||
"response_cost"
|
||||
] = self._response_cost_calculator(result=result)
|
||||
] = self._response_cost_calculator(result=logging_result)
|
||||
## STANDARDIZED LOGGING PAYLOAD
|
||||
|
||||
self.model_call_details[
|
||||
"standard_logging_object"
|
||||
] = get_standard_logging_object_payload(
|
||||
kwargs=self.model_call_details,
|
||||
init_response_obj=result,
|
||||
init_response_obj=logging_result,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
logging_obj=self,
|
||||
|
|
|
@ -8,10 +8,6 @@ model_list:
|
|||
litellm_params:
|
||||
model: gpt-4o-mini
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
- model_name: "openai/*"
|
||||
litellm_params:
|
||||
model: openai/*
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
- model_name: "bedrock-nova"
|
||||
litellm_params:
|
||||
model: us.amazon.nova-pro-v1:0
|
||||
|
@ -29,14 +25,13 @@ model_list:
|
|||
model: databricks/databricks-claude-3-7-sonnet
|
||||
api_key: os.environ/DATABRICKS_API_KEY
|
||||
api_base: os.environ/DATABRICKS_API_BASE
|
||||
- model_name: "llmaas-meta/llama-3.1-8b-instruct"
|
||||
- model_name: "gpt-4o-realtime-preview"
|
||||
litellm_params:
|
||||
model: nvidia_nim/meta/llama-3.3-70b-instruct
|
||||
api_key: "invalid"
|
||||
api_base: "http://0.0.0.0:8090"
|
||||
model: azure/gpt-4o-realtime-preview-2
|
||||
api_key: os.environ/AZURE_API_KEY_REALTIME
|
||||
api_base: https://krris-m2f9a9i7-eastus2.openai.azure.com/
|
||||
model_info:
|
||||
input_cost_per_token: "100"
|
||||
output_cost_per_token: "100"
|
||||
base_model: azure/gpt-4o-realtime-preview-2024-10-01
|
||||
|
||||
litellm_settings:
|
||||
num_retries: 0
|
||||
|
|
|
@ -7,6 +7,7 @@ from litellm import get_llm_provider
|
|||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
|
||||
from ..litellm_core_utils.get_litellm_params import get_litellm_params
|
||||
from ..litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
|
||||
from ..llms.azure.realtime.handler import AzureOpenAIRealtime
|
||||
from ..llms.openai.realtime.handler import OpenAIRealtime
|
||||
|
@ -34,13 +35,11 @@ async def _arealtime(
|
|||
For PROXY use only.
|
||||
"""
|
||||
litellm_logging_obj: LiteLLMLogging = kwargs.get("litellm_logging_obj") # type: ignore
|
||||
litellm_call_id: Optional[str] = kwargs.get("litellm_call_id", None)
|
||||
proxy_server_request = kwargs.get("proxy_server_request", None)
|
||||
model_info = kwargs.get("model_info", None)
|
||||
metadata = kwargs.get("metadata", {})
|
||||
user = kwargs.get("user", None)
|
||||
litellm_params = GenericLiteLLMParams(**kwargs)
|
||||
|
||||
litellm_params_dict = get_litellm_params(**kwargs)
|
||||
|
||||
model, _custom_llm_provider, dynamic_api_key, dynamic_api_base = get_llm_provider(
|
||||
model=model,
|
||||
api_base=api_base,
|
||||
|
@ -51,14 +50,7 @@ async def _arealtime(
|
|||
model=model,
|
||||
user=user,
|
||||
optional_params={},
|
||||
litellm_params={
|
||||
"litellm_call_id": litellm_call_id,
|
||||
"proxy_server_request": proxy_server_request,
|
||||
"model_info": model_info,
|
||||
"metadata": metadata,
|
||||
"preset_cache_key": None,
|
||||
"stream_response": {},
|
||||
},
|
||||
litellm_params=litellm_params_dict,
|
||||
custom_llm_provider=_custom_llm_provider,
|
||||
)
|
||||
|
||||
|
|
|
@ -2138,6 +2138,7 @@ class Router:
|
|||
request_kwargs=kwargs,
|
||||
)
|
||||
|
||||
self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs)
|
||||
data = deployment["litellm_params"].copy()
|
||||
for k, v in self.default_litellm_params.items():
|
||||
if (
|
||||
|
|
|
@ -34,6 +34,7 @@ from .llms.openai import (
|
|||
ChatCompletionUsageBlock,
|
||||
FileSearchTool,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAIRealtimeStreamList,
|
||||
WebSearchOptions,
|
||||
)
|
||||
from .rerank import RerankResponse
|
||||
|
@ -2152,6 +2153,31 @@ class LiteLLMBatch(Batch):
|
|||
return self.dict()
|
||||
|
||||
|
||||
class LiteLLMRealtimeStreamLoggingObject(LiteLLMPydanticObjectBase):
|
||||
results: OpenAIRealtimeStreamList
|
||||
usage: Usage
|
||||
_hidden_params: dict = {}
|
||||
|
||||
def __contains__(self, key):
|
||||
# Define custom behavior for the 'in' operator
|
||||
return hasattr(self, key)
|
||||
|
||||
def get(self, key, default=None):
|
||||
# Custom .get() method to access attributes with a default value if the attribute doesn't exist
|
||||
return getattr(self, key, default)
|
||||
|
||||
def __getitem__(self, key):
|
||||
# Allow dictionary-style access to attributes
|
||||
return getattr(self, key)
|
||||
|
||||
def json(self, **kwargs): # type: ignore
|
||||
try:
|
||||
return self.model_dump() # noqa
|
||||
except Exception:
|
||||
# if using pydantic v1
|
||||
return self.dict()
|
||||
|
||||
|
||||
class RawRequestTypedDict(TypedDict, total=False):
|
||||
raw_request_api_base: Optional[str]
|
||||
raw_request_body: Optional[dict]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue