(feat) add Vertex Batches API support in OpenAI format (#7032)

* working request

* working transform

* working request

* transform vertex batch response

* add _async_create_batch

* move gcs functions to base

* fix _get_content_from_openai_file

* transform_openai_file_content_to_vertex_ai_file_content

* fix transform vertex gcs bucket upload to OAI files format

* working e2e test

* _get_gcs_object_name

* fix linting

* add doc string

* fix transform_gcs_bucket_response_to_openai_file_object

* use vertex for batch endpoints

* add batches support for vertex

* test_vertex_batches_endpoint

* test_vertex_batch_prediction

* fix gcs bucket base auth

* docs clean up batches

* docs Batch API

* docs vertex batches api

* test_get_gcs_logging_config_without_service_account

* undo change

* fix vertex md

* test_get_gcs_logging_config_without_service_account

* ci/cd run again
This commit is contained in:
Ishaan Jaff 2024-12-04 19:40:28 -08:00 committed by GitHub
parent 6cad9c58ac
commit 84db69d4c4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 1347 additions and 424 deletions

View file

@ -6,8 +6,9 @@ import TabItem from '@theme/TabItem';
Covers Batches, Files
## **Supported Providers**:
- Azure OpenAI
- **[Azure OpenAI](./providers/azure#azure-batches-api)**
- OpenAI
- **[Vertex AI](./providers/vertex#batch-apis)**
## Quick Start
@ -141,182 +142,4 @@ print("list_batches_response=", list_batches_response)
</Tabs>
## [👉 Proxy API Reference](https://litellm-api.up.railway.app/#/batch)
## Azure Batches API
Just add the azure env vars to your environment.
```bash
export AZURE_API_KEY=""
export AZURE_API_BASE=""
```
AND use `/azure/*` for the Batches API calls
```bash
http://0.0.0.0:4000/azure/v1/batches
```
### Usage
**Setup**
- Add Azure API Keys to your environment
#### 1. Upload a File
```bash
curl http://localhost:4000/azure/v1/files \
-H "Authorization: Bearer sk-1234" \
-F purpose="batch" \
-F file="@mydata.jsonl"
```
**Example File**
Note: `model` should be your azure deployment name.
```json
{"custom_id": "task-0", "method": "POST", "url": "/chat/completions", "body": {"model": "REPLACE-WITH-MODEL-DEPLOYMENT-NAME", "messages": [{"role": "system", "content": "You are an AI assistant that helps people find information."}, {"role": "user", "content": "When was Microsoft founded?"}]}}
{"custom_id": "task-1", "method": "POST", "url": "/chat/completions", "body": {"model": "REPLACE-WITH-MODEL-DEPLOYMENT-NAME", "messages": [{"role": "system", "content": "You are an AI assistant that helps people find information."}, {"role": "user", "content": "When was the first XBOX released?"}]}}
{"custom_id": "task-2", "method": "POST", "url": "/chat/completions", "body": {"model": "REPLACE-WITH-MODEL-DEPLOYMENT-NAME", "messages": [{"role": "system", "content": "You are an AI assistant that helps people find information."}, {"role": "user", "content": "What is Altair Basic?"}]}}
```
#### 2. Create a batch
```bash
curl http://0.0.0.0:4000/azure/v1/batches \
-H "Authorization: Bearer $LITELLM_API_KEY" \
-H "Content-Type: application/json" \
-d '{
"input_file_id": "file-abc123",
"endpoint": "/v1/chat/completions",
"completion_window": "24h"
}'
```
#### 3. Retrieve batch
```bash
curl http://0.0.0.0:4000/azure/v1/batches/batch_abc123 \
-H "Authorization: Bearer $LITELLM_API_KEY" \
-H "Content-Type: application/json" \
```
#### 4. Cancel batch
```bash
curl http://0.0.0.0:4000/azure/v1/batches/batch_abc123/cancel \
-H "Authorization: Bearer $LITELLM_API_KEY" \
-H "Content-Type: application/json" \
-X POST
```
#### 5. List Batch
```bash
curl http://0.0.0.0:4000/v1/batches?limit=2 \
-H "Authorization: Bearer $LITELLM_API_KEY" \
-H "Content-Type: application/json"
```
### [👉 Health Check Azure Batch models](./proxy/health.md#batch-models-azure-only)
### [BETA] Loadbalance Multiple Azure Deployments
In your config.yaml, set `enable_loadbalancing_on_batch_endpoints: true`
```yaml
model_list:
- model_name: "batch-gpt-4o-mini"
litellm_params:
model: "azure/gpt-4o-mini"
api_key: os.environ/AZURE_API_KEY
api_base: os.environ/AZURE_API_BASE
model_info:
mode: batch
litellm_settings:
enable_loadbalancing_on_batch_endpoints: true # 👈 KEY CHANGE
```
Note: This works on `{PROXY_BASE_URL}/v1/files` and `{PROXY_BASE_URL}/v1/batches`.
Note: Response is in the OpenAI-format.
1. Upload a file
Just set `model: batch-gpt-4o-mini` in your .jsonl.
```bash
curl http://localhost:4000/v1/files \
-H "Authorization: Bearer sk-1234" \
-F purpose="batch" \
-F file="@mydata.jsonl"
```
**Example File**
Note: `model` should be your azure deployment name.
```json
{"custom_id": "task-0", "method": "POST", "url": "/chat/completions", "body": {"model": "batch-gpt-4o-mini", "messages": [{"role": "system", "content": "You are an AI assistant that helps people find information."}, {"role": "user", "content": "When was Microsoft founded?"}]}}
{"custom_id": "task-1", "method": "POST", "url": "/chat/completions", "body": {"model": "batch-gpt-4o-mini", "messages": [{"role": "system", "content": "You are an AI assistant that helps people find information."}, {"role": "user", "content": "When was the first XBOX released?"}]}}
{"custom_id": "task-2", "method": "POST", "url": "/chat/completions", "body": {"model": "batch-gpt-4o-mini", "messages": [{"role": "system", "content": "You are an AI assistant that helps people find information."}, {"role": "user", "content": "What is Altair Basic?"}]}}
```
Expected Response (OpenAI-compatible)
```bash
{"id":"file-f0be81f654454113a922da60acb0eea6",...}
```
2. Create a batch
```bash
curl http://0.0.0.0:4000/v1/batches \
-H "Authorization: Bearer $LITELLM_API_KEY" \
-H "Content-Type: application/json" \
-d '{
"input_file_id": "file-f0be81f654454113a922da60acb0eea6",
"endpoint": "/v1/chat/completions",
"completion_window": "24h",
"model: "batch-gpt-4o-mini"
}'
```
Expected Response:
```bash
{"id":"batch_94e43f0a-d805-477d-adf9-bbb9c50910ed",...}
```
3. Retrieve a batch
```bash
curl http://0.0.0.0:4000/v1/batches/batch_94e43f0a-d805-477d-adf9-bbb9c50910ed \
-H "Authorization: Bearer $LITELLM_API_KEY" \
-H "Content-Type: application/json" \
```
Expected Response:
```
{"id":"batch_94e43f0a-d805-477d-adf9-bbb9c50910ed",...}
```
4. List batch
```bash
curl http://0.0.0.0:4000/v1/batches?limit=2 \
-H "Authorization: Bearer $LITELLM_API_KEY" \
-H "Content-Type: application/json"
```
Expected Response:
```bash
{"data":[{"id":"batch_R3V...}
```
## [Swagger API Reference](https://litellm-api.up.railway.app/#/batch)

View file

@ -559,6 +559,185 @@ litellm_settings:
</Tabs>
## **Azure Batches API**
Just add the azure env vars to your environment.
```bash
export AZURE_API_KEY=""
export AZURE_API_BASE=""
```
AND use `/azure/*` for the Batches API calls
```bash
http://0.0.0.0:4000/azure/v1/batches
```
### Usage
**Setup**
- Add Azure API Keys to your environment
#### 1. Upload a File
```bash
curl http://localhost:4000/azure/v1/files \
-H "Authorization: Bearer sk-1234" \
-F purpose="batch" \
-F file="@mydata.jsonl"
```
**Example File**
Note: `model` should be your azure deployment name.
```json
{"custom_id": "task-0", "method": "POST", "url": "/chat/completions", "body": {"model": "REPLACE-WITH-MODEL-DEPLOYMENT-NAME", "messages": [{"role": "system", "content": "You are an AI assistant that helps people find information."}, {"role": "user", "content": "When was Microsoft founded?"}]}}
{"custom_id": "task-1", "method": "POST", "url": "/chat/completions", "body": {"model": "REPLACE-WITH-MODEL-DEPLOYMENT-NAME", "messages": [{"role": "system", "content": "You are an AI assistant that helps people find information."}, {"role": "user", "content": "When was the first XBOX released?"}]}}
{"custom_id": "task-2", "method": "POST", "url": "/chat/completions", "body": {"model": "REPLACE-WITH-MODEL-DEPLOYMENT-NAME", "messages": [{"role": "system", "content": "You are an AI assistant that helps people find information."}, {"role": "user", "content": "What is Altair Basic?"}]}}
```
#### 2. Create a batch
```bash
curl http://0.0.0.0:4000/azure/v1/batches \
-H "Authorization: Bearer $LITELLM_API_KEY" \
-H "Content-Type: application/json" \
-d '{
"input_file_id": "file-abc123",
"endpoint": "/v1/chat/completions",
"completion_window": "24h"
}'
```
#### 3. Retrieve batch
```bash
curl http://0.0.0.0:4000/azure/v1/batches/batch_abc123 \
-H "Authorization: Bearer $LITELLM_API_KEY" \
-H "Content-Type: application/json" \
```
#### 4. Cancel batch
```bash
curl http://0.0.0.0:4000/azure/v1/batches/batch_abc123/cancel \
-H "Authorization: Bearer $LITELLM_API_KEY" \
-H "Content-Type: application/json" \
-X POST
```
#### 5. List Batch
```bash
curl http://0.0.0.0:4000/v1/batches?limit=2 \
-H "Authorization: Bearer $LITELLM_API_KEY" \
-H "Content-Type: application/json"
```
### [Health Check Azure Batch models](./proxy/health.md#batch-models-azure-only)
### [BETA] Loadbalance Multiple Azure Deployments
In your config.yaml, set `enable_loadbalancing_on_batch_endpoints: true`
```yaml
model_list:
- model_name: "batch-gpt-4o-mini"
litellm_params:
model: "azure/gpt-4o-mini"
api_key: os.environ/AZURE_API_KEY
api_base: os.environ/AZURE_API_BASE
model_info:
mode: batch
litellm_settings:
enable_loadbalancing_on_batch_endpoints: true # 👈 KEY CHANGE
```
Note: This works on `{PROXY_BASE_URL}/v1/files` and `{PROXY_BASE_URL}/v1/batches`.
Note: Response is in the OpenAI-format.
1. Upload a file
Just set `model: batch-gpt-4o-mini` in your .jsonl.
```bash
curl http://localhost:4000/v1/files \
-H "Authorization: Bearer sk-1234" \
-F purpose="batch" \
-F file="@mydata.jsonl"
```
**Example File**
Note: `model` should be your azure deployment name.
```json
{"custom_id": "task-0", "method": "POST", "url": "/chat/completions", "body": {"model": "batch-gpt-4o-mini", "messages": [{"role": "system", "content": "You are an AI assistant that helps people find information."}, {"role": "user", "content": "When was Microsoft founded?"}]}}
{"custom_id": "task-1", "method": "POST", "url": "/chat/completions", "body": {"model": "batch-gpt-4o-mini", "messages": [{"role": "system", "content": "You are an AI assistant that helps people find information."}, {"role": "user", "content": "When was the first XBOX released?"}]}}
{"custom_id": "task-2", "method": "POST", "url": "/chat/completions", "body": {"model": "batch-gpt-4o-mini", "messages": [{"role": "system", "content": "You are an AI assistant that helps people find information."}, {"role": "user", "content": "What is Altair Basic?"}]}}
```
Expected Response (OpenAI-compatible)
```bash
{"id":"file-f0be81f654454113a922da60acb0eea6",...}
```
2. Create a batch
```bash
curl http://0.0.0.0:4000/v1/batches \
-H "Authorization: Bearer $LITELLM_API_KEY" \
-H "Content-Type: application/json" \
-d '{
"input_file_id": "file-f0be81f654454113a922da60acb0eea6",
"endpoint": "/v1/chat/completions",
"completion_window": "24h",
"model: "batch-gpt-4o-mini"
}'
```
Expected Response:
```bash
{"id":"batch_94e43f0a-d805-477d-adf9-bbb9c50910ed",...}
```
3. Retrieve a batch
```bash
curl http://0.0.0.0:4000/v1/batches/batch_94e43f0a-d805-477d-adf9-bbb9c50910ed \
-H "Authorization: Bearer $LITELLM_API_KEY" \
-H "Content-Type: application/json" \
```
Expected Response:
```
{"id":"batch_94e43f0a-d805-477d-adf9-bbb9c50910ed",...}
```
4. List batch
```bash
curl http://0.0.0.0:4000/v1/batches?limit=2 \
-H "Authorization: Bearer $LITELLM_API_KEY" \
-H "Content-Type: application/json"
```
Expected Response:
```bash
{"data":[{"id":"batch_R3V...}
```
## Advanced
### Azure API Load-Balancing

View file

@ -2393,6 +2393,114 @@ print("response from proxy", response)
</TabItem>
</Tabs>
## **Batch APIs**
Just add the following Vertex env vars to your environment.
```bash
# GCS Bucket settings, used to store batch prediction files in
export GCS_BUCKET_NAME = "litellm-testing-bucket" # the bucket you want to store batch prediction files in
export GCS_PATH_SERVICE_ACCOUNT="/path/to/service_account.json" # path to your service account json file
# Vertex /batch endpoint settings, used for LLM API requests
export GOOGLE_APPLICATION_CREDENTIALS="/path/to/service_account.json" # path to your service account json file
export VERTEXAI_LOCATION="us-central1" # can be any vertex location
export VERTEXAI_PROJECT="my-test-project"
```
### Usage
#### 1. Create a file of batch requests for vertex
LiteLLM expects the file to follow the **[OpenAI batches files format](https://platform.openai.com/docs/guides/batch)**
Each `body` in the file should be an **OpenAI API request**
Create a file called `vertex_batch_completions.jsonl` in the current working directory, the `model` should be the Vertex AI model name
```
{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gemini-1.5-flash-001", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 10}}
{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gemini-1.5-flash-001", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 10}}
```
#### 2. Upload a File of batch requests
For `vertex_ai` litellm will upload the file to the provided `GCS_BUCKET_NAME`
```python
import os
oai_client = OpenAI(
api_key="sk-1234", # litellm proxy API key
base_url="http://localhost:4000" # litellm proxy base url
)
file_name = "vertex_batch_completions.jsonl" #
_current_dir = os.path.dirname(os.path.abspath(__file__))
file_path = os.path.join(_current_dir, file_name)
file_obj = oai_client.files.create(
file=open(file_path, "rb"),
purpose="batch",
extra_body={"custom_llm_provider": "vertex_ai"}, # tell litellm to use vertex_ai for this file upload
)
```
**Expected Response**
```json
{
"id": "gs://litellm-testing-bucket/litellm-vertex-files/publishers/google/models/gemini-1.5-flash-001/d3f198cd-c0d1-436d-9b1e-28e3f282997a",
"bytes": 416,
"created_at": 1733392026,
"filename": "litellm-vertex-files/publishers/google/models/gemini-1.5-flash-001/d3f198cd-c0d1-436d-9b1e-28e3f282997a",
"object": "file",
"purpose": "batch",
"status": "uploaded",
"status_details": null
}
```
#### 3. Create a batch
```python
batch_input_file_id = file_obj.id # use `file_obj` from step 2
create_batch_response = oai_client.batches.create(
completion_window="24h",
endpoint="/v1/chat/completions",
input_file_id=batch_input_file_id, # example input_file_id = "gs://litellm-testing-bucket/litellm-vertex-files/publishers/google/models/gemini-1.5-flash-001/c2b1b785-252b-448c-b180-033c4c63b3ce"
extra_body={"custom_llm_provider": "vertex_ai"}, # tell litellm to use `vertex_ai` for this batch request
)
```
**Expected Response**
```json
{
"id": "projects/633608382793/locations/us-central1/batchPredictionJobs/986266568679751680",
"completion_window": "24hrs",
"created_at": 1733392026,
"endpoint": "",
"input_file_id": "gs://litellm-testing-bucket/litellm-vertex-files/publishers/google/models/gemini-1.5-flash-001/d3f198cd-c0d1-436d-9b1e-28e3f282997a",
"object": "batch",
"status": "validating",
"cancelled_at": null,
"cancelling_at": null,
"completed_at": null,
"error_file_id": null,
"errors": null,
"expired_at": null,
"expires_at": null,
"failed_at": null,
"finalizing_at": null,
"in_progress_at": null,
"metadata": null,
"output_file_id": "gs://litellm-testing-bucket/litellm-vertex-files/publishers/google/models/gemini-1.5-flash-001",
"request_counts": null
}
```
## Extra
### Using `GOOGLE_APPLICATION_CREDENTIALS`

View file

@ -22,6 +22,9 @@ import litellm
from litellm import client
from litellm.llms.AzureOpenAI.azure import AzureBatchesAPI
from litellm.llms.OpenAI.openai import OpenAIBatchesAPI
from litellm.llms.vertex_ai_and_google_ai_studio.batches.handler import (
VertexAIBatchPrediction,
)
from litellm.secret_managers.main import get_secret, get_secret_str
from litellm.types.llms.openai import (
Batch,
@ -40,6 +43,7 @@ from litellm.utils import supports_httpx_timeout
####### ENVIRONMENT VARIABLES ###################
openai_batches_instance = OpenAIBatchesAPI()
azure_batches_instance = AzureBatchesAPI()
vertex_ai_batches_instance = VertexAIBatchPrediction(gcs_bucket_name="")
#################################################
@ -47,7 +51,7 @@ async def acreate_batch(
completion_window: Literal["24h"],
endpoint: Literal["/v1/chat/completions", "/v1/embeddings", "/v1/completions"],
input_file_id: str,
custom_llm_provider: Literal["openai", "azure"] = "openai",
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
metadata: Optional[Dict[str, str]] = None,
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
@ -93,7 +97,7 @@ def create_batch(
completion_window: Literal["24h"],
endpoint: Literal["/v1/chat/completions", "/v1/embeddings", "/v1/completions"],
input_file_id: str,
custom_llm_provider: Literal["openai", "azure"] = "openai",
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
metadata: Optional[Dict[str, str]] = None,
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
@ -199,6 +203,32 @@ def create_batch(
max_retries=optional_params.max_retries,
create_batch_data=_create_batch_request,
)
elif custom_llm_provider == "vertex_ai":
api_base = optional_params.api_base or ""
vertex_ai_project = (
optional_params.vertex_project
or litellm.vertex_project
or get_secret_str("VERTEXAI_PROJECT")
)
vertex_ai_location = (
optional_params.vertex_location
or litellm.vertex_location
or get_secret_str("VERTEXAI_LOCATION")
)
vertex_credentials = optional_params.vertex_credentials or get_secret_str(
"VERTEXAI_CREDENTIALS"
)
response = vertex_ai_batches_instance.create_batch(
_is_async=_is_async,
api_base=api_base,
vertex_project=vertex_ai_project,
vertex_location=vertex_ai_location,
vertex_credentials=vertex_credentials,
timeout=timeout,
max_retries=optional_params.max_retries,
create_batch_data=_create_batch_request,
)
else:
raise litellm.exceptions.BadRequestError(
message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format(

View file

@ -17,6 +17,9 @@ import litellm
from litellm import client, get_secret_str
from litellm.llms.files_apis.azure import AzureOpenAIFilesAPI
from litellm.llms.OpenAI.openai import FileDeleted, FileObject, OpenAIFilesAPI
from litellm.llms.vertex_ai_and_google_ai_studio.files.handler import (
VertexAIFilesHandler,
)
from litellm.types.llms.openai import (
Batch,
CreateFileRequest,
@ -30,6 +33,7 @@ from litellm.utils import supports_httpx_timeout
####### ENVIRONMENT VARIABLES ###################
openai_files_instance = OpenAIFilesAPI()
azure_files_instance = AzureOpenAIFilesAPI()
vertex_ai_files_instance = VertexAIFilesHandler()
#################################################
@ -490,7 +494,7 @@ def file_list(
async def acreate_file(
file: FileTypes,
purpose: Literal["assistants", "batch", "fine-tune"],
custom_llm_provider: Literal["openai", "azure"] = "openai",
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
@ -532,7 +536,7 @@ async def acreate_file(
def create_file(
file: FileTypes,
purpose: Literal["assistants", "batch", "fine-tune"],
custom_llm_provider: Literal["openai", "azure"] = "openai",
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
@ -630,6 +634,32 @@ def create_file(
max_retries=optional_params.max_retries,
create_file_data=_create_file_request,
)
elif custom_llm_provider == "vertex_ai":
api_base = optional_params.api_base or ""
vertex_ai_project = (
optional_params.vertex_project
or litellm.vertex_project
or get_secret_str("VERTEXAI_PROJECT")
)
vertex_ai_location = (
optional_params.vertex_location
or litellm.vertex_location
or get_secret_str("VERTEXAI_LOCATION")
)
vertex_credentials = optional_params.vertex_credentials or get_secret_str(
"VERTEXAI_CREDENTIALS"
)
response = vertex_ai_files_instance.create_file(
_is_async=_is_async,
api_base=api_base,
vertex_project=vertex_ai_project,
vertex_location=vertex_ai_location,
vertex_credentials=vertex_credentials,
timeout=timeout,
max_retries=optional_params.max_retries,
create_file_data=_create_file_request,
)
else:
raise litellm.exceptions.BadRequestError(
message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format(

View file

@ -29,7 +29,6 @@ else:
VertexBase = Any
IAM_AUTH_KEY = "IAM_AUTH"
GCS_DEFAULT_BATCH_SIZE = 2048
GCS_DEFAULT_FLUSH_INTERVAL_SECONDS = 20
@ -39,7 +38,6 @@ class GCSBucketLogger(GCSBucketBase):
from litellm.proxy.proxy_server import premium_user
super().__init__(bucket_name=bucket_name)
self.vertex_instances: Dict[str, VertexBase] = {}
# Init Batch logging settings
self.log_queue: List[GCSLogQueueItem] = []
@ -178,232 +176,3 @@ class GCSBucketLogger(GCSBucketBase):
object_name = _metadata["gcs_log_id"]
return object_name
def _handle_folders_in_bucket_name(
self,
bucket_name: str,
object_name: str,
) -> Tuple[str, str]:
"""
Handles when the user passes a bucket name with a folder postfix
Example:
- Bucket name: "my-bucket/my-folder/dev"
- Object name: "my-object"
- Returns: bucket_name="my-bucket", object_name="my-folder/dev/my-object"
"""
if "/" in bucket_name:
bucket_name, prefix = bucket_name.split("/", 1)
object_name = f"{prefix}/{object_name}"
return bucket_name, object_name
return bucket_name, object_name
async def _log_json_data_on_gcs(
self,
headers: Dict[str, str],
bucket_name: str,
object_name: str,
logging_payload: StandardLoggingPayload,
):
"""
Helper function to make POST request to GCS Bucket in the specified bucket.
"""
json_logged_payload = json.dumps(logging_payload, default=str)
bucket_name, object_name = self._handle_folders_in_bucket_name(
bucket_name=bucket_name,
object_name=object_name,
)
response = await self.async_httpx_client.post(
headers=headers,
url=f"https://storage.googleapis.com/upload/storage/v1/b/{bucket_name}/o?uploadType=media&name={object_name}",
data=json_logged_payload,
)
if response.status_code != 200:
verbose_logger.error("GCS Bucket logging error: %s", str(response.text))
verbose_logger.debug("GCS Bucket response %s", response)
verbose_logger.debug("GCS Bucket status code %s", response.status_code)
verbose_logger.debug("GCS Bucket response.text %s", response.text)
async def get_gcs_logging_config(
self, kwargs: Optional[Dict[str, Any]] = {}
) -> GCSLoggingConfig:
"""
This function is used to get the GCS logging config for the GCS Bucket Logger.
It checks if the dynamic parameters are provided in the kwargs and uses them to get the GCS logging config.
If no dynamic parameters are provided, it uses the default values.
"""
if kwargs is None:
kwargs = {}
standard_callback_dynamic_params: Optional[StandardCallbackDynamicParams] = (
kwargs.get("standard_callback_dynamic_params", None)
)
bucket_name: str
path_service_account: Optional[str]
if standard_callback_dynamic_params is not None:
verbose_logger.debug("Using dynamic GCS logging")
verbose_logger.debug(
"standard_callback_dynamic_params: %s", standard_callback_dynamic_params
)
_bucket_name: Optional[str] = (
standard_callback_dynamic_params.get("gcs_bucket_name", None)
or self.BUCKET_NAME
)
_path_service_account: Optional[str] = (
standard_callback_dynamic_params.get("gcs_path_service_account", None)
or self.path_service_account_json
)
if _bucket_name is None:
raise ValueError(
"GCS_BUCKET_NAME is not set in the environment, but GCS Bucket is being used as a logging callback. Please set 'GCS_BUCKET_NAME' in the environment."
)
bucket_name = _bucket_name
path_service_account = _path_service_account
vertex_instance = await self.get_or_create_vertex_instance(
credentials=path_service_account
)
else:
# If no dynamic parameters, use the default instance
if self.BUCKET_NAME is None:
raise ValueError(
"GCS_BUCKET_NAME is not set in the environment, but GCS Bucket is being used as a logging callback. Please set 'GCS_BUCKET_NAME' in the environment."
)
bucket_name = self.BUCKET_NAME
path_service_account = self.path_service_account_json
vertex_instance = await self.get_or_create_vertex_instance(
credentials=path_service_account
)
return GCSLoggingConfig(
bucket_name=bucket_name,
vertex_instance=vertex_instance,
path_service_account=path_service_account,
)
async def get_or_create_vertex_instance(
self, credentials: Optional[str]
) -> VertexBase:
"""
This function is used to get the Vertex instance for the GCS Bucket Logger.
It checks if the Vertex instance is already created and cached, if not it creates a new instance and caches it.
"""
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_llm_base import (
VertexBase,
)
_in_memory_key = self._get_in_memory_key_for_vertex_instance(credentials)
if _in_memory_key not in self.vertex_instances:
vertex_instance = VertexBase()
await vertex_instance._ensure_access_token_async(
credentials=credentials,
project_id=None,
custom_llm_provider="vertex_ai",
)
self.vertex_instances[_in_memory_key] = vertex_instance
return self.vertex_instances[_in_memory_key]
def _get_in_memory_key_for_vertex_instance(self, credentials: Optional[str]) -> str:
"""
Returns key to use for caching the Vertex instance in-memory.
When using Vertex with Key based logging, we need to cache the Vertex instance in-memory.
- If a credentials string is provided, it is used as the key.
- If no credentials string is provided, "IAM_AUTH" is used as the key.
"""
return credentials or IAM_AUTH_KEY
async def download_gcs_object(self, object_name: str, **kwargs):
"""
Download an object from GCS.
https://cloud.google.com/storage/docs/downloading-objects#download-object-json
"""
try:
gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
kwargs=kwargs
)
headers = await self.construct_request_headers(
vertex_instance=gcs_logging_config["vertex_instance"],
service_account_json=gcs_logging_config["path_service_account"],
)
bucket_name = gcs_logging_config["bucket_name"]
bucket_name, object_name = self._handle_folders_in_bucket_name(
bucket_name=bucket_name,
object_name=object_name,
)
url = f"https://storage.googleapis.com/storage/v1/b/{bucket_name}/o/{object_name}?alt=media"
# Send the GET request to download the object
response = await self.async_httpx_client.get(url=url, headers=headers)
if response.status_code != 200:
verbose_logger.error(
"GCS object download error: %s", str(response.text)
)
return None
verbose_logger.debug(
"GCS object download response status code: %s", response.status_code
)
# Return the content of the downloaded object
return response.content
except Exception as e:
verbose_logger.error("GCS object download error: %s", str(e))
return None
async def delete_gcs_object(self, object_name: str, **kwargs):
"""
Delete an object from GCS.
"""
try:
gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
kwargs=kwargs
)
headers = await self.construct_request_headers(
vertex_instance=gcs_logging_config["vertex_instance"],
service_account_json=gcs_logging_config["path_service_account"],
)
bucket_name = gcs_logging_config["bucket_name"]
bucket_name, object_name = self._handle_folders_in_bucket_name(
bucket_name=bucket_name,
object_name=object_name,
)
url = f"https://storage.googleapis.com/storage/v1/b/{bucket_name}/o/{object_name}"
# Send the DELETE request to delete the object
response = await self.async_httpx_client.delete(url=url, headers=headers)
if (response.status_code != 200) or (response.status_code != 204):
verbose_logger.error(
"GCS object delete error: %s, status code: %s",
str(response.text),
response.status_code,
)
return None
verbose_logger.debug(
"GCS object delete response status code: %s, response: %s",
response.status_code,
response.text,
)
# Return the content of the downloaded object
return response.text
except Exception as e:
verbose_logger.error("GCS object download error: %s", str(e))
return None

View file

@ -2,7 +2,7 @@ import json
import os
import uuid
from datetime import datetime
from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypedDict, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, TypedDict, Union
import httpx
from pydantic import BaseModel, Field
@ -14,11 +14,18 @@ from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client,
httpxSpecialProvider,
)
from litellm.types.integrations.gcs_bucket import *
from litellm.types.utils import (
StandardCallbackDynamicParams,
StandardLoggingMetadata,
StandardLoggingPayload,
)
if TYPE_CHECKING:
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_llm_base import VertexBase
else:
VertexBase = Any
IAM_AUTH_KEY = "IAM_AUTH"
class GCSBucketBase(CustomBatchLogger):
@ -30,6 +37,7 @@ class GCSBucketBase(CustomBatchLogger):
_bucket_name = bucket_name or os.getenv("GCS_BUCKET_NAME")
self.path_service_account_json: Optional[str] = _path_service_account
self.BUCKET_NAME: Optional[str] = _bucket_name
self.vertex_instances: Dict[str, VertexBase] = {}
super().__init__(**kwargs)
async def construct_request_headers(
@ -94,3 +102,237 @@ class GCSBucketBase(CustomBatchLogger):
}
return headers
def _handle_folders_in_bucket_name(
self,
bucket_name: str,
object_name: str,
) -> Tuple[str, str]:
"""
Handles when the user passes a bucket name with a folder postfix
Example:
- Bucket name: "my-bucket/my-folder/dev"
- Object name: "my-object"
- Returns: bucket_name="my-bucket", object_name="my-folder/dev/my-object"
"""
if "/" in bucket_name:
bucket_name, prefix = bucket_name.split("/", 1)
object_name = f"{prefix}/{object_name}"
return bucket_name, object_name
return bucket_name, object_name
async def get_gcs_logging_config(
self, kwargs: Optional[Dict[str, Any]] = {}
) -> GCSLoggingConfig:
"""
This function is used to get the GCS logging config for the GCS Bucket Logger.
It checks if the dynamic parameters are provided in the kwargs and uses them to get the GCS logging config.
If no dynamic parameters are provided, it uses the default values.
"""
if kwargs is None:
kwargs = {}
standard_callback_dynamic_params: Optional[StandardCallbackDynamicParams] = (
kwargs.get("standard_callback_dynamic_params", None)
)
bucket_name: str
path_service_account: Optional[str]
if standard_callback_dynamic_params is not None:
verbose_logger.debug("Using dynamic GCS logging")
verbose_logger.debug(
"standard_callback_dynamic_params: %s", standard_callback_dynamic_params
)
_bucket_name: Optional[str] = (
standard_callback_dynamic_params.get("gcs_bucket_name", None)
or self.BUCKET_NAME
)
_path_service_account: Optional[str] = (
standard_callback_dynamic_params.get("gcs_path_service_account", None)
or self.path_service_account_json
)
if _bucket_name is None:
raise ValueError(
"GCS_BUCKET_NAME is not set in the environment, but GCS Bucket is being used as a logging callback. Please set 'GCS_BUCKET_NAME' in the environment."
)
bucket_name = _bucket_name
path_service_account = _path_service_account
vertex_instance = await self.get_or_create_vertex_instance(
credentials=path_service_account
)
else:
# If no dynamic parameters, use the default instance
if self.BUCKET_NAME is None:
raise ValueError(
"GCS_BUCKET_NAME is not set in the environment, but GCS Bucket is being used as a logging callback. Please set 'GCS_BUCKET_NAME' in the environment."
)
bucket_name = self.BUCKET_NAME
path_service_account = self.path_service_account_json
vertex_instance = await self.get_or_create_vertex_instance(
credentials=path_service_account
)
return GCSLoggingConfig(
bucket_name=bucket_name,
vertex_instance=vertex_instance,
path_service_account=path_service_account,
)
async def get_or_create_vertex_instance(
self, credentials: Optional[str]
) -> VertexBase:
"""
This function is used to get the Vertex instance for the GCS Bucket Logger.
It checks if the Vertex instance is already created and cached, if not it creates a new instance and caches it.
"""
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_llm_base import (
VertexBase,
)
_in_memory_key = self._get_in_memory_key_for_vertex_instance(credentials)
if _in_memory_key not in self.vertex_instances:
vertex_instance = VertexBase()
await vertex_instance._ensure_access_token_async(
credentials=credentials,
project_id=None,
custom_llm_provider="vertex_ai",
)
self.vertex_instances[_in_memory_key] = vertex_instance
return self.vertex_instances[_in_memory_key]
def _get_in_memory_key_for_vertex_instance(self, credentials: Optional[str]) -> str:
"""
Returns key to use for caching the Vertex instance in-memory.
When using Vertex with Key based logging, we need to cache the Vertex instance in-memory.
- If a credentials string is provided, it is used as the key.
- If no credentials string is provided, "IAM_AUTH" is used as the key.
"""
return credentials or IAM_AUTH_KEY
async def download_gcs_object(self, object_name: str, **kwargs):
"""
Download an object from GCS.
https://cloud.google.com/storage/docs/downloading-objects#download-object-json
"""
try:
gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
kwargs=kwargs
)
headers = await self.construct_request_headers(
vertex_instance=gcs_logging_config["vertex_instance"],
service_account_json=gcs_logging_config["path_service_account"],
)
bucket_name = gcs_logging_config["bucket_name"]
bucket_name, object_name = self._handle_folders_in_bucket_name(
bucket_name=bucket_name,
object_name=object_name,
)
url = f"https://storage.googleapis.com/storage/v1/b/{bucket_name}/o/{object_name}?alt=media"
# Send the GET request to download the object
response = await self.async_httpx_client.get(url=url, headers=headers)
if response.status_code != 200:
verbose_logger.error(
"GCS object download error: %s", str(response.text)
)
return None
verbose_logger.debug(
"GCS object download response status code: %s", response.status_code
)
# Return the content of the downloaded object
return response.content
except Exception as e:
verbose_logger.error("GCS object download error: %s", str(e))
return None
async def delete_gcs_object(self, object_name: str, **kwargs):
"""
Delete an object from GCS.
"""
try:
gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
kwargs=kwargs
)
headers = await self.construct_request_headers(
vertex_instance=gcs_logging_config["vertex_instance"],
service_account_json=gcs_logging_config["path_service_account"],
)
bucket_name = gcs_logging_config["bucket_name"]
bucket_name, object_name = self._handle_folders_in_bucket_name(
bucket_name=bucket_name,
object_name=object_name,
)
url = f"https://storage.googleapis.com/storage/v1/b/{bucket_name}/o/{object_name}"
# Send the DELETE request to delete the object
response = await self.async_httpx_client.delete(url=url, headers=headers)
if (response.status_code != 200) or (response.status_code != 204):
verbose_logger.error(
"GCS object delete error: %s, status code: %s",
str(response.text),
response.status_code,
)
return None
verbose_logger.debug(
"GCS object delete response status code: %s, response: %s",
response.status_code,
response.text,
)
# Return the content of the downloaded object
return response.text
except Exception as e:
verbose_logger.error("GCS object download error: %s", str(e))
return None
async def _log_json_data_on_gcs(
self,
headers: Dict[str, str],
bucket_name: str,
object_name: str,
logging_payload: Union[StandardLoggingPayload, str],
):
"""
Helper function to make POST request to GCS Bucket in the specified bucket.
"""
if isinstance(logging_payload, str):
json_logged_payload = logging_payload
else:
json_logged_payload = json.dumps(logging_payload, default=str)
bucket_name, object_name = self._handle_folders_in_bucket_name(
bucket_name=bucket_name,
object_name=object_name,
)
response = await self.async_httpx_client.post(
headers=headers,
url=f"https://storage.googleapis.com/upload/storage/v1/b/{bucket_name}/o?uploadType=media&name={object_name}",
data=json_logged_payload,
)
if response.status_code != 200:
verbose_logger.error("GCS Bucket logging error: %s", str(response.text))
verbose_logger.debug("GCS Bucket response %s", response)
verbose_logger.debug("GCS Bucket status code %s", response.status_code)
verbose_logger.debug("GCS Bucket response.text %s", response.text)
return response.json()

View file

@ -0,0 +1,6 @@
# Vertex AI Batch Prediction Jobs
Implementation to call VertexAI Batch endpoints in OpenAI Batch API spec
Vertex Docs: https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/batch-prediction-gemini

View file

@ -0,0 +1,141 @@
import json
from typing import Any, Coroutine, Dict, Optional, Union
import httpx
import litellm
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
_get_httpx_client,
get_async_httpx_client,
)
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
VertexAIError,
VertexLLM,
)
from litellm.types.llms.openai import (
Batch,
CancelBatchRequest,
CreateBatchRequest,
CreateFileRequest,
FileContentRequest,
FileObject,
FileTypes,
HttpxBinaryResponseContent,
RetrieveBatchRequest,
)
from litellm.types.llms.vertex_ai import VertexAIBatchPredictionJob
from .transformation import VertexAIBatchTransformation
class VertexAIBatchPrediction(VertexLLM):
def __init__(self, gcs_bucket_name: str, *args, **kwargs):
super().__init__(*args, **kwargs)
self.gcs_bucket_name = gcs_bucket_name
def create_batch(
self,
_is_async: bool,
create_batch_data: CreateBatchRequest,
api_base: Optional[str],
vertex_credentials: Optional[str],
vertex_project: Optional[str],
vertex_location: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
) -> Union[Batch, Coroutine[Any, Any, Batch]]:
sync_handler = _get_httpx_client()
access_token, project_id = self._ensure_access_token(
credentials=vertex_credentials,
project_id=vertex_project,
custom_llm_provider="vertex_ai",
)
default_api_base = self.create_vertex_url(
vertex_location=vertex_location or "us-central1",
vertex_project=vertex_project or project_id,
)
if len(default_api_base.split(":")) > 1:
endpoint = default_api_base.split(":")[-1]
else:
endpoint = ""
_, api_base = self._check_custom_proxy(
api_base=api_base,
custom_llm_provider="vertex_ai",
gemini_api_key=None,
endpoint=endpoint,
stream=None,
auth_header=None,
url=default_api_base,
)
headers = {
"Content-Type": "application/json; charset=utf-8",
"Authorization": f"Bearer {access_token}",
}
vertex_batch_request: VertexAIBatchPredictionJob = (
VertexAIBatchTransformation.transform_openai_batch_request_to_vertex_ai_batch_request(
request=create_batch_data
)
)
if _is_async is True:
return self._async_create_batch(
vertex_batch_request=vertex_batch_request,
api_base=api_base,
headers=headers,
)
response = sync_handler.post(
url=api_base,
headers=headers,
data=json.dumps(vertex_batch_request),
)
if response.status_code != 200:
raise Exception(f"Error: {response.status_code} {response.text}")
_json_response = response.json()
vertex_batch_response = VertexAIBatchTransformation.transform_vertex_ai_batch_response_to_openai_batch_response(
response=_json_response
)
return vertex_batch_response
async def _async_create_batch(
self,
vertex_batch_request: VertexAIBatchPredictionJob,
api_base: str,
headers: Dict[str, str],
) -> Batch:
client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.VERTEX_AI,
)
response = await client.post(
url=api_base,
headers=headers,
data=json.dumps(vertex_batch_request),
)
if response.status_code != 200:
raise Exception(f"Error: {response.status_code} {response.text}")
_json_response = response.json()
vertex_batch_response = VertexAIBatchTransformation.transform_vertex_ai_batch_response_to_openai_batch_response(
response=_json_response
)
return vertex_batch_response
def create_vertex_url(
self,
vertex_location: str,
vertex_project: str,
) -> str:
"""Return the base url for the vertex garden models"""
# POST https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/LOCATION/batchPredictionJobs
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/batchPredictionJobs"

View file

@ -0,0 +1,174 @@
import uuid
from typing import Any, Dict, Literal
from litellm.llms.vertex_ai_and_google_ai_studio.common_utils import (
_convert_vertex_datetime_to_openai_datetime,
)
from litellm.types.llms.openai import Batch, BatchJobStatus, CreateBatchRequest
from litellm.types.llms.vertex_ai import *
class VertexAIBatchTransformation:
"""
Transforms OpenAI Batch requests to Vertex AI Batch requests
API Ref: https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/batch-prediction-gemini
"""
@classmethod
def transform_openai_batch_request_to_vertex_ai_batch_request(
cls,
request: CreateBatchRequest,
) -> VertexAIBatchPredictionJob:
"""
Transforms OpenAI Batch requests to Vertex AI Batch requests
"""
request_display_name = f"litellm-vertex-batch-{uuid.uuid4()}"
input_file_id = request.get("input_file_id")
if input_file_id is None:
raise ValueError("input_file_id is required, but not provided")
input_config: InputConfig = InputConfig(
gcsSource=GcsSource(uris=input_file_id), instancesFormat="jsonl"
)
model: str = cls._get_model_from_gcs_file(input_file_id)
output_config: OutputConfig = OutputConfig(
predictionsFormat="jsonl",
gcsDestination=GcsDestination(
outputUriPrefix=cls._get_gcs_uri_prefix_from_file(input_file_id)
),
)
return VertexAIBatchPredictionJob(
inputConfig=input_config,
outputConfig=output_config,
model=model,
displayName=request_display_name,
)
@classmethod
def transform_vertex_ai_batch_response_to_openai_batch_response(
cls, response: VertexBatchPredictionResponse
) -> Batch:
return Batch(
id=response.get("name", ""),
completion_window="24hrs",
created_at=_convert_vertex_datetime_to_openai_datetime(
vertex_datetime=response.get("createTime", "")
),
endpoint="",
input_file_id=cls._get_input_file_id_from_vertex_ai_batch_response(
response
),
object="batch",
status=cls._get_batch_job_status_from_vertex_ai_batch_response(response),
error_file_id=None, # Vertex AI doesn't seem to have a direct equivalent
output_file_id=cls._get_output_file_id_from_vertex_ai_batch_response(
response
),
)
@classmethod
def _get_input_file_id_from_vertex_ai_batch_response(
cls, response: VertexBatchPredictionResponse
) -> str:
"""
Gets the input file id from the Vertex AI Batch response
"""
input_file_id: str = ""
input_config = response.get("inputConfig")
if input_config is None:
return input_file_id
gcs_source = input_config.get("gcsSource")
if gcs_source is None:
return input_file_id
uris = gcs_source.get("uris", "")
if len(uris) == 0:
return input_file_id
return uris[0]
@classmethod
def _get_output_file_id_from_vertex_ai_batch_response(
cls, response: VertexBatchPredictionResponse
) -> str:
"""
Gets the output file id from the Vertex AI Batch response
"""
output_file_id: str = ""
output_config = response.get("outputConfig")
if output_config is None:
return output_file_id
gcs_destination = output_config.get("gcsDestination")
if gcs_destination is None:
return output_file_id
output_uri_prefix = gcs_destination.get("outputUriPrefix", "")
return output_uri_prefix
@classmethod
def _get_batch_job_status_from_vertex_ai_batch_response(
cls, response: VertexBatchPredictionResponse
) -> BatchJobStatus:
"""
Gets the batch job status from the Vertex AI Batch response
ref: https://cloud.google.com/vertex-ai/docs/reference/rest/v1/JobState
"""
state_mapping: Dict[str, BatchJobStatus] = {
"JOB_STATE_UNSPECIFIED": "failed",
"JOB_STATE_QUEUED": "validating",
"JOB_STATE_PENDING": "validating",
"JOB_STATE_RUNNING": "in_progress",
"JOB_STATE_SUCCEEDED": "completed",
"JOB_STATE_FAILED": "failed",
"JOB_STATE_CANCELLING": "cancelling",
"JOB_STATE_CANCELLED": "cancelled",
"JOB_STATE_PAUSED": "in_progress",
"JOB_STATE_EXPIRED": "expired",
"JOB_STATE_UPDATING": "in_progress",
"JOB_STATE_PARTIALLY_SUCCEEDED": "completed",
}
vertex_state = response.get("state", "JOB_STATE_UNSPECIFIED")
return state_mapping[vertex_state]
@classmethod
def _get_gcs_uri_prefix_from_file(cls, input_file_id: str) -> str:
"""
Gets the gcs uri prefix from the input file id
Example:
input_file_id: "gs://litellm-testing-bucket/vtx_batch.jsonl"
returns: "gs://litellm-testing-bucket"
input_file_id: "gs://litellm-testing-bucket/batches/vtx_batch.jsonl"
returns: "gs://litellm-testing-bucket/batches"
"""
# Split the path and remove the filename
path_parts = input_file_id.rsplit("/", 1)
return path_parts[0]
@classmethod
def _get_model_from_gcs_file(cls, gcs_file_uri: str) -> str:
"""
Extracts the model from the gcs file uri
When files are uploaded using LiteLLM (/v1/files), the model is stored in the gcs file uri
Why?
- Because Vertex Requires the `model` param in create batch jobs request, but OpenAI does not require this
gcs_file_uri format: gs://litellm-testing-bucket/litellm-vertex-files/publishers/google/models/gemini-1.5-flash-001/e9412502-2c91-42a6-8e61-f5c294cc0fc8
returns: "publishers/google/models/gemini-1.5-flash-001"
"""
from urllib.parse import unquote
decoded_uri = unquote(gcs_file_uri)
model_path = decoded_uri.split("publishers/")[1]
parts = model_path.split("/")
model = f"publishers/{'/'.join(parts[:3])}"
return model

View file

@ -264,3 +264,18 @@ def strip_field(schema, field_name: str):
items = schema.get("items", None)
if items is not None:
strip_field(items, field_name)
def _convert_vertex_datetime_to_openai_datetime(vertex_datetime: str) -> int:
"""
Converts a Vertex AI datetime string to an OpenAI datetime integer
vertex_datetime: str = "2024-12-04T21:53:12.120184Z"
returns: int = 1722729192
"""
from datetime import datetime
# Parse the ISO format string to datetime object
dt = datetime.strptime(vertex_datetime, "%Y-%m-%dT%H:%M:%S.%fZ")
# Convert to Unix timestamp (seconds since epoch)
return int(dt.timestamp())

View file

@ -0,0 +1,111 @@
import json
import uuid
from typing import Any, Coroutine, Dict, Optional, Union
import httpx
import litellm
from litellm.integrations.gcs_bucket.gcs_bucket_base import (
GCSBucketBase,
GCSLoggingConfig,
)
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
_get_httpx_client,
get_async_httpx_client,
)
from litellm.llms.vertex_ai_and_google_ai_studio.common_utils import (
_convert_vertex_datetime_to_openai_datetime,
)
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
VertexAIError,
VertexLLM,
)
from litellm.types.llms.openai import (
Batch,
CreateFileRequest,
FileContentRequest,
FileObject,
FileTypes,
HttpxBinaryResponseContent,
)
from .transformation import VertexAIFilesTransformation
vertex_ai_files_transformation = VertexAIFilesTransformation()
class VertexAIFilesHandler(GCSBucketBase):
"""
Handles Calling VertexAI in OpenAI Files API format v1/files/*
This implementation uploads files on GCS Buckets
"""
pass
async def async_create_file(
self,
create_file_data: CreateFileRequest,
api_base: Optional[str],
vertex_credentials: Optional[str],
vertex_project: Optional[str],
vertex_location: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
):
gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
kwargs={}
)
headers = await self.construct_request_headers(
vertex_instance=gcs_logging_config["vertex_instance"],
service_account_json=gcs_logging_config["path_service_account"],
)
bucket_name = gcs_logging_config["bucket_name"]
logging_payload, object_name = (
vertex_ai_files_transformation.transform_openai_file_content_to_vertex_ai_file_content(
openai_file_content=create_file_data.get("file")
)
)
gcs_upload_response = await self._log_json_data_on_gcs(
headers=headers,
bucket_name=bucket_name,
object_name=object_name,
logging_payload=logging_payload,
)
return vertex_ai_files_transformation.transform_gcs_bucket_response_to_openai_file_object(
create_file_data=create_file_data,
gcs_upload_response=gcs_upload_response,
)
def create_file(
self,
_is_async: bool,
create_file_data: CreateFileRequest,
api_base: Optional[str],
vertex_credentials: Optional[str],
vertex_project: Optional[str],
vertex_location: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
) -> Union[FileObject, Coroutine[Any, Any, FileObject]]:
"""
Creates a file on VertexAI GCS Bucket
Only supported for Async litellm.acreate_file
"""
if _is_async:
return self.async_create_file(
create_file_data=create_file_data,
api_base=api_base,
vertex_credentials=vertex_credentials,
vertex_project=vertex_project,
vertex_location=vertex_location,
timeout=timeout,
max_retries=max_retries,
)
return None # type: ignore

View file

@ -0,0 +1,173 @@
import json
import uuid
from typing import Any, Dict, List, Optional, Tuple, Union
from litellm.llms.vertex_ai_and_google_ai_studio.common_utils import (
_convert_vertex_datetime_to_openai_datetime,
)
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.transformation import (
_transform_request_body,
)
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
VertexGeminiConfig,
)
from litellm.types.llms.openai import (
Batch,
CreateFileRequest,
FileContentRequest,
FileObject,
FileTypes,
HttpxBinaryResponseContent,
PathLike,
)
class VertexAIFilesTransformation(VertexGeminiConfig):
"""
Transforms OpenAI /v1/files/* requests to VertexAI /v1/files/* requests
"""
def transform_openai_file_content_to_vertex_ai_file_content(
self, openai_file_content: Optional[FileTypes] = None
) -> Tuple[str, str]:
"""
Transforms OpenAI FileContentRequest to VertexAI FileContentRequest
"""
if openai_file_content is None:
raise ValueError("contents of file are None")
# Read the content of the file
file_content = self._get_content_from_openai_file(openai_file_content)
# Split into lines and parse each line as JSON
openai_jsonl_content = [
json.loads(line) for line in file_content.splitlines() if line.strip()
]
vertex_jsonl_content = (
self._transform_openai_jsonl_content_to_vertex_ai_jsonl_content(
openai_jsonl_content
)
)
vertex_jsonl_string = "\n".join(
json.dumps(item) for item in vertex_jsonl_content
)
object_name = self._get_gcs_object_name(
openai_jsonl_content=openai_jsonl_content
)
return vertex_jsonl_string, object_name
def _transform_openai_jsonl_content_to_vertex_ai_jsonl_content(
self, openai_jsonl_content: List[Dict[str, Any]]
):
"""
Transforms OpenAI JSONL content to VertexAI JSONL content
jsonl body for vertex is {"request": <request_body>}
Example Vertex jsonl
{"request":{"contents": [{"role": "user", "parts": [{"text": "What is the relation between the following video and image samples?"}, {"fileData": {"fileUri": "gs://cloud-samples-data/generative-ai/video/animals.mp4", "mimeType": "video/mp4"}}, {"fileData": {"fileUri": "gs://cloud-samples-data/generative-ai/image/cricket.jpeg", "mimeType": "image/jpeg"}}]}]}}
{"request":{"contents": [{"role": "user", "parts": [{"text": "Describe what is happening in this video."}, {"fileData": {"fileUri": "gs://cloud-samples-data/generative-ai/video/another_video.mov", "mimeType": "video/mov"}}]}]}}
"""
vertex_jsonl_content = []
for _openai_jsonl_content in openai_jsonl_content:
openai_request_body = _openai_jsonl_content.get("body") or {}
vertex_request_body = _transform_request_body(
messages=openai_request_body.get("messages", []),
model=openai_request_body.get("model", ""),
optional_params=self._map_openai_to_vertex_params(openai_request_body),
custom_llm_provider="vertex_ai",
litellm_params={},
cached_content=None,
)
vertex_jsonl_content.append({"request": vertex_request_body})
return vertex_jsonl_content
def _get_gcs_object_name(
self,
openai_jsonl_content: List[Dict[str, Any]],
) -> str:
"""
Gets a unique GCS object name for the VertexAI batch prediction job
named as: litellm-vertex-{model}-{uuid}
"""
_model = openai_jsonl_content[0].get("body", {}).get("model", "")
if "publishers/google/models" not in _model:
_model = f"publishers/google/models/{_model}"
object_name = f"litellm-vertex-files/{_model}/{uuid.uuid4()}"
return object_name
def _map_openai_to_vertex_params(
self,
openai_request_body: Dict[str, Any],
) -> Dict[str, Any]:
"""
wrapper to call VertexGeminiConfig.map_openai_params
"""
_model = openai_request_body.get("model", "")
vertex_params = self.map_openai_params(
model=_model,
non_default_params=openai_request_body,
optional_params={},
drop_params=False,
)
return vertex_params
def _get_content_from_openai_file(self, openai_file_content: FileTypes) -> str:
"""
Helper to extract content from various OpenAI file types and return as string.
Handles:
- Direct content (str, bytes, IO[bytes])
- Tuple formats: (filename, content, [content_type], [headers])
- PathLike objects
"""
content: Union[str, bytes] = b""
# Extract file content from tuple if necessary
if isinstance(openai_file_content, tuple):
# Take the second element which is always the file content
file_content = openai_file_content[1]
else:
file_content = openai_file_content
# Handle different file content types
if isinstance(file_content, str):
# String content can be used directly
content = file_content
elif isinstance(file_content, bytes):
# Bytes content can be decoded
content = file_content
elif isinstance(file_content, PathLike): # PathLike
with open(str(file_content), "rb") as f:
content = f.read()
elif hasattr(file_content, "read"): # IO[bytes]
# File-like objects need to be read
content = file_content.read()
# Ensure content is string
if isinstance(content, bytes):
content = content.decode("utf-8")
return content
def transform_gcs_bucket_response_to_openai_file_object(
self, create_file_data: CreateFileRequest, gcs_upload_response: Dict[str, Any]
) -> FileObject:
"""
Transforms GCS Bucket upload file response to OpenAI FileObject
"""
gcs_id = gcs_upload_response.get("id", "")
# Remove the last numeric ID from the path
gcs_id = "/".join(gcs_id.split("/")[:-1]) if gcs_id else ""
return FileObject(
purpose=create_file_data.get("purpose", "batch"),
id=f"gs://{gcs_id}",
filename=gcs_upload_response.get("name", ""),
created_at=_convert_vertex_datetime_to_openai_datetime(
vertex_datetime=gcs_upload_response.get("timeCreated", "")
),
status="uploaded",
bytes=gcs_upload_response.get("size", 0),
object="file",
)

View file

@ -59,6 +59,8 @@ def get_files_provider_config(
custom_llm_provider: str,
):
global files_config
if custom_llm_provider == "vertex_ai":
return None
if files_config is None:
raise ValueError("files_config is not set, set it on your config.yaml file.")
for setting in files_config:
@ -212,9 +214,9 @@ async def create_file(
if llm_provider_config is not None:
# add llm_provider_config to data
_create_file_request.update(llm_provider_config)
_create_file_request.pop("custom_llm_provider", None) # type: ignore
# for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch
response = await litellm.acreate_file(**_create_file_request) # type: ignore
response = await litellm.acreate_file(**_create_file_request, custom_llm_provider=custom_llm_provider) # type: ignore
### ALERTING ###
asyncio.create_task(
@ -239,7 +241,6 @@ async def create_file(
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
)
)
return response
except Exception as e:
await proxy_logging_obj.post_call_failure_hook(

View file

@ -5225,6 +5225,7 @@ async def create_batch(
is_router_model = is_known_model(model=router_model, llm_router=llm_router)
_create_batch_data = CreateBatchRequest(**data)
custom_llm_provider = provider or _create_batch_data.pop("custom_llm_provider", None) # type: ignore
if (
litellm.enable_loadbalancing_on_batch_endpoints is True
@ -5241,10 +5242,10 @@ async def create_batch(
response = await llm_router.acreate_batch(**_create_batch_data) # type: ignore
else:
if provider is None:
provider = "openai"
if custom_llm_provider is None:
custom_llm_provider = "openai"
response = await litellm.acreate_batch(
custom_llm_provider=provider, **_create_batch_data # type: ignore
custom_llm_provider=custom_llm_provider, **_create_batch_data # type: ignore
)
### ALERTING ###

View file

@ -301,6 +301,18 @@ class ListBatchRequest(TypedDict, total=False):
timeout: Optional[float]
BatchJobStatus = Literal[
"validating",
"failed",
"in_progress",
"finalizing",
"completed",
"expired",
"cancelling",
"cancelled",
]
class ChatCompletionAudioDelta(TypedDict, total=False):
data: str
transcript: str

View file

@ -434,3 +434,43 @@ class VertexAIBatchEmbeddingsRequestBody(TypedDict, total=False):
class VertexAIBatchEmbeddingsResponseObject(TypedDict):
embeddings: List[ContentEmbeddings]
# Vertex AI Batch Prediction
class GcsSource(TypedDict):
uris: str
class InputConfig(TypedDict):
instancesFormat: str
gcsSource: GcsSource
class GcsDestination(TypedDict):
outputUriPrefix: str
class OutputConfig(TypedDict, total=False):
predictionsFormat: str
gcsDestination: GcsDestination
class VertexAIBatchPredictionJob(TypedDict):
displayName: str
model: str
inputConfig: InputConfig
outputConfig: OutputConfig
class VertexBatchPredictionResponse(TypedDict, total=False):
name: str
displayName: str
model: str
inputConfig: InputConfig
outputConfig: OutputConfig
state: str
createTime: str
updateTime: str
modelVersionId: str

View file

@ -11,8 +11,7 @@ from dotenv import load_dotenv
load_dotenv()
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import asyncio
) # Adds the parent directory to the system-path
import logging
import time
@ -20,6 +19,10 @@ import pytest
import litellm
from litellm import create_batch, create_file
from litellm._logging import verbose_logger
from test_gcs_bucket import load_vertex_ai_credentials
verbose_logger.setLevel(logging.DEBUG)
@pytest.mark.parametrize("provider", ["openai"]) # , "azure"
@ -206,3 +209,32 @@ def test_cancel_batch():
def test_list_batch():
pass
@pytest.mark.asyncio
async def test_vertex_batch_prediction():
load_vertex_ai_credentials()
file_name = "vertex_batch_completions.jsonl"
_current_dir = os.path.dirname(os.path.abspath(__file__))
file_path = os.path.join(_current_dir, file_name)
file_obj = await litellm.acreate_file(
file=open(file_path, "rb"),
purpose="batch",
custom_llm_provider="vertex_ai",
)
print("Response from creating file=", file_obj)
batch_input_file_id = file_obj.id
assert (
batch_input_file_id is not None
), f"Failed to create file, expected a non null file_id but got {batch_input_file_id}"
create_batch_response = await litellm.acreate_batch(
completion_window="24h",
endpoint="/v1/chat/completions",
input_file_id=batch_input_file_id,
custom_llm_provider="vertex_ai",
metadata={"key1": "value1", "key2": "value2"},
)
print("create_batch_response=", create_batch_response)
pass

View file

@ -0,0 +1,2 @@
{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gemini-1.5-flash-001", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 10}}
{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gemini-1.5-flash-001", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 10}}

View file

@ -85,3 +85,37 @@ async def test_batches_operations():
# Test delete file
await delete_file(session, file_id)
@pytest.mark.skip(reason="Local only test to verify if things work well")
def test_vertex_batches_endpoint():
"""
Test VertexAI Batches Endpoint
"""
import os
oai_client = OpenAI(api_key=API_KEY, base_url=BASE_URL)
file_name = "local_testing/vertex_batch_completions.jsonl"
_current_dir = os.path.dirname(os.path.abspath(__file__))
file_path = os.path.join(_current_dir, file_name)
file_obj = oai_client.files.create(
file=open(file_path, "rb"),
purpose="batch",
extra_body={"custom_llm_provider": "vertex_ai"},
)
print("Response from creating file=", file_obj)
batch_input_file_id = file_obj.id
assert (
batch_input_file_id is not None
), f"Failed to create file, expected a non null file_id but got {batch_input_file_id}"
create_batch_response = oai_client.batches.create(
completion_window="24h",
endpoint="/v1/chat/completions",
input_file_id=batch_input_file_id,
extra_body={"custom_llm_provider": "vertex_ai"},
metadata={"key1": "value1", "key2": "value2"},
)
print("response from create batch", create_batch_response)
pass