forked from phoenix/litellm-mirror
fix(s3.py): fix s3 logging payload to have valid json values
Previously pydantic objects were being stringified, making them unparsable
This commit is contained in:
parent
eb6a0a32f1
commit
cda50e5d47
5 changed files with 244 additions and 23 deletions
|
@ -7,9 +7,11 @@ import subprocess
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
import uuid
|
import uuid
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm._logging import print_verbose, verbose_logger
|
from litellm._logging import print_verbose, verbose_logger
|
||||||
|
from litellm.types.utils import StandardLoggingPayload
|
||||||
|
|
||||||
|
|
||||||
class S3Logger:
|
class S3Logger:
|
||||||
|
@ -123,29 +125,13 @@ class S3Logger:
|
||||||
else:
|
else:
|
||||||
clean_metadata[key] = value
|
clean_metadata[key] = value
|
||||||
|
|
||||||
# Build the initial payload
|
|
||||||
payload = {
|
|
||||||
"id": id,
|
|
||||||
"call_type": call_type,
|
|
||||||
"cache_hit": cache_hit,
|
|
||||||
"startTime": start_time,
|
|
||||||
"endTime": end_time,
|
|
||||||
"model": kwargs.get("model", ""),
|
|
||||||
"user": kwargs.get("user", ""),
|
|
||||||
"modelParameters": optional_params,
|
|
||||||
"messages": messages,
|
|
||||||
"response": response_obj,
|
|
||||||
"usage": usage,
|
|
||||||
"metadata": clean_metadata,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Ensure everything in the payload is converted to str
|
# Ensure everything in the payload is converted to str
|
||||||
for key, value in payload.items():
|
payload: Optional[StandardLoggingPayload] = kwargs.get(
|
||||||
try:
|
"standard_logging_object", None
|
||||||
payload[key] = str(value)
|
)
|
||||||
except:
|
|
||||||
# non blocking if it can't cast to a str
|
if payload is None:
|
||||||
pass
|
return
|
||||||
|
|
||||||
s3_file_name = litellm.utils.get_logging_id(start_time, payload) or ""
|
s3_file_name = litellm.utils.get_logging_id(start_time, payload) or ""
|
||||||
s3_object_key = (
|
s3_object_key = (
|
||||||
|
|
|
@ -10,6 +10,7 @@ import sys
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
import uuid
|
import uuid
|
||||||
|
from datetime import datetime as dt_object
|
||||||
from typing import Any, Callable, Dict, List, Literal, Optional, Union
|
from typing import Any, Callable, Dict, List, Literal, Optional, Union
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
@ -33,6 +34,8 @@ from litellm.types.utils import (
|
||||||
EmbeddingResponse,
|
EmbeddingResponse,
|
||||||
ImageResponse,
|
ImageResponse,
|
||||||
ModelResponse,
|
ModelResponse,
|
||||||
|
StandardLoggingMetadata,
|
||||||
|
StandardLoggingPayload,
|
||||||
TextCompletionResponse,
|
TextCompletionResponse,
|
||||||
TranscriptionResponse,
|
TranscriptionResponse,
|
||||||
)
|
)
|
||||||
|
@ -560,6 +563,14 @@ class Logging:
|
||||||
self.model_call_details["log_event_type"] = "successful_api_call"
|
self.model_call_details["log_event_type"] = "successful_api_call"
|
||||||
self.model_call_details["end_time"] = end_time
|
self.model_call_details["end_time"] = end_time
|
||||||
self.model_call_details["cache_hit"] = cache_hit
|
self.model_call_details["cache_hit"] = cache_hit
|
||||||
|
self.model_call_details["standard_logging_object"] = (
|
||||||
|
get_standard_logging_object_payload(
|
||||||
|
kwargs=self.model_call_details,
|
||||||
|
init_response_obj=result,
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
)
|
||||||
|
)
|
||||||
## if model in model cost map - log the response cost
|
## if model in model cost map - log the response cost
|
||||||
## else set cost to None
|
## else set cost to None
|
||||||
if (
|
if (
|
||||||
|
@ -2166,3 +2177,123 @@ def use_custom_pricing_for_model(litellm_params: Optional[dict]) -> bool:
|
||||||
if k in SPECIAL_MODEL_INFO_PARAMS:
|
if k in SPECIAL_MODEL_INFO_PARAMS:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def get_standard_logging_object_payload(
|
||||||
|
kwargs: dict, init_response_obj: Any, start_time: dt_object, end_time: dt_object
|
||||||
|
) -> Optional[StandardLoggingPayload]:
|
||||||
|
if kwargs is None:
|
||||||
|
kwargs = {}
|
||||||
|
if init_response_obj is None:
|
||||||
|
response_obj = {}
|
||||||
|
elif isinstance(init_response_obj, BaseModel):
|
||||||
|
response_obj = init_response_obj.model_dump()
|
||||||
|
elif isinstance(init_response_obj, dict):
|
||||||
|
response_obj = init_response_obj
|
||||||
|
|
||||||
|
# standardize this function to be used across, s3, dynamoDB, langfuse logging
|
||||||
|
litellm_params = kwargs.get("litellm_params", {})
|
||||||
|
proxy_server_request = litellm_params.get("proxy_server_request") or {}
|
||||||
|
end_user_id = proxy_server_request.get("body", {}).get("user", None)
|
||||||
|
metadata = (
|
||||||
|
litellm_params.get("metadata", {}) or {}
|
||||||
|
) # if litellm_params['metadata'] == None
|
||||||
|
completion_start_time = kwargs.get("completion_start_time", end_time)
|
||||||
|
call_type = kwargs.get("call_type")
|
||||||
|
cache_hit = kwargs.get("cache_hit", False)
|
||||||
|
usage = response_obj.get("usage", None) or {}
|
||||||
|
if type(usage) == litellm.Usage:
|
||||||
|
usage = dict(usage)
|
||||||
|
id = response_obj.get("id", kwargs.get("litellm_call_id"))
|
||||||
|
api_key = metadata.get("user_api_key", "")
|
||||||
|
if api_key is not None and isinstance(api_key, str) and api_key.startswith("sk-"):
|
||||||
|
# redact the api key
|
||||||
|
api_key = "REDACTED-BY-LITELLM--contains-sk-keyword"
|
||||||
|
|
||||||
|
_model_id = metadata.get("model_info", {}).get("id", "")
|
||||||
|
_model_group = metadata.get("model_group", "")
|
||||||
|
|
||||||
|
request_tags = (
|
||||||
|
json.dumps(metadata.get("tags", []))
|
||||||
|
if isinstance(metadata.get("tags", []), list)
|
||||||
|
else "[]"
|
||||||
|
)
|
||||||
|
|
||||||
|
# cleanup timestamps
|
||||||
|
if isinstance(start_time, datetime.datetime):
|
||||||
|
start_time_float = start_time.timestamp()
|
||||||
|
if isinstance(end_time, datetime.datetime):
|
||||||
|
end_time_float = end_time.timestamp()
|
||||||
|
if isinstance(completion_start_time, datetime.datetime):
|
||||||
|
completion_start_time_float = completion_start_time.timestamp()
|
||||||
|
|
||||||
|
# clean up litellm metadata
|
||||||
|
clean_metadata = StandardLoggingMetadata(
|
||||||
|
user_api_key=None,
|
||||||
|
user_api_key_alias=None,
|
||||||
|
user_api_key_team_id=None,
|
||||||
|
user_api_key_user_id=None,
|
||||||
|
user_api_key_team_alias=None,
|
||||||
|
spend_logs_metadata=None,
|
||||||
|
requester_ip_address=None,
|
||||||
|
)
|
||||||
|
if isinstance(metadata, dict):
|
||||||
|
# Filter the metadata dictionary to include only the specified keys
|
||||||
|
clean_metadata = StandardLoggingMetadata(
|
||||||
|
**{ # type: ignore
|
||||||
|
key: metadata[key]
|
||||||
|
for key in StandardLoggingMetadata.__annotations__.keys()
|
||||||
|
if key in metadata
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if litellm.cache is not None:
|
||||||
|
cache_key = litellm.cache.get_cache_key(**kwargs)
|
||||||
|
else:
|
||||||
|
cache_key = "Cache OFF"
|
||||||
|
if cache_hit is True:
|
||||||
|
import time
|
||||||
|
|
||||||
|
id = f"{id}_cache_hit{time.time()}" # do not duplicate the request id
|
||||||
|
|
||||||
|
try:
|
||||||
|
payload: StandardLoggingPayload = StandardLoggingPayload(
|
||||||
|
id=str(id),
|
||||||
|
call_type=call_type or "",
|
||||||
|
api_key=str(api_key),
|
||||||
|
cache_hit=cache_hit,
|
||||||
|
startTime=start_time_float,
|
||||||
|
endTime=end_time_float,
|
||||||
|
completionStartTime=completion_start_time_float,
|
||||||
|
model=kwargs.get("model", "") or "",
|
||||||
|
user=metadata.get("user_api_key_user_id", "") or "",
|
||||||
|
team_id=metadata.get("user_api_key_team_id", "") or "",
|
||||||
|
metadata=clean_metadata,
|
||||||
|
cache_key=cache_key,
|
||||||
|
spend=kwargs.get("response_cost", 0),
|
||||||
|
total_tokens=usage.get("total_tokens", 0),
|
||||||
|
prompt_tokens=usage.get("prompt_tokens", 0),
|
||||||
|
completion_tokens=usage.get("completion_tokens", 0),
|
||||||
|
request_tags=request_tags,
|
||||||
|
end_user=end_user_id or "",
|
||||||
|
api_base=litellm_params.get("api_base", ""),
|
||||||
|
model_group=_model_group,
|
||||||
|
model_id=_model_id,
|
||||||
|
requester_ip_address=clean_metadata.get("requester_ip_address", None),
|
||||||
|
messages=kwargs.get("messages"),
|
||||||
|
response=response_obj,
|
||||||
|
model_parameters=kwargs.get("optional_params", None),
|
||||||
|
)
|
||||||
|
|
||||||
|
verbose_logger.debug(
|
||||||
|
"Standard Logging: created payload - payload: %s\n\n", payload
|
||||||
|
)
|
||||||
|
|
||||||
|
return payload
|
||||||
|
except Exception as e:
|
||||||
|
verbose_logger.warning(
|
||||||
|
"Error creating standard logging object - {}\n{}".format(
|
||||||
|
str(e), traceback.format_exc()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
|
@ -4,3 +4,11 @@ model_list:
|
||||||
model: "gpt-4"
|
model: "gpt-4"
|
||||||
model_info:
|
model_info:
|
||||||
my_custom_key: "my_custom_value"
|
my_custom_key: "my_custom_value"
|
||||||
|
|
||||||
|
litellm_settings:
|
||||||
|
success_callback: ["s3"]
|
||||||
|
s3_callback_params:
|
||||||
|
s3_bucket_name: mytestbucketlitellm # AWS Bucket Name for S3
|
||||||
|
s3_region_name: us-west-2 # AWS Region Name for S3
|
||||||
|
s3_aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID # us os.environ/<variable name> to pass environment variables. This is AWS Access Key ID for S3
|
||||||
|
s3_aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY # AWS Secret Access Key for S3
|
||||||
|
|
|
@ -1166,3 +1166,55 @@ def test_turn_off_message_logging():
|
||||||
|
|
||||||
time.sleep(2)
|
time.sleep(2)
|
||||||
assert len(customHandler.errors) == 0
|
assert len(customHandler.errors) == 0
|
||||||
|
|
||||||
|
|
||||||
|
##### VALID JSON ######
|
||||||
|
|
||||||
|
|
||||||
|
def test_standard_logging_payload():
|
||||||
|
"""
|
||||||
|
Ensure valid standard_logging_payload is passed for logging calls to s3
|
||||||
|
|
||||||
|
Motivation: provide a standard set of things that are logged to s3/gcs/future integrations across all llm calls
|
||||||
|
"""
|
||||||
|
from litellm.types.utils import StandardLoggingPayload
|
||||||
|
|
||||||
|
# sync completion
|
||||||
|
customHandler = CompletionCustomHandler()
|
||||||
|
litellm.callbacks = [customHandler]
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
customHandler, "log_success_event", new=MagicMock()
|
||||||
|
) as mock_client:
|
||||||
|
_ = litellm.completion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||||
|
mock_response="Going well!",
|
||||||
|
)
|
||||||
|
|
||||||
|
time.sleep(2)
|
||||||
|
mock_client.assert_called_once()
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"mock_client_post.call_args: {mock_client.call_args.kwargs['kwargs'].keys()}"
|
||||||
|
)
|
||||||
|
assert "standard_logging_object" in mock_client.call_args.kwargs["kwargs"]
|
||||||
|
assert (
|
||||||
|
mock_client.call_args.kwargs["kwargs"]["standard_logging_object"]
|
||||||
|
is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
print(mock_client.call_args.kwargs["kwargs"]["standard_logging_object"])
|
||||||
|
|
||||||
|
keys_list = list(StandardLoggingPayload.__annotations__.keys())
|
||||||
|
|
||||||
|
for k in keys_list:
|
||||||
|
assert (
|
||||||
|
k in mock_client.call_args.kwargs["kwargs"]["standard_logging_object"]
|
||||||
|
)
|
||||||
|
|
||||||
|
## json serializable
|
||||||
|
json_str_payload = json.dumps(
|
||||||
|
mock_client.call_args.kwargs["kwargs"]["standard_logging_object"]
|
||||||
|
)
|
||||||
|
json.loads(json_str_payload)
|
||||||
|
|
|
@ -1166,3 +1166,47 @@ class AdapterCompletionStreamWrapper:
|
||||||
raise StopIteration
|
raise StopIteration
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
raise StopAsyncIteration
|
raise StopAsyncIteration
|
||||||
|
|
||||||
|
|
||||||
|
class StandardLoggingMetadata(TypedDict):
|
||||||
|
"""
|
||||||
|
Specific metadata k,v pairs logged to integration for easier cost tracking
|
||||||
|
"""
|
||||||
|
|
||||||
|
user_api_key: Optional[str]
|
||||||
|
user_api_key_alias: Optional[str]
|
||||||
|
user_api_key_team_id: Optional[str]
|
||||||
|
user_api_key_user_id: Optional[str]
|
||||||
|
user_api_key_team_alias: Optional[str]
|
||||||
|
spend_logs_metadata: Optional[
|
||||||
|
dict
|
||||||
|
] # special param to log k,v pairs to spendlogs for a call
|
||||||
|
requester_ip_address: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
|
class StandardLoggingPayload(TypedDict):
|
||||||
|
id: str
|
||||||
|
call_type: str
|
||||||
|
api_key: str
|
||||||
|
spend: float
|
||||||
|
total_tokens: int
|
||||||
|
prompt_tokens: int
|
||||||
|
completion_tokens: int
|
||||||
|
startTime: float
|
||||||
|
endTime: float
|
||||||
|
completionStartTime: float
|
||||||
|
model: str
|
||||||
|
model_id: Optional[str]
|
||||||
|
model_group: Optional[str]
|
||||||
|
api_base: str
|
||||||
|
user: str
|
||||||
|
metadata: StandardLoggingMetadata
|
||||||
|
cache_hit: Optional[bool]
|
||||||
|
cache_key: Optional[str]
|
||||||
|
request_tags: str # json str
|
||||||
|
team_id: Optional[str]
|
||||||
|
end_user: Optional[str]
|
||||||
|
requester_ip_address: Optional[str]
|
||||||
|
messages: Optional[Union[str, list, dict]]
|
||||||
|
response: Optional[Union[str, list, dict]]
|
||||||
|
model_parameters: dict
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue