(Feat - Batches API) add support for retrieving vertex api batch jobs (#7661)

* add _async_retrieve_batch

* fix aretrieve_batch

* fix _get_batch_id_from_vertex_ai_batch_response

* fix batches docs
This commit is contained in:
Ishaan Jaff 2025-01-09 18:35:03 -08:00 committed by GitHub
parent 2507c275f6
commit 13f364682d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 177 additions and 2 deletions

View file

@ -2480,7 +2480,7 @@ create_batch_response = oai_client.batches.create(
```json ```json
{ {
"id": "projects/633608382793/locations/us-central1/batchPredictionJobs/986266568679751680", "id": "3814889423749775360",
"completion_window": "24hrs", "completion_window": "24hrs",
"created_at": 1733392026, "created_at": 1733392026,
"endpoint": "", "endpoint": "",
@ -2503,6 +2503,43 @@ create_batch_response = oai_client.batches.create(
} }
``` ```
#### 4. Retrieve a batch
```python
retrieved_batch = oai_client.batches.retrieve(
batch_id=create_batch_response.id,
extra_body={"custom_llm_provider": "vertex_ai"}, # tell litellm to use `vertex_ai` for this batch request
)
```
**Expected Response**
```json
{
"id": "3814889423749775360",
"completion_window": "24hrs",
"created_at": 1736500100,
"endpoint": "",
"input_file_id": "gs://example-bucket-1-litellm/litellm-vertex-files/publishers/google/models/gemini-1.5-flash-001/7b2e47f5-3dd4-436d-920f-f9155bbdc952",
"object": "batch",
"status": "completed",
"cancelled_at": null,
"cancelling_at": null,
"completed_at": null,
"error_file_id": null,
"errors": null,
"expired_at": null,
"expires_at": null,
"failed_at": null,
"finalizing_at": null,
"in_progress_at": null,
"metadata": null,
"output_file_id": "gs://example-bucket-1-litellm/litellm-vertex-files/publishers/google/models/gemini-1.5-flash-001",
"request_counts": null
}
```
## **Fine Tuning APIs** ## **Fine Tuning APIs**

View file

@ -416,6 +416,32 @@ def retrieve_batch(
max_retries=optional_params.max_retries, max_retries=optional_params.max_retries,
retrieve_batch_data=_retrieve_batch_request, retrieve_batch_data=_retrieve_batch_request,
) )
elif custom_llm_provider == "vertex_ai":
api_base = optional_params.api_base or ""
vertex_ai_project = (
optional_params.vertex_project
or litellm.vertex_project
or get_secret_str("VERTEXAI_PROJECT")
)
vertex_ai_location = (
optional_params.vertex_location
or litellm.vertex_location
or get_secret_str("VERTEXAI_LOCATION")
)
vertex_credentials = optional_params.vertex_credentials or get_secret_str(
"VERTEXAI_CREDENTIALS"
)
response = vertex_ai_batches_instance.retrieve_batch(
_is_async=_is_async,
batch_id=batch_id,
api_base=api_base,
vertex_project=vertex_ai_project,
vertex_location=vertex_ai_location,
vertex_credentials=vertex_credentials,
timeout=timeout,
max_retries=optional_params.max_retries,
)
else: else:
raise litellm.exceptions.BadRequestError( raise litellm.exceptions.BadRequestError(
message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format( message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format(

View file

@ -124,3 +124,91 @@ class VertexAIBatchPrediction(VertexLLM):
"""Return the base url for the vertex garden models""" """Return the base url for the vertex garden models"""
# POST https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/LOCATION/batchPredictionJobs # POST https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/LOCATION/batchPredictionJobs
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/batchPredictionJobs" return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/batchPredictionJobs"
def retrieve_batch(
self,
_is_async: bool,
batch_id: str,
api_base: Optional[str],
vertex_credentials: Optional[str],
vertex_project: Optional[str],
vertex_location: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
) -> Union[Batch, Coroutine[Any, Any, Batch]]:
sync_handler = _get_httpx_client()
access_token, project_id = self._ensure_access_token(
credentials=vertex_credentials,
project_id=vertex_project,
custom_llm_provider="vertex_ai",
)
default_api_base = self.create_vertex_url(
vertex_location=vertex_location or "us-central1",
vertex_project=vertex_project or project_id,
)
# Append batch_id to the URL
default_api_base = f"{default_api_base}/{batch_id}"
if len(default_api_base.split(":")) > 1:
endpoint = default_api_base.split(":")[-1]
else:
endpoint = ""
_, api_base = self._check_custom_proxy(
api_base=api_base,
custom_llm_provider="vertex_ai",
gemini_api_key=None,
endpoint=endpoint,
stream=None,
auth_header=None,
url=default_api_base,
)
headers = {
"Content-Type": "application/json; charset=utf-8",
"Authorization": f"Bearer {access_token}",
}
if _is_async is True:
return self._async_retrieve_batch(
api_base=api_base,
headers=headers,
)
response = sync_handler.get(
url=api_base,
headers=headers,
)
if response.status_code != 200:
raise Exception(f"Error: {response.status_code} {response.text}")
_json_response = response.json()
vertex_batch_response = VertexAIBatchTransformation.transform_vertex_ai_batch_response_to_openai_batch_response(
response=_json_response
)
return vertex_batch_response
async def _async_retrieve_batch(
self,
api_base: str,
headers: Dict[str, str],
) -> Batch:
client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.VERTEX_AI,
)
response = await client.get(
url=api_base,
headers=headers,
)
if response.status_code != 200:
raise Exception(f"Error: {response.status_code} {response.text}")
_json_response = response.json()
vertex_batch_response = VertexAIBatchTransformation.transform_vertex_ai_batch_response_to_openai_batch_response(
response=_json_response
)
return vertex_batch_response

View file

@ -49,7 +49,7 @@ class VertexAIBatchTransformation:
cls, response: VertexBatchPredictionResponse cls, response: VertexBatchPredictionResponse
) -> Batch: ) -> Batch:
return Batch( return Batch(
id=response.get("name", ""), id=cls._get_batch_id_from_vertex_ai_batch_response(response),
completion_window="24hrs", completion_window="24hrs",
created_at=_convert_vertex_datetime_to_openai_datetime( created_at=_convert_vertex_datetime_to_openai_datetime(
vertex_datetime=response.get("createTime", "") vertex_datetime=response.get("createTime", "")
@ -66,6 +66,24 @@ class VertexAIBatchTransformation:
), ),
) )
@classmethod
def _get_batch_id_from_vertex_ai_batch_response(
cls, response: VertexBatchPredictionResponse
) -> str:
"""
Gets the batch id from the Vertex AI Batch response safely
vertex response: `projects/510528649030/locations/us-central1/batchPredictionJobs/3814889423749775360`
returns: `3814889423749775360`
"""
_name = response.get("name", "")
if not _name:
return ""
# Split by '/' and get the last part if it exists
parts = _name.split("/")
return parts[-1] if parts else _name
@classmethod @classmethod
def _get_input_file_id_from_vertex_ai_batch_response( def _get_input_file_id_from_vertex_ai_batch_response(
cls, response: VertexBatchPredictionResponse cls, response: VertexBatchPredictionResponse

View file

@ -393,4 +393,10 @@ async def test_avertex_batch_prediction():
metadata={"key1": "value1", "key2": "value2"}, metadata={"key1": "value1", "key2": "value2"},
) )
print("create_batch_response=", create_batch_response) print("create_batch_response=", create_batch_response)
retrieved_batch = await litellm.aretrieve_batch(
batch_id=create_batch_response.id,
custom_llm_provider="vertex_ai",
)
print("retrieved_batch=", retrieved_batch)
pass pass