Merge pull request #5028 from BerriAI/litellm_create_ft_job_gemini

[Feat] Add support for Vertex AI fine tuning endpoints
This commit is contained in:
Ishaan Jaff 2024-08-03 08:22:55 -07:00 committed by GitHub
commit f840a5f6b4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 351 additions and 2 deletions

View file

@ -25,6 +25,7 @@ from litellm.llms.fine_tuning_apis.openai import (
FineTuningJobCreate,
OpenAIFineTuningAPI,
)
from litellm.llms.fine_tuning_apis.vertex_ai import VertexFineTuningAPI
from litellm.types.llms.openai import Hyperparameters
from litellm.types.router import *
from litellm.utils import supports_httpx_timeout
@ -32,6 +33,7 @@ from litellm.utils import supports_httpx_timeout
####### ENVIRONMENT VARIABLES ###################
openai_fine_tuning_apis_instance = OpenAIFineTuningAPI()
azure_fine_tuning_apis_instance = AzureOpenAIFineTuningAPI()
vertex_fine_tuning_apis_instance = VertexFineTuningAPI()
#################################################
@ -43,7 +45,7 @@ async def acreate_fine_tuning_job(
validation_file: Optional[str] = None,
integrations: Optional[List[str]] = None,
seed: Optional[int] = None,
custom_llm_provider: Literal["openai", "azure"] = "openai",
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
@ -96,7 +98,7 @@ def create_fine_tuning_job(
validation_file: Optional[str] = None,
integrations: Optional[List[str]] = None,
seed: Optional[int] = None,
custom_llm_provider: Literal["openai", "azure"] = "openai",
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
@ -221,6 +223,39 @@ def create_fine_tuning_job(
max_retries=optional_params.max_retries,
_is_async=_is_async,
)
elif custom_llm_provider == "vertex_ai":
api_base = optional_params.api_base or ""
vertex_ai_project = (
optional_params.vertex_project
or litellm.vertex_project
or get_secret("VERTEXAI_PROJECT")
)
vertex_ai_location = (
optional_params.vertex_location
or litellm.vertex_location
or get_secret("VERTEXAI_LOCATION")
)
vertex_credentials = optional_params.vertex_credentials or get_secret(
"VERTEXAI_CREDENTIALS"
)
create_fine_tuning_job_data = FineTuningJobCreate(
model=model,
training_file=training_file,
hyperparameters=hyperparameters,
suffix=suffix,
validation_file=validation_file,
integrations=integrations,
seed=seed,
)
response = vertex_fine_tuning_apis_instance.create_fine_tuning_job(
_is_async=_is_async,
create_fine_tuning_job_data=create_fine_tuning_job_data,
vertex_credentials=vertex_credentials,
vertex_project=vertex_ai_project,
vertex_location=vertex_ai_location,
timeout=timeout,
api_base=api_base,
)
else:
raise litellm.exceptions.BadRequestError(
message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format(
@ -236,6 +271,7 @@ def create_fine_tuning_job(
)
return response
except Exception as e:
verbose_logger.error("got exception in create_fine_tuning_job=%s", str(e))
raise e

View file

@ -0,0 +1,213 @@
import traceback
from datetime import datetime
from typing import Any, Coroutine, Literal, Optional, Union
import httpx
from openai.types.fine_tuning.fine_tuning_job import FineTuningJob, Hyperparameters
from litellm._logging import verbose_logger
from litellm.llms.base import BaseLLM
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.llms.vertex_httpx import VertexLLM
from litellm.types.llms.openai import FineTuningJobCreate
from litellm.types.llms.vertex_ai import (
FineTuneJobCreate,
FineTunesupervisedTuningSpec,
ResponseTuningJob,
)
class VertexFineTuningAPI(VertexLLM):
"""
Vertex methods to support for batches
"""
def __init__(self) -> None:
super().__init__()
self.async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
)
def convert_response_created_at(self, response: ResponseTuningJob):
try:
create_time_str = response.get("createTime", "") or ""
create_time_datetime = datetime.fromisoformat(
create_time_str.replace("Z", "+00:00")
)
# Convert to Unix timestamp (seconds since epoch)
created_at = int(create_time_datetime.timestamp())
return created_at
except Exception as e:
return 0
def convert_vertex_response_to_open_ai_response(
self, response: ResponseTuningJob
) -> FineTuningJob:
status: Literal[
"validating_files", "queued", "running", "succeeded", "failed", "cancelled"
] = "queued"
if response["state"] == "JOB_STATE_PENDING":
status = "queued"
if response["state"] == "JOB_STATE_SUCCEEDED":
status = "succeeded"
if response["state"] == "JOB_STATE_FAILED":
status = "failed"
if response["state"] == "JOB_STATE_CANCELLED":
status = "cancelled"
if response["state"] == "JOB_STATE_RUNNING":
status = "running"
created_at = self.convert_response_created_at(response)
training_uri = ""
if "supervisedTuningSpec" in response and response["supervisedTuningSpec"]:
training_uri = response["supervisedTuningSpec"]["trainingDatasetUri"] or ""
return FineTuningJob(
id=response["name"] or "",
created_at=created_at,
fine_tuned_model=response["tunedModelDisplayName"],
finished_at=None,
hyperparameters=Hyperparameters(
n_epochs=0,
),
model=response["baseModel"] or "",
object="fine_tuning.job",
organization_id="",
result_files=[],
seed=0,
status=status,
trained_tokens=None,
training_file=training_uri,
validation_file=None,
estimated_finish=None,
integrations=[],
)
async def acreate_fine_tuning_job(
self,
fine_tuning_url: str,
headers: dict,
request_data: FineTuneJobCreate,
):
from litellm.fine_tuning.main import FineTuningJob
try:
verbose_logger.debug(
"about to create fine tuning job: %s, request_data: %s",
fine_tuning_url,
request_data,
)
if self.async_handler is None:
raise ValueError(
"VertexAI Fine Tuning - async_handler is not initialized"
)
response = await self.async_handler.post(
headers=headers,
url=fine_tuning_url,
json=request_data, # type: ignore
)
if response.status_code != 200:
raise Exception(
f"Error creating fine tuning job. Status code: {response.status_code}. Response: {response.text}"
)
verbose_logger.debug(
"got response from creating fine tuning job: %s", response.json()
)
vertex_response = ResponseTuningJob( # type: ignore
**response.json(),
)
verbose_logger.debug("vertex_response %s", vertex_response)
open_ai_response = self.convert_vertex_response_to_open_ai_response(
vertex_response
)
return open_ai_response
except Exception as e:
verbose_logger.error("asyncerror creating fine tuning job %s", e)
trace_back_str = traceback.format_exc()
verbose_logger.error(trace_back_str)
raise e
def create_fine_tuning_job(
self,
_is_async: bool,
create_fine_tuning_job_data: FineTuningJobCreate,
vertex_project: Optional[str],
vertex_location: Optional[str],
vertex_credentials: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
):
verbose_logger.debug(
"creating fine tuning job, args= %s", create_fine_tuning_job_data
)
auth_header, _ = self._get_token_and_url(
model="",
gemini_api_key=None,
vertex_credentials=vertex_credentials,
vertex_project=vertex_project,
vertex_location=vertex_location,
stream=False,
custom_llm_provider="vertex_ai_beta",
api_base=api_base,
)
headers = {
"Authorization": f"Bearer {auth_header}",
"Content-Type": "application/json",
}
supervised_tuning_spec = FineTunesupervisedTuningSpec(
training_dataset_uri=create_fine_tuning_job_data.training_file
)
fine_tune_job = FineTuneJobCreate(
baseModel=create_fine_tuning_job_data.model,
supervisedTuningSpec=supervised_tuning_spec,
)
fine_tuning_url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/tuningJobs"
if _is_async is True:
return self.acreate_fine_tuning_job( # type: ignore
fine_tuning_url=fine_tuning_url,
headers=headers,
request_data=fine_tune_job,
)
sync_handler = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
verbose_logger.debug(
"about to create fine tuning job: %s, request_data: %s",
fine_tuning_url,
fine_tune_job,
)
response = sync_handler.post(
headers=headers,
url=fine_tuning_url,
json=fine_tune_job, # type: ignore
)
if response.status_code != 200:
raise Exception(
f"Error creating fine tuning job. Status code: {response.status_code}. Response: {response.text}"
)
verbose_logger.debug(
"got response from creating fine tuning job: %s", response.json()
)
vertex_response = ResponseTuningJob( # type: ignore
**response.json(),
)
verbose_logger.debug("vertex_response %s", vertex_response)
open_ai_response = self.convert_vertex_response_to_open_ai_response(
vertex_response
)
return open_ai_response

View file

@ -16,6 +16,7 @@ import asyncio
import logging
import openai
from test_gcs_bucket import load_vertex_ai_credentials
from litellm import create_fine_tuning_job
from litellm._logging import verbose_logger
@ -183,3 +184,29 @@ async def test_azure_create_fine_tune_jobs_async():
except Exception as e:
pytest.fail(f"Error occurred: {e}")
pass
@pytest.mark.asyncio()
@pytest.mark.skip(reason="skipping until we can cancel fine tuning jobs")
async def test_create_vertex_fine_tune_jobs():
try:
verbose_logger.setLevel(logging.DEBUG)
load_vertex_ai_credentials()
vertex_credentials = os.getenv("GCS_PATH_SERVICE_ACCOUNT")
print("creating fine tuning job")
create_fine_tuning_response = await litellm.acreate_fine_tuning_job(
model="gemini-1.0-pro-002",
custom_llm_provider="vertex_ai",
training_file="gs://cloud-samples-data/ai-platform/generative_ai/sft_train_data.jsonl",
vertex_project="adroit-crow-413218",
vertex_location="us-central1",
vertex_credentials=vertex_credentials,
)
print("vertex ai create fine tuning response=", create_fine_tuning_response)
assert create_fine_tuning_response.id is not None
assert create_fine_tuning_response.model == "gemini-1.0-pro-002"
assert create_fine_tuning_response.object == "fine_tuning.job"
except:
pass

View file

@ -0,0 +1,22 @@
[
{
"messages": [
{
"role": "system",
"content": "You should classify the text into one of the following classes:[business, entertainment]"
},
{ "role": "user", "content": "Diversify your investment portfolio" },
{ "role": "model", "content": "business" }
]
},
{
"messages": [
{
"role": "system",
"content": "You should classify the text into one of the following classes:[business, entertainment]"
},
{ "role": "user", "content": "Watch a live concert" },
{ "role": "model", "content": "entertainment" }
]
}
]

View file

@ -260,3 +260,48 @@ class GenerateContentResponseBody(TypedDict, total=False):
candidates: Required[List[Candidates]]
promptFeedback: PromptFeedback
usageMetadata: Required[UsageMetadata]
class FineTunesupervisedTuningSpec(TypedDict, total=False):
training_dataset_uri: str
validation_dataset: Optional[str]
epoch_count: Optional[int]
learning_rate_multiplier: Optional[float]
tuned_model_display_name: Optional[str]
adapter_size: Optional[
Literal[
"ADAPTER_SIZE_UNSPECIFIED",
"ADAPTER_SIZE_ONE",
"ADAPTER_SIZE_FOUR",
"ADAPTER_SIZE_EIGHT",
"ADAPTER_SIZE_SIXTEEN",
]
]
class FineTuneJobCreate(TypedDict, total=False):
baseModel: str
supervisedTuningSpec: FineTunesupervisedTuningSpec
tunedModelDisplayName: Optional[str]
class ResponseSupervisedTuningSpec(TypedDict):
trainingDatasetUri: Optional[str]
class ResponseTuningJob(TypedDict):
name: Optional[str]
tunedModelDisplayName: Optional[str]
baseModel: Optional[str]
supervisedTuningSpec: Optional[ResponseSupervisedTuningSpec]
state: Optional[
Literal[
"JOB_STATE_PENDING",
"JOB_STATE_RUNNING",
"JOB_STATE_SUCCEEDED",
"JOB_STATE_FAILED",
"JOB_STATE_CANCELLED",
]
]
createTime: Optional[str]
updateTime: Optional[str]

View file

@ -143,6 +143,7 @@ class GenericLiteLLMParams(BaseModel):
## VERTEX AI ##
vertex_project: Optional[str] = None
vertex_location: Optional[str] = None
vertex_credentials: Optional[str] = None
## AWS BEDROCK / SAGEMAKER ##
aws_access_key_id: Optional[str] = None
aws_secret_access_key: Optional[str] = None
@ -486,6 +487,11 @@ class AssistantsTypedDict(TypedDict):
litellm_params: LiteLLMParamsTypedDict
class FineTuningConfig(BaseModel):
custom_llm_provider: Literal["azure", "openai"]
class CustomRoutingStrategyBase:
async def async_get_available_deployment(
self,