mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
387 lines
13 KiB
Python
387 lines
13 KiB
Python
#### What this does ####
|
|
# On success, logs events to Promptlayer
|
|
import traceback
|
|
from typing import (
|
|
TYPE_CHECKING,
|
|
Any,
|
|
AsyncGenerator,
|
|
List,
|
|
Literal,
|
|
Optional,
|
|
Tuple,
|
|
Union,
|
|
)
|
|
|
|
from pydantic import BaseModel
|
|
|
|
from litellm.caching.caching import DualCache
|
|
from litellm.types.integrations.argilla import ArgillaItem
|
|
from litellm.types.llms.openai import AllMessageValues, ChatCompletionRequest
|
|
from litellm.types.utils import (
|
|
AdapterCompletionStreamWrapper,
|
|
LLMResponseTypes,
|
|
ModelResponse,
|
|
ModelResponseStream,
|
|
StandardCallbackDynamicParams,
|
|
StandardLoggingPayload,
|
|
)
|
|
from litellm_proxy._types import UserAPIKeyAuth
|
|
|
|
if TYPE_CHECKING:
|
|
from opentelemetry.trace import Span as _Span
|
|
|
|
Span = Union[_Span, Any]
|
|
else:
|
|
Span = Any
|
|
|
|
|
|
class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
|
|
# Class variables or attributes
|
|
def __init__(self, message_logging: bool = True) -> None:
|
|
self.message_logging = message_logging
|
|
pass
|
|
|
|
def log_pre_api_call(self, model, messages, kwargs):
|
|
pass
|
|
|
|
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
|
|
pass
|
|
|
|
def log_stream_event(self, kwargs, response_obj, start_time, end_time):
|
|
pass
|
|
|
|
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
|
pass
|
|
|
|
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
|
pass
|
|
|
|
#### ASYNC ####
|
|
|
|
async def async_log_stream_event(self, kwargs, response_obj, start_time, end_time):
|
|
pass
|
|
|
|
async def async_log_pre_api_call(self, model, messages, kwargs):
|
|
pass
|
|
|
|
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
|
pass
|
|
|
|
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
|
pass
|
|
|
|
#### PROMPT MANAGEMENT HOOKS ####
|
|
|
|
async def async_get_chat_completion_prompt(
|
|
self,
|
|
model: str,
|
|
messages: List[AllMessageValues],
|
|
non_default_params: dict,
|
|
prompt_id: str,
|
|
prompt_variables: Optional[dict],
|
|
dynamic_callback_params: StandardCallbackDynamicParams,
|
|
) -> Tuple[str, List[AllMessageValues], dict]:
|
|
"""
|
|
Returns:
|
|
- model: str - the model to use (can be pulled from prompt management tool)
|
|
- messages: List[AllMessageValues] - the messages to use (can be pulled from prompt management tool)
|
|
- non_default_params: dict - update with any optional params (e.g. temperature, max_tokens, etc.) to use (can be pulled from prompt management tool)
|
|
"""
|
|
return model, messages, non_default_params
|
|
|
|
def get_chat_completion_prompt(
|
|
self,
|
|
model: str,
|
|
messages: List[AllMessageValues],
|
|
non_default_params: dict,
|
|
prompt_id: Optional[str],
|
|
prompt_variables: Optional[dict],
|
|
dynamic_callback_params: StandardCallbackDynamicParams,
|
|
) -> Tuple[str, List[AllMessageValues], dict]:
|
|
"""
|
|
Returns:
|
|
- model: str - the model to use (can be pulled from prompt management tool)
|
|
- messages: List[AllMessageValues] - the messages to use (can be pulled from prompt management tool)
|
|
- non_default_params: dict - update with any optional params (e.g. temperature, max_tokens, etc.) to use (can be pulled from prompt management tool)
|
|
"""
|
|
return model, messages, non_default_params
|
|
|
|
#### PRE-CALL CHECKS - router/proxy only ####
|
|
"""
|
|
Allows usage-based-routing-v2 to run pre-call rpm checks within the picked deployment's semaphore (concurrency-safe tpm/rpm checks).
|
|
"""
|
|
|
|
async def async_filter_deployments(
|
|
self,
|
|
model: str,
|
|
healthy_deployments: List,
|
|
messages: Optional[List[AllMessageValues]],
|
|
request_kwargs: Optional[dict] = None,
|
|
parent_otel_span: Optional[Span] = None,
|
|
) -> List[dict]:
|
|
return healthy_deployments
|
|
|
|
async def async_pre_call_check(
|
|
self, deployment: dict, parent_otel_span: Optional[Span]
|
|
) -> Optional[dict]:
|
|
pass
|
|
|
|
def pre_call_check(self, deployment: dict) -> Optional[dict]:
|
|
pass
|
|
|
|
#### Fallback Events - router/proxy only ####
|
|
async def log_model_group_rate_limit_error(
|
|
self, exception: Exception, original_model_group: Optional[str], kwargs: dict
|
|
):
|
|
pass
|
|
|
|
async def log_success_fallback_event(
|
|
self, original_model_group: str, kwargs: dict, original_exception: Exception
|
|
):
|
|
pass
|
|
|
|
async def log_failure_fallback_event(
|
|
self, original_model_group: str, kwargs: dict, original_exception: Exception
|
|
):
|
|
pass
|
|
|
|
#### ADAPTERS #### Allow calling 100+ LLMs in custom format - https://github.com/BerriAI/litellm/pulls
|
|
|
|
def translate_completion_input_params(
|
|
self, kwargs
|
|
) -> Optional[ChatCompletionRequest]:
|
|
"""
|
|
Translates the input params, from the provider's native format to the litellm.completion() format.
|
|
"""
|
|
pass
|
|
|
|
def translate_completion_output_params(
|
|
self, response: ModelResponse
|
|
) -> Optional[BaseModel]:
|
|
"""
|
|
Translates the output params, from the OpenAI format to the custom format.
|
|
"""
|
|
pass
|
|
|
|
def translate_completion_output_params_streaming(
|
|
self, completion_stream: Any
|
|
) -> Optional[AdapterCompletionStreamWrapper]:
|
|
"""
|
|
Translates the streaming chunk, from the OpenAI format to the custom format.
|
|
"""
|
|
pass
|
|
|
|
### DATASET HOOKS #### - currently only used for Argilla
|
|
|
|
async def async_dataset_hook(
|
|
self,
|
|
logged_item: ArgillaItem,
|
|
standard_logging_payload: Optional[StandardLoggingPayload],
|
|
) -> Optional[ArgillaItem]:
|
|
"""
|
|
- Decide if the result should be logged to Argilla.
|
|
- Modify the result before logging to Argilla.
|
|
- Return None if the result should not be logged to Argilla.
|
|
"""
|
|
raise NotImplementedError("async_dataset_hook not implemented")
|
|
|
|
#### CALL HOOKS - proxy only ####
|
|
"""
|
|
Control the modify incoming / outgoung data before calling the model
|
|
"""
|
|
|
|
async def async_pre_call_hook(
|
|
self,
|
|
user_api_key_dict: UserAPIKeyAuth,
|
|
cache: DualCache,
|
|
data: dict,
|
|
call_type: Literal[
|
|
"completion",
|
|
"text_completion",
|
|
"embeddings",
|
|
"image_generation",
|
|
"moderation",
|
|
"audio_transcription",
|
|
"pass_through_endpoint",
|
|
"rerank",
|
|
],
|
|
) -> Optional[
|
|
Union[Exception, str, dict]
|
|
]: # raise exception if invalid, return a str for the user to receive - if rejected, or return a modified dictionary for passing into litellm
|
|
pass
|
|
|
|
async def async_post_call_failure_hook(
|
|
self,
|
|
request_data: dict,
|
|
original_exception: Exception,
|
|
user_api_key_dict: UserAPIKeyAuth,
|
|
):
|
|
pass
|
|
|
|
async def async_post_call_success_hook(
|
|
self,
|
|
data: dict,
|
|
user_api_key_dict: UserAPIKeyAuth,
|
|
response: LLMResponseTypes,
|
|
) -> Any:
|
|
pass
|
|
|
|
async def async_logging_hook(
|
|
self, kwargs: dict, result: Any, call_type: str
|
|
) -> Tuple[dict, Any]:
|
|
"""For masking logged request/response. Return a modified version of the request/result."""
|
|
return kwargs, result
|
|
|
|
def logging_hook(
|
|
self, kwargs: dict, result: Any, call_type: str
|
|
) -> Tuple[dict, Any]:
|
|
"""For masking logged request/response. Return a modified version of the request/result."""
|
|
return kwargs, result
|
|
|
|
async def async_moderation_hook(
|
|
self,
|
|
data: dict,
|
|
user_api_key_dict: UserAPIKeyAuth,
|
|
call_type: Literal[
|
|
"completion",
|
|
"embeddings",
|
|
"image_generation",
|
|
"moderation",
|
|
"audio_transcription",
|
|
"responses",
|
|
],
|
|
) -> Any:
|
|
pass
|
|
|
|
async def async_post_call_streaming_hook(
|
|
self,
|
|
user_api_key_dict: UserAPIKeyAuth,
|
|
response: str,
|
|
) -> Any:
|
|
pass
|
|
|
|
async def async_post_call_streaming_iterator_hook(
|
|
self,
|
|
user_api_key_dict: UserAPIKeyAuth,
|
|
response: Any,
|
|
request_data: dict,
|
|
) -> AsyncGenerator[ModelResponseStream, None]:
|
|
async for item in response:
|
|
yield item
|
|
|
|
#### SINGLE-USE #### - https://docs.litellm.ai/docs/observability/custom_callback#using-your-custom-callback-function
|
|
|
|
def log_input_event(self, model, messages, kwargs, print_verbose, callback_func):
|
|
try:
|
|
kwargs["model"] = model
|
|
kwargs["messages"] = messages
|
|
kwargs["log_event_type"] = "pre_api_call"
|
|
callback_func(
|
|
kwargs,
|
|
)
|
|
print_verbose(f"Custom Logger - model call details: {kwargs}")
|
|
except Exception:
|
|
print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
|
|
|
|
async def async_log_input_event(
|
|
self, model, messages, kwargs, print_verbose, callback_func
|
|
):
|
|
try:
|
|
kwargs["model"] = model
|
|
kwargs["messages"] = messages
|
|
kwargs["log_event_type"] = "pre_api_call"
|
|
await callback_func(
|
|
kwargs,
|
|
)
|
|
print_verbose(f"Custom Logger - model call details: {kwargs}")
|
|
except Exception:
|
|
print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
|
|
|
|
def log_event(
|
|
self, kwargs, response_obj, start_time, end_time, print_verbose, callback_func
|
|
):
|
|
# Method definition
|
|
try:
|
|
kwargs["log_event_type"] = "post_api_call"
|
|
callback_func(
|
|
kwargs, # kwargs to func
|
|
response_obj,
|
|
start_time,
|
|
end_time,
|
|
)
|
|
except Exception:
|
|
print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
|
|
pass
|
|
|
|
async def async_log_event(
|
|
self, kwargs, response_obj, start_time, end_time, print_verbose, callback_func
|
|
):
|
|
# Method definition
|
|
try:
|
|
kwargs["log_event_type"] = "post_api_call"
|
|
await callback_func(
|
|
kwargs, # kwargs to func
|
|
response_obj,
|
|
start_time,
|
|
end_time,
|
|
)
|
|
except Exception:
|
|
print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
|
|
pass
|
|
|
|
# Useful helpers for custom logger classes
|
|
|
|
def truncate_standard_logging_payload_content(
|
|
self,
|
|
standard_logging_object: StandardLoggingPayload,
|
|
):
|
|
"""
|
|
Truncate error strings and message content in logging payload
|
|
|
|
Some loggers like DataDog/ GCS Bucket have a limit on the size of the payload. (1MB)
|
|
|
|
This function truncates the error string and the message content if they exceed a certain length.
|
|
"""
|
|
MAX_STR_LENGTH = 10_000
|
|
|
|
# Truncate fields that might exceed max length
|
|
fields_to_truncate = ["error_str", "messages", "response"]
|
|
for field in fields_to_truncate:
|
|
self._truncate_field(
|
|
standard_logging_object=standard_logging_object,
|
|
field_name=field,
|
|
max_length=MAX_STR_LENGTH,
|
|
)
|
|
|
|
def _truncate_field(
|
|
self,
|
|
standard_logging_object: StandardLoggingPayload,
|
|
field_name: str,
|
|
max_length: int,
|
|
) -> None:
|
|
"""
|
|
Helper function to truncate a field in the logging payload
|
|
|
|
This converts the field to a string and then truncates it if it exceeds the max length.
|
|
|
|
Why convert to string ?
|
|
1. User was sending a poorly formatted list for `messages` field, we could not predict where they would send content
|
|
- Converting to string and then truncating the logged content catches this
|
|
2. We want to avoid modifying the original `messages`, `response`, and `error_str` in the logging payload since these are in kwargs and could be returned to the user
|
|
"""
|
|
field_value = standard_logging_object.get(field_name) # type: ignore
|
|
if field_value:
|
|
str_value = str(field_value)
|
|
if len(str_value) > max_length:
|
|
standard_logging_object[field_name] = self._truncate_text( # type: ignore
|
|
text=str_value, max_length=max_length
|
|
)
|
|
|
|
def _truncate_text(self, text: str, max_length: int) -> str:
|
|
"""Truncate text if it exceeds max_length"""
|
|
return (
|
|
text[:max_length]
|
|
+ "...truncated by litellm, this logger does not support large content"
|
|
if len(text) > max_length
|
|
else text
|
|
)
|