diff --git a/docs/my-website/docs/pass_through/vertex_ai.md b/docs/my-website/docs/pass_through/vertex_ai.md
index 744c5e3ff..601f89f4b 100644
--- a/docs/my-website/docs/pass_through/vertex_ai.md
+++ b/docs/my-website/docs/pass_through/vertex_ai.md
@@ -69,6 +69,44 @@ generateContent();
+## Quick Start
+
+Let's call the Vertex AI [`/generateContent` endpoint](https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference)
+
+1. Add Vertex AI Credentials to your environment
+
+```bash
+export DEFAULT_VERTEXAI_PROJECT="" # "adroit-crow-413218"
+export DEFAULT_VERTEXAI_LOCATION="" # "us-central1"
+export DEFAULT_GOOGLE_APPLICATION_CREDENTIALS="" # "/Users/Downloads/adroit-crow-413218-a956eef1a2a8.json"
+```
+
+2. Start LiteLLM Proxy
+
+```bash
+litellm
+
+# RUNNING on http://0.0.0.0:4000
+```
+
+3. Test it!
+
+Let's call the Google AI Studio token counting endpoint
+
+```bash
+curl http://localhost:4000/vertex-ai/publishers/google/models/gemini-1.0-pro:generateContent \
+ -H "Content-Type: application/json" \
+ -H "Authorization: Bearer sk-1234" \
+ -d '{
+ "contents":[{
+ "role": "user",
+ "parts":[{"text": "How are you doing today?"}]
+ }]
+ }'
+```
+
+
+
## Supported API Endpoints
- Gemini API
@@ -87,206 +125,12 @@ LiteLLM Proxy Server supports two methods of authentication to Vertex AI:
2. Set Vertex AI credentials on proxy server
-## Quick Start Usage
-
-
-
-
-
-#### 1. Start litellm proxy
-
-```shell
-litellm --config /path/to/config.yaml
-```
-
-#### 2. Test it
-
-```python
-import vertexai
-from vertexai.preview.generative_models import GenerativeModel
-
-LITE_LLM_ENDPOINT = "http://localhost:4000"
-
-vertexai.init(
- project="", # enter your project id
- location="", # enter your region
- api_endpoint=f"{LITE_LLM_ENDPOINT}/vertex_ai", # route on litellm
- api_transport="rest",
-)
-
-model = GenerativeModel(model_name="gemini-1.0-pro")
-model.generate_content("hi")
-
-```
-
-
-
-
-
-
-#### 1. Set `default_vertex_config` on your `config.yaml`
-
-
-Add the following credentials to your litellm config.yaml to use the Vertex AI endpoints.
-
-```yaml
-default_vertex_config:
- vertex_project: "adroit-crow-413218"
- vertex_location: "us-central1"
- vertex_credentials: "/Users/ishaanjaffer/Downloads/adroit-crow-413218-a956eef1a2a8.json" # Add path to service account.json
-```
-
-#### 2. Start litellm proxy
-
-```shell
-litellm --config /path/to/config.yaml
-```
-
-#### 3. Test it
-
-```python
-import vertexai
-from google.auth.credentials import Credentials
-from vertexai.generative_models import GenerativeModel
-
-LITELLM_PROXY_API_KEY = "sk-1234"
-LITELLM_PROXY_BASE = "http://0.0.0.0:4000/vertex_ai"
-
-import datetime
-
-
-class CredentialsWrapper(Credentials):
- def __init__(self, token=None):
- super().__init__()
- self.token = token
- self.expiry = None # or set to a future date if needed
-
- def refresh(self, request):
- pass
-
- def apply(self, headers, token=None):
- headers["Authorization"] = f"Bearer {self.token}"
-
- @property
- def expired(self):
- return False # Always consider the token as non-expired
-
- @property
- def valid(self):
- return True # Always consider the credentials as valid
-
-
-credentials = CredentialsWrapper(token=LITELLM_PROXY_API_KEY)
-
-vertexai.init(
- project="adroit-crow-413218",
- location="us-central1",
- api_endpoint=LITELLM_PROXY_BASE,
- credentials=credentials,
- api_transport="rest",
-)
-
-model = GenerativeModel("gemini-1.5-flash-001")
-
-response = model.generate_content(
- "What's a good name for a flower shop that specializes in selling bouquets of dried flowers?"
-)
-
-print(response.text)
-```
-
-
-
-
## Usage Examples
### Gemini API (Generate Content)
-
-
-```python
-import vertexai
-from vertexai.generative_models import GenerativeModel
-
-LITELLM_PROXY_API_KEY = "sk-1234"
-LITELLM_PROXY_BASE = "http://0.0.0.0:4000/vertex_ai"
-
-vertexai.init(
- project="adroit-crow-413218",
- location="us-central1",
- api_endpoint=LITELLM_PROXY_BASE,
- api_transport="rest",
-
-)
-
-model = GenerativeModel("gemini-1.5-flash-001")
-
-response = model.generate_content(
- "What's a good name for a flower shop that specializes in selling bouquets of dried flowers?"
-)
-
-print(response.text)
-```
-
-
-
-
-```python
-import vertexai
-from google.auth.credentials import Credentials
-from vertexai.generative_models import GenerativeModel
-
-LITELLM_PROXY_API_KEY = "sk-1234"
-LITELLM_PROXY_BASE = "http://0.0.0.0:4000/vertex_ai"
-
-import datetime
-
-
-class CredentialsWrapper(Credentials):
- def __init__(self, token=None):
- super().__init__()
- self.token = token
- self.expiry = None # or set to a future date if needed
-
- def refresh(self, request):
- pass
-
- def apply(self, headers, token=None):
- headers["Authorization"] = f"Bearer {self.token}"
-
- @property
- def expired(self):
- return False # Always consider the token as non-expired
-
- @property
- def valid(self):
- return True # Always consider the credentials as valid
-
-
-credentials = CredentialsWrapper(token=LITELLM_PROXY_API_KEY)
-
-vertexai.init(
- project="adroit-crow-413218",
- location="us-central1",
- api_endpoint=LITELLM_PROXY_BASE,
- credentials=credentials,
- api_transport="rest",
-
-)
-
-model = GenerativeModel("gemini-1.5-flash-001")
-
-response = model.generate_content(
- "What's a good name for a flower shop that specializes in selling bouquets of dried flowers?"
-)
-
-print(response.text)
-```
-
-
-
```shell
curl http://localhost:4000/vertex_ai/publishers/google/models/gemini-1.5-flash-001:generateContent \
@@ -295,114 +139,10 @@ curl http://localhost:4000/vertex_ai/publishers/google/models/gemini-1.5-flash-0
-d '{"contents":[{"role": "user", "parts":[{"text": "hi"}]}]}'
```
-
-
### Embeddings API
-
-
-
-
-```python
-from typing import List, Optional
-from vertexai.language_models import TextEmbeddingInput, TextEmbeddingModel
-import vertexai
-from vertexai.generative_models import GenerativeModel
-
-LITELLM_PROXY_API_KEY = "sk-1234"
-LITELLM_PROXY_BASE = "http://0.0.0.0:4000/vertex_ai"
-
-import datetime
-
-vertexai.init(
- project="adroit-crow-413218",
- location="us-central1",
- api_endpoint=LITELLM_PROXY_BASE,
- api_transport="rest",
-)
-
-
-def embed_text(
- texts: List[str] = ["banana muffins? ", "banana bread? banana muffins?"],
- task: str = "RETRIEVAL_DOCUMENT",
- model_name: str = "text-embedding-004",
- dimensionality: Optional[int] = 256,
-) -> List[List[float]]:
- """Embeds texts with a pre-trained, foundational model."""
- model = TextEmbeddingModel.from_pretrained(model_name)
- inputs = [TextEmbeddingInput(text, task) for text in texts]
- kwargs = dict(output_dimensionality=dimensionality) if dimensionality else {}
- embeddings = model.get_embeddings(inputs, **kwargs)
- return [embedding.values for embedding in embeddings]
-```
-
-
-
-
-
-```python
-from typing import List, Optional
-from vertexai.language_models import TextEmbeddingInput, TextEmbeddingModel
-import vertexai
-from google.auth.credentials import Credentials
-from vertexai.generative_models import GenerativeModel
-
-LITELLM_PROXY_API_KEY = "sk-1234"
-LITELLM_PROXY_BASE = "http://0.0.0.0:4000/vertex_ai"
-
-import datetime
-
-
-class CredentialsWrapper(Credentials):
- def __init__(self, token=None):
- super().__init__()
- self.token = token
- self.expiry = None # or set to a future date if needed
-
- def refresh(self, request):
- pass
-
- def apply(self, headers, token=None):
- headers["Authorization"] = f"Bearer {self.token}"
-
- @property
- def expired(self):
- return False # Always consider the token as non-expired
-
- @property
- def valid(self):
- return True # Always consider the credentials as valid
-
-
-credentials = CredentialsWrapper(token=LITELLM_PROXY_API_KEY)
-
-vertexai.init(
- project="adroit-crow-413218",
- location="us-central1",
- api_endpoint=LITELLM_PROXY_BASE,
- credentials=credentials,
- api_transport="rest",
-)
-
-
-def embed_text(
- texts: List[str] = ["banana muffins? ", "banana bread? banana muffins?"],
- task: str = "RETRIEVAL_DOCUMENT",
- model_name: str = "text-embedding-004",
- dimensionality: Optional[int] = 256,
-) -> List[List[float]]:
- """Embeds texts with a pre-trained, foundational model."""
- model = TextEmbeddingModel.from_pretrained(model_name)
- inputs = [TextEmbeddingInput(text, task) for text in texts]
- kwargs = dict(output_dimensionality=dimensionality) if dimensionality else {}
- embeddings = model.get_embeddings(inputs, **kwargs)
- return [embedding.values for embedding in embeddings]
-```
-
-
-
```shell
curl http://localhost:4000/vertex_ai/publishers/google/models/textembedding-gecko@001:predict \
@@ -411,133 +151,9 @@ curl http://localhost:4000/vertex_ai/publishers/google/models/textembedding-geck
-d '{"instances":[{"content": "gm"}]}'
```
-
-
-
### Imagen API
-
-
-
-
-
-```python
-from typing import List, Optional
-from vertexai.preview.vision_models import ImageGenerationModel
-import vertexai
-from google.auth.credentials import Credentials
-
-LITELLM_PROXY_API_KEY = "sk-1234"
-LITELLM_PROXY_BASE = "http://0.0.0.0:4000/vertex_ai"
-
-import datetime
-
-vertexai.init(
- project="adroit-crow-413218",
- location="us-central1",
- api_endpoint=LITELLM_PROXY_BASE,
- api_transport="rest",
-)
-
-model = ImageGenerationModel.from_pretrained("imagen-3.0-generate-001")
-
-images = model.generate_images(
- prompt=prompt,
- # Optional parameters
- number_of_images=1,
- language="en",
- # You can't use a seed value and watermark at the same time.
- # add_watermark=False,
- # seed=100,
- aspect_ratio="1:1",
- safety_filter_level="block_some",
- person_generation="allow_adult",
-)
-
-images[0].save(location=output_file, include_generation_parameters=False)
-
-# Optional. View the generated image in a notebook.
-# images[0].show()
-
-print(f"Created output image using {len(images[0]._image_bytes)} bytes")
-
-```
-
-
-
-
-```python
-from typing import List, Optional
-from vertexai.preview.vision_models import ImageGenerationModel
-import vertexai
-from google.auth.credentials import Credentials
-
-LITELLM_PROXY_API_KEY = "sk-1234"
-LITELLM_PROXY_BASE = "http://0.0.0.0:4000/vertex_ai"
-
-import datetime
-
-
-class CredentialsWrapper(Credentials):
- def __init__(self, token=None):
- super().__init__()
- self.token = token
- self.expiry = None # or set to a future date if needed
-
- def refresh(self, request):
- pass
-
- def apply(self, headers, token=None):
- headers["Authorization"] = f"Bearer {self.token}"
-
- @property
- def expired(self):
- return False # Always consider the token as non-expired
-
- @property
- def valid(self):
- return True # Always consider the credentials as valid
-
-
-credentials = CredentialsWrapper(token=LITELLM_PROXY_API_KEY)
-
-vertexai.init(
- project="adroit-crow-413218",
- location="us-central1",
- api_endpoint=LITELLM_PROXY_BASE,
- credentials=credentials,
- api_transport="rest",
-)
-
-model = ImageGenerationModel.from_pretrained("imagen-3.0-generate-001")
-
-images = model.generate_images(
- prompt=prompt,
- # Optional parameters
- number_of_images=1,
- language="en",
- # You can't use a seed value and watermark at the same time.
- # add_watermark=False,
- # seed=100,
- aspect_ratio="1:1",
- safety_filter_level="block_some",
- person_generation="allow_adult",
-)
-
-images[0].save(location=output_file, include_generation_parameters=False)
-
-# Optional. View the generated image in a notebook.
-# images[0].show()
-
-print(f"Created output image using {len(images[0]._image_bytes)} bytes")
-
-```
-
-
-
-
-
```shell
curl http://localhost:4000/vertex_ai/publishers/google/models/imagen-3.0-generate-001:predict \
-H "Content-Type: application/json" \
@@ -545,252 +161,19 @@ curl http://localhost:4000/vertex_ai/publishers/google/models/imagen-3.0-generat
-d '{"instances":[{"prompt": "make an otter"}], "parameters": {"sampleCount": 1}}'
```
-
-
-
### Count Tokens API
-
-
-
-
-
-
-```python
-from typing import List, Optional
-from vertexai.generative_models import GenerativeModel
-import vertexai
-
-LITELLM_PROXY_API_KEY = "sk-1234"
-LITELLM_PROXY_BASE = "http://0.0.0.0:4000/vertex_ai"
-
-import datetime
-
-vertexai.init(
- project="adroit-crow-413218",
- location="us-central1",
- api_endpoint=LITELLM_PROXY_BASE,
- api_transport="rest",
-)
-
-
-model = GenerativeModel("gemini-1.5-flash-001")
-
-prompt = "Why is the sky blue?"
-
-# Prompt tokens count
-response = model.count_tokens(prompt)
-print(f"Prompt Token Count: {response.total_tokens}")
-print(f"Prompt Character Count: {response.total_billable_characters}")
-
-# Send text to Gemini
-response = model.generate_content(prompt)
-
-# Response tokens count
-usage_metadata = response.usage_metadata
-print(f"Prompt Token Count: {usage_metadata.prompt_token_count}")
-print(f"Candidates Token Count: {usage_metadata.candidates_token_count}")
-print(f"Total Token Count: {usage_metadata.total_token_count}")
-```
-
-
-
-
-
-
-```python
-from typing import List, Optional
-from vertexai.generative_models import GenerativeModel
-import vertexai
-from google.auth.credentials import Credentials
-
-LITELLM_PROXY_API_KEY = "sk-1234"
-LITELLM_PROXY_BASE = "http://0.0.0.0:4000/vertex_ai"
-
-import datetime
-
-
-class CredentialsWrapper(Credentials):
- def __init__(self, token=None):
- super().__init__()
- self.token = token
- self.expiry = None # or set to a future date if needed
-
- def refresh(self, request):
- pass
-
- def apply(self, headers, token=None):
- headers["Authorization"] = f"Bearer {self.token}"
-
- @property
- def expired(self):
- return False # Always consider the token as non-expired
-
- @property
- def valid(self):
- return True # Always consider the credentials as valid
-
-
-credentials = CredentialsWrapper(token=LITELLM_PROXY_API_KEY)
-
-vertexai.init(
- project="adroit-crow-413218",
- location="us-central1",
- api_endpoint=LITELLM_PROXY_BASE,
- credentials=credentials,
- api_transport="rest",
-)
-
-
-model = GenerativeModel("gemini-1.5-flash-001")
-
-prompt = "Why is the sky blue?"
-
-# Prompt tokens count
-response = model.count_tokens(prompt)
-print(f"Prompt Token Count: {response.total_tokens}")
-print(f"Prompt Character Count: {response.total_billable_characters}")
-
-# Send text to Gemini
-response = model.generate_content(prompt)
-
-# Response tokens count
-usage_metadata = response.usage_metadata
-print(f"Prompt Token Count: {usage_metadata.prompt_token_count}")
-print(f"Candidates Token Count: {usage_metadata.candidates_token_count}")
-print(f"Total Token Count: {usage_metadata.total_token_count}")
-```
-
-
-
-
-
-
-
```shell
curl http://localhost:4000/vertex_ai/publishers/google/models/gemini-1.5-flash-001:countTokens \
-H "Content-Type: application/json" \
-H "x-litellm-api-key: Bearer sk-1234" \
-d '{"contents":[{"role": "user", "parts":[{"text": "hi"}]}]}'
```
-
-
-
-
### Tuning API
Create Fine Tuning Job
-
-
-
-
-```python
-from typing import List, Optional
-from vertexai.preview.tuning import sft
-import vertexai
-
-LITELLM_PROXY_API_KEY = "sk-1234"
-LITELLM_PROXY_BASE = "http://0.0.0.0:4000/vertex_ai"
-
-
-vertexai.init(
- project="adroit-crow-413218",
- location="us-central1",
- api_endpoint=LITELLM_PROXY_BASE,
- api_transport="rest",
-)
-
-
-# TODO(developer): Update project
-vertexai.init(project=PROJECT_ID, location="us-central1")
-
-sft_tuning_job = sft.train(
- source_model="gemini-1.0-pro-002",
- train_dataset="gs://cloud-samples-data/ai-platform/generative_ai/sft_train_data.jsonl",
-)
-
-# Polling for job completion
-while not sft_tuning_job.has_ended:
- time.sleep(60)
- sft_tuning_job.refresh()
-
-print(sft_tuning_job.tuned_model_name)
-print(sft_tuning_job.tuned_model_endpoint_name)
-print(sft_tuning_job.experiment)
-
-```
-
-
-
-
-
-```python
-from typing import List, Optional
-from vertexai.preview.tuning import sft
-import vertexai
-from google.auth.credentials import Credentials
-
-LITELLM_PROXY_API_KEY = "sk-1234"
-LITELLM_PROXY_BASE = "http://0.0.0.0:4000/vertex_ai"
-
-import datetime
-
-
-class CredentialsWrapper(Credentials):
- def __init__(self, token=None):
- super().__init__()
- self.token = token
- self.expiry = None # or set to a future date if needed
-
- def refresh(self, request):
- pass
-
- def apply(self, headers, token=None):
- headers["Authorization"] = f"Bearer {self.token}"
-
- @property
- def expired(self):
- return False # Always consider the token as non-expired
-
- @property
- def valid(self):
- return True # Always consider the credentials as valid
-
-
-credentials = CredentialsWrapper(token=LITELLM_PROXY_API_KEY)
-
-vertexai.init(
- project="adroit-crow-413218",
- location="us-central1",
- api_endpoint=LITELLM_PROXY_BASE,
- credentials=credentials,
- api_transport="rest",
-)
-
-
-# TODO(developer): Update project
-vertexai.init(project=PROJECT_ID, location="us-central1")
-
-sft_tuning_job = sft.train(
- source_model="gemini-1.0-pro-002",
- train_dataset="gs://cloud-samples-data/ai-platform/generative_ai/sft_train_data.jsonl",
-)
-
-# Polling for job completion
-while not sft_tuning_job.has_ended:
- time.sleep(60)
- sft_tuning_job.refresh()
-
-print(sft_tuning_job.tuned_model_name)
-print(sft_tuning_job.tuned_model_endpoint_name)
-print(sft_tuning_job.experiment)
-```
-
-
-
-
```shell
curl http://localhost:4000/vertex_ai/tuningJobs \
@@ -804,118 +187,6 @@ curl http://localhost:4000/vertex_ai/tuningJobs \
}'
```
-
-
-
-
-
-### Context Caching
-
-Use Vertex AI Context Caching
-
-[**Relevant VertexAI Docs**](https://cloud.google.com/vertex-ai/generative-ai/docs/context-cache/context-cache-overview)
-
-
-
-
-
-1. Add model to config.yaml
-```yaml
-model_list:
- # used for /chat/completions, /completions, /embeddings endpoints
- - model_name: gemini-1.5-pro-001
- litellm_params:
- model: vertex_ai/gemini-1.5-pro-001
- vertex_project: "project-id"
- vertex_location: "us-central1"
- vertex_credentials: "adroit-crow-413218-a956eef1a2a8.json" # Add path to service account.json
-
-# used for the /cachedContent and vertexAI native endpoints
-default_vertex_config:
- vertex_project: "adroit-crow-413218"
- vertex_location: "us-central1"
- vertex_credentials: "adroit-crow-413218-a956eef1a2a8.json" # Add path to service account.json
-
-```
-
-2. Start Proxy
-
-```
-$ litellm --config /path/to/config.yaml
-```
-
-3. Make Request!
-We make the request in two steps:
-- Create a cachedContents object
-- Use the cachedContents object in your /chat/completions
-
-**Create a cachedContents object**
-
-First, create a cachedContents object by calling the Vertex `cachedContents` endpoint. The LiteLLM proxy forwards the `/cachedContents` request to the VertexAI API.
-
-```python
-import httpx
-
-# Set Litellm proxy variables
-LITELLM_BASE_URL = "http://0.0.0.0:4000"
-LITELLM_PROXY_API_KEY = "sk-1234"
-
-httpx_client = httpx.Client(timeout=30)
-
-print("Creating cached content")
-create_cache = httpx_client.post(
- url=f"{LITELLM_BASE_URL}/vertex_ai/cachedContents",
- headers={"x-litellm-api-key": f"Bearer {LITELLM_PROXY_API_KEY}"},
- json={
- "model": "gemini-1.5-pro-001",
- "contents": [
- {
- "role": "user",
- "parts": [{
- "text": "This is sample text to demonstrate explicit caching." * 4000
- }]
- }
- ],
- }
-)
-
-print("Response from create_cache:", create_cache)
-create_cache_response = create_cache.json()
-print("JSON from create_cache:", create_cache_response)
-cached_content_name = create_cache_response["name"]
-```
-
-**Use the cachedContents object in your /chat/completions request to VertexAI**
-
-```python
-import openai
-
-# Set Litellm proxy variables
-LITELLM_BASE_URL = "http://0.0.0.0:4000"
-LITELLM_PROXY_API_KEY = "sk-1234"
-
-client = openai.OpenAI(api_key=LITELLM_PROXY_API_KEY, base_url=LITELLM_BASE_URL)
-
-response = client.chat.completions.create(
- model="gemini-1.5-pro-001",
- max_tokens=8192,
- messages=[
- {
- "role": "user",
- "content": "What is the sample text about?",
- },
- ],
- temperature=0.7,
- extra_body={"cached_content": cached_content_name}, # Use the cached content
-)
-
-print("Response from proxy:", response)
-```
-
-
-
-
-
## Advanced
Pre-requisites
@@ -930,6 +201,11 @@ Use this, to avoid giving developers the raw Anthropic API key, but still lettin
```bash
export DATABASE_URL=""
export LITELLM_MASTER_KEY=""
+
+# vertex ai credentials
+export DEFAULT_VERTEXAI_PROJECT="" # "adroit-crow-413218"
+export DEFAULT_VERTEXAI_LOCATION="" # "us-central1"
+export DEFAULT_GOOGLE_APPLICATION_CREDENTIALS="" # "/Users/Downloads/adroit-crow-413218-a956eef1a2a8.json"
```
```bash
diff --git a/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py b/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py
index 1a0d09a88..271e8992c 100644
--- a/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py
+++ b/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py
@@ -28,25 +28,54 @@ from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
create_pass_through_route,
)
+from litellm.secret_managers.main import get_secret_str
+from litellm.types.passthrough_endpoints.vertex_ai import *
router = APIRouter()
-default_vertex_config = None
+
+default_vertex_config: VertexPassThroughCredentials = VertexPassThroughCredentials()
-def set_default_vertex_config(config):
+def _get_vertex_env_vars() -> VertexPassThroughCredentials:
+ """
+ Helper to get vertex pass through config from environment variables
+
+ The following environment variables are used:
+ - DEFAULT_VERTEXAI_PROJECT (project id)
+ - DEFAULT_VERTEXAI_LOCATION (location)
+ - DEFAULT_GOOGLE_APPLICATION_CREDENTIALS (path to credentials file)
+ """
+ return VertexPassThroughCredentials(
+ vertex_project=get_secret_str("DEFAULT_VERTEXAI_PROJECT"),
+ vertex_location=get_secret_str("DEFAULT_VERTEXAI_LOCATION"),
+ vertex_credentials=get_secret_str("DEFAULT_GOOGLE_APPLICATION_CREDENTIALS"),
+ )
+
+
+def set_default_vertex_config(config: Optional[dict] = None):
+ """Sets vertex configuration from provided config and/or environment variables
+
+ Args:
+ config (Optional[dict]): Configuration dictionary
+ Example: {
+ "vertex_project": "my-project-123",
+ "vertex_location": "us-central1",
+ "vertex_credentials": "os.environ/GOOGLE_CREDS"
+ }
+ """
global default_vertex_config
- if config is None:
- return
- if not isinstance(config, dict):
- raise ValueError("invalid config, vertex default config must be a dictionary")
+ # Initialize config dictionary if None
+ if config is None:
+ default_vertex_config = _get_vertex_env_vars()
+ return
if isinstance(config, dict):
for key, value in config.items():
if isinstance(value, str) and value.startswith("os.environ/"):
config[key] = litellm.get_secret(value)
- default_vertex_config = config
+ default_vertex_config = VertexPassThroughCredentials(**config)
def exception_handler(e: Exception):
@@ -140,7 +169,7 @@ async def vertex_proxy_route(
vertex_project = None
vertex_location = None
# Use headers from the incoming request if default_vertex_config is not set
- if default_vertex_config is None:
+ if default_vertex_config.vertex_project is None:
headers = dict(request.headers) or {}
verbose_proxy_logger.debug(
"default_vertex_config not set, incoming request headers %s", headers
@@ -153,9 +182,9 @@ async def vertex_proxy_route(
headers.pop("content-length", None)
headers.pop("host", None)
else:
- vertex_project = default_vertex_config.get("vertex_project")
- vertex_location = default_vertex_config.get("vertex_location")
- vertex_credentials = default_vertex_config.get("vertex_credentials")
+ vertex_project = default_vertex_config.vertex_project
+ vertex_location = default_vertex_config.vertex_location
+ vertex_credentials = default_vertex_config.vertex_credentials
base_target_url = f"https://{vertex_location}-aiplatform.googleapis.com/"
diff --git a/litellm/types/passthrough_endpoints/vertex_ai.py b/litellm/types/passthrough_endpoints/vertex_ai.py
new file mode 100644
index 000000000..3933aadcd
--- /dev/null
+++ b/litellm/types/passthrough_endpoints/vertex_ai.py
@@ -0,0 +1,18 @@
+"""
+Used for /vertex_ai/ pass through endpoints
+"""
+
+from typing import Optional
+
+from pydantic import BaseModel
+
+
+class VertexPassThroughCredentials(BaseModel):
+ # Example: vertex_project = "my-project-123"
+ vertex_project: Optional[str] = None
+
+ # Example: vertex_location = "us-central1"
+ vertex_location: Optional[str] = None
+
+ # Example: vertex_credentials = "/path/to/credentials.json" or "os.environ/GOOGLE_CREDS"
+ vertex_credentials: Optional[str] = None
diff --git a/tests/pass_through_unit_tests/test_unit_test_vertex_pass_through.py b/tests/pass_through_unit_tests/test_unit_test_vertex_pass_through.py
index a7b668813..4c66f6993 100644
--- a/tests/pass_through_unit_tests/test_unit_test_vertex_pass_through.py
+++ b/tests/pass_through_unit_tests/test_unit_test_vertex_pass_through.py
@@ -18,6 +18,10 @@ from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import (
get_litellm_virtual_key,
vertex_proxy_route,
+ _get_vertex_env_vars,
+ set_default_vertex_config,
+ VertexPassThroughCredentials,
+ default_vertex_config,
)
@@ -82,3 +86,84 @@ async def test_vertex_proxy_route_api_key_auth():
mock_auth.assert_called_once()
call_args = mock_auth.call_args[1]
assert call_args["api_key"] == "Bearer test-key-123"
+
+
+@pytest.mark.asyncio
+async def test_get_vertex_env_vars():
+ """Test that _get_vertex_env_vars correctly reads environment variables"""
+ # Set environment variables for the test
+ os.environ["DEFAULT_VERTEXAI_PROJECT"] = "test-project-123"
+ os.environ["DEFAULT_VERTEXAI_LOCATION"] = "us-central1"
+ os.environ["DEFAULT_GOOGLE_APPLICATION_CREDENTIALS"] = "/path/to/creds"
+
+ try:
+ result = _get_vertex_env_vars()
+ print(result)
+
+ # Verify the result
+ assert isinstance(result, VertexPassThroughCredentials)
+ assert result.vertex_project == "test-project-123"
+ assert result.vertex_location == "us-central1"
+ assert result.vertex_credentials == "/path/to/creds"
+
+ finally:
+ # Clean up environment variables
+ del os.environ["DEFAULT_VERTEXAI_PROJECT"]
+ del os.environ["DEFAULT_VERTEXAI_LOCATION"]
+ del os.environ["DEFAULT_GOOGLE_APPLICATION_CREDENTIALS"]
+
+
+@pytest.mark.asyncio
+async def test_set_default_vertex_config():
+ """Test set_default_vertex_config with various inputs"""
+ # Test with None config - set environment variables first
+ os.environ["DEFAULT_VERTEXAI_PROJECT"] = "env-project"
+ os.environ["DEFAULT_VERTEXAI_LOCATION"] = "env-location"
+ os.environ["DEFAULT_GOOGLE_APPLICATION_CREDENTIALS"] = "env-creds"
+ os.environ["GOOGLE_CREDS"] = "secret-creds"
+
+ try:
+ # Test with None config
+ set_default_vertex_config()
+ from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import (
+ default_vertex_config,
+ )
+
+ assert default_vertex_config.vertex_project == "env-project"
+ assert default_vertex_config.vertex_location == "env-location"
+ assert default_vertex_config.vertex_credentials == "env-creds"
+
+ # Test with valid config.yaml settings on vertex_config
+ test_config = {
+ "vertex_project": "my-project-123",
+ "vertex_location": "us-central1",
+ "vertex_credentials": "path/to/creds",
+ }
+ set_default_vertex_config(test_config)
+ from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import (
+ default_vertex_config,
+ )
+
+ assert default_vertex_config.vertex_project == "my-project-123"
+ assert default_vertex_config.vertex_location == "us-central1"
+ assert default_vertex_config.vertex_credentials == "path/to/creds"
+
+ # Test with environment variable reference
+ test_config = {
+ "vertex_project": "my-project-123",
+ "vertex_location": "us-central1",
+ "vertex_credentials": "os.environ/GOOGLE_CREDS",
+ }
+ set_default_vertex_config(test_config)
+ from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import (
+ default_vertex_config,
+ )
+
+ assert default_vertex_config.vertex_credentials == "secret-creds"
+
+ finally:
+ # Clean up environment variables
+ del os.environ["DEFAULT_VERTEXAI_PROJECT"]
+ del os.environ["DEFAULT_VERTEXAI_LOCATION"]
+ del os.environ["DEFAULT_GOOGLE_APPLICATION_CREDENTIALS"]
+ del os.environ["GOOGLE_CREDS"]