litellm-mirror/litellm/integrations/custom_logger.py
Krish Dholakia 2f08341a08
Litellm dev readd prompt caching (#7299)
* fix(router.py): re-add saving model id on prompt caching valid successful deployment

* fix(router.py): introduce optional pre_call_checks

isolate prompt caching logic in a separate file

* fix(prompt_caching_deployment_check.py): fix import

* fix(router.py): new 'async_filter_deployments' event hook

allows custom logger to filter deployments returned to routing strategy

* feat(prompt_caching_deployment_check.py): initial working commit of prompt caching based routing

* fix(cooldown_callbacks.py): fix linting error

* fix(budget_limiter.py): move budget logger to async_filter_deployment hook

* test: add unit test

* test(test_router_helper_utils.py): add unit testing

* fix(budget_limiter.py): fix linting errors

* docs(config_settings.md): add 'optional_pre_call_checks' to router_settings param docs
2024-12-18 15:13:49 -08:00

278 lines
8.6 KiB
Python

#### What this does ####
# On success, logs events to Promptlayer
import os
import traceback
from datetime import datetime as datetimeObj
from typing import TYPE_CHECKING, Any, List, Literal, Optional, Tuple, Union
import dotenv
from pydantic import BaseModel
from litellm.caching.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth
from litellm.types.integrations.argilla import ArgillaItem
from litellm.types.llms.openai import AllMessageValues, ChatCompletionRequest
from litellm.types.services import ServiceLoggerPayload
from litellm.types.utils import (
AdapterCompletionStreamWrapper,
EmbeddingResponse,
ImageResponse,
ModelResponse,
StandardLoggingPayload,
)
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Span = _Span
else:
Span = Any
class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
# Class variables or attributes
def __init__(self, message_logging: bool = True) -> None:
self.message_logging = message_logging
pass
def log_pre_api_call(self, model, messages, kwargs):
pass
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
pass
def log_stream_event(self, kwargs, response_obj, start_time, end_time):
pass
def log_success_event(self, kwargs, response_obj, start_time, end_time):
pass
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
pass
#### ASYNC ####
async def async_log_stream_event(self, kwargs, response_obj, start_time, end_time):
pass
async def async_log_pre_api_call(self, model, messages, kwargs):
pass
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
pass
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
pass
#### PRE-CALL CHECKS - router/proxy only ####
"""
Allows usage-based-routing-v2 to run pre-call rpm checks within the picked deployment's semaphore (concurrency-safe tpm/rpm checks).
"""
async def async_filter_deployments(
self,
model: str,
healthy_deployments: List,
messages: Optional[List[AllMessageValues]],
request_kwargs: Optional[dict] = None,
parent_otel_span: Optional[Span] = None,
) -> List[dict]:
return healthy_deployments
async def async_pre_call_check(
self, deployment: dict, parent_otel_span: Optional[Span]
) -> Optional[dict]:
pass
def pre_call_check(self, deployment: dict) -> Optional[dict]:
pass
#### Fallback Events - router/proxy only ####
async def log_model_group_rate_limit_error(
self, exception: Exception, original_model_group: Optional[str], kwargs: dict
):
pass
async def log_success_fallback_event(
self, original_model_group: str, kwargs: dict, original_exception: Exception
):
pass
async def log_failure_fallback_event(
self, original_model_group: str, kwargs: dict, original_exception: Exception
):
pass
#### ADAPTERS #### Allow calling 100+ LLMs in custom format - https://github.com/BerriAI/litellm/pulls
def translate_completion_input_params(
self, kwargs
) -> Optional[ChatCompletionRequest]:
"""
Translates the input params, from the provider's native format to the litellm.completion() format.
"""
pass
def translate_completion_output_params(
self, response: ModelResponse
) -> Optional[BaseModel]:
"""
Translates the output params, from the OpenAI format to the custom format.
"""
pass
def translate_completion_output_params_streaming(
self, completion_stream: Any
) -> Optional[AdapterCompletionStreamWrapper]:
"""
Translates the streaming chunk, from the OpenAI format to the custom format.
"""
pass
### DATASET HOOKS #### - currently only used for Argilla
async def async_dataset_hook(
self,
logged_item: ArgillaItem,
standard_logging_payload: Optional[StandardLoggingPayload],
) -> Optional[ArgillaItem]:
"""
- Decide if the result should be logged to Argilla.
- Modify the result before logging to Argilla.
- Return None if the result should not be logged to Argilla.
"""
raise NotImplementedError("async_dataset_hook not implemented")
#### CALL HOOKS - proxy only ####
"""
Control the modify incoming / outgoung data before calling the model
"""
async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: Literal[
"completion",
"text_completion",
"embeddings",
"image_generation",
"moderation",
"audio_transcription",
"pass_through_endpoint",
"rerank",
],
) -> Optional[
Union[Exception, str, dict]
]: # raise exception if invalid, return a str for the user to receive - if rejected, or return a modified dictionary for passing into litellm
pass
async def async_post_call_failure_hook(
self,
request_data: dict,
original_exception: Exception,
user_api_key_dict: UserAPIKeyAuth,
):
pass
async def async_post_call_success_hook(
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
response: Union[Any, ModelResponse, EmbeddingResponse, ImageResponse],
) -> Any:
pass
async def async_logging_hook(
self, kwargs: dict, result: Any, call_type: str
) -> Tuple[dict, Any]:
"""For masking logged request/response. Return a modified version of the request/result."""
return kwargs, result
def logging_hook(
self, kwargs: dict, result: Any, call_type: str
) -> Tuple[dict, Any]:
"""For masking logged request/response. Return a modified version of the request/result."""
return kwargs, result
async def async_moderation_hook(
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
call_type: Literal[
"completion",
"embeddings",
"image_generation",
"moderation",
"audio_transcription",
],
) -> Any:
pass
async def async_post_call_streaming_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
response: str,
) -> Any:
pass
#### SINGLE-USE #### - https://docs.litellm.ai/docs/observability/custom_callback#using-your-custom-callback-function
def log_input_event(self, model, messages, kwargs, print_verbose, callback_func):
try:
kwargs["model"] = model
kwargs["messages"] = messages
kwargs["log_event_type"] = "pre_api_call"
callback_func(
kwargs,
)
print_verbose(f"Custom Logger - model call details: {kwargs}")
except Exception:
print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
async def async_log_input_event(
self, model, messages, kwargs, print_verbose, callback_func
):
try:
kwargs["model"] = model
kwargs["messages"] = messages
kwargs["log_event_type"] = "pre_api_call"
await callback_func(
kwargs,
)
print_verbose(f"Custom Logger - model call details: {kwargs}")
except Exception:
print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
def log_event(
self, kwargs, response_obj, start_time, end_time, print_verbose, callback_func
):
# Method definition
try:
kwargs["log_event_type"] = "post_api_call"
callback_func(
kwargs, # kwargs to func
response_obj,
start_time,
end_time,
)
except Exception:
print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
pass
async def async_log_event(
self, kwargs, response_obj, start_time, end_time, print_verbose, callback_func
):
# Method definition
try:
kwargs["log_event_type"] = "post_api_call"
await callback_func(
kwargs, # kwargs to func
response_obj,
start_time,
end_time,
)
except Exception:
print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
pass