validation for passing config file

This commit is contained in:
Ishaan Jaff 2024-07-31 13:32:18 -07:00
parent bd7b485d09
commit e4c73036fc
2 changed files with 37 additions and 9 deletions

View file

@ -33,11 +33,29 @@ from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
router = APIRouter() router = APIRouter()
from litellm.llms.fine_tuning_apis.openai import ( from litellm.types.llms.openai import LiteLLMFineTuningJobCreate
FineTuningJob,
FineTuningJobCreate, fine_tuning_config = None
OpenAIFineTuningAPI,
def set_fine_tuning_config(config):
global fine_tuning_config
fine_tuning_config = config
# Function to search for specific custom_llm_provider and return its configuration
def get_provider_config(
custom_llm_provider: str,
):
global fine_tuning_config
if fine_tuning_config is None:
raise ValueError(
"fine_tuning_config is not set, set it on your config.yaml file."
) )
for setting in fine_tuning_config:
if setting.get("custom_llm_provider") == custom_llm_provider:
return setting
return None
@router.post( @router.post(
@ -53,7 +71,7 @@ from litellm.llms.fine_tuning_apis.openai import (
async def create_fine_tuning_job( async def create_fine_tuning_job(
request: Request, request: Request,
fastapi_response: Response, fastapi_response: Response,
fine_tuning_request: FineTuningJobCreate, fine_tuning_request: LiteLLMFineTuningJobCreate,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
): ):
""" """
@ -103,11 +121,17 @@ async def create_fine_tuning_job(
proxy_config=proxy_config, proxy_config=proxy_config,
) )
# For now, use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for fine-tuning # get configs for custom_llm_provider
response = await litellm.acreate_fine_tuning_job( llm_provider_config = get_provider_config(
custom_llm_provider="openai", **data custom_llm_provider=fine_tuning_request.custom_llm_provider,
) )
# add llm_provider_config to data
data.update(llm_provider_config)
# For now, use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for fine-tuning
response = await litellm.acreate_fine_tuning_job(**data)
### ALERTING ### ### ALERTING ###
asyncio.create_task( asyncio.create_task(
proxy_logging_obj.update_request_status( proxy_logging_obj.update_request_status(

View file

@ -503,3 +503,7 @@ class FineTuningJobCreate(BaseModel):
None # "A list of integrations to enable for your fine-tuning job." None # "A list of integrations to enable for your fine-tuning job."
) )
seed: Optional[int] = None # "The seed controls the reproducibility of the job." seed: Optional[int] = None # "The seed controls the reproducibility of the job."
class LiteLLMFineTuningJobCreate(FineTuningJobCreate):
custom_llm_provider: Literal["openai", "azure"]