refactor vertex endpoints to pass through all routes

This commit is contained in:
Ishaan Jaff 2024-08-21 17:08:42 -07:00
parent f947cec7fc
commit 0e1d3804ff
3 changed files with 70 additions and 259 deletions

View file

@ -25,6 +25,9 @@ from litellm.batches.main import FileObject
from litellm.fine_tuning.main import vertex_fine_tuning_apis_instance
from litellm.proxy._types import *
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
create_pass_through_route,
)
router = APIRouter()
default_vertex_config = None
@ -70,10 +73,17 @@ def exception_handler(e: Exception):
)
async def execute_post_vertex_ai_request(
@router.api_route(
"/vertex-ai/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE"]
)
async def vertex_proxy_route(
endpoint: str,
request: Request,
route: str,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
encoded_endpoint = httpx.URL(endpoint).path
from litellm.fine_tuning.main import vertex_fine_tuning_apis_instance
if default_vertex_config is None:
@ -83,250 +93,52 @@ async def execute_post_vertex_ai_request(
vertex_project = default_vertex_config.get("vertex_project", None)
vertex_location = default_vertex_config.get("vertex_location", None)
vertex_credentials = default_vertex_config.get("vertex_credentials", None)
base_target_url = f"https://{vertex_location}-aiplatform.googleapis.com/"
request_data_json = {}
body = await request.body()
body_str = body.decode()
if len(body_str) > 0:
try:
request_data_json = ast.literal_eval(body_str)
except:
request_data_json = json.loads(body_str)
verbose_proxy_logger.debug(
"Request received by LiteLLM:\n{}".format(
json.dumps(request_data_json, indent=4)
),
auth_header, _ = vertex_fine_tuning_apis_instance._get_token_and_url(
model="",
gemini_api_key=None,
vertex_credentials=vertex_credentials,
vertex_project=vertex_project,
vertex_location=vertex_location,
stream=False,
custom_llm_provider="vertex_ai_beta",
api_base="",
)
response = (
await vertex_fine_tuning_apis_instance.pass_through_vertex_ai_POST_request(
request_data=request_data_json,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_credentials=vertex_credentials,
request_route=route,
)
headers = {
"Authorization": f"Bearer {auth_header}",
}
request_route = encoded_endpoint
verbose_proxy_logger.debug("request_route %s", request_route)
# Ensure endpoint starts with '/' for proper URL construction
if not encoded_endpoint.startswith("/"):
encoded_endpoint = "/" + encoded_endpoint
# Construct the full target URL using httpx
base_url = httpx.URL(base_target_url)
updated_url = base_url.copy_with(path=encoded_endpoint)
verbose_proxy_logger.debug("updated url %s", updated_url)
## check for streaming
is_streaming_request = False
if "stream" in str(updated_url):
is_streaming_request = True
## CREATE PASS-THROUGH
endpoint_func = create_pass_through_route(
endpoint=endpoint,
target=str(updated_url),
custom_headers=headers,
) # dynamically construct pass-through endpoint based on incoming path
received_value = await endpoint_func(
request,
fastapi_response,
user_api_key_dict,
stream=is_streaming_request,
)
return response
@router.post(
"/vertex-ai/publishers/google/models/{model_id:path}:generateContent",
dependencies=[Depends(user_api_key_auth)],
tags=["Vertex AI endpoints"],
)
async def vertex_generate_content(
request: Request,
fastapi_response: Response,
model_id: str,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
this is a pass through endpoint for the Vertex AI API. /generateContent endpoint
Example Curl:
```
curl http://localhost:4000/vertex-ai/publishers/google/models/gemini-1.5-flash-001:generateContent \
-H "Content-Type: application/json" \
-H "Authorization: Bearer sk-1234" \
-d '{"contents":[{"role": "user", "parts":[{"text": "hi"}]}]}'
```
Vertex API Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference#rest
it uses the vertex ai credentials on the proxy and forwards to vertex ai api
"""
try:
response = await execute_post_vertex_ai_request(
request=request,
route=f"/publishers/google/models/{model_id}:generateContent",
)
return response
except Exception as e:
raise exception_handler(e) from e
@router.post(
"/vertex-ai/publishers/google/models/{model_id:path}:predict",
dependencies=[Depends(user_api_key_auth)],
tags=["Vertex AI endpoints"],
)
async def vertex_predict_endpoint(
request: Request,
fastapi_response: Response,
model_id: str,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
this is a pass through endpoint for the Vertex AI API. /predict endpoint
Use this for:
- Embeddings API - Text Embedding, Multi Modal Embedding
- Imagen API
- Code Completion API
Example Curl:
```
curl http://localhost:4000/vertex-ai/publishers/google/models/textembedding-gecko@001:predict \
-H "Content-Type: application/json" \
-H "Authorization: Bearer sk-1234" \
-d '{"instances":[{"content": "gm"}]}'
```
Vertex API Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings-api#generative-ai-get-text-embedding-drest
it uses the vertex ai credentials on the proxy and forwards to vertex ai api
"""
try:
response = await execute_post_vertex_ai_request(
request=request,
route=f"/publishers/google/models/{model_id}:predict",
)
return response
except Exception as e:
raise exception_handler(e) from e
@router.post(
"/vertex-ai/publishers/google/models/{model_id:path}:countTokens",
dependencies=[Depends(user_api_key_auth)],
tags=["Vertex AI endpoints"],
)
async def vertex_countTokens_endpoint(
request: Request,
fastapi_response: Response,
model_id: str,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
this is a pass through endpoint for the Vertex AI API. /countTokens endpoint
https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/count-tokens#curl
Example Curl:
```
curl http://localhost:4000/vertex-ai/publishers/google/models/gemini-1.5-flash-001:countTokens \
-H "Content-Type: application/json" \
-H "Authorization: Bearer sk-1234" \
-d '{"contents":[{"role": "user", "parts":[{"text": "hi"}]}]}'
```
it uses the vertex ai credentials on the proxy and forwards to vertex ai api
"""
try:
response = await execute_post_vertex_ai_request(
request=request,
route=f"/publishers/google/models/{model_id}:countTokens",
)
return response
except Exception as e:
raise exception_handler(e) from e
@router.post(
"/vertex-ai/batchPredictionJobs",
dependencies=[Depends(user_api_key_auth)],
tags=["Vertex AI endpoints"],
)
async def vertex_create_batch_prediction_job(
request: Request,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
this is a pass through endpoint for the Vertex AI API. /batchPredictionJobs endpoint
Vertex API Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/batch-prediction-api#syntax
it uses the vertex ai credentials on the proxy and forwards to vertex ai api
"""
try:
response = await execute_post_vertex_ai_request(
request=request,
route="/batchPredictionJobs",
)
return response
except Exception as e:
raise exception_handler(e) from e
@router.post(
"/vertex-ai/tuningJobs",
dependencies=[Depends(user_api_key_auth)],
tags=["Vertex AI endpoints"],
)
async def vertex_create_fine_tuning_job(
request: Request,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
this is a pass through endpoint for the Vertex AI API. /tuningJobs endpoint
Vertex API Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/tuning
it uses the vertex ai credentials on the proxy and forwards to vertex ai api
"""
try:
response = await execute_post_vertex_ai_request(
request=request,
route="/tuningJobs",
)
return response
except Exception as e:
raise exception_handler(e) from e
@router.post(
"/vertex-ai/tuningJobs/{job_id:path}:cancel",
dependencies=[Depends(user_api_key_auth)],
tags=["Vertex AI endpoints"],
)
async def vertex_cancel_fine_tuning_job(
request: Request,
job_id: str,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
this is a pass through endpoint for the Vertex AI API. tuningJobs/{job_id:path}:cancel
Vertex API Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/tuning#cancel_a_tuning_job
it uses the vertex ai credentials on the proxy and forwards to vertex ai api
"""
try:
response = await execute_post_vertex_ai_request(
request=request,
route=f"/tuningJobs/{job_id}:cancel",
)
return response
except Exception as e:
raise exception_handler(e) from e
@router.post(
"/vertex-ai/cachedContents",
dependencies=[Depends(user_api_key_auth)],
tags=["Vertex AI endpoints"],
)
async def vertex_create_add_cached_content(
request: Request,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
this is a pass through endpoint for the Vertex AI API. /cachedContents endpoint
Vertex API Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/context-cache/context-cache-create#create-context-cache-sample-drest
it uses the vertex ai credentials on the proxy and forwards to vertex ai api
"""
try:
response = await execute_post_vertex_ai_request(
request=request,
route="/cachedContents",
)
return response
except Exception as e:
raise exception_handler(e) from e
return received_value