mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
Support litellm.api_base
for vertex_ai + gemini/ across completion, embedding, image_generation (#9516)
* test(tests): add unit testing for litellm_proxy integration * fix(cost_calculator.py): fix tracking cost in sdk when calling proxy * fix(main.py): respect litellm.api_base on `vertex_ai/` and `gemini/` routes * fix(main.py): consistently support custom api base across gemini + vertexai on embedding + completion * feat(vertex_ai/): test * fix: fix linting error * test: set api base as None before starting loadtest
This commit is contained in:
parent
8657816477
commit
6fd18651d1
10 changed files with 223 additions and 43 deletions
|
@ -55,7 +55,9 @@ def get_supports_response_schema(
|
|||
|
||||
from typing import Literal, Optional
|
||||
|
||||
all_gemini_url_modes = Literal["chat", "embedding", "batch_embedding"]
|
||||
all_gemini_url_modes = Literal[
|
||||
"chat", "embedding", "batch_embedding", "image_generation"
|
||||
]
|
||||
|
||||
|
||||
def _get_vertex_url(
|
||||
|
@ -91,7 +93,11 @@ def _get_vertex_url(
|
|||
if model.isdigit():
|
||||
# https://us-central1-aiplatform.googleapis.com/v1/projects/$PROJECT_ID/locations/us-central1/endpoints/$ENDPOINT_ID:predict
|
||||
url = f"https://{vertex_location}-aiplatform.googleapis.com/{vertex_api_version}/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}:{endpoint}"
|
||||
|
||||
elif mode == "image_generation":
|
||||
endpoint = "predict"
|
||||
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}"
|
||||
if model.isdigit():
|
||||
url = f"https://{vertex_location}-aiplatform.googleapis.com/{vertex_api_version}/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}:{endpoint}"
|
||||
if not url or not endpoint:
|
||||
raise ValueError(f"Unable to get vertex url/endpoint for mode: {mode}")
|
||||
return url, endpoint
|
||||
|
@ -127,6 +133,10 @@ def _get_gemini_url(
|
|||
url = "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}".format(
|
||||
_gemini_model_name, endpoint, gemini_api_key
|
||||
)
|
||||
elif mode == "image_generation":
|
||||
raise ValueError(
|
||||
"LiteLLM's `gemini/` route does not support image generation yet. Let us know if you need this feature by opening an issue at https://github.com/BerriAI/litellm/issues"
|
||||
)
|
||||
|
||||
return url, endpoint
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue