forked from phoenix/litellm-mirror
Merge pull request #5028 from BerriAI/litellm_create_ft_job_gemini
[Feat] Add support for Vertex AI fine tuning endpoints
This commit is contained in:
commit
f840a5f6b4
6 changed files with 351 additions and 2 deletions
|
@ -25,6 +25,7 @@ from litellm.llms.fine_tuning_apis.openai import (
|
|||
FineTuningJobCreate,
|
||||
OpenAIFineTuningAPI,
|
||||
)
|
||||
from litellm.llms.fine_tuning_apis.vertex_ai import VertexFineTuningAPI
|
||||
from litellm.types.llms.openai import Hyperparameters
|
||||
from litellm.types.router import *
|
||||
from litellm.utils import supports_httpx_timeout
|
||||
|
@ -32,6 +33,7 @@ from litellm.utils import supports_httpx_timeout
|
|||
####### ENVIRONMENT VARIABLES ###################
|
||||
openai_fine_tuning_apis_instance = OpenAIFineTuningAPI()
|
||||
azure_fine_tuning_apis_instance = AzureOpenAIFineTuningAPI()
|
||||
vertex_fine_tuning_apis_instance = VertexFineTuningAPI()
|
||||
#################################################
|
||||
|
||||
|
||||
|
@ -43,7 +45,7 @@ async def acreate_fine_tuning_job(
|
|||
validation_file: Optional[str] = None,
|
||||
integrations: Optional[List[str]] = None,
|
||||
seed: Optional[int] = None,
|
||||
custom_llm_provider: Literal["openai", "azure"] = "openai",
|
||||
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
|
@ -96,7 +98,7 @@ def create_fine_tuning_job(
|
|||
validation_file: Optional[str] = None,
|
||||
integrations: Optional[List[str]] = None,
|
||||
seed: Optional[int] = None,
|
||||
custom_llm_provider: Literal["openai", "azure"] = "openai",
|
||||
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
|
@ -221,6 +223,39 @@ def create_fine_tuning_job(
|
|||
max_retries=optional_params.max_retries,
|
||||
_is_async=_is_async,
|
||||
)
|
||||
elif custom_llm_provider == "vertex_ai":
|
||||
api_base = optional_params.api_base or ""
|
||||
vertex_ai_project = (
|
||||
optional_params.vertex_project
|
||||
or litellm.vertex_project
|
||||
or get_secret("VERTEXAI_PROJECT")
|
||||
)
|
||||
vertex_ai_location = (
|
||||
optional_params.vertex_location
|
||||
or litellm.vertex_location
|
||||
or get_secret("VERTEXAI_LOCATION")
|
||||
)
|
||||
vertex_credentials = optional_params.vertex_credentials or get_secret(
|
||||
"VERTEXAI_CREDENTIALS"
|
||||
)
|
||||
create_fine_tuning_job_data = FineTuningJobCreate(
|
||||
model=model,
|
||||
training_file=training_file,
|
||||
hyperparameters=hyperparameters,
|
||||
suffix=suffix,
|
||||
validation_file=validation_file,
|
||||
integrations=integrations,
|
||||
seed=seed,
|
||||
)
|
||||
response = vertex_fine_tuning_apis_instance.create_fine_tuning_job(
|
||||
_is_async=_is_async,
|
||||
create_fine_tuning_job_data=create_fine_tuning_job_data,
|
||||
vertex_credentials=vertex_credentials,
|
||||
vertex_project=vertex_ai_project,
|
||||
vertex_location=vertex_ai_location,
|
||||
timeout=timeout,
|
||||
api_base=api_base,
|
||||
)
|
||||
else:
|
||||
raise litellm.exceptions.BadRequestError(
|
||||
message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format(
|
||||
|
@ -236,6 +271,7 @@ def create_fine_tuning_job(
|
|||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
verbose_logger.error("got exception in create_fine_tuning_job=%s", str(e))
|
||||
raise e
|
||||
|
||||
|
||||
|
|
213
litellm/llms/fine_tuning_apis/vertex_ai.py
Normal file
213
litellm/llms/fine_tuning_apis/vertex_ai.py
Normal file
|
@ -0,0 +1,213 @@
|
|||
import traceback
|
||||
from datetime import datetime
|
||||
from typing import Any, Coroutine, Literal, Optional, Union
|
||||
|
||||
import httpx
|
||||
from openai.types.fine_tuning.fine_tuning_job import FineTuningJob, Hyperparameters
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.llms.base import BaseLLM
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.llms.vertex_httpx import VertexLLM
|
||||
from litellm.types.llms.openai import FineTuningJobCreate
|
||||
from litellm.types.llms.vertex_ai import (
|
||||
FineTuneJobCreate,
|
||||
FineTunesupervisedTuningSpec,
|
||||
ResponseTuningJob,
|
||||
)
|
||||
|
||||
|
||||
class VertexFineTuningAPI(VertexLLM):
|
||||
"""
|
||||
Vertex methods to support for batches
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.async_handler = AsyncHTTPHandler(
|
||||
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
|
||||
)
|
||||
|
||||
def convert_response_created_at(self, response: ResponseTuningJob):
|
||||
try:
|
||||
|
||||
create_time_str = response.get("createTime", "") or ""
|
||||
create_time_datetime = datetime.fromisoformat(
|
||||
create_time_str.replace("Z", "+00:00")
|
||||
)
|
||||
# Convert to Unix timestamp (seconds since epoch)
|
||||
created_at = int(create_time_datetime.timestamp())
|
||||
|
||||
return created_at
|
||||
except Exception as e:
|
||||
return 0
|
||||
|
||||
def convert_vertex_response_to_open_ai_response(
|
||||
self, response: ResponseTuningJob
|
||||
) -> FineTuningJob:
|
||||
status: Literal[
|
||||
"validating_files", "queued", "running", "succeeded", "failed", "cancelled"
|
||||
] = "queued"
|
||||
if response["state"] == "JOB_STATE_PENDING":
|
||||
status = "queued"
|
||||
if response["state"] == "JOB_STATE_SUCCEEDED":
|
||||
status = "succeeded"
|
||||
if response["state"] == "JOB_STATE_FAILED":
|
||||
status = "failed"
|
||||
if response["state"] == "JOB_STATE_CANCELLED":
|
||||
status = "cancelled"
|
||||
if response["state"] == "JOB_STATE_RUNNING":
|
||||
status = "running"
|
||||
|
||||
created_at = self.convert_response_created_at(response)
|
||||
|
||||
training_uri = ""
|
||||
if "supervisedTuningSpec" in response and response["supervisedTuningSpec"]:
|
||||
training_uri = response["supervisedTuningSpec"]["trainingDatasetUri"] or ""
|
||||
|
||||
return FineTuningJob(
|
||||
id=response["name"] or "",
|
||||
created_at=created_at,
|
||||
fine_tuned_model=response["tunedModelDisplayName"],
|
||||
finished_at=None,
|
||||
hyperparameters=Hyperparameters(
|
||||
n_epochs=0,
|
||||
),
|
||||
model=response["baseModel"] or "",
|
||||
object="fine_tuning.job",
|
||||
organization_id="",
|
||||
result_files=[],
|
||||
seed=0,
|
||||
status=status,
|
||||
trained_tokens=None,
|
||||
training_file=training_uri,
|
||||
validation_file=None,
|
||||
estimated_finish=None,
|
||||
integrations=[],
|
||||
)
|
||||
|
||||
async def acreate_fine_tuning_job(
|
||||
self,
|
||||
fine_tuning_url: str,
|
||||
headers: dict,
|
||||
request_data: FineTuneJobCreate,
|
||||
):
|
||||
from litellm.fine_tuning.main import FineTuningJob
|
||||
|
||||
try:
|
||||
verbose_logger.debug(
|
||||
"about to create fine tuning job: %s, request_data: %s",
|
||||
fine_tuning_url,
|
||||
request_data,
|
||||
)
|
||||
if self.async_handler is None:
|
||||
raise ValueError(
|
||||
"VertexAI Fine Tuning - async_handler is not initialized"
|
||||
)
|
||||
response = await self.async_handler.post(
|
||||
headers=headers,
|
||||
url=fine_tuning_url,
|
||||
json=request_data, # type: ignore
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(
|
||||
f"Error creating fine tuning job. Status code: {response.status_code}. Response: {response.text}"
|
||||
)
|
||||
|
||||
verbose_logger.debug(
|
||||
"got response from creating fine tuning job: %s", response.json()
|
||||
)
|
||||
|
||||
vertex_response = ResponseTuningJob( # type: ignore
|
||||
**response.json(),
|
||||
)
|
||||
|
||||
verbose_logger.debug("vertex_response %s", vertex_response)
|
||||
open_ai_response = self.convert_vertex_response_to_open_ai_response(
|
||||
vertex_response
|
||||
)
|
||||
return open_ai_response
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.error("asyncerror creating fine tuning job %s", e)
|
||||
trace_back_str = traceback.format_exc()
|
||||
verbose_logger.error(trace_back_str)
|
||||
raise e
|
||||
|
||||
def create_fine_tuning_job(
|
||||
self,
|
||||
_is_async: bool,
|
||||
create_fine_tuning_job_data: FineTuningJobCreate,
|
||||
vertex_project: Optional[str],
|
||||
vertex_location: Optional[str],
|
||||
vertex_credentials: Optional[str],
|
||||
api_base: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
):
|
||||
|
||||
verbose_logger.debug(
|
||||
"creating fine tuning job, args= %s", create_fine_tuning_job_data
|
||||
)
|
||||
|
||||
auth_header, _ = self._get_token_and_url(
|
||||
model="",
|
||||
gemini_api_key=None,
|
||||
vertex_credentials=vertex_credentials,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
stream=False,
|
||||
custom_llm_provider="vertex_ai_beta",
|
||||
api_base=api_base,
|
||||
)
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {auth_header}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
supervised_tuning_spec = FineTunesupervisedTuningSpec(
|
||||
training_dataset_uri=create_fine_tuning_job_data.training_file
|
||||
)
|
||||
fine_tune_job = FineTuneJobCreate(
|
||||
baseModel=create_fine_tuning_job_data.model,
|
||||
supervisedTuningSpec=supervised_tuning_spec,
|
||||
)
|
||||
|
||||
fine_tuning_url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/tuningJobs"
|
||||
if _is_async is True:
|
||||
return self.acreate_fine_tuning_job( # type: ignore
|
||||
fine_tuning_url=fine_tuning_url,
|
||||
headers=headers,
|
||||
request_data=fine_tune_job,
|
||||
)
|
||||
sync_handler = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
|
||||
|
||||
verbose_logger.debug(
|
||||
"about to create fine tuning job: %s, request_data: %s",
|
||||
fine_tuning_url,
|
||||
fine_tune_job,
|
||||
)
|
||||
response = sync_handler.post(
|
||||
headers=headers,
|
||||
url=fine_tuning_url,
|
||||
json=fine_tune_job, # type: ignore
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(
|
||||
f"Error creating fine tuning job. Status code: {response.status_code}. Response: {response.text}"
|
||||
)
|
||||
|
||||
verbose_logger.debug(
|
||||
"got response from creating fine tuning job: %s", response.json()
|
||||
)
|
||||
vertex_response = ResponseTuningJob( # type: ignore
|
||||
**response.json(),
|
||||
)
|
||||
|
||||
verbose_logger.debug("vertex_response %s", vertex_response)
|
||||
open_ai_response = self.convert_vertex_response_to_open_ai_response(
|
||||
vertex_response
|
||||
)
|
||||
return open_ai_response
|
|
@ -16,6 +16,7 @@ import asyncio
|
|||
import logging
|
||||
|
||||
import openai
|
||||
from test_gcs_bucket import load_vertex_ai_credentials
|
||||
|
||||
from litellm import create_fine_tuning_job
|
||||
from litellm._logging import verbose_logger
|
||||
|
@ -183,3 +184,29 @@ async def test_azure_create_fine_tune_jobs_async():
|
|||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
@pytest.mark.skip(reason="skipping until we can cancel fine tuning jobs")
|
||||
async def test_create_vertex_fine_tune_jobs():
|
||||
try:
|
||||
verbose_logger.setLevel(logging.DEBUG)
|
||||
load_vertex_ai_credentials()
|
||||
|
||||
vertex_credentials = os.getenv("GCS_PATH_SERVICE_ACCOUNT")
|
||||
print("creating fine tuning job")
|
||||
create_fine_tuning_response = await litellm.acreate_fine_tuning_job(
|
||||
model="gemini-1.0-pro-002",
|
||||
custom_llm_provider="vertex_ai",
|
||||
training_file="gs://cloud-samples-data/ai-platform/generative_ai/sft_train_data.jsonl",
|
||||
vertex_project="adroit-crow-413218",
|
||||
vertex_location="us-central1",
|
||||
vertex_credentials=vertex_credentials,
|
||||
)
|
||||
print("vertex ai create fine tuning response=", create_fine_tuning_response)
|
||||
|
||||
assert create_fine_tuning_response.id is not None
|
||||
assert create_fine_tuning_response.model == "gemini-1.0-pro-002"
|
||||
assert create_fine_tuning_response.object == "fine_tuning.job"
|
||||
except:
|
||||
pass
|
||||
|
|
22
litellm/tests/vertex_ai.jsonl
Normal file
22
litellm/tests/vertex_ai.jsonl
Normal file
|
@ -0,0 +1,22 @@
|
|||
[
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You should classify the text into one of the following classes:[business, entertainment]"
|
||||
},
|
||||
{ "role": "user", "content": "Diversify your investment portfolio" },
|
||||
{ "role": "model", "content": "business" }
|
||||
]
|
||||
},
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You should classify the text into one of the following classes:[business, entertainment]"
|
||||
},
|
||||
{ "role": "user", "content": "Watch a live concert" },
|
||||
{ "role": "model", "content": "entertainment" }
|
||||
]
|
||||
}
|
||||
]
|
|
@ -260,3 +260,48 @@ class GenerateContentResponseBody(TypedDict, total=False):
|
|||
candidates: Required[List[Candidates]]
|
||||
promptFeedback: PromptFeedback
|
||||
usageMetadata: Required[UsageMetadata]
|
||||
|
||||
|
||||
class FineTunesupervisedTuningSpec(TypedDict, total=False):
|
||||
training_dataset_uri: str
|
||||
validation_dataset: Optional[str]
|
||||
epoch_count: Optional[int]
|
||||
learning_rate_multiplier: Optional[float]
|
||||
tuned_model_display_name: Optional[str]
|
||||
adapter_size: Optional[
|
||||
Literal[
|
||||
"ADAPTER_SIZE_UNSPECIFIED",
|
||||
"ADAPTER_SIZE_ONE",
|
||||
"ADAPTER_SIZE_FOUR",
|
||||
"ADAPTER_SIZE_EIGHT",
|
||||
"ADAPTER_SIZE_SIXTEEN",
|
||||
]
|
||||
]
|
||||
|
||||
|
||||
class FineTuneJobCreate(TypedDict, total=False):
|
||||
baseModel: str
|
||||
supervisedTuningSpec: FineTunesupervisedTuningSpec
|
||||
tunedModelDisplayName: Optional[str]
|
||||
|
||||
|
||||
class ResponseSupervisedTuningSpec(TypedDict):
|
||||
trainingDatasetUri: Optional[str]
|
||||
|
||||
|
||||
class ResponseTuningJob(TypedDict):
|
||||
name: Optional[str]
|
||||
tunedModelDisplayName: Optional[str]
|
||||
baseModel: Optional[str]
|
||||
supervisedTuningSpec: Optional[ResponseSupervisedTuningSpec]
|
||||
state: Optional[
|
||||
Literal[
|
||||
"JOB_STATE_PENDING",
|
||||
"JOB_STATE_RUNNING",
|
||||
"JOB_STATE_SUCCEEDED",
|
||||
"JOB_STATE_FAILED",
|
||||
"JOB_STATE_CANCELLED",
|
||||
]
|
||||
]
|
||||
createTime: Optional[str]
|
||||
updateTime: Optional[str]
|
||||
|
|
|
@ -143,6 +143,7 @@ class GenericLiteLLMParams(BaseModel):
|
|||
## VERTEX AI ##
|
||||
vertex_project: Optional[str] = None
|
||||
vertex_location: Optional[str] = None
|
||||
vertex_credentials: Optional[str] = None
|
||||
## AWS BEDROCK / SAGEMAKER ##
|
||||
aws_access_key_id: Optional[str] = None
|
||||
aws_secret_access_key: Optional[str] = None
|
||||
|
@ -486,6 +487,11 @@ class AssistantsTypedDict(TypedDict):
|
|||
litellm_params: LiteLLMParamsTypedDict
|
||||
|
||||
|
||||
class FineTuningConfig(BaseModel):
|
||||
|
||||
custom_llm_provider: Literal["azure", "openai"]
|
||||
|
||||
|
||||
class CustomRoutingStrategyBase:
|
||||
async def async_get_available_deployment(
|
||||
self,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue