diff --git a/litellm/fine_tuning/main.py b/litellm/fine_tuning/main.py index 7672ad43a9..179e600202 100644 --- a/litellm/fine_tuning/main.py +++ b/litellm/fine_tuning/main.py @@ -171,6 +171,7 @@ def create_fine_tuning_job( response = openai_fine_tuning_apis_instance.create_fine_tuning_job( api_base=api_base, api_key=api_key, + api_version=optional_params.api_version, organization=organization, create_fine_tuning_job_data=create_fine_tuning_job_data_dict, timeout=timeout, @@ -223,6 +224,7 @@ def create_fine_tuning_job( timeout=timeout, max_retries=optional_params.max_retries, _is_async=_is_async, + organization=optional_params.organization, ) elif custom_llm_provider == "vertex_ai": api_base = optional_params.api_base or "" @@ -279,7 +281,7 @@ def create_fine_tuning_job( async def acancel_fine_tuning_job( fine_tuning_job_id: str, - custom_llm_provider: Literal["openai"] = "openai", + custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai", extra_headers: Optional[Dict[str, str]] = None, extra_body: Optional[Dict[str, str]] = None, **kwargs, @@ -374,6 +376,7 @@ def cancel_fine_tuning_job( response = openai_fine_tuning_apis_instance.cancel_fine_tuning_job( api_base=api_base, api_key=api_key, + api_version=optional_params.api_version, organization=organization, fine_tuning_job_id=fine_tuning_job_id, timeout=timeout, @@ -412,6 +415,7 @@ def cancel_fine_tuning_job( timeout=timeout, max_retries=optional_params.max_retries, _is_async=_is_async, + organization=optional_params.organization, ) else: raise litellm.exceptions.BadRequestError( @@ -434,7 +438,7 @@ def cancel_fine_tuning_job( async def alist_fine_tuning_jobs( after: Optional[str] = None, limit: Optional[int] = None, - custom_llm_provider: Literal["openai"] = "openai", + custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai", extra_headers: Optional[Dict[str, str]] = None, extra_body: Optional[Dict[str, str]] = None, **kwargs, @@ -533,6 +537,7 @@ def list_fine_tuning_jobs( response = openai_fine_tuning_apis_instance.list_fine_tuning_jobs( api_base=api_base, api_key=api_key, + api_version=optional_params.api_version, organization=organization, after=after, limit=limit, @@ -573,6 +578,7 @@ def list_fine_tuning_jobs( timeout=timeout, max_retries=optional_params.max_retries, _is_async=_is_async, + organization=optional_params.organization, ) else: raise litellm.exceptions.BadRequestError( @@ -590,3 +596,153 @@ def list_fine_tuning_jobs( return response except Exception as e: raise e + + +async def aretrieve_fine_tuning_job( + fine_tuning_job_id: str, + custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai", + extra_headers: Optional[Dict[str, str]] = None, + extra_body: Optional[Dict[str, str]] = None, + **kwargs, +) -> FineTuningJob: + """ + Async: Get info about a fine-tuning job. + """ + try: + loop = asyncio.get_event_loop() + kwargs["aretrieve_fine_tuning_job"] = True + + # Use a partial function to pass your keyword arguments + func = partial( + retrieve_fine_tuning_job, + fine_tuning_job_id, + custom_llm_provider, + extra_headers, + extra_body, + **kwargs, + ) + + # Add the context to the function + ctx = contextvars.copy_context() + func_with_context = partial(ctx.run, func) + init_response = await loop.run_in_executor(None, func_with_context) + if asyncio.iscoroutine(init_response): + response = await init_response + else: + response = init_response # type: ignore + return response + except Exception as e: + raise e + + +def retrieve_fine_tuning_job( + fine_tuning_job_id: str, + custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai", + extra_headers: Optional[Dict[str, str]] = None, + extra_body: Optional[Dict[str, str]] = None, + **kwargs, +) -> Union[FineTuningJob, Coroutine[Any, Any, FineTuningJob]]: + """ + Get info about a fine-tuning job. + """ + try: + optional_params = GenericLiteLLMParams(**kwargs) + ### TIMEOUT LOGIC ### + timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600 + # set timeout for 10 minutes by default + + if ( + timeout is not None + and isinstance(timeout, httpx.Timeout) + and supports_httpx_timeout(custom_llm_provider) is False + ): + read_timeout = timeout.read or 600 + timeout = read_timeout # default 10 min timeout + elif timeout is not None and not isinstance(timeout, httpx.Timeout): + timeout = float(timeout) # type: ignore + elif timeout is None: + timeout = 600.0 + + _is_async = kwargs.pop("aretrieve_fine_tuning_job", False) is True + + # OpenAI + if custom_llm_provider == "openai": + api_base = ( + optional_params.api_base + or litellm.api_base + or os.getenv("OPENAI_API_BASE") + or "https://api.openai.com/v1" + ) + organization = ( + optional_params.organization + or litellm.organization + or os.getenv("OPENAI_ORGANIZATION", None) + or None + ) + api_key = ( + optional_params.api_key + or litellm.api_key + or litellm.openai_key + or os.getenv("OPENAI_API_KEY") + ) + + response = openai_fine_tuning_apis_instance.retrieve_fine_tuning_job( + api_base=api_base, + api_key=api_key, + api_version=optional_params.api_version, + organization=organization, + fine_tuning_job_id=fine_tuning_job_id, + timeout=timeout, + max_retries=optional_params.max_retries, + _is_async=_is_async, + ) + # Azure OpenAI + elif custom_llm_provider == "azure": + api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore + + api_version = ( + optional_params.api_version + or litellm.api_version + or get_secret_str("AZURE_API_VERSION") + ) # type: ignore + + api_key = ( + optional_params.api_key + or litellm.api_key + or litellm.azure_key + or get_secret_str("AZURE_OPENAI_API_KEY") + or get_secret_str("AZURE_API_KEY") + ) # type: ignore + + extra_body = optional_params.get("extra_body", {}) + if extra_body is not None: + extra_body.pop("azure_ad_token", None) + else: + get_secret_str("AZURE_AD_TOKEN") # type: ignore + + response = azure_fine_tuning_apis_instance.retrieve_fine_tuning_job( + api_base=api_base, + api_key=api_key, + api_version=api_version, + fine_tuning_job_id=fine_tuning_job_id, + timeout=timeout, + max_retries=optional_params.max_retries, + _is_async=_is_async, + organization=optional_params.organization, + ) + else: + raise litellm.exceptions.BadRequestError( + message="LiteLLM doesn't support {} for 'retrieve_fine_tuning_job'. Only 'openai' and 'azure' are supported.".format( + custom_llm_provider + ), + model="n/a", + llm_provider=custom_llm_provider, + response=httpx.Response( + status_code=400, + content="Unsupported provider", + request=httpx.Request(method="retrieve_fine_tuning_job", url="https://github.com/BerriAI/litellm"), # type: ignore + ), + ) + return response + except Exception as e: + raise e diff --git a/litellm/llms/azure/fine_tuning/handler.py b/litellm/llms/azure/fine_tuning/handler.py index c55c53f907..c34b181eff 100644 --- a/litellm/llms/azure/fine_tuning/handler.py +++ b/litellm/llms/azure/fine_tuning/handler.py @@ -1,179 +1,48 @@ -from typing import Any, Coroutine, Optional, Union +from typing import Optional, Union import httpx -from openai import AsyncAzureOpenAI, AzureOpenAI -from openai.types.fine_tuning import FineTuningJob +from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI -from litellm._logging import verbose_logger from litellm.llms.azure.files.handler import get_azure_openai_client -from litellm.llms.base import BaseLLM +from litellm.llms.openai.fine_tuning.handler import OpenAIFineTuningAPI -class AzureOpenAIFineTuningAPI(BaseLLM): +class AzureOpenAIFineTuningAPI(OpenAIFineTuningAPI): """ - AzureOpenAI methods to support for batches + AzureOpenAI methods to support fine tuning, inherits from OpenAIFineTuningAPI. """ - def __init__(self) -> None: - super().__init__() - - async def acreate_fine_tuning_job( + def get_openai_client( self, - create_fine_tuning_job_data: dict, - openai_client: AsyncAzureOpenAI, - ) -> FineTuningJob: - response = await openai_client.fine_tuning.jobs.create( - **create_fine_tuning_job_data # type: ignore - ) - return response - - def create_fine_tuning_job( - self, - _is_async: bool, - create_fine_tuning_job_data: dict, api_key: Optional[str], api_base: Optional[str], timeout: Union[float, httpx.Timeout], max_retries: Optional[int], - organization: Optional[str] = None, - client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None, + organization: Optional[str], + client: Optional[ + Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI] + ] = None, + _is_async: bool = False, api_version: Optional[str] = None, - ) -> Union[FineTuningJob, Coroutine[Any, Any, FineTuningJob]]: - openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = ( - get_azure_openai_client( - api_key=api_key, - api_base=api_base, - timeout=timeout, - max_retries=max_retries, - organization=organization, - api_version=api_version, - client=client, - _is_async=_is_async, - ) + ) -> Optional[ + Union[ + OpenAI, + AsyncOpenAI, + AzureOpenAI, + AsyncAzureOpenAI, + ] + ]: + # Override to use Azure-specific client initialization + if isinstance(client, OpenAI) or isinstance(client, AsyncOpenAI): + client = None + + return get_azure_openai_client( + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + api_version=api_version, + client=client, + _is_async=_is_async, ) - if openai_client is None: - raise ValueError( - "AzureOpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." - ) - - if _is_async is True: - if not isinstance(openai_client, AsyncAzureOpenAI): - raise ValueError( - "AzureOpenAI client is not an instance of AsyncAzureOpenAI. Make sure you passed an AsyncAzureOpenAI client." - ) - return self.acreate_fine_tuning_job( # type: ignore - create_fine_tuning_job_data=create_fine_tuning_job_data, - openai_client=openai_client, - ) - verbose_logger.debug( - "creating fine tuning job, args= %s", create_fine_tuning_job_data - ) - response = openai_client.fine_tuning.jobs.create(**create_fine_tuning_job_data) # type: ignore - return response - - async def acancel_fine_tuning_job( - self, - fine_tuning_job_id: str, - openai_client: AsyncAzureOpenAI, - ) -> FineTuningJob: - response = await openai_client.fine_tuning.jobs.cancel( - fine_tuning_job_id=fine_tuning_job_id - ) - return response - - def cancel_fine_tuning_job( - self, - _is_async: bool, - fine_tuning_job_id: str, - api_key: Optional[str], - api_base: Optional[str], - timeout: Union[float, httpx.Timeout], - max_retries: Optional[int], - organization: Optional[str] = None, - api_version: Optional[str] = None, - client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None, - ): - openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = ( - get_azure_openai_client( - api_key=api_key, - api_base=api_base, - api_version=api_version, - timeout=timeout, - max_retries=max_retries, - organization=organization, - client=client, - _is_async=_is_async, - ) - ) - if openai_client is None: - raise ValueError( - "AzureOpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." - ) - - if _is_async is True: - if not isinstance(openai_client, AsyncAzureOpenAI): - raise ValueError( - "AzureOpenAI client is not an instance of AsyncAzureOpenAI. Make sure you passed an AsyncAzureOpenAI client." - ) - return self.acancel_fine_tuning_job( # type: ignore - fine_tuning_job_id=fine_tuning_job_id, - openai_client=openai_client, - ) - verbose_logger.debug("canceling fine tuning job, args= %s", fine_tuning_job_id) - response = openai_client.fine_tuning.jobs.cancel( - fine_tuning_job_id=fine_tuning_job_id - ) - return response - - async def alist_fine_tuning_jobs( - self, - openai_client: AsyncAzureOpenAI, - after: Optional[str] = None, - limit: Optional[int] = None, - ): - response = await openai_client.fine_tuning.jobs.list(after=after, limit=limit) # type: ignore - return response - - def list_fine_tuning_jobs( - self, - _is_async: bool, - api_key: Optional[str], - api_base: Optional[str], - timeout: Union[float, httpx.Timeout], - max_retries: Optional[int], - organization: Optional[str] = None, - client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None, - api_version: Optional[str] = None, - after: Optional[str] = None, - limit: Optional[int] = None, - ): - openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = ( - get_azure_openai_client( - api_key=api_key, - api_base=api_base, - api_version=api_version, - timeout=timeout, - max_retries=max_retries, - organization=organization, - client=client, - _is_async=_is_async, - ) - ) - if openai_client is None: - raise ValueError( - "AzureOpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." - ) - - if _is_async is True: - if not isinstance(openai_client, AsyncAzureOpenAI): - raise ValueError( - "AzureOpenAI client is not an instance of AsyncAzureOpenAI. Make sure you passed an AsyncAzureOpenAI client." - ) - return self.alist_fine_tuning_jobs( # type: ignore - after=after, - limit=limit, - openai_client=openai_client, - ) - verbose_logger.debug("list fine tuning job, after= %s, limit= %s", after, limit) - response = openai_client.fine_tuning.jobs.list(after=after, limit=limit) # type: ignore - return response diff --git a/litellm/llms/openai/fine_tuning/handler.py b/litellm/llms/openai/fine_tuning/handler.py index a3f088a861..b7eab8e5fd 100644 --- a/litellm/llms/openai/fine_tuning/handler.py +++ b/litellm/llms/openai/fine_tuning/handler.py @@ -1,7 +1,7 @@ from typing import Any, Coroutine, Optional, Union import httpx -from openai import AsyncOpenAI, OpenAI +from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI from openai.types.fine_tuning import FineTuningJob from litellm._logging import verbose_logger @@ -22,11 +22,23 @@ class OpenAIFineTuningAPI: timeout: Union[float, httpx.Timeout], max_retries: Optional[int], organization: Optional[str], - client: Optional[Union[OpenAI, AsyncOpenAI]] = None, + client: Optional[ + Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI] + ] = None, _is_async: bool = False, - ) -> Optional[Union[OpenAI, AsyncOpenAI]]: + api_version: Optional[str] = None, + ) -> Optional[ + Union[ + OpenAI, + AsyncOpenAI, + AzureOpenAI, + AsyncAzureOpenAI, + ] + ]: received_args = locals() - openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = None + openai_client: Optional[ + Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI] + ] = None if client is None: data = {} for k, v in received_args.items(): @@ -48,7 +60,7 @@ class OpenAIFineTuningAPI: async def acreate_fine_tuning_job( self, create_fine_tuning_job_data: dict, - openai_client: AsyncOpenAI, + openai_client: Union[AsyncOpenAI, AsyncAzureOpenAI], ) -> FineTuningJob: response = await openai_client.fine_tuning.jobs.create( **create_fine_tuning_job_data @@ -61,12 +73,17 @@ class OpenAIFineTuningAPI: create_fine_tuning_job_data: dict, api_key: Optional[str], api_base: Optional[str], + api_version: Optional[str], timeout: Union[float, httpx.Timeout], max_retries: Optional[int], organization: Optional[str], - client: Optional[Union[OpenAI, AsyncOpenAI]] = None, + client: Optional[ + Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI] + ] = None, ) -> Union[FineTuningJob, Coroutine[Any, Any, FineTuningJob]]: - openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client( + openai_client: Optional[ + Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI] + ] = self.get_openai_client( api_key=api_key, api_base=api_base, timeout=timeout, @@ -74,6 +91,7 @@ class OpenAIFineTuningAPI: organization=organization, client=client, _is_async=_is_async, + api_version=api_version, ) if openai_client is None: raise ValueError( @@ -81,7 +99,7 @@ class OpenAIFineTuningAPI: ) if _is_async is True: - if not isinstance(openai_client, AsyncOpenAI): + if not isinstance(openai_client, (AsyncOpenAI, AsyncAzureOpenAI)): raise ValueError( "OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client." ) @@ -98,7 +116,7 @@ class OpenAIFineTuningAPI: async def acancel_fine_tuning_job( self, fine_tuning_job_id: str, - openai_client: AsyncOpenAI, + openai_client: Union[AsyncOpenAI, AsyncAzureOpenAI], ) -> FineTuningJob: response = await openai_client.fine_tuning.jobs.cancel( fine_tuning_job_id=fine_tuning_job_id @@ -111,12 +129,17 @@ class OpenAIFineTuningAPI: fine_tuning_job_id: str, api_key: Optional[str], api_base: Optional[str], + api_version: Optional[str], timeout: Union[float, httpx.Timeout], max_retries: Optional[int], organization: Optional[str], - client: Optional[Union[OpenAI, AsyncOpenAI]] = None, + client: Optional[ + Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI] + ] = None, ): - openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client( + openai_client: Optional[ + Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI] + ] = self.get_openai_client( api_key=api_key, api_base=api_base, timeout=timeout, @@ -124,6 +147,7 @@ class OpenAIFineTuningAPI: organization=organization, client=client, _is_async=_is_async, + api_version=api_version, ) if openai_client is None: raise ValueError( @@ -131,7 +155,7 @@ class OpenAIFineTuningAPI: ) if _is_async is True: - if not isinstance(openai_client, AsyncOpenAI): + if not isinstance(openai_client, (AsyncOpenAI, AsyncAzureOpenAI)): raise ValueError( "OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client." ) @@ -147,7 +171,7 @@ class OpenAIFineTuningAPI: async def alist_fine_tuning_jobs( self, - openai_client: AsyncOpenAI, + openai_client: Union[AsyncOpenAI, AsyncAzureOpenAI], after: Optional[str] = None, limit: Optional[int] = None, ): @@ -159,14 +183,19 @@ class OpenAIFineTuningAPI: _is_async: bool, api_key: Optional[str], api_base: Optional[str], + api_version: Optional[str], timeout: Union[float, httpx.Timeout], max_retries: Optional[int], organization: Optional[str], - client: Optional[Union[OpenAI, AsyncOpenAI]] = None, + client: Optional[ + Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI] + ] = None, after: Optional[str] = None, limit: Optional[int] = None, ): - openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client( + openai_client: Optional[ + Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI] + ] = self.get_openai_client( api_key=api_key, api_base=api_base, timeout=timeout, @@ -174,6 +203,7 @@ class OpenAIFineTuningAPI: organization=organization, client=client, _is_async=_is_async, + api_version=api_version, ) if openai_client is None: raise ValueError( @@ -181,7 +211,7 @@ class OpenAIFineTuningAPI: ) if _is_async is True: - if not isinstance(openai_client, AsyncOpenAI): + if not isinstance(openai_client, (AsyncOpenAI, AsyncAzureOpenAI)): raise ValueError( "OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client." ) @@ -193,4 +223,59 @@ class OpenAIFineTuningAPI: verbose_logger.debug("list fine tuning job, after= %s, limit= %s", after, limit) response = openai_client.fine_tuning.jobs.list(after=after, limit=limit) # type: ignore return response - pass + + async def aretrieve_fine_tuning_job( + self, + fine_tuning_job_id: str, + openai_client: Union[AsyncOpenAI, AsyncAzureOpenAI], + ) -> FineTuningJob: + response = await openai_client.fine_tuning.jobs.retrieve( + fine_tuning_job_id=fine_tuning_job_id + ) + return response + + def retrieve_fine_tuning_job( + self, + _is_async: bool, + fine_tuning_job_id: str, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + client: Optional[ + Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI] + ] = None, + ): + openai_client: Optional[ + Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI] + ] = self.get_openai_client( + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + _is_async=_is_async, + api_version=api_version, + ) + if openai_client is None: + raise ValueError( + "OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." + ) + + if _is_async is True: + if not isinstance(openai_client, AsyncOpenAI): + raise ValueError( + "OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client." + ) + return self.aretrieve_fine_tuning_job( # type: ignore + fine_tuning_job_id=fine_tuning_job_id, + openai_client=openai_client, + ) + verbose_logger.debug("retrieving fine tuning job, id= %s", fine_tuning_job_id) + response = openai_client.fine_tuning.jobs.retrieve( + fine_tuning_job_id=fine_tuning_job_id + ) + return response diff --git a/litellm/proxy/fine_tuning_endpoints/endpoints.py b/litellm/proxy/fine_tuning_endpoints/endpoints.py index b7b31c8408..63b0546bfa 100644 --- a/litellm/proxy/fine_tuning_endpoints/endpoints.py +++ b/litellm/proxy/fine_tuning_endpoints/endpoints.py @@ -9,12 +9,13 @@ import asyncio import traceback from typing import Optional -from fastapi import APIRouter, Depends, HTTPException, Request, Response, status +from fastapi import APIRouter, Depends, Request, Response import litellm from litellm._logging import verbose_proxy_logger from litellm.proxy._types import * from litellm.proxy.auth.user_api_key_auth import user_api_key_auth +from litellm.proxy.utils import handle_exception_on_proxy router = APIRouter() @@ -171,21 +172,105 @@ async def create_fine_tuning_job( ) ) verbose_proxy_logger.debug(traceback.format_exc()) - if isinstance(e, HTTPException): - raise ProxyException( - message=getattr(e, "message", str(e.detail)), - type=getattr(e, "type", "None"), - param=getattr(e, "param", "None"), - code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), + raise handle_exception_on_proxy(e) + + +@router.get( + "/v1/fine_tuning/jobs/{fine_tuning_job_id:path}", + dependencies=[Depends(user_api_key_auth)], + tags=["fine-tuning"], + summary="✨ (Enterprise) Retrieve Fine-Tuning Job", +) +@router.get( + "/fine_tuning/jobs/{fine_tuning_job_id:path}", + dependencies=[Depends(user_api_key_auth)], + tags=["fine-tuning"], + summary="✨ (Enterprise) Retrieve Fine-Tuning Job", +) +async def retrieve_fine_tuning_job( + request: Request, + fastapi_response: Response, + fine_tuning_job_id: str, + custom_llm_provider: Literal["openai", "azure"], + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Retrieves a fine-tuning job. + This is the equivalent of GET https://api.openai.com/v1/fine_tuning/jobs/{fine_tuning_job_id} + + Supported Query Params: + - `custom_llm_provider`: Name of the LiteLLM provider + - `fine_tuning_job_id`: The ID of the fine-tuning job to retrieve. + """ + from litellm.proxy.proxy_server import ( + add_litellm_data_to_request, + general_settings, + get_custom_headers, + premium_user, + proxy_config, + proxy_logging_obj, + version, + ) + + data: dict = {} + try: + if premium_user is not True: + raise ValueError( + f"Only premium users can use this endpoint + {CommonProxyErrors.not_premium_user.value}" ) - else: - error_msg = f"{str(e)}" - raise ProxyException( - message=getattr(e, "message", error_msg), - type=getattr(e, "type", "None"), - param=getattr(e, "param", "None"), - code=getattr(e, "status_code", 500), + # Include original request and headers in the data + data = await add_litellm_data_to_request( + data=data, + request=request, + general_settings=general_settings, + user_api_key_dict=user_api_key_dict, + version=version, + proxy_config=proxy_config, + ) + + # get configs for custom_llm_provider + llm_provider_config = get_fine_tuning_provider_config( + custom_llm_provider=custom_llm_provider + ) + + if llm_provider_config is not None: + data.update(llm_provider_config) + + response = await litellm.aretrieve_fine_tuning_job( + **data, + fine_tuning_job_id=fine_tuning_job_id, + ) + + ### RESPONSE HEADERS ### + hidden_params = getattr(response, "_hidden_params", {}) or {} + model_id = hidden_params.get("model_id", None) or "" + cache_key = hidden_params.get("cache_key", None) or "" + api_base = hidden_params.get("api_base", None) or "" + + fastapi_response.headers.update( + get_custom_headers( + user_api_key_dict=user_api_key_dict, + model_id=model_id, + cache_key=cache_key, + api_base=api_base, + version=version, + model_region=getattr(user_api_key_dict, "allowed_model_region", ""), ) + ) + + return response + + except Exception as e: + await proxy_logging_obj.post_call_failure_hook( + user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data + ) + verbose_proxy_logger.error( + "litellm.proxy.proxy_server.list_fine_tuning_jobs(): Exception occurred - {}".format( + str(e) + ) + ) + verbose_proxy_logger.debug(traceback.format_exc()) + raise handle_exception_on_proxy(e) @router.get( @@ -286,21 +371,7 @@ async def list_fine_tuning_jobs( ) ) verbose_proxy_logger.debug(traceback.format_exc()) - if isinstance(e, HTTPException): - raise ProxyException( - message=getattr(e, "message", str(e.detail)), - type=getattr(e, "type", "None"), - param=getattr(e, "param", "None"), - code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), - ) - else: - error_msg = f"{str(e)}" - raise ProxyException( - message=getattr(e, "message", error_msg), - type=getattr(e, "type", "None"), - param=getattr(e, "param", "None"), - code=getattr(e, "status_code", 500), - ) + raise handle_exception_on_proxy(e) @router.post( @@ -315,7 +386,7 @@ async def list_fine_tuning_jobs( tags=["fine-tuning"], summary="✨ (Enterprise) Cancel Fine-Tuning Jobs", ) -async def retrieve_fine_tuning_job( +async def cancel_fine_tuning_job( request: Request, fastapi_response: Response, fine_tuning_job_id: str, @@ -402,18 +473,4 @@ async def retrieve_fine_tuning_job( ) ) verbose_proxy_logger.debug(traceback.format_exc()) - if isinstance(e, HTTPException): - raise ProxyException( - message=getattr(e, "message", str(e.detail)), - type=getattr(e, "type", "None"), - param=getattr(e, "param", "None"), - code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), - ) - else: - error_msg = f"{str(e)}" - raise ProxyException( - message=getattr(e, "message", error_msg), - type=getattr(e, "type", "None"), - param=getattr(e, "param", "None"), - code=getattr(e, "status_code", 500), - ) + raise handle_exception_on_proxy(e) diff --git a/tests/batches_tests/test_fine_tuning_api.py b/tests/batches_tests/test_fine_tuning_api.py index cc53f599fa..1d11b580c0 100644 --- a/tests/batches_tests/test_fine_tuning_api.py +++ b/tests/batches_tests/test_fine_tuning_api.py @@ -148,6 +148,12 @@ async def test_create_fine_tune_jobs_async(): print("response from litellm.list_fine_tuning_jobs=", ft_jobs) assert len(list(ft_jobs)) > 0 + # retrieve fine tuning job + response = await litellm.aretrieve_fine_tuning_job( + fine_tuning_job_id=create_fine_tuning_response.id, + ) + print("response from litellm.retrieve_fine_tuning_job=", response) + # delete file await litellm.afile_delete(