(feat) GCS Bucket logging. Allow using IAM auth for logging to GCS (#6628)

* fix gcs bucket auth

* allow iam auth for gcs logging

* test_get_gcs_logging_config_without_service_account
This commit is contained in:
Ishaan Jaff 2024-11-06 17:14:56 -08:00 committed by GitHub
parent 0f8cceb274
commit 0ca50d56a8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 73 additions and 19 deletions

View file

@ -26,10 +26,13 @@ else:
VertexBase = Any VertexBase = Any
IAM_AUTH_KEY = "IAM_AUTH"
class GCSLoggingConfig(TypedDict): class GCSLoggingConfig(TypedDict):
bucket_name: str bucket_name: str
vertex_instance: VertexBase vertex_instance: VertexBase
path_service_account: str path_service_account: Optional[str]
class GCSBucketLogger(GCSBucketBase): class GCSBucketLogger(GCSBucketBase):
@ -173,7 +176,7 @@ class GCSBucketLogger(GCSBucketBase):
) )
bucket_name: str bucket_name: str
path_service_account: str path_service_account: Optional[str]
if standard_callback_dynamic_params is not None: if standard_callback_dynamic_params is not None:
verbose_logger.debug("Using dynamic GCS logging") verbose_logger.debug("Using dynamic GCS logging")
verbose_logger.debug( verbose_logger.debug(
@ -193,10 +196,6 @@ class GCSBucketLogger(GCSBucketBase):
raise ValueError( raise ValueError(
"GCS_BUCKET_NAME is not set in the environment, but GCS Bucket is being used as a logging callback. Please set 'GCS_BUCKET_NAME' in the environment." "GCS_BUCKET_NAME is not set in the environment, but GCS Bucket is being used as a logging callback. Please set 'GCS_BUCKET_NAME' in the environment."
) )
if _path_service_account is None:
raise ValueError(
"GCS_PATH_SERVICE_ACCOUNT is not set in the environment, but GCS Bucket is being used as a logging callback. Please set 'GCS_PATH_SERVICE_ACCOUNT' in the environment."
)
bucket_name = _bucket_name bucket_name = _bucket_name
path_service_account = _path_service_account path_service_account = _path_service_account
vertex_instance = await self.get_or_create_vertex_instance( vertex_instance = await self.get_or_create_vertex_instance(
@ -208,10 +207,6 @@ class GCSBucketLogger(GCSBucketBase):
raise ValueError( raise ValueError(
"GCS_BUCKET_NAME is not set in the environment, but GCS Bucket is being used as a logging callback. Please set 'GCS_BUCKET_NAME' in the environment." "GCS_BUCKET_NAME is not set in the environment, but GCS Bucket is being used as a logging callback. Please set 'GCS_BUCKET_NAME' in the environment."
) )
if self.path_service_account_json is None:
raise ValueError(
"GCS_PATH_SERVICE_ACCOUNT is not set in the environment, but GCS Bucket is being used as a logging callback. Please set 'GCS_PATH_SERVICE_ACCOUNT' in the environment."
)
bucket_name = self.BUCKET_NAME bucket_name = self.BUCKET_NAME
path_service_account = self.path_service_account_json path_service_account = self.path_service_account_json
vertex_instance = await self.get_or_create_vertex_instance( vertex_instance = await self.get_or_create_vertex_instance(
@ -224,7 +219,9 @@ class GCSBucketLogger(GCSBucketBase):
path_service_account=path_service_account, path_service_account=path_service_account,
) )
async def get_or_create_vertex_instance(self, credentials: str) -> VertexBase: async def get_or_create_vertex_instance(
self, credentials: Optional[str]
) -> VertexBase:
""" """
This function is used to get the Vertex instance for the GCS Bucket Logger. This function is used to get the Vertex instance for the GCS Bucket Logger.
It checks if the Vertex instance is already created and cached, if not it creates a new instance and caches it. It checks if the Vertex instance is already created and cached, if not it creates a new instance and caches it.
@ -233,15 +230,27 @@ class GCSBucketLogger(GCSBucketBase):
VertexBase, VertexBase,
) )
if credentials not in self.vertex_instances: _in_memory_key = self._get_in_memory_key_for_vertex_instance(credentials)
if _in_memory_key not in self.vertex_instances:
vertex_instance = VertexBase() vertex_instance = VertexBase()
await vertex_instance._ensure_access_token_async( await vertex_instance._ensure_access_token_async(
credentials=credentials, credentials=credentials,
project_id=None, project_id=None,
custom_llm_provider="vertex_ai", custom_llm_provider="vertex_ai",
) )
self.vertex_instances[credentials] = vertex_instance self.vertex_instances[_in_memory_key] = vertex_instance
return self.vertex_instances[credentials] return self.vertex_instances[_in_memory_key]
def _get_in_memory_key_for_vertex_instance(self, credentials: Optional[str]) -> str:
"""
Returns key to use for caching the Vertex instance in-memory.
When using Vertex with Key based logging, we need to cache the Vertex instance in-memory.
- If a credentials string is provided, it is used as the key.
- If no credentials string is provided, "IAM_AUTH" is used as the key.
"""
return credentials or IAM_AUTH_KEY
async def download_gcs_object(self, object_name: str, **kwargs): async def download_gcs_object(self, object_name: str, **kwargs):
""" """

View file

@ -33,7 +33,7 @@ class GCSBucketBase(CustomLogger):
async def construct_request_headers( async def construct_request_headers(
self, self,
service_account_json: str, service_account_json: Optional[str],
vertex_instance: Optional[VertexBase] = None, vertex_instance: Optional[VertexBase] = None,
) -> Dict[str, str]: ) -> Dict[str, str]:
from litellm import vertex_chat_completion from litellm import vertex_chat_completion

View file

@ -5,7 +5,5 @@ model_list:
api_key: os.environ/OPENAI_API_KEY api_key: os.environ/OPENAI_API_KEY
api_base: https://exampleopenaiendpoint-production.up.railway.app/ api_base: https://exampleopenaiendpoint-production.up.railway.app/
litellm_settings:
general_settings: callbacks: ["gcs_bucket"]
alerting: ["slack"]
alerting_threshold: 0.001

View file

@ -519,3 +519,50 @@ async def test_basic_gcs_logging_per_request_with_no_litellm_callback_set():
object_name=gcs_log_id, object_name=gcs_log_id,
standard_callback_dynamic_params=standard_callback_dynamic_params, standard_callback_dynamic_params=standard_callback_dynamic_params,
) )
@pytest.mark.asyncio
async def test_get_gcs_logging_config_without_service_account():
"""
Test the get_gcs_logging_config works for IAM auth on GCS
1. Key based logging without a service account
2. Default Callback without a service account
"""
# Mock the load_auth function to avoid credential loading issues
# Test 1: With standard_callback_dynamic_params (with service account)
gcs_logger = GCSBucketLogger()
dynamic_params = StandardCallbackDynamicParams(
gcs_bucket_name="dynamic-bucket",
)
config = await gcs_logger.get_gcs_logging_config(
{"standard_callback_dynamic_params": dynamic_params}
)
assert config["bucket_name"] == "dynamic-bucket"
assert config["path_service_account"] is None
assert config["vertex_instance"] is not None
# Test 2: With standard_callback_dynamic_params (without service account - this is IAM auth)
dynamic_params = StandardCallbackDynamicParams(
gcs_bucket_name="dynamic-bucket", gcs_path_service_account=None
)
config = await gcs_logger.get_gcs_logging_config(
{"standard_callback_dynamic_params": dynamic_params}
)
assert config["bucket_name"] == "dynamic-bucket"
assert config["path_service_account"] is None
assert config["vertex_instance"] is not None
# Test 5: With missing bucket name
with pytest.raises(ValueError, match="GCS_BUCKET_NAME is not set"):
_old_gcs_bucket_name = os.environ.get("GCS_BUCKET_NAME")
os.environ.pop("GCS_BUCKET_NAME")
gcs_logger = GCSBucketLogger(bucket_name=None)
await gcs_logger.get_gcs_logging_config({})
if _old_gcs_bucket_name is not None:
os.environ["GCS_BUCKET_NAME"] = _old_gcs_bucket_name