Support litellm.api_base for vertex_ai + gemini/ across completion, embedding, image_generation (#9516)
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 19s
Helm unit test / unit-test (push) Successful in 20s

* test(tests): add unit testing for litellm_proxy integration

* fix(cost_calculator.py): fix tracking cost in sdk when calling proxy

* fix(main.py): respect litellm.api_base on `vertex_ai/` and `gemini/` routes

* fix(main.py): consistently support custom api base across gemini + vertexai on embedding + completion

* feat(vertex_ai/): test

* fix: fix linting error

* test: set api base as None before starting loadtest
This commit is contained in:
Krish Dholakia 2025-03-25 23:46:20 -07:00 committed by GitHub
parent 8657816477
commit 6fd18651d1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 223 additions and 43 deletions

View file

@ -828,11 +828,14 @@ def get_response_cost_from_hidden_params(
_hidden_params_dict = hidden_params _hidden_params_dict = hidden_params
additional_headers = _hidden_params_dict.get("additional_headers", {}) additional_headers = _hidden_params_dict.get("additional_headers", {})
if additional_headers and "x-litellm-response-cost" in additional_headers: if (
response_cost = additional_headers["x-litellm-response-cost"] additional_headers
and "llm_provider-x-litellm-response-cost" in additional_headers
):
response_cost = additional_headers["llm_provider-x-litellm-response-cost"]
if response_cost is None: if response_cost is None:
return None return None
return float(additional_headers["x-litellm-response-cost"]) return float(additional_headers["llm_provider-x-litellm-response-cost"])
return None return None

View file

@ -55,7 +55,9 @@ def get_supports_response_schema(
from typing import Literal, Optional from typing import Literal, Optional
all_gemini_url_modes = Literal["chat", "embedding", "batch_embedding"] all_gemini_url_modes = Literal[
"chat", "embedding", "batch_embedding", "image_generation"
]
def _get_vertex_url( def _get_vertex_url(
@ -91,7 +93,11 @@ def _get_vertex_url(
if model.isdigit(): if model.isdigit():
# https://us-central1-aiplatform.googleapis.com/v1/projects/$PROJECT_ID/locations/us-central1/endpoints/$ENDPOINT_ID:predict # https://us-central1-aiplatform.googleapis.com/v1/projects/$PROJECT_ID/locations/us-central1/endpoints/$ENDPOINT_ID:predict
url = f"https://{vertex_location}-aiplatform.googleapis.com/{vertex_api_version}/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}:{endpoint}" url = f"https://{vertex_location}-aiplatform.googleapis.com/{vertex_api_version}/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}:{endpoint}"
elif mode == "image_generation":
endpoint = "predict"
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}"
if model.isdigit():
url = f"https://{vertex_location}-aiplatform.googleapis.com/{vertex_api_version}/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}:{endpoint}"
if not url or not endpoint: if not url or not endpoint:
raise ValueError(f"Unable to get vertex url/endpoint for mode: {mode}") raise ValueError(f"Unable to get vertex url/endpoint for mode: {mode}")
return url, endpoint return url, endpoint
@ -127,6 +133,10 @@ def _get_gemini_url(
url = "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}".format( url = "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}".format(
_gemini_model_name, endpoint, gemini_api_key _gemini_model_name, endpoint, gemini_api_key
) )
elif mode == "image_generation":
raise ValueError(
"LiteLLM's `gemini/` route does not support image generation yet. Let us know if you need this feature by opening an issue at https://github.com/BerriAI/litellm/issues"
)
return url, endpoint return url, endpoint

View file

@ -43,22 +43,23 @@ class VertexImageGeneration(VertexLLM):
def image_generation( def image_generation(
self, self,
prompt: str, prompt: str,
api_base: Optional[str],
vertex_project: Optional[str], vertex_project: Optional[str],
vertex_location: Optional[str], vertex_location: Optional[str],
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES], vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
model_response: ImageResponse, model_response: ImageResponse,
logging_obj: Any, logging_obj: Any,
model: Optional[ model: str = "imagegeneration", # vertex ai uses imagegeneration as the default model
str
] = "imagegeneration", # vertex ai uses imagegeneration as the default model
client: Optional[Any] = None, client: Optional[Any] = None,
optional_params: Optional[dict] = None, optional_params: Optional[dict] = None,
timeout: Optional[int] = None, timeout: Optional[int] = None,
aimg_generation=False, aimg_generation=False,
extra_headers: Optional[dict] = None,
) -> ImageResponse: ) -> ImageResponse:
if aimg_generation is True: if aimg_generation is True:
return self.aimage_generation( # type: ignore return self.aimage_generation( # type: ignore
prompt=prompt, prompt=prompt,
api_base=api_base,
vertex_project=vertex_project, vertex_project=vertex_project,
vertex_location=vertex_location, vertex_location=vertex_location,
vertex_credentials=vertex_credentials, vertex_credentials=vertex_credentials,
@ -83,13 +84,27 @@ class VertexImageGeneration(VertexLLM):
else: else:
sync_handler = client # type: ignore sync_handler = client # type: ignore
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:predict" # url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:predict"
auth_header: Optional[str] = None
auth_header, _ = self._ensure_access_token( auth_header, _ = self._ensure_access_token(
credentials=vertex_credentials, credentials=vertex_credentials,
project_id=vertex_project, project_id=vertex_project,
custom_llm_provider="vertex_ai", custom_llm_provider="vertex_ai",
) )
auth_header, api_base = self._get_token_and_url(
model=model,
gemini_api_key=None,
auth_header=auth_header,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_credentials=vertex_credentials,
stream=False,
custom_llm_provider="vertex_ai",
api_base=api_base,
should_use_v1beta1_features=False,
mode="image_generation",
)
optional_params = optional_params or { optional_params = optional_params or {
"sampleCount": 1 "sampleCount": 1
} # default optional params } # default optional params
@ -99,31 +114,21 @@ class VertexImageGeneration(VertexLLM):
"parameters": optional_params, "parameters": optional_params,
} }
request_str = f"\n curl -X POST \\\n -H \"Authorization: Bearer {auth_header[:10] + 'XXXXXXXXXX'}\" \\\n -H \"Content-Type: application/json; charset=utf-8\" \\\n -d {request_data} \\\n \"{url}\"" headers = self.set_headers(auth_header=auth_header, extra_headers=extra_headers)
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
logging_obj.pre_call( logging_obj.pre_call(
input=prompt, input=prompt,
api_key=None, api_key="",
additional_args={ additional_args={
"complete_input_dict": optional_params, "complete_input_dict": optional_params,
"request_str": request_str, "api_base": api_base,
"headers": headers,
}, },
) )
response = sync_handler.post( response = sync_handler.post(
url=url, url=api_base,
headers={ headers=headers,
"Content-Type": "application/json; charset=utf-8",
"Authorization": f"Bearer {auth_header}",
},
data=json.dumps(request_data), data=json.dumps(request_data),
) )
@ -138,17 +143,17 @@ class VertexImageGeneration(VertexLLM):
async def aimage_generation( async def aimage_generation(
self, self,
prompt: str, prompt: str,
api_base: Optional[str],
vertex_project: Optional[str], vertex_project: Optional[str],
vertex_location: Optional[str], vertex_location: Optional[str],
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES], vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
model_response: litellm.ImageResponse, model_response: litellm.ImageResponse,
logging_obj: Any, logging_obj: Any,
model: Optional[ model: str = "imagegeneration", # vertex ai uses imagegeneration as the default model
str
] = "imagegeneration", # vertex ai uses imagegeneration as the default model
client: Optional[AsyncHTTPHandler] = None, client: Optional[AsyncHTTPHandler] = None,
optional_params: Optional[dict] = None, optional_params: Optional[dict] = None,
timeout: Optional[int] = None, timeout: Optional[int] = None,
extra_headers: Optional[dict] = None,
): ):
response = None response = None
if client is None: if client is None:
@ -169,7 +174,6 @@ class VertexImageGeneration(VertexLLM):
# make POST request to # make POST request to
# https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict # https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:predict"
""" """
Docs link: https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/imagegeneration?project=adroit-crow-413218 Docs link: https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/imagegeneration?project=adroit-crow-413218
@ -188,11 +192,25 @@ class VertexImageGeneration(VertexLLM):
} \ } \
"https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict" "https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict"
""" """
auth_header: Optional[str] = None
auth_header, _ = self._ensure_access_token( auth_header, _ = self._ensure_access_token(
credentials=vertex_credentials, credentials=vertex_credentials,
project_id=vertex_project, project_id=vertex_project,
custom_llm_provider="vertex_ai", custom_llm_provider="vertex_ai",
) )
auth_header, api_base = self._get_token_and_url(
model=model,
gemini_api_key=None,
auth_header=auth_header,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_credentials=vertex_credentials,
stream=False,
custom_llm_provider="vertex_ai",
api_base=api_base,
should_use_v1beta1_features=False,
mode="image_generation",
)
optional_params = optional_params or { optional_params = optional_params or {
"sampleCount": 1 "sampleCount": 1
} # default optional params } # default optional params
@ -202,22 +220,21 @@ class VertexImageGeneration(VertexLLM):
"parameters": optional_params, "parameters": optional_params,
} }
request_str = f"\n curl -X POST \\\n -H \"Authorization: Bearer {auth_header[:10] + 'XXXXXXXXXX'}\" \\\n -H \"Content-Type: application/json; charset=utf-8\" \\\n -d {request_data} \\\n \"{url}\"" headers = self.set_headers(auth_header=auth_header, extra_headers=extra_headers)
logging_obj.pre_call( logging_obj.pre_call(
input=prompt, input=prompt,
api_key=None, api_key="",
additional_args={ additional_args={
"complete_input_dict": optional_params, "complete_input_dict": optional_params,
"request_str": request_str, "api_base": api_base,
"headers": headers,
}, },
) )
response = await self.async_handler.post( response = await self.async_handler.post(
url=url, url=api_base,
headers={ headers=headers,
"Content-Type": "application/json; charset=utf-8",
"Authorization": f"Bearer {auth_header}",
},
data=json.dumps(request_data), data=json.dumps(request_data),
) )

View file

@ -111,7 +111,7 @@ class VertexEmbedding(VertexBase):
) )
try: try:
response = client.post(api_base, headers=headers, json=vertex_request) # type: ignore response = client.post(url=api_base, headers=headers, json=vertex_request) # type: ignore
response.raise_for_status() response.raise_for_status()
except httpx.HTTPStatusError as err: except httpx.HTTPStatusError as err:
error_code = err.response.status_code error_code = err.response.status_code

View file

@ -2350,6 +2350,8 @@ def completion( # type: ignore # noqa: PLR0915
or litellm.api_key or litellm.api_key
) )
api_base = api_base or litellm.api_base or get_secret("GEMINI_API_BASE")
new_params = deepcopy(optional_params) new_params = deepcopy(optional_params)
response = vertex_chat_completion.completion( # type: ignore response = vertex_chat_completion.completion( # type: ignore
model=model, model=model,
@ -2392,6 +2394,8 @@ def completion( # type: ignore # noqa: PLR0915
or get_secret("VERTEXAI_CREDENTIALS") or get_secret("VERTEXAI_CREDENTIALS")
) )
api_base = api_base or litellm.api_base or get_secret("VERTEXAI_API_BASE")
new_params = deepcopy(optional_params) new_params = deepcopy(optional_params)
if ( if (
model.startswith("meta/") model.startswith("meta/")
@ -3657,6 +3661,8 @@ def embedding( # noqa: PLR0915
api_key or get_secret_str("GEMINI_API_KEY") or litellm.api_key api_key or get_secret_str("GEMINI_API_KEY") or litellm.api_key
) )
api_base = api_base or litellm.api_base or get_secret_str("GEMINI_API_BASE")
response = google_batch_embeddings.batch_embeddings( # type: ignore response = google_batch_embeddings.batch_embeddings( # type: ignore
model=model, model=model,
input=input, input=input,
@ -3671,6 +3677,8 @@ def embedding( # noqa: PLR0915
print_verbose=print_verbose, print_verbose=print_verbose,
custom_llm_provider="gemini", custom_llm_provider="gemini",
api_key=gemini_api_key, api_key=gemini_api_key,
api_base=api_base,
client=client,
) )
elif custom_llm_provider == "vertex_ai": elif custom_llm_provider == "vertex_ai":
@ -3695,6 +3703,13 @@ def embedding( # noqa: PLR0915
or get_secret_str("VERTEX_CREDENTIALS") or get_secret_str("VERTEX_CREDENTIALS")
) )
api_base = (
api_base
or litellm.api_base
or get_secret_str("VERTEXAI_API_BASE")
or get_secret_str("VERTEX_API_BASE")
)
if ( if (
"image" in optional_params "image" in optional_params
or "video" in optional_params or "video" in optional_params
@ -3716,6 +3731,7 @@ def embedding( # noqa: PLR0915
print_verbose=print_verbose, print_verbose=print_verbose,
custom_llm_provider="vertex_ai", custom_llm_provider="vertex_ai",
client=client, client=client,
api_base=api_base,
) )
else: else:
response = vertex_embedding.embedding( response = vertex_embedding.embedding(
@ -3733,6 +3749,8 @@ def embedding( # noqa: PLR0915
aembedding=aembedding, aembedding=aembedding,
print_verbose=print_verbose, print_verbose=print_verbose,
api_key=api_key, api_key=api_key,
api_base=api_base,
client=client,
) )
elif custom_llm_provider == "oobabooga": elif custom_llm_provider == "oobabooga":
response = oobabooga.embedding( response = oobabooga.embedding(
@ -4695,6 +4713,14 @@ def image_generation( # noqa: PLR0915
or optional_params.pop("vertex_ai_credentials", None) or optional_params.pop("vertex_ai_credentials", None)
or get_secret_str("VERTEXAI_CREDENTIALS") or get_secret_str("VERTEXAI_CREDENTIALS")
) )
api_base = (
api_base
or litellm.api_base
or get_secret_str("VERTEXAI_API_BASE")
or get_secret_str("VERTEX_API_BASE")
)
model_response = vertex_image_generation.image_generation( model_response = vertex_image_generation.image_generation(
model=model, model=model,
prompt=prompt, prompt=prompt,
@ -4706,6 +4732,8 @@ def image_generation( # noqa: PLR0915
vertex_location=vertex_ai_location, vertex_location=vertex_ai_location,
vertex_credentials=vertex_credentials, vertex_credentials=vertex_credentials,
aimg_generation=aimg_generation, aimg_generation=aimg_generation,
api_base=api_base,
client=client,
) )
elif ( elif (
custom_llm_provider in litellm._custom_providers custom_llm_provider in litellm._custom_providers

View file

@ -1,5 +1,5 @@
model_list: model_list:
- model_name: "gpt-3.5-turbo" - model_name: "gpt-4o"
litellm_params: litellm_params:
model: azure/chatgpt-v-2 model: azure/chatgpt-v-2
api_key: os.environ/AZURE_API_KEY api_key: os.environ/AZURE_API_KEY

View file

@ -15,9 +15,11 @@ from pydantic import BaseModel
from litellm.cost_calculator import response_cost_calculator from litellm.cost_calculator import response_cost_calculator
def test_cost_calculator(): def test_cost_calculator_with_response_cost_in_additional_headers():
class MockResponse(BaseModel): class MockResponse(BaseModel):
_hidden_params = {"additional_headers": {"x-litellm-response-cost": 1000}} _hidden_params = {
"additional_headers": {"llm_provider-x-litellm-response-cost": 1000}
}
result = response_cost_calculator( result = response_cost_calculator(
response_object=MockResponse(), response_object=MockResponse(),

View file

@ -31,7 +31,7 @@ async def test_litellm_gateway_from_sdk():
openai_client = OpenAI(api_key="fake-key") openai_client = OpenAI(api_key="fake-key")
with patch.object( with patch.object(
openai_client.chat.completions, "create", new=MagicMock() openai_client.chat.completions.with_raw_response, "create", new=MagicMock()
) as mock_call: ) as mock_call:
try: try:
completion( completion(
@ -374,3 +374,78 @@ async def test_litellm_gateway_from_sdk_rerank(is_async):
assert request_body["query"] == "What is machine learning?" assert request_body["query"] == "What is machine learning?"
assert request_body["model"] == "rerank-english-v2.0" assert request_body["model"] == "rerank-english-v2.0"
assert len(request_body["documents"]) == 2 assert len(request_body["documents"]) == 2
def test_litellm_gateway_from_sdk_with_response_cost_in_additional_headers():
litellm.set_verbose = True
litellm._turn_on_debug()
from openai import OpenAI
openai_client = OpenAI(api_key="fake-key")
# Create mock response object
mock_response = MagicMock()
mock_response.headers = {"x-litellm-response-cost": "120"}
mock_response.parse.return_value = litellm.ModelResponse(
**{
"id": "chatcmpl-BEkxQvRGp9VAushfAsOZCbhMFLsoy",
"choices": [
{
"finish_reason": "stop",
"index": 0,
"logprobs": None,
"message": {
"content": "Hello! How can I assist you today?",
"refusal": None,
"role": "assistant",
"annotations": [],
"audio": None,
"function_call": None,
"tool_calls": None,
},
}
],
"created": 1742856796,
"model": "gpt-4o-2024-08-06",
"object": "chat.completion",
"service_tier": "default",
"system_fingerprint": "fp_6ec83003ad",
"usage": {
"completion_tokens": 10,
"prompt_tokens": 9,
"total_tokens": 19,
"completion_tokens_details": {
"accepted_prediction_tokens": 0,
"audio_tokens": 0,
"reasoning_tokens": 0,
"rejected_prediction_tokens": 0,
},
"prompt_tokens_details": {"audio_tokens": 0, "cached_tokens": 0},
},
}
)
with patch.object(
openai_client.chat.completions.with_raw_response,
"create",
return_value=mock_response,
) as mock_call:
response = litellm.completion(
model="litellm_proxy/gpt-4o",
messages=[{"role": "user", "content": "Hello world"}],
api_base="http://0.0.0.0:4000",
api_key="sk-PIp1h0RekR",
client=openai_client,
)
# Assert the headers were properly passed through
print(f"additional_headers: {response._hidden_params['additional_headers']}")
assert (
response._hidden_params["additional_headers"][
"llm_provider-x-litellm-response-cost"
]
== "120"
)
assert response._hidden_params["response_cost"] == 120

View file

@ -109,12 +109,13 @@ def analyze_results(vertex_times):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_embedding_performance(): async def test_embedding_performance(monkeypatch):
""" """
Run load test on vertex AI embeddings to ensure vertex median response time is less than 300ms Run load test on vertex AI embeddings to ensure vertex median response time is less than 300ms
20 RPS for 20 seconds 20 RPS for 20 seconds
""" """
monkeypatch.setattr(litellm, "api_base", None)
duration_seconds = 20 duration_seconds = 20
requests_per_second = 20 requests_per_second = 20
vertex_times = await run_load_test(duration_seconds, requests_per_second) vertex_times = await run_load_test(duration_seconds, requests_per_second)

View file

@ -31,6 +31,7 @@ from litellm import (
completion, completion,
completion_cost, completion_cost,
embedding, embedding,
image_generation,
) )
from litellm.llms.vertex_ai.gemini.transformation import ( from litellm.llms.vertex_ai.gemini.transformation import (
_gemini_convert_messages_with_history, _gemini_convert_messages_with_history,
@ -3327,3 +3328,46 @@ def test_signed_s3_url_with_format():
json_str = json.dumps(mock_client.call_args.kwargs["json"]) json_str = json.dumps(mock_client.call_args.kwargs["json"])
assert "image/jpeg" in json_str assert "image/jpeg" in json_str
assert "image/png" not in json_str assert "image/png" not in json_str
@pytest.mark.parametrize("provider", ["vertex_ai", "gemini"])
@pytest.mark.parametrize("route", ["completion", "embedding", "image_generation"])
def test_litellm_api_base(monkeypatch, provider, route):
from litellm.llms.custom_httpx.http_handler import HTTPHandler
client = HTTPHandler()
import litellm
monkeypatch.setattr(litellm, "api_base", "https://litellm.com")
load_vertex_ai_credentials()
if route == "image_generation" and provider == "gemini":
pytest.skip("Gemini does not support image generation")
with patch.object(client, "post", new=MagicMock()) as mock_client:
try:
if route == "completion":
response = completion(
model=f"{provider}/gemini-2.0-flash-001",
messages=[{"role": "user", "content": "Hello, world!"}],
client=client,
)
elif route == "embedding":
response = embedding(
model=f"{provider}/gemini-2.0-flash-001",
input=["Hello, world!"],
client=client,
)
elif route == "image_generation":
response = image_generation(
model=f"{provider}/gemini-2.0-flash-001",
prompt="Hello, world!",
client=client,
)
except Exception as e:
print(e)
mock_client.assert_called()
assert mock_client.call_args.kwargs["url"].startswith("https://litellm.com")