(feat) GCS Bucket logging. Allow using IAM auth for logging to GCS (#6628)

* fix gcs bucket auth

* allow iam auth for gcs logging

* test_get_gcs_logging_config_without_service_account
This commit is contained in:
Ishaan Jaff 2024-11-06 17:14:56 -08:00 committed by GitHub
parent 0f8cceb274
commit 0ca50d56a8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 73 additions and 19 deletions

View file

@ -26,10 +26,13 @@ else:
VertexBase = Any
IAM_AUTH_KEY = "IAM_AUTH"
class GCSLoggingConfig(TypedDict):
bucket_name: str
vertex_instance: VertexBase
path_service_account: str
path_service_account: Optional[str]
class GCSBucketLogger(GCSBucketBase):
@ -173,7 +176,7 @@ class GCSBucketLogger(GCSBucketBase):
)
bucket_name: str
path_service_account: str
path_service_account: Optional[str]
if standard_callback_dynamic_params is not None:
verbose_logger.debug("Using dynamic GCS logging")
verbose_logger.debug(
@ -193,10 +196,6 @@ class GCSBucketLogger(GCSBucketBase):
raise ValueError(
"GCS_BUCKET_NAME is not set in the environment, but GCS Bucket is being used as a logging callback. Please set 'GCS_BUCKET_NAME' in the environment."
)
if _path_service_account is None:
raise ValueError(
"GCS_PATH_SERVICE_ACCOUNT is not set in the environment, but GCS Bucket is being used as a logging callback. Please set 'GCS_PATH_SERVICE_ACCOUNT' in the environment."
)
bucket_name = _bucket_name
path_service_account = _path_service_account
vertex_instance = await self.get_or_create_vertex_instance(
@ -208,10 +207,6 @@ class GCSBucketLogger(GCSBucketBase):
raise ValueError(
"GCS_BUCKET_NAME is not set in the environment, but GCS Bucket is being used as a logging callback. Please set 'GCS_BUCKET_NAME' in the environment."
)
if self.path_service_account_json is None:
raise ValueError(
"GCS_PATH_SERVICE_ACCOUNT is not set in the environment, but GCS Bucket is being used as a logging callback. Please set 'GCS_PATH_SERVICE_ACCOUNT' in the environment."
)
bucket_name = self.BUCKET_NAME
path_service_account = self.path_service_account_json
vertex_instance = await self.get_or_create_vertex_instance(
@ -224,7 +219,9 @@ class GCSBucketLogger(GCSBucketBase):
path_service_account=path_service_account,
)
async def get_or_create_vertex_instance(self, credentials: str) -> VertexBase:
async def get_or_create_vertex_instance(
self, credentials: Optional[str]
) -> VertexBase:
"""
This function is used to get the Vertex instance for the GCS Bucket Logger.
It checks if the Vertex instance is already created and cached, if not it creates a new instance and caches it.
@ -233,15 +230,27 @@ class GCSBucketLogger(GCSBucketBase):
VertexBase,
)
if credentials not in self.vertex_instances:
_in_memory_key = self._get_in_memory_key_for_vertex_instance(credentials)
if _in_memory_key not in self.vertex_instances:
vertex_instance = VertexBase()
await vertex_instance._ensure_access_token_async(
credentials=credentials,
project_id=None,
custom_llm_provider="vertex_ai",
)
self.vertex_instances[credentials] = vertex_instance
return self.vertex_instances[credentials]
self.vertex_instances[_in_memory_key] = vertex_instance
return self.vertex_instances[_in_memory_key]
def _get_in_memory_key_for_vertex_instance(self, credentials: Optional[str]) -> str:
"""
Returns key to use for caching the Vertex instance in-memory.
When using Vertex with Key based logging, we need to cache the Vertex instance in-memory.
- If a credentials string is provided, it is used as the key.
- If no credentials string is provided, "IAM_AUTH" is used as the key.
"""
return credentials or IAM_AUTH_KEY
async def download_gcs_object(self, object_name: str, **kwargs):
"""

View file

@ -33,7 +33,7 @@ class GCSBucketBase(CustomLogger):
async def construct_request_headers(
self,
service_account_json: str,
service_account_json: Optional[str],
vertex_instance: Optional[VertexBase] = None,
) -> Dict[str, str]:
from litellm import vertex_chat_completion

View file

@ -5,7 +5,5 @@ model_list:
api_key: os.environ/OPENAI_API_KEY
api_base: https://exampleopenaiendpoint-production.up.railway.app/
general_settings:
alerting: ["slack"]
alerting_threshold: 0.001
litellm_settings:
callbacks: ["gcs_bucket"]

View file

@ -519,3 +519,50 @@ async def test_basic_gcs_logging_per_request_with_no_litellm_callback_set():
object_name=gcs_log_id,
standard_callback_dynamic_params=standard_callback_dynamic_params,
)
@pytest.mark.asyncio
async def test_get_gcs_logging_config_without_service_account():
"""
Test the get_gcs_logging_config works for IAM auth on GCS
1. Key based logging without a service account
2. Default Callback without a service account
"""
# Mock the load_auth function to avoid credential loading issues
# Test 1: With standard_callback_dynamic_params (with service account)
gcs_logger = GCSBucketLogger()
dynamic_params = StandardCallbackDynamicParams(
gcs_bucket_name="dynamic-bucket",
)
config = await gcs_logger.get_gcs_logging_config(
{"standard_callback_dynamic_params": dynamic_params}
)
assert config["bucket_name"] == "dynamic-bucket"
assert config["path_service_account"] is None
assert config["vertex_instance"] is not None
# Test 2: With standard_callback_dynamic_params (without service account - this is IAM auth)
dynamic_params = StandardCallbackDynamicParams(
gcs_bucket_name="dynamic-bucket", gcs_path_service_account=None
)
config = await gcs_logger.get_gcs_logging_config(
{"standard_callback_dynamic_params": dynamic_params}
)
assert config["bucket_name"] == "dynamic-bucket"
assert config["path_service_account"] is None
assert config["vertex_instance"] is not None
# Test 5: With missing bucket name
with pytest.raises(ValueError, match="GCS_BUCKET_NAME is not set"):
_old_gcs_bucket_name = os.environ.get("GCS_BUCKET_NAME")
os.environ.pop("GCS_BUCKET_NAME")
gcs_logger = GCSBucketLogger(bucket_name=None)
await gcs_logger.get_gcs_logging_config({})
if _old_gcs_bucket_name is not None:
os.environ["GCS_BUCKET_NAME"] = _old_gcs_bucket_name