(Feat) - new endpoint GET /v1/fine_tuning/jobs/{fine_tuning_job_id:path} (#7427)

* init commit ft jobs logging

* add ft logging

* add logging for FineTuningJob

* simple FT Job create test

* simplify Azure fine tuning to use all methods in OAI ft

* update doc string

* add aretrieve_fine_tuning_job

* re use from litellm.proxy.utils import handle_exception_on_proxy

* fix naming

* add /fine_tuning/jobs/{fine_tuning_job_id:path}

* remove unused imports

* update func signature

* run ci/cd again

* ci/cd run again

* fix code qulity

* ci/cd run again
This commit is contained in:
Ishaan Jaff 2024-12-27 17:01:14 -08:00 committed by GitHub
parent 5e8c64f128
commit 2ece919f01
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 400 additions and 227 deletions

View file

@ -171,6 +171,7 @@ def create_fine_tuning_job(
response = openai_fine_tuning_apis_instance.create_fine_tuning_job( response = openai_fine_tuning_apis_instance.create_fine_tuning_job(
api_base=api_base, api_base=api_base,
api_key=api_key, api_key=api_key,
api_version=optional_params.api_version,
organization=organization, organization=organization,
create_fine_tuning_job_data=create_fine_tuning_job_data_dict, create_fine_tuning_job_data=create_fine_tuning_job_data_dict,
timeout=timeout, timeout=timeout,
@ -223,6 +224,7 @@ def create_fine_tuning_job(
timeout=timeout, timeout=timeout,
max_retries=optional_params.max_retries, max_retries=optional_params.max_retries,
_is_async=_is_async, _is_async=_is_async,
organization=optional_params.organization,
) )
elif custom_llm_provider == "vertex_ai": elif custom_llm_provider == "vertex_ai":
api_base = optional_params.api_base or "" api_base = optional_params.api_base or ""
@ -279,7 +281,7 @@ def create_fine_tuning_job(
async def acancel_fine_tuning_job( async def acancel_fine_tuning_job(
fine_tuning_job_id: str, fine_tuning_job_id: str,
custom_llm_provider: Literal["openai"] = "openai", custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
extra_headers: Optional[Dict[str, str]] = None, extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None, extra_body: Optional[Dict[str, str]] = None,
**kwargs, **kwargs,
@ -374,6 +376,7 @@ def cancel_fine_tuning_job(
response = openai_fine_tuning_apis_instance.cancel_fine_tuning_job( response = openai_fine_tuning_apis_instance.cancel_fine_tuning_job(
api_base=api_base, api_base=api_base,
api_key=api_key, api_key=api_key,
api_version=optional_params.api_version,
organization=organization, organization=organization,
fine_tuning_job_id=fine_tuning_job_id, fine_tuning_job_id=fine_tuning_job_id,
timeout=timeout, timeout=timeout,
@ -412,6 +415,7 @@ def cancel_fine_tuning_job(
timeout=timeout, timeout=timeout,
max_retries=optional_params.max_retries, max_retries=optional_params.max_retries,
_is_async=_is_async, _is_async=_is_async,
organization=optional_params.organization,
) )
else: else:
raise litellm.exceptions.BadRequestError( raise litellm.exceptions.BadRequestError(
@ -434,7 +438,7 @@ def cancel_fine_tuning_job(
async def alist_fine_tuning_jobs( async def alist_fine_tuning_jobs(
after: Optional[str] = None, after: Optional[str] = None,
limit: Optional[int] = None, limit: Optional[int] = None,
custom_llm_provider: Literal["openai"] = "openai", custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
extra_headers: Optional[Dict[str, str]] = None, extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None, extra_body: Optional[Dict[str, str]] = None,
**kwargs, **kwargs,
@ -533,6 +537,7 @@ def list_fine_tuning_jobs(
response = openai_fine_tuning_apis_instance.list_fine_tuning_jobs( response = openai_fine_tuning_apis_instance.list_fine_tuning_jobs(
api_base=api_base, api_base=api_base,
api_key=api_key, api_key=api_key,
api_version=optional_params.api_version,
organization=organization, organization=organization,
after=after, after=after,
limit=limit, limit=limit,
@ -573,6 +578,7 @@ def list_fine_tuning_jobs(
timeout=timeout, timeout=timeout,
max_retries=optional_params.max_retries, max_retries=optional_params.max_retries,
_is_async=_is_async, _is_async=_is_async,
organization=optional_params.organization,
) )
else: else:
raise litellm.exceptions.BadRequestError( raise litellm.exceptions.BadRequestError(
@ -590,3 +596,153 @@ def list_fine_tuning_jobs(
return response return response
except Exception as e: except Exception as e:
raise e raise e
async def aretrieve_fine_tuning_job(
fine_tuning_job_id: str,
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
) -> FineTuningJob:
"""
Async: Get info about a fine-tuning job.
"""
try:
loop = asyncio.get_event_loop()
kwargs["aretrieve_fine_tuning_job"] = True
# Use a partial function to pass your keyword arguments
func = partial(
retrieve_fine_tuning_job,
fine_tuning_job_id,
custom_llm_provider,
extra_headers,
extra_body,
**kwargs,
)
# Add the context to the function
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
init_response = await loop.run_in_executor(None, func_with_context)
if asyncio.iscoroutine(init_response):
response = await init_response
else:
response = init_response # type: ignore
return response
except Exception as e:
raise e
def retrieve_fine_tuning_job(
fine_tuning_job_id: str,
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
) -> Union[FineTuningJob, Coroutine[Any, Any, FineTuningJob]]:
"""
Get info about a fine-tuning job.
"""
try:
optional_params = GenericLiteLLMParams(**kwargs)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
# set timeout for 10 minutes by default
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) is False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
timeout = float(timeout) # type: ignore
elif timeout is None:
timeout = 600.0
_is_async = kwargs.pop("aretrieve_fine_tuning_job", False) is True
# OpenAI
if custom_llm_provider == "openai":
api_base = (
optional_params.api_base
or litellm.api_base
or os.getenv("OPENAI_API_BASE")
or "https://api.openai.com/v1"
)
organization = (
optional_params.organization
or litellm.organization
or os.getenv("OPENAI_ORGANIZATION", None)
or None
)
api_key = (
optional_params.api_key
or litellm.api_key
or litellm.openai_key
or os.getenv("OPENAI_API_KEY")
)
response = openai_fine_tuning_apis_instance.retrieve_fine_tuning_job(
api_base=api_base,
api_key=api_key,
api_version=optional_params.api_version,
organization=organization,
fine_tuning_job_id=fine_tuning_job_id,
timeout=timeout,
max_retries=optional_params.max_retries,
_is_async=_is_async,
)
# Azure OpenAI
elif custom_llm_provider == "azure":
api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
api_version = (
optional_params.api_version
or litellm.api_version
or get_secret_str("AZURE_API_VERSION")
) # type: ignore
api_key = (
optional_params.api_key
or litellm.api_key
or litellm.azure_key
or get_secret_str("AZURE_OPENAI_API_KEY")
or get_secret_str("AZURE_API_KEY")
) # type: ignore
extra_body = optional_params.get("extra_body", {})
if extra_body is not None:
extra_body.pop("azure_ad_token", None)
else:
get_secret_str("AZURE_AD_TOKEN") # type: ignore
response = azure_fine_tuning_apis_instance.retrieve_fine_tuning_job(
api_base=api_base,
api_key=api_key,
api_version=api_version,
fine_tuning_job_id=fine_tuning_job_id,
timeout=timeout,
max_retries=optional_params.max_retries,
_is_async=_is_async,
organization=optional_params.organization,
)
else:
raise litellm.exceptions.BadRequestError(
message="LiteLLM doesn't support {} for 'retrieve_fine_tuning_job'. Only 'openai' and 'azure' are supported.".format(
custom_llm_provider
),
model="n/a",
llm_provider=custom_llm_provider,
response=httpx.Response(
status_code=400,
content="Unsupported provider",
request=httpx.Request(method="retrieve_fine_tuning_job", url="https://github.com/BerriAI/litellm"), # type: ignore
),
)
return response
except Exception as e:
raise e

View file

@ -1,179 +1,48 @@
from typing import Any, Coroutine, Optional, Union from typing import Optional, Union
import httpx import httpx
from openai import AsyncAzureOpenAI, AzureOpenAI from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI
from openai.types.fine_tuning import FineTuningJob
from litellm._logging import verbose_logger
from litellm.llms.azure.files.handler import get_azure_openai_client from litellm.llms.azure.files.handler import get_azure_openai_client
from litellm.llms.base import BaseLLM from litellm.llms.openai.fine_tuning.handler import OpenAIFineTuningAPI
class AzureOpenAIFineTuningAPI(BaseLLM): class AzureOpenAIFineTuningAPI(OpenAIFineTuningAPI):
""" """
AzureOpenAI methods to support for batches AzureOpenAI methods to support fine tuning, inherits from OpenAIFineTuningAPI.
""" """
def __init__(self) -> None: def get_openai_client(
super().__init__()
async def acreate_fine_tuning_job(
self, self,
create_fine_tuning_job_data: dict,
openai_client: AsyncAzureOpenAI,
) -> FineTuningJob:
response = await openai_client.fine_tuning.jobs.create(
**create_fine_tuning_job_data # type: ignore
)
return response
def create_fine_tuning_job(
self,
_is_async: bool,
create_fine_tuning_job_data: dict,
api_key: Optional[str], api_key: Optional[str],
api_base: Optional[str], api_base: Optional[str],
timeout: Union[float, httpx.Timeout], timeout: Union[float, httpx.Timeout],
max_retries: Optional[int], max_retries: Optional[int],
organization: Optional[str] = None, organization: Optional[str],
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None, client: Optional[
Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI]
] = None,
_is_async: bool = False,
api_version: Optional[str] = None, api_version: Optional[str] = None,
) -> Union[FineTuningJob, Coroutine[Any, Any, FineTuningJob]]: ) -> Optional[
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = ( Union[
get_azure_openai_client( OpenAI,
api_key=api_key, AsyncOpenAI,
api_base=api_base, AzureOpenAI,
timeout=timeout, AsyncAzureOpenAI,
max_retries=max_retries, ]
organization=organization, ]:
api_version=api_version, # Override to use Azure-specific client initialization
client=client, if isinstance(client, OpenAI) or isinstance(client, AsyncOpenAI):
_is_async=_is_async, client = None
)
return get_azure_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
api_version=api_version,
client=client,
_is_async=_is_async,
) )
if openai_client is None:
raise ValueError(
"AzureOpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
if _is_async is True:
if not isinstance(openai_client, AsyncAzureOpenAI):
raise ValueError(
"AzureOpenAI client is not an instance of AsyncAzureOpenAI. Make sure you passed an AsyncAzureOpenAI client."
)
return self.acreate_fine_tuning_job( # type: ignore
create_fine_tuning_job_data=create_fine_tuning_job_data,
openai_client=openai_client,
)
verbose_logger.debug(
"creating fine tuning job, args= %s", create_fine_tuning_job_data
)
response = openai_client.fine_tuning.jobs.create(**create_fine_tuning_job_data) # type: ignore
return response
async def acancel_fine_tuning_job(
self,
fine_tuning_job_id: str,
openai_client: AsyncAzureOpenAI,
) -> FineTuningJob:
response = await openai_client.fine_tuning.jobs.cancel(
fine_tuning_job_id=fine_tuning_job_id
)
return response
def cancel_fine_tuning_job(
self,
_is_async: bool,
fine_tuning_job_id: str,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str] = None,
api_version: Optional[str] = None,
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
):
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
get_azure_openai_client(
api_key=api_key,
api_base=api_base,
api_version=api_version,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
_is_async=_is_async,
)
)
if openai_client is None:
raise ValueError(
"AzureOpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
if _is_async is True:
if not isinstance(openai_client, AsyncAzureOpenAI):
raise ValueError(
"AzureOpenAI client is not an instance of AsyncAzureOpenAI. Make sure you passed an AsyncAzureOpenAI client."
)
return self.acancel_fine_tuning_job( # type: ignore
fine_tuning_job_id=fine_tuning_job_id,
openai_client=openai_client,
)
verbose_logger.debug("canceling fine tuning job, args= %s", fine_tuning_job_id)
response = openai_client.fine_tuning.jobs.cancel(
fine_tuning_job_id=fine_tuning_job_id
)
return response
async def alist_fine_tuning_jobs(
self,
openai_client: AsyncAzureOpenAI,
after: Optional[str] = None,
limit: Optional[int] = None,
):
response = await openai_client.fine_tuning.jobs.list(after=after, limit=limit) # type: ignore
return response
def list_fine_tuning_jobs(
self,
_is_async: bool,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str] = None,
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
api_version: Optional[str] = None,
after: Optional[str] = None,
limit: Optional[int] = None,
):
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
get_azure_openai_client(
api_key=api_key,
api_base=api_base,
api_version=api_version,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
_is_async=_is_async,
)
)
if openai_client is None:
raise ValueError(
"AzureOpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
if _is_async is True:
if not isinstance(openai_client, AsyncAzureOpenAI):
raise ValueError(
"AzureOpenAI client is not an instance of AsyncAzureOpenAI. Make sure you passed an AsyncAzureOpenAI client."
)
return self.alist_fine_tuning_jobs( # type: ignore
after=after,
limit=limit,
openai_client=openai_client,
)
verbose_logger.debug("list fine tuning job, after= %s, limit= %s", after, limit)
response = openai_client.fine_tuning.jobs.list(after=after, limit=limit) # type: ignore
return response

View file

@ -1,7 +1,7 @@
from typing import Any, Coroutine, Optional, Union from typing import Any, Coroutine, Optional, Union
import httpx import httpx
from openai import AsyncOpenAI, OpenAI from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI
from openai.types.fine_tuning import FineTuningJob from openai.types.fine_tuning import FineTuningJob
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
@ -22,11 +22,23 @@ class OpenAIFineTuningAPI:
timeout: Union[float, httpx.Timeout], timeout: Union[float, httpx.Timeout],
max_retries: Optional[int], max_retries: Optional[int],
organization: Optional[str], organization: Optional[str],
client: Optional[Union[OpenAI, AsyncOpenAI]] = None, client: Optional[
Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI]
] = None,
_is_async: bool = False, _is_async: bool = False,
) -> Optional[Union[OpenAI, AsyncOpenAI]]: api_version: Optional[str] = None,
) -> Optional[
Union[
OpenAI,
AsyncOpenAI,
AzureOpenAI,
AsyncAzureOpenAI,
]
]:
received_args = locals() received_args = locals()
openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = None openai_client: Optional[
Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI]
] = None
if client is None: if client is None:
data = {} data = {}
for k, v in received_args.items(): for k, v in received_args.items():
@ -48,7 +60,7 @@ class OpenAIFineTuningAPI:
async def acreate_fine_tuning_job( async def acreate_fine_tuning_job(
self, self,
create_fine_tuning_job_data: dict, create_fine_tuning_job_data: dict,
openai_client: AsyncOpenAI, openai_client: Union[AsyncOpenAI, AsyncAzureOpenAI],
) -> FineTuningJob: ) -> FineTuningJob:
response = await openai_client.fine_tuning.jobs.create( response = await openai_client.fine_tuning.jobs.create(
**create_fine_tuning_job_data **create_fine_tuning_job_data
@ -61,12 +73,17 @@ class OpenAIFineTuningAPI:
create_fine_tuning_job_data: dict, create_fine_tuning_job_data: dict,
api_key: Optional[str], api_key: Optional[str],
api_base: Optional[str], api_base: Optional[str],
api_version: Optional[str],
timeout: Union[float, httpx.Timeout], timeout: Union[float, httpx.Timeout],
max_retries: Optional[int], max_retries: Optional[int],
organization: Optional[str], organization: Optional[str],
client: Optional[Union[OpenAI, AsyncOpenAI]] = None, client: Optional[
Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI]
] = None,
) -> Union[FineTuningJob, Coroutine[Any, Any, FineTuningJob]]: ) -> Union[FineTuningJob, Coroutine[Any, Any, FineTuningJob]]:
openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client( openai_client: Optional[
Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI]
] = self.get_openai_client(
api_key=api_key, api_key=api_key,
api_base=api_base, api_base=api_base,
timeout=timeout, timeout=timeout,
@ -74,6 +91,7 @@ class OpenAIFineTuningAPI:
organization=organization, organization=organization,
client=client, client=client,
_is_async=_is_async, _is_async=_is_async,
api_version=api_version,
) )
if openai_client is None: if openai_client is None:
raise ValueError( raise ValueError(
@ -81,7 +99,7 @@ class OpenAIFineTuningAPI:
) )
if _is_async is True: if _is_async is True:
if not isinstance(openai_client, AsyncOpenAI): if not isinstance(openai_client, (AsyncOpenAI, AsyncAzureOpenAI)):
raise ValueError( raise ValueError(
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client." "OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
) )
@ -98,7 +116,7 @@ class OpenAIFineTuningAPI:
async def acancel_fine_tuning_job( async def acancel_fine_tuning_job(
self, self,
fine_tuning_job_id: str, fine_tuning_job_id: str,
openai_client: AsyncOpenAI, openai_client: Union[AsyncOpenAI, AsyncAzureOpenAI],
) -> FineTuningJob: ) -> FineTuningJob:
response = await openai_client.fine_tuning.jobs.cancel( response = await openai_client.fine_tuning.jobs.cancel(
fine_tuning_job_id=fine_tuning_job_id fine_tuning_job_id=fine_tuning_job_id
@ -111,12 +129,17 @@ class OpenAIFineTuningAPI:
fine_tuning_job_id: str, fine_tuning_job_id: str,
api_key: Optional[str], api_key: Optional[str],
api_base: Optional[str], api_base: Optional[str],
api_version: Optional[str],
timeout: Union[float, httpx.Timeout], timeout: Union[float, httpx.Timeout],
max_retries: Optional[int], max_retries: Optional[int],
organization: Optional[str], organization: Optional[str],
client: Optional[Union[OpenAI, AsyncOpenAI]] = None, client: Optional[
Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI]
] = None,
): ):
openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client( openai_client: Optional[
Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI]
] = self.get_openai_client(
api_key=api_key, api_key=api_key,
api_base=api_base, api_base=api_base,
timeout=timeout, timeout=timeout,
@ -124,6 +147,7 @@ class OpenAIFineTuningAPI:
organization=organization, organization=organization,
client=client, client=client,
_is_async=_is_async, _is_async=_is_async,
api_version=api_version,
) )
if openai_client is None: if openai_client is None:
raise ValueError( raise ValueError(
@ -131,7 +155,7 @@ class OpenAIFineTuningAPI:
) )
if _is_async is True: if _is_async is True:
if not isinstance(openai_client, AsyncOpenAI): if not isinstance(openai_client, (AsyncOpenAI, AsyncAzureOpenAI)):
raise ValueError( raise ValueError(
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client." "OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
) )
@ -147,7 +171,7 @@ class OpenAIFineTuningAPI:
async def alist_fine_tuning_jobs( async def alist_fine_tuning_jobs(
self, self,
openai_client: AsyncOpenAI, openai_client: Union[AsyncOpenAI, AsyncAzureOpenAI],
after: Optional[str] = None, after: Optional[str] = None,
limit: Optional[int] = None, limit: Optional[int] = None,
): ):
@ -159,14 +183,19 @@ class OpenAIFineTuningAPI:
_is_async: bool, _is_async: bool,
api_key: Optional[str], api_key: Optional[str],
api_base: Optional[str], api_base: Optional[str],
api_version: Optional[str],
timeout: Union[float, httpx.Timeout], timeout: Union[float, httpx.Timeout],
max_retries: Optional[int], max_retries: Optional[int],
organization: Optional[str], organization: Optional[str],
client: Optional[Union[OpenAI, AsyncOpenAI]] = None, client: Optional[
Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI]
] = None,
after: Optional[str] = None, after: Optional[str] = None,
limit: Optional[int] = None, limit: Optional[int] = None,
): ):
openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client( openai_client: Optional[
Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI]
] = self.get_openai_client(
api_key=api_key, api_key=api_key,
api_base=api_base, api_base=api_base,
timeout=timeout, timeout=timeout,
@ -174,6 +203,7 @@ class OpenAIFineTuningAPI:
organization=organization, organization=organization,
client=client, client=client,
_is_async=_is_async, _is_async=_is_async,
api_version=api_version,
) )
if openai_client is None: if openai_client is None:
raise ValueError( raise ValueError(
@ -181,7 +211,7 @@ class OpenAIFineTuningAPI:
) )
if _is_async is True: if _is_async is True:
if not isinstance(openai_client, AsyncOpenAI): if not isinstance(openai_client, (AsyncOpenAI, AsyncAzureOpenAI)):
raise ValueError( raise ValueError(
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client." "OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
) )
@ -193,4 +223,59 @@ class OpenAIFineTuningAPI:
verbose_logger.debug("list fine tuning job, after= %s, limit= %s", after, limit) verbose_logger.debug("list fine tuning job, after= %s, limit= %s", after, limit)
response = openai_client.fine_tuning.jobs.list(after=after, limit=limit) # type: ignore response = openai_client.fine_tuning.jobs.list(after=after, limit=limit) # type: ignore
return response return response
pass
async def aretrieve_fine_tuning_job(
self,
fine_tuning_job_id: str,
openai_client: Union[AsyncOpenAI, AsyncAzureOpenAI],
) -> FineTuningJob:
response = await openai_client.fine_tuning.jobs.retrieve(
fine_tuning_job_id=fine_tuning_job_id
)
return response
def retrieve_fine_tuning_job(
self,
_is_async: bool,
fine_tuning_job_id: str,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[
Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI]
] = None,
):
openai_client: Optional[
Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI]
] = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
_is_async=_is_async,
api_version=api_version,
)
if openai_client is None:
raise ValueError(
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
if _is_async is True:
if not isinstance(openai_client, AsyncOpenAI):
raise ValueError(
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
)
return self.aretrieve_fine_tuning_job( # type: ignore
fine_tuning_job_id=fine_tuning_job_id,
openai_client=openai_client,
)
verbose_logger.debug("retrieving fine tuning job, id= %s", fine_tuning_job_id)
response = openai_client.fine_tuning.jobs.retrieve(
fine_tuning_job_id=fine_tuning_job_id
)
return response

View file

@ -9,12 +9,13 @@ import asyncio
import traceback import traceback
from typing import Optional from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status from fastapi import APIRouter, Depends, Request, Response
import litellm import litellm
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import * from litellm.proxy._types import *
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.utils import handle_exception_on_proxy
router = APIRouter() router = APIRouter()
@ -171,21 +172,105 @@ async def create_fine_tuning_job(
) )
) )
verbose_proxy_logger.debug(traceback.format_exc()) verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException): raise handle_exception_on_proxy(e)
raise ProxyException(
message=getattr(e, "message", str(e.detail)),
type=getattr(e, "type", "None"), @router.get(
param=getattr(e, "param", "None"), "/v1/fine_tuning/jobs/{fine_tuning_job_id:path}",
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), dependencies=[Depends(user_api_key_auth)],
tags=["fine-tuning"],
summary="✨ (Enterprise) Retrieve Fine-Tuning Job",
)
@router.get(
"/fine_tuning/jobs/{fine_tuning_job_id:path}",
dependencies=[Depends(user_api_key_auth)],
tags=["fine-tuning"],
summary="✨ (Enterprise) Retrieve Fine-Tuning Job",
)
async def retrieve_fine_tuning_job(
request: Request,
fastapi_response: Response,
fine_tuning_job_id: str,
custom_llm_provider: Literal["openai", "azure"],
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Retrieves a fine-tuning job.
This is the equivalent of GET https://api.openai.com/v1/fine_tuning/jobs/{fine_tuning_job_id}
Supported Query Params:
- `custom_llm_provider`: Name of the LiteLLM provider
- `fine_tuning_job_id`: The ID of the fine-tuning job to retrieve.
"""
from litellm.proxy.proxy_server import (
add_litellm_data_to_request,
general_settings,
get_custom_headers,
premium_user,
proxy_config,
proxy_logging_obj,
version,
)
data: dict = {}
try:
if premium_user is not True:
raise ValueError(
f"Only premium users can use this endpoint + {CommonProxyErrors.not_premium_user.value}"
) )
else: # Include original request and headers in the data
error_msg = f"{str(e)}" data = await add_litellm_data_to_request(
raise ProxyException( data=data,
message=getattr(e, "message", error_msg), request=request,
type=getattr(e, "type", "None"), general_settings=general_settings,
param=getattr(e, "param", "None"), user_api_key_dict=user_api_key_dict,
code=getattr(e, "status_code", 500), version=version,
proxy_config=proxy_config,
)
# get configs for custom_llm_provider
llm_provider_config = get_fine_tuning_provider_config(
custom_llm_provider=custom_llm_provider
)
if llm_provider_config is not None:
data.update(llm_provider_config)
response = await litellm.aretrieve_fine_tuning_job(
**data,
fine_tuning_job_id=fine_tuning_job_id,
)
### RESPONSE HEADERS ###
hidden_params = getattr(response, "_hidden_params", {}) or {}
model_id = hidden_params.get("model_id", None) or ""
cache_key = hidden_params.get("cache_key", None) or ""
api_base = hidden_params.get("api_base", None) or ""
fastapi_response.headers.update(
get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,
api_base=api_base,
version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
) )
)
return response
except Exception as e:
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
)
verbose_proxy_logger.error(
"litellm.proxy.proxy_server.list_fine_tuning_jobs(): Exception occurred - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
raise handle_exception_on_proxy(e)
@router.get( @router.get(
@ -286,21 +371,7 @@ async def list_fine_tuning_jobs(
) )
) )
verbose_proxy_logger.debug(traceback.format_exc()) verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException): raise handle_exception_on_proxy(e)
raise ProxyException(
message=getattr(e, "message", str(e.detail)),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
)
else:
error_msg = f"{str(e)}"
raise ProxyException(
message=getattr(e, "message", error_msg),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", 500),
)
@router.post( @router.post(
@ -315,7 +386,7 @@ async def list_fine_tuning_jobs(
tags=["fine-tuning"], tags=["fine-tuning"],
summary="✨ (Enterprise) Cancel Fine-Tuning Jobs", summary="✨ (Enterprise) Cancel Fine-Tuning Jobs",
) )
async def retrieve_fine_tuning_job( async def cancel_fine_tuning_job(
request: Request, request: Request,
fastapi_response: Response, fastapi_response: Response,
fine_tuning_job_id: str, fine_tuning_job_id: str,
@ -402,18 +473,4 @@ async def retrieve_fine_tuning_job(
) )
) )
verbose_proxy_logger.debug(traceback.format_exc()) verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException): raise handle_exception_on_proxy(e)
raise ProxyException(
message=getattr(e, "message", str(e.detail)),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
)
else:
error_msg = f"{str(e)}"
raise ProxyException(
message=getattr(e, "message", error_msg),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", 500),
)

View file

@ -148,6 +148,12 @@ async def test_create_fine_tune_jobs_async():
print("response from litellm.list_fine_tuning_jobs=", ft_jobs) print("response from litellm.list_fine_tuning_jobs=", ft_jobs)
assert len(list(ft_jobs)) > 0 assert len(list(ft_jobs)) > 0
# retrieve fine tuning job
response = await litellm.aretrieve_fine_tuning_job(
fine_tuning_job_id=create_fine_tuning_response.id,
)
print("response from litellm.retrieve_fine_tuning_job=", response)
# delete file # delete file
await litellm.afile_delete( await litellm.afile_delete(