feat(llm_passthrough_endpoints.py): base case passing for refactored vertex passthrough route

This commit is contained in:
Krrish Dholakia 2025-03-22 11:06:52 -07:00
parent 94d3413335
commit b44b3bd36b
7 changed files with 115 additions and 111 deletions

View file

@ -456,7 +456,6 @@ async def vertex_proxy_route(
request=request, request=request,
api_key=api_key_to_use, api_key=api_key_to_use,
) )
vertex_project: Optional[str] = get_vertex_project_id_from_url(endpoint) vertex_project: Optional[str] = get_vertex_project_id_from_url(endpoint)
vertex_location: Optional[str] = get_vertex_location_from_url(endpoint) vertex_location: Optional[str] = get_vertex_location_from_url(endpoint)
vertex_credentials = passthrough_endpoint_router.get_vertex_credentials( vertex_credentials = passthrough_endpoint_router.get_vertex_credentials(

View file

@ -1,7 +1,8 @@
from typing import Dict, Optional from typing import Dict, Optional
from litellm._logging import verbose_logger from litellm._logging import verbose_router_logger
from litellm.secret_managers.main import get_secret_str from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
from litellm.types.passthrough_endpoints.vertex_ai import VertexPassThroughCredentials from litellm.types.passthrough_endpoints.vertex_ai import VertexPassThroughCredentials
@ -15,6 +16,7 @@ class PassthroughEndpointRouter:
self.deployment_key_to_vertex_credentials: Dict[ self.deployment_key_to_vertex_credentials: Dict[
str, VertexPassThroughCredentials str, VertexPassThroughCredentials
] = {} ] = {}
self.default_vertex_config: Optional[VertexPassThroughCredentials] = None
def set_pass_through_credentials( def set_pass_through_credentials(
self, self,
@ -49,14 +51,14 @@ class PassthroughEndpointRouter:
custom_llm_provider=custom_llm_provider, custom_llm_provider=custom_llm_provider,
region_name=region_name, region_name=region_name,
) )
verbose_logger.debug( verbose_router_logger.debug(
f"Pass-through llm endpoints router, looking for credentials for {credential_name}" f"Pass-through llm endpoints router, looking for credentials for {credential_name}"
) )
if credential_name in self.credentials: if credential_name in self.credentials:
verbose_logger.debug(f"Found credentials for {credential_name}") verbose_router_logger.debug(f"Found credentials for {credential_name}")
return self.credentials[credential_name] return self.credentials[credential_name]
else: else:
verbose_logger.debug( verbose_router_logger.debug(
f"No credentials found for {credential_name}, looking for env variable" f"No credentials found for {credential_name}, looking for env variable"
) )
_env_variable_name = ( _env_variable_name = (
@ -66,6 +68,72 @@ class PassthroughEndpointRouter:
) )
return get_secret_str(_env_variable_name) return get_secret_str(_env_variable_name)
def _get_vertex_env_vars(self) -> VertexPassThroughCredentials:
"""
Helper to get vertex pass through config from environment variables
The following environment variables are used:
- DEFAULT_VERTEXAI_PROJECT (project id)
- DEFAULT_VERTEXAI_LOCATION (location)
- DEFAULT_GOOGLE_APPLICATION_CREDENTIALS (path to credentials file)
"""
return VertexPassThroughCredentials(
vertex_project=get_secret_str("DEFAULT_VERTEXAI_PROJECT"),
vertex_location=get_secret_str("DEFAULT_VERTEXAI_LOCATION"),
vertex_credentials=get_secret_str("DEFAULT_GOOGLE_APPLICATION_CREDENTIALS"),
)
def set_default_vertex_config(self, config: Optional[dict] = None):
"""Sets vertex configuration from provided config and/or environment variables
Args:
config (Optional[dict]): Configuration dictionary
Example: {
"vertex_project": "my-project-123",
"vertex_location": "us-central1",
"vertex_credentials": "os.environ/GOOGLE_CREDS"
}
"""
# Initialize config dictionary if None
if config is None:
self.default_vertex_config = self._get_vertex_env_vars()
return
if isinstance(config, dict):
for key, value in config.items():
if isinstance(value, str) and value.startswith("os.environ/"):
config[key] = get_secret_str(value)
self.default_vertex_config = VertexPassThroughCredentials(**config)
def add_vertex_credentials(
self,
project_id: str,
location: str,
vertex_credentials: VERTEX_CREDENTIALS_TYPES,
):
"""
Add the vertex credentials for the given project-id, location
"""
deployment_key = self._get_deployment_key(
project_id=project_id,
location=location,
)
if deployment_key is None:
verbose_router_logger.debug(
"No deployment key found for project-id, location"
)
return
vertex_pass_through_credentials = VertexPassThroughCredentials(
vertex_project=project_id,
vertex_location=location,
vertex_credentials=vertex_credentials,
)
self.deployment_key_to_vertex_credentials[deployment_key] = (
vertex_pass_through_credentials
)
def _get_deployment_key( def _get_deployment_key(
self, project_id: Optional[str], location: Optional[str] self, project_id: Optional[str], location: Optional[str]
) -> Optional[str]: ) -> Optional[str]:
@ -82,15 +150,13 @@ class PassthroughEndpointRouter:
""" """
Get the vertex credentials for the given project-id, location Get the vertex credentials for the given project-id, location
""" """
# from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import (
# default_vertex_config,
# )
default_vertex_config: Optional[VertexPassThroughCredentials] = None default_vertex_config: Optional[VertexPassThroughCredentials] = None
deployment_key = self._get_deployment_key( deployment_key = self._get_deployment_key(
project_id=project_id, project_id=project_id,
location=location, location=location,
) )
if deployment_key is None: if deployment_key is None:
return default_vertex_config return default_vertex_config
if deployment_key in self.deployment_key_to_vertex_credentials: if deployment_key in self.deployment_key_to_vertex_credentials:

View file

@ -235,6 +235,9 @@ from litellm.proxy.openai_files_endpoints.files_endpoints import (
router as openai_files_router, router as openai_files_router,
) )
from litellm.proxy.openai_files_endpoints.files_endpoints import set_files_config from litellm.proxy.openai_files_endpoints.files_endpoints import set_files_config
from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import (
passthrough_endpoint_router,
)
from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import ( from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import (
router as llm_passthrough_router, router as llm_passthrough_router,
) )
@ -272,8 +275,6 @@ from litellm.proxy.utils import (
from litellm.proxy.vertex_ai_endpoints.langfuse_endpoints import ( from litellm.proxy.vertex_ai_endpoints.langfuse_endpoints import (
router as langfuse_router, router as langfuse_router,
) )
from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import router as vertex_router
from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import set_default_vertex_config
from litellm.router import ( from litellm.router import (
AssistantsTypedDict, AssistantsTypedDict,
Deployment, Deployment,
@ -2115,7 +2116,9 @@ class ProxyConfig:
## default config for vertex ai routes ## default config for vertex ai routes
default_vertex_config = config.get("default_vertex_config", None) default_vertex_config = config.get("default_vertex_config", None)
set_default_vertex_config(config=default_vertex_config) passthrough_endpoint_router.set_default_vertex_config(
config=default_vertex_config
)
## ROUTER SETTINGS (e.g. routing_strategy, ...) ## ROUTER SETTINGS (e.g. routing_strategy, ...)
router_settings = config.get("router_settings", None) router_settings = config.get("router_settings", None)
@ -8161,7 +8164,6 @@ app.include_router(batches_router)
app.include_router(rerank_router) app.include_router(rerank_router)
app.include_router(fine_tuning_router) app.include_router(fine_tuning_router)
app.include_router(credential_router) app.include_router(credential_router)
app.include_router(vertex_router)
app.include_router(llm_passthrough_router) app.include_router(llm_passthrough_router)
app.include_router(anthropic_router) app.include_router(anthropic_router)
app.include_router(langfuse_router) app.include_router(langfuse_router)

View file

@ -23,48 +23,6 @@
# default_vertex_config: Optional[VertexPassThroughCredentials] = None # default_vertex_config: Optional[VertexPassThroughCredentials] = None
# def _get_vertex_env_vars() -> VertexPassThroughCredentials:
# """
# Helper to get vertex pass through config from environment variables
# The following environment variables are used:
# - DEFAULT_VERTEXAI_PROJECT (project id)
# - DEFAULT_VERTEXAI_LOCATION (location)
# - DEFAULT_GOOGLE_APPLICATION_CREDENTIALS (path to credentials file)
# """
# return VertexPassThroughCredentials(
# vertex_project=get_secret_str("DEFAULT_VERTEXAI_PROJECT"),
# vertex_location=get_secret_str("DEFAULT_VERTEXAI_LOCATION"),
# vertex_credentials=get_secret_str("DEFAULT_GOOGLE_APPLICATION_CREDENTIALS"),
# )
# def set_default_vertex_config(config: Optional[dict] = None):
# """Sets vertex configuration from provided config and/or environment variables
# Args:
# config (Optional[dict]): Configuration dictionary
# Example: {
# "vertex_project": "my-project-123",
# "vertex_location": "us-central1",
# "vertex_credentials": "os.environ/GOOGLE_CREDS"
# }
# """
# global default_vertex_config
# # Initialize config dictionary if None
# if config is None:
# default_vertex_config = _get_vertex_env_vars()
# return
# if isinstance(config, dict):
# for key, value in config.items():
# if isinstance(value, str) and value.startswith("os.environ/"):
# config[key] = litellm.get_secret(value)
# _set_default_vertex_config(VertexPassThroughCredentials(**config))
# def _set_default_vertex_config( # def _set_default_vertex_config(
# vertex_pass_through_credentials: VertexPassThroughCredentials, # vertex_pass_through_credentials: VertexPassThroughCredentials,
# ): # ):

View file

@ -55,37 +55,6 @@
# else: # else:
# return default_vertex_config # return default_vertex_config
# def add_vertex_credentials(
# self,
# project_id: str,
# location: str,
# vertex_credentials: VERTEX_CREDENTIALS_TYPES,
# ):
# """
# Add the vertex credentials for the given project-id, location
# """
# deployment_key = self._get_deployment_key(
# project_id=project_id,
# location=location,
# )
# if deployment_key is None:
# verbose_proxy_logger.debug(
# "No deployment key found for project-id, location"
# )
# return
# vertex_pass_through_credentials = VertexPassThroughCredentials(
# vertex_project=project_id,
# vertex_location=location,
# vertex_credentials=vertex_credentials,
# )
# self.deployment_key_to_vertex_credentials[deployment_key] = (
# vertex_pass_through_credentials
# )
# verbose_proxy_logger.debug(
# f"self.deployment_key_to_vertex_credentials: {json.dumps(self.deployment_key_to_vertex_credentials, indent=4, default=str)}"
# )
# @staticmethod # @staticmethod
# def _get_vertex_project_id_from_url(url: str) -> Optional[str]: # def _get_vertex_project_id_from_url(url: str) -> Optional[str]:

View file

@ -4495,11 +4495,11 @@ class Router:
Each provider uses diff .env vars for pass-through endpoints, this helper uses the deployment credentials to set the .env vars for pass-through endpoints Each provider uses diff .env vars for pass-through endpoints, this helper uses the deployment credentials to set the .env vars for pass-through endpoints
""" """
if deployment.litellm_params.use_in_pass_through is True: if deployment.litellm_params.use_in_pass_through is True:
if custom_llm_provider == "vertex_ai": from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import (
from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import ( passthrough_endpoint_router,
vertex_pass_through_router, )
)
if custom_llm_provider == "vertex_ai":
if ( if (
deployment.litellm_params.vertex_project is None deployment.litellm_params.vertex_project is None
or deployment.litellm_params.vertex_location is None or deployment.litellm_params.vertex_location is None
@ -4508,16 +4508,12 @@ class Router:
raise ValueError( raise ValueError(
"vertex_project, vertex_location, and vertex_credentials must be set in litellm_params for pass-through endpoints" "vertex_project, vertex_location, and vertex_credentials must be set in litellm_params for pass-through endpoints"
) )
vertex_pass_through_router.add_vertex_credentials( passthrough_endpoint_router.add_vertex_credentials(
project_id=deployment.litellm_params.vertex_project, project_id=deployment.litellm_params.vertex_project,
location=deployment.litellm_params.vertex_location, location=deployment.litellm_params.vertex_location,
vertex_credentials=deployment.litellm_params.vertex_credentials, vertex_credentials=deployment.litellm_params.vertex_credentials,
) )
else: else:
from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import (
passthrough_endpoint_router,
)
passthrough_endpoint_router.set_pass_through_credentials( passthrough_endpoint_router.set_pass_through_credentials(
custom_llm_provider=custom_llm_provider, custom_llm_provider=custom_llm_provider,
api_base=deployment.litellm_params.api_base, api_base=deployment.litellm_params.api_base,

View file

@ -191,16 +191,39 @@ class TestVertexAIPassThroughHandler:
""" """
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_vertex_passthrough_with_credentials(self): async def test_vertex_passthrough_with_credentials(self, monkeypatch):
""" """
Test that when passthrough credentials are set, they are correctly used in the request Test that when passthrough credentials are set, they are correctly used in the request
""" """
from litellm.proxy.pass_through_endpoints.passthrough_endpoint_router import (
PassthroughEndpointRouter,
)
vertex_project = "test-project"
vertex_location = "us-central1"
vertex_credentials = "test-creds"
pass_through_router = PassthroughEndpointRouter()
pass_through_router.add_vertex_credentials(
project_id=vertex_project,
location=vertex_location,
vertex_credentials=vertex_credentials,
)
monkeypatch.setattr(
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.passthrough_endpoint_router",
pass_through_router,
)
endpoint = f"/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/gemini-1.5-flash:generateContent"
# Mock request # Mock request
mock_request = Request( mock_request = Request(
scope={ scope={
"type": "http", "type": "http",
"method": "POST", "method": "POST",
"path": "/vertex_ai/models/test-model/predict", "path": endpoint,
"headers": {}, "headers": {},
} }
) )
@ -209,33 +232,24 @@ class TestVertexAIPassThroughHandler:
mock_response = Response() mock_response = Response()
# Mock vertex credentials # Mock vertex credentials
test_project = "test-project" test_project = vertex_project
test_location = "us-central1" test_location = vertex_location
test_token = "test-token-123" test_token = vertex_credentials
with mock.patch( with mock.patch(
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.passthrough_endpoint_router.get_vertex_credentials"
) as mock_get_creds, mock.patch(
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.vertex_llm_base._ensure_access_token_async" "litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.vertex_llm_base._ensure_access_token_async"
) as mock_ensure_token, mock.patch( ) as mock_ensure_token, mock.patch(
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.vertex_llm_base._get_token_and_url" "litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.vertex_llm_base._get_token_and_url"
) as mock_get_token, mock.patch( ) as mock_get_token, mock.patch(
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.create_pass_through_route" "litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.create_pass_through_route"
) as mock_create_route: ) as mock_create_route:
# Setup mock returns
mock_get_creds.return_value = VertexPassThroughCredentials(
vertex_project=test_project,
vertex_location=test_location,
vertex_credentials="test-creds",
)
mock_ensure_token.return_value = ("test-auth-header", test_project) mock_ensure_token.return_value = ("test-auth-header", test_project)
mock_get_token.return_value = (test_token, "") mock_get_token.return_value = (test_token, "")
# Call the route # Call the route
try: try:
await vertex_proxy_route( await vertex_proxy_route(
endpoint="models/test-model/predict", endpoint=endpoint,
request=mock_request, request=mock_request,
fastapi_response=mock_response, fastapi_response=mock_response,
) )
@ -244,7 +258,7 @@ class TestVertexAIPassThroughHandler:
# Verify create_pass_through_route was called with correct arguments # Verify create_pass_through_route was called with correct arguments
mock_create_route.assert_called_once_with( mock_create_route.assert_called_once_with(
endpoint="models/test-model/predict", endpoint=endpoint,
target=f"https://{test_location}-aiplatform.googleapis.com/v1/projects/{test_project}/locations/{test_location}/models/test-model/predict", target=f"https://{test_location}-aiplatform.googleapis.com/v1/projects/{test_project}/locations/{test_location}/publishers/google/models/gemini-1.5-flash:generateContent",
custom_headers={"Authorization": f"Bearer {test_token}"}, custom_headers={"Authorization": f"Bearer {test_token}"},
) )