mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
(Feat - Batches API) add support for retrieving vertex api batch jobs (#7661)
* add _async_retrieve_batch * fix aretrieve_batch * fix _get_batch_id_from_vertex_ai_batch_response * fix batches docs
This commit is contained in:
parent
2507c275f6
commit
13f364682d
5 changed files with 177 additions and 2 deletions
|
@ -2480,7 +2480,7 @@ create_batch_response = oai_client.batches.create(
|
||||||
|
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"id": "projects/633608382793/locations/us-central1/batchPredictionJobs/986266568679751680",
|
"id": "3814889423749775360",
|
||||||
"completion_window": "24hrs",
|
"completion_window": "24hrs",
|
||||||
"created_at": 1733392026,
|
"created_at": 1733392026,
|
||||||
"endpoint": "",
|
"endpoint": "",
|
||||||
|
@ -2503,6 +2503,43 @@ create_batch_response = oai_client.batches.create(
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### 4. Retrieve a batch
|
||||||
|
|
||||||
|
```python
|
||||||
|
retrieved_batch = oai_client.batches.retrieve(
|
||||||
|
batch_id=create_batch_response.id,
|
||||||
|
extra_body={"custom_llm_provider": "vertex_ai"}, # tell litellm to use `vertex_ai` for this batch request
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Expected Response**
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"id": "3814889423749775360",
|
||||||
|
"completion_window": "24hrs",
|
||||||
|
"created_at": 1736500100,
|
||||||
|
"endpoint": "",
|
||||||
|
"input_file_id": "gs://example-bucket-1-litellm/litellm-vertex-files/publishers/google/models/gemini-1.5-flash-001/7b2e47f5-3dd4-436d-920f-f9155bbdc952",
|
||||||
|
"object": "batch",
|
||||||
|
"status": "completed",
|
||||||
|
"cancelled_at": null,
|
||||||
|
"cancelling_at": null,
|
||||||
|
"completed_at": null,
|
||||||
|
"error_file_id": null,
|
||||||
|
"errors": null,
|
||||||
|
"expired_at": null,
|
||||||
|
"expires_at": null,
|
||||||
|
"failed_at": null,
|
||||||
|
"finalizing_at": null,
|
||||||
|
"in_progress_at": null,
|
||||||
|
"metadata": null,
|
||||||
|
"output_file_id": "gs://example-bucket-1-litellm/litellm-vertex-files/publishers/google/models/gemini-1.5-flash-001",
|
||||||
|
"request_counts": null
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
## **Fine Tuning APIs**
|
## **Fine Tuning APIs**
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -416,6 +416,32 @@ def retrieve_batch(
|
||||||
max_retries=optional_params.max_retries,
|
max_retries=optional_params.max_retries,
|
||||||
retrieve_batch_data=_retrieve_batch_request,
|
retrieve_batch_data=_retrieve_batch_request,
|
||||||
)
|
)
|
||||||
|
elif custom_llm_provider == "vertex_ai":
|
||||||
|
api_base = optional_params.api_base or ""
|
||||||
|
vertex_ai_project = (
|
||||||
|
optional_params.vertex_project
|
||||||
|
or litellm.vertex_project
|
||||||
|
or get_secret_str("VERTEXAI_PROJECT")
|
||||||
|
)
|
||||||
|
vertex_ai_location = (
|
||||||
|
optional_params.vertex_location
|
||||||
|
or litellm.vertex_location
|
||||||
|
or get_secret_str("VERTEXAI_LOCATION")
|
||||||
|
)
|
||||||
|
vertex_credentials = optional_params.vertex_credentials or get_secret_str(
|
||||||
|
"VERTEXAI_CREDENTIALS"
|
||||||
|
)
|
||||||
|
|
||||||
|
response = vertex_ai_batches_instance.retrieve_batch(
|
||||||
|
_is_async=_is_async,
|
||||||
|
batch_id=batch_id,
|
||||||
|
api_base=api_base,
|
||||||
|
vertex_project=vertex_ai_project,
|
||||||
|
vertex_location=vertex_ai_location,
|
||||||
|
vertex_credentials=vertex_credentials,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=optional_params.max_retries,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise litellm.exceptions.BadRequestError(
|
raise litellm.exceptions.BadRequestError(
|
||||||
message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format(
|
message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format(
|
||||||
|
|
|
@ -124,3 +124,91 @@ class VertexAIBatchPrediction(VertexLLM):
|
||||||
"""Return the base url for the vertex garden models"""
|
"""Return the base url for the vertex garden models"""
|
||||||
# POST https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/LOCATION/batchPredictionJobs
|
# POST https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/LOCATION/batchPredictionJobs
|
||||||
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/batchPredictionJobs"
|
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/batchPredictionJobs"
|
||||||
|
|
||||||
|
def retrieve_batch(
|
||||||
|
self,
|
||||||
|
_is_async: bool,
|
||||||
|
batch_id: str,
|
||||||
|
api_base: Optional[str],
|
||||||
|
vertex_credentials: Optional[str],
|
||||||
|
vertex_project: Optional[str],
|
||||||
|
vertex_location: Optional[str],
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
max_retries: Optional[int],
|
||||||
|
) -> Union[Batch, Coroutine[Any, Any, Batch]]:
|
||||||
|
sync_handler = _get_httpx_client()
|
||||||
|
|
||||||
|
access_token, project_id = self._ensure_access_token(
|
||||||
|
credentials=vertex_credentials,
|
||||||
|
project_id=vertex_project,
|
||||||
|
custom_llm_provider="vertex_ai",
|
||||||
|
)
|
||||||
|
|
||||||
|
default_api_base = self.create_vertex_url(
|
||||||
|
vertex_location=vertex_location or "us-central1",
|
||||||
|
vertex_project=vertex_project or project_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Append batch_id to the URL
|
||||||
|
default_api_base = f"{default_api_base}/{batch_id}"
|
||||||
|
|
||||||
|
if len(default_api_base.split(":")) > 1:
|
||||||
|
endpoint = default_api_base.split(":")[-1]
|
||||||
|
else:
|
||||||
|
endpoint = ""
|
||||||
|
|
||||||
|
_, api_base = self._check_custom_proxy(
|
||||||
|
api_base=api_base,
|
||||||
|
custom_llm_provider="vertex_ai",
|
||||||
|
gemini_api_key=None,
|
||||||
|
endpoint=endpoint,
|
||||||
|
stream=None,
|
||||||
|
auth_header=None,
|
||||||
|
url=default_api_base,
|
||||||
|
)
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json; charset=utf-8",
|
||||||
|
"Authorization": f"Bearer {access_token}",
|
||||||
|
}
|
||||||
|
|
||||||
|
if _is_async is True:
|
||||||
|
return self._async_retrieve_batch(
|
||||||
|
api_base=api_base,
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = sync_handler.get(
|
||||||
|
url=api_base,
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise Exception(f"Error: {response.status_code} {response.text}")
|
||||||
|
|
||||||
|
_json_response = response.json()
|
||||||
|
vertex_batch_response = VertexAIBatchTransformation.transform_vertex_ai_batch_response_to_openai_batch_response(
|
||||||
|
response=_json_response
|
||||||
|
)
|
||||||
|
return vertex_batch_response
|
||||||
|
|
||||||
|
async def _async_retrieve_batch(
|
||||||
|
self,
|
||||||
|
api_base: str,
|
||||||
|
headers: Dict[str, str],
|
||||||
|
) -> Batch:
|
||||||
|
client = get_async_httpx_client(
|
||||||
|
llm_provider=litellm.LlmProviders.VERTEX_AI,
|
||||||
|
)
|
||||||
|
response = await client.get(
|
||||||
|
url=api_base,
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise Exception(f"Error: {response.status_code} {response.text}")
|
||||||
|
|
||||||
|
_json_response = response.json()
|
||||||
|
vertex_batch_response = VertexAIBatchTransformation.transform_vertex_ai_batch_response_to_openai_batch_response(
|
||||||
|
response=_json_response
|
||||||
|
)
|
||||||
|
return vertex_batch_response
|
||||||
|
|
|
@ -49,7 +49,7 @@ class VertexAIBatchTransformation:
|
||||||
cls, response: VertexBatchPredictionResponse
|
cls, response: VertexBatchPredictionResponse
|
||||||
) -> Batch:
|
) -> Batch:
|
||||||
return Batch(
|
return Batch(
|
||||||
id=response.get("name", ""),
|
id=cls._get_batch_id_from_vertex_ai_batch_response(response),
|
||||||
completion_window="24hrs",
|
completion_window="24hrs",
|
||||||
created_at=_convert_vertex_datetime_to_openai_datetime(
|
created_at=_convert_vertex_datetime_to_openai_datetime(
|
||||||
vertex_datetime=response.get("createTime", "")
|
vertex_datetime=response.get("createTime", "")
|
||||||
|
@ -66,6 +66,24 @@ class VertexAIBatchTransformation:
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_batch_id_from_vertex_ai_batch_response(
|
||||||
|
cls, response: VertexBatchPredictionResponse
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Gets the batch id from the Vertex AI Batch response safely
|
||||||
|
|
||||||
|
vertex response: `projects/510528649030/locations/us-central1/batchPredictionJobs/3814889423749775360`
|
||||||
|
returns: `3814889423749775360`
|
||||||
|
"""
|
||||||
|
_name = response.get("name", "")
|
||||||
|
if not _name:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# Split by '/' and get the last part if it exists
|
||||||
|
parts = _name.split("/")
|
||||||
|
return parts[-1] if parts else _name
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_input_file_id_from_vertex_ai_batch_response(
|
def _get_input_file_id_from_vertex_ai_batch_response(
|
||||||
cls, response: VertexBatchPredictionResponse
|
cls, response: VertexBatchPredictionResponse
|
||||||
|
|
|
@ -393,4 +393,10 @@ async def test_avertex_batch_prediction():
|
||||||
metadata={"key1": "value1", "key2": "value2"},
|
metadata={"key1": "value1", "key2": "value2"},
|
||||||
)
|
)
|
||||||
print("create_batch_response=", create_batch_response)
|
print("create_batch_response=", create_batch_response)
|
||||||
|
|
||||||
|
retrieved_batch = await litellm.aretrieve_batch(
|
||||||
|
batch_id=create_batch_response.id,
|
||||||
|
custom_llm_provider="vertex_ai",
|
||||||
|
)
|
||||||
|
print("retrieved_batch=", retrieved_batch)
|
||||||
pass
|
pass
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue