From a6e62da9fb401118ae443a75d43754a289262b7b Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Wed, 31 Jul 2024 16:19:15 -0700 Subject: [PATCH] fix cancel ft job route --- litellm/proxy/fine_tuning_endpoints/endpoints.py | 15 ++++++++++++++- tests/test_openai_fine_tuning.py | 6 +++++- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/litellm/proxy/fine_tuning_endpoints/endpoints.py b/litellm/proxy/fine_tuning_endpoints/endpoints.py index eadb69f7a..e95c2bae3 100644 --- a/litellm/proxy/fine_tuning_endpoints/endpoints.py +++ b/litellm/proxy/fine_tuning_endpoints/endpoints.py @@ -197,6 +197,11 @@ async def create_fine_tuning_job( dependencies=[Depends(user_api_key_auth)], tags=["fine-tuning"], ) +@router.get( + "/fine_tuning/jobs", + dependencies=[Depends(user_api_key_auth)], + tags=["fine-tuning"], +) async def list_fine_tuning_jobs( request: Request, fastapi_response: Response, @@ -299,10 +304,14 @@ async def list_fine_tuning_jobs( dependencies=[Depends(user_api_key_auth)], tags=["fine-tuning"], ) +@router.post( + "/fine_tuning/jobs/{fine_tuning_job_id:path}/cancel", + dependencies=[Depends(user_api_key_auth)], + tags=["fine-tuning"], +) async def retrieve_fine_tuning_job( request: Request, fastapi_response: Response, - custom_llm_provider: Literal["openai", "azure"], fine_tuning_job_id: str, user_api_key_dict: dict = Depends(user_api_key_auth), ): @@ -336,6 +345,10 @@ async def retrieve_fine_tuning_job( proxy_config=proxy_config, ) + request_body = await request.json() + + custom_llm_provider = request_body.get("custom_llm_provider", None) + # get configs for custom_llm_provider llm_provider_config = get_fine_tuning_provider_config( custom_llm_provider=custom_llm_provider diff --git a/tests/test_openai_fine_tuning.py b/tests/test_openai_fine_tuning.py index db41e4543..6d67d4144 100644 --- a/tests/test_openai_fine_tuning.py +++ b/tests/test_openai_fine_tuning.py @@ -33,17 +33,21 @@ async def test_openai_fine_tuning(): print("response from ft job={}".format(ft_job)) # response from example endpoint - assert ft_job.id == "file-abc123" + assert ft_job.id == "ftjob-abc123" # list all fine tuning jobs list_ft_jobs = await client.fine_tuning.jobs.list( extra_query={"custom_llm_provider": "azure"} ) + print("list of ft jobs={}".format(list_ft_jobs)) + # cancel specific fine tuning job cancel_ft_job = await client.fine_tuning.jobs.cancel( fine_tuning_job_id="123", extra_body={"custom_llm_provider": "azure"}, ) + print("response from cancel ft job={}".format(cancel_ft_job)) + assert cancel_ft_job.id is not None