mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
test translating to vertex ai params
This commit is contained in:
parent
4917aaefab
commit
69e5a7cb68
3 changed files with 88 additions and 6 deletions
|
@ -255,6 +255,7 @@ def create_fine_tuning_job(
|
||||||
vertex_location=vertex_ai_location,
|
vertex_location=vertex_ai_location,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
kwargs=kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise litellm.exceptions.BadRequestError(
|
raise litellm.exceptions.BadRequestError(
|
||||||
|
|
|
@ -86,6 +86,29 @@ class VertexFineTuningAPI(VertexLLM):
|
||||||
integrations=[],
|
integrations=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def convert_openai_request_to_vertex(
|
||||||
|
self, create_fine_tuning_job_data: FineTuningJobCreate, **kwargs
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
convert request from OpenAI format to Vertex format
|
||||||
|
https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/tuning
|
||||||
|
supervised_tuning_spec = FineTunesupervisedTuningSpec(
|
||||||
|
"""
|
||||||
|
supervised_tuning_spec = FineTunesupervisedTuningSpec(
|
||||||
|
training_dataset_uri=create_fine_tuning_job_data.training_file,
|
||||||
|
validation_dataset=create_fine_tuning_job_data.validation_file,
|
||||||
|
epoch_count=create_fine_tuning_job_data.hyperparameters.n_epochs,
|
||||||
|
learning_rate_multiplier=create_fine_tuning_job_data.hyperparameters.learning_rate_multiplier,
|
||||||
|
adapter_size=kwargs.get("AdapterSize"),
|
||||||
|
)
|
||||||
|
fine_tune_job = FineTuneJobCreate(
|
||||||
|
baseModel=create_fine_tuning_job_data.model,
|
||||||
|
supervisedTuningSpec=supervised_tuning_spec,
|
||||||
|
tunedModelDisplayName=create_fine_tuning_job_data.suffix,
|
||||||
|
)
|
||||||
|
|
||||||
|
return fine_tune_job
|
||||||
|
|
||||||
async def acreate_fine_tuning_job(
|
async def acreate_fine_tuning_job(
|
||||||
self,
|
self,
|
||||||
fine_tuning_url: str,
|
fine_tuning_url: str,
|
||||||
|
@ -144,6 +167,7 @@ class VertexFineTuningAPI(VertexLLM):
|
||||||
vertex_credentials: Optional[str],
|
vertex_credentials: Optional[str],
|
||||||
api_base: Optional[str],
|
api_base: Optional[str],
|
||||||
timeout: Union[float, httpx.Timeout],
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
|
||||||
verbose_logger.debug(
|
verbose_logger.debug(
|
||||||
|
@ -166,12 +190,8 @@ class VertexFineTuningAPI(VertexLLM):
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
}
|
}
|
||||||
|
|
||||||
supervised_tuning_spec = FineTunesupervisedTuningSpec(
|
fine_tune_job = self.convert_openai_request_to_vertex(
|
||||||
training_dataset_uri=create_fine_tuning_job_data.training_file
|
create_fine_tuning_job_data=create_fine_tuning_job_data, **kwargs
|
||||||
)
|
|
||||||
fine_tune_job = FineTuneJobCreate(
|
|
||||||
baseModel=create_fine_tuning_job_data.model,
|
|
||||||
supervisedTuningSpec=supervised_tuning_spec,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
fine_tuning_url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/tuningJobs"
|
fine_tuning_url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/tuningJobs"
|
||||||
|
|
|
@ -20,6 +20,12 @@ from test_gcs_bucket import load_vertex_ai_credentials
|
||||||
|
|
||||||
from litellm import create_fine_tuning_job
|
from litellm import create_fine_tuning_job
|
||||||
from litellm._logging import verbose_logger
|
from litellm._logging import verbose_logger
|
||||||
|
from litellm.llms.fine_tuning_apis.vertex_ai import (
|
||||||
|
FineTuningJobCreate,
|
||||||
|
VertexFineTuningAPI,
|
||||||
|
)
|
||||||
|
|
||||||
|
vertex_finetune_api = VertexFineTuningAPI()
|
||||||
|
|
||||||
|
|
||||||
def test_create_fine_tune_job():
|
def test_create_fine_tune_job():
|
||||||
|
@ -210,3 +216,58 @@ async def test_create_vertex_fine_tune_jobs():
|
||||||
assert create_fine_tuning_response.object == "fine_tuning.job"
|
assert create_fine_tuning_response.object == "fine_tuning.job"
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# Testing OpenAI -> Vertex AI param mapping
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_openai_request_to_vertex_basic():
|
||||||
|
openai_data = FineTuningJobCreate(
|
||||||
|
training_file="gs://bucket/train.jsonl",
|
||||||
|
validation_file="gs://bucket/val.jsonl",
|
||||||
|
model="text-davinci-002",
|
||||||
|
hyperparameters={"n_epochs": 3, "learning_rate_multiplier": 0.1},
|
||||||
|
suffix="my_fine_tuned_model",
|
||||||
|
)
|
||||||
|
|
||||||
|
result = vertex_finetune_api.convert_openai_request_to_vertex(openai_data)
|
||||||
|
|
||||||
|
print("converted vertex ai result=", result)
|
||||||
|
|
||||||
|
assert result["baseModel"] == "text-davinci-002"
|
||||||
|
assert result["tunedModelDisplayName"] == "my_fine_tuned_model"
|
||||||
|
assert (
|
||||||
|
result["supervisedTuningSpec"]["training_dataset_uri"]
|
||||||
|
== "gs://bucket/train.jsonl"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
result["supervisedTuningSpec"]["validation_dataset"] == "gs://bucket/val.jsonl"
|
||||||
|
)
|
||||||
|
assert result["supervisedTuningSpec"]["epoch_count"] == 3
|
||||||
|
assert result["supervisedTuningSpec"]["learning_rate_multiplier"] == 0.1
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_openai_request_to_vertex_with_adapter_size():
|
||||||
|
openai_data = FineTuningJobCreate(
|
||||||
|
training_file="gs://bucket/train.jsonl",
|
||||||
|
model="text-davinci-002",
|
||||||
|
hyperparameters={"n_epochs": 5, "learning_rate_multiplier": 0.2},
|
||||||
|
suffix="custom_model",
|
||||||
|
)
|
||||||
|
|
||||||
|
result = vertex_finetune_api.convert_openai_request_to_vertex(
|
||||||
|
openai_data, AdapterSize="SMALL"
|
||||||
|
)
|
||||||
|
|
||||||
|
print("converted vertex ai result=", result)
|
||||||
|
|
||||||
|
assert result["baseModel"] == "text-davinci-002"
|
||||||
|
assert result["tunedModelDisplayName"] == "custom_model"
|
||||||
|
assert (
|
||||||
|
result["supervisedTuningSpec"]["training_dataset_uri"]
|
||||||
|
== "gs://bucket/train.jsonl"
|
||||||
|
)
|
||||||
|
assert result["supervisedTuningSpec"]["validation_dataset"] is None
|
||||||
|
assert result["supervisedTuningSpec"]["epoch_count"] == 5
|
||||||
|
assert result["supervisedTuningSpec"]["learning_rate_multiplier"] == 0.2
|
||||||
|
assert result["supervisedTuningSpec"]["adapter_size"] == "SMALL"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue