(feat proxy) add key based logging for GCS bucket (#6031)

* init litellm langfuse / gcs credentials in litellm logging obj

* add gcs key based test

* rename vars

* save standard_callback_dynamic_params in model call details

* add working gcs bucket key based logging

* test_basic_gcs_logging_per_request

* linting fix

* add doc on gcs  bucket team based logging
This commit is contained in:
Ishaan Jaff 2024-10-03 02:54:31 -07:00 committed by GitHub
parent 835db6ae98
commit 21e05a0f3e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 495 additions and 142 deletions

View file

@ -201,6 +201,9 @@ Use the `/key/generate` or `/key/update` endpoints to add logging callbacks to a
:::
<Tabs>
<TabItem label="Langfuse" value="langfuse">
```bash
curl -X POST 'http://0.0.0.0:4000/key/generate' \
-H 'Authorization: Bearer sk-1234' \
@ -208,7 +211,7 @@ curl -X POST 'http://0.0.0.0:4000/key/generate' \
-d '{
"metadata": {
"logging": [{
"callback_name": "langfuse", # "otel", "langfuse", "lunary"
"callback_name": "langfuse", # "otel", "gcs_bucket"
"callback_type": "success", # "success", "failure", "success_and_failure"
"callback_vars": {
"langfuse_public_key": "os.environ/LANGFUSE_PUBLIC_KEY", # [RECOMMENDED] reference key in proxy environment
@ -223,6 +226,30 @@ curl -X POST 'http://0.0.0.0:4000/key/generate' \
<iframe width="840" height="500" src="https://www.youtube.com/embed/8iF0Hvwk0YU" frameborder="0" webkitallowfullscreen mozallowfullscreen allowfullscreen></iframe>
</TabItem>
<TabItem label="GCS Bucket" value="gcs_bucket">
```bash
curl -X POST 'http://0.0.0.0:4000/key/generate' \
-H 'Authorization: Bearer sk-1234' \
-H 'Content-Type: application/json' \
-d '{
"metadata": {
"logging": [{
"callback_name": "gcs_bucket", # "otel", "gcs_bucket"
"callback_type": "success", # "success", "failure", "success_and_failure"
"callback_vars": {
"gcs_bucket_name": "my-gcs-bucket",
"gcs_path_service_account": "os.environ/GCS_SERVICE_ACCOUNT"
}
}]
}
}'
```
</TabItem>
</Tabs>
---

View file

@ -2,7 +2,8 @@ import json
import os
import uuid
from datetime import datetime
from typing import Any, Dict, List, Optional, TypedDict, Union
from re import S
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, TypedDict, Union
import httpx
from pydantic import BaseModel, Field
@ -16,13 +17,22 @@ from litellm.litellm_core_utils.logging_utils import (
)
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from litellm.proxy._types import CommonProxyErrors, SpendLogsMetadata, SpendLogsPayload
from litellm.types.utils import StandardLoggingMetadata, StandardLoggingPayload
from litellm.types.utils import (
StandardCallbackDynamicParams,
StandardLoggingMetadata,
StandardLoggingPayload,
)
if TYPE_CHECKING:
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_llm_base import VertexBase
else:
VertexBase = Any
class RequestKwargs(TypedDict):
model: Optional[str]
messages: Optional[List]
optional_params: Optional[Dict[str, Any]]
class GCSLoggingConfig(TypedDict):
bucket_name: str
vertex_instance: VertexBase
path_service_account: str
class GCSBucketLogger(GCSBucketBase):
@ -30,6 +40,7 @@ class GCSBucketLogger(GCSBucketBase):
from litellm.proxy.proxy_server import premium_user
super().__init__(bucket_name=bucket_name)
self.vertex_instances: Dict[str, VertexBase] = {}
if premium_user is not True:
raise ValueError(
f"GCS Bucket logging is a premium feature. Please upgrade to use it. {CommonProxyErrors.not_premium_user.value}"
@ -55,10 +66,14 @@ class GCSBucketLogger(GCSBucketBase):
kwargs,
response_obj,
)
start_time.strftime("%Y-%m-%d %H:%M:%S")
end_time.strftime("%Y-%m-%d %H:%M:%S")
headers = await self.construct_request_headers()
gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
kwargs
)
headers = await self.construct_request_headers(
vertex_instance=gcs_logging_config["vertex_instance"],
service_account_json=gcs_logging_config["path_service_account"],
)
bucket_name = gcs_logging_config["bucket_name"]
logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object", None
@ -76,7 +91,7 @@ class GCSBucketLogger(GCSBucketBase):
object_name = f"{current_date}/{response_obj['id']}"
response = await self.async_httpx_client.post(
headers=headers,
url=f"https://storage.googleapis.com/upload/storage/v1/b/{self.BUCKET_NAME}/o?uploadType=media&name={object_name}",
url=f"https://storage.googleapis.com/upload/storage/v1/b/{bucket_name}/o?uploadType=media&name={object_name}",
data=json_logged_payload,
)
@ -87,7 +102,7 @@ class GCSBucketLogger(GCSBucketBase):
verbose_logger.debug("GCS Bucket status code %s", response.status_code)
verbose_logger.debug("GCS Bucket response.text %s", response.text)
except Exception as e:
verbose_logger.error("GCS Bucket logging error: %s", str(e))
verbose_logger.exception(f"GCS Bucket logging error: {str(e)}")
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
from litellm.proxy.proxy_server import premium_user
@ -103,9 +118,14 @@ class GCSBucketLogger(GCSBucketBase):
response_obj,
)
start_time.strftime("%Y-%m-%d %H:%M:%S")
end_time.strftime("%Y-%m-%d %H:%M:%S")
headers = await self.construct_request_headers()
gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
kwargs
)
headers = await self.construct_request_headers(
vertex_instance=gcs_logging_config["vertex_instance"],
service_account_json=gcs_logging_config["path_service_account"],
)
bucket_name = gcs_logging_config["bucket_name"]
logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object", None
@ -130,7 +150,7 @@ class GCSBucketLogger(GCSBucketBase):
response = await self.async_httpx_client.post(
headers=headers,
url=f"https://storage.googleapis.com/upload/storage/v1/b/{self.BUCKET_NAME}/o?uploadType=media&name={object_name}",
url=f"https://storage.googleapis.com/upload/storage/v1/b/{bucket_name}/o?uploadType=media&name={object_name}",
data=json_logged_payload,
)
@ -141,4 +161,146 @@ class GCSBucketLogger(GCSBucketBase):
verbose_logger.debug("GCS Bucket status code %s", response.status_code)
verbose_logger.debug("GCS Bucket response.text %s", response.text)
except Exception as e:
verbose_logger.error("GCS Bucket logging error: %s", str(e))
verbose_logger.exception(f"GCS Bucket logging error: {str(e)}")
async def get_gcs_logging_config(
self, kwargs: Optional[Dict[str, Any]] = {}
) -> GCSLoggingConfig:
"""
This function is used to get the GCS logging config for the GCS Bucket Logger.
It checks if the dynamic parameters are provided in the kwargs and uses them to get the GCS logging config.
If no dynamic parameters are provided, it uses the default values.
"""
if kwargs is None:
kwargs = {}
standard_callback_dynamic_params: Optional[StandardCallbackDynamicParams] = (
kwargs.get("standard_callback_dynamic_params", None)
)
if standard_callback_dynamic_params is not None:
verbose_logger.debug("Using dynamic GCS logging")
verbose_logger.debug(
"standard_callback_dynamic_params: %s", standard_callback_dynamic_params
)
bucket_name: str = (
standard_callback_dynamic_params.get("gcs_bucket_name", None)
or self.BUCKET_NAME
)
path_service_account: str = (
standard_callback_dynamic_params.get("gcs_path_service_account", None)
or self.path_service_account_json
)
vertex_instance = await self.get_or_create_vertex_instance(
credentials=path_service_account
)
else:
# If no dynamic parameters, use the default instance
bucket_name = self.BUCKET_NAME
path_service_account = self.path_service_account_json
vertex_instance = await self.get_or_create_vertex_instance(
credentials=path_service_account
)
return GCSLoggingConfig(
bucket_name=bucket_name,
vertex_instance=vertex_instance,
path_service_account=path_service_account,
)
async def get_or_create_vertex_instance(self, credentials: str) -> VertexBase:
"""
This function is used to get the Vertex instance for the GCS Bucket Logger.
It checks if the Vertex instance is already created and cached, if not it creates a new instance and caches it.
"""
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_llm_base import (
VertexBase,
)
if credentials not in self.vertex_instances:
vertex_instance = VertexBase()
await vertex_instance._ensure_access_token_async(
credentials=credentials,
project_id=None,
custom_llm_provider="vertex_ai",
)
self.vertex_instances[credentials] = vertex_instance
return self.vertex_instances[credentials]
async def download_gcs_object(self, object_name: str, **kwargs):
"""
Download an object from GCS.
https://cloud.google.com/storage/docs/downloading-objects#download-object-json
"""
try:
gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
kwargs=kwargs
)
headers = await self.construct_request_headers(
vertex_instance=gcs_logging_config["vertex_instance"],
service_account_json=gcs_logging_config["path_service_account"],
)
bucket_name = gcs_logging_config["bucket_name"]
url = f"https://storage.googleapis.com/storage/v1/b/{bucket_name}/o/{object_name}?alt=media"
# Send the GET request to download the object
response = await self.async_httpx_client.get(url=url, headers=headers)
if response.status_code != 200:
verbose_logger.error(
"GCS object download error: %s", str(response.text)
)
return None
verbose_logger.debug(
"GCS object download response status code: %s", response.status_code
)
# Return the content of the downloaded object
return response.content
except Exception as e:
verbose_logger.error("GCS object download error: %s", str(e))
return None
async def delete_gcs_object(self, object_name: str, **kwargs):
"""
Delete an object from GCS.
"""
try:
gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
kwargs=kwargs
)
headers = await self.construct_request_headers(
vertex_instance=gcs_logging_config["vertex_instance"],
service_account_json=gcs_logging_config["path_service_account"],
)
bucket_name = gcs_logging_config["bucket_name"]
url = f"https://storage.googleapis.com/storage/v1/b/{bucket_name}/o/{object_name}"
# Send the DELETE request to delete the object
response = await self.async_httpx_client.delete(url=url, headers=headers)
if (response.status_code != 200) or (response.status_code != 204):
verbose_logger.error(
"GCS object delete error: %s, status code: %s",
str(response.text),
response.status_code,
)
return None
verbose_logger.debug(
"GCS object delete response status code: %s, response: %s",
response.status_code,
response.text,
)
# Return the content of the downloaded object
return response.text
except Exception as e:
verbose_logger.error("GCS object download error: %s", str(e))
return None

View file

@ -2,7 +2,7 @@ import json
import os
import uuid
from datetime import datetime
from typing import Any, Dict, List, Optional, TypedDict, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypedDict, Union
import httpx
from pydantic import BaseModel, Field
@ -18,37 +18,48 @@ from litellm.llms.custom_httpx.http_handler import (
httpxSpecialProvider,
)
if TYPE_CHECKING:
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_llm_base import VertexBase
else:
VertexBase = Any
class GCSBucketBase(CustomLogger):
def __init__(self, bucket_name: Optional[str] = None) -> None:
from litellm.proxy.proxy_server import premium_user
self.async_httpx_client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.LoggingCallback
)
self.path_service_account_json = os.getenv("GCS_PATH_SERVICE_ACCOUNT", None)
self.BUCKET_NAME = bucket_name or os.getenv("GCS_BUCKET_NAME", None)
if self.BUCKET_NAME is None:
_path_service_account = os.getenv("GCS_PATH_SERVICE_ACCOUNT")
_bucket_name = bucket_name or os.getenv("GCS_BUCKET_NAME")
if _path_service_account is None:
raise ValueError("GCS_PATH_SERVICE_ACCOUNT environment variable is not set")
if _bucket_name is None:
raise ValueError(
"GCS_BUCKET_NAME is not set in the environment, but GCS Bucket is being used as a logging callback. Please set 'GCS_BUCKET_NAME' in the environment."
)
self.path_service_account_json: str = _path_service_account
self.BUCKET_NAME: str = _bucket_name
async def construct_request_headers(self) -> Dict[str, str]:
async def construct_request_headers(
self,
service_account_json: str,
vertex_instance: Optional[VertexBase] = None,
) -> Dict[str, str]:
from litellm import vertex_chat_completion
_auth_header, vertex_project = (
await vertex_chat_completion._ensure_access_token_async(
credentials=self.path_service_account_json,
project_id=None,
custom_llm_provider="vertex_ai",
)
if vertex_instance is None:
vertex_instance = vertex_chat_completion
_auth_header, vertex_project = await vertex_instance._ensure_access_token_async(
credentials=service_account_json,
project_id=None,
custom_llm_provider="vertex_ai",
)
auth_header, _ = vertex_chat_completion._get_token_and_url(
auth_header, _ = vertex_instance._get_token_and_url(
model="gcs-bucket",
auth_header=_auth_header,
vertex_credentials=self.path_service_account_json,
vertex_credentials=service_account_json,
vertex_project=vertex_project,
vertex_location=None,
gemini_api_key=None,
@ -91,65 +102,3 @@ class GCSBucketBase(CustomLogger):
}
return headers
async def download_gcs_object(self, object_name):
"""
Download an object from GCS.
https://cloud.google.com/storage/docs/downloading-objects#download-object-json
"""
try:
headers = await self.construct_request_headers()
url = f"https://storage.googleapis.com/storage/v1/b/{self.BUCKET_NAME}/o/{object_name}?alt=media"
# Send the GET request to download the object
response = await self.async_httpx_client.get(url=url, headers=headers)
if response.status_code != 200:
verbose_logger.error(
"GCS object download error: %s", str(response.text)
)
return None
verbose_logger.debug(
"GCS object download response status code: %s", response.status_code
)
# Return the content of the downloaded object
return response.content
except Exception as e:
verbose_logger.error("GCS object download error: %s", str(e))
return None
async def delete_gcs_object(self, object_name):
"""
Delete an object from GCS.
"""
try:
headers = await self.construct_request_headers()
url = f"https://storage.googleapis.com/storage/v1/b/{self.BUCKET_NAME}/o/{object_name}"
# Send the DELETE request to delete the object
response = await self.async_httpx_client.delete(url=url, headers=headers)
if (response.status_code != 200) or (response.status_code != 204):
verbose_logger.error(
"GCS object delete error: %s, status code: %s",
str(response.text),
response.status_code,
)
return None
verbose_logger.debug(
"GCS object delete response status code: %s, response: %s",
response.status_code,
response.text,
)
# Return the content of the downloaded object
return response.text
except Exception as e:
verbose_logger.error("GCS object download error: %s", str(e))
return None

View file

@ -40,6 +40,7 @@ from litellm.types.utils import (
EmbeddingResponse,
ImageResponse,
ModelResponse,
StandardCallbackDynamicParams,
StandardLoggingHiddenParams,
StandardLoggingMetadata,
StandardLoggingModelCostFailureDebugInformation,
@ -200,9 +201,7 @@ class Logging:
dynamic_success_callbacks=None,
dynamic_failure_callbacks=None,
dynamic_async_success_callbacks=None,
langfuse_public_key=None,
langfuse_secret=None,
langfuse_host=None,
kwargs: Optional[Dict] = None,
):
if messages is not None:
if isinstance(messages, str):
@ -225,10 +224,14 @@ class Logging:
self.call_type = call_type
self.litellm_call_id = litellm_call_id
self.function_id = function_id
self.streaming_chunks = [] # for generating complete stream response
self.sync_streaming_chunks = [] # for generating complete stream response
self.model_call_details = {}
self.dynamic_input_callbacks = [] # [TODO] callbacks set for just that call
self.streaming_chunks: List[Any] = [] # for generating complete stream response
self.sync_streaming_chunks: List[Any] = (
[]
) # for generating complete stream response
self.model_call_details: Dict[Any, Any] = {}
self.dynamic_input_callbacks: List[Any] = (
[]
) # [TODO] callbacks set for just that call
self.dynamic_failure_callbacks = dynamic_failure_callbacks
self.dynamic_success_callbacks = (
dynamic_success_callbacks # callbacks set for just that call
@ -236,13 +239,27 @@ class Logging:
self.dynamic_async_success_callbacks = (
dynamic_async_success_callbacks # callbacks set for just that call
)
## DYNAMIC LANGFUSE KEYS ##
self.langfuse_public_key = langfuse_public_key
self.langfuse_secret = langfuse_secret
self.langfuse_host = langfuse_host
## DYNAMIC LANGFUSE / GCS / logging callback KEYS ##
self.standard_callback_dynamic_params: StandardCallbackDynamicParams = (
self.initialize_standard_callback_dynamic_params(kwargs)
)
## TIME TO FIRST TOKEN LOGGING ##
self.completion_start_time: Optional[datetime.datetime] = None
def initialize_standard_callback_dynamic_params(
self, kwargs: Optional[Dict] = None
) -> StandardCallbackDynamicParams:
standard_callback_dynamic_params = StandardCallbackDynamicParams()
if kwargs:
_supported_callback_params = (
StandardCallbackDynamicParams.__annotations__.keys()
)
for param in _supported_callback_params:
if param in kwargs:
standard_callback_dynamic_params[param] = kwargs.pop(param) # type: ignore
return standard_callback_dynamic_params
def update_environment_variables(
self, model, user, optional_params, litellm_params, **additional_params
):
@ -264,6 +281,7 @@ class Logging:
"call_type": str(self.call_type),
"litellm_call_id": self.litellm_call_id,
"completion_start_time": self.completion_start_time,
"standard_callback_dynamic_params": self.standard_callback_dynamic_params,
**self.optional_params,
**additional_params,
}
@ -999,23 +1017,46 @@ class Logging:
temp_langfuse_logger = langFuseLogger
if langFuseLogger is None or (
(
self.langfuse_public_key is not None
and self.langfuse_public_key
self.standard_callback_dynamic_params.get(
"langfuse_public_key"
)
is not None
and self.standard_callback_dynamic_params.get(
"langfuse_public_key"
)
!= langFuseLogger.public_key
)
or (
self.langfuse_secret is not None
and self.langfuse_secret != langFuseLogger.secret_key
self.standard_callback_dynamic_params.get(
"langfuse_secret"
)
is not None
and self.standard_callback_dynamic_params.get(
"langfuse_secret"
)
!= langFuseLogger.secret_key
)
or (
self.langfuse_host is not None
and self.langfuse_host != langFuseLogger.langfuse_host
self.standard_callback_dynamic_params.get(
"langfuse_host"
)
is not None
and self.standard_callback_dynamic_params.get(
"langfuse_host"
)
!= langFuseLogger.langfuse_host
)
):
credentials = {
"langfuse_public_key": self.langfuse_public_key,
"langfuse_secret": self.langfuse_secret,
"langfuse_host": self.langfuse_host,
"langfuse_public_key": self.standard_callback_dynamic_params.get(
"langfuse_public_key"
),
"langfuse_secret": self.standard_callback_dynamic_params.get(
"langfuse_secret"
),
"langfuse_host": self.standard_callback_dynamic_params.get(
"langfuse_host"
),
}
temp_langfuse_logger = (
in_memory_dynamic_logger_cache.get_cache(
@ -1024,9 +1065,15 @@ class Logging:
)
if temp_langfuse_logger is None:
temp_langfuse_logger = LangFuseLogger(
langfuse_public_key=self.langfuse_public_key,
langfuse_secret=self.langfuse_secret,
langfuse_host=self.langfuse_host,
langfuse_public_key=self.standard_callback_dynamic_params.get(
"langfuse_public_key"
),
langfuse_secret=self.standard_callback_dynamic_params.get(
"langfuse_secret"
),
langfuse_host=self.standard_callback_dynamic_params.get(
"langfuse_host"
),
)
in_memory_dynamic_logger_cache.set_cache(
credentials=credentials,
@ -1838,24 +1885,46 @@ class Logging:
# this only logs streaming once, complete_streaming_response exists i.e when stream ends
if langFuseLogger is None or (
(
self.langfuse_public_key is not None
and self.langfuse_public_key
self.standard_callback_dynamic_params.get(
"langfuse_public_key"
)
is not None
and self.standard_callback_dynamic_params.get(
"langfuse_public_key"
)
!= langFuseLogger.public_key
)
or (
self.langfuse_public_key is not None
and self.langfuse_public_key
self.standard_callback_dynamic_params.get(
"langfuse_public_key"
)
is not None
and self.standard_callback_dynamic_params.get(
"langfuse_public_key"
)
!= langFuseLogger.public_key
)
or (
self.langfuse_host is not None
and self.langfuse_host != langFuseLogger.langfuse_host
self.standard_callback_dynamic_params.get(
"langfuse_host"
)
is not None
and self.standard_callback_dynamic_params.get(
"langfuse_host"
)
!= langFuseLogger.langfuse_host
)
):
langFuseLogger = LangFuseLogger(
langfuse_public_key=self.langfuse_public_key,
langfuse_secret=self.langfuse_secret,
langfuse_host=self.langfuse_host,
langfuse_public_key=self.standard_callback_dynamic_params.get(
"langfuse_public_key"
),
langfuse_secret=self.standard_callback_dynamic_params.get(
"langfuse_secret"
),
langfuse_host=self.standard_callback_dynamic_params.get(
"langfuse_host"
),
)
_response = langFuseLogger.log_event(
start_time=start_time,
@ -1992,22 +2061,34 @@ class Logging:
if service_name == "langfuse":
if langFuseLogger is None or (
(
self.langfuse_public_key is not None
and self.langfuse_public_key != langFuseLogger.public_key
self.standard_callback_dynamic_params.get("langfuse_public_key")
is not None
and self.standard_callback_dynamic_params.get("langfuse_public_key")
!= langFuseLogger.public_key
)
or (
self.langfuse_public_key is not None
and self.langfuse_public_key != langFuseLogger.public_key
self.standard_callback_dynamic_params.get("langfuse_public_key")
is not None
and self.standard_callback_dynamic_params.get("langfuse_public_key")
!= langFuseLogger.public_key
)
or (
self.langfuse_host is not None
and self.langfuse_host != langFuseLogger.langfuse_host
self.standard_callback_dynamic_params.get("langfuse_host")
is not None
and self.standard_callback_dynamic_params.get("langfuse_host")
!= langFuseLogger.langfuse_host
)
):
return LangFuseLogger(
langfuse_public_key=self.langfuse_public_key,
langfuse_secret=self.langfuse_secret,
langfuse_host=self.langfuse_host,
langfuse_public_key=self.standard_callback_dynamic_params.get(
"langfuse_public_key"
),
langfuse_secret=self.standard_callback_dynamic_params.get(
"langfuse_secret"
),
langfuse_host=self.standard_callback_dynamic_params.get(
"langfuse_host"
),
)
return langFuseLogger

View file

@ -1365,3 +1365,11 @@ OPENAI_RESPONSE_HEADERS = [
"x-ratelimit-reset-requests",
"x-ratelimit-reset-tokens",
]
class StandardCallbackDynamicParams(TypedDict, total=False):
langfuse_public_key: Optional[str]
langfuse_secret: Optional[str]
langfuse_host: Optional[str]
gcs_bucket_name: Optional[str]
gcs_path_service_account: Optional[str]

View file

@ -561,13 +561,11 @@ def function_setup(
dynamic_success_callbacks=dynamic_success_callbacks,
dynamic_failure_callbacks=dynamic_failure_callbacks,
dynamic_async_success_callbacks=dynamic_async_success_callbacks,
langfuse_public_key=kwargs.pop("langfuse_public_key", None),
langfuse_secret=kwargs.pop("langfuse_secret", None)
or kwargs.pop("langfuse_secret_key", None),
langfuse_host=kwargs.pop("langfuse_host", None),
kwargs=kwargs,
)
## check if metadata is passed in
litellm_params = {"api_base": ""}
litellm_params: Dict[str, Any] = {"api_base": ""}
if "metadata" in kwargs:
litellm_params["metadata"] = kwargs["metadata"]
logging_obj.update_environment_variables(

View file

@ -17,6 +17,7 @@ import litellm
from litellm import completion
from litellm._logging import verbose_logger
from litellm.integrations.gcs_bucket import GCSBucketLogger, StandardLoggingPayload
from litellm.types.utils import StandardCallbackDynamicParams
verbose_logger.setLevel(logging.DEBUG)
@ -263,3 +264,130 @@ async def test_basic_gcs_logger_failure():
# Delete Object from GCS
print("deleting object from GCS")
await gcs_logger.delete_gcs_object(object_name=object_name)
@pytest.mark.asyncio
async def test_basic_gcs_logging_per_request():
"""
Test GCS Bucket logging per request
Request 1 - pass gcs_bucket_name in kwargs
Request 2 - don't pass gcs_bucket_name in kwargs - ensure 'litellm-testing-bucket'
"""
import logging
from litellm._logging import verbose_logger
verbose_logger.setLevel(logging.DEBUG)
load_vertex_ai_credentials()
gcs_logger = GCSBucketLogger()
print("GCSBucketLogger", gcs_logger)
litellm.callbacks = [gcs_logger]
GCS_BUCKET_NAME = "key-logging-project1"
standard_callback_dynamic_params: StandardCallbackDynamicParams = (
StandardCallbackDynamicParams(gcs_bucket_name=GCS_BUCKET_NAME)
)
try:
response = await litellm.acompletion(
model="gpt-4o-mini",
temperature=0.7,
messages=[{"role": "user", "content": "This is a test"}],
max_tokens=10,
user="ishaan-2",
gcs_bucket_name=GCS_BUCKET_NAME,
)
except:
pass
await asyncio.sleep(5)
# Get the current date
# Get the current date
current_date = datetime.now().strftime("%Y-%m-%d")
# Modify the object_name to include the date-based folder
object_name = f"{current_date}%2F{response.id}"
print("object_name", object_name)
# Check if object landed on GCS
object_from_gcs = await gcs_logger.download_gcs_object(
object_name=object_name,
standard_callback_dynamic_params=standard_callback_dynamic_params,
)
print("object from gcs=", object_from_gcs)
# convert object_from_gcs from bytes to DICT
parsed_data = json.loads(object_from_gcs)
print("object_from_gcs as dict", parsed_data)
print("type of object_from_gcs", type(parsed_data))
gcs_payload = StandardLoggingPayload(**parsed_data)
assert gcs_payload["model"] == "gpt-4o-mini"
assert gcs_payload["messages"] == [{"role": "user", "content": "This is a test"}]
assert gcs_payload["response_cost"] > 0.0
assert gcs_payload["status"] == "success"
# clean up the object from GCS
await gcs_logger.delete_gcs_object(
object_name=object_name,
standard_callback_dynamic_params=standard_callback_dynamic_params,
)
# Request 2 - don't pass gcs_bucket_name in kwargs - ensure 'litellm-testing-bucket'
try:
response = await litellm.acompletion(
model="gpt-4o-mini",
temperature=0.7,
messages=[{"role": "user", "content": "This is a test"}],
max_tokens=10,
user="ishaan-2",
mock_response="Hi!",
)
except:
pass
await asyncio.sleep(5)
# Get the current date
# Get the current date
current_date = datetime.now().strftime("%Y-%m-%d")
standard_callback_dynamic_params = StandardCallbackDynamicParams(
gcs_bucket_name="litellm-testing-bucket"
)
# Modify the object_name to include the date-based folder
object_name = f"{current_date}%2F{response.id}"
print("object_name", object_name)
# Check if object landed on GCS
object_from_gcs = await gcs_logger.download_gcs_object(
object_name=object_name,
standard_callback_dynamic_params=standard_callback_dynamic_params,
)
print("object from gcs=", object_from_gcs)
# convert object_from_gcs from bytes to DICT
parsed_data = json.loads(object_from_gcs)
print("object_from_gcs as dict", parsed_data)
print("type of object_from_gcs", type(parsed_data))
gcs_payload = StandardLoggingPayload(**parsed_data)
assert gcs_payload["model"] == "gpt-4o-mini"
assert gcs_payload["messages"] == [{"role": "user", "content": "This is a test"}]
assert gcs_payload["response_cost"] > 0.0
assert gcs_payload["status"] == "success"
# clean up the object from GCS
await gcs_logger.delete_gcs_object(
object_name=object_name,
standard_callback_dynamic_params=standard_callback_dynamic_params,
)