mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
(feat) add Vertex Batches API support in OpenAI format (#7032)
* working request * working transform * working request * transform vertex batch response * add _async_create_batch * move gcs functions to base * fix _get_content_from_openai_file * transform_openai_file_content_to_vertex_ai_file_content * fix transform vertex gcs bucket upload to OAI files format * working e2e test * _get_gcs_object_name * fix linting * add doc string * fix transform_gcs_bucket_response_to_openai_file_object * use vertex for batch endpoints * add batches support for vertex * test_vertex_batches_endpoint * test_vertex_batch_prediction * fix gcs bucket base auth * docs clean up batches * docs Batch API * docs vertex batches api * test_get_gcs_logging_config_without_service_account * undo change * fix vertex md * test_get_gcs_logging_config_without_service_account * ci/cd run again
This commit is contained in:
parent
dd5ccdd889
commit
0eef9df396
20 changed files with 1347 additions and 424 deletions
|
@ -6,8 +6,9 @@ import TabItem from '@theme/TabItem';
|
||||||
Covers Batches, Files
|
Covers Batches, Files
|
||||||
|
|
||||||
## **Supported Providers**:
|
## **Supported Providers**:
|
||||||
- Azure OpenAI
|
- **[Azure OpenAI](./providers/azure#azure-batches-api)**
|
||||||
- OpenAI
|
- OpenAI
|
||||||
|
- **[Vertex AI](./providers/vertex#batch-apis)**
|
||||||
|
|
||||||
## Quick Start
|
## Quick Start
|
||||||
|
|
||||||
|
@ -141,182 +142,4 @@ print("list_batches_response=", list_batches_response)
|
||||||
|
|
||||||
</Tabs>
|
</Tabs>
|
||||||
|
|
||||||
## [👉 Proxy API Reference](https://litellm-api.up.railway.app/#/batch)
|
## [Swagger API Reference](https://litellm-api.up.railway.app/#/batch)
|
||||||
|
|
||||||
## Azure Batches API
|
|
||||||
|
|
||||||
Just add the azure env vars to your environment.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
export AZURE_API_KEY=""
|
|
||||||
export AZURE_API_BASE=""
|
|
||||||
```
|
|
||||||
|
|
||||||
AND use `/azure/*` for the Batches API calls
|
|
||||||
|
|
||||||
```bash
|
|
||||||
http://0.0.0.0:4000/azure/v1/batches
|
|
||||||
```
|
|
||||||
### Usage
|
|
||||||
|
|
||||||
**Setup**
|
|
||||||
|
|
||||||
- Add Azure API Keys to your environment
|
|
||||||
|
|
||||||
#### 1. Upload a File
|
|
||||||
|
|
||||||
```bash
|
|
||||||
curl http://localhost:4000/azure/v1/files \
|
|
||||||
-H "Authorization: Bearer sk-1234" \
|
|
||||||
-F purpose="batch" \
|
|
||||||
-F file="@mydata.jsonl"
|
|
||||||
```
|
|
||||||
|
|
||||||
**Example File**
|
|
||||||
|
|
||||||
Note: `model` should be your azure deployment name.
|
|
||||||
|
|
||||||
```json
|
|
||||||
{"custom_id": "task-0", "method": "POST", "url": "/chat/completions", "body": {"model": "REPLACE-WITH-MODEL-DEPLOYMENT-NAME", "messages": [{"role": "system", "content": "You are an AI assistant that helps people find information."}, {"role": "user", "content": "When was Microsoft founded?"}]}}
|
|
||||||
{"custom_id": "task-1", "method": "POST", "url": "/chat/completions", "body": {"model": "REPLACE-WITH-MODEL-DEPLOYMENT-NAME", "messages": [{"role": "system", "content": "You are an AI assistant that helps people find information."}, {"role": "user", "content": "When was the first XBOX released?"}]}}
|
|
||||||
{"custom_id": "task-2", "method": "POST", "url": "/chat/completions", "body": {"model": "REPLACE-WITH-MODEL-DEPLOYMENT-NAME", "messages": [{"role": "system", "content": "You are an AI assistant that helps people find information."}, {"role": "user", "content": "What is Altair Basic?"}]}}
|
|
||||||
```
|
|
||||||
|
|
||||||
#### 2. Create a batch
|
|
||||||
|
|
||||||
```bash
|
|
||||||
curl http://0.0.0.0:4000/azure/v1/batches \
|
|
||||||
-H "Authorization: Bearer $LITELLM_API_KEY" \
|
|
||||||
-H "Content-Type: application/json" \
|
|
||||||
-d '{
|
|
||||||
"input_file_id": "file-abc123",
|
|
||||||
"endpoint": "/v1/chat/completions",
|
|
||||||
"completion_window": "24h"
|
|
||||||
}'
|
|
||||||
|
|
||||||
```
|
|
||||||
|
|
||||||
#### 3. Retrieve batch
|
|
||||||
|
|
||||||
|
|
||||||
```bash
|
|
||||||
curl http://0.0.0.0:4000/azure/v1/batches/batch_abc123 \
|
|
||||||
-H "Authorization: Bearer $LITELLM_API_KEY" \
|
|
||||||
-H "Content-Type: application/json" \
|
|
||||||
```
|
|
||||||
|
|
||||||
#### 4. Cancel batch
|
|
||||||
|
|
||||||
```bash
|
|
||||||
curl http://0.0.0.0:4000/azure/v1/batches/batch_abc123/cancel \
|
|
||||||
-H "Authorization: Bearer $LITELLM_API_KEY" \
|
|
||||||
-H "Content-Type: application/json" \
|
|
||||||
-X POST
|
|
||||||
```
|
|
||||||
|
|
||||||
#### 5. List Batch
|
|
||||||
|
|
||||||
```bash
|
|
||||||
curl http://0.0.0.0:4000/v1/batches?limit=2 \
|
|
||||||
-H "Authorization: Bearer $LITELLM_API_KEY" \
|
|
||||||
-H "Content-Type: application/json"
|
|
||||||
```
|
|
||||||
|
|
||||||
### [👉 Health Check Azure Batch models](./proxy/health.md#batch-models-azure-only)
|
|
||||||
|
|
||||||
|
|
||||||
### [BETA] Loadbalance Multiple Azure Deployments
|
|
||||||
In your config.yaml, set `enable_loadbalancing_on_batch_endpoints: true`
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
model_list:
|
|
||||||
- model_name: "batch-gpt-4o-mini"
|
|
||||||
litellm_params:
|
|
||||||
model: "azure/gpt-4o-mini"
|
|
||||||
api_key: os.environ/AZURE_API_KEY
|
|
||||||
api_base: os.environ/AZURE_API_BASE
|
|
||||||
model_info:
|
|
||||||
mode: batch
|
|
||||||
|
|
||||||
litellm_settings:
|
|
||||||
enable_loadbalancing_on_batch_endpoints: true # 👈 KEY CHANGE
|
|
||||||
```
|
|
||||||
|
|
||||||
Note: This works on `{PROXY_BASE_URL}/v1/files` and `{PROXY_BASE_URL}/v1/batches`.
|
|
||||||
Note: Response is in the OpenAI-format.
|
|
||||||
|
|
||||||
1. Upload a file
|
|
||||||
|
|
||||||
Just set `model: batch-gpt-4o-mini` in your .jsonl.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
curl http://localhost:4000/v1/files \
|
|
||||||
-H "Authorization: Bearer sk-1234" \
|
|
||||||
-F purpose="batch" \
|
|
||||||
-F file="@mydata.jsonl"
|
|
||||||
```
|
|
||||||
|
|
||||||
**Example File**
|
|
||||||
|
|
||||||
Note: `model` should be your azure deployment name.
|
|
||||||
|
|
||||||
```json
|
|
||||||
{"custom_id": "task-0", "method": "POST", "url": "/chat/completions", "body": {"model": "batch-gpt-4o-mini", "messages": [{"role": "system", "content": "You are an AI assistant that helps people find information."}, {"role": "user", "content": "When was Microsoft founded?"}]}}
|
|
||||||
{"custom_id": "task-1", "method": "POST", "url": "/chat/completions", "body": {"model": "batch-gpt-4o-mini", "messages": [{"role": "system", "content": "You are an AI assistant that helps people find information."}, {"role": "user", "content": "When was the first XBOX released?"}]}}
|
|
||||||
{"custom_id": "task-2", "method": "POST", "url": "/chat/completions", "body": {"model": "batch-gpt-4o-mini", "messages": [{"role": "system", "content": "You are an AI assistant that helps people find information."}, {"role": "user", "content": "What is Altair Basic?"}]}}
|
|
||||||
```
|
|
||||||
|
|
||||||
Expected Response (OpenAI-compatible)
|
|
||||||
|
|
||||||
```bash
|
|
||||||
{"id":"file-f0be81f654454113a922da60acb0eea6",...}
|
|
||||||
```
|
|
||||||
|
|
||||||
2. Create a batch
|
|
||||||
|
|
||||||
```bash
|
|
||||||
curl http://0.0.0.0:4000/v1/batches \
|
|
||||||
-H "Authorization: Bearer $LITELLM_API_KEY" \
|
|
||||||
-H "Content-Type: application/json" \
|
|
||||||
-d '{
|
|
||||||
"input_file_id": "file-f0be81f654454113a922da60acb0eea6",
|
|
||||||
"endpoint": "/v1/chat/completions",
|
|
||||||
"completion_window": "24h",
|
|
||||||
"model: "batch-gpt-4o-mini"
|
|
||||||
}'
|
|
||||||
```
|
|
||||||
|
|
||||||
Expected Response:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
{"id":"batch_94e43f0a-d805-477d-adf9-bbb9c50910ed",...}
|
|
||||||
```
|
|
||||||
|
|
||||||
3. Retrieve a batch
|
|
||||||
|
|
||||||
```bash
|
|
||||||
curl http://0.0.0.0:4000/v1/batches/batch_94e43f0a-d805-477d-adf9-bbb9c50910ed \
|
|
||||||
-H "Authorization: Bearer $LITELLM_API_KEY" \
|
|
||||||
-H "Content-Type: application/json" \
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
Expected Response:
|
|
||||||
|
|
||||||
```
|
|
||||||
{"id":"batch_94e43f0a-d805-477d-adf9-bbb9c50910ed",...}
|
|
||||||
```
|
|
||||||
|
|
||||||
4. List batch
|
|
||||||
|
|
||||||
```bash
|
|
||||||
curl http://0.0.0.0:4000/v1/batches?limit=2 \
|
|
||||||
-H "Authorization: Bearer $LITELLM_API_KEY" \
|
|
||||||
-H "Content-Type: application/json"
|
|
||||||
```
|
|
||||||
|
|
||||||
Expected Response:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
{"data":[{"id":"batch_R3V...}
|
|
||||||
```
|
|
||||||
|
|
|
@ -559,6 +559,185 @@ litellm_settings:
|
||||||
</Tabs>
|
</Tabs>
|
||||||
|
|
||||||
|
|
||||||
|
## **Azure Batches API**
|
||||||
|
|
||||||
|
Just add the azure env vars to your environment.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export AZURE_API_KEY=""
|
||||||
|
export AZURE_API_BASE=""
|
||||||
|
```
|
||||||
|
|
||||||
|
AND use `/azure/*` for the Batches API calls
|
||||||
|
|
||||||
|
```bash
|
||||||
|
http://0.0.0.0:4000/azure/v1/batches
|
||||||
|
```
|
||||||
|
### Usage
|
||||||
|
|
||||||
|
**Setup**
|
||||||
|
|
||||||
|
- Add Azure API Keys to your environment
|
||||||
|
|
||||||
|
#### 1. Upload a File
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl http://localhost:4000/azure/v1/files \
|
||||||
|
-H "Authorization: Bearer sk-1234" \
|
||||||
|
-F purpose="batch" \
|
||||||
|
-F file="@mydata.jsonl"
|
||||||
|
```
|
||||||
|
|
||||||
|
**Example File**
|
||||||
|
|
||||||
|
Note: `model` should be your azure deployment name.
|
||||||
|
|
||||||
|
```json
|
||||||
|
{"custom_id": "task-0", "method": "POST", "url": "/chat/completions", "body": {"model": "REPLACE-WITH-MODEL-DEPLOYMENT-NAME", "messages": [{"role": "system", "content": "You are an AI assistant that helps people find information."}, {"role": "user", "content": "When was Microsoft founded?"}]}}
|
||||||
|
{"custom_id": "task-1", "method": "POST", "url": "/chat/completions", "body": {"model": "REPLACE-WITH-MODEL-DEPLOYMENT-NAME", "messages": [{"role": "system", "content": "You are an AI assistant that helps people find information."}, {"role": "user", "content": "When was the first XBOX released?"}]}}
|
||||||
|
{"custom_id": "task-2", "method": "POST", "url": "/chat/completions", "body": {"model": "REPLACE-WITH-MODEL-DEPLOYMENT-NAME", "messages": [{"role": "system", "content": "You are an AI assistant that helps people find information."}, {"role": "user", "content": "What is Altair Basic?"}]}}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 2. Create a batch
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl http://0.0.0.0:4000/azure/v1/batches \
|
||||||
|
-H "Authorization: Bearer $LITELLM_API_KEY" \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"input_file_id": "file-abc123",
|
||||||
|
"endpoint": "/v1/chat/completions",
|
||||||
|
"completion_window": "24h"
|
||||||
|
}'
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 3. Retrieve batch
|
||||||
|
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl http://0.0.0.0:4000/azure/v1/batches/batch_abc123 \
|
||||||
|
-H "Authorization: Bearer $LITELLM_API_KEY" \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 4. Cancel batch
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl http://0.0.0.0:4000/azure/v1/batches/batch_abc123/cancel \
|
||||||
|
-H "Authorization: Bearer $LITELLM_API_KEY" \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-X POST
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 5. List Batch
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl http://0.0.0.0:4000/v1/batches?limit=2 \
|
||||||
|
-H "Authorization: Bearer $LITELLM_API_KEY" \
|
||||||
|
-H "Content-Type: application/json"
|
||||||
|
```
|
||||||
|
|
||||||
|
### [Health Check Azure Batch models](./proxy/health.md#batch-models-azure-only)
|
||||||
|
|
||||||
|
|
||||||
|
### [BETA] Loadbalance Multiple Azure Deployments
|
||||||
|
In your config.yaml, set `enable_loadbalancing_on_batch_endpoints: true`
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: "batch-gpt-4o-mini"
|
||||||
|
litellm_params:
|
||||||
|
model: "azure/gpt-4o-mini"
|
||||||
|
api_key: os.environ/AZURE_API_KEY
|
||||||
|
api_base: os.environ/AZURE_API_BASE
|
||||||
|
model_info:
|
||||||
|
mode: batch
|
||||||
|
|
||||||
|
litellm_settings:
|
||||||
|
enable_loadbalancing_on_batch_endpoints: true # 👈 KEY CHANGE
|
||||||
|
```
|
||||||
|
|
||||||
|
Note: This works on `{PROXY_BASE_URL}/v1/files` and `{PROXY_BASE_URL}/v1/batches`.
|
||||||
|
Note: Response is in the OpenAI-format.
|
||||||
|
|
||||||
|
1. Upload a file
|
||||||
|
|
||||||
|
Just set `model: batch-gpt-4o-mini` in your .jsonl.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl http://localhost:4000/v1/files \
|
||||||
|
-H "Authorization: Bearer sk-1234" \
|
||||||
|
-F purpose="batch" \
|
||||||
|
-F file="@mydata.jsonl"
|
||||||
|
```
|
||||||
|
|
||||||
|
**Example File**
|
||||||
|
|
||||||
|
Note: `model` should be your azure deployment name.
|
||||||
|
|
||||||
|
```json
|
||||||
|
{"custom_id": "task-0", "method": "POST", "url": "/chat/completions", "body": {"model": "batch-gpt-4o-mini", "messages": [{"role": "system", "content": "You are an AI assistant that helps people find information."}, {"role": "user", "content": "When was Microsoft founded?"}]}}
|
||||||
|
{"custom_id": "task-1", "method": "POST", "url": "/chat/completions", "body": {"model": "batch-gpt-4o-mini", "messages": [{"role": "system", "content": "You are an AI assistant that helps people find information."}, {"role": "user", "content": "When was the first XBOX released?"}]}}
|
||||||
|
{"custom_id": "task-2", "method": "POST", "url": "/chat/completions", "body": {"model": "batch-gpt-4o-mini", "messages": [{"role": "system", "content": "You are an AI assistant that helps people find information."}, {"role": "user", "content": "What is Altair Basic?"}]}}
|
||||||
|
```
|
||||||
|
|
||||||
|
Expected Response (OpenAI-compatible)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
{"id":"file-f0be81f654454113a922da60acb0eea6",...}
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Create a batch
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl http://0.0.0.0:4000/v1/batches \
|
||||||
|
-H "Authorization: Bearer $LITELLM_API_KEY" \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"input_file_id": "file-f0be81f654454113a922da60acb0eea6",
|
||||||
|
"endpoint": "/v1/chat/completions",
|
||||||
|
"completion_window": "24h",
|
||||||
|
"model: "batch-gpt-4o-mini"
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
Expected Response:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
{"id":"batch_94e43f0a-d805-477d-adf9-bbb9c50910ed",...}
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Retrieve a batch
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl http://0.0.0.0:4000/v1/batches/batch_94e43f0a-d805-477d-adf9-bbb9c50910ed \
|
||||||
|
-H "Authorization: Bearer $LITELLM_API_KEY" \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
Expected Response:
|
||||||
|
|
||||||
|
```
|
||||||
|
{"id":"batch_94e43f0a-d805-477d-adf9-bbb9c50910ed",...}
|
||||||
|
```
|
||||||
|
|
||||||
|
4. List batch
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl http://0.0.0.0:4000/v1/batches?limit=2 \
|
||||||
|
-H "Authorization: Bearer $LITELLM_API_KEY" \
|
||||||
|
-H "Content-Type: application/json"
|
||||||
|
```
|
||||||
|
|
||||||
|
Expected Response:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
{"data":[{"id":"batch_R3V...}
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
## Advanced
|
## Advanced
|
||||||
### Azure API Load-Balancing
|
### Azure API Load-Balancing
|
||||||
|
|
||||||
|
|
|
@ -2393,6 +2393,114 @@ print("response from proxy", response)
|
||||||
</TabItem>
|
</TabItem>
|
||||||
</Tabs>
|
</Tabs>
|
||||||
|
|
||||||
|
## **Batch APIs**
|
||||||
|
|
||||||
|
Just add the following Vertex env vars to your environment.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# GCS Bucket settings, used to store batch prediction files in
|
||||||
|
export GCS_BUCKET_NAME = "litellm-testing-bucket" # the bucket you want to store batch prediction files in
|
||||||
|
export GCS_PATH_SERVICE_ACCOUNT="/path/to/service_account.json" # path to your service account json file
|
||||||
|
|
||||||
|
# Vertex /batch endpoint settings, used for LLM API requests
|
||||||
|
export GOOGLE_APPLICATION_CREDENTIALS="/path/to/service_account.json" # path to your service account json file
|
||||||
|
export VERTEXAI_LOCATION="us-central1" # can be any vertex location
|
||||||
|
export VERTEXAI_PROJECT="my-test-project"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Usage
|
||||||
|
|
||||||
|
|
||||||
|
#### 1. Create a file of batch requests for vertex
|
||||||
|
|
||||||
|
LiteLLM expects the file to follow the **[OpenAI batches files format](https://platform.openai.com/docs/guides/batch)**
|
||||||
|
|
||||||
|
Each `body` in the file should be an **OpenAI API request**
|
||||||
|
|
||||||
|
Create a file called `vertex_batch_completions.jsonl` in the current working directory, the `model` should be the Vertex AI model name
|
||||||
|
```
|
||||||
|
{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gemini-1.5-flash-001", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 10}}
|
||||||
|
{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gemini-1.5-flash-001", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 10}}
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
#### 2. Upload a File of batch requests
|
||||||
|
|
||||||
|
For `vertex_ai` litellm will upload the file to the provided `GCS_BUCKET_NAME`
|
||||||
|
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
oai_client = OpenAI(
|
||||||
|
api_key="sk-1234", # litellm proxy API key
|
||||||
|
base_url="http://localhost:4000" # litellm proxy base url
|
||||||
|
)
|
||||||
|
file_name = "vertex_batch_completions.jsonl" #
|
||||||
|
_current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
file_path = os.path.join(_current_dir, file_name)
|
||||||
|
file_obj = oai_client.files.create(
|
||||||
|
file=open(file_path, "rb"),
|
||||||
|
purpose="batch",
|
||||||
|
extra_body={"custom_llm_provider": "vertex_ai"}, # tell litellm to use vertex_ai for this file upload
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Expected Response**
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"id": "gs://litellm-testing-bucket/litellm-vertex-files/publishers/google/models/gemini-1.5-flash-001/d3f198cd-c0d1-436d-9b1e-28e3f282997a",
|
||||||
|
"bytes": 416,
|
||||||
|
"created_at": 1733392026,
|
||||||
|
"filename": "litellm-vertex-files/publishers/google/models/gemini-1.5-flash-001/d3f198cd-c0d1-436d-9b1e-28e3f282997a",
|
||||||
|
"object": "file",
|
||||||
|
"purpose": "batch",
|
||||||
|
"status": "uploaded",
|
||||||
|
"status_details": null
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#### 3. Create a batch
|
||||||
|
|
||||||
|
```python
|
||||||
|
batch_input_file_id = file_obj.id # use `file_obj` from step 2
|
||||||
|
create_batch_response = oai_client.batches.create(
|
||||||
|
completion_window="24h",
|
||||||
|
endpoint="/v1/chat/completions",
|
||||||
|
input_file_id=batch_input_file_id, # example input_file_id = "gs://litellm-testing-bucket/litellm-vertex-files/publishers/google/models/gemini-1.5-flash-001/c2b1b785-252b-448c-b180-033c4c63b3ce"
|
||||||
|
extra_body={"custom_llm_provider": "vertex_ai"}, # tell litellm to use `vertex_ai` for this batch request
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Expected Response**
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"id": "projects/633608382793/locations/us-central1/batchPredictionJobs/986266568679751680",
|
||||||
|
"completion_window": "24hrs",
|
||||||
|
"created_at": 1733392026,
|
||||||
|
"endpoint": "",
|
||||||
|
"input_file_id": "gs://litellm-testing-bucket/litellm-vertex-files/publishers/google/models/gemini-1.5-flash-001/d3f198cd-c0d1-436d-9b1e-28e3f282997a",
|
||||||
|
"object": "batch",
|
||||||
|
"status": "validating",
|
||||||
|
"cancelled_at": null,
|
||||||
|
"cancelling_at": null,
|
||||||
|
"completed_at": null,
|
||||||
|
"error_file_id": null,
|
||||||
|
"errors": null,
|
||||||
|
"expired_at": null,
|
||||||
|
"expires_at": null,
|
||||||
|
"failed_at": null,
|
||||||
|
"finalizing_at": null,
|
||||||
|
"in_progress_at": null,
|
||||||
|
"metadata": null,
|
||||||
|
"output_file_id": "gs://litellm-testing-bucket/litellm-vertex-files/publishers/google/models/gemini-1.5-flash-001",
|
||||||
|
"request_counts": null
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
## Extra
|
## Extra
|
||||||
|
|
||||||
### Using `GOOGLE_APPLICATION_CREDENTIALS`
|
### Using `GOOGLE_APPLICATION_CREDENTIALS`
|
||||||
|
|
|
@ -22,6 +22,9 @@ import litellm
|
||||||
from litellm import client
|
from litellm import client
|
||||||
from litellm.llms.AzureOpenAI.azure import AzureBatchesAPI
|
from litellm.llms.AzureOpenAI.azure import AzureBatchesAPI
|
||||||
from litellm.llms.OpenAI.openai import OpenAIBatchesAPI
|
from litellm.llms.OpenAI.openai import OpenAIBatchesAPI
|
||||||
|
from litellm.llms.vertex_ai_and_google_ai_studio.batches.handler import (
|
||||||
|
VertexAIBatchPrediction,
|
||||||
|
)
|
||||||
from litellm.secret_managers.main import get_secret, get_secret_str
|
from litellm.secret_managers.main import get_secret, get_secret_str
|
||||||
from litellm.types.llms.openai import (
|
from litellm.types.llms.openai import (
|
||||||
Batch,
|
Batch,
|
||||||
|
@ -40,6 +43,7 @@ from litellm.utils import supports_httpx_timeout
|
||||||
####### ENVIRONMENT VARIABLES ###################
|
####### ENVIRONMENT VARIABLES ###################
|
||||||
openai_batches_instance = OpenAIBatchesAPI()
|
openai_batches_instance = OpenAIBatchesAPI()
|
||||||
azure_batches_instance = AzureBatchesAPI()
|
azure_batches_instance = AzureBatchesAPI()
|
||||||
|
vertex_ai_batches_instance = VertexAIBatchPrediction(gcs_bucket_name="")
|
||||||
#################################################
|
#################################################
|
||||||
|
|
||||||
|
|
||||||
|
@ -47,7 +51,7 @@ async def acreate_batch(
|
||||||
completion_window: Literal["24h"],
|
completion_window: Literal["24h"],
|
||||||
endpoint: Literal["/v1/chat/completions", "/v1/embeddings", "/v1/completions"],
|
endpoint: Literal["/v1/chat/completions", "/v1/embeddings", "/v1/completions"],
|
||||||
input_file_id: str,
|
input_file_id: str,
|
||||||
custom_llm_provider: Literal["openai", "azure"] = "openai",
|
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
||||||
metadata: Optional[Dict[str, str]] = None,
|
metadata: Optional[Dict[str, str]] = None,
|
||||||
extra_headers: Optional[Dict[str, str]] = None,
|
extra_headers: Optional[Dict[str, str]] = None,
|
||||||
extra_body: Optional[Dict[str, str]] = None,
|
extra_body: Optional[Dict[str, str]] = None,
|
||||||
|
@ -93,7 +97,7 @@ def create_batch(
|
||||||
completion_window: Literal["24h"],
|
completion_window: Literal["24h"],
|
||||||
endpoint: Literal["/v1/chat/completions", "/v1/embeddings", "/v1/completions"],
|
endpoint: Literal["/v1/chat/completions", "/v1/embeddings", "/v1/completions"],
|
||||||
input_file_id: str,
|
input_file_id: str,
|
||||||
custom_llm_provider: Literal["openai", "azure"] = "openai",
|
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
||||||
metadata: Optional[Dict[str, str]] = None,
|
metadata: Optional[Dict[str, str]] = None,
|
||||||
extra_headers: Optional[Dict[str, str]] = None,
|
extra_headers: Optional[Dict[str, str]] = None,
|
||||||
extra_body: Optional[Dict[str, str]] = None,
|
extra_body: Optional[Dict[str, str]] = None,
|
||||||
|
@ -199,6 +203,32 @@ def create_batch(
|
||||||
max_retries=optional_params.max_retries,
|
max_retries=optional_params.max_retries,
|
||||||
create_batch_data=_create_batch_request,
|
create_batch_data=_create_batch_request,
|
||||||
)
|
)
|
||||||
|
elif custom_llm_provider == "vertex_ai":
|
||||||
|
api_base = optional_params.api_base or ""
|
||||||
|
vertex_ai_project = (
|
||||||
|
optional_params.vertex_project
|
||||||
|
or litellm.vertex_project
|
||||||
|
or get_secret_str("VERTEXAI_PROJECT")
|
||||||
|
)
|
||||||
|
vertex_ai_location = (
|
||||||
|
optional_params.vertex_location
|
||||||
|
or litellm.vertex_location
|
||||||
|
or get_secret_str("VERTEXAI_LOCATION")
|
||||||
|
)
|
||||||
|
vertex_credentials = optional_params.vertex_credentials or get_secret_str(
|
||||||
|
"VERTEXAI_CREDENTIALS"
|
||||||
|
)
|
||||||
|
|
||||||
|
response = vertex_ai_batches_instance.create_batch(
|
||||||
|
_is_async=_is_async,
|
||||||
|
api_base=api_base,
|
||||||
|
vertex_project=vertex_ai_project,
|
||||||
|
vertex_location=vertex_ai_location,
|
||||||
|
vertex_credentials=vertex_credentials,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=optional_params.max_retries,
|
||||||
|
create_batch_data=_create_batch_request,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise litellm.exceptions.BadRequestError(
|
raise litellm.exceptions.BadRequestError(
|
||||||
message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format(
|
message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format(
|
||||||
|
|
|
@ -17,6 +17,9 @@ import litellm
|
||||||
from litellm import client, get_secret_str
|
from litellm import client, get_secret_str
|
||||||
from litellm.llms.files_apis.azure import AzureOpenAIFilesAPI
|
from litellm.llms.files_apis.azure import AzureOpenAIFilesAPI
|
||||||
from litellm.llms.OpenAI.openai import FileDeleted, FileObject, OpenAIFilesAPI
|
from litellm.llms.OpenAI.openai import FileDeleted, FileObject, OpenAIFilesAPI
|
||||||
|
from litellm.llms.vertex_ai_and_google_ai_studio.files.handler import (
|
||||||
|
VertexAIFilesHandler,
|
||||||
|
)
|
||||||
from litellm.types.llms.openai import (
|
from litellm.types.llms.openai import (
|
||||||
Batch,
|
Batch,
|
||||||
CreateFileRequest,
|
CreateFileRequest,
|
||||||
|
@ -30,6 +33,7 @@ from litellm.utils import supports_httpx_timeout
|
||||||
####### ENVIRONMENT VARIABLES ###################
|
####### ENVIRONMENT VARIABLES ###################
|
||||||
openai_files_instance = OpenAIFilesAPI()
|
openai_files_instance = OpenAIFilesAPI()
|
||||||
azure_files_instance = AzureOpenAIFilesAPI()
|
azure_files_instance = AzureOpenAIFilesAPI()
|
||||||
|
vertex_ai_files_instance = VertexAIFilesHandler()
|
||||||
#################################################
|
#################################################
|
||||||
|
|
||||||
|
|
||||||
|
@ -490,7 +494,7 @@ def file_list(
|
||||||
async def acreate_file(
|
async def acreate_file(
|
||||||
file: FileTypes,
|
file: FileTypes,
|
||||||
purpose: Literal["assistants", "batch", "fine-tune"],
|
purpose: Literal["assistants", "batch", "fine-tune"],
|
||||||
custom_llm_provider: Literal["openai", "azure"] = "openai",
|
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
||||||
extra_headers: Optional[Dict[str, str]] = None,
|
extra_headers: Optional[Dict[str, str]] = None,
|
||||||
extra_body: Optional[Dict[str, str]] = None,
|
extra_body: Optional[Dict[str, str]] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
|
@ -532,7 +536,7 @@ async def acreate_file(
|
||||||
def create_file(
|
def create_file(
|
||||||
file: FileTypes,
|
file: FileTypes,
|
||||||
purpose: Literal["assistants", "batch", "fine-tune"],
|
purpose: Literal["assistants", "batch", "fine-tune"],
|
||||||
custom_llm_provider: Literal["openai", "azure"] = "openai",
|
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
||||||
extra_headers: Optional[Dict[str, str]] = None,
|
extra_headers: Optional[Dict[str, str]] = None,
|
||||||
extra_body: Optional[Dict[str, str]] = None,
|
extra_body: Optional[Dict[str, str]] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
|
@ -630,6 +634,32 @@ def create_file(
|
||||||
max_retries=optional_params.max_retries,
|
max_retries=optional_params.max_retries,
|
||||||
create_file_data=_create_file_request,
|
create_file_data=_create_file_request,
|
||||||
)
|
)
|
||||||
|
elif custom_llm_provider == "vertex_ai":
|
||||||
|
api_base = optional_params.api_base or ""
|
||||||
|
vertex_ai_project = (
|
||||||
|
optional_params.vertex_project
|
||||||
|
or litellm.vertex_project
|
||||||
|
or get_secret_str("VERTEXAI_PROJECT")
|
||||||
|
)
|
||||||
|
vertex_ai_location = (
|
||||||
|
optional_params.vertex_location
|
||||||
|
or litellm.vertex_location
|
||||||
|
or get_secret_str("VERTEXAI_LOCATION")
|
||||||
|
)
|
||||||
|
vertex_credentials = optional_params.vertex_credentials or get_secret_str(
|
||||||
|
"VERTEXAI_CREDENTIALS"
|
||||||
|
)
|
||||||
|
|
||||||
|
response = vertex_ai_files_instance.create_file(
|
||||||
|
_is_async=_is_async,
|
||||||
|
api_base=api_base,
|
||||||
|
vertex_project=vertex_ai_project,
|
||||||
|
vertex_location=vertex_ai_location,
|
||||||
|
vertex_credentials=vertex_credentials,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=optional_params.max_retries,
|
||||||
|
create_file_data=_create_file_request,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise litellm.exceptions.BadRequestError(
|
raise litellm.exceptions.BadRequestError(
|
||||||
message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format(
|
message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format(
|
||||||
|
|
|
@ -29,7 +29,6 @@ else:
|
||||||
VertexBase = Any
|
VertexBase = Any
|
||||||
|
|
||||||
|
|
||||||
IAM_AUTH_KEY = "IAM_AUTH"
|
|
||||||
GCS_DEFAULT_BATCH_SIZE = 2048
|
GCS_DEFAULT_BATCH_SIZE = 2048
|
||||||
GCS_DEFAULT_FLUSH_INTERVAL_SECONDS = 20
|
GCS_DEFAULT_FLUSH_INTERVAL_SECONDS = 20
|
||||||
|
|
||||||
|
@ -39,7 +38,6 @@ class GCSBucketLogger(GCSBucketBase):
|
||||||
from litellm.proxy.proxy_server import premium_user
|
from litellm.proxy.proxy_server import premium_user
|
||||||
|
|
||||||
super().__init__(bucket_name=bucket_name)
|
super().__init__(bucket_name=bucket_name)
|
||||||
self.vertex_instances: Dict[str, VertexBase] = {}
|
|
||||||
|
|
||||||
# Init Batch logging settings
|
# Init Batch logging settings
|
||||||
self.log_queue: List[GCSLogQueueItem] = []
|
self.log_queue: List[GCSLogQueueItem] = []
|
||||||
|
@ -178,232 +176,3 @@ class GCSBucketLogger(GCSBucketBase):
|
||||||
object_name = _metadata["gcs_log_id"]
|
object_name = _metadata["gcs_log_id"]
|
||||||
|
|
||||||
return object_name
|
return object_name
|
||||||
|
|
||||||
def _handle_folders_in_bucket_name(
|
|
||||||
self,
|
|
||||||
bucket_name: str,
|
|
||||||
object_name: str,
|
|
||||||
) -> Tuple[str, str]:
|
|
||||||
"""
|
|
||||||
Handles when the user passes a bucket name with a folder postfix
|
|
||||||
|
|
||||||
|
|
||||||
Example:
|
|
||||||
- Bucket name: "my-bucket/my-folder/dev"
|
|
||||||
- Object name: "my-object"
|
|
||||||
- Returns: bucket_name="my-bucket", object_name="my-folder/dev/my-object"
|
|
||||||
|
|
||||||
"""
|
|
||||||
if "/" in bucket_name:
|
|
||||||
bucket_name, prefix = bucket_name.split("/", 1)
|
|
||||||
object_name = f"{prefix}/{object_name}"
|
|
||||||
return bucket_name, object_name
|
|
||||||
return bucket_name, object_name
|
|
||||||
|
|
||||||
async def _log_json_data_on_gcs(
|
|
||||||
self,
|
|
||||||
headers: Dict[str, str],
|
|
||||||
bucket_name: str,
|
|
||||||
object_name: str,
|
|
||||||
logging_payload: StandardLoggingPayload,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Helper function to make POST request to GCS Bucket in the specified bucket.
|
|
||||||
"""
|
|
||||||
json_logged_payload = json.dumps(logging_payload, default=str)
|
|
||||||
|
|
||||||
bucket_name, object_name = self._handle_folders_in_bucket_name(
|
|
||||||
bucket_name=bucket_name,
|
|
||||||
object_name=object_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
response = await self.async_httpx_client.post(
|
|
||||||
headers=headers,
|
|
||||||
url=f"https://storage.googleapis.com/upload/storage/v1/b/{bucket_name}/o?uploadType=media&name={object_name}",
|
|
||||||
data=json_logged_payload,
|
|
||||||
)
|
|
||||||
|
|
||||||
if response.status_code != 200:
|
|
||||||
verbose_logger.error("GCS Bucket logging error: %s", str(response.text))
|
|
||||||
|
|
||||||
verbose_logger.debug("GCS Bucket response %s", response)
|
|
||||||
verbose_logger.debug("GCS Bucket status code %s", response.status_code)
|
|
||||||
verbose_logger.debug("GCS Bucket response.text %s", response.text)
|
|
||||||
|
|
||||||
async def get_gcs_logging_config(
|
|
||||||
self, kwargs: Optional[Dict[str, Any]] = {}
|
|
||||||
) -> GCSLoggingConfig:
|
|
||||||
"""
|
|
||||||
This function is used to get the GCS logging config for the GCS Bucket Logger.
|
|
||||||
It checks if the dynamic parameters are provided in the kwargs and uses them to get the GCS logging config.
|
|
||||||
If no dynamic parameters are provided, it uses the default values.
|
|
||||||
"""
|
|
||||||
if kwargs is None:
|
|
||||||
kwargs = {}
|
|
||||||
|
|
||||||
standard_callback_dynamic_params: Optional[StandardCallbackDynamicParams] = (
|
|
||||||
kwargs.get("standard_callback_dynamic_params", None)
|
|
||||||
)
|
|
||||||
|
|
||||||
bucket_name: str
|
|
||||||
path_service_account: Optional[str]
|
|
||||||
if standard_callback_dynamic_params is not None:
|
|
||||||
verbose_logger.debug("Using dynamic GCS logging")
|
|
||||||
verbose_logger.debug(
|
|
||||||
"standard_callback_dynamic_params: %s", standard_callback_dynamic_params
|
|
||||||
)
|
|
||||||
|
|
||||||
_bucket_name: Optional[str] = (
|
|
||||||
standard_callback_dynamic_params.get("gcs_bucket_name", None)
|
|
||||||
or self.BUCKET_NAME
|
|
||||||
)
|
|
||||||
_path_service_account: Optional[str] = (
|
|
||||||
standard_callback_dynamic_params.get("gcs_path_service_account", None)
|
|
||||||
or self.path_service_account_json
|
|
||||||
)
|
|
||||||
|
|
||||||
if _bucket_name is None:
|
|
||||||
raise ValueError(
|
|
||||||
"GCS_BUCKET_NAME is not set in the environment, but GCS Bucket is being used as a logging callback. Please set 'GCS_BUCKET_NAME' in the environment."
|
|
||||||
)
|
|
||||||
bucket_name = _bucket_name
|
|
||||||
path_service_account = _path_service_account
|
|
||||||
vertex_instance = await self.get_or_create_vertex_instance(
|
|
||||||
credentials=path_service_account
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# If no dynamic parameters, use the default instance
|
|
||||||
if self.BUCKET_NAME is None:
|
|
||||||
raise ValueError(
|
|
||||||
"GCS_BUCKET_NAME is not set in the environment, but GCS Bucket is being used as a logging callback. Please set 'GCS_BUCKET_NAME' in the environment."
|
|
||||||
)
|
|
||||||
bucket_name = self.BUCKET_NAME
|
|
||||||
path_service_account = self.path_service_account_json
|
|
||||||
vertex_instance = await self.get_or_create_vertex_instance(
|
|
||||||
credentials=path_service_account
|
|
||||||
)
|
|
||||||
|
|
||||||
return GCSLoggingConfig(
|
|
||||||
bucket_name=bucket_name,
|
|
||||||
vertex_instance=vertex_instance,
|
|
||||||
path_service_account=path_service_account,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def get_or_create_vertex_instance(
|
|
||||||
self, credentials: Optional[str]
|
|
||||||
) -> VertexBase:
|
|
||||||
"""
|
|
||||||
This function is used to get the Vertex instance for the GCS Bucket Logger.
|
|
||||||
It checks if the Vertex instance is already created and cached, if not it creates a new instance and caches it.
|
|
||||||
"""
|
|
||||||
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_llm_base import (
|
|
||||||
VertexBase,
|
|
||||||
)
|
|
||||||
|
|
||||||
_in_memory_key = self._get_in_memory_key_for_vertex_instance(credentials)
|
|
||||||
if _in_memory_key not in self.vertex_instances:
|
|
||||||
vertex_instance = VertexBase()
|
|
||||||
await vertex_instance._ensure_access_token_async(
|
|
||||||
credentials=credentials,
|
|
||||||
project_id=None,
|
|
||||||
custom_llm_provider="vertex_ai",
|
|
||||||
)
|
|
||||||
self.vertex_instances[_in_memory_key] = vertex_instance
|
|
||||||
return self.vertex_instances[_in_memory_key]
|
|
||||||
|
|
||||||
def _get_in_memory_key_for_vertex_instance(self, credentials: Optional[str]) -> str:
|
|
||||||
"""
|
|
||||||
Returns key to use for caching the Vertex instance in-memory.
|
|
||||||
|
|
||||||
When using Vertex with Key based logging, we need to cache the Vertex instance in-memory.
|
|
||||||
|
|
||||||
- If a credentials string is provided, it is used as the key.
|
|
||||||
- If no credentials string is provided, "IAM_AUTH" is used as the key.
|
|
||||||
"""
|
|
||||||
return credentials or IAM_AUTH_KEY
|
|
||||||
|
|
||||||
async def download_gcs_object(self, object_name: str, **kwargs):
|
|
||||||
"""
|
|
||||||
Download an object from GCS.
|
|
||||||
|
|
||||||
https://cloud.google.com/storage/docs/downloading-objects#download-object-json
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
|
|
||||||
kwargs=kwargs
|
|
||||||
)
|
|
||||||
headers = await self.construct_request_headers(
|
|
||||||
vertex_instance=gcs_logging_config["vertex_instance"],
|
|
||||||
service_account_json=gcs_logging_config["path_service_account"],
|
|
||||||
)
|
|
||||||
bucket_name = gcs_logging_config["bucket_name"]
|
|
||||||
bucket_name, object_name = self._handle_folders_in_bucket_name(
|
|
||||||
bucket_name=bucket_name,
|
|
||||||
object_name=object_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
url = f"https://storage.googleapis.com/storage/v1/b/{bucket_name}/o/{object_name}?alt=media"
|
|
||||||
|
|
||||||
# Send the GET request to download the object
|
|
||||||
response = await self.async_httpx_client.get(url=url, headers=headers)
|
|
||||||
|
|
||||||
if response.status_code != 200:
|
|
||||||
verbose_logger.error(
|
|
||||||
"GCS object download error: %s", str(response.text)
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
verbose_logger.debug(
|
|
||||||
"GCS object download response status code: %s", response.status_code
|
|
||||||
)
|
|
||||||
|
|
||||||
# Return the content of the downloaded object
|
|
||||||
return response.content
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
verbose_logger.error("GCS object download error: %s", str(e))
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def delete_gcs_object(self, object_name: str, **kwargs):
|
|
||||||
"""
|
|
||||||
Delete an object from GCS.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
|
|
||||||
kwargs=kwargs
|
|
||||||
)
|
|
||||||
headers = await self.construct_request_headers(
|
|
||||||
vertex_instance=gcs_logging_config["vertex_instance"],
|
|
||||||
service_account_json=gcs_logging_config["path_service_account"],
|
|
||||||
)
|
|
||||||
bucket_name = gcs_logging_config["bucket_name"]
|
|
||||||
bucket_name, object_name = self._handle_folders_in_bucket_name(
|
|
||||||
bucket_name=bucket_name,
|
|
||||||
object_name=object_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
url = f"https://storage.googleapis.com/storage/v1/b/{bucket_name}/o/{object_name}"
|
|
||||||
|
|
||||||
# Send the DELETE request to delete the object
|
|
||||||
response = await self.async_httpx_client.delete(url=url, headers=headers)
|
|
||||||
|
|
||||||
if (response.status_code != 200) or (response.status_code != 204):
|
|
||||||
verbose_logger.error(
|
|
||||||
"GCS object delete error: %s, status code: %s",
|
|
||||||
str(response.text),
|
|
||||||
response.status_code,
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
verbose_logger.debug(
|
|
||||||
"GCS object delete response status code: %s, response: %s",
|
|
||||||
response.status_code,
|
|
||||||
response.text,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Return the content of the downloaded object
|
|
||||||
return response.text
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
verbose_logger.error("GCS object download error: %s", str(e))
|
|
||||||
return None
|
|
||||||
|
|
|
@ -2,7 +2,7 @@ import json
|
||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypedDict, Union
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, TypedDict, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
@ -14,11 +14,18 @@ from litellm.llms.custom_httpx.http_handler import (
|
||||||
get_async_httpx_client,
|
get_async_httpx_client,
|
||||||
httpxSpecialProvider,
|
httpxSpecialProvider,
|
||||||
)
|
)
|
||||||
|
from litellm.types.integrations.gcs_bucket import *
|
||||||
|
from litellm.types.utils import (
|
||||||
|
StandardCallbackDynamicParams,
|
||||||
|
StandardLoggingMetadata,
|
||||||
|
StandardLoggingPayload,
|
||||||
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_llm_base import VertexBase
|
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_llm_base import VertexBase
|
||||||
else:
|
else:
|
||||||
VertexBase = Any
|
VertexBase = Any
|
||||||
|
IAM_AUTH_KEY = "IAM_AUTH"
|
||||||
|
|
||||||
|
|
||||||
class GCSBucketBase(CustomBatchLogger):
|
class GCSBucketBase(CustomBatchLogger):
|
||||||
|
@ -30,6 +37,7 @@ class GCSBucketBase(CustomBatchLogger):
|
||||||
_bucket_name = bucket_name or os.getenv("GCS_BUCKET_NAME")
|
_bucket_name = bucket_name or os.getenv("GCS_BUCKET_NAME")
|
||||||
self.path_service_account_json: Optional[str] = _path_service_account
|
self.path_service_account_json: Optional[str] = _path_service_account
|
||||||
self.BUCKET_NAME: Optional[str] = _bucket_name
|
self.BUCKET_NAME: Optional[str] = _bucket_name
|
||||||
|
self.vertex_instances: Dict[str, VertexBase] = {}
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
async def construct_request_headers(
|
async def construct_request_headers(
|
||||||
|
@ -94,3 +102,237 @@ class GCSBucketBase(CustomBatchLogger):
|
||||||
}
|
}
|
||||||
|
|
||||||
return headers
|
return headers
|
||||||
|
|
||||||
|
def _handle_folders_in_bucket_name(
|
||||||
|
self,
|
||||||
|
bucket_name: str,
|
||||||
|
object_name: str,
|
||||||
|
) -> Tuple[str, str]:
|
||||||
|
"""
|
||||||
|
Handles when the user passes a bucket name with a folder postfix
|
||||||
|
|
||||||
|
|
||||||
|
Example:
|
||||||
|
- Bucket name: "my-bucket/my-folder/dev"
|
||||||
|
- Object name: "my-object"
|
||||||
|
- Returns: bucket_name="my-bucket", object_name="my-folder/dev/my-object"
|
||||||
|
|
||||||
|
"""
|
||||||
|
if "/" in bucket_name:
|
||||||
|
bucket_name, prefix = bucket_name.split("/", 1)
|
||||||
|
object_name = f"{prefix}/{object_name}"
|
||||||
|
return bucket_name, object_name
|
||||||
|
return bucket_name, object_name
|
||||||
|
|
||||||
|
async def get_gcs_logging_config(
|
||||||
|
self, kwargs: Optional[Dict[str, Any]] = {}
|
||||||
|
) -> GCSLoggingConfig:
|
||||||
|
"""
|
||||||
|
This function is used to get the GCS logging config for the GCS Bucket Logger.
|
||||||
|
It checks if the dynamic parameters are provided in the kwargs and uses them to get the GCS logging config.
|
||||||
|
If no dynamic parameters are provided, it uses the default values.
|
||||||
|
"""
|
||||||
|
if kwargs is None:
|
||||||
|
kwargs = {}
|
||||||
|
|
||||||
|
standard_callback_dynamic_params: Optional[StandardCallbackDynamicParams] = (
|
||||||
|
kwargs.get("standard_callback_dynamic_params", None)
|
||||||
|
)
|
||||||
|
|
||||||
|
bucket_name: str
|
||||||
|
path_service_account: Optional[str]
|
||||||
|
if standard_callback_dynamic_params is not None:
|
||||||
|
verbose_logger.debug("Using dynamic GCS logging")
|
||||||
|
verbose_logger.debug(
|
||||||
|
"standard_callback_dynamic_params: %s", standard_callback_dynamic_params
|
||||||
|
)
|
||||||
|
|
||||||
|
_bucket_name: Optional[str] = (
|
||||||
|
standard_callback_dynamic_params.get("gcs_bucket_name", None)
|
||||||
|
or self.BUCKET_NAME
|
||||||
|
)
|
||||||
|
_path_service_account: Optional[str] = (
|
||||||
|
standard_callback_dynamic_params.get("gcs_path_service_account", None)
|
||||||
|
or self.path_service_account_json
|
||||||
|
)
|
||||||
|
|
||||||
|
if _bucket_name is None:
|
||||||
|
raise ValueError(
|
||||||
|
"GCS_BUCKET_NAME is not set in the environment, but GCS Bucket is being used as a logging callback. Please set 'GCS_BUCKET_NAME' in the environment."
|
||||||
|
)
|
||||||
|
bucket_name = _bucket_name
|
||||||
|
path_service_account = _path_service_account
|
||||||
|
vertex_instance = await self.get_or_create_vertex_instance(
|
||||||
|
credentials=path_service_account
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# If no dynamic parameters, use the default instance
|
||||||
|
if self.BUCKET_NAME is None:
|
||||||
|
raise ValueError(
|
||||||
|
"GCS_BUCKET_NAME is not set in the environment, but GCS Bucket is being used as a logging callback. Please set 'GCS_BUCKET_NAME' in the environment."
|
||||||
|
)
|
||||||
|
bucket_name = self.BUCKET_NAME
|
||||||
|
path_service_account = self.path_service_account_json
|
||||||
|
vertex_instance = await self.get_or_create_vertex_instance(
|
||||||
|
credentials=path_service_account
|
||||||
|
)
|
||||||
|
|
||||||
|
return GCSLoggingConfig(
|
||||||
|
bucket_name=bucket_name,
|
||||||
|
vertex_instance=vertex_instance,
|
||||||
|
path_service_account=path_service_account,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_or_create_vertex_instance(
|
||||||
|
self, credentials: Optional[str]
|
||||||
|
) -> VertexBase:
|
||||||
|
"""
|
||||||
|
This function is used to get the Vertex instance for the GCS Bucket Logger.
|
||||||
|
It checks if the Vertex instance is already created and cached, if not it creates a new instance and caches it.
|
||||||
|
"""
|
||||||
|
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_llm_base import (
|
||||||
|
VertexBase,
|
||||||
|
)
|
||||||
|
|
||||||
|
_in_memory_key = self._get_in_memory_key_for_vertex_instance(credentials)
|
||||||
|
if _in_memory_key not in self.vertex_instances:
|
||||||
|
vertex_instance = VertexBase()
|
||||||
|
await vertex_instance._ensure_access_token_async(
|
||||||
|
credentials=credentials,
|
||||||
|
project_id=None,
|
||||||
|
custom_llm_provider="vertex_ai",
|
||||||
|
)
|
||||||
|
self.vertex_instances[_in_memory_key] = vertex_instance
|
||||||
|
return self.vertex_instances[_in_memory_key]
|
||||||
|
|
||||||
|
def _get_in_memory_key_for_vertex_instance(self, credentials: Optional[str]) -> str:
|
||||||
|
"""
|
||||||
|
Returns key to use for caching the Vertex instance in-memory.
|
||||||
|
|
||||||
|
When using Vertex with Key based logging, we need to cache the Vertex instance in-memory.
|
||||||
|
|
||||||
|
- If a credentials string is provided, it is used as the key.
|
||||||
|
- If no credentials string is provided, "IAM_AUTH" is used as the key.
|
||||||
|
"""
|
||||||
|
return credentials or IAM_AUTH_KEY
|
||||||
|
|
||||||
|
async def download_gcs_object(self, object_name: str, **kwargs):
|
||||||
|
"""
|
||||||
|
Download an object from GCS.
|
||||||
|
|
||||||
|
https://cloud.google.com/storage/docs/downloading-objects#download-object-json
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
|
||||||
|
kwargs=kwargs
|
||||||
|
)
|
||||||
|
headers = await self.construct_request_headers(
|
||||||
|
vertex_instance=gcs_logging_config["vertex_instance"],
|
||||||
|
service_account_json=gcs_logging_config["path_service_account"],
|
||||||
|
)
|
||||||
|
bucket_name = gcs_logging_config["bucket_name"]
|
||||||
|
bucket_name, object_name = self._handle_folders_in_bucket_name(
|
||||||
|
bucket_name=bucket_name,
|
||||||
|
object_name=object_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
url = f"https://storage.googleapis.com/storage/v1/b/{bucket_name}/o/{object_name}?alt=media"
|
||||||
|
|
||||||
|
# Send the GET request to download the object
|
||||||
|
response = await self.async_httpx_client.get(url=url, headers=headers)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
verbose_logger.error(
|
||||||
|
"GCS object download error: %s", str(response.text)
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
verbose_logger.debug(
|
||||||
|
"GCS object download response status code: %s", response.status_code
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return the content of the downloaded object
|
||||||
|
return response.content
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
verbose_logger.error("GCS object download error: %s", str(e))
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def delete_gcs_object(self, object_name: str, **kwargs):
|
||||||
|
"""
|
||||||
|
Delete an object from GCS.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
|
||||||
|
kwargs=kwargs
|
||||||
|
)
|
||||||
|
headers = await self.construct_request_headers(
|
||||||
|
vertex_instance=gcs_logging_config["vertex_instance"],
|
||||||
|
service_account_json=gcs_logging_config["path_service_account"],
|
||||||
|
)
|
||||||
|
bucket_name = gcs_logging_config["bucket_name"]
|
||||||
|
bucket_name, object_name = self._handle_folders_in_bucket_name(
|
||||||
|
bucket_name=bucket_name,
|
||||||
|
object_name=object_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
url = f"https://storage.googleapis.com/storage/v1/b/{bucket_name}/o/{object_name}"
|
||||||
|
|
||||||
|
# Send the DELETE request to delete the object
|
||||||
|
response = await self.async_httpx_client.delete(url=url, headers=headers)
|
||||||
|
|
||||||
|
if (response.status_code != 200) or (response.status_code != 204):
|
||||||
|
verbose_logger.error(
|
||||||
|
"GCS object delete error: %s, status code: %s",
|
||||||
|
str(response.text),
|
||||||
|
response.status_code,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
verbose_logger.debug(
|
||||||
|
"GCS object delete response status code: %s, response: %s",
|
||||||
|
response.status_code,
|
||||||
|
response.text,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return the content of the downloaded object
|
||||||
|
return response.text
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
verbose_logger.error("GCS object download error: %s", str(e))
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _log_json_data_on_gcs(
|
||||||
|
self,
|
||||||
|
headers: Dict[str, str],
|
||||||
|
bucket_name: str,
|
||||||
|
object_name: str,
|
||||||
|
logging_payload: Union[StandardLoggingPayload, str],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Helper function to make POST request to GCS Bucket in the specified bucket.
|
||||||
|
"""
|
||||||
|
if isinstance(logging_payload, str):
|
||||||
|
json_logged_payload = logging_payload
|
||||||
|
else:
|
||||||
|
json_logged_payload = json.dumps(logging_payload, default=str)
|
||||||
|
|
||||||
|
bucket_name, object_name = self._handle_folders_in_bucket_name(
|
||||||
|
bucket_name=bucket_name,
|
||||||
|
object_name=object_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await self.async_httpx_client.post(
|
||||||
|
headers=headers,
|
||||||
|
url=f"https://storage.googleapis.com/upload/storage/v1/b/{bucket_name}/o?uploadType=media&name={object_name}",
|
||||||
|
data=json_logged_payload,
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
verbose_logger.error("GCS Bucket logging error: %s", str(response.text))
|
||||||
|
|
||||||
|
verbose_logger.debug("GCS Bucket response %s", response)
|
||||||
|
verbose_logger.debug("GCS Bucket status code %s", response.status_code)
|
||||||
|
verbose_logger.debug("GCS Bucket response.text %s", response.text)
|
||||||
|
|
||||||
|
return response.json()
|
||||||
|
|
|
@ -0,0 +1,6 @@
|
||||||
|
# Vertex AI Batch Prediction Jobs
|
||||||
|
|
||||||
|
Implementation to call VertexAI Batch endpoints in OpenAI Batch API spec
|
||||||
|
|
||||||
|
Vertex Docs: https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/batch-prediction-gemini
|
||||||
|
|
141
litellm/llms/vertex_ai_and_google_ai_studio/batches/handler.py
Normal file
141
litellm/llms/vertex_ai_and_google_ai_studio/batches/handler.py
Normal file
|
@ -0,0 +1,141 @@
|
||||||
|
import json
|
||||||
|
from typing import Any, Coroutine, Dict, Optional, Union
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
|
AsyncHTTPHandler,
|
||||||
|
HTTPHandler,
|
||||||
|
_get_httpx_client,
|
||||||
|
get_async_httpx_client,
|
||||||
|
)
|
||||||
|
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
|
||||||
|
VertexAIError,
|
||||||
|
VertexLLM,
|
||||||
|
)
|
||||||
|
from litellm.types.llms.openai import (
|
||||||
|
Batch,
|
||||||
|
CancelBatchRequest,
|
||||||
|
CreateBatchRequest,
|
||||||
|
CreateFileRequest,
|
||||||
|
FileContentRequest,
|
||||||
|
FileObject,
|
||||||
|
FileTypes,
|
||||||
|
HttpxBinaryResponseContent,
|
||||||
|
RetrieveBatchRequest,
|
||||||
|
)
|
||||||
|
from litellm.types.llms.vertex_ai import VertexAIBatchPredictionJob
|
||||||
|
|
||||||
|
from .transformation import VertexAIBatchTransformation
|
||||||
|
|
||||||
|
|
||||||
|
class VertexAIBatchPrediction(VertexLLM):
|
||||||
|
def __init__(self, gcs_bucket_name: str, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.gcs_bucket_name = gcs_bucket_name
|
||||||
|
|
||||||
|
def create_batch(
|
||||||
|
self,
|
||||||
|
_is_async: bool,
|
||||||
|
create_batch_data: CreateBatchRequest,
|
||||||
|
api_base: Optional[str],
|
||||||
|
vertex_credentials: Optional[str],
|
||||||
|
vertex_project: Optional[str],
|
||||||
|
vertex_location: Optional[str],
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
max_retries: Optional[int],
|
||||||
|
) -> Union[Batch, Coroutine[Any, Any, Batch]]:
|
||||||
|
|
||||||
|
sync_handler = _get_httpx_client()
|
||||||
|
|
||||||
|
access_token, project_id = self._ensure_access_token(
|
||||||
|
credentials=vertex_credentials,
|
||||||
|
project_id=vertex_project,
|
||||||
|
custom_llm_provider="vertex_ai",
|
||||||
|
)
|
||||||
|
|
||||||
|
default_api_base = self.create_vertex_url(
|
||||||
|
vertex_location=vertex_location or "us-central1",
|
||||||
|
vertex_project=vertex_project or project_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(default_api_base.split(":")) > 1:
|
||||||
|
endpoint = default_api_base.split(":")[-1]
|
||||||
|
else:
|
||||||
|
endpoint = ""
|
||||||
|
|
||||||
|
_, api_base = self._check_custom_proxy(
|
||||||
|
api_base=api_base,
|
||||||
|
custom_llm_provider="vertex_ai",
|
||||||
|
gemini_api_key=None,
|
||||||
|
endpoint=endpoint,
|
||||||
|
stream=None,
|
||||||
|
auth_header=None,
|
||||||
|
url=default_api_base,
|
||||||
|
)
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json; charset=utf-8",
|
||||||
|
"Authorization": f"Bearer {access_token}",
|
||||||
|
}
|
||||||
|
|
||||||
|
vertex_batch_request: VertexAIBatchPredictionJob = (
|
||||||
|
VertexAIBatchTransformation.transform_openai_batch_request_to_vertex_ai_batch_request(
|
||||||
|
request=create_batch_data
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if _is_async is True:
|
||||||
|
return self._async_create_batch(
|
||||||
|
vertex_batch_request=vertex_batch_request,
|
||||||
|
api_base=api_base,
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = sync_handler.post(
|
||||||
|
url=api_base,
|
||||||
|
headers=headers,
|
||||||
|
data=json.dumps(vertex_batch_request),
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise Exception(f"Error: {response.status_code} {response.text}")
|
||||||
|
|
||||||
|
_json_response = response.json()
|
||||||
|
vertex_batch_response = VertexAIBatchTransformation.transform_vertex_ai_batch_response_to_openai_batch_response(
|
||||||
|
response=_json_response
|
||||||
|
)
|
||||||
|
return vertex_batch_response
|
||||||
|
|
||||||
|
async def _async_create_batch(
|
||||||
|
self,
|
||||||
|
vertex_batch_request: VertexAIBatchPredictionJob,
|
||||||
|
api_base: str,
|
||||||
|
headers: Dict[str, str],
|
||||||
|
) -> Batch:
|
||||||
|
client = get_async_httpx_client(
|
||||||
|
llm_provider=litellm.LlmProviders.VERTEX_AI,
|
||||||
|
)
|
||||||
|
response = await client.post(
|
||||||
|
url=api_base,
|
||||||
|
headers=headers,
|
||||||
|
data=json.dumps(vertex_batch_request),
|
||||||
|
)
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise Exception(f"Error: {response.status_code} {response.text}")
|
||||||
|
|
||||||
|
_json_response = response.json()
|
||||||
|
vertex_batch_response = VertexAIBatchTransformation.transform_vertex_ai_batch_response_to_openai_batch_response(
|
||||||
|
response=_json_response
|
||||||
|
)
|
||||||
|
return vertex_batch_response
|
||||||
|
|
||||||
|
def create_vertex_url(
|
||||||
|
self,
|
||||||
|
vertex_location: str,
|
||||||
|
vertex_project: str,
|
||||||
|
) -> str:
|
||||||
|
"""Return the base url for the vertex garden models"""
|
||||||
|
# POST https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/LOCATION/batchPredictionJobs
|
||||||
|
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/batchPredictionJobs"
|
|
@ -0,0 +1,174 @@
|
||||||
|
import uuid
|
||||||
|
from typing import Any, Dict, Literal
|
||||||
|
|
||||||
|
from litellm.llms.vertex_ai_and_google_ai_studio.common_utils import (
|
||||||
|
_convert_vertex_datetime_to_openai_datetime,
|
||||||
|
)
|
||||||
|
from litellm.types.llms.openai import Batch, BatchJobStatus, CreateBatchRequest
|
||||||
|
from litellm.types.llms.vertex_ai import *
|
||||||
|
|
||||||
|
|
||||||
|
class VertexAIBatchTransformation:
|
||||||
|
"""
|
||||||
|
Transforms OpenAI Batch requests to Vertex AI Batch requests
|
||||||
|
|
||||||
|
API Ref: https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/batch-prediction-gemini
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def transform_openai_batch_request_to_vertex_ai_batch_request(
|
||||||
|
cls,
|
||||||
|
request: CreateBatchRequest,
|
||||||
|
) -> VertexAIBatchPredictionJob:
|
||||||
|
"""
|
||||||
|
Transforms OpenAI Batch requests to Vertex AI Batch requests
|
||||||
|
"""
|
||||||
|
request_display_name = f"litellm-vertex-batch-{uuid.uuid4()}"
|
||||||
|
input_file_id = request.get("input_file_id")
|
||||||
|
if input_file_id is None:
|
||||||
|
raise ValueError("input_file_id is required, but not provided")
|
||||||
|
input_config: InputConfig = InputConfig(
|
||||||
|
gcsSource=GcsSource(uris=input_file_id), instancesFormat="jsonl"
|
||||||
|
)
|
||||||
|
model: str = cls._get_model_from_gcs_file(input_file_id)
|
||||||
|
output_config: OutputConfig = OutputConfig(
|
||||||
|
predictionsFormat="jsonl",
|
||||||
|
gcsDestination=GcsDestination(
|
||||||
|
outputUriPrefix=cls._get_gcs_uri_prefix_from_file(input_file_id)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return VertexAIBatchPredictionJob(
|
||||||
|
inputConfig=input_config,
|
||||||
|
outputConfig=output_config,
|
||||||
|
model=model,
|
||||||
|
displayName=request_display_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def transform_vertex_ai_batch_response_to_openai_batch_response(
|
||||||
|
cls, response: VertexBatchPredictionResponse
|
||||||
|
) -> Batch:
|
||||||
|
return Batch(
|
||||||
|
id=response.get("name", ""),
|
||||||
|
completion_window="24hrs",
|
||||||
|
created_at=_convert_vertex_datetime_to_openai_datetime(
|
||||||
|
vertex_datetime=response.get("createTime", "")
|
||||||
|
),
|
||||||
|
endpoint="",
|
||||||
|
input_file_id=cls._get_input_file_id_from_vertex_ai_batch_response(
|
||||||
|
response
|
||||||
|
),
|
||||||
|
object="batch",
|
||||||
|
status=cls._get_batch_job_status_from_vertex_ai_batch_response(response),
|
||||||
|
error_file_id=None, # Vertex AI doesn't seem to have a direct equivalent
|
||||||
|
output_file_id=cls._get_output_file_id_from_vertex_ai_batch_response(
|
||||||
|
response
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_input_file_id_from_vertex_ai_batch_response(
|
||||||
|
cls, response: VertexBatchPredictionResponse
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Gets the input file id from the Vertex AI Batch response
|
||||||
|
"""
|
||||||
|
input_file_id: str = ""
|
||||||
|
input_config = response.get("inputConfig")
|
||||||
|
if input_config is None:
|
||||||
|
return input_file_id
|
||||||
|
|
||||||
|
gcs_source = input_config.get("gcsSource")
|
||||||
|
if gcs_source is None:
|
||||||
|
return input_file_id
|
||||||
|
|
||||||
|
uris = gcs_source.get("uris", "")
|
||||||
|
if len(uris) == 0:
|
||||||
|
return input_file_id
|
||||||
|
|
||||||
|
return uris[0]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_output_file_id_from_vertex_ai_batch_response(
|
||||||
|
cls, response: VertexBatchPredictionResponse
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Gets the output file id from the Vertex AI Batch response
|
||||||
|
"""
|
||||||
|
output_file_id: str = ""
|
||||||
|
output_config = response.get("outputConfig")
|
||||||
|
if output_config is None:
|
||||||
|
return output_file_id
|
||||||
|
|
||||||
|
gcs_destination = output_config.get("gcsDestination")
|
||||||
|
if gcs_destination is None:
|
||||||
|
return output_file_id
|
||||||
|
|
||||||
|
output_uri_prefix = gcs_destination.get("outputUriPrefix", "")
|
||||||
|
return output_uri_prefix
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_batch_job_status_from_vertex_ai_batch_response(
|
||||||
|
cls, response: VertexBatchPredictionResponse
|
||||||
|
) -> BatchJobStatus:
|
||||||
|
"""
|
||||||
|
Gets the batch job status from the Vertex AI Batch response
|
||||||
|
|
||||||
|
ref: https://cloud.google.com/vertex-ai/docs/reference/rest/v1/JobState
|
||||||
|
"""
|
||||||
|
state_mapping: Dict[str, BatchJobStatus] = {
|
||||||
|
"JOB_STATE_UNSPECIFIED": "failed",
|
||||||
|
"JOB_STATE_QUEUED": "validating",
|
||||||
|
"JOB_STATE_PENDING": "validating",
|
||||||
|
"JOB_STATE_RUNNING": "in_progress",
|
||||||
|
"JOB_STATE_SUCCEEDED": "completed",
|
||||||
|
"JOB_STATE_FAILED": "failed",
|
||||||
|
"JOB_STATE_CANCELLING": "cancelling",
|
||||||
|
"JOB_STATE_CANCELLED": "cancelled",
|
||||||
|
"JOB_STATE_PAUSED": "in_progress",
|
||||||
|
"JOB_STATE_EXPIRED": "expired",
|
||||||
|
"JOB_STATE_UPDATING": "in_progress",
|
||||||
|
"JOB_STATE_PARTIALLY_SUCCEEDED": "completed",
|
||||||
|
}
|
||||||
|
|
||||||
|
vertex_state = response.get("state", "JOB_STATE_UNSPECIFIED")
|
||||||
|
return state_mapping[vertex_state]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_gcs_uri_prefix_from_file(cls, input_file_id: str) -> str:
|
||||||
|
"""
|
||||||
|
Gets the gcs uri prefix from the input file id
|
||||||
|
|
||||||
|
Example:
|
||||||
|
input_file_id: "gs://litellm-testing-bucket/vtx_batch.jsonl"
|
||||||
|
returns: "gs://litellm-testing-bucket"
|
||||||
|
|
||||||
|
input_file_id: "gs://litellm-testing-bucket/batches/vtx_batch.jsonl"
|
||||||
|
returns: "gs://litellm-testing-bucket/batches"
|
||||||
|
"""
|
||||||
|
# Split the path and remove the filename
|
||||||
|
path_parts = input_file_id.rsplit("/", 1)
|
||||||
|
return path_parts[0]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_model_from_gcs_file(cls, gcs_file_uri: str) -> str:
|
||||||
|
"""
|
||||||
|
Extracts the model from the gcs file uri
|
||||||
|
|
||||||
|
When files are uploaded using LiteLLM (/v1/files), the model is stored in the gcs file uri
|
||||||
|
|
||||||
|
Why?
|
||||||
|
- Because Vertex Requires the `model` param in create batch jobs request, but OpenAI does not require this
|
||||||
|
|
||||||
|
|
||||||
|
gcs_file_uri format: gs://litellm-testing-bucket/litellm-vertex-files/publishers/google/models/gemini-1.5-flash-001/e9412502-2c91-42a6-8e61-f5c294cc0fc8
|
||||||
|
returns: "publishers/google/models/gemini-1.5-flash-001"
|
||||||
|
"""
|
||||||
|
from urllib.parse import unquote
|
||||||
|
|
||||||
|
decoded_uri = unquote(gcs_file_uri)
|
||||||
|
|
||||||
|
model_path = decoded_uri.split("publishers/")[1]
|
||||||
|
parts = model_path.split("/")
|
||||||
|
model = f"publishers/{'/'.join(parts[:3])}"
|
||||||
|
return model
|
|
@ -264,3 +264,18 @@ def strip_field(schema, field_name: str):
|
||||||
items = schema.get("items", None)
|
items = schema.get("items", None)
|
||||||
if items is not None:
|
if items is not None:
|
||||||
strip_field(items, field_name)
|
strip_field(items, field_name)
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_vertex_datetime_to_openai_datetime(vertex_datetime: str) -> int:
|
||||||
|
"""
|
||||||
|
Converts a Vertex AI datetime string to an OpenAI datetime integer
|
||||||
|
|
||||||
|
vertex_datetime: str = "2024-12-04T21:53:12.120184Z"
|
||||||
|
returns: int = 1722729192
|
||||||
|
"""
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
# Parse the ISO format string to datetime object
|
||||||
|
dt = datetime.strptime(vertex_datetime, "%Y-%m-%dT%H:%M:%S.%fZ")
|
||||||
|
# Convert to Unix timestamp (seconds since epoch)
|
||||||
|
return int(dt.timestamp())
|
||||||
|
|
111
litellm/llms/vertex_ai_and_google_ai_studio/files/handler.py
Normal file
111
litellm/llms/vertex_ai_and_google_ai_studio/files/handler.py
Normal file
|
@ -0,0 +1,111 @@
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from typing import Any, Coroutine, Dict, Optional, Union
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm.integrations.gcs_bucket.gcs_bucket_base import (
|
||||||
|
GCSBucketBase,
|
||||||
|
GCSLoggingConfig,
|
||||||
|
)
|
||||||
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
|
AsyncHTTPHandler,
|
||||||
|
HTTPHandler,
|
||||||
|
_get_httpx_client,
|
||||||
|
get_async_httpx_client,
|
||||||
|
)
|
||||||
|
from litellm.llms.vertex_ai_and_google_ai_studio.common_utils import (
|
||||||
|
_convert_vertex_datetime_to_openai_datetime,
|
||||||
|
)
|
||||||
|
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
|
||||||
|
VertexAIError,
|
||||||
|
VertexLLM,
|
||||||
|
)
|
||||||
|
from litellm.types.llms.openai import (
|
||||||
|
Batch,
|
||||||
|
CreateFileRequest,
|
||||||
|
FileContentRequest,
|
||||||
|
FileObject,
|
||||||
|
FileTypes,
|
||||||
|
HttpxBinaryResponseContent,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .transformation import VertexAIFilesTransformation
|
||||||
|
|
||||||
|
vertex_ai_files_transformation = VertexAIFilesTransformation()
|
||||||
|
|
||||||
|
|
||||||
|
class VertexAIFilesHandler(GCSBucketBase):
|
||||||
|
"""
|
||||||
|
Handles Calling VertexAI in OpenAI Files API format v1/files/*
|
||||||
|
|
||||||
|
This implementation uploads files on GCS Buckets
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def async_create_file(
|
||||||
|
self,
|
||||||
|
create_file_data: CreateFileRequest,
|
||||||
|
api_base: Optional[str],
|
||||||
|
vertex_credentials: Optional[str],
|
||||||
|
vertex_project: Optional[str],
|
||||||
|
vertex_location: Optional[str],
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
max_retries: Optional[int],
|
||||||
|
):
|
||||||
|
gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
|
||||||
|
kwargs={}
|
||||||
|
)
|
||||||
|
headers = await self.construct_request_headers(
|
||||||
|
vertex_instance=gcs_logging_config["vertex_instance"],
|
||||||
|
service_account_json=gcs_logging_config["path_service_account"],
|
||||||
|
)
|
||||||
|
bucket_name = gcs_logging_config["bucket_name"]
|
||||||
|
logging_payload, object_name = (
|
||||||
|
vertex_ai_files_transformation.transform_openai_file_content_to_vertex_ai_file_content(
|
||||||
|
openai_file_content=create_file_data.get("file")
|
||||||
|
)
|
||||||
|
)
|
||||||
|
gcs_upload_response = await self._log_json_data_on_gcs(
|
||||||
|
headers=headers,
|
||||||
|
bucket_name=bucket_name,
|
||||||
|
object_name=object_name,
|
||||||
|
logging_payload=logging_payload,
|
||||||
|
)
|
||||||
|
|
||||||
|
return vertex_ai_files_transformation.transform_gcs_bucket_response_to_openai_file_object(
|
||||||
|
create_file_data=create_file_data,
|
||||||
|
gcs_upload_response=gcs_upload_response,
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_file(
|
||||||
|
self,
|
||||||
|
_is_async: bool,
|
||||||
|
create_file_data: CreateFileRequest,
|
||||||
|
api_base: Optional[str],
|
||||||
|
vertex_credentials: Optional[str],
|
||||||
|
vertex_project: Optional[str],
|
||||||
|
vertex_location: Optional[str],
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
max_retries: Optional[int],
|
||||||
|
) -> Union[FileObject, Coroutine[Any, Any, FileObject]]:
|
||||||
|
"""
|
||||||
|
Creates a file on VertexAI GCS Bucket
|
||||||
|
|
||||||
|
Only supported for Async litellm.acreate_file
|
||||||
|
"""
|
||||||
|
|
||||||
|
if _is_async:
|
||||||
|
return self.async_create_file(
|
||||||
|
create_file_data=create_file_data,
|
||||||
|
api_base=api_base,
|
||||||
|
vertex_credentials=vertex_credentials,
|
||||||
|
vertex_project=vertex_project,
|
||||||
|
vertex_location=vertex_location,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
)
|
||||||
|
|
||||||
|
return None # type: ignore
|
|
@ -0,0 +1,173 @@
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
from litellm.llms.vertex_ai_and_google_ai_studio.common_utils import (
|
||||||
|
_convert_vertex_datetime_to_openai_datetime,
|
||||||
|
)
|
||||||
|
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.transformation import (
|
||||||
|
_transform_request_body,
|
||||||
|
)
|
||||||
|
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
|
||||||
|
VertexGeminiConfig,
|
||||||
|
)
|
||||||
|
from litellm.types.llms.openai import (
|
||||||
|
Batch,
|
||||||
|
CreateFileRequest,
|
||||||
|
FileContentRequest,
|
||||||
|
FileObject,
|
||||||
|
FileTypes,
|
||||||
|
HttpxBinaryResponseContent,
|
||||||
|
PathLike,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class VertexAIFilesTransformation(VertexGeminiConfig):
|
||||||
|
"""
|
||||||
|
Transforms OpenAI /v1/files/* requests to VertexAI /v1/files/* requests
|
||||||
|
"""
|
||||||
|
|
||||||
|
def transform_openai_file_content_to_vertex_ai_file_content(
|
||||||
|
self, openai_file_content: Optional[FileTypes] = None
|
||||||
|
) -> Tuple[str, str]:
|
||||||
|
"""
|
||||||
|
Transforms OpenAI FileContentRequest to VertexAI FileContentRequest
|
||||||
|
"""
|
||||||
|
|
||||||
|
if openai_file_content is None:
|
||||||
|
raise ValueError("contents of file are None")
|
||||||
|
# Read the content of the file
|
||||||
|
file_content = self._get_content_from_openai_file(openai_file_content)
|
||||||
|
|
||||||
|
# Split into lines and parse each line as JSON
|
||||||
|
openai_jsonl_content = [
|
||||||
|
json.loads(line) for line in file_content.splitlines() if line.strip()
|
||||||
|
]
|
||||||
|
vertex_jsonl_content = (
|
||||||
|
self._transform_openai_jsonl_content_to_vertex_ai_jsonl_content(
|
||||||
|
openai_jsonl_content
|
||||||
|
)
|
||||||
|
)
|
||||||
|
vertex_jsonl_string = "\n".join(
|
||||||
|
json.dumps(item) for item in vertex_jsonl_content
|
||||||
|
)
|
||||||
|
object_name = self._get_gcs_object_name(
|
||||||
|
openai_jsonl_content=openai_jsonl_content
|
||||||
|
)
|
||||||
|
return vertex_jsonl_string, object_name
|
||||||
|
|
||||||
|
def _transform_openai_jsonl_content_to_vertex_ai_jsonl_content(
|
||||||
|
self, openai_jsonl_content: List[Dict[str, Any]]
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Transforms OpenAI JSONL content to VertexAI JSONL content
|
||||||
|
|
||||||
|
jsonl body for vertex is {"request": <request_body>}
|
||||||
|
Example Vertex jsonl
|
||||||
|
{"request":{"contents": [{"role": "user", "parts": [{"text": "What is the relation between the following video and image samples?"}, {"fileData": {"fileUri": "gs://cloud-samples-data/generative-ai/video/animals.mp4", "mimeType": "video/mp4"}}, {"fileData": {"fileUri": "gs://cloud-samples-data/generative-ai/image/cricket.jpeg", "mimeType": "image/jpeg"}}]}]}}
|
||||||
|
{"request":{"contents": [{"role": "user", "parts": [{"text": "Describe what is happening in this video."}, {"fileData": {"fileUri": "gs://cloud-samples-data/generative-ai/video/another_video.mov", "mimeType": "video/mov"}}]}]}}
|
||||||
|
"""
|
||||||
|
|
||||||
|
vertex_jsonl_content = []
|
||||||
|
for _openai_jsonl_content in openai_jsonl_content:
|
||||||
|
openai_request_body = _openai_jsonl_content.get("body") or {}
|
||||||
|
vertex_request_body = _transform_request_body(
|
||||||
|
messages=openai_request_body.get("messages", []),
|
||||||
|
model=openai_request_body.get("model", ""),
|
||||||
|
optional_params=self._map_openai_to_vertex_params(openai_request_body),
|
||||||
|
custom_llm_provider="vertex_ai",
|
||||||
|
litellm_params={},
|
||||||
|
cached_content=None,
|
||||||
|
)
|
||||||
|
vertex_jsonl_content.append({"request": vertex_request_body})
|
||||||
|
return vertex_jsonl_content
|
||||||
|
|
||||||
|
def _get_gcs_object_name(
|
||||||
|
self,
|
||||||
|
openai_jsonl_content: List[Dict[str, Any]],
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Gets a unique GCS object name for the VertexAI batch prediction job
|
||||||
|
|
||||||
|
named as: litellm-vertex-{model}-{uuid}
|
||||||
|
"""
|
||||||
|
_model = openai_jsonl_content[0].get("body", {}).get("model", "")
|
||||||
|
if "publishers/google/models" not in _model:
|
||||||
|
_model = f"publishers/google/models/{_model}"
|
||||||
|
object_name = f"litellm-vertex-files/{_model}/{uuid.uuid4()}"
|
||||||
|
return object_name
|
||||||
|
|
||||||
|
def _map_openai_to_vertex_params(
|
||||||
|
self,
|
||||||
|
openai_request_body: Dict[str, Any],
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
wrapper to call VertexGeminiConfig.map_openai_params
|
||||||
|
"""
|
||||||
|
_model = openai_request_body.get("model", "")
|
||||||
|
vertex_params = self.map_openai_params(
|
||||||
|
model=_model,
|
||||||
|
non_default_params=openai_request_body,
|
||||||
|
optional_params={},
|
||||||
|
drop_params=False,
|
||||||
|
)
|
||||||
|
return vertex_params
|
||||||
|
|
||||||
|
def _get_content_from_openai_file(self, openai_file_content: FileTypes) -> str:
|
||||||
|
"""
|
||||||
|
Helper to extract content from various OpenAI file types and return as string.
|
||||||
|
|
||||||
|
Handles:
|
||||||
|
- Direct content (str, bytes, IO[bytes])
|
||||||
|
- Tuple formats: (filename, content, [content_type], [headers])
|
||||||
|
- PathLike objects
|
||||||
|
"""
|
||||||
|
content: Union[str, bytes] = b""
|
||||||
|
# Extract file content from tuple if necessary
|
||||||
|
if isinstance(openai_file_content, tuple):
|
||||||
|
# Take the second element which is always the file content
|
||||||
|
file_content = openai_file_content[1]
|
||||||
|
else:
|
||||||
|
file_content = openai_file_content
|
||||||
|
|
||||||
|
# Handle different file content types
|
||||||
|
if isinstance(file_content, str):
|
||||||
|
# String content can be used directly
|
||||||
|
content = file_content
|
||||||
|
elif isinstance(file_content, bytes):
|
||||||
|
# Bytes content can be decoded
|
||||||
|
content = file_content
|
||||||
|
elif isinstance(file_content, PathLike): # PathLike
|
||||||
|
with open(str(file_content), "rb") as f:
|
||||||
|
content = f.read()
|
||||||
|
elif hasattr(file_content, "read"): # IO[bytes]
|
||||||
|
# File-like objects need to be read
|
||||||
|
content = file_content.read()
|
||||||
|
|
||||||
|
# Ensure content is string
|
||||||
|
if isinstance(content, bytes):
|
||||||
|
content = content.decode("utf-8")
|
||||||
|
|
||||||
|
return content
|
||||||
|
|
||||||
|
def transform_gcs_bucket_response_to_openai_file_object(
|
||||||
|
self, create_file_data: CreateFileRequest, gcs_upload_response: Dict[str, Any]
|
||||||
|
) -> FileObject:
|
||||||
|
"""
|
||||||
|
Transforms GCS Bucket upload file response to OpenAI FileObject
|
||||||
|
"""
|
||||||
|
gcs_id = gcs_upload_response.get("id", "")
|
||||||
|
# Remove the last numeric ID from the path
|
||||||
|
gcs_id = "/".join(gcs_id.split("/")[:-1]) if gcs_id else ""
|
||||||
|
|
||||||
|
return FileObject(
|
||||||
|
purpose=create_file_data.get("purpose", "batch"),
|
||||||
|
id=f"gs://{gcs_id}",
|
||||||
|
filename=gcs_upload_response.get("name", ""),
|
||||||
|
created_at=_convert_vertex_datetime_to_openai_datetime(
|
||||||
|
vertex_datetime=gcs_upload_response.get("timeCreated", "")
|
||||||
|
),
|
||||||
|
status="uploaded",
|
||||||
|
bytes=gcs_upload_response.get("size", 0),
|
||||||
|
object="file",
|
||||||
|
)
|
|
@ -59,6 +59,8 @@ def get_files_provider_config(
|
||||||
custom_llm_provider: str,
|
custom_llm_provider: str,
|
||||||
):
|
):
|
||||||
global files_config
|
global files_config
|
||||||
|
if custom_llm_provider == "vertex_ai":
|
||||||
|
return None
|
||||||
if files_config is None:
|
if files_config is None:
|
||||||
raise ValueError("files_config is not set, set it on your config.yaml file.")
|
raise ValueError("files_config is not set, set it on your config.yaml file.")
|
||||||
for setting in files_config:
|
for setting in files_config:
|
||||||
|
@ -212,9 +214,9 @@ async def create_file(
|
||||||
if llm_provider_config is not None:
|
if llm_provider_config is not None:
|
||||||
# add llm_provider_config to data
|
# add llm_provider_config to data
|
||||||
_create_file_request.update(llm_provider_config)
|
_create_file_request.update(llm_provider_config)
|
||||||
|
_create_file_request.pop("custom_llm_provider", None) # type: ignore
|
||||||
# for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch
|
# for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch
|
||||||
response = await litellm.acreate_file(**_create_file_request) # type: ignore
|
response = await litellm.acreate_file(**_create_file_request, custom_llm_provider=custom_llm_provider) # type: ignore
|
||||||
|
|
||||||
### ALERTING ###
|
### ALERTING ###
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
|
@ -239,7 +241,6 @@ async def create_file(
|
||||||
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await proxy_logging_obj.post_call_failure_hook(
|
await proxy_logging_obj.post_call_failure_hook(
|
||||||
|
|
|
@ -5225,6 +5225,7 @@ async def create_batch(
|
||||||
is_router_model = is_known_model(model=router_model, llm_router=llm_router)
|
is_router_model = is_known_model(model=router_model, llm_router=llm_router)
|
||||||
|
|
||||||
_create_batch_data = CreateBatchRequest(**data)
|
_create_batch_data = CreateBatchRequest(**data)
|
||||||
|
custom_llm_provider = provider or _create_batch_data.pop("custom_llm_provider", None) # type: ignore
|
||||||
|
|
||||||
if (
|
if (
|
||||||
litellm.enable_loadbalancing_on_batch_endpoints is True
|
litellm.enable_loadbalancing_on_batch_endpoints is True
|
||||||
|
@ -5241,10 +5242,10 @@ async def create_batch(
|
||||||
|
|
||||||
response = await llm_router.acreate_batch(**_create_batch_data) # type: ignore
|
response = await llm_router.acreate_batch(**_create_batch_data) # type: ignore
|
||||||
else:
|
else:
|
||||||
if provider is None:
|
if custom_llm_provider is None:
|
||||||
provider = "openai"
|
custom_llm_provider = "openai"
|
||||||
response = await litellm.acreate_batch(
|
response = await litellm.acreate_batch(
|
||||||
custom_llm_provider=provider, **_create_batch_data # type: ignore
|
custom_llm_provider=custom_llm_provider, **_create_batch_data # type: ignore
|
||||||
)
|
)
|
||||||
|
|
||||||
### ALERTING ###
|
### ALERTING ###
|
||||||
|
|
|
@ -301,6 +301,18 @@ class ListBatchRequest(TypedDict, total=False):
|
||||||
timeout: Optional[float]
|
timeout: Optional[float]
|
||||||
|
|
||||||
|
|
||||||
|
BatchJobStatus = Literal[
|
||||||
|
"validating",
|
||||||
|
"failed",
|
||||||
|
"in_progress",
|
||||||
|
"finalizing",
|
||||||
|
"completed",
|
||||||
|
"expired",
|
||||||
|
"cancelling",
|
||||||
|
"cancelled",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionAudioDelta(TypedDict, total=False):
|
class ChatCompletionAudioDelta(TypedDict, total=False):
|
||||||
data: str
|
data: str
|
||||||
transcript: str
|
transcript: str
|
||||||
|
|
|
@ -434,3 +434,43 @@ class VertexAIBatchEmbeddingsRequestBody(TypedDict, total=False):
|
||||||
|
|
||||||
class VertexAIBatchEmbeddingsResponseObject(TypedDict):
|
class VertexAIBatchEmbeddingsResponseObject(TypedDict):
|
||||||
embeddings: List[ContentEmbeddings]
|
embeddings: List[ContentEmbeddings]
|
||||||
|
|
||||||
|
|
||||||
|
# Vertex AI Batch Prediction
|
||||||
|
|
||||||
|
|
||||||
|
class GcsSource(TypedDict):
|
||||||
|
uris: str
|
||||||
|
|
||||||
|
|
||||||
|
class InputConfig(TypedDict):
|
||||||
|
instancesFormat: str
|
||||||
|
gcsSource: GcsSource
|
||||||
|
|
||||||
|
|
||||||
|
class GcsDestination(TypedDict):
|
||||||
|
outputUriPrefix: str
|
||||||
|
|
||||||
|
|
||||||
|
class OutputConfig(TypedDict, total=False):
|
||||||
|
predictionsFormat: str
|
||||||
|
gcsDestination: GcsDestination
|
||||||
|
|
||||||
|
|
||||||
|
class VertexAIBatchPredictionJob(TypedDict):
|
||||||
|
displayName: str
|
||||||
|
model: str
|
||||||
|
inputConfig: InputConfig
|
||||||
|
outputConfig: OutputConfig
|
||||||
|
|
||||||
|
|
||||||
|
class VertexBatchPredictionResponse(TypedDict, total=False):
|
||||||
|
name: str
|
||||||
|
displayName: str
|
||||||
|
model: str
|
||||||
|
inputConfig: InputConfig
|
||||||
|
outputConfig: OutputConfig
|
||||||
|
state: str
|
||||||
|
createTime: str
|
||||||
|
updateTime: str
|
||||||
|
modelVersionId: str
|
||||||
|
|
|
@ -11,8 +11,7 @@ from dotenv import load_dotenv
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
sys.path.insert(
|
sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system-path
|
||||||
import asyncio
|
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
@ -20,6 +19,10 @@ import pytest
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import create_batch, create_file
|
from litellm import create_batch, create_file
|
||||||
|
from litellm._logging import verbose_logger
|
||||||
|
from test_gcs_bucket import load_vertex_ai_credentials
|
||||||
|
|
||||||
|
verbose_logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("provider", ["openai"]) # , "azure"
|
@pytest.mark.parametrize("provider", ["openai"]) # , "azure"
|
||||||
|
@ -206,3 +209,32 @@ def test_cancel_batch():
|
||||||
|
|
||||||
def test_list_batch():
|
def test_list_batch():
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_vertex_batch_prediction():
|
||||||
|
load_vertex_ai_credentials()
|
||||||
|
file_name = "vertex_batch_completions.jsonl"
|
||||||
|
_current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
file_path = os.path.join(_current_dir, file_name)
|
||||||
|
file_obj = await litellm.acreate_file(
|
||||||
|
file=open(file_path, "rb"),
|
||||||
|
purpose="batch",
|
||||||
|
custom_llm_provider="vertex_ai",
|
||||||
|
)
|
||||||
|
print("Response from creating file=", file_obj)
|
||||||
|
|
||||||
|
batch_input_file_id = file_obj.id
|
||||||
|
assert (
|
||||||
|
batch_input_file_id is not None
|
||||||
|
), f"Failed to create file, expected a non null file_id but got {batch_input_file_id}"
|
||||||
|
|
||||||
|
create_batch_response = await litellm.acreate_batch(
|
||||||
|
completion_window="24h",
|
||||||
|
endpoint="/v1/chat/completions",
|
||||||
|
input_file_id=batch_input_file_id,
|
||||||
|
custom_llm_provider="vertex_ai",
|
||||||
|
metadata={"key1": "value1", "key2": "value2"},
|
||||||
|
)
|
||||||
|
print("create_batch_response=", create_batch_response)
|
||||||
|
pass
|
||||||
|
|
2
tests/local_testing/vertex_batch_completions.jsonl
Normal file
2
tests/local_testing/vertex_batch_completions.jsonl
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gemini-1.5-flash-001", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 10}}
|
||||||
|
{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gemini-1.5-flash-001", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 10}}
|
|
@ -85,3 +85,37 @@ async def test_batches_operations():
|
||||||
|
|
||||||
# Test delete file
|
# Test delete file
|
||||||
await delete_file(session, file_id)
|
await delete_file(session, file_id)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="Local only test to verify if things work well")
|
||||||
|
def test_vertex_batches_endpoint():
|
||||||
|
"""
|
||||||
|
Test VertexAI Batches Endpoint
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
|
||||||
|
oai_client = OpenAI(api_key=API_KEY, base_url=BASE_URL)
|
||||||
|
file_name = "local_testing/vertex_batch_completions.jsonl"
|
||||||
|
_current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
file_path = os.path.join(_current_dir, file_name)
|
||||||
|
file_obj = oai_client.files.create(
|
||||||
|
file=open(file_path, "rb"),
|
||||||
|
purpose="batch",
|
||||||
|
extra_body={"custom_llm_provider": "vertex_ai"},
|
||||||
|
)
|
||||||
|
print("Response from creating file=", file_obj)
|
||||||
|
|
||||||
|
batch_input_file_id = file_obj.id
|
||||||
|
assert (
|
||||||
|
batch_input_file_id is not None
|
||||||
|
), f"Failed to create file, expected a non null file_id but got {batch_input_file_id}"
|
||||||
|
|
||||||
|
create_batch_response = oai_client.batches.create(
|
||||||
|
completion_window="24h",
|
||||||
|
endpoint="/v1/chat/completions",
|
||||||
|
input_file_id=batch_input_file_id,
|
||||||
|
extra_body={"custom_llm_provider": "vertex_ai"},
|
||||||
|
metadata={"key1": "value1", "key2": "value2"},
|
||||||
|
)
|
||||||
|
print("response from create batch", create_batch_response)
|
||||||
|
pass
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue