mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
feat(batches): add azure openai batches endpoint support
Closes https://github.com/BerriAI/litellm/issues/5073
This commit is contained in:
parent
2f9f01e72c
commit
a9b5d5271f
7 changed files with 584 additions and 173 deletions
|
@ -117,7 +117,7 @@ disable_streaming_logging: bool = False
|
||||||
in_memory_llm_clients_cache: dict = {}
|
in_memory_llm_clients_cache: dict = {}
|
||||||
safe_memory_mode: bool = False
|
safe_memory_mode: bool = False
|
||||||
### DEFAULT AZURE API VERSION ###
|
### DEFAULT AZURE API VERSION ###
|
||||||
AZURE_DEFAULT_API_VERSION = "2024-02-01" # this is updated to the latest
|
AZURE_DEFAULT_API_VERSION = "2024-07-01-preview" # this is updated to the latest
|
||||||
### GUARDRAILS ###
|
### GUARDRAILS ###
|
||||||
llamaguard_model_name: Optional[str] = None
|
llamaguard_model_name: Optional[str] = None
|
||||||
openai_moderations_model_name: Optional[str] = None
|
openai_moderations_model_name: Optional[str] = None
|
||||||
|
|
|
@ -20,7 +20,8 @@ import httpx
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import client
|
from litellm import client
|
||||||
from litellm.llms.openai import OpenAIBatchesAPI, OpenAIFilesAPI
|
from litellm.llms.azure import AzureBatchesAPI
|
||||||
|
from litellm.llms.openai import OpenAIBatchesAPI
|
||||||
from litellm.types.llms.openai import (
|
from litellm.types.llms.openai import (
|
||||||
Batch,
|
Batch,
|
||||||
CancelBatchRequest,
|
CancelBatchRequest,
|
||||||
|
@ -33,10 +34,11 @@ from litellm.types.llms.openai import (
|
||||||
RetrieveBatchRequest,
|
RetrieveBatchRequest,
|
||||||
)
|
)
|
||||||
from litellm.types.router import GenericLiteLLMParams
|
from litellm.types.router import GenericLiteLLMParams
|
||||||
from litellm.utils import supports_httpx_timeout
|
from litellm.utils import get_secret, supports_httpx_timeout
|
||||||
|
|
||||||
####### ENVIRONMENT VARIABLES ###################
|
####### ENVIRONMENT VARIABLES ###################
|
||||||
openai_batches_instance = OpenAIBatchesAPI()
|
openai_batches_instance = OpenAIBatchesAPI()
|
||||||
|
azure_batches_instance = AzureBatchesAPI()
|
||||||
#################################################
|
#################################################
|
||||||
|
|
||||||
|
|
||||||
|
@ -90,7 +92,7 @@ def create_batch(
|
||||||
completion_window: Literal["24h"],
|
completion_window: Literal["24h"],
|
||||||
endpoint: Literal["/v1/chat/completions", "/v1/embeddings", "/v1/completions"],
|
endpoint: Literal["/v1/chat/completions", "/v1/embeddings", "/v1/completions"],
|
||||||
input_file_id: str,
|
input_file_id: str,
|
||||||
custom_llm_provider: Literal["openai"] = "openai",
|
custom_llm_provider: Literal["openai", "azure"] = "openai",
|
||||||
metadata: Optional[Dict[str, str]] = None,
|
metadata: Optional[Dict[str, str]] = None,
|
||||||
extra_headers: Optional[Dict[str, str]] = None,
|
extra_headers: Optional[Dict[str, str]] = None,
|
||||||
extra_body: Optional[Dict[str, str]] = None,
|
extra_body: Optional[Dict[str, str]] = None,
|
||||||
|
@ -103,6 +105,32 @@ def create_batch(
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
optional_params = GenericLiteLLMParams(**kwargs)
|
optional_params = GenericLiteLLMParams(**kwargs)
|
||||||
|
_is_async = kwargs.pop("acreate_batch", False) is True
|
||||||
|
### TIMEOUT LOGIC ###
|
||||||
|
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||||
|
# set timeout for 10 minutes by default
|
||||||
|
|
||||||
|
if (
|
||||||
|
timeout is not None
|
||||||
|
and isinstance(timeout, httpx.Timeout)
|
||||||
|
and supports_httpx_timeout(custom_llm_provider) is False
|
||||||
|
):
|
||||||
|
read_timeout = timeout.read or 600
|
||||||
|
timeout = read_timeout # default 10 min timeout
|
||||||
|
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
||||||
|
timeout = float(timeout) # type: ignore
|
||||||
|
elif timeout is None:
|
||||||
|
timeout = 600.0
|
||||||
|
|
||||||
|
_create_batch_request = CreateBatchRequest(
|
||||||
|
completion_window=completion_window,
|
||||||
|
endpoint=endpoint,
|
||||||
|
input_file_id=input_file_id,
|
||||||
|
metadata=metadata,
|
||||||
|
extra_headers=extra_headers,
|
||||||
|
extra_body=extra_body,
|
||||||
|
)
|
||||||
|
|
||||||
if custom_llm_provider == "openai":
|
if custom_llm_provider == "openai":
|
||||||
|
|
||||||
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
||||||
|
@ -125,34 +153,6 @@ def create_batch(
|
||||||
or litellm.openai_key
|
or litellm.openai_key
|
||||||
or os.getenv("OPENAI_API_KEY")
|
or os.getenv("OPENAI_API_KEY")
|
||||||
)
|
)
|
||||||
### TIMEOUT LOGIC ###
|
|
||||||
timeout = (
|
|
||||||
optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
|
||||||
)
|
|
||||||
# set timeout for 10 minutes by default
|
|
||||||
|
|
||||||
if (
|
|
||||||
timeout is not None
|
|
||||||
and isinstance(timeout, httpx.Timeout)
|
|
||||||
and supports_httpx_timeout(custom_llm_provider) == False
|
|
||||||
):
|
|
||||||
read_timeout = timeout.read or 600
|
|
||||||
timeout = read_timeout # default 10 min timeout
|
|
||||||
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
|
||||||
timeout = float(timeout) # type: ignore
|
|
||||||
elif timeout is None:
|
|
||||||
timeout = 600.0
|
|
||||||
|
|
||||||
_is_async = kwargs.pop("acreate_batch", False) is True
|
|
||||||
|
|
||||||
_create_batch_request = CreateBatchRequest(
|
|
||||||
completion_window=completion_window,
|
|
||||||
endpoint=endpoint,
|
|
||||||
input_file_id=input_file_id,
|
|
||||||
metadata=metadata,
|
|
||||||
extra_headers=extra_headers,
|
|
||||||
extra_body=extra_body,
|
|
||||||
)
|
|
||||||
|
|
||||||
response = openai_batches_instance.create_batch(
|
response = openai_batches_instance.create_batch(
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
@ -163,6 +163,38 @@ def create_batch(
|
||||||
max_retries=optional_params.max_retries,
|
max_retries=optional_params.max_retries,
|
||||||
_is_async=_is_async,
|
_is_async=_is_async,
|
||||||
)
|
)
|
||||||
|
elif custom_llm_provider == "azure":
|
||||||
|
api_base = optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore
|
||||||
|
api_version = (
|
||||||
|
optional_params.api_version
|
||||||
|
or litellm.api_version
|
||||||
|
or get_secret("AZURE_API_VERSION")
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
api_key = (
|
||||||
|
optional_params.api_key
|
||||||
|
or litellm.api_key
|
||||||
|
or litellm.azure_key
|
||||||
|
or get_secret("AZURE_OPENAI_API_KEY")
|
||||||
|
or get_secret("AZURE_API_KEY")
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
extra_body = optional_params.get("extra_body", {})
|
||||||
|
azure_ad_token: Optional[str] = None
|
||||||
|
if extra_body is not None:
|
||||||
|
azure_ad_token = extra_body.pop("azure_ad_token", None)
|
||||||
|
else:
|
||||||
|
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
|
||||||
|
|
||||||
|
response = azure_batches_instance.create_batch(
|
||||||
|
_is_async=_is_async,
|
||||||
|
api_base=api_base,
|
||||||
|
api_key=api_key,
|
||||||
|
api_version=api_version,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=optional_params.max_retries,
|
||||||
|
create_batch_data=_create_batch_request,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise litellm.exceptions.BadRequestError(
|
raise litellm.exceptions.BadRequestError(
|
||||||
message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format(
|
message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format(
|
||||||
|
@ -225,7 +257,7 @@ async def aretrieve_batch(
|
||||||
|
|
||||||
def retrieve_batch(
|
def retrieve_batch(
|
||||||
batch_id: str,
|
batch_id: str,
|
||||||
custom_llm_provider: Literal["openai"] = "openai",
|
custom_llm_provider: Literal["openai", "azure"] = "openai",
|
||||||
metadata: Optional[Dict[str, str]] = None,
|
metadata: Optional[Dict[str, str]] = None,
|
||||||
extra_headers: Optional[Dict[str, str]] = None,
|
extra_headers: Optional[Dict[str, str]] = None,
|
||||||
extra_body: Optional[Dict[str, str]] = None,
|
extra_body: Optional[Dict[str, str]] = None,
|
||||||
|
@ -238,6 +270,30 @@ def retrieve_batch(
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
optional_params = GenericLiteLLMParams(**kwargs)
|
optional_params = GenericLiteLLMParams(**kwargs)
|
||||||
|
### TIMEOUT LOGIC ###
|
||||||
|
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||||
|
# set timeout for 10 minutes by default
|
||||||
|
|
||||||
|
if (
|
||||||
|
timeout is not None
|
||||||
|
and isinstance(timeout, httpx.Timeout)
|
||||||
|
and supports_httpx_timeout(custom_llm_provider) is False
|
||||||
|
):
|
||||||
|
read_timeout = timeout.read or 600
|
||||||
|
timeout = read_timeout # default 10 min timeout
|
||||||
|
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
||||||
|
timeout = float(timeout) # type: ignore
|
||||||
|
elif timeout is None:
|
||||||
|
timeout = 600.0
|
||||||
|
|
||||||
|
_retrieve_batch_request = RetrieveBatchRequest(
|
||||||
|
batch_id=batch_id,
|
||||||
|
extra_headers=extra_headers,
|
||||||
|
extra_body=extra_body,
|
||||||
|
)
|
||||||
|
|
||||||
|
_is_async = kwargs.pop("aretrieve_batch", False) is True
|
||||||
|
|
||||||
if custom_llm_provider == "openai":
|
if custom_llm_provider == "openai":
|
||||||
|
|
||||||
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
||||||
|
@ -260,31 +316,6 @@ def retrieve_batch(
|
||||||
or litellm.openai_key
|
or litellm.openai_key
|
||||||
or os.getenv("OPENAI_API_KEY")
|
or os.getenv("OPENAI_API_KEY")
|
||||||
)
|
)
|
||||||
### TIMEOUT LOGIC ###
|
|
||||||
timeout = (
|
|
||||||
optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
|
||||||
)
|
|
||||||
# set timeout for 10 minutes by default
|
|
||||||
|
|
||||||
if (
|
|
||||||
timeout is not None
|
|
||||||
and isinstance(timeout, httpx.Timeout)
|
|
||||||
and supports_httpx_timeout(custom_llm_provider) == False
|
|
||||||
):
|
|
||||||
read_timeout = timeout.read or 600
|
|
||||||
timeout = read_timeout # default 10 min timeout
|
|
||||||
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
|
||||||
timeout = float(timeout) # type: ignore
|
|
||||||
elif timeout is None:
|
|
||||||
timeout = 600.0
|
|
||||||
|
|
||||||
_retrieve_batch_request = RetrieveBatchRequest(
|
|
||||||
batch_id=batch_id,
|
|
||||||
extra_headers=extra_headers,
|
|
||||||
extra_body=extra_body,
|
|
||||||
)
|
|
||||||
|
|
||||||
_is_async = kwargs.pop("aretrieve_batch", False) is True
|
|
||||||
|
|
||||||
response = openai_batches_instance.retrieve_batch(
|
response = openai_batches_instance.retrieve_batch(
|
||||||
_is_async=_is_async,
|
_is_async=_is_async,
|
||||||
|
@ -295,6 +326,38 @@ def retrieve_batch(
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
max_retries=optional_params.max_retries,
|
max_retries=optional_params.max_retries,
|
||||||
)
|
)
|
||||||
|
elif custom_llm_provider == "azure":
|
||||||
|
api_base = optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore
|
||||||
|
api_version = (
|
||||||
|
optional_params.api_version
|
||||||
|
or litellm.api_version
|
||||||
|
or get_secret("AZURE_API_VERSION")
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
api_key = (
|
||||||
|
optional_params.api_key
|
||||||
|
or litellm.api_key
|
||||||
|
or litellm.azure_key
|
||||||
|
or get_secret("AZURE_OPENAI_API_KEY")
|
||||||
|
or get_secret("AZURE_API_KEY")
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
extra_body = optional_params.get("extra_body", {})
|
||||||
|
azure_ad_token: Optional[str] = None
|
||||||
|
if extra_body is not None:
|
||||||
|
azure_ad_token = extra_body.pop("azure_ad_token", None)
|
||||||
|
else:
|
||||||
|
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
|
||||||
|
|
||||||
|
response = azure_batches_instance.retrieve_batch(
|
||||||
|
_is_async=_is_async,
|
||||||
|
api_base=api_base,
|
||||||
|
api_key=api_key,
|
||||||
|
api_version=api_version,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=optional_params.max_retries,
|
||||||
|
retrieve_batch_data=_retrieve_batch_request,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise litellm.exceptions.BadRequestError(
|
raise litellm.exceptions.BadRequestError(
|
||||||
message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format(
|
message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format(
|
||||||
|
@ -357,7 +420,7 @@ async def alist_batches(
|
||||||
def list_batches(
|
def list_batches(
|
||||||
after: Optional[str] = None,
|
after: Optional[str] = None,
|
||||||
limit: Optional[int] = None,
|
limit: Optional[int] = None,
|
||||||
custom_llm_provider: Literal["openai"] = "openai",
|
custom_llm_provider: Literal["openai", "azure"] = "openai",
|
||||||
extra_headers: Optional[Dict[str, str]] = None,
|
extra_headers: Optional[Dict[str, str]] = None,
|
||||||
extra_body: Optional[Dict[str, str]] = None,
|
extra_body: Optional[Dict[str, str]] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
|
@ -368,7 +431,31 @@ def list_batches(
|
||||||
List your organization's batches.
|
List your organization's batches.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
# set API KEY
|
||||||
optional_params = GenericLiteLLMParams(**kwargs)
|
optional_params = GenericLiteLLMParams(**kwargs)
|
||||||
|
api_key = (
|
||||||
|
optional_params.api_key
|
||||||
|
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
|
||||||
|
or litellm.openai_key
|
||||||
|
or os.getenv("OPENAI_API_KEY")
|
||||||
|
)
|
||||||
|
### TIMEOUT LOGIC ###
|
||||||
|
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||||
|
# set timeout for 10 minutes by default
|
||||||
|
|
||||||
|
if (
|
||||||
|
timeout is not None
|
||||||
|
and isinstance(timeout, httpx.Timeout)
|
||||||
|
and supports_httpx_timeout(custom_llm_provider) is False
|
||||||
|
):
|
||||||
|
read_timeout = timeout.read or 600
|
||||||
|
timeout = read_timeout # default 10 min timeout
|
||||||
|
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
||||||
|
timeout = float(timeout) # type: ignore
|
||||||
|
elif timeout is None:
|
||||||
|
timeout = 600.0
|
||||||
|
|
||||||
|
_is_async = kwargs.pop("alist_batches", False) is True
|
||||||
if custom_llm_provider == "openai":
|
if custom_llm_provider == "openai":
|
||||||
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
||||||
api_base = (
|
api_base = (
|
||||||
|
@ -383,32 +470,6 @@ def list_batches(
|
||||||
or os.getenv("OPENAI_ORGANIZATION", None)
|
or os.getenv("OPENAI_ORGANIZATION", None)
|
||||||
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
|
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
|
||||||
)
|
)
|
||||||
# set API KEY
|
|
||||||
api_key = (
|
|
||||||
optional_params.api_key
|
|
||||||
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
|
|
||||||
or litellm.openai_key
|
|
||||||
or os.getenv("OPENAI_API_KEY")
|
|
||||||
)
|
|
||||||
### TIMEOUT LOGIC ###
|
|
||||||
timeout = (
|
|
||||||
optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
|
||||||
)
|
|
||||||
# set timeout for 10 minutes by default
|
|
||||||
|
|
||||||
if (
|
|
||||||
timeout is not None
|
|
||||||
and isinstance(timeout, httpx.Timeout)
|
|
||||||
and supports_httpx_timeout(custom_llm_provider) == False
|
|
||||||
):
|
|
||||||
read_timeout = timeout.read or 600
|
|
||||||
timeout = read_timeout # default 10 min timeout
|
|
||||||
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
|
||||||
timeout = float(timeout) # type: ignore
|
|
||||||
elif timeout is None:
|
|
||||||
timeout = 600.0
|
|
||||||
|
|
||||||
_is_async = kwargs.pop("alist_batches", False) is True
|
|
||||||
|
|
||||||
response = openai_batches_instance.list_batches(
|
response = openai_batches_instance.list_batches(
|
||||||
_is_async=_is_async,
|
_is_async=_is_async,
|
||||||
|
@ -420,9 +481,40 @@ def list_batches(
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
max_retries=optional_params.max_retries,
|
max_retries=optional_params.max_retries,
|
||||||
)
|
)
|
||||||
|
elif custom_llm_provider == "azure":
|
||||||
|
api_base = optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore
|
||||||
|
api_version = (
|
||||||
|
optional_params.api_version
|
||||||
|
or litellm.api_version
|
||||||
|
or get_secret("AZURE_API_VERSION")
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
api_key = (
|
||||||
|
optional_params.api_key
|
||||||
|
or litellm.api_key
|
||||||
|
or litellm.azure_key
|
||||||
|
or get_secret("AZURE_OPENAI_API_KEY")
|
||||||
|
or get_secret("AZURE_API_KEY")
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
extra_body = optional_params.get("extra_body", {})
|
||||||
|
azure_ad_token: Optional[str] = None
|
||||||
|
if extra_body is not None:
|
||||||
|
azure_ad_token = extra_body.pop("azure_ad_token", None)
|
||||||
|
else:
|
||||||
|
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
|
||||||
|
|
||||||
|
response = azure_batches_instance.list_batches(
|
||||||
|
_is_async=_is_async,
|
||||||
|
api_base=api_base,
|
||||||
|
api_key=api_key,
|
||||||
|
api_version=api_version,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=optional_params.max_retries,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise litellm.exceptions.BadRequestError(
|
raise litellm.exceptions.BadRequestError(
|
||||||
message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format(
|
message="LiteLLM doesn't support {} for 'list_batch'. Only 'openai' is supported.".format(
|
||||||
custom_llm_provider
|
custom_llm_provider
|
||||||
),
|
),
|
||||||
model="n/a",
|
model="n/a",
|
||||||
|
|
|
@ -87,6 +87,24 @@ def file_retrieve(
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
optional_params = GenericLiteLLMParams(**kwargs)
|
optional_params = GenericLiteLLMParams(**kwargs)
|
||||||
|
### TIMEOUT LOGIC ###
|
||||||
|
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||||
|
# set timeout for 10 minutes by default
|
||||||
|
|
||||||
|
if (
|
||||||
|
timeout is not None
|
||||||
|
and isinstance(timeout, httpx.Timeout)
|
||||||
|
and supports_httpx_timeout(custom_llm_provider) is False
|
||||||
|
):
|
||||||
|
read_timeout = timeout.read or 600
|
||||||
|
timeout = read_timeout # default 10 min timeout
|
||||||
|
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
||||||
|
timeout = float(timeout) # type: ignore
|
||||||
|
elif timeout is None:
|
||||||
|
timeout = 600.0
|
||||||
|
|
||||||
|
_is_async = kwargs.pop("is_async", False) is True
|
||||||
|
|
||||||
if custom_llm_provider == "openai":
|
if custom_llm_provider == "openai":
|
||||||
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
||||||
api_base = (
|
api_base = (
|
||||||
|
@ -108,25 +126,6 @@ def file_retrieve(
|
||||||
or litellm.openai_key
|
or litellm.openai_key
|
||||||
or os.getenv("OPENAI_API_KEY")
|
or os.getenv("OPENAI_API_KEY")
|
||||||
)
|
)
|
||||||
### TIMEOUT LOGIC ###
|
|
||||||
timeout = (
|
|
||||||
optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
|
||||||
)
|
|
||||||
# set timeout for 10 minutes by default
|
|
||||||
|
|
||||||
if (
|
|
||||||
timeout is not None
|
|
||||||
and isinstance(timeout, httpx.Timeout)
|
|
||||||
and supports_httpx_timeout(custom_llm_provider) == False
|
|
||||||
):
|
|
||||||
read_timeout = timeout.read or 600
|
|
||||||
timeout = read_timeout # default 10 min timeout
|
|
||||||
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
|
||||||
timeout = float(timeout) # type: ignore
|
|
||||||
elif timeout is None:
|
|
||||||
timeout = 600.0
|
|
||||||
|
|
||||||
_is_async = kwargs.pop("is_async", False) is True
|
|
||||||
|
|
||||||
response = openai_files_instance.retrieve_file(
|
response = openai_files_instance.retrieve_file(
|
||||||
file_id=file_id,
|
file_id=file_id,
|
||||||
|
@ -137,9 +136,41 @@ def file_retrieve(
|
||||||
max_retries=optional_params.max_retries,
|
max_retries=optional_params.max_retries,
|
||||||
organization=organization,
|
organization=organization,
|
||||||
)
|
)
|
||||||
|
elif custom_llm_provider == "azure":
|
||||||
|
api_base = optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore
|
||||||
|
api_version = (
|
||||||
|
optional_params.api_version
|
||||||
|
or litellm.api_version
|
||||||
|
or get_secret("AZURE_API_VERSION")
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
api_key = (
|
||||||
|
optional_params.api_key
|
||||||
|
or litellm.api_key
|
||||||
|
or litellm.azure_key
|
||||||
|
or get_secret("AZURE_OPENAI_API_KEY")
|
||||||
|
or get_secret("AZURE_API_KEY")
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
extra_body = optional_params.get("extra_body", {})
|
||||||
|
azure_ad_token: Optional[str] = None
|
||||||
|
if extra_body is not None:
|
||||||
|
azure_ad_token = extra_body.pop("azure_ad_token", None)
|
||||||
|
else:
|
||||||
|
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
|
||||||
|
|
||||||
|
response = azure_files_instance.retrieve_file(
|
||||||
|
_is_async=_is_async,
|
||||||
|
api_base=api_base,
|
||||||
|
api_key=api_key,
|
||||||
|
api_version=api_version,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=optional_params.max_retries,
|
||||||
|
file_id=file_id,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise litellm.exceptions.BadRequestError(
|
raise litellm.exceptions.BadRequestError(
|
||||||
message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format(
|
message="LiteLLM doesn't support {} for 'file_retrieve'. Only 'openai' and 'azure' are supported.".format(
|
||||||
custom_llm_provider
|
custom_llm_provider
|
||||||
),
|
),
|
||||||
model="n/a",
|
model="n/a",
|
||||||
|
@ -361,6 +392,23 @@ def file_list(
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
optional_params = GenericLiteLLMParams(**kwargs)
|
optional_params = GenericLiteLLMParams(**kwargs)
|
||||||
|
### TIMEOUT LOGIC ###
|
||||||
|
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||||
|
# set timeout for 10 minutes by default
|
||||||
|
|
||||||
|
if (
|
||||||
|
timeout is not None
|
||||||
|
and isinstance(timeout, httpx.Timeout)
|
||||||
|
and supports_httpx_timeout(custom_llm_provider) == False
|
||||||
|
):
|
||||||
|
read_timeout = timeout.read or 600
|
||||||
|
timeout = read_timeout # default 10 min timeout
|
||||||
|
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
||||||
|
timeout = float(timeout) # type: ignore
|
||||||
|
elif timeout is None:
|
||||||
|
timeout = 600.0
|
||||||
|
|
||||||
|
_is_async = kwargs.pop("is_async", False) is True
|
||||||
if custom_llm_provider == "openai":
|
if custom_llm_provider == "openai":
|
||||||
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
||||||
api_base = (
|
api_base = (
|
||||||
|
@ -382,25 +430,6 @@ def file_list(
|
||||||
or litellm.openai_key
|
or litellm.openai_key
|
||||||
or os.getenv("OPENAI_API_KEY")
|
or os.getenv("OPENAI_API_KEY")
|
||||||
)
|
)
|
||||||
### TIMEOUT LOGIC ###
|
|
||||||
timeout = (
|
|
||||||
optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
|
||||||
)
|
|
||||||
# set timeout for 10 minutes by default
|
|
||||||
|
|
||||||
if (
|
|
||||||
timeout is not None
|
|
||||||
and isinstance(timeout, httpx.Timeout)
|
|
||||||
and supports_httpx_timeout(custom_llm_provider) == False
|
|
||||||
):
|
|
||||||
read_timeout = timeout.read or 600
|
|
||||||
timeout = read_timeout # default 10 min timeout
|
|
||||||
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
|
||||||
timeout = float(timeout) # type: ignore
|
|
||||||
elif timeout is None:
|
|
||||||
timeout = 600.0
|
|
||||||
|
|
||||||
_is_async = kwargs.pop("is_async", False) is True
|
|
||||||
|
|
||||||
response = openai_files_instance.list_files(
|
response = openai_files_instance.list_files(
|
||||||
purpose=purpose,
|
purpose=purpose,
|
||||||
|
@ -411,9 +440,41 @@ def file_list(
|
||||||
max_retries=optional_params.max_retries,
|
max_retries=optional_params.max_retries,
|
||||||
organization=organization,
|
organization=organization,
|
||||||
)
|
)
|
||||||
|
elif custom_llm_provider == "azure":
|
||||||
|
api_base = optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore
|
||||||
|
api_version = (
|
||||||
|
optional_params.api_version
|
||||||
|
or litellm.api_version
|
||||||
|
or get_secret("AZURE_API_VERSION")
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
api_key = (
|
||||||
|
optional_params.api_key
|
||||||
|
or litellm.api_key
|
||||||
|
or litellm.azure_key
|
||||||
|
or get_secret("AZURE_OPENAI_API_KEY")
|
||||||
|
or get_secret("AZURE_API_KEY")
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
extra_body = optional_params.get("extra_body", {})
|
||||||
|
azure_ad_token: Optional[str] = None
|
||||||
|
if extra_body is not None:
|
||||||
|
azure_ad_token = extra_body.pop("azure_ad_token", None)
|
||||||
|
else:
|
||||||
|
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
|
||||||
|
|
||||||
|
response = azure_files_instance.list_files(
|
||||||
|
_is_async=_is_async,
|
||||||
|
api_base=api_base,
|
||||||
|
api_key=api_key,
|
||||||
|
api_version=api_version,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=optional_params.max_retries,
|
||||||
|
purpose=purpose,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise litellm.exceptions.BadRequestError(
|
raise litellm.exceptions.BadRequestError(
|
||||||
message="LiteLLM doesn't support {} for 'file_list'. Only 'openai' is supported.".format(
|
message="LiteLLM doesn't support {} for 'file_list'. Only 'openai' and 'azure' are supported.".format(
|
||||||
custom_llm_provider
|
custom_llm_provider
|
||||||
),
|
),
|
||||||
model="n/a",
|
model="n/a",
|
||||||
|
@ -645,6 +706,29 @@ def file_content(
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
optional_params = GenericLiteLLMParams(**kwargs)
|
optional_params = GenericLiteLLMParams(**kwargs)
|
||||||
|
### TIMEOUT LOGIC ###
|
||||||
|
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||||
|
# set timeout for 10 minutes by default
|
||||||
|
|
||||||
|
if (
|
||||||
|
timeout is not None
|
||||||
|
and isinstance(timeout, httpx.Timeout)
|
||||||
|
and supports_httpx_timeout(custom_llm_provider) == False
|
||||||
|
):
|
||||||
|
read_timeout = timeout.read or 600
|
||||||
|
timeout = read_timeout # default 10 min timeout
|
||||||
|
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
||||||
|
timeout = float(timeout) # type: ignore
|
||||||
|
elif timeout is None:
|
||||||
|
timeout = 600.0
|
||||||
|
|
||||||
|
_file_content_request = FileContentRequest(
|
||||||
|
file_id=file_id,
|
||||||
|
extra_headers=extra_headers,
|
||||||
|
extra_body=extra_body,
|
||||||
|
)
|
||||||
|
|
||||||
|
_is_async = kwargs.pop("afile_content", False) is True
|
||||||
if custom_llm_provider == "openai":
|
if custom_llm_provider == "openai":
|
||||||
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
||||||
api_base = (
|
api_base = (
|
||||||
|
@ -666,31 +750,6 @@ def file_content(
|
||||||
or litellm.openai_key
|
or litellm.openai_key
|
||||||
or os.getenv("OPENAI_API_KEY")
|
or os.getenv("OPENAI_API_KEY")
|
||||||
)
|
)
|
||||||
### TIMEOUT LOGIC ###
|
|
||||||
timeout = (
|
|
||||||
optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
|
||||||
)
|
|
||||||
# set timeout for 10 minutes by default
|
|
||||||
|
|
||||||
if (
|
|
||||||
timeout is not None
|
|
||||||
and isinstance(timeout, httpx.Timeout)
|
|
||||||
and supports_httpx_timeout(custom_llm_provider) == False
|
|
||||||
):
|
|
||||||
read_timeout = timeout.read or 600
|
|
||||||
timeout = read_timeout # default 10 min timeout
|
|
||||||
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
|
||||||
timeout = float(timeout) # type: ignore
|
|
||||||
elif timeout is None:
|
|
||||||
timeout = 600.0
|
|
||||||
|
|
||||||
_file_content_request = FileContentRequest(
|
|
||||||
file_id=file_id,
|
|
||||||
extra_headers=extra_headers,
|
|
||||||
extra_body=extra_body,
|
|
||||||
)
|
|
||||||
|
|
||||||
_is_async = kwargs.pop("afile_content", False) is True
|
|
||||||
|
|
||||||
response = openai_files_instance.file_content(
|
response = openai_files_instance.file_content(
|
||||||
_is_async=_is_async,
|
_is_async=_is_async,
|
||||||
|
@ -701,9 +760,41 @@ def file_content(
|
||||||
max_retries=optional_params.max_retries,
|
max_retries=optional_params.max_retries,
|
||||||
organization=organization,
|
organization=organization,
|
||||||
)
|
)
|
||||||
|
elif custom_llm_provider == "azure":
|
||||||
|
api_base = optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore
|
||||||
|
api_version = (
|
||||||
|
optional_params.api_version
|
||||||
|
or litellm.api_version
|
||||||
|
or get_secret("AZURE_API_VERSION")
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
api_key = (
|
||||||
|
optional_params.api_key
|
||||||
|
or litellm.api_key
|
||||||
|
or litellm.azure_key
|
||||||
|
or get_secret("AZURE_OPENAI_API_KEY")
|
||||||
|
or get_secret("AZURE_API_KEY")
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
extra_body = optional_params.get("extra_body", {})
|
||||||
|
azure_ad_token: Optional[str] = None
|
||||||
|
if extra_body is not None:
|
||||||
|
azure_ad_token = extra_body.pop("azure_ad_token", None)
|
||||||
|
else:
|
||||||
|
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
|
||||||
|
|
||||||
|
response = azure_files_instance.file_content(
|
||||||
|
_is_async=_is_async,
|
||||||
|
api_base=api_base,
|
||||||
|
api_key=api_key,
|
||||||
|
api_version=api_version,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=optional_params.max_retries,
|
||||||
|
file_content_request=_file_content_request,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise litellm.exceptions.BadRequestError(
|
raise litellm.exceptions.BadRequestError(
|
||||||
message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format(
|
message="LiteLLM doesn't support {} for 'file_content'. Only 'openai' and 'azure' are supported.".format(
|
||||||
custom_llm_provider
|
custom_llm_provider
|
||||||
),
|
),
|
||||||
model="n/a",
|
model="n/a",
|
||||||
|
|
|
@ -47,14 +47,18 @@ from ..types.llms.openai import (
|
||||||
AsyncAssistantEventHandler,
|
AsyncAssistantEventHandler,
|
||||||
AsyncAssistantStreamManager,
|
AsyncAssistantStreamManager,
|
||||||
AsyncCursorPage,
|
AsyncCursorPage,
|
||||||
|
Batch,
|
||||||
|
CancelBatchRequest,
|
||||||
ChatCompletionToolChoiceFunctionParam,
|
ChatCompletionToolChoiceFunctionParam,
|
||||||
ChatCompletionToolChoiceObjectParam,
|
ChatCompletionToolChoiceObjectParam,
|
||||||
ChatCompletionToolParam,
|
ChatCompletionToolParam,
|
||||||
ChatCompletionToolParamFunctionChunk,
|
ChatCompletionToolParamFunctionChunk,
|
||||||
|
CreateBatchRequest,
|
||||||
HttpxBinaryResponseContent,
|
HttpxBinaryResponseContent,
|
||||||
MessageData,
|
MessageData,
|
||||||
OpenAICreateThreadParamsMessage,
|
OpenAICreateThreadParamsMessage,
|
||||||
OpenAIMessage,
|
OpenAIMessage,
|
||||||
|
RetrieveBatchRequest,
|
||||||
Run,
|
Run,
|
||||||
SyncCursorPage,
|
SyncCursorPage,
|
||||||
Thread,
|
Thread,
|
||||||
|
@ -2814,3 +2818,216 @@ class AzureAssistantsAPI(BaseLLM):
|
||||||
)
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
class AzureBatchesAPI(BaseLLM):
|
||||||
|
"""
|
||||||
|
Azure methods to support for batches
|
||||||
|
- create_batch()
|
||||||
|
- retrieve_batch()
|
||||||
|
- cancel_batch()
|
||||||
|
- list_batch()
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def get_azure_openai_client(
|
||||||
|
self,
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
max_retries: Optional[int],
|
||||||
|
api_version: Optional[str] = None,
|
||||||
|
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
||||||
|
_is_async: bool = False,
|
||||||
|
) -> Optional[Union[AzureOpenAI, AsyncAzureOpenAI]]:
|
||||||
|
received_args = locals()
|
||||||
|
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None
|
||||||
|
if client is None:
|
||||||
|
data = {}
|
||||||
|
for k, v in received_args.items():
|
||||||
|
if k == "self" or k == "client" or k == "_is_async":
|
||||||
|
pass
|
||||||
|
elif k == "api_base" and v is not None:
|
||||||
|
data["azure_endpoint"] = v
|
||||||
|
elif v is not None:
|
||||||
|
data[k] = v
|
||||||
|
if "api_version" not in data:
|
||||||
|
data["api_version"] = litellm.AZURE_DEFAULT_API_VERSION
|
||||||
|
if _is_async is True:
|
||||||
|
openai_client = AsyncAzureOpenAI(**data)
|
||||||
|
else:
|
||||||
|
openai_client = AzureOpenAI(**data) # type: ignore
|
||||||
|
else:
|
||||||
|
openai_client = client
|
||||||
|
|
||||||
|
return openai_client
|
||||||
|
|
||||||
|
async def acreate_batch(
|
||||||
|
self,
|
||||||
|
create_batch_data: CreateBatchRequest,
|
||||||
|
azure_client: AsyncAzureOpenAI,
|
||||||
|
) -> Batch:
|
||||||
|
response = await azure_client.batches.create(**create_batch_data)
|
||||||
|
return response
|
||||||
|
|
||||||
|
def create_batch(
|
||||||
|
self,
|
||||||
|
_is_async: bool,
|
||||||
|
create_batch_data: CreateBatchRequest,
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
api_version: Optional[str],
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
max_retries: Optional[int],
|
||||||
|
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
||||||
|
) -> Union[Batch, Coroutine[Any, Any, Batch]]:
|
||||||
|
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
|
||||||
|
self.get_azure_openai_client(
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
timeout=timeout,
|
||||||
|
api_version=api_version,
|
||||||
|
max_retries=max_retries,
|
||||||
|
client=client,
|
||||||
|
_is_async=_is_async,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if azure_client is None:
|
||||||
|
raise ValueError(
|
||||||
|
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
||||||
|
)
|
||||||
|
|
||||||
|
if _is_async is True:
|
||||||
|
if not isinstance(azure_client, AsyncAzureOpenAI):
|
||||||
|
raise ValueError(
|
||||||
|
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
|
||||||
|
)
|
||||||
|
return self.acreate_batch( # type: ignore
|
||||||
|
create_batch_data=create_batch_data, azure_client=azure_client
|
||||||
|
)
|
||||||
|
response = azure_client.batches.create(**create_batch_data)
|
||||||
|
return response
|
||||||
|
|
||||||
|
async def aretrieve_batch(
|
||||||
|
self,
|
||||||
|
retrieve_batch_data: RetrieveBatchRequest,
|
||||||
|
client: AsyncAzureOpenAI,
|
||||||
|
) -> Batch:
|
||||||
|
response = await client.batches.retrieve(**retrieve_batch_data)
|
||||||
|
return response
|
||||||
|
|
||||||
|
def retrieve_batch(
|
||||||
|
self,
|
||||||
|
_is_async: bool,
|
||||||
|
retrieve_batch_data: RetrieveBatchRequest,
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
api_version: Optional[str],
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
max_retries: Optional[int],
|
||||||
|
client: Optional[AzureOpenAI] = None,
|
||||||
|
):
|
||||||
|
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
|
||||||
|
self.get_azure_openai_client(
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
api_version=api_version,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
client=client,
|
||||||
|
_is_async=_is_async,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if azure_client is None:
|
||||||
|
raise ValueError(
|
||||||
|
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
||||||
|
)
|
||||||
|
|
||||||
|
if _is_async is True:
|
||||||
|
if not isinstance(azure_client, AsyncAzureOpenAI):
|
||||||
|
raise ValueError(
|
||||||
|
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
|
||||||
|
)
|
||||||
|
return self.aretrieve_batch( # type: ignore
|
||||||
|
retrieve_batch_data=retrieve_batch_data, client=azure_client
|
||||||
|
)
|
||||||
|
response = azure_client.batches.retrieve(**retrieve_batch_data)
|
||||||
|
return response
|
||||||
|
|
||||||
|
def cancel_batch(
|
||||||
|
self,
|
||||||
|
_is_async: bool,
|
||||||
|
cancel_batch_data: CancelBatchRequest,
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
max_retries: Optional[int],
|
||||||
|
organization: Optional[str],
|
||||||
|
client: Optional[AzureOpenAI] = None,
|
||||||
|
):
|
||||||
|
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
|
||||||
|
self.get_azure_openai_client(
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
client=client,
|
||||||
|
_is_async=_is_async,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if azure_client is None:
|
||||||
|
raise ValueError(
|
||||||
|
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
||||||
|
)
|
||||||
|
response = azure_client.batches.cancel(**cancel_batch_data)
|
||||||
|
return response
|
||||||
|
|
||||||
|
async def alist_batches(
|
||||||
|
self,
|
||||||
|
client: AsyncAzureOpenAI,
|
||||||
|
after: Optional[str] = None,
|
||||||
|
limit: Optional[int] = None,
|
||||||
|
):
|
||||||
|
response = await client.batches.list(after=after, limit=limit) # type: ignore
|
||||||
|
return response
|
||||||
|
|
||||||
|
def list_batches(
|
||||||
|
self,
|
||||||
|
_is_async: bool,
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
api_version: Optional[str],
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
max_retries: Optional[int],
|
||||||
|
after: Optional[str] = None,
|
||||||
|
limit: Optional[int] = None,
|
||||||
|
client: Optional[AzureOpenAI] = None,
|
||||||
|
):
|
||||||
|
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
|
||||||
|
self.get_azure_openai_client(
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
api_version=api_version,
|
||||||
|
client=client,
|
||||||
|
_is_async=_is_async,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if azure_client is None:
|
||||||
|
raise ValueError(
|
||||||
|
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
||||||
|
)
|
||||||
|
|
||||||
|
if _is_async is True:
|
||||||
|
if not isinstance(azure_client, AsyncAzureOpenAI):
|
||||||
|
raise ValueError(
|
||||||
|
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
|
||||||
|
)
|
||||||
|
return self.alist_batches( # type: ignore
|
||||||
|
client=azure_client, after=after, limit=limit
|
||||||
|
)
|
||||||
|
response = azure_client.batches.list(after=after, limit=limit) # type: ignore
|
||||||
|
return response
|
||||||
|
|
|
@ -121,7 +121,6 @@ class AzureOpenAIFilesAPI(BaseLLM):
|
||||||
api_key: Optional[str],
|
api_key: Optional[str],
|
||||||
timeout: Union[float, httpx.Timeout],
|
timeout: Union[float, httpx.Timeout],
|
||||||
max_retries: Optional[int],
|
max_retries: Optional[int],
|
||||||
organization: Optional[str],
|
|
||||||
api_version: Optional[str] = None,
|
api_version: Optional[str] = None,
|
||||||
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
||||||
) -> Union[
|
) -> Union[
|
||||||
|
@ -134,7 +133,7 @@ class AzureOpenAIFilesAPI(BaseLLM):
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
api_version=api_version,
|
api_version=api_version,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
organization=organization,
|
organization=None,
|
||||||
client=client,
|
client=client,
|
||||||
_is_async=_is_async,
|
_is_async=_is_async,
|
||||||
)
|
)
|
||||||
|
@ -173,7 +172,6 @@ class AzureOpenAIFilesAPI(BaseLLM):
|
||||||
api_key: Optional[str],
|
api_key: Optional[str],
|
||||||
timeout: Union[float, httpx.Timeout],
|
timeout: Union[float, httpx.Timeout],
|
||||||
max_retries: Optional[int],
|
max_retries: Optional[int],
|
||||||
organization: Optional[str],
|
|
||||||
api_version: Optional[str] = None,
|
api_version: Optional[str] = None,
|
||||||
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
||||||
):
|
):
|
||||||
|
@ -183,7 +181,7 @@ class AzureOpenAIFilesAPI(BaseLLM):
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
organization=organization,
|
organization=None,
|
||||||
api_version=api_version,
|
api_version=api_version,
|
||||||
client=client,
|
client=client,
|
||||||
_is_async=_is_async,
|
_is_async=_is_async,
|
||||||
|
@ -213,6 +211,9 @@ class AzureOpenAIFilesAPI(BaseLLM):
|
||||||
openai_client: AsyncAzureOpenAI,
|
openai_client: AsyncAzureOpenAI,
|
||||||
) -> FileDeleted:
|
) -> FileDeleted:
|
||||||
response = await openai_client.files.delete(file_id=file_id)
|
response = await openai_client.files.delete(file_id=file_id)
|
||||||
|
|
||||||
|
if not isinstance(response, FileDeleted): # azure returns an empty string
|
||||||
|
return FileDeleted(id=file_id, deleted=True, object="file")
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def delete_file(
|
def delete_file(
|
||||||
|
@ -255,6 +256,9 @@ class AzureOpenAIFilesAPI(BaseLLM):
|
||||||
)
|
)
|
||||||
response = openai_client.files.delete(file_id=file_id)
|
response = openai_client.files.delete(file_id=file_id)
|
||||||
|
|
||||||
|
if not isinstance(response, FileDeleted): # azure returns an empty string
|
||||||
|
return FileDeleted(id=file_id, deleted=True, object="file")
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
async def alist_files(
|
async def alist_files(
|
||||||
|
@ -275,7 +279,6 @@ class AzureOpenAIFilesAPI(BaseLLM):
|
||||||
api_key: Optional[str],
|
api_key: Optional[str],
|
||||||
timeout: Union[float, httpx.Timeout],
|
timeout: Union[float, httpx.Timeout],
|
||||||
max_retries: Optional[int],
|
max_retries: Optional[int],
|
||||||
organization: Optional[str],
|
|
||||||
purpose: Optional[str] = None,
|
purpose: Optional[str] = None,
|
||||||
api_version: Optional[str] = None,
|
api_version: Optional[str] = None,
|
||||||
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
||||||
|
@ -286,7 +289,7 @@ class AzureOpenAIFilesAPI(BaseLLM):
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
organization=organization,
|
organization=None, # openai param
|
||||||
api_version=api_version,
|
api_version=api_version,
|
||||||
client=client,
|
client=client,
|
||||||
_is_async=_is_async,
|
_is_async=_is_async,
|
||||||
|
|
2
litellm/tests/batch_job_results_furniture.jsonl
Normal file
2
litellm/tests/batch_job_results_furniture.jsonl
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo-0125", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 10}}
|
||||||
|
{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo-0125", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 10}}
|
|
@ -22,7 +22,8 @@ import litellm
|
||||||
from litellm import create_batch, create_file
|
from litellm import create_batch, create_file
|
||||||
|
|
||||||
|
|
||||||
def test_create_batch():
|
@pytest.mark.parametrize("provider", ["openai", "azure"])
|
||||||
|
def test_create_batch(provider):
|
||||||
"""
|
"""
|
||||||
1. Create File for Batch completion
|
1. Create File for Batch completion
|
||||||
2. Create Batch Request
|
2. Create Batch Request
|
||||||
|
@ -35,7 +36,7 @@ def test_create_batch():
|
||||||
file_obj = litellm.create_file(
|
file_obj = litellm.create_file(
|
||||||
file=open(file_path, "rb"),
|
file=open(file_path, "rb"),
|
||||||
purpose="batch",
|
purpose="batch",
|
||||||
custom_llm_provider="openai",
|
custom_llm_provider=provider,
|
||||||
)
|
)
|
||||||
print("Response from creating file=", file_obj)
|
print("Response from creating file=", file_obj)
|
||||||
|
|
||||||
|
@ -44,11 +45,12 @@ def test_create_batch():
|
||||||
batch_input_file_id is not None
|
batch_input_file_id is not None
|
||||||
), "Failed to create file, expected a non null file_id but got {batch_input_file_id}"
|
), "Failed to create file, expected a non null file_id but got {batch_input_file_id}"
|
||||||
|
|
||||||
|
time.sleep(5)
|
||||||
create_batch_response = litellm.create_batch(
|
create_batch_response = litellm.create_batch(
|
||||||
completion_window="24h",
|
completion_window="24h",
|
||||||
endpoint="/v1/chat/completions",
|
endpoint="/v1/chat/completions",
|
||||||
input_file_id=batch_input_file_id,
|
input_file_id=batch_input_file_id,
|
||||||
custom_llm_provider="openai",
|
custom_llm_provider=provider,
|
||||||
metadata={"key1": "value1", "key2": "value2"},
|
metadata={"key1": "value1", "key2": "value2"},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -59,13 +61,14 @@ def test_create_batch():
|
||||||
), f"Failed to create batch, expected a non null batch_id but got {create_batch_response.id}"
|
), f"Failed to create batch, expected a non null batch_id but got {create_batch_response.id}"
|
||||||
assert (
|
assert (
|
||||||
create_batch_response.endpoint == "/v1/chat/completions"
|
create_batch_response.endpoint == "/v1/chat/completions"
|
||||||
|
or create_batch_response.endpoint == "/chat/completions"
|
||||||
), f"Failed to create batch, expected endpoint to be /v1/chat/completions but got {create_batch_response.endpoint}"
|
), f"Failed to create batch, expected endpoint to be /v1/chat/completions but got {create_batch_response.endpoint}"
|
||||||
assert (
|
assert (
|
||||||
create_batch_response.input_file_id == batch_input_file_id
|
create_batch_response.input_file_id == batch_input_file_id
|
||||||
), f"Failed to create batch, expected input_file_id to be {batch_input_file_id} but got {create_batch_response.input_file_id}"
|
), f"Failed to create batch, expected input_file_id to be {batch_input_file_id} but got {create_batch_response.input_file_id}"
|
||||||
|
|
||||||
retrieved_batch = litellm.retrieve_batch(
|
retrieved_batch = litellm.retrieve_batch(
|
||||||
batch_id=create_batch_response.id, custom_llm_provider="openai"
|
batch_id=create_batch_response.id, custom_llm_provider=provider
|
||||||
)
|
)
|
||||||
print("retrieved batch=", retrieved_batch)
|
print("retrieved batch=", retrieved_batch)
|
||||||
# just assert that we retrieved a non None batch
|
# just assert that we retrieved a non None batch
|
||||||
|
@ -73,11 +76,11 @@ def test_create_batch():
|
||||||
assert retrieved_batch.id == create_batch_response.id
|
assert retrieved_batch.id == create_batch_response.id
|
||||||
|
|
||||||
# list all batches
|
# list all batches
|
||||||
list_batches = litellm.list_batches(custom_llm_provider="openai", limit=2)
|
list_batches = litellm.list_batches(custom_llm_provider=provider, limit=2)
|
||||||
print("list_batches=", list_batches)
|
print("list_batches=", list_batches)
|
||||||
|
|
||||||
file_content = litellm.file_content(
|
file_content = litellm.file_content(
|
||||||
file_id=batch_input_file_id, custom_llm_provider="openai"
|
file_id=batch_input_file_id, custom_llm_provider=provider
|
||||||
)
|
)
|
||||||
|
|
||||||
result = file_content.content
|
result = file_content.content
|
||||||
|
@ -90,8 +93,9 @@ def test_create_batch():
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("provider", ["openai", "azure"])
|
||||||
@pytest.mark.asyncio()
|
@pytest.mark.asyncio()
|
||||||
async def test_async_create_batch():
|
async def test_async_create_batch(provider):
|
||||||
"""
|
"""
|
||||||
1. Create File for Batch completion
|
1. Create File for Batch completion
|
||||||
2. Create Batch Request
|
2. Create Batch Request
|
||||||
|
@ -105,10 +109,11 @@ async def test_async_create_batch():
|
||||||
file_obj = await litellm.acreate_file(
|
file_obj = await litellm.acreate_file(
|
||||||
file=open(file_path, "rb"),
|
file=open(file_path, "rb"),
|
||||||
purpose="batch",
|
purpose="batch",
|
||||||
custom_llm_provider="openai",
|
custom_llm_provider=provider,
|
||||||
)
|
)
|
||||||
print("Response from creating file=", file_obj)
|
print("Response from creating file=", file_obj)
|
||||||
|
|
||||||
|
await asyncio.sleep(5)
|
||||||
batch_input_file_id = file_obj.id
|
batch_input_file_id = file_obj.id
|
||||||
assert (
|
assert (
|
||||||
batch_input_file_id is not None
|
batch_input_file_id is not None
|
||||||
|
@ -118,7 +123,7 @@ async def test_async_create_batch():
|
||||||
completion_window="24h",
|
completion_window="24h",
|
||||||
endpoint="/v1/chat/completions",
|
endpoint="/v1/chat/completions",
|
||||||
input_file_id=batch_input_file_id,
|
input_file_id=batch_input_file_id,
|
||||||
custom_llm_provider="openai",
|
custom_llm_provider=provider,
|
||||||
metadata={"key1": "value1", "key2": "value2"},
|
metadata={"key1": "value1", "key2": "value2"},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -129,6 +134,7 @@ async def test_async_create_batch():
|
||||||
), f"Failed to create batch, expected a non null batch_id but got {create_batch_response.id}"
|
), f"Failed to create batch, expected a non null batch_id but got {create_batch_response.id}"
|
||||||
assert (
|
assert (
|
||||||
create_batch_response.endpoint == "/v1/chat/completions"
|
create_batch_response.endpoint == "/v1/chat/completions"
|
||||||
|
or create_batch_response.endpoint == "/chat/completions"
|
||||||
), f"Failed to create batch, expected endpoint to be /v1/chat/completions but got {create_batch_response.endpoint}"
|
), f"Failed to create batch, expected endpoint to be /v1/chat/completions but got {create_batch_response.endpoint}"
|
||||||
assert (
|
assert (
|
||||||
create_batch_response.input_file_id == batch_input_file_id
|
create_batch_response.input_file_id == batch_input_file_id
|
||||||
|
@ -137,7 +143,7 @@ async def test_async_create_batch():
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
retrieved_batch = await litellm.aretrieve_batch(
|
retrieved_batch = await litellm.aretrieve_batch(
|
||||||
batch_id=create_batch_response.id, custom_llm_provider="openai"
|
batch_id=create_batch_response.id, custom_llm_provider=provider
|
||||||
)
|
)
|
||||||
print("retrieved batch=", retrieved_batch)
|
print("retrieved batch=", retrieved_batch)
|
||||||
# just assert that we retrieved a non None batch
|
# just assert that we retrieved a non None batch
|
||||||
|
@ -145,27 +151,27 @@ async def test_async_create_batch():
|
||||||
assert retrieved_batch.id == create_batch_response.id
|
assert retrieved_batch.id == create_batch_response.id
|
||||||
|
|
||||||
# list all batches
|
# list all batches
|
||||||
list_batches = await litellm.alist_batches(custom_llm_provider="openai", limit=2)
|
list_batches = await litellm.alist_batches(custom_llm_provider=provider, limit=2)
|
||||||
print("list_batches=", list_batches)
|
print("list_batches=", list_batches)
|
||||||
|
|
||||||
# try to get file content for our original file
|
# try to get file content for our original file
|
||||||
|
|
||||||
file_content = await litellm.afile_content(
|
file_content = await litellm.afile_content(
|
||||||
file_id=batch_input_file_id, custom_llm_provider="openai"
|
file_id=batch_input_file_id, custom_llm_provider=provider
|
||||||
)
|
)
|
||||||
|
|
||||||
print("file content = ", file_content)
|
print("file content = ", file_content)
|
||||||
|
|
||||||
# file obj
|
# file obj
|
||||||
file_obj = await litellm.afile_retrieve(
|
file_obj = await litellm.afile_retrieve(
|
||||||
file_id=batch_input_file_id, custom_llm_provider="openai"
|
file_id=batch_input_file_id, custom_llm_provider=provider
|
||||||
)
|
)
|
||||||
print("file obj = ", file_obj)
|
print("file obj = ", file_obj)
|
||||||
assert file_obj.id == batch_input_file_id
|
assert file_obj.id == batch_input_file_id
|
||||||
|
|
||||||
# delete file
|
# delete file
|
||||||
delete_file_response = await litellm.afile_delete(
|
delete_file_response = await litellm.afile_delete(
|
||||||
file_id=batch_input_file_id, custom_llm_provider="openai"
|
file_id=batch_input_file_id, custom_llm_provider=provider
|
||||||
)
|
)
|
||||||
|
|
||||||
print("delete file response = ", delete_file_response)
|
print("delete file response = ", delete_file_response)
|
||||||
|
@ -173,7 +179,7 @@ async def test_async_create_batch():
|
||||||
assert delete_file_response.id == batch_input_file_id
|
assert delete_file_response.id == batch_input_file_id
|
||||||
|
|
||||||
all_files_list = await litellm.afile_list(
|
all_files_list = await litellm.afile_list(
|
||||||
custom_llm_provider="openai",
|
custom_llm_provider=provider,
|
||||||
)
|
)
|
||||||
|
|
||||||
print("all_files_list = ", all_files_list)
|
print("all_files_list = ", all_files_list)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue