mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
(feat) add Vertex Batches API support in OpenAI format (#7032)
* working request * working transform * working request * transform vertex batch response * add _async_create_batch * move gcs functions to base * fix _get_content_from_openai_file * transform_openai_file_content_to_vertex_ai_file_content * fix transform vertex gcs bucket upload to OAI files format * working e2e test * _get_gcs_object_name * fix linting * add doc string * fix transform_gcs_bucket_response_to_openai_file_object * use vertex for batch endpoints * add batches support for vertex * test_vertex_batches_endpoint * test_vertex_batch_prediction * fix gcs bucket base auth * docs clean up batches * docs Batch API * docs vertex batches api * test_get_gcs_logging_config_without_service_account * undo change * fix vertex md * test_get_gcs_logging_config_without_service_account * ci/cd run again
This commit is contained in:
parent
6cad9c58ac
commit
84db69d4c4
20 changed files with 1347 additions and 424 deletions
|
@ -6,8 +6,9 @@ import TabItem from '@theme/TabItem';
|
|||
Covers Batches, Files
|
||||
|
||||
## **Supported Providers**:
|
||||
- Azure OpenAI
|
||||
- **[Azure OpenAI](./providers/azure#azure-batches-api)**
|
||||
- OpenAI
|
||||
- **[Vertex AI](./providers/vertex#batch-apis)**
|
||||
|
||||
## Quick Start
|
||||
|
||||
|
@ -141,182 +142,4 @@ print("list_batches_response=", list_batches_response)
|
|||
|
||||
</Tabs>
|
||||
|
||||
## [👉 Proxy API Reference](https://litellm-api.up.railway.app/#/batch)
|
||||
|
||||
## Azure Batches API
|
||||
|
||||
Just add the azure env vars to your environment.
|
||||
|
||||
```bash
|
||||
export AZURE_API_KEY=""
|
||||
export AZURE_API_BASE=""
|
||||
```
|
||||
|
||||
AND use `/azure/*` for the Batches API calls
|
||||
|
||||
```bash
|
||||
http://0.0.0.0:4000/azure/v1/batches
|
||||
```
|
||||
### Usage
|
||||
|
||||
**Setup**
|
||||
|
||||
- Add Azure API Keys to your environment
|
||||
|
||||
#### 1. Upload a File
|
||||
|
||||
```bash
|
||||
curl http://localhost:4000/azure/v1/files \
|
||||
-H "Authorization: Bearer sk-1234" \
|
||||
-F purpose="batch" \
|
||||
-F file="@mydata.jsonl"
|
||||
```
|
||||
|
||||
**Example File**
|
||||
|
||||
Note: `model` should be your azure deployment name.
|
||||
|
||||
```json
|
||||
{"custom_id": "task-0", "method": "POST", "url": "/chat/completions", "body": {"model": "REPLACE-WITH-MODEL-DEPLOYMENT-NAME", "messages": [{"role": "system", "content": "You are an AI assistant that helps people find information."}, {"role": "user", "content": "When was Microsoft founded?"}]}}
|
||||
{"custom_id": "task-1", "method": "POST", "url": "/chat/completions", "body": {"model": "REPLACE-WITH-MODEL-DEPLOYMENT-NAME", "messages": [{"role": "system", "content": "You are an AI assistant that helps people find information."}, {"role": "user", "content": "When was the first XBOX released?"}]}}
|
||||
{"custom_id": "task-2", "method": "POST", "url": "/chat/completions", "body": {"model": "REPLACE-WITH-MODEL-DEPLOYMENT-NAME", "messages": [{"role": "system", "content": "You are an AI assistant that helps people find information."}, {"role": "user", "content": "What is Altair Basic?"}]}}
|
||||
```
|
||||
|
||||
#### 2. Create a batch
|
||||
|
||||
```bash
|
||||
curl http://0.0.0.0:4000/azure/v1/batches \
|
||||
-H "Authorization: Bearer $LITELLM_API_KEY" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"input_file_id": "file-abc123",
|
||||
"endpoint": "/v1/chat/completions",
|
||||
"completion_window": "24h"
|
||||
}'
|
||||
|
||||
```
|
||||
|
||||
#### 3. Retrieve batch
|
||||
|
||||
|
||||
```bash
|
||||
curl http://0.0.0.0:4000/azure/v1/batches/batch_abc123 \
|
||||
-H "Authorization: Bearer $LITELLM_API_KEY" \
|
||||
-H "Content-Type: application/json" \
|
||||
```
|
||||
|
||||
#### 4. Cancel batch
|
||||
|
||||
```bash
|
||||
curl http://0.0.0.0:4000/azure/v1/batches/batch_abc123/cancel \
|
||||
-H "Authorization: Bearer $LITELLM_API_KEY" \
|
||||
-H "Content-Type: application/json" \
|
||||
-X POST
|
||||
```
|
||||
|
||||
#### 5. List Batch
|
||||
|
||||
```bash
|
||||
curl http://0.0.0.0:4000/v1/batches?limit=2 \
|
||||
-H "Authorization: Bearer $LITELLM_API_KEY" \
|
||||
-H "Content-Type: application/json"
|
||||
```
|
||||
|
||||
### [👉 Health Check Azure Batch models](./proxy/health.md#batch-models-azure-only)
|
||||
|
||||
|
||||
### [BETA] Loadbalance Multiple Azure Deployments
|
||||
In your config.yaml, set `enable_loadbalancing_on_batch_endpoints: true`
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: "batch-gpt-4o-mini"
|
||||
litellm_params:
|
||||
model: "azure/gpt-4o-mini"
|
||||
api_key: os.environ/AZURE_API_KEY
|
||||
api_base: os.environ/AZURE_API_BASE
|
||||
model_info:
|
||||
mode: batch
|
||||
|
||||
litellm_settings:
|
||||
enable_loadbalancing_on_batch_endpoints: true # 👈 KEY CHANGE
|
||||
```
|
||||
|
||||
Note: This works on `{PROXY_BASE_URL}/v1/files` and `{PROXY_BASE_URL}/v1/batches`.
|
||||
Note: Response is in the OpenAI-format.
|
||||
|
||||
1. Upload a file
|
||||
|
||||
Just set `model: batch-gpt-4o-mini` in your .jsonl.
|
||||
|
||||
```bash
|
||||
curl http://localhost:4000/v1/files \
|
||||
-H "Authorization: Bearer sk-1234" \
|
||||
-F purpose="batch" \
|
||||
-F file="@mydata.jsonl"
|
||||
```
|
||||
|
||||
**Example File**
|
||||
|
||||
Note: `model` should be your azure deployment name.
|
||||
|
||||
```json
|
||||
{"custom_id": "task-0", "method": "POST", "url": "/chat/completions", "body": {"model": "batch-gpt-4o-mini", "messages": [{"role": "system", "content": "You are an AI assistant that helps people find information."}, {"role": "user", "content": "When was Microsoft founded?"}]}}
|
||||
{"custom_id": "task-1", "method": "POST", "url": "/chat/completions", "body": {"model": "batch-gpt-4o-mini", "messages": [{"role": "system", "content": "You are an AI assistant that helps people find information."}, {"role": "user", "content": "When was the first XBOX released?"}]}}
|
||||
{"custom_id": "task-2", "method": "POST", "url": "/chat/completions", "body": {"model": "batch-gpt-4o-mini", "messages": [{"role": "system", "content": "You are an AI assistant that helps people find information."}, {"role": "user", "content": "What is Altair Basic?"}]}}
|
||||
```
|
||||
|
||||
Expected Response (OpenAI-compatible)
|
||||
|
||||
```bash
|
||||
{"id":"file-f0be81f654454113a922da60acb0eea6",...}
|
||||
```
|
||||
|
||||
2. Create a batch
|
||||
|
||||
```bash
|
||||
curl http://0.0.0.0:4000/v1/batches \
|
||||
-H "Authorization: Bearer $LITELLM_API_KEY" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"input_file_id": "file-f0be81f654454113a922da60acb0eea6",
|
||||
"endpoint": "/v1/chat/completions",
|
||||
"completion_window": "24h",
|
||||
"model: "batch-gpt-4o-mini"
|
||||
}'
|
||||
```
|
||||
|
||||
Expected Response:
|
||||
|
||||
```bash
|
||||
{"id":"batch_94e43f0a-d805-477d-adf9-bbb9c50910ed",...}
|
||||
```
|
||||
|
||||
3. Retrieve a batch
|
||||
|
||||
```bash
|
||||
curl http://0.0.0.0:4000/v1/batches/batch_94e43f0a-d805-477d-adf9-bbb9c50910ed \
|
||||
-H "Authorization: Bearer $LITELLM_API_KEY" \
|
||||
-H "Content-Type: application/json" \
|
||||
```
|
||||
|
||||
|
||||
Expected Response:
|
||||
|
||||
```
|
||||
{"id":"batch_94e43f0a-d805-477d-adf9-bbb9c50910ed",...}
|
||||
```
|
||||
|
||||
4. List batch
|
||||
|
||||
```bash
|
||||
curl http://0.0.0.0:4000/v1/batches?limit=2 \
|
||||
-H "Authorization: Bearer $LITELLM_API_KEY" \
|
||||
-H "Content-Type: application/json"
|
||||
```
|
||||
|
||||
Expected Response:
|
||||
|
||||
```bash
|
||||
{"data":[{"id":"batch_R3V...}
|
||||
```
|
||||
## [Swagger API Reference](https://litellm-api.up.railway.app/#/batch)
|
||||
|
|
|
@ -559,6 +559,185 @@ litellm_settings:
|
|||
</Tabs>
|
||||
|
||||
|
||||
## **Azure Batches API**
|
||||
|
||||
Just add the azure env vars to your environment.
|
||||
|
||||
```bash
|
||||
export AZURE_API_KEY=""
|
||||
export AZURE_API_BASE=""
|
||||
```
|
||||
|
||||
AND use `/azure/*` for the Batches API calls
|
||||
|
||||
```bash
|
||||
http://0.0.0.0:4000/azure/v1/batches
|
||||
```
|
||||
### Usage
|
||||
|
||||
**Setup**
|
||||
|
||||
- Add Azure API Keys to your environment
|
||||
|
||||
#### 1. Upload a File
|
||||
|
||||
```bash
|
||||
curl http://localhost:4000/azure/v1/files \
|
||||
-H "Authorization: Bearer sk-1234" \
|
||||
-F purpose="batch" \
|
||||
-F file="@mydata.jsonl"
|
||||
```
|
||||
|
||||
**Example File**
|
||||
|
||||
Note: `model` should be your azure deployment name.
|
||||
|
||||
```json
|
||||
{"custom_id": "task-0", "method": "POST", "url": "/chat/completions", "body": {"model": "REPLACE-WITH-MODEL-DEPLOYMENT-NAME", "messages": [{"role": "system", "content": "You are an AI assistant that helps people find information."}, {"role": "user", "content": "When was Microsoft founded?"}]}}
|
||||
{"custom_id": "task-1", "method": "POST", "url": "/chat/completions", "body": {"model": "REPLACE-WITH-MODEL-DEPLOYMENT-NAME", "messages": [{"role": "system", "content": "You are an AI assistant that helps people find information."}, {"role": "user", "content": "When was the first XBOX released?"}]}}
|
||||
{"custom_id": "task-2", "method": "POST", "url": "/chat/completions", "body": {"model": "REPLACE-WITH-MODEL-DEPLOYMENT-NAME", "messages": [{"role": "system", "content": "You are an AI assistant that helps people find information."}, {"role": "user", "content": "What is Altair Basic?"}]}}
|
||||
```
|
||||
|
||||
#### 2. Create a batch
|
||||
|
||||
```bash
|
||||
curl http://0.0.0.0:4000/azure/v1/batches \
|
||||
-H "Authorization: Bearer $LITELLM_API_KEY" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"input_file_id": "file-abc123",
|
||||
"endpoint": "/v1/chat/completions",
|
||||
"completion_window": "24h"
|
||||
}'
|
||||
|
||||
```
|
||||
|
||||
#### 3. Retrieve batch
|
||||
|
||||
|
||||
```bash
|
||||
curl http://0.0.0.0:4000/azure/v1/batches/batch_abc123 \
|
||||
-H "Authorization: Bearer $LITELLM_API_KEY" \
|
||||
-H "Content-Type: application/json" \
|
||||
```
|
||||
|
||||
#### 4. Cancel batch
|
||||
|
||||
```bash
|
||||
curl http://0.0.0.0:4000/azure/v1/batches/batch_abc123/cancel \
|
||||
-H "Authorization: Bearer $LITELLM_API_KEY" \
|
||||
-H "Content-Type: application/json" \
|
||||
-X POST
|
||||
```
|
||||
|
||||
#### 5. List Batch
|
||||
|
||||
```bash
|
||||
curl http://0.0.0.0:4000/v1/batches?limit=2 \
|
||||
-H "Authorization: Bearer $LITELLM_API_KEY" \
|
||||
-H "Content-Type: application/json"
|
||||
```
|
||||
|
||||
### [Health Check Azure Batch models](./proxy/health.md#batch-models-azure-only)
|
||||
|
||||
|
||||
### [BETA] Loadbalance Multiple Azure Deployments
|
||||
In your config.yaml, set `enable_loadbalancing_on_batch_endpoints: true`
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: "batch-gpt-4o-mini"
|
||||
litellm_params:
|
||||
model: "azure/gpt-4o-mini"
|
||||
api_key: os.environ/AZURE_API_KEY
|
||||
api_base: os.environ/AZURE_API_BASE
|
||||
model_info:
|
||||
mode: batch
|
||||
|
||||
litellm_settings:
|
||||
enable_loadbalancing_on_batch_endpoints: true # 👈 KEY CHANGE
|
||||
```
|
||||
|
||||
Note: This works on `{PROXY_BASE_URL}/v1/files` and `{PROXY_BASE_URL}/v1/batches`.
|
||||
Note: Response is in the OpenAI-format.
|
||||
|
||||
1. Upload a file
|
||||
|
||||
Just set `model: batch-gpt-4o-mini` in your .jsonl.
|
||||
|
||||
```bash
|
||||
curl http://localhost:4000/v1/files \
|
||||
-H "Authorization: Bearer sk-1234" \
|
||||
-F purpose="batch" \
|
||||
-F file="@mydata.jsonl"
|
||||
```
|
||||
|
||||
**Example File**
|
||||
|
||||
Note: `model` should be your azure deployment name.
|
||||
|
||||
```json
|
||||
{"custom_id": "task-0", "method": "POST", "url": "/chat/completions", "body": {"model": "batch-gpt-4o-mini", "messages": [{"role": "system", "content": "You are an AI assistant that helps people find information."}, {"role": "user", "content": "When was Microsoft founded?"}]}}
|
||||
{"custom_id": "task-1", "method": "POST", "url": "/chat/completions", "body": {"model": "batch-gpt-4o-mini", "messages": [{"role": "system", "content": "You are an AI assistant that helps people find information."}, {"role": "user", "content": "When was the first XBOX released?"}]}}
|
||||
{"custom_id": "task-2", "method": "POST", "url": "/chat/completions", "body": {"model": "batch-gpt-4o-mini", "messages": [{"role": "system", "content": "You are an AI assistant that helps people find information."}, {"role": "user", "content": "What is Altair Basic?"}]}}
|
||||
```
|
||||
|
||||
Expected Response (OpenAI-compatible)
|
||||
|
||||
```bash
|
||||
{"id":"file-f0be81f654454113a922da60acb0eea6",...}
|
||||
```
|
||||
|
||||
2. Create a batch
|
||||
|
||||
```bash
|
||||
curl http://0.0.0.0:4000/v1/batches \
|
||||
-H "Authorization: Bearer $LITELLM_API_KEY" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"input_file_id": "file-f0be81f654454113a922da60acb0eea6",
|
||||
"endpoint": "/v1/chat/completions",
|
||||
"completion_window": "24h",
|
||||
"model: "batch-gpt-4o-mini"
|
||||
}'
|
||||
```
|
||||
|
||||
Expected Response:
|
||||
|
||||
```bash
|
||||
{"id":"batch_94e43f0a-d805-477d-adf9-bbb9c50910ed",...}
|
||||
```
|
||||
|
||||
3. Retrieve a batch
|
||||
|
||||
```bash
|
||||
curl http://0.0.0.0:4000/v1/batches/batch_94e43f0a-d805-477d-adf9-bbb9c50910ed \
|
||||
-H "Authorization: Bearer $LITELLM_API_KEY" \
|
||||
-H "Content-Type: application/json" \
|
||||
```
|
||||
|
||||
|
||||
Expected Response:
|
||||
|
||||
```
|
||||
{"id":"batch_94e43f0a-d805-477d-adf9-bbb9c50910ed",...}
|
||||
```
|
||||
|
||||
4. List batch
|
||||
|
||||
```bash
|
||||
curl http://0.0.0.0:4000/v1/batches?limit=2 \
|
||||
-H "Authorization: Bearer $LITELLM_API_KEY" \
|
||||
-H "Content-Type: application/json"
|
||||
```
|
||||
|
||||
Expected Response:
|
||||
|
||||
```bash
|
||||
{"data":[{"id":"batch_R3V...}
|
||||
```
|
||||
|
||||
|
||||
## Advanced
|
||||
### Azure API Load-Balancing
|
||||
|
||||
|
|
|
@ -2393,6 +2393,114 @@ print("response from proxy", response)
|
|||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
## **Batch APIs**
|
||||
|
||||
Just add the following Vertex env vars to your environment.
|
||||
|
||||
```bash
|
||||
# GCS Bucket settings, used to store batch prediction files in
|
||||
export GCS_BUCKET_NAME = "litellm-testing-bucket" # the bucket you want to store batch prediction files in
|
||||
export GCS_PATH_SERVICE_ACCOUNT="/path/to/service_account.json" # path to your service account json file
|
||||
|
||||
# Vertex /batch endpoint settings, used for LLM API requests
|
||||
export GOOGLE_APPLICATION_CREDENTIALS="/path/to/service_account.json" # path to your service account json file
|
||||
export VERTEXAI_LOCATION="us-central1" # can be any vertex location
|
||||
export VERTEXAI_PROJECT="my-test-project"
|
||||
```
|
||||
|
||||
### Usage
|
||||
|
||||
|
||||
#### 1. Create a file of batch requests for vertex
|
||||
|
||||
LiteLLM expects the file to follow the **[OpenAI batches files format](https://platform.openai.com/docs/guides/batch)**
|
||||
|
||||
Each `body` in the file should be an **OpenAI API request**
|
||||
|
||||
Create a file called `vertex_batch_completions.jsonl` in the current working directory, the `model` should be the Vertex AI model name
|
||||
```
|
||||
{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gemini-1.5-flash-001", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 10}}
|
||||
{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gemini-1.5-flash-001", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 10}}
|
||||
```
|
||||
|
||||
|
||||
#### 2. Upload a File of batch requests
|
||||
|
||||
For `vertex_ai` litellm will upload the file to the provided `GCS_BUCKET_NAME`
|
||||
|
||||
```python
|
||||
import os
|
||||
oai_client = OpenAI(
|
||||
api_key="sk-1234", # litellm proxy API key
|
||||
base_url="http://localhost:4000" # litellm proxy base url
|
||||
)
|
||||
file_name = "vertex_batch_completions.jsonl" #
|
||||
_current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
file_path = os.path.join(_current_dir, file_name)
|
||||
file_obj = oai_client.files.create(
|
||||
file=open(file_path, "rb"),
|
||||
purpose="batch",
|
||||
extra_body={"custom_llm_provider": "vertex_ai"}, # tell litellm to use vertex_ai for this file upload
|
||||
)
|
||||
```
|
||||
|
||||
**Expected Response**
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "gs://litellm-testing-bucket/litellm-vertex-files/publishers/google/models/gemini-1.5-flash-001/d3f198cd-c0d1-436d-9b1e-28e3f282997a",
|
||||
"bytes": 416,
|
||||
"created_at": 1733392026,
|
||||
"filename": "litellm-vertex-files/publishers/google/models/gemini-1.5-flash-001/d3f198cd-c0d1-436d-9b1e-28e3f282997a",
|
||||
"object": "file",
|
||||
"purpose": "batch",
|
||||
"status": "uploaded",
|
||||
"status_details": null
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
|
||||
#### 3. Create a batch
|
||||
|
||||
```python
|
||||
batch_input_file_id = file_obj.id # use `file_obj` from step 2
|
||||
create_batch_response = oai_client.batches.create(
|
||||
completion_window="24h",
|
||||
endpoint="/v1/chat/completions",
|
||||
input_file_id=batch_input_file_id, # example input_file_id = "gs://litellm-testing-bucket/litellm-vertex-files/publishers/google/models/gemini-1.5-flash-001/c2b1b785-252b-448c-b180-033c4c63b3ce"
|
||||
extra_body={"custom_llm_provider": "vertex_ai"}, # tell litellm to use `vertex_ai` for this batch request
|
||||
)
|
||||
```
|
||||
|
||||
**Expected Response**
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "projects/633608382793/locations/us-central1/batchPredictionJobs/986266568679751680",
|
||||
"completion_window": "24hrs",
|
||||
"created_at": 1733392026,
|
||||
"endpoint": "",
|
||||
"input_file_id": "gs://litellm-testing-bucket/litellm-vertex-files/publishers/google/models/gemini-1.5-flash-001/d3f198cd-c0d1-436d-9b1e-28e3f282997a",
|
||||
"object": "batch",
|
||||
"status": "validating",
|
||||
"cancelled_at": null,
|
||||
"cancelling_at": null,
|
||||
"completed_at": null,
|
||||
"error_file_id": null,
|
||||
"errors": null,
|
||||
"expired_at": null,
|
||||
"expires_at": null,
|
||||
"failed_at": null,
|
||||
"finalizing_at": null,
|
||||
"in_progress_at": null,
|
||||
"metadata": null,
|
||||
"output_file_id": "gs://litellm-testing-bucket/litellm-vertex-files/publishers/google/models/gemini-1.5-flash-001",
|
||||
"request_counts": null
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
## Extra
|
||||
|
||||
### Using `GOOGLE_APPLICATION_CREDENTIALS`
|
||||
|
|
|
@ -22,6 +22,9 @@ import litellm
|
|||
from litellm import client
|
||||
from litellm.llms.AzureOpenAI.azure import AzureBatchesAPI
|
||||
from litellm.llms.OpenAI.openai import OpenAIBatchesAPI
|
||||
from litellm.llms.vertex_ai_and_google_ai_studio.batches.handler import (
|
||||
VertexAIBatchPrediction,
|
||||
)
|
||||
from litellm.secret_managers.main import get_secret, get_secret_str
|
||||
from litellm.types.llms.openai import (
|
||||
Batch,
|
||||
|
@ -40,6 +43,7 @@ from litellm.utils import supports_httpx_timeout
|
|||
####### ENVIRONMENT VARIABLES ###################
|
||||
openai_batches_instance = OpenAIBatchesAPI()
|
||||
azure_batches_instance = AzureBatchesAPI()
|
||||
vertex_ai_batches_instance = VertexAIBatchPrediction(gcs_bucket_name="")
|
||||
#################################################
|
||||
|
||||
|
||||
|
@ -47,7 +51,7 @@ async def acreate_batch(
|
|||
completion_window: Literal["24h"],
|
||||
endpoint: Literal["/v1/chat/completions", "/v1/embeddings", "/v1/completions"],
|
||||
input_file_id: str,
|
||||
custom_llm_provider: Literal["openai", "azure"] = "openai",
|
||||
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
||||
metadata: Optional[Dict[str, str]] = None,
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
|
@ -93,7 +97,7 @@ def create_batch(
|
|||
completion_window: Literal["24h"],
|
||||
endpoint: Literal["/v1/chat/completions", "/v1/embeddings", "/v1/completions"],
|
||||
input_file_id: str,
|
||||
custom_llm_provider: Literal["openai", "azure"] = "openai",
|
||||
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
||||
metadata: Optional[Dict[str, str]] = None,
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
|
@ -199,6 +203,32 @@ def create_batch(
|
|||
max_retries=optional_params.max_retries,
|
||||
create_batch_data=_create_batch_request,
|
||||
)
|
||||
elif custom_llm_provider == "vertex_ai":
|
||||
api_base = optional_params.api_base or ""
|
||||
vertex_ai_project = (
|
||||
optional_params.vertex_project
|
||||
or litellm.vertex_project
|
||||
or get_secret_str("VERTEXAI_PROJECT")
|
||||
)
|
||||
vertex_ai_location = (
|
||||
optional_params.vertex_location
|
||||
or litellm.vertex_location
|
||||
or get_secret_str("VERTEXAI_LOCATION")
|
||||
)
|
||||
vertex_credentials = optional_params.vertex_credentials or get_secret_str(
|
||||
"VERTEXAI_CREDENTIALS"
|
||||
)
|
||||
|
||||
response = vertex_ai_batches_instance.create_batch(
|
||||
_is_async=_is_async,
|
||||
api_base=api_base,
|
||||
vertex_project=vertex_ai_project,
|
||||
vertex_location=vertex_ai_location,
|
||||
vertex_credentials=vertex_credentials,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
create_batch_data=_create_batch_request,
|
||||
)
|
||||
else:
|
||||
raise litellm.exceptions.BadRequestError(
|
||||
message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format(
|
||||
|
|
|
@ -17,6 +17,9 @@ import litellm
|
|||
from litellm import client, get_secret_str
|
||||
from litellm.llms.files_apis.azure import AzureOpenAIFilesAPI
|
||||
from litellm.llms.OpenAI.openai import FileDeleted, FileObject, OpenAIFilesAPI
|
||||
from litellm.llms.vertex_ai_and_google_ai_studio.files.handler import (
|
||||
VertexAIFilesHandler,
|
||||
)
|
||||
from litellm.types.llms.openai import (
|
||||
Batch,
|
||||
CreateFileRequest,
|
||||
|
@ -30,6 +33,7 @@ from litellm.utils import supports_httpx_timeout
|
|||
####### ENVIRONMENT VARIABLES ###################
|
||||
openai_files_instance = OpenAIFilesAPI()
|
||||
azure_files_instance = AzureOpenAIFilesAPI()
|
||||
vertex_ai_files_instance = VertexAIFilesHandler()
|
||||
#################################################
|
||||
|
||||
|
||||
|
@ -490,7 +494,7 @@ def file_list(
|
|||
async def acreate_file(
|
||||
file: FileTypes,
|
||||
purpose: Literal["assistants", "batch", "fine-tune"],
|
||||
custom_llm_provider: Literal["openai", "azure"] = "openai",
|
||||
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
|
@ -532,7 +536,7 @@ async def acreate_file(
|
|||
def create_file(
|
||||
file: FileTypes,
|
||||
purpose: Literal["assistants", "batch", "fine-tune"],
|
||||
custom_llm_provider: Literal["openai", "azure"] = "openai",
|
||||
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
|
@ -630,6 +634,32 @@ def create_file(
|
|||
max_retries=optional_params.max_retries,
|
||||
create_file_data=_create_file_request,
|
||||
)
|
||||
elif custom_llm_provider == "vertex_ai":
|
||||
api_base = optional_params.api_base or ""
|
||||
vertex_ai_project = (
|
||||
optional_params.vertex_project
|
||||
or litellm.vertex_project
|
||||
or get_secret_str("VERTEXAI_PROJECT")
|
||||
)
|
||||
vertex_ai_location = (
|
||||
optional_params.vertex_location
|
||||
or litellm.vertex_location
|
||||
or get_secret_str("VERTEXAI_LOCATION")
|
||||
)
|
||||
vertex_credentials = optional_params.vertex_credentials or get_secret_str(
|
||||
"VERTEXAI_CREDENTIALS"
|
||||
)
|
||||
|
||||
response = vertex_ai_files_instance.create_file(
|
||||
_is_async=_is_async,
|
||||
api_base=api_base,
|
||||
vertex_project=vertex_ai_project,
|
||||
vertex_location=vertex_ai_location,
|
||||
vertex_credentials=vertex_credentials,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
create_file_data=_create_file_request,
|
||||
)
|
||||
else:
|
||||
raise litellm.exceptions.BadRequestError(
|
||||
message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format(
|
||||
|
|
|
@ -29,7 +29,6 @@ else:
|
|||
VertexBase = Any
|
||||
|
||||
|
||||
IAM_AUTH_KEY = "IAM_AUTH"
|
||||
GCS_DEFAULT_BATCH_SIZE = 2048
|
||||
GCS_DEFAULT_FLUSH_INTERVAL_SECONDS = 20
|
||||
|
||||
|
@ -39,7 +38,6 @@ class GCSBucketLogger(GCSBucketBase):
|
|||
from litellm.proxy.proxy_server import premium_user
|
||||
|
||||
super().__init__(bucket_name=bucket_name)
|
||||
self.vertex_instances: Dict[str, VertexBase] = {}
|
||||
|
||||
# Init Batch logging settings
|
||||
self.log_queue: List[GCSLogQueueItem] = []
|
||||
|
@ -178,232 +176,3 @@ class GCSBucketLogger(GCSBucketBase):
|
|||
object_name = _metadata["gcs_log_id"]
|
||||
|
||||
return object_name
|
||||
|
||||
def _handle_folders_in_bucket_name(
|
||||
self,
|
||||
bucket_name: str,
|
||||
object_name: str,
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
Handles when the user passes a bucket name with a folder postfix
|
||||
|
||||
|
||||
Example:
|
||||
- Bucket name: "my-bucket/my-folder/dev"
|
||||
- Object name: "my-object"
|
||||
- Returns: bucket_name="my-bucket", object_name="my-folder/dev/my-object"
|
||||
|
||||
"""
|
||||
if "/" in bucket_name:
|
||||
bucket_name, prefix = bucket_name.split("/", 1)
|
||||
object_name = f"{prefix}/{object_name}"
|
||||
return bucket_name, object_name
|
||||
return bucket_name, object_name
|
||||
|
||||
async def _log_json_data_on_gcs(
|
||||
self,
|
||||
headers: Dict[str, str],
|
||||
bucket_name: str,
|
||||
object_name: str,
|
||||
logging_payload: StandardLoggingPayload,
|
||||
):
|
||||
"""
|
||||
Helper function to make POST request to GCS Bucket in the specified bucket.
|
||||
"""
|
||||
json_logged_payload = json.dumps(logging_payload, default=str)
|
||||
|
||||
bucket_name, object_name = self._handle_folders_in_bucket_name(
|
||||
bucket_name=bucket_name,
|
||||
object_name=object_name,
|
||||
)
|
||||
|
||||
response = await self.async_httpx_client.post(
|
||||
headers=headers,
|
||||
url=f"https://storage.googleapis.com/upload/storage/v1/b/{bucket_name}/o?uploadType=media&name={object_name}",
|
||||
data=json_logged_payload,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
verbose_logger.error("GCS Bucket logging error: %s", str(response.text))
|
||||
|
||||
verbose_logger.debug("GCS Bucket response %s", response)
|
||||
verbose_logger.debug("GCS Bucket status code %s", response.status_code)
|
||||
verbose_logger.debug("GCS Bucket response.text %s", response.text)
|
||||
|
||||
async def get_gcs_logging_config(
|
||||
self, kwargs: Optional[Dict[str, Any]] = {}
|
||||
) -> GCSLoggingConfig:
|
||||
"""
|
||||
This function is used to get the GCS logging config for the GCS Bucket Logger.
|
||||
It checks if the dynamic parameters are provided in the kwargs and uses them to get the GCS logging config.
|
||||
If no dynamic parameters are provided, it uses the default values.
|
||||
"""
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
standard_callback_dynamic_params: Optional[StandardCallbackDynamicParams] = (
|
||||
kwargs.get("standard_callback_dynamic_params", None)
|
||||
)
|
||||
|
||||
bucket_name: str
|
||||
path_service_account: Optional[str]
|
||||
if standard_callback_dynamic_params is not None:
|
||||
verbose_logger.debug("Using dynamic GCS logging")
|
||||
verbose_logger.debug(
|
||||
"standard_callback_dynamic_params: %s", standard_callback_dynamic_params
|
||||
)
|
||||
|
||||
_bucket_name: Optional[str] = (
|
||||
standard_callback_dynamic_params.get("gcs_bucket_name", None)
|
||||
or self.BUCKET_NAME
|
||||
)
|
||||
_path_service_account: Optional[str] = (
|
||||
standard_callback_dynamic_params.get("gcs_path_service_account", None)
|
||||
or self.path_service_account_json
|
||||
)
|
||||
|
||||
if _bucket_name is None:
|
||||
raise ValueError(
|
||||
"GCS_BUCKET_NAME is not set in the environment, but GCS Bucket is being used as a logging callback. Please set 'GCS_BUCKET_NAME' in the environment."
|
||||
)
|
||||
bucket_name = _bucket_name
|
||||
path_service_account = _path_service_account
|
||||
vertex_instance = await self.get_or_create_vertex_instance(
|
||||
credentials=path_service_account
|
||||
)
|
||||
else:
|
||||
# If no dynamic parameters, use the default instance
|
||||
if self.BUCKET_NAME is None:
|
||||
raise ValueError(
|
||||
"GCS_BUCKET_NAME is not set in the environment, but GCS Bucket is being used as a logging callback. Please set 'GCS_BUCKET_NAME' in the environment."
|
||||
)
|
||||
bucket_name = self.BUCKET_NAME
|
||||
path_service_account = self.path_service_account_json
|
||||
vertex_instance = await self.get_or_create_vertex_instance(
|
||||
credentials=path_service_account
|
||||
)
|
||||
|
||||
return GCSLoggingConfig(
|
||||
bucket_name=bucket_name,
|
||||
vertex_instance=vertex_instance,
|
||||
path_service_account=path_service_account,
|
||||
)
|
||||
|
||||
async def get_or_create_vertex_instance(
|
||||
self, credentials: Optional[str]
|
||||
) -> VertexBase:
|
||||
"""
|
||||
This function is used to get the Vertex instance for the GCS Bucket Logger.
|
||||
It checks if the Vertex instance is already created and cached, if not it creates a new instance and caches it.
|
||||
"""
|
||||
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_llm_base import (
|
||||
VertexBase,
|
||||
)
|
||||
|
||||
_in_memory_key = self._get_in_memory_key_for_vertex_instance(credentials)
|
||||
if _in_memory_key not in self.vertex_instances:
|
||||
vertex_instance = VertexBase()
|
||||
await vertex_instance._ensure_access_token_async(
|
||||
credentials=credentials,
|
||||
project_id=None,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
self.vertex_instances[_in_memory_key] = vertex_instance
|
||||
return self.vertex_instances[_in_memory_key]
|
||||
|
||||
def _get_in_memory_key_for_vertex_instance(self, credentials: Optional[str]) -> str:
|
||||
"""
|
||||
Returns key to use for caching the Vertex instance in-memory.
|
||||
|
||||
When using Vertex with Key based logging, we need to cache the Vertex instance in-memory.
|
||||
|
||||
- If a credentials string is provided, it is used as the key.
|
||||
- If no credentials string is provided, "IAM_AUTH" is used as the key.
|
||||
"""
|
||||
return credentials or IAM_AUTH_KEY
|
||||
|
||||
async def download_gcs_object(self, object_name: str, **kwargs):
|
||||
"""
|
||||
Download an object from GCS.
|
||||
|
||||
https://cloud.google.com/storage/docs/downloading-objects#download-object-json
|
||||
"""
|
||||
try:
|
||||
gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
|
||||
kwargs=kwargs
|
||||
)
|
||||
headers = await self.construct_request_headers(
|
||||
vertex_instance=gcs_logging_config["vertex_instance"],
|
||||
service_account_json=gcs_logging_config["path_service_account"],
|
||||
)
|
||||
bucket_name = gcs_logging_config["bucket_name"]
|
||||
bucket_name, object_name = self._handle_folders_in_bucket_name(
|
||||
bucket_name=bucket_name,
|
||||
object_name=object_name,
|
||||
)
|
||||
|
||||
url = f"https://storage.googleapis.com/storage/v1/b/{bucket_name}/o/{object_name}?alt=media"
|
||||
|
||||
# Send the GET request to download the object
|
||||
response = await self.async_httpx_client.get(url=url, headers=headers)
|
||||
|
||||
if response.status_code != 200:
|
||||
verbose_logger.error(
|
||||
"GCS object download error: %s", str(response.text)
|
||||
)
|
||||
return None
|
||||
|
||||
verbose_logger.debug(
|
||||
"GCS object download response status code: %s", response.status_code
|
||||
)
|
||||
|
||||
# Return the content of the downloaded object
|
||||
return response.content
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.error("GCS object download error: %s", str(e))
|
||||
return None
|
||||
|
||||
async def delete_gcs_object(self, object_name: str, **kwargs):
|
||||
"""
|
||||
Delete an object from GCS.
|
||||
"""
|
||||
try:
|
||||
gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
|
||||
kwargs=kwargs
|
||||
)
|
||||
headers = await self.construct_request_headers(
|
||||
vertex_instance=gcs_logging_config["vertex_instance"],
|
||||
service_account_json=gcs_logging_config["path_service_account"],
|
||||
)
|
||||
bucket_name = gcs_logging_config["bucket_name"]
|
||||
bucket_name, object_name = self._handle_folders_in_bucket_name(
|
||||
bucket_name=bucket_name,
|
||||
object_name=object_name,
|
||||
)
|
||||
|
||||
url = f"https://storage.googleapis.com/storage/v1/b/{bucket_name}/o/{object_name}"
|
||||
|
||||
# Send the DELETE request to delete the object
|
||||
response = await self.async_httpx_client.delete(url=url, headers=headers)
|
||||
|
||||
if (response.status_code != 200) or (response.status_code != 204):
|
||||
verbose_logger.error(
|
||||
"GCS object delete error: %s, status code: %s",
|
||||
str(response.text),
|
||||
response.status_code,
|
||||
)
|
||||
return None
|
||||
|
||||
verbose_logger.debug(
|
||||
"GCS object delete response status code: %s, response: %s",
|
||||
response.status_code,
|
||||
response.text,
|
||||
)
|
||||
|
||||
# Return the content of the downloaded object
|
||||
return response.text
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.error("GCS object download error: %s", str(e))
|
||||
return None
|
||||
|
|
|
@ -2,7 +2,7 @@ import json
|
|||
import os
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypedDict, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, TypedDict, Union
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel, Field
|
||||
|
@ -14,11 +14,18 @@ from litellm.llms.custom_httpx.http_handler import (
|
|||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
from litellm.types.integrations.gcs_bucket import *
|
||||
from litellm.types.utils import (
|
||||
StandardCallbackDynamicParams,
|
||||
StandardLoggingMetadata,
|
||||
StandardLoggingPayload,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_llm_base import VertexBase
|
||||
else:
|
||||
VertexBase = Any
|
||||
IAM_AUTH_KEY = "IAM_AUTH"
|
||||
|
||||
|
||||
class GCSBucketBase(CustomBatchLogger):
|
||||
|
@ -30,6 +37,7 @@ class GCSBucketBase(CustomBatchLogger):
|
|||
_bucket_name = bucket_name or os.getenv("GCS_BUCKET_NAME")
|
||||
self.path_service_account_json: Optional[str] = _path_service_account
|
||||
self.BUCKET_NAME: Optional[str] = _bucket_name
|
||||
self.vertex_instances: Dict[str, VertexBase] = {}
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def construct_request_headers(
|
||||
|
@ -94,3 +102,237 @@ class GCSBucketBase(CustomBatchLogger):
|
|||
}
|
||||
|
||||
return headers
|
||||
|
||||
def _handle_folders_in_bucket_name(
|
||||
self,
|
||||
bucket_name: str,
|
||||
object_name: str,
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
Handles when the user passes a bucket name with a folder postfix
|
||||
|
||||
|
||||
Example:
|
||||
- Bucket name: "my-bucket/my-folder/dev"
|
||||
- Object name: "my-object"
|
||||
- Returns: bucket_name="my-bucket", object_name="my-folder/dev/my-object"
|
||||
|
||||
"""
|
||||
if "/" in bucket_name:
|
||||
bucket_name, prefix = bucket_name.split("/", 1)
|
||||
object_name = f"{prefix}/{object_name}"
|
||||
return bucket_name, object_name
|
||||
return bucket_name, object_name
|
||||
|
||||
async def get_gcs_logging_config(
|
||||
self, kwargs: Optional[Dict[str, Any]] = {}
|
||||
) -> GCSLoggingConfig:
|
||||
"""
|
||||
This function is used to get the GCS logging config for the GCS Bucket Logger.
|
||||
It checks if the dynamic parameters are provided in the kwargs and uses them to get the GCS logging config.
|
||||
If no dynamic parameters are provided, it uses the default values.
|
||||
"""
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
standard_callback_dynamic_params: Optional[StandardCallbackDynamicParams] = (
|
||||
kwargs.get("standard_callback_dynamic_params", None)
|
||||
)
|
||||
|
||||
bucket_name: str
|
||||
path_service_account: Optional[str]
|
||||
if standard_callback_dynamic_params is not None:
|
||||
verbose_logger.debug("Using dynamic GCS logging")
|
||||
verbose_logger.debug(
|
||||
"standard_callback_dynamic_params: %s", standard_callback_dynamic_params
|
||||
)
|
||||
|
||||
_bucket_name: Optional[str] = (
|
||||
standard_callback_dynamic_params.get("gcs_bucket_name", None)
|
||||
or self.BUCKET_NAME
|
||||
)
|
||||
_path_service_account: Optional[str] = (
|
||||
standard_callback_dynamic_params.get("gcs_path_service_account", None)
|
||||
or self.path_service_account_json
|
||||
)
|
||||
|
||||
if _bucket_name is None:
|
||||
raise ValueError(
|
||||
"GCS_BUCKET_NAME is not set in the environment, but GCS Bucket is being used as a logging callback. Please set 'GCS_BUCKET_NAME' in the environment."
|
||||
)
|
||||
bucket_name = _bucket_name
|
||||
path_service_account = _path_service_account
|
||||
vertex_instance = await self.get_or_create_vertex_instance(
|
||||
credentials=path_service_account
|
||||
)
|
||||
else:
|
||||
# If no dynamic parameters, use the default instance
|
||||
if self.BUCKET_NAME is None:
|
||||
raise ValueError(
|
||||
"GCS_BUCKET_NAME is not set in the environment, but GCS Bucket is being used as a logging callback. Please set 'GCS_BUCKET_NAME' in the environment."
|
||||
)
|
||||
bucket_name = self.BUCKET_NAME
|
||||
path_service_account = self.path_service_account_json
|
||||
vertex_instance = await self.get_or_create_vertex_instance(
|
||||
credentials=path_service_account
|
||||
)
|
||||
|
||||
return GCSLoggingConfig(
|
||||
bucket_name=bucket_name,
|
||||
vertex_instance=vertex_instance,
|
||||
path_service_account=path_service_account,
|
||||
)
|
||||
|
||||
async def get_or_create_vertex_instance(
|
||||
self, credentials: Optional[str]
|
||||
) -> VertexBase:
|
||||
"""
|
||||
This function is used to get the Vertex instance for the GCS Bucket Logger.
|
||||
It checks if the Vertex instance is already created and cached, if not it creates a new instance and caches it.
|
||||
"""
|
||||
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_llm_base import (
|
||||
VertexBase,
|
||||
)
|
||||
|
||||
_in_memory_key = self._get_in_memory_key_for_vertex_instance(credentials)
|
||||
if _in_memory_key not in self.vertex_instances:
|
||||
vertex_instance = VertexBase()
|
||||
await vertex_instance._ensure_access_token_async(
|
||||
credentials=credentials,
|
||||
project_id=None,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
self.vertex_instances[_in_memory_key] = vertex_instance
|
||||
return self.vertex_instances[_in_memory_key]
|
||||
|
||||
def _get_in_memory_key_for_vertex_instance(self, credentials: Optional[str]) -> str:
|
||||
"""
|
||||
Returns key to use for caching the Vertex instance in-memory.
|
||||
|
||||
When using Vertex with Key based logging, we need to cache the Vertex instance in-memory.
|
||||
|
||||
- If a credentials string is provided, it is used as the key.
|
||||
- If no credentials string is provided, "IAM_AUTH" is used as the key.
|
||||
"""
|
||||
return credentials or IAM_AUTH_KEY
|
||||
|
||||
async def download_gcs_object(self, object_name: str, **kwargs):
|
||||
"""
|
||||
Download an object from GCS.
|
||||
|
||||
https://cloud.google.com/storage/docs/downloading-objects#download-object-json
|
||||
"""
|
||||
try:
|
||||
gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
|
||||
kwargs=kwargs
|
||||
)
|
||||
headers = await self.construct_request_headers(
|
||||
vertex_instance=gcs_logging_config["vertex_instance"],
|
||||
service_account_json=gcs_logging_config["path_service_account"],
|
||||
)
|
||||
bucket_name = gcs_logging_config["bucket_name"]
|
||||
bucket_name, object_name = self._handle_folders_in_bucket_name(
|
||||
bucket_name=bucket_name,
|
||||
object_name=object_name,
|
||||
)
|
||||
|
||||
url = f"https://storage.googleapis.com/storage/v1/b/{bucket_name}/o/{object_name}?alt=media"
|
||||
|
||||
# Send the GET request to download the object
|
||||
response = await self.async_httpx_client.get(url=url, headers=headers)
|
||||
|
||||
if response.status_code != 200:
|
||||
verbose_logger.error(
|
||||
"GCS object download error: %s", str(response.text)
|
||||
)
|
||||
return None
|
||||
|
||||
verbose_logger.debug(
|
||||
"GCS object download response status code: %s", response.status_code
|
||||
)
|
||||
|
||||
# Return the content of the downloaded object
|
||||
return response.content
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.error("GCS object download error: %s", str(e))
|
||||
return None
|
||||
|
||||
async def delete_gcs_object(self, object_name: str, **kwargs):
|
||||
"""
|
||||
Delete an object from GCS.
|
||||
"""
|
||||
try:
|
||||
gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
|
||||
kwargs=kwargs
|
||||
)
|
||||
headers = await self.construct_request_headers(
|
||||
vertex_instance=gcs_logging_config["vertex_instance"],
|
||||
service_account_json=gcs_logging_config["path_service_account"],
|
||||
)
|
||||
bucket_name = gcs_logging_config["bucket_name"]
|
||||
bucket_name, object_name = self._handle_folders_in_bucket_name(
|
||||
bucket_name=bucket_name,
|
||||
object_name=object_name,
|
||||
)
|
||||
|
||||
url = f"https://storage.googleapis.com/storage/v1/b/{bucket_name}/o/{object_name}"
|
||||
|
||||
# Send the DELETE request to delete the object
|
||||
response = await self.async_httpx_client.delete(url=url, headers=headers)
|
||||
|
||||
if (response.status_code != 200) or (response.status_code != 204):
|
||||
verbose_logger.error(
|
||||
"GCS object delete error: %s, status code: %s",
|
||||
str(response.text),
|
||||
response.status_code,
|
||||
)
|
||||
return None
|
||||
|
||||
verbose_logger.debug(
|
||||
"GCS object delete response status code: %s, response: %s",
|
||||
response.status_code,
|
||||
response.text,
|
||||
)
|
||||
|
||||
# Return the content of the downloaded object
|
||||
return response.text
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.error("GCS object download error: %s", str(e))
|
||||
return None
|
||||
|
||||
async def _log_json_data_on_gcs(
|
||||
self,
|
||||
headers: Dict[str, str],
|
||||
bucket_name: str,
|
||||
object_name: str,
|
||||
logging_payload: Union[StandardLoggingPayload, str],
|
||||
):
|
||||
"""
|
||||
Helper function to make POST request to GCS Bucket in the specified bucket.
|
||||
"""
|
||||
if isinstance(logging_payload, str):
|
||||
json_logged_payload = logging_payload
|
||||
else:
|
||||
json_logged_payload = json.dumps(logging_payload, default=str)
|
||||
|
||||
bucket_name, object_name = self._handle_folders_in_bucket_name(
|
||||
bucket_name=bucket_name,
|
||||
object_name=object_name,
|
||||
)
|
||||
|
||||
response = await self.async_httpx_client.post(
|
||||
headers=headers,
|
||||
url=f"https://storage.googleapis.com/upload/storage/v1/b/{bucket_name}/o?uploadType=media&name={object_name}",
|
||||
data=json_logged_payload,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
verbose_logger.error("GCS Bucket logging error: %s", str(response.text))
|
||||
|
||||
verbose_logger.debug("GCS Bucket response %s", response)
|
||||
verbose_logger.debug("GCS Bucket status code %s", response.status_code)
|
||||
verbose_logger.debug("GCS Bucket response.text %s", response.text)
|
||||
|
||||
return response.json()
|
||||
|
|
|
@ -0,0 +1,6 @@
|
|||
# Vertex AI Batch Prediction Jobs
|
||||
|
||||
Implementation to call VertexAI Batch endpoints in OpenAI Batch API spec
|
||||
|
||||
Vertex Docs: https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/batch-prediction-gemini
|
||||
|
141
litellm/llms/vertex_ai_and_google_ai_studio/batches/handler.py
Normal file
141
litellm/llms/vertex_ai_and_google_ai_studio/batches/handler.py
Normal file
|
@ -0,0 +1,141 @@
|
|||
import json
|
||||
from typing import Any, Coroutine, Dict, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
HTTPHandler,
|
||||
_get_httpx_client,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
|
||||
VertexAIError,
|
||||
VertexLLM,
|
||||
)
|
||||
from litellm.types.llms.openai import (
|
||||
Batch,
|
||||
CancelBatchRequest,
|
||||
CreateBatchRequest,
|
||||
CreateFileRequest,
|
||||
FileContentRequest,
|
||||
FileObject,
|
||||
FileTypes,
|
||||
HttpxBinaryResponseContent,
|
||||
RetrieveBatchRequest,
|
||||
)
|
||||
from litellm.types.llms.vertex_ai import VertexAIBatchPredictionJob
|
||||
|
||||
from .transformation import VertexAIBatchTransformation
|
||||
|
||||
|
||||
class VertexAIBatchPrediction(VertexLLM):
|
||||
def __init__(self, gcs_bucket_name: str, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.gcs_bucket_name = gcs_bucket_name
|
||||
|
||||
def create_batch(
|
||||
self,
|
||||
_is_async: bool,
|
||||
create_batch_data: CreateBatchRequest,
|
||||
api_base: Optional[str],
|
||||
vertex_credentials: Optional[str],
|
||||
vertex_project: Optional[str],
|
||||
vertex_location: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
) -> Union[Batch, Coroutine[Any, Any, Batch]]:
|
||||
|
||||
sync_handler = _get_httpx_client()
|
||||
|
||||
access_token, project_id = self._ensure_access_token(
|
||||
credentials=vertex_credentials,
|
||||
project_id=vertex_project,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
|
||||
default_api_base = self.create_vertex_url(
|
||||
vertex_location=vertex_location or "us-central1",
|
||||
vertex_project=vertex_project or project_id,
|
||||
)
|
||||
|
||||
if len(default_api_base.split(":")) > 1:
|
||||
endpoint = default_api_base.split(":")[-1]
|
||||
else:
|
||||
endpoint = ""
|
||||
|
||||
_, api_base = self._check_custom_proxy(
|
||||
api_base=api_base,
|
||||
custom_llm_provider="vertex_ai",
|
||||
gemini_api_key=None,
|
||||
endpoint=endpoint,
|
||||
stream=None,
|
||||
auth_header=None,
|
||||
url=default_api_base,
|
||||
)
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json; charset=utf-8",
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
}
|
||||
|
||||
vertex_batch_request: VertexAIBatchPredictionJob = (
|
||||
VertexAIBatchTransformation.transform_openai_batch_request_to_vertex_ai_batch_request(
|
||||
request=create_batch_data
|
||||
)
|
||||
)
|
||||
|
||||
if _is_async is True:
|
||||
return self._async_create_batch(
|
||||
vertex_batch_request=vertex_batch_request,
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
response = sync_handler.post(
|
||||
url=api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(vertex_batch_request),
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Error: {response.status_code} {response.text}")
|
||||
|
||||
_json_response = response.json()
|
||||
vertex_batch_response = VertexAIBatchTransformation.transform_vertex_ai_batch_response_to_openai_batch_response(
|
||||
response=_json_response
|
||||
)
|
||||
return vertex_batch_response
|
||||
|
||||
async def _async_create_batch(
|
||||
self,
|
||||
vertex_batch_request: VertexAIBatchPredictionJob,
|
||||
api_base: str,
|
||||
headers: Dict[str, str],
|
||||
) -> Batch:
|
||||
client = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.VERTEX_AI,
|
||||
)
|
||||
response = await client.post(
|
||||
url=api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(vertex_batch_request),
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Error: {response.status_code} {response.text}")
|
||||
|
||||
_json_response = response.json()
|
||||
vertex_batch_response = VertexAIBatchTransformation.transform_vertex_ai_batch_response_to_openai_batch_response(
|
||||
response=_json_response
|
||||
)
|
||||
return vertex_batch_response
|
||||
|
||||
def create_vertex_url(
|
||||
self,
|
||||
vertex_location: str,
|
||||
vertex_project: str,
|
||||
) -> str:
|
||||
"""Return the base url for the vertex garden models"""
|
||||
# POST https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/LOCATION/batchPredictionJobs
|
||||
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/batchPredictionJobs"
|
|
@ -0,0 +1,174 @@
|
|||
import uuid
|
||||
from typing import Any, Dict, Literal
|
||||
|
||||
from litellm.llms.vertex_ai_and_google_ai_studio.common_utils import (
|
||||
_convert_vertex_datetime_to_openai_datetime,
|
||||
)
|
||||
from litellm.types.llms.openai import Batch, BatchJobStatus, CreateBatchRequest
|
||||
from litellm.types.llms.vertex_ai import *
|
||||
|
||||
|
||||
class VertexAIBatchTransformation:
|
||||
"""
|
||||
Transforms OpenAI Batch requests to Vertex AI Batch requests
|
||||
|
||||
API Ref: https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/batch-prediction-gemini
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def transform_openai_batch_request_to_vertex_ai_batch_request(
|
||||
cls,
|
||||
request: CreateBatchRequest,
|
||||
) -> VertexAIBatchPredictionJob:
|
||||
"""
|
||||
Transforms OpenAI Batch requests to Vertex AI Batch requests
|
||||
"""
|
||||
request_display_name = f"litellm-vertex-batch-{uuid.uuid4()}"
|
||||
input_file_id = request.get("input_file_id")
|
||||
if input_file_id is None:
|
||||
raise ValueError("input_file_id is required, but not provided")
|
||||
input_config: InputConfig = InputConfig(
|
||||
gcsSource=GcsSource(uris=input_file_id), instancesFormat="jsonl"
|
||||
)
|
||||
model: str = cls._get_model_from_gcs_file(input_file_id)
|
||||
output_config: OutputConfig = OutputConfig(
|
||||
predictionsFormat="jsonl",
|
||||
gcsDestination=GcsDestination(
|
||||
outputUriPrefix=cls._get_gcs_uri_prefix_from_file(input_file_id)
|
||||
),
|
||||
)
|
||||
return VertexAIBatchPredictionJob(
|
||||
inputConfig=input_config,
|
||||
outputConfig=output_config,
|
||||
model=model,
|
||||
displayName=request_display_name,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def transform_vertex_ai_batch_response_to_openai_batch_response(
|
||||
cls, response: VertexBatchPredictionResponse
|
||||
) -> Batch:
|
||||
return Batch(
|
||||
id=response.get("name", ""),
|
||||
completion_window="24hrs",
|
||||
created_at=_convert_vertex_datetime_to_openai_datetime(
|
||||
vertex_datetime=response.get("createTime", "")
|
||||
),
|
||||
endpoint="",
|
||||
input_file_id=cls._get_input_file_id_from_vertex_ai_batch_response(
|
||||
response
|
||||
),
|
||||
object="batch",
|
||||
status=cls._get_batch_job_status_from_vertex_ai_batch_response(response),
|
||||
error_file_id=None, # Vertex AI doesn't seem to have a direct equivalent
|
||||
output_file_id=cls._get_output_file_id_from_vertex_ai_batch_response(
|
||||
response
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _get_input_file_id_from_vertex_ai_batch_response(
|
||||
cls, response: VertexBatchPredictionResponse
|
||||
) -> str:
|
||||
"""
|
||||
Gets the input file id from the Vertex AI Batch response
|
||||
"""
|
||||
input_file_id: str = ""
|
||||
input_config = response.get("inputConfig")
|
||||
if input_config is None:
|
||||
return input_file_id
|
||||
|
||||
gcs_source = input_config.get("gcsSource")
|
||||
if gcs_source is None:
|
||||
return input_file_id
|
||||
|
||||
uris = gcs_source.get("uris", "")
|
||||
if len(uris) == 0:
|
||||
return input_file_id
|
||||
|
||||
return uris[0]
|
||||
|
||||
@classmethod
|
||||
def _get_output_file_id_from_vertex_ai_batch_response(
|
||||
cls, response: VertexBatchPredictionResponse
|
||||
) -> str:
|
||||
"""
|
||||
Gets the output file id from the Vertex AI Batch response
|
||||
"""
|
||||
output_file_id: str = ""
|
||||
output_config = response.get("outputConfig")
|
||||
if output_config is None:
|
||||
return output_file_id
|
||||
|
||||
gcs_destination = output_config.get("gcsDestination")
|
||||
if gcs_destination is None:
|
||||
return output_file_id
|
||||
|
||||
output_uri_prefix = gcs_destination.get("outputUriPrefix", "")
|
||||
return output_uri_prefix
|
||||
|
||||
@classmethod
|
||||
def _get_batch_job_status_from_vertex_ai_batch_response(
|
||||
cls, response: VertexBatchPredictionResponse
|
||||
) -> BatchJobStatus:
|
||||
"""
|
||||
Gets the batch job status from the Vertex AI Batch response
|
||||
|
||||
ref: https://cloud.google.com/vertex-ai/docs/reference/rest/v1/JobState
|
||||
"""
|
||||
state_mapping: Dict[str, BatchJobStatus] = {
|
||||
"JOB_STATE_UNSPECIFIED": "failed",
|
||||
"JOB_STATE_QUEUED": "validating",
|
||||
"JOB_STATE_PENDING": "validating",
|
||||
"JOB_STATE_RUNNING": "in_progress",
|
||||
"JOB_STATE_SUCCEEDED": "completed",
|
||||
"JOB_STATE_FAILED": "failed",
|
||||
"JOB_STATE_CANCELLING": "cancelling",
|
||||
"JOB_STATE_CANCELLED": "cancelled",
|
||||
"JOB_STATE_PAUSED": "in_progress",
|
||||
"JOB_STATE_EXPIRED": "expired",
|
||||
"JOB_STATE_UPDATING": "in_progress",
|
||||
"JOB_STATE_PARTIALLY_SUCCEEDED": "completed",
|
||||
}
|
||||
|
||||
vertex_state = response.get("state", "JOB_STATE_UNSPECIFIED")
|
||||
return state_mapping[vertex_state]
|
||||
|
||||
@classmethod
|
||||
def _get_gcs_uri_prefix_from_file(cls, input_file_id: str) -> str:
|
||||
"""
|
||||
Gets the gcs uri prefix from the input file id
|
||||
|
||||
Example:
|
||||
input_file_id: "gs://litellm-testing-bucket/vtx_batch.jsonl"
|
||||
returns: "gs://litellm-testing-bucket"
|
||||
|
||||
input_file_id: "gs://litellm-testing-bucket/batches/vtx_batch.jsonl"
|
||||
returns: "gs://litellm-testing-bucket/batches"
|
||||
"""
|
||||
# Split the path and remove the filename
|
||||
path_parts = input_file_id.rsplit("/", 1)
|
||||
return path_parts[0]
|
||||
|
||||
@classmethod
|
||||
def _get_model_from_gcs_file(cls, gcs_file_uri: str) -> str:
|
||||
"""
|
||||
Extracts the model from the gcs file uri
|
||||
|
||||
When files are uploaded using LiteLLM (/v1/files), the model is stored in the gcs file uri
|
||||
|
||||
Why?
|
||||
- Because Vertex Requires the `model` param in create batch jobs request, but OpenAI does not require this
|
||||
|
||||
|
||||
gcs_file_uri format: gs://litellm-testing-bucket/litellm-vertex-files/publishers/google/models/gemini-1.5-flash-001/e9412502-2c91-42a6-8e61-f5c294cc0fc8
|
||||
returns: "publishers/google/models/gemini-1.5-flash-001"
|
||||
"""
|
||||
from urllib.parse import unquote
|
||||
|
||||
decoded_uri = unquote(gcs_file_uri)
|
||||
|
||||
model_path = decoded_uri.split("publishers/")[1]
|
||||
parts = model_path.split("/")
|
||||
model = f"publishers/{'/'.join(parts[:3])}"
|
||||
return model
|
|
@ -264,3 +264,18 @@ def strip_field(schema, field_name: str):
|
|||
items = schema.get("items", None)
|
||||
if items is not None:
|
||||
strip_field(items, field_name)
|
||||
|
||||
|
||||
def _convert_vertex_datetime_to_openai_datetime(vertex_datetime: str) -> int:
|
||||
"""
|
||||
Converts a Vertex AI datetime string to an OpenAI datetime integer
|
||||
|
||||
vertex_datetime: str = "2024-12-04T21:53:12.120184Z"
|
||||
returns: int = 1722729192
|
||||
"""
|
||||
from datetime import datetime
|
||||
|
||||
# Parse the ISO format string to datetime object
|
||||
dt = datetime.strptime(vertex_datetime, "%Y-%m-%dT%H:%M:%S.%fZ")
|
||||
# Convert to Unix timestamp (seconds since epoch)
|
||||
return int(dt.timestamp())
|
||||
|
|
111
litellm/llms/vertex_ai_and_google_ai_studio/files/handler.py
Normal file
111
litellm/llms/vertex_ai_and_google_ai_studio/files/handler.py
Normal file
|
@ -0,0 +1,111 @@
|
|||
import json
|
||||
import uuid
|
||||
from typing import Any, Coroutine, Dict, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.integrations.gcs_bucket.gcs_bucket_base import (
|
||||
GCSBucketBase,
|
||||
GCSLoggingConfig,
|
||||
)
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
HTTPHandler,
|
||||
_get_httpx_client,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.llms.vertex_ai_and_google_ai_studio.common_utils import (
|
||||
_convert_vertex_datetime_to_openai_datetime,
|
||||
)
|
||||
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
|
||||
VertexAIError,
|
||||
VertexLLM,
|
||||
)
|
||||
from litellm.types.llms.openai import (
|
||||
Batch,
|
||||
CreateFileRequest,
|
||||
FileContentRequest,
|
||||
FileObject,
|
||||
FileTypes,
|
||||
HttpxBinaryResponseContent,
|
||||
)
|
||||
|
||||
from .transformation import VertexAIFilesTransformation
|
||||
|
||||
vertex_ai_files_transformation = VertexAIFilesTransformation()
|
||||
|
||||
|
||||
class VertexAIFilesHandler(GCSBucketBase):
|
||||
"""
|
||||
Handles Calling VertexAI in OpenAI Files API format v1/files/*
|
||||
|
||||
This implementation uploads files on GCS Buckets
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
async def async_create_file(
|
||||
self,
|
||||
create_file_data: CreateFileRequest,
|
||||
api_base: Optional[str],
|
||||
vertex_credentials: Optional[str],
|
||||
vertex_project: Optional[str],
|
||||
vertex_location: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
):
|
||||
gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
|
||||
kwargs={}
|
||||
)
|
||||
headers = await self.construct_request_headers(
|
||||
vertex_instance=gcs_logging_config["vertex_instance"],
|
||||
service_account_json=gcs_logging_config["path_service_account"],
|
||||
)
|
||||
bucket_name = gcs_logging_config["bucket_name"]
|
||||
logging_payload, object_name = (
|
||||
vertex_ai_files_transformation.transform_openai_file_content_to_vertex_ai_file_content(
|
||||
openai_file_content=create_file_data.get("file")
|
||||
)
|
||||
)
|
||||
gcs_upload_response = await self._log_json_data_on_gcs(
|
||||
headers=headers,
|
||||
bucket_name=bucket_name,
|
||||
object_name=object_name,
|
||||
logging_payload=logging_payload,
|
||||
)
|
||||
|
||||
return vertex_ai_files_transformation.transform_gcs_bucket_response_to_openai_file_object(
|
||||
create_file_data=create_file_data,
|
||||
gcs_upload_response=gcs_upload_response,
|
||||
)
|
||||
|
||||
def create_file(
|
||||
self,
|
||||
_is_async: bool,
|
||||
create_file_data: CreateFileRequest,
|
||||
api_base: Optional[str],
|
||||
vertex_credentials: Optional[str],
|
||||
vertex_project: Optional[str],
|
||||
vertex_location: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
) -> Union[FileObject, Coroutine[Any, Any, FileObject]]:
|
||||
"""
|
||||
Creates a file on VertexAI GCS Bucket
|
||||
|
||||
Only supported for Async litellm.acreate_file
|
||||
"""
|
||||
|
||||
if _is_async:
|
||||
return self.async_create_file(
|
||||
create_file_data=create_file_data,
|
||||
api_base=api_base,
|
||||
vertex_credentials=vertex_credentials,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
)
|
||||
|
||||
return None # type: ignore
|
|
@ -0,0 +1,173 @@
|
|||
import json
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from litellm.llms.vertex_ai_and_google_ai_studio.common_utils import (
|
||||
_convert_vertex_datetime_to_openai_datetime,
|
||||
)
|
||||
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.transformation import (
|
||||
_transform_request_body,
|
||||
)
|
||||
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
|
||||
VertexGeminiConfig,
|
||||
)
|
||||
from litellm.types.llms.openai import (
|
||||
Batch,
|
||||
CreateFileRequest,
|
||||
FileContentRequest,
|
||||
FileObject,
|
||||
FileTypes,
|
||||
HttpxBinaryResponseContent,
|
||||
PathLike,
|
||||
)
|
||||
|
||||
|
||||
class VertexAIFilesTransformation(VertexGeminiConfig):
|
||||
"""
|
||||
Transforms OpenAI /v1/files/* requests to VertexAI /v1/files/* requests
|
||||
"""
|
||||
|
||||
def transform_openai_file_content_to_vertex_ai_file_content(
|
||||
self, openai_file_content: Optional[FileTypes] = None
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
Transforms OpenAI FileContentRequest to VertexAI FileContentRequest
|
||||
"""
|
||||
|
||||
if openai_file_content is None:
|
||||
raise ValueError("contents of file are None")
|
||||
# Read the content of the file
|
||||
file_content = self._get_content_from_openai_file(openai_file_content)
|
||||
|
||||
# Split into lines and parse each line as JSON
|
||||
openai_jsonl_content = [
|
||||
json.loads(line) for line in file_content.splitlines() if line.strip()
|
||||
]
|
||||
vertex_jsonl_content = (
|
||||
self._transform_openai_jsonl_content_to_vertex_ai_jsonl_content(
|
||||
openai_jsonl_content
|
||||
)
|
||||
)
|
||||
vertex_jsonl_string = "\n".join(
|
||||
json.dumps(item) for item in vertex_jsonl_content
|
||||
)
|
||||
object_name = self._get_gcs_object_name(
|
||||
openai_jsonl_content=openai_jsonl_content
|
||||
)
|
||||
return vertex_jsonl_string, object_name
|
||||
|
||||
def _transform_openai_jsonl_content_to_vertex_ai_jsonl_content(
|
||||
self, openai_jsonl_content: List[Dict[str, Any]]
|
||||
):
|
||||
"""
|
||||
Transforms OpenAI JSONL content to VertexAI JSONL content
|
||||
|
||||
jsonl body for vertex is {"request": <request_body>}
|
||||
Example Vertex jsonl
|
||||
{"request":{"contents": [{"role": "user", "parts": [{"text": "What is the relation between the following video and image samples?"}, {"fileData": {"fileUri": "gs://cloud-samples-data/generative-ai/video/animals.mp4", "mimeType": "video/mp4"}}, {"fileData": {"fileUri": "gs://cloud-samples-data/generative-ai/image/cricket.jpeg", "mimeType": "image/jpeg"}}]}]}}
|
||||
{"request":{"contents": [{"role": "user", "parts": [{"text": "Describe what is happening in this video."}, {"fileData": {"fileUri": "gs://cloud-samples-data/generative-ai/video/another_video.mov", "mimeType": "video/mov"}}]}]}}
|
||||
"""
|
||||
|
||||
vertex_jsonl_content = []
|
||||
for _openai_jsonl_content in openai_jsonl_content:
|
||||
openai_request_body = _openai_jsonl_content.get("body") or {}
|
||||
vertex_request_body = _transform_request_body(
|
||||
messages=openai_request_body.get("messages", []),
|
||||
model=openai_request_body.get("model", ""),
|
||||
optional_params=self._map_openai_to_vertex_params(openai_request_body),
|
||||
custom_llm_provider="vertex_ai",
|
||||
litellm_params={},
|
||||
cached_content=None,
|
||||
)
|
||||
vertex_jsonl_content.append({"request": vertex_request_body})
|
||||
return vertex_jsonl_content
|
||||
|
||||
def _get_gcs_object_name(
|
||||
self,
|
||||
openai_jsonl_content: List[Dict[str, Any]],
|
||||
) -> str:
|
||||
"""
|
||||
Gets a unique GCS object name for the VertexAI batch prediction job
|
||||
|
||||
named as: litellm-vertex-{model}-{uuid}
|
||||
"""
|
||||
_model = openai_jsonl_content[0].get("body", {}).get("model", "")
|
||||
if "publishers/google/models" not in _model:
|
||||
_model = f"publishers/google/models/{_model}"
|
||||
object_name = f"litellm-vertex-files/{_model}/{uuid.uuid4()}"
|
||||
return object_name
|
||||
|
||||
def _map_openai_to_vertex_params(
|
||||
self,
|
||||
openai_request_body: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
wrapper to call VertexGeminiConfig.map_openai_params
|
||||
"""
|
||||
_model = openai_request_body.get("model", "")
|
||||
vertex_params = self.map_openai_params(
|
||||
model=_model,
|
||||
non_default_params=openai_request_body,
|
||||
optional_params={},
|
||||
drop_params=False,
|
||||
)
|
||||
return vertex_params
|
||||
|
||||
def _get_content_from_openai_file(self, openai_file_content: FileTypes) -> str:
|
||||
"""
|
||||
Helper to extract content from various OpenAI file types and return as string.
|
||||
|
||||
Handles:
|
||||
- Direct content (str, bytes, IO[bytes])
|
||||
- Tuple formats: (filename, content, [content_type], [headers])
|
||||
- PathLike objects
|
||||
"""
|
||||
content: Union[str, bytes] = b""
|
||||
# Extract file content from tuple if necessary
|
||||
if isinstance(openai_file_content, tuple):
|
||||
# Take the second element which is always the file content
|
||||
file_content = openai_file_content[1]
|
||||
else:
|
||||
file_content = openai_file_content
|
||||
|
||||
# Handle different file content types
|
||||
if isinstance(file_content, str):
|
||||
# String content can be used directly
|
||||
content = file_content
|
||||
elif isinstance(file_content, bytes):
|
||||
# Bytes content can be decoded
|
||||
content = file_content
|
||||
elif isinstance(file_content, PathLike): # PathLike
|
||||
with open(str(file_content), "rb") as f:
|
||||
content = f.read()
|
||||
elif hasattr(file_content, "read"): # IO[bytes]
|
||||
# File-like objects need to be read
|
||||
content = file_content.read()
|
||||
|
||||
# Ensure content is string
|
||||
if isinstance(content, bytes):
|
||||
content = content.decode("utf-8")
|
||||
|
||||
return content
|
||||
|
||||
def transform_gcs_bucket_response_to_openai_file_object(
|
||||
self, create_file_data: CreateFileRequest, gcs_upload_response: Dict[str, Any]
|
||||
) -> FileObject:
|
||||
"""
|
||||
Transforms GCS Bucket upload file response to OpenAI FileObject
|
||||
"""
|
||||
gcs_id = gcs_upload_response.get("id", "")
|
||||
# Remove the last numeric ID from the path
|
||||
gcs_id = "/".join(gcs_id.split("/")[:-1]) if gcs_id else ""
|
||||
|
||||
return FileObject(
|
||||
purpose=create_file_data.get("purpose", "batch"),
|
||||
id=f"gs://{gcs_id}",
|
||||
filename=gcs_upload_response.get("name", ""),
|
||||
created_at=_convert_vertex_datetime_to_openai_datetime(
|
||||
vertex_datetime=gcs_upload_response.get("timeCreated", "")
|
||||
),
|
||||
status="uploaded",
|
||||
bytes=gcs_upload_response.get("size", 0),
|
||||
object="file",
|
||||
)
|
|
@ -59,6 +59,8 @@ def get_files_provider_config(
|
|||
custom_llm_provider: str,
|
||||
):
|
||||
global files_config
|
||||
if custom_llm_provider == "vertex_ai":
|
||||
return None
|
||||
if files_config is None:
|
||||
raise ValueError("files_config is not set, set it on your config.yaml file.")
|
||||
for setting in files_config:
|
||||
|
@ -212,9 +214,9 @@ async def create_file(
|
|||
if llm_provider_config is not None:
|
||||
# add llm_provider_config to data
|
||||
_create_file_request.update(llm_provider_config)
|
||||
|
||||
_create_file_request.pop("custom_llm_provider", None) # type: ignore
|
||||
# for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch
|
||||
response = await litellm.acreate_file(**_create_file_request) # type: ignore
|
||||
response = await litellm.acreate_file(**_create_file_request, custom_llm_provider=custom_llm_provider) # type: ignore
|
||||
|
||||
### ALERTING ###
|
||||
asyncio.create_task(
|
||||
|
@ -239,7 +241,6 @@ async def create_file(
|
|||
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||
)
|
||||
)
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
await proxy_logging_obj.post_call_failure_hook(
|
||||
|
|
|
@ -5225,6 +5225,7 @@ async def create_batch(
|
|||
is_router_model = is_known_model(model=router_model, llm_router=llm_router)
|
||||
|
||||
_create_batch_data = CreateBatchRequest(**data)
|
||||
custom_llm_provider = provider or _create_batch_data.pop("custom_llm_provider", None) # type: ignore
|
||||
|
||||
if (
|
||||
litellm.enable_loadbalancing_on_batch_endpoints is True
|
||||
|
@ -5241,10 +5242,10 @@ async def create_batch(
|
|||
|
||||
response = await llm_router.acreate_batch(**_create_batch_data) # type: ignore
|
||||
else:
|
||||
if provider is None:
|
||||
provider = "openai"
|
||||
if custom_llm_provider is None:
|
||||
custom_llm_provider = "openai"
|
||||
response = await litellm.acreate_batch(
|
||||
custom_llm_provider=provider, **_create_batch_data # type: ignore
|
||||
custom_llm_provider=custom_llm_provider, **_create_batch_data # type: ignore
|
||||
)
|
||||
|
||||
### ALERTING ###
|
||||
|
|
|
@ -301,6 +301,18 @@ class ListBatchRequest(TypedDict, total=False):
|
|||
timeout: Optional[float]
|
||||
|
||||
|
||||
BatchJobStatus = Literal[
|
||||
"validating",
|
||||
"failed",
|
||||
"in_progress",
|
||||
"finalizing",
|
||||
"completed",
|
||||
"expired",
|
||||
"cancelling",
|
||||
"cancelled",
|
||||
]
|
||||
|
||||
|
||||
class ChatCompletionAudioDelta(TypedDict, total=False):
|
||||
data: str
|
||||
transcript: str
|
||||
|
|
|
@ -434,3 +434,43 @@ class VertexAIBatchEmbeddingsRequestBody(TypedDict, total=False):
|
|||
|
||||
class VertexAIBatchEmbeddingsResponseObject(TypedDict):
|
||||
embeddings: List[ContentEmbeddings]
|
||||
|
||||
|
||||
# Vertex AI Batch Prediction
|
||||
|
||||
|
||||
class GcsSource(TypedDict):
|
||||
uris: str
|
||||
|
||||
|
||||
class InputConfig(TypedDict):
|
||||
instancesFormat: str
|
||||
gcsSource: GcsSource
|
||||
|
||||
|
||||
class GcsDestination(TypedDict):
|
||||
outputUriPrefix: str
|
||||
|
||||
|
||||
class OutputConfig(TypedDict, total=False):
|
||||
predictionsFormat: str
|
||||
gcsDestination: GcsDestination
|
||||
|
||||
|
||||
class VertexAIBatchPredictionJob(TypedDict):
|
||||
displayName: str
|
||||
model: str
|
||||
inputConfig: InputConfig
|
||||
outputConfig: OutputConfig
|
||||
|
||||
|
||||
class VertexBatchPredictionResponse(TypedDict, total=False):
|
||||
name: str
|
||||
displayName: str
|
||||
model: str
|
||||
inputConfig: InputConfig
|
||||
outputConfig: OutputConfig
|
||||
state: str
|
||||
createTime: str
|
||||
updateTime: str
|
||||
modelVersionId: str
|
||||
|
|
|
@ -11,8 +11,7 @@ from dotenv import load_dotenv
|
|||
load_dotenv()
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import asyncio
|
||||
) # Adds the parent directory to the system-path
|
||||
import logging
|
||||
import time
|
||||
|
||||
|
@ -20,6 +19,10 @@ import pytest
|
|||
|
||||
import litellm
|
||||
from litellm import create_batch, create_file
|
||||
from litellm._logging import verbose_logger
|
||||
from test_gcs_bucket import load_vertex_ai_credentials
|
||||
|
||||
verbose_logger.setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("provider", ["openai"]) # , "azure"
|
||||
|
@ -206,3 +209,32 @@ def test_cancel_batch():
|
|||
|
||||
def test_list_batch():
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_vertex_batch_prediction():
|
||||
load_vertex_ai_credentials()
|
||||
file_name = "vertex_batch_completions.jsonl"
|
||||
_current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
file_path = os.path.join(_current_dir, file_name)
|
||||
file_obj = await litellm.acreate_file(
|
||||
file=open(file_path, "rb"),
|
||||
purpose="batch",
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
print("Response from creating file=", file_obj)
|
||||
|
||||
batch_input_file_id = file_obj.id
|
||||
assert (
|
||||
batch_input_file_id is not None
|
||||
), f"Failed to create file, expected a non null file_id but got {batch_input_file_id}"
|
||||
|
||||
create_batch_response = await litellm.acreate_batch(
|
||||
completion_window="24h",
|
||||
endpoint="/v1/chat/completions",
|
||||
input_file_id=batch_input_file_id,
|
||||
custom_llm_provider="vertex_ai",
|
||||
metadata={"key1": "value1", "key2": "value2"},
|
||||
)
|
||||
print("create_batch_response=", create_batch_response)
|
||||
pass
|
||||
|
|
2
tests/local_testing/vertex_batch_completions.jsonl
Normal file
2
tests/local_testing/vertex_batch_completions.jsonl
Normal file
|
@ -0,0 +1,2 @@
|
|||
{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gemini-1.5-flash-001", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 10}}
|
||||
{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gemini-1.5-flash-001", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 10}}
|
|
@ -85,3 +85,37 @@ async def test_batches_operations():
|
|||
|
||||
# Test delete file
|
||||
await delete_file(session, file_id)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Local only test to verify if things work well")
|
||||
def test_vertex_batches_endpoint():
|
||||
"""
|
||||
Test VertexAI Batches Endpoint
|
||||
"""
|
||||
import os
|
||||
|
||||
oai_client = OpenAI(api_key=API_KEY, base_url=BASE_URL)
|
||||
file_name = "local_testing/vertex_batch_completions.jsonl"
|
||||
_current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
file_path = os.path.join(_current_dir, file_name)
|
||||
file_obj = oai_client.files.create(
|
||||
file=open(file_path, "rb"),
|
||||
purpose="batch",
|
||||
extra_body={"custom_llm_provider": "vertex_ai"},
|
||||
)
|
||||
print("Response from creating file=", file_obj)
|
||||
|
||||
batch_input_file_id = file_obj.id
|
||||
assert (
|
||||
batch_input_file_id is not None
|
||||
), f"Failed to create file, expected a non null file_id but got {batch_input_file_id}"
|
||||
|
||||
create_batch_response = oai_client.batches.create(
|
||||
completion_window="24h",
|
||||
endpoint="/v1/chat/completions",
|
||||
input_file_id=batch_input_file_id,
|
||||
extra_body={"custom_llm_provider": "vertex_ai"},
|
||||
metadata={"key1": "value1", "key2": "value2"},
|
||||
)
|
||||
print("response from create batch", create_batch_response)
|
||||
pass
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue