diff --git a/docs/my-website/docs/providers/vertex.md b/docs/my-website/docs/providers/vertex.md index aa551d711c..cb8c031c06 100644 --- a/docs/my-website/docs/providers/vertex.md +++ b/docs/my-website/docs/providers/vertex.md @@ -2480,7 +2480,7 @@ create_batch_response = oai_client.batches.create( ```json { - "id": "projects/633608382793/locations/us-central1/batchPredictionJobs/986266568679751680", + "id": "3814889423749775360", "completion_window": "24hrs", "created_at": 1733392026, "endpoint": "", @@ -2503,6 +2503,43 @@ create_batch_response = oai_client.batches.create( } ``` +#### 4. Retrieve a batch + +```python +retrieved_batch = oai_client.batches.retrieve( + batch_id=create_batch_response.id, + extra_body={"custom_llm_provider": "vertex_ai"}, # tell litellm to use `vertex_ai` for this batch request +) +``` + +**Expected Response** + +```json +{ + "id": "3814889423749775360", + "completion_window": "24hrs", + "created_at": 1736500100, + "endpoint": "", + "input_file_id": "gs://example-bucket-1-litellm/litellm-vertex-files/publishers/google/models/gemini-1.5-flash-001/7b2e47f5-3dd4-436d-920f-f9155bbdc952", + "object": "batch", + "status": "completed", + "cancelled_at": null, + "cancelling_at": null, + "completed_at": null, + "error_file_id": null, + "errors": null, + "expired_at": null, + "expires_at": null, + "failed_at": null, + "finalizing_at": null, + "in_progress_at": null, + "metadata": null, + "output_file_id": "gs://example-bucket-1-litellm/litellm-vertex-files/publishers/google/models/gemini-1.5-flash-001", + "request_counts": null +} +``` + + ## **Fine Tuning APIs** diff --git a/litellm/batches/main.py b/litellm/batches/main.py index c7e524f2b0..32428c9c18 100644 --- a/litellm/batches/main.py +++ b/litellm/batches/main.py @@ -416,6 +416,32 @@ def retrieve_batch( max_retries=optional_params.max_retries, retrieve_batch_data=_retrieve_batch_request, ) + elif custom_llm_provider == "vertex_ai": + api_base = optional_params.api_base or "" + vertex_ai_project = ( + optional_params.vertex_project + or litellm.vertex_project + or get_secret_str("VERTEXAI_PROJECT") + ) + vertex_ai_location = ( + optional_params.vertex_location + or litellm.vertex_location + or get_secret_str("VERTEXAI_LOCATION") + ) + vertex_credentials = optional_params.vertex_credentials or get_secret_str( + "VERTEXAI_CREDENTIALS" + ) + + response = vertex_ai_batches_instance.retrieve_batch( + _is_async=_is_async, + batch_id=batch_id, + api_base=api_base, + vertex_project=vertex_ai_project, + vertex_location=vertex_ai_location, + vertex_credentials=vertex_credentials, + timeout=timeout, + max_retries=optional_params.max_retries, + ) else: raise litellm.exceptions.BadRequestError( message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format( diff --git a/litellm/llms/vertex_ai/batches/handler.py b/litellm/llms/vertex_ai/batches/handler.py index 06b2fd6f9d..0274cd5b05 100644 --- a/litellm/llms/vertex_ai/batches/handler.py +++ b/litellm/llms/vertex_ai/batches/handler.py @@ -124,3 +124,91 @@ class VertexAIBatchPrediction(VertexLLM): """Return the base url for the vertex garden models""" # POST https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/LOCATION/batchPredictionJobs return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/batchPredictionJobs" + + def retrieve_batch( + self, + _is_async: bool, + batch_id: str, + api_base: Optional[str], + vertex_credentials: Optional[str], + vertex_project: Optional[str], + vertex_location: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + ) -> Union[Batch, Coroutine[Any, Any, Batch]]: + sync_handler = _get_httpx_client() + + access_token, project_id = self._ensure_access_token( + credentials=vertex_credentials, + project_id=vertex_project, + custom_llm_provider="vertex_ai", + ) + + default_api_base = self.create_vertex_url( + vertex_location=vertex_location or "us-central1", + vertex_project=vertex_project or project_id, + ) + + # Append batch_id to the URL + default_api_base = f"{default_api_base}/{batch_id}" + + if len(default_api_base.split(":")) > 1: + endpoint = default_api_base.split(":")[-1] + else: + endpoint = "" + + _, api_base = self._check_custom_proxy( + api_base=api_base, + custom_llm_provider="vertex_ai", + gemini_api_key=None, + endpoint=endpoint, + stream=None, + auth_header=None, + url=default_api_base, + ) + + headers = { + "Content-Type": "application/json; charset=utf-8", + "Authorization": f"Bearer {access_token}", + } + + if _is_async is True: + return self._async_retrieve_batch( + api_base=api_base, + headers=headers, + ) + + response = sync_handler.get( + url=api_base, + headers=headers, + ) + + if response.status_code != 200: + raise Exception(f"Error: {response.status_code} {response.text}") + + _json_response = response.json() + vertex_batch_response = VertexAIBatchTransformation.transform_vertex_ai_batch_response_to_openai_batch_response( + response=_json_response + ) + return vertex_batch_response + + async def _async_retrieve_batch( + self, + api_base: str, + headers: Dict[str, str], + ) -> Batch: + client = get_async_httpx_client( + llm_provider=litellm.LlmProviders.VERTEX_AI, + ) + response = await client.get( + url=api_base, + headers=headers, + ) + if response.status_code != 200: + raise Exception(f"Error: {response.status_code} {response.text}") + + _json_response = response.json() + vertex_batch_response = VertexAIBatchTransformation.transform_vertex_ai_batch_response_to_openai_batch_response( + response=_json_response + ) + return vertex_batch_response diff --git a/litellm/llms/vertex_ai/batches/transformation.py b/litellm/llms/vertex_ai/batches/transformation.py index c18bbe4292..32cabdcf56 100644 --- a/litellm/llms/vertex_ai/batches/transformation.py +++ b/litellm/llms/vertex_ai/batches/transformation.py @@ -49,7 +49,7 @@ class VertexAIBatchTransformation: cls, response: VertexBatchPredictionResponse ) -> Batch: return Batch( - id=response.get("name", ""), + id=cls._get_batch_id_from_vertex_ai_batch_response(response), completion_window="24hrs", created_at=_convert_vertex_datetime_to_openai_datetime( vertex_datetime=response.get("createTime", "") @@ -66,6 +66,24 @@ class VertexAIBatchTransformation: ), ) + @classmethod + def _get_batch_id_from_vertex_ai_batch_response( + cls, response: VertexBatchPredictionResponse + ) -> str: + """ + Gets the batch id from the Vertex AI Batch response safely + + vertex response: `projects/510528649030/locations/us-central1/batchPredictionJobs/3814889423749775360` + returns: `3814889423749775360` + """ + _name = response.get("name", "") + if not _name: + return "" + + # Split by '/' and get the last part if it exists + parts = _name.split("/") + return parts[-1] if parts else _name + @classmethod def _get_input_file_id_from_vertex_ai_batch_response( cls, response: VertexBatchPredictionResponse diff --git a/tests/batches_tests/test_openai_batches_and_files.py b/tests/batches_tests/test_openai_batches_and_files.py index 14339106a1..c1ab45d6ae 100644 --- a/tests/batches_tests/test_openai_batches_and_files.py +++ b/tests/batches_tests/test_openai_batches_and_files.py @@ -393,4 +393,10 @@ async def test_avertex_batch_prediction(): metadata={"key1": "value1", "key2": "value2"}, ) print("create_batch_response=", create_batch_response) + + retrieved_batch = await litellm.aretrieve_batch( + batch_id=create_batch_response.id, + custom_llm_provider="vertex_ai", + ) + print("retrieved_batch=", retrieved_batch) pass