mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
(fixes) gcs bucket key based logging (#6044)
* fixes for gcs bucket logging * fix StandardCallbackDynamicParams * fix - gcs logging when payload is not serializable * add test_add_callback_via_key_litellm_pre_call_utils_gcs_bucket * working success callbacks * linting fixes * fix linting error * add type hints to functions * fixes for dynamic success and failure logging * fix for test_async_chat_openai_stream
This commit is contained in:
parent
793593e735
commit
670ecda4e2
9 changed files with 446 additions and 39 deletions
|
@ -82,7 +82,7 @@ class GCSBucketLogger(GCSBucketBase):
|
||||||
if logging_payload is None:
|
if logging_payload is None:
|
||||||
raise ValueError("standard_logging_object not found in kwargs")
|
raise ValueError("standard_logging_object not found in kwargs")
|
||||||
|
|
||||||
json_logged_payload = json.dumps(logging_payload)
|
json_logged_payload = json.dumps(logging_payload, default=str)
|
||||||
|
|
||||||
# Get the current date
|
# Get the current date
|
||||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||||
|
@ -137,7 +137,7 @@ class GCSBucketLogger(GCSBucketBase):
|
||||||
_litellm_params = kwargs.get("litellm_params") or {}
|
_litellm_params = kwargs.get("litellm_params") or {}
|
||||||
metadata = _litellm_params.get("metadata") or {}
|
metadata = _litellm_params.get("metadata") or {}
|
||||||
|
|
||||||
json_logged_payload = json.dumps(logging_payload)
|
json_logged_payload = json.dumps(logging_payload, default=str)
|
||||||
|
|
||||||
# Get the current date
|
# Get the current date
|
||||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||||
|
|
|
@ -192,16 +192,28 @@ class Logging:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model,
|
model: str,
|
||||||
messages,
|
messages,
|
||||||
stream,
|
stream,
|
||||||
call_type,
|
call_type,
|
||||||
start_time,
|
start_time,
|
||||||
litellm_call_id,
|
litellm_call_id: str,
|
||||||
function_id,
|
function_id: str,
|
||||||
dynamic_success_callbacks=None,
|
dynamic_input_callbacks: Optional[
|
||||||
dynamic_failure_callbacks=None,
|
List[Union[str, Callable, CustomLogger]]
|
||||||
dynamic_async_success_callbacks=None,
|
] = None,
|
||||||
|
dynamic_success_callbacks: Optional[
|
||||||
|
List[Union[str, Callable, CustomLogger]]
|
||||||
|
] = None,
|
||||||
|
dynamic_async_success_callbacks: Optional[
|
||||||
|
List[Union[str, Callable, CustomLogger]]
|
||||||
|
] = None,
|
||||||
|
dynamic_failure_callbacks: Optional[
|
||||||
|
List[Union[str, Callable, CustomLogger]]
|
||||||
|
] = None,
|
||||||
|
dynamic_async_failure_callbacks: Optional[
|
||||||
|
List[Union[str, Callable, CustomLogger]]
|
||||||
|
] = None,
|
||||||
kwargs: Optional[Dict] = None,
|
kwargs: Optional[Dict] = None,
|
||||||
):
|
):
|
||||||
if messages is not None:
|
if messages is not None:
|
||||||
|
@ -230,27 +242,117 @@ class Logging:
|
||||||
[]
|
[]
|
||||||
) # for generating complete stream response
|
) # for generating complete stream response
|
||||||
self.model_call_details: Dict[Any, Any] = {}
|
self.model_call_details: Dict[Any, Any] = {}
|
||||||
self.dynamic_input_callbacks: List[Any] = (
|
|
||||||
[]
|
# Initialize dynamic callbacks
|
||||||
) # [TODO] callbacks set for just that call
|
self.dynamic_input_callbacks: Optional[
|
||||||
self.dynamic_failure_callbacks = dynamic_failure_callbacks
|
List[Union[str, Callable, CustomLogger]]
|
||||||
self.dynamic_success_callbacks = (
|
] = dynamic_input_callbacks
|
||||||
dynamic_success_callbacks # callbacks set for just that call
|
self.dynamic_success_callbacks: Optional[
|
||||||
)
|
List[Union[str, Callable, CustomLogger]]
|
||||||
self.dynamic_async_success_callbacks = (
|
] = dynamic_success_callbacks
|
||||||
dynamic_async_success_callbacks # callbacks set for just that call
|
self.dynamic_async_success_callbacks: Optional[
|
||||||
)
|
List[Union[str, Callable, CustomLogger]]
|
||||||
|
] = dynamic_async_success_callbacks
|
||||||
|
self.dynamic_failure_callbacks: Optional[
|
||||||
|
List[Union[str, Callable, CustomLogger]]
|
||||||
|
] = dynamic_failure_callbacks
|
||||||
|
self.dynamic_async_failure_callbacks: Optional[
|
||||||
|
List[Union[str, Callable, CustomLogger]]
|
||||||
|
] = dynamic_async_failure_callbacks
|
||||||
|
|
||||||
|
# Process dynamic callbacks
|
||||||
|
self.process_dynamic_callbacks()
|
||||||
|
|
||||||
## DYNAMIC LANGFUSE / GCS / logging callback KEYS ##
|
## DYNAMIC LANGFUSE / GCS / logging callback KEYS ##
|
||||||
self.standard_callback_dynamic_params: StandardCallbackDynamicParams = (
|
self.standard_callback_dynamic_params: StandardCallbackDynamicParams = (
|
||||||
self.initialize_standard_callback_dynamic_params(kwargs)
|
self.initialize_standard_callback_dynamic_params(kwargs)
|
||||||
)
|
)
|
||||||
## TIME TO FIRST TOKEN LOGGING ##
|
|
||||||
|
|
||||||
|
## TIME TO FIRST TOKEN LOGGING ##
|
||||||
self.completion_start_time: Optional[datetime.datetime] = None
|
self.completion_start_time: Optional[datetime.datetime] = None
|
||||||
|
|
||||||
|
def process_dynamic_callbacks(self):
|
||||||
|
"""
|
||||||
|
Initializes CustomLogger compatible callbacks in self.dynamic_* callbacks
|
||||||
|
|
||||||
|
If a callback is in litellm._known_custom_logger_compatible_callbacks, it needs to be intialized and added to the respective dynamic_* callback list.
|
||||||
|
"""
|
||||||
|
# Process input callbacks
|
||||||
|
self.dynamic_input_callbacks = self._process_dynamic_callback_list(
|
||||||
|
self.dynamic_input_callbacks, dynamic_callbacks_type="input"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process failure callbacks
|
||||||
|
self.dynamic_failure_callbacks = self._process_dynamic_callback_list(
|
||||||
|
self.dynamic_failure_callbacks, dynamic_callbacks_type="failure"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process async failure callbacks
|
||||||
|
self.dynamic_async_failure_callbacks = self._process_dynamic_callback_list(
|
||||||
|
self.dynamic_async_failure_callbacks, dynamic_callbacks_type="async_failure"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process success callbacks
|
||||||
|
self.dynamic_success_callbacks = self._process_dynamic_callback_list(
|
||||||
|
self.dynamic_success_callbacks, dynamic_callbacks_type="success"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process async success callbacks
|
||||||
|
self.dynamic_async_success_callbacks = self._process_dynamic_callback_list(
|
||||||
|
self.dynamic_async_success_callbacks, dynamic_callbacks_type="async_success"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _process_dynamic_callback_list(
|
||||||
|
self,
|
||||||
|
callback_list: Optional[List[Union[str, Callable, CustomLogger]]],
|
||||||
|
dynamic_callbacks_type: Literal[
|
||||||
|
"input", "success", "failure", "async_success", "async_failure"
|
||||||
|
],
|
||||||
|
) -> Optional[List[Union[str, Callable, CustomLogger]]]:
|
||||||
|
"""
|
||||||
|
Helper function to initialize CustomLogger compatible callbacks in self.dynamic_* callbacks
|
||||||
|
|
||||||
|
- If a callback is in litellm._known_custom_logger_compatible_callbacks,
|
||||||
|
replace the string with the initialized callback class.
|
||||||
|
- If dynamic callback is a "success" callback that is a known_custom_logger_compatible_callbacks then add it to dynamic_async_success_callbacks
|
||||||
|
- If dynamic callback is a "failure" callback that is a known_custom_logger_compatible_callbacks then add it to dynamic_failure_callbacks
|
||||||
|
"""
|
||||||
|
if callback_list is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
processed_list: List[Union[str, Callable, CustomLogger]] = []
|
||||||
|
for callback in callback_list:
|
||||||
|
if (
|
||||||
|
isinstance(callback, str)
|
||||||
|
and callback in litellm._known_custom_logger_compatible_callbacks
|
||||||
|
):
|
||||||
|
callback_class = _init_custom_logger_compatible_class(
|
||||||
|
callback, internal_usage_cache=None, llm_router=None # type: ignore
|
||||||
|
)
|
||||||
|
if callback_class is not None:
|
||||||
|
processed_list.append(callback_class)
|
||||||
|
|
||||||
|
# If processing dynamic_success_callbacks, add to dynamic_async_success_callbacks
|
||||||
|
if dynamic_callbacks_type == "success":
|
||||||
|
if self.dynamic_async_success_callbacks is None:
|
||||||
|
self.dynamic_async_success_callbacks = []
|
||||||
|
self.dynamic_async_success_callbacks.append(callback_class)
|
||||||
|
elif dynamic_callbacks_type == "failure":
|
||||||
|
if self.dynamic_async_failure_callbacks is None:
|
||||||
|
self.dynamic_async_failure_callbacks = []
|
||||||
|
self.dynamic_async_failure_callbacks.append(callback_class)
|
||||||
|
else:
|
||||||
|
processed_list.append(callback)
|
||||||
|
return processed_list
|
||||||
|
|
||||||
def initialize_standard_callback_dynamic_params(
|
def initialize_standard_callback_dynamic_params(
|
||||||
self, kwargs: Optional[Dict] = None
|
self, kwargs: Optional[Dict] = None
|
||||||
) -> StandardCallbackDynamicParams:
|
) -> StandardCallbackDynamicParams:
|
||||||
|
"""
|
||||||
|
Initialize the standard callback dynamic params from the kwargs
|
||||||
|
|
||||||
|
checks if langfuse_secret_key, gcs_bucket_name in kwargs and sets the corresponding attributes in StandardCallbackDynamicParams
|
||||||
|
"""
|
||||||
standard_callback_dynamic_params = StandardCallbackDynamicParams()
|
standard_callback_dynamic_params = StandardCallbackDynamicParams()
|
||||||
if kwargs:
|
if kwargs:
|
||||||
_supported_callback_params = (
|
_supported_callback_params = (
|
||||||
|
@ -413,7 +515,7 @@ class Logging:
|
||||||
|
|
||||||
self.model_call_details["api_call_start_time"] = datetime.datetime.now()
|
self.model_call_details["api_call_start_time"] = datetime.datetime.now()
|
||||||
# Input Integration Logging -> If you want to log the fact that an attempt to call the model was made
|
# Input Integration Logging -> If you want to log the fact that an attempt to call the model was made
|
||||||
callbacks = litellm.input_callback + self.dynamic_input_callbacks
|
callbacks = litellm.input_callback + (self.dynamic_input_callbacks or [])
|
||||||
for callback in callbacks:
|
for callback in callbacks:
|
||||||
try:
|
try:
|
||||||
if callback == "supabase" and supabaseClient is not None:
|
if callback == "supabase" and supabaseClient is not None:
|
||||||
|
@ -529,7 +631,7 @@ class Logging:
|
||||||
)
|
)
|
||||||
# Input Integration Logging -> If you want to log the fact that an attempt to call the model was made
|
# Input Integration Logging -> If you want to log the fact that an attempt to call the model was made
|
||||||
|
|
||||||
callbacks = litellm.input_callback + self.dynamic_input_callbacks
|
callbacks = litellm.input_callback + (self.dynamic_input_callbacks or [])
|
||||||
for callback in callbacks:
|
for callback in callbacks:
|
||||||
try:
|
try:
|
||||||
if callback == "sentry" and add_breadcrumb:
|
if callback == "sentry" and add_breadcrumb:
|
||||||
|
@ -2004,8 +2106,25 @@ class Logging:
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
end_time=end_time,
|
end_time=end_time,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
callbacks = [] # init this to empty incase it's not created
|
||||||
|
|
||||||
|
if self.dynamic_async_failure_callbacks is not None and isinstance(
|
||||||
|
self.dynamic_async_failure_callbacks, list
|
||||||
|
):
|
||||||
|
callbacks = self.dynamic_async_failure_callbacks
|
||||||
|
## keep the internal functions ##
|
||||||
|
for callback in litellm._async_failure_callback:
|
||||||
|
if (
|
||||||
|
isinstance(callback, CustomLogger)
|
||||||
|
and "_PROXY_" in callback.__class__.__name__
|
||||||
|
):
|
||||||
|
callbacks.append(callback)
|
||||||
|
else:
|
||||||
|
callbacks = litellm._async_failure_callback
|
||||||
|
|
||||||
result = None # result sent to all loggers, init this to None incase it's not created
|
result = None # result sent to all loggers, init this to None incase it's not created
|
||||||
for callback in litellm._async_failure_callback:
|
for callback in callbacks:
|
||||||
try:
|
try:
|
||||||
if isinstance(callback, CustomLogger): # custom logger class
|
if isinstance(callback, CustomLogger): # custom logger class
|
||||||
await callback.async_log_failure_event(
|
await callback.async_log_failure_event(
|
||||||
|
|
|
@ -12,7 +12,7 @@ from typing_extensions import Annotated, TypedDict
|
||||||
|
|
||||||
from litellm.integrations.SlackAlerting.types import AlertType
|
from litellm.integrations.SlackAlerting.types import AlertType
|
||||||
from litellm.types.router import RouterErrors, UpdateRouterConfig
|
from litellm.types.router import RouterErrors, UpdateRouterConfig
|
||||||
from litellm.types.utils import ProviderField
|
from litellm.types.utils import ProviderField, StandardCallbackDynamicParams
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from opentelemetry.trace import Span as _Span
|
from opentelemetry.trace import Span as _Span
|
||||||
|
@ -959,21 +959,38 @@ class BlockKeyRequest(LiteLLMBase):
|
||||||
class AddTeamCallback(LiteLLMBase):
|
class AddTeamCallback(LiteLLMBase):
|
||||||
callback_name: str
|
callback_name: str
|
||||||
callback_type: Literal["success", "failure", "success_and_failure"]
|
callback_type: Literal["success", "failure", "success_and_failure"]
|
||||||
# for now - only supported for langfuse
|
callback_vars: Dict[str, str]
|
||||||
callback_vars: Dict[
|
|
||||||
Literal["langfuse_public_key", "langfuse_secret_key", "langfuse_host"], str
|
@model_validator(mode="before")
|
||||||
]
|
@classmethod
|
||||||
|
def validate_callback_vars(cls, values):
|
||||||
|
callback_vars = values.get("callback_vars", {})
|
||||||
|
valid_keys = set(StandardCallbackDynamicParams.__annotations__.keys())
|
||||||
|
for key in callback_vars:
|
||||||
|
if key not in valid_keys:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid callback variable: {key}. Must be one of {valid_keys}"
|
||||||
|
)
|
||||||
|
return values
|
||||||
|
|
||||||
|
|
||||||
class TeamCallbackMetadata(LiteLLMBase):
|
class TeamCallbackMetadata(LiteLLMBase):
|
||||||
success_callback: Optional[List[str]] = []
|
success_callback: Optional[List[str]] = []
|
||||||
failure_callback: Optional[List[str]] = []
|
failure_callback: Optional[List[str]] = []
|
||||||
# for now - only supported for langfuse
|
# for now - only supported for langfuse
|
||||||
callback_vars: Optional[
|
callback_vars: Optional[Dict[str, str]] = {}
|
||||||
Dict[
|
|
||||||
Literal["langfuse_public_key", "langfuse_secret_key", "langfuse_host"], str
|
@model_validator(mode="before")
|
||||||
]
|
@classmethod
|
||||||
] = {}
|
def validate_callback_vars(cls, values):
|
||||||
|
callback_vars = values.get("callback_vars", {})
|
||||||
|
valid_keys = set(StandardCallbackDynamicParams.__annotations__.keys())
|
||||||
|
for key in callback_vars:
|
||||||
|
if key not in valid_keys:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid callback variable: {key}. Must be one of {valid_keys}"
|
||||||
|
)
|
||||||
|
return values
|
||||||
|
|
||||||
|
|
||||||
class LiteLLM_TeamTable(TeamBase):
|
class LiteLLM_TeamTable(TeamBase):
|
||||||
|
|
|
@ -115,7 +115,7 @@ class PassThroughEndpointLogging:
|
||||||
encoding=None,
|
encoding=None,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
logging_obj.model = litellm_model_response.model
|
logging_obj.model = litellm_model_response.model or model
|
||||||
logging_obj.model_call_details["model"] = logging_obj.model
|
logging_obj.model_call_details["model"] = logging_obj.model
|
||||||
|
|
||||||
await logging_obj.async_success_handler(
|
await logging_obj.async_success_handler(
|
||||||
|
|
|
@ -5,7 +5,6 @@ model_list:
|
||||||
api_key: fake-key
|
api_key: fake-key
|
||||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||||
|
|
||||||
general_settings:
|
|
||||||
alerting: ["slack"]
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1384,8 +1384,12 @@ OPENAI_RESPONSE_HEADERS = [
|
||||||
|
|
||||||
|
|
||||||
class StandardCallbackDynamicParams(TypedDict, total=False):
|
class StandardCallbackDynamicParams(TypedDict, total=False):
|
||||||
|
# Langfuse dynamic params
|
||||||
langfuse_public_key: Optional[str]
|
langfuse_public_key: Optional[str]
|
||||||
langfuse_secret: Optional[str]
|
langfuse_secret: Optional[str]
|
||||||
|
langfuse_secret_key: Optional[str]
|
||||||
langfuse_host: Optional[str]
|
langfuse_host: Optional[str]
|
||||||
|
|
||||||
|
# GCS dynamic params
|
||||||
gcs_bucket_name: Optional[str]
|
gcs_bucket_name: Optional[str]
|
||||||
gcs_path_service_account: Optional[str]
|
gcs_path_service_account: Optional[str]
|
||||||
|
|
|
@ -58,6 +58,7 @@ import litellm.litellm_core_utils
|
||||||
import litellm.litellm_core_utils.audio_utils.utils
|
import litellm.litellm_core_utils.audio_utils.utils
|
||||||
import litellm.litellm_core_utils.json_validation_rule
|
import litellm.litellm_core_utils.json_validation_rule
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache
|
||||||
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||||
from litellm.litellm_core_utils.exception_mapping_utils import (
|
from litellm.litellm_core_utils.exception_mapping_utils import (
|
||||||
_get_litellm_response_headers,
|
_get_litellm_response_headers,
|
||||||
|
@ -430,9 +431,18 @@ def function_setup(
|
||||||
for index in reversed(removed_async_items):
|
for index in reversed(removed_async_items):
|
||||||
litellm.failure_callback.pop(index)
|
litellm.failure_callback.pop(index)
|
||||||
### DYNAMIC CALLBACKS ###
|
### DYNAMIC CALLBACKS ###
|
||||||
dynamic_success_callbacks = None
|
dynamic_success_callbacks: Optional[
|
||||||
dynamic_async_success_callbacks = None
|
List[Union[str, Callable, CustomLogger]]
|
||||||
dynamic_failure_callbacks = None
|
] = None
|
||||||
|
dynamic_async_success_callbacks: Optional[
|
||||||
|
List[Union[str, Callable, CustomLogger]]
|
||||||
|
] = None
|
||||||
|
dynamic_failure_callbacks: Optional[
|
||||||
|
List[Union[str, Callable, CustomLogger]]
|
||||||
|
] = None
|
||||||
|
dynamic_async_failure_callbacks: Optional[
|
||||||
|
List[Union[str, Callable, CustomLogger]]
|
||||||
|
] = None
|
||||||
if kwargs.get("success_callback", None) is not None and isinstance(
|
if kwargs.get("success_callback", None) is not None and isinstance(
|
||||||
kwargs["success_callback"], list
|
kwargs["success_callback"], list
|
||||||
):
|
):
|
||||||
|
@ -561,6 +571,7 @@ def function_setup(
|
||||||
dynamic_success_callbacks=dynamic_success_callbacks,
|
dynamic_success_callbacks=dynamic_success_callbacks,
|
||||||
dynamic_failure_callbacks=dynamic_failure_callbacks,
|
dynamic_failure_callbacks=dynamic_failure_callbacks,
|
||||||
dynamic_async_success_callbacks=dynamic_async_success_callbacks,
|
dynamic_async_success_callbacks=dynamic_async_success_callbacks,
|
||||||
|
dynamic_async_failure_callbacks=dynamic_async_failure_callbacks,
|
||||||
kwargs=kwargs,
|
kwargs=kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -267,7 +267,7 @@ async def test_basic_gcs_logger_failure():
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_basic_gcs_logging_per_request():
|
async def test_basic_gcs_logging_per_request_with_callback_set():
|
||||||
"""
|
"""
|
||||||
Test GCS Bucket logging per request
|
Test GCS Bucket logging per request
|
||||||
|
|
||||||
|
@ -391,3 +391,128 @@ async def test_basic_gcs_logging_per_request():
|
||||||
object_name=object_name,
|
object_name=object_name,
|
||||||
standard_callback_dynamic_params=standard_callback_dynamic_params,
|
standard_callback_dynamic_params=standard_callback_dynamic_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_basic_gcs_logging_per_request_with_no_litellm_callback_set():
|
||||||
|
"""
|
||||||
|
Test GCS Bucket logging per request
|
||||||
|
|
||||||
|
key difference: no litellm.callbacks set
|
||||||
|
|
||||||
|
Request 1 - pass gcs_bucket_name in kwargs
|
||||||
|
Request 2 - don't pass gcs_bucket_name in kwargs - ensure 'litellm-testing-bucket'
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
from litellm._logging import verbose_logger
|
||||||
|
|
||||||
|
verbose_logger.setLevel(logging.DEBUG)
|
||||||
|
load_vertex_ai_credentials()
|
||||||
|
gcs_logger = GCSBucketLogger()
|
||||||
|
|
||||||
|
GCS_BUCKET_NAME = "key-logging-project1"
|
||||||
|
standard_callback_dynamic_params: StandardCallbackDynamicParams = (
|
||||||
|
StandardCallbackDynamicParams(gcs_bucket_name=GCS_BUCKET_NAME)
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model="gpt-4o-mini",
|
||||||
|
temperature=0.7,
|
||||||
|
messages=[{"role": "user", "content": "This is a test"}],
|
||||||
|
max_tokens=10,
|
||||||
|
user="ishaan-2",
|
||||||
|
gcs_bucket_name=GCS_BUCKET_NAME,
|
||||||
|
success_callback=["gcs_bucket"],
|
||||||
|
failure_callback=["gcs_bucket"],
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
await asyncio.sleep(5)
|
||||||
|
|
||||||
|
# Get the current date
|
||||||
|
# Get the current date
|
||||||
|
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||||
|
|
||||||
|
# Modify the object_name to include the date-based folder
|
||||||
|
object_name = f"{current_date}%2F{response.id}"
|
||||||
|
|
||||||
|
print("object_name", object_name)
|
||||||
|
|
||||||
|
# Check if object landed on GCS
|
||||||
|
object_from_gcs = await gcs_logger.download_gcs_object(
|
||||||
|
object_name=object_name,
|
||||||
|
standard_callback_dynamic_params=standard_callback_dynamic_params,
|
||||||
|
)
|
||||||
|
print("object from gcs=", object_from_gcs)
|
||||||
|
# convert object_from_gcs from bytes to DICT
|
||||||
|
parsed_data = json.loads(object_from_gcs)
|
||||||
|
print("object_from_gcs as dict", parsed_data)
|
||||||
|
|
||||||
|
print("type of object_from_gcs", type(parsed_data))
|
||||||
|
|
||||||
|
gcs_payload = StandardLoggingPayload(**parsed_data)
|
||||||
|
|
||||||
|
assert gcs_payload["model"] == "gpt-4o-mini"
|
||||||
|
assert gcs_payload["messages"] == [{"role": "user", "content": "This is a test"}]
|
||||||
|
|
||||||
|
assert gcs_payload["response_cost"] > 0.0
|
||||||
|
|
||||||
|
assert gcs_payload["status"] == "success"
|
||||||
|
|
||||||
|
# clean up the object from GCS
|
||||||
|
await gcs_logger.delete_gcs_object(
|
||||||
|
object_name=object_name,
|
||||||
|
standard_callback_dynamic_params=standard_callback_dynamic_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
# make a failure request - assert that failure callback is hit
|
||||||
|
gcs_log_id = f"failure-test-{uuid.uuid4().hex}"
|
||||||
|
try:
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model="gpt-4o-mini",
|
||||||
|
temperature=0.7,
|
||||||
|
messages=[{"role": "user", "content": "This is a test"}],
|
||||||
|
max_tokens=10,
|
||||||
|
user="ishaan-2",
|
||||||
|
mock_response=litellm.BadRequestError(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
message="Error: 400: Bad Request: Invalid API key, please check your API key and try again.",
|
||||||
|
llm_provider="openai",
|
||||||
|
),
|
||||||
|
success_callback=["gcs_bucket"],
|
||||||
|
failure_callback=["gcs_bucket"],
|
||||||
|
gcs_bucket_name=GCS_BUCKET_NAME,
|
||||||
|
metadata={
|
||||||
|
"gcs_log_id": gcs_log_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
await asyncio.sleep(5)
|
||||||
|
|
||||||
|
# check if the failure object is logged in GCS
|
||||||
|
object_from_gcs = await gcs_logger.download_gcs_object(
|
||||||
|
object_name=gcs_log_id,
|
||||||
|
standard_callback_dynamic_params=standard_callback_dynamic_params,
|
||||||
|
)
|
||||||
|
print("object from gcs=", object_from_gcs)
|
||||||
|
# convert object_from_gcs from bytes to DICT
|
||||||
|
parsed_data = json.loads(object_from_gcs)
|
||||||
|
print("object_from_gcs as dict", parsed_data)
|
||||||
|
|
||||||
|
gcs_payload = StandardLoggingPayload(**parsed_data)
|
||||||
|
|
||||||
|
assert gcs_payload["model"] == "gpt-4o-mini"
|
||||||
|
assert gcs_payload["messages"] == [{"role": "user", "content": "This is a test"}]
|
||||||
|
|
||||||
|
assert gcs_payload["response_cost"] == 0
|
||||||
|
assert gcs_payload["status"] == "failure"
|
||||||
|
|
||||||
|
# clean up the object from GCS
|
||||||
|
await gcs_logger.delete_gcs_object(
|
||||||
|
object_name=gcs_log_id,
|
||||||
|
standard_callback_dynamic_params=standard_callback_dynamic_params,
|
||||||
|
)
|
||||||
|
|
|
@ -1389,6 +1389,138 @@ async def test_add_callback_via_key_litellm_pre_call_utils(
|
||||||
assert new_data["failure_callback"] == expected_failure_callbacks
|
assert new_data["failure_callback"] == expected_failure_callbacks
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"callback_type, expected_success_callbacks, expected_failure_callbacks",
|
||||||
|
[
|
||||||
|
("success", ["gcs_bucket"], []),
|
||||||
|
("failure", [], ["gcs_bucket"]),
|
||||||
|
("success_and_failure", ["gcs_bucket"], ["gcs_bucket"]),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_add_callback_via_key_litellm_pre_call_utils_gcs_bucket(
|
||||||
|
prisma_client, callback_type, expected_success_callbacks, expected_failure_callbacks
|
||||||
|
):
|
||||||
|
import json
|
||||||
|
|
||||||
|
from fastapi import HTTPException, Request, Response
|
||||||
|
from starlette.datastructures import URL
|
||||||
|
|
||||||
|
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
|
||||||
|
|
||||||
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||||
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||||
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||||
|
|
||||||
|
proxy_config = getattr(litellm.proxy.proxy_server, "proxy_config")
|
||||||
|
|
||||||
|
request = Request(scope={"type": "http", "method": "POST", "headers": {}})
|
||||||
|
request._url = URL(url="/chat/completions")
|
||||||
|
|
||||||
|
test_data = {
|
||||||
|
"model": "azure/chatgpt-v-2",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "write 1 sentence poem"},
|
||||||
|
],
|
||||||
|
"max_tokens": 10,
|
||||||
|
"mock_response": "Hello world",
|
||||||
|
"api_key": "my-fake-key",
|
||||||
|
}
|
||||||
|
|
||||||
|
json_bytes = json.dumps(test_data).encode("utf-8")
|
||||||
|
|
||||||
|
request._body = json_bytes
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"data": {
|
||||||
|
"model": "azure/chatgpt-v-2",
|
||||||
|
"messages": [{"role": "user", "content": "write 1 sentence poem"}],
|
||||||
|
"max_tokens": 10,
|
||||||
|
"mock_response": "Hello world",
|
||||||
|
"api_key": "my-fake-key",
|
||||||
|
},
|
||||||
|
"request": request,
|
||||||
|
"user_api_key_dict": UserAPIKeyAuth(
|
||||||
|
token=None,
|
||||||
|
key_name=None,
|
||||||
|
key_alias=None,
|
||||||
|
spend=0.0,
|
||||||
|
max_budget=None,
|
||||||
|
expires=None,
|
||||||
|
models=[],
|
||||||
|
aliases={},
|
||||||
|
config={},
|
||||||
|
user_id=None,
|
||||||
|
team_id=None,
|
||||||
|
max_parallel_requests=None,
|
||||||
|
metadata={
|
||||||
|
"logging": [
|
||||||
|
{
|
||||||
|
"callback_name": "gcs_bucket",
|
||||||
|
"callback_type": callback_type,
|
||||||
|
"callback_vars": {
|
||||||
|
"gcs_bucket_name": "key-logging-project1",
|
||||||
|
"gcs_path_service_account": "adroit-crow-413218-a956eef1a2a8.json",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
tpm_limit=None,
|
||||||
|
rpm_limit=None,
|
||||||
|
budget_duration=None,
|
||||||
|
budget_reset_at=None,
|
||||||
|
allowed_cache_controls=[],
|
||||||
|
permissions={},
|
||||||
|
model_spend={},
|
||||||
|
model_max_budget={},
|
||||||
|
soft_budget_cooldown=False,
|
||||||
|
litellm_budget_table=None,
|
||||||
|
org_id=None,
|
||||||
|
team_spend=None,
|
||||||
|
team_alias=None,
|
||||||
|
team_tpm_limit=None,
|
||||||
|
team_rpm_limit=None,
|
||||||
|
team_max_budget=None,
|
||||||
|
team_models=[],
|
||||||
|
team_blocked=False,
|
||||||
|
soft_budget=None,
|
||||||
|
team_model_aliases=None,
|
||||||
|
team_member_spend=None,
|
||||||
|
team_metadata=None,
|
||||||
|
end_user_id=None,
|
||||||
|
end_user_tpm_limit=None,
|
||||||
|
end_user_rpm_limit=None,
|
||||||
|
end_user_max_budget=None,
|
||||||
|
last_refreshed_at=None,
|
||||||
|
api_key=None,
|
||||||
|
user_role=None,
|
||||||
|
allowed_model_region=None,
|
||||||
|
parent_otel_span=None,
|
||||||
|
),
|
||||||
|
"proxy_config": proxy_config,
|
||||||
|
"general_settings": {},
|
||||||
|
"version": "0.0.0",
|
||||||
|
}
|
||||||
|
|
||||||
|
new_data = await add_litellm_data_to_request(**data)
|
||||||
|
print("NEW DATA: {}".format(new_data))
|
||||||
|
|
||||||
|
assert "gcs_bucket_name" in new_data
|
||||||
|
assert new_data["gcs_bucket_name"] == "key-logging-project1"
|
||||||
|
assert "gcs_path_service_account" in new_data
|
||||||
|
assert (
|
||||||
|
new_data["gcs_path_service_account"] == "adroit-crow-413218-a956eef1a2a8.json"
|
||||||
|
)
|
||||||
|
|
||||||
|
if expected_success_callbacks:
|
||||||
|
assert "success_callback" in new_data
|
||||||
|
assert new_data["success_callback"] == expected_success_callbacks
|
||||||
|
|
||||||
|
if expected_failure_callbacks:
|
||||||
|
assert "failure_callback" in new_data
|
||||||
|
assert new_data["failure_callback"] == expected_failure_callbacks
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_gemini_pass_through_endpoint():
|
async def test_gemini_pass_through_endpoint():
|
||||||
from starlette.datastructures import URL
|
from starlette.datastructures import URL
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue