diff --git a/docs/my-website/docs/fine_tuning.md b/docs/my-website/docs/fine_tuning.md index c69f4c1e6..fd3cbc792 100644 --- a/docs/my-website/docs/fine_tuning.md +++ b/docs/my-website/docs/fine_tuning.md @@ -124,7 +124,7 @@ ft_job = await client.fine_tuning.jobs.create( ``` - + ```shell curl http://localhost:4000/v1/fine_tuning/jobs \ @@ -136,6 +136,28 @@ curl http://localhost:4000/v1/fine_tuning/jobs \ "training_file": "gs://cloud-samples-data/ai-platform/generative_ai/sft_train_data.jsonl" }' ``` + + + + +:::info + +Use this to create Fine tuning Jobs in [the Vertex AI API Format](https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/tuning#create-tuning) + +::: + +```shell +curl http://localhost:4000/v1/projects/tuningJobs \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer sk-1234" \ + -d '{ + "baseModel": "gemini-1.0-pro-002", + "supervisedTuningSpec" : { + "training_dataset_uri": "gs://cloud-samples-data/ai-platform/generative_ai/sft_train_data.jsonl" + } +}' +``` + diff --git a/docs/my-website/docs/proxy/user_keys.md b/docs/my-website/docs/proxy/user_keys.md index 75e547d17..79d019a20 100644 --- a/docs/my-website/docs/proxy/user_keys.md +++ b/docs/my-website/docs/proxy/user_keys.md @@ -23,6 +23,9 @@ LiteLLM Proxy is **Azure OpenAI-compatible**: LiteLLM Proxy is **Anthropic-compatible**: * /messages +LiteLLM Proxy is **Vertex AI compatible**: +- [Supports ALL Vertex Endpoints](../vertex_ai) + This doc covers: * /chat/completion diff --git a/docs/my-website/docs/vertex_ai.md b/docs/my-website/docs/vertex_ai.md new file mode 100644 index 000000000..d9c8616a0 --- /dev/null +++ b/docs/my-website/docs/vertex_ai.md @@ -0,0 +1,93 @@ +# [BETA] Vertex AI Endpoints + +## Supported API Endpoints + +- Gemini API +- Embeddings API +- Imagen API +- Code Completion API +- Batch prediction API +- Tuning API +- CountTokens API + +## Quick Start Usage + +#### 1. Set `default_vertex_config` on your `config.yaml` + + +Add the following credentials to your litellm config.yaml to use the Vertex AI endpoints. + +```yaml +default_vertex_config: + vertex_project: "adroit-crow-413218" + vertex_location: "us-central1" + vertex_credentials: "/Users/ishaanjaffer/Downloads/adroit-crow-413218-a956eef1a2a8.json" # Add path to service account.json +``` + +#### 2. Start litellm proxy + +```shell +litellm --config /path/to/config.yaml +``` + +#### 3. Test it + +```shell +curl http://localhost:4000/vertex-ai/publishers/google/models/textembedding-gecko@001:countTokens \ +-H "Content-Type: application/json" \ +-H "Authorization: Bearer sk-1234" \ +-d '{"instances":[{"content": "gm"}]}' +``` +## Usage Examples + +### Gemini API (Generate Content) + +```shell +curl http://localhost:4000/vertex-ai/publishers/google/models/gemini-1.5-flash-001:generateContent \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer sk-1234" \ + -d '{"contents":[{"role": "user", "parts":[{"text": "hi"}]}]}' +``` + +### Embeddings API + +```shell +curl http://localhost:4000/vertex-ai/publishers/google/models/textembedding-gecko@001:predict \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer sk-1234" \ + -d '{"instances":[{"content": "gm"}]}' +``` + +### Imagen API + +```shell +curl http://localhost:4000/vertex-ai/publishers/google/models/imagen-3.0-generate-001:predict \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer sk-1234" \ + -d '{"instances":[{"prompt": "make an otter"}], "parameters": {"sampleCount": 1}}' +``` + +### Count Tokens API + +```shell +curl http://localhost:4000/vertex-ai/publishers/google/models/gemini-1.5-flash-001:countTokens \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer sk-1234" \ + -d '{"contents":[{"role": "user", "parts":[{"text": "hi"}]}]}' +``` + +### Tuning API + +Create Fine Tuning Job + +```shell +curl http://localhost:4000/vertex-ai/tuningJobs \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer sk-1234" \ + -d '{ + "baseModel": "gemini-1.0-pro-002", + "supervisedTuningSpec" : { + "training_dataset_uri": "gs://cloud-samples-data/ai-platform/generative_ai/sft_train_data.jsonl" + } +}' +``` \ No newline at end of file diff --git a/docs/my-website/sidebars.js b/docs/my-website/sidebars.js index 27084f3b4..6f6bcfeea 100644 --- a/docs/my-website/sidebars.js +++ b/docs/my-website/sidebars.js @@ -178,7 +178,7 @@ const sidebars = { }, { type: "category", - label: "Embedding(), Image Generation(), Assistants(), Moderation(), Audio Transcriptions(), TTS(), Batches(), Fine-Tuning()", + label: "Supported Endpoints - /images, /audio/speech, /assistants etc", items: [ "embedding/supported_embedding", "embedding/async_embedding", @@ -189,7 +189,8 @@ const sidebars = { "assistants", "batches", "fine_tuning", - "anthropic_completion" + "anthropic_completion", + "vertex_ai" ], }, { diff --git a/litellm/llms/fine_tuning_apis/vertex_ai.py b/litellm/llms/fine_tuning_apis/vertex_ai.py index f370652d2..5f96f0483 100644 --- a/litellm/llms/fine_tuning_apis/vertex_ai.py +++ b/litellm/llms/fine_tuning_apis/vertex_ai.py @@ -240,3 +240,59 @@ class VertexFineTuningAPI(VertexLLM): vertex_response ) return open_ai_response + + async def pass_through_vertex_ai_POST_request( + self, + request_data: dict, + vertex_project: str, + vertex_location: str, + vertex_credentials: str, + request_route: str, + ): + auth_header, _ = self._get_token_and_url( + model="", + gemini_api_key=None, + vertex_credentials=vertex_credentials, + vertex_project=vertex_project, + vertex_location=vertex_location, + stream=False, + custom_llm_provider="vertex_ai_beta", + api_base="", + ) + + headers = { + "Authorization": f"Bearer {auth_header}", + "Content-Type": "application/json", + } + + url = None + if request_route == "/tuningJobs": + url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/tuningJobs" + elif "/tuningJobs/" in request_route and "cancel" in request_route: + url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/tuningJobs{request_route}" + elif "generateContent" in request_route: + url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}{request_route}" + elif "predict" in request_route: + url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}{request_route}" + elif "/batchPredictionJobs" in request_route: + url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}{request_route}" + elif "countTokens" in request_route: + url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}{request_route}" + else: + raise ValueError(f"Unsupported Vertex AI request route: {request_route}") + if self.async_handler is None: + raise ValueError("VertexAI Fine Tuning - async_handler is not initialized") + + response = await self.async_handler.post( + headers=headers, + url=url, + json=request_data, # type: ignore + ) + + if response.status_code != 200: + raise Exception( + f"Error creating fine tuning job. Status code: {response.status_code}. Response: {response.text}" + ) + + response_json = response.json() + return response_json diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index aa2bfc525..0750a3937 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -48,6 +48,11 @@ files_settings: - custom_llm_provider: openai api_key: os.environ/OPENAI_API_KEY +default_vertex_config: + vertex_project: "adroit-crow-413218" + vertex_location: "us-central1" + vertex_credentials: "/Users/ishaanjaffer/Downloads/adroit-crow-413218-a956eef1a2a8.json" + general_settings: diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 0f57a5fd1..83126b954 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -213,6 +213,8 @@ from litellm.proxy.utils import ( send_email, update_spend, ) +from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import router as vertex_router +from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import set_default_vertex_config from litellm.router import ( AssistantsTypedDict, Deployment, @@ -1818,6 +1820,10 @@ class ProxyConfig: files_config = config.get("files_settings", None) set_files_config(config=files_config) + ## default config for vertex ai routes + default_vertex_config = config.get("default_vertex_config", None) + set_default_vertex_config(config=default_vertex_config) + ## ROUTER SETTINGS (e.g. routing_strategy, ...) router_settings = config.get("router_settings", None) if router_settings and isinstance(router_settings, dict): @@ -9631,6 +9637,7 @@ def cleanup_router_config_variables(): app.include_router(router) app.include_router(fine_tuning_router) +app.include_router(vertex_router) app.include_router(health_router) app.include_router(key_management_router) app.include_router(internal_user_router) diff --git a/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py b/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py new file mode 100644 index 000000000..b8c04583c --- /dev/null +++ b/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py @@ -0,0 +1,305 @@ +import ast +import asyncio +import traceback +from datetime import datetime, timedelta, timezone +from typing import List, Optional + +import fastapi +import httpx +from fastapi import ( + APIRouter, + Depends, + File, + Form, + Header, + HTTPException, + Request, + Response, + UploadFile, + status, +) + +import litellm +from litellm._logging import verbose_proxy_logger +from litellm.batches.main import FileObject +from litellm.fine_tuning.main import vertex_fine_tuning_apis_instance +from litellm.proxy._types import * +from litellm.proxy.auth.user_api_key_auth import user_api_key_auth + +router = APIRouter() +default_vertex_config = None + + +def set_default_vertex_config(config): + global default_vertex_config + if config is None: + return + + if not isinstance(config, dict): + raise ValueError("invalid config, vertex default config must be a dictionary") + + if isinstance(config, dict): + for key, value in config.items(): + if isinstance(value, str) and value.startswith("os.environ/"): + config[key] = litellm.get_secret(value) + + default_vertex_config = config + + +def exception_handler(e: Exception): + verbose_proxy_logger.error( + "litellm.proxy.proxy_server.v1/projects/tuningJobs(): Exception occurred - {}".format( + str(e) + ) + ) + verbose_proxy_logger.debug(traceback.format_exc()) + if isinstance(e, HTTPException): + return ProxyException( + message=getattr(e, "message", str(e.detail)), + type=getattr(e, "type", "None"), + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), + ) + else: + error_msg = f"{str(e)}" + return ProxyException( + message=getattr(e, "message", error_msg), + type=getattr(e, "type", "None"), + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", 500), + ) + + +async def execute_post_vertex_ai_request( + request: Request, + route: str, +): + from litellm.fine_tuning.main import vertex_fine_tuning_apis_instance + + if default_vertex_config is None: + raise ValueError( + "Vertex credentials not added on litellm proxy, please add `default_vertex_config` on your config.yaml" + ) + vertex_project = default_vertex_config.get("vertex_project", None) + vertex_location = default_vertex_config.get("vertex_location", None) + vertex_credentials = default_vertex_config.get("vertex_credentials", None) + + request_data_json = {} + body = await request.body() + body_str = body.decode() + if len(body_str) > 0: + try: + request_data_json = ast.literal_eval(body_str) + except: + request_data_json = json.loads(body_str) + + verbose_proxy_logger.debug( + "Request received by LiteLLM:\n{}".format( + json.dumps(request_data_json, indent=4) + ), + ) + + response = ( + await vertex_fine_tuning_apis_instance.pass_through_vertex_ai_POST_request( + request_data=request_data_json, + vertex_project=vertex_project, + vertex_location=vertex_location, + vertex_credentials=vertex_credentials, + request_route=route, + ) + ) + + return response + + +@router.post( + "/vertex-ai/publishers/google/models/{model_id:path}:generateContent", + dependencies=[Depends(user_api_key_auth)], + tags=["Vertex AI endpoints"], +) +async def vertex_generate_content( + request: Request, + fastapi_response: Response, + model_id: str, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + this is a pass through endpoint for the Vertex AI API. /generateContent endpoint + + Example Curl: + ``` + curl http://localhost:4000/vertex-ai/publishers/google/models/gemini-1.5-flash-001:generateContent \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer sk-1234" \ + -d '{"contents":[{"role": "user", "parts":[{"text": "hi"}]}]}' + ``` + + Vertex API Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference#rest + it uses the vertex ai credentials on the proxy and forwards to vertex ai api + """ + try: + response = await execute_post_vertex_ai_request( + request=request, + route=f"/publishers/google/models/{model_id}:generateContent", + ) + return response + except Exception as e: + raise exception_handler(e) from e + + +@router.post( + "/vertex-ai/publishers/google/models/{model_id:path}:predict", + dependencies=[Depends(user_api_key_auth)], + tags=["Vertex AI endpoints"], +) +async def vertex_predict_endpoint( + request: Request, + fastapi_response: Response, + model_id: str, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + this is a pass through endpoint for the Vertex AI API. /predict endpoint + Use this for: + - Embeddings API - Text Embedding, Multi Modal Embedding + - Imagen API + - Code Completion API + + Example Curl: + ``` + curl http://localhost:4000/vertex-ai/publishers/google/models/textembedding-gecko@001:predict \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer sk-1234" \ + -d '{"instances":[{"content": "gm"}]}' + ``` + + Vertex API Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings-api#generative-ai-get-text-embedding-drest + it uses the vertex ai credentials on the proxy and forwards to vertex ai api + """ + try: + response = await execute_post_vertex_ai_request( + request=request, + route=f"/publishers/google/models/{model_id}:predict", + ) + return response + except Exception as e: + raise exception_handler(e) from e + + +@router.post( + "/vertex-ai/publishers/google/models/{model_id:path}:countTokens", + dependencies=[Depends(user_api_key_auth)], + tags=["Vertex AI endpoints"], +) +async def vertex_countTokens_endpoint( + request: Request, + fastapi_response: Response, + model_id: str, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + this is a pass through endpoint for the Vertex AI API. /countTokens endpoint + https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/count-tokens#curl + + + Example Curl: + ``` + curl http://localhost:4000/vertex-ai/publishers/google/models/gemini-1.5-flash-001:countTokens \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer sk-1234" \ + -d '{"contents":[{"role": "user", "parts":[{"text": "hi"}]}]}' + ``` + + it uses the vertex ai credentials on the proxy and forwards to vertex ai api + """ + try: + response = await execute_post_vertex_ai_request( + request=request, + route=f"/publishers/google/models/{model_id}:countTokens", + ) + return response + except Exception as e: + raise exception_handler(e) from e + + +@router.post( + "/vertex-ai/batchPredictionJobs", + dependencies=[Depends(user_api_key_auth)], + tags=["Vertex AI endpoints"], +) +async def vertex_create_batch_prediction_job( + request: Request, + fastapi_response: Response, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + this is a pass through endpoint for the Vertex AI API. /batchPredictionJobs endpoint + + Vertex API Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/batch-prediction-api#syntax + + it uses the vertex ai credentials on the proxy and forwards to vertex ai api + """ + try: + response = await execute_post_vertex_ai_request( + request=request, + route="/batchPredictionJobs", + ) + return response + except Exception as e: + raise exception_handler(e) from e + + +@router.post( + "/vertex-ai/tuningJobs", + dependencies=[Depends(user_api_key_auth)], + tags=["Vertex AI endpoints"], +) +async def vertex_create_fine_tuning_job( + request: Request, + fastapi_response: Response, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + this is a pass through endpoint for the Vertex AI API. /tuningJobs endpoint + + Vertex API Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/tuning + + it uses the vertex ai credentials on the proxy and forwards to vertex ai api + """ + try: + response = await execute_post_vertex_ai_request( + request=request, + route="/tuningJobs", + ) + return response + except Exception as e: + raise exception_handler(e) from e + + +@router.post( + "/vertex-ai/tuningJobs/{job_id:path}:cancel", + dependencies=[Depends(user_api_key_auth)], + tags=["Vertex AI endpoints"], +) +async def vertex_cancel_fine_tuning_job( + request: Request, + job_id: str, + fastapi_response: Response, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + this is a pass through endpoint for the Vertex AI API. tuningJobs/{job_id:path}:cancel + + Vertex API Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/tuning#cancel_a_tuning_job + + it uses the vertex ai credentials on the proxy and forwards to vertex ai api + """ + try: + + response = await execute_post_vertex_ai_request( + request=request, + route=f"/tuningJobs/{job_id}:cancel", + ) + return response + except Exception as e: + raise exception_handler(e) from e