mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
(feat) /batches
Add support for using /batches
endpoints in OAI format (#7402)
* run azure testing on ci/cd * update docs on azure batches endpoints * add input azure.jsonl * refactor - use separate file for batches endpoints * fixes for passing custom llm provider to /batch endpoints * pass custom llm provider to files endpoints * update azure batches doc * add info for azure batches api * update batches endpoints * use simple helper for raising proxy exception * update config.yml * fix imports * update tests * use existing settings * update env var used * update configs * update config.yml * update ft testing
This commit is contained in:
parent
fe43403359
commit
47e12802df
17 changed files with 718 additions and 464 deletions
|
@ -1111,7 +1111,9 @@ jobs:
|
|||
docker run -d \
|
||||
-p 4000:4000 \
|
||||
-e DATABASE_URL=$PROXY_DATABASE_URL \
|
||||
-e AZURE_API_KEY=$AZURE_API_KEY \
|
||||
-e AZURE_API_KEY=$AZURE_BATCHES_API_KEY \
|
||||
-e AZURE_API_BASE=$AZURE_BATCHES_API_BASE \
|
||||
-e AZURE_API_VERSION="2024-05-01-preview" \
|
||||
-e REDIS_HOST=$REDIS_HOST \
|
||||
-e REDIS_PASSWORD=$REDIS_PASSWORD \
|
||||
-e REDIS_PORT=$REDIS_PORT \
|
||||
|
|
|
@ -559,7 +559,18 @@ litellm_settings:
|
|||
</Tabs>
|
||||
|
||||
|
||||
## **Azure Batches API**
|
||||
## **Azure Batches API**
|
||||
|
||||
| Property | Details |
|
||||
|-------|-------|
|
||||
| Description | Azure OpenAI Batches API |
|
||||
| `custom_llm_provider` on LiteLLM | `azure/` |
|
||||
| Supported Operations | `/v1/batches`, `/v1/files` |
|
||||
| Azure OpenAI Batches API | [Azure OpenAI Batches API ↗](https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/batch) |
|
||||
| Cost Tracking, Logging Support | ✅ LiteLLM will log, track cost for Batch API Requests |
|
||||
|
||||
|
||||
### Quick Start
|
||||
|
||||
Just add the azure env vars to your environment.
|
||||
|
||||
|
@ -568,40 +579,71 @@ export AZURE_API_KEY=""
|
|||
export AZURE_API_BASE=""
|
||||
```
|
||||
|
||||
AND use `/azure/*` for the Batches API calls
|
||||
<Tabs>
|
||||
<TabItem value="proxy" label="LiteLLM PROXY Server">
|
||||
|
||||
```bash
|
||||
http://0.0.0.0:4000/azure/v1/batches
|
||||
**1. Upload a File**
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="sdk" label="OpenAI Python SDK">
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
# Initialize the client
|
||||
client = OpenAI(
|
||||
base_url="http://localhost:4000",
|
||||
api_key="your-api-key"
|
||||
)
|
||||
|
||||
batch_input_file = client.files.create(
|
||||
file=open("mydata.jsonl", "rb"),
|
||||
purpose="batch",
|
||||
extra_body={"custom_llm_provider": "azure"}
|
||||
)
|
||||
file_id = batch_input_file.id
|
||||
```
|
||||
### Usage
|
||||
|
||||
**Setup**
|
||||
|
||||
- Add Azure API Keys to your environment
|
||||
|
||||
#### 1. Upload a File
|
||||
</TabItem>
|
||||
<TabItem value="curl" label="Curl">
|
||||
|
||||
```bash
|
||||
curl http://localhost:4000/azure/v1/files \
|
||||
curl http://localhost:4000/v1/files \
|
||||
-H "Authorization: Bearer sk-1234" \
|
||||
-F purpose="batch" \
|
||||
-F file="@mydata.jsonl"
|
||||
```
|
||||
|
||||
**Example File**
|
||||
|
||||
Note: `model` should be your azure deployment name.
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
**Example File Format**
|
||||
```json
|
||||
{"custom_id": "task-0", "method": "POST", "url": "/chat/completions", "body": {"model": "REPLACE-WITH-MODEL-DEPLOYMENT-NAME", "messages": [{"role": "system", "content": "You are an AI assistant that helps people find information."}, {"role": "user", "content": "When was Microsoft founded?"}]}}
|
||||
{"custom_id": "task-1", "method": "POST", "url": "/chat/completions", "body": {"model": "REPLACE-WITH-MODEL-DEPLOYMENT-NAME", "messages": [{"role": "system", "content": "You are an AI assistant that helps people find information."}, {"role": "user", "content": "When was the first XBOX released?"}]}}
|
||||
{"custom_id": "task-2", "method": "POST", "url": "/chat/completions", "body": {"model": "REPLACE-WITH-MODEL-DEPLOYMENT-NAME", "messages": [{"role": "system", "content": "You are an AI assistant that helps people find information."}, {"role": "user", "content": "What is Altair Basic?"}]}}
|
||||
```
|
||||
|
||||
#### 2. Create a batch
|
||||
**2. Create a Batch Request**
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="sdk" label="OpenAI Python SDK">
|
||||
|
||||
```python
|
||||
batch = client.batches.create( # re use client from above
|
||||
input_file_id=file_id,
|
||||
endpoint="/v1/chat/completions",
|
||||
completion_window="24h",
|
||||
metadata={"description": "My batch job"},
|
||||
extra_body={"custom_llm_provider": "azure"}
|
||||
)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
<TabItem value="curl" label="Curl">
|
||||
|
||||
```bash
|
||||
curl http://0.0.0.0:4000/azure/v1/batches \
|
||||
curl http://localhost:4000/v1/batches \
|
||||
-H "Authorization: Bearer $LITELLM_API_KEY" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
|
@ -609,34 +651,144 @@ curl http://0.0.0.0:4000/azure/v1/batches \
|
|||
"endpoint": "/v1/chat/completions",
|
||||
"completion_window": "24h"
|
||||
}'
|
||||
```
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
**3. Retrieve a Batch**
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="sdk" label="OpenAI Python SDK">
|
||||
|
||||
```python
|
||||
retrieved_batch = client.batches.retrieve(
|
||||
batch.id,
|
||||
extra_body={"custom_llm_provider": "azure"}
|
||||
)
|
||||
```
|
||||
|
||||
#### 3. Retrieve batch
|
||||
|
||||
</TabItem>
|
||||
<TabItem value="curl" label="Curl">
|
||||
|
||||
```bash
|
||||
curl http://0.0.0.0:4000/azure/v1/batches/batch_abc123 \
|
||||
curl http://localhost:4000/v1/batches/batch_abc123 \
|
||||
-H "Authorization: Bearer $LITELLM_API_KEY" \
|
||||
-H "Content-Type: application/json" \
|
||||
```
|
||||
|
||||
#### 4. Cancel batch
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
**4. Cancel a Batch**
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="sdk" label="OpenAI Python SDK">
|
||||
|
||||
```python
|
||||
cancelled_batch = client.batches.cancel(
|
||||
batch.id,
|
||||
extra_body={"custom_llm_provider": "azure"}
|
||||
)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
<TabItem value="curl" label="Curl">
|
||||
|
||||
```bash
|
||||
curl http://0.0.0.0:4000/azure/v1/batches/batch_abc123/cancel \
|
||||
curl http://localhost:4000/v1/batches/batch_abc123/cancel \
|
||||
-H "Authorization: Bearer $LITELLM_API_KEY" \
|
||||
-H "Content-Type: application/json" \
|
||||
-X POST
|
||||
```
|
||||
|
||||
#### 5. List Batch
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
**5. List Batches**
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="sdk" label="OpenAI Python SDK">
|
||||
|
||||
```python
|
||||
client.batches.list(extra_body={"custom_llm_provider": "azure"})
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
<TabItem value="curl" label="Curl">
|
||||
|
||||
```bash
|
||||
curl http://0.0.0.0:4000/v1/batches?limit=2 \
|
||||
curl http://localhost:4000/v1/batches?limit=2 \
|
||||
-H "Authorization: Bearer $LITELLM_API_KEY" \
|
||||
-H "Content-Type: application/json"
|
||||
```
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
</TabItem>
|
||||
<TabItem value="sdk" label="LiteLLM SDK">
|
||||
|
||||
**1. Create File for Batch Completion**
|
||||
|
||||
```python
|
||||
from litellm
|
||||
import os
|
||||
|
||||
os.environ["AZURE_API_KEY"] = ""
|
||||
os.environ["AZURE_API_BASE"] = ""
|
||||
|
||||
file_name = "azure_batch_completions.jsonl"
|
||||
_current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
file_path = os.path.join(_current_dir, file_name)
|
||||
file_obj = await litellm.acreate_file(
|
||||
file=open(file_path, "rb"),
|
||||
purpose="batch",
|
||||
custom_llm_provider="azure",
|
||||
)
|
||||
print("Response from creating file=", file_obj)
|
||||
```
|
||||
|
||||
**2. Create Batch Request**
|
||||
|
||||
```python
|
||||
create_batch_response = await litellm.acreate_batch(
|
||||
completion_window="24h",
|
||||
endpoint="/v1/chat/completions",
|
||||
input_file_id=batch_input_file_id,
|
||||
custom_llm_provider="azure",
|
||||
metadata={"key1": "value1", "key2": "value2"},
|
||||
)
|
||||
|
||||
print("response from litellm.create_batch=", create_batch_response)
|
||||
```
|
||||
|
||||
**3. Retrieve Batch and File Content**
|
||||
|
||||
```python
|
||||
retrieved_batch = await litellm.aretrieve_batch(
|
||||
batch_id=create_batch_response.id,
|
||||
custom_llm_provider="azure"
|
||||
)
|
||||
print("retrieved batch=", retrieved_batch)
|
||||
|
||||
# Get file content
|
||||
file_content = await litellm.afile_content(
|
||||
file_id=batch_input_file_id,
|
||||
custom_llm_provider="azure"
|
||||
)
|
||||
print("file content = ", file_content)
|
||||
```
|
||||
|
||||
**4. List Batches**
|
||||
|
||||
```python
|
||||
list_batches_response = litellm.list_batches(
|
||||
custom_llm_provider="azure",
|
||||
limit=2
|
||||
)
|
||||
print("list_batches_response=", list_batches_response)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
### [Health Check Azure Batch models](./proxy/health.md#batch-models-azure-only)
|
||||
|
||||
|
|
|
@ -234,7 +234,7 @@ def create_batch(
|
|||
)
|
||||
else:
|
||||
raise litellm.exceptions.BadRequestError(
|
||||
message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format(
|
||||
message="LiteLLM doesn't support custom_llm_provider={} for 'create_batch'".format(
|
||||
custom_llm_provider
|
||||
),
|
||||
model="n/a",
|
||||
|
|
360
litellm/proxy/batches_endpoints/endpoints.py
Normal file
360
litellm/proxy/batches_endpoints/endpoints.py
Normal file
|
@ -0,0 +1,360 @@
|
|||
######################################################################
|
||||
|
||||
# /v1/batches Endpoints
|
||||
|
||||
import asyncio
|
||||
|
||||
######################################################################
|
||||
from typing import Dict, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Path, Request, Response
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.batches.main import CreateBatchRequest, RetrieveBatchRequest
|
||||
from litellm.proxy._types import *
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
|
||||
from litellm.proxy.common_utils.openai_endpoint_utils import (
|
||||
get_custom_llm_provider_from_request_body,
|
||||
)
|
||||
from litellm.proxy.openai_files_endpoints.files_endpoints import is_known_model
|
||||
from litellm.proxy.utils import handle_exception_on_proxy
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{provider}/v1/batches",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
tags=["batch"],
|
||||
)
|
||||
@router.post(
|
||||
"/v1/batches",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
tags=["batch"],
|
||||
)
|
||||
@router.post(
|
||||
"/batches",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
tags=["batch"],
|
||||
)
|
||||
async def create_batch(
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
provider: Optional[str] = None,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Create large batches of API requests for asynchronous processing.
|
||||
This is the equivalent of POST https://api.openai.com/v1/batch
|
||||
Supports Identical Params as: https://platform.openai.com/docs/api-reference/batch
|
||||
|
||||
Example Curl
|
||||
```
|
||||
curl http://localhost:4000/v1/batches \
|
||||
-H "Authorization: Bearer sk-1234" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"input_file_id": "file-abc123",
|
||||
"endpoint": "/v1/chat/completions",
|
||||
"completion_window": "24h"
|
||||
}'
|
||||
```
|
||||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
add_litellm_data_to_request,
|
||||
general_settings,
|
||||
get_custom_headers,
|
||||
llm_router,
|
||||
proxy_config,
|
||||
proxy_logging_obj,
|
||||
version,
|
||||
)
|
||||
|
||||
data: Dict = {}
|
||||
try:
|
||||
data = await _read_request_body(request=request)
|
||||
verbose_proxy_logger.debug(
|
||||
"Request received by LiteLLM:\n{}".format(json.dumps(data, indent=4)),
|
||||
)
|
||||
|
||||
# Include original request and headers in the data
|
||||
data = await add_litellm_data_to_request(
|
||||
data=data,
|
||||
request=request,
|
||||
general_settings=general_settings,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
version=version,
|
||||
proxy_config=proxy_config,
|
||||
)
|
||||
|
||||
## check if model is a loadbalanced model
|
||||
router_model: Optional[str] = None
|
||||
is_router_model = False
|
||||
if litellm.enable_loadbalancing_on_batch_endpoints is True:
|
||||
router_model = data.get("model", None)
|
||||
is_router_model = is_known_model(model=router_model, llm_router=llm_router)
|
||||
|
||||
custom_llm_provider = (
|
||||
provider or data.pop("custom_llm_provider", None) or "openai"
|
||||
)
|
||||
_create_batch_data = CreateBatchRequest(**data)
|
||||
if (
|
||||
litellm.enable_loadbalancing_on_batch_endpoints is True
|
||||
and is_router_model
|
||||
and router_model is not None
|
||||
):
|
||||
if llm_router is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "LLM Router not initialized. Ensure models added to proxy."
|
||||
},
|
||||
)
|
||||
|
||||
response = await llm_router.acreate_batch(**_create_batch_data) # type: ignore
|
||||
else:
|
||||
response = await litellm.acreate_batch(
|
||||
custom_llm_provider=custom_llm_provider, **_create_batch_data # type: ignore
|
||||
)
|
||||
|
||||
### ALERTING ###
|
||||
asyncio.create_task(
|
||||
proxy_logging_obj.update_request_status(
|
||||
litellm_call_id=data.get("litellm_call_id", ""), status="success"
|
||||
)
|
||||
)
|
||||
|
||||
### RESPONSE HEADERS ###
|
||||
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
||||
model_id = hidden_params.get("model_id", None) or ""
|
||||
cache_key = hidden_params.get("cache_key", None) or ""
|
||||
api_base = hidden_params.get("api_base", None) or ""
|
||||
|
||||
fastapi_response.headers.update(
|
||||
get_custom_headers(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
model_id=model_id,
|
||||
cache_key=cache_key,
|
||||
api_base=api_base,
|
||||
version=version,
|
||||
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||
request_data=data,
|
||||
)
|
||||
)
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
await proxy_logging_obj.post_call_failure_hook(
|
||||
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
||||
)
|
||||
verbose_proxy_logger.exception(
|
||||
"litellm.proxy.proxy_server.create_batch(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
raise handle_exception_on_proxy(e)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{provider}/v1/batches/{batch_id:path}",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
tags=["batch"],
|
||||
)
|
||||
@router.get(
|
||||
"/v1/batches/{batch_id:path}",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
tags=["batch"],
|
||||
)
|
||||
@router.get(
|
||||
"/batches/{batch_id:path}",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
tags=["batch"],
|
||||
)
|
||||
async def retrieve_batch(
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
provider: Optional[str] = None,
|
||||
batch_id: str = Path(
|
||||
title="Batch ID to retrieve", description="The ID of the batch to retrieve"
|
||||
),
|
||||
):
|
||||
"""
|
||||
Retrieves a batch.
|
||||
This is the equivalent of GET https://api.openai.com/v1/batches/{batch_id}
|
||||
Supports Identical Params as: https://platform.openai.com/docs/api-reference/batch/retrieve
|
||||
|
||||
Example Curl
|
||||
```
|
||||
curl http://localhost:4000/v1/batches/batch_abc123 \
|
||||
-H "Authorization: Bearer sk-1234" \
|
||||
-H "Content-Type: application/json" \
|
||||
|
||||
```
|
||||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
get_custom_headers,
|
||||
llm_router,
|
||||
proxy_logging_obj,
|
||||
version,
|
||||
)
|
||||
|
||||
data: Dict = {}
|
||||
try:
|
||||
## check if model is a loadbalanced model
|
||||
_retrieve_batch_request = RetrieveBatchRequest(
|
||||
batch_id=batch_id,
|
||||
)
|
||||
|
||||
if litellm.enable_loadbalancing_on_batch_endpoints is True:
|
||||
if llm_router is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "LLM Router not initialized. Ensure models added to proxy."
|
||||
},
|
||||
)
|
||||
|
||||
response = await llm_router.aretrieve_batch(**_retrieve_batch_request) # type: ignore
|
||||
else:
|
||||
custom_llm_provider = (
|
||||
provider
|
||||
or await get_custom_llm_provider_from_request_body(request=request)
|
||||
or "openai"
|
||||
)
|
||||
response = await litellm.aretrieve_batch(
|
||||
custom_llm_provider=custom_llm_provider, **_retrieve_batch_request # type: ignore
|
||||
)
|
||||
|
||||
### ALERTING ###
|
||||
asyncio.create_task(
|
||||
proxy_logging_obj.update_request_status(
|
||||
litellm_call_id=data.get("litellm_call_id", ""), status="success"
|
||||
)
|
||||
)
|
||||
|
||||
### RESPONSE HEADERS ###
|
||||
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
||||
model_id = hidden_params.get("model_id", None) or ""
|
||||
cache_key = hidden_params.get("cache_key", None) or ""
|
||||
api_base = hidden_params.get("api_base", None) or ""
|
||||
|
||||
fastapi_response.headers.update(
|
||||
get_custom_headers(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
model_id=model_id,
|
||||
cache_key=cache_key,
|
||||
api_base=api_base,
|
||||
version=version,
|
||||
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||
request_data=data,
|
||||
)
|
||||
)
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
await proxy_logging_obj.post_call_failure_hook(
|
||||
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
||||
)
|
||||
verbose_proxy_logger.exception(
|
||||
"litellm.proxy.proxy_server.retrieve_batch(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
raise handle_exception_on_proxy(e)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{provider}/v1/batches",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
tags=["batch"],
|
||||
)
|
||||
@router.get(
|
||||
"/v1/batches",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
tags=["batch"],
|
||||
)
|
||||
@router.get(
|
||||
"/batches",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
tags=["batch"],
|
||||
)
|
||||
async def list_batches(
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
provider: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
after: Optional[str] = None,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Lists
|
||||
This is the equivalent of GET https://api.openai.com/v1/batches/
|
||||
Supports Identical Params as: https://platform.openai.com/docs/api-reference/batch/list
|
||||
|
||||
Example Curl
|
||||
```
|
||||
curl http://localhost:4000/v1/batches?limit=2 \
|
||||
-H "Authorization: Bearer sk-1234" \
|
||||
-H "Content-Type: application/json" \
|
||||
|
||||
```
|
||||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
get_custom_headers,
|
||||
proxy_logging_obj,
|
||||
version,
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug("GET /v1/batches after={} limit={}".format(after, limit))
|
||||
try:
|
||||
custom_llm_provider = (
|
||||
provider
|
||||
or await get_custom_llm_provider_from_request_body(request=request)
|
||||
or "openai"
|
||||
)
|
||||
response = await litellm.alist_batches(
|
||||
custom_llm_provider=custom_llm_provider, # type: ignore
|
||||
after=after,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
### RESPONSE HEADERS ###
|
||||
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
||||
model_id = hidden_params.get("model_id", None) or ""
|
||||
cache_key = hidden_params.get("cache_key", None) or ""
|
||||
api_base = hidden_params.get("api_base", None) or ""
|
||||
|
||||
fastapi_response.headers.update(
|
||||
get_custom_headers(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
model_id=model_id,
|
||||
cache_key=cache_key,
|
||||
api_base=api_base,
|
||||
version=version,
|
||||
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||
)
|
||||
)
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
await proxy_logging_obj.post_call_failure_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
original_exception=e,
|
||||
request_data={"after": after, "limit": limit},
|
||||
)
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.proxy_server.retrieve_batch(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
raise handle_exception_on_proxy(e)
|
||||
|
||||
|
||||
######################################################################
|
||||
|
||||
# END OF /v1/batches Endpoints Implementation
|
||||
|
||||
######################################################################
|
|
@ -2,6 +2,12 @@
|
|||
Contains utils used by OpenAI compatible endpoints
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
|
||||
|
||||
|
||||
def remove_sensitive_info_from_deployment(deployment_dict: dict) -> dict:
|
||||
"""
|
||||
|
@ -19,3 +25,15 @@ def remove_sensitive_info_from_deployment(deployment_dict: dict) -> dict:
|
|||
deployment_dict["litellm_params"].pop("aws_secret_access_key", None)
|
||||
|
||||
return deployment_dict
|
||||
|
||||
|
||||
async def get_custom_llm_provider_from_request_body(request: Request) -> Optional[str]:
|
||||
"""
|
||||
Get the `custom_llm_provider` from the request body
|
||||
|
||||
Safely reads the request body
|
||||
"""
|
||||
request_body: dict = await _read_request_body(request=request) or {}
|
||||
if "custom_llm_provider" in request_body:
|
||||
return request_body["custom_llm_provider"]
|
||||
return None
|
||||
|
|
|
@ -43,18 +43,18 @@ litellm_settings:
|
|||
# For /fine_tuning/jobs endpoints
|
||||
finetune_settings:
|
||||
- custom_llm_provider: azure
|
||||
api_base: https://exampleopenaiendpoint-production.up.railway.app
|
||||
api_key: fake-key
|
||||
api_version: "2023-03-15-preview"
|
||||
api_base: os.environ/AZURE_API_BASE
|
||||
api_key: os.environ/AZURE_API_KEY
|
||||
api_version: "2024-05-01-preview"
|
||||
- custom_llm_provider: openai
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
|
||||
# for /files endpoints
|
||||
files_settings:
|
||||
- custom_llm_provider: azure
|
||||
api_base: https://exampleopenaiendpoint-production.up.railway.app
|
||||
api_key: fake-key
|
||||
api_version: "2023-03-15-preview"
|
||||
api_base: os.environ/AZURE_API_BASE
|
||||
api_key: os.environ/AZURE_API_KEY
|
||||
api_version: "2024-05-01-preview"
|
||||
- custom_llm_provider: openai
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
|
||||
|
|
|
@ -27,6 +27,9 @@ from litellm import CreateFileRequest, get_secret_str
|
|||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import *
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.common_utils.openai_endpoint_utils import (
|
||||
get_custom_llm_provider_from_request_body,
|
||||
)
|
||||
from litellm.router import Router
|
||||
|
||||
router = APIRouter()
|
||||
|
@ -151,11 +154,14 @@ async def create_file(
|
|||
|
||||
data: Dict = {}
|
||||
try:
|
||||
if provider is not None:
|
||||
custom_llm_provider = provider
|
||||
# Use orjson to parse JSON data, orjson speeds up requests significantly
|
||||
# Read the file content
|
||||
file_content = await file.read()
|
||||
custom_llm_provider = (
|
||||
provider
|
||||
or await get_custom_llm_provider_from_request_body(request=request)
|
||||
or "openai"
|
||||
)
|
||||
# Prepare the data for forwarding
|
||||
|
||||
data = {"purpose": purpose}
|
||||
|
@ -322,10 +328,13 @@ async def get_file_content(
|
|||
proxy_config=proxy_config,
|
||||
)
|
||||
|
||||
if provider is None:
|
||||
provider = "openai"
|
||||
custom_llm_provider = (
|
||||
provider
|
||||
or await get_custom_llm_provider_from_request_body(request=request)
|
||||
or "openai"
|
||||
)
|
||||
response = await litellm.afile_content(
|
||||
custom_llm_provider=provider, file_id=file_id, **data # type: ignore
|
||||
custom_llm_provider=custom_llm_provider, file_id=file_id, **data # type: ignore
|
||||
)
|
||||
|
||||
### ALERTING ###
|
||||
|
@ -436,7 +445,11 @@ async def get_file(
|
|||
|
||||
data: Dict = {}
|
||||
try:
|
||||
|
||||
custom_llm_provider = (
|
||||
provider
|
||||
or await get_custom_llm_provider_from_request_body(request=request)
|
||||
or "openai"
|
||||
)
|
||||
# Include original request and headers in the data
|
||||
data = await add_litellm_data_to_request(
|
||||
data=data,
|
||||
|
@ -446,11 +459,8 @@ async def get_file(
|
|||
version=version,
|
||||
proxy_config=proxy_config,
|
||||
)
|
||||
|
||||
if provider is None: # default to openai
|
||||
provider = "openai"
|
||||
response = await litellm.afile_retrieve(
|
||||
custom_llm_provider=provider, file_id=file_id, **data # type: ignore
|
||||
custom_llm_provider=custom_llm_provider, file_id=file_id, **data # type: ignore
|
||||
)
|
||||
|
||||
### ALERTING ###
|
||||
|
@ -552,7 +562,11 @@ async def delete_file(
|
|||
|
||||
data: Dict = {}
|
||||
try:
|
||||
|
||||
custom_llm_provider = (
|
||||
provider
|
||||
or await get_custom_llm_provider_from_request_body(request=request)
|
||||
or "openai"
|
||||
)
|
||||
# Include original request and headers in the data
|
||||
data = await add_litellm_data_to_request(
|
||||
data=data,
|
||||
|
@ -563,10 +577,8 @@ async def delete_file(
|
|||
proxy_config=proxy_config,
|
||||
)
|
||||
|
||||
if provider is None: # default to openai
|
||||
provider = "openai"
|
||||
response = await litellm.afile_delete(
|
||||
custom_llm_provider=provider, file_id=file_id, **data # type: ignore
|
||||
custom_llm_provider=custom_llm_provider, file_id=file_id, **data # type: ignore
|
||||
)
|
||||
|
||||
### ALERTING ###
|
||||
|
@ -667,7 +679,11 @@ async def list_files(
|
|||
|
||||
data: Dict = {}
|
||||
try:
|
||||
|
||||
custom_llm_provider = (
|
||||
provider
|
||||
or await get_custom_llm_provider_from_request_body(request=request)
|
||||
or "openai"
|
||||
)
|
||||
# Include original request and headers in the data
|
||||
data = await add_litellm_data_to_request(
|
||||
data=data,
|
||||
|
@ -678,10 +694,8 @@ async def list_files(
|
|||
proxy_config=proxy_config,
|
||||
)
|
||||
|
||||
if provider is None:
|
||||
provider = "openai"
|
||||
response = await litellm.afile_list(
|
||||
custom_llm_provider=provider, purpose=purpose, **data # type: ignore
|
||||
custom_llm_provider=custom_llm_provider, purpose=purpose, **data # type: ignore
|
||||
)
|
||||
|
||||
### ALERTING ###
|
||||
|
|
|
@ -12,7 +12,21 @@ model_list:
|
|||
model: bedrock/*
|
||||
|
||||
|
||||
# for /files endpoints
|
||||
# For /fine_tuning/jobs endpoints
|
||||
finetune_settings:
|
||||
- custom_llm_provider: azure
|
||||
api_base: os.environ/AZURE_BATCHES_API_BASE
|
||||
api_key: os.environ/AZURE_BATCHES_API_KEY
|
||||
api_version: "2024-05-01-preview"
|
||||
- custom_llm_provider: openai
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
|
||||
# for /files endpoints
|
||||
files_settings:
|
||||
- custom_llm_provider: azure
|
||||
api_base: os.environ/AZURE_BATCHES_API_BASE
|
||||
api_key: os.environ/AZURE_BATCHES_API_KEY
|
||||
api_version: "2024-05-01-preview"
|
||||
- custom_llm_provider: openai
|
||||
api_key: os.environ/OPENAI_API_KEY
|
|
@ -104,13 +104,7 @@ def generate_feedback_box():
|
|||
from collections import defaultdict
|
||||
|
||||
import litellm
|
||||
from litellm import (
|
||||
CancelBatchRequest,
|
||||
CreateBatchRequest,
|
||||
ListBatchRequest,
|
||||
RetrieveBatchRequest,
|
||||
Router,
|
||||
)
|
||||
from litellm import Router
|
||||
from litellm._logging import verbose_proxy_logger, verbose_router_logger
|
||||
from litellm.caching.caching import DualCache, RedisCache
|
||||
from litellm.exceptions import RejectedRequestError
|
||||
|
@ -137,6 +131,7 @@ from litellm.proxy.auth.user_api_key_auth import (
|
|||
user_api_key_auth,
|
||||
user_api_key_auth_websocket,
|
||||
)
|
||||
from litellm.proxy.batches_endpoints.endpoints import router as batches_router
|
||||
|
||||
## Import All Misc routes here ##
|
||||
from litellm.proxy.caching_routes import router as caching_router
|
||||
|
@ -208,7 +203,6 @@ from litellm.proxy.management_endpoints.team_endpoints import router as team_rou
|
|||
from litellm.proxy.management_endpoints.team_endpoints import update_team
|
||||
from litellm.proxy.management_endpoints.ui_sso import router as ui_sso_router
|
||||
from litellm.proxy.management_helpers.audit_logs import create_audit_log_for_update
|
||||
from litellm.proxy.openai_files_endpoints.files_endpoints import is_known_model
|
||||
from litellm.proxy.openai_files_endpoints.files_endpoints import (
|
||||
router as openai_files_router,
|
||||
)
|
||||
|
@ -5095,377 +5089,6 @@ async def run_thread(
|
|||
)
|
||||
|
||||
|
||||
######################################################################
|
||||
|
||||
# /v1/batches Endpoints
|
||||
|
||||
|
||||
######################################################################
|
||||
@router.post(
|
||||
"/{provider}/v1/batches",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
tags=["batch"],
|
||||
)
|
||||
@router.post(
|
||||
"/v1/batches",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
tags=["batch"],
|
||||
)
|
||||
@router.post(
|
||||
"/batches",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
tags=["batch"],
|
||||
)
|
||||
async def create_batch(
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
provider: Optional[str] = None,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Create large batches of API requests for asynchronous processing.
|
||||
This is the equivalent of POST https://api.openai.com/v1/batch
|
||||
Supports Identical Params as: https://platform.openai.com/docs/api-reference/batch
|
||||
|
||||
Example Curl
|
||||
```
|
||||
curl http://localhost:4000/v1/batches \
|
||||
-H "Authorization: Bearer sk-1234" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"input_file_id": "file-abc123",
|
||||
"endpoint": "/v1/chat/completions",
|
||||
"completion_window": "24h"
|
||||
}'
|
||||
```
|
||||
"""
|
||||
global proxy_logging_obj
|
||||
data: Dict = {}
|
||||
|
||||
try:
|
||||
body = await request.body()
|
||||
body_str = body.decode()
|
||||
try:
|
||||
data = ast.literal_eval(body_str)
|
||||
except Exception:
|
||||
data = json.loads(body_str)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"Request received by LiteLLM:\n{}".format(json.dumps(data, indent=4)),
|
||||
)
|
||||
|
||||
# Include original request and headers in the data
|
||||
data = await add_litellm_data_to_request(
|
||||
data=data,
|
||||
request=request,
|
||||
general_settings=general_settings,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
version=version,
|
||||
proxy_config=proxy_config,
|
||||
)
|
||||
|
||||
## check if model is a loadbalanced model
|
||||
router_model: Optional[str] = None
|
||||
is_router_model = False
|
||||
if litellm.enable_loadbalancing_on_batch_endpoints is True:
|
||||
router_model = data.get("model", None)
|
||||
is_router_model = is_known_model(model=router_model, llm_router=llm_router)
|
||||
|
||||
_create_batch_data = CreateBatchRequest(**data)
|
||||
custom_llm_provider = provider or _create_batch_data.pop("custom_llm_provider", None) # type: ignore
|
||||
|
||||
if (
|
||||
litellm.enable_loadbalancing_on_batch_endpoints is True
|
||||
and is_router_model
|
||||
and router_model is not None
|
||||
):
|
||||
if llm_router is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "LLM Router not initialized. Ensure models added to proxy."
|
||||
},
|
||||
)
|
||||
|
||||
response = await llm_router.acreate_batch(**_create_batch_data) # type: ignore
|
||||
else:
|
||||
if custom_llm_provider is None:
|
||||
custom_llm_provider = "openai"
|
||||
response = await litellm.acreate_batch(
|
||||
custom_llm_provider=custom_llm_provider, **_create_batch_data # type: ignore
|
||||
)
|
||||
|
||||
### ALERTING ###
|
||||
asyncio.create_task(
|
||||
proxy_logging_obj.update_request_status(
|
||||
litellm_call_id=data.get("litellm_call_id", ""), status="success"
|
||||
)
|
||||
)
|
||||
|
||||
### RESPONSE HEADERS ###
|
||||
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
||||
model_id = hidden_params.get("model_id", None) or ""
|
||||
cache_key = hidden_params.get("cache_key", None) or ""
|
||||
api_base = hidden_params.get("api_base", None) or ""
|
||||
|
||||
fastapi_response.headers.update(
|
||||
get_custom_headers(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
model_id=model_id,
|
||||
cache_key=cache_key,
|
||||
api_base=api_base,
|
||||
version=version,
|
||||
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||
request_data=data,
|
||||
)
|
||||
)
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
await proxy_logging_obj.post_call_failure_hook(
|
||||
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
||||
)
|
||||
verbose_proxy_logger.exception(
|
||||
"litellm.proxy.proxy_server.create_batch(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "message", str(e.detail)),
|
||||
type=getattr(e, "type", "None"),
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
||||
)
|
||||
else:
|
||||
error_msg = f"{str(e)}"
|
||||
raise ProxyException(
|
||||
message=getattr(e, "message", error_msg),
|
||||
type=getattr(e, "type", "None"),
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", 500),
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{provider}/v1/batches/{batch_id:path}",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
tags=["batch"],
|
||||
)
|
||||
@router.get(
|
||||
"/v1/batches/{batch_id:path}",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
tags=["batch"],
|
||||
)
|
||||
@router.get(
|
||||
"/batches/{batch_id:path}",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
tags=["batch"],
|
||||
)
|
||||
async def retrieve_batch(
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
provider: Optional[str] = None,
|
||||
batch_id: str = Path(
|
||||
title="Batch ID to retrieve", description="The ID of the batch to retrieve"
|
||||
),
|
||||
):
|
||||
"""
|
||||
Retrieves a batch.
|
||||
This is the equivalent of GET https://api.openai.com/v1/batches/{batch_id}
|
||||
Supports Identical Params as: https://platform.openai.com/docs/api-reference/batch/retrieve
|
||||
|
||||
Example Curl
|
||||
```
|
||||
curl http://localhost:4000/v1/batches/batch_abc123 \
|
||||
-H "Authorization: Bearer sk-1234" \
|
||||
-H "Content-Type: application/json" \
|
||||
|
||||
```
|
||||
"""
|
||||
global proxy_logging_obj
|
||||
data: Dict = {}
|
||||
try:
|
||||
## check if model is a loadbalanced model
|
||||
|
||||
_retrieve_batch_request = RetrieveBatchRequest(
|
||||
batch_id=batch_id,
|
||||
)
|
||||
|
||||
if litellm.enable_loadbalancing_on_batch_endpoints is True:
|
||||
if llm_router is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "LLM Router not initialized. Ensure models added to proxy."
|
||||
},
|
||||
)
|
||||
|
||||
response = await llm_router.aretrieve_batch(**_retrieve_batch_request) # type: ignore
|
||||
else:
|
||||
if provider is None:
|
||||
provider = "openai"
|
||||
response = await litellm.aretrieve_batch(
|
||||
custom_llm_provider=provider, **_retrieve_batch_request # type: ignore
|
||||
)
|
||||
|
||||
### ALERTING ###
|
||||
asyncio.create_task(
|
||||
proxy_logging_obj.update_request_status(
|
||||
litellm_call_id=data.get("litellm_call_id", ""), status="success"
|
||||
)
|
||||
)
|
||||
|
||||
### RESPONSE HEADERS ###
|
||||
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
||||
model_id = hidden_params.get("model_id", None) or ""
|
||||
cache_key = hidden_params.get("cache_key", None) or ""
|
||||
api_base = hidden_params.get("api_base", None) or ""
|
||||
|
||||
fastapi_response.headers.update(
|
||||
get_custom_headers(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
model_id=model_id,
|
||||
cache_key=cache_key,
|
||||
api_base=api_base,
|
||||
version=version,
|
||||
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||
request_data=data,
|
||||
)
|
||||
)
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
await proxy_logging_obj.post_call_failure_hook(
|
||||
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
||||
)
|
||||
verbose_proxy_logger.exception(
|
||||
"litellm.proxy.proxy_server.retrieve_batch(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "message", str(e.detail)),
|
||||
type=getattr(e, "type", "None"),
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
||||
)
|
||||
else:
|
||||
traceback.format_exc()
|
||||
error_msg = f"{str(e)}"
|
||||
raise ProxyException(
|
||||
message=getattr(e, "message", error_msg),
|
||||
type=getattr(e, "type", "None"),
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", 500),
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{provider}/v1/batches",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
tags=["batch"],
|
||||
)
|
||||
@router.get(
|
||||
"/v1/batches",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
tags=["batch"],
|
||||
)
|
||||
@router.get(
|
||||
"/batches",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
tags=["batch"],
|
||||
)
|
||||
async def list_batches(
|
||||
fastapi_response: Response,
|
||||
provider: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
after: Optional[str] = None,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Lists
|
||||
This is the equivalent of GET https://api.openai.com/v1/batches/
|
||||
Supports Identical Params as: https://platform.openai.com/docs/api-reference/batch/list
|
||||
|
||||
Example Curl
|
||||
```
|
||||
curl http://localhost:4000/v1/batches?limit=2 \
|
||||
-H "Authorization: Bearer sk-1234" \
|
||||
-H "Content-Type: application/json" \
|
||||
|
||||
```
|
||||
"""
|
||||
global proxy_logging_obj
|
||||
verbose_proxy_logger.debug("GET /v1/batches after={} limit={}".format(after, limit))
|
||||
try:
|
||||
if provider is None:
|
||||
provider = "openai"
|
||||
response = await litellm.alist_batches(
|
||||
custom_llm_provider=provider, # type: ignore
|
||||
after=after,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
### RESPONSE HEADERS ###
|
||||
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
||||
model_id = hidden_params.get("model_id", None) or ""
|
||||
cache_key = hidden_params.get("cache_key", None) or ""
|
||||
api_base = hidden_params.get("api_base", None) or ""
|
||||
|
||||
fastapi_response.headers.update(
|
||||
get_custom_headers(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
model_id=model_id,
|
||||
cache_key=cache_key,
|
||||
api_base=api_base,
|
||||
version=version,
|
||||
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||
)
|
||||
)
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
await proxy_logging_obj.post_call_failure_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
original_exception=e,
|
||||
request_data={"after": after, "limit": limit},
|
||||
)
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.proxy_server.retrieve_batch(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "message", str(e.detail)),
|
||||
type=getattr(e, "type", "None"),
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
||||
)
|
||||
else:
|
||||
traceback.format_exc()
|
||||
error_msg = f"{str(e)}"
|
||||
raise ProxyException(
|
||||
message=getattr(e, "message", error_msg),
|
||||
type=getattr(e, "type", "None"),
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", 500),
|
||||
)
|
||||
|
||||
|
||||
######################################################################
|
||||
|
||||
# END OF /v1/batches Endpoints Implementation
|
||||
|
||||
######################################################################
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/moderations",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
|
@ -9203,6 +8826,7 @@ def cleanup_router_config_variables():
|
|||
|
||||
|
||||
app.include_router(router)
|
||||
app.include_router(batches_router)
|
||||
app.include_router(rerank_router)
|
||||
app.include_router(fine_tuning_router)
|
||||
app.include_router(vertex_router)
|
||||
|
|
|
@ -127,8 +127,8 @@ litellm_settings:
|
|||
# For /fine_tuning/jobs endpoints
|
||||
finetune_settings:
|
||||
- custom_llm_provider: azure
|
||||
api_base: https://exampleopenaiendpoint-production.up.railway.app
|
||||
api_key: fake-key
|
||||
api_base: os.environ/AZURE_API_BASE
|
||||
api_key: os.environ/AZURE_API_KEY
|
||||
api_version: "2023-03-15-preview"
|
||||
- custom_llm_provider: openai
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
|
@ -136,8 +136,8 @@ finetune_settings:
|
|||
# for /files endpoints
|
||||
files_settings:
|
||||
- custom_llm_provider: azure
|
||||
api_base: https://exampleopenaiendpoint-production.up.railway.app
|
||||
api_key: fake-key
|
||||
api_base: os.environ/AZURE_API_BASE
|
||||
api_key: os.environ/AZURE_API_KEY
|
||||
api_version: "2023-03-15-preview"
|
||||
- custom_llm_provider: openai
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
|
|
|
@ -78,9 +78,6 @@ async def test_create_batch(provider):
|
|||
2. Create Batch Request
|
||||
3. Retrieve the specific batch
|
||||
"""
|
||||
custom_logger = TestCustomLogger()
|
||||
litellm.callbacks = [custom_logger, "datadog"]
|
||||
|
||||
if provider == "azure":
|
||||
# Don't have anymore Azure Quota
|
||||
return
|
||||
|
@ -112,12 +109,6 @@ async def test_create_batch(provider):
|
|||
print("response from litellm.create_batch=", create_batch_response)
|
||||
await asyncio.sleep(6)
|
||||
|
||||
# Assert that the create batch event is logged on CustomLogger
|
||||
assert custom_logger.standard_logging_object is not None
|
||||
print(
|
||||
"standard_logging_object=",
|
||||
json.dumps(custom_logger.standard_logging_object, indent=4),
|
||||
)
|
||||
assert (
|
||||
create_batch_response.id is not None
|
||||
), f"Failed to create batch, expected a non null batch_id but got {create_batch_response.id}"
|
||||
|
@ -170,7 +161,7 @@ class TestCustomLogger(CustomLogger):
|
|||
self.standard_logging_object = kwargs["standard_logging_object"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("provider", ["openai"]) # "azure"
|
||||
@pytest.mark.parametrize("provider", ["azure", "openai"]) # "azure"
|
||||
@pytest.mark.asyncio()
|
||||
@pytest.mark.flaky(retries=3, delay=1)
|
||||
async def test_async_create_batch(provider):
|
||||
|
@ -180,9 +171,6 @@ async def test_async_create_batch(provider):
|
|||
3. Retrieve the specific batch
|
||||
"""
|
||||
print("Testing async create batch")
|
||||
if provider == "azure":
|
||||
# Don't have anymore Azure Quota
|
||||
return
|
||||
|
||||
custom_logger = TestCustomLogger()
|
||||
litellm.callbacks = [custom_logger, "datadog"]
|
||||
|
@ -276,6 +264,24 @@ async def test_async_create_batch(provider):
|
|||
file.write(file_content.content)
|
||||
|
||||
|
||||
def cleanup_azure_files():
|
||||
"""
|
||||
Delete all files for Azure - helper for when we run out of Azure Files Quota
|
||||
"""
|
||||
azure_files = litellm.file_list(
|
||||
custom_llm_provider="azure",
|
||||
)
|
||||
print("azure_files=", azure_files)
|
||||
for _file in azure_files:
|
||||
print("deleting file=", _file)
|
||||
delete_file_response = litellm.file_delete(
|
||||
file_id=_file.id,
|
||||
custom_llm_provider="azure",
|
||||
)
|
||||
print("delete_file_response=", delete_file_response)
|
||||
assert delete_file_response.id == _file.id
|
||||
|
||||
|
||||
def test_retrieve_batch():
|
||||
pass
|
||||
|
||||
|
|
1
tests/openai_misc_endpoints_tests/input_azure.jsonl
Normal file
1
tests/openai_misc_endpoints_tests/input_azure.jsonl
Normal file
|
@ -0,0 +1 @@
|
|||
{"custom_id": "ae006110bb364606||/workspace/saved_models/meta-llama/Meta-Llama-3.1-8B-Instruct", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-4o-mini", "temperature": 0, "max_tokens": 1024, "response_format": {"type": "json_object"}, "messages": [{"role": "user", "content": "# Instruction \n\nYou are an expert evaluator. Your task is to evaluate the quality of the responses generated by AI models. \nWe will provide you with the user query and an AI-generated responses.\nYo must respond in json"}]}}
|
10
tests/openai_misc_endpoints_tests/openai_fine_tuning.jsonl
Normal file
10
tests/openai_misc_endpoints_tests/openai_fine_tuning.jsonl
Normal file
|
@ -0,0 +1,10 @@
|
|||
{"messages": [{"role": "system", "content": "Clippy is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "Who discovered Antarctica?"}, {"role": "assistant", "content": "Some chaps named Fabian Gottlieb von Bellingshausen and Mikhail Lazarev, as if they don't teach that in every school!"}]}
|
||||
{"messages": [{"role": "system", "content": "Clippy is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "What is the biggest ocean?"}, {"role": "assistant", "content": "The Pacific Ocean. It's not like it's a small pond or anything."}]}
|
||||
{"messages": [{"role": "system", "content": "Clippy is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "What is the largest planet?"}, {"role": "assistant", "content": "It's called Jupiter, you may have heard of it...or not."}]}
|
||||
{"messages": [{"role": "system", "content": "Clippy is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "Who painted the Mona Lisa?"}, {"role": "assistant", "content": "Oh, some obscure artist named Leonardo da Vinci. Not that he is famous or anything."}]}
|
||||
{"messages": [{"role": "system", "content": "Clippy is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "What's the boiling point of water?"}, {"role": "assistant", "content": "Exciting fact of the day: it's 100 degrees Celsius. Life-changing, isn't it?"}]}
|
||||
{"messages": [{"role": "system", "content": "Clippy is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "Who wrote 'The Odyssey'?"}, {"role": "assistant", "content": "Oh, just some poet named Homer. Nothing special, really."}]}
|
||||
{"messages": [{"role": "system", "content": "Clippy is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "What is the largest mammal?"}, {"role": "assistant", "content": "Blue Whale. Quite a big fish isn't it? Oh wait, it's not a fish!"}]}
|
||||
{"messages": [{"role": "system", "content": "Clippy is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "What's the highest mountain peak?"}, {"role": "assistant", "content": "Everest. Brace yourself, it's not a walk in the park."}]}
|
||||
{"messages": [{"role": "system", "content": "Clippy is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "Who invented the telephone?"}, {"role": "assistant", "content": "A guy named Alexander Graham Bell. Bet you didn't guess that one!"}]}
|
||||
{"messages": [{"role": "system", "content": "Clippy is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "What is the distance from Earth to the Sun?"}, {"role": "assistant", "content": "About 93 million miles. Just a quick drive, really."}]}
|
|
@ -1 +1 @@
|
|||
{"id": "batch_req_6765ed82629c8190b70c10c183b5e994", "custom_id": "ae006110bb364606||/workspace/saved_models/meta-llama/Meta-Llama-3.1-8B-Instruct", "response": {"status_code": 200, "request_id": "36bbc935dec50094e84af1db52cf2cc7", "body": {"id": "chatcmpl-AgfdQmdwJQ0NrQManGI8ecwMvF0ZC", "object": "chat.completion", "created": 1734733184, "model": "gpt-4o-2024-05-13", "choices": [{"index": 0, "message": {"role": "assistant", "content": "{\n \"user_query\": \"What are the benefits of using renewable energy sources?\",\n \"ai_response\": \"Renewable energy sources, such as solar, wind, and hydroelectric power, offer numerous benefits. They are sustainable and can be replenished naturally, reducing the reliance on finite fossil fuels. Additionally, renewable energy sources produce little to no greenhouse gas emissions, helping to combat climate change and reduce air pollution. They also create jobs in the renewable energy sector and can lead to energy independence for countries that invest in their development. Furthermore, renewable energy technologies often have lower operating costs once established, providing long-term economic benefits.\"\n}", "refusal": null}, "logprobs": null, "finish_reason": "stop"}], "usage": {"prompt_tokens": 51, "completion_tokens": 128, "total_tokens": 179, "prompt_tokens_details": {"cached_tokens": 0, "audio_tokens": 0}, "completion_tokens_details": {"reasoning_tokens": 0, "audio_tokens": 0, "accepted_prediction_tokens": 0, "rejected_prediction_tokens": 0}}, "system_fingerprint": "fp_20cb129c3a"}}, "error": null}
|
||||
{"id": "batch_req_676aee0162f88190b707f57aa28ee2e6", "custom_id": "ae006110bb364606||/workspace/saved_models/meta-llama/Meta-Llama-3.1-8B-Instruct", "response": {"status_code": 200, "request_id": "d5d91547aef5fb21997a8d67fe4785fa", "body": {"id": "chatcmpl-Ai2uZRrTVwowbwysu9TorTUYf3qFq", "object": "chat.completion", "created": 1735060987, "model": "gpt-4o-2024-05-13", "choices": [{"index": 0, "message": {"role": "assistant", "content": "{\n \"user_query\": \"What are the benefits of a plant-based diet?\",\n \"ai_response\": \"A plant-based diet offers numerous benefits, including improved heart health, weight management, and a lower risk of chronic diseases. It is rich in essential nutrients, fiber, and antioxidants, which can help reduce inflammation and improve overall well-being. Additionally, a plant-based diet is environmentally sustainable and can contribute to reduced greenhouse gas emissions and conservation of natural resources.\"\n}", "refusal": null}, "logprobs": null, "finish_reason": "stop"}], "usage": {"prompt_tokens": 51, "completion_tokens": 95, "total_tokens": 146, "prompt_tokens_details": {"cached_tokens": 0, "audio_tokens": 0}, "completion_tokens_details": {"reasoning_tokens": 0, "audio_tokens": 0, "accepted_prediction_tokens": 0, "rejected_prediction_tokens": 0}}, "system_fingerprint": "fp_0325da12f2"}}, "error": null}
|
||||
|
|
1
tests/openai_misc_endpoints_tests/out_azure.jsonl
Normal file
1
tests/openai_misc_endpoints_tests/out_azure.jsonl
Normal file
|
@ -0,0 +1 @@
|
|||
{"custom_id": "ae006110bb364606||/workspace/saved_models/meta-llama/Meta-Llama-3.1-8B-Instruct", "response": {"body": {"choices": [{"content_filter_results": {"hate": {"filtered": false, "severity": "safe"}, "self_harm": {"filtered": false, "severity": "safe"}, "sexual": {"filtered": false, "severity": "safe"}, "violence": {"filtered": false, "severity": "safe"}}, "finish_reason": "stop", "index": 0, "logprobs": null, "message": {"content": "{\"evaluation\": \"Awaiting user query and AI-generated response for evaluation.\"}", "refusal": null, "role": "assistant"}}], "created": 1735059173, "id": "chatcmpl-Ai2RJ6yVuMIoRxTpPxaPPKrZ0Rg7s", "model": "gpt-4o-mini-2024-07-18", "object": "chat.completion", "prompt_filter_results": [{"prompt_index": 0, "content_filter_results": {"hate": {"filtered": false, "severity": "safe"}, "jailbreak": {"filtered": false, "detected": false}, "self_harm": {"filtered": false, "severity": "safe"}, "sexual": {"filtered": false, "severity": "safe"}, "violence": {"filtered": false, "severity": "safe"}}}], "system_fingerprint": "fp_5154047bf2", "usage": {"completion_tokens": 16, "completion_tokens_details": {"accepted_prediction_tokens": 0, "audio_tokens": 0, "reasoning_tokens": 0, "rejected_prediction_tokens": 0}, "prompt_tokens": 51, "prompt_tokens_details": {"audio_tokens": 0, "cached_tokens": 0}, "total_tokens": 67}}, "request_id": "346b61a8-b9ca-432f-a081-b0caeddbc31d", "status_code": 200}, "error": null}
|
|
@ -95,10 +95,16 @@ from openai import OpenAI
|
|||
client = OpenAI(base_url=BASE_URL, api_key=API_KEY)
|
||||
|
||||
|
||||
def create_batch_oai_sdk(filepath) -> str:
|
||||
batch_input_file = client.files.create(file=open(filepath, "rb"), purpose="batch")
|
||||
def create_batch_oai_sdk(filepath: str, custom_llm_provider: str) -> str:
|
||||
batch_input_file = client.files.create(
|
||||
file=open(filepath, "rb"),
|
||||
purpose="batch",
|
||||
extra_body={"custom_llm_provider": custom_llm_provider},
|
||||
)
|
||||
batch_input_file_id = batch_input_file.id
|
||||
|
||||
print("waiting for file to be processed......")
|
||||
time.sleep(5)
|
||||
rq = client.batches.create(
|
||||
input_file_id=batch_input_file_id,
|
||||
endpoint="/v1/chat/completions",
|
||||
|
@ -106,15 +112,18 @@ def create_batch_oai_sdk(filepath) -> str:
|
|||
metadata={
|
||||
"description": filepath,
|
||||
},
|
||||
extra_body={"custom_llm_provider": custom_llm_provider},
|
||||
)
|
||||
|
||||
print(f"Batch submitted. ID: {rq.id}")
|
||||
return rq.id
|
||||
|
||||
|
||||
def await_batch_completion(batch_id: str):
|
||||
def await_batch_completion(batch_id: str, custom_llm_provider: str):
|
||||
while True:
|
||||
batch = client.batches.retrieve(batch_id)
|
||||
batch = client.batches.retrieve(
|
||||
batch_id, extra_body={"custom_llm_provider": custom_llm_provider}
|
||||
)
|
||||
if batch.status == "completed":
|
||||
print(f"Batch {batch_id} completed.")
|
||||
return
|
||||
|
@ -123,9 +132,16 @@ def await_batch_completion(batch_id: str):
|
|||
time.sleep(10)
|
||||
|
||||
|
||||
def write_content_to_file(batch_id: str, output_path: str) -> str:
|
||||
batch = client.batches.retrieve(batch_id)
|
||||
content = client.files.content(batch.output_file_id)
|
||||
def write_content_to_file(
|
||||
batch_id: str, output_path: str, custom_llm_provider: str
|
||||
) -> str:
|
||||
batch = client.batches.retrieve(
|
||||
batch_id=batch_id, extra_body={"custom_llm_provider": custom_llm_provider}
|
||||
)
|
||||
content = client.files.content(
|
||||
file_id=batch.output_file_id,
|
||||
extra_body={"custom_llm_provider": custom_llm_provider},
|
||||
)
|
||||
print("content from files.content", content.content)
|
||||
content.write_to_file(output_path)
|
||||
|
||||
|
@ -145,20 +161,47 @@ def read_jsonl(filepath: str):
|
|||
print(custom_id)
|
||||
|
||||
|
||||
def test_e2e_batches_files():
|
||||
def get_any_completed_batch_id_azure():
|
||||
print("AZURE getting any completed batch id")
|
||||
list_of_batches = client.batches.list(extra_body={"custom_llm_provider": "azure"})
|
||||
print("list of batches", list_of_batches)
|
||||
for batch in list_of_batches:
|
||||
if batch.status == "completed":
|
||||
return batch.id
|
||||
return None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("custom_llm_provider", ["azure", "openai"])
|
||||
def test_e2e_batches_files(custom_llm_provider):
|
||||
"""
|
||||
[PROD Test] Ensures OpenAI Batches + files work with OpenAI SDK
|
||||
"""
|
||||
input_path = "input.jsonl"
|
||||
output_path = "out.jsonl"
|
||||
input_path = (
|
||||
"input.jsonl" if custom_llm_provider == "openai" else "input_azure.jsonl"
|
||||
)
|
||||
output_path = "out.jsonl" if custom_llm_provider == "openai" else "out_azure.jsonl"
|
||||
|
||||
_current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
input_file_path = os.path.join(_current_dir, input_path)
|
||||
output_file_path = os.path.join(_current_dir, output_path)
|
||||
print("running e2e batches files with custom_llm_provider=", custom_llm_provider)
|
||||
batch_id = create_batch_oai_sdk(
|
||||
filepath=input_file_path, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
|
||||
batch_id = create_batch_oai_sdk(input_file_path)
|
||||
await_batch_completion(batch_id)
|
||||
write_content_to_file(batch_id, output_file_path)
|
||||
if custom_llm_provider == "azure":
|
||||
# azure takes very long to complete a batch - randomly pick a completed batch
|
||||
batch_id = get_any_completed_batch_id_azure()
|
||||
else:
|
||||
await_batch_completion(
|
||||
batch_id=batch_id, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
|
||||
write_content_to_file(
|
||||
batch_id=batch_id,
|
||||
output_path=output_file_path,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
read_jsonl(output_file_path)
|
||||
|
||||
|
||||
|
|
|
@ -1,16 +1,17 @@
|
|||
from openai import AsyncOpenAI
|
||||
import os
|
||||
import pytest
|
||||
import asyncio
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_fine_tuning():
|
||||
"""
|
||||
[PROD Test] Ensures logprobs are returned correctly
|
||||
[PROD Test] e2e tests for /fine_tuning/jobs endpoints
|
||||
"""
|
||||
client = AsyncOpenAI(api_key="sk-1234", base_url="http://0.0.0.0:4000")
|
||||
|
||||
file_name = "openai_batch_completions.jsonl"
|
||||
file_name = "openai_fine_tuning.jsonl"
|
||||
_current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
file_path = os.path.join(_current_dir, file_name)
|
||||
|
||||
|
@ -22,10 +23,12 @@ async def test_openai_fine_tuning():
|
|||
|
||||
print("response from files.create: {}".format(response))
|
||||
|
||||
await asyncio.sleep(5)
|
||||
|
||||
# create fine tuning job
|
||||
|
||||
ft_job = await client.fine_tuning.jobs.create(
|
||||
model="gpt-35-turbo-1106",
|
||||
model="gpt-35-turbo-0613",
|
||||
training_file=response.id,
|
||||
extra_body={"custom_llm_provider": "azure"},
|
||||
)
|
||||
|
@ -33,7 +36,7 @@ async def test_openai_fine_tuning():
|
|||
print("response from ft job={}".format(ft_job))
|
||||
|
||||
# response from example endpoint
|
||||
assert ft_job.id == "ftjob-abc123"
|
||||
assert ft_job.id is not None
|
||||
|
||||
# list all fine tuning jobs
|
||||
list_ft_jobs = await client.fine_tuning.jobs.list(
|
||||
|
@ -44,10 +47,16 @@ async def test_openai_fine_tuning():
|
|||
|
||||
# cancel specific fine tuning job
|
||||
cancel_ft_job = await client.fine_tuning.jobs.cancel(
|
||||
fine_tuning_job_id="123",
|
||||
fine_tuning_job_id=ft_job.id,
|
||||
extra_body={"custom_llm_provider": "azure"},
|
||||
)
|
||||
|
||||
print("response from cancel ft job={}".format(cancel_ft_job))
|
||||
|
||||
assert cancel_ft_job.id is not None
|
||||
|
||||
# delete OG file
|
||||
await client.files.delete(
|
||||
file_id=response.id,
|
||||
extra_body={"custom_llm_provider": "azure"},
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue