mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
(Feat) pass through vertex - allow using credentials defined on litellm router for vertex pass through (#8100)
* test_add_vertex_pass_through_deployment * VertexPassThroughRouter * fix use_in_pass_through * VertexPassThroughRouter * fix vertex_credentials * allow using _initialize_deployment_for_pass_through * test_add_vertex_pass_through_deployment * _set_default_vertex_config * fix verbose_proxy_logger * fix use_in_pass_through * fix _get_token_and_url * test_get_vertex_location_from_url * test_get_vertex_credentials_none * run pt unit testing again * fix add_vertex_credentials * test_adding_deployments.py * rename file
This commit is contained in:
parent
892581ffc3
commit
b6d61ec22b
7 changed files with 490 additions and 19 deletions
|
@ -15,7 +15,10 @@ from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
|
||||||
from litellm.secret_managers.main import get_secret_str
|
from litellm.secret_managers.main import get_secret_str
|
||||||
from litellm.types.passthrough_endpoints.vertex_ai import *
|
from litellm.types.passthrough_endpoints.vertex_ai import *
|
||||||
|
|
||||||
|
from .vertex_passthrough_router import VertexPassThroughRouter
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
vertex_pass_through_router = VertexPassThroughRouter()
|
||||||
|
|
||||||
default_vertex_config: VertexPassThroughCredentials = VertexPassThroughCredentials()
|
default_vertex_config: VertexPassThroughCredentials = VertexPassThroughCredentials()
|
||||||
|
|
||||||
|
@ -59,7 +62,14 @@ def set_default_vertex_config(config: Optional[dict] = None):
|
||||||
if isinstance(value, str) and value.startswith("os.environ/"):
|
if isinstance(value, str) and value.startswith("os.environ/"):
|
||||||
config[key] = litellm.get_secret(value)
|
config[key] = litellm.get_secret(value)
|
||||||
|
|
||||||
default_vertex_config = VertexPassThroughCredentials(**config)
|
_set_default_vertex_config(VertexPassThroughCredentials(**config))
|
||||||
|
|
||||||
|
|
||||||
|
def _set_default_vertex_config(
|
||||||
|
vertex_pass_through_credentials: VertexPassThroughCredentials,
|
||||||
|
):
|
||||||
|
global default_vertex_config
|
||||||
|
default_vertex_config = vertex_pass_through_credentials
|
||||||
|
|
||||||
|
|
||||||
def exception_handler(e: Exception):
|
def exception_handler(e: Exception):
|
||||||
|
@ -147,9 +157,6 @@ async def vertex_proxy_route(
|
||||||
[Docs](https://docs.litellm.ai/docs/pass_through/vertex_ai)
|
[Docs](https://docs.litellm.ai/docs/pass_through/vertex_ai)
|
||||||
"""
|
"""
|
||||||
encoded_endpoint = httpx.URL(endpoint).path
|
encoded_endpoint = httpx.URL(endpoint).path
|
||||||
|
|
||||||
import re
|
|
||||||
|
|
||||||
verbose_proxy_logger.debug("requested endpoint %s", endpoint)
|
verbose_proxy_logger.debug("requested endpoint %s", endpoint)
|
||||||
headers: dict = {}
|
headers: dict = {}
|
||||||
api_key_to_use = get_litellm_virtual_key(request=request)
|
api_key_to_use = get_litellm_virtual_key(request=request)
|
||||||
|
@ -158,31 +165,37 @@ async def vertex_proxy_route(
|
||||||
api_key=api_key_to_use,
|
api_key=api_key_to_use,
|
||||||
)
|
)
|
||||||
|
|
||||||
vertex_project = None
|
vertex_project: Optional[str] = (
|
||||||
vertex_location = None
|
VertexPassThroughRouter._get_vertex_project_id_from_url(endpoint)
|
||||||
# Use headers from the incoming request if default_vertex_config is not set
|
)
|
||||||
if default_vertex_config.vertex_project is None:
|
vertex_location: Optional[str] = (
|
||||||
|
VertexPassThroughRouter._get_vertex_location_from_url(endpoint)
|
||||||
|
)
|
||||||
|
vertex_credentials = vertex_pass_through_router.get_vertex_credentials(
|
||||||
|
project_id=vertex_project,
|
||||||
|
location=vertex_location,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use headers from the incoming request if no vertex credentials are found
|
||||||
|
if vertex_credentials.vertex_project is None:
|
||||||
headers = dict(request.headers) or {}
|
headers = dict(request.headers) or {}
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
"default_vertex_config not set, incoming request headers %s", headers
|
"default_vertex_config not set, incoming request headers %s", headers
|
||||||
)
|
)
|
||||||
# extract location from endpoint, endpoint
|
|
||||||
# "v1beta1/projects/adroit-crow-413218/locations/us-central1/publishers/google/models/gemini-1.5-pro:generateContent"
|
|
||||||
match = re.search(r"/locations/([^/]+)", endpoint)
|
|
||||||
vertex_location = match.group(1) if match else None
|
|
||||||
base_target_url = f"https://{vertex_location}-aiplatform.googleapis.com/"
|
base_target_url = f"https://{vertex_location}-aiplatform.googleapis.com/"
|
||||||
headers.pop("content-length", None)
|
headers.pop("content-length", None)
|
||||||
headers.pop("host", None)
|
headers.pop("host", None)
|
||||||
else:
|
else:
|
||||||
vertex_project = default_vertex_config.vertex_project
|
vertex_project = vertex_credentials.vertex_project
|
||||||
vertex_location = default_vertex_config.vertex_location
|
vertex_location = vertex_credentials.vertex_location
|
||||||
vertex_credentials = default_vertex_config.vertex_credentials
|
vertex_credentials_str = vertex_credentials.vertex_credentials
|
||||||
|
|
||||||
|
# Construct base URL for the target endpoint
|
||||||
base_target_url = f"https://{vertex_location}-aiplatform.googleapis.com/"
|
base_target_url = f"https://{vertex_location}-aiplatform.googleapis.com/"
|
||||||
|
|
||||||
_auth_header, vertex_project = (
|
_auth_header, vertex_project = (
|
||||||
await vertex_fine_tuning_apis_instance._ensure_access_token_async(
|
await vertex_fine_tuning_apis_instance._ensure_access_token_async(
|
||||||
credentials=vertex_credentials,
|
credentials=vertex_credentials_str,
|
||||||
project_id=vertex_project,
|
project_id=vertex_project,
|
||||||
custom_llm_provider="vertex_ai_beta",
|
custom_llm_provider="vertex_ai_beta",
|
||||||
)
|
)
|
||||||
|
@ -192,7 +205,7 @@ async def vertex_proxy_route(
|
||||||
model="",
|
model="",
|
||||||
auth_header=_auth_header,
|
auth_header=_auth_header,
|
||||||
gemini_api_key=None,
|
gemini_api_key=None,
|
||||||
vertex_credentials=vertex_credentials,
|
vertex_credentials=vertex_credentials_str,
|
||||||
vertex_project=vertex_project,
|
vertex_project=vertex_project,
|
||||||
vertex_location=vertex_location,
|
vertex_location=vertex_location,
|
||||||
stream=False,
|
stream=False,
|
||||||
|
|
120
litellm/proxy/vertex_ai_endpoints/vertex_passthrough_router.py
Normal file
120
litellm/proxy/vertex_ai_endpoints/vertex_passthrough_router.py
Normal file
|
@ -0,0 +1,120 @@
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
from litellm._logging import verbose_proxy_logger
|
||||||
|
from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import (
|
||||||
|
VertexPassThroughCredentials,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class VertexPassThroughRouter:
|
||||||
|
"""
|
||||||
|
Vertex Pass Through Router for Vertex AI pass-through endpoints
|
||||||
|
|
||||||
|
|
||||||
|
- if request specifies a project-id, location -> use credentials corresponding to the project-id, location
|
||||||
|
- if request does not specify a project-id, location -> use credentials corresponding to the DEFAULT_VERTEXAI_PROJECT, DEFAULT_VERTEXAI_LOCATION
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""
|
||||||
|
Initialize the VertexPassThroughRouter
|
||||||
|
Stores the vertex credentials for each deployment key
|
||||||
|
```
|
||||||
|
{
|
||||||
|
"project_id-location": VertexPassThroughCredentials,
|
||||||
|
"adroit-crow-us-central1": VertexPassThroughCredentials,
|
||||||
|
}
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
self.deployment_key_to_vertex_credentials: Dict[
|
||||||
|
str, VertexPassThroughCredentials
|
||||||
|
] = {}
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_vertex_credentials(
|
||||||
|
self, project_id: Optional[str], location: Optional[str]
|
||||||
|
) -> VertexPassThroughCredentials:
|
||||||
|
"""
|
||||||
|
Get the vertex credentials for the given project-id, location
|
||||||
|
"""
|
||||||
|
from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import (
|
||||||
|
default_vertex_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
deployment_key = self._get_deployment_key(
|
||||||
|
project_id=project_id,
|
||||||
|
location=location,
|
||||||
|
)
|
||||||
|
if deployment_key is None:
|
||||||
|
return default_vertex_config
|
||||||
|
if deployment_key in self.deployment_key_to_vertex_credentials:
|
||||||
|
return self.deployment_key_to_vertex_credentials[deployment_key]
|
||||||
|
else:
|
||||||
|
return default_vertex_config
|
||||||
|
|
||||||
|
def add_vertex_credentials(
|
||||||
|
self,
|
||||||
|
project_id: str,
|
||||||
|
location: str,
|
||||||
|
vertex_credentials: str,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Add the vertex credentials for the given project-id, location
|
||||||
|
"""
|
||||||
|
from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import (
|
||||||
|
_set_default_vertex_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
deployment_key = self._get_deployment_key(
|
||||||
|
project_id=project_id,
|
||||||
|
location=location,
|
||||||
|
)
|
||||||
|
if deployment_key is None:
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
"No deployment key found for project-id, location"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
vertex_pass_through_credentials = VertexPassThroughCredentials(
|
||||||
|
vertex_project=project_id,
|
||||||
|
vertex_location=location,
|
||||||
|
vertex_credentials=vertex_credentials,
|
||||||
|
)
|
||||||
|
self.deployment_key_to_vertex_credentials[deployment_key] = (
|
||||||
|
vertex_pass_through_credentials
|
||||||
|
)
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
f"self.deployment_key_to_vertex_credentials: {json.dumps(self.deployment_key_to_vertex_credentials, indent=4, default=str)}"
|
||||||
|
)
|
||||||
|
_set_default_vertex_config(vertex_pass_through_credentials)
|
||||||
|
|
||||||
|
def _get_deployment_key(
|
||||||
|
self, project_id: Optional[str], location: Optional[str]
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Get the deployment key for the given project-id, location
|
||||||
|
"""
|
||||||
|
if project_id is None or location is None:
|
||||||
|
return None
|
||||||
|
return f"{project_id}-{location}"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_vertex_project_id_from_url(url: str) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Get the vertex project id from the url
|
||||||
|
|
||||||
|
`https://${LOCATION}-aiplatform.googleapis.com/v1/projects/${PROJECT_ID}/locations/${LOCATION}/publishers/google/models/${MODEL_ID}:streamGenerateContent`
|
||||||
|
"""
|
||||||
|
match = re.search(r"/projects/([^/]+)", url)
|
||||||
|
return match.group(1) if match else None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_vertex_location_from_url(url: str) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Get the vertex location from the url
|
||||||
|
|
||||||
|
`https://${LOCATION}-aiplatform.googleapis.com/v1/projects/${PROJECT_ID}/locations/${LOCATION}/publishers/google/models/${MODEL_ID}:streamGenerateContent`
|
||||||
|
"""
|
||||||
|
match = re.search(r"/locations/([^/]+)", url)
|
||||||
|
return match.group(1) if match else None
|
|
@ -4133,8 +4133,48 @@ class Router:
|
||||||
litellm_router_instance=self, model=deployment.to_json(exclude_none=True)
|
litellm_router_instance=self, model=deployment.to_json(exclude_none=True)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._initialize_deployment_for_pass_through(
|
||||||
|
deployment=deployment,
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
model=deployment.litellm_params.model,
|
||||||
|
)
|
||||||
|
|
||||||
return deployment
|
return deployment
|
||||||
|
|
||||||
|
def _initialize_deployment_for_pass_through(
|
||||||
|
self, deployment: Deployment, custom_llm_provider: str, model: str
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Optional: Initialize deployment for pass-through endpoints if `deployment.litellm_params.use_in_pass_through` is True
|
||||||
|
|
||||||
|
Each provider uses diff .env vars for pass-through endpoints, this helper uses the deployment credentials to set the .env vars for pass-through endpoints
|
||||||
|
"""
|
||||||
|
if deployment.litellm_params.use_in_pass_through is True:
|
||||||
|
if custom_llm_provider == "vertex_ai":
|
||||||
|
from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import (
|
||||||
|
vertex_pass_through_router,
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
deployment.litellm_params.vertex_project is None
|
||||||
|
or deployment.litellm_params.vertex_location is None
|
||||||
|
or deployment.litellm_params.vertex_credentials is None
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"vertex_project, vertex_location, and vertex_credentials must be set in litellm_params for pass-through endpoints"
|
||||||
|
)
|
||||||
|
vertex_pass_through_router.add_vertex_credentials(
|
||||||
|
project_id=deployment.litellm_params.vertex_project,
|
||||||
|
location=deployment.litellm_params.vertex_location,
|
||||||
|
vertex_credentials=deployment.litellm_params.vertex_credentials,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
verbose_router_logger.error(
|
||||||
|
f"Unsupported provider - {custom_llm_provider} for pass-through endpoints"
|
||||||
|
)
|
||||||
|
pass
|
||||||
|
pass
|
||||||
|
|
||||||
def add_deployment(self, deployment: Deployment) -> Optional[Deployment]:
|
def add_deployment(self, deployment: Deployment) -> Optional[Deployment]:
|
||||||
"""
|
"""
|
||||||
Parameters:
|
Parameters:
|
||||||
|
|
|
@ -176,7 +176,7 @@ class GenericLiteLLMParams(BaseModel):
|
||||||
# Deployment budgets
|
# Deployment budgets
|
||||||
max_budget: Optional[float] = None
|
max_budget: Optional[float] = None
|
||||||
budget_duration: Optional[str] = None
|
budget_duration: Optional[str] = None
|
||||||
|
use_in_pass_through: Optional[bool] = False
|
||||||
model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True)
|
model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -215,6 +215,8 @@ class GenericLiteLLMParams(BaseModel):
|
||||||
# Deployment budgets
|
# Deployment budgets
|
||||||
max_budget: Optional[float] = None,
|
max_budget: Optional[float] = None,
|
||||||
budget_duration: Optional[str] = None,
|
budget_duration: Optional[str] = None,
|
||||||
|
# Pass through params
|
||||||
|
use_in_pass_through: Optional[bool] = False,
|
||||||
**params,
|
**params,
|
||||||
):
|
):
|
||||||
args = locals()
|
args = locals()
|
||||||
|
@ -276,6 +278,8 @@ class LiteLLM_Params(GenericLiteLLMParams):
|
||||||
# OpenAI / Azure Whisper
|
# OpenAI / Azure Whisper
|
||||||
# set a max-size of file that can be passed to litellm proxy
|
# set a max-size of file that can be passed to litellm proxy
|
||||||
max_file_size_mb: Optional[float] = None,
|
max_file_size_mb: Optional[float] = None,
|
||||||
|
# will use deployment on pass-through endpoints if True
|
||||||
|
use_in_pass_through: Optional[bool] = False,
|
||||||
**params,
|
**params,
|
||||||
):
|
):
|
||||||
args = locals()
|
args = locals()
|
||||||
|
|
|
@ -1729,6 +1729,7 @@ all_litellm_params = [
|
||||||
"max_fallbacks",
|
"max_fallbacks",
|
||||||
"max_budget",
|
"max_budget",
|
||||||
"budget_duration",
|
"budget_duration",
|
||||||
|
"use_in_pass_through",
|
||||||
] + list(StandardCallbackDynamicParams.__annotations__.keys())
|
] + list(StandardCallbackDynamicParams.__annotations__.keys())
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -6,7 +6,7 @@ from unittest.mock import AsyncMock, Mock, patch
|
||||||
|
|
||||||
sys.path.insert(
|
sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system-path
|
||||||
|
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
@ -23,6 +23,9 @@ from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import (
|
||||||
VertexPassThroughCredentials,
|
VertexPassThroughCredentials,
|
||||||
default_vertex_config,
|
default_vertex_config,
|
||||||
)
|
)
|
||||||
|
from litellm.proxy.vertex_ai_endpoints.vertex_passthrough_router import (
|
||||||
|
VertexPassThroughRouter,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@ -167,3 +170,123 @@ async def test_set_default_vertex_config():
|
||||||
del os.environ["DEFAULT_VERTEXAI_LOCATION"]
|
del os.environ["DEFAULT_VERTEXAI_LOCATION"]
|
||||||
del os.environ["DEFAULT_GOOGLE_APPLICATION_CREDENTIALS"]
|
del os.environ["DEFAULT_GOOGLE_APPLICATION_CREDENTIALS"]
|
||||||
del os.environ["GOOGLE_CREDS"]
|
del os.environ["GOOGLE_CREDS"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_vertex_passthrough_router_init():
|
||||||
|
"""Test VertexPassThroughRouter initialization"""
|
||||||
|
router = VertexPassThroughRouter()
|
||||||
|
assert isinstance(router.deployment_key_to_vertex_credentials, dict)
|
||||||
|
assert len(router.deployment_key_to_vertex_credentials) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_vertex_credentials_none():
|
||||||
|
"""Test get_vertex_credentials with various inputs"""
|
||||||
|
from litellm.proxy.vertex_ai_endpoints import vertex_endpoints
|
||||||
|
|
||||||
|
setattr(vertex_endpoints, "default_vertex_config", VertexPassThroughCredentials())
|
||||||
|
router = VertexPassThroughRouter()
|
||||||
|
|
||||||
|
# Test with None project_id and location - should return default config
|
||||||
|
creds = router.get_vertex_credentials(None, None)
|
||||||
|
assert isinstance(creds, VertexPassThroughCredentials)
|
||||||
|
|
||||||
|
# Test with valid project_id and location but no stored credentials
|
||||||
|
creds = router.get_vertex_credentials("test-project", "us-central1")
|
||||||
|
assert isinstance(creds, VertexPassThroughCredentials)
|
||||||
|
assert creds.vertex_project is None
|
||||||
|
assert creds.vertex_location is None
|
||||||
|
assert creds.vertex_credentials is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_vertex_credentials_stored():
|
||||||
|
"""Test get_vertex_credentials with stored credentials"""
|
||||||
|
router = VertexPassThroughRouter()
|
||||||
|
router.add_vertex_credentials(
|
||||||
|
project_id="test-project",
|
||||||
|
location="us-central1",
|
||||||
|
vertex_credentials="test-creds",
|
||||||
|
)
|
||||||
|
|
||||||
|
creds = router.get_vertex_credentials(
|
||||||
|
project_id="test-project", location="us-central1"
|
||||||
|
)
|
||||||
|
assert creds.vertex_project == "test-project"
|
||||||
|
assert creds.vertex_location == "us-central1"
|
||||||
|
assert creds.vertex_credentials == "test-creds"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_vertex_credentials():
|
||||||
|
"""Test add_vertex_credentials functionality"""
|
||||||
|
router = VertexPassThroughRouter()
|
||||||
|
|
||||||
|
# Test adding valid credentials
|
||||||
|
router.add_vertex_credentials(
|
||||||
|
project_id="test-project",
|
||||||
|
location="us-central1",
|
||||||
|
vertex_credentials="test-creds",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "test-project-us-central1" in router.deployment_key_to_vertex_credentials
|
||||||
|
creds = router.deployment_key_to_vertex_credentials["test-project-us-central1"]
|
||||||
|
assert creds.vertex_project == "test-project"
|
||||||
|
assert creds.vertex_location == "us-central1"
|
||||||
|
assert creds.vertex_credentials == "test-creds"
|
||||||
|
|
||||||
|
# Test adding with None values
|
||||||
|
router.add_vertex_credentials(
|
||||||
|
project_id=None, location=None, vertex_credentials="test-creds"
|
||||||
|
)
|
||||||
|
# Should not add None values
|
||||||
|
assert len(router.deployment_key_to_vertex_credentials) == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_deployment_key():
|
||||||
|
"""Test _get_deployment_key with various inputs"""
|
||||||
|
router = VertexPassThroughRouter()
|
||||||
|
|
||||||
|
# Test with valid inputs
|
||||||
|
key = router._get_deployment_key("test-project", "us-central1")
|
||||||
|
assert key == "test-project-us-central1"
|
||||||
|
|
||||||
|
# Test with None values
|
||||||
|
key = router._get_deployment_key(None, "us-central1")
|
||||||
|
assert key is None
|
||||||
|
|
||||||
|
key = router._get_deployment_key("test-project", None)
|
||||||
|
assert key is None
|
||||||
|
|
||||||
|
key = router._get_deployment_key(None, None)
|
||||||
|
assert key is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_vertex_project_id_from_url():
|
||||||
|
"""Test _get_vertex_project_id_from_url with various URLs"""
|
||||||
|
# Test with valid URL
|
||||||
|
url = "https://us-central1-aiplatform.googleapis.com/v1/projects/test-project/locations/us-central1/publishers/google/models/gemini-pro:streamGenerateContent"
|
||||||
|
project_id = VertexPassThroughRouter._get_vertex_project_id_from_url(url)
|
||||||
|
assert project_id == "test-project"
|
||||||
|
|
||||||
|
# Test with invalid URL
|
||||||
|
url = "https://invalid-url.com"
|
||||||
|
project_id = VertexPassThroughRouter._get_vertex_project_id_from_url(url)
|
||||||
|
assert project_id is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_vertex_location_from_url():
|
||||||
|
"""Test _get_vertex_location_from_url with various URLs"""
|
||||||
|
# Test with valid URL
|
||||||
|
url = "https://us-central1-aiplatform.googleapis.com/v1/projects/test-project/locations/us-central1/publishers/google/models/gemini-pro:streamGenerateContent"
|
||||||
|
location = VertexPassThroughRouter._get_vertex_location_from_url(url)
|
||||||
|
assert location == "us-central1"
|
||||||
|
|
||||||
|
# Test with invalid URL
|
||||||
|
url = "https://invalid-url.com"
|
||||||
|
location = VertexPassThroughRouter._get_vertex_location_from_url(url)
|
||||||
|
assert location is None
|
||||||
|
|
170
tests/router_unit_tests/test_router_adding_deployments.py
Normal file
170
tests/router_unit_tests/test_router_adding_deployments.py
Normal file
|
@ -0,0 +1,170 @@
|
||||||
|
import sys, os
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
from litellm import Router
|
||||||
|
from litellm.router import Deployment, LiteLLM_Params
|
||||||
|
from unittest.mock import patch
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
def test_initialize_deployment_for_pass_through_success():
|
||||||
|
"""
|
||||||
|
Test successful initialization of a Vertex AI pass-through deployment
|
||||||
|
"""
|
||||||
|
router = Router(model_list=[])
|
||||||
|
deployment = Deployment(
|
||||||
|
model_name="vertex-test",
|
||||||
|
litellm_params=LiteLLM_Params(
|
||||||
|
model="vertex_ai/test-model",
|
||||||
|
vertex_project="test-project",
|
||||||
|
vertex_location="us-central1",
|
||||||
|
vertex_credentials=json.dumps(
|
||||||
|
{"type": "service_account", "project_id": "test"}
|
||||||
|
),
|
||||||
|
use_in_pass_through=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test the initialization
|
||||||
|
router._initialize_deployment_for_pass_through(
|
||||||
|
deployment=deployment,
|
||||||
|
custom_llm_provider="vertex_ai",
|
||||||
|
model="vertex_ai/test-model",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the credentials were properly set
|
||||||
|
from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import (
|
||||||
|
vertex_pass_through_router,
|
||||||
|
)
|
||||||
|
|
||||||
|
vertex_creds = vertex_pass_through_router.get_vertex_credentials(
|
||||||
|
project_id="test-project", location="us-central1"
|
||||||
|
)
|
||||||
|
assert vertex_creds.vertex_project == "test-project"
|
||||||
|
assert vertex_creds.vertex_location == "us-central1"
|
||||||
|
assert vertex_creds.vertex_credentials == json.dumps(
|
||||||
|
{"type": "service_account", "project_id": "test"}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_initialize_deployment_for_pass_through_missing_params():
|
||||||
|
"""
|
||||||
|
Test initialization fails when required Vertex AI parameters are missing
|
||||||
|
"""
|
||||||
|
router = Router(model_list=[])
|
||||||
|
deployment = Deployment(
|
||||||
|
model_name="vertex-test",
|
||||||
|
litellm_params=LiteLLM_Params(
|
||||||
|
model="vertex_ai/test-model",
|
||||||
|
# Missing required parameters
|
||||||
|
use_in_pass_through=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test that initialization raises ValueError
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match="vertex_project, vertex_location, and vertex_credentials must be set",
|
||||||
|
):
|
||||||
|
router._initialize_deployment_for_pass_through(
|
||||||
|
deployment=deployment,
|
||||||
|
custom_llm_provider="vertex_ai",
|
||||||
|
model="vertex_ai/test-model",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_initialize_deployment_for_pass_through_unsupported_provider():
|
||||||
|
"""
|
||||||
|
Test initialization with an unsupported provider
|
||||||
|
"""
|
||||||
|
router = Router(model_list=[])
|
||||||
|
deployment = Deployment(
|
||||||
|
model_name="unsupported-test",
|
||||||
|
litellm_params=LiteLLM_Params(
|
||||||
|
model="unsupported/test-model",
|
||||||
|
use_in_pass_through=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should not raise an error, but log a warning
|
||||||
|
router._initialize_deployment_for_pass_through(
|
||||||
|
deployment=deployment,
|
||||||
|
custom_llm_provider="unsupported_provider",
|
||||||
|
model="unsupported/test-model",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_initialize_deployment_when_pass_through_disabled():
|
||||||
|
"""
|
||||||
|
Test that initialization simply exits when use_in_pass_through is False
|
||||||
|
"""
|
||||||
|
router = Router(model_list=[])
|
||||||
|
deployment = Deployment(
|
||||||
|
model_name="vertex-test",
|
||||||
|
litellm_params=LiteLLM_Params(
|
||||||
|
model="vertex_ai/test-model",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# This should exit without error, even with missing vertex parameters
|
||||||
|
router._initialize_deployment_for_pass_through(
|
||||||
|
deployment=deployment,
|
||||||
|
custom_llm_provider="vertex_ai",
|
||||||
|
model="vertex_ai/test-model",
|
||||||
|
)
|
||||||
|
|
||||||
|
# If we reach this point, the test passes as the method exited without raising any errors
|
||||||
|
assert True
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_vertex_pass_through_deployment():
|
||||||
|
"""
|
||||||
|
Test adding a Vertex AI deployment with pass-through configuration
|
||||||
|
"""
|
||||||
|
router = Router(model_list=[])
|
||||||
|
|
||||||
|
# Create a deployment with Vertex AI pass-through settings
|
||||||
|
deployment = Deployment(
|
||||||
|
model_name="vertex-test",
|
||||||
|
litellm_params=LiteLLM_Params(
|
||||||
|
model="vertex_ai/test-model",
|
||||||
|
vertex_project="test-project",
|
||||||
|
vertex_location="us-central1",
|
||||||
|
vertex_credentials=json.dumps(
|
||||||
|
{"type": "service_account", "project_id": "test"}
|
||||||
|
),
|
||||||
|
use_in_pass_through=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add deployment to router
|
||||||
|
router.add_deployment(deployment)
|
||||||
|
|
||||||
|
# Get the vertex credentials from the router
|
||||||
|
from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import (
|
||||||
|
vertex_pass_through_router,
|
||||||
|
)
|
||||||
|
|
||||||
|
# current state of pass-through vertex router
|
||||||
|
print("\n vertex_pass_through_router.deployment_key_to_vertex_credentials\n\n")
|
||||||
|
print(
|
||||||
|
json.dumps(
|
||||||
|
vertex_pass_through_router.deployment_key_to_vertex_credentials,
|
||||||
|
indent=4,
|
||||||
|
default=str,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
vertex_creds = vertex_pass_through_router.get_vertex_credentials(
|
||||||
|
project_id="test-project", location="us-central1"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the credentials were properly set
|
||||||
|
assert vertex_creds.vertex_project == "test-project"
|
||||||
|
assert vertex_creds.vertex_location == "us-central1"
|
||||||
|
assert vertex_creds.vertex_credentials == json.dumps(
|
||||||
|
{"type": "service_account", "project_id": "test"}
|
||||||
|
)
|
Loading…
Add table
Add a link
Reference in a new issue