mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
refactor vertex endpoints to pass through all routes
This commit is contained in:
parent
f947cec7fc
commit
0e1d3804ff
3 changed files with 70 additions and 259 deletions
|
@ -25,6 +25,9 @@ from litellm.batches.main import FileObject
|
|||
from litellm.fine_tuning.main import vertex_fine_tuning_apis_instance
|
||||
from litellm.proxy._types import *
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
|
||||
create_pass_through_route,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
default_vertex_config = None
|
||||
|
@ -70,10 +73,17 @@ def exception_handler(e: Exception):
|
|||
)
|
||||
|
||||
|
||||
async def execute_post_vertex_ai_request(
|
||||
@router.api_route(
|
||||
"/vertex-ai/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE"]
|
||||
)
|
||||
async def vertex_proxy_route(
|
||||
endpoint: str,
|
||||
request: Request,
|
||||
route: str,
|
||||
fastapi_response: Response,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
encoded_endpoint = httpx.URL(endpoint).path
|
||||
|
||||
from litellm.fine_tuning.main import vertex_fine_tuning_apis_instance
|
||||
|
||||
if default_vertex_config is None:
|
||||
|
@ -83,250 +93,52 @@ async def execute_post_vertex_ai_request(
|
|||
vertex_project = default_vertex_config.get("vertex_project", None)
|
||||
vertex_location = default_vertex_config.get("vertex_location", None)
|
||||
vertex_credentials = default_vertex_config.get("vertex_credentials", None)
|
||||
base_target_url = f"https://{vertex_location}-aiplatform.googleapis.com/"
|
||||
|
||||
request_data_json = {}
|
||||
body = await request.body()
|
||||
body_str = body.decode()
|
||||
if len(body_str) > 0:
|
||||
try:
|
||||
request_data_json = ast.literal_eval(body_str)
|
||||
except:
|
||||
request_data_json = json.loads(body_str)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"Request received by LiteLLM:\n{}".format(
|
||||
json.dumps(request_data_json, indent=4)
|
||||
),
|
||||
auth_header, _ = vertex_fine_tuning_apis_instance._get_token_and_url(
|
||||
model="",
|
||||
gemini_api_key=None,
|
||||
vertex_credentials=vertex_credentials,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
stream=False,
|
||||
custom_llm_provider="vertex_ai_beta",
|
||||
api_base="",
|
||||
)
|
||||
|
||||
response = (
|
||||
await vertex_fine_tuning_apis_instance.pass_through_vertex_ai_POST_request(
|
||||
request_data=request_data_json,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
vertex_credentials=vertex_credentials,
|
||||
request_route=route,
|
||||
)
|
||||
headers = {
|
||||
"Authorization": f"Bearer {auth_header}",
|
||||
}
|
||||
|
||||
request_route = encoded_endpoint
|
||||
verbose_proxy_logger.debug("request_route %s", request_route)
|
||||
|
||||
# Ensure endpoint starts with '/' for proper URL construction
|
||||
if not encoded_endpoint.startswith("/"):
|
||||
encoded_endpoint = "/" + encoded_endpoint
|
||||
|
||||
# Construct the full target URL using httpx
|
||||
base_url = httpx.URL(base_target_url)
|
||||
updated_url = base_url.copy_with(path=encoded_endpoint)
|
||||
|
||||
verbose_proxy_logger.debug("updated url %s", updated_url)
|
||||
|
||||
## check for streaming
|
||||
is_streaming_request = False
|
||||
if "stream" in str(updated_url):
|
||||
is_streaming_request = True
|
||||
|
||||
## CREATE PASS-THROUGH
|
||||
endpoint_func = create_pass_through_route(
|
||||
endpoint=endpoint,
|
||||
target=str(updated_url),
|
||||
custom_headers=headers,
|
||||
) # dynamically construct pass-through endpoint based on incoming path
|
||||
received_value = await endpoint_func(
|
||||
request,
|
||||
fastapi_response,
|
||||
user_api_key_dict,
|
||||
stream=is_streaming_request,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@router.post(
|
||||
"/vertex-ai/publishers/google/models/{model_id:path}:generateContent",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
tags=["Vertex AI endpoints"],
|
||||
)
|
||||
async def vertex_generate_content(
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
model_id: str,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
this is a pass through endpoint for the Vertex AI API. /generateContent endpoint
|
||||
|
||||
Example Curl:
|
||||
```
|
||||
curl http://localhost:4000/vertex-ai/publishers/google/models/gemini-1.5-flash-001:generateContent \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer sk-1234" \
|
||||
-d '{"contents":[{"role": "user", "parts":[{"text": "hi"}]}]}'
|
||||
```
|
||||
|
||||
Vertex API Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference#rest
|
||||
it uses the vertex ai credentials on the proxy and forwards to vertex ai api
|
||||
"""
|
||||
try:
|
||||
response = await execute_post_vertex_ai_request(
|
||||
request=request,
|
||||
route=f"/publishers/google/models/{model_id}:generateContent",
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
raise exception_handler(e) from e
|
||||
|
||||
|
||||
@router.post(
|
||||
"/vertex-ai/publishers/google/models/{model_id:path}:predict",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
tags=["Vertex AI endpoints"],
|
||||
)
|
||||
async def vertex_predict_endpoint(
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
model_id: str,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
this is a pass through endpoint for the Vertex AI API. /predict endpoint
|
||||
Use this for:
|
||||
- Embeddings API - Text Embedding, Multi Modal Embedding
|
||||
- Imagen API
|
||||
- Code Completion API
|
||||
|
||||
Example Curl:
|
||||
```
|
||||
curl http://localhost:4000/vertex-ai/publishers/google/models/textembedding-gecko@001:predict \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer sk-1234" \
|
||||
-d '{"instances":[{"content": "gm"}]}'
|
||||
```
|
||||
|
||||
Vertex API Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings-api#generative-ai-get-text-embedding-drest
|
||||
it uses the vertex ai credentials on the proxy and forwards to vertex ai api
|
||||
"""
|
||||
try:
|
||||
response = await execute_post_vertex_ai_request(
|
||||
request=request,
|
||||
route=f"/publishers/google/models/{model_id}:predict",
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
raise exception_handler(e) from e
|
||||
|
||||
|
||||
@router.post(
|
||||
"/vertex-ai/publishers/google/models/{model_id:path}:countTokens",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
tags=["Vertex AI endpoints"],
|
||||
)
|
||||
async def vertex_countTokens_endpoint(
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
model_id: str,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
this is a pass through endpoint for the Vertex AI API. /countTokens endpoint
|
||||
https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/count-tokens#curl
|
||||
|
||||
|
||||
Example Curl:
|
||||
```
|
||||
curl http://localhost:4000/vertex-ai/publishers/google/models/gemini-1.5-flash-001:countTokens \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer sk-1234" \
|
||||
-d '{"contents":[{"role": "user", "parts":[{"text": "hi"}]}]}'
|
||||
```
|
||||
|
||||
it uses the vertex ai credentials on the proxy and forwards to vertex ai api
|
||||
"""
|
||||
try:
|
||||
response = await execute_post_vertex_ai_request(
|
||||
request=request,
|
||||
route=f"/publishers/google/models/{model_id}:countTokens",
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
raise exception_handler(e) from e
|
||||
|
||||
|
||||
@router.post(
|
||||
"/vertex-ai/batchPredictionJobs",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
tags=["Vertex AI endpoints"],
|
||||
)
|
||||
async def vertex_create_batch_prediction_job(
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
this is a pass through endpoint for the Vertex AI API. /batchPredictionJobs endpoint
|
||||
|
||||
Vertex API Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/batch-prediction-api#syntax
|
||||
|
||||
it uses the vertex ai credentials on the proxy and forwards to vertex ai api
|
||||
"""
|
||||
try:
|
||||
response = await execute_post_vertex_ai_request(
|
||||
request=request,
|
||||
route="/batchPredictionJobs",
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
raise exception_handler(e) from e
|
||||
|
||||
|
||||
@router.post(
|
||||
"/vertex-ai/tuningJobs",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
tags=["Vertex AI endpoints"],
|
||||
)
|
||||
async def vertex_create_fine_tuning_job(
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
this is a pass through endpoint for the Vertex AI API. /tuningJobs endpoint
|
||||
|
||||
Vertex API Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/tuning
|
||||
|
||||
it uses the vertex ai credentials on the proxy and forwards to vertex ai api
|
||||
"""
|
||||
try:
|
||||
response = await execute_post_vertex_ai_request(
|
||||
request=request,
|
||||
route="/tuningJobs",
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
raise exception_handler(e) from e
|
||||
|
||||
|
||||
@router.post(
|
||||
"/vertex-ai/tuningJobs/{job_id:path}:cancel",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
tags=["Vertex AI endpoints"],
|
||||
)
|
||||
async def vertex_cancel_fine_tuning_job(
|
||||
request: Request,
|
||||
job_id: str,
|
||||
fastapi_response: Response,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
this is a pass through endpoint for the Vertex AI API. tuningJobs/{job_id:path}:cancel
|
||||
|
||||
Vertex API Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/tuning#cancel_a_tuning_job
|
||||
|
||||
it uses the vertex ai credentials on the proxy and forwards to vertex ai api
|
||||
"""
|
||||
try:
|
||||
|
||||
response = await execute_post_vertex_ai_request(
|
||||
request=request,
|
||||
route=f"/tuningJobs/{job_id}:cancel",
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
raise exception_handler(e) from e
|
||||
|
||||
|
||||
@router.post(
|
||||
"/vertex-ai/cachedContents",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
tags=["Vertex AI endpoints"],
|
||||
)
|
||||
async def vertex_create_add_cached_content(
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
this is a pass through endpoint for the Vertex AI API. /cachedContents endpoint
|
||||
|
||||
Vertex API Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/context-cache/context-cache-create#create-context-cache-sample-drest
|
||||
|
||||
it uses the vertex ai credentials on the proxy and forwards to vertex ai api
|
||||
"""
|
||||
try:
|
||||
response = await execute_post_vertex_ai_request(
|
||||
request=request,
|
||||
route="/cachedContents",
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
raise exception_handler(e) from e
|
||||
return received_value
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue