Support litellm.api_base for vertex_ai + gemini/ across completion, embedding, image_generation (#9516)
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 19s
Helm unit test / unit-test (push) Successful in 20s

* test(tests): add unit testing for litellm_proxy integration

* fix(cost_calculator.py): fix tracking cost in sdk when calling proxy

* fix(main.py): respect litellm.api_base on `vertex_ai/` and `gemini/` routes

* fix(main.py): consistently support custom api base across gemini + vertexai on embedding + completion

* feat(vertex_ai/): test

* fix: fix linting error

* test: set api base as None before starting loadtest
This commit is contained in:
Krish Dholakia 2025-03-25 23:46:20 -07:00 committed by GitHub
parent 8657816477
commit 6fd18651d1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 223 additions and 43 deletions

View file

@ -2350,6 +2350,8 @@ def completion( # type: ignore # noqa: PLR0915
or litellm.api_key
)
api_base = api_base or litellm.api_base or get_secret("GEMINI_API_BASE")
new_params = deepcopy(optional_params)
response = vertex_chat_completion.completion( # type: ignore
model=model,
@ -2392,6 +2394,8 @@ def completion( # type: ignore # noqa: PLR0915
or get_secret("VERTEXAI_CREDENTIALS")
)
api_base = api_base or litellm.api_base or get_secret("VERTEXAI_API_BASE")
new_params = deepcopy(optional_params)
if (
model.startswith("meta/")
@ -3657,6 +3661,8 @@ def embedding( # noqa: PLR0915
api_key or get_secret_str("GEMINI_API_KEY") or litellm.api_key
)
api_base = api_base or litellm.api_base or get_secret_str("GEMINI_API_BASE")
response = google_batch_embeddings.batch_embeddings( # type: ignore
model=model,
input=input,
@ -3671,6 +3677,8 @@ def embedding( # noqa: PLR0915
print_verbose=print_verbose,
custom_llm_provider="gemini",
api_key=gemini_api_key,
api_base=api_base,
client=client,
)
elif custom_llm_provider == "vertex_ai":
@ -3695,6 +3703,13 @@ def embedding( # noqa: PLR0915
or get_secret_str("VERTEX_CREDENTIALS")
)
api_base = (
api_base
or litellm.api_base
or get_secret_str("VERTEXAI_API_BASE")
or get_secret_str("VERTEX_API_BASE")
)
if (
"image" in optional_params
or "video" in optional_params
@ -3716,6 +3731,7 @@ def embedding( # noqa: PLR0915
print_verbose=print_verbose,
custom_llm_provider="vertex_ai",
client=client,
api_base=api_base,
)
else:
response = vertex_embedding.embedding(
@ -3733,6 +3749,8 @@ def embedding( # noqa: PLR0915
aembedding=aembedding,
print_verbose=print_verbose,
api_key=api_key,
api_base=api_base,
client=client,
)
elif custom_llm_provider == "oobabooga":
response = oobabooga.embedding(
@ -4695,6 +4713,14 @@ def image_generation( # noqa: PLR0915
or optional_params.pop("vertex_ai_credentials", None)
or get_secret_str("VERTEXAI_CREDENTIALS")
)
api_base = (
api_base
or litellm.api_base
or get_secret_str("VERTEXAI_API_BASE")
or get_secret_str("VERTEX_API_BASE")
)
model_response = vertex_image_generation.image_generation(
model=model,
prompt=prompt,
@ -4706,6 +4732,8 @@ def image_generation( # noqa: PLR0915
vertex_location=vertex_ai_location,
vertex_credentials=vertex_credentials,
aimg_generation=aimg_generation,
api_base=api_base,
client=client,
)
elif (
custom_llm_provider in litellm._custom_providers