From 9a1c2f091c469b1dceda65b440aa6a72d9d26842 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 25 Nov 2024 23:55:56 -0800 Subject: [PATCH] add unit testing for vtx pass through auth --- .../vertex_ai_endpoints/vertex_endpoints.py | 5 +- .../test_unit_test_vertex_pass_through.py | 85 +++++++++++++++++++ 2 files changed, 86 insertions(+), 4 deletions(-) diff --git a/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py b/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py index cc0e7e208..271e8992c 100644 --- a/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py +++ b/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py @@ -52,7 +52,7 @@ def _get_vertex_env_vars() -> VertexPassThroughCredentials: ) -def set_default_vertex_config(config: Optional[dict]): +def set_default_vertex_config(config: Optional[dict] = None): """Sets vertex configuration from provided config and/or environment variables Args: @@ -70,9 +70,6 @@ def set_default_vertex_config(config: Optional[dict]): default_vertex_config = _get_vertex_env_vars() return - if not isinstance(config, dict): - raise ValueError("invalid config, vertex default config must be a dictionary") - if isinstance(config, dict): for key, value in config.items(): if isinstance(value, str) and value.startswith("os.environ/"): diff --git a/tests/pass_through_unit_tests/test_unit_test_vertex_pass_through.py b/tests/pass_through_unit_tests/test_unit_test_vertex_pass_through.py index a7b668813..4c66f6993 100644 --- a/tests/pass_through_unit_tests/test_unit_test_vertex_pass_through.py +++ b/tests/pass_through_unit_tests/test_unit_test_vertex_pass_through.py @@ -18,6 +18,10 @@ from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import ( get_litellm_virtual_key, vertex_proxy_route, + _get_vertex_env_vars, + set_default_vertex_config, + VertexPassThroughCredentials, + default_vertex_config, ) @@ -82,3 +86,84 @@ async def test_vertex_proxy_route_api_key_auth(): mock_auth.assert_called_once() call_args = mock_auth.call_args[1] assert call_args["api_key"] == "Bearer test-key-123" + + +@pytest.mark.asyncio +async def test_get_vertex_env_vars(): + """Test that _get_vertex_env_vars correctly reads environment variables""" + # Set environment variables for the test + os.environ["DEFAULT_VERTEXAI_PROJECT"] = "test-project-123" + os.environ["DEFAULT_VERTEXAI_LOCATION"] = "us-central1" + os.environ["DEFAULT_GOOGLE_APPLICATION_CREDENTIALS"] = "/path/to/creds" + + try: + result = _get_vertex_env_vars() + print(result) + + # Verify the result + assert isinstance(result, VertexPassThroughCredentials) + assert result.vertex_project == "test-project-123" + assert result.vertex_location == "us-central1" + assert result.vertex_credentials == "/path/to/creds" + + finally: + # Clean up environment variables + del os.environ["DEFAULT_VERTEXAI_PROJECT"] + del os.environ["DEFAULT_VERTEXAI_LOCATION"] + del os.environ["DEFAULT_GOOGLE_APPLICATION_CREDENTIALS"] + + +@pytest.mark.asyncio +async def test_set_default_vertex_config(): + """Test set_default_vertex_config with various inputs""" + # Test with None config - set environment variables first + os.environ["DEFAULT_VERTEXAI_PROJECT"] = "env-project" + os.environ["DEFAULT_VERTEXAI_LOCATION"] = "env-location" + os.environ["DEFAULT_GOOGLE_APPLICATION_CREDENTIALS"] = "env-creds" + os.environ["GOOGLE_CREDS"] = "secret-creds" + + try: + # Test with None config + set_default_vertex_config() + from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import ( + default_vertex_config, + ) + + assert default_vertex_config.vertex_project == "env-project" + assert default_vertex_config.vertex_location == "env-location" + assert default_vertex_config.vertex_credentials == "env-creds" + + # Test with valid config.yaml settings on vertex_config + test_config = { + "vertex_project": "my-project-123", + "vertex_location": "us-central1", + "vertex_credentials": "path/to/creds", + } + set_default_vertex_config(test_config) + from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import ( + default_vertex_config, + ) + + assert default_vertex_config.vertex_project == "my-project-123" + assert default_vertex_config.vertex_location == "us-central1" + assert default_vertex_config.vertex_credentials == "path/to/creds" + + # Test with environment variable reference + test_config = { + "vertex_project": "my-project-123", + "vertex_location": "us-central1", + "vertex_credentials": "os.environ/GOOGLE_CREDS", + } + set_default_vertex_config(test_config) + from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import ( + default_vertex_config, + ) + + assert default_vertex_config.vertex_credentials == "secret-creds" + + finally: + # Clean up environment variables + del os.environ["DEFAULT_VERTEXAI_PROJECT"] + del os.environ["DEFAULT_VERTEXAI_LOCATION"] + del os.environ["DEFAULT_GOOGLE_APPLICATION_CREDENTIALS"] + del os.environ["GOOGLE_CREDS"]