mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
* add _async_retrieve_batch * fix aretrieve_batch * fix _get_batch_id_from_vertex_ai_batch_response * fix batches docs
192 lines
6.7 KiB
Python
192 lines
6.7 KiB
Python
import uuid
|
|
from typing import Dict
|
|
|
|
from litellm.llms.vertex_ai.common_utils import (
|
|
_convert_vertex_datetime_to_openai_datetime,
|
|
)
|
|
from litellm.types.llms.openai import Batch, BatchJobStatus, CreateBatchRequest
|
|
from litellm.types.llms.vertex_ai import *
|
|
|
|
|
|
class VertexAIBatchTransformation:
|
|
"""
|
|
Transforms OpenAI Batch requests to Vertex AI Batch requests
|
|
|
|
API Ref: https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/batch-prediction-gemini
|
|
"""
|
|
|
|
@classmethod
|
|
def transform_openai_batch_request_to_vertex_ai_batch_request(
|
|
cls,
|
|
request: CreateBatchRequest,
|
|
) -> VertexAIBatchPredictionJob:
|
|
"""
|
|
Transforms OpenAI Batch requests to Vertex AI Batch requests
|
|
"""
|
|
request_display_name = f"litellm-vertex-batch-{uuid.uuid4()}"
|
|
input_file_id = request.get("input_file_id")
|
|
if input_file_id is None:
|
|
raise ValueError("input_file_id is required, but not provided")
|
|
input_config: InputConfig = InputConfig(
|
|
gcsSource=GcsSource(uris=input_file_id), instancesFormat="jsonl"
|
|
)
|
|
model: str = cls._get_model_from_gcs_file(input_file_id)
|
|
output_config: OutputConfig = OutputConfig(
|
|
predictionsFormat="jsonl",
|
|
gcsDestination=GcsDestination(
|
|
outputUriPrefix=cls._get_gcs_uri_prefix_from_file(input_file_id)
|
|
),
|
|
)
|
|
return VertexAIBatchPredictionJob(
|
|
inputConfig=input_config,
|
|
outputConfig=output_config,
|
|
model=model,
|
|
displayName=request_display_name,
|
|
)
|
|
|
|
@classmethod
|
|
def transform_vertex_ai_batch_response_to_openai_batch_response(
|
|
cls, response: VertexBatchPredictionResponse
|
|
) -> Batch:
|
|
return Batch(
|
|
id=cls._get_batch_id_from_vertex_ai_batch_response(response),
|
|
completion_window="24hrs",
|
|
created_at=_convert_vertex_datetime_to_openai_datetime(
|
|
vertex_datetime=response.get("createTime", "")
|
|
),
|
|
endpoint="",
|
|
input_file_id=cls._get_input_file_id_from_vertex_ai_batch_response(
|
|
response
|
|
),
|
|
object="batch",
|
|
status=cls._get_batch_job_status_from_vertex_ai_batch_response(response),
|
|
error_file_id=None, # Vertex AI doesn't seem to have a direct equivalent
|
|
output_file_id=cls._get_output_file_id_from_vertex_ai_batch_response(
|
|
response
|
|
),
|
|
)
|
|
|
|
@classmethod
|
|
def _get_batch_id_from_vertex_ai_batch_response(
|
|
cls, response: VertexBatchPredictionResponse
|
|
) -> str:
|
|
"""
|
|
Gets the batch id from the Vertex AI Batch response safely
|
|
|
|
vertex response: `projects/510528649030/locations/us-central1/batchPredictionJobs/3814889423749775360`
|
|
returns: `3814889423749775360`
|
|
"""
|
|
_name = response.get("name", "")
|
|
if not _name:
|
|
return ""
|
|
|
|
# Split by '/' and get the last part if it exists
|
|
parts = _name.split("/")
|
|
return parts[-1] if parts else _name
|
|
|
|
@classmethod
|
|
def _get_input_file_id_from_vertex_ai_batch_response(
|
|
cls, response: VertexBatchPredictionResponse
|
|
) -> str:
|
|
"""
|
|
Gets the input file id from the Vertex AI Batch response
|
|
"""
|
|
input_file_id: str = ""
|
|
input_config = response.get("inputConfig")
|
|
if input_config is None:
|
|
return input_file_id
|
|
|
|
gcs_source = input_config.get("gcsSource")
|
|
if gcs_source is None:
|
|
return input_file_id
|
|
|
|
uris = gcs_source.get("uris", "")
|
|
if len(uris) == 0:
|
|
return input_file_id
|
|
|
|
return uris[0]
|
|
|
|
@classmethod
|
|
def _get_output_file_id_from_vertex_ai_batch_response(
|
|
cls, response: VertexBatchPredictionResponse
|
|
) -> str:
|
|
"""
|
|
Gets the output file id from the Vertex AI Batch response
|
|
"""
|
|
output_file_id: str = ""
|
|
output_config = response.get("outputConfig")
|
|
if output_config is None:
|
|
return output_file_id
|
|
|
|
gcs_destination = output_config.get("gcsDestination")
|
|
if gcs_destination is None:
|
|
return output_file_id
|
|
|
|
output_uri_prefix = gcs_destination.get("outputUriPrefix", "")
|
|
return output_uri_prefix
|
|
|
|
@classmethod
|
|
def _get_batch_job_status_from_vertex_ai_batch_response(
|
|
cls, response: VertexBatchPredictionResponse
|
|
) -> BatchJobStatus:
|
|
"""
|
|
Gets the batch job status from the Vertex AI Batch response
|
|
|
|
ref: https://cloud.google.com/vertex-ai/docs/reference/rest/v1/JobState
|
|
"""
|
|
state_mapping: Dict[str, BatchJobStatus] = {
|
|
"JOB_STATE_UNSPECIFIED": "failed",
|
|
"JOB_STATE_QUEUED": "validating",
|
|
"JOB_STATE_PENDING": "validating",
|
|
"JOB_STATE_RUNNING": "in_progress",
|
|
"JOB_STATE_SUCCEEDED": "completed",
|
|
"JOB_STATE_FAILED": "failed",
|
|
"JOB_STATE_CANCELLING": "cancelling",
|
|
"JOB_STATE_CANCELLED": "cancelled",
|
|
"JOB_STATE_PAUSED": "in_progress",
|
|
"JOB_STATE_EXPIRED": "expired",
|
|
"JOB_STATE_UPDATING": "in_progress",
|
|
"JOB_STATE_PARTIALLY_SUCCEEDED": "completed",
|
|
}
|
|
|
|
vertex_state = response.get("state", "JOB_STATE_UNSPECIFIED")
|
|
return state_mapping[vertex_state]
|
|
|
|
@classmethod
|
|
def _get_gcs_uri_prefix_from_file(cls, input_file_id: str) -> str:
|
|
"""
|
|
Gets the gcs uri prefix from the input file id
|
|
|
|
Example:
|
|
input_file_id: "gs://litellm-testing-bucket/vtx_batch.jsonl"
|
|
returns: "gs://litellm-testing-bucket"
|
|
|
|
input_file_id: "gs://litellm-testing-bucket/batches/vtx_batch.jsonl"
|
|
returns: "gs://litellm-testing-bucket/batches"
|
|
"""
|
|
# Split the path and remove the filename
|
|
path_parts = input_file_id.rsplit("/", 1)
|
|
return path_parts[0]
|
|
|
|
@classmethod
|
|
def _get_model_from_gcs_file(cls, gcs_file_uri: str) -> str:
|
|
"""
|
|
Extracts the model from the gcs file uri
|
|
|
|
When files are uploaded using LiteLLM (/v1/files), the model is stored in the gcs file uri
|
|
|
|
Why?
|
|
- Because Vertex Requires the `model` param in create batch jobs request, but OpenAI does not require this
|
|
|
|
|
|
gcs_file_uri format: gs://litellm-testing-bucket/litellm-vertex-files/publishers/google/models/gemini-1.5-flash-001/e9412502-2c91-42a6-8e61-f5c294cc0fc8
|
|
returns: "publishers/google/models/gemini-1.5-flash-001"
|
|
"""
|
|
from urllib.parse import unquote
|
|
|
|
decoded_uri = unquote(gcs_file_uri)
|
|
|
|
model_path = decoded_uri.split("publishers/")[1]
|
|
parts = model_path.split("/")
|
|
model = f"publishers/{'/'.join(parts[:3])}"
|
|
return model
|