diff --git a/litellm/llms/fine_tuning_apis/vertex_ai.py b/litellm/llms/fine_tuning_apis/vertex_ai.py index 618894245..189ace11a 100644 --- a/litellm/llms/fine_tuning_apis/vertex_ai.py +++ b/litellm/llms/fine_tuning_apis/vertex_ai.py @@ -241,12 +241,13 @@ class VertexFineTuningAPI(VertexLLM): ) return open_ai_response - async def pass_through_vertex_ai_fine_tuning_job( + async def pass_through_vertex_ai_POST_request( self, request_data: dict, vertex_project: str, vertex_location: str, vertex_credentials: str, + request_route: str, ): auth_header, _ = self._get_token_and_url( model="", @@ -264,14 +265,16 @@ class VertexFineTuningAPI(VertexLLM): "Content-Type": "application/json", } - fine_tuning_url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/tuningJobs" + url = None + if request_route == "tuningJobs": + url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/tuningJobs" if self.async_handler is None: raise ValueError("VertexAI Fine Tuning - async_handler is not initialized") response = await self.async_handler.post( headers=headers, - url=fine_tuning_url, + url=url, json=request_data, # type: ignore ) diff --git a/litellm/proxy/fine_tuning_endpoints/endpoints.py b/litellm/proxy/fine_tuning_endpoints/endpoints.py index c2d89dd25..cda226b5a 100644 --- a/litellm/proxy/fine_tuning_endpoints/endpoints.py +++ b/litellm/proxy/fine_tuning_endpoints/endpoints.py @@ -429,72 +429,3 @@ async def retrieve_fine_tuning_job( param=getattr(e, "param", "None"), code=getattr(e, "status_code", 500), ) - - -@router.post( - "/v1/projects/tuningJobs", - dependencies=[Depends(user_api_key_auth)], - tags=["fine-tuning"], - summary="✨ (Enterprise) Create Fine-Tuning Jobs", -) -async def vertex_create_fine_tuning_job( - request: Request, - fastapi_response: Response, - user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), -): - """ - this is a pass through endpoint for the Vertex AI API. /tuningJobs endpoint - - it uses the vertex ai credentials on the proxy and forwards to vertex ai api - """ - try: - from litellm.fine_tuning.main import vertex_fine_tuning_apis_instance - from litellm.proxy.proxy_server import ( - add_litellm_data_to_request, - general_settings, - get_custom_headers, - premium_user, - proxy_config, - proxy_logging_obj, - version, - ) - - # get configs for custom_llm_provider - llm_provider_config = get_fine_tuning_provider_config( - custom_llm_provider="vertex_ai" - ) - - vertex_project = llm_provider_config.get("vertex_project", None) - vertex_location = llm_provider_config.get("vertex_location", None) - vertex_credentials = llm_provider_config.get("vertex_credentials", None) - request_data_json = await request.json() - response = await vertex_fine_tuning_apis_instance.pass_through_vertex_ai_fine_tuning_job( - request_data=request_data_json, - vertex_project=vertex_project, - vertex_location=vertex_location, - vertex_credentials=vertex_credentials, - ) - - return response - except Exception as e: - verbose_proxy_logger.error( - "litellm.proxy.proxy_server.v1/projects/tuningJobs(): Exception occurred - {}".format( - str(e) - ) - ) - verbose_proxy_logger.debug(traceback.format_exc()) - if isinstance(e, HTTPException): - raise ProxyException( - message=getattr(e, "message", str(e.detail)), - type=getattr(e, "type", "None"), - param=getattr(e, "param", "None"), - code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), - ) - else: - error_msg = f"{str(e)}" - raise ProxyException( - message=getattr(e, "message", error_msg), - type=getattr(e, "type", "None"), - param=getattr(e, "param", "None"), - code=getattr(e, "status_code", 500), - ) diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index aa2bfc525..0750a3937 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -48,6 +48,11 @@ files_settings: - custom_llm_provider: openai api_key: os.environ/OPENAI_API_KEY +default_vertex_config: + vertex_project: "adroit-crow-413218" + vertex_location: "us-central1" + vertex_credentials: "/Users/ishaanjaffer/Downloads/adroit-crow-413218-a956eef1a2a8.json" + general_settings: diff --git a/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py b/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py new file mode 100644 index 000000000..be09a4932 --- /dev/null +++ b/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py @@ -0,0 +1,120 @@ +import asyncio +import traceback +from datetime import datetime, timedelta, timezone +from typing import List, Optional + +import fastapi +import httpx +from fastapi import ( + APIRouter, + Depends, + File, + Form, + Header, + HTTPException, + Request, + Response, + UploadFile, + status, +) + +import litellm +from litellm._logging import verbose_proxy_logger +from litellm.batches.main import FileObject +from litellm.fine_tuning.main import vertex_fine_tuning_apis_instance +from litellm.proxy._types import * +from litellm.proxy.auth.user_api_key_auth import user_api_key_auth + +router = APIRouter() +default_vertex_config = None + + +def set_default_vertex_config(config): + global default_vertex_config + if config is None: + return + + if not isinstance(config, list): + raise ValueError("invalid files config, expected a list is not a list") + + for element in config: + if isinstance(element, dict): + for key, value in element.items(): + if isinstance(value, str) and value.startswith("os.environ/"): + element[key] = litellm.get_secret(value) + + default_vertex_config = config + + +def exception_handler(e: Exception): + verbose_proxy_logger.error( + "litellm.proxy.proxy_server.v1/projects/tuningJobs(): Exception occurred - {}".format( + str(e) + ) + ) + verbose_proxy_logger.debug(traceback.format_exc()) + if isinstance(e, HTTPException): + return ProxyException( + message=getattr(e, "message", str(e.detail)), + type=getattr(e, "type", "None"), + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), + ) + else: + error_msg = f"{str(e)}" + return ProxyException( + message=getattr(e, "message", error_msg), + type=getattr(e, "type", "None"), + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", 500), + ) + + +async def execute_post_vertex_ai_request( + request: Request, + route: str, +): + from litellm.fine_tuning.main import vertex_fine_tuning_apis_instance + + vertex_project = default_vertex_config.get("vertex_project", None) + vertex_location = default_vertex_config.get("vertex_location", None) + vertex_credentials = default_vertex_config.get("vertex_credentials", None) + request_data_json = await request.json() + + response = ( + await vertex_fine_tuning_apis_instance.pass_through_vertex_ai_POST_request( + request_data=request_data_json, + vertex_project=vertex_project, + vertex_location=vertex_location, + vertex_credentials=vertex_credentials, + request_route=route, + ) + ) + + return response + + +@router.post( + "/vertex-ai/tuningJobs", + dependencies=[Depends(user_api_key_auth)], + tags=["Vertex AI endpoints"], +) +async def vertex_create_fine_tuning_job( + request: Request, + fastapi_response: Response, + endpoint_name: str, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + this is a pass through endpoint for the Vertex AI API. /tuningJobs endpoint + + it uses the vertex ai credentials on the proxy and forwards to vertex ai api + """ + try: + response = await execute_post_vertex_ai_request( + request=request, + route="/tuningJobs", + ) + return response + except Exception as e: + raise exception_handler(e) from e