mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
fix(vertex_ai/common_utils.py): fix handling constructed url with default vertex config
This commit is contained in:
parent
b44b3bd36b
commit
06e69a414e
4 changed files with 98 additions and 17 deletions
|
@ -303,11 +303,26 @@ def get_vertex_location_from_url(url: str) -> Optional[str]:
|
||||||
return match.group(1) if match else None
|
return match.group(1) if match else None
|
||||||
|
|
||||||
|
|
||||||
|
def replace_project_and_location_in_route(
|
||||||
|
requested_route: str, vertex_project: str, vertex_location: str
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Replace project and location values in the route with the provided values
|
||||||
|
"""
|
||||||
|
# Replace project and location values while keeping route structure
|
||||||
|
modified_route = re.sub(
|
||||||
|
r"/projects/[^/]+/locations/[^/]+/",
|
||||||
|
f"/projects/{vertex_project}/locations/{vertex_location}/",
|
||||||
|
requested_route,
|
||||||
|
)
|
||||||
|
return modified_route
|
||||||
|
|
||||||
|
|
||||||
def construct_target_url(
|
def construct_target_url(
|
||||||
base_url: str,
|
base_url: str,
|
||||||
requested_route: str,
|
requested_route: str,
|
||||||
default_vertex_location: Optional[str],
|
vertex_location: Optional[str],
|
||||||
default_vertex_project: Optional[str],
|
vertex_project: Optional[str],
|
||||||
) -> httpx.URL:
|
) -> httpx.URL:
|
||||||
"""
|
"""
|
||||||
Allow user to specify their own project id / location.
|
Allow user to specify their own project id / location.
|
||||||
|
@ -321,8 +336,12 @@ def construct_target_url(
|
||||||
"""
|
"""
|
||||||
new_base_url = httpx.URL(base_url)
|
new_base_url = httpx.URL(base_url)
|
||||||
if "locations" in requested_route: # contains the target project id + location
|
if "locations" in requested_route: # contains the target project id + location
|
||||||
updated_url = new_base_url.copy_with(path=requested_route)
|
if vertex_project and vertex_location:
|
||||||
return updated_url
|
requested_route = replace_project_and_location_in_route(
|
||||||
|
requested_route, vertex_project, vertex_location
|
||||||
|
)
|
||||||
|
return new_base_url.copy_with(path=requested_route)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
- Add endpoint version (e.g. v1beta for cachedContent, v1 for rest)
|
- Add endpoint version (e.g. v1beta for cachedContent, v1 for rest)
|
||||||
- Add default project id
|
- Add default project id
|
||||||
|
@ -333,7 +352,7 @@ def construct_target_url(
|
||||||
vertex_version = "v1beta1"
|
vertex_version = "v1beta1"
|
||||||
|
|
||||||
base_requested_route = "{}/projects/{}/locations/{}".format(
|
base_requested_route = "{}/projects/{}/locations/{}".format(
|
||||||
vertex_version, default_vertex_project, default_vertex_location
|
vertex_version, vertex_project, vertex_location
|
||||||
)
|
)
|
||||||
|
|
||||||
updated_requested_route = "/" + base_requested_route + requested_route
|
updated_requested_route = "/" + base_requested_route + requested_route
|
||||||
|
|
|
@ -463,13 +463,8 @@ async def vertex_proxy_route(
|
||||||
location=vertex_location,
|
location=vertex_location,
|
||||||
)
|
)
|
||||||
|
|
||||||
if vertex_credentials is None:
|
|
||||||
raise Exception(
|
|
||||||
f"No matching vertex credentials found, for project_id: {vertex_project}, location: {vertex_location}. No default_vertex_config set either."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Use headers from the incoming request if no vertex credentials are found
|
# Use headers from the incoming request if no vertex credentials are found
|
||||||
if vertex_credentials.vertex_project is None:
|
if vertex_credentials is None or vertex_credentials.vertex_project is None:
|
||||||
headers = dict(request.headers) or {}
|
headers = dict(request.headers) or {}
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
"default_vertex_config not set, incoming request headers %s", headers
|
"default_vertex_config not set, incoming request headers %s", headers
|
||||||
|
@ -518,8 +513,8 @@ async def vertex_proxy_route(
|
||||||
updated_url = construct_target_url(
|
updated_url = construct_target_url(
|
||||||
base_url=base_target_url,
|
base_url=base_target_url,
|
||||||
requested_route=encoded_endpoint,
|
requested_route=encoded_endpoint,
|
||||||
default_vertex_location=vertex_location,
|
vertex_location=vertex_location,
|
||||||
default_vertex_project=vertex_project,
|
vertex_project=vertex_project,
|
||||||
)
|
)
|
||||||
# base_url = httpx.URL(base_target_url)
|
# base_url = httpx.URL(base_target_url)
|
||||||
# updated_url = base_url.copy_with(path=encoded_endpoint)
|
# updated_url = base_url.copy_with(path=encoded_endpoint)
|
||||||
|
|
|
@ -150,19 +150,17 @@ class PassthroughEndpointRouter:
|
||||||
"""
|
"""
|
||||||
Get the vertex credentials for the given project-id, location
|
Get the vertex credentials for the given project-id, location
|
||||||
"""
|
"""
|
||||||
default_vertex_config: Optional[VertexPassThroughCredentials] = None
|
|
||||||
|
|
||||||
deployment_key = self._get_deployment_key(
|
deployment_key = self._get_deployment_key(
|
||||||
project_id=project_id,
|
project_id=project_id,
|
||||||
location=location,
|
location=location,
|
||||||
)
|
)
|
||||||
|
|
||||||
if deployment_key is None:
|
if deployment_key is None:
|
||||||
return default_vertex_config
|
return self.default_vertex_config
|
||||||
if deployment_key in self.deployment_key_to_vertex_credentials:
|
if deployment_key in self.deployment_key_to_vertex_credentials:
|
||||||
return self.deployment_key_to_vertex_credentials[deployment_key]
|
return self.deployment_key_to_vertex_credentials[deployment_key]
|
||||||
else:
|
else:
|
||||||
return default_vertex_config
|
return self.default_vertex_config
|
||||||
|
|
||||||
def _get_credential_name_for_provider(
|
def _get_credential_name_for_provider(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
import traceback
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
@ -262,3 +263,71 @@ class TestVertexAIPassThroughHandler:
|
||||||
target=f"https://{test_location}-aiplatform.googleapis.com/v1/projects/{test_project}/locations/{test_location}/publishers/google/models/gemini-1.5-flash:generateContent",
|
target=f"https://{test_location}-aiplatform.googleapis.com/v1/projects/{test_project}/locations/{test_location}/publishers/google/models/gemini-1.5-flash:generateContent",
|
||||||
custom_headers={"Authorization": f"Bearer {test_token}"},
|
custom_headers={"Authorization": f"Bearer {test_token}"},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_vertex_passthrough_with_default_credentials(self, monkeypatch):
|
||||||
|
"""
|
||||||
|
Test that when no passthrough credentials are set, default credentials are used in the request
|
||||||
|
"""
|
||||||
|
from litellm.proxy.pass_through_endpoints.passthrough_endpoint_router import (
|
||||||
|
PassthroughEndpointRouter,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Setup default credentials
|
||||||
|
default_project = "default-project"
|
||||||
|
default_location = "us-central1"
|
||||||
|
default_credentials = "default-creds"
|
||||||
|
|
||||||
|
pass_through_router = PassthroughEndpointRouter()
|
||||||
|
pass_through_router.default_vertex_config = VertexPassThroughCredentials(
|
||||||
|
vertex_project=default_project,
|
||||||
|
vertex_location=default_location,
|
||||||
|
vertex_credentials=default_credentials,
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.passthrough_endpoint_router",
|
||||||
|
pass_through_router,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use different project/location in request than the default
|
||||||
|
request_project = "non-existing-project"
|
||||||
|
request_location = "bad-location"
|
||||||
|
endpoint = f"/v1/projects/{request_project}/locations/{request_location}/publishers/google/models/gemini-1.5-flash:generateContent"
|
||||||
|
|
||||||
|
mock_request = Request(
|
||||||
|
scope={
|
||||||
|
"type": "http",
|
||||||
|
"method": "POST",
|
||||||
|
"path": endpoint,
|
||||||
|
"headers": {},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
mock_response = Response()
|
||||||
|
|
||||||
|
with mock.patch(
|
||||||
|
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.vertex_llm_base._ensure_access_token_async"
|
||||||
|
) as mock_ensure_token, mock.patch(
|
||||||
|
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.vertex_llm_base._get_token_and_url"
|
||||||
|
) as mock_get_token, mock.patch(
|
||||||
|
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.create_pass_through_route"
|
||||||
|
) as mock_create_route:
|
||||||
|
mock_ensure_token.return_value = ("test-auth-header", default_project)
|
||||||
|
mock_get_token.return_value = (default_credentials, "")
|
||||||
|
|
||||||
|
try:
|
||||||
|
await vertex_proxy_route(
|
||||||
|
endpoint=endpoint,
|
||||||
|
request=mock_request,
|
||||||
|
fastapi_response=mock_response,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
||||||
|
print(f"Error: {e}")
|
||||||
|
|
||||||
|
# Verify default credentials were used
|
||||||
|
mock_create_route.assert_called_once_with(
|
||||||
|
endpoint=endpoint,
|
||||||
|
target=f"https://{default_location}-aiplatform.googleapis.com/v1/projects/{default_project}/locations/{default_location}/publishers/google/models/gemini-1.5-flash:generateContent",
|
||||||
|
custom_headers={"Authorization": f"Bearer {default_credentials}"},
|
||||||
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue