mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
(Feat - Batches API) add support for retrieving vertex api batch jobs (#7661)
* add _async_retrieve_batch * fix aretrieve_batch * fix _get_batch_id_from_vertex_ai_batch_response * fix batches docs
This commit is contained in:
parent
2507c275f6
commit
13f364682d
5 changed files with 177 additions and 2 deletions
|
@ -2480,7 +2480,7 @@ create_batch_response = oai_client.batches.create(
|
|||
|
||||
```json
|
||||
{
|
||||
"id": "projects/633608382793/locations/us-central1/batchPredictionJobs/986266568679751680",
|
||||
"id": "3814889423749775360",
|
||||
"completion_window": "24hrs",
|
||||
"created_at": 1733392026,
|
||||
"endpoint": "",
|
||||
|
@ -2503,6 +2503,43 @@ create_batch_response = oai_client.batches.create(
|
|||
}
|
||||
```
|
||||
|
||||
#### 4. Retrieve a batch
|
||||
|
||||
```python
|
||||
retrieved_batch = oai_client.batches.retrieve(
|
||||
batch_id=create_batch_response.id,
|
||||
extra_body={"custom_llm_provider": "vertex_ai"}, # tell litellm to use `vertex_ai` for this batch request
|
||||
)
|
||||
```
|
||||
|
||||
**Expected Response**
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "3814889423749775360",
|
||||
"completion_window": "24hrs",
|
||||
"created_at": 1736500100,
|
||||
"endpoint": "",
|
||||
"input_file_id": "gs://example-bucket-1-litellm/litellm-vertex-files/publishers/google/models/gemini-1.5-flash-001/7b2e47f5-3dd4-436d-920f-f9155bbdc952",
|
||||
"object": "batch",
|
||||
"status": "completed",
|
||||
"cancelled_at": null,
|
||||
"cancelling_at": null,
|
||||
"completed_at": null,
|
||||
"error_file_id": null,
|
||||
"errors": null,
|
||||
"expired_at": null,
|
||||
"expires_at": null,
|
||||
"failed_at": null,
|
||||
"finalizing_at": null,
|
||||
"in_progress_at": null,
|
||||
"metadata": null,
|
||||
"output_file_id": "gs://example-bucket-1-litellm/litellm-vertex-files/publishers/google/models/gemini-1.5-flash-001",
|
||||
"request_counts": null
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
## **Fine Tuning APIs**
|
||||
|
||||
|
||||
|
|
|
@ -416,6 +416,32 @@ def retrieve_batch(
|
|||
max_retries=optional_params.max_retries,
|
||||
retrieve_batch_data=_retrieve_batch_request,
|
||||
)
|
||||
elif custom_llm_provider == "vertex_ai":
|
||||
api_base = optional_params.api_base or ""
|
||||
vertex_ai_project = (
|
||||
optional_params.vertex_project
|
||||
or litellm.vertex_project
|
||||
or get_secret_str("VERTEXAI_PROJECT")
|
||||
)
|
||||
vertex_ai_location = (
|
||||
optional_params.vertex_location
|
||||
or litellm.vertex_location
|
||||
or get_secret_str("VERTEXAI_LOCATION")
|
||||
)
|
||||
vertex_credentials = optional_params.vertex_credentials or get_secret_str(
|
||||
"VERTEXAI_CREDENTIALS"
|
||||
)
|
||||
|
||||
response = vertex_ai_batches_instance.retrieve_batch(
|
||||
_is_async=_is_async,
|
||||
batch_id=batch_id,
|
||||
api_base=api_base,
|
||||
vertex_project=vertex_ai_project,
|
||||
vertex_location=vertex_ai_location,
|
||||
vertex_credentials=vertex_credentials,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
)
|
||||
else:
|
||||
raise litellm.exceptions.BadRequestError(
|
||||
message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format(
|
||||
|
|
|
@ -124,3 +124,91 @@ class VertexAIBatchPrediction(VertexLLM):
|
|||
"""Return the base url for the vertex garden models"""
|
||||
# POST https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/LOCATION/batchPredictionJobs
|
||||
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/batchPredictionJobs"
|
||||
|
||||
def retrieve_batch(
|
||||
self,
|
||||
_is_async: bool,
|
||||
batch_id: str,
|
||||
api_base: Optional[str],
|
||||
vertex_credentials: Optional[str],
|
||||
vertex_project: Optional[str],
|
||||
vertex_location: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
) -> Union[Batch, Coroutine[Any, Any, Batch]]:
|
||||
sync_handler = _get_httpx_client()
|
||||
|
||||
access_token, project_id = self._ensure_access_token(
|
||||
credentials=vertex_credentials,
|
||||
project_id=vertex_project,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
|
||||
default_api_base = self.create_vertex_url(
|
||||
vertex_location=vertex_location or "us-central1",
|
||||
vertex_project=vertex_project or project_id,
|
||||
)
|
||||
|
||||
# Append batch_id to the URL
|
||||
default_api_base = f"{default_api_base}/{batch_id}"
|
||||
|
||||
if len(default_api_base.split(":")) > 1:
|
||||
endpoint = default_api_base.split(":")[-1]
|
||||
else:
|
||||
endpoint = ""
|
||||
|
||||
_, api_base = self._check_custom_proxy(
|
||||
api_base=api_base,
|
||||
custom_llm_provider="vertex_ai",
|
||||
gemini_api_key=None,
|
||||
endpoint=endpoint,
|
||||
stream=None,
|
||||
auth_header=None,
|
||||
url=default_api_base,
|
||||
)
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json; charset=utf-8",
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
}
|
||||
|
||||
if _is_async is True:
|
||||
return self._async_retrieve_batch(
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
response = sync_handler.get(
|
||||
url=api_base,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Error: {response.status_code} {response.text}")
|
||||
|
||||
_json_response = response.json()
|
||||
vertex_batch_response = VertexAIBatchTransformation.transform_vertex_ai_batch_response_to_openai_batch_response(
|
||||
response=_json_response
|
||||
)
|
||||
return vertex_batch_response
|
||||
|
||||
async def _async_retrieve_batch(
|
||||
self,
|
||||
api_base: str,
|
||||
headers: Dict[str, str],
|
||||
) -> Batch:
|
||||
client = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.VERTEX_AI,
|
||||
)
|
||||
response = await client.get(
|
||||
url=api_base,
|
||||
headers=headers,
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Error: {response.status_code} {response.text}")
|
||||
|
||||
_json_response = response.json()
|
||||
vertex_batch_response = VertexAIBatchTransformation.transform_vertex_ai_batch_response_to_openai_batch_response(
|
||||
response=_json_response
|
||||
)
|
||||
return vertex_batch_response
|
||||
|
|
|
@ -49,7 +49,7 @@ class VertexAIBatchTransformation:
|
|||
cls, response: VertexBatchPredictionResponse
|
||||
) -> Batch:
|
||||
return Batch(
|
||||
id=response.get("name", ""),
|
||||
id=cls._get_batch_id_from_vertex_ai_batch_response(response),
|
||||
completion_window="24hrs",
|
||||
created_at=_convert_vertex_datetime_to_openai_datetime(
|
||||
vertex_datetime=response.get("createTime", "")
|
||||
|
@ -66,6 +66,24 @@ class VertexAIBatchTransformation:
|
|||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _get_batch_id_from_vertex_ai_batch_response(
|
||||
cls, response: VertexBatchPredictionResponse
|
||||
) -> str:
|
||||
"""
|
||||
Gets the batch id from the Vertex AI Batch response safely
|
||||
|
||||
vertex response: `projects/510528649030/locations/us-central1/batchPredictionJobs/3814889423749775360`
|
||||
returns: `3814889423749775360`
|
||||
"""
|
||||
_name = response.get("name", "")
|
||||
if not _name:
|
||||
return ""
|
||||
|
||||
# Split by '/' and get the last part if it exists
|
||||
parts = _name.split("/")
|
||||
return parts[-1] if parts else _name
|
||||
|
||||
@classmethod
|
||||
def _get_input_file_id_from_vertex_ai_batch_response(
|
||||
cls, response: VertexBatchPredictionResponse
|
||||
|
|
|
@ -393,4 +393,10 @@ async def test_avertex_batch_prediction():
|
|||
metadata={"key1": "value1", "key2": "value2"},
|
||||
)
|
||||
print("create_batch_response=", create_batch_response)
|
||||
|
||||
retrieved_batch = await litellm.aretrieve_batch(
|
||||
batch_id=create_batch_response.id,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
print("retrieved_batch=", retrieved_batch)
|
||||
pass
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue