fix type errors

This commit is contained in:
Ishaan Jaff 2024-07-29 20:10:03 -07:00
parent 6abc49c611
commit f18827cbc0
3 changed files with 11 additions and 31 deletions

View file

@ -34,7 +34,7 @@ openai_fine_tuning_instance = OpenAIFineTuningAPI()
async def acreate_fine_tuning_job( async def acreate_fine_tuning_job(
model: str, model: str,
training_file: str, training_file: str,
hyperparameters: Optional[Hyperparameters] = {}, hyperparameters: Optional[Hyperparameters] = {}, # type: ignore
suffix: Optional[str] = None, suffix: Optional[str] = None,
validation_file: Optional[str] = None, validation_file: Optional[str] = None,
integrations: Optional[List[str]] = None, integrations: Optional[List[str]] = None,
@ -84,7 +84,7 @@ async def acreate_fine_tuning_job(
def create_fine_tuning_job( def create_fine_tuning_job(
model: str, model: str,
training_file: str, training_file: str,
hyperparameters: Optional[Hyperparameters] = {}, hyperparameters: Optional[Hyperparameters] = {}, # type: ignore
suffix: Optional[str] = None, suffix: Optional[str] = None,
validation_file: Optional[str] = None, validation_file: Optional[str] = None,
integrations: Optional[List[str]] = None, integrations: Optional[List[str]] = None,

View file

@ -54,7 +54,7 @@ class OpenAIFineTuningAPI(BaseLLM):
openai_client: AsyncOpenAI, openai_client: AsyncOpenAI,
) -> FineTuningJob: ) -> FineTuningJob:
response = await openai_client.fine_tuning.jobs.create( response = await openai_client.fine_tuning.jobs.create(
**create_fine_tuning_job_data **create_fine_tuning_job_data # type: ignore
) )
return response return response
@ -68,7 +68,7 @@ class OpenAIFineTuningAPI(BaseLLM):
max_retries: Optional[int], max_retries: Optional[int],
organization: Optional[str], organization: Optional[str],
client: Optional[Union[OpenAI, AsyncOpenAI]] = None, client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
) -> Union[Coroutine[Any, Any, FineTuningJob]]: ) -> Union[FineTuningJob, Union[Coroutine[Any, Any, FineTuningJob]]]:
openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client( openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client(
api_key=api_key, api_key=api_key,
api_base=api_base, api_base=api_base,
@ -154,7 +154,7 @@ class OpenAIFineTuningAPI(BaseLLM):
after: Optional[str] = None, after: Optional[str] = None,
limit: Optional[int] = None, limit: Optional[int] = None,
): ):
response = await openai_client.fine_tuning.jobs.list(after=after, limit=limit) response = await openai_client.fine_tuning.jobs.list(after=after, limit=limit) # type: ignore
return response return response
def list_fine_tuning_jobs( def list_fine_tuning_jobs(
@ -194,6 +194,6 @@ class OpenAIFineTuningAPI(BaseLLM):
openai_client=openai_client, openai_client=openai_client,
) )
verbose_logger.debug("list fine tuning job, after= %s, limit= %s", after, limit) verbose_logger.debug("list fine tuning job, after= %s, limit= %s", after, limit)
response = openai_client.fine_tuning.jobs.list(after=after, limit=limit) response = openai_client.fine_tuning.jobs.list(after=after, limit=limit) # type: ignore
return response return response
pass pass

View file

@ -487,13 +487,11 @@ class FineTuningJobCreate(TypedDict):
``` ```
""" """
model: str = Field(..., description="The name of the model to fine-tune.") model: str # "The name of the model to fine-tune."
training_file: str = Field( training_file: str # "The ID of an uploaded file that contains training data."
..., description="The ID of an uploaded file that contains training data." hyperparameters: Optional[
) Hyperparameters
hyperparameters: Optional[Hyperparameters] = Field( ] # "The hyperparameters used for the fine-tuning job."
default={}, description="The hyperparameters used for the fine-tuning job."
)
suffix: Optional[ suffix: Optional[
str str
] # "A string of up to 18 characters that will be added to your fine-tuned model name." ] # "A string of up to 18 characters that will be added to your fine-tuned model name."
@ -504,21 +502,3 @@ class FineTuningJobCreate(TypedDict):
List[str] List[str]
] # "A list of integrations to enable for your fine-tuning job." ] # "A list of integrations to enable for your fine-tuning job."
seed: Optional[int] # "The seed controls the reproducibility of the job." seed: Optional[int] # "The seed controls the reproducibility of the job."
class Config:
allow_population_by_field_name = True
schema_extra = {
"example": {
"model": "gpt-3.5-turbo",
"training_file": "file-abc123",
"hyperparameters": {
"batch_size": "auto",
"learning_rate_multiplier": 0.1,
"n_epochs": 3,
},
"suffix": "custom-model-name",
"validation_file": "file-xyz789",
"integrations": ["slack"],
"seed": 42,
}
}