mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
Merge branch 'main' into litellm_web_search_2
This commit is contained in:
commit
3a454d00df
19 changed files with 1099 additions and 731 deletions
|
@ -15,6 +15,91 @@ Pass-through endpoints for Vertex AI - call provider-specific endpoint, in nativ
|
||||||
|
|
||||||
Just replace `https://REGION-aiplatform.googleapis.com` with `LITELLM_PROXY_BASE_URL/vertex_ai`
|
Just replace `https://REGION-aiplatform.googleapis.com` with `LITELLM_PROXY_BASE_URL/vertex_ai`
|
||||||
|
|
||||||
|
LiteLLM supports 3 flows for calling Vertex AI endpoints via pass-through:
|
||||||
|
|
||||||
|
1. **Specific Credentials**: Admin sets passthrough credentials for a specific project/region.
|
||||||
|
|
||||||
|
2. **Default Credentials**: Admin sets default credentials.
|
||||||
|
|
||||||
|
3. **Client-Side Credentials**: User can send client-side credentials through to Vertex AI (default behavior - if no default or mapped credentials are found, the request is passed through directly).
|
||||||
|
|
||||||
|
|
||||||
|
## Example Usage
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="specific_credentials" label="Specific Project/Region">
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: gemini-1.0-pro
|
||||||
|
litellm_params:
|
||||||
|
model: vertex_ai/gemini-1.0-pro
|
||||||
|
vertex_project: adroit-crow-413218
|
||||||
|
vertex_region: us-central1
|
||||||
|
vertex_credentials: /path/to/credentials.json
|
||||||
|
use_in_pass_through: true # 👈 KEY CHANGE
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="default_credentials" label="Default Credentials">
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="yaml" label="Set in config.yaml">
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
default_vertex_config:
|
||||||
|
vertex_project: adroit-crow-413218
|
||||||
|
vertex_region: us-central1
|
||||||
|
vertex_credentials: /path/to/credentials.json
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="env_var" label="Set in environment variables">
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export DEFAULT_VERTEXAI_PROJECT="adroit-crow-413218"
|
||||||
|
export DEFAULT_VERTEXAI_LOCATION="us-central1"
|
||||||
|
export DEFAULT_GOOGLE_APPLICATION_CREDENTIALS="/path/to/credentials.json"
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="client_credentials" label="Client Credentials">
|
||||||
|
|
||||||
|
Try Gemini 2.0 Flash (curl)
|
||||||
|
|
||||||
|
```
|
||||||
|
MODEL_ID="gemini-2.0-flash-001"
|
||||||
|
PROJECT_ID="YOUR_PROJECT_ID"
|
||||||
|
```
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl \
|
||||||
|
-X POST \
|
||||||
|
-H "Authorization: Bearer $(gcloud auth application-default print-access-token)" \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
"${LITELLM_PROXY_BASE_URL}/vertex_ai/v1/projects/${PROJECT_ID}/locations/us-central1/publishers/google/models/${MODEL_ID}:streamGenerateContent" -d \
|
||||||
|
$'{
|
||||||
|
"contents": {
|
||||||
|
"role": "user",
|
||||||
|
"parts": [
|
||||||
|
{
|
||||||
|
"fileData": {
|
||||||
|
"mimeType": "image/png",
|
||||||
|
"fileUri": "gs://generativeai-downloads/images/scones.jpg"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"text": "Describe this picture."
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
|
||||||
#### **Example Usage**
|
#### **Example Usage**
|
||||||
|
|
||||||
|
@ -22,7 +107,7 @@ Just replace `https://REGION-aiplatform.googleapis.com` with `LITELLM_PROXY_BASE
|
||||||
<TabItem value="curl" label="curl">
|
<TabItem value="curl" label="curl">
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
curl http://localhost:4000/vertex_ai/publishers/google/models/gemini-1.0-pro:generateContent \
|
curl http://localhost:4000/vertex_ai/vertex_ai/v1/projects/${PROJECT_ID}/locations/us-central1/publishers/google/models/${MODEL_ID}:generateContent \
|
||||||
-H "Content-Type: application/json" \
|
-H "Content-Type: application/json" \
|
||||||
-H "x-litellm-api-key: Bearer sk-1234" \
|
-H "x-litellm-api-key: Bearer sk-1234" \
|
||||||
-d '{
|
-d '{
|
||||||
|
@ -101,7 +186,7 @@ litellm
|
||||||
Let's call the Google AI Studio token counting endpoint
|
Let's call the Google AI Studio token counting endpoint
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
curl http://localhost:4000/vertex-ai/publishers/google/models/gemini-1.0-pro:generateContent \
|
curl http://localhost:4000/vertex-ai/v1/projects/${PROJECT_ID}/locations/us-central1/publishers/google/models/gemini-1.0-pro:generateContent \
|
||||||
-H "Content-Type: application/json" \
|
-H "Content-Type: application/json" \
|
||||||
-H "Authorization: Bearer sk-1234" \
|
-H "Authorization: Bearer sk-1234" \
|
||||||
-d '{
|
-d '{
|
||||||
|
@ -140,7 +225,7 @@ LiteLLM Proxy Server supports two methods of authentication to Vertex AI:
|
||||||
|
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl http://localhost:4000/vertex_ai/publishers/google/models/gemini-1.5-flash-001:generateContent \
|
curl http://localhost:4000/vertex_ai/v1/projects/${PROJECT_ID}/locations/us-central1/publishers/google/models/gemini-1.5-flash-001:generateContent \
|
||||||
-H "Content-Type: application/json" \
|
-H "Content-Type: application/json" \
|
||||||
-H "x-litellm-api-key: Bearer sk-1234" \
|
-H "x-litellm-api-key: Bearer sk-1234" \
|
||||||
-d '{"contents":[{"role": "user", "parts":[{"text": "hi"}]}]}'
|
-d '{"contents":[{"role": "user", "parts":[{"text": "hi"}]}]}'
|
||||||
|
@ -152,7 +237,7 @@ curl http://localhost:4000/vertex_ai/publishers/google/models/gemini-1.5-flash-0
|
||||||
|
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl http://localhost:4000/vertex_ai/publishers/google/models/textembedding-gecko@001:predict \
|
curl http://localhost:4000/vertex_ai/v1/projects/${PROJECT_ID}/locations/us-central1/publishers/google/models/textembedding-gecko@001:predict \
|
||||||
-H "Content-Type: application/json" \
|
-H "Content-Type: application/json" \
|
||||||
-H "x-litellm-api-key: Bearer sk-1234" \
|
-H "x-litellm-api-key: Bearer sk-1234" \
|
||||||
-d '{"instances":[{"content": "gm"}]}'
|
-d '{"instances":[{"content": "gm"}]}'
|
||||||
|
@ -162,7 +247,7 @@ curl http://localhost:4000/vertex_ai/publishers/google/models/textembedding-geck
|
||||||
### Imagen API
|
### Imagen API
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl http://localhost:4000/vertex_ai/publishers/google/models/imagen-3.0-generate-001:predict \
|
curl http://localhost:4000/vertex_ai/v1/projects/${PROJECT_ID}/locations/us-central1/publishers/google/models/imagen-3.0-generate-001:predict \
|
||||||
-H "Content-Type: application/json" \
|
-H "Content-Type: application/json" \
|
||||||
-H "x-litellm-api-key: Bearer sk-1234" \
|
-H "x-litellm-api-key: Bearer sk-1234" \
|
||||||
-d '{"instances":[{"prompt": "make an otter"}], "parameters": {"sampleCount": 1}}'
|
-d '{"instances":[{"prompt": "make an otter"}], "parameters": {"sampleCount": 1}}'
|
||||||
|
@ -172,7 +257,7 @@ curl http://localhost:4000/vertex_ai/publishers/google/models/imagen-3.0-generat
|
||||||
### Count Tokens API
|
### Count Tokens API
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl http://localhost:4000/vertex_ai/publishers/google/models/gemini-1.5-flash-001:countTokens \
|
curl http://localhost:4000/vertex_ai/v1/projects/${PROJECT_ID}/locations/us-central1/publishers/google/models/gemini-1.5-flash-001:countTokens \
|
||||||
-H "Content-Type: application/json" \
|
-H "Content-Type: application/json" \
|
||||||
-H "x-litellm-api-key: Bearer sk-1234" \
|
-H "x-litellm-api-key: Bearer sk-1234" \
|
||||||
-d '{"contents":[{"role": "user", "parts":[{"text": "hi"}]}]}'
|
-d '{"contents":[{"role": "user", "parts":[{"text": "hi"}]}]}'
|
||||||
|
@ -183,7 +268,7 @@ Create Fine Tuning Job
|
||||||
|
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl http://localhost:4000/vertex_ai/tuningJobs \
|
curl http://localhost:4000/vertex_ai/v1/projects/${PROJECT_ID}/locations/us-central1/publishers/google/models/gemini-1.5-flash-001:tuningJobs \
|
||||||
-H "Content-Type: application/json" \
|
-H "Content-Type: application/json" \
|
||||||
-H "x-litellm-api-key: Bearer sk-1234" \
|
-H "x-litellm-api-key: Bearer sk-1234" \
|
||||||
-d '{
|
-d '{
|
||||||
|
@ -243,7 +328,7 @@ Expected Response
|
||||||
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
curl http://localhost:4000/vertex_ai/publishers/google/models/gemini-1.0-pro:generateContent \
|
curl http://localhost:4000/vertex_ai/v1/projects/${PROJECT_ID}/locations/us-central1/publishers/google/models/gemini-1.0-pro:generateContent \
|
||||||
-H "Content-Type: application/json" \
|
-H "Content-Type: application/json" \
|
||||||
-H "x-litellm-api-key: Bearer sk-1234" \
|
-H "x-litellm-api-key: Bearer sk-1234" \
|
||||||
-d '{
|
-d '{
|
||||||
|
@ -268,7 +353,7 @@ tags: ["vertex-js-sdk", "pass-through-endpoint"]
|
||||||
<TabItem value="curl" label="curl">
|
<TabItem value="curl" label="curl">
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
curl http://localhost:4000/vertex-ai/publishers/google/models/gemini-1.0-pro:generateContent \
|
curl http://localhost:4000/vertex_ai/v1/projects/${PROJECT_ID}/locations/us-central1/publishers/google/models/gemini-1.0-pro:generateContent \
|
||||||
-H "Content-Type: application/json" \
|
-H "Content-Type: application/json" \
|
||||||
-H "x-litellm-api-key: Bearer sk-1234" \
|
-H "x-litellm-api-key: Bearer sk-1234" \
|
||||||
-H "tags: vertex-js-sdk,pass-through-endpoint" \
|
-H "tags: vertex-js-sdk,pass-through-endpoint" \
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import re
|
||||||
from typing import Dict, List, Literal, Optional, Tuple, Union
|
from typing import Dict, List, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
@ -280,3 +281,81 @@ def _convert_vertex_datetime_to_openai_datetime(vertex_datetime: str) -> int:
|
||||||
dt = datetime.strptime(vertex_datetime, "%Y-%m-%dT%H:%M:%S.%fZ")
|
dt = datetime.strptime(vertex_datetime, "%Y-%m-%dT%H:%M:%S.%fZ")
|
||||||
# Convert to Unix timestamp (seconds since epoch)
|
# Convert to Unix timestamp (seconds since epoch)
|
||||||
return int(dt.timestamp())
|
return int(dt.timestamp())
|
||||||
|
|
||||||
|
|
||||||
|
def get_vertex_project_id_from_url(url: str) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Get the vertex project id from the url
|
||||||
|
|
||||||
|
`https://${LOCATION}-aiplatform.googleapis.com/v1/projects/${PROJECT_ID}/locations/${LOCATION}/publishers/google/models/${MODEL_ID}:streamGenerateContent`
|
||||||
|
"""
|
||||||
|
match = re.search(r"/projects/([^/]+)", url)
|
||||||
|
return match.group(1) if match else None
|
||||||
|
|
||||||
|
|
||||||
|
def get_vertex_location_from_url(url: str) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Get the vertex location from the url
|
||||||
|
|
||||||
|
`https://${LOCATION}-aiplatform.googleapis.com/v1/projects/${PROJECT_ID}/locations/${LOCATION}/publishers/google/models/${MODEL_ID}:streamGenerateContent`
|
||||||
|
"""
|
||||||
|
match = re.search(r"/locations/([^/]+)", url)
|
||||||
|
return match.group(1) if match else None
|
||||||
|
|
||||||
|
|
||||||
|
def replace_project_and_location_in_route(
|
||||||
|
requested_route: str, vertex_project: str, vertex_location: str
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Replace project and location values in the route with the provided values
|
||||||
|
"""
|
||||||
|
# Replace project and location values while keeping route structure
|
||||||
|
modified_route = re.sub(
|
||||||
|
r"/projects/[^/]+/locations/[^/]+/",
|
||||||
|
f"/projects/{vertex_project}/locations/{vertex_location}/",
|
||||||
|
requested_route,
|
||||||
|
)
|
||||||
|
return modified_route
|
||||||
|
|
||||||
|
|
||||||
|
def construct_target_url(
|
||||||
|
base_url: str,
|
||||||
|
requested_route: str,
|
||||||
|
vertex_location: Optional[str],
|
||||||
|
vertex_project: Optional[str],
|
||||||
|
) -> httpx.URL:
|
||||||
|
"""
|
||||||
|
Allow user to specify their own project id / location.
|
||||||
|
|
||||||
|
If missing, use defaults
|
||||||
|
|
||||||
|
Handle cachedContent scenario - https://github.com/BerriAI/litellm/issues/5460
|
||||||
|
|
||||||
|
Constructed Url:
|
||||||
|
POST https://LOCATION-aiplatform.googleapis.com/{version}/projects/PROJECT_ID/locations/LOCATION/cachedContents
|
||||||
|
"""
|
||||||
|
new_base_url = httpx.URL(base_url)
|
||||||
|
if "locations" in requested_route: # contains the target project id + location
|
||||||
|
if vertex_project and vertex_location:
|
||||||
|
requested_route = replace_project_and_location_in_route(
|
||||||
|
requested_route, vertex_project, vertex_location
|
||||||
|
)
|
||||||
|
return new_base_url.copy_with(path=requested_route)
|
||||||
|
|
||||||
|
"""
|
||||||
|
- Add endpoint version (e.g. v1beta for cachedContent, v1 for rest)
|
||||||
|
- Add default project id
|
||||||
|
- Add default location
|
||||||
|
"""
|
||||||
|
vertex_version: Literal["v1", "v1beta1"] = "v1"
|
||||||
|
if "cachedContent" in requested_route:
|
||||||
|
vertex_version = "v1beta1"
|
||||||
|
|
||||||
|
base_requested_route = "{}/projects/{}/locations/{}".format(
|
||||||
|
vertex_version, vertex_project, vertex_location
|
||||||
|
)
|
||||||
|
|
||||||
|
updated_requested_route = "/" + base_requested_route + requested_route
|
||||||
|
|
||||||
|
updated_url = new_base_url.copy_with(path=updated_requested_route)
|
||||||
|
return updated_url
|
||||||
|
|
|
@ -82,6 +82,31 @@
|
||||||
"supports_system_messages": true,
|
"supports_system_messages": true,
|
||||||
"supports_tool_choice": true
|
"supports_tool_choice": true
|
||||||
},
|
},
|
||||||
|
"gpt-4o-search-preview-2025-03-11": {
|
||||||
|
"max_tokens": 16384,
|
||||||
|
"max_input_tokens": 128000,
|
||||||
|
"max_output_tokens": 16384,
|
||||||
|
"input_cost_per_token": 0.0000025,
|
||||||
|
"output_cost_per_token": 0.000010,
|
||||||
|
"input_cost_per_token_batches": 0.00000125,
|
||||||
|
"output_cost_per_token_batches": 0.00000500,
|
||||||
|
"cache_read_input_token_cost": 0.00000125,
|
||||||
|
"litellm_provider": "openai",
|
||||||
|
"mode": "chat",
|
||||||
|
"supports_function_calling": true,
|
||||||
|
"supports_parallel_function_calling": true,
|
||||||
|
"supports_response_schema": true,
|
||||||
|
"supports_vision": true,
|
||||||
|
"supports_prompt_caching": true,
|
||||||
|
"supports_system_messages": true,
|
||||||
|
"supports_tool_choice": true,
|
||||||
|
"supports_web_search": true,
|
||||||
|
"search_context_cost_per_query": {
|
||||||
|
"search_context_size_low": 0.030,
|
||||||
|
"search_context_size_medium": 0.035,
|
||||||
|
"search_context_size_high": 0.050
|
||||||
|
}
|
||||||
|
},
|
||||||
"gpt-4o-search-preview": {
|
"gpt-4o-search-preview": {
|
||||||
"max_tokens": 16384,
|
"max_tokens": 16384,
|
||||||
"max_input_tokens": 128000,
|
"max_input_tokens": 128000,
|
||||||
|
@ -232,6 +257,31 @@
|
||||||
"supports_system_messages": true,
|
"supports_system_messages": true,
|
||||||
"supports_tool_choice": true
|
"supports_tool_choice": true
|
||||||
},
|
},
|
||||||
|
"gpt-4o-mini-search-preview-2025-03-11":{
|
||||||
|
"max_tokens": 16384,
|
||||||
|
"max_input_tokens": 128000,
|
||||||
|
"max_output_tokens": 16384,
|
||||||
|
"input_cost_per_token": 0.00000015,
|
||||||
|
"output_cost_per_token": 0.00000060,
|
||||||
|
"input_cost_per_token_batches": 0.000000075,
|
||||||
|
"output_cost_per_token_batches": 0.00000030,
|
||||||
|
"cache_read_input_token_cost": 0.000000075,
|
||||||
|
"litellm_provider": "openai",
|
||||||
|
"mode": "chat",
|
||||||
|
"supports_function_calling": true,
|
||||||
|
"supports_parallel_function_calling": true,
|
||||||
|
"supports_response_schema": true,
|
||||||
|
"supports_vision": true,
|
||||||
|
"supports_prompt_caching": true,
|
||||||
|
"supports_system_messages": true,
|
||||||
|
"supports_tool_choice": true,
|
||||||
|
"supports_web_search": true,
|
||||||
|
"search_context_cost_per_query": {
|
||||||
|
"search_context_size_low": 0.025,
|
||||||
|
"search_context_size_medium": 0.0275,
|
||||||
|
"search_context_size_high": 0.030
|
||||||
|
}
|
||||||
|
},
|
||||||
"gpt-4o-mini-search-preview": {
|
"gpt-4o-mini-search-preview": {
|
||||||
"max_tokens": 16384,
|
"max_tokens": 16384,
|
||||||
"max_input_tokens": 128000,
|
"max_input_tokens": 128000,
|
||||||
|
|
16
litellm/proxy/pass_through_endpoints/common_utils.py
Normal file
16
litellm/proxy/pass_through_endpoints/common_utils.py
Normal file
|
@ -0,0 +1,16 @@
|
||||||
|
from fastapi import Request
|
||||||
|
|
||||||
|
|
||||||
|
def get_litellm_virtual_key(request: Request) -> str:
|
||||||
|
"""
|
||||||
|
Extract and format API key from request headers.
|
||||||
|
Prioritizes x-litellm-api-key over Authorization header.
|
||||||
|
|
||||||
|
|
||||||
|
Vertex JS SDK uses `Authorization` header, we use `x-litellm-api-key` to pass litellm virtual key
|
||||||
|
|
||||||
|
"""
|
||||||
|
litellm_api_key = request.headers.get("x-litellm-api-key")
|
||||||
|
if litellm_api_key:
|
||||||
|
return f"Bearer {litellm_api_key}"
|
||||||
|
return request.headers.get("Authorization", "")
|
|
@ -12,10 +12,13 @@ import httpx
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response
|
from fastapi import APIRouter, Depends, HTTPException, Request, Response
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
from litellm._logging import verbose_proxy_logger
|
||||||
from litellm.constants import BEDROCK_AGENT_RUNTIME_PASS_THROUGH_ROUTES
|
from litellm.constants import BEDROCK_AGENT_RUNTIME_PASS_THROUGH_ROUTES
|
||||||
|
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
|
||||||
from litellm.proxy._types import *
|
from litellm.proxy._types import *
|
||||||
from litellm.proxy.auth.route_checks import RouteChecks
|
from litellm.proxy.auth.route_checks import RouteChecks
|
||||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||||
|
from litellm.proxy.pass_through_endpoints.common_utils import get_litellm_virtual_key
|
||||||
from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
|
from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
|
||||||
create_pass_through_route,
|
create_pass_through_route,
|
||||||
)
|
)
|
||||||
|
@ -23,6 +26,7 @@ from litellm.secret_managers.main import get_secret_str
|
||||||
|
|
||||||
from .passthrough_endpoint_router import PassthroughEndpointRouter
|
from .passthrough_endpoint_router import PassthroughEndpointRouter
|
||||||
|
|
||||||
|
vertex_llm_base = VertexBase()
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
default_vertex_config = None
|
default_vertex_config = None
|
||||||
|
|
||||||
|
@ -417,6 +421,138 @@ async def azure_proxy_route(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.api_route(
|
||||||
|
"/vertex-ai/{endpoint:path}",
|
||||||
|
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
|
||||||
|
tags=["Vertex AI Pass-through", "pass-through"],
|
||||||
|
include_in_schema=False,
|
||||||
|
)
|
||||||
|
@router.api_route(
|
||||||
|
"/vertex_ai/{endpoint:path}",
|
||||||
|
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
|
||||||
|
tags=["Vertex AI Pass-through", "pass-through"],
|
||||||
|
)
|
||||||
|
async def vertex_proxy_route(
|
||||||
|
endpoint: str,
|
||||||
|
request: Request,
|
||||||
|
fastapi_response: Response,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Call LiteLLM proxy via Vertex AI SDK.
|
||||||
|
|
||||||
|
[Docs](https://docs.litellm.ai/docs/pass_through/vertex_ai)
|
||||||
|
"""
|
||||||
|
from litellm.llms.vertex_ai.common_utils import (
|
||||||
|
construct_target_url,
|
||||||
|
get_vertex_location_from_url,
|
||||||
|
get_vertex_project_id_from_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
encoded_endpoint = httpx.URL(endpoint).path
|
||||||
|
verbose_proxy_logger.debug("requested endpoint %s", endpoint)
|
||||||
|
headers: dict = {}
|
||||||
|
api_key_to_use = get_litellm_virtual_key(request=request)
|
||||||
|
user_api_key_dict = await user_api_key_auth(
|
||||||
|
request=request,
|
||||||
|
api_key=api_key_to_use,
|
||||||
|
)
|
||||||
|
vertex_project: Optional[str] = get_vertex_project_id_from_url(endpoint)
|
||||||
|
vertex_location: Optional[str] = get_vertex_location_from_url(endpoint)
|
||||||
|
vertex_credentials = passthrough_endpoint_router.get_vertex_credentials(
|
||||||
|
project_id=vertex_project,
|
||||||
|
location=vertex_location,
|
||||||
|
)
|
||||||
|
|
||||||
|
headers_passed_through = False
|
||||||
|
# Use headers from the incoming request if no vertex credentials are found
|
||||||
|
if vertex_credentials is None or vertex_credentials.vertex_project is None:
|
||||||
|
headers = dict(request.headers) or {}
|
||||||
|
headers_passed_through = True
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
"default_vertex_config not set, incoming request headers %s", headers
|
||||||
|
)
|
||||||
|
base_target_url = f"https://{vertex_location}-aiplatform.googleapis.com/"
|
||||||
|
headers.pop("content-length", None)
|
||||||
|
headers.pop("host", None)
|
||||||
|
else:
|
||||||
|
vertex_project = vertex_credentials.vertex_project
|
||||||
|
vertex_location = vertex_credentials.vertex_location
|
||||||
|
vertex_credentials_str = vertex_credentials.vertex_credentials
|
||||||
|
|
||||||
|
# Construct base URL for the target endpoint
|
||||||
|
base_target_url = f"https://{vertex_location}-aiplatform.googleapis.com/"
|
||||||
|
|
||||||
|
_auth_header, vertex_project = await vertex_llm_base._ensure_access_token_async(
|
||||||
|
credentials=vertex_credentials_str,
|
||||||
|
project_id=vertex_project,
|
||||||
|
custom_llm_provider="vertex_ai_beta",
|
||||||
|
)
|
||||||
|
|
||||||
|
auth_header, _ = vertex_llm_base._get_token_and_url(
|
||||||
|
model="",
|
||||||
|
auth_header=_auth_header,
|
||||||
|
gemini_api_key=None,
|
||||||
|
vertex_credentials=vertex_credentials_str,
|
||||||
|
vertex_project=vertex_project,
|
||||||
|
vertex_location=vertex_location,
|
||||||
|
stream=False,
|
||||||
|
custom_llm_provider="vertex_ai_beta",
|
||||||
|
api_base="",
|
||||||
|
)
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {auth_header}",
|
||||||
|
}
|
||||||
|
|
||||||
|
request_route = encoded_endpoint
|
||||||
|
verbose_proxy_logger.debug("request_route %s", request_route)
|
||||||
|
|
||||||
|
# Ensure endpoint starts with '/' for proper URL construction
|
||||||
|
if not encoded_endpoint.startswith("/"):
|
||||||
|
encoded_endpoint = "/" + encoded_endpoint
|
||||||
|
|
||||||
|
# Construct the full target URL using httpx
|
||||||
|
updated_url = construct_target_url(
|
||||||
|
base_url=base_target_url,
|
||||||
|
requested_route=encoded_endpoint,
|
||||||
|
vertex_location=vertex_location,
|
||||||
|
vertex_project=vertex_project,
|
||||||
|
)
|
||||||
|
|
||||||
|
verbose_proxy_logger.debug("updated url %s", updated_url)
|
||||||
|
|
||||||
|
## check for streaming
|
||||||
|
target = str(updated_url)
|
||||||
|
is_streaming_request = False
|
||||||
|
if "stream" in str(updated_url):
|
||||||
|
is_streaming_request = True
|
||||||
|
target += "?alt=sse"
|
||||||
|
|
||||||
|
## CREATE PASS-THROUGH
|
||||||
|
endpoint_func = create_pass_through_route(
|
||||||
|
endpoint=endpoint,
|
||||||
|
target=target,
|
||||||
|
custom_headers=headers,
|
||||||
|
) # dynamically construct pass-through endpoint based on incoming path
|
||||||
|
|
||||||
|
try:
|
||||||
|
received_value = await endpoint_func(
|
||||||
|
request,
|
||||||
|
fastapi_response,
|
||||||
|
user_api_key_dict,
|
||||||
|
stream=is_streaming_request, # type: ignore
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
if headers_passed_through:
|
||||||
|
raise Exception(
|
||||||
|
f"No credentials found on proxy for this request. Headers were passed through directly but request failed with error: {str(e)}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
return received_value
|
||||||
|
|
||||||
|
|
||||||
@router.api_route(
|
@router.api_route(
|
||||||
"/openai/{endpoint:path}",
|
"/openai/{endpoint:path}",
|
||||||
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
|
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
|
||||||
|
|
|
@ -1,7 +1,9 @@
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
from litellm._logging import verbose_logger
|
from litellm._logging import verbose_router_logger
|
||||||
from litellm.secret_managers.main import get_secret_str
|
from litellm.secret_managers.main import get_secret_str
|
||||||
|
from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
|
||||||
|
from litellm.types.passthrough_endpoints.vertex_ai import VertexPassThroughCredentials
|
||||||
|
|
||||||
|
|
||||||
class PassthroughEndpointRouter:
|
class PassthroughEndpointRouter:
|
||||||
|
@ -11,6 +13,10 @@ class PassthroughEndpointRouter:
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.credentials: Dict[str, str] = {}
|
self.credentials: Dict[str, str] = {}
|
||||||
|
self.deployment_key_to_vertex_credentials: Dict[
|
||||||
|
str, VertexPassThroughCredentials
|
||||||
|
] = {}
|
||||||
|
self.default_vertex_config: Optional[VertexPassThroughCredentials] = None
|
||||||
|
|
||||||
def set_pass_through_credentials(
|
def set_pass_through_credentials(
|
||||||
self,
|
self,
|
||||||
|
@ -45,14 +51,14 @@ class PassthroughEndpointRouter:
|
||||||
custom_llm_provider=custom_llm_provider,
|
custom_llm_provider=custom_llm_provider,
|
||||||
region_name=region_name,
|
region_name=region_name,
|
||||||
)
|
)
|
||||||
verbose_logger.debug(
|
verbose_router_logger.debug(
|
||||||
f"Pass-through llm endpoints router, looking for credentials for {credential_name}"
|
f"Pass-through llm endpoints router, looking for credentials for {credential_name}"
|
||||||
)
|
)
|
||||||
if credential_name in self.credentials:
|
if credential_name in self.credentials:
|
||||||
verbose_logger.debug(f"Found credentials for {credential_name}")
|
verbose_router_logger.debug(f"Found credentials for {credential_name}")
|
||||||
return self.credentials[credential_name]
|
return self.credentials[credential_name]
|
||||||
else:
|
else:
|
||||||
verbose_logger.debug(
|
verbose_router_logger.debug(
|
||||||
f"No credentials found for {credential_name}, looking for env variable"
|
f"No credentials found for {credential_name}, looking for env variable"
|
||||||
)
|
)
|
||||||
_env_variable_name = (
|
_env_variable_name = (
|
||||||
|
@ -62,6 +68,100 @@ class PassthroughEndpointRouter:
|
||||||
)
|
)
|
||||||
return get_secret_str(_env_variable_name)
|
return get_secret_str(_env_variable_name)
|
||||||
|
|
||||||
|
def _get_vertex_env_vars(self) -> VertexPassThroughCredentials:
|
||||||
|
"""
|
||||||
|
Helper to get vertex pass through config from environment variables
|
||||||
|
|
||||||
|
The following environment variables are used:
|
||||||
|
- DEFAULT_VERTEXAI_PROJECT (project id)
|
||||||
|
- DEFAULT_VERTEXAI_LOCATION (location)
|
||||||
|
- DEFAULT_GOOGLE_APPLICATION_CREDENTIALS (path to credentials file)
|
||||||
|
"""
|
||||||
|
return VertexPassThroughCredentials(
|
||||||
|
vertex_project=get_secret_str("DEFAULT_VERTEXAI_PROJECT"),
|
||||||
|
vertex_location=get_secret_str("DEFAULT_VERTEXAI_LOCATION"),
|
||||||
|
vertex_credentials=get_secret_str("DEFAULT_GOOGLE_APPLICATION_CREDENTIALS"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def set_default_vertex_config(self, config: Optional[dict] = None):
|
||||||
|
"""Sets vertex configuration from provided config and/or environment variables
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (Optional[dict]): Configuration dictionary
|
||||||
|
Example: {
|
||||||
|
"vertex_project": "my-project-123",
|
||||||
|
"vertex_location": "us-central1",
|
||||||
|
"vertex_credentials": "os.environ/GOOGLE_CREDS"
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
# Initialize config dictionary if None
|
||||||
|
if config is None:
|
||||||
|
self.default_vertex_config = self._get_vertex_env_vars()
|
||||||
|
return
|
||||||
|
|
||||||
|
if isinstance(config, dict):
|
||||||
|
for key, value in config.items():
|
||||||
|
if isinstance(value, str) and value.startswith("os.environ/"):
|
||||||
|
config[key] = get_secret_str(value)
|
||||||
|
|
||||||
|
self.default_vertex_config = VertexPassThroughCredentials(**config)
|
||||||
|
|
||||||
|
def add_vertex_credentials(
|
||||||
|
self,
|
||||||
|
project_id: str,
|
||||||
|
location: str,
|
||||||
|
vertex_credentials: VERTEX_CREDENTIALS_TYPES,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Add the vertex credentials for the given project-id, location
|
||||||
|
"""
|
||||||
|
|
||||||
|
deployment_key = self._get_deployment_key(
|
||||||
|
project_id=project_id,
|
||||||
|
location=location,
|
||||||
|
)
|
||||||
|
if deployment_key is None:
|
||||||
|
verbose_router_logger.debug(
|
||||||
|
"No deployment key found for project-id, location"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
vertex_pass_through_credentials = VertexPassThroughCredentials(
|
||||||
|
vertex_project=project_id,
|
||||||
|
vertex_location=location,
|
||||||
|
vertex_credentials=vertex_credentials,
|
||||||
|
)
|
||||||
|
self.deployment_key_to_vertex_credentials[deployment_key] = (
|
||||||
|
vertex_pass_through_credentials
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_deployment_key(
|
||||||
|
self, project_id: Optional[str], location: Optional[str]
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Get the deployment key for the given project-id, location
|
||||||
|
"""
|
||||||
|
if project_id is None or location is None:
|
||||||
|
return None
|
||||||
|
return f"{project_id}-{location}"
|
||||||
|
|
||||||
|
def get_vertex_credentials(
|
||||||
|
self, project_id: Optional[str], location: Optional[str]
|
||||||
|
) -> Optional[VertexPassThroughCredentials]:
|
||||||
|
"""
|
||||||
|
Get the vertex credentials for the given project-id, location
|
||||||
|
"""
|
||||||
|
deployment_key = self._get_deployment_key(
|
||||||
|
project_id=project_id,
|
||||||
|
location=location,
|
||||||
|
)
|
||||||
|
|
||||||
|
if deployment_key is None:
|
||||||
|
return self.default_vertex_config
|
||||||
|
if deployment_key in self.deployment_key_to_vertex_credentials:
|
||||||
|
return self.deployment_key_to_vertex_credentials[deployment_key]
|
||||||
|
else:
|
||||||
|
return self.default_vertex_config
|
||||||
|
|
||||||
def _get_credential_name_for_provider(
|
def _get_credential_name_for_provider(
|
||||||
self,
|
self,
|
||||||
custom_llm_provider: str,
|
custom_llm_provider: str,
|
||||||
|
|
|
@ -235,6 +235,9 @@ from litellm.proxy.openai_files_endpoints.files_endpoints import (
|
||||||
router as openai_files_router,
|
router as openai_files_router,
|
||||||
)
|
)
|
||||||
from litellm.proxy.openai_files_endpoints.files_endpoints import set_files_config
|
from litellm.proxy.openai_files_endpoints.files_endpoints import set_files_config
|
||||||
|
from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import (
|
||||||
|
passthrough_endpoint_router,
|
||||||
|
)
|
||||||
from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import (
|
from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import (
|
||||||
router as llm_passthrough_router,
|
router as llm_passthrough_router,
|
||||||
)
|
)
|
||||||
|
@ -272,8 +275,6 @@ from litellm.proxy.utils import (
|
||||||
from litellm.proxy.vertex_ai_endpoints.langfuse_endpoints import (
|
from litellm.proxy.vertex_ai_endpoints.langfuse_endpoints import (
|
||||||
router as langfuse_router,
|
router as langfuse_router,
|
||||||
)
|
)
|
||||||
from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import router as vertex_router
|
|
||||||
from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import set_default_vertex_config
|
|
||||||
from litellm.router import (
|
from litellm.router import (
|
||||||
AssistantsTypedDict,
|
AssistantsTypedDict,
|
||||||
Deployment,
|
Deployment,
|
||||||
|
@ -2115,7 +2116,9 @@ class ProxyConfig:
|
||||||
|
|
||||||
## default config for vertex ai routes
|
## default config for vertex ai routes
|
||||||
default_vertex_config = config.get("default_vertex_config", None)
|
default_vertex_config = config.get("default_vertex_config", None)
|
||||||
set_default_vertex_config(config=default_vertex_config)
|
passthrough_endpoint_router.set_default_vertex_config(
|
||||||
|
config=default_vertex_config
|
||||||
|
)
|
||||||
|
|
||||||
## ROUTER SETTINGS (e.g. routing_strategy, ...)
|
## ROUTER SETTINGS (e.g. routing_strategy, ...)
|
||||||
router_settings = config.get("router_settings", None)
|
router_settings = config.get("router_settings", None)
|
||||||
|
@ -8161,7 +8164,6 @@ app.include_router(batches_router)
|
||||||
app.include_router(rerank_router)
|
app.include_router(rerank_router)
|
||||||
app.include_router(fine_tuning_router)
|
app.include_router(fine_tuning_router)
|
||||||
app.include_router(credential_router)
|
app.include_router(credential_router)
|
||||||
app.include_router(vertex_router)
|
|
||||||
app.include_router(llm_passthrough_router)
|
app.include_router(llm_passthrough_router)
|
||||||
app.include_router(anthropic_router)
|
app.include_router(anthropic_router)
|
||||||
app.include_router(langfuse_router)
|
app.include_router(langfuse_router)
|
||||||
|
|
|
@ -1,274 +0,0 @@
|
||||||
import traceback
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
from fastapi import APIRouter, HTTPException, Request, Response, status
|
|
||||||
|
|
||||||
import litellm
|
|
||||||
from litellm._logging import verbose_proxy_logger
|
|
||||||
from litellm.fine_tuning.main import vertex_fine_tuning_apis_instance
|
|
||||||
from litellm.proxy._types import *
|
|
||||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
|
||||||
from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
|
|
||||||
create_pass_through_route,
|
|
||||||
)
|
|
||||||
from litellm.secret_managers.main import get_secret_str
|
|
||||||
from litellm.types.passthrough_endpoints.vertex_ai import *
|
|
||||||
|
|
||||||
from .vertex_passthrough_router import VertexPassThroughRouter
|
|
||||||
|
|
||||||
router = APIRouter()
|
|
||||||
vertex_pass_through_router = VertexPassThroughRouter()
|
|
||||||
|
|
||||||
default_vertex_config: VertexPassThroughCredentials = VertexPassThroughCredentials()
|
|
||||||
|
|
||||||
|
|
||||||
def _get_vertex_env_vars() -> VertexPassThroughCredentials:
|
|
||||||
"""
|
|
||||||
Helper to get vertex pass through config from environment variables
|
|
||||||
|
|
||||||
The following environment variables are used:
|
|
||||||
- DEFAULT_VERTEXAI_PROJECT (project id)
|
|
||||||
- DEFAULT_VERTEXAI_LOCATION (location)
|
|
||||||
- DEFAULT_GOOGLE_APPLICATION_CREDENTIALS (path to credentials file)
|
|
||||||
"""
|
|
||||||
return VertexPassThroughCredentials(
|
|
||||||
vertex_project=get_secret_str("DEFAULT_VERTEXAI_PROJECT"),
|
|
||||||
vertex_location=get_secret_str("DEFAULT_VERTEXAI_LOCATION"),
|
|
||||||
vertex_credentials=get_secret_str("DEFAULT_GOOGLE_APPLICATION_CREDENTIALS"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def set_default_vertex_config(config: Optional[dict] = None):
|
|
||||||
"""Sets vertex configuration from provided config and/or environment variables
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config (Optional[dict]): Configuration dictionary
|
|
||||||
Example: {
|
|
||||||
"vertex_project": "my-project-123",
|
|
||||||
"vertex_location": "us-central1",
|
|
||||||
"vertex_credentials": "os.environ/GOOGLE_CREDS"
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
global default_vertex_config
|
|
||||||
|
|
||||||
# Initialize config dictionary if None
|
|
||||||
if config is None:
|
|
||||||
default_vertex_config = _get_vertex_env_vars()
|
|
||||||
return
|
|
||||||
|
|
||||||
if isinstance(config, dict):
|
|
||||||
for key, value in config.items():
|
|
||||||
if isinstance(value, str) and value.startswith("os.environ/"):
|
|
||||||
config[key] = litellm.get_secret(value)
|
|
||||||
|
|
||||||
_set_default_vertex_config(VertexPassThroughCredentials(**config))
|
|
||||||
|
|
||||||
|
|
||||||
def _set_default_vertex_config(
|
|
||||||
vertex_pass_through_credentials: VertexPassThroughCredentials,
|
|
||||||
):
|
|
||||||
global default_vertex_config
|
|
||||||
default_vertex_config = vertex_pass_through_credentials
|
|
||||||
|
|
||||||
|
|
||||||
def exception_handler(e: Exception):
|
|
||||||
verbose_proxy_logger.error(
|
|
||||||
"litellm.proxy.proxy_server.v1/projects/tuningJobs(): Exception occurred - {}".format(
|
|
||||||
str(e)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
verbose_proxy_logger.debug(traceback.format_exc())
|
|
||||||
if isinstance(e, HTTPException):
|
|
||||||
return ProxyException(
|
|
||||||
message=getattr(e, "message", str(e.detail)),
|
|
||||||
type=getattr(e, "type", "None"),
|
|
||||||
param=getattr(e, "param", "None"),
|
|
||||||
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
error_msg = f"{str(e)}"
|
|
||||||
return ProxyException(
|
|
||||||
message=getattr(e, "message", error_msg),
|
|
||||||
type=getattr(e, "type", "None"),
|
|
||||||
param=getattr(e, "param", "None"),
|
|
||||||
code=getattr(e, "status_code", 500),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def construct_target_url(
|
|
||||||
base_url: str,
|
|
||||||
requested_route: str,
|
|
||||||
default_vertex_location: Optional[str],
|
|
||||||
default_vertex_project: Optional[str],
|
|
||||||
) -> httpx.URL:
|
|
||||||
"""
|
|
||||||
Allow user to specify their own project id / location.
|
|
||||||
|
|
||||||
If missing, use defaults
|
|
||||||
|
|
||||||
Handle cachedContent scenario - https://github.com/BerriAI/litellm/issues/5460
|
|
||||||
|
|
||||||
Constructed Url:
|
|
||||||
POST https://LOCATION-aiplatform.googleapis.com/{version}/projects/PROJECT_ID/locations/LOCATION/cachedContents
|
|
||||||
"""
|
|
||||||
new_base_url = httpx.URL(base_url)
|
|
||||||
if "locations" in requested_route: # contains the target project id + location
|
|
||||||
updated_url = new_base_url.copy_with(path=requested_route)
|
|
||||||
return updated_url
|
|
||||||
"""
|
|
||||||
- Add endpoint version (e.g. v1beta for cachedContent, v1 for rest)
|
|
||||||
- Add default project id
|
|
||||||
- Add default location
|
|
||||||
"""
|
|
||||||
vertex_version: Literal["v1", "v1beta1"] = "v1"
|
|
||||||
if "cachedContent" in requested_route:
|
|
||||||
vertex_version = "v1beta1"
|
|
||||||
|
|
||||||
base_requested_route = "{}/projects/{}/locations/{}".format(
|
|
||||||
vertex_version, default_vertex_project, default_vertex_location
|
|
||||||
)
|
|
||||||
|
|
||||||
updated_requested_route = "/" + base_requested_route + requested_route
|
|
||||||
|
|
||||||
updated_url = new_base_url.copy_with(path=updated_requested_route)
|
|
||||||
return updated_url
|
|
||||||
|
|
||||||
|
|
||||||
@router.api_route(
|
|
||||||
"/vertex-ai/{endpoint:path}",
|
|
||||||
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
|
|
||||||
tags=["Vertex AI Pass-through", "pass-through"],
|
|
||||||
include_in_schema=False,
|
|
||||||
)
|
|
||||||
@router.api_route(
|
|
||||||
"/vertex_ai/{endpoint:path}",
|
|
||||||
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
|
|
||||||
tags=["Vertex AI Pass-through", "pass-through"],
|
|
||||||
)
|
|
||||||
async def vertex_proxy_route(
|
|
||||||
endpoint: str,
|
|
||||||
request: Request,
|
|
||||||
fastapi_response: Response,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Call LiteLLM proxy via Vertex AI SDK.
|
|
||||||
|
|
||||||
[Docs](https://docs.litellm.ai/docs/pass_through/vertex_ai)
|
|
||||||
"""
|
|
||||||
encoded_endpoint = httpx.URL(endpoint).path
|
|
||||||
verbose_proxy_logger.debug("requested endpoint %s", endpoint)
|
|
||||||
headers: dict = {}
|
|
||||||
api_key_to_use = get_litellm_virtual_key(request=request)
|
|
||||||
user_api_key_dict = await user_api_key_auth(
|
|
||||||
request=request,
|
|
||||||
api_key=api_key_to_use,
|
|
||||||
)
|
|
||||||
|
|
||||||
vertex_project: Optional[str] = (
|
|
||||||
VertexPassThroughRouter._get_vertex_project_id_from_url(endpoint)
|
|
||||||
)
|
|
||||||
vertex_location: Optional[str] = (
|
|
||||||
VertexPassThroughRouter._get_vertex_location_from_url(endpoint)
|
|
||||||
)
|
|
||||||
vertex_credentials = vertex_pass_through_router.get_vertex_credentials(
|
|
||||||
project_id=vertex_project,
|
|
||||||
location=vertex_location,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Use headers from the incoming request if no vertex credentials are found
|
|
||||||
if vertex_credentials.vertex_project is None:
|
|
||||||
headers = dict(request.headers) or {}
|
|
||||||
verbose_proxy_logger.debug(
|
|
||||||
"default_vertex_config not set, incoming request headers %s", headers
|
|
||||||
)
|
|
||||||
base_target_url = f"https://{vertex_location}-aiplatform.googleapis.com/"
|
|
||||||
headers.pop("content-length", None)
|
|
||||||
headers.pop("host", None)
|
|
||||||
else:
|
|
||||||
vertex_project = vertex_credentials.vertex_project
|
|
||||||
vertex_location = vertex_credentials.vertex_location
|
|
||||||
vertex_credentials_str = vertex_credentials.vertex_credentials
|
|
||||||
|
|
||||||
# Construct base URL for the target endpoint
|
|
||||||
base_target_url = f"https://{vertex_location}-aiplatform.googleapis.com/"
|
|
||||||
|
|
||||||
_auth_header, vertex_project = (
|
|
||||||
await vertex_fine_tuning_apis_instance._ensure_access_token_async(
|
|
||||||
credentials=vertex_credentials_str,
|
|
||||||
project_id=vertex_project,
|
|
||||||
custom_llm_provider="vertex_ai_beta",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
auth_header, _ = vertex_fine_tuning_apis_instance._get_token_and_url(
|
|
||||||
model="",
|
|
||||||
auth_header=_auth_header,
|
|
||||||
gemini_api_key=None,
|
|
||||||
vertex_credentials=vertex_credentials_str,
|
|
||||||
vertex_project=vertex_project,
|
|
||||||
vertex_location=vertex_location,
|
|
||||||
stream=False,
|
|
||||||
custom_llm_provider="vertex_ai_beta",
|
|
||||||
api_base="",
|
|
||||||
)
|
|
||||||
|
|
||||||
headers = {
|
|
||||||
"Authorization": f"Bearer {auth_header}",
|
|
||||||
}
|
|
||||||
|
|
||||||
request_route = encoded_endpoint
|
|
||||||
verbose_proxy_logger.debug("request_route %s", request_route)
|
|
||||||
|
|
||||||
# Ensure endpoint starts with '/' for proper URL construction
|
|
||||||
if not encoded_endpoint.startswith("/"):
|
|
||||||
encoded_endpoint = "/" + encoded_endpoint
|
|
||||||
|
|
||||||
# Construct the full target URL using httpx
|
|
||||||
updated_url = construct_target_url(
|
|
||||||
base_url=base_target_url,
|
|
||||||
requested_route=encoded_endpoint,
|
|
||||||
default_vertex_location=vertex_location,
|
|
||||||
default_vertex_project=vertex_project,
|
|
||||||
)
|
|
||||||
# base_url = httpx.URL(base_target_url)
|
|
||||||
# updated_url = base_url.copy_with(path=encoded_endpoint)
|
|
||||||
|
|
||||||
verbose_proxy_logger.debug("updated url %s", updated_url)
|
|
||||||
|
|
||||||
## check for streaming
|
|
||||||
target = str(updated_url)
|
|
||||||
is_streaming_request = False
|
|
||||||
if "stream" in str(updated_url):
|
|
||||||
is_streaming_request = True
|
|
||||||
target += "?alt=sse"
|
|
||||||
|
|
||||||
## CREATE PASS-THROUGH
|
|
||||||
endpoint_func = create_pass_through_route(
|
|
||||||
endpoint=endpoint,
|
|
||||||
target=target,
|
|
||||||
custom_headers=headers,
|
|
||||||
) # dynamically construct pass-through endpoint based on incoming path
|
|
||||||
received_value = await endpoint_func(
|
|
||||||
request,
|
|
||||||
fastapi_response,
|
|
||||||
user_api_key_dict,
|
|
||||||
stream=is_streaming_request, # type: ignore
|
|
||||||
)
|
|
||||||
|
|
||||||
return received_value
|
|
||||||
|
|
||||||
|
|
||||||
def get_litellm_virtual_key(request: Request) -> str:
|
|
||||||
"""
|
|
||||||
Extract and format API key from request headers.
|
|
||||||
Prioritizes x-litellm-api-key over Authorization header.
|
|
||||||
|
|
||||||
|
|
||||||
Vertex JS SDK uses `Authorization` header, we use `x-litellm-api-key` to pass litellm virtual key
|
|
||||||
|
|
||||||
"""
|
|
||||||
litellm_api_key = request.headers.get("x-litellm-api-key")
|
|
||||||
if litellm_api_key:
|
|
||||||
return f"Bearer {litellm_api_key}"
|
|
||||||
return request.headers.get("Authorization", "")
|
|
|
@ -1,121 +0,0 @@
|
||||||
import json
|
|
||||||
import re
|
|
||||||
from typing import Dict, Optional
|
|
||||||
|
|
||||||
from litellm._logging import verbose_proxy_logger
|
|
||||||
from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import (
|
|
||||||
VertexPassThroughCredentials,
|
|
||||||
)
|
|
||||||
from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
|
|
||||||
|
|
||||||
|
|
||||||
class VertexPassThroughRouter:
|
|
||||||
"""
|
|
||||||
Vertex Pass Through Router for Vertex AI pass-through endpoints
|
|
||||||
|
|
||||||
|
|
||||||
- if request specifies a project-id, location -> use credentials corresponding to the project-id, location
|
|
||||||
- if request does not specify a project-id, location -> use credentials corresponding to the DEFAULT_VERTEXAI_PROJECT, DEFAULT_VERTEXAI_LOCATION
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
"""
|
|
||||||
Initialize the VertexPassThroughRouter
|
|
||||||
Stores the vertex credentials for each deployment key
|
|
||||||
```
|
|
||||||
{
|
|
||||||
"project_id-location": VertexPassThroughCredentials,
|
|
||||||
"adroit-crow-us-central1": VertexPassThroughCredentials,
|
|
||||||
}
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
self.deployment_key_to_vertex_credentials: Dict[
|
|
||||||
str, VertexPassThroughCredentials
|
|
||||||
] = {}
|
|
||||||
pass
|
|
||||||
|
|
||||||
def get_vertex_credentials(
|
|
||||||
self, project_id: Optional[str], location: Optional[str]
|
|
||||||
) -> VertexPassThroughCredentials:
|
|
||||||
"""
|
|
||||||
Get the vertex credentials for the given project-id, location
|
|
||||||
"""
|
|
||||||
from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import (
|
|
||||||
default_vertex_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
deployment_key = self._get_deployment_key(
|
|
||||||
project_id=project_id,
|
|
||||||
location=location,
|
|
||||||
)
|
|
||||||
if deployment_key is None:
|
|
||||||
return default_vertex_config
|
|
||||||
if deployment_key in self.deployment_key_to_vertex_credentials:
|
|
||||||
return self.deployment_key_to_vertex_credentials[deployment_key]
|
|
||||||
else:
|
|
||||||
return default_vertex_config
|
|
||||||
|
|
||||||
def add_vertex_credentials(
|
|
||||||
self,
|
|
||||||
project_id: str,
|
|
||||||
location: str,
|
|
||||||
vertex_credentials: VERTEX_CREDENTIALS_TYPES,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Add the vertex credentials for the given project-id, location
|
|
||||||
"""
|
|
||||||
from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import (
|
|
||||||
_set_default_vertex_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
deployment_key = self._get_deployment_key(
|
|
||||||
project_id=project_id,
|
|
||||||
location=location,
|
|
||||||
)
|
|
||||||
if deployment_key is None:
|
|
||||||
verbose_proxy_logger.debug(
|
|
||||||
"No deployment key found for project-id, location"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
vertex_pass_through_credentials = VertexPassThroughCredentials(
|
|
||||||
vertex_project=project_id,
|
|
||||||
vertex_location=location,
|
|
||||||
vertex_credentials=vertex_credentials,
|
|
||||||
)
|
|
||||||
self.deployment_key_to_vertex_credentials[deployment_key] = (
|
|
||||||
vertex_pass_through_credentials
|
|
||||||
)
|
|
||||||
verbose_proxy_logger.debug(
|
|
||||||
f"self.deployment_key_to_vertex_credentials: {json.dumps(self.deployment_key_to_vertex_credentials, indent=4, default=str)}"
|
|
||||||
)
|
|
||||||
_set_default_vertex_config(vertex_pass_through_credentials)
|
|
||||||
|
|
||||||
def _get_deployment_key(
|
|
||||||
self, project_id: Optional[str], location: Optional[str]
|
|
||||||
) -> Optional[str]:
|
|
||||||
"""
|
|
||||||
Get the deployment key for the given project-id, location
|
|
||||||
"""
|
|
||||||
if project_id is None or location is None:
|
|
||||||
return None
|
|
||||||
return f"{project_id}-{location}"
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _get_vertex_project_id_from_url(url: str) -> Optional[str]:
|
|
||||||
"""
|
|
||||||
Get the vertex project id from the url
|
|
||||||
|
|
||||||
`https://${LOCATION}-aiplatform.googleapis.com/v1/projects/${PROJECT_ID}/locations/${LOCATION}/publishers/google/models/${MODEL_ID}:streamGenerateContent`
|
|
||||||
"""
|
|
||||||
match = re.search(r"/projects/([^/]+)", url)
|
|
||||||
return match.group(1) if match else None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _get_vertex_location_from_url(url: str) -> Optional[str]:
|
|
||||||
"""
|
|
||||||
Get the vertex location from the url
|
|
||||||
|
|
||||||
`https://${LOCATION}-aiplatform.googleapis.com/v1/projects/${PROJECT_ID}/locations/${LOCATION}/publishers/google/models/${MODEL_ID}:streamGenerateContent`
|
|
||||||
"""
|
|
||||||
match = re.search(r"/locations/([^/]+)", url)
|
|
||||||
return match.group(1) if match else None
|
|
|
@ -4495,11 +4495,11 @@ class Router:
|
||||||
Each provider uses diff .env vars for pass-through endpoints, this helper uses the deployment credentials to set the .env vars for pass-through endpoints
|
Each provider uses diff .env vars for pass-through endpoints, this helper uses the deployment credentials to set the .env vars for pass-through endpoints
|
||||||
"""
|
"""
|
||||||
if deployment.litellm_params.use_in_pass_through is True:
|
if deployment.litellm_params.use_in_pass_through is True:
|
||||||
if custom_llm_provider == "vertex_ai":
|
from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import (
|
||||||
from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import (
|
passthrough_endpoint_router,
|
||||||
vertex_pass_through_router,
|
)
|
||||||
)
|
|
||||||
|
|
||||||
|
if custom_llm_provider == "vertex_ai":
|
||||||
if (
|
if (
|
||||||
deployment.litellm_params.vertex_project is None
|
deployment.litellm_params.vertex_project is None
|
||||||
or deployment.litellm_params.vertex_location is None
|
or deployment.litellm_params.vertex_location is None
|
||||||
|
@ -4508,16 +4508,12 @@ class Router:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"vertex_project, vertex_location, and vertex_credentials must be set in litellm_params for pass-through endpoints"
|
"vertex_project, vertex_location, and vertex_credentials must be set in litellm_params for pass-through endpoints"
|
||||||
)
|
)
|
||||||
vertex_pass_through_router.add_vertex_credentials(
|
passthrough_endpoint_router.add_vertex_credentials(
|
||||||
project_id=deployment.litellm_params.vertex_project,
|
project_id=deployment.litellm_params.vertex_project,
|
||||||
location=deployment.litellm_params.vertex_location,
|
location=deployment.litellm_params.vertex_location,
|
||||||
vertex_credentials=deployment.litellm_params.vertex_credentials,
|
vertex_credentials=deployment.litellm_params.vertex_credentials,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import (
|
|
||||||
passthrough_endpoint_router,
|
|
||||||
)
|
|
||||||
|
|
||||||
passthrough_endpoint_router.set_pass_through_credentials(
|
passthrough_endpoint_router.set_pass_through_credentials(
|
||||||
custom_llm_provider=custom_llm_provider,
|
custom_llm_provider=custom_llm_provider,
|
||||||
api_base=deployment.litellm_params.api_base,
|
api_base=deployment.litellm_params.api_base,
|
||||||
|
|
|
@ -82,6 +82,31 @@
|
||||||
"supports_system_messages": true,
|
"supports_system_messages": true,
|
||||||
"supports_tool_choice": true
|
"supports_tool_choice": true
|
||||||
},
|
},
|
||||||
|
"gpt-4o-search-preview-2025-03-11": {
|
||||||
|
"max_tokens": 16384,
|
||||||
|
"max_input_tokens": 128000,
|
||||||
|
"max_output_tokens": 16384,
|
||||||
|
"input_cost_per_token": 0.0000025,
|
||||||
|
"output_cost_per_token": 0.000010,
|
||||||
|
"input_cost_per_token_batches": 0.00000125,
|
||||||
|
"output_cost_per_token_batches": 0.00000500,
|
||||||
|
"cache_read_input_token_cost": 0.00000125,
|
||||||
|
"litellm_provider": "openai",
|
||||||
|
"mode": "chat",
|
||||||
|
"supports_function_calling": true,
|
||||||
|
"supports_parallel_function_calling": true,
|
||||||
|
"supports_response_schema": true,
|
||||||
|
"supports_vision": true,
|
||||||
|
"supports_prompt_caching": true,
|
||||||
|
"supports_system_messages": true,
|
||||||
|
"supports_tool_choice": true,
|
||||||
|
"supports_web_search": true,
|
||||||
|
"search_context_cost_per_query": {
|
||||||
|
"search_context_size_low": 0.030,
|
||||||
|
"search_context_size_medium": 0.035,
|
||||||
|
"search_context_size_high": 0.050
|
||||||
|
}
|
||||||
|
},
|
||||||
"gpt-4o-search-preview": {
|
"gpt-4o-search-preview": {
|
||||||
"max_tokens": 16384,
|
"max_tokens": 16384,
|
||||||
"max_input_tokens": 128000,
|
"max_input_tokens": 128000,
|
||||||
|
@ -232,6 +257,31 @@
|
||||||
"supports_system_messages": true,
|
"supports_system_messages": true,
|
||||||
"supports_tool_choice": true
|
"supports_tool_choice": true
|
||||||
},
|
},
|
||||||
|
"gpt-4o-mini-search-preview-2025-03-11":{
|
||||||
|
"max_tokens": 16384,
|
||||||
|
"max_input_tokens": 128000,
|
||||||
|
"max_output_tokens": 16384,
|
||||||
|
"input_cost_per_token": 0.00000015,
|
||||||
|
"output_cost_per_token": 0.00000060,
|
||||||
|
"input_cost_per_token_batches": 0.000000075,
|
||||||
|
"output_cost_per_token_batches": 0.00000030,
|
||||||
|
"cache_read_input_token_cost": 0.000000075,
|
||||||
|
"litellm_provider": "openai",
|
||||||
|
"mode": "chat",
|
||||||
|
"supports_function_calling": true,
|
||||||
|
"supports_parallel_function_calling": true,
|
||||||
|
"supports_response_schema": true,
|
||||||
|
"supports_vision": true,
|
||||||
|
"supports_prompt_caching": true,
|
||||||
|
"supports_system_messages": true,
|
||||||
|
"supports_tool_choice": true,
|
||||||
|
"supports_web_search": true,
|
||||||
|
"search_context_cost_per_query": {
|
||||||
|
"search_context_size_low": 0.025,
|
||||||
|
"search_context_size_medium": 0.0275,
|
||||||
|
"search_context_size_high": 0.030
|
||||||
|
}
|
||||||
|
},
|
||||||
"gpt-4o-mini-search-preview": {
|
"gpt-4o-mini-search-preview": {
|
||||||
"max_tokens": 16384,
|
"max_tokens": 16384,
|
||||||
"max_input_tokens": 128000,
|
"max_input_tokens": 128000,
|
||||||
|
|
43
tests/litellm/llms/vertex_ai/test_vertex_ai_common_utils.py
Normal file
43
tests/litellm/llms/vertex_ai/test_vertex_ai_common_utils.py
Normal file
|
@ -0,0 +1,43 @@
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from unittest.mock import MagicMock, call, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm.llms.vertex_ai.common_utils import (
|
||||||
|
get_vertex_location_from_url,
|
||||||
|
get_vertex_project_id_from_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_vertex_project_id_from_url():
|
||||||
|
"""Test _get_vertex_project_id_from_url with various URLs"""
|
||||||
|
# Test with valid URL
|
||||||
|
url = "https://us-central1-aiplatform.googleapis.com/v1/projects/test-project/locations/us-central1/publishers/google/models/gemini-pro:streamGenerateContent"
|
||||||
|
project_id = get_vertex_project_id_from_url(url)
|
||||||
|
assert project_id == "test-project"
|
||||||
|
|
||||||
|
# Test with invalid URL
|
||||||
|
url = "https://invalid-url.com"
|
||||||
|
project_id = get_vertex_project_id_from_url(url)
|
||||||
|
assert project_id is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_vertex_location_from_url():
|
||||||
|
"""Test _get_vertex_location_from_url with various URLs"""
|
||||||
|
# Test with valid URL
|
||||||
|
url = "https://us-central1-aiplatform.googleapis.com/v1/projects/test-project/locations/us-central1/publishers/google/models/gemini-pro:streamGenerateContent"
|
||||||
|
location = get_vertex_location_from_url(url)
|
||||||
|
assert location == "us-central1"
|
||||||
|
|
||||||
|
# Test with invalid URL
|
||||||
|
url = "https://invalid-url.com"
|
||||||
|
location = get_vertex_location_from_url(url)
|
||||||
|
assert location is None
|
|
@ -1,7 +1,9 @@
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from unittest.mock import MagicMock, patch
|
import traceback
|
||||||
|
from unittest import mock
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import pytest
|
import pytest
|
||||||
|
@ -17,7 +19,9 @@ from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import (
|
||||||
BaseOpenAIPassThroughHandler,
|
BaseOpenAIPassThroughHandler,
|
||||||
RouteChecks,
|
RouteChecks,
|
||||||
create_pass_through_route,
|
create_pass_through_route,
|
||||||
|
vertex_proxy_route,
|
||||||
)
|
)
|
||||||
|
from litellm.types.passthrough_endpoints.vertex_ai import VertexPassThroughCredentials
|
||||||
|
|
||||||
|
|
||||||
class TestBaseOpenAIPassThroughHandler:
|
class TestBaseOpenAIPassThroughHandler:
|
||||||
|
@ -176,3 +180,279 @@ class TestBaseOpenAIPassThroughHandler:
|
||||||
print(f"query_params: {call_kwargs['query_params']}")
|
print(f"query_params: {call_kwargs['query_params']}")
|
||||||
assert call_kwargs["stream"] is False
|
assert call_kwargs["stream"] is False
|
||||||
assert call_kwargs["query_params"] == {"model": "gpt-4"}
|
assert call_kwargs["query_params"] == {"model": "gpt-4"}
|
||||||
|
|
||||||
|
|
||||||
|
class TestVertexAIPassThroughHandler:
|
||||||
|
"""
|
||||||
|
Case 1: User set passthrough credentials - confirm credentials used.
|
||||||
|
|
||||||
|
Case 2: User set default credentials, no exact passthrough credentials - confirm default credentials used.
|
||||||
|
|
||||||
|
Case 3: No default credentials, no mapped credentials - request passed through directly.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_vertex_passthrough_with_credentials(self, monkeypatch):
|
||||||
|
"""
|
||||||
|
Test that when passthrough credentials are set, they are correctly used in the request
|
||||||
|
"""
|
||||||
|
from litellm.proxy.pass_through_endpoints.passthrough_endpoint_router import (
|
||||||
|
PassthroughEndpointRouter,
|
||||||
|
)
|
||||||
|
|
||||||
|
vertex_project = "test-project"
|
||||||
|
vertex_location = "us-central1"
|
||||||
|
vertex_credentials = "test-creds"
|
||||||
|
|
||||||
|
pass_through_router = PassthroughEndpointRouter()
|
||||||
|
|
||||||
|
pass_through_router.add_vertex_credentials(
|
||||||
|
project_id=vertex_project,
|
||||||
|
location=vertex_location,
|
||||||
|
vertex_credentials=vertex_credentials,
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.passthrough_endpoint_router",
|
||||||
|
pass_through_router,
|
||||||
|
)
|
||||||
|
|
||||||
|
endpoint = f"/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/gemini-1.5-flash:generateContent"
|
||||||
|
|
||||||
|
# Mock request
|
||||||
|
mock_request = Request(
|
||||||
|
scope={
|
||||||
|
"type": "http",
|
||||||
|
"method": "POST",
|
||||||
|
"path": endpoint,
|
||||||
|
"headers": [
|
||||||
|
(b"Authorization", b"Bearer test-creds"),
|
||||||
|
(b"Content-Type", b"application/json"),
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock response
|
||||||
|
mock_response = Response()
|
||||||
|
|
||||||
|
# Mock vertex credentials
|
||||||
|
test_project = vertex_project
|
||||||
|
test_location = vertex_location
|
||||||
|
test_token = vertex_credentials
|
||||||
|
|
||||||
|
with mock.patch(
|
||||||
|
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.vertex_llm_base._ensure_access_token_async"
|
||||||
|
) as mock_ensure_token, mock.patch(
|
||||||
|
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.vertex_llm_base._get_token_and_url"
|
||||||
|
) as mock_get_token, mock.patch(
|
||||||
|
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.create_pass_through_route"
|
||||||
|
) as mock_create_route:
|
||||||
|
mock_ensure_token.return_value = ("test-auth-header", test_project)
|
||||||
|
mock_get_token.return_value = (test_token, "")
|
||||||
|
|
||||||
|
# Call the route
|
||||||
|
try:
|
||||||
|
await vertex_proxy_route(
|
||||||
|
endpoint=endpoint,
|
||||||
|
request=mock_request,
|
||||||
|
fastapi_response=mock_response,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error: {e}")
|
||||||
|
|
||||||
|
# Verify create_pass_through_route was called with correct arguments
|
||||||
|
mock_create_route.assert_called_once_with(
|
||||||
|
endpoint=endpoint,
|
||||||
|
target=f"https://{test_location}-aiplatform.googleapis.com/v1/projects/{test_project}/locations/{test_location}/publishers/google/models/gemini-1.5-flash:generateContent",
|
||||||
|
custom_headers={"Authorization": f"Bearer {test_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"initial_endpoint",
|
||||||
|
[
|
||||||
|
"publishers/google/models/gemini-1.5-flash:generateContent",
|
||||||
|
"v1/projects/bad-project/locations/bad-location/publishers/google/models/gemini-1.5-flash:generateContent",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_vertex_passthrough_with_default_credentials(
|
||||||
|
self, monkeypatch, initial_endpoint
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test that when no passthrough credentials are set, default credentials are used in the request
|
||||||
|
"""
|
||||||
|
from litellm.proxy.pass_through_endpoints.passthrough_endpoint_router import (
|
||||||
|
PassthroughEndpointRouter,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Setup default credentials
|
||||||
|
default_project = "default-project"
|
||||||
|
default_location = "us-central1"
|
||||||
|
default_credentials = "default-creds"
|
||||||
|
|
||||||
|
pass_through_router = PassthroughEndpointRouter()
|
||||||
|
pass_through_router.default_vertex_config = VertexPassThroughCredentials(
|
||||||
|
vertex_project=default_project,
|
||||||
|
vertex_location=default_location,
|
||||||
|
vertex_credentials=default_credentials,
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.passthrough_endpoint_router",
|
||||||
|
pass_through_router,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use different project/location in request than the default
|
||||||
|
endpoint = initial_endpoint
|
||||||
|
|
||||||
|
mock_request = Request(
|
||||||
|
scope={
|
||||||
|
"type": "http",
|
||||||
|
"method": "POST",
|
||||||
|
"path": f"/vertex_ai/{endpoint}",
|
||||||
|
"headers": {},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
mock_response = Response()
|
||||||
|
|
||||||
|
with mock.patch(
|
||||||
|
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.vertex_llm_base._ensure_access_token_async"
|
||||||
|
) as mock_ensure_token, mock.patch(
|
||||||
|
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.vertex_llm_base._get_token_and_url"
|
||||||
|
) as mock_get_token, mock.patch(
|
||||||
|
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.create_pass_through_route"
|
||||||
|
) as mock_create_route:
|
||||||
|
mock_ensure_token.return_value = ("test-auth-header", default_project)
|
||||||
|
mock_get_token.return_value = (default_credentials, "")
|
||||||
|
|
||||||
|
try:
|
||||||
|
await vertex_proxy_route(
|
||||||
|
endpoint=endpoint,
|
||||||
|
request=mock_request,
|
||||||
|
fastapi_response=mock_response,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
||||||
|
print(f"Error: {e}")
|
||||||
|
|
||||||
|
# Verify default credentials were used
|
||||||
|
mock_create_route.assert_called_once_with(
|
||||||
|
endpoint=endpoint,
|
||||||
|
target=f"https://{default_location}-aiplatform.googleapis.com/v1/projects/{default_project}/locations/{default_location}/publishers/google/models/gemini-1.5-flash:generateContent",
|
||||||
|
custom_headers={"Authorization": f"Bearer {default_credentials}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_vertex_passthrough_with_no_default_credentials(self, monkeypatch):
|
||||||
|
"""
|
||||||
|
Test that when no default credentials are set, the request fails
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
Test that when passthrough credentials are set, they are correctly used in the request
|
||||||
|
"""
|
||||||
|
from litellm.proxy.pass_through_endpoints.passthrough_endpoint_router import (
|
||||||
|
PassthroughEndpointRouter,
|
||||||
|
)
|
||||||
|
|
||||||
|
vertex_project = "my-project"
|
||||||
|
vertex_location = "us-central1"
|
||||||
|
vertex_credentials = "test-creds"
|
||||||
|
|
||||||
|
test_project = "test-project"
|
||||||
|
test_location = "test-location"
|
||||||
|
test_token = "test-creds"
|
||||||
|
|
||||||
|
pass_through_router = PassthroughEndpointRouter()
|
||||||
|
|
||||||
|
pass_through_router.add_vertex_credentials(
|
||||||
|
project_id=vertex_project,
|
||||||
|
location=vertex_location,
|
||||||
|
vertex_credentials=vertex_credentials,
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.passthrough_endpoint_router",
|
||||||
|
pass_through_router,
|
||||||
|
)
|
||||||
|
|
||||||
|
endpoint = f"/v1/projects/{test_project}/locations/{test_location}/publishers/google/models/gemini-1.5-flash:generateContent"
|
||||||
|
|
||||||
|
# Mock request
|
||||||
|
mock_request = Request(
|
||||||
|
scope={
|
||||||
|
"type": "http",
|
||||||
|
"method": "POST",
|
||||||
|
"path": endpoint,
|
||||||
|
"headers": [
|
||||||
|
(b"authorization", b"Bearer test-creds"),
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock response
|
||||||
|
mock_response = Response()
|
||||||
|
|
||||||
|
with mock.patch(
|
||||||
|
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.vertex_llm_base._ensure_access_token_async"
|
||||||
|
) as mock_ensure_token, mock.patch(
|
||||||
|
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.vertex_llm_base._get_token_and_url"
|
||||||
|
) as mock_get_token, mock.patch(
|
||||||
|
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.create_pass_through_route"
|
||||||
|
) as mock_create_route:
|
||||||
|
mock_ensure_token.return_value = ("test-auth-header", test_project)
|
||||||
|
mock_get_token.return_value = (test_token, "")
|
||||||
|
|
||||||
|
# Call the route
|
||||||
|
try:
|
||||||
|
await vertex_proxy_route(
|
||||||
|
endpoint=endpoint,
|
||||||
|
request=mock_request,
|
||||||
|
fastapi_response=mock_response,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
||||||
|
print(f"Error: {e}")
|
||||||
|
|
||||||
|
# Verify create_pass_through_route was called with correct arguments
|
||||||
|
mock_create_route.assert_called_once_with(
|
||||||
|
endpoint=endpoint,
|
||||||
|
target=f"https://{test_location}-aiplatform.googleapis.com/v1/projects/{test_project}/locations/{test_location}/publishers/google/models/gemini-1.5-flash:generateContent",
|
||||||
|
custom_headers={"authorization": f"Bearer {test_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_vertex_proxy_route_api_key_auth(self):
|
||||||
|
"""
|
||||||
|
Critical
|
||||||
|
|
||||||
|
This is how Vertex AI JS SDK will Auth to Litellm Proxy
|
||||||
|
"""
|
||||||
|
# Mock dependencies
|
||||||
|
mock_request = Mock()
|
||||||
|
mock_request.headers = {"x-litellm-api-key": "test-key-123"}
|
||||||
|
mock_request.method = "POST"
|
||||||
|
mock_response = Mock()
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.user_api_key_auth"
|
||||||
|
) as mock_auth:
|
||||||
|
mock_auth.return_value = {"api_key": "test-key-123"}
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.create_pass_through_route"
|
||||||
|
) as mock_pass_through:
|
||||||
|
mock_pass_through.return_value = AsyncMock(
|
||||||
|
return_value={"status": "success"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Call the function
|
||||||
|
result = await vertex_proxy_route(
|
||||||
|
endpoint="v1/projects/test-project/locations/us-central1/publishers/google/models/gemini-1.5-pro:generateContent",
|
||||||
|
request=mock_request,
|
||||||
|
fastapi_response=mock_response,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify user_api_key_auth was called with the correct Bearer token
|
||||||
|
mock_auth.assert_called_once()
|
||||||
|
call_args = mock_auth.call_args[1]
|
||||||
|
assert call_args["api_key"] == "Bearer test-key-123"
|
||||||
|
|
|
@ -0,0 +1,44 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
from unittest import mock
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import pytest
|
||||||
|
from fastapi import Request, Response
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../../../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
from litellm.proxy.pass_through_endpoints.common_utils import get_litellm_virtual_key
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_litellm_virtual_key():
|
||||||
|
"""
|
||||||
|
Test that the get_litellm_virtual_key function correctly handles the API key authentication
|
||||||
|
"""
|
||||||
|
# Test with x-litellm-api-key
|
||||||
|
mock_request = Mock()
|
||||||
|
mock_request.headers = {"x-litellm-api-key": "test-key-123"}
|
||||||
|
result = get_litellm_virtual_key(mock_request)
|
||||||
|
assert result == "Bearer test-key-123"
|
||||||
|
|
||||||
|
# Test with Authorization header
|
||||||
|
mock_request.headers = {"Authorization": "Bearer auth-key-456"}
|
||||||
|
result = get_litellm_virtual_key(mock_request)
|
||||||
|
assert result == "Bearer auth-key-456"
|
||||||
|
|
||||||
|
# Test with both headers (x-litellm-api-key should take precedence)
|
||||||
|
mock_request.headers = {
|
||||||
|
"x-litellm-api-key": "test-key-123",
|
||||||
|
"Authorization": "Bearer auth-key-456",
|
||||||
|
}
|
||||||
|
result = get_litellm_virtual_key(mock_request)
|
||||||
|
assert result == "Bearer test-key-123"
|
|
@ -339,9 +339,6 @@ def test_pass_through_routes_support_all_methods():
|
||||||
from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import (
|
from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import (
|
||||||
router as llm_router,
|
router as llm_router,
|
||||||
)
|
)
|
||||||
from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import (
|
|
||||||
router as vertex_router,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Expected HTTP methods
|
# Expected HTTP methods
|
||||||
expected_methods = {"GET", "POST", "PUT", "DELETE", "PATCH"}
|
expected_methods = {"GET", "POST", "PUT", "DELETE", "PATCH"}
|
||||||
|
@ -361,7 +358,6 @@ def test_pass_through_routes_support_all_methods():
|
||||||
|
|
||||||
# Check both routers
|
# Check both routers
|
||||||
check_router_methods(llm_router)
|
check_router_methods(llm_router)
|
||||||
check_router_methods(vertex_router)
|
|
||||||
|
|
||||||
|
|
||||||
def test_is_bedrock_agent_runtime_route():
|
def test_is_bedrock_agent_runtime_route():
|
||||||
|
|
|
@ -11,6 +11,7 @@ from unittest.mock import patch
|
||||||
from litellm.proxy.pass_through_endpoints.passthrough_endpoint_router import (
|
from litellm.proxy.pass_through_endpoints.passthrough_endpoint_router import (
|
||||||
PassthroughEndpointRouter,
|
PassthroughEndpointRouter,
|
||||||
)
|
)
|
||||||
|
from litellm.types.passthrough_endpoints.vertex_ai import VertexPassThroughCredentials
|
||||||
|
|
||||||
passthrough_endpoint_router = PassthroughEndpointRouter()
|
passthrough_endpoint_router = PassthroughEndpointRouter()
|
||||||
|
|
||||||
|
@ -132,3 +133,185 @@ class TestPassthroughEndpointRouter(unittest.TestCase):
|
||||||
),
|
),
|
||||||
"COHERE_API_KEY",
|
"COHERE_API_KEY",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_get_deployment_key(self):
|
||||||
|
"""Test _get_deployment_key with various inputs"""
|
||||||
|
router = PassthroughEndpointRouter()
|
||||||
|
|
||||||
|
# Test with valid inputs
|
||||||
|
key = router._get_deployment_key("test-project", "us-central1")
|
||||||
|
assert key == "test-project-us-central1"
|
||||||
|
|
||||||
|
# Test with None values
|
||||||
|
key = router._get_deployment_key(None, "us-central1")
|
||||||
|
assert key is None
|
||||||
|
|
||||||
|
key = router._get_deployment_key("test-project", None)
|
||||||
|
assert key is None
|
||||||
|
|
||||||
|
key = router._get_deployment_key(None, None)
|
||||||
|
assert key is None
|
||||||
|
|
||||||
|
def test_add_vertex_credentials(self):
|
||||||
|
"""Test add_vertex_credentials functionality"""
|
||||||
|
router = PassthroughEndpointRouter()
|
||||||
|
|
||||||
|
# Test adding valid credentials
|
||||||
|
router.add_vertex_credentials(
|
||||||
|
project_id="test-project",
|
||||||
|
location="us-central1",
|
||||||
|
vertex_credentials='{"credentials": "test-creds"}',
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "test-project-us-central1" in router.deployment_key_to_vertex_credentials
|
||||||
|
creds = router.deployment_key_to_vertex_credentials["test-project-us-central1"]
|
||||||
|
assert creds.vertex_project == "test-project"
|
||||||
|
assert creds.vertex_location == "us-central1"
|
||||||
|
assert creds.vertex_credentials == '{"credentials": "test-creds"}'
|
||||||
|
|
||||||
|
# Test adding with None values
|
||||||
|
router.add_vertex_credentials(
|
||||||
|
project_id=None,
|
||||||
|
location=None,
|
||||||
|
vertex_credentials='{"credentials": "test-creds"}',
|
||||||
|
)
|
||||||
|
# Should not add None values
|
||||||
|
assert len(router.deployment_key_to_vertex_credentials) == 1
|
||||||
|
|
||||||
|
def test_default_credentials(self):
|
||||||
|
"""
|
||||||
|
Test get_vertex_credentials with stored credentials.
|
||||||
|
|
||||||
|
Tests if default credentials are used if set.
|
||||||
|
|
||||||
|
Tests if no default credentials are used, if no default set
|
||||||
|
"""
|
||||||
|
router = PassthroughEndpointRouter()
|
||||||
|
router.add_vertex_credentials(
|
||||||
|
project_id="test-project",
|
||||||
|
location="us-central1",
|
||||||
|
vertex_credentials='{"credentials": "test-creds"}',
|
||||||
|
)
|
||||||
|
|
||||||
|
creds = router.get_vertex_credentials(
|
||||||
|
project_id="test-project", location="us-central2"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert creds is None
|
||||||
|
|
||||||
|
def test_get_vertex_env_vars(self):
|
||||||
|
"""Test that _get_vertex_env_vars correctly reads environment variables"""
|
||||||
|
# Set environment variables for the test
|
||||||
|
os.environ["DEFAULT_VERTEXAI_PROJECT"] = "test-project-123"
|
||||||
|
os.environ["DEFAULT_VERTEXAI_LOCATION"] = "us-central1"
|
||||||
|
os.environ["DEFAULT_GOOGLE_APPLICATION_CREDENTIALS"] = "/path/to/creds"
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = self.router._get_vertex_env_vars()
|
||||||
|
print(result)
|
||||||
|
|
||||||
|
# Verify the result
|
||||||
|
assert isinstance(result, VertexPassThroughCredentials)
|
||||||
|
assert result.vertex_project == "test-project-123"
|
||||||
|
assert result.vertex_location == "us-central1"
|
||||||
|
assert result.vertex_credentials == "/path/to/creds"
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Clean up environment variables
|
||||||
|
del os.environ["DEFAULT_VERTEXAI_PROJECT"]
|
||||||
|
del os.environ["DEFAULT_VERTEXAI_LOCATION"]
|
||||||
|
del os.environ["DEFAULT_GOOGLE_APPLICATION_CREDENTIALS"]
|
||||||
|
|
||||||
|
def test_set_default_vertex_config(self):
|
||||||
|
"""Test set_default_vertex_config with various inputs"""
|
||||||
|
# Test with None config - set environment variables first
|
||||||
|
os.environ["DEFAULT_VERTEXAI_PROJECT"] = "env-project"
|
||||||
|
os.environ["DEFAULT_VERTEXAI_LOCATION"] = "env-location"
|
||||||
|
os.environ["DEFAULT_GOOGLE_APPLICATION_CREDENTIALS"] = "env-creds"
|
||||||
|
os.environ["GOOGLE_CREDS"] = "secret-creds"
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Test with None config
|
||||||
|
self.router.set_default_vertex_config()
|
||||||
|
|
||||||
|
assert self.router.default_vertex_config.vertex_project == "env-project"
|
||||||
|
assert self.router.default_vertex_config.vertex_location == "env-location"
|
||||||
|
assert self.router.default_vertex_config.vertex_credentials == "env-creds"
|
||||||
|
|
||||||
|
# Test with valid config.yaml settings on vertex_config
|
||||||
|
test_config = {
|
||||||
|
"vertex_project": "my-project-123",
|
||||||
|
"vertex_location": "us-central1",
|
||||||
|
"vertex_credentials": "path/to/creds",
|
||||||
|
}
|
||||||
|
self.router.set_default_vertex_config(test_config)
|
||||||
|
|
||||||
|
assert self.router.default_vertex_config.vertex_project == "my-project-123"
|
||||||
|
assert self.router.default_vertex_config.vertex_location == "us-central1"
|
||||||
|
assert (
|
||||||
|
self.router.default_vertex_config.vertex_credentials == "path/to/creds"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test with environment variable reference
|
||||||
|
test_config = {
|
||||||
|
"vertex_project": "my-project-123",
|
||||||
|
"vertex_location": "us-central1",
|
||||||
|
"vertex_credentials": "os.environ/GOOGLE_CREDS",
|
||||||
|
}
|
||||||
|
self.router.set_default_vertex_config(test_config)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
self.router.default_vertex_config.vertex_credentials == "secret-creds"
|
||||||
|
)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Clean up environment variables
|
||||||
|
del os.environ["DEFAULT_VERTEXAI_PROJECT"]
|
||||||
|
del os.environ["DEFAULT_VERTEXAI_LOCATION"]
|
||||||
|
del os.environ["DEFAULT_GOOGLE_APPLICATION_CREDENTIALS"]
|
||||||
|
del os.environ["GOOGLE_CREDS"]
|
||||||
|
|
||||||
|
def test_vertex_passthrough_router_init(self):
|
||||||
|
"""Test VertexPassThroughRouter initialization"""
|
||||||
|
router = PassthroughEndpointRouter()
|
||||||
|
assert isinstance(router.deployment_key_to_vertex_credentials, dict)
|
||||||
|
assert len(router.deployment_key_to_vertex_credentials) == 0
|
||||||
|
|
||||||
|
def test_get_vertex_credentials_none(self):
|
||||||
|
"""Test get_vertex_credentials with various inputs"""
|
||||||
|
router = PassthroughEndpointRouter()
|
||||||
|
|
||||||
|
router.set_default_vertex_config(
|
||||||
|
config={
|
||||||
|
"vertex_project": None,
|
||||||
|
"vertex_location": None,
|
||||||
|
"vertex_credentials": None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test with None project_id and location - should return default config
|
||||||
|
creds = router.get_vertex_credentials(None, None)
|
||||||
|
assert isinstance(creds, VertexPassThroughCredentials)
|
||||||
|
|
||||||
|
# Test with valid project_id and location but no stored credentials
|
||||||
|
creds = router.get_vertex_credentials("test-project", "us-central1")
|
||||||
|
assert isinstance(creds, VertexPassThroughCredentials)
|
||||||
|
assert creds.vertex_project is None
|
||||||
|
assert creds.vertex_location is None
|
||||||
|
assert creds.vertex_credentials is None
|
||||||
|
|
||||||
|
def test_get_vertex_credentials_stored(self):
|
||||||
|
"""Test get_vertex_credentials with stored credentials"""
|
||||||
|
router = PassthroughEndpointRouter()
|
||||||
|
router.add_vertex_credentials(
|
||||||
|
project_id="test-project",
|
||||||
|
location="us-central1",
|
||||||
|
vertex_credentials='{"credentials": "test-creds"}',
|
||||||
|
)
|
||||||
|
|
||||||
|
creds = router.get_vertex_credentials(
|
||||||
|
project_id="test-project", location="us-central1"
|
||||||
|
)
|
||||||
|
assert creds.vertex_project == "test-project"
|
||||||
|
assert creds.vertex_location == "us-central1"
|
||||||
|
assert creds.vertex_credentials == '{"credentials": "test-creds"}'
|
||||||
|
|
|
@ -1,294 +0,0 @@
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
from datetime import datetime
|
|
||||||
from unittest.mock import AsyncMock, Mock, patch
|
|
||||||
|
|
||||||
sys.path.insert(
|
|
||||||
0, os.path.abspath("../..")
|
|
||||||
) # Adds the parent directory to the system-path
|
|
||||||
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
import pytest
|
|
||||||
import litellm
|
|
||||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
|
||||||
|
|
||||||
|
|
||||||
from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import (
|
|
||||||
get_litellm_virtual_key,
|
|
||||||
vertex_proxy_route,
|
|
||||||
_get_vertex_env_vars,
|
|
||||||
set_default_vertex_config,
|
|
||||||
VertexPassThroughCredentials,
|
|
||||||
default_vertex_config,
|
|
||||||
)
|
|
||||||
from litellm.proxy.vertex_ai_endpoints.vertex_passthrough_router import (
|
|
||||||
VertexPassThroughRouter,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_get_litellm_virtual_key():
|
|
||||||
"""
|
|
||||||
Test that the get_litellm_virtual_key function correctly handles the API key authentication
|
|
||||||
"""
|
|
||||||
# Test with x-litellm-api-key
|
|
||||||
mock_request = Mock()
|
|
||||||
mock_request.headers = {"x-litellm-api-key": "test-key-123"}
|
|
||||||
result = get_litellm_virtual_key(mock_request)
|
|
||||||
assert result == "Bearer test-key-123"
|
|
||||||
|
|
||||||
# Test with Authorization header
|
|
||||||
mock_request.headers = {"Authorization": "Bearer auth-key-456"}
|
|
||||||
result = get_litellm_virtual_key(mock_request)
|
|
||||||
assert result == "Bearer auth-key-456"
|
|
||||||
|
|
||||||
# Test with both headers (x-litellm-api-key should take precedence)
|
|
||||||
mock_request.headers = {
|
|
||||||
"x-litellm-api-key": "test-key-123",
|
|
||||||
"Authorization": "Bearer auth-key-456",
|
|
||||||
}
|
|
||||||
result = get_litellm_virtual_key(mock_request)
|
|
||||||
assert result == "Bearer test-key-123"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_async_vertex_proxy_route_api_key_auth():
|
|
||||||
"""
|
|
||||||
Critical
|
|
||||||
|
|
||||||
This is how Vertex AI JS SDK will Auth to Litellm Proxy
|
|
||||||
"""
|
|
||||||
# Mock dependencies
|
|
||||||
mock_request = Mock()
|
|
||||||
mock_request.headers = {"x-litellm-api-key": "test-key-123"}
|
|
||||||
mock_request.method = "POST"
|
|
||||||
mock_response = Mock()
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"litellm.proxy.vertex_ai_endpoints.vertex_endpoints.user_api_key_auth"
|
|
||||||
) as mock_auth:
|
|
||||||
mock_auth.return_value = {"api_key": "test-key-123"}
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"litellm.proxy.vertex_ai_endpoints.vertex_endpoints.create_pass_through_route"
|
|
||||||
) as mock_pass_through:
|
|
||||||
mock_pass_through.return_value = AsyncMock(
|
|
||||||
return_value={"status": "success"}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Call the function
|
|
||||||
result = await vertex_proxy_route(
|
|
||||||
endpoint="v1/projects/test-project/locations/us-central1/publishers/google/models/gemini-1.5-pro:generateContent",
|
|
||||||
request=mock_request,
|
|
||||||
fastapi_response=mock_response,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify user_api_key_auth was called with the correct Bearer token
|
|
||||||
mock_auth.assert_called_once()
|
|
||||||
call_args = mock_auth.call_args[1]
|
|
||||||
assert call_args["api_key"] == "Bearer test-key-123"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_get_vertex_env_vars():
|
|
||||||
"""Test that _get_vertex_env_vars correctly reads environment variables"""
|
|
||||||
# Set environment variables for the test
|
|
||||||
os.environ["DEFAULT_VERTEXAI_PROJECT"] = "test-project-123"
|
|
||||||
os.environ["DEFAULT_VERTEXAI_LOCATION"] = "us-central1"
|
|
||||||
os.environ["DEFAULT_GOOGLE_APPLICATION_CREDENTIALS"] = "/path/to/creds"
|
|
||||||
|
|
||||||
try:
|
|
||||||
result = _get_vertex_env_vars()
|
|
||||||
print(result)
|
|
||||||
|
|
||||||
# Verify the result
|
|
||||||
assert isinstance(result, VertexPassThroughCredentials)
|
|
||||||
assert result.vertex_project == "test-project-123"
|
|
||||||
assert result.vertex_location == "us-central1"
|
|
||||||
assert result.vertex_credentials == "/path/to/creds"
|
|
||||||
|
|
||||||
finally:
|
|
||||||
# Clean up environment variables
|
|
||||||
del os.environ["DEFAULT_VERTEXAI_PROJECT"]
|
|
||||||
del os.environ["DEFAULT_VERTEXAI_LOCATION"]
|
|
||||||
del os.environ["DEFAULT_GOOGLE_APPLICATION_CREDENTIALS"]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_set_default_vertex_config():
|
|
||||||
"""Test set_default_vertex_config with various inputs"""
|
|
||||||
# Test with None config - set environment variables first
|
|
||||||
os.environ["DEFAULT_VERTEXAI_PROJECT"] = "env-project"
|
|
||||||
os.environ["DEFAULT_VERTEXAI_LOCATION"] = "env-location"
|
|
||||||
os.environ["DEFAULT_GOOGLE_APPLICATION_CREDENTIALS"] = "env-creds"
|
|
||||||
os.environ["GOOGLE_CREDS"] = "secret-creds"
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Test with None config
|
|
||||||
set_default_vertex_config()
|
|
||||||
from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import (
|
|
||||||
default_vertex_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert default_vertex_config.vertex_project == "env-project"
|
|
||||||
assert default_vertex_config.vertex_location == "env-location"
|
|
||||||
assert default_vertex_config.vertex_credentials == "env-creds"
|
|
||||||
|
|
||||||
# Test with valid config.yaml settings on vertex_config
|
|
||||||
test_config = {
|
|
||||||
"vertex_project": "my-project-123",
|
|
||||||
"vertex_location": "us-central1",
|
|
||||||
"vertex_credentials": "path/to/creds",
|
|
||||||
}
|
|
||||||
set_default_vertex_config(test_config)
|
|
||||||
from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import (
|
|
||||||
default_vertex_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert default_vertex_config.vertex_project == "my-project-123"
|
|
||||||
assert default_vertex_config.vertex_location == "us-central1"
|
|
||||||
assert default_vertex_config.vertex_credentials == "path/to/creds"
|
|
||||||
|
|
||||||
# Test with environment variable reference
|
|
||||||
test_config = {
|
|
||||||
"vertex_project": "my-project-123",
|
|
||||||
"vertex_location": "us-central1",
|
|
||||||
"vertex_credentials": "os.environ/GOOGLE_CREDS",
|
|
||||||
}
|
|
||||||
set_default_vertex_config(test_config)
|
|
||||||
from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import (
|
|
||||||
default_vertex_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert default_vertex_config.vertex_credentials == "secret-creds"
|
|
||||||
|
|
||||||
finally:
|
|
||||||
# Clean up environment variables
|
|
||||||
del os.environ["DEFAULT_VERTEXAI_PROJECT"]
|
|
||||||
del os.environ["DEFAULT_VERTEXAI_LOCATION"]
|
|
||||||
del os.environ["DEFAULT_GOOGLE_APPLICATION_CREDENTIALS"]
|
|
||||||
del os.environ["GOOGLE_CREDS"]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_vertex_passthrough_router_init():
|
|
||||||
"""Test VertexPassThroughRouter initialization"""
|
|
||||||
router = VertexPassThroughRouter()
|
|
||||||
assert isinstance(router.deployment_key_to_vertex_credentials, dict)
|
|
||||||
assert len(router.deployment_key_to_vertex_credentials) == 0
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_get_vertex_credentials_none():
|
|
||||||
"""Test get_vertex_credentials with various inputs"""
|
|
||||||
from litellm.proxy.vertex_ai_endpoints import vertex_endpoints
|
|
||||||
|
|
||||||
setattr(vertex_endpoints, "default_vertex_config", VertexPassThroughCredentials())
|
|
||||||
router = VertexPassThroughRouter()
|
|
||||||
|
|
||||||
# Test with None project_id and location - should return default config
|
|
||||||
creds = router.get_vertex_credentials(None, None)
|
|
||||||
assert isinstance(creds, VertexPassThroughCredentials)
|
|
||||||
|
|
||||||
# Test with valid project_id and location but no stored credentials
|
|
||||||
creds = router.get_vertex_credentials("test-project", "us-central1")
|
|
||||||
assert isinstance(creds, VertexPassThroughCredentials)
|
|
||||||
assert creds.vertex_project is None
|
|
||||||
assert creds.vertex_location is None
|
|
||||||
assert creds.vertex_credentials is None
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_get_vertex_credentials_stored():
|
|
||||||
"""Test get_vertex_credentials with stored credentials"""
|
|
||||||
router = VertexPassThroughRouter()
|
|
||||||
router.add_vertex_credentials(
|
|
||||||
project_id="test-project",
|
|
||||||
location="us-central1",
|
|
||||||
vertex_credentials='{"credentials": "test-creds"}',
|
|
||||||
)
|
|
||||||
|
|
||||||
creds = router.get_vertex_credentials(
|
|
||||||
project_id="test-project", location="us-central1"
|
|
||||||
)
|
|
||||||
assert creds.vertex_project == "test-project"
|
|
||||||
assert creds.vertex_location == "us-central1"
|
|
||||||
assert creds.vertex_credentials == '{"credentials": "test-creds"}'
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_add_vertex_credentials():
|
|
||||||
"""Test add_vertex_credentials functionality"""
|
|
||||||
router = VertexPassThroughRouter()
|
|
||||||
|
|
||||||
# Test adding valid credentials
|
|
||||||
router.add_vertex_credentials(
|
|
||||||
project_id="test-project",
|
|
||||||
location="us-central1",
|
|
||||||
vertex_credentials='{"credentials": "test-creds"}',
|
|
||||||
)
|
|
||||||
|
|
||||||
assert "test-project-us-central1" in router.deployment_key_to_vertex_credentials
|
|
||||||
creds = router.deployment_key_to_vertex_credentials["test-project-us-central1"]
|
|
||||||
assert creds.vertex_project == "test-project"
|
|
||||||
assert creds.vertex_location == "us-central1"
|
|
||||||
assert creds.vertex_credentials == '{"credentials": "test-creds"}'
|
|
||||||
|
|
||||||
# Test adding with None values
|
|
||||||
router.add_vertex_credentials(
|
|
||||||
project_id=None,
|
|
||||||
location=None,
|
|
||||||
vertex_credentials='{"credentials": "test-creds"}',
|
|
||||||
)
|
|
||||||
# Should not add None values
|
|
||||||
assert len(router.deployment_key_to_vertex_credentials) == 1
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_get_deployment_key():
|
|
||||||
"""Test _get_deployment_key with various inputs"""
|
|
||||||
router = VertexPassThroughRouter()
|
|
||||||
|
|
||||||
# Test with valid inputs
|
|
||||||
key = router._get_deployment_key("test-project", "us-central1")
|
|
||||||
assert key == "test-project-us-central1"
|
|
||||||
|
|
||||||
# Test with None values
|
|
||||||
key = router._get_deployment_key(None, "us-central1")
|
|
||||||
assert key is None
|
|
||||||
|
|
||||||
key = router._get_deployment_key("test-project", None)
|
|
||||||
assert key is None
|
|
||||||
|
|
||||||
key = router._get_deployment_key(None, None)
|
|
||||||
assert key is None
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_get_vertex_project_id_from_url():
|
|
||||||
"""Test _get_vertex_project_id_from_url with various URLs"""
|
|
||||||
# Test with valid URL
|
|
||||||
url = "https://us-central1-aiplatform.googleapis.com/v1/projects/test-project/locations/us-central1/publishers/google/models/gemini-pro:streamGenerateContent"
|
|
||||||
project_id = VertexPassThroughRouter._get_vertex_project_id_from_url(url)
|
|
||||||
assert project_id == "test-project"
|
|
||||||
|
|
||||||
# Test with invalid URL
|
|
||||||
url = "https://invalid-url.com"
|
|
||||||
project_id = VertexPassThroughRouter._get_vertex_project_id_from_url(url)
|
|
||||||
assert project_id is None
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_get_vertex_location_from_url():
|
|
||||||
"""Test _get_vertex_location_from_url with various URLs"""
|
|
||||||
# Test with valid URL
|
|
||||||
url = "https://us-central1-aiplatform.googleapis.com/v1/projects/test-project/locations/us-central1/publishers/google/models/gemini-pro:streamGenerateContent"
|
|
||||||
location = VertexPassThroughRouter._get_vertex_location_from_url(url)
|
|
||||||
assert location == "us-central1"
|
|
||||||
|
|
||||||
# Test with invalid URL
|
|
||||||
url = "https://invalid-url.com"
|
|
||||||
location = VertexPassThroughRouter._get_vertex_location_from_url(url)
|
|
||||||
assert location is None
|
|
|
@ -30,9 +30,6 @@ from litellm.proxy._types import LiteLLM_UserTable, LitellmUserRoles, UserAPIKey
|
||||||
from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import (
|
from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import (
|
||||||
router as llm_passthrough_router,
|
router as llm_passthrough_router,
|
||||||
)
|
)
|
||||||
from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import (
|
|
||||||
router as vertex_router,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Replace the actual hash_token function with our mock
|
# Replace the actual hash_token function with our mock
|
||||||
import litellm.proxy.auth.route_checks
|
import litellm.proxy.auth.route_checks
|
||||||
|
@ -96,7 +93,7 @@ def test_is_llm_api_route():
|
||||||
assert RouteChecks.is_llm_api_route("/key/regenerate/82akk800000000jjsk") is False
|
assert RouteChecks.is_llm_api_route("/key/regenerate/82akk800000000jjsk") is False
|
||||||
assert RouteChecks.is_llm_api_route("/key/82akk800000000jjsk/delete") is False
|
assert RouteChecks.is_llm_api_route("/key/82akk800000000jjsk/delete") is False
|
||||||
|
|
||||||
all_llm_api_routes = vertex_router.routes + llm_passthrough_router.routes
|
all_llm_api_routes = llm_passthrough_router.routes
|
||||||
|
|
||||||
# check all routes in llm_passthrough_router, ensure they are considered llm api routes
|
# check all routes in llm_passthrough_router, ensure they are considered llm api routes
|
||||||
for route in all_llm_api_routes:
|
for route in all_llm_api_routes:
|
||||||
|
|
|
@ -36,11 +36,11 @@ def test_initialize_deployment_for_pass_through_success():
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify the credentials were properly set
|
# Verify the credentials were properly set
|
||||||
from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import (
|
from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import (
|
||||||
vertex_pass_through_router,
|
passthrough_endpoint_router,
|
||||||
)
|
)
|
||||||
|
|
||||||
vertex_creds = vertex_pass_through_router.get_vertex_credentials(
|
vertex_creds = passthrough_endpoint_router.get_vertex_credentials(
|
||||||
project_id="test-project", location="us-central1"
|
project_id="test-project", location="us-central1"
|
||||||
)
|
)
|
||||||
assert vertex_creds.vertex_project == "test-project"
|
assert vertex_creds.vertex_project == "test-project"
|
||||||
|
@ -123,21 +123,21 @@ def test_add_vertex_pass_through_deployment():
|
||||||
router.add_deployment(deployment)
|
router.add_deployment(deployment)
|
||||||
|
|
||||||
# Get the vertex credentials from the router
|
# Get the vertex credentials from the router
|
||||||
from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import (
|
from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import (
|
||||||
vertex_pass_through_router,
|
passthrough_endpoint_router,
|
||||||
)
|
)
|
||||||
|
|
||||||
# current state of pass-through vertex router
|
# current state of pass-through vertex router
|
||||||
print("\n vertex_pass_through_router.deployment_key_to_vertex_credentials\n\n")
|
print("\n vertex_pass_through_router.deployment_key_to_vertex_credentials\n\n")
|
||||||
print(
|
print(
|
||||||
json.dumps(
|
json.dumps(
|
||||||
vertex_pass_through_router.deployment_key_to_vertex_credentials,
|
passthrough_endpoint_router.deployment_key_to_vertex_credentials,
|
||||||
indent=4,
|
indent=4,
|
||||||
default=str,
|
default=str,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
vertex_creds = vertex_pass_through_router.get_vertex_credentials(
|
vertex_creds = passthrough_endpoint_router.get_vertex_credentials(
|
||||||
project_id="test-project", location="us-central1"
|
project_id="test-project", location="us-central1"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue