Merge pull request #4987 from BerriAI/litellm_add_ft_endpoints

[Feat-Proxy] Add List fine-tuning jobs
This commit is contained in:
Ishaan Jaff 2024-07-31 16:49:59 -07:00 committed by GitHub
commit d833c69acb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 1478 additions and 151 deletions

View file

@ -14,7 +14,8 @@ from typing import Any, Coroutine, Dict, Literal, Optional, Union
import httpx
import litellm
from litellm import client
from litellm import client, get_secret
from litellm.llms.files_apis.azure import AzureOpenAIFilesAPI
from litellm.llms.openai import FileDeleted, FileObject, OpenAIFilesAPI
from litellm.types.llms.openai import (
Batch,
@ -28,12 +29,13 @@ from litellm.utils import supports_httpx_timeout
####### ENVIRONMENT VARIABLES ###################
openai_files_instance = OpenAIFilesAPI()
azure_files_instance = AzureOpenAIFilesAPI()
#################################################
async def afile_retrieve(
file_id: str,
custom_llm_provider: Literal["openai"] = "openai",
custom_llm_provider: Literal["openai", "azure"] = "openai",
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
@ -73,7 +75,7 @@ async def afile_retrieve(
def file_retrieve(
file_id: str,
custom_llm_provider: Literal["openai"] = "openai",
custom_llm_provider: Literal["openai", "azure"] = "openai",
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
@ -156,7 +158,7 @@ def file_retrieve(
# Delete file
async def afile_delete(
file_id: str,
custom_llm_provider: Literal["openai"] = "openai",
custom_llm_provider: Literal["openai", "azure"] = "openai",
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
@ -196,7 +198,7 @@ async def afile_delete(
def file_delete(
file_id: str,
custom_llm_provider: Literal["openai"] = "openai",
custom_llm_provider: Literal["openai", "azure"] = "openai",
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
@ -208,6 +210,22 @@ def file_delete(
"""
try:
optional_params = GenericLiteLLMParams(**kwargs)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
# set timeout for 10 minutes by default
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) == False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
timeout = float(timeout) # type: ignore
elif timeout is None:
timeout = 600.0
_is_async = kwargs.pop("is_async", False) is True
if custom_llm_provider == "openai":
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
api_base = (
@ -229,26 +247,6 @@ def file_delete(
or litellm.openai_key
or os.getenv("OPENAI_API_KEY")
)
### TIMEOUT LOGIC ###
timeout = (
optional_params.timeout or kwargs.get("request_timeout", 600) or 600
)
# set timeout for 10 minutes by default
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) == False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
timeout = float(timeout) # type: ignore
elif timeout is None:
timeout = 600.0
_is_async = kwargs.pop("is_async", False) is True
response = openai_files_instance.delete_file(
file_id=file_id,
_is_async=_is_async,
@ -258,6 +256,38 @@ def file_delete(
max_retries=optional_params.max_retries,
organization=organization,
)
elif custom_llm_provider == "azure":
api_base = optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore
api_version = (
optional_params.api_version
or litellm.api_version
or get_secret("AZURE_API_VERSION")
) # type: ignore
api_key = (
optional_params.api_key
or litellm.api_key
or litellm.azure_key
or get_secret("AZURE_OPENAI_API_KEY")
or get_secret("AZURE_API_KEY")
) # type: ignore
extra_body = optional_params.get("extra_body", {})
azure_ad_token: Optional[str] = None
if extra_body is not None:
azure_ad_token = extra_body.pop("azure_ad_token", None)
else:
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
response = azure_files_instance.delete_file(
_is_async=_is_async,
api_base=api_base,
api_key=api_key,
api_version=api_version,
timeout=timeout,
max_retries=optional_params.max_retries,
file_id=file_id,
)
else:
raise litellm.exceptions.BadRequestError(
message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format(
@ -278,7 +308,7 @@ def file_delete(
# List files
async def afile_list(
custom_llm_provider: Literal["openai"] = "openai",
custom_llm_provider: Literal["openai", "azure"] = "openai",
purpose: Optional[str] = None,
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
@ -318,7 +348,7 @@ async def afile_list(
def file_list(
custom_llm_provider: Literal["openai"] = "openai",
custom_llm_provider: Literal["openai", "azure"] = "openai",
purpose: Optional[str] = None,
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
@ -402,7 +432,7 @@ def file_list(
async def acreate_file(
file: FileTypes,
purpose: Literal["assistants", "batch", "fine-tune"],
custom_llm_provider: Literal["openai"] = "openai",
custom_llm_provider: Literal["openai", "azure"] = "openai",
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
@ -444,7 +474,7 @@ async def acreate_file(
def create_file(
file: FileTypes,
purpose: Literal["assistants", "batch", "fine-tune"],
custom_llm_provider: Literal["openai"] = "openai",
custom_llm_provider: Literal["openai", "azure"] = "openai",
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
@ -455,7 +485,31 @@ def create_file(
LiteLLM Equivalent of POST: POST https://api.openai.com/v1/files
"""
try:
_is_async = kwargs.pop("acreate_file", False) is True
optional_params = GenericLiteLLMParams(**kwargs)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
# set timeout for 10 minutes by default
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) == False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
timeout = float(timeout) # type: ignore
elif timeout is None:
timeout = 600.0
_create_file_request = CreateFileRequest(
file=file,
purpose=purpose,
extra_headers=extra_headers,
extra_body=extra_body,
)
if custom_llm_provider == "openai":
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
api_base = (
@ -477,32 +531,6 @@ def create_file(
or litellm.openai_key
or os.getenv("OPENAI_API_KEY")
)
### TIMEOUT LOGIC ###
timeout = (
optional_params.timeout or kwargs.get("request_timeout", 600) or 600
)
# set timeout for 10 minutes by default
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) == False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
timeout = float(timeout) # type: ignore
elif timeout is None:
timeout = 600.0
_create_file_request = CreateFileRequest(
file=file,
purpose=purpose,
extra_headers=extra_headers,
extra_body=extra_body,
)
_is_async = kwargs.pop("acreate_file", False) is True
response = openai_files_instance.create_file(
_is_async=_is_async,
@ -513,6 +541,38 @@ def create_file(
organization=organization,
create_file_data=_create_file_request,
)
elif custom_llm_provider == "azure":
api_base = optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore
api_version = (
optional_params.api_version
or litellm.api_version
or get_secret("AZURE_API_VERSION")
) # type: ignore
api_key = (
optional_params.api_key
or litellm.api_key
or litellm.azure_key
or get_secret("AZURE_OPENAI_API_KEY")
or get_secret("AZURE_API_KEY")
) # type: ignore
extra_body = optional_params.get("extra_body", {})
azure_ad_token: Optional[str] = None
if extra_body is not None:
azure_ad_token = extra_body.pop("azure_ad_token", None)
else:
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
response = azure_files_instance.create_file(
_is_async=_is_async,
api_base=api_base,
api_key=api_key,
api_version=api_version,
timeout=timeout,
max_retries=optional_params.max_retries,
create_file_data=_create_file_request,
)
else:
raise litellm.exceptions.BadRequestError(
message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format(
@ -533,7 +593,7 @@ def create_file(
async def afile_content(
file_id: str,
custom_llm_provider: Literal["openai"] = "openai",
custom_llm_provider: Literal["openai", "azure"] = "openai",
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
@ -573,7 +633,7 @@ async def afile_content(
def file_content(
file_id: str,
custom_llm_provider: Literal["openai"] = "openai",
custom_llm_provider: Literal["openai", "azure"] = "openai",
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,

View file

@ -17,7 +17,10 @@ from typing import Any, Coroutine, Dict, Literal, Optional, Union
import httpx
import litellm
from litellm.llms.openai_fine_tuning.openai import (
from litellm import get_secret
from litellm._logging import verbose_logger
from litellm.llms.fine_tuning_apis.azure import AzureOpenAIFineTuningAPI
from litellm.llms.fine_tuning_apis.openai import (
FineTuningJob,
FineTuningJobCreate,
OpenAIFineTuningAPI,
@ -27,7 +30,8 @@ from litellm.types.router import *
from litellm.utils import supports_httpx_timeout
####### ENVIRONMENT VARIABLES ###################
openai_fine_tuning_instance = OpenAIFineTuningAPI()
openai_fine_tuning_apis_instance = OpenAIFineTuningAPI()
azure_fine_tuning_apis_instance = AzureOpenAIFineTuningAPI()
#################################################
@ -39,7 +43,7 @@ async def acreate_fine_tuning_job(
validation_file: Optional[str] = None,
integrations: Optional[List[str]] = None,
seed: Optional[int] = None,
custom_llm_provider: Literal["openai"] = "openai",
custom_llm_provider: Literal["openai", "azure"] = "openai",
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
@ -48,6 +52,9 @@ async def acreate_fine_tuning_job(
Async: Creates and executes a batch from an uploaded file of request
"""
verbose_logger.debug(
"inside acreate_fine_tuning_job model=%s and kwargs=%s", model, kwargs
)
try:
loop = asyncio.get_event_loop()
kwargs["acreate_fine_tuning_job"] = True
@ -89,7 +96,7 @@ def create_fine_tuning_job(
validation_file: Optional[str] = None,
integrations: Optional[List[str]] = None,
seed: Optional[int] = None,
custom_llm_provider: Literal["openai"] = "openai",
custom_llm_provider: Literal["openai", "azure"] = "openai",
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
@ -101,7 +108,25 @@ def create_fine_tuning_job(
"""
try:
_is_async = kwargs.pop("acreate_fine_tuning_job", False) is True
optional_params = GenericLiteLLMParams(**kwargs)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
# set timeout for 10 minutes by default
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) == False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
timeout = float(timeout) # type: ignore
elif timeout is None:
timeout = 600.0
# OpenAI
if custom_llm_provider == "openai":
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
@ -124,25 +149,6 @@ def create_fine_tuning_job(
or litellm.openai_key
or os.getenv("OPENAI_API_KEY")
)
### TIMEOUT LOGIC ###
timeout = (
optional_params.timeout or kwargs.get("request_timeout", 600) or 600
)
# set timeout for 10 minutes by default
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) == False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
timeout = float(timeout) # type: ignore
elif timeout is None:
timeout = 600.0
_is_async = kwargs.pop("acreate_fine_tuning_job", False) is True
create_fine_tuning_job_data = FineTuningJobCreate(
model=model,
@ -154,11 +160,63 @@ def create_fine_tuning_job(
seed=seed,
)
response = openai_fine_tuning_instance.create_fine_tuning_job(
create_fine_tuning_job_data_dict = create_fine_tuning_job_data.model_dump(
exclude_none=True
)
response = openai_fine_tuning_apis_instance.create_fine_tuning_job(
api_base=api_base,
api_key=api_key,
organization=organization,
create_fine_tuning_job_data=create_fine_tuning_job_data,
create_fine_tuning_job_data=create_fine_tuning_job_data_dict,
timeout=timeout,
max_retries=optional_params.max_retries,
_is_async=_is_async,
)
# Azure OpenAI
elif custom_llm_provider == "azure":
api_base = optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore
api_version = (
optional_params.api_version
or litellm.api_version
or get_secret("AZURE_API_VERSION")
) # type: ignore
api_key = (
optional_params.api_key
or litellm.api_key
or litellm.azure_key
or get_secret("AZURE_OPENAI_API_KEY")
or get_secret("AZURE_API_KEY")
) # type: ignore
extra_body = optional_params.get("extra_body", {})
azure_ad_token: Optional[str] = None
if extra_body is not None:
azure_ad_token = extra_body.pop("azure_ad_token", None)
else:
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
create_fine_tuning_job_data = FineTuningJobCreate(
model=model,
training_file=training_file,
hyperparameters=hyperparameters,
suffix=suffix,
validation_file=validation_file,
integrations=integrations,
seed=seed,
)
create_fine_tuning_job_data_dict = create_fine_tuning_job_data.model_dump(
exclude_none=True
)
response = azure_fine_tuning_apis_instance.create_fine_tuning_job(
api_base=api_base,
api_key=api_key,
api_version=api_version,
create_fine_tuning_job_data=create_fine_tuning_job_data_dict,
timeout=timeout,
max_retries=optional_params.max_retries,
_is_async=_is_async,
@ -233,6 +291,25 @@ def cancel_fine_tuning_job(
"""
try:
optional_params = GenericLiteLLMParams(**kwargs)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
# set timeout for 10 minutes by default
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) == False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
timeout = float(timeout) # type: ignore
elif timeout is None:
timeout = 600.0
_is_async = kwargs.pop("acancel_fine_tuning_job", False) is True
# OpenAI
if custom_llm_provider == "openai":
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
@ -255,27 +332,8 @@ def cancel_fine_tuning_job(
or litellm.openai_key
or os.getenv("OPENAI_API_KEY")
)
### TIMEOUT LOGIC ###
timeout = (
optional_params.timeout or kwargs.get("request_timeout", 600) or 600
)
# set timeout for 10 minutes by default
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) == False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
timeout = float(timeout) # type: ignore
elif timeout is None:
timeout = 600.0
_is_async = kwargs.pop("acancel_fine_tuning_job", False) is True
response = openai_fine_tuning_instance.cancel_fine_tuning_job(
response = openai_fine_tuning_apis_instance.cancel_fine_tuning_job(
api_base=api_base,
api_key=api_key,
organization=organization,
@ -284,6 +342,40 @@ def cancel_fine_tuning_job(
max_retries=optional_params.max_retries,
_is_async=_is_async,
)
# Azure OpenAI
elif custom_llm_provider == "azure":
api_base = optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore
api_version = (
optional_params.api_version
or litellm.api_version
or get_secret("AZURE_API_VERSION")
) # type: ignore
api_key = (
optional_params.api_key
or litellm.api_key
or litellm.azure_key
or get_secret("AZURE_OPENAI_API_KEY")
or get_secret("AZURE_API_KEY")
) # type: ignore
extra_body = optional_params.get("extra_body", {})
azure_ad_token: Optional[str] = None
if extra_body is not None:
azure_ad_token = extra_body.pop("azure_ad_token", None)
else:
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
response = azure_fine_tuning_apis_instance.cancel_fine_tuning_job(
api_base=api_base,
api_key=api_key,
api_version=api_version,
fine_tuning_job_id=fine_tuning_job_id,
timeout=timeout,
max_retries=optional_params.max_retries,
_is_async=_is_async,
)
else:
raise litellm.exceptions.BadRequestError(
message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format(
@ -359,6 +451,25 @@ def list_fine_tuning_jobs(
"""
try:
optional_params = GenericLiteLLMParams(**kwargs)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
# set timeout for 10 minutes by default
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) == False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
timeout = float(timeout) # type: ignore
elif timeout is None:
timeout = 600.0
_is_async = kwargs.pop("alist_fine_tuning_jobs", False) is True
# OpenAI
if custom_llm_provider == "openai":
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
@ -381,27 +492,8 @@ def list_fine_tuning_jobs(
or litellm.openai_key
or os.getenv("OPENAI_API_KEY")
)
### TIMEOUT LOGIC ###
timeout = (
optional_params.timeout or kwargs.get("request_timeout", 600) or 600
)
# set timeout for 10 minutes by default
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) == False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
timeout = float(timeout) # type: ignore
elif timeout is None:
timeout = 600.0
_is_async = kwargs.pop("alist_fine_tuning_jobs", False) is True
response = openai_fine_tuning_instance.list_fine_tuning_jobs(
response = openai_fine_tuning_apis_instance.list_fine_tuning_jobs(
api_base=api_base,
api_key=api_key,
organization=organization,
@ -411,6 +503,41 @@ def list_fine_tuning_jobs(
max_retries=optional_params.max_retries,
_is_async=_is_async,
)
# Azure OpenAI
elif custom_llm_provider == "azure":
api_base = optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore
api_version = (
optional_params.api_version
or litellm.api_version
or get_secret("AZURE_API_VERSION")
) # type: ignore
api_key = (
optional_params.api_key
or litellm.api_key
or litellm.azure_key
or get_secret("AZURE_OPENAI_API_KEY")
or get_secret("AZURE_API_KEY")
) # type: ignore
extra_body = optional_params.get("extra_body", {})
azure_ad_token: Optional[str] = None
if extra_body is not None:
azure_ad_token = extra_body.pop("azure_ad_token", None)
else:
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
response = azure_fine_tuning_apis_instance.list_fine_tuning_jobs(
api_base=api_base,
api_key=api_key,
api_version=api_version,
after=after,
limit=limit,
timeout=timeout,
max_retries=optional_params.max_retries,
_is_async=_is_async,
)
else:
raise litellm.exceptions.BadRequestError(
message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format(

View file

@ -0,0 +1,315 @@
from typing import Any, Coroutine, Dict, List, Optional, Union
import httpx
from openai import AsyncAzureOpenAI, AzureOpenAI
from openai.types.file_deleted import FileDeleted
import litellm
from litellm._logging import verbose_logger
from litellm.llms.base import BaseLLM
from litellm.types.llms.openai import *
def get_azure_openai_client(
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
api_version: Optional[str] = None,
organization: Optional[str] = None,
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
_is_async: bool = False,
) -> Optional[Union[AzureOpenAI, AsyncAzureOpenAI]]:
received_args = locals()
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None
if client is None:
data = {}
for k, v in received_args.items():
if k == "self" or k == "client" or k == "_is_async":
pass
elif k == "api_base" and v is not None:
data["azure_endpoint"] = v
elif v is not None:
data[k] = v
if "api_version" not in data:
data["api_version"] = litellm.AZURE_DEFAULT_API_VERSION
if _is_async is True:
openai_client = AsyncAzureOpenAI(**data)
else:
openai_client = AzureOpenAI(**data) # type: ignore
else:
openai_client = client
return openai_client
class AzureOpenAIFilesAPI(BaseLLM):
"""
AzureOpenAI methods to support for batches
- create_file()
- retrieve_file()
- list_files()
- delete_file()
- file_content()
- update_file()
"""
def __init__(self) -> None:
super().__init__()
async def acreate_file(
self,
create_file_data: CreateFileRequest,
openai_client: AsyncAzureOpenAI,
) -> FileObject:
verbose_logger.debug("create_file_data=%s", create_file_data)
response = await openai_client.files.create(**create_file_data)
verbose_logger.debug("create_file_response=%s", response)
return response
def create_file(
self,
_is_async: bool,
create_file_data: CreateFileRequest,
api_base: str,
api_key: Optional[str],
api_version: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
) -> Union[FileObject, Coroutine[Any, Any, FileObject]]:
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
get_azure_openai_client(
api_key=api_key,
api_base=api_base,
api_version=api_version,
timeout=timeout,
max_retries=max_retries,
client=client,
_is_async=_is_async,
)
)
if openai_client is None:
raise ValueError(
"AzureOpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
if _is_async is True:
if not isinstance(openai_client, AsyncAzureOpenAI):
raise ValueError(
"AzureOpenAI client is not an instance of AsyncAzureOpenAI. Make sure you passed an AsyncAzureOpenAI client."
)
return self.acreate_file( # type: ignore
create_file_data=create_file_data, openai_client=openai_client
)
response = openai_client.files.create(**create_file_data)
return response
async def afile_content(
self,
file_content_request: FileContentRequest,
openai_client: AsyncAzureOpenAI,
) -> HttpxBinaryResponseContent:
response = await openai_client.files.content(**file_content_request)
return response
def file_content(
self,
_is_async: bool,
file_content_request: FileContentRequest,
api_base: str,
api_key: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
api_version: Optional[str] = None,
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
) -> Union[
HttpxBinaryResponseContent, Coroutine[Any, Any, HttpxBinaryResponseContent]
]:
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
get_azure_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
api_version=api_version,
max_retries=max_retries,
organization=organization,
client=client,
_is_async=_is_async,
)
)
if openai_client is None:
raise ValueError(
"AzureOpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
if _is_async is True:
if not isinstance(openai_client, AsyncAzureOpenAI):
raise ValueError(
"AzureOpenAI client is not an instance of AsyncAzureOpenAI. Make sure you passed an AsyncAzureOpenAI client."
)
return self.afile_content( # type: ignore
file_content_request=file_content_request,
openai_client=openai_client,
)
response = openai_client.files.content(**file_content_request)
return response
async def aretrieve_file(
self,
file_id: str,
openai_client: AsyncAzureOpenAI,
) -> FileObject:
response = await openai_client.files.retrieve(file_id=file_id)
return response
def retrieve_file(
self,
_is_async: bool,
file_id: str,
api_base: str,
api_key: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
api_version: Optional[str] = None,
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
):
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
get_azure_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
api_version=api_version,
client=client,
_is_async=_is_async,
)
)
if openai_client is None:
raise ValueError(
"AzureOpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
if _is_async is True:
if not isinstance(openai_client, AsyncAzureOpenAI):
raise ValueError(
"AzureOpenAI client is not an instance of AsyncAzureOpenAI. Make sure you passed an AsyncAzureOpenAI client."
)
return self.aretrieve_file( # type: ignore
file_id=file_id,
openai_client=openai_client,
)
response = openai_client.files.retrieve(file_id=file_id)
return response
async def adelete_file(
self,
file_id: str,
openai_client: AsyncAzureOpenAI,
) -> FileDeleted:
response = await openai_client.files.delete(file_id=file_id)
return response
def delete_file(
self,
_is_async: bool,
file_id: str,
api_base: str,
api_key: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str] = None,
api_version: Optional[str] = None,
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
):
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
get_azure_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
api_version=api_version,
client=client,
_is_async=_is_async,
)
)
if openai_client is None:
raise ValueError(
"AzureOpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
if _is_async is True:
if not isinstance(openai_client, AsyncAzureOpenAI):
raise ValueError(
"AzureOpenAI client is not an instance of AsyncAzureOpenAI. Make sure you passed an AsyncAzureOpenAI client."
)
return self.adelete_file( # type: ignore
file_id=file_id,
openai_client=openai_client,
)
response = openai_client.files.delete(file_id=file_id)
return response
async def alist_files(
self,
openai_client: AsyncAzureOpenAI,
purpose: Optional[str] = None,
):
if isinstance(purpose, str):
response = await openai_client.files.list(purpose=purpose)
else:
response = await openai_client.files.list()
return response
def list_files(
self,
_is_async: bool,
api_base: str,
api_key: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
purpose: Optional[str] = None,
api_version: Optional[str] = None,
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
):
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
get_azure_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
api_version=api_version,
client=client,
_is_async=_is_async,
)
)
if openai_client is None:
raise ValueError(
"AzureOpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
if _is_async is True:
if not isinstance(openai_client, AsyncAzureOpenAI):
raise ValueError(
"AzureOpenAI client is not an instance of AsyncAzureOpenAI. Make sure you passed an AsyncAzureOpenAI client."
)
return self.alist_files( # type: ignore
purpose=purpose,
openai_client=openai_client,
)
if isinstance(purpose, str):
response = openai_client.files.list(purpose=purpose)
else:
response = openai_client.files.list()
return response

View file

@ -0,0 +1,181 @@
from typing import Any, Coroutine, Optional, Union
import httpx
from openai import AsyncAzureOpenAI, AzureOpenAI
from openai.pagination import AsyncCursorPage
from openai.types.fine_tuning import FineTuningJob
from litellm._logging import verbose_logger
from litellm.llms.base import BaseLLM
from litellm.llms.files_apis.azure import get_azure_openai_client
from litellm.types.llms.openai import FineTuningJobCreate
class AzureOpenAIFineTuningAPI(BaseLLM):
"""
AzureOpenAI methods to support for batches
"""
def __init__(self) -> None:
super().__init__()
async def acreate_fine_tuning_job(
self,
create_fine_tuning_job_data: dict,
openai_client: AsyncAzureOpenAI,
) -> FineTuningJob:
response = await openai_client.fine_tuning.jobs.create(
**create_fine_tuning_job_data # type: ignore
)
return response
def create_fine_tuning_job(
self,
_is_async: bool,
create_fine_tuning_job_data: dict,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str] = None,
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
api_version: Optional[str] = None,
) -> Union[FineTuningJob, Union[Coroutine[Any, Any, FineTuningJob]]]:
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
get_azure_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
api_version=api_version,
client=client,
_is_async=_is_async,
)
)
if openai_client is None:
raise ValueError(
"AzureOpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
if _is_async is True:
if not isinstance(openai_client, AsyncAzureOpenAI):
raise ValueError(
"AzureOpenAI client is not an instance of AsyncAzureOpenAI. Make sure you passed an AsyncAzureOpenAI client."
)
return self.acreate_fine_tuning_job( # type: ignore
create_fine_tuning_job_data=create_fine_tuning_job_data,
openai_client=openai_client,
)
verbose_logger.debug(
"creating fine tuning job, args= %s", create_fine_tuning_job_data
)
response = openai_client.fine_tuning.jobs.create(**create_fine_tuning_job_data) # type: ignore
return response
async def acancel_fine_tuning_job(
self,
fine_tuning_job_id: str,
openai_client: AsyncAzureOpenAI,
) -> FineTuningJob:
response = await openai_client.fine_tuning.jobs.cancel(
fine_tuning_job_id=fine_tuning_job_id
)
return response
def cancel_fine_tuning_job(
self,
_is_async: bool,
fine_tuning_job_id: str,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str] = None,
api_version: Optional[str] = None,
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
):
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
get_azure_openai_client(
api_key=api_key,
api_base=api_base,
api_version=api_version,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
_is_async=_is_async,
)
)
if openai_client is None:
raise ValueError(
"AzureOpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
if _is_async is True:
if not isinstance(openai_client, AsyncAzureOpenAI):
raise ValueError(
"AzureOpenAI client is not an instance of AsyncAzureOpenAI. Make sure you passed an AsyncAzureOpenAI client."
)
return self.acancel_fine_tuning_job( # type: ignore
fine_tuning_job_id=fine_tuning_job_id,
openai_client=openai_client,
)
verbose_logger.debug("canceling fine tuning job, args= %s", fine_tuning_job_id)
response = openai_client.fine_tuning.jobs.cancel(
fine_tuning_job_id=fine_tuning_job_id
)
return response
async def alist_fine_tuning_jobs(
self,
openai_client: AsyncAzureOpenAI,
after: Optional[str] = None,
limit: Optional[int] = None,
):
response = await openai_client.fine_tuning.jobs.list(after=after, limit=limit) # type: ignore
return response
def list_fine_tuning_jobs(
self,
_is_async: bool,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str] = None,
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
api_version: Optional[str] = None,
after: Optional[str] = None,
limit: Optional[int] = None,
):
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
get_azure_openai_client(
api_key=api_key,
api_base=api_base,
api_version=api_version,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
_is_async=_is_async,
)
)
if openai_client is None:
raise ValueError(
"AzureOpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
if _is_async is True:
if not isinstance(openai_client, AsyncAzureOpenAI):
raise ValueError(
"AzureOpenAI client is not an instance of AsyncAzureOpenAI. Make sure you passed an AsyncAzureOpenAI client."
)
return self.alist_fine_tuning_jobs( # type: ignore
after=after,
limit=limit,
openai_client=openai_client,
)
verbose_logger.debug("list fine tuning job, after= %s, limit= %s", after, limit)
response = openai_client.fine_tuning.jobs.list(after=after, limit=limit) # type: ignore
return response

View file

@ -50,18 +50,18 @@ class OpenAIFineTuningAPI(BaseLLM):
async def acreate_fine_tuning_job(
self,
create_fine_tuning_job_data: FineTuningJobCreate,
create_fine_tuning_job_data: dict,
openai_client: AsyncOpenAI,
) -> FineTuningJob:
response = await openai_client.fine_tuning.jobs.create(
**create_fine_tuning_job_data # type: ignore
**create_fine_tuning_job_data
)
return response
def create_fine_tuning_job(
self,
_is_async: bool,
create_fine_tuning_job_data: FineTuningJobCreate,
create_fine_tuning_job_data: dict,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
@ -95,7 +95,7 @@ class OpenAIFineTuningAPI(BaseLLM):
verbose_logger.debug(
"creating fine tuning job, args= %s", create_fine_tuning_job_data
)
response = openai_client.fine_tuning.jobs.create(**create_fine_tuning_job_data) # type: ignore
response = openai_client.fine_tuning.jobs.create(**create_fine_tuning_job_data)
return response
async def acancel_fine_tuning_job(

View file

@ -208,6 +208,11 @@ class LiteLLMRoutes(enum.Enum):
"/files/{file_id}",
"/v1/files/{file_id}/content",
"/files/{file_id}/content",
# fine_tuning
"/fine_tuning/jobs",
"v1/fine_tuning/jobs",
"/fine_tuning/jobs/{fine_tuning_job_id}/cancel"
"/v1/fine_tuning/jobs/{fine_tuning_job_id}/cancel"
# assistants-related routes
"/assistants",
"/v1/assistants",

View file

@ -0,0 +1,431 @@
#########################################################################
# /v1/fine_tuning Endpoints
# Equivalent of https://platform.openai.com/docs/api-reference/fine-tuning
##########################################################################
import asyncio
import traceback
from datetime import datetime, timedelta, timezone
from typing import List, Optional
import fastapi
import httpx
from fastapi import (
APIRouter,
Depends,
File,
Form,
Header,
HTTPException,
Request,
Response,
UploadFile,
status,
)
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.batches.main import FileObject
from litellm.proxy._types import *
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
router = APIRouter()
from litellm.types.llms.openai import LiteLLMFineTuningJobCreate
fine_tuning_config = None
def set_fine_tuning_config(config):
if config is None:
return
global fine_tuning_config
if not isinstance(config, list):
raise ValueError("invalid fine_tuning config, expected a list is not a list")
for element in config:
if isinstance(element, dict):
for key, value in element.items():
if isinstance(value, str) and value.startswith("os.environ/"):
element[key] = litellm.get_secret(value)
fine_tuning_config = config
# Function to search for specific custom_llm_provider and return its configuration
def get_fine_tuning_provider_config(
custom_llm_provider: str,
):
global fine_tuning_config
if fine_tuning_config is None:
raise ValueError(
"fine_tuning_config is not set, set it on your config.yaml file."
)
for setting in fine_tuning_config:
if setting.get("custom_llm_provider") == custom_llm_provider:
return setting
return None
@router.post(
"/v1/fine_tuning/jobs",
dependencies=[Depends(user_api_key_auth)],
tags=["fine-tuning"],
summary="✨ (Enterprise) Create Fine-Tuning Job",
)
@router.post(
"/fine_tuning/jobs",
dependencies=[Depends(user_api_key_auth)],
tags=["fine-tuning"],
summary="✨ (Enterprise) Create Fine-Tuning Job",
)
async def create_fine_tuning_job(
request: Request,
fastapi_response: Response,
fine_tuning_request: LiteLLMFineTuningJobCreate,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Creates a fine-tuning job which begins the process of creating a new model from a given dataset.
This is the equivalent of POST https://api.openai.com/v1/fine_tuning/jobs
Supports Identical Params as: https://platform.openai.com/docs/api-reference/fine-tuning/create
Example Curl:
```
curl http://localhost:4000/v1/fine_tuning/jobs \
-H "Content-Type: application/json" \
-H "Authorization: Bearer sk-1234" \
-d '{
"model": "gpt-3.5-turbo",
"training_file": "file-abc123",
"hyperparameters": {
"n_epochs": 4
}
}'
```
"""
from litellm.proxy.proxy_server import (
add_litellm_data_to_request,
general_settings,
get_custom_headers,
premium_user,
proxy_config,
proxy_logging_obj,
version,
)
try:
if premium_user is not True:
raise ValueError(
f"Only premium users can use this endpoint + {CommonProxyErrors.not_premium_user.value}"
)
# Convert Pydantic model to dict
data = fine_tuning_request.model_dump(exclude_none=True)
verbose_proxy_logger.debug(
"Request received by LiteLLM:\n{}".format(json.dumps(data, indent=4)),
)
# Include original request and headers in the data
data = await add_litellm_data_to_request(
data=data,
request=request,
general_settings=general_settings,
user_api_key_dict=user_api_key_dict,
version=version,
proxy_config=proxy_config,
)
# get configs for custom_llm_provider
llm_provider_config = get_fine_tuning_provider_config(
custom_llm_provider=fine_tuning_request.custom_llm_provider,
)
# add llm_provider_config to data
data.update(llm_provider_config)
response = await litellm.acreate_fine_tuning_job(**data)
### ALERTING ###
asyncio.create_task(
proxy_logging_obj.update_request_status(
litellm_call_id=data.get("litellm_call_id", ""), status="success"
)
)
### RESPONSE HEADERS ###
hidden_params = getattr(response, "_hidden_params", {}) or {}
model_id = hidden_params.get("model_id", None) or ""
cache_key = hidden_params.get("cache_key", None) or ""
api_base = hidden_params.get("api_base", None) or ""
fastapi_response.headers.update(
get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,
api_base=api_base,
version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
)
)
return response
except Exception as e:
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
)
verbose_proxy_logger.error(
"litellm.proxy.proxy_server.create_fine_tuning_job(): Exception occurred - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "message", str(e.detail)),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
)
else:
error_msg = f"{str(e)}"
raise ProxyException(
message=getattr(e, "message", error_msg),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", 500),
)
@router.get(
"/v1/fine_tuning/jobs",
dependencies=[Depends(user_api_key_auth)],
tags=["fine-tuning"],
summary="✨ (Enterprise) List Fine-Tuning Jobs",
)
@router.get(
"/fine_tuning/jobs",
dependencies=[Depends(user_api_key_auth)],
tags=["fine-tuning"],
summary="✨ (Enterprise) List Fine-Tuning Jobs",
)
async def list_fine_tuning_jobs(
request: Request,
fastapi_response: Response,
custom_llm_provider: Literal["openai", "azure"],
after: Optional[str] = None,
limit: Optional[int] = None,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Lists fine-tuning jobs for the organization.
This is the equivalent of GET https://api.openai.com/v1/fine_tuning/jobs
Supported Query Params:
- `custom_llm_provider`: Name of the LiteLLM provider
- `after`: Identifier for the last job from the previous pagination request.
- `limit`: Number of fine-tuning jobs to retrieve (default is 20).
"""
from litellm.proxy.proxy_server import (
add_litellm_data_to_request,
general_settings,
get_custom_headers,
premium_user,
proxy_config,
proxy_logging_obj,
version,
)
data: dict = {}
try:
if premium_user is not True:
raise ValueError(
f"Only premium users can use this endpoint + {CommonProxyErrors.not_premium_user.value}"
)
# Include original request and headers in the data
data = await add_litellm_data_to_request(
data=data,
request=request,
general_settings=general_settings,
user_api_key_dict=user_api_key_dict,
version=version,
proxy_config=proxy_config,
)
# get configs for custom_llm_provider
llm_provider_config = get_fine_tuning_provider_config(
custom_llm_provider=custom_llm_provider
)
data.update(llm_provider_config)
response = await litellm.alist_fine_tuning_jobs(
**data,
after=after,
limit=limit,
)
### RESPONSE HEADERS ###
hidden_params = getattr(response, "_hidden_params", {}) or {}
model_id = hidden_params.get("model_id", None) or ""
cache_key = hidden_params.get("cache_key", None) or ""
api_base = hidden_params.get("api_base", None) or ""
fastapi_response.headers.update(
get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,
api_base=api_base,
version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
)
)
return response
except Exception as e:
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
)
verbose_proxy_logger.error(
"litellm.proxy.proxy_server.list_fine_tuning_jobs(): Exception occurred - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "message", str(e.detail)),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
)
else:
error_msg = f"{str(e)}"
raise ProxyException(
message=getattr(e, "message", error_msg),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", 500),
)
@router.post(
"/v1/fine_tuning/jobs/{fine_tuning_job_id:path}/cancel",
dependencies=[Depends(user_api_key_auth)],
tags=["fine-tuning"],
summary="✨ (Enterprise) Cancel Fine-Tuning Jobs",
)
@router.post(
"/fine_tuning/jobs/{fine_tuning_job_id:path}/cancel",
dependencies=[Depends(user_api_key_auth)],
tags=["fine-tuning"],
summary="✨ (Enterprise) Cancel Fine-Tuning Jobs",
)
async def retrieve_fine_tuning_job(
request: Request,
fastapi_response: Response,
fine_tuning_job_id: str,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Cancel a fine-tuning job.
This is the equivalent of POST https://api.openai.com/v1/fine_tuning/jobs/{fine_tuning_job_id}/cancel
Supported Query Params:
- `custom_llm_provider`: Name of the LiteLLM provider
- `fine_tuning_job_id`: The ID of the fine-tuning job to cancel.
"""
from litellm.proxy.proxy_server import (
add_litellm_data_to_request,
general_settings,
get_custom_headers,
premium_user,
proxy_config,
proxy_logging_obj,
version,
)
data: dict = {}
try:
if premium_user is not True:
raise ValueError(
f"Only premium users can use this endpoint + {CommonProxyErrors.not_premium_user.value}"
)
# Include original request and headers in the data
data = await add_litellm_data_to_request(
data=data,
request=request,
general_settings=general_settings,
user_api_key_dict=user_api_key_dict,
version=version,
proxy_config=proxy_config,
)
request_body = await request.json()
custom_llm_provider = request_body.get("custom_llm_provider", None)
# get configs for custom_llm_provider
llm_provider_config = get_fine_tuning_provider_config(
custom_llm_provider=custom_llm_provider
)
data.update(llm_provider_config)
response = await litellm.acancel_fine_tuning_job(
**data,
fine_tuning_job_id=fine_tuning_job_id,
)
### RESPONSE HEADERS ###
hidden_params = getattr(response, "_hidden_params", {}) or {}
model_id = hidden_params.get("model_id", None) or ""
cache_key = hidden_params.get("cache_key", None) or ""
api_base = hidden_params.get("api_base", None) or ""
fastapi_response.headers.update(
get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,
api_base=api_base,
version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
)
)
return response
except Exception as e:
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
)
verbose_proxy_logger.error(
"litellm.proxy.proxy_server.list_fine_tuning_jobs(): Exception occurred - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "message", str(e.detail)),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
)
else:
error_msg = f"{str(e)}"
raise ProxyException(
message=getattr(e, "message", error_msg),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", 500),
)

View file

@ -34,6 +34,37 @@ from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
router = APIRouter()
files_config = None
def set_files_config(config):
global files_config
if config is None:
return
if not isinstance(config, list):
raise ValueError("invalid files config, expected a list is not a list")
for element in config:
if isinstance(element, dict):
for key, value in element.items():
if isinstance(value, str) and value.startswith("os.environ/"):
element[key] = litellm.get_secret(value)
files_config = config
def get_files_provider_config(
custom_llm_provider: str,
):
global files_config
if files_config is None:
raise ValueError("files_config is not set, set it on your config.yaml file.")
for setting in files_config:
if setting.get("custom_llm_provider") == custom_llm_provider:
return setting
return None
@router.post(
"/v1/files",
@ -49,6 +80,7 @@ async def create_file(
request: Request,
fastapi_response: Response,
purpose: str = Form(...),
custom_llm_provider: str = Form(default="openai"),
file: UploadFile = File(...),
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
@ -100,11 +132,17 @@ async def create_file(
_create_file_request = CreateFileRequest(file=file_data, **data)
# for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch
response = await litellm.acreate_file(
custom_llm_provider="openai", **_create_file_request
# get configs for custom_llm_provider
llm_provider_config = get_files_provider_config(
custom_llm_provider=custom_llm_provider
)
# add llm_provider_config to data
_create_file_request.update(llm_provider_config)
# for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch
response = await litellm.acreate_file(**_create_file_request)
### ALERTING ###
asyncio.create_task(
proxy_logging_obj.update_request_status(

View file

@ -25,6 +25,27 @@ model_list:
api_key: "os.environ/OPENAI_API_KEY"
model_info:
mode: audio_speech
# For /fine_tuning/jobs endpoints
finetune_settings:
- custom_llm_provider: azure
api_base: https://exampleopenaiendpoint-production.up.railway.app
api_key: fake-key
api_version: "2023-03-15-preview"
- custom_llm_provider: openai
api_key: os.environ/OPENAI_API_KEY
# for /files endpoints
files_settings:
- custom_llm_provider: azure
api_base: https://exampleopenaiendpoint-production.up.railway.app
api_key: fake-key
api_version: "2023-03-15-preview"
- custom_llm_provider: openai
api_key: os.environ/OPENAI_API_KEY
general_settings:
master_key: sk-1234

View file

@ -153,6 +153,8 @@ from litellm.proxy.common_utils.init_callbacks import initialize_callbacks_on_pr
from litellm.proxy.common_utils.openai_endpoint_utils import (
remove_sensitive_info_from_deployment,
)
from litellm.proxy.fine_tuning_endpoints.endpoints import router as fine_tuning_router
from litellm.proxy.fine_tuning_endpoints.endpoints import set_fine_tuning_config
from litellm.proxy.guardrails.init_guardrails import initialize_guardrails
from litellm.proxy.health_check import perform_health_check
from litellm.proxy.health_endpoints._health_endpoints import router as health_router
@ -179,6 +181,7 @@ from litellm.proxy.management_endpoints.team_endpoints import router as team_rou
from litellm.proxy.openai_files_endpoints.files_endpoints import (
router as openai_files_router,
)
from litellm.proxy.openai_files_endpoints.files_endpoints import set_files_config
from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
initialize_pass_through_endpoints,
)
@ -1807,6 +1810,14 @@ class ProxyConfig:
assistant_settings["litellm_params"][k] = v
assistants_config = AssistantsTypedDict(**assistant_settings) # type: ignore
## /fine_tuning/jobs endpoints config
finetuning_config = config.get("finetune_settings", None)
set_fine_tuning_config(config=finetuning_config)
## /files endpoint config
files_config = config.get("files_settings", None)
set_files_config(config=files_config)
## ROUTER SETTINGS (e.g. routing_strategy, ...)
router_settings = config.get("router_settings", None)
if router_settings and isinstance(router_settings, dict):
@ -9598,6 +9609,7 @@ def cleanup_router_config_variables():
app.include_router(router)
app.include_router(fine_tuning_router)
app.include_router(health_router)
app.include_router(key_management_router)
app.include_router(internal_user_router)

View file

@ -0,0 +1,12 @@
{"messages": [{"role": "system", "content": "Clippy is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "What's the capital of France?"}, {"role": "assistant", "content": "Paris, as if everyone doesn't know that already."}]}
{"messages": [{"role": "system", "content": "Clippy is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "Who wrote 'Romeo and Juliet'?"}, {"role": "assistant", "content": "Oh, just some guy named William Shakespeare. Ever heard of him?"}]}
{"messages": [{"role": "system", "content": "Clippy is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "How far is the Moon from Earth?"}, {"role": "assistant", "content": "Around 384,400 kilometers. Give or take a few, like that really matters."}]}
{"messages": [{"role": "system", "content": "Clippy is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "What's the capital of France?"}, {"role": "assistant", "content": "Paris, as if everyone doesn't know that already."}]}
{"messages": [{"role": "system", "content": "Clippy is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "Who wrote 'Romeo and Juliet'?"}, {"role": "assistant", "content": "Oh, just some guy named William Shakespeare. Ever heard of him?"}]}
{"messages": [{"role": "system", "content": "Clippy is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "How far is the Moon from Earth?"}, {"role": "assistant", "content": "Around 384,400 kilometers. Give or take a few, like that really matters."}]}
{"messages": [{"role": "system", "content": "Clippy is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "What's the capital of France?"}, {"role": "assistant", "content": "Paris, as if everyone doesn't know that already."}]}
{"messages": [{"role": "system", "content": "Clippy is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "Who wrote 'Romeo and Juliet'?"}, {"role": "assistant", "content": "Oh, just some guy named William Shakespeare. Ever heard of him?"}]}
{"messages": [{"role": "system", "content": "Clippy is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "How far is the Moon from Earth?"}, {"role": "assistant", "content": "Around 384,400 kilometers. Give or take a few, like that really matters."}]}
{"messages": [{"role": "system", "content": "Clippy is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "What's the capital of France?"}, {"role": "assistant", "content": "Paris, as if everyone doesn't know that already."}]}
{"messages": [{"role": "system", "content": "Clippy is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "Who wrote 'Romeo and Juliet'?"}, {"role": "assistant", "content": "Oh, just some guy named William Shakespeare. Ever heard of him?"}]}
{"messages": [{"role": "system", "content": "Clippy is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "How far is the Moon from Earth?"}, {"role": "assistant", "content": "Around 384,400 kilometers. Give or take a few, like that really matters."}]}

View file

@ -12,6 +12,7 @@ from openai import APITimeoutError as Timeout
import litellm
litellm.num_retries = 0
import asyncio
import logging
import openai
@ -128,3 +129,49 @@ async def test_create_fine_tune_jobs_async():
except Exception as e:
pytest.fail(f"Error occurred: {e}")
pass
@pytest.mark.asyncio
async def test_azure_create_fine_tune_jobs_async():
verbose_logger.setLevel(logging.DEBUG)
file_name = "azure_fine_tune.jsonl"
_current_dir = os.path.dirname(os.path.abspath(__file__))
file_path = os.path.join(_current_dir, file_name)
file_id = "file-5e4b20ecbd724182b9964f3cd2ab7212"
create_fine_tuning_response = await litellm.acreate_fine_tuning_job(
model="gpt-35-turbo-1106",
training_file=file_id,
custom_llm_provider="azure",
api_key=os.getenv("AZURE_SWEDEN_API_KEY"),
api_base="https://my-endpoint-sweden-berri992.openai.azure.com/",
)
print("response from litellm.create_fine_tuning_job=", create_fine_tuning_response)
assert create_fine_tuning_response.id is not None
assert create_fine_tuning_response.model == "gpt-35-turbo-1106"
# list fine tuning jobs
print("listing ft jobs")
ft_jobs = await litellm.alist_fine_tuning_jobs(
limit=2,
custom_llm_provider="azure",
api_key=os.getenv("AZURE_SWEDEN_API_KEY"),
api_base="https://my-endpoint-sweden-berri992.openai.azure.com/",
)
print("response from litellm.list_fine_tuning_jobs=", ft_jobs)
# cancel ft job
response = await litellm.acancel_fine_tuning_job(
fine_tuning_job_id=create_fine_tuning_response.id,
custom_llm_provider="azure",
api_key=os.getenv("AZURE_SWEDEN_API_KEY"),
api_base="https://my-endpoint-sweden-berri992.openai.azure.com/",
)
print("response from litellm.cancel_fine_tuning_job=", response)
assert response.status == "cancelled"
assert response.id == create_fine_tuning_response.id

View file

@ -9,7 +9,6 @@ from typing import (
Mapping,
Optional,
Tuple,
TypedDict,
Union,
)
@ -31,7 +30,7 @@ from openai.types.beta.threads.message import Message as OpenAIMessage
from openai.types.beta.threads.message_content import MessageContent
from openai.types.beta.threads.run import Run
from pydantic import BaseModel, Field
from typing_extensions import Dict, Required, override
from typing_extensions import Dict, Required, TypedDict, override
FileContent = Union[IO[bytes], bytes, PathLike]
@ -457,15 +456,17 @@ class ChatCompletionUsageBlock(TypedDict):
total_tokens: int
class Hyperparameters(TypedDict):
batch_size: Optional[Union[str, int]] # "Number of examples in each batch."
learning_rate_multiplier: Optional[
Union[str, float]
] # Scaling factor for the learning rate
n_epochs: Optional[Union[str, int]] # "The number of epochs to train the model for"
class Hyperparameters(BaseModel):
batch_size: Optional[Union[str, int]] = None # "Number of examples in each batch."
learning_rate_multiplier: Optional[Union[str, float]] = (
None # Scaling factor for the learning rate
)
n_epochs: Optional[Union[str, int]] = (
None # "The number of epochs to train the model for"
)
class FineTuningJobCreate(TypedDict):
class FineTuningJobCreate(BaseModel):
"""
FineTuningJobCreate - Create a fine-tuning job
@ -489,16 +490,20 @@ class FineTuningJobCreate(TypedDict):
model: str # "The name of the model to fine-tune."
training_file: str # "The ID of an uploaded file that contains training data."
hyperparameters: Optional[
Hyperparameters
] # "The hyperparameters used for the fine-tuning job."
suffix: Optional[
str
] # "A string of up to 18 characters that will be added to your fine-tuned model name."
validation_file: Optional[
str
] # "The ID of an uploaded file that contains validation data."
integrations: Optional[
List[str]
] # "A list of integrations to enable for your fine-tuning job."
seed: Optional[int] # "The seed controls the reproducibility of the job."
hyperparameters: Optional[Hyperparameters] = (
None # "The hyperparameters used for the fine-tuning job."
)
suffix: Optional[str] = (
None # "A string of up to 18 characters that will be added to your fine-tuned model name."
)
validation_file: Optional[str] = (
None # "The ID of an uploaded file that contains validation data."
)
integrations: Optional[List[str]] = (
None # "A list of integrations to enable for your fine-tuning job."
)
seed: Optional[int] = None # "The seed controls the reproducibility of the job."
class LiteLLMFineTuningJobCreate(FineTuningJobCreate):
custom_llm_provider: Literal["openai", "azure"]

View file

@ -120,6 +120,24 @@ litellm_settings:
langfuse_secret: os.environ/LANGFUSE_PROJECT2_SECRET # Project 2
langfuse_host: https://us.cloud.langfuse.com
# For /fine_tuning/jobs endpoints
finetune_settings:
- custom_llm_provider: azure
api_base: https://exampleopenaiendpoint-production.up.railway.app
api_key: fake-key
api_version: "2023-03-15-preview"
- custom_llm_provider: openai
api_key: os.environ/OPENAI_API_KEY
# for /files endpoints
files_settings:
- custom_llm_provider: azure
api_base: http://0.0.0.0:8090
api_key: fake-key
api_version: "2023-03-15-preview"
- custom_llm_provider: openai
api_key: os.environ/OPENAI_API_KEY
router_settings:
routing_strategy: usage-based-routing-v2
redis_host: os.environ/REDIS_HOST

View file

@ -0,0 +1,2 @@
{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo-0125", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 10}}
{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo-0125", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 10}}

View file

@ -0,0 +1,53 @@
from openai import AsyncOpenAI
import os
import pytest
@pytest.mark.asyncio
async def test_openai_fine_tuning():
"""
[PROD Test] Ensures logprobs are returned correctly
"""
client = AsyncOpenAI(api_key="sk-1234", base_url="http://0.0.0.0:4000")
file_name = "openai_batch_completions.jsonl"
_current_dir = os.path.dirname(os.path.abspath(__file__))
file_path = os.path.join(_current_dir, file_name)
response = await client.files.create(
extra_body={"custom_llm_provider": "azure"},
file=open(file_path, "rb"),
purpose="fine-tune",
)
print("response from files.create: {}".format(response))
# create fine tuning job
ft_job = await client.fine_tuning.jobs.create(
model="gpt-35-turbo-1106",
training_file=response.id,
extra_body={"custom_llm_provider": "azure"},
)
print("response from ft job={}".format(ft_job))
# response from example endpoint
assert ft_job.id == "ftjob-abc123"
# list all fine tuning jobs
list_ft_jobs = await client.fine_tuning.jobs.list(
extra_query={"custom_llm_provider": "azure"}
)
print("list of ft jobs={}".format(list_ft_jobs))
# cancel specific fine tuning job
cancel_ft_job = await client.fine_tuning.jobs.cancel(
fine_tuning_job_id="123",
extra_body={"custom_llm_provider": "azure"},
)
print("response from cancel ft job={}".format(cancel_ft_job))
assert cancel_ft_job.id is not None