mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
(Feat) pass through vertex - allow using credentials defined on litellm router for vertex pass through (#8100)
* test_add_vertex_pass_through_deployment * VertexPassThroughRouter * fix use_in_pass_through * VertexPassThroughRouter * fix vertex_credentials * allow using _initialize_deployment_for_pass_through * test_add_vertex_pass_through_deployment * _set_default_vertex_config * fix verbose_proxy_logger * fix use_in_pass_through * fix _get_token_and_url * test_get_vertex_location_from_url * test_get_vertex_credentials_none * run pt unit testing again * fix add_vertex_credentials * test_adding_deployments.py * rename file
This commit is contained in:
parent
892581ffc3
commit
b6d61ec22b
7 changed files with 490 additions and 19 deletions
|
@ -15,7 +15,10 @@ from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
|
|||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.passthrough_endpoints.vertex_ai import *
|
||||
|
||||
from .vertex_passthrough_router import VertexPassThroughRouter
|
||||
|
||||
router = APIRouter()
|
||||
vertex_pass_through_router = VertexPassThroughRouter()
|
||||
|
||||
default_vertex_config: VertexPassThroughCredentials = VertexPassThroughCredentials()
|
||||
|
||||
|
@ -59,7 +62,14 @@ def set_default_vertex_config(config: Optional[dict] = None):
|
|||
if isinstance(value, str) and value.startswith("os.environ/"):
|
||||
config[key] = litellm.get_secret(value)
|
||||
|
||||
default_vertex_config = VertexPassThroughCredentials(**config)
|
||||
_set_default_vertex_config(VertexPassThroughCredentials(**config))
|
||||
|
||||
|
||||
def _set_default_vertex_config(
|
||||
vertex_pass_through_credentials: VertexPassThroughCredentials,
|
||||
):
|
||||
global default_vertex_config
|
||||
default_vertex_config = vertex_pass_through_credentials
|
||||
|
||||
|
||||
def exception_handler(e: Exception):
|
||||
|
@ -147,9 +157,6 @@ async def vertex_proxy_route(
|
|||
[Docs](https://docs.litellm.ai/docs/pass_through/vertex_ai)
|
||||
"""
|
||||
encoded_endpoint = httpx.URL(endpoint).path
|
||||
|
||||
import re
|
||||
|
||||
verbose_proxy_logger.debug("requested endpoint %s", endpoint)
|
||||
headers: dict = {}
|
||||
api_key_to_use = get_litellm_virtual_key(request=request)
|
||||
|
@ -158,31 +165,37 @@ async def vertex_proxy_route(
|
|||
api_key=api_key_to_use,
|
||||
)
|
||||
|
||||
vertex_project = None
|
||||
vertex_location = None
|
||||
# Use headers from the incoming request if default_vertex_config is not set
|
||||
if default_vertex_config.vertex_project is None:
|
||||
vertex_project: Optional[str] = (
|
||||
VertexPassThroughRouter._get_vertex_project_id_from_url(endpoint)
|
||||
)
|
||||
vertex_location: Optional[str] = (
|
||||
VertexPassThroughRouter._get_vertex_location_from_url(endpoint)
|
||||
)
|
||||
vertex_credentials = vertex_pass_through_router.get_vertex_credentials(
|
||||
project_id=vertex_project,
|
||||
location=vertex_location,
|
||||
)
|
||||
|
||||
# Use headers from the incoming request if no vertex credentials are found
|
||||
if vertex_credentials.vertex_project is None:
|
||||
headers = dict(request.headers) or {}
|
||||
verbose_proxy_logger.debug(
|
||||
"default_vertex_config not set, incoming request headers %s", headers
|
||||
)
|
||||
# extract location from endpoint, endpoint
|
||||
# "v1beta1/projects/adroit-crow-413218/locations/us-central1/publishers/google/models/gemini-1.5-pro:generateContent"
|
||||
match = re.search(r"/locations/([^/]+)", endpoint)
|
||||
vertex_location = match.group(1) if match else None
|
||||
base_target_url = f"https://{vertex_location}-aiplatform.googleapis.com/"
|
||||
headers.pop("content-length", None)
|
||||
headers.pop("host", None)
|
||||
else:
|
||||
vertex_project = default_vertex_config.vertex_project
|
||||
vertex_location = default_vertex_config.vertex_location
|
||||
vertex_credentials = default_vertex_config.vertex_credentials
|
||||
vertex_project = vertex_credentials.vertex_project
|
||||
vertex_location = vertex_credentials.vertex_location
|
||||
vertex_credentials_str = vertex_credentials.vertex_credentials
|
||||
|
||||
# Construct base URL for the target endpoint
|
||||
base_target_url = f"https://{vertex_location}-aiplatform.googleapis.com/"
|
||||
|
||||
_auth_header, vertex_project = (
|
||||
await vertex_fine_tuning_apis_instance._ensure_access_token_async(
|
||||
credentials=vertex_credentials,
|
||||
credentials=vertex_credentials_str,
|
||||
project_id=vertex_project,
|
||||
custom_llm_provider="vertex_ai_beta",
|
||||
)
|
||||
|
@ -192,7 +205,7 @@ async def vertex_proxy_route(
|
|||
model="",
|
||||
auth_header=_auth_header,
|
||||
gemini_api_key=None,
|
||||
vertex_credentials=vertex_credentials,
|
||||
vertex_credentials=vertex_credentials_str,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
stream=False,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue