From cda50e5d47e7f39b55e49e74766d90091b48ace1 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 15 Aug 2024 17:09:02 -0700 Subject: [PATCH] fix(s3.py): fix s3 logging payload to have valid json values Previously pydantic objects were being stringified, making them unparsable --- litellm/integrations/s3.py | 30 ++-- litellm/litellm_core_utils/litellm_logging.py | 131 ++++++++++++++++++ litellm/proxy/_new_secret_config.yaml | 10 +- litellm/tests/test_custom_callback_input.py | 52 +++++++ litellm/types/utils.py | 44 ++++++ 5 files changed, 244 insertions(+), 23 deletions(-) diff --git a/litellm/integrations/s3.py b/litellm/integrations/s3.py index 6e8c4a4e4..c440be5f1 100644 --- a/litellm/integrations/s3.py +++ b/litellm/integrations/s3.py @@ -7,9 +7,11 @@ import subprocess import sys import traceback import uuid +from typing import Optional import litellm from litellm._logging import print_verbose, verbose_logger +from litellm.types.utils import StandardLoggingPayload class S3Logger: @@ -123,29 +125,13 @@ class S3Logger: else: clean_metadata[key] = value - # Build the initial payload - payload = { - "id": id, - "call_type": call_type, - "cache_hit": cache_hit, - "startTime": start_time, - "endTime": end_time, - "model": kwargs.get("model", ""), - "user": kwargs.get("user", ""), - "modelParameters": optional_params, - "messages": messages, - "response": response_obj, - "usage": usage, - "metadata": clean_metadata, - } - # Ensure everything in the payload is converted to str - for key, value in payload.items(): - try: - payload[key] = str(value) - except: - # non blocking if it can't cast to a str - pass + payload: Optional[StandardLoggingPayload] = kwargs.get( + "standard_logging_object", None + ) + + if payload is None: + return s3_file_name = litellm.utils.get_logging_id(start_time, payload) or "" s3_object_key = ( diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index 9f84b26d6..db37fe450 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -10,6 +10,7 @@ import sys import time import traceback import uuid +from datetime import datetime as dt_object from typing import Any, Callable, Dict, List, Literal, Optional, Union from pydantic import BaseModel @@ -33,6 +34,8 @@ from litellm.types.utils import ( EmbeddingResponse, ImageResponse, ModelResponse, + StandardLoggingMetadata, + StandardLoggingPayload, TextCompletionResponse, TranscriptionResponse, ) @@ -560,6 +563,14 @@ class Logging: self.model_call_details["log_event_type"] = "successful_api_call" self.model_call_details["end_time"] = end_time self.model_call_details["cache_hit"] = cache_hit + self.model_call_details["standard_logging_object"] = ( + get_standard_logging_object_payload( + kwargs=self.model_call_details, + init_response_obj=result, + start_time=start_time, + end_time=end_time, + ) + ) ## if model in model cost map - log the response cost ## else set cost to None if ( @@ -2166,3 +2177,123 @@ def use_custom_pricing_for_model(litellm_params: Optional[dict]) -> bool: if k in SPECIAL_MODEL_INFO_PARAMS: return True return False + + +def get_standard_logging_object_payload( + kwargs: dict, init_response_obj: Any, start_time: dt_object, end_time: dt_object +) -> Optional[StandardLoggingPayload]: + if kwargs is None: + kwargs = {} + if init_response_obj is None: + response_obj = {} + elif isinstance(init_response_obj, BaseModel): + response_obj = init_response_obj.model_dump() + elif isinstance(init_response_obj, dict): + response_obj = init_response_obj + + # standardize this function to be used across, s3, dynamoDB, langfuse logging + litellm_params = kwargs.get("litellm_params", {}) + proxy_server_request = litellm_params.get("proxy_server_request") or {} + end_user_id = proxy_server_request.get("body", {}).get("user", None) + metadata = ( + litellm_params.get("metadata", {}) or {} + ) # if litellm_params['metadata'] == None + completion_start_time = kwargs.get("completion_start_time", end_time) + call_type = kwargs.get("call_type") + cache_hit = kwargs.get("cache_hit", False) + usage = response_obj.get("usage", None) or {} + if type(usage) == litellm.Usage: + usage = dict(usage) + id = response_obj.get("id", kwargs.get("litellm_call_id")) + api_key = metadata.get("user_api_key", "") + if api_key is not None and isinstance(api_key, str) and api_key.startswith("sk-"): + # redact the api key + api_key = "REDACTED-BY-LITELLM--contains-sk-keyword" + + _model_id = metadata.get("model_info", {}).get("id", "") + _model_group = metadata.get("model_group", "") + + request_tags = ( + json.dumps(metadata.get("tags", [])) + if isinstance(metadata.get("tags", []), list) + else "[]" + ) + + # cleanup timestamps + if isinstance(start_time, datetime.datetime): + start_time_float = start_time.timestamp() + if isinstance(end_time, datetime.datetime): + end_time_float = end_time.timestamp() + if isinstance(completion_start_time, datetime.datetime): + completion_start_time_float = completion_start_time.timestamp() + + # clean up litellm metadata + clean_metadata = StandardLoggingMetadata( + user_api_key=None, + user_api_key_alias=None, + user_api_key_team_id=None, + user_api_key_user_id=None, + user_api_key_team_alias=None, + spend_logs_metadata=None, + requester_ip_address=None, + ) + if isinstance(metadata, dict): + # Filter the metadata dictionary to include only the specified keys + clean_metadata = StandardLoggingMetadata( + **{ # type: ignore + key: metadata[key] + for key in StandardLoggingMetadata.__annotations__.keys() + if key in metadata + } + ) + + if litellm.cache is not None: + cache_key = litellm.cache.get_cache_key(**kwargs) + else: + cache_key = "Cache OFF" + if cache_hit is True: + import time + + id = f"{id}_cache_hit{time.time()}" # do not duplicate the request id + + try: + payload: StandardLoggingPayload = StandardLoggingPayload( + id=str(id), + call_type=call_type or "", + api_key=str(api_key), + cache_hit=cache_hit, + startTime=start_time_float, + endTime=end_time_float, + completionStartTime=completion_start_time_float, + model=kwargs.get("model", "") or "", + user=metadata.get("user_api_key_user_id", "") or "", + team_id=metadata.get("user_api_key_team_id", "") or "", + metadata=clean_metadata, + cache_key=cache_key, + spend=kwargs.get("response_cost", 0), + total_tokens=usage.get("total_tokens", 0), + prompt_tokens=usage.get("prompt_tokens", 0), + completion_tokens=usage.get("completion_tokens", 0), + request_tags=request_tags, + end_user=end_user_id or "", + api_base=litellm_params.get("api_base", ""), + model_group=_model_group, + model_id=_model_id, + requester_ip_address=clean_metadata.get("requester_ip_address", None), + messages=kwargs.get("messages"), + response=response_obj, + model_parameters=kwargs.get("optional_params", None), + ) + + verbose_logger.debug( + "Standard Logging: created payload - payload: %s\n\n", payload + ) + + return payload + except Exception as e: + verbose_logger.warning( + "Error creating standard logging object - {}\n{}".format( + str(e), traceback.format_exc() + ) + ) + return None diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index dfa5c1652..49ea65297 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -3,4 +3,12 @@ model_list: litellm_params: model: "gpt-4" model_info: - my_custom_key: "my_custom_value" \ No newline at end of file + my_custom_key: "my_custom_value" + +litellm_settings: + success_callback: ["s3"] + s3_callback_params: + s3_bucket_name: mytestbucketlitellm # AWS Bucket Name for S3 + s3_region_name: us-west-2 # AWS Region Name for S3 + s3_aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID # us os.environ/ to pass environment variables. This is AWS Access Key ID for S3 + s3_aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY # AWS Secret Access Key for S3 diff --git a/litellm/tests/test_custom_callback_input.py b/litellm/tests/test_custom_callback_input.py index 2995fdbe5..dc1508e85 100644 --- a/litellm/tests/test_custom_callback_input.py +++ b/litellm/tests/test_custom_callback_input.py @@ -1166,3 +1166,55 @@ def test_turn_off_message_logging(): time.sleep(2) assert len(customHandler.errors) == 0 + + +##### VALID JSON ###### + + +def test_standard_logging_payload(): + """ + Ensure valid standard_logging_payload is passed for logging calls to s3 + + Motivation: provide a standard set of things that are logged to s3/gcs/future integrations across all llm calls + """ + from litellm.types.utils import StandardLoggingPayload + + # sync completion + customHandler = CompletionCustomHandler() + litellm.callbacks = [customHandler] + + with patch.object( + customHandler, "log_success_event", new=MagicMock() + ) as mock_client: + _ = litellm.completion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Hey, how's it going?"}], + mock_response="Going well!", + ) + + time.sleep(2) + mock_client.assert_called_once() + + print( + f"mock_client_post.call_args: {mock_client.call_args.kwargs['kwargs'].keys()}" + ) + assert "standard_logging_object" in mock_client.call_args.kwargs["kwargs"] + assert ( + mock_client.call_args.kwargs["kwargs"]["standard_logging_object"] + is not None + ) + + print(mock_client.call_args.kwargs["kwargs"]["standard_logging_object"]) + + keys_list = list(StandardLoggingPayload.__annotations__.keys()) + + for k in keys_list: + assert ( + k in mock_client.call_args.kwargs["kwargs"]["standard_logging_object"] + ) + + ## json serializable + json_str_payload = json.dumps( + mock_client.call_args.kwargs["kwargs"]["standard_logging_object"] + ) + json.loads(json_str_payload) diff --git a/litellm/types/utils.py b/litellm/types/utils.py index 5cf627086..4968596df 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -1166,3 +1166,47 @@ class AdapterCompletionStreamWrapper: raise StopIteration except StopIteration: raise StopAsyncIteration + + +class StandardLoggingMetadata(TypedDict): + """ + Specific metadata k,v pairs logged to integration for easier cost tracking + """ + + user_api_key: Optional[str] + user_api_key_alias: Optional[str] + user_api_key_team_id: Optional[str] + user_api_key_user_id: Optional[str] + user_api_key_team_alias: Optional[str] + spend_logs_metadata: Optional[ + dict + ] # special param to log k,v pairs to spendlogs for a call + requester_ip_address: Optional[str] + + +class StandardLoggingPayload(TypedDict): + id: str + call_type: str + api_key: str + spend: float + total_tokens: int + prompt_tokens: int + completion_tokens: int + startTime: float + endTime: float + completionStartTime: float + model: str + model_id: Optional[str] + model_group: Optional[str] + api_base: str + user: str + metadata: StandardLoggingMetadata + cache_hit: Optional[bool] + cache_key: Optional[str] + request_tags: str # json str + team_id: Optional[str] + end_user: Optional[str] + requester_ip_address: Optional[str] + messages: Optional[Union[str, list, dict]] + response: Optional[Union[str, list, dict]] + model_parameters: dict