feat - import batches in __init__

This commit is contained in:
Ishaan Jaff 2024-05-28 15:35:11 -07:00
parent 04aace73e6
commit 4dc7bfebd4
5 changed files with 539 additions and 20 deletions

View file

@ -797,3 +797,4 @@ from .budget_manager import BudgetManager
from .proxy.proxy_cli import run_server from .proxy.proxy_cli import run_server
from .router import Router from .router import Router
from .assistants.main import * from .assistants.main import *
from .batches.main import *

239
litellm/batches/main.py Normal file
View file

@ -0,0 +1,239 @@
"""
Main File for Batches API implementation
https://platform.openai.com/docs/api-reference/batch
- create_batch()
- retrieve_batch()
- cancel_batch()
- list_batch()
"""
from typing import Iterable
import os
import litellm
from openai import OpenAI
import httpx
from litellm import client
from litellm.utils import supports_httpx_timeout
from ..types.router import *
from ..llms.openai import OpenAIBatchesAPI, OpenAIFilesAPI
from ..types.llms.openai import (
CreateBatchRequest,
RetrieveBatchRequest,
CancelBatchRequest,
CreateFileRequest,
FileTypes,
FileObject,
)
from typing import Literal, Optional, Dict
####### ENVIRONMENT VARIABLES ###################
openai_batches_instance = OpenAIBatchesAPI()
openai_files_instance = OpenAIFilesAPI()
#################################################
def create_file(
file: FileTypes,
purpose: Literal["assistants", "batch", "fine-tune"],
custom_llm_provider: Literal["openai"] = "openai",
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
) -> FileObject:
try:
optional_params = GenericLiteLLMParams(**kwargs)
if custom_llm_provider == "openai":
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
api_base = (
optional_params.api_base
or litellm.api_base
or os.getenv("OPENAI_API_BASE")
or "https://api.openai.com/v1"
)
organization = (
optional_params.organization
or litellm.organization
or os.getenv("OPENAI_ORGANIZATION", None)
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
)
# set API KEY
api_key = (
optional_params.api_key
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
or litellm.openai_key
or os.getenv("OPENAI_API_KEY")
)
### TIMEOUT LOGIC ###
timeout = (
optional_params.timeout or kwargs.get("request_timeout", 600) or 600
)
# set timeout for 10 minutes by default
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) == False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
timeout = float(timeout) # type: ignore
elif timeout is None:
timeout = 600.0
_create_file_request = CreateFileRequest(
file=file,
purpose=purpose,
extra_headers=extra_headers,
extra_body=extra_body,
)
response = openai_files_instance.create_file(
api_base=api_base,
api_key=api_key,
timeout=timeout,
max_retries=optional_params.max_retries,
organization=organization,
create_file_data=_create_file_request,
)
else:
raise litellm.exceptions.BadRequestError(
message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format(
custom_llm_provider
),
model="n/a",
llm_provider=custom_llm_provider,
response=httpx.Response(
status_code=400,
content="Unsupported provider",
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
),
)
return response
except Exception as e:
raise e
def create_batch(
completion_window: Literal["24h"],
endpoint: Literal["/v1/chat/completions", "/v1/embeddings", "/v1/completions"],
input_file_id: str,
custom_llm_provider: Literal["openai"] = "openai",
metadata: Optional[Dict[str, str]] = None,
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
):
"""
Creates and executes a batch from an uploaded file of request
LiteLLM Equivalent of POST: https://api.openai.com/v1/batches
"""
try:
optional_params = GenericLiteLLMParams(**kwargs)
if custom_llm_provider == "openai":
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
api_base = (
optional_params.api_base
or litellm.api_base
or os.getenv("OPENAI_API_BASE")
or "https://api.openai.com/v1"
)
organization = (
optional_params.organization
or litellm.organization
or os.getenv("OPENAI_ORGANIZATION", None)
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
)
# set API KEY
api_key = (
optional_params.api_key
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
or litellm.openai_key
or os.getenv("OPENAI_API_KEY")
)
### TIMEOUT LOGIC ###
timeout = (
optional_params.timeout or kwargs.get("request_timeout", 600) or 600
)
# set timeout for 10 minutes by default
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) == False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
timeout = float(timeout) # type: ignore
elif timeout is None:
timeout = 600.0
_create_batch_request = CreateBatchRequest(
completion_window=completion_window,
endpoint=endpoint,
input_file_id=input_file_id,
metadata=metadata,
extra_headers=extra_headers,
extra_body=extra_body,
)
response = openai_batches_instance.create_batch(
api_base=api_base,
api_key=api_key,
organization=organization,
create_batch_data=_create_batch_request,
timeout=timeout,
max_retries=optional_params.max_retries,
)
else:
raise litellm.exceptions.BadRequestError(
message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format(
custom_llm_provider
),
model="n/a",
llm_provider=custom_llm_provider,
response=httpx.Response(
status_code=400,
content="Unsupported provider",
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
),
)
return response
except Exception as e:
raise e
def retrieve_batch():
pass
def cancel_batch():
pass
def list_batch():
pass
# Async Functions
async def acreate_batch():
pass
async def aretrieve_batch():
pass
async def acancel_batch():
pass
async def alist_batch():
pass

View file

@ -1497,6 +1497,189 @@ class OpenAITextCompletion(BaseLLM):
yield transformed_chunk yield transformed_chunk
class OpenAIFilesAPI(BaseLLM):
"""
OpenAI methods to support for batches
- create_file()
- retrieve_file()
- list_files()
- delete_file()
- file_content()
- update_file()
"""
def __init__(self) -> None:
super().__init__()
def get_openai_client(
self,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[OpenAI] = None,
) -> OpenAI:
received_args = locals()
if client is None:
data = {}
for k, v in received_args.items():
if k == "self" or k == "client":
pass
elif k == "api_base" and v is not None:
data["base_url"] = v
elif v is not None:
data[k] = v
openai_client = OpenAI(**data) # type: ignore
else:
openai_client = client
return openai_client
def create_file(
self,
create_file_data: CreateFileRequest,
api_base: str,
api_key: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[OpenAI] = None,
) -> FileObject:
openai_client: OpenAI = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
response = openai_client.files.create(**create_file_data)
return response
class OpenAIBatchesAPI(BaseLLM):
"""
OpenAI methods to support for batches
- create_batch()
- retrieve_batch()
- cancel_batch()
- list_batch()
"""
def __init__(self) -> None:
super().__init__()
def get_openai_client(
self,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[OpenAI] = None,
) -> OpenAI:
received_args = locals()
if client is None:
data = {}
for k, v in received_args.items():
if k == "self" or k == "client":
pass
elif k == "api_base" and v is not None:
data["base_url"] = v
elif v is not None:
data[k] = v
openai_client = OpenAI(**data) # type: ignore
else:
openai_client = client
return openai_client
def create_batch(
self,
create_batch_data: CreateBatchRequest,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[OpenAI] = None,
):
openai_client: OpenAI = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
response = openai_client.batches.create(**create_batch_data)
return response
def retrieve_batch(
self,
retrieve_batch_data: RetrieveBatchRequest,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[OpenAI] = None,
):
openai_client: OpenAI = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
response = openai_client.batches.retrieve(**retrieve_batch_data)
return response
def cancel_batch(
self,
cancel_batch_data: CancelBatchRequest,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[OpenAI] = None,
):
openai_client: OpenAI = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
response = openai_client.batches.cancel(**cancel_batch_data)
return response
# def list_batch(
# self,
# list_batch_data: ListBatchRequest,
# api_key: Optional[str],
# api_base: Optional[str],
# timeout: Union[float, httpx.Timeout],
# max_retries: Optional[int],
# organization: Optional[str],
# client: Optional[OpenAI] = None,
# ):
# openai_client: OpenAI = self.get_openai_client(
# api_key=api_key,
# api_base=api_base,
# timeout=timeout,
# max_retries=max_retries,
# organization=organization,
# client=client,
# )
# response = openai_client.batches.list(**list_batch_data)
# return response
class OpenAIAssistantsAPI(BaseLLM): class OpenAIAssistantsAPI(BaseLLM):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()

View file

@ -0,0 +1,58 @@
# What is this?
## Unit Tests for OpenAI Batches API
import sys, os, json
import traceback
from dotenv import load_dotenv
load_dotenv()
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest, logging, asyncio
import litellm
from litellm import (
create_batch,
create_file,
)
def test_create_batch():
"""
1. Create File for Batch completion
2. Create Batch Request
"""
file_obj = litellm.create_file(
file=open("openai_batch_completions.jsonl", "rb"),
purpose="batch",
custom_llm_provider="openai",
)
print("Response from creating file=", file_obj)
batch_input_file_id = file_obj.id
assert (
batch_input_file_id is not None
), "Failed to create file, expected a non null file_id but got {batch_input_file_id}"
print("response from creating file=", file_obj)
# response = create_batch(
# completion_window="24h",
# endpoint="/v1/chat/completions",
# input_file_id="1",
# custom_llm_provider="openai",
# metadata={"key1": "value1", "key2": "value2"},
# )
print("response")
pass
def test_retrieve_batch():
pass
def test_cancel_batch():
pass
def test_list_batch():
pass

View file

@ -18,8 +18,23 @@ from openai.types.beta.assistant_tool_param import AssistantToolParam
from openai.types.beta.threads.run import Run from openai.types.beta.threads.run import Run
from openai.types.beta.assistant import Assistant from openai.types.beta.assistant import Assistant
from openai.pagination import SyncCursorPage from openai.pagination import SyncCursorPage
from os import PathLike
from openai.types import FileObject
from typing import TypedDict, List, Optional from typing import TypedDict, List, Optional, Tuple, Mapping, IO
FileContent = Union[IO[bytes], bytes, PathLike[str]]
FileTypes = Union[
# file (or bytes)
FileContent,
# (filename, file (or bytes))
Tuple[Optional[str], FileContent],
# (filename, file (or bytes), content_type)
Tuple[Optional[str], FileContent, Optional[str]],
# (filename, file (or bytes), content_type, headers)
Tuple[Optional[str], FileContent, Optional[str], Mapping[str, str]],
]
class NotGiven: class NotGiven:
@ -148,8 +163,31 @@ class Thread(BaseModel):
"""The object type, which is always `thread`.""" """The object type, which is always `thread`."""
# OpenAI Files Types
class CreateFileRequest(TypedDict, total=False):
"""
CreateFileRequest
Used by Assistants API, Batches API, and Fine-Tunes API
Required Params:
file: FileTypes
purpose: Literal['assistants', 'batch', 'fine-tune']
Optional Params:
extra_headers: Optional[Dict[str, str]]
extra_body: Optional[Dict[str, str]] = None
timeout: Optional[float] = None
"""
file: FileTypes
purpose: Literal["assistants", "batch", "fine-tune"]
extra_headers: Optional[Dict[str, str]]
extra_body: Optional[Dict[str, str]]
timeout: Optional[float]
# OpenAI Batches Types # OpenAI Batches Types
class CreateBatchRequest(BaseModel): class CreateBatchRequest(TypedDict, total=False):
""" """
CreateBatchRequest CreateBatchRequest
""" """
@ -157,42 +195,42 @@ class CreateBatchRequest(BaseModel):
completion_window: Literal["24h"] completion_window: Literal["24h"]
endpoint: Literal["/v1/chat/completions", "/v1/embeddings", "/v1/completions"] endpoint: Literal["/v1/chat/completions", "/v1/embeddings", "/v1/completions"]
input_file_id: str input_file_id: str
metadata: Optional[Dict[str, str]] = None metadata: Optional[Dict[str, str]]
extra_headers: Optional[Dict[str, str]] = None extra_headers: Optional[Dict[str, str]]
extra_body: Optional[Dict[str, str]] = None extra_body: Optional[Dict[str, str]]
timeout: Optional[float] = None timeout: Optional[float]
class RetrieveBatchRequest(BaseModel): class RetrieveBatchRequest(TypedDict, total=False):
""" """
RetrieveBatchRequest RetrieveBatchRequest
""" """
batch_id: str batch_id: str
extra_headers: Optional[Dict[str, str]] = None extra_headers: Optional[Dict[str, str]]
extra_body: Optional[Dict[str, str]] = None extra_body: Optional[Dict[str, str]]
timeout: Optional[float] = None timeout: Optional[float]
class CancelBatchRequest(BaseModel): class CancelBatchRequest(TypedDict, total=False):
""" """
CancelBatchRequest CancelBatchRequest
""" """
batch_id: str batch_id: str
extra_headers: Optional[Dict[str, str]] = None extra_headers: Optional[Dict[str, str]]
extra_body: Optional[Dict[str, str]] = None extra_body: Optional[Dict[str, str]]
timeout: Optional[float] = None timeout: Optional[float]
class ListBatchRequest(BaseModel): class ListBatchRequest(TypedDict, total=False):
""" """
ListBatchRequest - List your organization's batches ListBatchRequest - List your organization's batches
Calls https://api.openai.com/v1/batches Calls https://api.openai.com/v1/batches
""" """
after: Optional[str] = None after: Union[str, NotGiven]
limit: Optional[int] = 20 limit: Union[int, NotGiven]
extra_headers: Optional[Dict[str, str]] = None extra_headers: Optional[Dict[str, str]]
extra_body: Optional[Dict[str, str]] = None extra_body: Optional[Dict[str, str]]
timeout: Optional[float] = None timeout: Optional[float]