mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
(Feat) pass through vertex - allow using credentials defined on litellm router for vertex pass through (#8100)
* test_add_vertex_pass_through_deployment * VertexPassThroughRouter * fix use_in_pass_through * VertexPassThroughRouter * fix vertex_credentials * allow using _initialize_deployment_for_pass_through * test_add_vertex_pass_through_deployment * _set_default_vertex_config * fix verbose_proxy_logger * fix use_in_pass_through * fix _get_token_and_url * test_get_vertex_location_from_url * test_get_vertex_credentials_none * run pt unit testing again * fix add_vertex_credentials * test_adding_deployments.py * rename file
This commit is contained in:
parent
892581ffc3
commit
b6d61ec22b
7 changed files with 490 additions and 19 deletions
|
@ -15,7 +15,10 @@ from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
|
|||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.passthrough_endpoints.vertex_ai import *
|
||||
|
||||
from .vertex_passthrough_router import VertexPassThroughRouter
|
||||
|
||||
router = APIRouter()
|
||||
vertex_pass_through_router = VertexPassThroughRouter()
|
||||
|
||||
default_vertex_config: VertexPassThroughCredentials = VertexPassThroughCredentials()
|
||||
|
||||
|
@ -59,7 +62,14 @@ def set_default_vertex_config(config: Optional[dict] = None):
|
|||
if isinstance(value, str) and value.startswith("os.environ/"):
|
||||
config[key] = litellm.get_secret(value)
|
||||
|
||||
default_vertex_config = VertexPassThroughCredentials(**config)
|
||||
_set_default_vertex_config(VertexPassThroughCredentials(**config))
|
||||
|
||||
|
||||
def _set_default_vertex_config(
|
||||
vertex_pass_through_credentials: VertexPassThroughCredentials,
|
||||
):
|
||||
global default_vertex_config
|
||||
default_vertex_config = vertex_pass_through_credentials
|
||||
|
||||
|
||||
def exception_handler(e: Exception):
|
||||
|
@ -147,9 +157,6 @@ async def vertex_proxy_route(
|
|||
[Docs](https://docs.litellm.ai/docs/pass_through/vertex_ai)
|
||||
"""
|
||||
encoded_endpoint = httpx.URL(endpoint).path
|
||||
|
||||
import re
|
||||
|
||||
verbose_proxy_logger.debug("requested endpoint %s", endpoint)
|
||||
headers: dict = {}
|
||||
api_key_to_use = get_litellm_virtual_key(request=request)
|
||||
|
@ -158,31 +165,37 @@ async def vertex_proxy_route(
|
|||
api_key=api_key_to_use,
|
||||
)
|
||||
|
||||
vertex_project = None
|
||||
vertex_location = None
|
||||
# Use headers from the incoming request if default_vertex_config is not set
|
||||
if default_vertex_config.vertex_project is None:
|
||||
vertex_project: Optional[str] = (
|
||||
VertexPassThroughRouter._get_vertex_project_id_from_url(endpoint)
|
||||
)
|
||||
vertex_location: Optional[str] = (
|
||||
VertexPassThroughRouter._get_vertex_location_from_url(endpoint)
|
||||
)
|
||||
vertex_credentials = vertex_pass_through_router.get_vertex_credentials(
|
||||
project_id=vertex_project,
|
||||
location=vertex_location,
|
||||
)
|
||||
|
||||
# Use headers from the incoming request if no vertex credentials are found
|
||||
if vertex_credentials.vertex_project is None:
|
||||
headers = dict(request.headers) or {}
|
||||
verbose_proxy_logger.debug(
|
||||
"default_vertex_config not set, incoming request headers %s", headers
|
||||
)
|
||||
# extract location from endpoint, endpoint
|
||||
# "v1beta1/projects/adroit-crow-413218/locations/us-central1/publishers/google/models/gemini-1.5-pro:generateContent"
|
||||
match = re.search(r"/locations/([^/]+)", endpoint)
|
||||
vertex_location = match.group(1) if match else None
|
||||
base_target_url = f"https://{vertex_location}-aiplatform.googleapis.com/"
|
||||
headers.pop("content-length", None)
|
||||
headers.pop("host", None)
|
||||
else:
|
||||
vertex_project = default_vertex_config.vertex_project
|
||||
vertex_location = default_vertex_config.vertex_location
|
||||
vertex_credentials = default_vertex_config.vertex_credentials
|
||||
vertex_project = vertex_credentials.vertex_project
|
||||
vertex_location = vertex_credentials.vertex_location
|
||||
vertex_credentials_str = vertex_credentials.vertex_credentials
|
||||
|
||||
# Construct base URL for the target endpoint
|
||||
base_target_url = f"https://{vertex_location}-aiplatform.googleapis.com/"
|
||||
|
||||
_auth_header, vertex_project = (
|
||||
await vertex_fine_tuning_apis_instance._ensure_access_token_async(
|
||||
credentials=vertex_credentials,
|
||||
credentials=vertex_credentials_str,
|
||||
project_id=vertex_project,
|
||||
custom_llm_provider="vertex_ai_beta",
|
||||
)
|
||||
|
@ -192,7 +205,7 @@ async def vertex_proxy_route(
|
|||
model="",
|
||||
auth_header=_auth_header,
|
||||
gemini_api_key=None,
|
||||
vertex_credentials=vertex_credentials,
|
||||
vertex_credentials=vertex_credentials_str,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
stream=False,
|
||||
|
|
120
litellm/proxy/vertex_ai_endpoints/vertex_passthrough_router.py
Normal file
120
litellm/proxy/vertex_ai_endpoints/vertex_passthrough_router.py
Normal file
|
@ -0,0 +1,120 @@
|
|||
import json
|
||||
import re
|
||||
from typing import Dict, Optional
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import (
|
||||
VertexPassThroughCredentials,
|
||||
)
|
||||
|
||||
|
||||
class VertexPassThroughRouter:
|
||||
"""
|
||||
Vertex Pass Through Router for Vertex AI pass-through endpoints
|
||||
|
||||
|
||||
- if request specifies a project-id, location -> use credentials corresponding to the project-id, location
|
||||
- if request does not specify a project-id, location -> use credentials corresponding to the DEFAULT_VERTEXAI_PROJECT, DEFAULT_VERTEXAI_LOCATION
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Initialize the VertexPassThroughRouter
|
||||
Stores the vertex credentials for each deployment key
|
||||
```
|
||||
{
|
||||
"project_id-location": VertexPassThroughCredentials,
|
||||
"adroit-crow-us-central1": VertexPassThroughCredentials,
|
||||
}
|
||||
```
|
||||
"""
|
||||
self.deployment_key_to_vertex_credentials: Dict[
|
||||
str, VertexPassThroughCredentials
|
||||
] = {}
|
||||
pass
|
||||
|
||||
def get_vertex_credentials(
|
||||
self, project_id: Optional[str], location: Optional[str]
|
||||
) -> VertexPassThroughCredentials:
|
||||
"""
|
||||
Get the vertex credentials for the given project-id, location
|
||||
"""
|
||||
from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import (
|
||||
default_vertex_config,
|
||||
)
|
||||
|
||||
deployment_key = self._get_deployment_key(
|
||||
project_id=project_id,
|
||||
location=location,
|
||||
)
|
||||
if deployment_key is None:
|
||||
return default_vertex_config
|
||||
if deployment_key in self.deployment_key_to_vertex_credentials:
|
||||
return self.deployment_key_to_vertex_credentials[deployment_key]
|
||||
else:
|
||||
return default_vertex_config
|
||||
|
||||
def add_vertex_credentials(
|
||||
self,
|
||||
project_id: str,
|
||||
location: str,
|
||||
vertex_credentials: str,
|
||||
):
|
||||
"""
|
||||
Add the vertex credentials for the given project-id, location
|
||||
"""
|
||||
from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import (
|
||||
_set_default_vertex_config,
|
||||
)
|
||||
|
||||
deployment_key = self._get_deployment_key(
|
||||
project_id=project_id,
|
||||
location=location,
|
||||
)
|
||||
if deployment_key is None:
|
||||
verbose_proxy_logger.debug(
|
||||
"No deployment key found for project-id, location"
|
||||
)
|
||||
return
|
||||
vertex_pass_through_credentials = VertexPassThroughCredentials(
|
||||
vertex_project=project_id,
|
||||
vertex_location=location,
|
||||
vertex_credentials=vertex_credentials,
|
||||
)
|
||||
self.deployment_key_to_vertex_credentials[deployment_key] = (
|
||||
vertex_pass_through_credentials
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
f"self.deployment_key_to_vertex_credentials: {json.dumps(self.deployment_key_to_vertex_credentials, indent=4, default=str)}"
|
||||
)
|
||||
_set_default_vertex_config(vertex_pass_through_credentials)
|
||||
|
||||
def _get_deployment_key(
|
||||
self, project_id: Optional[str], location: Optional[str]
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Get the deployment key for the given project-id, location
|
||||
"""
|
||||
if project_id is None or location is None:
|
||||
return None
|
||||
return f"{project_id}-{location}"
|
||||
|
||||
@staticmethod
|
||||
def _get_vertex_project_id_from_url(url: str) -> Optional[str]:
|
||||
"""
|
||||
Get the vertex project id from the url
|
||||
|
||||
`https://${LOCATION}-aiplatform.googleapis.com/v1/projects/${PROJECT_ID}/locations/${LOCATION}/publishers/google/models/${MODEL_ID}:streamGenerateContent`
|
||||
"""
|
||||
match = re.search(r"/projects/([^/]+)", url)
|
||||
return match.group(1) if match else None
|
||||
|
||||
@staticmethod
|
||||
def _get_vertex_location_from_url(url: str) -> Optional[str]:
|
||||
"""
|
||||
Get the vertex location from the url
|
||||
|
||||
`https://${LOCATION}-aiplatform.googleapis.com/v1/projects/${PROJECT_ID}/locations/${LOCATION}/publishers/google/models/${MODEL_ID}:streamGenerateContent`
|
||||
"""
|
||||
match = re.search(r"/locations/([^/]+)", url)
|
||||
return match.group(1) if match else None
|
|
@ -4133,8 +4133,48 @@ class Router:
|
|||
litellm_router_instance=self, model=deployment.to_json(exclude_none=True)
|
||||
)
|
||||
|
||||
self._initialize_deployment_for_pass_through(
|
||||
deployment=deployment,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
model=deployment.litellm_params.model,
|
||||
)
|
||||
|
||||
return deployment
|
||||
|
||||
def _initialize_deployment_for_pass_through(
|
||||
self, deployment: Deployment, custom_llm_provider: str, model: str
|
||||
):
|
||||
"""
|
||||
Optional: Initialize deployment for pass-through endpoints if `deployment.litellm_params.use_in_pass_through` is True
|
||||
|
||||
Each provider uses diff .env vars for pass-through endpoints, this helper uses the deployment credentials to set the .env vars for pass-through endpoints
|
||||
"""
|
||||
if deployment.litellm_params.use_in_pass_through is True:
|
||||
if custom_llm_provider == "vertex_ai":
|
||||
from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import (
|
||||
vertex_pass_through_router,
|
||||
)
|
||||
|
||||
if (
|
||||
deployment.litellm_params.vertex_project is None
|
||||
or deployment.litellm_params.vertex_location is None
|
||||
or deployment.litellm_params.vertex_credentials is None
|
||||
):
|
||||
raise ValueError(
|
||||
"vertex_project, vertex_location, and vertex_credentials must be set in litellm_params for pass-through endpoints"
|
||||
)
|
||||
vertex_pass_through_router.add_vertex_credentials(
|
||||
project_id=deployment.litellm_params.vertex_project,
|
||||
location=deployment.litellm_params.vertex_location,
|
||||
vertex_credentials=deployment.litellm_params.vertex_credentials,
|
||||
)
|
||||
else:
|
||||
verbose_router_logger.error(
|
||||
f"Unsupported provider - {custom_llm_provider} for pass-through endpoints"
|
||||
)
|
||||
pass
|
||||
pass
|
||||
|
||||
def add_deployment(self, deployment: Deployment) -> Optional[Deployment]:
|
||||
"""
|
||||
Parameters:
|
||||
|
|
|
@ -176,7 +176,7 @@ class GenericLiteLLMParams(BaseModel):
|
|||
# Deployment budgets
|
||||
max_budget: Optional[float] = None
|
||||
budget_duration: Optional[str] = None
|
||||
|
||||
use_in_pass_through: Optional[bool] = False
|
||||
model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True)
|
||||
|
||||
def __init__(
|
||||
|
@ -215,6 +215,8 @@ class GenericLiteLLMParams(BaseModel):
|
|||
# Deployment budgets
|
||||
max_budget: Optional[float] = None,
|
||||
budget_duration: Optional[str] = None,
|
||||
# Pass through params
|
||||
use_in_pass_through: Optional[bool] = False,
|
||||
**params,
|
||||
):
|
||||
args = locals()
|
||||
|
@ -276,6 +278,8 @@ class LiteLLM_Params(GenericLiteLLMParams):
|
|||
# OpenAI / Azure Whisper
|
||||
# set a max-size of file that can be passed to litellm proxy
|
||||
max_file_size_mb: Optional[float] = None,
|
||||
# will use deployment on pass-through endpoints if True
|
||||
use_in_pass_through: Optional[bool] = False,
|
||||
**params,
|
||||
):
|
||||
args = locals()
|
||||
|
|
|
@ -1729,6 +1729,7 @@ all_litellm_params = [
|
|||
"max_fallbacks",
|
||||
"max_budget",
|
||||
"budget_duration",
|
||||
"use_in_pass_through",
|
||||
] + list(StandardCallbackDynamicParams.__annotations__.keys())
|
||||
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ from unittest.mock import AsyncMock, Mock, patch
|
|||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
) # Adds the parent directory to the system-path
|
||||
|
||||
|
||||
import httpx
|
||||
|
@ -23,6 +23,9 @@ from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import (
|
|||
VertexPassThroughCredentials,
|
||||
default_vertex_config,
|
||||
)
|
||||
from litellm.proxy.vertex_ai_endpoints.vertex_passthrough_router import (
|
||||
VertexPassThroughRouter,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -167,3 +170,123 @@ async def test_set_default_vertex_config():
|
|||
del os.environ["DEFAULT_VERTEXAI_LOCATION"]
|
||||
del os.environ["DEFAULT_GOOGLE_APPLICATION_CREDENTIALS"]
|
||||
del os.environ["GOOGLE_CREDS"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_vertex_passthrough_router_init():
|
||||
"""Test VertexPassThroughRouter initialization"""
|
||||
router = VertexPassThroughRouter()
|
||||
assert isinstance(router.deployment_key_to_vertex_credentials, dict)
|
||||
assert len(router.deployment_key_to_vertex_credentials) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_vertex_credentials_none():
|
||||
"""Test get_vertex_credentials with various inputs"""
|
||||
from litellm.proxy.vertex_ai_endpoints import vertex_endpoints
|
||||
|
||||
setattr(vertex_endpoints, "default_vertex_config", VertexPassThroughCredentials())
|
||||
router = VertexPassThroughRouter()
|
||||
|
||||
# Test with None project_id and location - should return default config
|
||||
creds = router.get_vertex_credentials(None, None)
|
||||
assert isinstance(creds, VertexPassThroughCredentials)
|
||||
|
||||
# Test with valid project_id and location but no stored credentials
|
||||
creds = router.get_vertex_credentials("test-project", "us-central1")
|
||||
assert isinstance(creds, VertexPassThroughCredentials)
|
||||
assert creds.vertex_project is None
|
||||
assert creds.vertex_location is None
|
||||
assert creds.vertex_credentials is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_vertex_credentials_stored():
|
||||
"""Test get_vertex_credentials with stored credentials"""
|
||||
router = VertexPassThroughRouter()
|
||||
router.add_vertex_credentials(
|
||||
project_id="test-project",
|
||||
location="us-central1",
|
||||
vertex_credentials="test-creds",
|
||||
)
|
||||
|
||||
creds = router.get_vertex_credentials(
|
||||
project_id="test-project", location="us-central1"
|
||||
)
|
||||
assert creds.vertex_project == "test-project"
|
||||
assert creds.vertex_location == "us-central1"
|
||||
assert creds.vertex_credentials == "test-creds"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_vertex_credentials():
|
||||
"""Test add_vertex_credentials functionality"""
|
||||
router = VertexPassThroughRouter()
|
||||
|
||||
# Test adding valid credentials
|
||||
router.add_vertex_credentials(
|
||||
project_id="test-project",
|
||||
location="us-central1",
|
||||
vertex_credentials="test-creds",
|
||||
)
|
||||
|
||||
assert "test-project-us-central1" in router.deployment_key_to_vertex_credentials
|
||||
creds = router.deployment_key_to_vertex_credentials["test-project-us-central1"]
|
||||
assert creds.vertex_project == "test-project"
|
||||
assert creds.vertex_location == "us-central1"
|
||||
assert creds.vertex_credentials == "test-creds"
|
||||
|
||||
# Test adding with None values
|
||||
router.add_vertex_credentials(
|
||||
project_id=None, location=None, vertex_credentials="test-creds"
|
||||
)
|
||||
# Should not add None values
|
||||
assert len(router.deployment_key_to_vertex_credentials) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_deployment_key():
|
||||
"""Test _get_deployment_key with various inputs"""
|
||||
router = VertexPassThroughRouter()
|
||||
|
||||
# Test with valid inputs
|
||||
key = router._get_deployment_key("test-project", "us-central1")
|
||||
assert key == "test-project-us-central1"
|
||||
|
||||
# Test with None values
|
||||
key = router._get_deployment_key(None, "us-central1")
|
||||
assert key is None
|
||||
|
||||
key = router._get_deployment_key("test-project", None)
|
||||
assert key is None
|
||||
|
||||
key = router._get_deployment_key(None, None)
|
||||
assert key is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_vertex_project_id_from_url():
|
||||
"""Test _get_vertex_project_id_from_url with various URLs"""
|
||||
# Test with valid URL
|
||||
url = "https://us-central1-aiplatform.googleapis.com/v1/projects/test-project/locations/us-central1/publishers/google/models/gemini-pro:streamGenerateContent"
|
||||
project_id = VertexPassThroughRouter._get_vertex_project_id_from_url(url)
|
||||
assert project_id == "test-project"
|
||||
|
||||
# Test with invalid URL
|
||||
url = "https://invalid-url.com"
|
||||
project_id = VertexPassThroughRouter._get_vertex_project_id_from_url(url)
|
||||
assert project_id is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_vertex_location_from_url():
|
||||
"""Test _get_vertex_location_from_url with various URLs"""
|
||||
# Test with valid URL
|
||||
url = "https://us-central1-aiplatform.googleapis.com/v1/projects/test-project/locations/us-central1/publishers/google/models/gemini-pro:streamGenerateContent"
|
||||
location = VertexPassThroughRouter._get_vertex_location_from_url(url)
|
||||
assert location == "us-central1"
|
||||
|
||||
# Test with invalid URL
|
||||
url = "https://invalid-url.com"
|
||||
location = VertexPassThroughRouter._get_vertex_location_from_url(url)
|
||||
assert location is None
|
||||
|
|
170
tests/router_unit_tests/test_router_adding_deployments.py
Normal file
170
tests/router_unit_tests/test_router_adding_deployments.py
Normal file
|
@ -0,0 +1,170 @@
|
|||
import sys, os
|
||||
import pytest
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
from litellm import Router
|
||||
from litellm.router import Deployment, LiteLLM_Params
|
||||
from unittest.mock import patch
|
||||
import json
|
||||
|
||||
|
||||
def test_initialize_deployment_for_pass_through_success():
|
||||
"""
|
||||
Test successful initialization of a Vertex AI pass-through deployment
|
||||
"""
|
||||
router = Router(model_list=[])
|
||||
deployment = Deployment(
|
||||
model_name="vertex-test",
|
||||
litellm_params=LiteLLM_Params(
|
||||
model="vertex_ai/test-model",
|
||||
vertex_project="test-project",
|
||||
vertex_location="us-central1",
|
||||
vertex_credentials=json.dumps(
|
||||
{"type": "service_account", "project_id": "test"}
|
||||
),
|
||||
use_in_pass_through=True,
|
||||
),
|
||||
)
|
||||
|
||||
# Test the initialization
|
||||
router._initialize_deployment_for_pass_through(
|
||||
deployment=deployment,
|
||||
custom_llm_provider="vertex_ai",
|
||||
model="vertex_ai/test-model",
|
||||
)
|
||||
|
||||
# Verify the credentials were properly set
|
||||
from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import (
|
||||
vertex_pass_through_router,
|
||||
)
|
||||
|
||||
vertex_creds = vertex_pass_through_router.get_vertex_credentials(
|
||||
project_id="test-project", location="us-central1"
|
||||
)
|
||||
assert vertex_creds.vertex_project == "test-project"
|
||||
assert vertex_creds.vertex_location == "us-central1"
|
||||
assert vertex_creds.vertex_credentials == json.dumps(
|
||||
{"type": "service_account", "project_id": "test"}
|
||||
)
|
||||
|
||||
|
||||
def test_initialize_deployment_for_pass_through_missing_params():
|
||||
"""
|
||||
Test initialization fails when required Vertex AI parameters are missing
|
||||
"""
|
||||
router = Router(model_list=[])
|
||||
deployment = Deployment(
|
||||
model_name="vertex-test",
|
||||
litellm_params=LiteLLM_Params(
|
||||
model="vertex_ai/test-model",
|
||||
# Missing required parameters
|
||||
use_in_pass_through=True,
|
||||
),
|
||||
)
|
||||
|
||||
# Test that initialization raises ValueError
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="vertex_project, vertex_location, and vertex_credentials must be set",
|
||||
):
|
||||
router._initialize_deployment_for_pass_through(
|
||||
deployment=deployment,
|
||||
custom_llm_provider="vertex_ai",
|
||||
model="vertex_ai/test-model",
|
||||
)
|
||||
|
||||
|
||||
def test_initialize_deployment_for_pass_through_unsupported_provider():
|
||||
"""
|
||||
Test initialization with an unsupported provider
|
||||
"""
|
||||
router = Router(model_list=[])
|
||||
deployment = Deployment(
|
||||
model_name="unsupported-test",
|
||||
litellm_params=LiteLLM_Params(
|
||||
model="unsupported/test-model",
|
||||
use_in_pass_through=True,
|
||||
),
|
||||
)
|
||||
|
||||
# Should not raise an error, but log a warning
|
||||
router._initialize_deployment_for_pass_through(
|
||||
deployment=deployment,
|
||||
custom_llm_provider="unsupported_provider",
|
||||
model="unsupported/test-model",
|
||||
)
|
||||
|
||||
|
||||
def test_initialize_deployment_when_pass_through_disabled():
|
||||
"""
|
||||
Test that initialization simply exits when use_in_pass_through is False
|
||||
"""
|
||||
router = Router(model_list=[])
|
||||
deployment = Deployment(
|
||||
model_name="vertex-test",
|
||||
litellm_params=LiteLLM_Params(
|
||||
model="vertex_ai/test-model",
|
||||
),
|
||||
)
|
||||
|
||||
# This should exit without error, even with missing vertex parameters
|
||||
router._initialize_deployment_for_pass_through(
|
||||
deployment=deployment,
|
||||
custom_llm_provider="vertex_ai",
|
||||
model="vertex_ai/test-model",
|
||||
)
|
||||
|
||||
# If we reach this point, the test passes as the method exited without raising any errors
|
||||
assert True
|
||||
|
||||
|
||||
def test_add_vertex_pass_through_deployment():
|
||||
"""
|
||||
Test adding a Vertex AI deployment with pass-through configuration
|
||||
"""
|
||||
router = Router(model_list=[])
|
||||
|
||||
# Create a deployment with Vertex AI pass-through settings
|
||||
deployment = Deployment(
|
||||
model_name="vertex-test",
|
||||
litellm_params=LiteLLM_Params(
|
||||
model="vertex_ai/test-model",
|
||||
vertex_project="test-project",
|
||||
vertex_location="us-central1",
|
||||
vertex_credentials=json.dumps(
|
||||
{"type": "service_account", "project_id": "test"}
|
||||
),
|
||||
use_in_pass_through=True,
|
||||
),
|
||||
)
|
||||
|
||||
# Add deployment to router
|
||||
router.add_deployment(deployment)
|
||||
|
||||
# Get the vertex credentials from the router
|
||||
from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import (
|
||||
vertex_pass_through_router,
|
||||
)
|
||||
|
||||
# current state of pass-through vertex router
|
||||
print("\n vertex_pass_through_router.deployment_key_to_vertex_credentials\n\n")
|
||||
print(
|
||||
json.dumps(
|
||||
vertex_pass_through_router.deployment_key_to_vertex_credentials,
|
||||
indent=4,
|
||||
default=str,
|
||||
)
|
||||
)
|
||||
|
||||
vertex_creds = vertex_pass_through_router.get_vertex_credentials(
|
||||
project_id="test-project", location="us-central1"
|
||||
)
|
||||
|
||||
# Verify the credentials were properly set
|
||||
assert vertex_creds.vertex_project == "test-project"
|
||||
assert vertex_creds.vertex_location == "us-central1"
|
||||
assert vertex_creds.vertex_credentials == json.dumps(
|
||||
{"type": "service_account", "project_id": "test"}
|
||||
)
|
Loading…
Add table
Add a link
Reference in a new issue