diff --git a/docs/my-website/docs/pass_through/vertex_ai.md b/docs/my-website/docs/pass_through/vertex_ai.md index ce366af541..4918d889ed 100644 --- a/docs/my-website/docs/pass_through/vertex_ai.md +++ b/docs/my-website/docs/pass_through/vertex_ai.md @@ -15,6 +15,91 @@ Pass-through endpoints for Vertex AI - call provider-specific endpoint, in nativ Just replace `https://REGION-aiplatform.googleapis.com` with `LITELLM_PROXY_BASE_URL/vertex_ai` +LiteLLM supports 3 flows for calling Vertex AI endpoints via pass-through: + +1. **Specific Credentials**: Admin sets passthrough credentials for a specific project/region. + +2. **Default Credentials**: Admin sets default credentials. + +3. **Client-Side Credentials**: User can send client-side credentials through to Vertex AI (default behavior - if no default or mapped credentials are found, the request is passed through directly). + + +## Example Usage + + + + +```yaml +model_list: + - model_name: gemini-1.0-pro + litellm_params: + model: vertex_ai/gemini-1.0-pro + vertex_project: adroit-crow-413218 + vertex_region: us-central1 + vertex_credentials: /path/to/credentials.json + use_in_pass_through: true # 👈 KEY CHANGE +``` + + + + + + + +```yaml +default_vertex_config: + vertex_project: adroit-crow-413218 + vertex_region: us-central1 + vertex_credentials: /path/to/credentials.json +``` + + + +```bash +export DEFAULT_VERTEXAI_PROJECT="adroit-crow-413218" +export DEFAULT_VERTEXAI_LOCATION="us-central1" +export DEFAULT_GOOGLE_APPLICATION_CREDENTIALS="/path/to/credentials.json" +``` + + + + + + +Try Gemini 2.0 Flash (curl) + +``` +MODEL_ID="gemini-2.0-flash-001" +PROJECT_ID="YOUR_PROJECT_ID" +``` + +```bash +curl \ + -X POST \ + -H "Authorization: Bearer $(gcloud auth application-default print-access-token)" \ + -H "Content-Type: application/json" \ + "${LITELLM_PROXY_BASE_URL}/vertex_ai/v1/projects/${PROJECT_ID}/locations/us-central1/publishers/google/models/${MODEL_ID}:streamGenerateContent" -d \ + $'{ + "contents": { + "role": "user", + "parts": [ + { + "fileData": { + "mimeType": "image/png", + "fileUri": "gs://generativeai-downloads/images/scones.jpg" + } + }, + { + "text": "Describe this picture." + } + ] + } + }' +``` + + + + #### **Example Usage** @@ -22,7 +107,7 @@ Just replace `https://REGION-aiplatform.googleapis.com` with `LITELLM_PROXY_BASE ```bash -curl http://localhost:4000/vertex_ai/publishers/google/models/gemini-1.0-pro:generateContent \ +curl http://localhost:4000/vertex_ai/vertex_ai/v1/projects/${PROJECT_ID}/locations/us-central1/publishers/google/models/${MODEL_ID}:generateContent \ -H "Content-Type: application/json" \ -H "x-litellm-api-key: Bearer sk-1234" \ -d '{ @@ -101,7 +186,7 @@ litellm Let's call the Google AI Studio token counting endpoint ```bash -curl http://localhost:4000/vertex-ai/publishers/google/models/gemini-1.0-pro:generateContent \ +curl http://localhost:4000/vertex-ai/v1/projects/${PROJECT_ID}/locations/us-central1/publishers/google/models/gemini-1.0-pro:generateContent \ -H "Content-Type: application/json" \ -H "Authorization: Bearer sk-1234" \ -d '{ @@ -140,7 +225,7 @@ LiteLLM Proxy Server supports two methods of authentication to Vertex AI: ```shell -curl http://localhost:4000/vertex_ai/publishers/google/models/gemini-1.5-flash-001:generateContent \ +curl http://localhost:4000/vertex_ai/v1/projects/${PROJECT_ID}/locations/us-central1/publishers/google/models/gemini-1.5-flash-001:generateContent \ -H "Content-Type: application/json" \ -H "x-litellm-api-key: Bearer sk-1234" \ -d '{"contents":[{"role": "user", "parts":[{"text": "hi"}]}]}' @@ -152,7 +237,7 @@ curl http://localhost:4000/vertex_ai/publishers/google/models/gemini-1.5-flash-0 ```shell -curl http://localhost:4000/vertex_ai/publishers/google/models/textembedding-gecko@001:predict \ +curl http://localhost:4000/vertex_ai/v1/projects/${PROJECT_ID}/locations/us-central1/publishers/google/models/textembedding-gecko@001:predict \ -H "Content-Type: application/json" \ -H "x-litellm-api-key: Bearer sk-1234" \ -d '{"instances":[{"content": "gm"}]}' @@ -162,7 +247,7 @@ curl http://localhost:4000/vertex_ai/publishers/google/models/textembedding-geck ### Imagen API ```shell -curl http://localhost:4000/vertex_ai/publishers/google/models/imagen-3.0-generate-001:predict \ +curl http://localhost:4000/vertex_ai/v1/projects/${PROJECT_ID}/locations/us-central1/publishers/google/models/imagen-3.0-generate-001:predict \ -H "Content-Type: application/json" \ -H "x-litellm-api-key: Bearer sk-1234" \ -d '{"instances":[{"prompt": "make an otter"}], "parameters": {"sampleCount": 1}}' @@ -172,7 +257,7 @@ curl http://localhost:4000/vertex_ai/publishers/google/models/imagen-3.0-generat ### Count Tokens API ```shell -curl http://localhost:4000/vertex_ai/publishers/google/models/gemini-1.5-flash-001:countTokens \ +curl http://localhost:4000/vertex_ai/v1/projects/${PROJECT_ID}/locations/us-central1/publishers/google/models/gemini-1.5-flash-001:countTokens \ -H "Content-Type: application/json" \ -H "x-litellm-api-key: Bearer sk-1234" \ -d '{"contents":[{"role": "user", "parts":[{"text": "hi"}]}]}' @@ -183,7 +268,7 @@ Create Fine Tuning Job ```shell -curl http://localhost:4000/vertex_ai/tuningJobs \ +curl http://localhost:4000/vertex_ai/v1/projects/${PROJECT_ID}/locations/us-central1/publishers/google/models/gemini-1.5-flash-001:tuningJobs \ -H "Content-Type: application/json" \ -H "x-litellm-api-key: Bearer sk-1234" \ -d '{ @@ -243,7 +328,7 @@ Expected Response ```bash -curl http://localhost:4000/vertex_ai/publishers/google/models/gemini-1.0-pro:generateContent \ +curl http://localhost:4000/vertex_ai/v1/projects/${PROJECT_ID}/locations/us-central1/publishers/google/models/gemini-1.0-pro:generateContent \ -H "Content-Type: application/json" \ -H "x-litellm-api-key: Bearer sk-1234" \ -d '{ @@ -268,7 +353,7 @@ tags: ["vertex-js-sdk", "pass-through-endpoint"] ```bash -curl http://localhost:4000/vertex-ai/publishers/google/models/gemini-1.0-pro:generateContent \ +curl http://localhost:4000/vertex_ai/v1/projects/${PROJECT_ID}/locations/us-central1/publishers/google/models/gemini-1.0-pro:generateContent \ -H "Content-Type: application/json" \ -H "x-litellm-api-key: Bearer sk-1234" \ -H "tags: vertex-js-sdk,pass-through-endpoint" \ diff --git a/litellm/llms/vertex_ai/common_utils.py b/litellm/llms/vertex_ai/common_utils.py index f7149c349a..a3f91fbacc 100644 --- a/litellm/llms/vertex_ai/common_utils.py +++ b/litellm/llms/vertex_ai/common_utils.py @@ -1,3 +1,4 @@ +import re from typing import Dict, List, Literal, Optional, Tuple, Union import httpx @@ -280,3 +281,81 @@ def _convert_vertex_datetime_to_openai_datetime(vertex_datetime: str) -> int: dt = datetime.strptime(vertex_datetime, "%Y-%m-%dT%H:%M:%S.%fZ") # Convert to Unix timestamp (seconds since epoch) return int(dt.timestamp()) + + +def get_vertex_project_id_from_url(url: str) -> Optional[str]: + """ + Get the vertex project id from the url + + `https://${LOCATION}-aiplatform.googleapis.com/v1/projects/${PROJECT_ID}/locations/${LOCATION}/publishers/google/models/${MODEL_ID}:streamGenerateContent` + """ + match = re.search(r"/projects/([^/]+)", url) + return match.group(1) if match else None + + +def get_vertex_location_from_url(url: str) -> Optional[str]: + """ + Get the vertex location from the url + + `https://${LOCATION}-aiplatform.googleapis.com/v1/projects/${PROJECT_ID}/locations/${LOCATION}/publishers/google/models/${MODEL_ID}:streamGenerateContent` + """ + match = re.search(r"/locations/([^/]+)", url) + return match.group(1) if match else None + + +def replace_project_and_location_in_route( + requested_route: str, vertex_project: str, vertex_location: str +) -> str: + """ + Replace project and location values in the route with the provided values + """ + # Replace project and location values while keeping route structure + modified_route = re.sub( + r"/projects/[^/]+/locations/[^/]+/", + f"/projects/{vertex_project}/locations/{vertex_location}/", + requested_route, + ) + return modified_route + + +def construct_target_url( + base_url: str, + requested_route: str, + vertex_location: Optional[str], + vertex_project: Optional[str], +) -> httpx.URL: + """ + Allow user to specify their own project id / location. + + If missing, use defaults + + Handle cachedContent scenario - https://github.com/BerriAI/litellm/issues/5460 + + Constructed Url: + POST https://LOCATION-aiplatform.googleapis.com/{version}/projects/PROJECT_ID/locations/LOCATION/cachedContents + """ + new_base_url = httpx.URL(base_url) + if "locations" in requested_route: # contains the target project id + location + if vertex_project and vertex_location: + requested_route = replace_project_and_location_in_route( + requested_route, vertex_project, vertex_location + ) + return new_base_url.copy_with(path=requested_route) + + """ + - Add endpoint version (e.g. v1beta for cachedContent, v1 for rest) + - Add default project id + - Add default location + """ + vertex_version: Literal["v1", "v1beta1"] = "v1" + if "cachedContent" in requested_route: + vertex_version = "v1beta1" + + base_requested_route = "{}/projects/{}/locations/{}".format( + vertex_version, vertex_project, vertex_location + ) + + updated_requested_route = "/" + base_requested_route + requested_route + + updated_url = new_base_url.copy_with(path=updated_requested_route) + return updated_url diff --git a/litellm/model_prices_and_context_window_backup.json b/litellm/model_prices_and_context_window_backup.json index 6131afdd51..330fe890e4 100644 --- a/litellm/model_prices_and_context_window_backup.json +++ b/litellm/model_prices_and_context_window_backup.json @@ -82,6 +82,31 @@ "supports_system_messages": true, "supports_tool_choice": true }, + "gpt-4o-search-preview-2025-03-11": { + "max_tokens": 16384, + "max_input_tokens": 128000, + "max_output_tokens": 16384, + "input_cost_per_token": 0.0000025, + "output_cost_per_token": 0.000010, + "input_cost_per_token_batches": 0.00000125, + "output_cost_per_token_batches": 0.00000500, + "cache_read_input_token_cost": 0.00000125, + "litellm_provider": "openai", + "mode": "chat", + "supports_function_calling": true, + "supports_parallel_function_calling": true, + "supports_response_schema": true, + "supports_vision": true, + "supports_prompt_caching": true, + "supports_system_messages": true, + "supports_tool_choice": true, + "supports_web_search": true, + "search_context_cost_per_query": { + "search_context_size_low": 0.030, + "search_context_size_medium": 0.035, + "search_context_size_high": 0.050 + } + }, "gpt-4o-search-preview": { "max_tokens": 16384, "max_input_tokens": 128000, @@ -232,6 +257,31 @@ "supports_system_messages": true, "supports_tool_choice": true }, + "gpt-4o-mini-search-preview-2025-03-11":{ + "max_tokens": 16384, + "max_input_tokens": 128000, + "max_output_tokens": 16384, + "input_cost_per_token": 0.00000015, + "output_cost_per_token": 0.00000060, + "input_cost_per_token_batches": 0.000000075, + "output_cost_per_token_batches": 0.00000030, + "cache_read_input_token_cost": 0.000000075, + "litellm_provider": "openai", + "mode": "chat", + "supports_function_calling": true, + "supports_parallel_function_calling": true, + "supports_response_schema": true, + "supports_vision": true, + "supports_prompt_caching": true, + "supports_system_messages": true, + "supports_tool_choice": true, + "supports_web_search": true, + "search_context_cost_per_query": { + "search_context_size_low": 0.025, + "search_context_size_medium": 0.0275, + "search_context_size_high": 0.030 + } + }, "gpt-4o-mini-search-preview": { "max_tokens": 16384, "max_input_tokens": 128000, diff --git a/litellm/proxy/pass_through_endpoints/common_utils.py b/litellm/proxy/pass_through_endpoints/common_utils.py new file mode 100644 index 0000000000..3a3783dd57 --- /dev/null +++ b/litellm/proxy/pass_through_endpoints/common_utils.py @@ -0,0 +1,16 @@ +from fastapi import Request + + +def get_litellm_virtual_key(request: Request) -> str: + """ + Extract and format API key from request headers. + Prioritizes x-litellm-api-key over Authorization header. + + + Vertex JS SDK uses `Authorization` header, we use `x-litellm-api-key` to pass litellm virtual key + + """ + litellm_api_key = request.headers.get("x-litellm-api-key") + if litellm_api_key: + return f"Bearer {litellm_api_key}" + return request.headers.get("Authorization", "") diff --git a/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py b/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py index 4724c7f9d1..c4d96b67f6 100644 --- a/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py @@ -12,10 +12,13 @@ import httpx from fastapi import APIRouter, Depends, HTTPException, Request, Response import litellm +from litellm._logging import verbose_proxy_logger from litellm.constants import BEDROCK_AGENT_RUNTIME_PASS_THROUGH_ROUTES +from litellm.llms.vertex_ai.vertex_llm_base import VertexBase from litellm.proxy._types import * from litellm.proxy.auth.route_checks import RouteChecks from litellm.proxy.auth.user_api_key_auth import user_api_key_auth +from litellm.proxy.pass_through_endpoints.common_utils import get_litellm_virtual_key from litellm.proxy.pass_through_endpoints.pass_through_endpoints import ( create_pass_through_route, ) @@ -23,6 +26,7 @@ from litellm.secret_managers.main import get_secret_str from .passthrough_endpoint_router import PassthroughEndpointRouter +vertex_llm_base = VertexBase() router = APIRouter() default_vertex_config = None @@ -417,6 +421,138 @@ async def azure_proxy_route( ) +@router.api_route( + "/vertex-ai/{endpoint:path}", + methods=["GET", "POST", "PUT", "DELETE", "PATCH"], + tags=["Vertex AI Pass-through", "pass-through"], + include_in_schema=False, +) +@router.api_route( + "/vertex_ai/{endpoint:path}", + methods=["GET", "POST", "PUT", "DELETE", "PATCH"], + tags=["Vertex AI Pass-through", "pass-through"], +) +async def vertex_proxy_route( + endpoint: str, + request: Request, + fastapi_response: Response, +): + """ + Call LiteLLM proxy via Vertex AI SDK. + + [Docs](https://docs.litellm.ai/docs/pass_through/vertex_ai) + """ + from litellm.llms.vertex_ai.common_utils import ( + construct_target_url, + get_vertex_location_from_url, + get_vertex_project_id_from_url, + ) + + encoded_endpoint = httpx.URL(endpoint).path + verbose_proxy_logger.debug("requested endpoint %s", endpoint) + headers: dict = {} + api_key_to_use = get_litellm_virtual_key(request=request) + user_api_key_dict = await user_api_key_auth( + request=request, + api_key=api_key_to_use, + ) + vertex_project: Optional[str] = get_vertex_project_id_from_url(endpoint) + vertex_location: Optional[str] = get_vertex_location_from_url(endpoint) + vertex_credentials = passthrough_endpoint_router.get_vertex_credentials( + project_id=vertex_project, + location=vertex_location, + ) + + headers_passed_through = False + # Use headers from the incoming request if no vertex credentials are found + if vertex_credentials is None or vertex_credentials.vertex_project is None: + headers = dict(request.headers) or {} + headers_passed_through = True + verbose_proxy_logger.debug( + "default_vertex_config not set, incoming request headers %s", headers + ) + base_target_url = f"https://{vertex_location}-aiplatform.googleapis.com/" + headers.pop("content-length", None) + headers.pop("host", None) + else: + vertex_project = vertex_credentials.vertex_project + vertex_location = vertex_credentials.vertex_location + vertex_credentials_str = vertex_credentials.vertex_credentials + + # Construct base URL for the target endpoint + base_target_url = f"https://{vertex_location}-aiplatform.googleapis.com/" + + _auth_header, vertex_project = await vertex_llm_base._ensure_access_token_async( + credentials=vertex_credentials_str, + project_id=vertex_project, + custom_llm_provider="vertex_ai_beta", + ) + + auth_header, _ = vertex_llm_base._get_token_and_url( + model="", + auth_header=_auth_header, + gemini_api_key=None, + vertex_credentials=vertex_credentials_str, + vertex_project=vertex_project, + vertex_location=vertex_location, + stream=False, + custom_llm_provider="vertex_ai_beta", + api_base="", + ) + + headers = { + "Authorization": f"Bearer {auth_header}", + } + + request_route = encoded_endpoint + verbose_proxy_logger.debug("request_route %s", request_route) + + # Ensure endpoint starts with '/' for proper URL construction + if not encoded_endpoint.startswith("/"): + encoded_endpoint = "/" + encoded_endpoint + + # Construct the full target URL using httpx + updated_url = construct_target_url( + base_url=base_target_url, + requested_route=encoded_endpoint, + vertex_location=vertex_location, + vertex_project=vertex_project, + ) + + verbose_proxy_logger.debug("updated url %s", updated_url) + + ## check for streaming + target = str(updated_url) + is_streaming_request = False + if "stream" in str(updated_url): + is_streaming_request = True + target += "?alt=sse" + + ## CREATE PASS-THROUGH + endpoint_func = create_pass_through_route( + endpoint=endpoint, + target=target, + custom_headers=headers, + ) # dynamically construct pass-through endpoint based on incoming path + + try: + received_value = await endpoint_func( + request, + fastapi_response, + user_api_key_dict, + stream=is_streaming_request, # type: ignore + ) + except Exception as e: + if headers_passed_through: + raise Exception( + f"No credentials found on proxy for this request. Headers were passed through directly but request failed with error: {str(e)}" + ) + else: + raise e + + return received_value + + @router.api_route( "/openai/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"], diff --git a/litellm/proxy/pass_through_endpoints/passthrough_endpoint_router.py b/litellm/proxy/pass_through_endpoints/passthrough_endpoint_router.py index adf7d0f30c..89cccfc071 100644 --- a/litellm/proxy/pass_through_endpoints/passthrough_endpoint_router.py +++ b/litellm/proxy/pass_through_endpoints/passthrough_endpoint_router.py @@ -1,7 +1,9 @@ from typing import Dict, Optional -from litellm._logging import verbose_logger +from litellm._logging import verbose_router_logger from litellm.secret_managers.main import get_secret_str +from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES +from litellm.types.passthrough_endpoints.vertex_ai import VertexPassThroughCredentials class PassthroughEndpointRouter: @@ -11,6 +13,10 @@ class PassthroughEndpointRouter: def __init__(self): self.credentials: Dict[str, str] = {} + self.deployment_key_to_vertex_credentials: Dict[ + str, VertexPassThroughCredentials + ] = {} + self.default_vertex_config: Optional[VertexPassThroughCredentials] = None def set_pass_through_credentials( self, @@ -45,14 +51,14 @@ class PassthroughEndpointRouter: custom_llm_provider=custom_llm_provider, region_name=region_name, ) - verbose_logger.debug( + verbose_router_logger.debug( f"Pass-through llm endpoints router, looking for credentials for {credential_name}" ) if credential_name in self.credentials: - verbose_logger.debug(f"Found credentials for {credential_name}") + verbose_router_logger.debug(f"Found credentials for {credential_name}") return self.credentials[credential_name] else: - verbose_logger.debug( + verbose_router_logger.debug( f"No credentials found for {credential_name}, looking for env variable" ) _env_variable_name = ( @@ -62,6 +68,100 @@ class PassthroughEndpointRouter: ) return get_secret_str(_env_variable_name) + def _get_vertex_env_vars(self) -> VertexPassThroughCredentials: + """ + Helper to get vertex pass through config from environment variables + + The following environment variables are used: + - DEFAULT_VERTEXAI_PROJECT (project id) + - DEFAULT_VERTEXAI_LOCATION (location) + - DEFAULT_GOOGLE_APPLICATION_CREDENTIALS (path to credentials file) + """ + return VertexPassThroughCredentials( + vertex_project=get_secret_str("DEFAULT_VERTEXAI_PROJECT"), + vertex_location=get_secret_str("DEFAULT_VERTEXAI_LOCATION"), + vertex_credentials=get_secret_str("DEFAULT_GOOGLE_APPLICATION_CREDENTIALS"), + ) + + def set_default_vertex_config(self, config: Optional[dict] = None): + """Sets vertex configuration from provided config and/or environment variables + + Args: + config (Optional[dict]): Configuration dictionary + Example: { + "vertex_project": "my-project-123", + "vertex_location": "us-central1", + "vertex_credentials": "os.environ/GOOGLE_CREDS" + } + """ + # Initialize config dictionary if None + if config is None: + self.default_vertex_config = self._get_vertex_env_vars() + return + + if isinstance(config, dict): + for key, value in config.items(): + if isinstance(value, str) and value.startswith("os.environ/"): + config[key] = get_secret_str(value) + + self.default_vertex_config = VertexPassThroughCredentials(**config) + + def add_vertex_credentials( + self, + project_id: str, + location: str, + vertex_credentials: VERTEX_CREDENTIALS_TYPES, + ): + """ + Add the vertex credentials for the given project-id, location + """ + + deployment_key = self._get_deployment_key( + project_id=project_id, + location=location, + ) + if deployment_key is None: + verbose_router_logger.debug( + "No deployment key found for project-id, location" + ) + return + vertex_pass_through_credentials = VertexPassThroughCredentials( + vertex_project=project_id, + vertex_location=location, + vertex_credentials=vertex_credentials, + ) + self.deployment_key_to_vertex_credentials[deployment_key] = ( + vertex_pass_through_credentials + ) + + def _get_deployment_key( + self, project_id: Optional[str], location: Optional[str] + ) -> Optional[str]: + """ + Get the deployment key for the given project-id, location + """ + if project_id is None or location is None: + return None + return f"{project_id}-{location}" + + def get_vertex_credentials( + self, project_id: Optional[str], location: Optional[str] + ) -> Optional[VertexPassThroughCredentials]: + """ + Get the vertex credentials for the given project-id, location + """ + deployment_key = self._get_deployment_key( + project_id=project_id, + location=location, + ) + + if deployment_key is None: + return self.default_vertex_config + if deployment_key in self.deployment_key_to_vertex_credentials: + return self.deployment_key_to_vertex_credentials[deployment_key] + else: + return self.default_vertex_config + def _get_credential_name_for_provider( self, custom_llm_provider: str, diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index ae1c8d18af..ee2a906200 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -235,6 +235,9 @@ from litellm.proxy.openai_files_endpoints.files_endpoints import ( router as openai_files_router, ) from litellm.proxy.openai_files_endpoints.files_endpoints import set_files_config +from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import ( + passthrough_endpoint_router, +) from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import ( router as llm_passthrough_router, ) @@ -272,8 +275,6 @@ from litellm.proxy.utils import ( from litellm.proxy.vertex_ai_endpoints.langfuse_endpoints import ( router as langfuse_router, ) -from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import router as vertex_router -from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import set_default_vertex_config from litellm.router import ( AssistantsTypedDict, Deployment, @@ -2115,7 +2116,9 @@ class ProxyConfig: ## default config for vertex ai routes default_vertex_config = config.get("default_vertex_config", None) - set_default_vertex_config(config=default_vertex_config) + passthrough_endpoint_router.set_default_vertex_config( + config=default_vertex_config + ) ## ROUTER SETTINGS (e.g. routing_strategy, ...) router_settings = config.get("router_settings", None) @@ -8161,7 +8164,6 @@ app.include_router(batches_router) app.include_router(rerank_router) app.include_router(fine_tuning_router) app.include_router(credential_router) -app.include_router(vertex_router) app.include_router(llm_passthrough_router) app.include_router(anthropic_router) app.include_router(langfuse_router) diff --git a/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py b/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py deleted file mode 100644 index 7444e3d1c1..0000000000 --- a/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py +++ /dev/null @@ -1,274 +0,0 @@ -import traceback -from typing import Optional - -import httpx -from fastapi import APIRouter, HTTPException, Request, Response, status - -import litellm -from litellm._logging import verbose_proxy_logger -from litellm.fine_tuning.main import vertex_fine_tuning_apis_instance -from litellm.proxy._types import * -from litellm.proxy.auth.user_api_key_auth import user_api_key_auth -from litellm.proxy.pass_through_endpoints.pass_through_endpoints import ( - create_pass_through_route, -) -from litellm.secret_managers.main import get_secret_str -from litellm.types.passthrough_endpoints.vertex_ai import * - -from .vertex_passthrough_router import VertexPassThroughRouter - -router = APIRouter() -vertex_pass_through_router = VertexPassThroughRouter() - -default_vertex_config: VertexPassThroughCredentials = VertexPassThroughCredentials() - - -def _get_vertex_env_vars() -> VertexPassThroughCredentials: - """ - Helper to get vertex pass through config from environment variables - - The following environment variables are used: - - DEFAULT_VERTEXAI_PROJECT (project id) - - DEFAULT_VERTEXAI_LOCATION (location) - - DEFAULT_GOOGLE_APPLICATION_CREDENTIALS (path to credentials file) - """ - return VertexPassThroughCredentials( - vertex_project=get_secret_str("DEFAULT_VERTEXAI_PROJECT"), - vertex_location=get_secret_str("DEFAULT_VERTEXAI_LOCATION"), - vertex_credentials=get_secret_str("DEFAULT_GOOGLE_APPLICATION_CREDENTIALS"), - ) - - -def set_default_vertex_config(config: Optional[dict] = None): - """Sets vertex configuration from provided config and/or environment variables - - Args: - config (Optional[dict]): Configuration dictionary - Example: { - "vertex_project": "my-project-123", - "vertex_location": "us-central1", - "vertex_credentials": "os.environ/GOOGLE_CREDS" - } - """ - global default_vertex_config - - # Initialize config dictionary if None - if config is None: - default_vertex_config = _get_vertex_env_vars() - return - - if isinstance(config, dict): - for key, value in config.items(): - if isinstance(value, str) and value.startswith("os.environ/"): - config[key] = litellm.get_secret(value) - - _set_default_vertex_config(VertexPassThroughCredentials(**config)) - - -def _set_default_vertex_config( - vertex_pass_through_credentials: VertexPassThroughCredentials, -): - global default_vertex_config - default_vertex_config = vertex_pass_through_credentials - - -def exception_handler(e: Exception): - verbose_proxy_logger.error( - "litellm.proxy.proxy_server.v1/projects/tuningJobs(): Exception occurred - {}".format( - str(e) - ) - ) - verbose_proxy_logger.debug(traceback.format_exc()) - if isinstance(e, HTTPException): - return ProxyException( - message=getattr(e, "message", str(e.detail)), - type=getattr(e, "type", "None"), - param=getattr(e, "param", "None"), - code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), - ) - else: - error_msg = f"{str(e)}" - return ProxyException( - message=getattr(e, "message", error_msg), - type=getattr(e, "type", "None"), - param=getattr(e, "param", "None"), - code=getattr(e, "status_code", 500), - ) - - -def construct_target_url( - base_url: str, - requested_route: str, - default_vertex_location: Optional[str], - default_vertex_project: Optional[str], -) -> httpx.URL: - """ - Allow user to specify their own project id / location. - - If missing, use defaults - - Handle cachedContent scenario - https://github.com/BerriAI/litellm/issues/5460 - - Constructed Url: - POST https://LOCATION-aiplatform.googleapis.com/{version}/projects/PROJECT_ID/locations/LOCATION/cachedContents - """ - new_base_url = httpx.URL(base_url) - if "locations" in requested_route: # contains the target project id + location - updated_url = new_base_url.copy_with(path=requested_route) - return updated_url - """ - - Add endpoint version (e.g. v1beta for cachedContent, v1 for rest) - - Add default project id - - Add default location - """ - vertex_version: Literal["v1", "v1beta1"] = "v1" - if "cachedContent" in requested_route: - vertex_version = "v1beta1" - - base_requested_route = "{}/projects/{}/locations/{}".format( - vertex_version, default_vertex_project, default_vertex_location - ) - - updated_requested_route = "/" + base_requested_route + requested_route - - updated_url = new_base_url.copy_with(path=updated_requested_route) - return updated_url - - -@router.api_route( - "/vertex-ai/{endpoint:path}", - methods=["GET", "POST", "PUT", "DELETE", "PATCH"], - tags=["Vertex AI Pass-through", "pass-through"], - include_in_schema=False, -) -@router.api_route( - "/vertex_ai/{endpoint:path}", - methods=["GET", "POST", "PUT", "DELETE", "PATCH"], - tags=["Vertex AI Pass-through", "pass-through"], -) -async def vertex_proxy_route( - endpoint: str, - request: Request, - fastapi_response: Response, -): - """ - Call LiteLLM proxy via Vertex AI SDK. - - [Docs](https://docs.litellm.ai/docs/pass_through/vertex_ai) - """ - encoded_endpoint = httpx.URL(endpoint).path - verbose_proxy_logger.debug("requested endpoint %s", endpoint) - headers: dict = {} - api_key_to_use = get_litellm_virtual_key(request=request) - user_api_key_dict = await user_api_key_auth( - request=request, - api_key=api_key_to_use, - ) - - vertex_project: Optional[str] = ( - VertexPassThroughRouter._get_vertex_project_id_from_url(endpoint) - ) - vertex_location: Optional[str] = ( - VertexPassThroughRouter._get_vertex_location_from_url(endpoint) - ) - vertex_credentials = vertex_pass_through_router.get_vertex_credentials( - project_id=vertex_project, - location=vertex_location, - ) - - # Use headers from the incoming request if no vertex credentials are found - if vertex_credentials.vertex_project is None: - headers = dict(request.headers) or {} - verbose_proxy_logger.debug( - "default_vertex_config not set, incoming request headers %s", headers - ) - base_target_url = f"https://{vertex_location}-aiplatform.googleapis.com/" - headers.pop("content-length", None) - headers.pop("host", None) - else: - vertex_project = vertex_credentials.vertex_project - vertex_location = vertex_credentials.vertex_location - vertex_credentials_str = vertex_credentials.vertex_credentials - - # Construct base URL for the target endpoint - base_target_url = f"https://{vertex_location}-aiplatform.googleapis.com/" - - _auth_header, vertex_project = ( - await vertex_fine_tuning_apis_instance._ensure_access_token_async( - credentials=vertex_credentials_str, - project_id=vertex_project, - custom_llm_provider="vertex_ai_beta", - ) - ) - - auth_header, _ = vertex_fine_tuning_apis_instance._get_token_and_url( - model="", - auth_header=_auth_header, - gemini_api_key=None, - vertex_credentials=vertex_credentials_str, - vertex_project=vertex_project, - vertex_location=vertex_location, - stream=False, - custom_llm_provider="vertex_ai_beta", - api_base="", - ) - - headers = { - "Authorization": f"Bearer {auth_header}", - } - - request_route = encoded_endpoint - verbose_proxy_logger.debug("request_route %s", request_route) - - # Ensure endpoint starts with '/' for proper URL construction - if not encoded_endpoint.startswith("/"): - encoded_endpoint = "/" + encoded_endpoint - - # Construct the full target URL using httpx - updated_url = construct_target_url( - base_url=base_target_url, - requested_route=encoded_endpoint, - default_vertex_location=vertex_location, - default_vertex_project=vertex_project, - ) - # base_url = httpx.URL(base_target_url) - # updated_url = base_url.copy_with(path=encoded_endpoint) - - verbose_proxy_logger.debug("updated url %s", updated_url) - - ## check for streaming - target = str(updated_url) - is_streaming_request = False - if "stream" in str(updated_url): - is_streaming_request = True - target += "?alt=sse" - - ## CREATE PASS-THROUGH - endpoint_func = create_pass_through_route( - endpoint=endpoint, - target=target, - custom_headers=headers, - ) # dynamically construct pass-through endpoint based on incoming path - received_value = await endpoint_func( - request, - fastapi_response, - user_api_key_dict, - stream=is_streaming_request, # type: ignore - ) - - return received_value - - -def get_litellm_virtual_key(request: Request) -> str: - """ - Extract and format API key from request headers. - Prioritizes x-litellm-api-key over Authorization header. - - - Vertex JS SDK uses `Authorization` header, we use `x-litellm-api-key` to pass litellm virtual key - - """ - litellm_api_key = request.headers.get("x-litellm-api-key") - if litellm_api_key: - return f"Bearer {litellm_api_key}" - return request.headers.get("Authorization", "") diff --git a/litellm/proxy/vertex_ai_endpoints/vertex_passthrough_router.py b/litellm/proxy/vertex_ai_endpoints/vertex_passthrough_router.py deleted file mode 100644 index 0273a62047..0000000000 --- a/litellm/proxy/vertex_ai_endpoints/vertex_passthrough_router.py +++ /dev/null @@ -1,121 +0,0 @@ -import json -import re -from typing import Dict, Optional - -from litellm._logging import verbose_proxy_logger -from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import ( - VertexPassThroughCredentials, -) -from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES - - -class VertexPassThroughRouter: - """ - Vertex Pass Through Router for Vertex AI pass-through endpoints - - - - if request specifies a project-id, location -> use credentials corresponding to the project-id, location - - if request does not specify a project-id, location -> use credentials corresponding to the DEFAULT_VERTEXAI_PROJECT, DEFAULT_VERTEXAI_LOCATION - """ - - def __init__(self): - """ - Initialize the VertexPassThroughRouter - Stores the vertex credentials for each deployment key - ``` - { - "project_id-location": VertexPassThroughCredentials, - "adroit-crow-us-central1": VertexPassThroughCredentials, - } - ``` - """ - self.deployment_key_to_vertex_credentials: Dict[ - str, VertexPassThroughCredentials - ] = {} - pass - - def get_vertex_credentials( - self, project_id: Optional[str], location: Optional[str] - ) -> VertexPassThroughCredentials: - """ - Get the vertex credentials for the given project-id, location - """ - from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import ( - default_vertex_config, - ) - - deployment_key = self._get_deployment_key( - project_id=project_id, - location=location, - ) - if deployment_key is None: - return default_vertex_config - if deployment_key in self.deployment_key_to_vertex_credentials: - return self.deployment_key_to_vertex_credentials[deployment_key] - else: - return default_vertex_config - - def add_vertex_credentials( - self, - project_id: str, - location: str, - vertex_credentials: VERTEX_CREDENTIALS_TYPES, - ): - """ - Add the vertex credentials for the given project-id, location - """ - from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import ( - _set_default_vertex_config, - ) - - deployment_key = self._get_deployment_key( - project_id=project_id, - location=location, - ) - if deployment_key is None: - verbose_proxy_logger.debug( - "No deployment key found for project-id, location" - ) - return - vertex_pass_through_credentials = VertexPassThroughCredentials( - vertex_project=project_id, - vertex_location=location, - vertex_credentials=vertex_credentials, - ) - self.deployment_key_to_vertex_credentials[deployment_key] = ( - vertex_pass_through_credentials - ) - verbose_proxy_logger.debug( - f"self.deployment_key_to_vertex_credentials: {json.dumps(self.deployment_key_to_vertex_credentials, indent=4, default=str)}" - ) - _set_default_vertex_config(vertex_pass_through_credentials) - - def _get_deployment_key( - self, project_id: Optional[str], location: Optional[str] - ) -> Optional[str]: - """ - Get the deployment key for the given project-id, location - """ - if project_id is None or location is None: - return None - return f"{project_id}-{location}" - - @staticmethod - def _get_vertex_project_id_from_url(url: str) -> Optional[str]: - """ - Get the vertex project id from the url - - `https://${LOCATION}-aiplatform.googleapis.com/v1/projects/${PROJECT_ID}/locations/${LOCATION}/publishers/google/models/${MODEL_ID}:streamGenerateContent` - """ - match = re.search(r"/projects/([^/]+)", url) - return match.group(1) if match else None - - @staticmethod - def _get_vertex_location_from_url(url: str) -> Optional[str]: - """ - Get the vertex location from the url - - `https://${LOCATION}-aiplatform.googleapis.com/v1/projects/${PROJECT_ID}/locations/${LOCATION}/publishers/google/models/${MODEL_ID}:streamGenerateContent` - """ - match = re.search(r"/locations/([^/]+)", url) - return match.group(1) if match else None diff --git a/litellm/router.py b/litellm/router.py index 49ce815a24..524c55539f 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -4495,11 +4495,11 @@ class Router: Each provider uses diff .env vars for pass-through endpoints, this helper uses the deployment credentials to set the .env vars for pass-through endpoints """ if deployment.litellm_params.use_in_pass_through is True: - if custom_llm_provider == "vertex_ai": - from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import ( - vertex_pass_through_router, - ) + from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import ( + passthrough_endpoint_router, + ) + if custom_llm_provider == "vertex_ai": if ( deployment.litellm_params.vertex_project is None or deployment.litellm_params.vertex_location is None @@ -4508,16 +4508,12 @@ class Router: raise ValueError( "vertex_project, vertex_location, and vertex_credentials must be set in litellm_params for pass-through endpoints" ) - vertex_pass_through_router.add_vertex_credentials( + passthrough_endpoint_router.add_vertex_credentials( project_id=deployment.litellm_params.vertex_project, location=deployment.litellm_params.vertex_location, vertex_credentials=deployment.litellm_params.vertex_credentials, ) else: - from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import ( - passthrough_endpoint_router, - ) - passthrough_endpoint_router.set_pass_through_credentials( custom_llm_provider=custom_llm_provider, api_base=deployment.litellm_params.api_base, diff --git a/model_prices_and_context_window.json b/model_prices_and_context_window.json index 6131afdd51..330fe890e4 100644 --- a/model_prices_and_context_window.json +++ b/model_prices_and_context_window.json @@ -82,6 +82,31 @@ "supports_system_messages": true, "supports_tool_choice": true }, + "gpt-4o-search-preview-2025-03-11": { + "max_tokens": 16384, + "max_input_tokens": 128000, + "max_output_tokens": 16384, + "input_cost_per_token": 0.0000025, + "output_cost_per_token": 0.000010, + "input_cost_per_token_batches": 0.00000125, + "output_cost_per_token_batches": 0.00000500, + "cache_read_input_token_cost": 0.00000125, + "litellm_provider": "openai", + "mode": "chat", + "supports_function_calling": true, + "supports_parallel_function_calling": true, + "supports_response_schema": true, + "supports_vision": true, + "supports_prompt_caching": true, + "supports_system_messages": true, + "supports_tool_choice": true, + "supports_web_search": true, + "search_context_cost_per_query": { + "search_context_size_low": 0.030, + "search_context_size_medium": 0.035, + "search_context_size_high": 0.050 + } + }, "gpt-4o-search-preview": { "max_tokens": 16384, "max_input_tokens": 128000, @@ -232,6 +257,31 @@ "supports_system_messages": true, "supports_tool_choice": true }, + "gpt-4o-mini-search-preview-2025-03-11":{ + "max_tokens": 16384, + "max_input_tokens": 128000, + "max_output_tokens": 16384, + "input_cost_per_token": 0.00000015, + "output_cost_per_token": 0.00000060, + "input_cost_per_token_batches": 0.000000075, + "output_cost_per_token_batches": 0.00000030, + "cache_read_input_token_cost": 0.000000075, + "litellm_provider": "openai", + "mode": "chat", + "supports_function_calling": true, + "supports_parallel_function_calling": true, + "supports_response_schema": true, + "supports_vision": true, + "supports_prompt_caching": true, + "supports_system_messages": true, + "supports_tool_choice": true, + "supports_web_search": true, + "search_context_cost_per_query": { + "search_context_size_low": 0.025, + "search_context_size_medium": 0.0275, + "search_context_size_high": 0.030 + } + }, "gpt-4o-mini-search-preview": { "max_tokens": 16384, "max_input_tokens": 128000, diff --git a/tests/litellm/llms/vertex_ai/test_vertex_ai_common_utils.py b/tests/litellm/llms/vertex_ai/test_vertex_ai_common_utils.py new file mode 100644 index 0000000000..e89355443f --- /dev/null +++ b/tests/litellm/llms/vertex_ai/test_vertex_ai_common_utils.py @@ -0,0 +1,43 @@ +import os +import sys +from unittest.mock import MagicMock, call, patch + +import pytest + +sys.path.insert( + 0, os.path.abspath("../../..") +) # Adds the parent directory to the system path + +import litellm +from litellm.llms.vertex_ai.common_utils import ( + get_vertex_location_from_url, + get_vertex_project_id_from_url, +) + + +@pytest.mark.asyncio +async def test_get_vertex_project_id_from_url(): + """Test _get_vertex_project_id_from_url with various URLs""" + # Test with valid URL + url = "https://us-central1-aiplatform.googleapis.com/v1/projects/test-project/locations/us-central1/publishers/google/models/gemini-pro:streamGenerateContent" + project_id = get_vertex_project_id_from_url(url) + assert project_id == "test-project" + + # Test with invalid URL + url = "https://invalid-url.com" + project_id = get_vertex_project_id_from_url(url) + assert project_id is None + + +@pytest.mark.asyncio +async def test_get_vertex_location_from_url(): + """Test _get_vertex_location_from_url with various URLs""" + # Test with valid URL + url = "https://us-central1-aiplatform.googleapis.com/v1/projects/test-project/locations/us-central1/publishers/google/models/gemini-pro:streamGenerateContent" + location = get_vertex_location_from_url(url) + assert location == "us-central1" + + # Test with invalid URL + url = "https://invalid-url.com" + location = get_vertex_location_from_url(url) + assert location is None diff --git a/tests/litellm/proxy/pass_through_endpoints/test_llm_pass_through_endpoints.py b/tests/litellm/proxy/pass_through_endpoints/test_llm_pass_through_endpoints.py index 2f5ce85de7..da08dea605 100644 --- a/tests/litellm/proxy/pass_through_endpoints/test_llm_pass_through_endpoints.py +++ b/tests/litellm/proxy/pass_through_endpoints/test_llm_pass_through_endpoints.py @@ -1,7 +1,9 @@ import json import os import sys -from unittest.mock import MagicMock, patch +import traceback +from unittest import mock +from unittest.mock import AsyncMock, MagicMock, Mock, patch import httpx import pytest @@ -17,7 +19,9 @@ from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import ( BaseOpenAIPassThroughHandler, RouteChecks, create_pass_through_route, + vertex_proxy_route, ) +from litellm.types.passthrough_endpoints.vertex_ai import VertexPassThroughCredentials class TestBaseOpenAIPassThroughHandler: @@ -176,3 +180,279 @@ class TestBaseOpenAIPassThroughHandler: print(f"query_params: {call_kwargs['query_params']}") assert call_kwargs["stream"] is False assert call_kwargs["query_params"] == {"model": "gpt-4"} + + +class TestVertexAIPassThroughHandler: + """ + Case 1: User set passthrough credentials - confirm credentials used. + + Case 2: User set default credentials, no exact passthrough credentials - confirm default credentials used. + + Case 3: No default credentials, no mapped credentials - request passed through directly. + """ + + @pytest.mark.asyncio + async def test_vertex_passthrough_with_credentials(self, monkeypatch): + """ + Test that when passthrough credentials are set, they are correctly used in the request + """ + from litellm.proxy.pass_through_endpoints.passthrough_endpoint_router import ( + PassthroughEndpointRouter, + ) + + vertex_project = "test-project" + vertex_location = "us-central1" + vertex_credentials = "test-creds" + + pass_through_router = PassthroughEndpointRouter() + + pass_through_router.add_vertex_credentials( + project_id=vertex_project, + location=vertex_location, + vertex_credentials=vertex_credentials, + ) + + monkeypatch.setattr( + "litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.passthrough_endpoint_router", + pass_through_router, + ) + + endpoint = f"/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/gemini-1.5-flash:generateContent" + + # Mock request + mock_request = Request( + scope={ + "type": "http", + "method": "POST", + "path": endpoint, + "headers": [ + (b"Authorization", b"Bearer test-creds"), + (b"Content-Type", b"application/json"), + ], + } + ) + + # Mock response + mock_response = Response() + + # Mock vertex credentials + test_project = vertex_project + test_location = vertex_location + test_token = vertex_credentials + + with mock.patch( + "litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.vertex_llm_base._ensure_access_token_async" + ) as mock_ensure_token, mock.patch( + "litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.vertex_llm_base._get_token_and_url" + ) as mock_get_token, mock.patch( + "litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.create_pass_through_route" + ) as mock_create_route: + mock_ensure_token.return_value = ("test-auth-header", test_project) + mock_get_token.return_value = (test_token, "") + + # Call the route + try: + await vertex_proxy_route( + endpoint=endpoint, + request=mock_request, + fastapi_response=mock_response, + ) + except Exception as e: + print(f"Error: {e}") + + # Verify create_pass_through_route was called with correct arguments + mock_create_route.assert_called_once_with( + endpoint=endpoint, + target=f"https://{test_location}-aiplatform.googleapis.com/v1/projects/{test_project}/locations/{test_location}/publishers/google/models/gemini-1.5-flash:generateContent", + custom_headers={"Authorization": f"Bearer {test_token}"}, + ) + + @pytest.mark.parametrize( + "initial_endpoint", + [ + "publishers/google/models/gemini-1.5-flash:generateContent", + "v1/projects/bad-project/locations/bad-location/publishers/google/models/gemini-1.5-flash:generateContent", + ], + ) + @pytest.mark.asyncio + async def test_vertex_passthrough_with_default_credentials( + self, monkeypatch, initial_endpoint + ): + """ + Test that when no passthrough credentials are set, default credentials are used in the request + """ + from litellm.proxy.pass_through_endpoints.passthrough_endpoint_router import ( + PassthroughEndpointRouter, + ) + + # Setup default credentials + default_project = "default-project" + default_location = "us-central1" + default_credentials = "default-creds" + + pass_through_router = PassthroughEndpointRouter() + pass_through_router.default_vertex_config = VertexPassThroughCredentials( + vertex_project=default_project, + vertex_location=default_location, + vertex_credentials=default_credentials, + ) + + monkeypatch.setattr( + "litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.passthrough_endpoint_router", + pass_through_router, + ) + + # Use different project/location in request than the default + endpoint = initial_endpoint + + mock_request = Request( + scope={ + "type": "http", + "method": "POST", + "path": f"/vertex_ai/{endpoint}", + "headers": {}, + } + ) + mock_response = Response() + + with mock.patch( + "litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.vertex_llm_base._ensure_access_token_async" + ) as mock_ensure_token, mock.patch( + "litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.vertex_llm_base._get_token_and_url" + ) as mock_get_token, mock.patch( + "litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.create_pass_through_route" + ) as mock_create_route: + mock_ensure_token.return_value = ("test-auth-header", default_project) + mock_get_token.return_value = (default_credentials, "") + + try: + await vertex_proxy_route( + endpoint=endpoint, + request=mock_request, + fastapi_response=mock_response, + ) + except Exception as e: + traceback.print_exc() + print(f"Error: {e}") + + # Verify default credentials were used + mock_create_route.assert_called_once_with( + endpoint=endpoint, + target=f"https://{default_location}-aiplatform.googleapis.com/v1/projects/{default_project}/locations/{default_location}/publishers/google/models/gemini-1.5-flash:generateContent", + custom_headers={"Authorization": f"Bearer {default_credentials}"}, + ) + + @pytest.mark.asyncio + async def test_vertex_passthrough_with_no_default_credentials(self, monkeypatch): + """ + Test that when no default credentials are set, the request fails + """ + """ + Test that when passthrough credentials are set, they are correctly used in the request + """ + from litellm.proxy.pass_through_endpoints.passthrough_endpoint_router import ( + PassthroughEndpointRouter, + ) + + vertex_project = "my-project" + vertex_location = "us-central1" + vertex_credentials = "test-creds" + + test_project = "test-project" + test_location = "test-location" + test_token = "test-creds" + + pass_through_router = PassthroughEndpointRouter() + + pass_through_router.add_vertex_credentials( + project_id=vertex_project, + location=vertex_location, + vertex_credentials=vertex_credentials, + ) + + monkeypatch.setattr( + "litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.passthrough_endpoint_router", + pass_through_router, + ) + + endpoint = f"/v1/projects/{test_project}/locations/{test_location}/publishers/google/models/gemini-1.5-flash:generateContent" + + # Mock request + mock_request = Request( + scope={ + "type": "http", + "method": "POST", + "path": endpoint, + "headers": [ + (b"authorization", b"Bearer test-creds"), + ], + } + ) + + # Mock response + mock_response = Response() + + with mock.patch( + "litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.vertex_llm_base._ensure_access_token_async" + ) as mock_ensure_token, mock.patch( + "litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.vertex_llm_base._get_token_and_url" + ) as mock_get_token, mock.patch( + "litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.create_pass_through_route" + ) as mock_create_route: + mock_ensure_token.return_value = ("test-auth-header", test_project) + mock_get_token.return_value = (test_token, "") + + # Call the route + try: + await vertex_proxy_route( + endpoint=endpoint, + request=mock_request, + fastapi_response=mock_response, + ) + except Exception as e: + traceback.print_exc() + print(f"Error: {e}") + + # Verify create_pass_through_route was called with correct arguments + mock_create_route.assert_called_once_with( + endpoint=endpoint, + target=f"https://{test_location}-aiplatform.googleapis.com/v1/projects/{test_project}/locations/{test_location}/publishers/google/models/gemini-1.5-flash:generateContent", + custom_headers={"authorization": f"Bearer {test_token}"}, + ) + + @pytest.mark.asyncio + async def test_async_vertex_proxy_route_api_key_auth(self): + """ + Critical + + This is how Vertex AI JS SDK will Auth to Litellm Proxy + """ + # Mock dependencies + mock_request = Mock() + mock_request.headers = {"x-litellm-api-key": "test-key-123"} + mock_request.method = "POST" + mock_response = Mock() + + with patch( + "litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.user_api_key_auth" + ) as mock_auth: + mock_auth.return_value = {"api_key": "test-key-123"} + + with patch( + "litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.create_pass_through_route" + ) as mock_pass_through: + mock_pass_through.return_value = AsyncMock( + return_value={"status": "success"} + ) + + # Call the function + result = await vertex_proxy_route( + endpoint="v1/projects/test-project/locations/us-central1/publishers/google/models/gemini-1.5-pro:generateContent", + request=mock_request, + fastapi_response=mock_response, + ) + + # Verify user_api_key_auth was called with the correct Bearer token + mock_auth.assert_called_once() + call_args = mock_auth.call_args[1] + assert call_args["api_key"] == "Bearer test-key-123" diff --git a/tests/litellm/proxy/pass_through_endpoints/test_passthrough_endpoints_common_utils.py b/tests/litellm/proxy/pass_through_endpoints/test_passthrough_endpoints_common_utils.py new file mode 100644 index 0000000000..bd8c5f5a99 --- /dev/null +++ b/tests/litellm/proxy/pass_through_endpoints/test_passthrough_endpoints_common_utils.py @@ -0,0 +1,44 @@ +import json +import os +import sys +import traceback +from unittest import mock +from unittest.mock import MagicMock, patch + +import httpx +import pytest +from fastapi import Request, Response +from fastapi.testclient import TestClient + +sys.path.insert( + 0, os.path.abspath("../../../..") +) # Adds the parent directory to the system path + +from unittest.mock import Mock + +from litellm.proxy.pass_through_endpoints.common_utils import get_litellm_virtual_key + + +@pytest.mark.asyncio +async def test_get_litellm_virtual_key(): + """ + Test that the get_litellm_virtual_key function correctly handles the API key authentication + """ + # Test with x-litellm-api-key + mock_request = Mock() + mock_request.headers = {"x-litellm-api-key": "test-key-123"} + result = get_litellm_virtual_key(mock_request) + assert result == "Bearer test-key-123" + + # Test with Authorization header + mock_request.headers = {"Authorization": "Bearer auth-key-456"} + result = get_litellm_virtual_key(mock_request) + assert result == "Bearer auth-key-456" + + # Test with both headers (x-litellm-api-key should take precedence) + mock_request.headers = { + "x-litellm-api-key": "test-key-123", + "Authorization": "Bearer auth-key-456", + } + result = get_litellm_virtual_key(mock_request) + assert result == "Bearer test-key-123" diff --git a/tests/pass_through_unit_tests/test_pass_through_unit_tests.py b/tests/pass_through_unit_tests/test_pass_through_unit_tests.py index db0a647e41..cb9db00324 100644 --- a/tests/pass_through_unit_tests/test_pass_through_unit_tests.py +++ b/tests/pass_through_unit_tests/test_pass_through_unit_tests.py @@ -339,9 +339,6 @@ def test_pass_through_routes_support_all_methods(): from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import ( router as llm_router, ) - from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import ( - router as vertex_router, - ) # Expected HTTP methods expected_methods = {"GET", "POST", "PUT", "DELETE", "PATCH"} @@ -361,7 +358,6 @@ def test_pass_through_routes_support_all_methods(): # Check both routers check_router_methods(llm_router) - check_router_methods(vertex_router) def test_is_bedrock_agent_runtime_route(): diff --git a/tests/pass_through_unit_tests/test_unit_test_passthrough_router.py b/tests/pass_through_unit_tests/test_unit_test_passthrough_router.py index 6e8296876a..8e016b68d0 100644 --- a/tests/pass_through_unit_tests/test_unit_test_passthrough_router.py +++ b/tests/pass_through_unit_tests/test_unit_test_passthrough_router.py @@ -11,6 +11,7 @@ from unittest.mock import patch from litellm.proxy.pass_through_endpoints.passthrough_endpoint_router import ( PassthroughEndpointRouter, ) +from litellm.types.passthrough_endpoints.vertex_ai import VertexPassThroughCredentials passthrough_endpoint_router = PassthroughEndpointRouter() @@ -132,3 +133,185 @@ class TestPassthroughEndpointRouter(unittest.TestCase): ), "COHERE_API_KEY", ) + + def test_get_deployment_key(self): + """Test _get_deployment_key with various inputs""" + router = PassthroughEndpointRouter() + + # Test with valid inputs + key = router._get_deployment_key("test-project", "us-central1") + assert key == "test-project-us-central1" + + # Test with None values + key = router._get_deployment_key(None, "us-central1") + assert key is None + + key = router._get_deployment_key("test-project", None) + assert key is None + + key = router._get_deployment_key(None, None) + assert key is None + + def test_add_vertex_credentials(self): + """Test add_vertex_credentials functionality""" + router = PassthroughEndpointRouter() + + # Test adding valid credentials + router.add_vertex_credentials( + project_id="test-project", + location="us-central1", + vertex_credentials='{"credentials": "test-creds"}', + ) + + assert "test-project-us-central1" in router.deployment_key_to_vertex_credentials + creds = router.deployment_key_to_vertex_credentials["test-project-us-central1"] + assert creds.vertex_project == "test-project" + assert creds.vertex_location == "us-central1" + assert creds.vertex_credentials == '{"credentials": "test-creds"}' + + # Test adding with None values + router.add_vertex_credentials( + project_id=None, + location=None, + vertex_credentials='{"credentials": "test-creds"}', + ) + # Should not add None values + assert len(router.deployment_key_to_vertex_credentials) == 1 + + def test_default_credentials(self): + """ + Test get_vertex_credentials with stored credentials. + + Tests if default credentials are used if set. + + Tests if no default credentials are used, if no default set + """ + router = PassthroughEndpointRouter() + router.add_vertex_credentials( + project_id="test-project", + location="us-central1", + vertex_credentials='{"credentials": "test-creds"}', + ) + + creds = router.get_vertex_credentials( + project_id="test-project", location="us-central2" + ) + + assert creds is None + + def test_get_vertex_env_vars(self): + """Test that _get_vertex_env_vars correctly reads environment variables""" + # Set environment variables for the test + os.environ["DEFAULT_VERTEXAI_PROJECT"] = "test-project-123" + os.environ["DEFAULT_VERTEXAI_LOCATION"] = "us-central1" + os.environ["DEFAULT_GOOGLE_APPLICATION_CREDENTIALS"] = "/path/to/creds" + + try: + result = self.router._get_vertex_env_vars() + print(result) + + # Verify the result + assert isinstance(result, VertexPassThroughCredentials) + assert result.vertex_project == "test-project-123" + assert result.vertex_location == "us-central1" + assert result.vertex_credentials == "/path/to/creds" + + finally: + # Clean up environment variables + del os.environ["DEFAULT_VERTEXAI_PROJECT"] + del os.environ["DEFAULT_VERTEXAI_LOCATION"] + del os.environ["DEFAULT_GOOGLE_APPLICATION_CREDENTIALS"] + + def test_set_default_vertex_config(self): + """Test set_default_vertex_config with various inputs""" + # Test with None config - set environment variables first + os.environ["DEFAULT_VERTEXAI_PROJECT"] = "env-project" + os.environ["DEFAULT_VERTEXAI_LOCATION"] = "env-location" + os.environ["DEFAULT_GOOGLE_APPLICATION_CREDENTIALS"] = "env-creds" + os.environ["GOOGLE_CREDS"] = "secret-creds" + + try: + # Test with None config + self.router.set_default_vertex_config() + + assert self.router.default_vertex_config.vertex_project == "env-project" + assert self.router.default_vertex_config.vertex_location == "env-location" + assert self.router.default_vertex_config.vertex_credentials == "env-creds" + + # Test with valid config.yaml settings on vertex_config + test_config = { + "vertex_project": "my-project-123", + "vertex_location": "us-central1", + "vertex_credentials": "path/to/creds", + } + self.router.set_default_vertex_config(test_config) + + assert self.router.default_vertex_config.vertex_project == "my-project-123" + assert self.router.default_vertex_config.vertex_location == "us-central1" + assert ( + self.router.default_vertex_config.vertex_credentials == "path/to/creds" + ) + + # Test with environment variable reference + test_config = { + "vertex_project": "my-project-123", + "vertex_location": "us-central1", + "vertex_credentials": "os.environ/GOOGLE_CREDS", + } + self.router.set_default_vertex_config(test_config) + + assert ( + self.router.default_vertex_config.vertex_credentials == "secret-creds" + ) + + finally: + # Clean up environment variables + del os.environ["DEFAULT_VERTEXAI_PROJECT"] + del os.environ["DEFAULT_VERTEXAI_LOCATION"] + del os.environ["DEFAULT_GOOGLE_APPLICATION_CREDENTIALS"] + del os.environ["GOOGLE_CREDS"] + + def test_vertex_passthrough_router_init(self): + """Test VertexPassThroughRouter initialization""" + router = PassthroughEndpointRouter() + assert isinstance(router.deployment_key_to_vertex_credentials, dict) + assert len(router.deployment_key_to_vertex_credentials) == 0 + + def test_get_vertex_credentials_none(self): + """Test get_vertex_credentials with various inputs""" + router = PassthroughEndpointRouter() + + router.set_default_vertex_config( + config={ + "vertex_project": None, + "vertex_location": None, + "vertex_credentials": None, + } + ) + + # Test with None project_id and location - should return default config + creds = router.get_vertex_credentials(None, None) + assert isinstance(creds, VertexPassThroughCredentials) + + # Test with valid project_id and location but no stored credentials + creds = router.get_vertex_credentials("test-project", "us-central1") + assert isinstance(creds, VertexPassThroughCredentials) + assert creds.vertex_project is None + assert creds.vertex_location is None + assert creds.vertex_credentials is None + + def test_get_vertex_credentials_stored(self): + """Test get_vertex_credentials with stored credentials""" + router = PassthroughEndpointRouter() + router.add_vertex_credentials( + project_id="test-project", + location="us-central1", + vertex_credentials='{"credentials": "test-creds"}', + ) + + creds = router.get_vertex_credentials( + project_id="test-project", location="us-central1" + ) + assert creds.vertex_project == "test-project" + assert creds.vertex_location == "us-central1" + assert creds.vertex_credentials == '{"credentials": "test-creds"}' diff --git a/tests/pass_through_unit_tests/test_unit_test_vertex_pass_through.py b/tests/pass_through_unit_tests/test_unit_test_vertex_pass_through.py deleted file mode 100644 index ba5dfa33a8..0000000000 --- a/tests/pass_through_unit_tests/test_unit_test_vertex_pass_through.py +++ /dev/null @@ -1,294 +0,0 @@ -import json -import os -import sys -from datetime import datetime -from unittest.mock import AsyncMock, Mock, patch - -sys.path.insert( - 0, os.path.abspath("../..") -) # Adds the parent directory to the system-path - - -import httpx -import pytest -import litellm -from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj - - -from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import ( - get_litellm_virtual_key, - vertex_proxy_route, - _get_vertex_env_vars, - set_default_vertex_config, - VertexPassThroughCredentials, - default_vertex_config, -) -from litellm.proxy.vertex_ai_endpoints.vertex_passthrough_router import ( - VertexPassThroughRouter, -) - - -@pytest.mark.asyncio -async def test_get_litellm_virtual_key(): - """ - Test that the get_litellm_virtual_key function correctly handles the API key authentication - """ - # Test with x-litellm-api-key - mock_request = Mock() - mock_request.headers = {"x-litellm-api-key": "test-key-123"} - result = get_litellm_virtual_key(mock_request) - assert result == "Bearer test-key-123" - - # Test with Authorization header - mock_request.headers = {"Authorization": "Bearer auth-key-456"} - result = get_litellm_virtual_key(mock_request) - assert result == "Bearer auth-key-456" - - # Test with both headers (x-litellm-api-key should take precedence) - mock_request.headers = { - "x-litellm-api-key": "test-key-123", - "Authorization": "Bearer auth-key-456", - } - result = get_litellm_virtual_key(mock_request) - assert result == "Bearer test-key-123" - - -@pytest.mark.asyncio -async def test_async_vertex_proxy_route_api_key_auth(): - """ - Critical - - This is how Vertex AI JS SDK will Auth to Litellm Proxy - """ - # Mock dependencies - mock_request = Mock() - mock_request.headers = {"x-litellm-api-key": "test-key-123"} - mock_request.method = "POST" - mock_response = Mock() - - with patch( - "litellm.proxy.vertex_ai_endpoints.vertex_endpoints.user_api_key_auth" - ) as mock_auth: - mock_auth.return_value = {"api_key": "test-key-123"} - - with patch( - "litellm.proxy.vertex_ai_endpoints.vertex_endpoints.create_pass_through_route" - ) as mock_pass_through: - mock_pass_through.return_value = AsyncMock( - return_value={"status": "success"} - ) - - # Call the function - result = await vertex_proxy_route( - endpoint="v1/projects/test-project/locations/us-central1/publishers/google/models/gemini-1.5-pro:generateContent", - request=mock_request, - fastapi_response=mock_response, - ) - - # Verify user_api_key_auth was called with the correct Bearer token - mock_auth.assert_called_once() - call_args = mock_auth.call_args[1] - assert call_args["api_key"] == "Bearer test-key-123" - - -@pytest.mark.asyncio -async def test_get_vertex_env_vars(): - """Test that _get_vertex_env_vars correctly reads environment variables""" - # Set environment variables for the test - os.environ["DEFAULT_VERTEXAI_PROJECT"] = "test-project-123" - os.environ["DEFAULT_VERTEXAI_LOCATION"] = "us-central1" - os.environ["DEFAULT_GOOGLE_APPLICATION_CREDENTIALS"] = "/path/to/creds" - - try: - result = _get_vertex_env_vars() - print(result) - - # Verify the result - assert isinstance(result, VertexPassThroughCredentials) - assert result.vertex_project == "test-project-123" - assert result.vertex_location == "us-central1" - assert result.vertex_credentials == "/path/to/creds" - - finally: - # Clean up environment variables - del os.environ["DEFAULT_VERTEXAI_PROJECT"] - del os.environ["DEFAULT_VERTEXAI_LOCATION"] - del os.environ["DEFAULT_GOOGLE_APPLICATION_CREDENTIALS"] - - -@pytest.mark.asyncio -async def test_set_default_vertex_config(): - """Test set_default_vertex_config with various inputs""" - # Test with None config - set environment variables first - os.environ["DEFAULT_VERTEXAI_PROJECT"] = "env-project" - os.environ["DEFAULT_VERTEXAI_LOCATION"] = "env-location" - os.environ["DEFAULT_GOOGLE_APPLICATION_CREDENTIALS"] = "env-creds" - os.environ["GOOGLE_CREDS"] = "secret-creds" - - try: - # Test with None config - set_default_vertex_config() - from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import ( - default_vertex_config, - ) - - assert default_vertex_config.vertex_project == "env-project" - assert default_vertex_config.vertex_location == "env-location" - assert default_vertex_config.vertex_credentials == "env-creds" - - # Test with valid config.yaml settings on vertex_config - test_config = { - "vertex_project": "my-project-123", - "vertex_location": "us-central1", - "vertex_credentials": "path/to/creds", - } - set_default_vertex_config(test_config) - from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import ( - default_vertex_config, - ) - - assert default_vertex_config.vertex_project == "my-project-123" - assert default_vertex_config.vertex_location == "us-central1" - assert default_vertex_config.vertex_credentials == "path/to/creds" - - # Test with environment variable reference - test_config = { - "vertex_project": "my-project-123", - "vertex_location": "us-central1", - "vertex_credentials": "os.environ/GOOGLE_CREDS", - } - set_default_vertex_config(test_config) - from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import ( - default_vertex_config, - ) - - assert default_vertex_config.vertex_credentials == "secret-creds" - - finally: - # Clean up environment variables - del os.environ["DEFAULT_VERTEXAI_PROJECT"] - del os.environ["DEFAULT_VERTEXAI_LOCATION"] - del os.environ["DEFAULT_GOOGLE_APPLICATION_CREDENTIALS"] - del os.environ["GOOGLE_CREDS"] - - -@pytest.mark.asyncio -async def test_vertex_passthrough_router_init(): - """Test VertexPassThroughRouter initialization""" - router = VertexPassThroughRouter() - assert isinstance(router.deployment_key_to_vertex_credentials, dict) - assert len(router.deployment_key_to_vertex_credentials) == 0 - - -@pytest.mark.asyncio -async def test_get_vertex_credentials_none(): - """Test get_vertex_credentials with various inputs""" - from litellm.proxy.vertex_ai_endpoints import vertex_endpoints - - setattr(vertex_endpoints, "default_vertex_config", VertexPassThroughCredentials()) - router = VertexPassThroughRouter() - - # Test with None project_id and location - should return default config - creds = router.get_vertex_credentials(None, None) - assert isinstance(creds, VertexPassThroughCredentials) - - # Test with valid project_id and location but no stored credentials - creds = router.get_vertex_credentials("test-project", "us-central1") - assert isinstance(creds, VertexPassThroughCredentials) - assert creds.vertex_project is None - assert creds.vertex_location is None - assert creds.vertex_credentials is None - - -@pytest.mark.asyncio -async def test_get_vertex_credentials_stored(): - """Test get_vertex_credentials with stored credentials""" - router = VertexPassThroughRouter() - router.add_vertex_credentials( - project_id="test-project", - location="us-central1", - vertex_credentials='{"credentials": "test-creds"}', - ) - - creds = router.get_vertex_credentials( - project_id="test-project", location="us-central1" - ) - assert creds.vertex_project == "test-project" - assert creds.vertex_location == "us-central1" - assert creds.vertex_credentials == '{"credentials": "test-creds"}' - - -@pytest.mark.asyncio -async def test_add_vertex_credentials(): - """Test add_vertex_credentials functionality""" - router = VertexPassThroughRouter() - - # Test adding valid credentials - router.add_vertex_credentials( - project_id="test-project", - location="us-central1", - vertex_credentials='{"credentials": "test-creds"}', - ) - - assert "test-project-us-central1" in router.deployment_key_to_vertex_credentials - creds = router.deployment_key_to_vertex_credentials["test-project-us-central1"] - assert creds.vertex_project == "test-project" - assert creds.vertex_location == "us-central1" - assert creds.vertex_credentials == '{"credentials": "test-creds"}' - - # Test adding with None values - router.add_vertex_credentials( - project_id=None, - location=None, - vertex_credentials='{"credentials": "test-creds"}', - ) - # Should not add None values - assert len(router.deployment_key_to_vertex_credentials) == 1 - - -@pytest.mark.asyncio -async def test_get_deployment_key(): - """Test _get_deployment_key with various inputs""" - router = VertexPassThroughRouter() - - # Test with valid inputs - key = router._get_deployment_key("test-project", "us-central1") - assert key == "test-project-us-central1" - - # Test with None values - key = router._get_deployment_key(None, "us-central1") - assert key is None - - key = router._get_deployment_key("test-project", None) - assert key is None - - key = router._get_deployment_key(None, None) - assert key is None - - -@pytest.mark.asyncio -async def test_get_vertex_project_id_from_url(): - """Test _get_vertex_project_id_from_url with various URLs""" - # Test with valid URL - url = "https://us-central1-aiplatform.googleapis.com/v1/projects/test-project/locations/us-central1/publishers/google/models/gemini-pro:streamGenerateContent" - project_id = VertexPassThroughRouter._get_vertex_project_id_from_url(url) - assert project_id == "test-project" - - # Test with invalid URL - url = "https://invalid-url.com" - project_id = VertexPassThroughRouter._get_vertex_project_id_from_url(url) - assert project_id is None - - -@pytest.mark.asyncio -async def test_get_vertex_location_from_url(): - """Test _get_vertex_location_from_url with various URLs""" - # Test with valid URL - url = "https://us-central1-aiplatform.googleapis.com/v1/projects/test-project/locations/us-central1/publishers/google/models/gemini-pro:streamGenerateContent" - location = VertexPassThroughRouter._get_vertex_location_from_url(url) - assert location == "us-central1" - - # Test with invalid URL - url = "https://invalid-url.com" - location = VertexPassThroughRouter._get_vertex_location_from_url(url) - assert location is None diff --git a/tests/proxy_admin_ui_tests/test_route_check_unit_tests.py b/tests/proxy_admin_ui_tests/test_route_check_unit_tests.py index 7f5ed297ca..937eb6f298 100644 --- a/tests/proxy_admin_ui_tests/test_route_check_unit_tests.py +++ b/tests/proxy_admin_ui_tests/test_route_check_unit_tests.py @@ -30,9 +30,6 @@ from litellm.proxy._types import LiteLLM_UserTable, LitellmUserRoles, UserAPIKey from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import ( router as llm_passthrough_router, ) -from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import ( - router as vertex_router, -) # Replace the actual hash_token function with our mock import litellm.proxy.auth.route_checks @@ -96,7 +93,7 @@ def test_is_llm_api_route(): assert RouteChecks.is_llm_api_route("/key/regenerate/82akk800000000jjsk") is False assert RouteChecks.is_llm_api_route("/key/82akk800000000jjsk/delete") is False - all_llm_api_routes = vertex_router.routes + llm_passthrough_router.routes + all_llm_api_routes = llm_passthrough_router.routes # check all routes in llm_passthrough_router, ensure they are considered llm api routes for route in all_llm_api_routes: diff --git a/tests/router_unit_tests/test_router_adding_deployments.py b/tests/router_unit_tests/test_router_adding_deployments.py index fca3f147e5..53fe7347d3 100644 --- a/tests/router_unit_tests/test_router_adding_deployments.py +++ b/tests/router_unit_tests/test_router_adding_deployments.py @@ -36,11 +36,11 @@ def test_initialize_deployment_for_pass_through_success(): ) # Verify the credentials were properly set - from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import ( - vertex_pass_through_router, + from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import ( + passthrough_endpoint_router, ) - vertex_creds = vertex_pass_through_router.get_vertex_credentials( + vertex_creds = passthrough_endpoint_router.get_vertex_credentials( project_id="test-project", location="us-central1" ) assert vertex_creds.vertex_project == "test-project" @@ -123,21 +123,21 @@ def test_add_vertex_pass_through_deployment(): router.add_deployment(deployment) # Get the vertex credentials from the router - from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import ( - vertex_pass_through_router, + from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import ( + passthrough_endpoint_router, ) # current state of pass-through vertex router print("\n vertex_pass_through_router.deployment_key_to_vertex_credentials\n\n") print( json.dumps( - vertex_pass_through_router.deployment_key_to_vertex_credentials, + passthrough_endpoint_router.deployment_key_to_vertex_credentials, indent=4, default=str, ) ) - vertex_creds = vertex_pass_through_router.get_vertex_credentials( + vertex_creds = passthrough_endpoint_router.get_vertex_credentials( project_id="test-project", location="us-central1" )