mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 03:34:10 +00:00
Realtime API: Support 'base_model' cost tracking + show response in spend logs (if enabled) (#9897)
* refactor(litellm_logging.py): refactor realtime cost tracking to use common code as rest Ensures basic features like base model just work * feat(realtime/): support 'base_model' cost tracking on realtime api Fixes issue where base model was not working on realtime * fix: fix ruff linting error * test: fix test
This commit is contained in:
parent
78879c68a9
commit
9f27e8363f
6 changed files with 102 additions and 46 deletions
|
@ -66,6 +66,7 @@ from litellm.types.llms.openai import (
|
||||||
from litellm.types.rerank import RerankBilledUnits, RerankResponse
|
from litellm.types.rerank import RerankBilledUnits, RerankResponse
|
||||||
from litellm.types.utils import (
|
from litellm.types.utils import (
|
||||||
CallTypesLiteral,
|
CallTypesLiteral,
|
||||||
|
LiteLLMRealtimeStreamLoggingObject,
|
||||||
LlmProviders,
|
LlmProviders,
|
||||||
LlmProvidersSet,
|
LlmProvidersSet,
|
||||||
ModelInfo,
|
ModelInfo,
|
||||||
|
@ -617,6 +618,7 @@ def completion_cost( # noqa: PLR0915
|
||||||
completion_response=completion_response
|
completion_response=completion_response
|
||||||
)
|
)
|
||||||
rerank_billed_units: Optional[RerankBilledUnits] = None
|
rerank_billed_units: Optional[RerankBilledUnits] = None
|
||||||
|
|
||||||
selected_model = _select_model_name_for_cost_calc(
|
selected_model = _select_model_name_for_cost_calc(
|
||||||
model=model,
|
model=model,
|
||||||
completion_response=completion_response,
|
completion_response=completion_response,
|
||||||
|
@ -792,6 +794,25 @@ def completion_cost( # noqa: PLR0915
|
||||||
billed_units.get("search_units") or 1
|
billed_units.get("search_units") or 1
|
||||||
) # cohere charges per request by default.
|
) # cohere charges per request by default.
|
||||||
completion_tokens = search_units
|
completion_tokens = search_units
|
||||||
|
elif call_type == CallTypes.arealtime.value and isinstance(
|
||||||
|
completion_response, LiteLLMRealtimeStreamLoggingObject
|
||||||
|
):
|
||||||
|
if (
|
||||||
|
cost_per_token_usage_object is None
|
||||||
|
or custom_llm_provider is None
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"usage object and custom_llm_provider must be provided for realtime stream cost calculation. Got cost_per_token_usage_object={}, custom_llm_provider={}".format(
|
||||||
|
cost_per_token_usage_object,
|
||||||
|
custom_llm_provider,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return handle_realtime_stream_cost_calculation(
|
||||||
|
results=completion_response.results,
|
||||||
|
combined_usage_object=cost_per_token_usage_object,
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
litellm_model_name=model,
|
||||||
|
)
|
||||||
# Calculate cost based on prompt_tokens, completion_tokens
|
# Calculate cost based on prompt_tokens, completion_tokens
|
||||||
if (
|
if (
|
||||||
"togethercomputer" in model
|
"togethercomputer" in model
|
||||||
|
@ -921,6 +942,7 @@ def response_cost_calculator(
|
||||||
HttpxBinaryResponseContent,
|
HttpxBinaryResponseContent,
|
||||||
RerankResponse,
|
RerankResponse,
|
||||||
ResponsesAPIResponse,
|
ResponsesAPIResponse,
|
||||||
|
LiteLLMRealtimeStreamLoggingObject,
|
||||||
],
|
],
|
||||||
model: str,
|
model: str,
|
||||||
custom_llm_provider: Optional[str],
|
custom_llm_provider: Optional[str],
|
||||||
|
@ -1274,6 +1296,15 @@ class RealtimeAPITokenUsageProcessor:
|
||||||
)
|
)
|
||||||
return combined_usage_object
|
return combined_usage_object
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_logging_realtime_object(
|
||||||
|
usage: Usage, results: OpenAIRealtimeStreamList
|
||||||
|
) -> LiteLLMRealtimeStreamLoggingObject:
|
||||||
|
return LiteLLMRealtimeStreamLoggingObject(
|
||||||
|
usage=usage,
|
||||||
|
results=results,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def handle_realtime_stream_cost_calculation(
|
def handle_realtime_stream_cost_calculation(
|
||||||
results: OpenAIRealtimeStreamList,
|
results: OpenAIRealtimeStreamList,
|
||||||
|
|
|
@ -35,7 +35,6 @@ from litellm.constants import (
|
||||||
from litellm.cost_calculator import (
|
from litellm.cost_calculator import (
|
||||||
RealtimeAPITokenUsageProcessor,
|
RealtimeAPITokenUsageProcessor,
|
||||||
_select_model_name_for_cost_calc,
|
_select_model_name_for_cost_calc,
|
||||||
handle_realtime_stream_cost_calculation,
|
|
||||||
)
|
)
|
||||||
from litellm.integrations.arize.arize import ArizeLogger
|
from litellm.integrations.arize.arize import ArizeLogger
|
||||||
from litellm.integrations.custom_guardrail import CustomGuardrail
|
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||||
|
@ -68,6 +67,7 @@ from litellm.types.utils import (
|
||||||
ImageResponse,
|
ImageResponse,
|
||||||
LiteLLMBatch,
|
LiteLLMBatch,
|
||||||
LiteLLMLoggingBaseClass,
|
LiteLLMLoggingBaseClass,
|
||||||
|
LiteLLMRealtimeStreamLoggingObject,
|
||||||
ModelResponse,
|
ModelResponse,
|
||||||
ModelResponseStream,
|
ModelResponseStream,
|
||||||
RawRequestTypedDict,
|
RawRequestTypedDict,
|
||||||
|
@ -902,6 +902,7 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
FineTuningJob,
|
FineTuningJob,
|
||||||
ResponsesAPIResponse,
|
ResponsesAPIResponse,
|
||||||
ResponseCompletedEvent,
|
ResponseCompletedEvent,
|
||||||
|
LiteLLMRealtimeStreamLoggingObject,
|
||||||
],
|
],
|
||||||
cache_hit: Optional[bool] = None,
|
cache_hit: Optional[bool] = None,
|
||||||
litellm_model_name: Optional[str] = None,
|
litellm_model_name: Optional[str] = None,
|
||||||
|
@ -1055,39 +1056,49 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
## if model in model cost map - log the response cost
|
## if model in model cost map - log the response cost
|
||||||
## else set cost to None
|
## else set cost to None
|
||||||
|
|
||||||
|
logging_result = result
|
||||||
|
|
||||||
if self.call_type == CallTypes.arealtime.value and isinstance(result, list):
|
if self.call_type == CallTypes.arealtime.value and isinstance(result, list):
|
||||||
combined_usage_object = RealtimeAPITokenUsageProcessor.collect_and_combine_usage_from_realtime_stream_results(
|
combined_usage_object = RealtimeAPITokenUsageProcessor.collect_and_combine_usage_from_realtime_stream_results(
|
||||||
results=result
|
results=result
|
||||||
)
|
)
|
||||||
self.model_call_details[
|
logging_result = (
|
||||||
"response_cost"
|
RealtimeAPITokenUsageProcessor.create_logging_realtime_object(
|
||||||
] = handle_realtime_stream_cost_calculation(
|
usage=combined_usage_object,
|
||||||
results=result,
|
results=result,
|
||||||
combined_usage_object=combined_usage_object,
|
)
|
||||||
custom_llm_provider=self.custom_llm_provider,
|
|
||||||
litellm_model_name=self.model,
|
|
||||||
)
|
)
|
||||||
self.model_call_details["combined_usage_object"] = combined_usage_object
|
|
||||||
|
# self.model_call_details[
|
||||||
|
# "response_cost"
|
||||||
|
# ] = handle_realtime_stream_cost_calculation(
|
||||||
|
# results=result,
|
||||||
|
# combined_usage_object=combined_usage_object,
|
||||||
|
# custom_llm_provider=self.custom_llm_provider,
|
||||||
|
# litellm_model_name=self.model,
|
||||||
|
# )
|
||||||
|
# self.model_call_details["combined_usage_object"] = combined_usage_object
|
||||||
if (
|
if (
|
||||||
standard_logging_object is None
|
standard_logging_object is None
|
||||||
and result is not None
|
and result is not None
|
||||||
and self.stream is not True
|
and self.stream is not True
|
||||||
):
|
):
|
||||||
if (
|
if (
|
||||||
isinstance(result, ModelResponse)
|
isinstance(logging_result, ModelResponse)
|
||||||
or isinstance(result, ModelResponseStream)
|
or isinstance(logging_result, ModelResponseStream)
|
||||||
or isinstance(result, EmbeddingResponse)
|
or isinstance(logging_result, EmbeddingResponse)
|
||||||
or isinstance(result, ImageResponse)
|
or isinstance(logging_result, ImageResponse)
|
||||||
or isinstance(result, TranscriptionResponse)
|
or isinstance(logging_result, TranscriptionResponse)
|
||||||
or isinstance(result, TextCompletionResponse)
|
or isinstance(logging_result, TextCompletionResponse)
|
||||||
or isinstance(result, HttpxBinaryResponseContent) # tts
|
or isinstance(logging_result, HttpxBinaryResponseContent) # tts
|
||||||
or isinstance(result, RerankResponse)
|
or isinstance(logging_result, RerankResponse)
|
||||||
or isinstance(result, FineTuningJob)
|
or isinstance(logging_result, FineTuningJob)
|
||||||
or isinstance(result, LiteLLMBatch)
|
or isinstance(logging_result, LiteLLMBatch)
|
||||||
or isinstance(result, ResponsesAPIResponse)
|
or isinstance(logging_result, ResponsesAPIResponse)
|
||||||
|
or isinstance(logging_result, LiteLLMRealtimeStreamLoggingObject)
|
||||||
):
|
):
|
||||||
## HIDDEN PARAMS ##
|
## HIDDEN PARAMS ##
|
||||||
hidden_params = getattr(result, "_hidden_params", {})
|
hidden_params = getattr(logging_result, "_hidden_params", {})
|
||||||
if hidden_params:
|
if hidden_params:
|
||||||
# add to metadata for logging
|
# add to metadata for logging
|
||||||
if self.model_call_details.get("litellm_params") is not None:
|
if self.model_call_details.get("litellm_params") is not None:
|
||||||
|
@ -1105,7 +1116,7 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
self.model_call_details["litellm_params"]["metadata"][ # type: ignore
|
self.model_call_details["litellm_params"]["metadata"][ # type: ignore
|
||||||
"hidden_params"
|
"hidden_params"
|
||||||
] = getattr(
|
] = getattr(
|
||||||
result, "_hidden_params", {}
|
logging_result, "_hidden_params", {}
|
||||||
)
|
)
|
||||||
## RESPONSE COST - Only calculate if not in hidden_params ##
|
## RESPONSE COST - Only calculate if not in hidden_params ##
|
||||||
if "response_cost" in hidden_params:
|
if "response_cost" in hidden_params:
|
||||||
|
@ -1115,14 +1126,14 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
else:
|
else:
|
||||||
self.model_call_details[
|
self.model_call_details[
|
||||||
"response_cost"
|
"response_cost"
|
||||||
] = self._response_cost_calculator(result=result)
|
] = self._response_cost_calculator(result=logging_result)
|
||||||
## STANDARDIZED LOGGING PAYLOAD
|
## STANDARDIZED LOGGING PAYLOAD
|
||||||
|
|
||||||
self.model_call_details[
|
self.model_call_details[
|
||||||
"standard_logging_object"
|
"standard_logging_object"
|
||||||
] = get_standard_logging_object_payload(
|
] = get_standard_logging_object_payload(
|
||||||
kwargs=self.model_call_details,
|
kwargs=self.model_call_details,
|
||||||
init_response_obj=result,
|
init_response_obj=logging_result,
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
end_time=end_time,
|
end_time=end_time,
|
||||||
logging_obj=self,
|
logging_obj=self,
|
||||||
|
|
|
@ -8,10 +8,6 @@ model_list:
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: gpt-4o-mini
|
model: gpt-4o-mini
|
||||||
api_key: os.environ/OPENAI_API_KEY
|
api_key: os.environ/OPENAI_API_KEY
|
||||||
- model_name: "openai/*"
|
|
||||||
litellm_params:
|
|
||||||
model: openai/*
|
|
||||||
api_key: os.environ/OPENAI_API_KEY
|
|
||||||
- model_name: "bedrock-nova"
|
- model_name: "bedrock-nova"
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: us.amazon.nova-pro-v1:0
|
model: us.amazon.nova-pro-v1:0
|
||||||
|
@ -29,14 +25,13 @@ model_list:
|
||||||
model: databricks/databricks-claude-3-7-sonnet
|
model: databricks/databricks-claude-3-7-sonnet
|
||||||
api_key: os.environ/DATABRICKS_API_KEY
|
api_key: os.environ/DATABRICKS_API_KEY
|
||||||
api_base: os.environ/DATABRICKS_API_BASE
|
api_base: os.environ/DATABRICKS_API_BASE
|
||||||
- model_name: "llmaas-meta/llama-3.1-8b-instruct"
|
- model_name: "gpt-4o-realtime-preview"
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: nvidia_nim/meta/llama-3.3-70b-instruct
|
model: azure/gpt-4o-realtime-preview-2
|
||||||
api_key: "invalid"
|
api_key: os.environ/AZURE_API_KEY_REALTIME
|
||||||
api_base: "http://0.0.0.0:8090"
|
api_base: https://krris-m2f9a9i7-eastus2.openai.azure.com/
|
||||||
model_info:
|
model_info:
|
||||||
input_cost_per_token: "100"
|
base_model: azure/gpt-4o-realtime-preview-2024-10-01
|
||||||
output_cost_per_token: "100"
|
|
||||||
|
|
||||||
litellm_settings:
|
litellm_settings:
|
||||||
num_retries: 0
|
num_retries: 0
|
||||||
|
|
|
@ -7,6 +7,7 @@ from litellm import get_llm_provider
|
||||||
from litellm.secret_managers.main import get_secret_str
|
from litellm.secret_managers.main import get_secret_str
|
||||||
from litellm.types.router import GenericLiteLLMParams
|
from litellm.types.router import GenericLiteLLMParams
|
||||||
|
|
||||||
|
from ..litellm_core_utils.get_litellm_params import get_litellm_params
|
||||||
from ..litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
|
from ..litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
|
||||||
from ..llms.azure.realtime.handler import AzureOpenAIRealtime
|
from ..llms.azure.realtime.handler import AzureOpenAIRealtime
|
||||||
from ..llms.openai.realtime.handler import OpenAIRealtime
|
from ..llms.openai.realtime.handler import OpenAIRealtime
|
||||||
|
@ -34,13 +35,11 @@ async def _arealtime(
|
||||||
For PROXY use only.
|
For PROXY use only.
|
||||||
"""
|
"""
|
||||||
litellm_logging_obj: LiteLLMLogging = kwargs.get("litellm_logging_obj") # type: ignore
|
litellm_logging_obj: LiteLLMLogging = kwargs.get("litellm_logging_obj") # type: ignore
|
||||||
litellm_call_id: Optional[str] = kwargs.get("litellm_call_id", None)
|
|
||||||
proxy_server_request = kwargs.get("proxy_server_request", None)
|
|
||||||
model_info = kwargs.get("model_info", None)
|
|
||||||
metadata = kwargs.get("metadata", {})
|
|
||||||
user = kwargs.get("user", None)
|
user = kwargs.get("user", None)
|
||||||
litellm_params = GenericLiteLLMParams(**kwargs)
|
litellm_params = GenericLiteLLMParams(**kwargs)
|
||||||
|
|
||||||
|
litellm_params_dict = get_litellm_params(**kwargs)
|
||||||
|
|
||||||
model, _custom_llm_provider, dynamic_api_key, dynamic_api_base = get_llm_provider(
|
model, _custom_llm_provider, dynamic_api_key, dynamic_api_base = get_llm_provider(
|
||||||
model=model,
|
model=model,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
@ -51,14 +50,7 @@ async def _arealtime(
|
||||||
model=model,
|
model=model,
|
||||||
user=user,
|
user=user,
|
||||||
optional_params={},
|
optional_params={},
|
||||||
litellm_params={
|
litellm_params=litellm_params_dict,
|
||||||
"litellm_call_id": litellm_call_id,
|
|
||||||
"proxy_server_request": proxy_server_request,
|
|
||||||
"model_info": model_info,
|
|
||||||
"metadata": metadata,
|
|
||||||
"preset_cache_key": None,
|
|
||||||
"stream_response": {},
|
|
||||||
},
|
|
||||||
custom_llm_provider=_custom_llm_provider,
|
custom_llm_provider=_custom_llm_provider,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -2138,6 +2138,7 @@ class Router:
|
||||||
request_kwargs=kwargs,
|
request_kwargs=kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs)
|
||||||
data = deployment["litellm_params"].copy()
|
data = deployment["litellm_params"].copy()
|
||||||
for k, v in self.default_litellm_params.items():
|
for k, v in self.default_litellm_params.items():
|
||||||
if (
|
if (
|
||||||
|
|
|
@ -34,6 +34,7 @@ from .llms.openai import (
|
||||||
ChatCompletionUsageBlock,
|
ChatCompletionUsageBlock,
|
||||||
FileSearchTool,
|
FileSearchTool,
|
||||||
OpenAIChatCompletionChunk,
|
OpenAIChatCompletionChunk,
|
||||||
|
OpenAIRealtimeStreamList,
|
||||||
WebSearchOptions,
|
WebSearchOptions,
|
||||||
)
|
)
|
||||||
from .rerank import RerankResponse
|
from .rerank import RerankResponse
|
||||||
|
@ -2152,6 +2153,31 @@ class LiteLLMBatch(Batch):
|
||||||
return self.dict()
|
return self.dict()
|
||||||
|
|
||||||
|
|
||||||
|
class LiteLLMRealtimeStreamLoggingObject(LiteLLMPydanticObjectBase):
|
||||||
|
results: OpenAIRealtimeStreamList
|
||||||
|
usage: Usage
|
||||||
|
_hidden_params: dict = {}
|
||||||
|
|
||||||
|
def __contains__(self, key):
|
||||||
|
# Define custom behavior for the 'in' operator
|
||||||
|
return hasattr(self, key)
|
||||||
|
|
||||||
|
def get(self, key, default=None):
|
||||||
|
# Custom .get() method to access attributes with a default value if the attribute doesn't exist
|
||||||
|
return getattr(self, key, default)
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
# Allow dictionary-style access to attributes
|
||||||
|
return getattr(self, key)
|
||||||
|
|
||||||
|
def json(self, **kwargs): # type: ignore
|
||||||
|
try:
|
||||||
|
return self.model_dump() # noqa
|
||||||
|
except Exception:
|
||||||
|
# if using pydantic v1
|
||||||
|
return self.dict()
|
||||||
|
|
||||||
|
|
||||||
class RawRequestTypedDict(TypedDict, total=False):
|
class RawRequestTypedDict(TypedDict, total=False):
|
||||||
raw_request_api_base: Optional[str]
|
raw_request_api_base: Optional[str]
|
||||||
raw_request_body: Optional[dict]
|
raw_request_body: Optional[dict]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue