mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
add fine tuning for vertex
This commit is contained in:
parent
c614632ae9
commit
b463a290a9
1 changed files with 113 additions and 0 deletions
113
litellm/llms/fine_tuning_apis/vertex_ai.py
Normal file
113
litellm/llms/fine_tuning_apis/vertex_ai.py
Normal file
|
@ -0,0 +1,113 @@
|
|||
import traceback
|
||||
from typing import Any, Coroutine, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.llms.base import BaseLLM
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||
from litellm.llms.vertex_httpx import VertexLLM
|
||||
from litellm.types.llms.openai import FineTuningJobCreate
|
||||
from litellm.types.llms.vertex_ai import FineTuneJobCreate, FineTunesupervisedTuningSpec
|
||||
|
||||
|
||||
class VertexFineTuningAPI(VertexLLM):
|
||||
"""
|
||||
Vertex methods to support for batches
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.async_handler = AsyncHTTPHandler(
|
||||
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
|
||||
)
|
||||
|
||||
async def acreate_fine_tuning_job(
|
||||
self,
|
||||
fine_tuning_url: str,
|
||||
headers: dict,
|
||||
request_data: dict,
|
||||
):
|
||||
try:
|
||||
verbose_logger.debug(
|
||||
"about to create fine tuning job: %s, request_data: %s",
|
||||
fine_tuning_url,
|
||||
request_data,
|
||||
)
|
||||
response = await self.async_handler.post(
|
||||
headers=headers,
|
||||
url=fine_tuning_url,
|
||||
json=request_data,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(
|
||||
f"Error creating fine tuning job. Status code: {response.status_code}. Response: {response.text}"
|
||||
)
|
||||
|
||||
verbose_logger.debug(
|
||||
"got response from creating fine tuning job: %s", response.json()
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.error("asyncerror creating fine tuning job %s", e)
|
||||
trace_back_str = traceback.format_exc()
|
||||
verbose_logger.error(trace_back_str)
|
||||
raise e
|
||||
|
||||
def create_fine_tuning_job(
|
||||
self,
|
||||
_is_async: bool,
|
||||
create_fine_tuning_job_data: FineTuningJobCreate,
|
||||
vertex_project: Optional[str],
|
||||
vertex_location: Optional[str],
|
||||
vertex_credentials: Optional[str],
|
||||
api_base: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
):
|
||||
|
||||
verbose_logger.debug(
|
||||
"creating fine tuning job, args= %s", create_fine_tuning_job_data
|
||||
)
|
||||
|
||||
auth_header, _ = self._get_token_and_url(
|
||||
model=None,
|
||||
gemini_api_key=None,
|
||||
vertex_credentials=vertex_credentials,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
stream=False,
|
||||
custom_llm_provider="vertex_ai_beta",
|
||||
api_base=api_base,
|
||||
)
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {auth_header}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
supervised_tuning_spec = FineTunesupervisedTuningSpec(
|
||||
training_dataset_uri=create_fine_tuning_job_data.training_file
|
||||
)
|
||||
fine_tune_job = FineTuneJobCreate(
|
||||
baseModel=create_fine_tuning_job_data.model,
|
||||
supervisedTuningSpec=supervised_tuning_spec,
|
||||
tunedModelDisplayName="ishaan-test",
|
||||
)
|
||||
|
||||
fine_tuning_url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/tuningJobs"
|
||||
if _is_async is True:
|
||||
return self.acreate_fine_tuning_job( # type: ignore
|
||||
fine_tuning_url=fine_tuning_url,
|
||||
headers=headers,
|
||||
request_data=fine_tune_job,
|
||||
)
|
||||
# response = self.async_handler.post(
|
||||
# url=fine_tuning_url,
|
||||
# headers=headers,
|
||||
# json=fine_tune_job,
|
||||
# )
|
||||
|
||||
# response = openai_client.fine_tuning.jobs.create(**create_fine_tuning_job_data) # type: ignore
|
Loading…
Add table
Add a link
Reference in a new issue