(feat) add Vertex Batches API support in OpenAI format (#7032)

* working request

* working transform

* working request

* transform vertex batch response

* add _async_create_batch

* move gcs functions to base

* fix _get_content_from_openai_file

* transform_openai_file_content_to_vertex_ai_file_content

* fix transform vertex gcs bucket upload to OAI files format

* working e2e test

* _get_gcs_object_name

* fix linting

* add doc string

* fix transform_gcs_bucket_response_to_openai_file_object

* use vertex for batch endpoints

* add batches support for vertex

* test_vertex_batches_endpoint

* test_vertex_batch_prediction

* fix gcs bucket base auth

* docs clean up batches

* docs Batch API

* docs vertex batches api

* test_get_gcs_logging_config_without_service_account

* undo change

* fix vertex md

* test_get_gcs_logging_config_without_service_account

* ci/cd run again
This commit is contained in:
Ishaan Jaff 2024-12-04 19:40:28 -08:00 committed by GitHub
parent dd5ccdd889
commit 0eef9df396
20 changed files with 1347 additions and 424 deletions

View file

@ -22,6 +22,9 @@ import litellm
from litellm import client
from litellm.llms.AzureOpenAI.azure import AzureBatchesAPI
from litellm.llms.OpenAI.openai import OpenAIBatchesAPI
from litellm.llms.vertex_ai_and_google_ai_studio.batches.handler import (
VertexAIBatchPrediction,
)
from litellm.secret_managers.main import get_secret, get_secret_str
from litellm.types.llms.openai import (
Batch,
@ -40,6 +43,7 @@ from litellm.utils import supports_httpx_timeout
####### ENVIRONMENT VARIABLES ###################
openai_batches_instance = OpenAIBatchesAPI()
azure_batches_instance = AzureBatchesAPI()
vertex_ai_batches_instance = VertexAIBatchPrediction(gcs_bucket_name="")
#################################################
@ -47,7 +51,7 @@ async def acreate_batch(
completion_window: Literal["24h"],
endpoint: Literal["/v1/chat/completions", "/v1/embeddings", "/v1/completions"],
input_file_id: str,
custom_llm_provider: Literal["openai", "azure"] = "openai",
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
metadata: Optional[Dict[str, str]] = None,
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
@ -93,7 +97,7 @@ def create_batch(
completion_window: Literal["24h"],
endpoint: Literal["/v1/chat/completions", "/v1/embeddings", "/v1/completions"],
input_file_id: str,
custom_llm_provider: Literal["openai", "azure"] = "openai",
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
metadata: Optional[Dict[str, str]] = None,
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
@ -199,6 +203,32 @@ def create_batch(
max_retries=optional_params.max_retries,
create_batch_data=_create_batch_request,
)
elif custom_llm_provider == "vertex_ai":
api_base = optional_params.api_base or ""
vertex_ai_project = (
optional_params.vertex_project
or litellm.vertex_project
or get_secret_str("VERTEXAI_PROJECT")
)
vertex_ai_location = (
optional_params.vertex_location
or litellm.vertex_location
or get_secret_str("VERTEXAI_LOCATION")
)
vertex_credentials = optional_params.vertex_credentials or get_secret_str(
"VERTEXAI_CREDENTIALS"
)
response = vertex_ai_batches_instance.create_batch(
_is_async=_is_async,
api_base=api_base,
vertex_project=vertex_ai_project,
vertex_location=vertex_ai_location,
vertex_credentials=vertex_credentials,
timeout=timeout,
max_retries=optional_params.max_retries,
create_batch_data=_create_batch_request,
)
else:
raise litellm.exceptions.BadRequestError(
message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format(