diff --git a/litellm/__init__.py b/litellm/__init__.py index 3c78c9b27..56a2088e7 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -797,3 +797,4 @@ from .budget_manager import BudgetManager from .proxy.proxy_cli import run_server from .router import Router from .assistants.main import * +from .batches.main import * diff --git a/litellm/batches/main.py b/litellm/batches/main.py new file mode 100644 index 000000000..5d9a3a141 --- /dev/null +++ b/litellm/batches/main.py @@ -0,0 +1,589 @@ +""" +Main File for Batches API implementation + +https://platform.openai.com/docs/api-reference/batch + +- create_batch() +- retrieve_batch() +- cancel_batch() +- list_batch() + +""" + +import os +import asyncio +from functools import partial +import contextvars +from typing import Literal, Optional, Dict, Coroutine, Any, Union +import httpx + +import litellm +from litellm import client +from litellm.utils import supports_httpx_timeout +from ..types.router import * +from ..llms.openai import OpenAIBatchesAPI, OpenAIFilesAPI +from ..types.llms.openai import ( + CreateBatchRequest, + RetrieveBatchRequest, + CancelBatchRequest, + CreateFileRequest, + FileTypes, + FileObject, + Batch, + FileContentRequest, + HttpxBinaryResponseContent, +) + +####### ENVIRONMENT VARIABLES ################### +openai_batches_instance = OpenAIBatchesAPI() +openai_files_instance = OpenAIFilesAPI() +################################################# + + +async def acreate_file( + file: FileTypes, + purpose: Literal["assistants", "batch", "fine-tune"], + custom_llm_provider: Literal["openai"] = "openai", + extra_headers: Optional[Dict[str, str]] = None, + extra_body: Optional[Dict[str, str]] = None, + **kwargs, +) -> Coroutine[Any, Any, FileObject]: + """ + Async: Files are used to upload documents that can be used with features like Assistants, Fine-tuning, and Batch API. + + LiteLLM Equivalent of POST: POST https://api.openai.com/v1/files + """ + try: + loop = asyncio.get_event_loop() + kwargs["acreate_file"] = True + + # Use a partial function to pass your keyword arguments + func = partial( + create_file, + file, + purpose, + custom_llm_provider, + extra_headers, + extra_body, + **kwargs, + ) + + # Add the context to the function + ctx = contextvars.copy_context() + func_with_context = partial(ctx.run, func) + init_response = await loop.run_in_executor(None, func_with_context) + if asyncio.iscoroutine(init_response): + response = await init_response + else: + response = init_response # type: ignore + + return response + except Exception as e: + raise e + + +def create_file( + file: FileTypes, + purpose: Literal["assistants", "batch", "fine-tune"], + custom_llm_provider: Literal["openai"] = "openai", + extra_headers: Optional[Dict[str, str]] = None, + extra_body: Optional[Dict[str, str]] = None, + **kwargs, +) -> Union[FileObject, Coroutine[Any, Any, FileObject]]: + """ + Files are used to upload documents that can be used with features like Assistants, Fine-tuning, and Batch API. + + LiteLLM Equivalent of POST: POST https://api.openai.com/v1/files + """ + try: + optional_params = GenericLiteLLMParams(**kwargs) + if custom_llm_provider == "openai": + # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there + api_base = ( + optional_params.api_base + or litellm.api_base + or os.getenv("OPENAI_API_BASE") + or "https://api.openai.com/v1" + ) + organization = ( + optional_params.organization + or litellm.organization + or os.getenv("OPENAI_ORGANIZATION", None) + or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105 + ) + # set API KEY + api_key = ( + optional_params.api_key + or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there + or litellm.openai_key + or os.getenv("OPENAI_API_KEY") + ) + ### TIMEOUT LOGIC ### + timeout = ( + optional_params.timeout or kwargs.get("request_timeout", 600) or 600 + ) + # set timeout for 10 minutes by default + + if ( + timeout is not None + and isinstance(timeout, httpx.Timeout) + and supports_httpx_timeout(custom_llm_provider) == False + ): + read_timeout = timeout.read or 600 + timeout = read_timeout # default 10 min timeout + elif timeout is not None and not isinstance(timeout, httpx.Timeout): + timeout = float(timeout) # type: ignore + elif timeout is None: + timeout = 600.0 + + _create_file_request = CreateFileRequest( + file=file, + purpose=purpose, + extra_headers=extra_headers, + extra_body=extra_body, + ) + + _is_async = kwargs.pop("acreate_file", False) is True + + response = openai_files_instance.create_file( + _is_async=_is_async, + api_base=api_base, + api_key=api_key, + timeout=timeout, + max_retries=optional_params.max_retries, + organization=organization, + create_file_data=_create_file_request, + ) + else: + raise litellm.exceptions.BadRequestError( + message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format( + custom_llm_provider + ), + model="n/a", + llm_provider=custom_llm_provider, + response=httpx.Response( + status_code=400, + content="Unsupported provider", + request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore + ), + ) + return response + except Exception as e: + raise e + + +async def afile_content( + file_id: str, + custom_llm_provider: Literal["openai"] = "openai", + extra_headers: Optional[Dict[str, str]] = None, + extra_body: Optional[Dict[str, str]] = None, + **kwargs, +) -> Coroutine[Any, Any, HttpxBinaryResponseContent]: + """ + Async: Get file contents + + LiteLLM Equivalent of GET https://api.openai.com/v1/files + """ + try: + loop = asyncio.get_event_loop() + kwargs["afile_content"] = True + + # Use a partial function to pass your keyword arguments + func = partial( + file_content, + file_id, + custom_llm_provider, + extra_headers, + extra_body, + **kwargs, + ) + + # Add the context to the function + ctx = contextvars.copy_context() + func_with_context = partial(ctx.run, func) + init_response = await loop.run_in_executor(None, func_with_context) + if asyncio.iscoroutine(init_response): + response = await init_response + else: + response = init_response # type: ignore + + return response + except Exception as e: + raise e + + +def file_content( + file_id: str, + custom_llm_provider: Literal["openai"] = "openai", + extra_headers: Optional[Dict[str, str]] = None, + extra_body: Optional[Dict[str, str]] = None, + **kwargs, +) -> Union[HttpxBinaryResponseContent, Coroutine[Any, Any, HttpxBinaryResponseContent]]: + """ + Returns the contents of the specified file. + + LiteLLM Equivalent of POST: POST https://api.openai.com/v1/files + """ + try: + optional_params = GenericLiteLLMParams(**kwargs) + if custom_llm_provider == "openai": + # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there + api_base = ( + optional_params.api_base + or litellm.api_base + or os.getenv("OPENAI_API_BASE") + or "https://api.openai.com/v1" + ) + organization = ( + optional_params.organization + or litellm.organization + or os.getenv("OPENAI_ORGANIZATION", None) + or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105 + ) + # set API KEY + api_key = ( + optional_params.api_key + or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there + or litellm.openai_key + or os.getenv("OPENAI_API_KEY") + ) + ### TIMEOUT LOGIC ### + timeout = ( + optional_params.timeout or kwargs.get("request_timeout", 600) or 600 + ) + # set timeout for 10 minutes by default + + if ( + timeout is not None + and isinstance(timeout, httpx.Timeout) + and supports_httpx_timeout(custom_llm_provider) == False + ): + read_timeout = timeout.read or 600 + timeout = read_timeout # default 10 min timeout + elif timeout is not None and not isinstance(timeout, httpx.Timeout): + timeout = float(timeout) # type: ignore + elif timeout is None: + timeout = 600.0 + + _file_content_request = FileContentRequest( + file_id=file_id, + extra_headers=extra_headers, + extra_body=extra_body, + ) + + _is_async = kwargs.pop("afile_content", False) is True + + response = openai_files_instance.file_content( + _is_async=_is_async, + file_content_request=_file_content_request, + api_base=api_base, + api_key=api_key, + timeout=timeout, + max_retries=optional_params.max_retries, + organization=organization, + ) + else: + raise litellm.exceptions.BadRequestError( + message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format( + custom_llm_provider + ), + model="n/a", + llm_provider=custom_llm_provider, + response=httpx.Response( + status_code=400, + content="Unsupported provider", + request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore + ), + ) + return response + except Exception as e: + raise e + + +async def acreate_batch( + completion_window: Literal["24h"], + endpoint: Literal["/v1/chat/completions", "/v1/embeddings", "/v1/completions"], + input_file_id: str, + custom_llm_provider: Literal["openai"] = "openai", + metadata: Optional[Dict[str, str]] = None, + extra_headers: Optional[Dict[str, str]] = None, + extra_body: Optional[Dict[str, str]] = None, + **kwargs, +) -> Coroutine[Any, Any, Batch]: + """ + Async: Creates and executes a batch from an uploaded file of request + + LiteLLM Equivalent of POST: https://api.openai.com/v1/batches + """ + try: + loop = asyncio.get_event_loop() + kwargs["acreate_batch"] = True + + # Use a partial function to pass your keyword arguments + func = partial( + create_batch, + completion_window, + endpoint, + input_file_id, + custom_llm_provider, + metadata, + extra_headers, + extra_body, + **kwargs, + ) + + # Add the context to the function + ctx = contextvars.copy_context() + func_with_context = partial(ctx.run, func) + init_response = await loop.run_in_executor(None, func_with_context) + if asyncio.iscoroutine(init_response): + response = await init_response + else: + response = init_response # type: ignore + + return response + except Exception as e: + raise e + + +def create_batch( + completion_window: Literal["24h"], + endpoint: Literal["/v1/chat/completions", "/v1/embeddings", "/v1/completions"], + input_file_id: str, + custom_llm_provider: Literal["openai"] = "openai", + metadata: Optional[Dict[str, str]] = None, + extra_headers: Optional[Dict[str, str]] = None, + extra_body: Optional[Dict[str, str]] = None, + **kwargs, +) -> Union[Batch, Coroutine[Any, Any, Batch]]: + """ + Creates and executes a batch from an uploaded file of request + + LiteLLM Equivalent of POST: https://api.openai.com/v1/batches + """ + try: + optional_params = GenericLiteLLMParams(**kwargs) + if custom_llm_provider == "openai": + + # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there + api_base = ( + optional_params.api_base + or litellm.api_base + or os.getenv("OPENAI_API_BASE") + or "https://api.openai.com/v1" + ) + organization = ( + optional_params.organization + or litellm.organization + or os.getenv("OPENAI_ORGANIZATION", None) + or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105 + ) + # set API KEY + api_key = ( + optional_params.api_key + or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there + or litellm.openai_key + or os.getenv("OPENAI_API_KEY") + ) + ### TIMEOUT LOGIC ### + timeout = ( + optional_params.timeout or kwargs.get("request_timeout", 600) or 600 + ) + # set timeout for 10 minutes by default + + if ( + timeout is not None + and isinstance(timeout, httpx.Timeout) + and supports_httpx_timeout(custom_llm_provider) == False + ): + read_timeout = timeout.read or 600 + timeout = read_timeout # default 10 min timeout + elif timeout is not None and not isinstance(timeout, httpx.Timeout): + timeout = float(timeout) # type: ignore + elif timeout is None: + timeout = 600.0 + + _is_async = kwargs.pop("acreate_batch", False) is True + + _create_batch_request = CreateBatchRequest( + completion_window=completion_window, + endpoint=endpoint, + input_file_id=input_file_id, + metadata=metadata, + extra_headers=extra_headers, + extra_body=extra_body, + ) + + response = openai_batches_instance.create_batch( + api_base=api_base, + api_key=api_key, + organization=organization, + create_batch_data=_create_batch_request, + timeout=timeout, + max_retries=optional_params.max_retries, + _is_async=_is_async, + ) + else: + raise litellm.exceptions.BadRequestError( + message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format( + custom_llm_provider + ), + model="n/a", + llm_provider=custom_llm_provider, + response=httpx.Response( + status_code=400, + content="Unsupported provider", + request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore + ), + ) + return response + except Exception as e: + raise e + + +async def aretrieve_batch( + batch_id: str, + custom_llm_provider: Literal["openai"] = "openai", + metadata: Optional[Dict[str, str]] = None, + extra_headers: Optional[Dict[str, str]] = None, + extra_body: Optional[Dict[str, str]] = None, + **kwargs, +) -> Coroutine[Any, Any, Batch]: + """ + Async: Retrieves a batch. + + LiteLLM Equivalent of GET https://api.openai.com/v1/batches/{batch_id} + """ + try: + loop = asyncio.get_event_loop() + kwargs["aretrieve_batch"] = True + + # Use a partial function to pass your keyword arguments + func = partial( + retrieve_batch, + batch_id, + custom_llm_provider, + metadata, + extra_headers, + extra_body, + **kwargs, + ) + + # Add the context to the function + ctx = contextvars.copy_context() + func_with_context = partial(ctx.run, func) + init_response = await loop.run_in_executor(None, func_with_context) + if asyncio.iscoroutine(init_response): + response = await init_response + else: + response = init_response # type: ignore + + return response + except Exception as e: + raise e + + +def retrieve_batch( + batch_id: str, + custom_llm_provider: Literal["openai"] = "openai", + metadata: Optional[Dict[str, str]] = None, + extra_headers: Optional[Dict[str, str]] = None, + extra_body: Optional[Dict[str, str]] = None, + **kwargs, +) -> Union[Batch, Coroutine[Any, Any, Batch]]: + """ + Retrieves a batch. + + LiteLLM Equivalent of GET https://api.openai.com/v1/batches/{batch_id} + """ + try: + optional_params = GenericLiteLLMParams(**kwargs) + if custom_llm_provider == "openai": + + # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there + api_base = ( + optional_params.api_base + or litellm.api_base + or os.getenv("OPENAI_API_BASE") + or "https://api.openai.com/v1" + ) + organization = ( + optional_params.organization + or litellm.organization + or os.getenv("OPENAI_ORGANIZATION", None) + or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105 + ) + # set API KEY + api_key = ( + optional_params.api_key + or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there + or litellm.openai_key + or os.getenv("OPENAI_API_KEY") + ) + ### TIMEOUT LOGIC ### + timeout = ( + optional_params.timeout or kwargs.get("request_timeout", 600) or 600 + ) + # set timeout for 10 minutes by default + + if ( + timeout is not None + and isinstance(timeout, httpx.Timeout) + and supports_httpx_timeout(custom_llm_provider) == False + ): + read_timeout = timeout.read or 600 + timeout = read_timeout # default 10 min timeout + elif timeout is not None and not isinstance(timeout, httpx.Timeout): + timeout = float(timeout) # type: ignore + elif timeout is None: + timeout = 600.0 + + _retrieve_batch_request = RetrieveBatchRequest( + batch_id=batch_id, + extra_headers=extra_headers, + extra_body=extra_body, + ) + + _is_async = kwargs.pop("aretrieve_batch", False) is True + + response = openai_batches_instance.retrieve_batch( + _is_async=_is_async, + retrieve_batch_data=_retrieve_batch_request, + api_base=api_base, + api_key=api_key, + organization=organization, + timeout=timeout, + max_retries=optional_params.max_retries, + ) + else: + raise litellm.exceptions.BadRequestError( + message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format( + custom_llm_provider + ), + model="n/a", + llm_provider=custom_llm_provider, + response=httpx.Response( + status_code=400, + content="Unsupported provider", + request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore + ), + ) + return response + except Exception as e: + raise e + + +def cancel_batch(): + pass + + +def list_batch(): + pass + + +async def acancel_batch(): + pass + + +async def alist_batch(): + pass diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index 05e6566ff..1a1dc4e6d 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -21,7 +21,7 @@ from litellm.utils import ( TranscriptionResponse, TextCompletionResponse, ) -from typing import Callable, Optional +from typing import Callable, Optional, Coroutine import litellm from .prompt_templates.factory import prompt_factory, custom_prompt from openai import OpenAI, AsyncOpenAI @@ -1497,6 +1497,322 @@ class OpenAITextCompletion(BaseLLM): yield transformed_chunk +class OpenAIFilesAPI(BaseLLM): + """ + OpenAI methods to support for batches + - create_file() + - retrieve_file() + - list_files() + - delete_file() + - file_content() + - update_file() + """ + + def __init__(self) -> None: + super().__init__() + + def get_openai_client( + self, + api_key: Optional[str], + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + client: Optional[Union[OpenAI, AsyncOpenAI]] = None, + _is_async: bool = False, + ) -> Optional[Union[OpenAI, AsyncOpenAI]]: + received_args = locals() + openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = None + if client is None: + data = {} + for k, v in received_args.items(): + if k == "self" or k == "client" or k == "_is_async": + pass + elif k == "api_base" and v is not None: + data["base_url"] = v + elif v is not None: + data[k] = v + if _is_async is True: + openai_client = AsyncOpenAI(**data) + else: + openai_client = OpenAI(**data) # type: ignore + else: + openai_client = client + + return openai_client + + async def acreate_file( + self, + create_file_data: CreateFileRequest, + openai_client: AsyncOpenAI, + ) -> FileObject: + response = await openai_client.files.create(**create_file_data) + return response + + def create_file( + self, + _is_async: bool, + create_file_data: CreateFileRequest, + api_base: str, + api_key: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + client: Optional[Union[OpenAI, AsyncOpenAI]] = None, + ) -> Union[FileObject, Coroutine[Any, Any, FileObject]]: + openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client( + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + _is_async=_is_async, + ) + if openai_client is None: + raise ValueError( + "OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." + ) + + if _is_async is True: + if not isinstance(openai_client, AsyncOpenAI): + raise ValueError( + "OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client." + ) + return self.acreate_file( # type: ignore + create_file_data=create_file_data, openai_client=openai_client + ) + response = openai_client.files.create(**create_file_data) + return response + + async def afile_content( + self, + file_content_request: FileContentRequest, + openai_client: AsyncOpenAI, + ) -> HttpxBinaryResponseContent: + response = await openai_client.files.content(**file_content_request) + return response + + def file_content( + self, + _is_async: bool, + file_content_request: FileContentRequest, + api_base: str, + api_key: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + client: Optional[Union[OpenAI, AsyncOpenAI]] = None, + ) -> Union[ + HttpxBinaryResponseContent, Coroutine[Any, Any, HttpxBinaryResponseContent] + ]: + openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client( + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + _is_async=_is_async, + ) + if openai_client is None: + raise ValueError( + "OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." + ) + + if _is_async is True: + if not isinstance(openai_client, AsyncOpenAI): + raise ValueError( + "OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client." + ) + return self.afile_content( # type: ignore + file_content_request=file_content_request, + openai_client=openai_client, + ) + response = openai_client.files.content(**file_content_request) + + return response + + +class OpenAIBatchesAPI(BaseLLM): + """ + OpenAI methods to support for batches + - create_batch() + - retrieve_batch() + - cancel_batch() + - list_batch() + """ + + def __init__(self) -> None: + super().__init__() + + def get_openai_client( + self, + api_key: Optional[str], + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + client: Optional[Union[OpenAI, AsyncOpenAI]] = None, + _is_async: bool = False, + ) -> Optional[Union[OpenAI, AsyncOpenAI]]: + received_args = locals() + openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = None + if client is None: + data = {} + for k, v in received_args.items(): + if k == "self" or k == "client" or k == "_is_async": + pass + elif k == "api_base" and v is not None: + data["base_url"] = v + elif v is not None: + data[k] = v + if _is_async is True: + openai_client = AsyncOpenAI(**data) + else: + openai_client = OpenAI(**data) # type: ignore + else: + openai_client = client + + return openai_client + + async def acreate_batch( + self, + create_batch_data: CreateBatchRequest, + openai_client: AsyncOpenAI, + ) -> Batch: + response = await openai_client.batches.create(**create_batch_data) + return response + + def create_batch( + self, + _is_async: bool, + create_batch_data: CreateBatchRequest, + api_key: Optional[str], + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + client: Optional[Union[OpenAI, AsyncOpenAI]] = None, + ) -> Union[Batch, Coroutine[Any, Any, Batch]]: + openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client( + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + _is_async=_is_async, + ) + if openai_client is None: + raise ValueError( + "OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." + ) + + if _is_async is True: + if not isinstance(openai_client, AsyncOpenAI): + raise ValueError( + "OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client." + ) + return self.acreate_batch( # type: ignore + create_batch_data=create_batch_data, openai_client=openai_client + ) + response = openai_client.batches.create(**create_batch_data) + return response + + async def aretrieve_batch( + self, + retrieve_batch_data: RetrieveBatchRequest, + openai_client: AsyncOpenAI, + ) -> Batch: + response = await openai_client.batches.retrieve(**retrieve_batch_data) + return response + + def retrieve_batch( + self, + _is_async: bool, + retrieve_batch_data: RetrieveBatchRequest, + api_key: Optional[str], + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + client: Optional[OpenAI] = None, + ): + openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client( + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + _is_async=_is_async, + ) + if openai_client is None: + raise ValueError( + "OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." + ) + + if _is_async is True: + if not isinstance(openai_client, AsyncOpenAI): + raise ValueError( + "OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client." + ) + return self.aretrieve_batch( # type: ignore + retrieve_batch_data=retrieve_batch_data, openai_client=openai_client + ) + response = openai_client.batches.retrieve(**retrieve_batch_data) + return response + + def cancel_batch( + self, + _is_async: bool, + cancel_batch_data: CancelBatchRequest, + api_key: Optional[str], + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + client: Optional[OpenAI] = None, + ): + openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client( + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + _is_async=_is_async, + ) + if openai_client is None: + raise ValueError( + "OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." + ) + response = openai_client.batches.cancel(**cancel_batch_data) + return response + + # def list_batch( + # self, + # list_batch_data: ListBatchRequest, + # api_key: Optional[str], + # api_base: Optional[str], + # timeout: Union[float, httpx.Timeout], + # max_retries: Optional[int], + # organization: Optional[str], + # client: Optional[OpenAI] = None, + # ): + # openai_client: OpenAI = self.get_openai_client( + # api_key=api_key, + # api_base=api_base, + # timeout=timeout, + # max_retries=max_retries, + # organization=organization, + # client=client, + # ) + # response = openai_client.batches.list(**list_batch_data) + # return response + + class OpenAIAssistantsAPI(BaseLLM): def __init__(self) -> None: super().__init__() diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 1b97c6836..07812a756 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -99,6 +99,14 @@ class LiteLLMRoutes(enum.Enum): # moderations "/moderations", "/v1/moderations", + # batches + "/v1/batches", + "/batches", + "/v1/batches{batch_id}", + "/batches{batch_id}", + # files + "/v1/files", + "/files", # models "/models", "/v1/models", @@ -1215,6 +1223,7 @@ class InvitationModel(LiteLLMBase): updated_at: datetime updated_by: str + class ConfigFieldInfo(LiteLLMBase): field_name: str field_value: Any diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 083452089..f228b1605 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -100,6 +100,13 @@ from litellm.proxy.utils import ( encrypt_value, decrypt_value, ) +from litellm import ( + CreateBatchRequest, + RetrieveBatchRequest, + ListBatchRequest, + CancelBatchRequest, + CreateFileRequest, +) from litellm.proxy.secret_managers.google_kms import load_google_kms from litellm.proxy.secret_managers.aws_secret_manager import load_aws_secret_manager import pydantic @@ -142,6 +149,7 @@ from fastapi import ( Request, HTTPException, status, + Path, Depends, Header, Response, @@ -499,7 +507,7 @@ async def user_api_key_auth( if route in LiteLLMRoutes.public_routes.value: # check if public endpoint - return UserAPIKeyAuth() + return UserAPIKeyAuth(user_role="app_owner") if general_settings.get("enable_jwt_auth", False) == True: is_jwt = jwt_handler.is_jwt(token=api_key) @@ -1391,7 +1399,9 @@ async def user_api_key_auth( api_key=api_key, user_role="app_owner", **valid_token_dict ) else: - return UserAPIKeyAuth(api_key=api_key, **valid_token_dict) + return UserAPIKeyAuth( + api_key=api_key, user_role="app_owner", **valid_token_dict + ) else: raise Exception() except Exception as e: @@ -5042,6 +5052,447 @@ async def audio_transcriptions( ) +###################################################################### + +# /v1/batches Endpoints + + +###################################################################### +@router.post( + "/v1/batches", + dependencies=[Depends(user_api_key_auth)], + tags=["batch"], +) +@router.post( + "/batches", + dependencies=[Depends(user_api_key_auth)], + tags=["batch"], +) +async def create_batch( + request: Request, + fastapi_response: Response, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Create large batches of API requests for asynchronous processing. + This is the equivalent of POST https://api.openai.com/v1/batch + Supports Identical Params as: https://platform.openai.com/docs/api-reference/batch + + Example Curl + ``` + curl http://localhost:4000/v1/batches \ + -H "Authorization: Bearer sk-1234" \ + -H "Content-Type: application/json" \ + -d '{ + "input_file_id": "file-abc123", + "endpoint": "/v1/chat/completions", + "completion_window": "24h" + }' + ``` + """ + global proxy_logging_obj + data: Dict = {} + try: + # Use orjson to parse JSON data, orjson speeds up requests significantly + form_data = await request.form() + data = {key: value for key, value in form_data.items() if key != "file"} + + # Include original request and headers in the data + data["proxy_server_request"] = { # type: ignore + "url": str(request.url), + "method": request.method, + "headers": dict(request.headers), + "body": copy.copy(data), # use copy instead of deepcopy + } + + if data.get("user", None) is None and user_api_key_dict.user_id is not None: + data["user"] = user_api_key_dict.user_id + + if "metadata" not in data: + data["metadata"] = {} + data["metadata"]["user_api_key"] = user_api_key_dict.api_key + data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata + _headers = dict(request.headers) + _headers.pop( + "authorization", None + ) # do not store the original `sk-..` api key in the db + data["metadata"]["headers"] = _headers + data["metadata"]["user_api_key_alias"] = getattr( + user_api_key_dict, "key_alias", None + ) + data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id + data["metadata"]["user_api_key_team_id"] = getattr( + user_api_key_dict, "team_id", None + ) + data["metadata"]["global_max_parallel_requests"] = general_settings.get( + "global_max_parallel_requests", None + ) + data["metadata"]["user_api_key_team_alias"] = getattr( + user_api_key_dict, "team_alias", None + ) + data["metadata"]["endpoint"] = str(request.url) + + ### TEAM-SPECIFIC PARAMS ### + if user_api_key_dict.team_id is not None: + team_config = await proxy_config.load_team_config( + team_id=user_api_key_dict.team_id + ) + if len(team_config) == 0: + pass + else: + team_id = team_config.pop("team_id", None) + data["metadata"]["team_id"] = team_id + data = { + **team_config, + **data, + } # add the team-specific configs to the completion call + + _create_batch_data = CreateBatchRequest(**data) + + # for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch + response = await litellm.acreate_batch( + custom_llm_provider="openai", **_create_batch_data + ) + + ### ALERTING ### + data["litellm_status"] = "success" # used for alerting + + ### RESPONSE HEADERS ### + hidden_params = getattr(response, "_hidden_params", {}) or {} + model_id = hidden_params.get("model_id", None) or "" + cache_key = hidden_params.get("cache_key", None) or "" + api_base = hidden_params.get("api_base", None) or "" + + fastapi_response.headers.update( + get_custom_headers( + user_api_key_dict=user_api_key_dict, + model_id=model_id, + cache_key=cache_key, + api_base=api_base, + version=version, + model_region=getattr(user_api_key_dict, "allowed_model_region", ""), + ) + ) + + return response + except Exception as e: + data["litellm_status"] = "fail" # used for alerting + await proxy_logging_obj.post_call_failure_hook( + user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data + ) + traceback.print_exc() + if isinstance(e, HTTPException): + raise ProxyException( + message=getattr(e, "message", str(e.detail)), + type=getattr(e, "type", "None"), + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), + ) + else: + error_traceback = traceback.format_exc() + error_msg = f"{str(e)}" + raise ProxyException( + message=getattr(e, "message", error_msg), + type=getattr(e, "type", "None"), + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", 500), + ) + + +@router.get( + "/v1/batches{batch_id}", + dependencies=[Depends(user_api_key_auth)], + tags=["Batch"], +) +@router.get( + "/batches{batch_id}", + dependencies=[Depends(user_api_key_auth)], + tags=["Batch"], +) +async def retrieve_batch( + request: Request, + fastapi_response: Response, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), + batch_id: str = Path( + title="Batch ID to retrieve", description="The ID of the batch to retrieve" + ), +): + """ + Retrieves a batch. + This is the equivalent of GET https://api.openai.com/v1/batches/{batch_id} + Supports Identical Params as: https://platform.openai.com/docs/api-reference/batch/retrieve + + Example Curl + ``` + curl http://localhost:4000/v1/batches/batch_abc123 \ + -H "Authorization: Bearer sk-1234" \ + -H "Content-Type: application/json" \ + + ``` + """ + global proxy_logging_obj + data: Dict = {} + try: + # Use orjson to parse JSON data, orjson speeds up requests significantly + form_data = await request.form() + data = {key: value for key, value in form_data.items() if key != "file"} + + # Include original request and headers in the data + data["proxy_server_request"] = { # type: ignore + "url": str(request.url), + "method": request.method, + "headers": dict(request.headers), + "body": copy.copy(data), # use copy instead of deepcopy + } + + if data.get("user", None) is None and user_api_key_dict.user_id is not None: + data["user"] = user_api_key_dict.user_id + + if "metadata" not in data: + data["metadata"] = {} + data["metadata"]["user_api_key"] = user_api_key_dict.api_key + data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata + _headers = dict(request.headers) + _headers.pop( + "authorization", None + ) # do not store the original `sk-..` api key in the db + data["metadata"]["headers"] = _headers + data["metadata"]["user_api_key_alias"] = getattr( + user_api_key_dict, "key_alias", None + ) + data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id + data["metadata"]["user_api_key_team_id"] = getattr( + user_api_key_dict, "team_id", None + ) + data["metadata"]["global_max_parallel_requests"] = general_settings.get( + "global_max_parallel_requests", None + ) + data["metadata"]["user_api_key_team_alias"] = getattr( + user_api_key_dict, "team_alias", None + ) + data["metadata"]["endpoint"] = str(request.url) + + ### TEAM-SPECIFIC PARAMS ### + if user_api_key_dict.team_id is not None: + team_config = await proxy_config.load_team_config( + team_id=user_api_key_dict.team_id + ) + if len(team_config) == 0: + pass + else: + team_id = team_config.pop("team_id", None) + data["metadata"]["team_id"] = team_id + data = { + **team_config, + **data, + } # add the team-specific configs to the completion call + + _retrieve_batch_request = RetrieveBatchRequest( + batch_id=batch_id, + ) + + # for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch + response = await litellm.aretrieve_batch( + custom_llm_provider="openai", **_retrieve_batch_request + ) + + ### ALERTING ### + data["litellm_status"] = "success" # used for alerting + + ### RESPONSE HEADERS ### + hidden_params = getattr(response, "_hidden_params", {}) or {} + model_id = hidden_params.get("model_id", None) or "" + cache_key = hidden_params.get("cache_key", None) or "" + api_base = hidden_params.get("api_base", None) or "" + + fastapi_response.headers.update( + get_custom_headers( + user_api_key_dict=user_api_key_dict, + model_id=model_id, + cache_key=cache_key, + api_base=api_base, + version=version, + model_region=getattr(user_api_key_dict, "allowed_model_region", ""), + ) + ) + + return response + except Exception as e: + data["litellm_status"] = "fail" # used for alerting + await proxy_logging_obj.post_call_failure_hook( + user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data + ) + traceback.print_exc() + if isinstance(e, HTTPException): + raise ProxyException( + message=getattr(e, "message", str(e.detail)), + type=getattr(e, "type", "None"), + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), + ) + else: + error_traceback = traceback.format_exc() + error_msg = f"{str(e)}" + raise ProxyException( + message=getattr(e, "message", error_msg), + type=getattr(e, "type", "None"), + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", 500), + ) + + +###################################################################### + +# END OF /v1/batches Endpoints Implementation + +###################################################################### + + +###################################################################### + +# /v1/files Endpoints + + +###################################################################### +@router.post( + "/v1/files", + dependencies=[Depends(user_api_key_auth)], + tags=["files"], +) +@router.post( + "/files", + dependencies=[Depends(user_api_key_auth)], + tags=["files"], +) +async def create_file( + request: Request, + fastapi_response: Response, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Upload a file that can be used across - Assistants API, Batch API + This is the equivalent of POST https://api.openai.com/v1/files + + Supports Identical Params as: https://platform.openai.com/docs/api-reference/files/create + + Example Curl + ``` + curl https://api.openai.com/v1/files \ + -H "Authorization: Bearer sk-1234" \ + -F purpose="batch" \ + -F file="@mydata.jsonl" + + ``` + """ + global proxy_logging_obj + data: Dict = {} + try: + # Use orjson to parse JSON data, orjson speeds up requests significantly + form_data = await request.form() + data = {key: value for key, value in form_data.items() if key != "file"} + + # Include original request and headers in the data + data["proxy_server_request"] = { # type: ignore + "url": str(request.url), + "method": request.method, + "headers": dict(request.headers), + "body": copy.copy(data), # use copy instead of deepcopy + } + + if data.get("user", None) is None and user_api_key_dict.user_id is not None: + data["user"] = user_api_key_dict.user_id + + if "metadata" not in data: + data["metadata"] = {} + data["metadata"]["user_api_key"] = user_api_key_dict.api_key + data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata + _headers = dict(request.headers) + _headers.pop( + "authorization", None + ) # do not store the original `sk-..` api key in the db + data["metadata"]["headers"] = _headers + data["metadata"]["user_api_key_alias"] = getattr( + user_api_key_dict, "key_alias", None + ) + data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id + data["metadata"]["user_api_key_team_id"] = getattr( + user_api_key_dict, "team_id", None + ) + data["metadata"]["global_max_parallel_requests"] = general_settings.get( + "global_max_parallel_requests", None + ) + data["metadata"]["user_api_key_team_alias"] = getattr( + user_api_key_dict, "team_alias", None + ) + data["metadata"]["endpoint"] = str(request.url) + + ### TEAM-SPECIFIC PARAMS ### + if user_api_key_dict.team_id is not None: + team_config = await proxy_config.load_team_config( + team_id=user_api_key_dict.team_id + ) + if len(team_config) == 0: + pass + else: + team_id = team_config.pop("team_id", None) + data["metadata"]["team_id"] = team_id + data = { + **team_config, + **data, + } # add the team-specific configs to the completion call + + _create_file_request = CreateFileRequest() + + # for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch + response = await litellm.acreate_file( + custom_llm_provider="openai", **_create_file_request + ) + + ### ALERTING ### + data["litellm_status"] = "success" # used for alerting + + ### RESPONSE HEADERS ### + hidden_params = getattr(response, "_hidden_params", {}) or {} + model_id = hidden_params.get("model_id", None) or "" + cache_key = hidden_params.get("cache_key", None) or "" + api_base = hidden_params.get("api_base", None) or "" + + fastapi_response.headers.update( + get_custom_headers( + user_api_key_dict=user_api_key_dict, + model_id=model_id, + cache_key=cache_key, + api_base=api_base, + version=version, + model_region=getattr(user_api_key_dict, "allowed_model_region", ""), + ) + ) + + return response + except Exception as e: + data["litellm_status"] = "fail" # used for alerting + await proxy_logging_obj.post_call_failure_hook( + user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data + ) + traceback.print_exc() + if isinstance(e, HTTPException): + raise ProxyException( + message=getattr(e, "message", str(e.detail)), + type=getattr(e, "type", "None"), + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), + ) + else: + error_traceback = traceback.format_exc() + error_msg = f"{str(e)}" + raise ProxyException( + message=getattr(e, "message", error_msg), + type=getattr(e, "type", "None"), + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", 500), + ) + + @router.post( "/v1/moderations", dependencies=[Depends(user_api_key_auth)], @@ -9596,28 +10047,54 @@ async def model_streaming_metrics( startTime = startTime or datetime.now() - timedelta(days=7) # show over past week endTime = endTime or datetime.now() - sql_query = """ - SELECT - api_base, - model_group, - model, - DATE_TRUNC('day', "startTime")::DATE AS day, - AVG(EXTRACT(epoch FROM ("completionStartTime" - "startTime"))) AS time_to_first_token - FROM - "LiteLLM_SpendLogs" - WHERE - "startTime" BETWEEN $2::timestamp AND $3::timestamp - AND "model_group" = $1 AND "cache_hit" != 'True' - AND "completionStartTime" IS NOT NULL - AND "completionStartTime" != "endTime" - GROUP BY - api_base, - model_group, - model, - day - ORDER BY - time_to_first_token DESC; - """ + is_same_day = startTime.date() == endTime.date() + if is_same_day: + sql_query = """ + SELECT + api_base, + model_group, + model, + "startTime", + request_id, + EXTRACT(epoch FROM ("completionStartTime" - "startTime")) AS time_to_first_token + FROM + "LiteLLM_SpendLogs" + WHERE + "model_group" = $1 AND "cache_hit" != 'True' + AND "completionStartTime" IS NOT NULL + AND "completionStartTime" != "endTime" + AND DATE("startTime") = DATE($2::timestamp) + GROUP BY + api_base, + model_group, + model, + request_id + ORDER BY + time_to_first_token DESC; + """ + else: + sql_query = """ + SELECT + api_base, + model_group, + model, + DATE_TRUNC('day', "startTime")::DATE AS day, + AVG(EXTRACT(epoch FROM ("completionStartTime" - "startTime"))) AS time_to_first_token + FROM + "LiteLLM_SpendLogs" + WHERE + "startTime" BETWEEN $2::timestamp AND $3::timestamp + AND "model_group" = $1 AND "cache_hit" != 'True' + AND "completionStartTime" IS NOT NULL + AND "completionStartTime" != "endTime" + GROUP BY + api_base, + model_group, + model, + day + ORDER BY + time_to_first_token DESC; + """ _all_api_bases = set() db_response = await prisma_client.db.query_raw( @@ -9628,10 +10105,19 @@ async def model_streaming_metrics( for model_data in db_response: _api_base = model_data["api_base"] _model = model_data["model"] - _day = model_data["day"] time_to_first_token = model_data["time_to_first_token"] - if _day not in _daily_entries: - _daily_entries[_day] = {} + unique_key = "" + if is_same_day: + _request_id = model_data["request_id"] + unique_key = _request_id + if _request_id not in _daily_entries: + _daily_entries[_request_id] = {} + else: + _day = model_data["day"] + unique_key = _day + time_to_first_token = model_data["time_to_first_token"] + if _day not in _daily_entries: + _daily_entries[_day] = {} _combined_model_name = str(_model) if "https://" in _api_base: _combined_model_name = str(_api_base) @@ -9639,7 +10125,8 @@ async def model_streaming_metrics( _combined_model_name = _combined_model_name.split("/openai/")[0] _all_api_bases.add(_combined_model_name) - _daily_entries[_day][_combined_model_name] = time_to_first_token + + _daily_entries[unique_key][_combined_model_name] = time_to_first_token """ each entry needs to be like this: diff --git a/litellm/router.py b/litellm/router.py index 3715ec26c..c79dc2de0 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -663,12 +663,40 @@ class Router: raise e async def abatch_completion( - self, models: List[str], messages: List[Dict[str, str]], **kwargs + self, + models: List[str], + messages: Union[List[Dict[str, str]], List[List[Dict[str, str]]]], + **kwargs, ): """ - Async Batch Completion - Batch Process 1 request to multiple model_group on litellm.Router - Use this for sending the same request to N models + Async Batch Completion. Used for 2 scenarios: + 1. Batch Process 1 request to N models on litellm.Router. Pass messages as List[Dict[str, str]] to use this + 2. Batch Process N requests to M models on litellm.Router. Pass messages as List[List[Dict[str, str]]] to use this + + Example Request for 1 request to N models: + ``` + response = await router.abatch_completion( + models=["gpt-3.5-turbo", "groq-llama"], + messages=[ + {"role": "user", "content": "is litellm becoming a better product ?"} + ], + max_tokens=15, + ) + ``` + + + Example Request for N requests to M models: + ``` + response = await router.abatch_completion( + models=["gpt-3.5-turbo", "groq-llama"], + messages=[ + [{"role": "user", "content": "is litellm becoming a better product ?"}], + [{"role": "user", "content": "who is this"}], + ], + ) + ``` """ + ############## Helpers for async completion ################## async def _async_completion_no_exceptions( model: str, messages: List[Dict[str, str]], **kwargs @@ -681,17 +709,50 @@ class Router: except Exception as e: return e - _tasks = [] - for model in models: - # add each task but if the task fails - _tasks.append( - _async_completion_no_exceptions( - model=model, messages=messages, **kwargs + async def _async_completion_no_exceptions_return_idx( + model: str, + messages: List[Dict[str, str]], + idx: int, # index of message this response corresponds to + **kwargs, + ): + """ + Wrapper around self.async_completion that catches exceptions and returns them as a result + """ + try: + return ( + await self.acompletion(model=model, messages=messages, **kwargs), + idx, ) - ) + except Exception as e: + return e, idx - response = await asyncio.gather(*_tasks) - return response + ############## Helpers for async completion ################## + + if isinstance(messages, list) and all(isinstance(m, dict) for m in messages): + _tasks = [] + for model in models: + # add each task but if the task fails + _tasks.append(_async_completion_no_exceptions(model=model, messages=messages, **kwargs)) # type: ignore + response = await asyncio.gather(*_tasks) + return response + elif isinstance(messages, list) and all(isinstance(m, list) for m in messages): + _tasks = [] + for idx, message in enumerate(messages): + for model in models: + # Request Number X, Model Number Y + _tasks.append( + _async_completion_no_exceptions_return_idx( + model=model, idx=idx, messages=message, **kwargs # type: ignore + ) + ) + responses = await asyncio.gather(*_tasks) + final_responses: List[List[Any]] = [[] for _ in range(len(messages))] + for response in responses: + if isinstance(response, tuple): + final_responses[response[1]].append(response[0]) + else: + final_responses[0].append(response) + return final_responses async def abatch_completion_one_model_multiple_requests( self, model: str, messages: List[List[Dict[str, str]]], **kwargs diff --git a/litellm/tests/openai_batch_completions.jsonl b/litellm/tests/openai_batch_completions.jsonl new file mode 100644 index 000000000..05448952a --- /dev/null +++ b/litellm/tests/openai_batch_completions.jsonl @@ -0,0 +1,2 @@ +{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo-0125", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 10}} +{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo-0125", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 10}} \ No newline at end of file diff --git a/litellm/tests/test_openai_batches.py b/litellm/tests/test_openai_batches.py new file mode 100644 index 000000000..d7e3e1809 --- /dev/null +++ b/litellm/tests/test_openai_batches.py @@ -0,0 +1,161 @@ +# What is this? +## Unit Tests for OpenAI Batches API +import sys, os, json +import traceback +import asyncio +from dotenv import load_dotenv + +load_dotenv() +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import pytest, logging, asyncio +import litellm +from litellm import ( + create_batch, + create_file, +) +import time + + +def test_create_batch(): + """ + 1. Create File for Batch completion + 2. Create Batch Request + 3. Retrieve the specific batch + """ + file_name = "openai_batch_completions.jsonl" + _current_dir = os.path.dirname(os.path.abspath(__file__)) + file_path = os.path.join(_current_dir, file_name) + + file_obj = litellm.create_file( + file=open(file_path, "rb"), + purpose="batch", + custom_llm_provider="openai", + ) + print("Response from creating file=", file_obj) + + batch_input_file_id = file_obj.id + assert ( + batch_input_file_id is not None + ), "Failed to create file, expected a non null file_id but got {batch_input_file_id}" + + create_batch_response = litellm.create_batch( + completion_window="24h", + endpoint="/v1/chat/completions", + input_file_id=batch_input_file_id, + custom_llm_provider="openai", + metadata={"key1": "value1", "key2": "value2"}, + ) + + print("response from litellm.create_batch=", create_batch_response) + + assert ( + create_batch_response.id is not None + ), f"Failed to create batch, expected a non null batch_id but got {create_batch_response.id}" + assert ( + create_batch_response.endpoint == "/v1/chat/completions" + ), f"Failed to create batch, expected endpoint to be /v1/chat/completions but got {create_batch_response.endpoint}" + assert ( + create_batch_response.input_file_id == batch_input_file_id + ), f"Failed to create batch, expected input_file_id to be {batch_input_file_id} but got {create_batch_response.input_file_id}" + + retrieved_batch = litellm.retrieve_batch( + batch_id=create_batch_response.id, custom_llm_provider="openai" + ) + print("retrieved batch=", retrieved_batch) + # just assert that we retrieved a non None batch + + assert retrieved_batch.id == create_batch_response.id + + file_content = litellm.file_content( + file_id=batch_input_file_id, custom_llm_provider="openai" + ) + + result = file_content.content + + result_file_name = "batch_job_results_furniture.jsonl" + + with open(result_file_name, "wb") as file: + file.write(result) + + pass + + +@pytest.mark.asyncio() +async def test_async_create_batch(): + """ + 1. Create File for Batch completion + 2. Create Batch Request + 3. Retrieve the specific batch + """ + print("Testing async create batch") + + file_name = "openai_batch_completions.jsonl" + _current_dir = os.path.dirname(os.path.abspath(__file__)) + file_path = os.path.join(_current_dir, file_name) + file_obj = await litellm.acreate_file( + file=open(file_path, "rb"), + purpose="batch", + custom_llm_provider="openai", + ) + print("Response from creating file=", file_obj) + + batch_input_file_id = file_obj.id + assert ( + batch_input_file_id is not None + ), "Failed to create file, expected a non null file_id but got {batch_input_file_id}" + + create_batch_response = await litellm.acreate_batch( + completion_window="24h", + endpoint="/v1/chat/completions", + input_file_id=batch_input_file_id, + custom_llm_provider="openai", + metadata={"key1": "value1", "key2": "value2"}, + ) + + print("response from litellm.create_batch=", create_batch_response) + + assert ( + create_batch_response.id is not None + ), f"Failed to create batch, expected a non null batch_id but got {create_batch_response.id}" + assert ( + create_batch_response.endpoint == "/v1/chat/completions" + ), f"Failed to create batch, expected endpoint to be /v1/chat/completions but got {create_batch_response.endpoint}" + assert ( + create_batch_response.input_file_id == batch_input_file_id + ), f"Failed to create batch, expected input_file_id to be {batch_input_file_id} but got {create_batch_response.input_file_id}" + + await asyncio.sleep(1) + + retrieved_batch = await litellm.aretrieve_batch( + batch_id=create_batch_response.id, custom_llm_provider="openai" + ) + print("retrieved batch=", retrieved_batch) + # just assert that we retrieved a non None batch + + assert retrieved_batch.id == create_batch_response.id + + # try to get file content for our original file + + file_content = await litellm.afile_content( + file_id=batch_input_file_id, custom_llm_provider="openai" + ) + + print("file content = ", file_content) + + # # write this file content to a file + # with open("file_content.json", "w") as f: + # json.dump(file_content, f) + + +def test_retrieve_batch(): + pass + + +def test_cancel_batch(): + pass + + +def test_list_batch(): + pass diff --git a/litellm/types/llms/openai.py b/litellm/types/llms/openai.py index 1c60ad6db..77791b8ec 100644 --- a/litellm/types/llms/openai.py +++ b/litellm/types/llms/openai.py @@ -6,7 +6,7 @@ from typing import ( Literal, Iterable, ) -from typing_extensions import override, Required +from typing_extensions import override, Required, Dict from pydantic import BaseModel from openai.types.beta.threads.message_content import MessageContent @@ -18,8 +18,24 @@ from openai.types.beta.assistant_tool_param import AssistantToolParam from openai.types.beta.threads.run import Run from openai.types.beta.assistant import Assistant from openai.pagination import SyncCursorPage +from os import PathLike +from openai.types import FileObject, Batch +from openai._legacy_response import HttpxBinaryResponseContent -from typing import TypedDict, List, Optional +from typing import TypedDict, List, Optional, Tuple, Mapping, IO + +FileContent = Union[IO[bytes], bytes, PathLike] + +FileTypes = Union[ + # file (or bytes) + FileContent, + # (filename, file (or bytes)) + Tuple[Optional[str], FileContent], + # (filename, file (or bytes), content_type) + Tuple[Optional[str], FileContent, Optional[str]], + # (filename, file (or bytes), content_type, headers) + Tuple[Optional[str], FileContent, Optional[str], Mapping[str, str]], +] class NotGiven: @@ -146,3 +162,96 @@ class Thread(BaseModel): object: Literal["thread"] """The object type, which is always `thread`.""" + + +# OpenAI Files Types +class CreateFileRequest(TypedDict, total=False): + """ + CreateFileRequest + Used by Assistants API, Batches API, and Fine-Tunes API + + Required Params: + file: FileTypes + purpose: Literal['assistants', 'batch', 'fine-tune'] + + Optional Params: + extra_headers: Optional[Dict[str, str]] + extra_body: Optional[Dict[str, str]] = None + timeout: Optional[float] = None + """ + + file: FileTypes + purpose: Literal["assistants", "batch", "fine-tune"] + extra_headers: Optional[Dict[str, str]] + extra_body: Optional[Dict[str, str]] + timeout: Optional[float] + + +class FileContentRequest(TypedDict, total=False): + """ + FileContentRequest + Used by Assistants API, Batches API, and Fine-Tunes API + + Required Params: + file_id: str + + Optional Params: + extra_headers: Optional[Dict[str, str]] + extra_body: Optional[Dict[str, str]] = None + timeout: Optional[float] = None + """ + + file_id: str + extra_headers: Optional[Dict[str, str]] + extra_body: Optional[Dict[str, str]] + timeout: Optional[float] + + +# OpenAI Batches Types +class CreateBatchRequest(TypedDict, total=False): + """ + CreateBatchRequest + """ + + completion_window: Literal["24h"] + endpoint: Literal["/v1/chat/completions", "/v1/embeddings", "/v1/completions"] + input_file_id: str + metadata: Optional[Dict[str, str]] + extra_headers: Optional[Dict[str, str]] + extra_body: Optional[Dict[str, str]] + timeout: Optional[float] + + +class RetrieveBatchRequest(TypedDict, total=False): + """ + RetrieveBatchRequest + """ + + batch_id: str + extra_headers: Optional[Dict[str, str]] + extra_body: Optional[Dict[str, str]] + timeout: Optional[float] + + +class CancelBatchRequest(TypedDict, total=False): + """ + CancelBatchRequest + """ + + batch_id: str + extra_headers: Optional[Dict[str, str]] + extra_body: Optional[Dict[str, str]] + timeout: Optional[float] + + +class ListBatchRequest(TypedDict, total=False): + """ + ListBatchRequest - List your organization's batches + Calls https://api.openai.com/v1/batches + """ + + after: Union[str, NotGiven] + limit: Union[int, NotGiven] + extra_headers: Optional[Dict[str, str]] + extra_body: Optional[Dict[str, str]] + timeout: Optional[float] diff --git a/pyproject.toml b/pyproject.toml index 0fb6b3269..d124ea4a1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "litellm" -version = "1.39.0" +version = "1.39.1" description = "Library to easily interface with LLM API providers" authors = ["BerriAI"] license = "MIT" @@ -79,7 +79,7 @@ requires = ["poetry-core", "wheel"] build-backend = "poetry.core.masonry.api" [tool.commitizen] -version = "1.39.0" +version = "1.39.1" version_files = [ "pyproject.toml:^version" ]