forked from phoenix/litellm-mirror
Merge pull request #4987 from BerriAI/litellm_add_ft_endpoints
[Feat-Proxy] Add List fine-tuning jobs
This commit is contained in:
commit
d833c69acb
16 changed files with 1478 additions and 151 deletions
|
@ -14,7 +14,8 @@ from typing import Any, Coroutine, Dict, Literal, Optional, Union
|
|||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm import client
|
||||
from litellm import client, get_secret
|
||||
from litellm.llms.files_apis.azure import AzureOpenAIFilesAPI
|
||||
from litellm.llms.openai import FileDeleted, FileObject, OpenAIFilesAPI
|
||||
from litellm.types.llms.openai import (
|
||||
Batch,
|
||||
|
@ -28,12 +29,13 @@ from litellm.utils import supports_httpx_timeout
|
|||
|
||||
####### ENVIRONMENT VARIABLES ###################
|
||||
openai_files_instance = OpenAIFilesAPI()
|
||||
azure_files_instance = AzureOpenAIFilesAPI()
|
||||
#################################################
|
||||
|
||||
|
||||
async def afile_retrieve(
|
||||
file_id: str,
|
||||
custom_llm_provider: Literal["openai"] = "openai",
|
||||
custom_llm_provider: Literal["openai", "azure"] = "openai",
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
|
@ -73,7 +75,7 @@ async def afile_retrieve(
|
|||
|
||||
def file_retrieve(
|
||||
file_id: str,
|
||||
custom_llm_provider: Literal["openai"] = "openai",
|
||||
custom_llm_provider: Literal["openai", "azure"] = "openai",
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
|
@ -156,7 +158,7 @@ def file_retrieve(
|
|||
# Delete file
|
||||
async def afile_delete(
|
||||
file_id: str,
|
||||
custom_llm_provider: Literal["openai"] = "openai",
|
||||
custom_llm_provider: Literal["openai", "azure"] = "openai",
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
|
@ -196,7 +198,7 @@ async def afile_delete(
|
|||
|
||||
def file_delete(
|
||||
file_id: str,
|
||||
custom_llm_provider: Literal["openai"] = "openai",
|
||||
custom_llm_provider: Literal["openai", "azure"] = "openai",
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
|
@ -208,6 +210,22 @@ def file_delete(
|
|||
"""
|
||||
try:
|
||||
optional_params = GenericLiteLLMParams(**kwargs)
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||
# set timeout for 10 minutes by default
|
||||
|
||||
if (
|
||||
timeout is not None
|
||||
and isinstance(timeout, httpx.Timeout)
|
||||
and supports_httpx_timeout(custom_llm_provider) == False
|
||||
):
|
||||
read_timeout = timeout.read or 600
|
||||
timeout = read_timeout # default 10 min timeout
|
||||
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
||||
timeout = float(timeout) # type: ignore
|
||||
elif timeout is None:
|
||||
timeout = 600.0
|
||||
_is_async = kwargs.pop("is_async", False) is True
|
||||
if custom_llm_provider == "openai":
|
||||
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
||||
api_base = (
|
||||
|
@ -229,26 +247,6 @@ def file_delete(
|
|||
or litellm.openai_key
|
||||
or os.getenv("OPENAI_API_KEY")
|
||||
)
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = (
|
||||
optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||
)
|
||||
# set timeout for 10 minutes by default
|
||||
|
||||
if (
|
||||
timeout is not None
|
||||
and isinstance(timeout, httpx.Timeout)
|
||||
and supports_httpx_timeout(custom_llm_provider) == False
|
||||
):
|
||||
read_timeout = timeout.read or 600
|
||||
timeout = read_timeout # default 10 min timeout
|
||||
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
||||
timeout = float(timeout) # type: ignore
|
||||
elif timeout is None:
|
||||
timeout = 600.0
|
||||
|
||||
_is_async = kwargs.pop("is_async", False) is True
|
||||
|
||||
response = openai_files_instance.delete_file(
|
||||
file_id=file_id,
|
||||
_is_async=_is_async,
|
||||
|
@ -258,6 +256,38 @@ def file_delete(
|
|||
max_retries=optional_params.max_retries,
|
||||
organization=organization,
|
||||
)
|
||||
elif custom_llm_provider == "azure":
|
||||
api_base = optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore
|
||||
api_version = (
|
||||
optional_params.api_version
|
||||
or litellm.api_version
|
||||
or get_secret("AZURE_API_VERSION")
|
||||
) # type: ignore
|
||||
|
||||
api_key = (
|
||||
optional_params.api_key
|
||||
or litellm.api_key
|
||||
or litellm.azure_key
|
||||
or get_secret("AZURE_OPENAI_API_KEY")
|
||||
or get_secret("AZURE_API_KEY")
|
||||
) # type: ignore
|
||||
|
||||
extra_body = optional_params.get("extra_body", {})
|
||||
azure_ad_token: Optional[str] = None
|
||||
if extra_body is not None:
|
||||
azure_ad_token = extra_body.pop("azure_ad_token", None)
|
||||
else:
|
||||
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
|
||||
|
||||
response = azure_files_instance.delete_file(
|
||||
_is_async=_is_async,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
api_version=api_version,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
file_id=file_id,
|
||||
)
|
||||
else:
|
||||
raise litellm.exceptions.BadRequestError(
|
||||
message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format(
|
||||
|
@ -278,7 +308,7 @@ def file_delete(
|
|||
|
||||
# List files
|
||||
async def afile_list(
|
||||
custom_llm_provider: Literal["openai"] = "openai",
|
||||
custom_llm_provider: Literal["openai", "azure"] = "openai",
|
||||
purpose: Optional[str] = None,
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
|
@ -318,7 +348,7 @@ async def afile_list(
|
|||
|
||||
|
||||
def file_list(
|
||||
custom_llm_provider: Literal["openai"] = "openai",
|
||||
custom_llm_provider: Literal["openai", "azure"] = "openai",
|
||||
purpose: Optional[str] = None,
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
|
@ -402,7 +432,7 @@ def file_list(
|
|||
async def acreate_file(
|
||||
file: FileTypes,
|
||||
purpose: Literal["assistants", "batch", "fine-tune"],
|
||||
custom_llm_provider: Literal["openai"] = "openai",
|
||||
custom_llm_provider: Literal["openai", "azure"] = "openai",
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
|
@ -444,7 +474,7 @@ async def acreate_file(
|
|||
def create_file(
|
||||
file: FileTypes,
|
||||
purpose: Literal["assistants", "batch", "fine-tune"],
|
||||
custom_llm_provider: Literal["openai"] = "openai",
|
||||
custom_llm_provider: Literal["openai", "azure"] = "openai",
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
|
@ -455,7 +485,31 @@ def create_file(
|
|||
LiteLLM Equivalent of POST: POST https://api.openai.com/v1/files
|
||||
"""
|
||||
try:
|
||||
_is_async = kwargs.pop("acreate_file", False) is True
|
||||
optional_params = GenericLiteLLMParams(**kwargs)
|
||||
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||
# set timeout for 10 minutes by default
|
||||
|
||||
if (
|
||||
timeout is not None
|
||||
and isinstance(timeout, httpx.Timeout)
|
||||
and supports_httpx_timeout(custom_llm_provider) == False
|
||||
):
|
||||
read_timeout = timeout.read or 600
|
||||
timeout = read_timeout # default 10 min timeout
|
||||
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
||||
timeout = float(timeout) # type: ignore
|
||||
elif timeout is None:
|
||||
timeout = 600.0
|
||||
|
||||
_create_file_request = CreateFileRequest(
|
||||
file=file,
|
||||
purpose=purpose,
|
||||
extra_headers=extra_headers,
|
||||
extra_body=extra_body,
|
||||
)
|
||||
if custom_llm_provider == "openai":
|
||||
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
||||
api_base = (
|
||||
|
@ -477,32 +531,6 @@ def create_file(
|
|||
or litellm.openai_key
|
||||
or os.getenv("OPENAI_API_KEY")
|
||||
)
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = (
|
||||
optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||
)
|
||||
# set timeout for 10 minutes by default
|
||||
|
||||
if (
|
||||
timeout is not None
|
||||
and isinstance(timeout, httpx.Timeout)
|
||||
and supports_httpx_timeout(custom_llm_provider) == False
|
||||
):
|
||||
read_timeout = timeout.read or 600
|
||||
timeout = read_timeout # default 10 min timeout
|
||||
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
||||
timeout = float(timeout) # type: ignore
|
||||
elif timeout is None:
|
||||
timeout = 600.0
|
||||
|
||||
_create_file_request = CreateFileRequest(
|
||||
file=file,
|
||||
purpose=purpose,
|
||||
extra_headers=extra_headers,
|
||||
extra_body=extra_body,
|
||||
)
|
||||
|
||||
_is_async = kwargs.pop("acreate_file", False) is True
|
||||
|
||||
response = openai_files_instance.create_file(
|
||||
_is_async=_is_async,
|
||||
|
@ -513,6 +541,38 @@ def create_file(
|
|||
organization=organization,
|
||||
create_file_data=_create_file_request,
|
||||
)
|
||||
elif custom_llm_provider == "azure":
|
||||
api_base = optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore
|
||||
api_version = (
|
||||
optional_params.api_version
|
||||
or litellm.api_version
|
||||
or get_secret("AZURE_API_VERSION")
|
||||
) # type: ignore
|
||||
|
||||
api_key = (
|
||||
optional_params.api_key
|
||||
or litellm.api_key
|
||||
or litellm.azure_key
|
||||
or get_secret("AZURE_OPENAI_API_KEY")
|
||||
or get_secret("AZURE_API_KEY")
|
||||
) # type: ignore
|
||||
|
||||
extra_body = optional_params.get("extra_body", {})
|
||||
azure_ad_token: Optional[str] = None
|
||||
if extra_body is not None:
|
||||
azure_ad_token = extra_body.pop("azure_ad_token", None)
|
||||
else:
|
||||
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
|
||||
|
||||
response = azure_files_instance.create_file(
|
||||
_is_async=_is_async,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
api_version=api_version,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
create_file_data=_create_file_request,
|
||||
)
|
||||
else:
|
||||
raise litellm.exceptions.BadRequestError(
|
||||
message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format(
|
||||
|
@ -533,7 +593,7 @@ def create_file(
|
|||
|
||||
async def afile_content(
|
||||
file_id: str,
|
||||
custom_llm_provider: Literal["openai"] = "openai",
|
||||
custom_llm_provider: Literal["openai", "azure"] = "openai",
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
|
@ -573,7 +633,7 @@ async def afile_content(
|
|||
|
||||
def file_content(
|
||||
file_id: str,
|
||||
custom_llm_provider: Literal["openai"] = "openai",
|
||||
custom_llm_provider: Literal["openai", "azure"] = "openai",
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
|
|
|
@ -17,7 +17,10 @@ from typing import Any, Coroutine, Dict, Literal, Optional, Union
|
|||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.llms.openai_fine_tuning.openai import (
|
||||
from litellm import get_secret
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.llms.fine_tuning_apis.azure import AzureOpenAIFineTuningAPI
|
||||
from litellm.llms.fine_tuning_apis.openai import (
|
||||
FineTuningJob,
|
||||
FineTuningJobCreate,
|
||||
OpenAIFineTuningAPI,
|
||||
|
@ -27,7 +30,8 @@ from litellm.types.router import *
|
|||
from litellm.utils import supports_httpx_timeout
|
||||
|
||||
####### ENVIRONMENT VARIABLES ###################
|
||||
openai_fine_tuning_instance = OpenAIFineTuningAPI()
|
||||
openai_fine_tuning_apis_instance = OpenAIFineTuningAPI()
|
||||
azure_fine_tuning_apis_instance = AzureOpenAIFineTuningAPI()
|
||||
#################################################
|
||||
|
||||
|
||||
|
@ -39,7 +43,7 @@ async def acreate_fine_tuning_job(
|
|||
validation_file: Optional[str] = None,
|
||||
integrations: Optional[List[str]] = None,
|
||||
seed: Optional[int] = None,
|
||||
custom_llm_provider: Literal["openai"] = "openai",
|
||||
custom_llm_provider: Literal["openai", "azure"] = "openai",
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
|
@ -48,6 +52,9 @@ async def acreate_fine_tuning_job(
|
|||
Async: Creates and executes a batch from an uploaded file of request
|
||||
|
||||
"""
|
||||
verbose_logger.debug(
|
||||
"inside acreate_fine_tuning_job model=%s and kwargs=%s", model, kwargs
|
||||
)
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
kwargs["acreate_fine_tuning_job"] = True
|
||||
|
@ -89,7 +96,7 @@ def create_fine_tuning_job(
|
|||
validation_file: Optional[str] = None,
|
||||
integrations: Optional[List[str]] = None,
|
||||
seed: Optional[int] = None,
|
||||
custom_llm_provider: Literal["openai"] = "openai",
|
||||
custom_llm_provider: Literal["openai", "azure"] = "openai",
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
|
@ -101,7 +108,25 @@ def create_fine_tuning_job(
|
|||
|
||||
"""
|
||||
try:
|
||||
_is_async = kwargs.pop("acreate_fine_tuning_job", False) is True
|
||||
optional_params = GenericLiteLLMParams(**kwargs)
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||
# set timeout for 10 minutes by default
|
||||
|
||||
if (
|
||||
timeout is not None
|
||||
and isinstance(timeout, httpx.Timeout)
|
||||
and supports_httpx_timeout(custom_llm_provider) == False
|
||||
):
|
||||
read_timeout = timeout.read or 600
|
||||
timeout = read_timeout # default 10 min timeout
|
||||
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
||||
timeout = float(timeout) # type: ignore
|
||||
elif timeout is None:
|
||||
timeout = 600.0
|
||||
|
||||
# OpenAI
|
||||
if custom_llm_provider == "openai":
|
||||
|
||||
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
||||
|
@ -124,25 +149,6 @@ def create_fine_tuning_job(
|
|||
or litellm.openai_key
|
||||
or os.getenv("OPENAI_API_KEY")
|
||||
)
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = (
|
||||
optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||
)
|
||||
# set timeout for 10 minutes by default
|
||||
|
||||
if (
|
||||
timeout is not None
|
||||
and isinstance(timeout, httpx.Timeout)
|
||||
and supports_httpx_timeout(custom_llm_provider) == False
|
||||
):
|
||||
read_timeout = timeout.read or 600
|
||||
timeout = read_timeout # default 10 min timeout
|
||||
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
||||
timeout = float(timeout) # type: ignore
|
||||
elif timeout is None:
|
||||
timeout = 600.0
|
||||
|
||||
_is_async = kwargs.pop("acreate_fine_tuning_job", False) is True
|
||||
|
||||
create_fine_tuning_job_data = FineTuningJobCreate(
|
||||
model=model,
|
||||
|
@ -154,11 +160,63 @@ def create_fine_tuning_job(
|
|||
seed=seed,
|
||||
)
|
||||
|
||||
response = openai_fine_tuning_instance.create_fine_tuning_job(
|
||||
create_fine_tuning_job_data_dict = create_fine_tuning_job_data.model_dump(
|
||||
exclude_none=True
|
||||
)
|
||||
|
||||
response = openai_fine_tuning_apis_instance.create_fine_tuning_job(
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
organization=organization,
|
||||
create_fine_tuning_job_data=create_fine_tuning_job_data,
|
||||
create_fine_tuning_job_data=create_fine_tuning_job_data_dict,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
_is_async=_is_async,
|
||||
)
|
||||
# Azure OpenAI
|
||||
elif custom_llm_provider == "azure":
|
||||
api_base = optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore
|
||||
|
||||
api_version = (
|
||||
optional_params.api_version
|
||||
or litellm.api_version
|
||||
or get_secret("AZURE_API_VERSION")
|
||||
) # type: ignore
|
||||
|
||||
api_key = (
|
||||
optional_params.api_key
|
||||
or litellm.api_key
|
||||
or litellm.azure_key
|
||||
or get_secret("AZURE_OPENAI_API_KEY")
|
||||
or get_secret("AZURE_API_KEY")
|
||||
) # type: ignore
|
||||
|
||||
extra_body = optional_params.get("extra_body", {})
|
||||
azure_ad_token: Optional[str] = None
|
||||
if extra_body is not None:
|
||||
azure_ad_token = extra_body.pop("azure_ad_token", None)
|
||||
else:
|
||||
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
|
||||
|
||||
create_fine_tuning_job_data = FineTuningJobCreate(
|
||||
model=model,
|
||||
training_file=training_file,
|
||||
hyperparameters=hyperparameters,
|
||||
suffix=suffix,
|
||||
validation_file=validation_file,
|
||||
integrations=integrations,
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
create_fine_tuning_job_data_dict = create_fine_tuning_job_data.model_dump(
|
||||
exclude_none=True
|
||||
)
|
||||
|
||||
response = azure_fine_tuning_apis_instance.create_fine_tuning_job(
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
api_version=api_version,
|
||||
create_fine_tuning_job_data=create_fine_tuning_job_data_dict,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
_is_async=_is_async,
|
||||
|
@ -233,6 +291,25 @@ def cancel_fine_tuning_job(
|
|||
"""
|
||||
try:
|
||||
optional_params = GenericLiteLLMParams(**kwargs)
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||
# set timeout for 10 minutes by default
|
||||
|
||||
if (
|
||||
timeout is not None
|
||||
and isinstance(timeout, httpx.Timeout)
|
||||
and supports_httpx_timeout(custom_llm_provider) == False
|
||||
):
|
||||
read_timeout = timeout.read or 600
|
||||
timeout = read_timeout # default 10 min timeout
|
||||
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
||||
timeout = float(timeout) # type: ignore
|
||||
elif timeout is None:
|
||||
timeout = 600.0
|
||||
|
||||
_is_async = kwargs.pop("acancel_fine_tuning_job", False) is True
|
||||
|
||||
# OpenAI
|
||||
if custom_llm_provider == "openai":
|
||||
|
||||
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
||||
|
@ -255,27 +332,8 @@ def cancel_fine_tuning_job(
|
|||
or litellm.openai_key
|
||||
or os.getenv("OPENAI_API_KEY")
|
||||
)
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = (
|
||||
optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||
)
|
||||
# set timeout for 10 minutes by default
|
||||
|
||||
if (
|
||||
timeout is not None
|
||||
and isinstance(timeout, httpx.Timeout)
|
||||
and supports_httpx_timeout(custom_llm_provider) == False
|
||||
):
|
||||
read_timeout = timeout.read or 600
|
||||
timeout = read_timeout # default 10 min timeout
|
||||
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
||||
timeout = float(timeout) # type: ignore
|
||||
elif timeout is None:
|
||||
timeout = 600.0
|
||||
|
||||
_is_async = kwargs.pop("acancel_fine_tuning_job", False) is True
|
||||
|
||||
response = openai_fine_tuning_instance.cancel_fine_tuning_job(
|
||||
response = openai_fine_tuning_apis_instance.cancel_fine_tuning_job(
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
organization=organization,
|
||||
|
@ -284,6 +342,40 @@ def cancel_fine_tuning_job(
|
|||
max_retries=optional_params.max_retries,
|
||||
_is_async=_is_async,
|
||||
)
|
||||
# Azure OpenAI
|
||||
elif custom_llm_provider == "azure":
|
||||
api_base = optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore
|
||||
|
||||
api_version = (
|
||||
optional_params.api_version
|
||||
or litellm.api_version
|
||||
or get_secret("AZURE_API_VERSION")
|
||||
) # type: ignore
|
||||
|
||||
api_key = (
|
||||
optional_params.api_key
|
||||
or litellm.api_key
|
||||
or litellm.azure_key
|
||||
or get_secret("AZURE_OPENAI_API_KEY")
|
||||
or get_secret("AZURE_API_KEY")
|
||||
) # type: ignore
|
||||
|
||||
extra_body = optional_params.get("extra_body", {})
|
||||
azure_ad_token: Optional[str] = None
|
||||
if extra_body is not None:
|
||||
azure_ad_token = extra_body.pop("azure_ad_token", None)
|
||||
else:
|
||||
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
|
||||
|
||||
response = azure_fine_tuning_apis_instance.cancel_fine_tuning_job(
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
api_version=api_version,
|
||||
fine_tuning_job_id=fine_tuning_job_id,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
_is_async=_is_async,
|
||||
)
|
||||
else:
|
||||
raise litellm.exceptions.BadRequestError(
|
||||
message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format(
|
||||
|
@ -359,6 +451,25 @@ def list_fine_tuning_jobs(
|
|||
"""
|
||||
try:
|
||||
optional_params = GenericLiteLLMParams(**kwargs)
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||
# set timeout for 10 minutes by default
|
||||
|
||||
if (
|
||||
timeout is not None
|
||||
and isinstance(timeout, httpx.Timeout)
|
||||
and supports_httpx_timeout(custom_llm_provider) == False
|
||||
):
|
||||
read_timeout = timeout.read or 600
|
||||
timeout = read_timeout # default 10 min timeout
|
||||
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
||||
timeout = float(timeout) # type: ignore
|
||||
elif timeout is None:
|
||||
timeout = 600.0
|
||||
|
||||
_is_async = kwargs.pop("alist_fine_tuning_jobs", False) is True
|
||||
|
||||
# OpenAI
|
||||
if custom_llm_provider == "openai":
|
||||
|
||||
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
||||
|
@ -381,27 +492,8 @@ def list_fine_tuning_jobs(
|
|||
or litellm.openai_key
|
||||
or os.getenv("OPENAI_API_KEY")
|
||||
)
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = (
|
||||
optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||
)
|
||||
# set timeout for 10 minutes by default
|
||||
|
||||
if (
|
||||
timeout is not None
|
||||
and isinstance(timeout, httpx.Timeout)
|
||||
and supports_httpx_timeout(custom_llm_provider) == False
|
||||
):
|
||||
read_timeout = timeout.read or 600
|
||||
timeout = read_timeout # default 10 min timeout
|
||||
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
||||
timeout = float(timeout) # type: ignore
|
||||
elif timeout is None:
|
||||
timeout = 600.0
|
||||
|
||||
_is_async = kwargs.pop("alist_fine_tuning_jobs", False) is True
|
||||
|
||||
response = openai_fine_tuning_instance.list_fine_tuning_jobs(
|
||||
response = openai_fine_tuning_apis_instance.list_fine_tuning_jobs(
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
organization=organization,
|
||||
|
@ -411,6 +503,41 @@ def list_fine_tuning_jobs(
|
|||
max_retries=optional_params.max_retries,
|
||||
_is_async=_is_async,
|
||||
)
|
||||
# Azure OpenAI
|
||||
elif custom_llm_provider == "azure":
|
||||
api_base = optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore
|
||||
|
||||
api_version = (
|
||||
optional_params.api_version
|
||||
or litellm.api_version
|
||||
or get_secret("AZURE_API_VERSION")
|
||||
) # type: ignore
|
||||
|
||||
api_key = (
|
||||
optional_params.api_key
|
||||
or litellm.api_key
|
||||
or litellm.azure_key
|
||||
or get_secret("AZURE_OPENAI_API_KEY")
|
||||
or get_secret("AZURE_API_KEY")
|
||||
) # type: ignore
|
||||
|
||||
extra_body = optional_params.get("extra_body", {})
|
||||
azure_ad_token: Optional[str] = None
|
||||
if extra_body is not None:
|
||||
azure_ad_token = extra_body.pop("azure_ad_token", None)
|
||||
else:
|
||||
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
|
||||
|
||||
response = azure_fine_tuning_apis_instance.list_fine_tuning_jobs(
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
api_version=api_version,
|
||||
after=after,
|
||||
limit=limit,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
_is_async=_is_async,
|
||||
)
|
||||
else:
|
||||
raise litellm.exceptions.BadRequestError(
|
||||
message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format(
|
||||
|
|
315
litellm/llms/files_apis/azure.py
Normal file
315
litellm/llms/files_apis/azure.py
Normal file
|
@ -0,0 +1,315 @@
|
|||
from typing import Any, Coroutine, Dict, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
from openai import AsyncAzureOpenAI, AzureOpenAI
|
||||
from openai.types.file_deleted import FileDeleted
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.llms.base import BaseLLM
|
||||
from litellm.types.llms.openai import *
|
||||
|
||||
|
||||
def get_azure_openai_client(
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
api_version: Optional[str] = None,
|
||||
organization: Optional[str] = None,
|
||||
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
||||
_is_async: bool = False,
|
||||
) -> Optional[Union[AzureOpenAI, AsyncAzureOpenAI]]:
|
||||
received_args = locals()
|
||||
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None
|
||||
if client is None:
|
||||
data = {}
|
||||
for k, v in received_args.items():
|
||||
if k == "self" or k == "client" or k == "_is_async":
|
||||
pass
|
||||
elif k == "api_base" and v is not None:
|
||||
data["azure_endpoint"] = v
|
||||
elif v is not None:
|
||||
data[k] = v
|
||||
if "api_version" not in data:
|
||||
data["api_version"] = litellm.AZURE_DEFAULT_API_VERSION
|
||||
if _is_async is True:
|
||||
openai_client = AsyncAzureOpenAI(**data)
|
||||
else:
|
||||
openai_client = AzureOpenAI(**data) # type: ignore
|
||||
else:
|
||||
openai_client = client
|
||||
|
||||
return openai_client
|
||||
|
||||
|
||||
class AzureOpenAIFilesAPI(BaseLLM):
|
||||
"""
|
||||
AzureOpenAI methods to support for batches
|
||||
- create_file()
|
||||
- retrieve_file()
|
||||
- list_files()
|
||||
- delete_file()
|
||||
- file_content()
|
||||
- update_file()
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
async def acreate_file(
|
||||
self,
|
||||
create_file_data: CreateFileRequest,
|
||||
openai_client: AsyncAzureOpenAI,
|
||||
) -> FileObject:
|
||||
verbose_logger.debug("create_file_data=%s", create_file_data)
|
||||
response = await openai_client.files.create(**create_file_data)
|
||||
verbose_logger.debug("create_file_response=%s", response)
|
||||
return response
|
||||
|
||||
def create_file(
|
||||
self,
|
||||
_is_async: bool,
|
||||
create_file_data: CreateFileRequest,
|
||||
api_base: str,
|
||||
api_key: Optional[str],
|
||||
api_version: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
||||
) -> Union[FileObject, Coroutine[Any, Any, FileObject]]:
|
||||
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
|
||||
get_azure_openai_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
_is_async=_is_async,
|
||||
)
|
||||
)
|
||||
if openai_client is None:
|
||||
raise ValueError(
|
||||
"AzureOpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
||||
)
|
||||
|
||||
if _is_async is True:
|
||||
if not isinstance(openai_client, AsyncAzureOpenAI):
|
||||
raise ValueError(
|
||||
"AzureOpenAI client is not an instance of AsyncAzureOpenAI. Make sure you passed an AsyncAzureOpenAI client."
|
||||
)
|
||||
return self.acreate_file( # type: ignore
|
||||
create_file_data=create_file_data, openai_client=openai_client
|
||||
)
|
||||
response = openai_client.files.create(**create_file_data)
|
||||
return response
|
||||
|
||||
async def afile_content(
|
||||
self,
|
||||
file_content_request: FileContentRequest,
|
||||
openai_client: AsyncAzureOpenAI,
|
||||
) -> HttpxBinaryResponseContent:
|
||||
response = await openai_client.files.content(**file_content_request)
|
||||
return response
|
||||
|
||||
def file_content(
|
||||
self,
|
||||
_is_async: bool,
|
||||
file_content_request: FileContentRequest,
|
||||
api_base: str,
|
||||
api_key: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
organization: Optional[str],
|
||||
api_version: Optional[str] = None,
|
||||
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
||||
) -> Union[
|
||||
HttpxBinaryResponseContent, Coroutine[Any, Any, HttpxBinaryResponseContent]
|
||||
]:
|
||||
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
|
||||
get_azure_openai_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
timeout=timeout,
|
||||
api_version=api_version,
|
||||
max_retries=max_retries,
|
||||
organization=organization,
|
||||
client=client,
|
||||
_is_async=_is_async,
|
||||
)
|
||||
)
|
||||
if openai_client is None:
|
||||
raise ValueError(
|
||||
"AzureOpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
||||
)
|
||||
|
||||
if _is_async is True:
|
||||
if not isinstance(openai_client, AsyncAzureOpenAI):
|
||||
raise ValueError(
|
||||
"AzureOpenAI client is not an instance of AsyncAzureOpenAI. Make sure you passed an AsyncAzureOpenAI client."
|
||||
)
|
||||
return self.afile_content( # type: ignore
|
||||
file_content_request=file_content_request,
|
||||
openai_client=openai_client,
|
||||
)
|
||||
response = openai_client.files.content(**file_content_request)
|
||||
|
||||
return response
|
||||
|
||||
async def aretrieve_file(
|
||||
self,
|
||||
file_id: str,
|
||||
openai_client: AsyncAzureOpenAI,
|
||||
) -> FileObject:
|
||||
response = await openai_client.files.retrieve(file_id=file_id)
|
||||
return response
|
||||
|
||||
def retrieve_file(
|
||||
self,
|
||||
_is_async: bool,
|
||||
file_id: str,
|
||||
api_base: str,
|
||||
api_key: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
organization: Optional[str],
|
||||
api_version: Optional[str] = None,
|
||||
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
||||
):
|
||||
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
|
||||
get_azure_openai_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
organization=organization,
|
||||
api_version=api_version,
|
||||
client=client,
|
||||
_is_async=_is_async,
|
||||
)
|
||||
)
|
||||
if openai_client is None:
|
||||
raise ValueError(
|
||||
"AzureOpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
||||
)
|
||||
|
||||
if _is_async is True:
|
||||
if not isinstance(openai_client, AsyncAzureOpenAI):
|
||||
raise ValueError(
|
||||
"AzureOpenAI client is not an instance of AsyncAzureOpenAI. Make sure you passed an AsyncAzureOpenAI client."
|
||||
)
|
||||
return self.aretrieve_file( # type: ignore
|
||||
file_id=file_id,
|
||||
openai_client=openai_client,
|
||||
)
|
||||
response = openai_client.files.retrieve(file_id=file_id)
|
||||
|
||||
return response
|
||||
|
||||
async def adelete_file(
|
||||
self,
|
||||
file_id: str,
|
||||
openai_client: AsyncAzureOpenAI,
|
||||
) -> FileDeleted:
|
||||
response = await openai_client.files.delete(file_id=file_id)
|
||||
return response
|
||||
|
||||
def delete_file(
|
||||
self,
|
||||
_is_async: bool,
|
||||
file_id: str,
|
||||
api_base: str,
|
||||
api_key: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
organization: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
||||
):
|
||||
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
|
||||
get_azure_openai_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
organization=organization,
|
||||
api_version=api_version,
|
||||
client=client,
|
||||
_is_async=_is_async,
|
||||
)
|
||||
)
|
||||
if openai_client is None:
|
||||
raise ValueError(
|
||||
"AzureOpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
||||
)
|
||||
|
||||
if _is_async is True:
|
||||
if not isinstance(openai_client, AsyncAzureOpenAI):
|
||||
raise ValueError(
|
||||
"AzureOpenAI client is not an instance of AsyncAzureOpenAI. Make sure you passed an AsyncAzureOpenAI client."
|
||||
)
|
||||
return self.adelete_file( # type: ignore
|
||||
file_id=file_id,
|
||||
openai_client=openai_client,
|
||||
)
|
||||
response = openai_client.files.delete(file_id=file_id)
|
||||
|
||||
return response
|
||||
|
||||
async def alist_files(
|
||||
self,
|
||||
openai_client: AsyncAzureOpenAI,
|
||||
purpose: Optional[str] = None,
|
||||
):
|
||||
if isinstance(purpose, str):
|
||||
response = await openai_client.files.list(purpose=purpose)
|
||||
else:
|
||||
response = await openai_client.files.list()
|
||||
return response
|
||||
|
||||
def list_files(
|
||||
self,
|
||||
_is_async: bool,
|
||||
api_base: str,
|
||||
api_key: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
organization: Optional[str],
|
||||
purpose: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
||||
):
|
||||
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
|
||||
get_azure_openai_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
organization=organization,
|
||||
api_version=api_version,
|
||||
client=client,
|
||||
_is_async=_is_async,
|
||||
)
|
||||
)
|
||||
if openai_client is None:
|
||||
raise ValueError(
|
||||
"AzureOpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
||||
)
|
||||
|
||||
if _is_async is True:
|
||||
if not isinstance(openai_client, AsyncAzureOpenAI):
|
||||
raise ValueError(
|
||||
"AzureOpenAI client is not an instance of AsyncAzureOpenAI. Make sure you passed an AsyncAzureOpenAI client."
|
||||
)
|
||||
return self.alist_files( # type: ignore
|
||||
purpose=purpose,
|
||||
openai_client=openai_client,
|
||||
)
|
||||
|
||||
if isinstance(purpose, str):
|
||||
response = openai_client.files.list(purpose=purpose)
|
||||
else:
|
||||
response = openai_client.files.list()
|
||||
|
||||
return response
|
181
litellm/llms/fine_tuning_apis/azure.py
Normal file
181
litellm/llms/fine_tuning_apis/azure.py
Normal file
|
@ -0,0 +1,181 @@
|
|||
from typing import Any, Coroutine, Optional, Union
|
||||
|
||||
import httpx
|
||||
from openai import AsyncAzureOpenAI, AzureOpenAI
|
||||
from openai.pagination import AsyncCursorPage
|
||||
from openai.types.fine_tuning import FineTuningJob
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.llms.base import BaseLLM
|
||||
from litellm.llms.files_apis.azure import get_azure_openai_client
|
||||
from litellm.types.llms.openai import FineTuningJobCreate
|
||||
|
||||
|
||||
class AzureOpenAIFineTuningAPI(BaseLLM):
|
||||
"""
|
||||
AzureOpenAI methods to support for batches
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
async def acreate_fine_tuning_job(
|
||||
self,
|
||||
create_fine_tuning_job_data: dict,
|
||||
openai_client: AsyncAzureOpenAI,
|
||||
) -> FineTuningJob:
|
||||
response = await openai_client.fine_tuning.jobs.create(
|
||||
**create_fine_tuning_job_data # type: ignore
|
||||
)
|
||||
return response
|
||||
|
||||
def create_fine_tuning_job(
|
||||
self,
|
||||
_is_async: bool,
|
||||
create_fine_tuning_job_data: dict,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
organization: Optional[str] = None,
|
||||
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
||||
api_version: Optional[str] = None,
|
||||
) -> Union[FineTuningJob, Union[Coroutine[Any, Any, FineTuningJob]]]:
|
||||
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
|
||||
get_azure_openai_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
organization=organization,
|
||||
api_version=api_version,
|
||||
client=client,
|
||||
_is_async=_is_async,
|
||||
)
|
||||
)
|
||||
if openai_client is None:
|
||||
raise ValueError(
|
||||
"AzureOpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
||||
)
|
||||
|
||||
if _is_async is True:
|
||||
if not isinstance(openai_client, AsyncAzureOpenAI):
|
||||
raise ValueError(
|
||||
"AzureOpenAI client is not an instance of AsyncAzureOpenAI. Make sure you passed an AsyncAzureOpenAI client."
|
||||
)
|
||||
return self.acreate_fine_tuning_job( # type: ignore
|
||||
create_fine_tuning_job_data=create_fine_tuning_job_data,
|
||||
openai_client=openai_client,
|
||||
)
|
||||
verbose_logger.debug(
|
||||
"creating fine tuning job, args= %s", create_fine_tuning_job_data
|
||||
)
|
||||
response = openai_client.fine_tuning.jobs.create(**create_fine_tuning_job_data) # type: ignore
|
||||
return response
|
||||
|
||||
async def acancel_fine_tuning_job(
|
||||
self,
|
||||
fine_tuning_job_id: str,
|
||||
openai_client: AsyncAzureOpenAI,
|
||||
) -> FineTuningJob:
|
||||
response = await openai_client.fine_tuning.jobs.cancel(
|
||||
fine_tuning_job_id=fine_tuning_job_id
|
||||
)
|
||||
return response
|
||||
|
||||
def cancel_fine_tuning_job(
|
||||
self,
|
||||
_is_async: bool,
|
||||
fine_tuning_job_id: str,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
organization: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
||||
):
|
||||
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
|
||||
get_azure_openai_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
organization=organization,
|
||||
client=client,
|
||||
_is_async=_is_async,
|
||||
)
|
||||
)
|
||||
if openai_client is None:
|
||||
raise ValueError(
|
||||
"AzureOpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
||||
)
|
||||
|
||||
if _is_async is True:
|
||||
if not isinstance(openai_client, AsyncAzureOpenAI):
|
||||
raise ValueError(
|
||||
"AzureOpenAI client is not an instance of AsyncAzureOpenAI. Make sure you passed an AsyncAzureOpenAI client."
|
||||
)
|
||||
return self.acancel_fine_tuning_job( # type: ignore
|
||||
fine_tuning_job_id=fine_tuning_job_id,
|
||||
openai_client=openai_client,
|
||||
)
|
||||
verbose_logger.debug("canceling fine tuning job, args= %s", fine_tuning_job_id)
|
||||
response = openai_client.fine_tuning.jobs.cancel(
|
||||
fine_tuning_job_id=fine_tuning_job_id
|
||||
)
|
||||
return response
|
||||
|
||||
async def alist_fine_tuning_jobs(
|
||||
self,
|
||||
openai_client: AsyncAzureOpenAI,
|
||||
after: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
):
|
||||
response = await openai_client.fine_tuning.jobs.list(after=after, limit=limit) # type: ignore
|
||||
return response
|
||||
|
||||
def list_fine_tuning_jobs(
|
||||
self,
|
||||
_is_async: bool,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
organization: Optional[str] = None,
|
||||
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
||||
api_version: Optional[str] = None,
|
||||
after: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
):
|
||||
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
|
||||
get_azure_openai_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
organization=organization,
|
||||
client=client,
|
||||
_is_async=_is_async,
|
||||
)
|
||||
)
|
||||
if openai_client is None:
|
||||
raise ValueError(
|
||||
"AzureOpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
||||
)
|
||||
|
||||
if _is_async is True:
|
||||
if not isinstance(openai_client, AsyncAzureOpenAI):
|
||||
raise ValueError(
|
||||
"AzureOpenAI client is not an instance of AsyncAzureOpenAI. Make sure you passed an AsyncAzureOpenAI client."
|
||||
)
|
||||
return self.alist_fine_tuning_jobs( # type: ignore
|
||||
after=after,
|
||||
limit=limit,
|
||||
openai_client=openai_client,
|
||||
)
|
||||
verbose_logger.debug("list fine tuning job, after= %s, limit= %s", after, limit)
|
||||
response = openai_client.fine_tuning.jobs.list(after=after, limit=limit) # type: ignore
|
||||
return response
|
|
@ -50,18 +50,18 @@ class OpenAIFineTuningAPI(BaseLLM):
|
|||
|
||||
async def acreate_fine_tuning_job(
|
||||
self,
|
||||
create_fine_tuning_job_data: FineTuningJobCreate,
|
||||
create_fine_tuning_job_data: dict,
|
||||
openai_client: AsyncOpenAI,
|
||||
) -> FineTuningJob:
|
||||
response = await openai_client.fine_tuning.jobs.create(
|
||||
**create_fine_tuning_job_data # type: ignore
|
||||
**create_fine_tuning_job_data
|
||||
)
|
||||
return response
|
||||
|
||||
def create_fine_tuning_job(
|
||||
self,
|
||||
_is_async: bool,
|
||||
create_fine_tuning_job_data: FineTuningJobCreate,
|
||||
create_fine_tuning_job_data: dict,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
|
@ -95,7 +95,7 @@ class OpenAIFineTuningAPI(BaseLLM):
|
|||
verbose_logger.debug(
|
||||
"creating fine tuning job, args= %s", create_fine_tuning_job_data
|
||||
)
|
||||
response = openai_client.fine_tuning.jobs.create(**create_fine_tuning_job_data) # type: ignore
|
||||
response = openai_client.fine_tuning.jobs.create(**create_fine_tuning_job_data)
|
||||
return response
|
||||
|
||||
async def acancel_fine_tuning_job(
|
|
@ -208,6 +208,11 @@ class LiteLLMRoutes(enum.Enum):
|
|||
"/files/{file_id}",
|
||||
"/v1/files/{file_id}/content",
|
||||
"/files/{file_id}/content",
|
||||
# fine_tuning
|
||||
"/fine_tuning/jobs",
|
||||
"v1/fine_tuning/jobs",
|
||||
"/fine_tuning/jobs/{fine_tuning_job_id}/cancel"
|
||||
"/v1/fine_tuning/jobs/{fine_tuning_job_id}/cancel"
|
||||
# assistants-related routes
|
||||
"/assistants",
|
||||
"/v1/assistants",
|
||||
|
|
431
litellm/proxy/fine_tuning_endpoints/endpoints.py
Normal file
431
litellm/proxy/fine_tuning_endpoints/endpoints.py
Normal file
|
@ -0,0 +1,431 @@
|
|||
#########################################################################
|
||||
|
||||
# /v1/fine_tuning Endpoints
|
||||
|
||||
# Equivalent of https://platform.openai.com/docs/api-reference/fine-tuning
|
||||
##########################################################################
|
||||
|
||||
import asyncio
|
||||
import traceback
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import List, Optional
|
||||
|
||||
import fastapi
|
||||
import httpx
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Depends,
|
||||
File,
|
||||
Form,
|
||||
Header,
|
||||
HTTPException,
|
||||
Request,
|
||||
Response,
|
||||
UploadFile,
|
||||
status,
|
||||
)
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.batches.main import FileObject
|
||||
from litellm.proxy._types import *
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
from litellm.types.llms.openai import LiteLLMFineTuningJobCreate
|
||||
|
||||
fine_tuning_config = None
|
||||
|
||||
|
||||
def set_fine_tuning_config(config):
|
||||
if config is None:
|
||||
return
|
||||
|
||||
global fine_tuning_config
|
||||
if not isinstance(config, list):
|
||||
raise ValueError("invalid fine_tuning config, expected a list is not a list")
|
||||
|
||||
for element in config:
|
||||
if isinstance(element, dict):
|
||||
for key, value in element.items():
|
||||
if isinstance(value, str) and value.startswith("os.environ/"):
|
||||
element[key] = litellm.get_secret(value)
|
||||
|
||||
fine_tuning_config = config
|
||||
|
||||
|
||||
# Function to search for specific custom_llm_provider and return its configuration
|
||||
def get_fine_tuning_provider_config(
|
||||
custom_llm_provider: str,
|
||||
):
|
||||
global fine_tuning_config
|
||||
if fine_tuning_config is None:
|
||||
raise ValueError(
|
||||
"fine_tuning_config is not set, set it on your config.yaml file."
|
||||
)
|
||||
for setting in fine_tuning_config:
|
||||
if setting.get("custom_llm_provider") == custom_llm_provider:
|
||||
return setting
|
||||
return None
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/fine_tuning/jobs",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
tags=["fine-tuning"],
|
||||
summary="✨ (Enterprise) Create Fine-Tuning Job",
|
||||
)
|
||||
@router.post(
|
||||
"/fine_tuning/jobs",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
tags=["fine-tuning"],
|
||||
summary="✨ (Enterprise) Create Fine-Tuning Job",
|
||||
)
|
||||
async def create_fine_tuning_job(
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
fine_tuning_request: LiteLLMFineTuningJobCreate,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Creates a fine-tuning job which begins the process of creating a new model from a given dataset.
|
||||
This is the equivalent of POST https://api.openai.com/v1/fine_tuning/jobs
|
||||
|
||||
Supports Identical Params as: https://platform.openai.com/docs/api-reference/fine-tuning/create
|
||||
|
||||
Example Curl:
|
||||
```
|
||||
curl http://localhost:4000/v1/fine_tuning/jobs \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer sk-1234" \
|
||||
-d '{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"training_file": "file-abc123",
|
||||
"hyperparameters": {
|
||||
"n_epochs": 4
|
||||
}
|
||||
}'
|
||||
```
|
||||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
add_litellm_data_to_request,
|
||||
general_settings,
|
||||
get_custom_headers,
|
||||
premium_user,
|
||||
proxy_config,
|
||||
proxy_logging_obj,
|
||||
version,
|
||||
)
|
||||
|
||||
try:
|
||||
if premium_user is not True:
|
||||
raise ValueError(
|
||||
f"Only premium users can use this endpoint + {CommonProxyErrors.not_premium_user.value}"
|
||||
)
|
||||
# Convert Pydantic model to dict
|
||||
data = fine_tuning_request.model_dump(exclude_none=True)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"Request received by LiteLLM:\n{}".format(json.dumps(data, indent=4)),
|
||||
)
|
||||
|
||||
# Include original request and headers in the data
|
||||
data = await add_litellm_data_to_request(
|
||||
data=data,
|
||||
request=request,
|
||||
general_settings=general_settings,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
version=version,
|
||||
proxy_config=proxy_config,
|
||||
)
|
||||
|
||||
# get configs for custom_llm_provider
|
||||
llm_provider_config = get_fine_tuning_provider_config(
|
||||
custom_llm_provider=fine_tuning_request.custom_llm_provider,
|
||||
)
|
||||
|
||||
# add llm_provider_config to data
|
||||
data.update(llm_provider_config)
|
||||
|
||||
response = await litellm.acreate_fine_tuning_job(**data)
|
||||
|
||||
### ALERTING ###
|
||||
asyncio.create_task(
|
||||
proxy_logging_obj.update_request_status(
|
||||
litellm_call_id=data.get("litellm_call_id", ""), status="success"
|
||||
)
|
||||
)
|
||||
|
||||
### RESPONSE HEADERS ###
|
||||
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
||||
model_id = hidden_params.get("model_id", None) or ""
|
||||
cache_key = hidden_params.get("cache_key", None) or ""
|
||||
api_base = hidden_params.get("api_base", None) or ""
|
||||
|
||||
fastapi_response.headers.update(
|
||||
get_custom_headers(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
model_id=model_id,
|
||||
cache_key=cache_key,
|
||||
api_base=api_base,
|
||||
version=version,
|
||||
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||
)
|
||||
)
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
await proxy_logging_obj.post_call_failure_hook(
|
||||
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
||||
)
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.proxy_server.create_fine_tuning_job(): Exception occurred - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "message", str(e.detail)),
|
||||
type=getattr(e, "type", "None"),
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
||||
)
|
||||
else:
|
||||
error_msg = f"{str(e)}"
|
||||
raise ProxyException(
|
||||
message=getattr(e, "message", error_msg),
|
||||
type=getattr(e, "type", "None"),
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", 500),
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/v1/fine_tuning/jobs",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
tags=["fine-tuning"],
|
||||
summary="✨ (Enterprise) List Fine-Tuning Jobs",
|
||||
)
|
||||
@router.get(
|
||||
"/fine_tuning/jobs",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
tags=["fine-tuning"],
|
||||
summary="✨ (Enterprise) List Fine-Tuning Jobs",
|
||||
)
|
||||
async def list_fine_tuning_jobs(
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
custom_llm_provider: Literal["openai", "azure"],
|
||||
after: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Lists fine-tuning jobs for the organization.
|
||||
This is the equivalent of GET https://api.openai.com/v1/fine_tuning/jobs
|
||||
|
||||
Supported Query Params:
|
||||
- `custom_llm_provider`: Name of the LiteLLM provider
|
||||
- `after`: Identifier for the last job from the previous pagination request.
|
||||
- `limit`: Number of fine-tuning jobs to retrieve (default is 20).
|
||||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
add_litellm_data_to_request,
|
||||
general_settings,
|
||||
get_custom_headers,
|
||||
premium_user,
|
||||
proxy_config,
|
||||
proxy_logging_obj,
|
||||
version,
|
||||
)
|
||||
|
||||
data: dict = {}
|
||||
try:
|
||||
if premium_user is not True:
|
||||
raise ValueError(
|
||||
f"Only premium users can use this endpoint + {CommonProxyErrors.not_premium_user.value}"
|
||||
)
|
||||
# Include original request and headers in the data
|
||||
data = await add_litellm_data_to_request(
|
||||
data=data,
|
||||
request=request,
|
||||
general_settings=general_settings,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
version=version,
|
||||
proxy_config=proxy_config,
|
||||
)
|
||||
|
||||
# get configs for custom_llm_provider
|
||||
llm_provider_config = get_fine_tuning_provider_config(
|
||||
custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
|
||||
data.update(llm_provider_config)
|
||||
|
||||
response = await litellm.alist_fine_tuning_jobs(
|
||||
**data,
|
||||
after=after,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
### RESPONSE HEADERS ###
|
||||
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
||||
model_id = hidden_params.get("model_id", None) or ""
|
||||
cache_key = hidden_params.get("cache_key", None) or ""
|
||||
api_base = hidden_params.get("api_base", None) or ""
|
||||
|
||||
fastapi_response.headers.update(
|
||||
get_custom_headers(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
model_id=model_id,
|
||||
cache_key=cache_key,
|
||||
api_base=api_base,
|
||||
version=version,
|
||||
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||
)
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
await proxy_logging_obj.post_call_failure_hook(
|
||||
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
||||
)
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.proxy_server.list_fine_tuning_jobs(): Exception occurred - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "message", str(e.detail)),
|
||||
type=getattr(e, "type", "None"),
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
||||
)
|
||||
else:
|
||||
error_msg = f"{str(e)}"
|
||||
raise ProxyException(
|
||||
message=getattr(e, "message", error_msg),
|
||||
type=getattr(e, "type", "None"),
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", 500),
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/fine_tuning/jobs/{fine_tuning_job_id:path}/cancel",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
tags=["fine-tuning"],
|
||||
summary="✨ (Enterprise) Cancel Fine-Tuning Jobs",
|
||||
)
|
||||
@router.post(
|
||||
"/fine_tuning/jobs/{fine_tuning_job_id:path}/cancel",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
tags=["fine-tuning"],
|
||||
summary="✨ (Enterprise) Cancel Fine-Tuning Jobs",
|
||||
)
|
||||
async def retrieve_fine_tuning_job(
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
fine_tuning_job_id: str,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Cancel a fine-tuning job.
|
||||
|
||||
This is the equivalent of POST https://api.openai.com/v1/fine_tuning/jobs/{fine_tuning_job_id}/cancel
|
||||
|
||||
Supported Query Params:
|
||||
- `custom_llm_provider`: Name of the LiteLLM provider
|
||||
- `fine_tuning_job_id`: The ID of the fine-tuning job to cancel.
|
||||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
add_litellm_data_to_request,
|
||||
general_settings,
|
||||
get_custom_headers,
|
||||
premium_user,
|
||||
proxy_config,
|
||||
proxy_logging_obj,
|
||||
version,
|
||||
)
|
||||
|
||||
data: dict = {}
|
||||
try:
|
||||
if premium_user is not True:
|
||||
raise ValueError(
|
||||
f"Only premium users can use this endpoint + {CommonProxyErrors.not_premium_user.value}"
|
||||
)
|
||||
# Include original request and headers in the data
|
||||
data = await add_litellm_data_to_request(
|
||||
data=data,
|
||||
request=request,
|
||||
general_settings=general_settings,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
version=version,
|
||||
proxy_config=proxy_config,
|
||||
)
|
||||
|
||||
request_body = await request.json()
|
||||
|
||||
custom_llm_provider = request_body.get("custom_llm_provider", None)
|
||||
|
||||
# get configs for custom_llm_provider
|
||||
llm_provider_config = get_fine_tuning_provider_config(
|
||||
custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
|
||||
data.update(llm_provider_config)
|
||||
|
||||
response = await litellm.acancel_fine_tuning_job(
|
||||
**data,
|
||||
fine_tuning_job_id=fine_tuning_job_id,
|
||||
)
|
||||
|
||||
### RESPONSE HEADERS ###
|
||||
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
||||
model_id = hidden_params.get("model_id", None) or ""
|
||||
cache_key = hidden_params.get("cache_key", None) or ""
|
||||
api_base = hidden_params.get("api_base", None) or ""
|
||||
|
||||
fastapi_response.headers.update(
|
||||
get_custom_headers(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
model_id=model_id,
|
||||
cache_key=cache_key,
|
||||
api_base=api_base,
|
||||
version=version,
|
||||
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||
)
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
await proxy_logging_obj.post_call_failure_hook(
|
||||
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
||||
)
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.proxy_server.list_fine_tuning_jobs(): Exception occurred - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "message", str(e.detail)),
|
||||
type=getattr(e, "type", "None"),
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
||||
)
|
||||
else:
|
||||
error_msg = f"{str(e)}"
|
||||
raise ProxyException(
|
||||
message=getattr(e, "message", error_msg),
|
||||
type=getattr(e, "type", "None"),
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", 500),
|
||||
)
|
|
@ -34,6 +34,37 @@ from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
|||
|
||||
router = APIRouter()
|
||||
|
||||
files_config = None
|
||||
|
||||
|
||||
def set_files_config(config):
|
||||
global files_config
|
||||
if config is None:
|
||||
return
|
||||
|
||||
if not isinstance(config, list):
|
||||
raise ValueError("invalid files config, expected a list is not a list")
|
||||
|
||||
for element in config:
|
||||
if isinstance(element, dict):
|
||||
for key, value in element.items():
|
||||
if isinstance(value, str) and value.startswith("os.environ/"):
|
||||
element[key] = litellm.get_secret(value)
|
||||
|
||||
files_config = config
|
||||
|
||||
|
||||
def get_files_provider_config(
|
||||
custom_llm_provider: str,
|
||||
):
|
||||
global files_config
|
||||
if files_config is None:
|
||||
raise ValueError("files_config is not set, set it on your config.yaml file.")
|
||||
for setting in files_config:
|
||||
if setting.get("custom_llm_provider") == custom_llm_provider:
|
||||
return setting
|
||||
return None
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/files",
|
||||
|
@ -49,6 +80,7 @@ async def create_file(
|
|||
request: Request,
|
||||
fastapi_response: Response,
|
||||
purpose: str = Form(...),
|
||||
custom_llm_provider: str = Form(default="openai"),
|
||||
file: UploadFile = File(...),
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
|
@ -100,11 +132,17 @@ async def create_file(
|
|||
|
||||
_create_file_request = CreateFileRequest(file=file_data, **data)
|
||||
|
||||
# for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch
|
||||
response = await litellm.acreate_file(
|
||||
custom_llm_provider="openai", **_create_file_request
|
||||
# get configs for custom_llm_provider
|
||||
llm_provider_config = get_files_provider_config(
|
||||
custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
|
||||
# add llm_provider_config to data
|
||||
_create_file_request.update(llm_provider_config)
|
||||
|
||||
# for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch
|
||||
response = await litellm.acreate_file(**_create_file_request)
|
||||
|
||||
### ALERTING ###
|
||||
asyncio.create_task(
|
||||
proxy_logging_obj.update_request_status(
|
||||
|
|
|
@ -25,6 +25,27 @@ model_list:
|
|||
api_key: "os.environ/OPENAI_API_KEY"
|
||||
model_info:
|
||||
mode: audio_speech
|
||||
|
||||
# For /fine_tuning/jobs endpoints
|
||||
finetune_settings:
|
||||
- custom_llm_provider: azure
|
||||
api_base: https://exampleopenaiendpoint-production.up.railway.app
|
||||
api_key: fake-key
|
||||
api_version: "2023-03-15-preview"
|
||||
- custom_llm_provider: openai
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
|
||||
# for /files endpoints
|
||||
files_settings:
|
||||
- custom_llm_provider: azure
|
||||
api_base: https://exampleopenaiendpoint-production.up.railway.app
|
||||
api_key: fake-key
|
||||
api_version: "2023-03-15-preview"
|
||||
- custom_llm_provider: openai
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
|
||||
|
||||
|
||||
general_settings:
|
||||
master_key: sk-1234
|
||||
|
||||
|
|
|
@ -153,6 +153,8 @@ from litellm.proxy.common_utils.init_callbacks import initialize_callbacks_on_pr
|
|||
from litellm.proxy.common_utils.openai_endpoint_utils import (
|
||||
remove_sensitive_info_from_deployment,
|
||||
)
|
||||
from litellm.proxy.fine_tuning_endpoints.endpoints import router as fine_tuning_router
|
||||
from litellm.proxy.fine_tuning_endpoints.endpoints import set_fine_tuning_config
|
||||
from litellm.proxy.guardrails.init_guardrails import initialize_guardrails
|
||||
from litellm.proxy.health_check import perform_health_check
|
||||
from litellm.proxy.health_endpoints._health_endpoints import router as health_router
|
||||
|
@ -179,6 +181,7 @@ from litellm.proxy.management_endpoints.team_endpoints import router as team_rou
|
|||
from litellm.proxy.openai_files_endpoints.files_endpoints import (
|
||||
router as openai_files_router,
|
||||
)
|
||||
from litellm.proxy.openai_files_endpoints.files_endpoints import set_files_config
|
||||
from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
|
||||
initialize_pass_through_endpoints,
|
||||
)
|
||||
|
@ -1807,6 +1810,14 @@ class ProxyConfig:
|
|||
assistant_settings["litellm_params"][k] = v
|
||||
assistants_config = AssistantsTypedDict(**assistant_settings) # type: ignore
|
||||
|
||||
## /fine_tuning/jobs endpoints config
|
||||
finetuning_config = config.get("finetune_settings", None)
|
||||
set_fine_tuning_config(config=finetuning_config)
|
||||
|
||||
## /files endpoint config
|
||||
files_config = config.get("files_settings", None)
|
||||
set_files_config(config=files_config)
|
||||
|
||||
## ROUTER SETTINGS (e.g. routing_strategy, ...)
|
||||
router_settings = config.get("router_settings", None)
|
||||
if router_settings and isinstance(router_settings, dict):
|
||||
|
@ -9598,6 +9609,7 @@ def cleanup_router_config_variables():
|
|||
|
||||
|
||||
app.include_router(router)
|
||||
app.include_router(fine_tuning_router)
|
||||
app.include_router(health_router)
|
||||
app.include_router(key_management_router)
|
||||
app.include_router(internal_user_router)
|
||||
|
|
12
litellm/tests/azure_fine_tune.jsonl
Normal file
12
litellm/tests/azure_fine_tune.jsonl
Normal file
|
@ -0,0 +1,12 @@
|
|||
{"messages": [{"role": "system", "content": "Clippy is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "What's the capital of France?"}, {"role": "assistant", "content": "Paris, as if everyone doesn't know that already."}]}
|
||||
{"messages": [{"role": "system", "content": "Clippy is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "Who wrote 'Romeo and Juliet'?"}, {"role": "assistant", "content": "Oh, just some guy named William Shakespeare. Ever heard of him?"}]}
|
||||
{"messages": [{"role": "system", "content": "Clippy is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "How far is the Moon from Earth?"}, {"role": "assistant", "content": "Around 384,400 kilometers. Give or take a few, like that really matters."}]}
|
||||
{"messages": [{"role": "system", "content": "Clippy is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "What's the capital of France?"}, {"role": "assistant", "content": "Paris, as if everyone doesn't know that already."}]}
|
||||
{"messages": [{"role": "system", "content": "Clippy is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "Who wrote 'Romeo and Juliet'?"}, {"role": "assistant", "content": "Oh, just some guy named William Shakespeare. Ever heard of him?"}]}
|
||||
{"messages": [{"role": "system", "content": "Clippy is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "How far is the Moon from Earth?"}, {"role": "assistant", "content": "Around 384,400 kilometers. Give or take a few, like that really matters."}]}
|
||||
{"messages": [{"role": "system", "content": "Clippy is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "What's the capital of France?"}, {"role": "assistant", "content": "Paris, as if everyone doesn't know that already."}]}
|
||||
{"messages": [{"role": "system", "content": "Clippy is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "Who wrote 'Romeo and Juliet'?"}, {"role": "assistant", "content": "Oh, just some guy named William Shakespeare. Ever heard of him?"}]}
|
||||
{"messages": [{"role": "system", "content": "Clippy is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "How far is the Moon from Earth?"}, {"role": "assistant", "content": "Around 384,400 kilometers. Give or take a few, like that really matters."}]}
|
||||
{"messages": [{"role": "system", "content": "Clippy is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "What's the capital of France?"}, {"role": "assistant", "content": "Paris, as if everyone doesn't know that already."}]}
|
||||
{"messages": [{"role": "system", "content": "Clippy is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "Who wrote 'Romeo and Juliet'?"}, {"role": "assistant", "content": "Oh, just some guy named William Shakespeare. Ever heard of him?"}]}
|
||||
{"messages": [{"role": "system", "content": "Clippy is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "How far is the Moon from Earth?"}, {"role": "assistant", "content": "Around 384,400 kilometers. Give or take a few, like that really matters."}]}
|
|
@ -12,6 +12,7 @@ from openai import APITimeoutError as Timeout
|
|||
import litellm
|
||||
|
||||
litellm.num_retries = 0
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
import openai
|
||||
|
@ -128,3 +129,49 @@ async def test_create_fine_tune_jobs_async():
|
|||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_azure_create_fine_tune_jobs_async():
|
||||
verbose_logger.setLevel(logging.DEBUG)
|
||||
file_name = "azure_fine_tune.jsonl"
|
||||
_current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
file_path = os.path.join(_current_dir, file_name)
|
||||
|
||||
file_id = "file-5e4b20ecbd724182b9964f3cd2ab7212"
|
||||
|
||||
create_fine_tuning_response = await litellm.acreate_fine_tuning_job(
|
||||
model="gpt-35-turbo-1106",
|
||||
training_file=file_id,
|
||||
custom_llm_provider="azure",
|
||||
api_key=os.getenv("AZURE_SWEDEN_API_KEY"),
|
||||
api_base="https://my-endpoint-sweden-berri992.openai.azure.com/",
|
||||
)
|
||||
|
||||
print("response from litellm.create_fine_tuning_job=", create_fine_tuning_response)
|
||||
|
||||
assert create_fine_tuning_response.id is not None
|
||||
assert create_fine_tuning_response.model == "gpt-35-turbo-1106"
|
||||
|
||||
# list fine tuning jobs
|
||||
print("listing ft jobs")
|
||||
ft_jobs = await litellm.alist_fine_tuning_jobs(
|
||||
limit=2,
|
||||
custom_llm_provider="azure",
|
||||
api_key=os.getenv("AZURE_SWEDEN_API_KEY"),
|
||||
api_base="https://my-endpoint-sweden-berri992.openai.azure.com/",
|
||||
)
|
||||
print("response from litellm.list_fine_tuning_jobs=", ft_jobs)
|
||||
|
||||
# cancel ft job
|
||||
response = await litellm.acancel_fine_tuning_job(
|
||||
fine_tuning_job_id=create_fine_tuning_response.id,
|
||||
custom_llm_provider="azure",
|
||||
api_key=os.getenv("AZURE_SWEDEN_API_KEY"),
|
||||
api_base="https://my-endpoint-sweden-berri992.openai.azure.com/",
|
||||
)
|
||||
|
||||
print("response from litellm.cancel_fine_tuning_job=", response)
|
||||
|
||||
assert response.status == "cancelled"
|
||||
assert response.id == create_fine_tuning_response.id
|
||||
|
|
|
@ -9,7 +9,6 @@ from typing import (
|
|||
Mapping,
|
||||
Optional,
|
||||
Tuple,
|
||||
TypedDict,
|
||||
Union,
|
||||
)
|
||||
|
||||
|
@ -31,7 +30,7 @@ from openai.types.beta.threads.message import Message as OpenAIMessage
|
|||
from openai.types.beta.threads.message_content import MessageContent
|
||||
from openai.types.beta.threads.run import Run
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Dict, Required, override
|
||||
from typing_extensions import Dict, Required, TypedDict, override
|
||||
|
||||
FileContent = Union[IO[bytes], bytes, PathLike]
|
||||
|
||||
|
@ -457,15 +456,17 @@ class ChatCompletionUsageBlock(TypedDict):
|
|||
total_tokens: int
|
||||
|
||||
|
||||
class Hyperparameters(TypedDict):
|
||||
batch_size: Optional[Union[str, int]] # "Number of examples in each batch."
|
||||
learning_rate_multiplier: Optional[
|
||||
Union[str, float]
|
||||
] # Scaling factor for the learning rate
|
||||
n_epochs: Optional[Union[str, int]] # "The number of epochs to train the model for"
|
||||
class Hyperparameters(BaseModel):
|
||||
batch_size: Optional[Union[str, int]] = None # "Number of examples in each batch."
|
||||
learning_rate_multiplier: Optional[Union[str, float]] = (
|
||||
None # Scaling factor for the learning rate
|
||||
)
|
||||
n_epochs: Optional[Union[str, int]] = (
|
||||
None # "The number of epochs to train the model for"
|
||||
)
|
||||
|
||||
|
||||
class FineTuningJobCreate(TypedDict):
|
||||
class FineTuningJobCreate(BaseModel):
|
||||
"""
|
||||
FineTuningJobCreate - Create a fine-tuning job
|
||||
|
||||
|
@ -489,16 +490,20 @@ class FineTuningJobCreate(TypedDict):
|
|||
|
||||
model: str # "The name of the model to fine-tune."
|
||||
training_file: str # "The ID of an uploaded file that contains training data."
|
||||
hyperparameters: Optional[
|
||||
Hyperparameters
|
||||
] # "The hyperparameters used for the fine-tuning job."
|
||||
suffix: Optional[
|
||||
str
|
||||
] # "A string of up to 18 characters that will be added to your fine-tuned model name."
|
||||
validation_file: Optional[
|
||||
str
|
||||
] # "The ID of an uploaded file that contains validation data."
|
||||
integrations: Optional[
|
||||
List[str]
|
||||
] # "A list of integrations to enable for your fine-tuning job."
|
||||
seed: Optional[int] # "The seed controls the reproducibility of the job."
|
||||
hyperparameters: Optional[Hyperparameters] = (
|
||||
None # "The hyperparameters used for the fine-tuning job."
|
||||
)
|
||||
suffix: Optional[str] = (
|
||||
None # "A string of up to 18 characters that will be added to your fine-tuned model name."
|
||||
)
|
||||
validation_file: Optional[str] = (
|
||||
None # "The ID of an uploaded file that contains validation data."
|
||||
)
|
||||
integrations: Optional[List[str]] = (
|
||||
None # "A list of integrations to enable for your fine-tuning job."
|
||||
)
|
||||
seed: Optional[int] = None # "The seed controls the reproducibility of the job."
|
||||
|
||||
|
||||
class LiteLLMFineTuningJobCreate(FineTuningJobCreate):
|
||||
custom_llm_provider: Literal["openai", "azure"]
|
||||
|
|
|
@ -120,6 +120,24 @@ litellm_settings:
|
|||
langfuse_secret: os.environ/LANGFUSE_PROJECT2_SECRET # Project 2
|
||||
langfuse_host: https://us.cloud.langfuse.com
|
||||
|
||||
# For /fine_tuning/jobs endpoints
|
||||
finetune_settings:
|
||||
- custom_llm_provider: azure
|
||||
api_base: https://exampleopenaiendpoint-production.up.railway.app
|
||||
api_key: fake-key
|
||||
api_version: "2023-03-15-preview"
|
||||
- custom_llm_provider: openai
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
|
||||
# for /files endpoints
|
||||
files_settings:
|
||||
- custom_llm_provider: azure
|
||||
api_base: http://0.0.0.0:8090
|
||||
api_key: fake-key
|
||||
api_version: "2023-03-15-preview"
|
||||
- custom_llm_provider: openai
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
|
||||
router_settings:
|
||||
routing_strategy: usage-based-routing-v2
|
||||
redis_host: os.environ/REDIS_HOST
|
||||
|
|
2
tests/openai_batch_completions.jsonl
Normal file
2
tests/openai_batch_completions.jsonl
Normal file
|
@ -0,0 +1,2 @@
|
|||
{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo-0125", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 10}}
|
||||
{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo-0125", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 10}}
|
53
tests/test_openai_fine_tuning.py
Normal file
53
tests/test_openai_fine_tuning.py
Normal file
|
@ -0,0 +1,53 @@
|
|||
from openai import AsyncOpenAI
|
||||
import os
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_fine_tuning():
|
||||
"""
|
||||
[PROD Test] Ensures logprobs are returned correctly
|
||||
"""
|
||||
client = AsyncOpenAI(api_key="sk-1234", base_url="http://0.0.0.0:4000")
|
||||
|
||||
file_name = "openai_batch_completions.jsonl"
|
||||
_current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
file_path = os.path.join(_current_dir, file_name)
|
||||
|
||||
response = await client.files.create(
|
||||
extra_body={"custom_llm_provider": "azure"},
|
||||
file=open(file_path, "rb"),
|
||||
purpose="fine-tune",
|
||||
)
|
||||
|
||||
print("response from files.create: {}".format(response))
|
||||
|
||||
# create fine tuning job
|
||||
|
||||
ft_job = await client.fine_tuning.jobs.create(
|
||||
model="gpt-35-turbo-1106",
|
||||
training_file=response.id,
|
||||
extra_body={"custom_llm_provider": "azure"},
|
||||
)
|
||||
|
||||
print("response from ft job={}".format(ft_job))
|
||||
|
||||
# response from example endpoint
|
||||
assert ft_job.id == "ftjob-abc123"
|
||||
|
||||
# list all fine tuning jobs
|
||||
list_ft_jobs = await client.fine_tuning.jobs.list(
|
||||
extra_query={"custom_llm_provider": "azure"}
|
||||
)
|
||||
|
||||
print("list of ft jobs={}".format(list_ft_jobs))
|
||||
|
||||
# cancel specific fine tuning job
|
||||
cancel_ft_job = await client.fine_tuning.jobs.cancel(
|
||||
fine_tuning_job_id="123",
|
||||
extra_body={"custom_llm_provider": "azure"},
|
||||
)
|
||||
|
||||
print("response from cancel ft job={}".format(cancel_ft_job))
|
||||
|
||||
assert cancel_ft_job.id is not None
|
Loading…
Add table
Add a link
Reference in a new issue