diff --git a/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py b/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py index 1a0d09a88..cc0e7e208 100644 --- a/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py +++ b/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py @@ -28,14 +28,46 @@ from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.proxy.pass_through_endpoints.pass_through_endpoints import ( create_pass_through_route, ) +from litellm.secret_managers.main import get_secret_str +from litellm.types.passthrough_endpoints.vertex_ai import * router = APIRouter() -default_vertex_config = None + +default_vertex_config: VertexPassThroughCredentials = VertexPassThroughCredentials() -def set_default_vertex_config(config): +def _get_vertex_env_vars() -> VertexPassThroughCredentials: + """ + Helper to get vertex pass through config from environment variables + + The following environment variables are used: + - DEFAULT_VERTEXAI_PROJECT (project id) + - DEFAULT_VERTEXAI_LOCATION (location) + - DEFAULT_GOOGLE_APPLICATION_CREDENTIALS (path to credentials file) + """ + return VertexPassThroughCredentials( + vertex_project=get_secret_str("DEFAULT_VERTEXAI_PROJECT"), + vertex_location=get_secret_str("DEFAULT_VERTEXAI_LOCATION"), + vertex_credentials=get_secret_str("DEFAULT_GOOGLE_APPLICATION_CREDENTIALS"), + ) + + +def set_default_vertex_config(config: Optional[dict]): + """Sets vertex configuration from provided config and/or environment variables + + Args: + config (Optional[dict]): Configuration dictionary + Example: { + "vertex_project": "my-project-123", + "vertex_location": "us-central1", + "vertex_credentials": "os.environ/GOOGLE_CREDS" + } + """ global default_vertex_config + + # Initialize config dictionary if None if config is None: + default_vertex_config = _get_vertex_env_vars() return if not isinstance(config, dict): @@ -46,7 +78,7 @@ def set_default_vertex_config(config): if isinstance(value, str) and value.startswith("os.environ/"): config[key] = litellm.get_secret(value) - default_vertex_config = config + default_vertex_config = VertexPassThroughCredentials(**config) def exception_handler(e: Exception): @@ -140,7 +172,7 @@ async def vertex_proxy_route( vertex_project = None vertex_location = None # Use headers from the incoming request if default_vertex_config is not set - if default_vertex_config is None: + if default_vertex_config.vertex_project is None: headers = dict(request.headers) or {} verbose_proxy_logger.debug( "default_vertex_config not set, incoming request headers %s", headers @@ -153,9 +185,9 @@ async def vertex_proxy_route( headers.pop("content-length", None) headers.pop("host", None) else: - vertex_project = default_vertex_config.get("vertex_project") - vertex_location = default_vertex_config.get("vertex_location") - vertex_credentials = default_vertex_config.get("vertex_credentials") + vertex_project = default_vertex_config.vertex_project + vertex_location = default_vertex_config.vertex_location + vertex_credentials = default_vertex_config.vertex_credentials base_target_url = f"https://{vertex_location}-aiplatform.googleapis.com/" diff --git a/litellm/types/passthrough_endpoints/vertex_ai.py b/litellm/types/passthrough_endpoints/vertex_ai.py new file mode 100644 index 000000000..3933aadcd --- /dev/null +++ b/litellm/types/passthrough_endpoints/vertex_ai.py @@ -0,0 +1,18 @@ +""" +Used for /vertex_ai/ pass through endpoints +""" + +from typing import Optional + +from pydantic import BaseModel + + +class VertexPassThroughCredentials(BaseModel): + # Example: vertex_project = "my-project-123" + vertex_project: Optional[str] = None + + # Example: vertex_location = "us-central1" + vertex_location: Optional[str] = None + + # Example: vertex_credentials = "/path/to/credentials.json" or "os.environ/GOOGLE_CREDS" + vertex_credentials: Optional[str] = None