mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
Merge branch 'main' into litellm_web_search_2
This commit is contained in:
commit
3a454d00df
19 changed files with 1099 additions and 731 deletions
|
@ -15,6 +15,91 @@ Pass-through endpoints for Vertex AI - call provider-specific endpoint, in nativ
|
|||
|
||||
Just replace `https://REGION-aiplatform.googleapis.com` with `LITELLM_PROXY_BASE_URL/vertex_ai`
|
||||
|
||||
LiteLLM supports 3 flows for calling Vertex AI endpoints via pass-through:
|
||||
|
||||
1. **Specific Credentials**: Admin sets passthrough credentials for a specific project/region.
|
||||
|
||||
2. **Default Credentials**: Admin sets default credentials.
|
||||
|
||||
3. **Client-Side Credentials**: User can send client-side credentials through to Vertex AI (default behavior - if no default or mapped credentials are found, the request is passed through directly).
|
||||
|
||||
|
||||
## Example Usage
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="specific_credentials" label="Specific Project/Region">
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: gemini-1.0-pro
|
||||
litellm_params:
|
||||
model: vertex_ai/gemini-1.0-pro
|
||||
vertex_project: adroit-crow-413218
|
||||
vertex_region: us-central1
|
||||
vertex_credentials: /path/to/credentials.json
|
||||
use_in_pass_through: true # 👈 KEY CHANGE
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
<TabItem value="default_credentials" label="Default Credentials">
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="yaml" label="Set in config.yaml">
|
||||
|
||||
```yaml
|
||||
default_vertex_config:
|
||||
vertex_project: adroit-crow-413218
|
||||
vertex_region: us-central1
|
||||
vertex_credentials: /path/to/credentials.json
|
||||
```
|
||||
</TabItem>
|
||||
<TabItem value="env_var" label="Set in environment variables">
|
||||
|
||||
```bash
|
||||
export DEFAULT_VERTEXAI_PROJECT="adroit-crow-413218"
|
||||
export DEFAULT_VERTEXAI_LOCATION="us-central1"
|
||||
export DEFAULT_GOOGLE_APPLICATION_CREDENTIALS="/path/to/credentials.json"
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
</TabItem>
|
||||
<TabItem value="client_credentials" label="Client Credentials">
|
||||
|
||||
Try Gemini 2.0 Flash (curl)
|
||||
|
||||
```
|
||||
MODEL_ID="gemini-2.0-flash-001"
|
||||
PROJECT_ID="YOUR_PROJECT_ID"
|
||||
```
|
||||
|
||||
```bash
|
||||
curl \
|
||||
-X POST \
|
||||
-H "Authorization: Bearer $(gcloud auth application-default print-access-token)" \
|
||||
-H "Content-Type: application/json" \
|
||||
"${LITELLM_PROXY_BASE_URL}/vertex_ai/v1/projects/${PROJECT_ID}/locations/us-central1/publishers/google/models/${MODEL_ID}:streamGenerateContent" -d \
|
||||
$'{
|
||||
"contents": {
|
||||
"role": "user",
|
||||
"parts": [
|
||||
{
|
||||
"fileData": {
|
||||
"mimeType": "image/png",
|
||||
"fileUri": "gs://generativeai-downloads/images/scones.jpg"
|
||||
}
|
||||
},
|
||||
{
|
||||
"text": "Describe this picture."
|
||||
}
|
||||
]
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
|
||||
#### **Example Usage**
|
||||
|
||||
|
@ -22,7 +107,7 @@ Just replace `https://REGION-aiplatform.googleapis.com` with `LITELLM_PROXY_BASE
|
|||
<TabItem value="curl" label="curl">
|
||||
|
||||
```bash
|
||||
curl http://localhost:4000/vertex_ai/publishers/google/models/gemini-1.0-pro:generateContent \
|
||||
curl http://localhost:4000/vertex_ai/vertex_ai/v1/projects/${PROJECT_ID}/locations/us-central1/publishers/google/models/${MODEL_ID}:generateContent \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "x-litellm-api-key: Bearer sk-1234" \
|
||||
-d '{
|
||||
|
@ -101,7 +186,7 @@ litellm
|
|||
Let's call the Google AI Studio token counting endpoint
|
||||
|
||||
```bash
|
||||
curl http://localhost:4000/vertex-ai/publishers/google/models/gemini-1.0-pro:generateContent \
|
||||
curl http://localhost:4000/vertex-ai/v1/projects/${PROJECT_ID}/locations/us-central1/publishers/google/models/gemini-1.0-pro:generateContent \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer sk-1234" \
|
||||
-d '{
|
||||
|
@ -140,7 +225,7 @@ LiteLLM Proxy Server supports two methods of authentication to Vertex AI:
|
|||
|
||||
|
||||
```shell
|
||||
curl http://localhost:4000/vertex_ai/publishers/google/models/gemini-1.5-flash-001:generateContent \
|
||||
curl http://localhost:4000/vertex_ai/v1/projects/${PROJECT_ID}/locations/us-central1/publishers/google/models/gemini-1.5-flash-001:generateContent \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "x-litellm-api-key: Bearer sk-1234" \
|
||||
-d '{"contents":[{"role": "user", "parts":[{"text": "hi"}]}]}'
|
||||
|
@ -152,7 +237,7 @@ curl http://localhost:4000/vertex_ai/publishers/google/models/gemini-1.5-flash-0
|
|||
|
||||
|
||||
```shell
|
||||
curl http://localhost:4000/vertex_ai/publishers/google/models/textembedding-gecko@001:predict \
|
||||
curl http://localhost:4000/vertex_ai/v1/projects/${PROJECT_ID}/locations/us-central1/publishers/google/models/textembedding-gecko@001:predict \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "x-litellm-api-key: Bearer sk-1234" \
|
||||
-d '{"instances":[{"content": "gm"}]}'
|
||||
|
@ -162,7 +247,7 @@ curl http://localhost:4000/vertex_ai/publishers/google/models/textembedding-geck
|
|||
### Imagen API
|
||||
|
||||
```shell
|
||||
curl http://localhost:4000/vertex_ai/publishers/google/models/imagen-3.0-generate-001:predict \
|
||||
curl http://localhost:4000/vertex_ai/v1/projects/${PROJECT_ID}/locations/us-central1/publishers/google/models/imagen-3.0-generate-001:predict \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "x-litellm-api-key: Bearer sk-1234" \
|
||||
-d '{"instances":[{"prompt": "make an otter"}], "parameters": {"sampleCount": 1}}'
|
||||
|
@ -172,7 +257,7 @@ curl http://localhost:4000/vertex_ai/publishers/google/models/imagen-3.0-generat
|
|||
### Count Tokens API
|
||||
|
||||
```shell
|
||||
curl http://localhost:4000/vertex_ai/publishers/google/models/gemini-1.5-flash-001:countTokens \
|
||||
curl http://localhost:4000/vertex_ai/v1/projects/${PROJECT_ID}/locations/us-central1/publishers/google/models/gemini-1.5-flash-001:countTokens \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "x-litellm-api-key: Bearer sk-1234" \
|
||||
-d '{"contents":[{"role": "user", "parts":[{"text": "hi"}]}]}'
|
||||
|
@ -183,7 +268,7 @@ Create Fine Tuning Job
|
|||
|
||||
|
||||
```shell
|
||||
curl http://localhost:4000/vertex_ai/tuningJobs \
|
||||
curl http://localhost:4000/vertex_ai/v1/projects/${PROJECT_ID}/locations/us-central1/publishers/google/models/gemini-1.5-flash-001:tuningJobs \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "x-litellm-api-key: Bearer sk-1234" \
|
||||
-d '{
|
||||
|
@ -243,7 +328,7 @@ Expected Response
|
|||
|
||||
|
||||
```bash
|
||||
curl http://localhost:4000/vertex_ai/publishers/google/models/gemini-1.0-pro:generateContent \
|
||||
curl http://localhost:4000/vertex_ai/v1/projects/${PROJECT_ID}/locations/us-central1/publishers/google/models/gemini-1.0-pro:generateContent \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "x-litellm-api-key: Bearer sk-1234" \
|
||||
-d '{
|
||||
|
@ -268,7 +353,7 @@ tags: ["vertex-js-sdk", "pass-through-endpoint"]
|
|||
<TabItem value="curl" label="curl">
|
||||
|
||||
```bash
|
||||
curl http://localhost:4000/vertex-ai/publishers/google/models/gemini-1.0-pro:generateContent \
|
||||
curl http://localhost:4000/vertex_ai/v1/projects/${PROJECT_ID}/locations/us-central1/publishers/google/models/gemini-1.0-pro:generateContent \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "x-litellm-api-key: Bearer sk-1234" \
|
||||
-H "tags: vertex-js-sdk,pass-through-endpoint" \
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import re
|
||||
from typing import Dict, List, Literal, Optional, Tuple, Union
|
||||
|
||||
import httpx
|
||||
|
@ -280,3 +281,81 @@ def _convert_vertex_datetime_to_openai_datetime(vertex_datetime: str) -> int:
|
|||
dt = datetime.strptime(vertex_datetime, "%Y-%m-%dT%H:%M:%S.%fZ")
|
||||
# Convert to Unix timestamp (seconds since epoch)
|
||||
return int(dt.timestamp())
|
||||
|
||||
|
||||
def get_vertex_project_id_from_url(url: str) -> Optional[str]:
|
||||
"""
|
||||
Get the vertex project id from the url
|
||||
|
||||
`https://${LOCATION}-aiplatform.googleapis.com/v1/projects/${PROJECT_ID}/locations/${LOCATION}/publishers/google/models/${MODEL_ID}:streamGenerateContent`
|
||||
"""
|
||||
match = re.search(r"/projects/([^/]+)", url)
|
||||
return match.group(1) if match else None
|
||||
|
||||
|
||||
def get_vertex_location_from_url(url: str) -> Optional[str]:
|
||||
"""
|
||||
Get the vertex location from the url
|
||||
|
||||
`https://${LOCATION}-aiplatform.googleapis.com/v1/projects/${PROJECT_ID}/locations/${LOCATION}/publishers/google/models/${MODEL_ID}:streamGenerateContent`
|
||||
"""
|
||||
match = re.search(r"/locations/([^/]+)", url)
|
||||
return match.group(1) if match else None
|
||||
|
||||
|
||||
def replace_project_and_location_in_route(
|
||||
requested_route: str, vertex_project: str, vertex_location: str
|
||||
) -> str:
|
||||
"""
|
||||
Replace project and location values in the route with the provided values
|
||||
"""
|
||||
# Replace project and location values while keeping route structure
|
||||
modified_route = re.sub(
|
||||
r"/projects/[^/]+/locations/[^/]+/",
|
||||
f"/projects/{vertex_project}/locations/{vertex_location}/",
|
||||
requested_route,
|
||||
)
|
||||
return modified_route
|
||||
|
||||
|
||||
def construct_target_url(
|
||||
base_url: str,
|
||||
requested_route: str,
|
||||
vertex_location: Optional[str],
|
||||
vertex_project: Optional[str],
|
||||
) -> httpx.URL:
|
||||
"""
|
||||
Allow user to specify their own project id / location.
|
||||
|
||||
If missing, use defaults
|
||||
|
||||
Handle cachedContent scenario - https://github.com/BerriAI/litellm/issues/5460
|
||||
|
||||
Constructed Url:
|
||||
POST https://LOCATION-aiplatform.googleapis.com/{version}/projects/PROJECT_ID/locations/LOCATION/cachedContents
|
||||
"""
|
||||
new_base_url = httpx.URL(base_url)
|
||||
if "locations" in requested_route: # contains the target project id + location
|
||||
if vertex_project and vertex_location:
|
||||
requested_route = replace_project_and_location_in_route(
|
||||
requested_route, vertex_project, vertex_location
|
||||
)
|
||||
return new_base_url.copy_with(path=requested_route)
|
||||
|
||||
"""
|
||||
- Add endpoint version (e.g. v1beta for cachedContent, v1 for rest)
|
||||
- Add default project id
|
||||
- Add default location
|
||||
"""
|
||||
vertex_version: Literal["v1", "v1beta1"] = "v1"
|
||||
if "cachedContent" in requested_route:
|
||||
vertex_version = "v1beta1"
|
||||
|
||||
base_requested_route = "{}/projects/{}/locations/{}".format(
|
||||
vertex_version, vertex_project, vertex_location
|
||||
)
|
||||
|
||||
updated_requested_route = "/" + base_requested_route + requested_route
|
||||
|
||||
updated_url = new_base_url.copy_with(path=updated_requested_route)
|
||||
return updated_url
|
||||
|
|
|
@ -82,6 +82,31 @@
|
|||
"supports_system_messages": true,
|
||||
"supports_tool_choice": true
|
||||
},
|
||||
"gpt-4o-search-preview-2025-03-11": {
|
||||
"max_tokens": 16384,
|
||||
"max_input_tokens": 128000,
|
||||
"max_output_tokens": 16384,
|
||||
"input_cost_per_token": 0.0000025,
|
||||
"output_cost_per_token": 0.000010,
|
||||
"input_cost_per_token_batches": 0.00000125,
|
||||
"output_cost_per_token_batches": 0.00000500,
|
||||
"cache_read_input_token_cost": 0.00000125,
|
||||
"litellm_provider": "openai",
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_parallel_function_calling": true,
|
||||
"supports_response_schema": true,
|
||||
"supports_vision": true,
|
||||
"supports_prompt_caching": true,
|
||||
"supports_system_messages": true,
|
||||
"supports_tool_choice": true,
|
||||
"supports_web_search": true,
|
||||
"search_context_cost_per_query": {
|
||||
"search_context_size_low": 0.030,
|
||||
"search_context_size_medium": 0.035,
|
||||
"search_context_size_high": 0.050
|
||||
}
|
||||
},
|
||||
"gpt-4o-search-preview": {
|
||||
"max_tokens": 16384,
|
||||
"max_input_tokens": 128000,
|
||||
|
@ -232,6 +257,31 @@
|
|||
"supports_system_messages": true,
|
||||
"supports_tool_choice": true
|
||||
},
|
||||
"gpt-4o-mini-search-preview-2025-03-11":{
|
||||
"max_tokens": 16384,
|
||||
"max_input_tokens": 128000,
|
||||
"max_output_tokens": 16384,
|
||||
"input_cost_per_token": 0.00000015,
|
||||
"output_cost_per_token": 0.00000060,
|
||||
"input_cost_per_token_batches": 0.000000075,
|
||||
"output_cost_per_token_batches": 0.00000030,
|
||||
"cache_read_input_token_cost": 0.000000075,
|
||||
"litellm_provider": "openai",
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_parallel_function_calling": true,
|
||||
"supports_response_schema": true,
|
||||
"supports_vision": true,
|
||||
"supports_prompt_caching": true,
|
||||
"supports_system_messages": true,
|
||||
"supports_tool_choice": true,
|
||||
"supports_web_search": true,
|
||||
"search_context_cost_per_query": {
|
||||
"search_context_size_low": 0.025,
|
||||
"search_context_size_medium": 0.0275,
|
||||
"search_context_size_high": 0.030
|
||||
}
|
||||
},
|
||||
"gpt-4o-mini-search-preview": {
|
||||
"max_tokens": 16384,
|
||||
"max_input_tokens": 128000,
|
||||
|
|
16
litellm/proxy/pass_through_endpoints/common_utils.py
Normal file
16
litellm/proxy/pass_through_endpoints/common_utils.py
Normal file
|
@ -0,0 +1,16 @@
|
|||
from fastapi import Request
|
||||
|
||||
|
||||
def get_litellm_virtual_key(request: Request) -> str:
|
||||
"""
|
||||
Extract and format API key from request headers.
|
||||
Prioritizes x-litellm-api-key over Authorization header.
|
||||
|
||||
|
||||
Vertex JS SDK uses `Authorization` header, we use `x-litellm-api-key` to pass litellm virtual key
|
||||
|
||||
"""
|
||||
litellm_api_key = request.headers.get("x-litellm-api-key")
|
||||
if litellm_api_key:
|
||||
return f"Bearer {litellm_api_key}"
|
||||
return request.headers.get("Authorization", "")
|
|
@ -12,10 +12,13 @@ import httpx
|
|||
from fastapi import APIRouter, Depends, HTTPException, Request, Response
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.constants import BEDROCK_AGENT_RUNTIME_PASS_THROUGH_ROUTES
|
||||
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
|
||||
from litellm.proxy._types import *
|
||||
from litellm.proxy.auth.route_checks import RouteChecks
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.pass_through_endpoints.common_utils import get_litellm_virtual_key
|
||||
from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
|
||||
create_pass_through_route,
|
||||
)
|
||||
|
@ -23,6 +26,7 @@ from litellm.secret_managers.main import get_secret_str
|
|||
|
||||
from .passthrough_endpoint_router import PassthroughEndpointRouter
|
||||
|
||||
vertex_llm_base = VertexBase()
|
||||
router = APIRouter()
|
||||
default_vertex_config = None
|
||||
|
||||
|
@ -417,6 +421,138 @@ async def azure_proxy_route(
|
|||
)
|
||||
|
||||
|
||||
@router.api_route(
|
||||
"/vertex-ai/{endpoint:path}",
|
||||
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
|
||||
tags=["Vertex AI Pass-through", "pass-through"],
|
||||
include_in_schema=False,
|
||||
)
|
||||
@router.api_route(
|
||||
"/vertex_ai/{endpoint:path}",
|
||||
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
|
||||
tags=["Vertex AI Pass-through", "pass-through"],
|
||||
)
|
||||
async def vertex_proxy_route(
|
||||
endpoint: str,
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
):
|
||||
"""
|
||||
Call LiteLLM proxy via Vertex AI SDK.
|
||||
|
||||
[Docs](https://docs.litellm.ai/docs/pass_through/vertex_ai)
|
||||
"""
|
||||
from litellm.llms.vertex_ai.common_utils import (
|
||||
construct_target_url,
|
||||
get_vertex_location_from_url,
|
||||
get_vertex_project_id_from_url,
|
||||
)
|
||||
|
||||
encoded_endpoint = httpx.URL(endpoint).path
|
||||
verbose_proxy_logger.debug("requested endpoint %s", endpoint)
|
||||
headers: dict = {}
|
||||
api_key_to_use = get_litellm_virtual_key(request=request)
|
||||
user_api_key_dict = await user_api_key_auth(
|
||||
request=request,
|
||||
api_key=api_key_to_use,
|
||||
)
|
||||
vertex_project: Optional[str] = get_vertex_project_id_from_url(endpoint)
|
||||
vertex_location: Optional[str] = get_vertex_location_from_url(endpoint)
|
||||
vertex_credentials = passthrough_endpoint_router.get_vertex_credentials(
|
||||
project_id=vertex_project,
|
||||
location=vertex_location,
|
||||
)
|
||||
|
||||
headers_passed_through = False
|
||||
# Use headers from the incoming request if no vertex credentials are found
|
||||
if vertex_credentials is None or vertex_credentials.vertex_project is None:
|
||||
headers = dict(request.headers) or {}
|
||||
headers_passed_through = True
|
||||
verbose_proxy_logger.debug(
|
||||
"default_vertex_config not set, incoming request headers %s", headers
|
||||
)
|
||||
base_target_url = f"https://{vertex_location}-aiplatform.googleapis.com/"
|
||||
headers.pop("content-length", None)
|
||||
headers.pop("host", None)
|
||||
else:
|
||||
vertex_project = vertex_credentials.vertex_project
|
||||
vertex_location = vertex_credentials.vertex_location
|
||||
vertex_credentials_str = vertex_credentials.vertex_credentials
|
||||
|
||||
# Construct base URL for the target endpoint
|
||||
base_target_url = f"https://{vertex_location}-aiplatform.googleapis.com/"
|
||||
|
||||
_auth_header, vertex_project = await vertex_llm_base._ensure_access_token_async(
|
||||
credentials=vertex_credentials_str,
|
||||
project_id=vertex_project,
|
||||
custom_llm_provider="vertex_ai_beta",
|
||||
)
|
||||
|
||||
auth_header, _ = vertex_llm_base._get_token_and_url(
|
||||
model="",
|
||||
auth_header=_auth_header,
|
||||
gemini_api_key=None,
|
||||
vertex_credentials=vertex_credentials_str,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
stream=False,
|
||||
custom_llm_provider="vertex_ai_beta",
|
||||
api_base="",
|
||||
)
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {auth_header}",
|
||||
}
|
||||
|
||||
request_route = encoded_endpoint
|
||||
verbose_proxy_logger.debug("request_route %s", request_route)
|
||||
|
||||
# Ensure endpoint starts with '/' for proper URL construction
|
||||
if not encoded_endpoint.startswith("/"):
|
||||
encoded_endpoint = "/" + encoded_endpoint
|
||||
|
||||
# Construct the full target URL using httpx
|
||||
updated_url = construct_target_url(
|
||||
base_url=base_target_url,
|
||||
requested_route=encoded_endpoint,
|
||||
vertex_location=vertex_location,
|
||||
vertex_project=vertex_project,
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug("updated url %s", updated_url)
|
||||
|
||||
## check for streaming
|
||||
target = str(updated_url)
|
||||
is_streaming_request = False
|
||||
if "stream" in str(updated_url):
|
||||
is_streaming_request = True
|
||||
target += "?alt=sse"
|
||||
|
||||
## CREATE PASS-THROUGH
|
||||
endpoint_func = create_pass_through_route(
|
||||
endpoint=endpoint,
|
||||
target=target,
|
||||
custom_headers=headers,
|
||||
) # dynamically construct pass-through endpoint based on incoming path
|
||||
|
||||
try:
|
||||
received_value = await endpoint_func(
|
||||
request,
|
||||
fastapi_response,
|
||||
user_api_key_dict,
|
||||
stream=is_streaming_request, # type: ignore
|
||||
)
|
||||
except Exception as e:
|
||||
if headers_passed_through:
|
||||
raise Exception(
|
||||
f"No credentials found on proxy for this request. Headers were passed through directly but request failed with error: {str(e)}"
|
||||
)
|
||||
else:
|
||||
raise e
|
||||
|
||||
return received_value
|
||||
|
||||
|
||||
@router.api_route(
|
||||
"/openai/{endpoint:path}",
|
||||
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
from typing import Dict, Optional
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm._logging import verbose_router_logger
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
|
||||
from litellm.types.passthrough_endpoints.vertex_ai import VertexPassThroughCredentials
|
||||
|
||||
|
||||
class PassthroughEndpointRouter:
|
||||
|
@ -11,6 +13,10 @@ class PassthroughEndpointRouter:
|
|||
|
||||
def __init__(self):
|
||||
self.credentials: Dict[str, str] = {}
|
||||
self.deployment_key_to_vertex_credentials: Dict[
|
||||
str, VertexPassThroughCredentials
|
||||
] = {}
|
||||
self.default_vertex_config: Optional[VertexPassThroughCredentials] = None
|
||||
|
||||
def set_pass_through_credentials(
|
||||
self,
|
||||
|
@ -45,14 +51,14 @@ class PassthroughEndpointRouter:
|
|||
custom_llm_provider=custom_llm_provider,
|
||||
region_name=region_name,
|
||||
)
|
||||
verbose_logger.debug(
|
||||
verbose_router_logger.debug(
|
||||
f"Pass-through llm endpoints router, looking for credentials for {credential_name}"
|
||||
)
|
||||
if credential_name in self.credentials:
|
||||
verbose_logger.debug(f"Found credentials for {credential_name}")
|
||||
verbose_router_logger.debug(f"Found credentials for {credential_name}")
|
||||
return self.credentials[credential_name]
|
||||
else:
|
||||
verbose_logger.debug(
|
||||
verbose_router_logger.debug(
|
||||
f"No credentials found for {credential_name}, looking for env variable"
|
||||
)
|
||||
_env_variable_name = (
|
||||
|
@ -62,6 +68,100 @@ class PassthroughEndpointRouter:
|
|||
)
|
||||
return get_secret_str(_env_variable_name)
|
||||
|
||||
def _get_vertex_env_vars(self) -> VertexPassThroughCredentials:
|
||||
"""
|
||||
Helper to get vertex pass through config from environment variables
|
||||
|
||||
The following environment variables are used:
|
||||
- DEFAULT_VERTEXAI_PROJECT (project id)
|
||||
- DEFAULT_VERTEXAI_LOCATION (location)
|
||||
- DEFAULT_GOOGLE_APPLICATION_CREDENTIALS (path to credentials file)
|
||||
"""
|
||||
return VertexPassThroughCredentials(
|
||||
vertex_project=get_secret_str("DEFAULT_VERTEXAI_PROJECT"),
|
||||
vertex_location=get_secret_str("DEFAULT_VERTEXAI_LOCATION"),
|
||||
vertex_credentials=get_secret_str("DEFAULT_GOOGLE_APPLICATION_CREDENTIALS"),
|
||||
)
|
||||
|
||||
def set_default_vertex_config(self, config: Optional[dict] = None):
|
||||
"""Sets vertex configuration from provided config and/or environment variables
|
||||
|
||||
Args:
|
||||
config (Optional[dict]): Configuration dictionary
|
||||
Example: {
|
||||
"vertex_project": "my-project-123",
|
||||
"vertex_location": "us-central1",
|
||||
"vertex_credentials": "os.environ/GOOGLE_CREDS"
|
||||
}
|
||||
"""
|
||||
# Initialize config dictionary if None
|
||||
if config is None:
|
||||
self.default_vertex_config = self._get_vertex_env_vars()
|
||||
return
|
||||
|
||||
if isinstance(config, dict):
|
||||
for key, value in config.items():
|
||||
if isinstance(value, str) and value.startswith("os.environ/"):
|
||||
config[key] = get_secret_str(value)
|
||||
|
||||
self.default_vertex_config = VertexPassThroughCredentials(**config)
|
||||
|
||||
def add_vertex_credentials(
|
||||
self,
|
||||
project_id: str,
|
||||
location: str,
|
||||
vertex_credentials: VERTEX_CREDENTIALS_TYPES,
|
||||
):
|
||||
"""
|
||||
Add the vertex credentials for the given project-id, location
|
||||
"""
|
||||
|
||||
deployment_key = self._get_deployment_key(
|
||||
project_id=project_id,
|
||||
location=location,
|
||||
)
|
||||
if deployment_key is None:
|
||||
verbose_router_logger.debug(
|
||||
"No deployment key found for project-id, location"
|
||||
)
|
||||
return
|
||||
vertex_pass_through_credentials = VertexPassThroughCredentials(
|
||||
vertex_project=project_id,
|
||||
vertex_location=location,
|
||||
vertex_credentials=vertex_credentials,
|
||||
)
|
||||
self.deployment_key_to_vertex_credentials[deployment_key] = (
|
||||
vertex_pass_through_credentials
|
||||
)
|
||||
|
||||
def _get_deployment_key(
|
||||
self, project_id: Optional[str], location: Optional[str]
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Get the deployment key for the given project-id, location
|
||||
"""
|
||||
if project_id is None or location is None:
|
||||
return None
|
||||
return f"{project_id}-{location}"
|
||||
|
||||
def get_vertex_credentials(
|
||||
self, project_id: Optional[str], location: Optional[str]
|
||||
) -> Optional[VertexPassThroughCredentials]:
|
||||
"""
|
||||
Get the vertex credentials for the given project-id, location
|
||||
"""
|
||||
deployment_key = self._get_deployment_key(
|
||||
project_id=project_id,
|
||||
location=location,
|
||||
)
|
||||
|
||||
if deployment_key is None:
|
||||
return self.default_vertex_config
|
||||
if deployment_key in self.deployment_key_to_vertex_credentials:
|
||||
return self.deployment_key_to_vertex_credentials[deployment_key]
|
||||
else:
|
||||
return self.default_vertex_config
|
||||
|
||||
def _get_credential_name_for_provider(
|
||||
self,
|
||||
custom_llm_provider: str,
|
||||
|
|
|
@ -235,6 +235,9 @@ from litellm.proxy.openai_files_endpoints.files_endpoints import (
|
|||
router as openai_files_router,
|
||||
)
|
||||
from litellm.proxy.openai_files_endpoints.files_endpoints import set_files_config
|
||||
from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import (
|
||||
passthrough_endpoint_router,
|
||||
)
|
||||
from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import (
|
||||
router as llm_passthrough_router,
|
||||
)
|
||||
|
@ -272,8 +275,6 @@ from litellm.proxy.utils import (
|
|||
from litellm.proxy.vertex_ai_endpoints.langfuse_endpoints import (
|
||||
router as langfuse_router,
|
||||
)
|
||||
from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import router as vertex_router
|
||||
from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import set_default_vertex_config
|
||||
from litellm.router import (
|
||||
AssistantsTypedDict,
|
||||
Deployment,
|
||||
|
@ -2115,7 +2116,9 @@ class ProxyConfig:
|
|||
|
||||
## default config for vertex ai routes
|
||||
default_vertex_config = config.get("default_vertex_config", None)
|
||||
set_default_vertex_config(config=default_vertex_config)
|
||||
passthrough_endpoint_router.set_default_vertex_config(
|
||||
config=default_vertex_config
|
||||
)
|
||||
|
||||
## ROUTER SETTINGS (e.g. routing_strategy, ...)
|
||||
router_settings = config.get("router_settings", None)
|
||||
|
@ -8161,7 +8164,6 @@ app.include_router(batches_router)
|
|||
app.include_router(rerank_router)
|
||||
app.include_router(fine_tuning_router)
|
||||
app.include_router(credential_router)
|
||||
app.include_router(vertex_router)
|
||||
app.include_router(llm_passthrough_router)
|
||||
app.include_router(anthropic_router)
|
||||
app.include_router(langfuse_router)
|
||||
|
|
|
@ -1,274 +0,0 @@
|
|||
import traceback
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, HTTPException, Request, Response, status
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.fine_tuning.main import vertex_fine_tuning_apis_instance
|
||||
from litellm.proxy._types import *
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
|
||||
create_pass_through_route,
|
||||
)
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.passthrough_endpoints.vertex_ai import *
|
||||
|
||||
from .vertex_passthrough_router import VertexPassThroughRouter
|
||||
|
||||
router = APIRouter()
|
||||
vertex_pass_through_router = VertexPassThroughRouter()
|
||||
|
||||
default_vertex_config: VertexPassThroughCredentials = VertexPassThroughCredentials()
|
||||
|
||||
|
||||
def _get_vertex_env_vars() -> VertexPassThroughCredentials:
|
||||
"""
|
||||
Helper to get vertex pass through config from environment variables
|
||||
|
||||
The following environment variables are used:
|
||||
- DEFAULT_VERTEXAI_PROJECT (project id)
|
||||
- DEFAULT_VERTEXAI_LOCATION (location)
|
||||
- DEFAULT_GOOGLE_APPLICATION_CREDENTIALS (path to credentials file)
|
||||
"""
|
||||
return VertexPassThroughCredentials(
|
||||
vertex_project=get_secret_str("DEFAULT_VERTEXAI_PROJECT"),
|
||||
vertex_location=get_secret_str("DEFAULT_VERTEXAI_LOCATION"),
|
||||
vertex_credentials=get_secret_str("DEFAULT_GOOGLE_APPLICATION_CREDENTIALS"),
|
||||
)
|
||||
|
||||
|
||||
def set_default_vertex_config(config: Optional[dict] = None):
|
||||
"""Sets vertex configuration from provided config and/or environment variables
|
||||
|
||||
Args:
|
||||
config (Optional[dict]): Configuration dictionary
|
||||
Example: {
|
||||
"vertex_project": "my-project-123",
|
||||
"vertex_location": "us-central1",
|
||||
"vertex_credentials": "os.environ/GOOGLE_CREDS"
|
||||
}
|
||||
"""
|
||||
global default_vertex_config
|
||||
|
||||
# Initialize config dictionary if None
|
||||
if config is None:
|
||||
default_vertex_config = _get_vertex_env_vars()
|
||||
return
|
||||
|
||||
if isinstance(config, dict):
|
||||
for key, value in config.items():
|
||||
if isinstance(value, str) and value.startswith("os.environ/"):
|
||||
config[key] = litellm.get_secret(value)
|
||||
|
||||
_set_default_vertex_config(VertexPassThroughCredentials(**config))
|
||||
|
||||
|
||||
def _set_default_vertex_config(
|
||||
vertex_pass_through_credentials: VertexPassThroughCredentials,
|
||||
):
|
||||
global default_vertex_config
|
||||
default_vertex_config = vertex_pass_through_credentials
|
||||
|
||||
|
||||
def exception_handler(e: Exception):
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.proxy_server.v1/projects/tuningJobs(): Exception occurred - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
if isinstance(e, HTTPException):
|
||||
return ProxyException(
|
||||
message=getattr(e, "message", str(e.detail)),
|
||||
type=getattr(e, "type", "None"),
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
||||
)
|
||||
else:
|
||||
error_msg = f"{str(e)}"
|
||||
return ProxyException(
|
||||
message=getattr(e, "message", error_msg),
|
||||
type=getattr(e, "type", "None"),
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", 500),
|
||||
)
|
||||
|
||||
|
||||
def construct_target_url(
|
||||
base_url: str,
|
||||
requested_route: str,
|
||||
default_vertex_location: Optional[str],
|
||||
default_vertex_project: Optional[str],
|
||||
) -> httpx.URL:
|
||||
"""
|
||||
Allow user to specify their own project id / location.
|
||||
|
||||
If missing, use defaults
|
||||
|
||||
Handle cachedContent scenario - https://github.com/BerriAI/litellm/issues/5460
|
||||
|
||||
Constructed Url:
|
||||
POST https://LOCATION-aiplatform.googleapis.com/{version}/projects/PROJECT_ID/locations/LOCATION/cachedContents
|
||||
"""
|
||||
new_base_url = httpx.URL(base_url)
|
||||
if "locations" in requested_route: # contains the target project id + location
|
||||
updated_url = new_base_url.copy_with(path=requested_route)
|
||||
return updated_url
|
||||
"""
|
||||
- Add endpoint version (e.g. v1beta for cachedContent, v1 for rest)
|
||||
- Add default project id
|
||||
- Add default location
|
||||
"""
|
||||
vertex_version: Literal["v1", "v1beta1"] = "v1"
|
||||
if "cachedContent" in requested_route:
|
||||
vertex_version = "v1beta1"
|
||||
|
||||
base_requested_route = "{}/projects/{}/locations/{}".format(
|
||||
vertex_version, default_vertex_project, default_vertex_location
|
||||
)
|
||||
|
||||
updated_requested_route = "/" + base_requested_route + requested_route
|
||||
|
||||
updated_url = new_base_url.copy_with(path=updated_requested_route)
|
||||
return updated_url
|
||||
|
||||
|
||||
@router.api_route(
|
||||
"/vertex-ai/{endpoint:path}",
|
||||
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
|
||||
tags=["Vertex AI Pass-through", "pass-through"],
|
||||
include_in_schema=False,
|
||||
)
|
||||
@router.api_route(
|
||||
"/vertex_ai/{endpoint:path}",
|
||||
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
|
||||
tags=["Vertex AI Pass-through", "pass-through"],
|
||||
)
|
||||
async def vertex_proxy_route(
|
||||
endpoint: str,
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
):
|
||||
"""
|
||||
Call LiteLLM proxy via Vertex AI SDK.
|
||||
|
||||
[Docs](https://docs.litellm.ai/docs/pass_through/vertex_ai)
|
||||
"""
|
||||
encoded_endpoint = httpx.URL(endpoint).path
|
||||
verbose_proxy_logger.debug("requested endpoint %s", endpoint)
|
||||
headers: dict = {}
|
||||
api_key_to_use = get_litellm_virtual_key(request=request)
|
||||
user_api_key_dict = await user_api_key_auth(
|
||||
request=request,
|
||||
api_key=api_key_to_use,
|
||||
)
|
||||
|
||||
vertex_project: Optional[str] = (
|
||||
VertexPassThroughRouter._get_vertex_project_id_from_url(endpoint)
|
||||
)
|
||||
vertex_location: Optional[str] = (
|
||||
VertexPassThroughRouter._get_vertex_location_from_url(endpoint)
|
||||
)
|
||||
vertex_credentials = vertex_pass_through_router.get_vertex_credentials(
|
||||
project_id=vertex_project,
|
||||
location=vertex_location,
|
||||
)
|
||||
|
||||
# Use headers from the incoming request if no vertex credentials are found
|
||||
if vertex_credentials.vertex_project is None:
|
||||
headers = dict(request.headers) or {}
|
||||
verbose_proxy_logger.debug(
|
||||
"default_vertex_config not set, incoming request headers %s", headers
|
||||
)
|
||||
base_target_url = f"https://{vertex_location}-aiplatform.googleapis.com/"
|
||||
headers.pop("content-length", None)
|
||||
headers.pop("host", None)
|
||||
else:
|
||||
vertex_project = vertex_credentials.vertex_project
|
||||
vertex_location = vertex_credentials.vertex_location
|
||||
vertex_credentials_str = vertex_credentials.vertex_credentials
|
||||
|
||||
# Construct base URL for the target endpoint
|
||||
base_target_url = f"https://{vertex_location}-aiplatform.googleapis.com/"
|
||||
|
||||
_auth_header, vertex_project = (
|
||||
await vertex_fine_tuning_apis_instance._ensure_access_token_async(
|
||||
credentials=vertex_credentials_str,
|
||||
project_id=vertex_project,
|
||||
custom_llm_provider="vertex_ai_beta",
|
||||
)
|
||||
)
|
||||
|
||||
auth_header, _ = vertex_fine_tuning_apis_instance._get_token_and_url(
|
||||
model="",
|
||||
auth_header=_auth_header,
|
||||
gemini_api_key=None,
|
||||
vertex_credentials=vertex_credentials_str,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
stream=False,
|
||||
custom_llm_provider="vertex_ai_beta",
|
||||
api_base="",
|
||||
)
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {auth_header}",
|
||||
}
|
||||
|
||||
request_route = encoded_endpoint
|
||||
verbose_proxy_logger.debug("request_route %s", request_route)
|
||||
|
||||
# Ensure endpoint starts with '/' for proper URL construction
|
||||
if not encoded_endpoint.startswith("/"):
|
||||
encoded_endpoint = "/" + encoded_endpoint
|
||||
|
||||
# Construct the full target URL using httpx
|
||||
updated_url = construct_target_url(
|
||||
base_url=base_target_url,
|
||||
requested_route=encoded_endpoint,
|
||||
default_vertex_location=vertex_location,
|
||||
default_vertex_project=vertex_project,
|
||||
)
|
||||
# base_url = httpx.URL(base_target_url)
|
||||
# updated_url = base_url.copy_with(path=encoded_endpoint)
|
||||
|
||||
verbose_proxy_logger.debug("updated url %s", updated_url)
|
||||
|
||||
## check for streaming
|
||||
target = str(updated_url)
|
||||
is_streaming_request = False
|
||||
if "stream" in str(updated_url):
|
||||
is_streaming_request = True
|
||||
target += "?alt=sse"
|
||||
|
||||
## CREATE PASS-THROUGH
|
||||
endpoint_func = create_pass_through_route(
|
||||
endpoint=endpoint,
|
||||
target=target,
|
||||
custom_headers=headers,
|
||||
) # dynamically construct pass-through endpoint based on incoming path
|
||||
received_value = await endpoint_func(
|
||||
request,
|
||||
fastapi_response,
|
||||
user_api_key_dict,
|
||||
stream=is_streaming_request, # type: ignore
|
||||
)
|
||||
|
||||
return received_value
|
||||
|
||||
|
||||
def get_litellm_virtual_key(request: Request) -> str:
|
||||
"""
|
||||
Extract and format API key from request headers.
|
||||
Prioritizes x-litellm-api-key over Authorization header.
|
||||
|
||||
|
||||
Vertex JS SDK uses `Authorization` header, we use `x-litellm-api-key` to pass litellm virtual key
|
||||
|
||||
"""
|
||||
litellm_api_key = request.headers.get("x-litellm-api-key")
|
||||
if litellm_api_key:
|
||||
return f"Bearer {litellm_api_key}"
|
||||
return request.headers.get("Authorization", "")
|
|
@ -1,121 +0,0 @@
|
|||
import json
|
||||
import re
|
||||
from typing import Dict, Optional
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import (
|
||||
VertexPassThroughCredentials,
|
||||
)
|
||||
from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
|
||||
|
||||
|
||||
class VertexPassThroughRouter:
|
||||
"""
|
||||
Vertex Pass Through Router for Vertex AI pass-through endpoints
|
||||
|
||||
|
||||
- if request specifies a project-id, location -> use credentials corresponding to the project-id, location
|
||||
- if request does not specify a project-id, location -> use credentials corresponding to the DEFAULT_VERTEXAI_PROJECT, DEFAULT_VERTEXAI_LOCATION
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Initialize the VertexPassThroughRouter
|
||||
Stores the vertex credentials for each deployment key
|
||||
```
|
||||
{
|
||||
"project_id-location": VertexPassThroughCredentials,
|
||||
"adroit-crow-us-central1": VertexPassThroughCredentials,
|
||||
}
|
||||
```
|
||||
"""
|
||||
self.deployment_key_to_vertex_credentials: Dict[
|
||||
str, VertexPassThroughCredentials
|
||||
] = {}
|
||||
pass
|
||||
|
||||
def get_vertex_credentials(
|
||||
self, project_id: Optional[str], location: Optional[str]
|
||||
) -> VertexPassThroughCredentials:
|
||||
"""
|
||||
Get the vertex credentials for the given project-id, location
|
||||
"""
|
||||
from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import (
|
||||
default_vertex_config,
|
||||
)
|
||||
|
||||
deployment_key = self._get_deployment_key(
|
||||
project_id=project_id,
|
||||
location=location,
|
||||
)
|
||||
if deployment_key is None:
|
||||
return default_vertex_config
|
||||
if deployment_key in self.deployment_key_to_vertex_credentials:
|
||||
return self.deployment_key_to_vertex_credentials[deployment_key]
|
||||
else:
|
||||
return default_vertex_config
|
||||
|
||||
def add_vertex_credentials(
|
||||
self,
|
||||
project_id: str,
|
||||
location: str,
|
||||
vertex_credentials: VERTEX_CREDENTIALS_TYPES,
|
||||
):
|
||||
"""
|
||||
Add the vertex credentials for the given project-id, location
|
||||
"""
|
||||
from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import (
|
||||
_set_default_vertex_config,
|
||||
)
|
||||
|
||||
deployment_key = self._get_deployment_key(
|
||||
project_id=project_id,
|
||||
location=location,
|
||||
)
|
||||
if deployment_key is None:
|
||||
verbose_proxy_logger.debug(
|
||||
"No deployment key found for project-id, location"
|
||||
)
|
||||
return
|
||||
vertex_pass_through_credentials = VertexPassThroughCredentials(
|
||||
vertex_project=project_id,
|
||||
vertex_location=location,
|
||||
vertex_credentials=vertex_credentials,
|
||||
)
|
||||
self.deployment_key_to_vertex_credentials[deployment_key] = (
|
||||
vertex_pass_through_credentials
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
f"self.deployment_key_to_vertex_credentials: {json.dumps(self.deployment_key_to_vertex_credentials, indent=4, default=str)}"
|
||||
)
|
||||
_set_default_vertex_config(vertex_pass_through_credentials)
|
||||
|
||||
def _get_deployment_key(
|
||||
self, project_id: Optional[str], location: Optional[str]
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Get the deployment key for the given project-id, location
|
||||
"""
|
||||
if project_id is None or location is None:
|
||||
return None
|
||||
return f"{project_id}-{location}"
|
||||
|
||||
@staticmethod
|
||||
def _get_vertex_project_id_from_url(url: str) -> Optional[str]:
|
||||
"""
|
||||
Get the vertex project id from the url
|
||||
|
||||
`https://${LOCATION}-aiplatform.googleapis.com/v1/projects/${PROJECT_ID}/locations/${LOCATION}/publishers/google/models/${MODEL_ID}:streamGenerateContent`
|
||||
"""
|
||||
match = re.search(r"/projects/([^/]+)", url)
|
||||
return match.group(1) if match else None
|
||||
|
||||
@staticmethod
|
||||
def _get_vertex_location_from_url(url: str) -> Optional[str]:
|
||||
"""
|
||||
Get the vertex location from the url
|
||||
|
||||
`https://${LOCATION}-aiplatform.googleapis.com/v1/projects/${PROJECT_ID}/locations/${LOCATION}/publishers/google/models/${MODEL_ID}:streamGenerateContent`
|
||||
"""
|
||||
match = re.search(r"/locations/([^/]+)", url)
|
||||
return match.group(1) if match else None
|
|
@ -4495,11 +4495,11 @@ class Router:
|
|||
Each provider uses diff .env vars for pass-through endpoints, this helper uses the deployment credentials to set the .env vars for pass-through endpoints
|
||||
"""
|
||||
if deployment.litellm_params.use_in_pass_through is True:
|
||||
if custom_llm_provider == "vertex_ai":
|
||||
from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import (
|
||||
vertex_pass_through_router,
|
||||
)
|
||||
from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import (
|
||||
passthrough_endpoint_router,
|
||||
)
|
||||
|
||||
if custom_llm_provider == "vertex_ai":
|
||||
if (
|
||||
deployment.litellm_params.vertex_project is None
|
||||
or deployment.litellm_params.vertex_location is None
|
||||
|
@ -4508,16 +4508,12 @@ class Router:
|
|||
raise ValueError(
|
||||
"vertex_project, vertex_location, and vertex_credentials must be set in litellm_params for pass-through endpoints"
|
||||
)
|
||||
vertex_pass_through_router.add_vertex_credentials(
|
||||
passthrough_endpoint_router.add_vertex_credentials(
|
||||
project_id=deployment.litellm_params.vertex_project,
|
||||
location=deployment.litellm_params.vertex_location,
|
||||
vertex_credentials=deployment.litellm_params.vertex_credentials,
|
||||
)
|
||||
else:
|
||||
from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import (
|
||||
passthrough_endpoint_router,
|
||||
)
|
||||
|
||||
passthrough_endpoint_router.set_pass_through_credentials(
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
api_base=deployment.litellm_params.api_base,
|
||||
|
|
|
@ -82,6 +82,31 @@
|
|||
"supports_system_messages": true,
|
||||
"supports_tool_choice": true
|
||||
},
|
||||
"gpt-4o-search-preview-2025-03-11": {
|
||||
"max_tokens": 16384,
|
||||
"max_input_tokens": 128000,
|
||||
"max_output_tokens": 16384,
|
||||
"input_cost_per_token": 0.0000025,
|
||||
"output_cost_per_token": 0.000010,
|
||||
"input_cost_per_token_batches": 0.00000125,
|
||||
"output_cost_per_token_batches": 0.00000500,
|
||||
"cache_read_input_token_cost": 0.00000125,
|
||||
"litellm_provider": "openai",
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_parallel_function_calling": true,
|
||||
"supports_response_schema": true,
|
||||
"supports_vision": true,
|
||||
"supports_prompt_caching": true,
|
||||
"supports_system_messages": true,
|
||||
"supports_tool_choice": true,
|
||||
"supports_web_search": true,
|
||||
"search_context_cost_per_query": {
|
||||
"search_context_size_low": 0.030,
|
||||
"search_context_size_medium": 0.035,
|
||||
"search_context_size_high": 0.050
|
||||
}
|
||||
},
|
||||
"gpt-4o-search-preview": {
|
||||
"max_tokens": 16384,
|
||||
"max_input_tokens": 128000,
|
||||
|
@ -232,6 +257,31 @@
|
|||
"supports_system_messages": true,
|
||||
"supports_tool_choice": true
|
||||
},
|
||||
"gpt-4o-mini-search-preview-2025-03-11":{
|
||||
"max_tokens": 16384,
|
||||
"max_input_tokens": 128000,
|
||||
"max_output_tokens": 16384,
|
||||
"input_cost_per_token": 0.00000015,
|
||||
"output_cost_per_token": 0.00000060,
|
||||
"input_cost_per_token_batches": 0.000000075,
|
||||
"output_cost_per_token_batches": 0.00000030,
|
||||
"cache_read_input_token_cost": 0.000000075,
|
||||
"litellm_provider": "openai",
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_parallel_function_calling": true,
|
||||
"supports_response_schema": true,
|
||||
"supports_vision": true,
|
||||
"supports_prompt_caching": true,
|
||||
"supports_system_messages": true,
|
||||
"supports_tool_choice": true,
|
||||
"supports_web_search": true,
|
||||
"search_context_cost_per_query": {
|
||||
"search_context_size_low": 0.025,
|
||||
"search_context_size_medium": 0.0275,
|
||||
"search_context_size_high": 0.030
|
||||
}
|
||||
},
|
||||
"gpt-4o-mini-search-preview": {
|
||||
"max_tokens": 16384,
|
||||
"max_input_tokens": 128000,
|
||||
|
|
43
tests/litellm/llms/vertex_ai/test_vertex_ai_common_utils.py
Normal file
43
tests/litellm/llms/vertex_ai/test_vertex_ai_common_utils.py
Normal file
|
@ -0,0 +1,43 @@
|
|||
import os
|
||||
import sys
|
||||
from unittest.mock import MagicMock, call, patch
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
import litellm
|
||||
from litellm.llms.vertex_ai.common_utils import (
|
||||
get_vertex_location_from_url,
|
||||
get_vertex_project_id_from_url,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_vertex_project_id_from_url():
|
||||
"""Test _get_vertex_project_id_from_url with various URLs"""
|
||||
# Test with valid URL
|
||||
url = "https://us-central1-aiplatform.googleapis.com/v1/projects/test-project/locations/us-central1/publishers/google/models/gemini-pro:streamGenerateContent"
|
||||
project_id = get_vertex_project_id_from_url(url)
|
||||
assert project_id == "test-project"
|
||||
|
||||
# Test with invalid URL
|
||||
url = "https://invalid-url.com"
|
||||
project_id = get_vertex_project_id_from_url(url)
|
||||
assert project_id is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_vertex_location_from_url():
|
||||
"""Test _get_vertex_location_from_url with various URLs"""
|
||||
# Test with valid URL
|
||||
url = "https://us-central1-aiplatform.googleapis.com/v1/projects/test-project/locations/us-central1/publishers/google/models/gemini-pro:streamGenerateContent"
|
||||
location = get_vertex_location_from_url(url)
|
||||
assert location == "us-central1"
|
||||
|
||||
# Test with invalid URL
|
||||
url = "https://invalid-url.com"
|
||||
location = get_vertex_location_from_url(url)
|
||||
assert location is None
|
|
@ -1,7 +1,9 @@
|
|||
import json
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import MagicMock, patch
|
||||
import traceback
|
||||
from unittest import mock
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
@ -17,7 +19,9 @@ from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import (
|
|||
BaseOpenAIPassThroughHandler,
|
||||
RouteChecks,
|
||||
create_pass_through_route,
|
||||
vertex_proxy_route,
|
||||
)
|
||||
from litellm.types.passthrough_endpoints.vertex_ai import VertexPassThroughCredentials
|
||||
|
||||
|
||||
class TestBaseOpenAIPassThroughHandler:
|
||||
|
@ -176,3 +180,279 @@ class TestBaseOpenAIPassThroughHandler:
|
|||
print(f"query_params: {call_kwargs['query_params']}")
|
||||
assert call_kwargs["stream"] is False
|
||||
assert call_kwargs["query_params"] == {"model": "gpt-4"}
|
||||
|
||||
|
||||
class TestVertexAIPassThroughHandler:
|
||||
"""
|
||||
Case 1: User set passthrough credentials - confirm credentials used.
|
||||
|
||||
Case 2: User set default credentials, no exact passthrough credentials - confirm default credentials used.
|
||||
|
||||
Case 3: No default credentials, no mapped credentials - request passed through directly.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_vertex_passthrough_with_credentials(self, monkeypatch):
|
||||
"""
|
||||
Test that when passthrough credentials are set, they are correctly used in the request
|
||||
"""
|
||||
from litellm.proxy.pass_through_endpoints.passthrough_endpoint_router import (
|
||||
PassthroughEndpointRouter,
|
||||
)
|
||||
|
||||
vertex_project = "test-project"
|
||||
vertex_location = "us-central1"
|
||||
vertex_credentials = "test-creds"
|
||||
|
||||
pass_through_router = PassthroughEndpointRouter()
|
||||
|
||||
pass_through_router.add_vertex_credentials(
|
||||
project_id=vertex_project,
|
||||
location=vertex_location,
|
||||
vertex_credentials=vertex_credentials,
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.passthrough_endpoint_router",
|
||||
pass_through_router,
|
||||
)
|
||||
|
||||
endpoint = f"/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/gemini-1.5-flash:generateContent"
|
||||
|
||||
# Mock request
|
||||
mock_request = Request(
|
||||
scope={
|
||||
"type": "http",
|
||||
"method": "POST",
|
||||
"path": endpoint,
|
||||
"headers": [
|
||||
(b"Authorization", b"Bearer test-creds"),
|
||||
(b"Content-Type", b"application/json"),
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
# Mock response
|
||||
mock_response = Response()
|
||||
|
||||
# Mock vertex credentials
|
||||
test_project = vertex_project
|
||||
test_location = vertex_location
|
||||
test_token = vertex_credentials
|
||||
|
||||
with mock.patch(
|
||||
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.vertex_llm_base._ensure_access_token_async"
|
||||
) as mock_ensure_token, mock.patch(
|
||||
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.vertex_llm_base._get_token_and_url"
|
||||
) as mock_get_token, mock.patch(
|
||||
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.create_pass_through_route"
|
||||
) as mock_create_route:
|
||||
mock_ensure_token.return_value = ("test-auth-header", test_project)
|
||||
mock_get_token.return_value = (test_token, "")
|
||||
|
||||
# Call the route
|
||||
try:
|
||||
await vertex_proxy_route(
|
||||
endpoint=endpoint,
|
||||
request=mock_request,
|
||||
fastapi_response=mock_response,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
|
||||
# Verify create_pass_through_route was called with correct arguments
|
||||
mock_create_route.assert_called_once_with(
|
||||
endpoint=endpoint,
|
||||
target=f"https://{test_location}-aiplatform.googleapis.com/v1/projects/{test_project}/locations/{test_location}/publishers/google/models/gemini-1.5-flash:generateContent",
|
||||
custom_headers={"Authorization": f"Bearer {test_token}"},
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"initial_endpoint",
|
||||
[
|
||||
"publishers/google/models/gemini-1.5-flash:generateContent",
|
||||
"v1/projects/bad-project/locations/bad-location/publishers/google/models/gemini-1.5-flash:generateContent",
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_vertex_passthrough_with_default_credentials(
|
||||
self, monkeypatch, initial_endpoint
|
||||
):
|
||||
"""
|
||||
Test that when no passthrough credentials are set, default credentials are used in the request
|
||||
"""
|
||||
from litellm.proxy.pass_through_endpoints.passthrough_endpoint_router import (
|
||||
PassthroughEndpointRouter,
|
||||
)
|
||||
|
||||
# Setup default credentials
|
||||
default_project = "default-project"
|
||||
default_location = "us-central1"
|
||||
default_credentials = "default-creds"
|
||||
|
||||
pass_through_router = PassthroughEndpointRouter()
|
||||
pass_through_router.default_vertex_config = VertexPassThroughCredentials(
|
||||
vertex_project=default_project,
|
||||
vertex_location=default_location,
|
||||
vertex_credentials=default_credentials,
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.passthrough_endpoint_router",
|
||||
pass_through_router,
|
||||
)
|
||||
|
||||
# Use different project/location in request than the default
|
||||
endpoint = initial_endpoint
|
||||
|
||||
mock_request = Request(
|
||||
scope={
|
||||
"type": "http",
|
||||
"method": "POST",
|
||||
"path": f"/vertex_ai/{endpoint}",
|
||||
"headers": {},
|
||||
}
|
||||
)
|
||||
mock_response = Response()
|
||||
|
||||
with mock.patch(
|
||||
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.vertex_llm_base._ensure_access_token_async"
|
||||
) as mock_ensure_token, mock.patch(
|
||||
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.vertex_llm_base._get_token_and_url"
|
||||
) as mock_get_token, mock.patch(
|
||||
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.create_pass_through_route"
|
||||
) as mock_create_route:
|
||||
mock_ensure_token.return_value = ("test-auth-header", default_project)
|
||||
mock_get_token.return_value = (default_credentials, "")
|
||||
|
||||
try:
|
||||
await vertex_proxy_route(
|
||||
endpoint=endpoint,
|
||||
request=mock_request,
|
||||
fastapi_response=mock_response,
|
||||
)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
print(f"Error: {e}")
|
||||
|
||||
# Verify default credentials were used
|
||||
mock_create_route.assert_called_once_with(
|
||||
endpoint=endpoint,
|
||||
target=f"https://{default_location}-aiplatform.googleapis.com/v1/projects/{default_project}/locations/{default_location}/publishers/google/models/gemini-1.5-flash:generateContent",
|
||||
custom_headers={"Authorization": f"Bearer {default_credentials}"},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_vertex_passthrough_with_no_default_credentials(self, monkeypatch):
|
||||
"""
|
||||
Test that when no default credentials are set, the request fails
|
||||
"""
|
||||
"""
|
||||
Test that when passthrough credentials are set, they are correctly used in the request
|
||||
"""
|
||||
from litellm.proxy.pass_through_endpoints.passthrough_endpoint_router import (
|
||||
PassthroughEndpointRouter,
|
||||
)
|
||||
|
||||
vertex_project = "my-project"
|
||||
vertex_location = "us-central1"
|
||||
vertex_credentials = "test-creds"
|
||||
|
||||
test_project = "test-project"
|
||||
test_location = "test-location"
|
||||
test_token = "test-creds"
|
||||
|
||||
pass_through_router = PassthroughEndpointRouter()
|
||||
|
||||
pass_through_router.add_vertex_credentials(
|
||||
project_id=vertex_project,
|
||||
location=vertex_location,
|
||||
vertex_credentials=vertex_credentials,
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.passthrough_endpoint_router",
|
||||
pass_through_router,
|
||||
)
|
||||
|
||||
endpoint = f"/v1/projects/{test_project}/locations/{test_location}/publishers/google/models/gemini-1.5-flash:generateContent"
|
||||
|
||||
# Mock request
|
||||
mock_request = Request(
|
||||
scope={
|
||||
"type": "http",
|
||||
"method": "POST",
|
||||
"path": endpoint,
|
||||
"headers": [
|
||||
(b"authorization", b"Bearer test-creds"),
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
# Mock response
|
||||
mock_response = Response()
|
||||
|
||||
with mock.patch(
|
||||
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.vertex_llm_base._ensure_access_token_async"
|
||||
) as mock_ensure_token, mock.patch(
|
||||
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.vertex_llm_base._get_token_and_url"
|
||||
) as mock_get_token, mock.patch(
|
||||
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.create_pass_through_route"
|
||||
) as mock_create_route:
|
||||
mock_ensure_token.return_value = ("test-auth-header", test_project)
|
||||
mock_get_token.return_value = (test_token, "")
|
||||
|
||||
# Call the route
|
||||
try:
|
||||
await vertex_proxy_route(
|
||||
endpoint=endpoint,
|
||||
request=mock_request,
|
||||
fastapi_response=mock_response,
|
||||
)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
print(f"Error: {e}")
|
||||
|
||||
# Verify create_pass_through_route was called with correct arguments
|
||||
mock_create_route.assert_called_once_with(
|
||||
endpoint=endpoint,
|
||||
target=f"https://{test_location}-aiplatform.googleapis.com/v1/projects/{test_project}/locations/{test_location}/publishers/google/models/gemini-1.5-flash:generateContent",
|
||||
custom_headers={"authorization": f"Bearer {test_token}"},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_vertex_proxy_route_api_key_auth(self):
|
||||
"""
|
||||
Critical
|
||||
|
||||
This is how Vertex AI JS SDK will Auth to Litellm Proxy
|
||||
"""
|
||||
# Mock dependencies
|
||||
mock_request = Mock()
|
||||
mock_request.headers = {"x-litellm-api-key": "test-key-123"}
|
||||
mock_request.method = "POST"
|
||||
mock_response = Mock()
|
||||
|
||||
with patch(
|
||||
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.user_api_key_auth"
|
||||
) as mock_auth:
|
||||
mock_auth.return_value = {"api_key": "test-key-123"}
|
||||
|
||||
with patch(
|
||||
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.create_pass_through_route"
|
||||
) as mock_pass_through:
|
||||
mock_pass_through.return_value = AsyncMock(
|
||||
return_value={"status": "success"}
|
||||
)
|
||||
|
||||
# Call the function
|
||||
result = await vertex_proxy_route(
|
||||
endpoint="v1/projects/test-project/locations/us-central1/publishers/google/models/gemini-1.5-pro:generateContent",
|
||||
request=mock_request,
|
||||
fastapi_response=mock_response,
|
||||
)
|
||||
|
||||
# Verify user_api_key_auth was called with the correct Bearer token
|
||||
mock_auth.assert_called_once()
|
||||
call_args = mock_auth.call_args[1]
|
||||
assert call_args["api_key"] == "Bearer test-key-123"
|
||||
|
|
|
@ -0,0 +1,44 @@
|
|||
import json
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
from unittest import mock
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from fastapi import Request, Response
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
from litellm.proxy.pass_through_endpoints.common_utils import get_litellm_virtual_key
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_litellm_virtual_key():
|
||||
"""
|
||||
Test that the get_litellm_virtual_key function correctly handles the API key authentication
|
||||
"""
|
||||
# Test with x-litellm-api-key
|
||||
mock_request = Mock()
|
||||
mock_request.headers = {"x-litellm-api-key": "test-key-123"}
|
||||
result = get_litellm_virtual_key(mock_request)
|
||||
assert result == "Bearer test-key-123"
|
||||
|
||||
# Test with Authorization header
|
||||
mock_request.headers = {"Authorization": "Bearer auth-key-456"}
|
||||
result = get_litellm_virtual_key(mock_request)
|
||||
assert result == "Bearer auth-key-456"
|
||||
|
||||
# Test with both headers (x-litellm-api-key should take precedence)
|
||||
mock_request.headers = {
|
||||
"x-litellm-api-key": "test-key-123",
|
||||
"Authorization": "Bearer auth-key-456",
|
||||
}
|
||||
result = get_litellm_virtual_key(mock_request)
|
||||
assert result == "Bearer test-key-123"
|
|
@ -339,9 +339,6 @@ def test_pass_through_routes_support_all_methods():
|
|||
from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import (
|
||||
router as llm_router,
|
||||
)
|
||||
from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import (
|
||||
router as vertex_router,
|
||||
)
|
||||
|
||||
# Expected HTTP methods
|
||||
expected_methods = {"GET", "POST", "PUT", "DELETE", "PATCH"}
|
||||
|
@ -361,7 +358,6 @@ def test_pass_through_routes_support_all_methods():
|
|||
|
||||
# Check both routers
|
||||
check_router_methods(llm_router)
|
||||
check_router_methods(vertex_router)
|
||||
|
||||
|
||||
def test_is_bedrock_agent_runtime_route():
|
||||
|
|
|
@ -11,6 +11,7 @@ from unittest.mock import patch
|
|||
from litellm.proxy.pass_through_endpoints.passthrough_endpoint_router import (
|
||||
PassthroughEndpointRouter,
|
||||
)
|
||||
from litellm.types.passthrough_endpoints.vertex_ai import VertexPassThroughCredentials
|
||||
|
||||
passthrough_endpoint_router = PassthroughEndpointRouter()
|
||||
|
||||
|
@ -132,3 +133,185 @@ class TestPassthroughEndpointRouter(unittest.TestCase):
|
|||
),
|
||||
"COHERE_API_KEY",
|
||||
)
|
||||
|
||||
def test_get_deployment_key(self):
|
||||
"""Test _get_deployment_key with various inputs"""
|
||||
router = PassthroughEndpointRouter()
|
||||
|
||||
# Test with valid inputs
|
||||
key = router._get_deployment_key("test-project", "us-central1")
|
||||
assert key == "test-project-us-central1"
|
||||
|
||||
# Test with None values
|
||||
key = router._get_deployment_key(None, "us-central1")
|
||||
assert key is None
|
||||
|
||||
key = router._get_deployment_key("test-project", None)
|
||||
assert key is None
|
||||
|
||||
key = router._get_deployment_key(None, None)
|
||||
assert key is None
|
||||
|
||||
def test_add_vertex_credentials(self):
|
||||
"""Test add_vertex_credentials functionality"""
|
||||
router = PassthroughEndpointRouter()
|
||||
|
||||
# Test adding valid credentials
|
||||
router.add_vertex_credentials(
|
||||
project_id="test-project",
|
||||
location="us-central1",
|
||||
vertex_credentials='{"credentials": "test-creds"}',
|
||||
)
|
||||
|
||||
assert "test-project-us-central1" in router.deployment_key_to_vertex_credentials
|
||||
creds = router.deployment_key_to_vertex_credentials["test-project-us-central1"]
|
||||
assert creds.vertex_project == "test-project"
|
||||
assert creds.vertex_location == "us-central1"
|
||||
assert creds.vertex_credentials == '{"credentials": "test-creds"}'
|
||||
|
||||
# Test adding with None values
|
||||
router.add_vertex_credentials(
|
||||
project_id=None,
|
||||
location=None,
|
||||
vertex_credentials='{"credentials": "test-creds"}',
|
||||
)
|
||||
# Should not add None values
|
||||
assert len(router.deployment_key_to_vertex_credentials) == 1
|
||||
|
||||
def test_default_credentials(self):
|
||||
"""
|
||||
Test get_vertex_credentials with stored credentials.
|
||||
|
||||
Tests if default credentials are used if set.
|
||||
|
||||
Tests if no default credentials are used, if no default set
|
||||
"""
|
||||
router = PassthroughEndpointRouter()
|
||||
router.add_vertex_credentials(
|
||||
project_id="test-project",
|
||||
location="us-central1",
|
||||
vertex_credentials='{"credentials": "test-creds"}',
|
||||
)
|
||||
|
||||
creds = router.get_vertex_credentials(
|
||||
project_id="test-project", location="us-central2"
|
||||
)
|
||||
|
||||
assert creds is None
|
||||
|
||||
def test_get_vertex_env_vars(self):
|
||||
"""Test that _get_vertex_env_vars correctly reads environment variables"""
|
||||
# Set environment variables for the test
|
||||
os.environ["DEFAULT_VERTEXAI_PROJECT"] = "test-project-123"
|
||||
os.environ["DEFAULT_VERTEXAI_LOCATION"] = "us-central1"
|
||||
os.environ["DEFAULT_GOOGLE_APPLICATION_CREDENTIALS"] = "/path/to/creds"
|
||||
|
||||
try:
|
||||
result = self.router._get_vertex_env_vars()
|
||||
print(result)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, VertexPassThroughCredentials)
|
||||
assert result.vertex_project == "test-project-123"
|
||||
assert result.vertex_location == "us-central1"
|
||||
assert result.vertex_credentials == "/path/to/creds"
|
||||
|
||||
finally:
|
||||
# Clean up environment variables
|
||||
del os.environ["DEFAULT_VERTEXAI_PROJECT"]
|
||||
del os.environ["DEFAULT_VERTEXAI_LOCATION"]
|
||||
del os.environ["DEFAULT_GOOGLE_APPLICATION_CREDENTIALS"]
|
||||
|
||||
def test_set_default_vertex_config(self):
|
||||
"""Test set_default_vertex_config with various inputs"""
|
||||
# Test with None config - set environment variables first
|
||||
os.environ["DEFAULT_VERTEXAI_PROJECT"] = "env-project"
|
||||
os.environ["DEFAULT_VERTEXAI_LOCATION"] = "env-location"
|
||||
os.environ["DEFAULT_GOOGLE_APPLICATION_CREDENTIALS"] = "env-creds"
|
||||
os.environ["GOOGLE_CREDS"] = "secret-creds"
|
||||
|
||||
try:
|
||||
# Test with None config
|
||||
self.router.set_default_vertex_config()
|
||||
|
||||
assert self.router.default_vertex_config.vertex_project == "env-project"
|
||||
assert self.router.default_vertex_config.vertex_location == "env-location"
|
||||
assert self.router.default_vertex_config.vertex_credentials == "env-creds"
|
||||
|
||||
# Test with valid config.yaml settings on vertex_config
|
||||
test_config = {
|
||||
"vertex_project": "my-project-123",
|
||||
"vertex_location": "us-central1",
|
||||
"vertex_credentials": "path/to/creds",
|
||||
}
|
||||
self.router.set_default_vertex_config(test_config)
|
||||
|
||||
assert self.router.default_vertex_config.vertex_project == "my-project-123"
|
||||
assert self.router.default_vertex_config.vertex_location == "us-central1"
|
||||
assert (
|
||||
self.router.default_vertex_config.vertex_credentials == "path/to/creds"
|
||||
)
|
||||
|
||||
# Test with environment variable reference
|
||||
test_config = {
|
||||
"vertex_project": "my-project-123",
|
||||
"vertex_location": "us-central1",
|
||||
"vertex_credentials": "os.environ/GOOGLE_CREDS",
|
||||
}
|
||||
self.router.set_default_vertex_config(test_config)
|
||||
|
||||
assert (
|
||||
self.router.default_vertex_config.vertex_credentials == "secret-creds"
|
||||
)
|
||||
|
||||
finally:
|
||||
# Clean up environment variables
|
||||
del os.environ["DEFAULT_VERTEXAI_PROJECT"]
|
||||
del os.environ["DEFAULT_VERTEXAI_LOCATION"]
|
||||
del os.environ["DEFAULT_GOOGLE_APPLICATION_CREDENTIALS"]
|
||||
del os.environ["GOOGLE_CREDS"]
|
||||
|
||||
def test_vertex_passthrough_router_init(self):
|
||||
"""Test VertexPassThroughRouter initialization"""
|
||||
router = PassthroughEndpointRouter()
|
||||
assert isinstance(router.deployment_key_to_vertex_credentials, dict)
|
||||
assert len(router.deployment_key_to_vertex_credentials) == 0
|
||||
|
||||
def test_get_vertex_credentials_none(self):
|
||||
"""Test get_vertex_credentials with various inputs"""
|
||||
router = PassthroughEndpointRouter()
|
||||
|
||||
router.set_default_vertex_config(
|
||||
config={
|
||||
"vertex_project": None,
|
||||
"vertex_location": None,
|
||||
"vertex_credentials": None,
|
||||
}
|
||||
)
|
||||
|
||||
# Test with None project_id and location - should return default config
|
||||
creds = router.get_vertex_credentials(None, None)
|
||||
assert isinstance(creds, VertexPassThroughCredentials)
|
||||
|
||||
# Test with valid project_id and location but no stored credentials
|
||||
creds = router.get_vertex_credentials("test-project", "us-central1")
|
||||
assert isinstance(creds, VertexPassThroughCredentials)
|
||||
assert creds.vertex_project is None
|
||||
assert creds.vertex_location is None
|
||||
assert creds.vertex_credentials is None
|
||||
|
||||
def test_get_vertex_credentials_stored(self):
|
||||
"""Test get_vertex_credentials with stored credentials"""
|
||||
router = PassthroughEndpointRouter()
|
||||
router.add_vertex_credentials(
|
||||
project_id="test-project",
|
||||
location="us-central1",
|
||||
vertex_credentials='{"credentials": "test-creds"}',
|
||||
)
|
||||
|
||||
creds = router.get_vertex_credentials(
|
||||
project_id="test-project", location="us-central1"
|
||||
)
|
||||
assert creds.vertex_project == "test-project"
|
||||
assert creds.vertex_location == "us-central1"
|
||||
assert creds.vertex_credentials == '{"credentials": "test-creds"}'
|
||||
|
|
|
@ -1,294 +0,0 @@
|
|||
import json
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system-path
|
||||
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
|
||||
|
||||
from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import (
|
||||
get_litellm_virtual_key,
|
||||
vertex_proxy_route,
|
||||
_get_vertex_env_vars,
|
||||
set_default_vertex_config,
|
||||
VertexPassThroughCredentials,
|
||||
default_vertex_config,
|
||||
)
|
||||
from litellm.proxy.vertex_ai_endpoints.vertex_passthrough_router import (
|
||||
VertexPassThroughRouter,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_litellm_virtual_key():
|
||||
"""
|
||||
Test that the get_litellm_virtual_key function correctly handles the API key authentication
|
||||
"""
|
||||
# Test with x-litellm-api-key
|
||||
mock_request = Mock()
|
||||
mock_request.headers = {"x-litellm-api-key": "test-key-123"}
|
||||
result = get_litellm_virtual_key(mock_request)
|
||||
assert result == "Bearer test-key-123"
|
||||
|
||||
# Test with Authorization header
|
||||
mock_request.headers = {"Authorization": "Bearer auth-key-456"}
|
||||
result = get_litellm_virtual_key(mock_request)
|
||||
assert result == "Bearer auth-key-456"
|
||||
|
||||
# Test with both headers (x-litellm-api-key should take precedence)
|
||||
mock_request.headers = {
|
||||
"x-litellm-api-key": "test-key-123",
|
||||
"Authorization": "Bearer auth-key-456",
|
||||
}
|
||||
result = get_litellm_virtual_key(mock_request)
|
||||
assert result == "Bearer test-key-123"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_vertex_proxy_route_api_key_auth():
|
||||
"""
|
||||
Critical
|
||||
|
||||
This is how Vertex AI JS SDK will Auth to Litellm Proxy
|
||||
"""
|
||||
# Mock dependencies
|
||||
mock_request = Mock()
|
||||
mock_request.headers = {"x-litellm-api-key": "test-key-123"}
|
||||
mock_request.method = "POST"
|
||||
mock_response = Mock()
|
||||
|
||||
with patch(
|
||||
"litellm.proxy.vertex_ai_endpoints.vertex_endpoints.user_api_key_auth"
|
||||
) as mock_auth:
|
||||
mock_auth.return_value = {"api_key": "test-key-123"}
|
||||
|
||||
with patch(
|
||||
"litellm.proxy.vertex_ai_endpoints.vertex_endpoints.create_pass_through_route"
|
||||
) as mock_pass_through:
|
||||
mock_pass_through.return_value = AsyncMock(
|
||||
return_value={"status": "success"}
|
||||
)
|
||||
|
||||
# Call the function
|
||||
result = await vertex_proxy_route(
|
||||
endpoint="v1/projects/test-project/locations/us-central1/publishers/google/models/gemini-1.5-pro:generateContent",
|
||||
request=mock_request,
|
||||
fastapi_response=mock_response,
|
||||
)
|
||||
|
||||
# Verify user_api_key_auth was called with the correct Bearer token
|
||||
mock_auth.assert_called_once()
|
||||
call_args = mock_auth.call_args[1]
|
||||
assert call_args["api_key"] == "Bearer test-key-123"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_vertex_env_vars():
|
||||
"""Test that _get_vertex_env_vars correctly reads environment variables"""
|
||||
# Set environment variables for the test
|
||||
os.environ["DEFAULT_VERTEXAI_PROJECT"] = "test-project-123"
|
||||
os.environ["DEFAULT_VERTEXAI_LOCATION"] = "us-central1"
|
||||
os.environ["DEFAULT_GOOGLE_APPLICATION_CREDENTIALS"] = "/path/to/creds"
|
||||
|
||||
try:
|
||||
result = _get_vertex_env_vars()
|
||||
print(result)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, VertexPassThroughCredentials)
|
||||
assert result.vertex_project == "test-project-123"
|
||||
assert result.vertex_location == "us-central1"
|
||||
assert result.vertex_credentials == "/path/to/creds"
|
||||
|
||||
finally:
|
||||
# Clean up environment variables
|
||||
del os.environ["DEFAULT_VERTEXAI_PROJECT"]
|
||||
del os.environ["DEFAULT_VERTEXAI_LOCATION"]
|
||||
del os.environ["DEFAULT_GOOGLE_APPLICATION_CREDENTIALS"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_default_vertex_config():
|
||||
"""Test set_default_vertex_config with various inputs"""
|
||||
# Test with None config - set environment variables first
|
||||
os.environ["DEFAULT_VERTEXAI_PROJECT"] = "env-project"
|
||||
os.environ["DEFAULT_VERTEXAI_LOCATION"] = "env-location"
|
||||
os.environ["DEFAULT_GOOGLE_APPLICATION_CREDENTIALS"] = "env-creds"
|
||||
os.environ["GOOGLE_CREDS"] = "secret-creds"
|
||||
|
||||
try:
|
||||
# Test with None config
|
||||
set_default_vertex_config()
|
||||
from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import (
|
||||
default_vertex_config,
|
||||
)
|
||||
|
||||
assert default_vertex_config.vertex_project == "env-project"
|
||||
assert default_vertex_config.vertex_location == "env-location"
|
||||
assert default_vertex_config.vertex_credentials == "env-creds"
|
||||
|
||||
# Test with valid config.yaml settings on vertex_config
|
||||
test_config = {
|
||||
"vertex_project": "my-project-123",
|
||||
"vertex_location": "us-central1",
|
||||
"vertex_credentials": "path/to/creds",
|
||||
}
|
||||
set_default_vertex_config(test_config)
|
||||
from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import (
|
||||
default_vertex_config,
|
||||
)
|
||||
|
||||
assert default_vertex_config.vertex_project == "my-project-123"
|
||||
assert default_vertex_config.vertex_location == "us-central1"
|
||||
assert default_vertex_config.vertex_credentials == "path/to/creds"
|
||||
|
||||
# Test with environment variable reference
|
||||
test_config = {
|
||||
"vertex_project": "my-project-123",
|
||||
"vertex_location": "us-central1",
|
||||
"vertex_credentials": "os.environ/GOOGLE_CREDS",
|
||||
}
|
||||
set_default_vertex_config(test_config)
|
||||
from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import (
|
||||
default_vertex_config,
|
||||
)
|
||||
|
||||
assert default_vertex_config.vertex_credentials == "secret-creds"
|
||||
|
||||
finally:
|
||||
# Clean up environment variables
|
||||
del os.environ["DEFAULT_VERTEXAI_PROJECT"]
|
||||
del os.environ["DEFAULT_VERTEXAI_LOCATION"]
|
||||
del os.environ["DEFAULT_GOOGLE_APPLICATION_CREDENTIALS"]
|
||||
del os.environ["GOOGLE_CREDS"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_vertex_passthrough_router_init():
|
||||
"""Test VertexPassThroughRouter initialization"""
|
||||
router = VertexPassThroughRouter()
|
||||
assert isinstance(router.deployment_key_to_vertex_credentials, dict)
|
||||
assert len(router.deployment_key_to_vertex_credentials) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_vertex_credentials_none():
|
||||
"""Test get_vertex_credentials with various inputs"""
|
||||
from litellm.proxy.vertex_ai_endpoints import vertex_endpoints
|
||||
|
||||
setattr(vertex_endpoints, "default_vertex_config", VertexPassThroughCredentials())
|
||||
router = VertexPassThroughRouter()
|
||||
|
||||
# Test with None project_id and location - should return default config
|
||||
creds = router.get_vertex_credentials(None, None)
|
||||
assert isinstance(creds, VertexPassThroughCredentials)
|
||||
|
||||
# Test with valid project_id and location but no stored credentials
|
||||
creds = router.get_vertex_credentials("test-project", "us-central1")
|
||||
assert isinstance(creds, VertexPassThroughCredentials)
|
||||
assert creds.vertex_project is None
|
||||
assert creds.vertex_location is None
|
||||
assert creds.vertex_credentials is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_vertex_credentials_stored():
|
||||
"""Test get_vertex_credentials with stored credentials"""
|
||||
router = VertexPassThroughRouter()
|
||||
router.add_vertex_credentials(
|
||||
project_id="test-project",
|
||||
location="us-central1",
|
||||
vertex_credentials='{"credentials": "test-creds"}',
|
||||
)
|
||||
|
||||
creds = router.get_vertex_credentials(
|
||||
project_id="test-project", location="us-central1"
|
||||
)
|
||||
assert creds.vertex_project == "test-project"
|
||||
assert creds.vertex_location == "us-central1"
|
||||
assert creds.vertex_credentials == '{"credentials": "test-creds"}'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_vertex_credentials():
|
||||
"""Test add_vertex_credentials functionality"""
|
||||
router = VertexPassThroughRouter()
|
||||
|
||||
# Test adding valid credentials
|
||||
router.add_vertex_credentials(
|
||||
project_id="test-project",
|
||||
location="us-central1",
|
||||
vertex_credentials='{"credentials": "test-creds"}',
|
||||
)
|
||||
|
||||
assert "test-project-us-central1" in router.deployment_key_to_vertex_credentials
|
||||
creds = router.deployment_key_to_vertex_credentials["test-project-us-central1"]
|
||||
assert creds.vertex_project == "test-project"
|
||||
assert creds.vertex_location == "us-central1"
|
||||
assert creds.vertex_credentials == '{"credentials": "test-creds"}'
|
||||
|
||||
# Test adding with None values
|
||||
router.add_vertex_credentials(
|
||||
project_id=None,
|
||||
location=None,
|
||||
vertex_credentials='{"credentials": "test-creds"}',
|
||||
)
|
||||
# Should not add None values
|
||||
assert len(router.deployment_key_to_vertex_credentials) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_deployment_key():
|
||||
"""Test _get_deployment_key with various inputs"""
|
||||
router = VertexPassThroughRouter()
|
||||
|
||||
# Test with valid inputs
|
||||
key = router._get_deployment_key("test-project", "us-central1")
|
||||
assert key == "test-project-us-central1"
|
||||
|
||||
# Test with None values
|
||||
key = router._get_deployment_key(None, "us-central1")
|
||||
assert key is None
|
||||
|
||||
key = router._get_deployment_key("test-project", None)
|
||||
assert key is None
|
||||
|
||||
key = router._get_deployment_key(None, None)
|
||||
assert key is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_vertex_project_id_from_url():
|
||||
"""Test _get_vertex_project_id_from_url with various URLs"""
|
||||
# Test with valid URL
|
||||
url = "https://us-central1-aiplatform.googleapis.com/v1/projects/test-project/locations/us-central1/publishers/google/models/gemini-pro:streamGenerateContent"
|
||||
project_id = VertexPassThroughRouter._get_vertex_project_id_from_url(url)
|
||||
assert project_id == "test-project"
|
||||
|
||||
# Test with invalid URL
|
||||
url = "https://invalid-url.com"
|
||||
project_id = VertexPassThroughRouter._get_vertex_project_id_from_url(url)
|
||||
assert project_id is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_vertex_location_from_url():
|
||||
"""Test _get_vertex_location_from_url with various URLs"""
|
||||
# Test with valid URL
|
||||
url = "https://us-central1-aiplatform.googleapis.com/v1/projects/test-project/locations/us-central1/publishers/google/models/gemini-pro:streamGenerateContent"
|
||||
location = VertexPassThroughRouter._get_vertex_location_from_url(url)
|
||||
assert location == "us-central1"
|
||||
|
||||
# Test with invalid URL
|
||||
url = "https://invalid-url.com"
|
||||
location = VertexPassThroughRouter._get_vertex_location_from_url(url)
|
||||
assert location is None
|
|
@ -30,9 +30,6 @@ from litellm.proxy._types import LiteLLM_UserTable, LitellmUserRoles, UserAPIKey
|
|||
from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import (
|
||||
router as llm_passthrough_router,
|
||||
)
|
||||
from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import (
|
||||
router as vertex_router,
|
||||
)
|
||||
|
||||
# Replace the actual hash_token function with our mock
|
||||
import litellm.proxy.auth.route_checks
|
||||
|
@ -96,7 +93,7 @@ def test_is_llm_api_route():
|
|||
assert RouteChecks.is_llm_api_route("/key/regenerate/82akk800000000jjsk") is False
|
||||
assert RouteChecks.is_llm_api_route("/key/82akk800000000jjsk/delete") is False
|
||||
|
||||
all_llm_api_routes = vertex_router.routes + llm_passthrough_router.routes
|
||||
all_llm_api_routes = llm_passthrough_router.routes
|
||||
|
||||
# check all routes in llm_passthrough_router, ensure they are considered llm api routes
|
||||
for route in all_llm_api_routes:
|
||||
|
|
|
@ -36,11 +36,11 @@ def test_initialize_deployment_for_pass_through_success():
|
|||
)
|
||||
|
||||
# Verify the credentials were properly set
|
||||
from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import (
|
||||
vertex_pass_through_router,
|
||||
from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import (
|
||||
passthrough_endpoint_router,
|
||||
)
|
||||
|
||||
vertex_creds = vertex_pass_through_router.get_vertex_credentials(
|
||||
vertex_creds = passthrough_endpoint_router.get_vertex_credentials(
|
||||
project_id="test-project", location="us-central1"
|
||||
)
|
||||
assert vertex_creds.vertex_project == "test-project"
|
||||
|
@ -123,21 +123,21 @@ def test_add_vertex_pass_through_deployment():
|
|||
router.add_deployment(deployment)
|
||||
|
||||
# Get the vertex credentials from the router
|
||||
from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import (
|
||||
vertex_pass_through_router,
|
||||
from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import (
|
||||
passthrough_endpoint_router,
|
||||
)
|
||||
|
||||
# current state of pass-through vertex router
|
||||
print("\n vertex_pass_through_router.deployment_key_to_vertex_credentials\n\n")
|
||||
print(
|
||||
json.dumps(
|
||||
vertex_pass_through_router.deployment_key_to_vertex_credentials,
|
||||
passthrough_endpoint_router.deployment_key_to_vertex_credentials,
|
||||
indent=4,
|
||||
default=str,
|
||||
)
|
||||
)
|
||||
|
||||
vertex_creds = vertex_pass_through_router.get_vertex_credentials(
|
||||
vertex_creds = passthrough_endpoint_router.get_vertex_credentials(
|
||||
project_id="test-project", location="us-central1"
|
||||
)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue