diff --git a/docs/my-website/docs/proxy/team_logging.md b/docs/my-website/docs/proxy/team_logging.md index fb177da76..ed5d28af4 100644 --- a/docs/my-website/docs/proxy/team_logging.md +++ b/docs/my-website/docs/proxy/team_logging.md @@ -201,6 +201,9 @@ Use the `/key/generate` or `/key/update` endpoints to add logging callbacks to a ::: + + + ```bash curl -X POST 'http://0.0.0.0:4000/key/generate' \ -H 'Authorization: Bearer sk-1234' \ @@ -208,7 +211,7 @@ curl -X POST 'http://0.0.0.0:4000/key/generate' \ -d '{ "metadata": { "logging": [{ - "callback_name": "langfuse", # "otel", "langfuse", "lunary" + "callback_name": "langfuse", # "otel", "gcs_bucket" "callback_type": "success", # "success", "failure", "success_and_failure" "callback_vars": { "langfuse_public_key": "os.environ/LANGFUSE_PUBLIC_KEY", # [RECOMMENDED] reference key in proxy environment @@ -223,6 +226,30 @@ curl -X POST 'http://0.0.0.0:4000/key/generate' \ + + + +```bash +curl -X POST 'http://0.0.0.0:4000/key/generate' \ +-H 'Authorization: Bearer sk-1234' \ +-H 'Content-Type: application/json' \ +-d '{ + "metadata": { + "logging": [{ + "callback_name": "gcs_bucket", # "otel", "gcs_bucket" + "callback_type": "success", # "success", "failure", "success_and_failure" + "callback_vars": { + "gcs_bucket_name": "my-gcs-bucket", + "gcs_path_service_account": "os.environ/GCS_SERVICE_ACCOUNT" + } + }] + } +}' + +``` + + + --- diff --git a/litellm/integrations/gcs_bucket.py b/litellm/integrations/gcs_bucket.py index d33268bff..dea12025b 100644 --- a/litellm/integrations/gcs_bucket.py +++ b/litellm/integrations/gcs_bucket.py @@ -2,7 +2,8 @@ import json import os import uuid from datetime import datetime -from typing import Any, Dict, List, Optional, TypedDict, Union +from re import S +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, TypedDict, Union import httpx from pydantic import BaseModel, Field @@ -16,13 +17,22 @@ from litellm.litellm_core_utils.logging_utils import ( ) from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler from litellm.proxy._types import CommonProxyErrors, SpendLogsMetadata, SpendLogsPayload -from litellm.types.utils import StandardLoggingMetadata, StandardLoggingPayload +from litellm.types.utils import ( + StandardCallbackDynamicParams, + StandardLoggingMetadata, + StandardLoggingPayload, +) + +if TYPE_CHECKING: + from litellm.llms.vertex_ai_and_google_ai_studio.vertex_llm_base import VertexBase +else: + VertexBase = Any -class RequestKwargs(TypedDict): - model: Optional[str] - messages: Optional[List] - optional_params: Optional[Dict[str, Any]] +class GCSLoggingConfig(TypedDict): + bucket_name: str + vertex_instance: VertexBase + path_service_account: str class GCSBucketLogger(GCSBucketBase): @@ -30,6 +40,7 @@ class GCSBucketLogger(GCSBucketBase): from litellm.proxy.proxy_server import premium_user super().__init__(bucket_name=bucket_name) + self.vertex_instances: Dict[str, VertexBase] = {} if premium_user is not True: raise ValueError( f"GCS Bucket logging is a premium feature. Please upgrade to use it. {CommonProxyErrors.not_premium_user.value}" @@ -55,10 +66,14 @@ class GCSBucketLogger(GCSBucketBase): kwargs, response_obj, ) - - start_time.strftime("%Y-%m-%d %H:%M:%S") - end_time.strftime("%Y-%m-%d %H:%M:%S") - headers = await self.construct_request_headers() + gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config( + kwargs + ) + headers = await self.construct_request_headers( + vertex_instance=gcs_logging_config["vertex_instance"], + service_account_json=gcs_logging_config["path_service_account"], + ) + bucket_name = gcs_logging_config["bucket_name"] logging_payload: Optional[StandardLoggingPayload] = kwargs.get( "standard_logging_object", None @@ -76,7 +91,7 @@ class GCSBucketLogger(GCSBucketBase): object_name = f"{current_date}/{response_obj['id']}" response = await self.async_httpx_client.post( headers=headers, - url=f"https://storage.googleapis.com/upload/storage/v1/b/{self.BUCKET_NAME}/o?uploadType=media&name={object_name}", + url=f"https://storage.googleapis.com/upload/storage/v1/b/{bucket_name}/o?uploadType=media&name={object_name}", data=json_logged_payload, ) @@ -87,7 +102,7 @@ class GCSBucketLogger(GCSBucketBase): verbose_logger.debug("GCS Bucket status code %s", response.status_code) verbose_logger.debug("GCS Bucket response.text %s", response.text) except Exception as e: - verbose_logger.error("GCS Bucket logging error: %s", str(e)) + verbose_logger.exception(f"GCS Bucket logging error: {str(e)}") async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): from litellm.proxy.proxy_server import premium_user @@ -103,9 +118,14 @@ class GCSBucketLogger(GCSBucketBase): response_obj, ) - start_time.strftime("%Y-%m-%d %H:%M:%S") - end_time.strftime("%Y-%m-%d %H:%M:%S") - headers = await self.construct_request_headers() + gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config( + kwargs + ) + headers = await self.construct_request_headers( + vertex_instance=gcs_logging_config["vertex_instance"], + service_account_json=gcs_logging_config["path_service_account"], + ) + bucket_name = gcs_logging_config["bucket_name"] logging_payload: Optional[StandardLoggingPayload] = kwargs.get( "standard_logging_object", None @@ -130,7 +150,7 @@ class GCSBucketLogger(GCSBucketBase): response = await self.async_httpx_client.post( headers=headers, - url=f"https://storage.googleapis.com/upload/storage/v1/b/{self.BUCKET_NAME}/o?uploadType=media&name={object_name}", + url=f"https://storage.googleapis.com/upload/storage/v1/b/{bucket_name}/o?uploadType=media&name={object_name}", data=json_logged_payload, ) @@ -141,4 +161,146 @@ class GCSBucketLogger(GCSBucketBase): verbose_logger.debug("GCS Bucket status code %s", response.status_code) verbose_logger.debug("GCS Bucket response.text %s", response.text) except Exception as e: - verbose_logger.error("GCS Bucket logging error: %s", str(e)) + verbose_logger.exception(f"GCS Bucket logging error: {str(e)}") + + async def get_gcs_logging_config( + self, kwargs: Optional[Dict[str, Any]] = {} + ) -> GCSLoggingConfig: + """ + This function is used to get the GCS logging config for the GCS Bucket Logger. + It checks if the dynamic parameters are provided in the kwargs and uses them to get the GCS logging config. + If no dynamic parameters are provided, it uses the default values. + """ + if kwargs is None: + kwargs = {} + + standard_callback_dynamic_params: Optional[StandardCallbackDynamicParams] = ( + kwargs.get("standard_callback_dynamic_params", None) + ) + + if standard_callback_dynamic_params is not None: + verbose_logger.debug("Using dynamic GCS logging") + verbose_logger.debug( + "standard_callback_dynamic_params: %s", standard_callback_dynamic_params + ) + + bucket_name: str = ( + standard_callback_dynamic_params.get("gcs_bucket_name", None) + or self.BUCKET_NAME + ) + path_service_account: str = ( + standard_callback_dynamic_params.get("gcs_path_service_account", None) + or self.path_service_account_json + ) + + vertex_instance = await self.get_or_create_vertex_instance( + credentials=path_service_account + ) + else: + # If no dynamic parameters, use the default instance + bucket_name = self.BUCKET_NAME + path_service_account = self.path_service_account_json + vertex_instance = await self.get_or_create_vertex_instance( + credentials=path_service_account + ) + + return GCSLoggingConfig( + bucket_name=bucket_name, + vertex_instance=vertex_instance, + path_service_account=path_service_account, + ) + + async def get_or_create_vertex_instance(self, credentials: str) -> VertexBase: + """ + This function is used to get the Vertex instance for the GCS Bucket Logger. + It checks if the Vertex instance is already created and cached, if not it creates a new instance and caches it. + """ + from litellm.llms.vertex_ai_and_google_ai_studio.vertex_llm_base import ( + VertexBase, + ) + + if credentials not in self.vertex_instances: + vertex_instance = VertexBase() + await vertex_instance._ensure_access_token_async( + credentials=credentials, + project_id=None, + custom_llm_provider="vertex_ai", + ) + self.vertex_instances[credentials] = vertex_instance + return self.vertex_instances[credentials] + + async def download_gcs_object(self, object_name: str, **kwargs): + """ + Download an object from GCS. + + https://cloud.google.com/storage/docs/downloading-objects#download-object-json + """ + try: + gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config( + kwargs=kwargs + ) + headers = await self.construct_request_headers( + vertex_instance=gcs_logging_config["vertex_instance"], + service_account_json=gcs_logging_config["path_service_account"], + ) + bucket_name = gcs_logging_config["bucket_name"] + url = f"https://storage.googleapis.com/storage/v1/b/{bucket_name}/o/{object_name}?alt=media" + + # Send the GET request to download the object + response = await self.async_httpx_client.get(url=url, headers=headers) + + if response.status_code != 200: + verbose_logger.error( + "GCS object download error: %s", str(response.text) + ) + return None + + verbose_logger.debug( + "GCS object download response status code: %s", response.status_code + ) + + # Return the content of the downloaded object + return response.content + + except Exception as e: + verbose_logger.error("GCS object download error: %s", str(e)) + return None + + async def delete_gcs_object(self, object_name: str, **kwargs): + """ + Delete an object from GCS. + """ + try: + gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config( + kwargs=kwargs + ) + headers = await self.construct_request_headers( + vertex_instance=gcs_logging_config["vertex_instance"], + service_account_json=gcs_logging_config["path_service_account"], + ) + bucket_name = gcs_logging_config["bucket_name"] + url = f"https://storage.googleapis.com/storage/v1/b/{bucket_name}/o/{object_name}" + + # Send the DELETE request to delete the object + response = await self.async_httpx_client.delete(url=url, headers=headers) + + if (response.status_code != 200) or (response.status_code != 204): + verbose_logger.error( + "GCS object delete error: %s, status code: %s", + str(response.text), + response.status_code, + ) + return None + + verbose_logger.debug( + "GCS object delete response status code: %s, response: %s", + response.status_code, + response.text, + ) + + # Return the content of the downloaded object + return response.text + + except Exception as e: + verbose_logger.error("GCS object download error: %s", str(e)) + return None diff --git a/litellm/integrations/gcs_bucket_base.py b/litellm/integrations/gcs_bucket_base.py index 165686ed6..1d1672c4d 100644 --- a/litellm/integrations/gcs_bucket_base.py +++ b/litellm/integrations/gcs_bucket_base.py @@ -2,7 +2,7 @@ import json import os import uuid from datetime import datetime -from typing import Any, Dict, List, Optional, TypedDict, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypedDict, Union import httpx from pydantic import BaseModel, Field @@ -18,37 +18,48 @@ from litellm.llms.custom_httpx.http_handler import ( httpxSpecialProvider, ) +if TYPE_CHECKING: + from litellm.llms.vertex_ai_and_google_ai_studio.vertex_llm_base import VertexBase +else: + VertexBase = Any + class GCSBucketBase(CustomLogger): def __init__(self, bucket_name: Optional[str] = None) -> None: - from litellm.proxy.proxy_server import premium_user - self.async_httpx_client = get_async_httpx_client( llm_provider=httpxSpecialProvider.LoggingCallback ) - self.path_service_account_json = os.getenv("GCS_PATH_SERVICE_ACCOUNT", None) - self.BUCKET_NAME = bucket_name or os.getenv("GCS_BUCKET_NAME", None) - - if self.BUCKET_NAME is None: + _path_service_account = os.getenv("GCS_PATH_SERVICE_ACCOUNT") + _bucket_name = bucket_name or os.getenv("GCS_BUCKET_NAME") + if _path_service_account is None: + raise ValueError("GCS_PATH_SERVICE_ACCOUNT environment variable is not set") + if _bucket_name is None: raise ValueError( "GCS_BUCKET_NAME is not set in the environment, but GCS Bucket is being used as a logging callback. Please set 'GCS_BUCKET_NAME' in the environment." ) + self.path_service_account_json: str = _path_service_account + self.BUCKET_NAME: str = _bucket_name - async def construct_request_headers(self) -> Dict[str, str]: + async def construct_request_headers( + self, + service_account_json: str, + vertex_instance: Optional[VertexBase] = None, + ) -> Dict[str, str]: from litellm import vertex_chat_completion - _auth_header, vertex_project = ( - await vertex_chat_completion._ensure_access_token_async( - credentials=self.path_service_account_json, - project_id=None, - custom_llm_provider="vertex_ai", - ) + if vertex_instance is None: + vertex_instance = vertex_chat_completion + + _auth_header, vertex_project = await vertex_instance._ensure_access_token_async( + credentials=service_account_json, + project_id=None, + custom_llm_provider="vertex_ai", ) - auth_header, _ = vertex_chat_completion._get_token_and_url( + auth_header, _ = vertex_instance._get_token_and_url( model="gcs-bucket", auth_header=_auth_header, - vertex_credentials=self.path_service_account_json, + vertex_credentials=service_account_json, vertex_project=vertex_project, vertex_location=None, gemini_api_key=None, @@ -91,65 +102,3 @@ class GCSBucketBase(CustomLogger): } return headers - - async def download_gcs_object(self, object_name): - """ - Download an object from GCS. - - https://cloud.google.com/storage/docs/downloading-objects#download-object-json - """ - try: - headers = await self.construct_request_headers() - url = f"https://storage.googleapis.com/storage/v1/b/{self.BUCKET_NAME}/o/{object_name}?alt=media" - - # Send the GET request to download the object - response = await self.async_httpx_client.get(url=url, headers=headers) - - if response.status_code != 200: - verbose_logger.error( - "GCS object download error: %s", str(response.text) - ) - return None - - verbose_logger.debug( - "GCS object download response status code: %s", response.status_code - ) - - # Return the content of the downloaded object - return response.content - - except Exception as e: - verbose_logger.error("GCS object download error: %s", str(e)) - return None - - async def delete_gcs_object(self, object_name): - """ - Delete an object from GCS. - """ - try: - headers = await self.construct_request_headers() - url = f"https://storage.googleapis.com/storage/v1/b/{self.BUCKET_NAME}/o/{object_name}" - - # Send the DELETE request to delete the object - response = await self.async_httpx_client.delete(url=url, headers=headers) - - if (response.status_code != 200) or (response.status_code != 204): - verbose_logger.error( - "GCS object delete error: %s, status code: %s", - str(response.text), - response.status_code, - ) - return None - - verbose_logger.debug( - "GCS object delete response status code: %s, response: %s", - response.status_code, - response.text, - ) - - # Return the content of the downloaded object - return response.text - - except Exception as e: - verbose_logger.error("GCS object download error: %s", str(e)) - return None diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index f3af2dcbd..3702004f5 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -40,6 +40,7 @@ from litellm.types.utils import ( EmbeddingResponse, ImageResponse, ModelResponse, + StandardCallbackDynamicParams, StandardLoggingHiddenParams, StandardLoggingMetadata, StandardLoggingModelCostFailureDebugInformation, @@ -200,9 +201,7 @@ class Logging: dynamic_success_callbacks=None, dynamic_failure_callbacks=None, dynamic_async_success_callbacks=None, - langfuse_public_key=None, - langfuse_secret=None, - langfuse_host=None, + kwargs: Optional[Dict] = None, ): if messages is not None: if isinstance(messages, str): @@ -225,10 +224,14 @@ class Logging: self.call_type = call_type self.litellm_call_id = litellm_call_id self.function_id = function_id - self.streaming_chunks = [] # for generating complete stream response - self.sync_streaming_chunks = [] # for generating complete stream response - self.model_call_details = {} - self.dynamic_input_callbacks = [] # [TODO] callbacks set for just that call + self.streaming_chunks: List[Any] = [] # for generating complete stream response + self.sync_streaming_chunks: List[Any] = ( + [] + ) # for generating complete stream response + self.model_call_details: Dict[Any, Any] = {} + self.dynamic_input_callbacks: List[Any] = ( + [] + ) # [TODO] callbacks set for just that call self.dynamic_failure_callbacks = dynamic_failure_callbacks self.dynamic_success_callbacks = ( dynamic_success_callbacks # callbacks set for just that call @@ -236,13 +239,27 @@ class Logging: self.dynamic_async_success_callbacks = ( dynamic_async_success_callbacks # callbacks set for just that call ) - ## DYNAMIC LANGFUSE KEYS ## - self.langfuse_public_key = langfuse_public_key - self.langfuse_secret = langfuse_secret - self.langfuse_host = langfuse_host + ## DYNAMIC LANGFUSE / GCS / logging callback KEYS ## + self.standard_callback_dynamic_params: StandardCallbackDynamicParams = ( + self.initialize_standard_callback_dynamic_params(kwargs) + ) ## TIME TO FIRST TOKEN LOGGING ## + self.completion_start_time: Optional[datetime.datetime] = None + def initialize_standard_callback_dynamic_params( + self, kwargs: Optional[Dict] = None + ) -> StandardCallbackDynamicParams: + standard_callback_dynamic_params = StandardCallbackDynamicParams() + if kwargs: + _supported_callback_params = ( + StandardCallbackDynamicParams.__annotations__.keys() + ) + for param in _supported_callback_params: + if param in kwargs: + standard_callback_dynamic_params[param] = kwargs.pop(param) # type: ignore + return standard_callback_dynamic_params + def update_environment_variables( self, model, user, optional_params, litellm_params, **additional_params ): @@ -264,6 +281,7 @@ class Logging: "call_type": str(self.call_type), "litellm_call_id": self.litellm_call_id, "completion_start_time": self.completion_start_time, + "standard_callback_dynamic_params": self.standard_callback_dynamic_params, **self.optional_params, **additional_params, } @@ -999,23 +1017,46 @@ class Logging: temp_langfuse_logger = langFuseLogger if langFuseLogger is None or ( ( - self.langfuse_public_key is not None - and self.langfuse_public_key + self.standard_callback_dynamic_params.get( + "langfuse_public_key" + ) + is not None + and self.standard_callback_dynamic_params.get( + "langfuse_public_key" + ) != langFuseLogger.public_key ) or ( - self.langfuse_secret is not None - and self.langfuse_secret != langFuseLogger.secret_key + self.standard_callback_dynamic_params.get( + "langfuse_secret" + ) + is not None + and self.standard_callback_dynamic_params.get( + "langfuse_secret" + ) + != langFuseLogger.secret_key ) or ( - self.langfuse_host is not None - and self.langfuse_host != langFuseLogger.langfuse_host + self.standard_callback_dynamic_params.get( + "langfuse_host" + ) + is not None + and self.standard_callback_dynamic_params.get( + "langfuse_host" + ) + != langFuseLogger.langfuse_host ) ): credentials = { - "langfuse_public_key": self.langfuse_public_key, - "langfuse_secret": self.langfuse_secret, - "langfuse_host": self.langfuse_host, + "langfuse_public_key": self.standard_callback_dynamic_params.get( + "langfuse_public_key" + ), + "langfuse_secret": self.standard_callback_dynamic_params.get( + "langfuse_secret" + ), + "langfuse_host": self.standard_callback_dynamic_params.get( + "langfuse_host" + ), } temp_langfuse_logger = ( in_memory_dynamic_logger_cache.get_cache( @@ -1024,9 +1065,15 @@ class Logging: ) if temp_langfuse_logger is None: temp_langfuse_logger = LangFuseLogger( - langfuse_public_key=self.langfuse_public_key, - langfuse_secret=self.langfuse_secret, - langfuse_host=self.langfuse_host, + langfuse_public_key=self.standard_callback_dynamic_params.get( + "langfuse_public_key" + ), + langfuse_secret=self.standard_callback_dynamic_params.get( + "langfuse_secret" + ), + langfuse_host=self.standard_callback_dynamic_params.get( + "langfuse_host" + ), ) in_memory_dynamic_logger_cache.set_cache( credentials=credentials, @@ -1838,24 +1885,46 @@ class Logging: # this only logs streaming once, complete_streaming_response exists i.e when stream ends if langFuseLogger is None or ( ( - self.langfuse_public_key is not None - and self.langfuse_public_key + self.standard_callback_dynamic_params.get( + "langfuse_public_key" + ) + is not None + and self.standard_callback_dynamic_params.get( + "langfuse_public_key" + ) != langFuseLogger.public_key ) or ( - self.langfuse_public_key is not None - and self.langfuse_public_key + self.standard_callback_dynamic_params.get( + "langfuse_public_key" + ) + is not None + and self.standard_callback_dynamic_params.get( + "langfuse_public_key" + ) != langFuseLogger.public_key ) or ( - self.langfuse_host is not None - and self.langfuse_host != langFuseLogger.langfuse_host + self.standard_callback_dynamic_params.get( + "langfuse_host" + ) + is not None + and self.standard_callback_dynamic_params.get( + "langfuse_host" + ) + != langFuseLogger.langfuse_host ) ): langFuseLogger = LangFuseLogger( - langfuse_public_key=self.langfuse_public_key, - langfuse_secret=self.langfuse_secret, - langfuse_host=self.langfuse_host, + langfuse_public_key=self.standard_callback_dynamic_params.get( + "langfuse_public_key" + ), + langfuse_secret=self.standard_callback_dynamic_params.get( + "langfuse_secret" + ), + langfuse_host=self.standard_callback_dynamic_params.get( + "langfuse_host" + ), ) _response = langFuseLogger.log_event( start_time=start_time, @@ -1992,22 +2061,34 @@ class Logging: if service_name == "langfuse": if langFuseLogger is None or ( ( - self.langfuse_public_key is not None - and self.langfuse_public_key != langFuseLogger.public_key + self.standard_callback_dynamic_params.get("langfuse_public_key") + is not None + and self.standard_callback_dynamic_params.get("langfuse_public_key") + != langFuseLogger.public_key ) or ( - self.langfuse_public_key is not None - and self.langfuse_public_key != langFuseLogger.public_key + self.standard_callback_dynamic_params.get("langfuse_public_key") + is not None + and self.standard_callback_dynamic_params.get("langfuse_public_key") + != langFuseLogger.public_key ) or ( - self.langfuse_host is not None - and self.langfuse_host != langFuseLogger.langfuse_host + self.standard_callback_dynamic_params.get("langfuse_host") + is not None + and self.standard_callback_dynamic_params.get("langfuse_host") + != langFuseLogger.langfuse_host ) ): return LangFuseLogger( - langfuse_public_key=self.langfuse_public_key, - langfuse_secret=self.langfuse_secret, - langfuse_host=self.langfuse_host, + langfuse_public_key=self.standard_callback_dynamic_params.get( + "langfuse_public_key" + ), + langfuse_secret=self.standard_callback_dynamic_params.get( + "langfuse_secret" + ), + langfuse_host=self.standard_callback_dynamic_params.get( + "langfuse_host" + ), ) return langFuseLogger diff --git a/litellm/types/utils.py b/litellm/types/utils.py index 0b2c34a28..b043ff098 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -1365,3 +1365,11 @@ OPENAI_RESPONSE_HEADERS = [ "x-ratelimit-reset-requests", "x-ratelimit-reset-tokens", ] + + +class StandardCallbackDynamicParams(TypedDict, total=False): + langfuse_public_key: Optional[str] + langfuse_secret: Optional[str] + langfuse_host: Optional[str] + gcs_bucket_name: Optional[str] + gcs_path_service_account: Optional[str] diff --git a/litellm/utils.py b/litellm/utils.py index b0417babd..b06fe3986 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -561,13 +561,11 @@ def function_setup( dynamic_success_callbacks=dynamic_success_callbacks, dynamic_failure_callbacks=dynamic_failure_callbacks, dynamic_async_success_callbacks=dynamic_async_success_callbacks, - langfuse_public_key=kwargs.pop("langfuse_public_key", None), - langfuse_secret=kwargs.pop("langfuse_secret", None) - or kwargs.pop("langfuse_secret_key", None), - langfuse_host=kwargs.pop("langfuse_host", None), + kwargs=kwargs, ) + ## check if metadata is passed in - litellm_params = {"api_base": ""} + litellm_params: Dict[str, Any] = {"api_base": ""} if "metadata" in kwargs: litellm_params["metadata"] = kwargs["metadata"] logging_obj.update_environment_variables( diff --git a/tests/local_testing/test_gcs_bucket.py b/tests/local_testing/test_gcs_bucket.py index bc4322855..2e0899fe4 100644 --- a/tests/local_testing/test_gcs_bucket.py +++ b/tests/local_testing/test_gcs_bucket.py @@ -17,6 +17,7 @@ import litellm from litellm import completion from litellm._logging import verbose_logger from litellm.integrations.gcs_bucket import GCSBucketLogger, StandardLoggingPayload +from litellm.types.utils import StandardCallbackDynamicParams verbose_logger.setLevel(logging.DEBUG) @@ -263,3 +264,130 @@ async def test_basic_gcs_logger_failure(): # Delete Object from GCS print("deleting object from GCS") await gcs_logger.delete_gcs_object(object_name=object_name) + + +@pytest.mark.asyncio +async def test_basic_gcs_logging_per_request(): + """ + Test GCS Bucket logging per request + + Request 1 - pass gcs_bucket_name in kwargs + Request 2 - don't pass gcs_bucket_name in kwargs - ensure 'litellm-testing-bucket' + """ + import logging + from litellm._logging import verbose_logger + + verbose_logger.setLevel(logging.DEBUG) + load_vertex_ai_credentials() + gcs_logger = GCSBucketLogger() + print("GCSBucketLogger", gcs_logger) + litellm.callbacks = [gcs_logger] + + GCS_BUCKET_NAME = "key-logging-project1" + standard_callback_dynamic_params: StandardCallbackDynamicParams = ( + StandardCallbackDynamicParams(gcs_bucket_name=GCS_BUCKET_NAME) + ) + + try: + response = await litellm.acompletion( + model="gpt-4o-mini", + temperature=0.7, + messages=[{"role": "user", "content": "This is a test"}], + max_tokens=10, + user="ishaan-2", + gcs_bucket_name=GCS_BUCKET_NAME, + ) + except: + pass + + await asyncio.sleep(5) + + # Get the current date + # Get the current date + current_date = datetime.now().strftime("%Y-%m-%d") + + # Modify the object_name to include the date-based folder + object_name = f"{current_date}%2F{response.id}" + + print("object_name", object_name) + + # Check if object landed on GCS + object_from_gcs = await gcs_logger.download_gcs_object( + object_name=object_name, + standard_callback_dynamic_params=standard_callback_dynamic_params, + ) + print("object from gcs=", object_from_gcs) + # convert object_from_gcs from bytes to DICT + parsed_data = json.loads(object_from_gcs) + print("object_from_gcs as dict", parsed_data) + + print("type of object_from_gcs", type(parsed_data)) + + gcs_payload = StandardLoggingPayload(**parsed_data) + + assert gcs_payload["model"] == "gpt-4o-mini" + assert gcs_payload["messages"] == [{"role": "user", "content": "This is a test"}] + + assert gcs_payload["response_cost"] > 0.0 + + assert gcs_payload["status"] == "success" + + # clean up the object from GCS + await gcs_logger.delete_gcs_object( + object_name=object_name, + standard_callback_dynamic_params=standard_callback_dynamic_params, + ) + + # Request 2 - don't pass gcs_bucket_name in kwargs - ensure 'litellm-testing-bucket' + try: + response = await litellm.acompletion( + model="gpt-4o-mini", + temperature=0.7, + messages=[{"role": "user", "content": "This is a test"}], + max_tokens=10, + user="ishaan-2", + mock_response="Hi!", + ) + except: + pass + + await asyncio.sleep(5) + + # Get the current date + # Get the current date + current_date = datetime.now().strftime("%Y-%m-%d") + standard_callback_dynamic_params = StandardCallbackDynamicParams( + gcs_bucket_name="litellm-testing-bucket" + ) + + # Modify the object_name to include the date-based folder + object_name = f"{current_date}%2F{response.id}" + + print("object_name", object_name) + + # Check if object landed on GCS + object_from_gcs = await gcs_logger.download_gcs_object( + object_name=object_name, + standard_callback_dynamic_params=standard_callback_dynamic_params, + ) + print("object from gcs=", object_from_gcs) + # convert object_from_gcs from bytes to DICT + parsed_data = json.loads(object_from_gcs) + print("object_from_gcs as dict", parsed_data) + + print("type of object_from_gcs", type(parsed_data)) + + gcs_payload = StandardLoggingPayload(**parsed_data) + + assert gcs_payload["model"] == "gpt-4o-mini" + assert gcs_payload["messages"] == [{"role": "user", "content": "This is a test"}] + + assert gcs_payload["response_cost"] > 0.0 + + assert gcs_payload["status"] == "success" + + # clean up the object from GCS + await gcs_logger.delete_gcs_object( + object_name=object_name, + standard_callback_dynamic_params=standard_callback_dynamic_params, + )