From 670ecda4e211d773d89a66c533da43fd5d969ae8 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 4 Oct 2024 11:56:10 +0530 Subject: [PATCH] (fixes) gcs bucket key based logging (#6044) * fixes for gcs bucket logging * fix StandardCallbackDynamicParams * fix - gcs logging when payload is not serializable * add test_add_callback_via_key_litellm_pre_call_utils_gcs_bucket * working success callbacks * linting fixes * fix linting error * add type hints to functions * fixes for dynamic success and failure logging * fix for test_async_chat_openai_stream --- litellm/integrations/gcs_bucket.py | 4 +- litellm/litellm_core_utils/litellm_logging.py | 159 +++++++++++++++--- litellm/proxy/_types.py | 37 ++-- .../pass_through_endpoints/success_handler.py | 2 +- litellm/proxy/proxy_config.yaml | 3 +- litellm/types/utils.py | 4 + litellm/utils.py | 17 +- tests/local_testing/test_gcs_bucket.py | 127 +++++++++++++- tests/local_testing/test_proxy_server.py | 132 +++++++++++++++ 9 files changed, 446 insertions(+), 39 deletions(-) diff --git a/litellm/integrations/gcs_bucket.py b/litellm/integrations/gcs_bucket.py index dea12025b..4d82bd56b 100644 --- a/litellm/integrations/gcs_bucket.py +++ b/litellm/integrations/gcs_bucket.py @@ -82,7 +82,7 @@ class GCSBucketLogger(GCSBucketBase): if logging_payload is None: raise ValueError("standard_logging_object not found in kwargs") - json_logged_payload = json.dumps(logging_payload) + json_logged_payload = json.dumps(logging_payload, default=str) # Get the current date current_date = datetime.now().strftime("%Y-%m-%d") @@ -137,7 +137,7 @@ class GCSBucketLogger(GCSBucketBase): _litellm_params = kwargs.get("litellm_params") or {} metadata = _litellm_params.get("metadata") or {} - json_logged_payload = json.dumps(logging_payload) + json_logged_payload = json.dumps(logging_payload, default=str) # Get the current date current_date = datetime.now().strftime("%Y-%m-%d") diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index c22df85af..2e874a487 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -192,16 +192,28 @@ class Logging: def __init__( self, - model, + model: str, messages, stream, call_type, start_time, - litellm_call_id, - function_id, - dynamic_success_callbacks=None, - dynamic_failure_callbacks=None, - dynamic_async_success_callbacks=None, + litellm_call_id: str, + function_id: str, + dynamic_input_callbacks: Optional[ + List[Union[str, Callable, CustomLogger]] + ] = None, + dynamic_success_callbacks: Optional[ + List[Union[str, Callable, CustomLogger]] + ] = None, + dynamic_async_success_callbacks: Optional[ + List[Union[str, Callable, CustomLogger]] + ] = None, + dynamic_failure_callbacks: Optional[ + List[Union[str, Callable, CustomLogger]] + ] = None, + dynamic_async_failure_callbacks: Optional[ + List[Union[str, Callable, CustomLogger]] + ] = None, kwargs: Optional[Dict] = None, ): if messages is not None: @@ -230,27 +242,117 @@ class Logging: [] ) # for generating complete stream response self.model_call_details: Dict[Any, Any] = {} - self.dynamic_input_callbacks: List[Any] = ( - [] - ) # [TODO] callbacks set for just that call - self.dynamic_failure_callbacks = dynamic_failure_callbacks - self.dynamic_success_callbacks = ( - dynamic_success_callbacks # callbacks set for just that call - ) - self.dynamic_async_success_callbacks = ( - dynamic_async_success_callbacks # callbacks set for just that call - ) + + # Initialize dynamic callbacks + self.dynamic_input_callbacks: Optional[ + List[Union[str, Callable, CustomLogger]] + ] = dynamic_input_callbacks + self.dynamic_success_callbacks: Optional[ + List[Union[str, Callable, CustomLogger]] + ] = dynamic_success_callbacks + self.dynamic_async_success_callbacks: Optional[ + List[Union[str, Callable, CustomLogger]] + ] = dynamic_async_success_callbacks + self.dynamic_failure_callbacks: Optional[ + List[Union[str, Callable, CustomLogger]] + ] = dynamic_failure_callbacks + self.dynamic_async_failure_callbacks: Optional[ + List[Union[str, Callable, CustomLogger]] + ] = dynamic_async_failure_callbacks + + # Process dynamic callbacks + self.process_dynamic_callbacks() + ## DYNAMIC LANGFUSE / GCS / logging callback KEYS ## self.standard_callback_dynamic_params: StandardCallbackDynamicParams = ( self.initialize_standard_callback_dynamic_params(kwargs) ) - ## TIME TO FIRST TOKEN LOGGING ## + ## TIME TO FIRST TOKEN LOGGING ## self.completion_start_time: Optional[datetime.datetime] = None + def process_dynamic_callbacks(self): + """ + Initializes CustomLogger compatible callbacks in self.dynamic_* callbacks + + If a callback is in litellm._known_custom_logger_compatible_callbacks, it needs to be intialized and added to the respective dynamic_* callback list. + """ + # Process input callbacks + self.dynamic_input_callbacks = self._process_dynamic_callback_list( + self.dynamic_input_callbacks, dynamic_callbacks_type="input" + ) + + # Process failure callbacks + self.dynamic_failure_callbacks = self._process_dynamic_callback_list( + self.dynamic_failure_callbacks, dynamic_callbacks_type="failure" + ) + + # Process async failure callbacks + self.dynamic_async_failure_callbacks = self._process_dynamic_callback_list( + self.dynamic_async_failure_callbacks, dynamic_callbacks_type="async_failure" + ) + + # Process success callbacks + self.dynamic_success_callbacks = self._process_dynamic_callback_list( + self.dynamic_success_callbacks, dynamic_callbacks_type="success" + ) + + # Process async success callbacks + self.dynamic_async_success_callbacks = self._process_dynamic_callback_list( + self.dynamic_async_success_callbacks, dynamic_callbacks_type="async_success" + ) + + def _process_dynamic_callback_list( + self, + callback_list: Optional[List[Union[str, Callable, CustomLogger]]], + dynamic_callbacks_type: Literal[ + "input", "success", "failure", "async_success", "async_failure" + ], + ) -> Optional[List[Union[str, Callable, CustomLogger]]]: + """ + Helper function to initialize CustomLogger compatible callbacks in self.dynamic_* callbacks + + - If a callback is in litellm._known_custom_logger_compatible_callbacks, + replace the string with the initialized callback class. + - If dynamic callback is a "success" callback that is a known_custom_logger_compatible_callbacks then add it to dynamic_async_success_callbacks + - If dynamic callback is a "failure" callback that is a known_custom_logger_compatible_callbacks then add it to dynamic_failure_callbacks + """ + if callback_list is None: + return None + + processed_list: List[Union[str, Callable, CustomLogger]] = [] + for callback in callback_list: + if ( + isinstance(callback, str) + and callback in litellm._known_custom_logger_compatible_callbacks + ): + callback_class = _init_custom_logger_compatible_class( + callback, internal_usage_cache=None, llm_router=None # type: ignore + ) + if callback_class is not None: + processed_list.append(callback_class) + + # If processing dynamic_success_callbacks, add to dynamic_async_success_callbacks + if dynamic_callbacks_type == "success": + if self.dynamic_async_success_callbacks is None: + self.dynamic_async_success_callbacks = [] + self.dynamic_async_success_callbacks.append(callback_class) + elif dynamic_callbacks_type == "failure": + if self.dynamic_async_failure_callbacks is None: + self.dynamic_async_failure_callbacks = [] + self.dynamic_async_failure_callbacks.append(callback_class) + else: + processed_list.append(callback) + return processed_list + def initialize_standard_callback_dynamic_params( self, kwargs: Optional[Dict] = None ) -> StandardCallbackDynamicParams: + """ + Initialize the standard callback dynamic params from the kwargs + + checks if langfuse_secret_key, gcs_bucket_name in kwargs and sets the corresponding attributes in StandardCallbackDynamicParams + """ standard_callback_dynamic_params = StandardCallbackDynamicParams() if kwargs: _supported_callback_params = ( @@ -413,7 +515,7 @@ class Logging: self.model_call_details["api_call_start_time"] = datetime.datetime.now() # Input Integration Logging -> If you want to log the fact that an attempt to call the model was made - callbacks = litellm.input_callback + self.dynamic_input_callbacks + callbacks = litellm.input_callback + (self.dynamic_input_callbacks or []) for callback in callbacks: try: if callback == "supabase" and supabaseClient is not None: @@ -529,7 +631,7 @@ class Logging: ) # Input Integration Logging -> If you want to log the fact that an attempt to call the model was made - callbacks = litellm.input_callback + self.dynamic_input_callbacks + callbacks = litellm.input_callback + (self.dynamic_input_callbacks or []) for callback in callbacks: try: if callback == "sentry" and add_breadcrumb: @@ -2004,8 +2106,25 @@ class Logging: start_time=start_time, end_time=end_time, ) + + callbacks = [] # init this to empty incase it's not created + + if self.dynamic_async_failure_callbacks is not None and isinstance( + self.dynamic_async_failure_callbacks, list + ): + callbacks = self.dynamic_async_failure_callbacks + ## keep the internal functions ## + for callback in litellm._async_failure_callback: + if ( + isinstance(callback, CustomLogger) + and "_PROXY_" in callback.__class__.__name__ + ): + callbacks.append(callback) + else: + callbacks = litellm._async_failure_callback + result = None # result sent to all loggers, init this to None incase it's not created - for callback in litellm._async_failure_callback: + for callback in callbacks: try: if isinstance(callback, CustomLogger): # custom logger class await callback.async_log_failure_event( diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 8f01bfbea..1b4996856 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -12,7 +12,7 @@ from typing_extensions import Annotated, TypedDict from litellm.integrations.SlackAlerting.types import AlertType from litellm.types.router import RouterErrors, UpdateRouterConfig -from litellm.types.utils import ProviderField +from litellm.types.utils import ProviderField, StandardCallbackDynamicParams if TYPE_CHECKING: from opentelemetry.trace import Span as _Span @@ -959,21 +959,38 @@ class BlockKeyRequest(LiteLLMBase): class AddTeamCallback(LiteLLMBase): callback_name: str callback_type: Literal["success", "failure", "success_and_failure"] - # for now - only supported for langfuse - callback_vars: Dict[ - Literal["langfuse_public_key", "langfuse_secret_key", "langfuse_host"], str - ] + callback_vars: Dict[str, str] + + @model_validator(mode="before") + @classmethod + def validate_callback_vars(cls, values): + callback_vars = values.get("callback_vars", {}) + valid_keys = set(StandardCallbackDynamicParams.__annotations__.keys()) + for key in callback_vars: + if key not in valid_keys: + raise ValueError( + f"Invalid callback variable: {key}. Must be one of {valid_keys}" + ) + return values class TeamCallbackMetadata(LiteLLMBase): success_callback: Optional[List[str]] = [] failure_callback: Optional[List[str]] = [] # for now - only supported for langfuse - callback_vars: Optional[ - Dict[ - Literal["langfuse_public_key", "langfuse_secret_key", "langfuse_host"], str - ] - ] = {} + callback_vars: Optional[Dict[str, str]] = {} + + @model_validator(mode="before") + @classmethod + def validate_callback_vars(cls, values): + callback_vars = values.get("callback_vars", {}) + valid_keys = set(StandardCallbackDynamicParams.__annotations__.keys()) + for key in callback_vars: + if key not in valid_keys: + raise ValueError( + f"Invalid callback variable: {key}. Must be one of {valid_keys}" + ) + return values class LiteLLM_TeamTable(TeamBase): diff --git a/litellm/proxy/pass_through_endpoints/success_handler.py b/litellm/proxy/pass_through_endpoints/success_handler.py index 45ba10f1c..4cfaf490f 100644 --- a/litellm/proxy/pass_through_endpoints/success_handler.py +++ b/litellm/proxy/pass_through_endpoints/success_handler.py @@ -115,7 +115,7 @@ class PassThroughEndpointLogging: encoding=None, ) ) - logging_obj.model = litellm_model_response.model + logging_obj.model = litellm_model_response.model or model logging_obj.model_call_details["model"] = logging_obj.model await logging_obj.async_success_handler( diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index c37e433de..d9b275ae8 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -5,7 +5,6 @@ model_list: api_key: fake-key api_base: https://exampleopenaiendpoint-production.up.railway.app/ -general_settings: - alerting: ["slack"] + diff --git a/litellm/types/utils.py b/litellm/types/utils.py index c8adf0bdc..429d8cf6e 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -1384,8 +1384,12 @@ OPENAI_RESPONSE_HEADERS = [ class StandardCallbackDynamicParams(TypedDict, total=False): + # Langfuse dynamic params langfuse_public_key: Optional[str] langfuse_secret: Optional[str] + langfuse_secret_key: Optional[str] langfuse_host: Optional[str] + + # GCS dynamic params gcs_bucket_name: Optional[str] gcs_path_service_account: Optional[str] diff --git a/litellm/utils.py b/litellm/utils.py index 7fcc6e6e1..e8dea5759 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -58,6 +58,7 @@ import litellm.litellm_core_utils import litellm.litellm_core_utils.audio_utils.utils import litellm.litellm_core_utils.json_validation_rule from litellm.caching import DualCache +from litellm.integrations.custom_logger import CustomLogger from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.litellm_core_utils.exception_mapping_utils import ( _get_litellm_response_headers, @@ -430,9 +431,18 @@ def function_setup( for index in reversed(removed_async_items): litellm.failure_callback.pop(index) ### DYNAMIC CALLBACKS ### - dynamic_success_callbacks = None - dynamic_async_success_callbacks = None - dynamic_failure_callbacks = None + dynamic_success_callbacks: Optional[ + List[Union[str, Callable, CustomLogger]] + ] = None + dynamic_async_success_callbacks: Optional[ + List[Union[str, Callable, CustomLogger]] + ] = None + dynamic_failure_callbacks: Optional[ + List[Union[str, Callable, CustomLogger]] + ] = None + dynamic_async_failure_callbacks: Optional[ + List[Union[str, Callable, CustomLogger]] + ] = None if kwargs.get("success_callback", None) is not None and isinstance( kwargs["success_callback"], list ): @@ -561,6 +571,7 @@ def function_setup( dynamic_success_callbacks=dynamic_success_callbacks, dynamic_failure_callbacks=dynamic_failure_callbacks, dynamic_async_success_callbacks=dynamic_async_success_callbacks, + dynamic_async_failure_callbacks=dynamic_async_failure_callbacks, kwargs=kwargs, ) diff --git a/tests/local_testing/test_gcs_bucket.py b/tests/local_testing/test_gcs_bucket.py index 2e0899fe4..711f45459 100644 --- a/tests/local_testing/test_gcs_bucket.py +++ b/tests/local_testing/test_gcs_bucket.py @@ -267,7 +267,7 @@ async def test_basic_gcs_logger_failure(): @pytest.mark.asyncio -async def test_basic_gcs_logging_per_request(): +async def test_basic_gcs_logging_per_request_with_callback_set(): """ Test GCS Bucket logging per request @@ -391,3 +391,128 @@ async def test_basic_gcs_logging_per_request(): object_name=object_name, standard_callback_dynamic_params=standard_callback_dynamic_params, ) + + +@pytest.mark.asyncio +async def test_basic_gcs_logging_per_request_with_no_litellm_callback_set(): + """ + Test GCS Bucket logging per request + + key difference: no litellm.callbacks set + + Request 1 - pass gcs_bucket_name in kwargs + Request 2 - don't pass gcs_bucket_name in kwargs - ensure 'litellm-testing-bucket' + """ + import logging + from litellm._logging import verbose_logger + + verbose_logger.setLevel(logging.DEBUG) + load_vertex_ai_credentials() + gcs_logger = GCSBucketLogger() + + GCS_BUCKET_NAME = "key-logging-project1" + standard_callback_dynamic_params: StandardCallbackDynamicParams = ( + StandardCallbackDynamicParams(gcs_bucket_name=GCS_BUCKET_NAME) + ) + + try: + response = await litellm.acompletion( + model="gpt-4o-mini", + temperature=0.7, + messages=[{"role": "user", "content": "This is a test"}], + max_tokens=10, + user="ishaan-2", + gcs_bucket_name=GCS_BUCKET_NAME, + success_callback=["gcs_bucket"], + failure_callback=["gcs_bucket"], + ) + except: + pass + + await asyncio.sleep(5) + + # Get the current date + # Get the current date + current_date = datetime.now().strftime("%Y-%m-%d") + + # Modify the object_name to include the date-based folder + object_name = f"{current_date}%2F{response.id}" + + print("object_name", object_name) + + # Check if object landed on GCS + object_from_gcs = await gcs_logger.download_gcs_object( + object_name=object_name, + standard_callback_dynamic_params=standard_callback_dynamic_params, + ) + print("object from gcs=", object_from_gcs) + # convert object_from_gcs from bytes to DICT + parsed_data = json.loads(object_from_gcs) + print("object_from_gcs as dict", parsed_data) + + print("type of object_from_gcs", type(parsed_data)) + + gcs_payload = StandardLoggingPayload(**parsed_data) + + assert gcs_payload["model"] == "gpt-4o-mini" + assert gcs_payload["messages"] == [{"role": "user", "content": "This is a test"}] + + assert gcs_payload["response_cost"] > 0.0 + + assert gcs_payload["status"] == "success" + + # clean up the object from GCS + await gcs_logger.delete_gcs_object( + object_name=object_name, + standard_callback_dynamic_params=standard_callback_dynamic_params, + ) + + # make a failure request - assert that failure callback is hit + gcs_log_id = f"failure-test-{uuid.uuid4().hex}" + try: + response = await litellm.acompletion( + model="gpt-4o-mini", + temperature=0.7, + messages=[{"role": "user", "content": "This is a test"}], + max_tokens=10, + user="ishaan-2", + mock_response=litellm.BadRequestError( + model="gpt-3.5-turbo", + message="Error: 400: Bad Request: Invalid API key, please check your API key and try again.", + llm_provider="openai", + ), + success_callback=["gcs_bucket"], + failure_callback=["gcs_bucket"], + gcs_bucket_name=GCS_BUCKET_NAME, + metadata={ + "gcs_log_id": gcs_log_id, + }, + ) + except: + pass + + await asyncio.sleep(5) + + # check if the failure object is logged in GCS + object_from_gcs = await gcs_logger.download_gcs_object( + object_name=gcs_log_id, + standard_callback_dynamic_params=standard_callback_dynamic_params, + ) + print("object from gcs=", object_from_gcs) + # convert object_from_gcs from bytes to DICT + parsed_data = json.loads(object_from_gcs) + print("object_from_gcs as dict", parsed_data) + + gcs_payload = StandardLoggingPayload(**parsed_data) + + assert gcs_payload["model"] == "gpt-4o-mini" + assert gcs_payload["messages"] == [{"role": "user", "content": "This is a test"}] + + assert gcs_payload["response_cost"] == 0 + assert gcs_payload["status"] == "failure" + + # clean up the object from GCS + await gcs_logger.delete_gcs_object( + object_name=gcs_log_id, + standard_callback_dynamic_params=standard_callback_dynamic_params, + ) diff --git a/tests/local_testing/test_proxy_server.py b/tests/local_testing/test_proxy_server.py index 5bca1136f..98b5f058d 100644 --- a/tests/local_testing/test_proxy_server.py +++ b/tests/local_testing/test_proxy_server.py @@ -1389,6 +1389,138 @@ async def test_add_callback_via_key_litellm_pre_call_utils( assert new_data["failure_callback"] == expected_failure_callbacks +@pytest.mark.asyncio +@pytest.mark.parametrize( + "callback_type, expected_success_callbacks, expected_failure_callbacks", + [ + ("success", ["gcs_bucket"], []), + ("failure", [], ["gcs_bucket"]), + ("success_and_failure", ["gcs_bucket"], ["gcs_bucket"]), + ], +) +async def test_add_callback_via_key_litellm_pre_call_utils_gcs_bucket( + prisma_client, callback_type, expected_success_callbacks, expected_failure_callbacks +): + import json + + from fastapi import HTTPException, Request, Response + from starlette.datastructures import URL + + from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request + + setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) + setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") + await litellm.proxy.proxy_server.prisma_client.connect() + + proxy_config = getattr(litellm.proxy.proxy_server, "proxy_config") + + request = Request(scope={"type": "http", "method": "POST", "headers": {}}) + request._url = URL(url="/chat/completions") + + test_data = { + "model": "azure/chatgpt-v-2", + "messages": [ + {"role": "user", "content": "write 1 sentence poem"}, + ], + "max_tokens": 10, + "mock_response": "Hello world", + "api_key": "my-fake-key", + } + + json_bytes = json.dumps(test_data).encode("utf-8") + + request._body = json_bytes + + data = { + "data": { + "model": "azure/chatgpt-v-2", + "messages": [{"role": "user", "content": "write 1 sentence poem"}], + "max_tokens": 10, + "mock_response": "Hello world", + "api_key": "my-fake-key", + }, + "request": request, + "user_api_key_dict": UserAPIKeyAuth( + token=None, + key_name=None, + key_alias=None, + spend=0.0, + max_budget=None, + expires=None, + models=[], + aliases={}, + config={}, + user_id=None, + team_id=None, + max_parallel_requests=None, + metadata={ + "logging": [ + { + "callback_name": "gcs_bucket", + "callback_type": callback_type, + "callback_vars": { + "gcs_bucket_name": "key-logging-project1", + "gcs_path_service_account": "adroit-crow-413218-a956eef1a2a8.json", + }, + } + ] + }, + tpm_limit=None, + rpm_limit=None, + budget_duration=None, + budget_reset_at=None, + allowed_cache_controls=[], + permissions={}, + model_spend={}, + model_max_budget={}, + soft_budget_cooldown=False, + litellm_budget_table=None, + org_id=None, + team_spend=None, + team_alias=None, + team_tpm_limit=None, + team_rpm_limit=None, + team_max_budget=None, + team_models=[], + team_blocked=False, + soft_budget=None, + team_model_aliases=None, + team_member_spend=None, + team_metadata=None, + end_user_id=None, + end_user_tpm_limit=None, + end_user_rpm_limit=None, + end_user_max_budget=None, + last_refreshed_at=None, + api_key=None, + user_role=None, + allowed_model_region=None, + parent_otel_span=None, + ), + "proxy_config": proxy_config, + "general_settings": {}, + "version": "0.0.0", + } + + new_data = await add_litellm_data_to_request(**data) + print("NEW DATA: {}".format(new_data)) + + assert "gcs_bucket_name" in new_data + assert new_data["gcs_bucket_name"] == "key-logging-project1" + assert "gcs_path_service_account" in new_data + assert ( + new_data["gcs_path_service_account"] == "adroit-crow-413218-a956eef1a2a8.json" + ) + + if expected_success_callbacks: + assert "success_callback" in new_data + assert new_data["success_callback"] == expected_success_callbacks + + if expected_failure_callbacks: + assert "failure_callback" in new_data + assert new_data["failure_callback"] == expected_failure_callbacks + + @pytest.mark.asyncio async def test_gemini_pass_through_endpoint(): from starlette.datastructures import URL