mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
feat - import batches in __init__
This commit is contained in:
parent
04aace73e6
commit
4dc7bfebd4
5 changed files with 539 additions and 20 deletions
|
@ -797,3 +797,4 @@ from .budget_manager import BudgetManager
|
|||
from .proxy.proxy_cli import run_server
|
||||
from .router import Router
|
||||
from .assistants.main import *
|
||||
from .batches.main import *
|
||||
|
|
239
litellm/batches/main.py
Normal file
239
litellm/batches/main.py
Normal file
|
@ -0,0 +1,239 @@
|
|||
"""
|
||||
Main File for Batches API implementation
|
||||
|
||||
https://platform.openai.com/docs/api-reference/batch
|
||||
|
||||
- create_batch()
|
||||
- retrieve_batch()
|
||||
- cancel_batch()
|
||||
- list_batch()
|
||||
|
||||
"""
|
||||
|
||||
from typing import Iterable
|
||||
import os
|
||||
import litellm
|
||||
from openai import OpenAI
|
||||
import httpx
|
||||
from litellm import client
|
||||
from litellm.utils import supports_httpx_timeout
|
||||
from ..types.router import *
|
||||
from ..llms.openai import OpenAIBatchesAPI, OpenAIFilesAPI
|
||||
from ..types.llms.openai import (
|
||||
CreateBatchRequest,
|
||||
RetrieveBatchRequest,
|
||||
CancelBatchRequest,
|
||||
CreateFileRequest,
|
||||
FileTypes,
|
||||
FileObject,
|
||||
)
|
||||
|
||||
from typing import Literal, Optional, Dict
|
||||
|
||||
####### ENVIRONMENT VARIABLES ###################
|
||||
openai_batches_instance = OpenAIBatchesAPI()
|
||||
openai_files_instance = OpenAIFilesAPI()
|
||||
#################################################
|
||||
|
||||
|
||||
def create_file(
|
||||
file: FileTypes,
|
||||
purpose: Literal["assistants", "batch", "fine-tune"],
|
||||
custom_llm_provider: Literal["openai"] = "openai",
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
) -> FileObject:
|
||||
try:
|
||||
optional_params = GenericLiteLLMParams(**kwargs)
|
||||
if custom_llm_provider == "openai":
|
||||
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
||||
api_base = (
|
||||
optional_params.api_base
|
||||
or litellm.api_base
|
||||
or os.getenv("OPENAI_API_BASE")
|
||||
or "https://api.openai.com/v1"
|
||||
)
|
||||
organization = (
|
||||
optional_params.organization
|
||||
or litellm.organization
|
||||
or os.getenv("OPENAI_ORGANIZATION", None)
|
||||
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
|
||||
)
|
||||
# set API KEY
|
||||
api_key = (
|
||||
optional_params.api_key
|
||||
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
|
||||
or litellm.openai_key
|
||||
or os.getenv("OPENAI_API_KEY")
|
||||
)
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = (
|
||||
optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||
)
|
||||
# set timeout for 10 minutes by default
|
||||
|
||||
if (
|
||||
timeout is not None
|
||||
and isinstance(timeout, httpx.Timeout)
|
||||
and supports_httpx_timeout(custom_llm_provider) == False
|
||||
):
|
||||
read_timeout = timeout.read or 600
|
||||
timeout = read_timeout # default 10 min timeout
|
||||
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
||||
timeout = float(timeout) # type: ignore
|
||||
elif timeout is None:
|
||||
timeout = 600.0
|
||||
|
||||
_create_file_request = CreateFileRequest(
|
||||
file=file,
|
||||
purpose=purpose,
|
||||
extra_headers=extra_headers,
|
||||
extra_body=extra_body,
|
||||
)
|
||||
|
||||
response = openai_files_instance.create_file(
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
organization=organization,
|
||||
create_file_data=_create_file_request,
|
||||
)
|
||||
else:
|
||||
raise litellm.exceptions.BadRequestError(
|
||||
message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format(
|
||||
custom_llm_provider
|
||||
),
|
||||
model="n/a",
|
||||
llm_provider=custom_llm_provider,
|
||||
response=httpx.Response(
|
||||
status_code=400,
|
||||
content="Unsupported provider",
|
||||
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||
),
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
def create_batch(
|
||||
completion_window: Literal["24h"],
|
||||
endpoint: Literal["/v1/chat/completions", "/v1/embeddings", "/v1/completions"],
|
||||
input_file_id: str,
|
||||
custom_llm_provider: Literal["openai"] = "openai",
|
||||
metadata: Optional[Dict[str, str]] = None,
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Creates and executes a batch from an uploaded file of request
|
||||
|
||||
LiteLLM Equivalent of POST: https://api.openai.com/v1/batches
|
||||
"""
|
||||
try:
|
||||
optional_params = GenericLiteLLMParams(**kwargs)
|
||||
if custom_llm_provider == "openai":
|
||||
|
||||
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
||||
api_base = (
|
||||
optional_params.api_base
|
||||
or litellm.api_base
|
||||
or os.getenv("OPENAI_API_BASE")
|
||||
or "https://api.openai.com/v1"
|
||||
)
|
||||
organization = (
|
||||
optional_params.organization
|
||||
or litellm.organization
|
||||
or os.getenv("OPENAI_ORGANIZATION", None)
|
||||
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
|
||||
)
|
||||
# set API KEY
|
||||
api_key = (
|
||||
optional_params.api_key
|
||||
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
|
||||
or litellm.openai_key
|
||||
or os.getenv("OPENAI_API_KEY")
|
||||
)
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = (
|
||||
optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||
)
|
||||
# set timeout for 10 minutes by default
|
||||
|
||||
if (
|
||||
timeout is not None
|
||||
and isinstance(timeout, httpx.Timeout)
|
||||
and supports_httpx_timeout(custom_llm_provider) == False
|
||||
):
|
||||
read_timeout = timeout.read or 600
|
||||
timeout = read_timeout # default 10 min timeout
|
||||
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
||||
timeout = float(timeout) # type: ignore
|
||||
elif timeout is None:
|
||||
timeout = 600.0
|
||||
|
||||
_create_batch_request = CreateBatchRequest(
|
||||
completion_window=completion_window,
|
||||
endpoint=endpoint,
|
||||
input_file_id=input_file_id,
|
||||
metadata=metadata,
|
||||
extra_headers=extra_headers,
|
||||
extra_body=extra_body,
|
||||
)
|
||||
|
||||
response = openai_batches_instance.create_batch(
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
organization=organization,
|
||||
create_batch_data=_create_batch_request,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
)
|
||||
else:
|
||||
raise litellm.exceptions.BadRequestError(
|
||||
message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format(
|
||||
custom_llm_provider
|
||||
),
|
||||
model="n/a",
|
||||
llm_provider=custom_llm_provider,
|
||||
response=httpx.Response(
|
||||
status_code=400,
|
||||
content="Unsupported provider",
|
||||
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||
),
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
def retrieve_batch():
|
||||
pass
|
||||
|
||||
|
||||
def cancel_batch():
|
||||
pass
|
||||
|
||||
|
||||
def list_batch():
|
||||
pass
|
||||
|
||||
|
||||
# Async Functions
|
||||
async def acreate_batch():
|
||||
pass
|
||||
|
||||
|
||||
async def aretrieve_batch():
|
||||
pass
|
||||
|
||||
|
||||
async def acancel_batch():
|
||||
pass
|
||||
|
||||
|
||||
async def alist_batch():
|
||||
pass
|
|
@ -1497,6 +1497,189 @@ class OpenAITextCompletion(BaseLLM):
|
|||
yield transformed_chunk
|
||||
|
||||
|
||||
class OpenAIFilesAPI(BaseLLM):
|
||||
"""
|
||||
OpenAI methods to support for batches
|
||||
- create_file()
|
||||
- retrieve_file()
|
||||
- list_files()
|
||||
- delete_file()
|
||||
- file_content()
|
||||
- update_file()
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def get_openai_client(
|
||||
self,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
organization: Optional[str],
|
||||
client: Optional[OpenAI] = None,
|
||||
) -> OpenAI:
|
||||
received_args = locals()
|
||||
if client is None:
|
||||
data = {}
|
||||
for k, v in received_args.items():
|
||||
if k == "self" or k == "client":
|
||||
pass
|
||||
elif k == "api_base" and v is not None:
|
||||
data["base_url"] = v
|
||||
elif v is not None:
|
||||
data[k] = v
|
||||
openai_client = OpenAI(**data) # type: ignore
|
||||
else:
|
||||
openai_client = client
|
||||
|
||||
return openai_client
|
||||
|
||||
def create_file(
|
||||
self,
|
||||
create_file_data: CreateFileRequest,
|
||||
api_base: str,
|
||||
api_key: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
organization: Optional[str],
|
||||
client: Optional[OpenAI] = None,
|
||||
) -> FileObject:
|
||||
openai_client: OpenAI = self.get_openai_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
organization=organization,
|
||||
client=client,
|
||||
)
|
||||
response = openai_client.files.create(**create_file_data)
|
||||
return response
|
||||
|
||||
|
||||
class OpenAIBatchesAPI(BaseLLM):
|
||||
"""
|
||||
OpenAI methods to support for batches
|
||||
- create_batch()
|
||||
- retrieve_batch()
|
||||
- cancel_batch()
|
||||
- list_batch()
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def get_openai_client(
|
||||
self,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
organization: Optional[str],
|
||||
client: Optional[OpenAI] = None,
|
||||
) -> OpenAI:
|
||||
received_args = locals()
|
||||
if client is None:
|
||||
data = {}
|
||||
for k, v in received_args.items():
|
||||
if k == "self" or k == "client":
|
||||
pass
|
||||
elif k == "api_base" and v is not None:
|
||||
data["base_url"] = v
|
||||
elif v is not None:
|
||||
data[k] = v
|
||||
openai_client = OpenAI(**data) # type: ignore
|
||||
else:
|
||||
openai_client = client
|
||||
|
||||
return openai_client
|
||||
|
||||
def create_batch(
|
||||
self,
|
||||
create_batch_data: CreateBatchRequest,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
organization: Optional[str],
|
||||
client: Optional[OpenAI] = None,
|
||||
):
|
||||
openai_client: OpenAI = self.get_openai_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
organization=organization,
|
||||
client=client,
|
||||
)
|
||||
response = openai_client.batches.create(**create_batch_data)
|
||||
return response
|
||||
|
||||
def retrieve_batch(
|
||||
self,
|
||||
retrieve_batch_data: RetrieveBatchRequest,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
organization: Optional[str],
|
||||
client: Optional[OpenAI] = None,
|
||||
):
|
||||
openai_client: OpenAI = self.get_openai_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
organization=organization,
|
||||
client=client,
|
||||
)
|
||||
response = openai_client.batches.retrieve(**retrieve_batch_data)
|
||||
return response
|
||||
|
||||
def cancel_batch(
|
||||
self,
|
||||
cancel_batch_data: CancelBatchRequest,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
organization: Optional[str],
|
||||
client: Optional[OpenAI] = None,
|
||||
):
|
||||
openai_client: OpenAI = self.get_openai_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
organization=organization,
|
||||
client=client,
|
||||
)
|
||||
response = openai_client.batches.cancel(**cancel_batch_data)
|
||||
return response
|
||||
|
||||
# def list_batch(
|
||||
# self,
|
||||
# list_batch_data: ListBatchRequest,
|
||||
# api_key: Optional[str],
|
||||
# api_base: Optional[str],
|
||||
# timeout: Union[float, httpx.Timeout],
|
||||
# max_retries: Optional[int],
|
||||
# organization: Optional[str],
|
||||
# client: Optional[OpenAI] = None,
|
||||
# ):
|
||||
# openai_client: OpenAI = self.get_openai_client(
|
||||
# api_key=api_key,
|
||||
# api_base=api_base,
|
||||
# timeout=timeout,
|
||||
# max_retries=max_retries,
|
||||
# organization=organization,
|
||||
# client=client,
|
||||
# )
|
||||
# response = openai_client.batches.list(**list_batch_data)
|
||||
# return response
|
||||
|
||||
|
||||
class OpenAIAssistantsAPI(BaseLLM):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
|
58
litellm/tests/test_openai_batches.py
Normal file
58
litellm/tests/test_openai_batches.py
Normal file
|
@ -0,0 +1,58 @@
|
|||
# What is this?
|
||||
## Unit Tests for OpenAI Batches API
|
||||
import sys, os, json
|
||||
import traceback
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import pytest, logging, asyncio
|
||||
import litellm
|
||||
from litellm import (
|
||||
create_batch,
|
||||
create_file,
|
||||
)
|
||||
|
||||
|
||||
def test_create_batch():
|
||||
"""
|
||||
1. Create File for Batch completion
|
||||
2. Create Batch Request
|
||||
"""
|
||||
file_obj = litellm.create_file(
|
||||
file=open("openai_batch_completions.jsonl", "rb"),
|
||||
purpose="batch",
|
||||
custom_llm_provider="openai",
|
||||
)
|
||||
print("Response from creating file=", file_obj)
|
||||
|
||||
batch_input_file_id = file_obj.id
|
||||
assert (
|
||||
batch_input_file_id is not None
|
||||
), "Failed to create file, expected a non null file_id but got {batch_input_file_id}"
|
||||
|
||||
print("response from creating file=", file_obj)
|
||||
# response = create_batch(
|
||||
# completion_window="24h",
|
||||
# endpoint="/v1/chat/completions",
|
||||
# input_file_id="1",
|
||||
# custom_llm_provider="openai",
|
||||
# metadata={"key1": "value1", "key2": "value2"},
|
||||
# )
|
||||
|
||||
print("response")
|
||||
pass
|
||||
|
||||
|
||||
def test_retrieve_batch():
|
||||
pass
|
||||
|
||||
|
||||
def test_cancel_batch():
|
||||
pass
|
||||
|
||||
|
||||
def test_list_batch():
|
||||
pass
|
|
@ -18,8 +18,23 @@ from openai.types.beta.assistant_tool_param import AssistantToolParam
|
|||
from openai.types.beta.threads.run import Run
|
||||
from openai.types.beta.assistant import Assistant
|
||||
from openai.pagination import SyncCursorPage
|
||||
from os import PathLike
|
||||
from openai.types import FileObject
|
||||
|
||||
from typing import TypedDict, List, Optional
|
||||
from typing import TypedDict, List, Optional, Tuple, Mapping, IO
|
||||
|
||||
FileContent = Union[IO[bytes], bytes, PathLike[str]]
|
||||
|
||||
FileTypes = Union[
|
||||
# file (or bytes)
|
||||
FileContent,
|
||||
# (filename, file (or bytes))
|
||||
Tuple[Optional[str], FileContent],
|
||||
# (filename, file (or bytes), content_type)
|
||||
Tuple[Optional[str], FileContent, Optional[str]],
|
||||
# (filename, file (or bytes), content_type, headers)
|
||||
Tuple[Optional[str], FileContent, Optional[str], Mapping[str, str]],
|
||||
]
|
||||
|
||||
|
||||
class NotGiven:
|
||||
|
@ -148,8 +163,31 @@ class Thread(BaseModel):
|
|||
"""The object type, which is always `thread`."""
|
||||
|
||||
|
||||
# OpenAI Files Types
|
||||
class CreateFileRequest(TypedDict, total=False):
|
||||
"""
|
||||
CreateFileRequest
|
||||
Used by Assistants API, Batches API, and Fine-Tunes API
|
||||
|
||||
Required Params:
|
||||
file: FileTypes
|
||||
purpose: Literal['assistants', 'batch', 'fine-tune']
|
||||
|
||||
Optional Params:
|
||||
extra_headers: Optional[Dict[str, str]]
|
||||
extra_body: Optional[Dict[str, str]] = None
|
||||
timeout: Optional[float] = None
|
||||
"""
|
||||
|
||||
file: FileTypes
|
||||
purpose: Literal["assistants", "batch", "fine-tune"]
|
||||
extra_headers: Optional[Dict[str, str]]
|
||||
extra_body: Optional[Dict[str, str]]
|
||||
timeout: Optional[float]
|
||||
|
||||
|
||||
# OpenAI Batches Types
|
||||
class CreateBatchRequest(BaseModel):
|
||||
class CreateBatchRequest(TypedDict, total=False):
|
||||
"""
|
||||
CreateBatchRequest
|
||||
"""
|
||||
|
@ -157,42 +195,42 @@ class CreateBatchRequest(BaseModel):
|
|||
completion_window: Literal["24h"]
|
||||
endpoint: Literal["/v1/chat/completions", "/v1/embeddings", "/v1/completions"]
|
||||
input_file_id: str
|
||||
metadata: Optional[Dict[str, str]] = None
|
||||
extra_headers: Optional[Dict[str, str]] = None
|
||||
extra_body: Optional[Dict[str, str]] = None
|
||||
timeout: Optional[float] = None
|
||||
metadata: Optional[Dict[str, str]]
|
||||
extra_headers: Optional[Dict[str, str]]
|
||||
extra_body: Optional[Dict[str, str]]
|
||||
timeout: Optional[float]
|
||||
|
||||
|
||||
class RetrieveBatchRequest(BaseModel):
|
||||
class RetrieveBatchRequest(TypedDict, total=False):
|
||||
"""
|
||||
RetrieveBatchRequest
|
||||
"""
|
||||
|
||||
batch_id: str
|
||||
extra_headers: Optional[Dict[str, str]] = None
|
||||
extra_body: Optional[Dict[str, str]] = None
|
||||
timeout: Optional[float] = None
|
||||
extra_headers: Optional[Dict[str, str]]
|
||||
extra_body: Optional[Dict[str, str]]
|
||||
timeout: Optional[float]
|
||||
|
||||
|
||||
class CancelBatchRequest(BaseModel):
|
||||
class CancelBatchRequest(TypedDict, total=False):
|
||||
"""
|
||||
CancelBatchRequest
|
||||
"""
|
||||
|
||||
batch_id: str
|
||||
extra_headers: Optional[Dict[str, str]] = None
|
||||
extra_body: Optional[Dict[str, str]] = None
|
||||
timeout: Optional[float] = None
|
||||
extra_headers: Optional[Dict[str, str]]
|
||||
extra_body: Optional[Dict[str, str]]
|
||||
timeout: Optional[float]
|
||||
|
||||
|
||||
class ListBatchRequest(BaseModel):
|
||||
class ListBatchRequest(TypedDict, total=False):
|
||||
"""
|
||||
ListBatchRequest - List your organization's batches
|
||||
Calls https://api.openai.com/v1/batches
|
||||
"""
|
||||
|
||||
after: Optional[str] = None
|
||||
limit: Optional[int] = 20
|
||||
extra_headers: Optional[Dict[str, str]] = None
|
||||
extra_body: Optional[Dict[str, str]] = None
|
||||
timeout: Optional[float] = None
|
||||
after: Union[str, NotGiven]
|
||||
limit: Union[int, NotGiven]
|
||||
extra_headers: Optional[Dict[str, str]]
|
||||
extra_body: Optional[Dict[str, str]]
|
||||
timeout: Optional[float]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue