feat(litellm_logging.py): cleanup payload + add response cost to logged payload

This commit is contained in:
Krrish Dholakia 2024-08-15 17:53:25 -07:00
parent 3ddeb3297d
commit f6dba82882
3 changed files with 51 additions and 15 deletions

View file

@ -35,6 +35,7 @@ from litellm.types.utils import (
EmbeddingResponse, EmbeddingResponse,
ImageResponse, ImageResponse,
ModelResponse, ModelResponse,
StandardLoggingHiddenParams,
StandardLoggingMetadata, StandardLoggingMetadata,
StandardLoggingPayload, StandardLoggingPayload,
TextCompletionResponse, TextCompletionResponse,
@ -564,14 +565,6 @@ class Logging:
self.model_call_details["log_event_type"] = "successful_api_call" self.model_call_details["log_event_type"] = "successful_api_call"
self.model_call_details["end_time"] = end_time self.model_call_details["end_time"] = end_time
self.model_call_details["cache_hit"] = cache_hit self.model_call_details["cache_hit"] = cache_hit
self.model_call_details["standard_logging_object"] = (
get_standard_logging_object_payload(
kwargs=self.model_call_details,
init_response_obj=result,
start_time=start_time,
end_time=end_time,
)
)
## if model in model cost map - log the response cost ## if model in model cost map - log the response cost
## else set cost to None ## else set cost to None
if ( if (
@ -629,6 +622,16 @@ class Logging:
total_time=float_diff, total_time=float_diff,
) )
## STANDARDIZED LOGGING PAYLOAD
self.model_call_details["standard_logging_object"] = (
get_standard_logging_object_payload(
kwargs=self.model_call_details,
init_response_obj=result,
start_time=start_time,
end_time=end_time,
)
)
return start_time, end_time, result return start_time, end_time, result
except Exception as e: except Exception as e:
raise Exception(f"[Non-Blocking] LiteLLM.Success_Call Error: {str(e)}") raise Exception(f"[Non-Blocking] LiteLLM.Success_Call Error: {str(e)}")
@ -2193,10 +2196,13 @@ def get_standard_logging_object_payload(
) -> Optional[StandardLoggingPayload]: ) -> Optional[StandardLoggingPayload]:
if kwargs is None: if kwargs is None:
kwargs = {} kwargs = {}
hidden_params: Optional[dict] = None
if init_response_obj is None: if init_response_obj is None:
response_obj = {} response_obj = {}
elif isinstance(init_response_obj, BaseModel): elif isinstance(init_response_obj, BaseModel):
response_obj = init_response_obj.model_dump() response_obj = init_response_obj.model_dump()
hidden_params = getattr(init_response_obj, "_hidden_params", None)
elif isinstance(init_response_obj, dict): elif isinstance(init_response_obj, dict):
response_obj = init_response_obj response_obj = init_response_obj
@ -2230,6 +2236,22 @@ def get_standard_logging_object_payload(
if isinstance(completion_start_time, datetime.datetime): if isinstance(completion_start_time, datetime.datetime):
completion_start_time_float = completion_start_time.timestamp() completion_start_time_float = completion_start_time.timestamp()
# clean up litellm hidden params
clean_hidden_params = StandardLoggingHiddenParams(
model_id=None,
cache_key=None,
api_base=None,
response_cost=None,
additional_headers=None,
)
if hidden_params is not None:
clean_hidden_params = StandardLoggingHiddenParams(
**{ # type: ignore
key: hidden_params[key]
for key in StandardLoggingHiddenParams.__annotations__.keys()
if key in hidden_params
}
)
# clean up litellm metadata # clean up litellm metadata
clean_metadata = StandardLoggingMetadata( clean_metadata = StandardLoggingMetadata(
user_api_key_hash=None, user_api_key_hash=None,
@ -2259,7 +2281,7 @@ def get_standard_logging_object_payload(
if litellm.cache is not None: if litellm.cache is not None:
cache_key = litellm.cache.get_cache_key(**kwargs) cache_key = litellm.cache.get_cache_key(**kwargs)
else: else:
cache_key = "Cache OFF" cache_key = None
if cache_hit is True: if cache_hit is True:
import time import time
@ -2274,11 +2296,9 @@ def get_standard_logging_object_payload(
endTime=end_time_float, endTime=end_time_float,
completionStartTime=completion_start_time_float, completionStartTime=completion_start_time_float,
model=kwargs.get("model", "") or "", model=kwargs.get("model", "") or "",
user=metadata.get("user_api_key_user_id", "") or "",
team_id=metadata.get("user_api_key_team_id", "") or "",
metadata=clean_metadata, metadata=clean_metadata,
cache_key=cache_key, cache_key=cache_key,
spend=kwargs.get("response_cost", 0), response_cost=kwargs.get("response_cost", 0),
total_tokens=usage.get("total_tokens", 0), total_tokens=usage.get("total_tokens", 0),
prompt_tokens=usage.get("prompt_tokens", 0), prompt_tokens=usage.get("prompt_tokens", 0),
completion_tokens=usage.get("completion_tokens", 0), completion_tokens=usage.get("completion_tokens", 0),
@ -2291,6 +2311,7 @@ def get_standard_logging_object_payload(
messages=kwargs.get("messages"), messages=kwargs.get("messages"),
response=response_obj, response=response_obj,
model_parameters=kwargs.get("optional_params", None), model_parameters=kwargs.get("optional_params", None),
hidden_params=clean_hidden_params,
) )
verbose_logger.debug( verbose_logger.debug(

View file

@ -1218,3 +1218,11 @@ def test_standard_logging_payload():
mock_client.call_args.kwargs["kwargs"]["standard_logging_object"] mock_client.call_args.kwargs["kwargs"]["standard_logging_object"]
) )
json.loads(json_str_payload) json.loads(json_str_payload)
## response cost
assert (
mock_client.call_args.kwargs["kwargs"]["standard_logging_object"][
"response_cost"
]
> 0
)

View file

@ -1184,10 +1184,18 @@ class StandardLoggingMetadata(TypedDict):
requester_ip_address: Optional[str] requester_ip_address: Optional[str]
class StandardLoggingHiddenParams(TypedDict):
model_id: Optional[str]
cache_key: Optional[str]
api_base: Optional[str]
response_cost: Optional[str]
additional_headers: Optional[dict]
class StandardLoggingPayload(TypedDict): class StandardLoggingPayload(TypedDict):
id: str id: str
call_type: str call_type: str
spend: float response_cost: float
total_tokens: int total_tokens: int
prompt_tokens: int prompt_tokens: int
completion_tokens: int completion_tokens: int
@ -1198,14 +1206,13 @@ class StandardLoggingPayload(TypedDict):
model_id: Optional[str] model_id: Optional[str]
model_group: Optional[str] model_group: Optional[str]
api_base: str api_base: str
user: str
metadata: StandardLoggingMetadata metadata: StandardLoggingMetadata
cache_hit: Optional[bool] cache_hit: Optional[bool]
cache_key: Optional[str] cache_key: Optional[str]
request_tags: list request_tags: list
team_id: Optional[str]
end_user: Optional[str] end_user: Optional[str]
requester_ip_address: Optional[str] requester_ip_address: Optional[str]
messages: Optional[Union[str, list, dict]] messages: Optional[Union[str, list, dict]]
response: Optional[Union[str, list, dict]] response: Optional[Union[str, list, dict]]
model_parameters: dict model_parameters: dict
hidden_params: StandardLoggingHiddenParams