fix typing

This commit is contained in:
Ishaan Jaff 2024-08-02 18:46:43 -07:00
parent 40430dde10
commit f194aa3a93
2 changed files with 44 additions and 18 deletions

View file

@ -224,7 +224,7 @@ def create_fine_tuning_job(
_is_async=_is_async, _is_async=_is_async,
) )
elif custom_llm_provider == "vertex_ai": elif custom_llm_provider == "vertex_ai":
api_base = optional_params.api_base api_base = optional_params.api_base or ""
vertex_ai_project = ( vertex_ai_project = (
optional_params.vertex_project optional_params.vertex_project
or litellm.vertex_project or litellm.vertex_project

View file

@ -30,11 +30,13 @@ class VertexFineTuningAPI(VertexLLM):
def convert_response_created_at(self, response: ResponseTuningJob): def convert_response_created_at(self, response: ResponseTuningJob):
try: try:
create_time = datetime.fromisoformat(
response["createTime"].replace("Z", "+00:00") create_time_str = response.get("createTime", "") or ""
create_time_datetime = datetime.fromisoformat(
create_time_str.replace("Z", "+00:00")
) )
# Convert to Unix timestamp (seconds since epoch) # Convert to Unix timestamp (seconds since epoch)
created_at = int(create_time.timestamp()) created_at = int(create_time_datetime.timestamp())
return created_at return created_at
except Exception as e: except Exception as e:
@ -59,33 +61,36 @@ class VertexFineTuningAPI(VertexLLM):
created_at = self.convert_response_created_at(response) created_at = self.convert_response_created_at(response)
training_uri = ""
if "supervisedTuningSpec" in response and response["supervisedTuningSpec"]:
training_uri = response["supervisedTuningSpec"]["trainingDatasetUri"] or ""
return FineTuningJob( return FineTuningJob(
id=response["name"], id=response["name"] or "",
created_at=created_at, created_at=created_at,
fine_tuned_model=response["tunedModelDisplayName"], fine_tuned_model=response["tunedModelDisplayName"],
finished_at=None, finished_at=None,
hyperparameters=Hyperparameters( hyperparameters=Hyperparameters(
n_epochs=0, batch_size="", learning_rate_multiplier="" n_epochs=0,
), ),
model=response["baseModel"], model=response["baseModel"] or "",
object="fine_tuning.job", object="fine_tuning.job",
organization_id="", organization_id="",
result_files=[], result_files=[],
seed=0, seed=0,
status=status, status=status,
trained_tokens=None, trained_tokens=None,
training_file=response["supervisedTuningSpec"]["trainingDatasetUri"], training_file=training_uri,
validation_file=None, validation_file=None,
estimated_finish=None, estimated_finish=None,
integrations=[], integrations=[],
user_provided_suffix=None,
) )
async def acreate_fine_tuning_job( async def acreate_fine_tuning_job(
self, self,
fine_tuning_url: str, fine_tuning_url: str,
headers: dict, headers: dict,
request_data: dict, request_data: FineTuneJobCreate,
): ):
from litellm.fine_tuning.main import FineTuningJob from litellm.fine_tuning.main import FineTuningJob
@ -95,10 +100,14 @@ class VertexFineTuningAPI(VertexLLM):
fine_tuning_url, fine_tuning_url,
request_data, request_data,
) )
if self.async_handler is None:
raise ValueError(
"VertexAI Fine Tuning - async_handler is not initialized"
)
response = await self.async_handler.post( response = await self.async_handler.post(
headers=headers, headers=headers,
url=fine_tuning_url, url=fine_tuning_url,
json=request_data, json=request_data, # type: ignore
) )
if response.status_code != 200: if response.status_code != 200:
@ -110,7 +119,16 @@ class VertexFineTuningAPI(VertexLLM):
"got response from creating fine tuning job: %s", response.json() "got response from creating fine tuning job: %s", response.json()
) )
vertex_response = ResponseTuningJob(**response.json()) vertex_response = ResponseTuningJob(
name=None,
tunedModelDisplayName=None,
baseModel=None,
supervisedTuningSpec=None,
state=None,
createTime=None,
updateTime=None,
**response.json(),
)
verbose_logger.debug("vertex_response %s", vertex_response) verbose_logger.debug("vertex_response %s", vertex_response)
open_ai_response = self.convert_vertex_response_to_open_ai_response( open_ai_response = self.convert_vertex_response_to_open_ai_response(
@ -140,7 +158,7 @@ class VertexFineTuningAPI(VertexLLM):
) )
auth_header, _ = self._get_token_and_url( auth_header, _ = self._get_token_and_url(
model=None, model="",
gemini_api_key=None, gemini_api_key=None,
vertex_credentials=vertex_credentials, vertex_credentials=vertex_credentials,
vertex_project=vertex_project, vertex_project=vertex_project,
@ -172,16 +190,15 @@ class VertexFineTuningAPI(VertexLLM):
) )
sync_handler = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0)) sync_handler = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
request_data = fine_tune_job
verbose_logger.debug( verbose_logger.debug(
"about to create fine tuning job: %s, request_data: %s", "about to create fine tuning job: %s, request_data: %s",
fine_tuning_url, fine_tuning_url,
request_data, fine_tune_job,
) )
response = self.sync_handler.post( response = sync_handler.post(
headers=headers, headers=headers,
url=fine_tuning_url, url=fine_tuning_url,
json=request_data, json=fine_tune_job, # type: ignore
) )
if response.status_code != 200: if response.status_code != 200:
@ -192,7 +209,16 @@ class VertexFineTuningAPI(VertexLLM):
verbose_logger.debug( verbose_logger.debug(
"got response from creating fine tuning job: %s", response.json() "got response from creating fine tuning job: %s", response.json()
) )
vertex_response = ResponseTuningJob(**response.json()) vertex_response = ResponseTuningJob(
name=None,
tunedModelDisplayName=None,
baseModel=None,
supervisedTuningSpec=None,
state=None,
createTime=None,
updateTime=None,
**response.json(),
)
verbose_logger.debug("vertex_response %s", vertex_response) verbose_logger.debug("vertex_response %s", vertex_response)
open_ai_response = self.convert_vertex_response_to_open_ai_response( open_ai_response = self.convert_vertex_response_to_open_ai_response(