feat - import batches in __init__

This commit is contained in:
Ishaan Jaff 2024-05-28 15:35:11 -07:00
parent 0af4c9206f
commit d5dbf084ed
5 changed files with 539 additions and 20 deletions

View file

@ -797,3 +797,4 @@ from .budget_manager import BudgetManager
from .proxy.proxy_cli import run_server
from .router import Router
from .assistants.main import *
from .batches.main import *

239
litellm/batches/main.py Normal file
View file

@ -0,0 +1,239 @@
"""
Main File for Batches API implementation
https://platform.openai.com/docs/api-reference/batch
- create_batch()
- retrieve_batch()
- cancel_batch()
- list_batch()
"""
from typing import Iterable
import os
import litellm
from openai import OpenAI
import httpx
from litellm import client
from litellm.utils import supports_httpx_timeout
from ..types.router import *
from ..llms.openai import OpenAIBatchesAPI, OpenAIFilesAPI
from ..types.llms.openai import (
CreateBatchRequest,
RetrieveBatchRequest,
CancelBatchRequest,
CreateFileRequest,
FileTypes,
FileObject,
)
from typing import Literal, Optional, Dict
####### ENVIRONMENT VARIABLES ###################
openai_batches_instance = OpenAIBatchesAPI()
openai_files_instance = OpenAIFilesAPI()
#################################################
def create_file(
file: FileTypes,
purpose: Literal["assistants", "batch", "fine-tune"],
custom_llm_provider: Literal["openai"] = "openai",
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
) -> FileObject:
try:
optional_params = GenericLiteLLMParams(**kwargs)
if custom_llm_provider == "openai":
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
api_base = (
optional_params.api_base
or litellm.api_base
or os.getenv("OPENAI_API_BASE")
or "https://api.openai.com/v1"
)
organization = (
optional_params.organization
or litellm.organization
or os.getenv("OPENAI_ORGANIZATION", None)
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
)
# set API KEY
api_key = (
optional_params.api_key
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
or litellm.openai_key
or os.getenv("OPENAI_API_KEY")
)
### TIMEOUT LOGIC ###
timeout = (
optional_params.timeout or kwargs.get("request_timeout", 600) or 600
)
# set timeout for 10 minutes by default
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) == False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
timeout = float(timeout) # type: ignore
elif timeout is None:
timeout = 600.0
_create_file_request = CreateFileRequest(
file=file,
purpose=purpose,
extra_headers=extra_headers,
extra_body=extra_body,
)
response = openai_files_instance.create_file(
api_base=api_base,
api_key=api_key,
timeout=timeout,
max_retries=optional_params.max_retries,
organization=organization,
create_file_data=_create_file_request,
)
else:
raise litellm.exceptions.BadRequestError(
message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format(
custom_llm_provider
),
model="n/a",
llm_provider=custom_llm_provider,
response=httpx.Response(
status_code=400,
content="Unsupported provider",
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
),
)
return response
except Exception as e:
raise e
def create_batch(
completion_window: Literal["24h"],
endpoint: Literal["/v1/chat/completions", "/v1/embeddings", "/v1/completions"],
input_file_id: str,
custom_llm_provider: Literal["openai"] = "openai",
metadata: Optional[Dict[str, str]] = None,
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
):
"""
Creates and executes a batch from an uploaded file of request
LiteLLM Equivalent of POST: https://api.openai.com/v1/batches
"""
try:
optional_params = GenericLiteLLMParams(**kwargs)
if custom_llm_provider == "openai":
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
api_base = (
optional_params.api_base
or litellm.api_base
or os.getenv("OPENAI_API_BASE")
or "https://api.openai.com/v1"
)
organization = (
optional_params.organization
or litellm.organization
or os.getenv("OPENAI_ORGANIZATION", None)
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
)
# set API KEY
api_key = (
optional_params.api_key
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
or litellm.openai_key
or os.getenv("OPENAI_API_KEY")
)
### TIMEOUT LOGIC ###
timeout = (
optional_params.timeout or kwargs.get("request_timeout", 600) or 600
)
# set timeout for 10 minutes by default
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) == False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
timeout = float(timeout) # type: ignore
elif timeout is None:
timeout = 600.0
_create_batch_request = CreateBatchRequest(
completion_window=completion_window,
endpoint=endpoint,
input_file_id=input_file_id,
metadata=metadata,
extra_headers=extra_headers,
extra_body=extra_body,
)
response = openai_batches_instance.create_batch(
api_base=api_base,
api_key=api_key,
organization=organization,
create_batch_data=_create_batch_request,
timeout=timeout,
max_retries=optional_params.max_retries,
)
else:
raise litellm.exceptions.BadRequestError(
message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format(
custom_llm_provider
),
model="n/a",
llm_provider=custom_llm_provider,
response=httpx.Response(
status_code=400,
content="Unsupported provider",
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
),
)
return response
except Exception as e:
raise e
def retrieve_batch():
pass
def cancel_batch():
pass
def list_batch():
pass
# Async Functions
async def acreate_batch():
pass
async def aretrieve_batch():
pass
async def acancel_batch():
pass
async def alist_batch():
pass

View file

@ -1497,6 +1497,189 @@ class OpenAITextCompletion(BaseLLM):
yield transformed_chunk
class OpenAIFilesAPI(BaseLLM):
"""
OpenAI methods to support for batches
- create_file()
- retrieve_file()
- list_files()
- delete_file()
- file_content()
- update_file()
"""
def __init__(self) -> None:
super().__init__()
def get_openai_client(
self,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[OpenAI] = None,
) -> OpenAI:
received_args = locals()
if client is None:
data = {}
for k, v in received_args.items():
if k == "self" or k == "client":
pass
elif k == "api_base" and v is not None:
data["base_url"] = v
elif v is not None:
data[k] = v
openai_client = OpenAI(**data) # type: ignore
else:
openai_client = client
return openai_client
def create_file(
self,
create_file_data: CreateFileRequest,
api_base: str,
api_key: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[OpenAI] = None,
) -> FileObject:
openai_client: OpenAI = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
response = openai_client.files.create(**create_file_data)
return response
class OpenAIBatchesAPI(BaseLLM):
"""
OpenAI methods to support for batches
- create_batch()
- retrieve_batch()
- cancel_batch()
- list_batch()
"""
def __init__(self) -> None:
super().__init__()
def get_openai_client(
self,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[OpenAI] = None,
) -> OpenAI:
received_args = locals()
if client is None:
data = {}
for k, v in received_args.items():
if k == "self" or k == "client":
pass
elif k == "api_base" and v is not None:
data["base_url"] = v
elif v is not None:
data[k] = v
openai_client = OpenAI(**data) # type: ignore
else:
openai_client = client
return openai_client
def create_batch(
self,
create_batch_data: CreateBatchRequest,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[OpenAI] = None,
):
openai_client: OpenAI = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
response = openai_client.batches.create(**create_batch_data)
return response
def retrieve_batch(
self,
retrieve_batch_data: RetrieveBatchRequest,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[OpenAI] = None,
):
openai_client: OpenAI = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
response = openai_client.batches.retrieve(**retrieve_batch_data)
return response
def cancel_batch(
self,
cancel_batch_data: CancelBatchRequest,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[OpenAI] = None,
):
openai_client: OpenAI = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
response = openai_client.batches.cancel(**cancel_batch_data)
return response
# def list_batch(
# self,
# list_batch_data: ListBatchRequest,
# api_key: Optional[str],
# api_base: Optional[str],
# timeout: Union[float, httpx.Timeout],
# max_retries: Optional[int],
# organization: Optional[str],
# client: Optional[OpenAI] = None,
# ):
# openai_client: OpenAI = self.get_openai_client(
# api_key=api_key,
# api_base=api_base,
# timeout=timeout,
# max_retries=max_retries,
# organization=organization,
# client=client,
# )
# response = openai_client.batches.list(**list_batch_data)
# return response
class OpenAIAssistantsAPI(BaseLLM):
def __init__(self) -> None:
super().__init__()

View file

@ -0,0 +1,58 @@
# What is this?
## Unit Tests for OpenAI Batches API
import sys, os, json
import traceback
from dotenv import load_dotenv
load_dotenv()
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest, logging, asyncio
import litellm
from litellm import (
create_batch,
create_file,
)
def test_create_batch():
"""
1. Create File for Batch completion
2. Create Batch Request
"""
file_obj = litellm.create_file(
file=open("openai_batch_completions.jsonl", "rb"),
purpose="batch",
custom_llm_provider="openai",
)
print("Response from creating file=", file_obj)
batch_input_file_id = file_obj.id
assert (
batch_input_file_id is not None
), "Failed to create file, expected a non null file_id but got {batch_input_file_id}"
print("response from creating file=", file_obj)
# response = create_batch(
# completion_window="24h",
# endpoint="/v1/chat/completions",
# input_file_id="1",
# custom_llm_provider="openai",
# metadata={"key1": "value1", "key2": "value2"},
# )
print("response")
pass
def test_retrieve_batch():
pass
def test_cancel_batch():
pass
def test_list_batch():
pass

View file

@ -18,8 +18,23 @@ from openai.types.beta.assistant_tool_param import AssistantToolParam
from openai.types.beta.threads.run import Run
from openai.types.beta.assistant import Assistant
from openai.pagination import SyncCursorPage
from os import PathLike
from openai.types import FileObject
from typing import TypedDict, List, Optional
from typing import TypedDict, List, Optional, Tuple, Mapping, IO
FileContent = Union[IO[bytes], bytes, PathLike[str]]
FileTypes = Union[
# file (or bytes)
FileContent,
# (filename, file (or bytes))
Tuple[Optional[str], FileContent],
# (filename, file (or bytes), content_type)
Tuple[Optional[str], FileContent, Optional[str]],
# (filename, file (or bytes), content_type, headers)
Tuple[Optional[str], FileContent, Optional[str], Mapping[str, str]],
]
class NotGiven:
@ -148,8 +163,31 @@ class Thread(BaseModel):
"""The object type, which is always `thread`."""
# OpenAI Files Types
class CreateFileRequest(TypedDict, total=False):
"""
CreateFileRequest
Used by Assistants API, Batches API, and Fine-Tunes API
Required Params:
file: FileTypes
purpose: Literal['assistants', 'batch', 'fine-tune']
Optional Params:
extra_headers: Optional[Dict[str, str]]
extra_body: Optional[Dict[str, str]] = None
timeout: Optional[float] = None
"""
file: FileTypes
purpose: Literal["assistants", "batch", "fine-tune"]
extra_headers: Optional[Dict[str, str]]
extra_body: Optional[Dict[str, str]]
timeout: Optional[float]
# OpenAI Batches Types
class CreateBatchRequest(BaseModel):
class CreateBatchRequest(TypedDict, total=False):
"""
CreateBatchRequest
"""
@ -157,42 +195,42 @@ class CreateBatchRequest(BaseModel):
completion_window: Literal["24h"]
endpoint: Literal["/v1/chat/completions", "/v1/embeddings", "/v1/completions"]
input_file_id: str
metadata: Optional[Dict[str, str]] = None
extra_headers: Optional[Dict[str, str]] = None
extra_body: Optional[Dict[str, str]] = None
timeout: Optional[float] = None
metadata: Optional[Dict[str, str]]
extra_headers: Optional[Dict[str, str]]
extra_body: Optional[Dict[str, str]]
timeout: Optional[float]
class RetrieveBatchRequest(BaseModel):
class RetrieveBatchRequest(TypedDict, total=False):
"""
RetrieveBatchRequest
"""
batch_id: str
extra_headers: Optional[Dict[str, str]] = None
extra_body: Optional[Dict[str, str]] = None
timeout: Optional[float] = None
extra_headers: Optional[Dict[str, str]]
extra_body: Optional[Dict[str, str]]
timeout: Optional[float]
class CancelBatchRequest(BaseModel):
class CancelBatchRequest(TypedDict, total=False):
"""
CancelBatchRequest
"""
batch_id: str
extra_headers: Optional[Dict[str, str]] = None
extra_body: Optional[Dict[str, str]] = None
timeout: Optional[float] = None
extra_headers: Optional[Dict[str, str]]
extra_body: Optional[Dict[str, str]]
timeout: Optional[float]
class ListBatchRequest(BaseModel):
class ListBatchRequest(TypedDict, total=False):
"""
ListBatchRequest - List your organization's batches
Calls https://api.openai.com/v1/batches
"""
after: Optional[str] = None
limit: Optional[int] = 20
extra_headers: Optional[Dict[str, str]] = None
extra_body: Optional[Dict[str, str]] = None
timeout: Optional[float] = None
after: Union[str, NotGiven]
limit: Union[int, NotGiven]
extra_headers: Optional[Dict[str, str]]
extra_body: Optional[Dict[str, str]]
timeout: Optional[float]