add vertex generateContent

This commit is contained in:
Ishaan Jaff 2024-08-03 17:17:54 -07:00
parent a293bebdb1
commit beb7f1b2c6
2 changed files with 86 additions and 1 deletions

View file

@ -1,3 +1,4 @@
import ast
import asyncio
import traceback
from datetime import datetime, timedelta, timezone
@ -78,7 +79,21 @@ async def execute_post_vertex_ai_request(
vertex_project = default_vertex_config.get("vertex_project", None)
vertex_location = default_vertex_config.get("vertex_location", None)
vertex_credentials = default_vertex_config.get("vertex_credentials", None)
request_data_json = await request.json()
request_data_json = {}
body = await request.body()
body_str = body.decode()
if len(body_str) > 0:
try:
request_data_json = ast.literal_eval(body_str)
except:
request_data_json = json.loads(body_str)
verbose_proxy_logger.debug(
"Request received by LiteLLM:\n{}".format(
json.dumps(request_data_json, indent=4)
),
)
response = (
await vertex_fine_tuning_apis_instance.pass_through_vertex_ai_POST_request(
@ -93,6 +108,41 @@ async def execute_post_vertex_ai_request(
return response
@router.post(
"/vertex-ai/publishers/google/models/{model_id:path}:generateContent",
dependencies=[Depends(user_api_key_auth)],
tags=["Vertex AI endpoints"],
)
async def vertex_generate_content(
request: Request,
fastapi_response: Response,
model_id: str,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
this is a pass through endpoint for the Vertex AI API. /generateContent endpoint
Example Curl:
```
curl http://localhost:4000/vertex-ai/publishers/google/models/gemini-1.5-flash-001:generateContent \
-H "Content-Type: application/json" \
-H "Authorization: Bearer sk-1234" \
-d '{"contents":[{"role": "user", "parts":[{"text": "hi"}]}]}'
```
Vertex API Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference#rest
it uses the vertex ai credentials on the proxy and forwards to vertex ai api
"""
try:
response = await execute_post_vertex_ai_request(
request=request,
route=f"/publishers/google/models/{model_id}:generateContent",
)
return response
except Exception as e:
raise exception_handler(e) from e
@router.post(
"/vertex-ai/tuningJobs",
dependencies=[Depends(user_api_key_auth)],
@ -106,6 +156,8 @@ async def vertex_create_fine_tuning_job(
"""
this is a pass through endpoint for the Vertex AI API. /tuningJobs endpoint
Vertex API Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/tuning
it uses the vertex ai credentials on the proxy and forwards to vertex ai api
"""
try:
@ -116,3 +168,32 @@ async def vertex_create_fine_tuning_job(
return response
except Exception as e:
raise exception_handler(e) from e
@router.post(
"/vertex-ai/tuningJobs/{job_id:path}:cancel",
dependencies=[Depends(user_api_key_auth)],
tags=["Vertex AI endpoints"],
)
async def vertex_cancel_fine_tuning_job(
request: Request,
job_id: str,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
this is a pass through endpoint for the Vertex AI API. tuningJobs/{job_id:path}:cancel
Vertex API Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/tuning#cancel_a_tuning_job
it uses the vertex ai credentials on the proxy and forwards to vertex ai api
"""
try:
response = await execute_post_vertex_ai_request(
request=request,
route=f"/tuningJobs/{job_id}:cancel",
)
return response
except Exception as e:
raise exception_handler(e) from e