Support litellm.api_base for vertex_ai + gemini/ across completion, embedding, image_generation (#9516)
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 19s
Helm unit test / unit-test (push) Successful in 20s

* test(tests): add unit testing for litellm_proxy integration

* fix(cost_calculator.py): fix tracking cost in sdk when calling proxy

* fix(main.py): respect litellm.api_base on `vertex_ai/` and `gemini/` routes

* fix(main.py): consistently support custom api base across gemini + vertexai on embedding + completion

* feat(vertex_ai/): test

* fix: fix linting error

* test: set api base as None before starting loadtest
This commit is contained in:
Krish Dholakia 2025-03-25 23:46:20 -07:00 committed by GitHub
parent 8657816477
commit 6fd18651d1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 223 additions and 43 deletions

View file

@ -55,7 +55,9 @@ def get_supports_response_schema(
from typing import Literal, Optional
all_gemini_url_modes = Literal["chat", "embedding", "batch_embedding"]
all_gemini_url_modes = Literal[
"chat", "embedding", "batch_embedding", "image_generation"
]
def _get_vertex_url(
@ -91,7 +93,11 @@ def _get_vertex_url(
if model.isdigit():
# https://us-central1-aiplatform.googleapis.com/v1/projects/$PROJECT_ID/locations/us-central1/endpoints/$ENDPOINT_ID:predict
url = f"https://{vertex_location}-aiplatform.googleapis.com/{vertex_api_version}/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}:{endpoint}"
elif mode == "image_generation":
endpoint = "predict"
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}"
if model.isdigit():
url = f"https://{vertex_location}-aiplatform.googleapis.com/{vertex_api_version}/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}:{endpoint}"
if not url or not endpoint:
raise ValueError(f"Unable to get vertex url/endpoint for mode: {mode}")
return url, endpoint
@ -127,6 +133,10 @@ def _get_gemini_url(
url = "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}".format(
_gemini_model_name, endpoint, gemini_api_key
)
elif mode == "image_generation":
raise ValueError(
"LiteLLM's `gemini/` route does not support image generation yet. Let us know if you need this feature by opening an issue at https://github.com/BerriAI/litellm/issues"
)
return url, endpoint