Support litellm.api_base for vertex_ai + gemini/ across completion, embedding, image_generation (#9516)
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 19s
Helm unit test / unit-test (push) Successful in 20s

* test(tests): add unit testing for litellm_proxy integration

* fix(cost_calculator.py): fix tracking cost in sdk when calling proxy

* fix(main.py): respect litellm.api_base on `vertex_ai/` and `gemini/` routes

* fix(main.py): consistently support custom api base across gemini + vertexai on embedding + completion

* feat(vertex_ai/): test

* fix: fix linting error

* test: set api base as None before starting loadtest
This commit is contained in:
Krish Dholakia 2025-03-25 23:46:20 -07:00 committed by GitHub
parent 8657816477
commit 6fd18651d1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 223 additions and 43 deletions

View file

@ -828,11 +828,14 @@ def get_response_cost_from_hidden_params(
_hidden_params_dict = hidden_params
additional_headers = _hidden_params_dict.get("additional_headers", {})
if additional_headers and "x-litellm-response-cost" in additional_headers:
response_cost = additional_headers["x-litellm-response-cost"]
if (
additional_headers
and "llm_provider-x-litellm-response-cost" in additional_headers
):
response_cost = additional_headers["llm_provider-x-litellm-response-cost"]
if response_cost is None:
return None
return float(additional_headers["x-litellm-response-cost"])
return float(additional_headers["llm_provider-x-litellm-response-cost"])
return None

View file

@ -55,7 +55,9 @@ def get_supports_response_schema(
from typing import Literal, Optional
all_gemini_url_modes = Literal["chat", "embedding", "batch_embedding"]
all_gemini_url_modes = Literal[
"chat", "embedding", "batch_embedding", "image_generation"
]
def _get_vertex_url(
@ -91,7 +93,11 @@ def _get_vertex_url(
if model.isdigit():
# https://us-central1-aiplatform.googleapis.com/v1/projects/$PROJECT_ID/locations/us-central1/endpoints/$ENDPOINT_ID:predict
url = f"https://{vertex_location}-aiplatform.googleapis.com/{vertex_api_version}/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}:{endpoint}"
elif mode == "image_generation":
endpoint = "predict"
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}"
if model.isdigit():
url = f"https://{vertex_location}-aiplatform.googleapis.com/{vertex_api_version}/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}:{endpoint}"
if not url or not endpoint:
raise ValueError(f"Unable to get vertex url/endpoint for mode: {mode}")
return url, endpoint
@ -127,6 +133,10 @@ def _get_gemini_url(
url = "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}".format(
_gemini_model_name, endpoint, gemini_api_key
)
elif mode == "image_generation":
raise ValueError(
"LiteLLM's `gemini/` route does not support image generation yet. Let us know if you need this feature by opening an issue at https://github.com/BerriAI/litellm/issues"
)
return url, endpoint

View file

@ -43,22 +43,23 @@ class VertexImageGeneration(VertexLLM):
def image_generation(
self,
prompt: str,
api_base: Optional[str],
vertex_project: Optional[str],
vertex_location: Optional[str],
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
model_response: ImageResponse,
logging_obj: Any,
model: Optional[
str
] = "imagegeneration", # vertex ai uses imagegeneration as the default model
model: str = "imagegeneration", # vertex ai uses imagegeneration as the default model
client: Optional[Any] = None,
optional_params: Optional[dict] = None,
timeout: Optional[int] = None,
aimg_generation=False,
extra_headers: Optional[dict] = None,
) -> ImageResponse:
if aimg_generation is True:
return self.aimage_generation( # type: ignore
prompt=prompt,
api_base=api_base,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_credentials=vertex_credentials,
@ -83,13 +84,27 @@ class VertexImageGeneration(VertexLLM):
else:
sync_handler = client # type: ignore
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:predict"
# url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:predict"
auth_header: Optional[str] = None
auth_header, _ = self._ensure_access_token(
credentials=vertex_credentials,
project_id=vertex_project,
custom_llm_provider="vertex_ai",
)
auth_header, api_base = self._get_token_and_url(
model=model,
gemini_api_key=None,
auth_header=auth_header,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_credentials=vertex_credentials,
stream=False,
custom_llm_provider="vertex_ai",
api_base=api_base,
should_use_v1beta1_features=False,
mode="image_generation",
)
optional_params = optional_params or {
"sampleCount": 1
} # default optional params
@ -99,31 +114,21 @@ class VertexImageGeneration(VertexLLM):
"parameters": optional_params,
}
request_str = f"\n curl -X POST \\\n -H \"Authorization: Bearer {auth_header[:10] + 'XXXXXXXXXX'}\" \\\n -H \"Content-Type: application/json; charset=utf-8\" \\\n -d {request_data} \\\n \"{url}\""
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
headers = self.set_headers(auth_header=auth_header, extra_headers=extra_headers)
logging_obj.pre_call(
input=prompt,
api_key=None,
api_key="",
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
"api_base": api_base,
"headers": headers,
},
)
response = sync_handler.post(
url=url,
headers={
"Content-Type": "application/json; charset=utf-8",
"Authorization": f"Bearer {auth_header}",
},
url=api_base,
headers=headers,
data=json.dumps(request_data),
)
@ -138,17 +143,17 @@ class VertexImageGeneration(VertexLLM):
async def aimage_generation(
self,
prompt: str,
api_base: Optional[str],
vertex_project: Optional[str],
vertex_location: Optional[str],
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
model_response: litellm.ImageResponse,
logging_obj: Any,
model: Optional[
str
] = "imagegeneration", # vertex ai uses imagegeneration as the default model
model: str = "imagegeneration", # vertex ai uses imagegeneration as the default model
client: Optional[AsyncHTTPHandler] = None,
optional_params: Optional[dict] = None,
timeout: Optional[int] = None,
extra_headers: Optional[dict] = None,
):
response = None
if client is None:
@ -169,7 +174,6 @@ class VertexImageGeneration(VertexLLM):
# make POST request to
# https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:predict"
"""
Docs link: https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/imagegeneration?project=adroit-crow-413218
@ -188,11 +192,25 @@ class VertexImageGeneration(VertexLLM):
} \
"https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict"
"""
auth_header: Optional[str] = None
auth_header, _ = self._ensure_access_token(
credentials=vertex_credentials,
project_id=vertex_project,
custom_llm_provider="vertex_ai",
)
auth_header, api_base = self._get_token_and_url(
model=model,
gemini_api_key=None,
auth_header=auth_header,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_credentials=vertex_credentials,
stream=False,
custom_llm_provider="vertex_ai",
api_base=api_base,
should_use_v1beta1_features=False,
mode="image_generation",
)
optional_params = optional_params or {
"sampleCount": 1
} # default optional params
@ -202,22 +220,21 @@ class VertexImageGeneration(VertexLLM):
"parameters": optional_params,
}
request_str = f"\n curl -X POST \\\n -H \"Authorization: Bearer {auth_header[:10] + 'XXXXXXXXXX'}\" \\\n -H \"Content-Type: application/json; charset=utf-8\" \\\n -d {request_data} \\\n \"{url}\""
headers = self.set_headers(auth_header=auth_header, extra_headers=extra_headers)
logging_obj.pre_call(
input=prompt,
api_key=None,
api_key="",
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
"api_base": api_base,
"headers": headers,
},
)
response = await self.async_handler.post(
url=url,
headers={
"Content-Type": "application/json; charset=utf-8",
"Authorization": f"Bearer {auth_header}",
},
url=api_base,
headers=headers,
data=json.dumps(request_data),
)

View file

@ -111,7 +111,7 @@ class VertexEmbedding(VertexBase):
)
try:
response = client.post(api_base, headers=headers, json=vertex_request) # type: ignore
response = client.post(url=api_base, headers=headers, json=vertex_request) # type: ignore
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code

View file

@ -2350,6 +2350,8 @@ def completion( # type: ignore # noqa: PLR0915
or litellm.api_key
)
api_base = api_base or litellm.api_base or get_secret("GEMINI_API_BASE")
new_params = deepcopy(optional_params)
response = vertex_chat_completion.completion( # type: ignore
model=model,
@ -2392,6 +2394,8 @@ def completion( # type: ignore # noqa: PLR0915
or get_secret("VERTEXAI_CREDENTIALS")
)
api_base = api_base or litellm.api_base or get_secret("VERTEXAI_API_BASE")
new_params = deepcopy(optional_params)
if (
model.startswith("meta/")
@ -3657,6 +3661,8 @@ def embedding( # noqa: PLR0915
api_key or get_secret_str("GEMINI_API_KEY") or litellm.api_key
)
api_base = api_base or litellm.api_base or get_secret_str("GEMINI_API_BASE")
response = google_batch_embeddings.batch_embeddings( # type: ignore
model=model,
input=input,
@ -3671,6 +3677,8 @@ def embedding( # noqa: PLR0915
print_verbose=print_verbose,
custom_llm_provider="gemini",
api_key=gemini_api_key,
api_base=api_base,
client=client,
)
elif custom_llm_provider == "vertex_ai":
@ -3695,6 +3703,13 @@ def embedding( # noqa: PLR0915
or get_secret_str("VERTEX_CREDENTIALS")
)
api_base = (
api_base
or litellm.api_base
or get_secret_str("VERTEXAI_API_BASE")
or get_secret_str("VERTEX_API_BASE")
)
if (
"image" in optional_params
or "video" in optional_params
@ -3716,6 +3731,7 @@ def embedding( # noqa: PLR0915
print_verbose=print_verbose,
custom_llm_provider="vertex_ai",
client=client,
api_base=api_base,
)
else:
response = vertex_embedding.embedding(
@ -3733,6 +3749,8 @@ def embedding( # noqa: PLR0915
aembedding=aembedding,
print_verbose=print_verbose,
api_key=api_key,
api_base=api_base,
client=client,
)
elif custom_llm_provider == "oobabooga":
response = oobabooga.embedding(
@ -4695,6 +4713,14 @@ def image_generation( # noqa: PLR0915
or optional_params.pop("vertex_ai_credentials", None)
or get_secret_str("VERTEXAI_CREDENTIALS")
)
api_base = (
api_base
or litellm.api_base
or get_secret_str("VERTEXAI_API_BASE")
or get_secret_str("VERTEX_API_BASE")
)
model_response = vertex_image_generation.image_generation(
model=model,
prompt=prompt,
@ -4706,6 +4732,8 @@ def image_generation( # noqa: PLR0915
vertex_location=vertex_ai_location,
vertex_credentials=vertex_credentials,
aimg_generation=aimg_generation,
api_base=api_base,
client=client,
)
elif (
custom_llm_provider in litellm._custom_providers

View file

@ -1,5 +1,5 @@
model_list:
- model_name: "gpt-3.5-turbo"
- model_name: "gpt-4o"
litellm_params:
model: azure/chatgpt-v-2
api_key: os.environ/AZURE_API_KEY

View file

@ -15,9 +15,11 @@ from pydantic import BaseModel
from litellm.cost_calculator import response_cost_calculator
def test_cost_calculator():
def test_cost_calculator_with_response_cost_in_additional_headers():
class MockResponse(BaseModel):
_hidden_params = {"additional_headers": {"x-litellm-response-cost": 1000}}
_hidden_params = {
"additional_headers": {"llm_provider-x-litellm-response-cost": 1000}
}
result = response_cost_calculator(
response_object=MockResponse(),

View file

@ -31,7 +31,7 @@ async def test_litellm_gateway_from_sdk():
openai_client = OpenAI(api_key="fake-key")
with patch.object(
openai_client.chat.completions, "create", new=MagicMock()
openai_client.chat.completions.with_raw_response, "create", new=MagicMock()
) as mock_call:
try:
completion(
@ -374,3 +374,78 @@ async def test_litellm_gateway_from_sdk_rerank(is_async):
assert request_body["query"] == "What is machine learning?"
assert request_body["model"] == "rerank-english-v2.0"
assert len(request_body["documents"]) == 2
def test_litellm_gateway_from_sdk_with_response_cost_in_additional_headers():
litellm.set_verbose = True
litellm._turn_on_debug()
from openai import OpenAI
openai_client = OpenAI(api_key="fake-key")
# Create mock response object
mock_response = MagicMock()
mock_response.headers = {"x-litellm-response-cost": "120"}
mock_response.parse.return_value = litellm.ModelResponse(
**{
"id": "chatcmpl-BEkxQvRGp9VAushfAsOZCbhMFLsoy",
"choices": [
{
"finish_reason": "stop",
"index": 0,
"logprobs": None,
"message": {
"content": "Hello! How can I assist you today?",
"refusal": None,
"role": "assistant",
"annotations": [],
"audio": None,
"function_call": None,
"tool_calls": None,
},
}
],
"created": 1742856796,
"model": "gpt-4o-2024-08-06",
"object": "chat.completion",
"service_tier": "default",
"system_fingerprint": "fp_6ec83003ad",
"usage": {
"completion_tokens": 10,
"prompt_tokens": 9,
"total_tokens": 19,
"completion_tokens_details": {
"accepted_prediction_tokens": 0,
"audio_tokens": 0,
"reasoning_tokens": 0,
"rejected_prediction_tokens": 0,
},
"prompt_tokens_details": {"audio_tokens": 0, "cached_tokens": 0},
},
}
)
with patch.object(
openai_client.chat.completions.with_raw_response,
"create",
return_value=mock_response,
) as mock_call:
response = litellm.completion(
model="litellm_proxy/gpt-4o",
messages=[{"role": "user", "content": "Hello world"}],
api_base="http://0.0.0.0:4000",
api_key="sk-PIp1h0RekR",
client=openai_client,
)
# Assert the headers were properly passed through
print(f"additional_headers: {response._hidden_params['additional_headers']}")
assert (
response._hidden_params["additional_headers"][
"llm_provider-x-litellm-response-cost"
]
== "120"
)
assert response._hidden_params["response_cost"] == 120

View file

@ -109,12 +109,13 @@ def analyze_results(vertex_times):
@pytest.mark.asyncio
async def test_embedding_performance():
async def test_embedding_performance(monkeypatch):
"""
Run load test on vertex AI embeddings to ensure vertex median response time is less than 300ms
20 RPS for 20 seconds
"""
monkeypatch.setattr(litellm, "api_base", None)
duration_seconds = 20
requests_per_second = 20
vertex_times = await run_load_test(duration_seconds, requests_per_second)

View file

@ -31,6 +31,7 @@ from litellm import (
completion,
completion_cost,
embedding,
image_generation,
)
from litellm.llms.vertex_ai.gemini.transformation import (
_gemini_convert_messages_with_history,
@ -3327,3 +3328,46 @@ def test_signed_s3_url_with_format():
json_str = json.dumps(mock_client.call_args.kwargs["json"])
assert "image/jpeg" in json_str
assert "image/png" not in json_str
@pytest.mark.parametrize("provider", ["vertex_ai", "gemini"])
@pytest.mark.parametrize("route", ["completion", "embedding", "image_generation"])
def test_litellm_api_base(monkeypatch, provider, route):
from litellm.llms.custom_httpx.http_handler import HTTPHandler
client = HTTPHandler()
import litellm
monkeypatch.setattr(litellm, "api_base", "https://litellm.com")
load_vertex_ai_credentials()
if route == "image_generation" and provider == "gemini":
pytest.skip("Gemini does not support image generation")
with patch.object(client, "post", new=MagicMock()) as mock_client:
try:
if route == "completion":
response = completion(
model=f"{provider}/gemini-2.0-flash-001",
messages=[{"role": "user", "content": "Hello, world!"}],
client=client,
)
elif route == "embedding":
response = embedding(
model=f"{provider}/gemini-2.0-flash-001",
input=["Hello, world!"],
client=client,
)
elif route == "image_generation":
response = image_generation(
model=f"{provider}/gemini-2.0-flash-001",
prompt="Hello, world!",
client=client,
)
except Exception as e:
print(e)
mock_client.assert_called()
assert mock_client.call_args.kwargs["url"].startswith("https://litellm.com")