mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
Realtime API Cost tracking (#9795)
* fix(proxy_server.py): log realtime calls to spendlogs Fixes https://github.com/BerriAI/litellm/issues/8410 * feat(realtime/): OpenAI Realtime API cost tracking Closes https://github.com/BerriAI/litellm/issues/8410 * test: add unit testing for coverage * test: add more unit testing * fix: handle edge cases
This commit is contained in:
parent
1c63e3dad0
commit
ae6bc8ac77
12 changed files with 401 additions and 39 deletions
|
@ -16,7 +16,10 @@ from litellm.constants import (
|
||||||
from litellm.litellm_core_utils.llm_cost_calc.tool_call_cost_tracking import (
|
from litellm.litellm_core_utils.llm_cost_calc.tool_call_cost_tracking import (
|
||||||
StandardBuiltInToolCostTracking,
|
StandardBuiltInToolCostTracking,
|
||||||
)
|
)
|
||||||
from litellm.litellm_core_utils.llm_cost_calc.utils import _generic_cost_per_character
|
from litellm.litellm_core_utils.llm_cost_calc.utils import (
|
||||||
|
_generic_cost_per_character,
|
||||||
|
generic_cost_per_token,
|
||||||
|
)
|
||||||
from litellm.llms.anthropic.cost_calculation import (
|
from litellm.llms.anthropic.cost_calculation import (
|
||||||
cost_per_token as anthropic_cost_per_token,
|
cost_per_token as anthropic_cost_per_token,
|
||||||
)
|
)
|
||||||
|
@ -54,6 +57,9 @@ from litellm.llms.vertex_ai.image_generation.cost_calculator import (
|
||||||
from litellm.responses.utils import ResponseAPILoggingUtils
|
from litellm.responses.utils import ResponseAPILoggingUtils
|
||||||
from litellm.types.llms.openai import (
|
from litellm.types.llms.openai import (
|
||||||
HttpxBinaryResponseContent,
|
HttpxBinaryResponseContent,
|
||||||
|
OpenAIRealtimeStreamList,
|
||||||
|
OpenAIRealtimeStreamResponseBaseObject,
|
||||||
|
OpenAIRealtimeStreamSessionEvents,
|
||||||
ResponseAPIUsage,
|
ResponseAPIUsage,
|
||||||
ResponsesAPIResponse,
|
ResponsesAPIResponse,
|
||||||
)
|
)
|
||||||
|
@ -1141,3 +1147,50 @@ def batch_cost_calculator(
|
||||||
) # batch cost is usually half of the regular token cost
|
) # batch cost is usually half of the regular token cost
|
||||||
|
|
||||||
return total_prompt_cost, total_completion_cost
|
return total_prompt_cost, total_completion_cost
|
||||||
|
|
||||||
|
|
||||||
|
def handle_realtime_stream_cost_calculation(
|
||||||
|
results: OpenAIRealtimeStreamList, custom_llm_provider: str, litellm_model_name: str
|
||||||
|
) -> float:
|
||||||
|
"""
|
||||||
|
Handles the cost calculation for realtime stream responses.
|
||||||
|
|
||||||
|
Pick the 'response.done' events. Calculate total cost across all 'response.done' events.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results: A list of OpenAIRealtimeStreamBaseObject objects
|
||||||
|
"""
|
||||||
|
response_done_events: List[OpenAIRealtimeStreamResponseBaseObject] = cast(
|
||||||
|
List[OpenAIRealtimeStreamResponseBaseObject],
|
||||||
|
[result for result in results if result["type"] == "response.done"],
|
||||||
|
)
|
||||||
|
received_model = None
|
||||||
|
potential_model_names = []
|
||||||
|
for result in results:
|
||||||
|
if result["type"] == "session.created":
|
||||||
|
received_model = cast(OpenAIRealtimeStreamSessionEvents, result)["session"][
|
||||||
|
"model"
|
||||||
|
]
|
||||||
|
potential_model_names.append(received_model)
|
||||||
|
|
||||||
|
potential_model_names.append(litellm_model_name)
|
||||||
|
input_cost_per_token = 0.0
|
||||||
|
output_cost_per_token = 0.0
|
||||||
|
for result in response_done_events:
|
||||||
|
usage_object = (
|
||||||
|
ResponseAPILoggingUtils._transform_response_api_usage_to_chat_usage(
|
||||||
|
result["response"].get("usage", {})
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
for model_name in potential_model_names:
|
||||||
|
_input_cost_per_token, _output_cost_per_token = generic_cost_per_token(
|
||||||
|
model=model_name,
|
||||||
|
usage=usage_object,
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
)
|
||||||
|
input_cost_per_token += _input_cost_per_token
|
||||||
|
output_cost_per_token += _output_cost_per_token
|
||||||
|
total_cost = input_cost_per_token + output_cost_per_token
|
||||||
|
|
||||||
|
return total_cost
|
||||||
|
|
|
@ -32,7 +32,10 @@ from litellm.constants import (
|
||||||
DEFAULT_MOCK_RESPONSE_COMPLETION_TOKEN_COUNT,
|
DEFAULT_MOCK_RESPONSE_COMPLETION_TOKEN_COUNT,
|
||||||
DEFAULT_MOCK_RESPONSE_PROMPT_TOKEN_COUNT,
|
DEFAULT_MOCK_RESPONSE_PROMPT_TOKEN_COUNT,
|
||||||
)
|
)
|
||||||
from litellm.cost_calculator import _select_model_name_for_cost_calc
|
from litellm.cost_calculator import (
|
||||||
|
_select_model_name_for_cost_calc,
|
||||||
|
handle_realtime_stream_cost_calculation,
|
||||||
|
)
|
||||||
from litellm.integrations.arize.arize import ArizeLogger
|
from litellm.integrations.arize.arize import ArizeLogger
|
||||||
from litellm.integrations.custom_guardrail import CustomGuardrail
|
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
@ -1049,6 +1052,13 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
result = self._handle_anthropic_messages_response_logging(result=result)
|
result = self._handle_anthropic_messages_response_logging(result=result)
|
||||||
## if model in model cost map - log the response cost
|
## if model in model cost map - log the response cost
|
||||||
## else set cost to None
|
## else set cost to None
|
||||||
|
|
||||||
|
if self.call_type == CallTypes.arealtime.value and isinstance(result, list):
|
||||||
|
self.model_call_details[
|
||||||
|
"response_cost"
|
||||||
|
] = handle_realtime_stream_cost_calculation(
|
||||||
|
result, self.custom_llm_provider, self.model
|
||||||
|
)
|
||||||
if (
|
if (
|
||||||
standard_logging_object is None
|
standard_logging_object is None
|
||||||
and result is not None
|
and result is not None
|
||||||
|
|
|
@ -30,6 +30,11 @@ import json
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
from litellm._logging import verbose_logger
|
||||||
|
from litellm.types.llms.openai import (
|
||||||
|
OpenAIRealtimeStreamResponseBaseObject,
|
||||||
|
OpenAIRealtimeStreamSessionEvents,
|
||||||
|
)
|
||||||
|
|
||||||
from .litellm_logging import Logging as LiteLLMLogging
|
from .litellm_logging import Logging as LiteLLMLogging
|
||||||
|
|
||||||
|
@ -53,7 +58,12 @@ class RealTimeStreaming:
|
||||||
self.websocket = websocket
|
self.websocket = websocket
|
||||||
self.backend_ws = backend_ws
|
self.backend_ws = backend_ws
|
||||||
self.logging_obj = logging_obj
|
self.logging_obj = logging_obj
|
||||||
self.messages: List = []
|
self.messages: List[
|
||||||
|
Union[
|
||||||
|
OpenAIRealtimeStreamResponseBaseObject,
|
||||||
|
OpenAIRealtimeStreamSessionEvents,
|
||||||
|
]
|
||||||
|
] = []
|
||||||
self.input_message: Dict = {}
|
self.input_message: Dict = {}
|
||||||
|
|
||||||
_logged_real_time_event_types = litellm.logged_real_time_event_types
|
_logged_real_time_event_types = litellm.logged_real_time_event_types
|
||||||
|
@ -62,10 +72,14 @@ class RealTimeStreaming:
|
||||||
_logged_real_time_event_types = DefaultLoggedRealTimeEventTypes
|
_logged_real_time_event_types = DefaultLoggedRealTimeEventTypes
|
||||||
self.logged_real_time_event_types = _logged_real_time_event_types
|
self.logged_real_time_event_types = _logged_real_time_event_types
|
||||||
|
|
||||||
def _should_store_message(self, message: Union[str, bytes]) -> bool:
|
def _should_store_message(
|
||||||
if isinstance(message, bytes):
|
self,
|
||||||
message = message.decode("utf-8")
|
message_obj: Union[
|
||||||
message_obj = json.loads(message)
|
dict,
|
||||||
|
OpenAIRealtimeStreamSessionEvents,
|
||||||
|
OpenAIRealtimeStreamResponseBaseObject,
|
||||||
|
],
|
||||||
|
) -> bool:
|
||||||
_msg_type = message_obj["type"]
|
_msg_type = message_obj["type"]
|
||||||
if self.logged_real_time_event_types == "*":
|
if self.logged_real_time_event_types == "*":
|
||||||
return True
|
return True
|
||||||
|
@ -75,8 +89,22 @@ class RealTimeStreaming:
|
||||||
|
|
||||||
def store_message(self, message: Union[str, bytes]):
|
def store_message(self, message: Union[str, bytes]):
|
||||||
"""Store message in list"""
|
"""Store message in list"""
|
||||||
if self._should_store_message(message):
|
if isinstance(message, bytes):
|
||||||
self.messages.append(message)
|
message = message.decode("utf-8")
|
||||||
|
message_obj = json.loads(message)
|
||||||
|
try:
|
||||||
|
if (
|
||||||
|
message_obj.get("type") == "session.created"
|
||||||
|
or message_obj.get("type") == "session.updated"
|
||||||
|
):
|
||||||
|
message_obj = OpenAIRealtimeStreamSessionEvents(**message_obj) # type: ignore
|
||||||
|
else:
|
||||||
|
message_obj = OpenAIRealtimeStreamResponseBaseObject(**message_obj) # type: ignore
|
||||||
|
except Exception as e:
|
||||||
|
verbose_logger.debug(f"Error parsing message for logging: {e}")
|
||||||
|
raise e
|
||||||
|
if self._should_store_message(message_obj):
|
||||||
|
self.messages.append(message_obj)
|
||||||
|
|
||||||
def store_input(self, message: dict):
|
def store_input(self, message: dict):
|
||||||
"""Store input message"""
|
"""Store input message"""
|
||||||
|
|
|
@ -40,9 +40,4 @@ litellm_settings:
|
||||||
|
|
||||||
files_settings:
|
files_settings:
|
||||||
- custom_llm_provider: gemini
|
- custom_llm_provider: gemini
|
||||||
api_key: os.environ/GEMINI_API_KEY
|
api_key: os.environ/GEMINI_API_KEY
|
||||||
|
|
||||||
|
|
||||||
general_settings:
|
|
||||||
disable_spend_logs: True
|
|
||||||
disable_error_logs: True
|
|
|
@ -2,7 +2,7 @@ import asyncio
|
||||||
import json
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union
|
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from fastapi import HTTPException, Request, status
|
from fastapi import HTTPException, Request, status
|
||||||
|
@ -101,33 +101,22 @@ class ProxyBaseLLMRequestProcessing:
|
||||||
verbose_proxy_logger.error(f"Error setting custom headers: {e}")
|
verbose_proxy_logger.error(f"Error setting custom headers: {e}")
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
async def base_process_llm_request(
|
async def common_processing_pre_call_logic(
|
||||||
self,
|
self,
|
||||||
request: Request,
|
request: Request,
|
||||||
fastapi_response: Response,
|
|
||||||
user_api_key_dict: UserAPIKeyAuth,
|
|
||||||
route_type: Literal["acompletion", "aresponses"],
|
|
||||||
proxy_logging_obj: ProxyLogging,
|
|
||||||
general_settings: dict,
|
general_settings: dict,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
proxy_logging_obj: ProxyLogging,
|
||||||
proxy_config: ProxyConfig,
|
proxy_config: ProxyConfig,
|
||||||
select_data_generator: Callable,
|
route_type: Literal["acompletion", "aresponses", "_arealtime"],
|
||||||
llm_router: Optional[Router] = None,
|
version: Optional[str] = None,
|
||||||
model: Optional[str] = None,
|
|
||||||
user_model: Optional[str] = None,
|
user_model: Optional[str] = None,
|
||||||
user_temperature: Optional[float] = None,
|
user_temperature: Optional[float] = None,
|
||||||
user_request_timeout: Optional[float] = None,
|
user_request_timeout: Optional[float] = None,
|
||||||
user_max_tokens: Optional[int] = None,
|
user_max_tokens: Optional[int] = None,
|
||||||
user_api_base: Optional[str] = None,
|
user_api_base: Optional[str] = None,
|
||||||
version: Optional[str] = None,
|
model: Optional[str] = None,
|
||||||
) -> Any:
|
) -> Tuple[dict, LiteLLMLoggingObj]:
|
||||||
"""
|
|
||||||
Common request processing logic for both chat completions and responses API endpoints
|
|
||||||
"""
|
|
||||||
|
|
||||||
verbose_proxy_logger.debug(
|
|
||||||
"Request received by LiteLLM:\n{}".format(json.dumps(self.data, indent=4)),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.data = await add_litellm_data_to_request(
|
self.data = await add_litellm_data_to_request(
|
||||||
data=self.data,
|
data=self.data,
|
||||||
request=request,
|
request=request,
|
||||||
|
@ -182,13 +171,57 @@ class ProxyBaseLLMRequestProcessing:
|
||||||
|
|
||||||
self.data["litellm_logging_obj"] = logging_obj
|
self.data["litellm_logging_obj"] = logging_obj
|
||||||
|
|
||||||
|
return self.data, logging_obj
|
||||||
|
|
||||||
|
async def base_process_llm_request(
|
||||||
|
self,
|
||||||
|
request: Request,
|
||||||
|
fastapi_response: Response,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
route_type: Literal["acompletion", "aresponses", "_arealtime"],
|
||||||
|
proxy_logging_obj: ProxyLogging,
|
||||||
|
general_settings: dict,
|
||||||
|
proxy_config: ProxyConfig,
|
||||||
|
select_data_generator: Callable,
|
||||||
|
llm_router: Optional[Router] = None,
|
||||||
|
model: Optional[str] = None,
|
||||||
|
user_model: Optional[str] = None,
|
||||||
|
user_temperature: Optional[float] = None,
|
||||||
|
user_request_timeout: Optional[float] = None,
|
||||||
|
user_max_tokens: Optional[int] = None,
|
||||||
|
user_api_base: Optional[str] = None,
|
||||||
|
version: Optional[str] = None,
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Common request processing logic for both chat completions and responses API endpoints
|
||||||
|
"""
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
"Request received by LiteLLM:\n{}".format(json.dumps(self.data, indent=4)),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.data, logging_obj = await self.common_processing_pre_call_logic(
|
||||||
|
request=request,
|
||||||
|
general_settings=general_settings,
|
||||||
|
proxy_logging_obj=proxy_logging_obj,
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
version=version,
|
||||||
|
proxy_config=proxy_config,
|
||||||
|
user_model=user_model,
|
||||||
|
user_temperature=user_temperature,
|
||||||
|
user_request_timeout=user_request_timeout,
|
||||||
|
user_max_tokens=user_max_tokens,
|
||||||
|
user_api_base=user_api_base,
|
||||||
|
model=model,
|
||||||
|
route_type=route_type,
|
||||||
|
)
|
||||||
|
|
||||||
tasks = []
|
tasks = []
|
||||||
tasks.append(
|
tasks.append(
|
||||||
proxy_logging_obj.during_call_hook(
|
proxy_logging_obj.during_call_hook(
|
||||||
data=self.data,
|
data=self.data,
|
||||||
user_api_key_dict=user_api_key_dict,
|
user_api_key_dict=user_api_key_dict,
|
||||||
call_type=ProxyBaseLLMRequestProcessing._get_pre_call_type(
|
call_type=ProxyBaseLLMRequestProcessing._get_pre_call_type(
|
||||||
route_type=route_type
|
route_type=route_type # type: ignore
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
@ -194,13 +194,15 @@ class _ProxyDBLogger(CustomLogger):
|
||||||
error_msg = f"Error in tracking cost callback - {str(e)}\n Traceback:{traceback.format_exc()}"
|
error_msg = f"Error in tracking cost callback - {str(e)}\n Traceback:{traceback.format_exc()}"
|
||||||
model = kwargs.get("model", "")
|
model = kwargs.get("model", "")
|
||||||
metadata = kwargs.get("litellm_params", {}).get("metadata", {})
|
metadata = kwargs.get("litellm_params", {}).get("metadata", {})
|
||||||
error_msg += f"\n Args to _PROXY_track_cost_callback\n model: {model}\n metadata: {metadata}\n"
|
call_type = kwargs.get("call_type", "")
|
||||||
|
error_msg += f"\n Args to _PROXY_track_cost_callback\n model: {model}\n metadata: {metadata}\n call_type: {call_type}\n"
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
proxy_logging_obj.failed_tracking_alert(
|
proxy_logging_obj.failed_tracking_alert(
|
||||||
error_message=error_msg,
|
error_message=error_msg,
|
||||||
failing_model=model,
|
failing_model=model,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
verbose_proxy_logger.exception(
|
verbose_proxy_logger.exception(
|
||||||
"Error in tracking cost callback - %s", str(e)
|
"Error in tracking cost callback - %s", str(e)
|
||||||
)
|
)
|
||||||
|
|
|
@ -191,6 +191,7 @@ def clean_headers(
|
||||||
if litellm_key_header_name is not None:
|
if litellm_key_header_name is not None:
|
||||||
special_headers.append(litellm_key_header_name.lower())
|
special_headers.append(litellm_key_header_name.lower())
|
||||||
clean_headers = {}
|
clean_headers = {}
|
||||||
|
|
||||||
for header, value in headers.items():
|
for header, value in headers.items():
|
||||||
if header.lower() not in special_headers:
|
if header.lower() not in special_headers:
|
||||||
clean_headers[header] = value
|
clean_headers[header] = value
|
||||||
|
|
|
@ -4261,8 +4261,47 @@ async def websocket_endpoint(
|
||||||
"websocket": websocket,
|
"websocket": websocket,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
headers = dict(websocket.headers.items()) # Convert headers to dict first
|
||||||
|
|
||||||
|
request = Request(
|
||||||
|
scope={
|
||||||
|
"type": "http",
|
||||||
|
"headers": [(k.lower().encode(), v.encode()) for k, v in headers.items()],
|
||||||
|
"method": "POST",
|
||||||
|
"path": "/v1/realtime",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
request._url = websocket.url
|
||||||
|
|
||||||
|
async def return_body():
|
||||||
|
return_string = f'{{"model": "{model}"}}'
|
||||||
|
# return string as bytes
|
||||||
|
return return_string.encode()
|
||||||
|
|
||||||
|
request.body = return_body # type: ignore
|
||||||
|
|
||||||
### ROUTE THE REQUEST ###
|
### ROUTE THE REQUEST ###
|
||||||
|
base_llm_response_processor = ProxyBaseLLMRequestProcessing(data=data)
|
||||||
try:
|
try:
|
||||||
|
(
|
||||||
|
data,
|
||||||
|
litellm_logging_obj,
|
||||||
|
) = await base_llm_response_processor.common_processing_pre_call_logic(
|
||||||
|
request=request,
|
||||||
|
general_settings=general_settings,
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
version=version,
|
||||||
|
proxy_logging_obj=proxy_logging_obj,
|
||||||
|
proxy_config=proxy_config,
|
||||||
|
user_model=user_model,
|
||||||
|
user_temperature=user_temperature,
|
||||||
|
user_request_timeout=user_request_timeout,
|
||||||
|
user_max_tokens=user_max_tokens,
|
||||||
|
user_api_base=user_api_base,
|
||||||
|
model=model,
|
||||||
|
route_type="_arealtime",
|
||||||
|
)
|
||||||
llm_call = await route_request(
|
llm_call = await route_request(
|
||||||
data=data,
|
data=data,
|
||||||
route_type="_arealtime",
|
route_type="_arealtime",
|
||||||
|
|
|
@ -892,7 +892,17 @@ class BaseLiteLLMOpenAIResponseObject(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class OutputTokensDetails(BaseLiteLLMOpenAIResponseObject):
|
class OutputTokensDetails(BaseLiteLLMOpenAIResponseObject):
|
||||||
reasoning_tokens: int
|
reasoning_tokens: Optional[int] = None
|
||||||
|
|
||||||
|
text_tokens: Optional[int] = None
|
||||||
|
|
||||||
|
model_config = {"extra": "allow"}
|
||||||
|
|
||||||
|
|
||||||
|
class InputTokensDetails(BaseLiteLLMOpenAIResponseObject):
|
||||||
|
audio_tokens: Optional[int] = None
|
||||||
|
cached_tokens: Optional[int] = None
|
||||||
|
text_tokens: Optional[int] = None
|
||||||
|
|
||||||
model_config = {"extra": "allow"}
|
model_config = {"extra": "allow"}
|
||||||
|
|
||||||
|
@ -901,10 +911,13 @@ class ResponseAPIUsage(BaseLiteLLMOpenAIResponseObject):
|
||||||
input_tokens: int
|
input_tokens: int
|
||||||
"""The number of input tokens."""
|
"""The number of input tokens."""
|
||||||
|
|
||||||
|
input_tokens_details: Optional[InputTokensDetails] = None
|
||||||
|
"""A detailed breakdown of the input tokens."""
|
||||||
|
|
||||||
output_tokens: int
|
output_tokens: int
|
||||||
"""The number of output tokens."""
|
"""The number of output tokens."""
|
||||||
|
|
||||||
output_tokens_details: Optional[OutputTokensDetails]
|
output_tokens_details: Optional[OutputTokensDetails] = None
|
||||||
"""A detailed breakdown of the output tokens."""
|
"""A detailed breakdown of the output tokens."""
|
||||||
|
|
||||||
total_tokens: int
|
total_tokens: int
|
||||||
|
@ -1173,3 +1186,20 @@ ResponsesAPIStreamingResponse = Annotated[
|
||||||
|
|
||||||
|
|
||||||
REASONING_EFFORT = Literal["low", "medium", "high"]
|
REASONING_EFFORT = Literal["low", "medium", "high"]
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIRealtimeStreamSessionEvents(TypedDict):
|
||||||
|
event_id: str
|
||||||
|
session: dict
|
||||||
|
type: Union[Literal["session.created"], Literal["session.updated"]]
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIRealtimeStreamResponseBaseObject(TypedDict):
|
||||||
|
event_id: str
|
||||||
|
response: dict
|
||||||
|
type: str
|
||||||
|
|
||||||
|
|
||||||
|
OpenAIRealtimeStreamList = List[
|
||||||
|
Union[OpenAIRealtimeStreamResponseBaseObject, OpenAIRealtimeStreamSessionEvents]
|
||||||
|
]
|
||||||
|
|
65
tests/litellm/litellm_core_utils/test_realtime_streaming.py
Normal file
65
tests/litellm/litellm_core_utils/test_realtime_streaming.py
Normal file
|
@ -0,0 +1,65 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
|
||||||
|
from litellm.litellm_core_utils.realtime_streaming import RealTimeStreaming
|
||||||
|
from litellm.types.llms.openai import (
|
||||||
|
OpenAIRealtimeStreamResponseBaseObject,
|
||||||
|
OpenAIRealtimeStreamSessionEvents,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_realtime_streaming_store_message():
|
||||||
|
# Setup
|
||||||
|
websocket = MagicMock()
|
||||||
|
backend_ws = MagicMock()
|
||||||
|
logging_obj = MagicMock()
|
||||||
|
streaming = RealTimeStreaming(websocket, backend_ws, logging_obj)
|
||||||
|
|
||||||
|
# Test 1: Session created event (string input)
|
||||||
|
session_created_msg = json.dumps(
|
||||||
|
{"type": "session.created", "session": {"id": "test-session"}}
|
||||||
|
)
|
||||||
|
streaming.store_message(session_created_msg)
|
||||||
|
assert len(streaming.messages) == 1
|
||||||
|
assert "session" in streaming.messages[0]
|
||||||
|
assert streaming.messages[0]["type"] == "session.created"
|
||||||
|
|
||||||
|
# Test 2: Response object (bytes input)
|
||||||
|
response_msg = json.dumps(
|
||||||
|
{
|
||||||
|
"type": "response.create",
|
||||||
|
"event_id": "test-event",
|
||||||
|
"response": {"text": "test response"},
|
||||||
|
}
|
||||||
|
).encode("utf-8")
|
||||||
|
streaming.store_message(response_msg)
|
||||||
|
assert len(streaming.messages) == 2
|
||||||
|
assert "response" in streaming.messages[1]
|
||||||
|
assert streaming.messages[1]["type"] == "response.create"
|
||||||
|
|
||||||
|
# Test 3: Invalid message format
|
||||||
|
invalid_msg = "invalid json"
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
streaming.store_message(invalid_msg)
|
||||||
|
|
||||||
|
# Test 4: Message type not in logged events
|
||||||
|
streaming.logged_real_time_event_types = [
|
||||||
|
"session.created"
|
||||||
|
] # Only log session.created
|
||||||
|
other_msg = json.dumps(
|
||||||
|
{
|
||||||
|
"type": "response.done",
|
||||||
|
"event_id": "test-event",
|
||||||
|
"response": {"text": "test response"},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
streaming.store_message(other_msg)
|
||||||
|
assert len(streaming.messages) == 2 # Should not store the new message
|
|
@ -13,7 +13,11 @@ from unittest.mock import MagicMock, patch
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.cost_calculator import response_cost_calculator
|
from litellm.cost_calculator import (
|
||||||
|
handle_realtime_stream_cost_calculation,
|
||||||
|
response_cost_calculator,
|
||||||
|
)
|
||||||
|
from litellm.types.llms.openai import OpenAIRealtimeStreamList
|
||||||
from litellm.types.utils import ModelResponse, PromptTokensDetailsWrapper, Usage
|
from litellm.types.utils import ModelResponse, PromptTokensDetailsWrapper, Usage
|
||||||
|
|
||||||
|
|
||||||
|
@ -71,3 +75,66 @@ def test_cost_calculator_with_usage():
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result == expected_cost, f"Got {result}, Expected {expected_cost}"
|
assert result == expected_cost, f"Got {result}, Expected {expected_cost}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_handle_realtime_stream_cost_calculation():
|
||||||
|
# Setup test data
|
||||||
|
results: OpenAIRealtimeStreamList = [
|
||||||
|
{"type": "session.created", "session": {"model": "gpt-3.5-turbo"}},
|
||||||
|
{
|
||||||
|
"type": "response.done",
|
||||||
|
"response": {
|
||||||
|
"usage": {"input_tokens": 100, "output_tokens": 50, "total_tokens": 150}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "response.done",
|
||||||
|
"response": {
|
||||||
|
"usage": {
|
||||||
|
"input_tokens": 200,
|
||||||
|
"output_tokens": 100,
|
||||||
|
"total_tokens": 300,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
# Test with explicit model name
|
||||||
|
cost = handle_realtime_stream_cost_calculation(
|
||||||
|
results=results,
|
||||||
|
custom_llm_provider="openai",
|
||||||
|
litellm_model_name="gpt-3.5-turbo",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate expected cost
|
||||||
|
# gpt-3.5-turbo costs: $0.0015/1K tokens input, $0.002/1K tokens output
|
||||||
|
expected_cost = (300 * 0.0015 / 1000) + ( # input tokens (100 + 200)
|
||||||
|
150 * 0.002 / 1000
|
||||||
|
) # output tokens (50 + 100)
|
||||||
|
assert (
|
||||||
|
abs(cost - expected_cost) <= 0.00075
|
||||||
|
) # Allow small floating point differences
|
||||||
|
|
||||||
|
# Test with different model name in session
|
||||||
|
results[0]["session"]["model"] = "gpt-4"
|
||||||
|
cost = handle_realtime_stream_cost_calculation(
|
||||||
|
results=results,
|
||||||
|
custom_llm_provider="openai",
|
||||||
|
litellm_model_name="gpt-3.5-turbo",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate expected cost using gpt-4 rates
|
||||||
|
# gpt-4 costs: $0.03/1K tokens input, $0.06/1K tokens output
|
||||||
|
expected_cost = (300 * 0.03 / 1000) + ( # input tokens
|
||||||
|
150 * 0.06 / 1000
|
||||||
|
) # output tokens
|
||||||
|
assert abs(cost - expected_cost) < 0.00076
|
||||||
|
|
||||||
|
# Test with no response.done events
|
||||||
|
results = [{"type": "session.created", "session": {"model": "gpt-3.5-turbo"}}]
|
||||||
|
cost = handle_realtime_stream_cost_calculation(
|
||||||
|
results=results,
|
||||||
|
custom_llm_provider="openai",
|
||||||
|
litellm_model_name="gpt-3.5-turbo",
|
||||||
|
)
|
||||||
|
assert cost == 0.0 # No usage, no cost
|
||||||
|
|
39
tests/logging_callback_tests/log.txt
Normal file
39
tests/logging_callback_tests/log.txt
Normal file
|
@ -0,0 +1,39 @@
|
||||||
|
============================= test session starts ==============================
|
||||||
|
platform darwin -- Python 3.11.4, pytest-7.4.1, pluggy-1.2.0 -- /Library/Frameworks/Python.framework/Versions/3.11/bin/python3
|
||||||
|
cachedir: .pytest_cache
|
||||||
|
rootdir: /Users/krrishdholakia/Documents/litellm/tests/logging_callback_tests
|
||||||
|
plugins: snapshot-0.9.0, cov-5.0.0, timeout-2.2.0, postgresql-7.0.1, respx-0.21.1, asyncio-0.21.1, langsmith-0.3.4, anyio-4.8.0, mock-3.11.1, Faker-25.9.2
|
||||||
|
asyncio: mode=Mode.STRICT
|
||||||
|
collecting ... collected 4 items
|
||||||
|
|
||||||
|
test_built_in_tools_cost_tracking.py::test_openai_responses_api_web_search_cost_tracking[tools_config0-search_context_size_low-True] PASSED [ 25%]
|
||||||
|
test_built_in_tools_cost_tracking.py::test_openai_responses_api_web_search_cost_tracking[tools_config1-search_context_size_low-False] PASSED [ 50%]
|
||||||
|
test_built_in_tools_cost_tracking.py::test_openai_responses_api_web_search_cost_tracking[tools_config2-search_context_size_medium-True] PASSED [ 75%]
|
||||||
|
test_built_in_tools_cost_tracking.py::test_openai_responses_api_web_search_cost_tracking[tools_config3-search_context_size_medium-False] PASSED [100%]
|
||||||
|
|
||||||
|
=============================== warnings summary ===============================
|
||||||
|
../../../../../../Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/pydantic/_internal/_config.py:295
|
||||||
|
/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/pydantic/_internal/_config.py:295: PydanticDeprecatedSince20: Support for class-based `config` is deprecated, use ConfigDict instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.10/migration/
|
||||||
|
warnings.warn(DEPRECATION_MESSAGE, DeprecationWarning)
|
||||||
|
|
||||||
|
../../litellm/litellm_core_utils/get_model_cost_map.py:24
|
||||||
|
test_built_in_tools_cost_tracking.py::test_openai_responses_api_web_search_cost_tracking[tools_config0-search_context_size_low-True]
|
||||||
|
test_built_in_tools_cost_tracking.py::test_openai_responses_api_web_search_cost_tracking[tools_config1-search_context_size_low-False]
|
||||||
|
test_built_in_tools_cost_tracking.py::test_openai_responses_api_web_search_cost_tracking[tools_config2-search_context_size_medium-True]
|
||||||
|
test_built_in_tools_cost_tracking.py::test_openai_responses_api_web_search_cost_tracking[tools_config3-search_context_size_medium-False]
|
||||||
|
/Users/krrishdholakia/Documents/litellm/litellm/litellm_core_utils/get_model_cost_map.py:24: DeprecationWarning: open_text is deprecated. Use files() instead. Refer to https://importlib-resources.readthedocs.io/en/latest/using.html#migrating-from-legacy for migration advice.
|
||||||
|
with importlib.resources.open_text(
|
||||||
|
|
||||||
|
../../litellm/utils.py:183
|
||||||
|
/Users/krrishdholakia/Documents/litellm/litellm/utils.py:183: DeprecationWarning: open_text is deprecated. Use files() instead. Refer to https://importlib-resources.readthedocs.io/en/latest/using.html#migrating-from-legacy for migration advice.
|
||||||
|
with resources.open_text(
|
||||||
|
|
||||||
|
test_built_in_tools_cost_tracking.py::test_openai_responses_api_web_search_cost_tracking[tools_config0-search_context_size_low-True]
|
||||||
|
test_built_in_tools_cost_tracking.py::test_openai_responses_api_web_search_cost_tracking[tools_config1-search_context_size_low-False]
|
||||||
|
test_built_in_tools_cost_tracking.py::test_openai_responses_api_web_search_cost_tracking[tools_config2-search_context_size_medium-True]
|
||||||
|
test_built_in_tools_cost_tracking.py::test_openai_responses_api_web_search_cost_tracking[tools_config3-search_context_size_medium-False]
|
||||||
|
/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/httpx/_content.py:204: DeprecationWarning: Use 'content=<...>' to upload raw bytes/text content.
|
||||||
|
warnings.warn(message, DeprecationWarning)
|
||||||
|
|
||||||
|
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
|
||||||
|
======================= 4 passed, 11 warnings in 18.95s ========================
|
Loading…
Add table
Add a link
Reference in a new issue