(Feat) add `"/v1/batches/{batch_id:path}/cancel" endpoint (#7406)

* use 1 file for azure batches handling

* add cancel_batch endpoint

* add a cancel batch on open ai

* add cancel_batch endpoint

* add cancel batches to test

* remove unused imports

* test_batches_operations

* update test_batches_operations
This commit is contained in:
Ishaan Jaff 2024-12-24 20:23:50 -08:00 committed by GitHub
parent 440009fb32
commit 54cb64d03d
7 changed files with 589 additions and 304 deletions

View file

@ -20,11 +20,16 @@ import httpx
import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.azure.azure import AzureBatchesAPI
from litellm.llms.azure.batches.handler import AzureBatchesAPI
from litellm.llms.openai.openai import OpenAIBatchesAPI
from litellm.llms.vertex_ai.batches.handler import VertexAIBatchPrediction
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import Batch, CreateBatchRequest, RetrieveBatchRequest
from litellm.types.llms.openai import (
Batch,
CancelBatchRequest,
CreateBatchRequest,
RetrieveBatchRequest,
)
from litellm.types.router import GenericLiteLLMParams
from litellm.utils import client, get_litellm_params, supports_httpx_timeout
@ -582,9 +587,163 @@ def list_batches(
raise e
def cancel_batch():
pass
async def acancel_batch(
batch_id: str,
custom_llm_provider: Literal["openai", "azure"] = "openai",
metadata: Optional[Dict[str, str]] = None,
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
) -> Batch:
"""
Async: Cancels a batch.
LiteLLM Equivalent of POST https://api.openai.com/v1/batches/{batch_id}/cancel
"""
try:
loop = asyncio.get_event_loop()
kwargs["acancel_batch"] = True
# Use a partial function to pass your keyword arguments
func = partial(
cancel_batch,
batch_id,
custom_llm_provider,
metadata,
extra_headers,
extra_body,
**kwargs,
)
# Add the context to the function
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
init_response = await loop.run_in_executor(None, func_with_context)
if asyncio.iscoroutine(init_response):
response = await init_response
else:
response = init_response
return response
except Exception as e:
raise e
async def acancel_batch():
pass
def cancel_batch(
batch_id: str,
custom_llm_provider: Literal["openai", "azure"] = "openai",
metadata: Optional[Dict[str, str]] = None,
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
) -> Union[Batch, Coroutine[Any, Any, Batch]]:
"""
Cancels a batch.
LiteLLM Equivalent of POST https://api.openai.com/v1/batches/{batch_id}/cancel
"""
try:
optional_params = GenericLiteLLMParams(**kwargs)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
# set timeout for 10 minutes by default
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) is False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
timeout = float(timeout) # type: ignore
elif timeout is None:
timeout = 600.0
_cancel_batch_request = CancelBatchRequest(
batch_id=batch_id,
extra_headers=extra_headers,
extra_body=extra_body,
)
_is_async = kwargs.pop("acancel_batch", False) is True
api_base: Optional[str] = None
if custom_llm_provider == "openai":
api_base = (
optional_params.api_base
or litellm.api_base
or os.getenv("OPENAI_API_BASE")
or "https://api.openai.com/v1"
)
organization = (
optional_params.organization
or litellm.organization
or os.getenv("OPENAI_ORGANIZATION", None)
or None
)
api_key = (
optional_params.api_key
or litellm.api_key
or litellm.openai_key
or os.getenv("OPENAI_API_KEY")
)
response = openai_batches_instance.cancel_batch(
_is_async=_is_async,
cancel_batch_data=_cancel_batch_request,
api_base=api_base,
api_key=api_key,
organization=organization,
timeout=timeout,
max_retries=optional_params.max_retries,
)
elif custom_llm_provider == "azure":
api_base = (
optional_params.api_base
or litellm.api_base
or get_secret_str("AZURE_API_BASE")
)
api_version = (
optional_params.api_version
or litellm.api_version
or get_secret_str("AZURE_API_VERSION")
)
api_key = (
optional_params.api_key
or litellm.api_key
or litellm.azure_key
or get_secret_str("AZURE_OPENAI_API_KEY")
or get_secret_str("AZURE_API_KEY")
)
extra_body = optional_params.get("extra_body", {})
if extra_body is not None:
extra_body.pop("azure_ad_token", None)
else:
get_secret_str("AZURE_AD_TOKEN") # type: ignore
response = azure_batches_instance.cancel_batch(
_is_async=_is_async,
api_base=api_base,
api_key=api_key,
api_version=api_version,
timeout=timeout,
max_retries=optional_params.max_retries,
cancel_batch_data=_cancel_batch_request,
)
else:
raise litellm.exceptions.BadRequestError(
message="LiteLLM doesn't support {} for 'cancel_batch'. Only 'openai' and 'azure' are supported.".format(
custom_llm_provider
),
model="n/a",
llm_provider=custom_llm_provider,
response=httpx.Response(
status_code=400,
content="Unsupported provider",
request=httpx.Request(method="cancel_batch", url="https://github.com/BerriAI/litellm"), # type: ignore
),
)
return response
except Exception as e:
raise e

View file

@ -2,7 +2,7 @@ import asyncio
import json
import os
import time
from typing import Any, Callable, Coroutine, List, Literal, Optional, Union
from typing import Any, Callable, List, Literal, Optional, Union
import httpx # type: ignore
from openai import AsyncAzureOpenAI, AzureOpenAI
@ -28,13 +28,7 @@ from litellm.utils import (
modify_url,
)
from ...types.llms.openai import (
Batch,
CancelBatchRequest,
CreateBatchRequest,
HttpxBinaryResponseContent,
RetrieveBatchRequest,
)
from ...types.llms.openai import HttpxBinaryResponseContent
from ..base import BaseLLM
from .common_utils import AzureOpenAIError, process_azure_headers
@ -1613,216 +1607,3 @@ class AzureChatCompletion(BaseLLM):
response["x-ms-region"] = completion.headers["x-ms-region"]
return response
class AzureBatchesAPI(BaseLLM):
"""
Azure methods to support for batches
- create_batch()
- retrieve_batch()
- cancel_batch()
- list_batch()
"""
def __init__(self) -> None:
super().__init__()
def get_azure_openai_client(
self,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
api_version: Optional[str] = None,
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
_is_async: bool = False,
) -> Optional[Union[AzureOpenAI, AsyncAzureOpenAI]]:
received_args = locals()
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None
if client is None:
data = {}
for k, v in received_args.items():
if k == "self" or k == "client" or k == "_is_async":
pass
elif k == "api_base" and v is not None:
data["azure_endpoint"] = v
elif v is not None:
data[k] = v
if "api_version" not in data:
data["api_version"] = litellm.AZURE_DEFAULT_API_VERSION
if _is_async is True:
openai_client = AsyncAzureOpenAI(**data)
else:
openai_client = AzureOpenAI(**data) # type: ignore
else:
openai_client = client
return openai_client
async def acreate_batch(
self,
create_batch_data: CreateBatchRequest,
azure_client: AsyncAzureOpenAI,
) -> Batch:
response = await azure_client.batches.create(**create_batch_data)
return response
def create_batch(
self,
_is_async: bool,
create_batch_data: CreateBatchRequest,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
) -> Union[Batch, Coroutine[Any, Any, Batch]]:
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
self.get_azure_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
api_version=api_version,
max_retries=max_retries,
client=client,
_is_async=_is_async,
)
)
if azure_client is None:
raise ValueError(
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
if _is_async is True:
if not isinstance(azure_client, AsyncAzureOpenAI):
raise ValueError(
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
)
return self.acreate_batch( # type: ignore
create_batch_data=create_batch_data, azure_client=azure_client
)
response = azure_client.batches.create(**create_batch_data)
return response
async def aretrieve_batch(
self,
retrieve_batch_data: RetrieveBatchRequest,
client: AsyncAzureOpenAI,
) -> Batch:
response = await client.batches.retrieve(**retrieve_batch_data)
return response
def retrieve_batch(
self,
_is_async: bool,
retrieve_batch_data: RetrieveBatchRequest,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[AzureOpenAI] = None,
):
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
self.get_azure_openai_client(
api_key=api_key,
api_base=api_base,
api_version=api_version,
timeout=timeout,
max_retries=max_retries,
client=client,
_is_async=_is_async,
)
)
if azure_client is None:
raise ValueError(
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
if _is_async is True:
if not isinstance(azure_client, AsyncAzureOpenAI):
raise ValueError(
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
)
return self.aretrieve_batch( # type: ignore
retrieve_batch_data=retrieve_batch_data, client=azure_client
)
response = azure_client.batches.retrieve(**retrieve_batch_data)
return response
def cancel_batch(
self,
_is_async: bool,
cancel_batch_data: CancelBatchRequest,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[AzureOpenAI] = None,
):
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
self.get_azure_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
client=client,
_is_async=_is_async,
)
)
if azure_client is None:
raise ValueError(
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
response = azure_client.batches.cancel(**cancel_batch_data)
return response
async def alist_batches(
self,
client: AsyncAzureOpenAI,
after: Optional[str] = None,
limit: Optional[int] = None,
):
response = await client.batches.list(after=after, limit=limit) # type: ignore
return response
def list_batches(
self,
_is_async: bool,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
after: Optional[str] = None,
limit: Optional[int] = None,
client: Optional[AzureOpenAI] = None,
):
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
self.get_azure_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
api_version=api_version,
client=client,
_is_async=_is_async,
)
)
if azure_client is None:
raise ValueError(
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
if _is_async is True:
if not isinstance(azure_client, AsyncAzureOpenAI):
raise ValueError(
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
)
return self.alist_batches( # type: ignore
client=azure_client, after=after, limit=limit
)
response = azure_client.batches.list(after=after, limit=limit) # type: ignore
return response

View file

@ -0,0 +1,238 @@
"""
Azure Batches API Handler
"""
from typing import Any, Coroutine, Optional, Union
import httpx
import litellm
from litellm.llms.azure.azure import AsyncAzureOpenAI, AzureOpenAI
from litellm.types.llms.openai import (
Batch,
CancelBatchRequest,
CreateBatchRequest,
RetrieveBatchRequest,
)
class AzureBatchesAPI:
"""
Azure methods to support for batches
- create_batch()
- retrieve_batch()
- cancel_batch()
- list_batch()
"""
def __init__(self) -> None:
super().__init__()
def get_azure_openai_client(
self,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
api_version: Optional[str] = None,
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
_is_async: bool = False,
) -> Optional[Union[AzureOpenAI, AsyncAzureOpenAI]]:
received_args = locals()
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None
if client is None:
data = {}
for k, v in received_args.items():
if k == "self" or k == "client" or k == "_is_async":
pass
elif k == "api_base" and v is not None:
data["azure_endpoint"] = v
elif v is not None:
data[k] = v
if "api_version" not in data:
data["api_version"] = litellm.AZURE_DEFAULT_API_VERSION
if _is_async is True:
openai_client = AsyncAzureOpenAI(**data)
else:
openai_client = AzureOpenAI(**data) # type: ignore
else:
openai_client = client
return openai_client
async def acreate_batch(
self,
create_batch_data: CreateBatchRequest,
azure_client: AsyncAzureOpenAI,
) -> Batch:
response = await azure_client.batches.create(**create_batch_data)
return response
def create_batch(
self,
_is_async: bool,
create_batch_data: CreateBatchRequest,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
) -> Union[Batch, Coroutine[Any, Any, Batch]]:
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
self.get_azure_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
api_version=api_version,
max_retries=max_retries,
client=client,
_is_async=_is_async,
)
)
if azure_client is None:
raise ValueError(
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
if _is_async is True:
if not isinstance(azure_client, AsyncAzureOpenAI):
raise ValueError(
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
)
return self.acreate_batch( # type: ignore
create_batch_data=create_batch_data, azure_client=azure_client
)
response = azure_client.batches.create(**create_batch_data)
return response
async def aretrieve_batch(
self,
retrieve_batch_data: RetrieveBatchRequest,
client: AsyncAzureOpenAI,
) -> Batch:
response = await client.batches.retrieve(**retrieve_batch_data)
return response
def retrieve_batch(
self,
_is_async: bool,
retrieve_batch_data: RetrieveBatchRequest,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[AzureOpenAI] = None,
):
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
self.get_azure_openai_client(
api_key=api_key,
api_base=api_base,
api_version=api_version,
timeout=timeout,
max_retries=max_retries,
client=client,
_is_async=_is_async,
)
)
if azure_client is None:
raise ValueError(
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
if _is_async is True:
if not isinstance(azure_client, AsyncAzureOpenAI):
raise ValueError(
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
)
return self.aretrieve_batch( # type: ignore
retrieve_batch_data=retrieve_batch_data, client=azure_client
)
response = azure_client.batches.retrieve(**retrieve_batch_data)
return response
async def acancel_batch(
self,
cancel_batch_data: CancelBatchRequest,
client: AsyncAzureOpenAI,
) -> Batch:
response = await client.batches.cancel(**cancel_batch_data)
return response
def cancel_batch(
self,
_is_async: bool,
cancel_batch_data: CancelBatchRequest,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[AzureOpenAI] = None,
):
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
self.get_azure_openai_client(
api_key=api_key,
api_base=api_base,
api_version=api_version,
timeout=timeout,
max_retries=max_retries,
client=client,
_is_async=_is_async,
)
)
if azure_client is None:
raise ValueError(
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
response = azure_client.batches.cancel(**cancel_batch_data)
return response
async def alist_batches(
self,
client: AsyncAzureOpenAI,
after: Optional[str] = None,
limit: Optional[int] = None,
):
response = await client.batches.list(after=after, limit=limit) # type: ignore
return response
def list_batches(
self,
_is_async: bool,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
after: Optional[str] = None,
limit: Optional[int] = None,
client: Optional[AzureOpenAI] = None,
):
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
self.get_azure_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
api_version=api_version,
client=client,
_is_async=_is_async,
)
)
if azure_client is None:
raise ValueError(
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
if _is_async is True:
if not isinstance(azure_client, AsyncAzureOpenAI):
raise ValueError(
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
)
return self.alist_batches( # type: ignore
client=azure_client, after=after, limit=limit
)
response = azure_client.batches.list(after=after, limit=limit) # type: ignore
return response

View file

@ -1799,6 +1799,15 @@ class OpenAIBatchesAPI(BaseLLM):
response = openai_client.batches.retrieve(**retrieve_batch_data)
return response
async def acancel_batch(
self,
cancel_batch_data: CancelBatchRequest,
openai_client: AsyncOpenAI,
) -> Batch:
verbose_logger.debug("async cancelling batch, args= %s", cancel_batch_data)
response = await openai_client.batches.cancel(**cancel_batch_data)
return response
def cancel_batch(
self,
_is_async: bool,
@ -1823,6 +1832,16 @@ class OpenAIBatchesAPI(BaseLLM):
raise ValueError(
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
if _is_async is True:
if not isinstance(openai_client, AsyncOpenAI):
raise ValueError(
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
)
return self.acancel_batch( # type: ignore
cancel_batch_data=cancel_batch_data, openai_client=openai_client
)
response = openai_client.batches.cancel(**cancel_batch_data)
return response

View file

@ -11,7 +11,11 @@ from fastapi import APIRouter, Depends, HTTPException, Path, Request, Response
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.batches.main import CreateBatchRequest, RetrieveBatchRequest
from litellm.batches.main import (
CancelBatchRequest,
CreateBatchRequest,
RetrieveBatchRequest,
)
from litellm.proxy._types import *
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
@ -353,6 +357,116 @@ async def list_batches(
raise handle_exception_on_proxy(e)
@router.post(
"/{provider}/v1/batches/{batch_id:path}/cancel",
dependencies=[Depends(user_api_key_auth)],
tags=["batch"],
)
@router.post(
"/v1/batches/{batch_id:path}/cancel",
dependencies=[Depends(user_api_key_auth)],
tags=["batch"],
)
@router.post(
"/batches/{batch_id:path}/cancel",
dependencies=[Depends(user_api_key_auth)],
tags=["batch"],
)
async def cancel_batch(
request: Request,
batch_id: str,
fastapi_response: Response,
provider: Optional[str] = None,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Cancel a batch.
This is the equivalent of POST https://api.openai.com/v1/batches/{batch_id}/cancel
Supports Identical Params as: https://platform.openai.com/docs/api-reference/batch/cancel
Example Curl
```
curl http://localhost:4000/v1/batches/batch_abc123/cancel \
-H "Authorization: Bearer sk-1234" \
-H "Content-Type: application/json" \
-X POST
```
"""
from litellm.proxy.proxy_server import (
add_litellm_data_to_request,
general_settings,
get_custom_headers,
proxy_config,
proxy_logging_obj,
version,
)
data: Dict = {}
try:
data = await _read_request_body(request=request)
verbose_proxy_logger.debug(
"Request received by LiteLLM:\n{}".format(json.dumps(data, indent=4)),
)
# Include original request and headers in the data
data = await add_litellm_data_to_request(
data=data,
request=request,
general_settings=general_settings,
user_api_key_dict=user_api_key_dict,
version=version,
proxy_config=proxy_config,
)
custom_llm_provider = (
provider or data.pop("custom_llm_provider", None) or "openai"
)
_cancel_batch_data = CancelBatchRequest(batch_id=batch_id, **data)
response = await litellm.acancel_batch(
custom_llm_provider=custom_llm_provider, # type: ignore
**_cancel_batch_data
)
### ALERTING ###
asyncio.create_task(
proxy_logging_obj.update_request_status(
litellm_call_id=data.get("litellm_call_id", ""), status="success"
)
)
### RESPONSE HEADERS ###
hidden_params = getattr(response, "_hidden_params", {}) or {}
model_id = hidden_params.get("model_id", None) or ""
cache_key = hidden_params.get("cache_key", None) or ""
api_base = hidden_params.get("api_base", None) or ""
fastapi_response.headers.update(
get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,
api_base=api_base,
version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
request_data=data,
)
)
return response
except Exception as e:
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
)
verbose_proxy_logger.exception(
"litellm.proxy.proxy_server.create_batch(): Exception occured - {}".format(
str(e)
)
)
raise handle_exception_on_proxy(e)
######################################################################
# END OF /v1/batches Endpoints Implementation

View file

@ -144,6 +144,13 @@ async def test_create_batch(provider):
with open(result_file_name, "wb") as file:
file.write(result)
# Cancel Batch
cancel_batch_response = await litellm.acancel_batch(
batch_id=create_batch_response.id,
custom_llm_provider=provider,
)
print("cancel_batch_response=", cancel_batch_response)
pass
@ -282,6 +289,13 @@ async def test_async_create_batch(provider):
with open(result_file_name, "wb") as file:
file.write(file_content.content)
# Cancel Batch
cancel_batch_response = await litellm.acancel_batch(
batch_id=create_batch_response.id,
custom_llm_provider=provider,
)
print("cancel_batch_response=", cancel_batch_response)
def cleanup_azure_files():
"""
@ -301,18 +315,6 @@ def cleanup_azure_files():
assert delete_file_response.id == _file.id
def test_retrieve_batch():
pass
def test_cancel_batch():
pass
def test_list_batch():
pass
@pytest.mark.asyncio
async def test_avertex_batch_prediction():
load_vertex_ai_credentials()

View file

@ -14,85 +14,57 @@ import time
BASE_URL = "http://localhost:4000" # Replace with your actual base URL
API_KEY = "sk-1234" # Replace with your actual API key
from openai import OpenAI
async def create_batch(session, input_file_id, endpoint, completion_window):
url = f"{BASE_URL}/v1/batches"
headers = {"Authorization": f"Bearer {API_KEY}", "Content-Type": "application/json"}
payload = {
"input_file_id": input_file_id,
"endpoint": endpoint,
"completion_window": completion_window,
}
async with session.post(url, headers=headers, json=payload) as response:
assert response.status == 200, f"Expected status 200, got {response.status}"
result = await response.json()
print(f"Batch creation successful. Batch ID: {result.get('id', 'N/A')}")
return result
async def get_batch_by_id(session, batch_id):
url = f"{BASE_URL}/v1/batches/{batch_id}"
headers = {"Authorization": f"Bearer {API_KEY}"}
async with session.get(url, headers=headers) as response:
if response.status == 200:
result = await response.json()
return result
else:
print(f"Error: Failed to get batch. Status code: {response.status}")
return None
async def list_batches(session):
url = f"{BASE_URL}/v1/batches"
headers = {"Authorization": f"Bearer {API_KEY}"}
async with session.get(url, headers=headers) as response:
if response.status == 200:
result = await response.json()
return result
else:
print(f"Error: Failed to get batch. Status code: {response.status}")
return None
client = OpenAI(base_url=BASE_URL, api_key=API_KEY)
@pytest.mark.asyncio
async def test_batches_operations():
async with aiohttp.ClientSession() as session:
# Test file upload and get file_id
file_id = await upload_file(session, purpose="batch")
_current_dir = os.path.dirname(os.path.abspath(__file__))
input_file_path = os.path.join(_current_dir, "input.jsonl")
file_obj = client.files.create(
file=open(input_file_path, "rb"),
purpose="batch",
)
create_batch_response = await create_batch(
session, file_id, "/v1/chat/completions", "24h"
)
batch_id = create_batch_response.get("id")
assert batch_id is not None
batch = client.batches.create(
input_file_id=file_obj.id,
endpoint="/v1/chat/completions",
completion_window="24h",
)
# Test get batch
get_batch_response = await get_batch_by_id(session, batch_id)
print("response from get batch", get_batch_response)
assert batch.id is not None
assert get_batch_response["id"] == batch_id
assert get_batch_response["input_file_id"] == file_id
# Test get batch
_retrieved_batch = client.batches.retrieve(batch_id=batch.id)
print("response from get batch", _retrieved_batch)
# test LIST Batches
list_batch_response = await list_batches(session)
print("response from list batch", list_batch_response)
assert _retrieved_batch.id == batch.id
assert _retrieved_batch.input_file_id == file_obj.id
assert list_batch_response is not None
assert len(list_batch_response["data"]) > 0
# Test list batches
_list_batches = client.batches.list()
print("response from list batches", _list_batches)
element_0 = list_batch_response["data"][0]
assert element_0["id"] is not None
assert _list_batches is not None
assert len(_list_batches.data) > 0
# Test delete file
await delete_file(session, file_id)
# Clean up
# Test cancel batch
_canceled_batch = client.batches.cancel(batch_id=batch.id)
print("response from cancel batch", _canceled_batch)
assert _canceled_batch.status is not None
assert (
_canceled_batch.status == "cancelling" or _canceled_batch.status == "cancelled"
)
from openai import OpenAI
# finally delete the file
_deleted_file = client.files.delete(file_id=file_obj.id)
print("response from delete file", _deleted_file)
client = OpenAI(base_url=BASE_URL, api_key=API_KEY)
assert _deleted_file.deleted is True
def create_batch_oai_sdk(filepath: str, custom_llm_provider: str) -> str: