From e4c73036fc1e785b384fdeb9439afdc4d063e9fe Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Wed, 31 Jul 2024 13:32:18 -0700 Subject: [PATCH] validation for passing config file --- .../proxy/fine_tuning_endpoints/endpoints.py | 42 +++++++++++++++---- litellm/types/llms/openai.py | 4 ++ 2 files changed, 37 insertions(+), 9 deletions(-) diff --git a/litellm/proxy/fine_tuning_endpoints/endpoints.py b/litellm/proxy/fine_tuning_endpoints/endpoints.py index b15de075f..9c58337d1 100644 --- a/litellm/proxy/fine_tuning_endpoints/endpoints.py +++ b/litellm/proxy/fine_tuning_endpoints/endpoints.py @@ -33,11 +33,29 @@ from litellm.proxy.auth.user_api_key_auth import user_api_key_auth router = APIRouter() -from litellm.llms.fine_tuning_apis.openai import ( - FineTuningJob, - FineTuningJobCreate, - OpenAIFineTuningAPI, -) +from litellm.types.llms.openai import LiteLLMFineTuningJobCreate + +fine_tuning_config = None + + +def set_fine_tuning_config(config): + global fine_tuning_config + fine_tuning_config = config + + +# Function to search for specific custom_llm_provider and return its configuration +def get_provider_config( + custom_llm_provider: str, +): + global fine_tuning_config + if fine_tuning_config is None: + raise ValueError( + "fine_tuning_config is not set, set it on your config.yaml file." + ) + for setting in fine_tuning_config: + if setting.get("custom_llm_provider") == custom_llm_provider: + return setting + return None @router.post( @@ -53,7 +71,7 @@ from litellm.llms.fine_tuning_apis.openai import ( async def create_fine_tuning_job( request: Request, fastapi_response: Response, - fine_tuning_request: FineTuningJobCreate, + fine_tuning_request: LiteLLMFineTuningJobCreate, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), ): """ @@ -103,11 +121,17 @@ async def create_fine_tuning_job( proxy_config=proxy_config, ) - # For now, use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for fine-tuning - response = await litellm.acreate_fine_tuning_job( - custom_llm_provider="openai", **data + # get configs for custom_llm_provider + llm_provider_config = get_provider_config( + custom_llm_provider=fine_tuning_request.custom_llm_provider, ) + # add llm_provider_config to data + data.update(llm_provider_config) + + # For now, use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for fine-tuning + response = await litellm.acreate_fine_tuning_job(**data) + ### ALERTING ### asyncio.create_task( proxy_logging_obj.update_request_status( diff --git a/litellm/types/llms/openai.py b/litellm/types/llms/openai.py index 3bb59f005..875ccadf1 100644 --- a/litellm/types/llms/openai.py +++ b/litellm/types/llms/openai.py @@ -503,3 +503,7 @@ class FineTuningJobCreate(BaseModel): None # "A list of integrations to enable for your fine-tuning job." ) seed: Optional[int] = None # "The seed controls the reproducibility of the job." + + +class LiteLLMFineTuningJobCreate(FineTuningJobCreate): + custom_llm_provider: Literal["openai", "azure"]