fix linting errors

This commit is contained in:
Ishaan Jaff 2024-08-03 08:55:36 -07:00
parent 69e5a7cb68
commit 4ca0464395
2 changed files with 15 additions and 6 deletions

View file

@ -88,19 +88,28 @@ class VertexFineTuningAPI(VertexLLM):
def convert_openai_request_to_vertex( def convert_openai_request_to_vertex(
self, create_fine_tuning_job_data: FineTuningJobCreate, **kwargs self, create_fine_tuning_job_data: FineTuningJobCreate, **kwargs
) -> dict: ) -> FineTuningJobCreate:
""" """
convert request from OpenAI format to Vertex format convert request from OpenAI format to Vertex format
https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/tuning https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/tuning
supervised_tuning_spec = FineTunesupervisedTuningSpec( supervised_tuning_spec = FineTunesupervisedTuningSpec(
""" """
hyperparameters = create_fine_tuning_job_data.hyperparameters
supervised_tuning_spec = FineTunesupervisedTuningSpec( supervised_tuning_spec = FineTunesupervisedTuningSpec(
training_dataset_uri=create_fine_tuning_job_data.training_file, training_dataset_uri=create_fine_tuning_job_data.training_file,
validation_dataset=create_fine_tuning_job_data.validation_file, validation_dataset=create_fine_tuning_job_data.validation_file,
epoch_count=create_fine_tuning_job_data.hyperparameters.n_epochs,
learning_rate_multiplier=create_fine_tuning_job_data.hyperparameters.learning_rate_multiplier,
adapter_size=kwargs.get("AdapterSize"),
) )
if hyperparameters:
if hyperparameters.n_epochs:
supervised_tuning_spec["epoch_count"] = int(hyperparameters.n_epochs)
if hyperparameters.learning_rate_multiplier:
supervised_tuning_spec["learning_rate_multiplier"] = float(
hyperparameters.learning_rate_multiplier
)
supervised_tuning_spec["adapter_size"] = kwargs.get("adapter_size")
fine_tune_job = FineTuneJobCreate( fine_tune_job = FineTuneJobCreate(
baseModel=create_fine_tuning_job_data.model, baseModel=create_fine_tuning_job_data.model,
supervisedTuningSpec=supervised_tuning_spec, supervisedTuningSpec=supervised_tuning_spec,
@ -113,7 +122,7 @@ class VertexFineTuningAPI(VertexLLM):
self, self,
fine_tuning_url: str, fine_tuning_url: str,
headers: dict, headers: dict,
request_data: FineTuneJobCreate, request_data: FineTuningJobCreate,
): ):
from litellm.fine_tuning.main import FineTuningJob from litellm.fine_tuning.main import FineTuningJob

View file

@ -256,7 +256,7 @@ def test_convert_openai_request_to_vertex_with_adapter_size():
) )
result = vertex_finetune_api.convert_openai_request_to_vertex( result = vertex_finetune_api.convert_openai_request_to_vertex(
openai_data, AdapterSize="SMALL" openai_data, adapter_size="SMALL"
) )
print("converted vertex ai result=", result) print("converted vertex ai result=", result)