Compare commits

...
Sign in to create a new pull request.

3 commits

Author SHA1 Message Date
Ishaan Jaff
9a1c2f091c add unit testing for vtx pass through auth 2024-11-25 23:55:56 -08:00
Ishaan Jaff
7dfc25f894 allow using known path for setting up pass throughs 2024-11-25 23:39:37 -08:00
Ishaan Jaff
c61c429c44 simplify vertex pass through docs 2024-11-25 23:23:39 -08:00
4 changed files with 186 additions and 778 deletions

View file

@ -69,6 +69,44 @@ generateContent();
</Tabs>
## Quick Start
Let's call the Vertex AI [`/generateContent` endpoint](https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference)
1. Add Vertex AI Credentials to your environment
```bash
export DEFAULT_VERTEXAI_PROJECT="" # "adroit-crow-413218"
export DEFAULT_VERTEXAI_LOCATION="" # "us-central1"
export DEFAULT_GOOGLE_APPLICATION_CREDENTIALS="" # "/Users/Downloads/adroit-crow-413218-a956eef1a2a8.json"
```
2. Start LiteLLM Proxy
```bash
litellm
# RUNNING on http://0.0.0.0:4000
```
3. Test it!
Let's call the Google AI Studio token counting endpoint
```bash
curl http://localhost:4000/vertex-ai/publishers/google/models/gemini-1.0-pro:generateContent \
-H "Content-Type: application/json" \
-H "Authorization: Bearer sk-1234" \
-d '{
"contents":[{
"role": "user",
"parts":[{"text": "How are you doing today?"}]
}]
}'
```
## Supported API Endpoints
- Gemini API
@ -87,206 +125,12 @@ LiteLLM Proxy Server supports two methods of authentication to Vertex AI:
2. Set Vertex AI credentials on proxy server
## Quick Start Usage
<Tabs>
<TabItem value="without_default_config" label="Pass Vertex Credetials client side to proxy server">
#### 1. Start litellm proxy
```shell
litellm --config /path/to/config.yaml
```
#### 2. Test it
```python
import vertexai
from vertexai.preview.generative_models import GenerativeModel
LITE_LLM_ENDPOINT = "http://localhost:4000"
vertexai.init(
project="<your-vertex_ai-project-id>", # enter your project id
location="<your-vertex_ai-location>", # enter your region
api_endpoint=f"{LITE_LLM_ENDPOINT}/vertex_ai", # route on litellm
api_transport="rest",
)
model = GenerativeModel(model_name="gemini-1.0-pro")
model.generate_content("hi")
```
</TabItem>
<TabItem value="with_default_config" label="Set Vertex AI Credentials on Proxy Server">
#### 1. Set `default_vertex_config` on your `config.yaml`
Add the following credentials to your litellm config.yaml to use the Vertex AI endpoints.
```yaml
default_vertex_config:
vertex_project: "adroit-crow-413218"
vertex_location: "us-central1"
vertex_credentials: "/Users/ishaanjaffer/Downloads/adroit-crow-413218-a956eef1a2a8.json" # Add path to service account.json
```
#### 2. Start litellm proxy
```shell
litellm --config /path/to/config.yaml
```
#### 3. Test it
```python
import vertexai
from google.auth.credentials import Credentials
from vertexai.generative_models import GenerativeModel
LITELLM_PROXY_API_KEY = "sk-1234"
LITELLM_PROXY_BASE = "http://0.0.0.0:4000/vertex_ai"
import datetime
class CredentialsWrapper(Credentials):
def __init__(self, token=None):
super().__init__()
self.token = token
self.expiry = None # or set to a future date if needed
def refresh(self, request):
pass
def apply(self, headers, token=None):
headers["Authorization"] = f"Bearer {self.token}"
@property
def expired(self):
return False # Always consider the token as non-expired
@property
def valid(self):
return True # Always consider the credentials as valid
credentials = CredentialsWrapper(token=LITELLM_PROXY_API_KEY)
vertexai.init(
project="adroit-crow-413218",
location="us-central1",
api_endpoint=LITELLM_PROXY_BASE,
credentials=credentials,
api_transport="rest",
)
model = GenerativeModel("gemini-1.5-flash-001")
response = model.generate_content(
"What's a good name for a flower shop that specializes in selling bouquets of dried flowers?"
)
print(response.text)
```
</TabItem>
</Tabs>
## Usage Examples
### Gemini API (Generate Content)
<Tabs>
<TabItem value="client_side" label="Vertex Python SDK (client side vertex credentials)">
```python
import vertexai
from vertexai.generative_models import GenerativeModel
LITELLM_PROXY_API_KEY = "sk-1234"
LITELLM_PROXY_BASE = "http://0.0.0.0:4000/vertex_ai"
vertexai.init(
project="adroit-crow-413218",
location="us-central1",
api_endpoint=LITELLM_PROXY_BASE,
api_transport="rest",
)
model = GenerativeModel("gemini-1.5-flash-001")
response = model.generate_content(
"What's a good name for a flower shop that specializes in selling bouquets of dried flowers?"
)
print(response.text)
```
</TabItem>
<TabItem value="py" label="Vertex Python SDK (litellm virtual keys client side)">
```python
import vertexai
from google.auth.credentials import Credentials
from vertexai.generative_models import GenerativeModel
LITELLM_PROXY_API_KEY = "sk-1234"
LITELLM_PROXY_BASE = "http://0.0.0.0:4000/vertex_ai"
import datetime
class CredentialsWrapper(Credentials):
def __init__(self, token=None):
super().__init__()
self.token = token
self.expiry = None # or set to a future date if needed
def refresh(self, request):
pass
def apply(self, headers, token=None):
headers["Authorization"] = f"Bearer {self.token}"
@property
def expired(self):
return False # Always consider the token as non-expired
@property
def valid(self):
return True # Always consider the credentials as valid
credentials = CredentialsWrapper(token=LITELLM_PROXY_API_KEY)
vertexai.init(
project="adroit-crow-413218",
location="us-central1",
api_endpoint=LITELLM_PROXY_BASE,
credentials=credentials,
api_transport="rest",
)
model = GenerativeModel("gemini-1.5-flash-001")
response = model.generate_content(
"What's a good name for a flower shop that specializes in selling bouquets of dried flowers?"
)
print(response.text)
```
</TabItem>
<TabItem value="Curl" label="Curl">
```shell
curl http://localhost:4000/vertex_ai/publishers/google/models/gemini-1.5-flash-001:generateContent \
@ -295,114 +139,10 @@ curl http://localhost:4000/vertex_ai/publishers/google/models/gemini-1.5-flash-0
-d '{"contents":[{"role": "user", "parts":[{"text": "hi"}]}]}'
```
</TabItem>
</Tabs>
### Embeddings API
<Tabs>
<TabItem value="client_side" label="Vertex Python SDK (client side vertex credentials)">
```python
from typing import List, Optional
from vertexai.language_models import TextEmbeddingInput, TextEmbeddingModel
import vertexai
from vertexai.generative_models import GenerativeModel
LITELLM_PROXY_API_KEY = "sk-1234"
LITELLM_PROXY_BASE = "http://0.0.0.0:4000/vertex_ai"
import datetime
vertexai.init(
project="adroit-crow-413218",
location="us-central1",
api_endpoint=LITELLM_PROXY_BASE,
api_transport="rest",
)
def embed_text(
texts: List[str] = ["banana muffins? ", "banana bread? banana muffins?"],
task: str = "RETRIEVAL_DOCUMENT",
model_name: str = "text-embedding-004",
dimensionality: Optional[int] = 256,
) -> List[List[float]]:
"""Embeds texts with a pre-trained, foundational model."""
model = TextEmbeddingModel.from_pretrained(model_name)
inputs = [TextEmbeddingInput(text, task) for text in texts]
kwargs = dict(output_dimensionality=dimensionality) if dimensionality else {}
embeddings = model.get_embeddings(inputs, **kwargs)
return [embedding.values for embedding in embeddings]
```
</TabItem>
<TabItem value="py" label="Vertex Python SDK (litellm virtual keys client side)">
```python
from typing import List, Optional
from vertexai.language_models import TextEmbeddingInput, TextEmbeddingModel
import vertexai
from google.auth.credentials import Credentials
from vertexai.generative_models import GenerativeModel
LITELLM_PROXY_API_KEY = "sk-1234"
LITELLM_PROXY_BASE = "http://0.0.0.0:4000/vertex_ai"
import datetime
class CredentialsWrapper(Credentials):
def __init__(self, token=None):
super().__init__()
self.token = token
self.expiry = None # or set to a future date if needed
def refresh(self, request):
pass
def apply(self, headers, token=None):
headers["Authorization"] = f"Bearer {self.token}"
@property
def expired(self):
return False # Always consider the token as non-expired
@property
def valid(self):
return True # Always consider the credentials as valid
credentials = CredentialsWrapper(token=LITELLM_PROXY_API_KEY)
vertexai.init(
project="adroit-crow-413218",
location="us-central1",
api_endpoint=LITELLM_PROXY_BASE,
credentials=credentials,
api_transport="rest",
)
def embed_text(
texts: List[str] = ["banana muffins? ", "banana bread? banana muffins?"],
task: str = "RETRIEVAL_DOCUMENT",
model_name: str = "text-embedding-004",
dimensionality: Optional[int] = 256,
) -> List[List[float]]:
"""Embeds texts with a pre-trained, foundational model."""
model = TextEmbeddingModel.from_pretrained(model_name)
inputs = [TextEmbeddingInput(text, task) for text in texts]
kwargs = dict(output_dimensionality=dimensionality) if dimensionality else {}
embeddings = model.get_embeddings(inputs, **kwargs)
return [embedding.values for embedding in embeddings]
```
</TabItem>
<TabItem value="curl" label="Curl">
```shell
curl http://localhost:4000/vertex_ai/publishers/google/models/textembedding-gecko@001:predict \
@ -411,133 +151,9 @@ curl http://localhost:4000/vertex_ai/publishers/google/models/textembedding-geck
-d '{"instances":[{"content": "gm"}]}'
```
</TabItem>
</Tabs>
### Imagen API
<Tabs>
<TabItem value="client_side" label="Vertex Python SDK (client side vertex credentials)">
```python
from typing import List, Optional
from vertexai.preview.vision_models import ImageGenerationModel
import vertexai
from google.auth.credentials import Credentials
LITELLM_PROXY_API_KEY = "sk-1234"
LITELLM_PROXY_BASE = "http://0.0.0.0:4000/vertex_ai"
import datetime
vertexai.init(
project="adroit-crow-413218",
location="us-central1",
api_endpoint=LITELLM_PROXY_BASE,
api_transport="rest",
)
model = ImageGenerationModel.from_pretrained("imagen-3.0-generate-001")
images = model.generate_images(
prompt=prompt,
# Optional parameters
number_of_images=1,
language="en",
# You can't use a seed value and watermark at the same time.
# add_watermark=False,
# seed=100,
aspect_ratio="1:1",
safety_filter_level="block_some",
person_generation="allow_adult",
)
images[0].save(location=output_file, include_generation_parameters=False)
# Optional. View the generated image in a notebook.
# images[0].show()
print(f"Created output image using {len(images[0]._image_bytes)} bytes")
```
</TabItem>
<TabItem value="py" label="Vertex Python SDK (litellm virtual keys client side)">
```python
from typing import List, Optional
from vertexai.preview.vision_models import ImageGenerationModel
import vertexai
from google.auth.credentials import Credentials
LITELLM_PROXY_API_KEY = "sk-1234"
LITELLM_PROXY_BASE = "http://0.0.0.0:4000/vertex_ai"
import datetime
class CredentialsWrapper(Credentials):
def __init__(self, token=None):
super().__init__()
self.token = token
self.expiry = None # or set to a future date if needed
def refresh(self, request):
pass
def apply(self, headers, token=None):
headers["Authorization"] = f"Bearer {self.token}"
@property
def expired(self):
return False # Always consider the token as non-expired
@property
def valid(self):
return True # Always consider the credentials as valid
credentials = CredentialsWrapper(token=LITELLM_PROXY_API_KEY)
vertexai.init(
project="adroit-crow-413218",
location="us-central1",
api_endpoint=LITELLM_PROXY_BASE,
credentials=credentials,
api_transport="rest",
)
model = ImageGenerationModel.from_pretrained("imagen-3.0-generate-001")
images = model.generate_images(
prompt=prompt,
# Optional parameters
number_of_images=1,
language="en",
# You can't use a seed value and watermark at the same time.
# add_watermark=False,
# seed=100,
aspect_ratio="1:1",
safety_filter_level="block_some",
person_generation="allow_adult",
)
images[0].save(location=output_file, include_generation_parameters=False)
# Optional. View the generated image in a notebook.
# images[0].show()
print(f"Created output image using {len(images[0]._image_bytes)} bytes")
```
</TabItem>
<TabItem value="curl" label="Curl">
```shell
curl http://localhost:4000/vertex_ai/publishers/google/models/imagen-3.0-generate-001:predict \
-H "Content-Type: application/json" \
@ -545,252 +161,19 @@ curl http://localhost:4000/vertex_ai/publishers/google/models/imagen-3.0-generat
-d '{"instances":[{"prompt": "make an otter"}], "parameters": {"sampleCount": 1}}'
```
</TabItem>
</Tabs>
### Count Tokens API
<Tabs>
<TabItem value="client_side" label="Vertex Python SDK (client side vertex credentials)">
```python
from typing import List, Optional
from vertexai.generative_models import GenerativeModel
import vertexai
LITELLM_PROXY_API_KEY = "sk-1234"
LITELLM_PROXY_BASE = "http://0.0.0.0:4000/vertex_ai"
import datetime
vertexai.init(
project="adroit-crow-413218",
location="us-central1",
api_endpoint=LITELLM_PROXY_BASE,
api_transport="rest",
)
model = GenerativeModel("gemini-1.5-flash-001")
prompt = "Why is the sky blue?"
# Prompt tokens count
response = model.count_tokens(prompt)
print(f"Prompt Token Count: {response.total_tokens}")
print(f"Prompt Character Count: {response.total_billable_characters}")
# Send text to Gemini
response = model.generate_content(prompt)
# Response tokens count
usage_metadata = response.usage_metadata
print(f"Prompt Token Count: {usage_metadata.prompt_token_count}")
print(f"Candidates Token Count: {usage_metadata.candidates_token_count}")
print(f"Total Token Count: {usage_metadata.total_token_count}")
```
</TabItem>
<TabItem value="py" label="Vertex Python SDK (litellm virtual keys client side)">
```python
from typing import List, Optional
from vertexai.generative_models import GenerativeModel
import vertexai
from google.auth.credentials import Credentials
LITELLM_PROXY_API_KEY = "sk-1234"
LITELLM_PROXY_BASE = "http://0.0.0.0:4000/vertex_ai"
import datetime
class CredentialsWrapper(Credentials):
def __init__(self, token=None):
super().__init__()
self.token = token
self.expiry = None # or set to a future date if needed
def refresh(self, request):
pass
def apply(self, headers, token=None):
headers["Authorization"] = f"Bearer {self.token}"
@property
def expired(self):
return False # Always consider the token as non-expired
@property
def valid(self):
return True # Always consider the credentials as valid
credentials = CredentialsWrapper(token=LITELLM_PROXY_API_KEY)
vertexai.init(
project="adroit-crow-413218",
location="us-central1",
api_endpoint=LITELLM_PROXY_BASE,
credentials=credentials,
api_transport="rest",
)
model = GenerativeModel("gemini-1.5-flash-001")
prompt = "Why is the sky blue?"
# Prompt tokens count
response = model.count_tokens(prompt)
print(f"Prompt Token Count: {response.total_tokens}")
print(f"Prompt Character Count: {response.total_billable_characters}")
# Send text to Gemini
response = model.generate_content(prompt)
# Response tokens count
usage_metadata = response.usage_metadata
print(f"Prompt Token Count: {usage_metadata.prompt_token_count}")
print(f"Candidates Token Count: {usage_metadata.candidates_token_count}")
print(f"Total Token Count: {usage_metadata.total_token_count}")
```
</TabItem>
<TabItem value="curl" label="Curl">
```shell
curl http://localhost:4000/vertex_ai/publishers/google/models/gemini-1.5-flash-001:countTokens \
-H "Content-Type: application/json" \
-H "x-litellm-api-key: Bearer sk-1234" \
-d '{"contents":[{"role": "user", "parts":[{"text": "hi"}]}]}'
```
</TabItem>
</Tabs>
### Tuning API
Create Fine Tuning Job
<Tabs>
<TabItem value="client_side" label="Vertex Python SDK (client side vertex credentials)">
```python
from typing import List, Optional
from vertexai.preview.tuning import sft
import vertexai
LITELLM_PROXY_API_KEY = "sk-1234"
LITELLM_PROXY_BASE = "http://0.0.0.0:4000/vertex_ai"
vertexai.init(
project="adroit-crow-413218",
location="us-central1",
api_endpoint=LITELLM_PROXY_BASE,
api_transport="rest",
)
# TODO(developer): Update project
vertexai.init(project=PROJECT_ID, location="us-central1")
sft_tuning_job = sft.train(
source_model="gemini-1.0-pro-002",
train_dataset="gs://cloud-samples-data/ai-platform/generative_ai/sft_train_data.jsonl",
)
# Polling for job completion
while not sft_tuning_job.has_ended:
time.sleep(60)
sft_tuning_job.refresh()
print(sft_tuning_job.tuned_model_name)
print(sft_tuning_job.tuned_model_endpoint_name)
print(sft_tuning_job.experiment)
```
</TabItem>
<TabItem value="py" label="Vertex Python SDK (litellm virtual keys client side)">
```python
from typing import List, Optional
from vertexai.preview.tuning import sft
import vertexai
from google.auth.credentials import Credentials
LITELLM_PROXY_API_KEY = "sk-1234"
LITELLM_PROXY_BASE = "http://0.0.0.0:4000/vertex_ai"
import datetime
class CredentialsWrapper(Credentials):
def __init__(self, token=None):
super().__init__()
self.token = token
self.expiry = None # or set to a future date if needed
def refresh(self, request):
pass
def apply(self, headers, token=None):
headers["Authorization"] = f"Bearer {self.token}"
@property
def expired(self):
return False # Always consider the token as non-expired
@property
def valid(self):
return True # Always consider the credentials as valid
credentials = CredentialsWrapper(token=LITELLM_PROXY_API_KEY)
vertexai.init(
project="adroit-crow-413218",
location="us-central1",
api_endpoint=LITELLM_PROXY_BASE,
credentials=credentials,
api_transport="rest",
)
# TODO(developer): Update project
vertexai.init(project=PROJECT_ID, location="us-central1")
sft_tuning_job = sft.train(
source_model="gemini-1.0-pro-002",
train_dataset="gs://cloud-samples-data/ai-platform/generative_ai/sft_train_data.jsonl",
)
# Polling for job completion
while not sft_tuning_job.has_ended:
time.sleep(60)
sft_tuning_job.refresh()
print(sft_tuning_job.tuned_model_name)
print(sft_tuning_job.tuned_model_endpoint_name)
print(sft_tuning_job.experiment)
```
</TabItem>
<TabItem value="curl" label="Curl">
```shell
curl http://localhost:4000/vertex_ai/tuningJobs \
@ -804,118 +187,6 @@ curl http://localhost:4000/vertex_ai/tuningJobs \
}'
```
</TabItem>
</Tabs>
### Context Caching
Use Vertex AI Context Caching
[**Relevant VertexAI Docs**](https://cloud.google.com/vertex-ai/generative-ai/docs/context-cache/context-cache-overview)
<Tabs>
<TabItem value="proxy" label="LiteLLM PROXY">
1. Add model to config.yaml
```yaml
model_list:
# used for /chat/completions, /completions, /embeddings endpoints
- model_name: gemini-1.5-pro-001
litellm_params:
model: vertex_ai/gemini-1.5-pro-001
vertex_project: "project-id"
vertex_location: "us-central1"
vertex_credentials: "adroit-crow-413218-a956eef1a2a8.json" # Add path to service account.json
# used for the /cachedContent and vertexAI native endpoints
default_vertex_config:
vertex_project: "adroit-crow-413218"
vertex_location: "us-central1"
vertex_credentials: "adroit-crow-413218-a956eef1a2a8.json" # Add path to service account.json
```
2. Start Proxy
```
$ litellm --config /path/to/config.yaml
```
3. Make Request!
We make the request in two steps:
- Create a cachedContents object
- Use the cachedContents object in your /chat/completions
**Create a cachedContents object**
First, create a cachedContents object by calling the Vertex `cachedContents` endpoint. The LiteLLM proxy forwards the `/cachedContents` request to the VertexAI API.
```python
import httpx
# Set Litellm proxy variables
LITELLM_BASE_URL = "http://0.0.0.0:4000"
LITELLM_PROXY_API_KEY = "sk-1234"
httpx_client = httpx.Client(timeout=30)
print("Creating cached content")
create_cache = httpx_client.post(
url=f"{LITELLM_BASE_URL}/vertex_ai/cachedContents",
headers={"x-litellm-api-key": f"Bearer {LITELLM_PROXY_API_KEY}"},
json={
"model": "gemini-1.5-pro-001",
"contents": [
{
"role": "user",
"parts": [{
"text": "This is sample text to demonstrate explicit caching." * 4000
}]
}
],
}
)
print("Response from create_cache:", create_cache)
create_cache_response = create_cache.json()
print("JSON from create_cache:", create_cache_response)
cached_content_name = create_cache_response["name"]
```
**Use the cachedContents object in your /chat/completions request to VertexAI**
```python
import openai
# Set Litellm proxy variables
LITELLM_BASE_URL = "http://0.0.0.0:4000"
LITELLM_PROXY_API_KEY = "sk-1234"
client = openai.OpenAI(api_key=LITELLM_PROXY_API_KEY, base_url=LITELLM_BASE_URL)
response = client.chat.completions.create(
model="gemini-1.5-pro-001",
max_tokens=8192,
messages=[
{
"role": "user",
"content": "What is the sample text about?",
},
],
temperature=0.7,
extra_body={"cached_content": cached_content_name}, # Use the cached content
)
print("Response from proxy:", response)
```
</TabItem>
</Tabs>
## Advanced
Pre-requisites
@ -930,6 +201,11 @@ Use this, to avoid giving developers the raw Anthropic API key, but still lettin
```bash
export DATABASE_URL=""
export LITELLM_MASTER_KEY=""
# vertex ai credentials
export DEFAULT_VERTEXAI_PROJECT="" # "adroit-crow-413218"
export DEFAULT_VERTEXAI_LOCATION="" # "us-central1"
export DEFAULT_GOOGLE_APPLICATION_CREDENTIALS="" # "/Users/Downloads/adroit-crow-413218-a956eef1a2a8.json"
```
```bash

View file

@ -28,25 +28,54 @@ from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
create_pass_through_route,
)
from litellm.secret_managers.main import get_secret_str
from litellm.types.passthrough_endpoints.vertex_ai import *
router = APIRouter()
default_vertex_config = None
default_vertex_config: VertexPassThroughCredentials = VertexPassThroughCredentials()
def set_default_vertex_config(config):
def _get_vertex_env_vars() -> VertexPassThroughCredentials:
"""
Helper to get vertex pass through config from environment variables
The following environment variables are used:
- DEFAULT_VERTEXAI_PROJECT (project id)
- DEFAULT_VERTEXAI_LOCATION (location)
- DEFAULT_GOOGLE_APPLICATION_CREDENTIALS (path to credentials file)
"""
return VertexPassThroughCredentials(
vertex_project=get_secret_str("DEFAULT_VERTEXAI_PROJECT"),
vertex_location=get_secret_str("DEFAULT_VERTEXAI_LOCATION"),
vertex_credentials=get_secret_str("DEFAULT_GOOGLE_APPLICATION_CREDENTIALS"),
)
def set_default_vertex_config(config: Optional[dict] = None):
"""Sets vertex configuration from provided config and/or environment variables
Args:
config (Optional[dict]): Configuration dictionary
Example: {
"vertex_project": "my-project-123",
"vertex_location": "us-central1",
"vertex_credentials": "os.environ/GOOGLE_CREDS"
}
"""
global default_vertex_config
if config is None:
return
if not isinstance(config, dict):
raise ValueError("invalid config, vertex default config must be a dictionary")
# Initialize config dictionary if None
if config is None:
default_vertex_config = _get_vertex_env_vars()
return
if isinstance(config, dict):
for key, value in config.items():
if isinstance(value, str) and value.startswith("os.environ/"):
config[key] = litellm.get_secret(value)
default_vertex_config = config
default_vertex_config = VertexPassThroughCredentials(**config)
def exception_handler(e: Exception):
@ -140,7 +169,7 @@ async def vertex_proxy_route(
vertex_project = None
vertex_location = None
# Use headers from the incoming request if default_vertex_config is not set
if default_vertex_config is None:
if default_vertex_config.vertex_project is None:
headers = dict(request.headers) or {}
verbose_proxy_logger.debug(
"default_vertex_config not set, incoming request headers %s", headers
@ -153,9 +182,9 @@ async def vertex_proxy_route(
headers.pop("content-length", None)
headers.pop("host", None)
else:
vertex_project = default_vertex_config.get("vertex_project")
vertex_location = default_vertex_config.get("vertex_location")
vertex_credentials = default_vertex_config.get("vertex_credentials")
vertex_project = default_vertex_config.vertex_project
vertex_location = default_vertex_config.vertex_location
vertex_credentials = default_vertex_config.vertex_credentials
base_target_url = f"https://{vertex_location}-aiplatform.googleapis.com/"

View file

@ -0,0 +1,18 @@
"""
Used for /vertex_ai/ pass through endpoints
"""
from typing import Optional
from pydantic import BaseModel
class VertexPassThroughCredentials(BaseModel):
# Example: vertex_project = "my-project-123"
vertex_project: Optional[str] = None
# Example: vertex_location = "us-central1"
vertex_location: Optional[str] = None
# Example: vertex_credentials = "/path/to/credentials.json" or "os.environ/GOOGLE_CREDS"
vertex_credentials: Optional[str] = None

View file

@ -18,6 +18,10 @@ from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import (
get_litellm_virtual_key,
vertex_proxy_route,
_get_vertex_env_vars,
set_default_vertex_config,
VertexPassThroughCredentials,
default_vertex_config,
)
@ -82,3 +86,84 @@ async def test_vertex_proxy_route_api_key_auth():
mock_auth.assert_called_once()
call_args = mock_auth.call_args[1]
assert call_args["api_key"] == "Bearer test-key-123"
@pytest.mark.asyncio
async def test_get_vertex_env_vars():
"""Test that _get_vertex_env_vars correctly reads environment variables"""
# Set environment variables for the test
os.environ["DEFAULT_VERTEXAI_PROJECT"] = "test-project-123"
os.environ["DEFAULT_VERTEXAI_LOCATION"] = "us-central1"
os.environ["DEFAULT_GOOGLE_APPLICATION_CREDENTIALS"] = "/path/to/creds"
try:
result = _get_vertex_env_vars()
print(result)
# Verify the result
assert isinstance(result, VertexPassThroughCredentials)
assert result.vertex_project == "test-project-123"
assert result.vertex_location == "us-central1"
assert result.vertex_credentials == "/path/to/creds"
finally:
# Clean up environment variables
del os.environ["DEFAULT_VERTEXAI_PROJECT"]
del os.environ["DEFAULT_VERTEXAI_LOCATION"]
del os.environ["DEFAULT_GOOGLE_APPLICATION_CREDENTIALS"]
@pytest.mark.asyncio
async def test_set_default_vertex_config():
"""Test set_default_vertex_config with various inputs"""
# Test with None config - set environment variables first
os.environ["DEFAULT_VERTEXAI_PROJECT"] = "env-project"
os.environ["DEFAULT_VERTEXAI_LOCATION"] = "env-location"
os.environ["DEFAULT_GOOGLE_APPLICATION_CREDENTIALS"] = "env-creds"
os.environ["GOOGLE_CREDS"] = "secret-creds"
try:
# Test with None config
set_default_vertex_config()
from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import (
default_vertex_config,
)
assert default_vertex_config.vertex_project == "env-project"
assert default_vertex_config.vertex_location == "env-location"
assert default_vertex_config.vertex_credentials == "env-creds"
# Test with valid config.yaml settings on vertex_config
test_config = {
"vertex_project": "my-project-123",
"vertex_location": "us-central1",
"vertex_credentials": "path/to/creds",
}
set_default_vertex_config(test_config)
from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import (
default_vertex_config,
)
assert default_vertex_config.vertex_project == "my-project-123"
assert default_vertex_config.vertex_location == "us-central1"
assert default_vertex_config.vertex_credentials == "path/to/creds"
# Test with environment variable reference
test_config = {
"vertex_project": "my-project-123",
"vertex_location": "us-central1",
"vertex_credentials": "os.environ/GOOGLE_CREDS",
}
set_default_vertex_config(test_config)
from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import (
default_vertex_config,
)
assert default_vertex_config.vertex_credentials == "secret-creds"
finally:
# Clean up environment variables
del os.environ["DEFAULT_VERTEXAI_PROJECT"]
del os.environ["DEFAULT_VERTEXAI_LOCATION"]
del os.environ["DEFAULT_GOOGLE_APPLICATION_CREDENTIALS"]
del os.environ["GOOGLE_CREDS"]