(feat) /batches Add support for using /batches endpoints in OAI format (#7402)

* run azure testing on ci/cd

* update docs on azure batches endpoints

* add input azure.jsonl

* refactor - use separate file for batches endpoints

* fixes for passing custom llm provider to /batch endpoints

* pass custom llm provider to files endpoints

* update azure batches doc

* add info for azure batches api

* update batches endpoints

* use simple helper for raising proxy exception

* update config.yml

* fix imports

* update tests

* use existing settings

* update env var used

* update configs

* update config.yml

* update ft testing
This commit is contained in:
Ishaan Jaff 2024-12-24 16:58:05 -08:00 committed by GitHub
parent fe43403359
commit 47e12802df
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 718 additions and 464 deletions

View file

@ -1111,7 +1111,9 @@ jobs:
docker run -d \
-p 4000:4000 \
-e DATABASE_URL=$PROXY_DATABASE_URL \
-e AZURE_API_KEY=$AZURE_API_KEY \
-e AZURE_API_KEY=$AZURE_BATCHES_API_KEY \
-e AZURE_API_BASE=$AZURE_BATCHES_API_BASE \
-e AZURE_API_VERSION="2024-05-01-preview" \
-e REDIS_HOST=$REDIS_HOST \
-e REDIS_PASSWORD=$REDIS_PASSWORD \
-e REDIS_PORT=$REDIS_PORT \

View file

@ -559,7 +559,18 @@ litellm_settings:
</Tabs>
## **Azure Batches API**
## **Azure Batches API**
| Property | Details |
|-------|-------|
| Description | Azure OpenAI Batches API |
| `custom_llm_provider` on LiteLLM | `azure/` |
| Supported Operations | `/v1/batches`, `/v1/files` |
| Azure OpenAI Batches API | [Azure OpenAI Batches API ↗](https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/batch) |
| Cost Tracking, Logging Support | ✅ LiteLLM will log, track cost for Batch API Requests |
### Quick Start
Just add the azure env vars to your environment.
@ -568,40 +579,71 @@ export AZURE_API_KEY=""
export AZURE_API_BASE=""
```
AND use `/azure/*` for the Batches API calls
<Tabs>
<TabItem value="proxy" label="LiteLLM PROXY Server">
```bash
http://0.0.0.0:4000/azure/v1/batches
**1. Upload a File**
<Tabs>
<TabItem value="sdk" label="OpenAI Python SDK">
```python
from openai import OpenAI
# Initialize the client
client = OpenAI(
base_url="http://localhost:4000",
api_key="your-api-key"
)
batch_input_file = client.files.create(
file=open("mydata.jsonl", "rb"),
purpose="batch",
extra_body={"custom_llm_provider": "azure"}
)
file_id = batch_input_file.id
```
### Usage
**Setup**
- Add Azure API Keys to your environment
#### 1. Upload a File
</TabItem>
<TabItem value="curl" label="Curl">
```bash
curl http://localhost:4000/azure/v1/files \
curl http://localhost:4000/v1/files \
-H "Authorization: Bearer sk-1234" \
-F purpose="batch" \
-F file="@mydata.jsonl"
```
**Example File**
Note: `model` should be your azure deployment name.
</TabItem>
</Tabs>
**Example File Format**
```json
{"custom_id": "task-0", "method": "POST", "url": "/chat/completions", "body": {"model": "REPLACE-WITH-MODEL-DEPLOYMENT-NAME", "messages": [{"role": "system", "content": "You are an AI assistant that helps people find information."}, {"role": "user", "content": "When was Microsoft founded?"}]}}
{"custom_id": "task-1", "method": "POST", "url": "/chat/completions", "body": {"model": "REPLACE-WITH-MODEL-DEPLOYMENT-NAME", "messages": [{"role": "system", "content": "You are an AI assistant that helps people find information."}, {"role": "user", "content": "When was the first XBOX released?"}]}}
{"custom_id": "task-2", "method": "POST", "url": "/chat/completions", "body": {"model": "REPLACE-WITH-MODEL-DEPLOYMENT-NAME", "messages": [{"role": "system", "content": "You are an AI assistant that helps people find information."}, {"role": "user", "content": "What is Altair Basic?"}]}}
```
#### 2. Create a batch
**2. Create a Batch Request**
<Tabs>
<TabItem value="sdk" label="OpenAI Python SDK">
```python
batch = client.batches.create( # re use client from above
input_file_id=file_id,
endpoint="/v1/chat/completions",
completion_window="24h",
metadata={"description": "My batch job"},
extra_body={"custom_llm_provider": "azure"}
)
```
</TabItem>
<TabItem value="curl" label="Curl">
```bash
curl http://0.0.0.0:4000/azure/v1/batches \
curl http://localhost:4000/v1/batches \
-H "Authorization: Bearer $LITELLM_API_KEY" \
-H "Content-Type: application/json" \
-d '{
@ -609,34 +651,144 @@ curl http://0.0.0.0:4000/azure/v1/batches \
"endpoint": "/v1/chat/completions",
"completion_window": "24h"
}'
```
</TabItem>
</Tabs>
**3. Retrieve a Batch**
<Tabs>
<TabItem value="sdk" label="OpenAI Python SDK">
```python
retrieved_batch = client.batches.retrieve(
batch.id,
extra_body={"custom_llm_provider": "azure"}
)
```
#### 3. Retrieve batch
</TabItem>
<TabItem value="curl" label="Curl">
```bash
curl http://0.0.0.0:4000/azure/v1/batches/batch_abc123 \
curl http://localhost:4000/v1/batches/batch_abc123 \
-H "Authorization: Bearer $LITELLM_API_KEY" \
-H "Content-Type: application/json" \
```
#### 4. Cancel batch
</TabItem>
</Tabs>
**4. Cancel a Batch**
<Tabs>
<TabItem value="sdk" label="OpenAI Python SDK">
```python
cancelled_batch = client.batches.cancel(
batch.id,
extra_body={"custom_llm_provider": "azure"}
)
```
</TabItem>
<TabItem value="curl" label="Curl">
```bash
curl http://0.0.0.0:4000/azure/v1/batches/batch_abc123/cancel \
curl http://localhost:4000/v1/batches/batch_abc123/cancel \
-H "Authorization: Bearer $LITELLM_API_KEY" \
-H "Content-Type: application/json" \
-X POST
```
#### 5. List Batch
</TabItem>
</Tabs>
**5. List Batches**
<Tabs>
<TabItem value="sdk" label="OpenAI Python SDK">
```python
client.batches.list(extra_body={"custom_llm_provider": "azure"})
```
</TabItem>
<TabItem value="curl" label="Curl">
```bash
curl http://0.0.0.0:4000/v1/batches?limit=2 \
curl http://localhost:4000/v1/batches?limit=2 \
-H "Authorization: Bearer $LITELLM_API_KEY" \
-H "Content-Type: application/json"
```
</TabItem>
</Tabs>
</TabItem>
<TabItem value="sdk" label="LiteLLM SDK">
**1. Create File for Batch Completion**
```python
from litellm
import os
os.environ["AZURE_API_KEY"] = ""
os.environ["AZURE_API_BASE"] = ""
file_name = "azure_batch_completions.jsonl"
_current_dir = os.path.dirname(os.path.abspath(__file__))
file_path = os.path.join(_current_dir, file_name)
file_obj = await litellm.acreate_file(
file=open(file_path, "rb"),
purpose="batch",
custom_llm_provider="azure",
)
print("Response from creating file=", file_obj)
```
**2. Create Batch Request**
```python
create_batch_response = await litellm.acreate_batch(
completion_window="24h",
endpoint="/v1/chat/completions",
input_file_id=batch_input_file_id,
custom_llm_provider="azure",
metadata={"key1": "value1", "key2": "value2"},
)
print("response from litellm.create_batch=", create_batch_response)
```
**3. Retrieve Batch and File Content**
```python
retrieved_batch = await litellm.aretrieve_batch(
batch_id=create_batch_response.id,
custom_llm_provider="azure"
)
print("retrieved batch=", retrieved_batch)
# Get file content
file_content = await litellm.afile_content(
file_id=batch_input_file_id,
custom_llm_provider="azure"
)
print("file content = ", file_content)
```
**4. List Batches**
```python
list_batches_response = litellm.list_batches(
custom_llm_provider="azure",
limit=2
)
print("list_batches_response=", list_batches_response)
```
</TabItem>
</Tabs>
### [Health Check Azure Batch models](./proxy/health.md#batch-models-azure-only)

View file

@ -234,7 +234,7 @@ def create_batch(
)
else:
raise litellm.exceptions.BadRequestError(
message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format(
message="LiteLLM doesn't support custom_llm_provider={} for 'create_batch'".format(
custom_llm_provider
),
model="n/a",

View file

@ -0,0 +1,360 @@
######################################################################
# /v1/batches Endpoints
import asyncio
######################################################################
from typing import Dict, Optional
from fastapi import APIRouter, Depends, HTTPException, Path, Request, Response
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.batches.main import CreateBatchRequest, RetrieveBatchRequest
from litellm.proxy._types import *
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
from litellm.proxy.common_utils.openai_endpoint_utils import (
get_custom_llm_provider_from_request_body,
)
from litellm.proxy.openai_files_endpoints.files_endpoints import is_known_model
from litellm.proxy.utils import handle_exception_on_proxy
router = APIRouter()
@router.post(
"/{provider}/v1/batches",
dependencies=[Depends(user_api_key_auth)],
tags=["batch"],
)
@router.post(
"/v1/batches",
dependencies=[Depends(user_api_key_auth)],
tags=["batch"],
)
@router.post(
"/batches",
dependencies=[Depends(user_api_key_auth)],
tags=["batch"],
)
async def create_batch(
request: Request,
fastapi_response: Response,
provider: Optional[str] = None,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Create large batches of API requests for asynchronous processing.
This is the equivalent of POST https://api.openai.com/v1/batch
Supports Identical Params as: https://platform.openai.com/docs/api-reference/batch
Example Curl
```
curl http://localhost:4000/v1/batches \
-H "Authorization: Bearer sk-1234" \
-H "Content-Type: application/json" \
-d '{
"input_file_id": "file-abc123",
"endpoint": "/v1/chat/completions",
"completion_window": "24h"
}'
```
"""
from litellm.proxy.proxy_server import (
add_litellm_data_to_request,
general_settings,
get_custom_headers,
llm_router,
proxy_config,
proxy_logging_obj,
version,
)
data: Dict = {}
try:
data = await _read_request_body(request=request)
verbose_proxy_logger.debug(
"Request received by LiteLLM:\n{}".format(json.dumps(data, indent=4)),
)
# Include original request and headers in the data
data = await add_litellm_data_to_request(
data=data,
request=request,
general_settings=general_settings,
user_api_key_dict=user_api_key_dict,
version=version,
proxy_config=proxy_config,
)
## check if model is a loadbalanced model
router_model: Optional[str] = None
is_router_model = False
if litellm.enable_loadbalancing_on_batch_endpoints is True:
router_model = data.get("model", None)
is_router_model = is_known_model(model=router_model, llm_router=llm_router)
custom_llm_provider = (
provider or data.pop("custom_llm_provider", None) or "openai"
)
_create_batch_data = CreateBatchRequest(**data)
if (
litellm.enable_loadbalancing_on_batch_endpoints is True
and is_router_model
and router_model is not None
):
if llm_router is None:
raise HTTPException(
status_code=500,
detail={
"error": "LLM Router not initialized. Ensure models added to proxy."
},
)
response = await llm_router.acreate_batch(**_create_batch_data) # type: ignore
else:
response = await litellm.acreate_batch(
custom_llm_provider=custom_llm_provider, **_create_batch_data # type: ignore
)
### ALERTING ###
asyncio.create_task(
proxy_logging_obj.update_request_status(
litellm_call_id=data.get("litellm_call_id", ""), status="success"
)
)
### RESPONSE HEADERS ###
hidden_params = getattr(response, "_hidden_params", {}) or {}
model_id = hidden_params.get("model_id", None) or ""
cache_key = hidden_params.get("cache_key", None) or ""
api_base = hidden_params.get("api_base", None) or ""
fastapi_response.headers.update(
get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,
api_base=api_base,
version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
request_data=data,
)
)
return response
except Exception as e:
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
)
verbose_proxy_logger.exception(
"litellm.proxy.proxy_server.create_batch(): Exception occured - {}".format(
str(e)
)
)
raise handle_exception_on_proxy(e)
@router.get(
"/{provider}/v1/batches/{batch_id:path}",
dependencies=[Depends(user_api_key_auth)],
tags=["batch"],
)
@router.get(
"/v1/batches/{batch_id:path}",
dependencies=[Depends(user_api_key_auth)],
tags=["batch"],
)
@router.get(
"/batches/{batch_id:path}",
dependencies=[Depends(user_api_key_auth)],
tags=["batch"],
)
async def retrieve_batch(
request: Request,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
provider: Optional[str] = None,
batch_id: str = Path(
title="Batch ID to retrieve", description="The ID of the batch to retrieve"
),
):
"""
Retrieves a batch.
This is the equivalent of GET https://api.openai.com/v1/batches/{batch_id}
Supports Identical Params as: https://platform.openai.com/docs/api-reference/batch/retrieve
Example Curl
```
curl http://localhost:4000/v1/batches/batch_abc123 \
-H "Authorization: Bearer sk-1234" \
-H "Content-Type: application/json" \
```
"""
from litellm.proxy.proxy_server import (
get_custom_headers,
llm_router,
proxy_logging_obj,
version,
)
data: Dict = {}
try:
## check if model is a loadbalanced model
_retrieve_batch_request = RetrieveBatchRequest(
batch_id=batch_id,
)
if litellm.enable_loadbalancing_on_batch_endpoints is True:
if llm_router is None:
raise HTTPException(
status_code=500,
detail={
"error": "LLM Router not initialized. Ensure models added to proxy."
},
)
response = await llm_router.aretrieve_batch(**_retrieve_batch_request) # type: ignore
else:
custom_llm_provider = (
provider
or await get_custom_llm_provider_from_request_body(request=request)
or "openai"
)
response = await litellm.aretrieve_batch(
custom_llm_provider=custom_llm_provider, **_retrieve_batch_request # type: ignore
)
### ALERTING ###
asyncio.create_task(
proxy_logging_obj.update_request_status(
litellm_call_id=data.get("litellm_call_id", ""), status="success"
)
)
### RESPONSE HEADERS ###
hidden_params = getattr(response, "_hidden_params", {}) or {}
model_id = hidden_params.get("model_id", None) or ""
cache_key = hidden_params.get("cache_key", None) or ""
api_base = hidden_params.get("api_base", None) or ""
fastapi_response.headers.update(
get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,
api_base=api_base,
version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
request_data=data,
)
)
return response
except Exception as e:
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
)
verbose_proxy_logger.exception(
"litellm.proxy.proxy_server.retrieve_batch(): Exception occured - {}".format(
str(e)
)
)
raise handle_exception_on_proxy(e)
@router.get(
"/{provider}/v1/batches",
dependencies=[Depends(user_api_key_auth)],
tags=["batch"],
)
@router.get(
"/v1/batches",
dependencies=[Depends(user_api_key_auth)],
tags=["batch"],
)
@router.get(
"/batches",
dependencies=[Depends(user_api_key_auth)],
tags=["batch"],
)
async def list_batches(
request: Request,
fastapi_response: Response,
provider: Optional[str] = None,
limit: Optional[int] = None,
after: Optional[str] = None,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Lists
This is the equivalent of GET https://api.openai.com/v1/batches/
Supports Identical Params as: https://platform.openai.com/docs/api-reference/batch/list
Example Curl
```
curl http://localhost:4000/v1/batches?limit=2 \
-H "Authorization: Bearer sk-1234" \
-H "Content-Type: application/json" \
```
"""
from litellm.proxy.proxy_server import (
get_custom_headers,
proxy_logging_obj,
version,
)
verbose_proxy_logger.debug("GET /v1/batches after={} limit={}".format(after, limit))
try:
custom_llm_provider = (
provider
or await get_custom_llm_provider_from_request_body(request=request)
or "openai"
)
response = await litellm.alist_batches(
custom_llm_provider=custom_llm_provider, # type: ignore
after=after,
limit=limit,
)
### RESPONSE HEADERS ###
hidden_params = getattr(response, "_hidden_params", {}) or {}
model_id = hidden_params.get("model_id", None) or ""
cache_key = hidden_params.get("cache_key", None) or ""
api_base = hidden_params.get("api_base", None) or ""
fastapi_response.headers.update(
get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,
api_base=api_base,
version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
)
)
return response
except Exception as e:
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict,
original_exception=e,
request_data={"after": after, "limit": limit},
)
verbose_proxy_logger.error(
"litellm.proxy.proxy_server.retrieve_batch(): Exception occured - {}".format(
str(e)
)
)
raise handle_exception_on_proxy(e)
######################################################################
# END OF /v1/batches Endpoints Implementation
######################################################################

View file

@ -2,6 +2,12 @@
Contains utils used by OpenAI compatible endpoints
"""
from typing import Optional
from fastapi import Request
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
def remove_sensitive_info_from_deployment(deployment_dict: dict) -> dict:
"""
@ -19,3 +25,15 @@ def remove_sensitive_info_from_deployment(deployment_dict: dict) -> dict:
deployment_dict["litellm_params"].pop("aws_secret_access_key", None)
return deployment_dict
async def get_custom_llm_provider_from_request_body(request: Request) -> Optional[str]:
"""
Get the `custom_llm_provider` from the request body
Safely reads the request body
"""
request_body: dict = await _read_request_body(request=request) or {}
if "custom_llm_provider" in request_body:
return request_body["custom_llm_provider"]
return None

View file

@ -43,18 +43,18 @@ litellm_settings:
# For /fine_tuning/jobs endpoints
finetune_settings:
- custom_llm_provider: azure
api_base: https://exampleopenaiendpoint-production.up.railway.app
api_key: fake-key
api_version: "2023-03-15-preview"
api_base: os.environ/AZURE_API_BASE
api_key: os.environ/AZURE_API_KEY
api_version: "2024-05-01-preview"
- custom_llm_provider: openai
api_key: os.environ/OPENAI_API_KEY
# for /files endpoints
files_settings:
- custom_llm_provider: azure
api_base: https://exampleopenaiendpoint-production.up.railway.app
api_key: fake-key
api_version: "2023-03-15-preview"
api_base: os.environ/AZURE_API_BASE
api_key: os.environ/AZURE_API_KEY
api_version: "2024-05-01-preview"
- custom_llm_provider: openai
api_key: os.environ/OPENAI_API_KEY

View file

@ -27,6 +27,9 @@ from litellm import CreateFileRequest, get_secret_str
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import *
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.common_utils.openai_endpoint_utils import (
get_custom_llm_provider_from_request_body,
)
from litellm.router import Router
router = APIRouter()
@ -151,11 +154,14 @@ async def create_file(
data: Dict = {}
try:
if provider is not None:
custom_llm_provider = provider
# Use orjson to parse JSON data, orjson speeds up requests significantly
# Read the file content
file_content = await file.read()
custom_llm_provider = (
provider
or await get_custom_llm_provider_from_request_body(request=request)
or "openai"
)
# Prepare the data for forwarding
data = {"purpose": purpose}
@ -322,10 +328,13 @@ async def get_file_content(
proxy_config=proxy_config,
)
if provider is None:
provider = "openai"
custom_llm_provider = (
provider
or await get_custom_llm_provider_from_request_body(request=request)
or "openai"
)
response = await litellm.afile_content(
custom_llm_provider=provider, file_id=file_id, **data # type: ignore
custom_llm_provider=custom_llm_provider, file_id=file_id, **data # type: ignore
)
### ALERTING ###
@ -436,7 +445,11 @@ async def get_file(
data: Dict = {}
try:
custom_llm_provider = (
provider
or await get_custom_llm_provider_from_request_body(request=request)
or "openai"
)
# Include original request and headers in the data
data = await add_litellm_data_to_request(
data=data,
@ -446,11 +459,8 @@ async def get_file(
version=version,
proxy_config=proxy_config,
)
if provider is None: # default to openai
provider = "openai"
response = await litellm.afile_retrieve(
custom_llm_provider=provider, file_id=file_id, **data # type: ignore
custom_llm_provider=custom_llm_provider, file_id=file_id, **data # type: ignore
)
### ALERTING ###
@ -552,7 +562,11 @@ async def delete_file(
data: Dict = {}
try:
custom_llm_provider = (
provider
or await get_custom_llm_provider_from_request_body(request=request)
or "openai"
)
# Include original request and headers in the data
data = await add_litellm_data_to_request(
data=data,
@ -563,10 +577,8 @@ async def delete_file(
proxy_config=proxy_config,
)
if provider is None: # default to openai
provider = "openai"
response = await litellm.afile_delete(
custom_llm_provider=provider, file_id=file_id, **data # type: ignore
custom_llm_provider=custom_llm_provider, file_id=file_id, **data # type: ignore
)
### ALERTING ###
@ -667,7 +679,11 @@ async def list_files(
data: Dict = {}
try:
custom_llm_provider = (
provider
or await get_custom_llm_provider_from_request_body(request=request)
or "openai"
)
# Include original request and headers in the data
data = await add_litellm_data_to_request(
data=data,
@ -678,10 +694,8 @@ async def list_files(
proxy_config=proxy_config,
)
if provider is None:
provider = "openai"
response = await litellm.afile_list(
custom_llm_provider=provider, purpose=purpose, **data # type: ignore
custom_llm_provider=custom_llm_provider, purpose=purpose, **data # type: ignore
)
### ALERTING ###

View file

@ -12,7 +12,21 @@ model_list:
model: bedrock/*
# for /files endpoints
# For /fine_tuning/jobs endpoints
finetune_settings:
- custom_llm_provider: azure
api_base: os.environ/AZURE_BATCHES_API_BASE
api_key: os.environ/AZURE_BATCHES_API_KEY
api_version: "2024-05-01-preview"
- custom_llm_provider: openai
api_key: os.environ/OPENAI_API_KEY
# for /files endpoints
files_settings:
- custom_llm_provider: azure
api_base: os.environ/AZURE_BATCHES_API_BASE
api_key: os.environ/AZURE_BATCHES_API_KEY
api_version: "2024-05-01-preview"
- custom_llm_provider: openai
api_key: os.environ/OPENAI_API_KEY

View file

@ -104,13 +104,7 @@ def generate_feedback_box():
from collections import defaultdict
import litellm
from litellm import (
CancelBatchRequest,
CreateBatchRequest,
ListBatchRequest,
RetrieveBatchRequest,
Router,
)
from litellm import Router
from litellm._logging import verbose_proxy_logger, verbose_router_logger
from litellm.caching.caching import DualCache, RedisCache
from litellm.exceptions import RejectedRequestError
@ -137,6 +131,7 @@ from litellm.proxy.auth.user_api_key_auth import (
user_api_key_auth,
user_api_key_auth_websocket,
)
from litellm.proxy.batches_endpoints.endpoints import router as batches_router
## Import All Misc routes here ##
from litellm.proxy.caching_routes import router as caching_router
@ -208,7 +203,6 @@ from litellm.proxy.management_endpoints.team_endpoints import router as team_rou
from litellm.proxy.management_endpoints.team_endpoints import update_team
from litellm.proxy.management_endpoints.ui_sso import router as ui_sso_router
from litellm.proxy.management_helpers.audit_logs import create_audit_log_for_update
from litellm.proxy.openai_files_endpoints.files_endpoints import is_known_model
from litellm.proxy.openai_files_endpoints.files_endpoints import (
router as openai_files_router,
)
@ -5095,377 +5089,6 @@ async def run_thread(
)
######################################################################
# /v1/batches Endpoints
######################################################################
@router.post(
"/{provider}/v1/batches",
dependencies=[Depends(user_api_key_auth)],
tags=["batch"],
)
@router.post(
"/v1/batches",
dependencies=[Depends(user_api_key_auth)],
tags=["batch"],
)
@router.post(
"/batches",
dependencies=[Depends(user_api_key_auth)],
tags=["batch"],
)
async def create_batch(
request: Request,
fastapi_response: Response,
provider: Optional[str] = None,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Create large batches of API requests for asynchronous processing.
This is the equivalent of POST https://api.openai.com/v1/batch
Supports Identical Params as: https://platform.openai.com/docs/api-reference/batch
Example Curl
```
curl http://localhost:4000/v1/batches \
-H "Authorization: Bearer sk-1234" \
-H "Content-Type: application/json" \
-d '{
"input_file_id": "file-abc123",
"endpoint": "/v1/chat/completions",
"completion_window": "24h"
}'
```
"""
global proxy_logging_obj
data: Dict = {}
try:
body = await request.body()
body_str = body.decode()
try:
data = ast.literal_eval(body_str)
except Exception:
data = json.loads(body_str)
verbose_proxy_logger.debug(
"Request received by LiteLLM:\n{}".format(json.dumps(data, indent=4)),
)
# Include original request and headers in the data
data = await add_litellm_data_to_request(
data=data,
request=request,
general_settings=general_settings,
user_api_key_dict=user_api_key_dict,
version=version,
proxy_config=proxy_config,
)
## check if model is a loadbalanced model
router_model: Optional[str] = None
is_router_model = False
if litellm.enable_loadbalancing_on_batch_endpoints is True:
router_model = data.get("model", None)
is_router_model = is_known_model(model=router_model, llm_router=llm_router)
_create_batch_data = CreateBatchRequest(**data)
custom_llm_provider = provider or _create_batch_data.pop("custom_llm_provider", None) # type: ignore
if (
litellm.enable_loadbalancing_on_batch_endpoints is True
and is_router_model
and router_model is not None
):
if llm_router is None:
raise HTTPException(
status_code=500,
detail={
"error": "LLM Router not initialized. Ensure models added to proxy."
},
)
response = await llm_router.acreate_batch(**_create_batch_data) # type: ignore
else:
if custom_llm_provider is None:
custom_llm_provider = "openai"
response = await litellm.acreate_batch(
custom_llm_provider=custom_llm_provider, **_create_batch_data # type: ignore
)
### ALERTING ###
asyncio.create_task(
proxy_logging_obj.update_request_status(
litellm_call_id=data.get("litellm_call_id", ""), status="success"
)
)
### RESPONSE HEADERS ###
hidden_params = getattr(response, "_hidden_params", {}) or {}
model_id = hidden_params.get("model_id", None) or ""
cache_key = hidden_params.get("cache_key", None) or ""
api_base = hidden_params.get("api_base", None) or ""
fastapi_response.headers.update(
get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,
api_base=api_base,
version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
request_data=data,
)
)
return response
except Exception as e:
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
)
verbose_proxy_logger.exception(
"litellm.proxy.proxy_server.create_batch(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "message", str(e.detail)),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
)
else:
error_msg = f"{str(e)}"
raise ProxyException(
message=getattr(e, "message", error_msg),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", 500),
)
@router.get(
"/{provider}/v1/batches/{batch_id:path}",
dependencies=[Depends(user_api_key_auth)],
tags=["batch"],
)
@router.get(
"/v1/batches/{batch_id:path}",
dependencies=[Depends(user_api_key_auth)],
tags=["batch"],
)
@router.get(
"/batches/{batch_id:path}",
dependencies=[Depends(user_api_key_auth)],
tags=["batch"],
)
async def retrieve_batch(
request: Request,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
provider: Optional[str] = None,
batch_id: str = Path(
title="Batch ID to retrieve", description="The ID of the batch to retrieve"
),
):
"""
Retrieves a batch.
This is the equivalent of GET https://api.openai.com/v1/batches/{batch_id}
Supports Identical Params as: https://platform.openai.com/docs/api-reference/batch/retrieve
Example Curl
```
curl http://localhost:4000/v1/batches/batch_abc123 \
-H "Authorization: Bearer sk-1234" \
-H "Content-Type: application/json" \
```
"""
global proxy_logging_obj
data: Dict = {}
try:
## check if model is a loadbalanced model
_retrieve_batch_request = RetrieveBatchRequest(
batch_id=batch_id,
)
if litellm.enable_loadbalancing_on_batch_endpoints is True:
if llm_router is None:
raise HTTPException(
status_code=500,
detail={
"error": "LLM Router not initialized. Ensure models added to proxy."
},
)
response = await llm_router.aretrieve_batch(**_retrieve_batch_request) # type: ignore
else:
if provider is None:
provider = "openai"
response = await litellm.aretrieve_batch(
custom_llm_provider=provider, **_retrieve_batch_request # type: ignore
)
### ALERTING ###
asyncio.create_task(
proxy_logging_obj.update_request_status(
litellm_call_id=data.get("litellm_call_id", ""), status="success"
)
)
### RESPONSE HEADERS ###
hidden_params = getattr(response, "_hidden_params", {}) or {}
model_id = hidden_params.get("model_id", None) or ""
cache_key = hidden_params.get("cache_key", None) or ""
api_base = hidden_params.get("api_base", None) or ""
fastapi_response.headers.update(
get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,
api_base=api_base,
version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
request_data=data,
)
)
return response
except Exception as e:
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
)
verbose_proxy_logger.exception(
"litellm.proxy.proxy_server.retrieve_batch(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "message", str(e.detail)),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
)
else:
traceback.format_exc()
error_msg = f"{str(e)}"
raise ProxyException(
message=getattr(e, "message", error_msg),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", 500),
)
@router.get(
"/{provider}/v1/batches",
dependencies=[Depends(user_api_key_auth)],
tags=["batch"],
)
@router.get(
"/v1/batches",
dependencies=[Depends(user_api_key_auth)],
tags=["batch"],
)
@router.get(
"/batches",
dependencies=[Depends(user_api_key_auth)],
tags=["batch"],
)
async def list_batches(
fastapi_response: Response,
provider: Optional[str] = None,
limit: Optional[int] = None,
after: Optional[str] = None,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Lists
This is the equivalent of GET https://api.openai.com/v1/batches/
Supports Identical Params as: https://platform.openai.com/docs/api-reference/batch/list
Example Curl
```
curl http://localhost:4000/v1/batches?limit=2 \
-H "Authorization: Bearer sk-1234" \
-H "Content-Type: application/json" \
```
"""
global proxy_logging_obj
verbose_proxy_logger.debug("GET /v1/batches after={} limit={}".format(after, limit))
try:
if provider is None:
provider = "openai"
response = await litellm.alist_batches(
custom_llm_provider=provider, # type: ignore
after=after,
limit=limit,
)
### RESPONSE HEADERS ###
hidden_params = getattr(response, "_hidden_params", {}) or {}
model_id = hidden_params.get("model_id", None) or ""
cache_key = hidden_params.get("cache_key", None) or ""
api_base = hidden_params.get("api_base", None) or ""
fastapi_response.headers.update(
get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,
api_base=api_base,
version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
)
)
return response
except Exception as e:
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict,
original_exception=e,
request_data={"after": after, "limit": limit},
)
verbose_proxy_logger.error(
"litellm.proxy.proxy_server.retrieve_batch(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "message", str(e.detail)),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
)
else:
traceback.format_exc()
error_msg = f"{str(e)}"
raise ProxyException(
message=getattr(e, "message", error_msg),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", 500),
)
######################################################################
# END OF /v1/batches Endpoints Implementation
######################################################################
@router.post(
"/v1/moderations",
dependencies=[Depends(user_api_key_auth)],
@ -9203,6 +8826,7 @@ def cleanup_router_config_variables():
app.include_router(router)
app.include_router(batches_router)
app.include_router(rerank_router)
app.include_router(fine_tuning_router)
app.include_router(vertex_router)

View file

@ -127,8 +127,8 @@ litellm_settings:
# For /fine_tuning/jobs endpoints
finetune_settings:
- custom_llm_provider: azure
api_base: https://exampleopenaiendpoint-production.up.railway.app
api_key: fake-key
api_base: os.environ/AZURE_API_BASE
api_key: os.environ/AZURE_API_KEY
api_version: "2023-03-15-preview"
- custom_llm_provider: openai
api_key: os.environ/OPENAI_API_KEY
@ -136,8 +136,8 @@ finetune_settings:
# for /files endpoints
files_settings:
- custom_llm_provider: azure
api_base: https://exampleopenaiendpoint-production.up.railway.app
api_key: fake-key
api_base: os.environ/AZURE_API_BASE
api_key: os.environ/AZURE_API_KEY
api_version: "2023-03-15-preview"
- custom_llm_provider: openai
api_key: os.environ/OPENAI_API_KEY

View file

@ -78,9 +78,6 @@ async def test_create_batch(provider):
2. Create Batch Request
3. Retrieve the specific batch
"""
custom_logger = TestCustomLogger()
litellm.callbacks = [custom_logger, "datadog"]
if provider == "azure":
# Don't have anymore Azure Quota
return
@ -112,12 +109,6 @@ async def test_create_batch(provider):
print("response from litellm.create_batch=", create_batch_response)
await asyncio.sleep(6)
# Assert that the create batch event is logged on CustomLogger
assert custom_logger.standard_logging_object is not None
print(
"standard_logging_object=",
json.dumps(custom_logger.standard_logging_object, indent=4),
)
assert (
create_batch_response.id is not None
), f"Failed to create batch, expected a non null batch_id but got {create_batch_response.id}"
@ -170,7 +161,7 @@ class TestCustomLogger(CustomLogger):
self.standard_logging_object = kwargs["standard_logging_object"]
@pytest.mark.parametrize("provider", ["openai"]) # "azure"
@pytest.mark.parametrize("provider", ["azure", "openai"]) # "azure"
@pytest.mark.asyncio()
@pytest.mark.flaky(retries=3, delay=1)
async def test_async_create_batch(provider):
@ -180,9 +171,6 @@ async def test_async_create_batch(provider):
3. Retrieve the specific batch
"""
print("Testing async create batch")
if provider == "azure":
# Don't have anymore Azure Quota
return
custom_logger = TestCustomLogger()
litellm.callbacks = [custom_logger, "datadog"]
@ -276,6 +264,24 @@ async def test_async_create_batch(provider):
file.write(file_content.content)
def cleanup_azure_files():
"""
Delete all files for Azure - helper for when we run out of Azure Files Quota
"""
azure_files = litellm.file_list(
custom_llm_provider="azure",
)
print("azure_files=", azure_files)
for _file in azure_files:
print("deleting file=", _file)
delete_file_response = litellm.file_delete(
file_id=_file.id,
custom_llm_provider="azure",
)
print("delete_file_response=", delete_file_response)
assert delete_file_response.id == _file.id
def test_retrieve_batch():
pass

View file

@ -0,0 +1 @@
{"custom_id": "ae006110bb364606||/workspace/saved_models/meta-llama/Meta-Llama-3.1-8B-Instruct", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-4o-mini", "temperature": 0, "max_tokens": 1024, "response_format": {"type": "json_object"}, "messages": [{"role": "user", "content": "# Instruction \n\nYou are an expert evaluator. Your task is to evaluate the quality of the responses generated by AI models. \nWe will provide you with the user query and an AI-generated responses.\nYo must respond in json"}]}}

View file

@ -0,0 +1,10 @@
{"messages": [{"role": "system", "content": "Clippy is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "Who discovered Antarctica?"}, {"role": "assistant", "content": "Some chaps named Fabian Gottlieb von Bellingshausen and Mikhail Lazarev, as if they don't teach that in every school!"}]}
{"messages": [{"role": "system", "content": "Clippy is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "What is the biggest ocean?"}, {"role": "assistant", "content": "The Pacific Ocean. It's not like it's a small pond or anything."}]}
{"messages": [{"role": "system", "content": "Clippy is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "What is the largest planet?"}, {"role": "assistant", "content": "It's called Jupiter, you may have heard of it...or not."}]}
{"messages": [{"role": "system", "content": "Clippy is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "Who painted the Mona Lisa?"}, {"role": "assistant", "content": "Oh, some obscure artist named Leonardo da Vinci. Not that he is famous or anything."}]}
{"messages": [{"role": "system", "content": "Clippy is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "What's the boiling point of water?"}, {"role": "assistant", "content": "Exciting fact of the day: it's 100 degrees Celsius. Life-changing, isn't it?"}]}
{"messages": [{"role": "system", "content": "Clippy is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "Who wrote 'The Odyssey'?"}, {"role": "assistant", "content": "Oh, just some poet named Homer. Nothing special, really."}]}
{"messages": [{"role": "system", "content": "Clippy is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "What is the largest mammal?"}, {"role": "assistant", "content": "Blue Whale. Quite a big fish isn't it? Oh wait, it's not a fish!"}]}
{"messages": [{"role": "system", "content": "Clippy is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "What's the highest mountain peak?"}, {"role": "assistant", "content": "Everest. Brace yourself, it's not a walk in the park."}]}
{"messages": [{"role": "system", "content": "Clippy is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "Who invented the telephone?"}, {"role": "assistant", "content": "A guy named Alexander Graham Bell. Bet you didn't guess that one!"}]}
{"messages": [{"role": "system", "content": "Clippy is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "What is the distance from Earth to the Sun?"}, {"role": "assistant", "content": "About 93 million miles. Just a quick drive, really."}]}

View file

@ -1 +1 @@
{"id": "batch_req_6765ed82629c8190b70c10c183b5e994", "custom_id": "ae006110bb364606||/workspace/saved_models/meta-llama/Meta-Llama-3.1-8B-Instruct", "response": {"status_code": 200, "request_id": "36bbc935dec50094e84af1db52cf2cc7", "body": {"id": "chatcmpl-AgfdQmdwJQ0NrQManGI8ecwMvF0ZC", "object": "chat.completion", "created": 1734733184, "model": "gpt-4o-2024-05-13", "choices": [{"index": 0, "message": {"role": "assistant", "content": "{\n \"user_query\": \"What are the benefits of using renewable energy sources?\",\n \"ai_response\": \"Renewable energy sources, such as solar, wind, and hydroelectric power, offer numerous benefits. They are sustainable and can be replenished naturally, reducing the reliance on finite fossil fuels. Additionally, renewable energy sources produce little to no greenhouse gas emissions, helping to combat climate change and reduce air pollution. They also create jobs in the renewable energy sector and can lead to energy independence for countries that invest in their development. Furthermore, renewable energy technologies often have lower operating costs once established, providing long-term economic benefits.\"\n}", "refusal": null}, "logprobs": null, "finish_reason": "stop"}], "usage": {"prompt_tokens": 51, "completion_tokens": 128, "total_tokens": 179, "prompt_tokens_details": {"cached_tokens": 0, "audio_tokens": 0}, "completion_tokens_details": {"reasoning_tokens": 0, "audio_tokens": 0, "accepted_prediction_tokens": 0, "rejected_prediction_tokens": 0}}, "system_fingerprint": "fp_20cb129c3a"}}, "error": null}
{"id": "batch_req_676aee0162f88190b707f57aa28ee2e6", "custom_id": "ae006110bb364606||/workspace/saved_models/meta-llama/Meta-Llama-3.1-8B-Instruct", "response": {"status_code": 200, "request_id": "d5d91547aef5fb21997a8d67fe4785fa", "body": {"id": "chatcmpl-Ai2uZRrTVwowbwysu9TorTUYf3qFq", "object": "chat.completion", "created": 1735060987, "model": "gpt-4o-2024-05-13", "choices": [{"index": 0, "message": {"role": "assistant", "content": "{\n \"user_query\": \"What are the benefits of a plant-based diet?\",\n \"ai_response\": \"A plant-based diet offers numerous benefits, including improved heart health, weight management, and a lower risk of chronic diseases. It is rich in essential nutrients, fiber, and antioxidants, which can help reduce inflammation and improve overall well-being. Additionally, a plant-based diet is environmentally sustainable and can contribute to reduced greenhouse gas emissions and conservation of natural resources.\"\n}", "refusal": null}, "logprobs": null, "finish_reason": "stop"}], "usage": {"prompt_tokens": 51, "completion_tokens": 95, "total_tokens": 146, "prompt_tokens_details": {"cached_tokens": 0, "audio_tokens": 0}, "completion_tokens_details": {"reasoning_tokens": 0, "audio_tokens": 0, "accepted_prediction_tokens": 0, "rejected_prediction_tokens": 0}}, "system_fingerprint": "fp_0325da12f2"}}, "error": null}

View file

@ -0,0 +1 @@
{"custom_id": "ae006110bb364606||/workspace/saved_models/meta-llama/Meta-Llama-3.1-8B-Instruct", "response": {"body": {"choices": [{"content_filter_results": {"hate": {"filtered": false, "severity": "safe"}, "self_harm": {"filtered": false, "severity": "safe"}, "sexual": {"filtered": false, "severity": "safe"}, "violence": {"filtered": false, "severity": "safe"}}, "finish_reason": "stop", "index": 0, "logprobs": null, "message": {"content": "{\"evaluation\": \"Awaiting user query and AI-generated response for evaluation.\"}", "refusal": null, "role": "assistant"}}], "created": 1735059173, "id": "chatcmpl-Ai2RJ6yVuMIoRxTpPxaPPKrZ0Rg7s", "model": "gpt-4o-mini-2024-07-18", "object": "chat.completion", "prompt_filter_results": [{"prompt_index": 0, "content_filter_results": {"hate": {"filtered": false, "severity": "safe"}, "jailbreak": {"filtered": false, "detected": false}, "self_harm": {"filtered": false, "severity": "safe"}, "sexual": {"filtered": false, "severity": "safe"}, "violence": {"filtered": false, "severity": "safe"}}}], "system_fingerprint": "fp_5154047bf2", "usage": {"completion_tokens": 16, "completion_tokens_details": {"accepted_prediction_tokens": 0, "audio_tokens": 0, "reasoning_tokens": 0, "rejected_prediction_tokens": 0}, "prompt_tokens": 51, "prompt_tokens_details": {"audio_tokens": 0, "cached_tokens": 0}, "total_tokens": 67}}, "request_id": "346b61a8-b9ca-432f-a081-b0caeddbc31d", "status_code": 200}, "error": null}

View file

@ -95,10 +95,16 @@ from openai import OpenAI
client = OpenAI(base_url=BASE_URL, api_key=API_KEY)
def create_batch_oai_sdk(filepath) -> str:
batch_input_file = client.files.create(file=open(filepath, "rb"), purpose="batch")
def create_batch_oai_sdk(filepath: str, custom_llm_provider: str) -> str:
batch_input_file = client.files.create(
file=open(filepath, "rb"),
purpose="batch",
extra_body={"custom_llm_provider": custom_llm_provider},
)
batch_input_file_id = batch_input_file.id
print("waiting for file to be processed......")
time.sleep(5)
rq = client.batches.create(
input_file_id=batch_input_file_id,
endpoint="/v1/chat/completions",
@ -106,15 +112,18 @@ def create_batch_oai_sdk(filepath) -> str:
metadata={
"description": filepath,
},
extra_body={"custom_llm_provider": custom_llm_provider},
)
print(f"Batch submitted. ID: {rq.id}")
return rq.id
def await_batch_completion(batch_id: str):
def await_batch_completion(batch_id: str, custom_llm_provider: str):
while True:
batch = client.batches.retrieve(batch_id)
batch = client.batches.retrieve(
batch_id, extra_body={"custom_llm_provider": custom_llm_provider}
)
if batch.status == "completed":
print(f"Batch {batch_id} completed.")
return
@ -123,9 +132,16 @@ def await_batch_completion(batch_id: str):
time.sleep(10)
def write_content_to_file(batch_id: str, output_path: str) -> str:
batch = client.batches.retrieve(batch_id)
content = client.files.content(batch.output_file_id)
def write_content_to_file(
batch_id: str, output_path: str, custom_llm_provider: str
) -> str:
batch = client.batches.retrieve(
batch_id=batch_id, extra_body={"custom_llm_provider": custom_llm_provider}
)
content = client.files.content(
file_id=batch.output_file_id,
extra_body={"custom_llm_provider": custom_llm_provider},
)
print("content from files.content", content.content)
content.write_to_file(output_path)
@ -145,20 +161,47 @@ def read_jsonl(filepath: str):
print(custom_id)
def test_e2e_batches_files():
def get_any_completed_batch_id_azure():
print("AZURE getting any completed batch id")
list_of_batches = client.batches.list(extra_body={"custom_llm_provider": "azure"})
print("list of batches", list_of_batches)
for batch in list_of_batches:
if batch.status == "completed":
return batch.id
return None
@pytest.mark.parametrize("custom_llm_provider", ["azure", "openai"])
def test_e2e_batches_files(custom_llm_provider):
"""
[PROD Test] Ensures OpenAI Batches + files work with OpenAI SDK
"""
input_path = "input.jsonl"
output_path = "out.jsonl"
input_path = (
"input.jsonl" if custom_llm_provider == "openai" else "input_azure.jsonl"
)
output_path = "out.jsonl" if custom_llm_provider == "openai" else "out_azure.jsonl"
_current_dir = os.path.dirname(os.path.abspath(__file__))
input_file_path = os.path.join(_current_dir, input_path)
output_file_path = os.path.join(_current_dir, output_path)
print("running e2e batches files with custom_llm_provider=", custom_llm_provider)
batch_id = create_batch_oai_sdk(
filepath=input_file_path, custom_llm_provider=custom_llm_provider
)
batch_id = create_batch_oai_sdk(input_file_path)
await_batch_completion(batch_id)
write_content_to_file(batch_id, output_file_path)
if custom_llm_provider == "azure":
# azure takes very long to complete a batch - randomly pick a completed batch
batch_id = get_any_completed_batch_id_azure()
else:
await_batch_completion(
batch_id=batch_id, custom_llm_provider=custom_llm_provider
)
write_content_to_file(
batch_id=batch_id,
output_path=output_file_path,
custom_llm_provider=custom_llm_provider,
)
read_jsonl(output_file_path)

View file

@ -1,16 +1,17 @@
from openai import AsyncOpenAI
import os
import pytest
import asyncio
@pytest.mark.asyncio
async def test_openai_fine_tuning():
"""
[PROD Test] Ensures logprobs are returned correctly
[PROD Test] e2e tests for /fine_tuning/jobs endpoints
"""
client = AsyncOpenAI(api_key="sk-1234", base_url="http://0.0.0.0:4000")
file_name = "openai_batch_completions.jsonl"
file_name = "openai_fine_tuning.jsonl"
_current_dir = os.path.dirname(os.path.abspath(__file__))
file_path = os.path.join(_current_dir, file_name)
@ -22,10 +23,12 @@ async def test_openai_fine_tuning():
print("response from files.create: {}".format(response))
await asyncio.sleep(5)
# create fine tuning job
ft_job = await client.fine_tuning.jobs.create(
model="gpt-35-turbo-1106",
model="gpt-35-turbo-0613",
training_file=response.id,
extra_body={"custom_llm_provider": "azure"},
)
@ -33,7 +36,7 @@ async def test_openai_fine_tuning():
print("response from ft job={}".format(ft_job))
# response from example endpoint
assert ft_job.id == "ftjob-abc123"
assert ft_job.id is not None
# list all fine tuning jobs
list_ft_jobs = await client.fine_tuning.jobs.list(
@ -44,10 +47,16 @@ async def test_openai_fine_tuning():
# cancel specific fine tuning job
cancel_ft_job = await client.fine_tuning.jobs.cancel(
fine_tuning_job_id="123",
fine_tuning_job_id=ft_job.id,
extra_body={"custom_llm_provider": "azure"},
)
print("response from cancel ft job={}".format(cancel_ft_job))
assert cancel_ft_job.id is not None
# delete OG file
await client.files.delete(
file_id=response.id,
extra_body={"custom_llm_provider": "azure"},
)