From c285132ad6a7b7699169289bad3cc391eb09b993 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 25 Nov 2024 23:57:50 -0800 Subject: [PATCH] (docs) Simplify `/vertex_ai/` pass through docs (#6910) * simplify vertex pass through docs * allow using known path for setting up pass throughs * add unit testing for vtx pass through auth --- .../my-website/docs/pass_through/vertex_ai.md | 810 +----------------- .../vertex_ai_endpoints/vertex_endpoints.py | 51 +- .../types/passthrough_endpoints/vertex_ai.py | 18 + .../test_unit_test_vertex_pass_through.py | 85 ++ 4 files changed, 186 insertions(+), 778 deletions(-) create mode 100644 litellm/types/passthrough_endpoints/vertex_ai.py diff --git a/docs/my-website/docs/pass_through/vertex_ai.md b/docs/my-website/docs/pass_through/vertex_ai.md index 744c5e3ff..601f89f4b 100644 --- a/docs/my-website/docs/pass_through/vertex_ai.md +++ b/docs/my-website/docs/pass_through/vertex_ai.md @@ -69,6 +69,44 @@ generateContent(); +## Quick Start + +Let's call the Vertex AI [`/generateContent` endpoint](https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference) + +1. Add Vertex AI Credentials to your environment + +```bash +export DEFAULT_VERTEXAI_PROJECT="" # "adroit-crow-413218" +export DEFAULT_VERTEXAI_LOCATION="" # "us-central1" +export DEFAULT_GOOGLE_APPLICATION_CREDENTIALS="" # "/Users/Downloads/adroit-crow-413218-a956eef1a2a8.json" +``` + +2. Start LiteLLM Proxy + +```bash +litellm + +# RUNNING on http://0.0.0.0:4000 +``` + +3. Test it! + +Let's call the Google AI Studio token counting endpoint + +```bash +curl http://localhost:4000/vertex-ai/publishers/google/models/gemini-1.0-pro:generateContent \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer sk-1234" \ + -d '{ + "contents":[{ + "role": "user", + "parts":[{"text": "How are you doing today?"}] + }] + }' +``` + + + ## Supported API Endpoints - Gemini API @@ -87,206 +125,12 @@ LiteLLM Proxy Server supports two methods of authentication to Vertex AI: 2. Set Vertex AI credentials on proxy server -## Quick Start Usage - - - - - -#### 1. Start litellm proxy - -```shell -litellm --config /path/to/config.yaml -``` - -#### 2. Test it - -```python -import vertexai -from vertexai.preview.generative_models import GenerativeModel - -LITE_LLM_ENDPOINT = "http://localhost:4000" - -vertexai.init( - project="", # enter your project id - location="", # enter your region - api_endpoint=f"{LITE_LLM_ENDPOINT}/vertex_ai", # route on litellm - api_transport="rest", -) - -model = GenerativeModel(model_name="gemini-1.0-pro") -model.generate_content("hi") - -``` - - - - - - -#### 1. Set `default_vertex_config` on your `config.yaml` - - -Add the following credentials to your litellm config.yaml to use the Vertex AI endpoints. - -```yaml -default_vertex_config: - vertex_project: "adroit-crow-413218" - vertex_location: "us-central1" - vertex_credentials: "/Users/ishaanjaffer/Downloads/adroit-crow-413218-a956eef1a2a8.json" # Add path to service account.json -``` - -#### 2. Start litellm proxy - -```shell -litellm --config /path/to/config.yaml -``` - -#### 3. Test it - -```python -import vertexai -from google.auth.credentials import Credentials -from vertexai.generative_models import GenerativeModel - -LITELLM_PROXY_API_KEY = "sk-1234" -LITELLM_PROXY_BASE = "http://0.0.0.0:4000/vertex_ai" - -import datetime - - -class CredentialsWrapper(Credentials): - def __init__(self, token=None): - super().__init__() - self.token = token - self.expiry = None # or set to a future date if needed - - def refresh(self, request): - pass - - def apply(self, headers, token=None): - headers["Authorization"] = f"Bearer {self.token}" - - @property - def expired(self): - return False # Always consider the token as non-expired - - @property - def valid(self): - return True # Always consider the credentials as valid - - -credentials = CredentialsWrapper(token=LITELLM_PROXY_API_KEY) - -vertexai.init( - project="adroit-crow-413218", - location="us-central1", - api_endpoint=LITELLM_PROXY_BASE, - credentials=credentials, - api_transport="rest", -) - -model = GenerativeModel("gemini-1.5-flash-001") - -response = model.generate_content( - "What's a good name for a flower shop that specializes in selling bouquets of dried flowers?" -) - -print(response.text) -``` - - - - ## Usage Examples ### Gemini API (Generate Content) - - -```python -import vertexai -from vertexai.generative_models import GenerativeModel - -LITELLM_PROXY_API_KEY = "sk-1234" -LITELLM_PROXY_BASE = "http://0.0.0.0:4000/vertex_ai" - -vertexai.init( - project="adroit-crow-413218", - location="us-central1", - api_endpoint=LITELLM_PROXY_BASE, - api_transport="rest", - -) - -model = GenerativeModel("gemini-1.5-flash-001") - -response = model.generate_content( - "What's a good name for a flower shop that specializes in selling bouquets of dried flowers?" -) - -print(response.text) -``` - - - - -```python -import vertexai -from google.auth.credentials import Credentials -from vertexai.generative_models import GenerativeModel - -LITELLM_PROXY_API_KEY = "sk-1234" -LITELLM_PROXY_BASE = "http://0.0.0.0:4000/vertex_ai" - -import datetime - - -class CredentialsWrapper(Credentials): - def __init__(self, token=None): - super().__init__() - self.token = token - self.expiry = None # or set to a future date if needed - - def refresh(self, request): - pass - - def apply(self, headers, token=None): - headers["Authorization"] = f"Bearer {self.token}" - - @property - def expired(self): - return False # Always consider the token as non-expired - - @property - def valid(self): - return True # Always consider the credentials as valid - - -credentials = CredentialsWrapper(token=LITELLM_PROXY_API_KEY) - -vertexai.init( - project="adroit-crow-413218", - location="us-central1", - api_endpoint=LITELLM_PROXY_BASE, - credentials=credentials, - api_transport="rest", - -) - -model = GenerativeModel("gemini-1.5-flash-001") - -response = model.generate_content( - "What's a good name for a flower shop that specializes in selling bouquets of dried flowers?" -) - -print(response.text) -``` - - - ```shell curl http://localhost:4000/vertex_ai/publishers/google/models/gemini-1.5-flash-001:generateContent \ @@ -295,114 +139,10 @@ curl http://localhost:4000/vertex_ai/publishers/google/models/gemini-1.5-flash-0 -d '{"contents":[{"role": "user", "parts":[{"text": "hi"}]}]}' ``` - - ### Embeddings API - - - - -```python -from typing import List, Optional -from vertexai.language_models import TextEmbeddingInput, TextEmbeddingModel -import vertexai -from vertexai.generative_models import GenerativeModel - -LITELLM_PROXY_API_KEY = "sk-1234" -LITELLM_PROXY_BASE = "http://0.0.0.0:4000/vertex_ai" - -import datetime - -vertexai.init( - project="adroit-crow-413218", - location="us-central1", - api_endpoint=LITELLM_PROXY_BASE, - api_transport="rest", -) - - -def embed_text( - texts: List[str] = ["banana muffins? ", "banana bread? banana muffins?"], - task: str = "RETRIEVAL_DOCUMENT", - model_name: str = "text-embedding-004", - dimensionality: Optional[int] = 256, -) -> List[List[float]]: - """Embeds texts with a pre-trained, foundational model.""" - model = TextEmbeddingModel.from_pretrained(model_name) - inputs = [TextEmbeddingInput(text, task) for text in texts] - kwargs = dict(output_dimensionality=dimensionality) if dimensionality else {} - embeddings = model.get_embeddings(inputs, **kwargs) - return [embedding.values for embedding in embeddings] -``` - - - - - -```python -from typing import List, Optional -from vertexai.language_models import TextEmbeddingInput, TextEmbeddingModel -import vertexai -from google.auth.credentials import Credentials -from vertexai.generative_models import GenerativeModel - -LITELLM_PROXY_API_KEY = "sk-1234" -LITELLM_PROXY_BASE = "http://0.0.0.0:4000/vertex_ai" - -import datetime - - -class CredentialsWrapper(Credentials): - def __init__(self, token=None): - super().__init__() - self.token = token - self.expiry = None # or set to a future date if needed - - def refresh(self, request): - pass - - def apply(self, headers, token=None): - headers["Authorization"] = f"Bearer {self.token}" - - @property - def expired(self): - return False # Always consider the token as non-expired - - @property - def valid(self): - return True # Always consider the credentials as valid - - -credentials = CredentialsWrapper(token=LITELLM_PROXY_API_KEY) - -vertexai.init( - project="adroit-crow-413218", - location="us-central1", - api_endpoint=LITELLM_PROXY_BASE, - credentials=credentials, - api_transport="rest", -) - - -def embed_text( - texts: List[str] = ["banana muffins? ", "banana bread? banana muffins?"], - task: str = "RETRIEVAL_DOCUMENT", - model_name: str = "text-embedding-004", - dimensionality: Optional[int] = 256, -) -> List[List[float]]: - """Embeds texts with a pre-trained, foundational model.""" - model = TextEmbeddingModel.from_pretrained(model_name) - inputs = [TextEmbeddingInput(text, task) for text in texts] - kwargs = dict(output_dimensionality=dimensionality) if dimensionality else {} - embeddings = model.get_embeddings(inputs, **kwargs) - return [embedding.values for embedding in embeddings] -``` - - - ```shell curl http://localhost:4000/vertex_ai/publishers/google/models/textembedding-gecko@001:predict \ @@ -411,133 +151,9 @@ curl http://localhost:4000/vertex_ai/publishers/google/models/textembedding-geck -d '{"instances":[{"content": "gm"}]}' ``` - - - ### Imagen API - - - - - -```python -from typing import List, Optional -from vertexai.preview.vision_models import ImageGenerationModel -import vertexai -from google.auth.credentials import Credentials - -LITELLM_PROXY_API_KEY = "sk-1234" -LITELLM_PROXY_BASE = "http://0.0.0.0:4000/vertex_ai" - -import datetime - -vertexai.init( - project="adroit-crow-413218", - location="us-central1", - api_endpoint=LITELLM_PROXY_BASE, - api_transport="rest", -) - -model = ImageGenerationModel.from_pretrained("imagen-3.0-generate-001") - -images = model.generate_images( - prompt=prompt, - # Optional parameters - number_of_images=1, - language="en", - # You can't use a seed value and watermark at the same time. - # add_watermark=False, - # seed=100, - aspect_ratio="1:1", - safety_filter_level="block_some", - person_generation="allow_adult", -) - -images[0].save(location=output_file, include_generation_parameters=False) - -# Optional. View the generated image in a notebook. -# images[0].show() - -print(f"Created output image using {len(images[0]._image_bytes)} bytes") - -``` - - - - -```python -from typing import List, Optional -from vertexai.preview.vision_models import ImageGenerationModel -import vertexai -from google.auth.credentials import Credentials - -LITELLM_PROXY_API_KEY = "sk-1234" -LITELLM_PROXY_BASE = "http://0.0.0.0:4000/vertex_ai" - -import datetime - - -class CredentialsWrapper(Credentials): - def __init__(self, token=None): - super().__init__() - self.token = token - self.expiry = None # or set to a future date if needed - - def refresh(self, request): - pass - - def apply(self, headers, token=None): - headers["Authorization"] = f"Bearer {self.token}" - - @property - def expired(self): - return False # Always consider the token as non-expired - - @property - def valid(self): - return True # Always consider the credentials as valid - - -credentials = CredentialsWrapper(token=LITELLM_PROXY_API_KEY) - -vertexai.init( - project="adroit-crow-413218", - location="us-central1", - api_endpoint=LITELLM_PROXY_BASE, - credentials=credentials, - api_transport="rest", -) - -model = ImageGenerationModel.from_pretrained("imagen-3.0-generate-001") - -images = model.generate_images( - prompt=prompt, - # Optional parameters - number_of_images=1, - language="en", - # You can't use a seed value and watermark at the same time. - # add_watermark=False, - # seed=100, - aspect_ratio="1:1", - safety_filter_level="block_some", - person_generation="allow_adult", -) - -images[0].save(location=output_file, include_generation_parameters=False) - -# Optional. View the generated image in a notebook. -# images[0].show() - -print(f"Created output image using {len(images[0]._image_bytes)} bytes") - -``` - - - - - ```shell curl http://localhost:4000/vertex_ai/publishers/google/models/imagen-3.0-generate-001:predict \ -H "Content-Type: application/json" \ @@ -545,252 +161,19 @@ curl http://localhost:4000/vertex_ai/publishers/google/models/imagen-3.0-generat -d '{"instances":[{"prompt": "make an otter"}], "parameters": {"sampleCount": 1}}' ``` - - - ### Count Tokens API - - - - - - -```python -from typing import List, Optional -from vertexai.generative_models import GenerativeModel -import vertexai - -LITELLM_PROXY_API_KEY = "sk-1234" -LITELLM_PROXY_BASE = "http://0.0.0.0:4000/vertex_ai" - -import datetime - -vertexai.init( - project="adroit-crow-413218", - location="us-central1", - api_endpoint=LITELLM_PROXY_BASE, - api_transport="rest", -) - - -model = GenerativeModel("gemini-1.5-flash-001") - -prompt = "Why is the sky blue?" - -# Prompt tokens count -response = model.count_tokens(prompt) -print(f"Prompt Token Count: {response.total_tokens}") -print(f"Prompt Character Count: {response.total_billable_characters}") - -# Send text to Gemini -response = model.generate_content(prompt) - -# Response tokens count -usage_metadata = response.usage_metadata -print(f"Prompt Token Count: {usage_metadata.prompt_token_count}") -print(f"Candidates Token Count: {usage_metadata.candidates_token_count}") -print(f"Total Token Count: {usage_metadata.total_token_count}") -``` - - - - - - -```python -from typing import List, Optional -from vertexai.generative_models import GenerativeModel -import vertexai -from google.auth.credentials import Credentials - -LITELLM_PROXY_API_KEY = "sk-1234" -LITELLM_PROXY_BASE = "http://0.0.0.0:4000/vertex_ai" - -import datetime - - -class CredentialsWrapper(Credentials): - def __init__(self, token=None): - super().__init__() - self.token = token - self.expiry = None # or set to a future date if needed - - def refresh(self, request): - pass - - def apply(self, headers, token=None): - headers["Authorization"] = f"Bearer {self.token}" - - @property - def expired(self): - return False # Always consider the token as non-expired - - @property - def valid(self): - return True # Always consider the credentials as valid - - -credentials = CredentialsWrapper(token=LITELLM_PROXY_API_KEY) - -vertexai.init( - project="adroit-crow-413218", - location="us-central1", - api_endpoint=LITELLM_PROXY_BASE, - credentials=credentials, - api_transport="rest", -) - - -model = GenerativeModel("gemini-1.5-flash-001") - -prompt = "Why is the sky blue?" - -# Prompt tokens count -response = model.count_tokens(prompt) -print(f"Prompt Token Count: {response.total_tokens}") -print(f"Prompt Character Count: {response.total_billable_characters}") - -# Send text to Gemini -response = model.generate_content(prompt) - -# Response tokens count -usage_metadata = response.usage_metadata -print(f"Prompt Token Count: {usage_metadata.prompt_token_count}") -print(f"Candidates Token Count: {usage_metadata.candidates_token_count}") -print(f"Total Token Count: {usage_metadata.total_token_count}") -``` - - - - - - - ```shell curl http://localhost:4000/vertex_ai/publishers/google/models/gemini-1.5-flash-001:countTokens \ -H "Content-Type: application/json" \ -H "x-litellm-api-key: Bearer sk-1234" \ -d '{"contents":[{"role": "user", "parts":[{"text": "hi"}]}]}' ``` - - - - ### Tuning API Create Fine Tuning Job - - - - -```python -from typing import List, Optional -from vertexai.preview.tuning import sft -import vertexai - -LITELLM_PROXY_API_KEY = "sk-1234" -LITELLM_PROXY_BASE = "http://0.0.0.0:4000/vertex_ai" - - -vertexai.init( - project="adroit-crow-413218", - location="us-central1", - api_endpoint=LITELLM_PROXY_BASE, - api_transport="rest", -) - - -# TODO(developer): Update project -vertexai.init(project=PROJECT_ID, location="us-central1") - -sft_tuning_job = sft.train( - source_model="gemini-1.0-pro-002", - train_dataset="gs://cloud-samples-data/ai-platform/generative_ai/sft_train_data.jsonl", -) - -# Polling for job completion -while not sft_tuning_job.has_ended: - time.sleep(60) - sft_tuning_job.refresh() - -print(sft_tuning_job.tuned_model_name) -print(sft_tuning_job.tuned_model_endpoint_name) -print(sft_tuning_job.experiment) - -``` - - - - - -```python -from typing import List, Optional -from vertexai.preview.tuning import sft -import vertexai -from google.auth.credentials import Credentials - -LITELLM_PROXY_API_KEY = "sk-1234" -LITELLM_PROXY_BASE = "http://0.0.0.0:4000/vertex_ai" - -import datetime - - -class CredentialsWrapper(Credentials): - def __init__(self, token=None): - super().__init__() - self.token = token - self.expiry = None # or set to a future date if needed - - def refresh(self, request): - pass - - def apply(self, headers, token=None): - headers["Authorization"] = f"Bearer {self.token}" - - @property - def expired(self): - return False # Always consider the token as non-expired - - @property - def valid(self): - return True # Always consider the credentials as valid - - -credentials = CredentialsWrapper(token=LITELLM_PROXY_API_KEY) - -vertexai.init( - project="adroit-crow-413218", - location="us-central1", - api_endpoint=LITELLM_PROXY_BASE, - credentials=credentials, - api_transport="rest", -) - - -# TODO(developer): Update project -vertexai.init(project=PROJECT_ID, location="us-central1") - -sft_tuning_job = sft.train( - source_model="gemini-1.0-pro-002", - train_dataset="gs://cloud-samples-data/ai-platform/generative_ai/sft_train_data.jsonl", -) - -# Polling for job completion -while not sft_tuning_job.has_ended: - time.sleep(60) - sft_tuning_job.refresh() - -print(sft_tuning_job.tuned_model_name) -print(sft_tuning_job.tuned_model_endpoint_name) -print(sft_tuning_job.experiment) -``` - - - - ```shell curl http://localhost:4000/vertex_ai/tuningJobs \ @@ -804,118 +187,6 @@ curl http://localhost:4000/vertex_ai/tuningJobs \ }' ``` - - - - - -### Context Caching - -Use Vertex AI Context Caching - -[**Relevant VertexAI Docs**](https://cloud.google.com/vertex-ai/generative-ai/docs/context-cache/context-cache-overview) - - - - - -1. Add model to config.yaml -```yaml -model_list: - # used for /chat/completions, /completions, /embeddings endpoints - - model_name: gemini-1.5-pro-001 - litellm_params: - model: vertex_ai/gemini-1.5-pro-001 - vertex_project: "project-id" - vertex_location: "us-central1" - vertex_credentials: "adroit-crow-413218-a956eef1a2a8.json" # Add path to service account.json - -# used for the /cachedContent and vertexAI native endpoints -default_vertex_config: - vertex_project: "adroit-crow-413218" - vertex_location: "us-central1" - vertex_credentials: "adroit-crow-413218-a956eef1a2a8.json" # Add path to service account.json - -``` - -2. Start Proxy - -``` -$ litellm --config /path/to/config.yaml -``` - -3. Make Request! -We make the request in two steps: -- Create a cachedContents object -- Use the cachedContents object in your /chat/completions - -**Create a cachedContents object** - -First, create a cachedContents object by calling the Vertex `cachedContents` endpoint. The LiteLLM proxy forwards the `/cachedContents` request to the VertexAI API. - -```python -import httpx - -# Set Litellm proxy variables -LITELLM_BASE_URL = "http://0.0.0.0:4000" -LITELLM_PROXY_API_KEY = "sk-1234" - -httpx_client = httpx.Client(timeout=30) - -print("Creating cached content") -create_cache = httpx_client.post( - url=f"{LITELLM_BASE_URL}/vertex_ai/cachedContents", - headers={"x-litellm-api-key": f"Bearer {LITELLM_PROXY_API_KEY}"}, - json={ - "model": "gemini-1.5-pro-001", - "contents": [ - { - "role": "user", - "parts": [{ - "text": "This is sample text to demonstrate explicit caching." * 4000 - }] - } - ], - } -) - -print("Response from create_cache:", create_cache) -create_cache_response = create_cache.json() -print("JSON from create_cache:", create_cache_response) -cached_content_name = create_cache_response["name"] -``` - -**Use the cachedContents object in your /chat/completions request to VertexAI** - -```python -import openai - -# Set Litellm proxy variables -LITELLM_BASE_URL = "http://0.0.0.0:4000" -LITELLM_PROXY_API_KEY = "sk-1234" - -client = openai.OpenAI(api_key=LITELLM_PROXY_API_KEY, base_url=LITELLM_BASE_URL) - -response = client.chat.completions.create( - model="gemini-1.5-pro-001", - max_tokens=8192, - messages=[ - { - "role": "user", - "content": "What is the sample text about?", - }, - ], - temperature=0.7, - extra_body={"cached_content": cached_content_name}, # Use the cached content -) - -print("Response from proxy:", response) -``` - - - - - ## Advanced Pre-requisites @@ -930,6 +201,11 @@ Use this, to avoid giving developers the raw Anthropic API key, but still lettin ```bash export DATABASE_URL="" export LITELLM_MASTER_KEY="" + +# vertex ai credentials +export DEFAULT_VERTEXAI_PROJECT="" # "adroit-crow-413218" +export DEFAULT_VERTEXAI_LOCATION="" # "us-central1" +export DEFAULT_GOOGLE_APPLICATION_CREDENTIALS="" # "/Users/Downloads/adroit-crow-413218-a956eef1a2a8.json" ``` ```bash diff --git a/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py b/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py index 1a0d09a88..271e8992c 100644 --- a/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py +++ b/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py @@ -28,25 +28,54 @@ from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.proxy.pass_through_endpoints.pass_through_endpoints import ( create_pass_through_route, ) +from litellm.secret_managers.main import get_secret_str +from litellm.types.passthrough_endpoints.vertex_ai import * router = APIRouter() -default_vertex_config = None + +default_vertex_config: VertexPassThroughCredentials = VertexPassThroughCredentials() -def set_default_vertex_config(config): +def _get_vertex_env_vars() -> VertexPassThroughCredentials: + """ + Helper to get vertex pass through config from environment variables + + The following environment variables are used: + - DEFAULT_VERTEXAI_PROJECT (project id) + - DEFAULT_VERTEXAI_LOCATION (location) + - DEFAULT_GOOGLE_APPLICATION_CREDENTIALS (path to credentials file) + """ + return VertexPassThroughCredentials( + vertex_project=get_secret_str("DEFAULT_VERTEXAI_PROJECT"), + vertex_location=get_secret_str("DEFAULT_VERTEXAI_LOCATION"), + vertex_credentials=get_secret_str("DEFAULT_GOOGLE_APPLICATION_CREDENTIALS"), + ) + + +def set_default_vertex_config(config: Optional[dict] = None): + """Sets vertex configuration from provided config and/or environment variables + + Args: + config (Optional[dict]): Configuration dictionary + Example: { + "vertex_project": "my-project-123", + "vertex_location": "us-central1", + "vertex_credentials": "os.environ/GOOGLE_CREDS" + } + """ global default_vertex_config - if config is None: - return - if not isinstance(config, dict): - raise ValueError("invalid config, vertex default config must be a dictionary") + # Initialize config dictionary if None + if config is None: + default_vertex_config = _get_vertex_env_vars() + return if isinstance(config, dict): for key, value in config.items(): if isinstance(value, str) and value.startswith("os.environ/"): config[key] = litellm.get_secret(value) - default_vertex_config = config + default_vertex_config = VertexPassThroughCredentials(**config) def exception_handler(e: Exception): @@ -140,7 +169,7 @@ async def vertex_proxy_route( vertex_project = None vertex_location = None # Use headers from the incoming request if default_vertex_config is not set - if default_vertex_config is None: + if default_vertex_config.vertex_project is None: headers = dict(request.headers) or {} verbose_proxy_logger.debug( "default_vertex_config not set, incoming request headers %s", headers @@ -153,9 +182,9 @@ async def vertex_proxy_route( headers.pop("content-length", None) headers.pop("host", None) else: - vertex_project = default_vertex_config.get("vertex_project") - vertex_location = default_vertex_config.get("vertex_location") - vertex_credentials = default_vertex_config.get("vertex_credentials") + vertex_project = default_vertex_config.vertex_project + vertex_location = default_vertex_config.vertex_location + vertex_credentials = default_vertex_config.vertex_credentials base_target_url = f"https://{vertex_location}-aiplatform.googleapis.com/" diff --git a/litellm/types/passthrough_endpoints/vertex_ai.py b/litellm/types/passthrough_endpoints/vertex_ai.py new file mode 100644 index 000000000..3933aadcd --- /dev/null +++ b/litellm/types/passthrough_endpoints/vertex_ai.py @@ -0,0 +1,18 @@ +""" +Used for /vertex_ai/ pass through endpoints +""" + +from typing import Optional + +from pydantic import BaseModel + + +class VertexPassThroughCredentials(BaseModel): + # Example: vertex_project = "my-project-123" + vertex_project: Optional[str] = None + + # Example: vertex_location = "us-central1" + vertex_location: Optional[str] = None + + # Example: vertex_credentials = "/path/to/credentials.json" or "os.environ/GOOGLE_CREDS" + vertex_credentials: Optional[str] = None diff --git a/tests/pass_through_unit_tests/test_unit_test_vertex_pass_through.py b/tests/pass_through_unit_tests/test_unit_test_vertex_pass_through.py index a7b668813..4c66f6993 100644 --- a/tests/pass_through_unit_tests/test_unit_test_vertex_pass_through.py +++ b/tests/pass_through_unit_tests/test_unit_test_vertex_pass_through.py @@ -18,6 +18,10 @@ from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import ( get_litellm_virtual_key, vertex_proxy_route, + _get_vertex_env_vars, + set_default_vertex_config, + VertexPassThroughCredentials, + default_vertex_config, ) @@ -82,3 +86,84 @@ async def test_vertex_proxy_route_api_key_auth(): mock_auth.assert_called_once() call_args = mock_auth.call_args[1] assert call_args["api_key"] == "Bearer test-key-123" + + +@pytest.mark.asyncio +async def test_get_vertex_env_vars(): + """Test that _get_vertex_env_vars correctly reads environment variables""" + # Set environment variables for the test + os.environ["DEFAULT_VERTEXAI_PROJECT"] = "test-project-123" + os.environ["DEFAULT_VERTEXAI_LOCATION"] = "us-central1" + os.environ["DEFAULT_GOOGLE_APPLICATION_CREDENTIALS"] = "/path/to/creds" + + try: + result = _get_vertex_env_vars() + print(result) + + # Verify the result + assert isinstance(result, VertexPassThroughCredentials) + assert result.vertex_project == "test-project-123" + assert result.vertex_location == "us-central1" + assert result.vertex_credentials == "/path/to/creds" + + finally: + # Clean up environment variables + del os.environ["DEFAULT_VERTEXAI_PROJECT"] + del os.environ["DEFAULT_VERTEXAI_LOCATION"] + del os.environ["DEFAULT_GOOGLE_APPLICATION_CREDENTIALS"] + + +@pytest.mark.asyncio +async def test_set_default_vertex_config(): + """Test set_default_vertex_config with various inputs""" + # Test with None config - set environment variables first + os.environ["DEFAULT_VERTEXAI_PROJECT"] = "env-project" + os.environ["DEFAULT_VERTEXAI_LOCATION"] = "env-location" + os.environ["DEFAULT_GOOGLE_APPLICATION_CREDENTIALS"] = "env-creds" + os.environ["GOOGLE_CREDS"] = "secret-creds" + + try: + # Test with None config + set_default_vertex_config() + from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import ( + default_vertex_config, + ) + + assert default_vertex_config.vertex_project == "env-project" + assert default_vertex_config.vertex_location == "env-location" + assert default_vertex_config.vertex_credentials == "env-creds" + + # Test with valid config.yaml settings on vertex_config + test_config = { + "vertex_project": "my-project-123", + "vertex_location": "us-central1", + "vertex_credentials": "path/to/creds", + } + set_default_vertex_config(test_config) + from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import ( + default_vertex_config, + ) + + assert default_vertex_config.vertex_project == "my-project-123" + assert default_vertex_config.vertex_location == "us-central1" + assert default_vertex_config.vertex_credentials == "path/to/creds" + + # Test with environment variable reference + test_config = { + "vertex_project": "my-project-123", + "vertex_location": "us-central1", + "vertex_credentials": "os.environ/GOOGLE_CREDS", + } + set_default_vertex_config(test_config) + from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import ( + default_vertex_config, + ) + + assert default_vertex_config.vertex_credentials == "secret-creds" + + finally: + # Clean up environment variables + del os.environ["DEFAULT_VERTEXAI_PROJECT"] + del os.environ["DEFAULT_VERTEXAI_LOCATION"] + del os.environ["DEFAULT_GOOGLE_APPLICATION_CREDENTIALS"] + del os.environ["GOOGLE_CREDS"]