mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
refactor vertex endpoints to pass through all routes
This commit is contained in:
parent
f947cec7fc
commit
0e1d3804ff
3 changed files with 70 additions and 259 deletions
|
@ -120,6 +120,8 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
||||||
max_parallel_requests = user_api_key_dict.max_parallel_requests
|
max_parallel_requests = user_api_key_dict.max_parallel_requests
|
||||||
if max_parallel_requests is None:
|
if max_parallel_requests is None:
|
||||||
max_parallel_requests = sys.maxsize
|
max_parallel_requests = sys.maxsize
|
||||||
|
if data is None:
|
||||||
|
data = {}
|
||||||
global_max_parallel_requests = data.get("metadata", {}).get(
|
global_max_parallel_requests = data.get("metadata", {}).get(
|
||||||
"global_max_parallel_requests", None
|
"global_max_parallel_requests", None
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,18 +1,15 @@
|
||||||
model_list:
|
model_list:
|
||||||
- model_name: fake-openai-endpoint
|
- model_name: multimodalembedding@001
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: openai/fake
|
model: vertex_ai/multimodalembedding@001
|
||||||
api_key: fake-key
|
vertex_project: "adroit-crow-413218"
|
||||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
vertex_location: "us-central1"
|
||||||
|
vertex_credentials: adroit-crow-413218-a956eef1a2a8.json
|
||||||
|
|
||||||
guardrails:
|
default_vertex_config:
|
||||||
- guardrail_name: "lakera-pre-guard"
|
vertex_project: "adroit-crow-413218"
|
||||||
litellm_params:
|
vertex_location: "us-central1"
|
||||||
guardrail: lakera # supported values: "aporia", "bedrock", "lakera"
|
vertex_credentials: adroit-crow-413218-a956eef1a2a8.json
|
||||||
mode: "during_call"
|
|
||||||
api_key: os.environ/LAKERA_API_KEY
|
|
||||||
api_base: os.environ/LAKERA_API_BASE
|
|
||||||
category_thresholds:
|
|
||||||
prompt_injection: 0.1
|
|
||||||
jailbreak: 0.1
|
|
||||||
|
|
||||||
|
litellm_settings:
|
||||||
|
drop_params: True
|
|
@ -25,6 +25,9 @@ from litellm.batches.main import FileObject
|
||||||
from litellm.fine_tuning.main import vertex_fine_tuning_apis_instance
|
from litellm.fine_tuning.main import vertex_fine_tuning_apis_instance
|
||||||
from litellm.proxy._types import *
|
from litellm.proxy._types import *
|
||||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||||
|
from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
|
||||||
|
create_pass_through_route,
|
||||||
|
)
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
default_vertex_config = None
|
default_vertex_config = None
|
||||||
|
@ -70,10 +73,17 @@ def exception_handler(e: Exception):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def execute_post_vertex_ai_request(
|
@router.api_route(
|
||||||
|
"/vertex-ai/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE"]
|
||||||
|
)
|
||||||
|
async def vertex_proxy_route(
|
||||||
|
endpoint: str,
|
||||||
request: Request,
|
request: Request,
|
||||||
route: str,
|
fastapi_response: Response,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||||
):
|
):
|
||||||
|
encoded_endpoint = httpx.URL(endpoint).path
|
||||||
|
|
||||||
from litellm.fine_tuning.main import vertex_fine_tuning_apis_instance
|
from litellm.fine_tuning.main import vertex_fine_tuning_apis_instance
|
||||||
|
|
||||||
if default_vertex_config is None:
|
if default_vertex_config is None:
|
||||||
|
@ -83,250 +93,52 @@ async def execute_post_vertex_ai_request(
|
||||||
vertex_project = default_vertex_config.get("vertex_project", None)
|
vertex_project = default_vertex_config.get("vertex_project", None)
|
||||||
vertex_location = default_vertex_config.get("vertex_location", None)
|
vertex_location = default_vertex_config.get("vertex_location", None)
|
||||||
vertex_credentials = default_vertex_config.get("vertex_credentials", None)
|
vertex_credentials = default_vertex_config.get("vertex_credentials", None)
|
||||||
|
base_target_url = f"https://{vertex_location}-aiplatform.googleapis.com/"
|
||||||
|
|
||||||
request_data_json = {}
|
auth_header, _ = vertex_fine_tuning_apis_instance._get_token_and_url(
|
||||||
body = await request.body()
|
model="",
|
||||||
body_str = body.decode()
|
gemini_api_key=None,
|
||||||
if len(body_str) > 0:
|
vertex_credentials=vertex_credentials,
|
||||||
try:
|
vertex_project=vertex_project,
|
||||||
request_data_json = ast.literal_eval(body_str)
|
vertex_location=vertex_location,
|
||||||
except:
|
stream=False,
|
||||||
request_data_json = json.loads(body_str)
|
custom_llm_provider="vertex_ai_beta",
|
||||||
|
api_base="",
|
||||||
verbose_proxy_logger.debug(
|
|
||||||
"Request received by LiteLLM:\n{}".format(
|
|
||||||
json.dumps(request_data_json, indent=4)
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
response = (
|
headers = {
|
||||||
await vertex_fine_tuning_apis_instance.pass_through_vertex_ai_POST_request(
|
"Authorization": f"Bearer {auth_header}",
|
||||||
request_data=request_data_json,
|
}
|
||||||
vertex_project=vertex_project,
|
|
||||||
vertex_location=vertex_location,
|
request_route = encoded_endpoint
|
||||||
vertex_credentials=vertex_credentials,
|
verbose_proxy_logger.debug("request_route %s", request_route)
|
||||||
request_route=route,
|
|
||||||
)
|
# Ensure endpoint starts with '/' for proper URL construction
|
||||||
|
if not encoded_endpoint.startswith("/"):
|
||||||
|
encoded_endpoint = "/" + encoded_endpoint
|
||||||
|
|
||||||
|
# Construct the full target URL using httpx
|
||||||
|
base_url = httpx.URL(base_target_url)
|
||||||
|
updated_url = base_url.copy_with(path=encoded_endpoint)
|
||||||
|
|
||||||
|
verbose_proxy_logger.debug("updated url %s", updated_url)
|
||||||
|
|
||||||
|
## check for streaming
|
||||||
|
is_streaming_request = False
|
||||||
|
if "stream" in str(updated_url):
|
||||||
|
is_streaming_request = True
|
||||||
|
|
||||||
|
## CREATE PASS-THROUGH
|
||||||
|
endpoint_func = create_pass_through_route(
|
||||||
|
endpoint=endpoint,
|
||||||
|
target=str(updated_url),
|
||||||
|
custom_headers=headers,
|
||||||
|
) # dynamically construct pass-through endpoint based on incoming path
|
||||||
|
received_value = await endpoint_func(
|
||||||
|
request,
|
||||||
|
fastapi_response,
|
||||||
|
user_api_key_dict,
|
||||||
|
stream=is_streaming_request,
|
||||||
)
|
)
|
||||||
|
|
||||||
return response
|
return received_value
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/vertex-ai/publishers/google/models/{model_id:path}:generateContent",
|
|
||||||
dependencies=[Depends(user_api_key_auth)],
|
|
||||||
tags=["Vertex AI endpoints"],
|
|
||||||
)
|
|
||||||
async def vertex_generate_content(
|
|
||||||
request: Request,
|
|
||||||
fastapi_response: Response,
|
|
||||||
model_id: str,
|
|
||||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
this is a pass through endpoint for the Vertex AI API. /generateContent endpoint
|
|
||||||
|
|
||||||
Example Curl:
|
|
||||||
```
|
|
||||||
curl http://localhost:4000/vertex-ai/publishers/google/models/gemini-1.5-flash-001:generateContent \
|
|
||||||
-H "Content-Type: application/json" \
|
|
||||||
-H "Authorization: Bearer sk-1234" \
|
|
||||||
-d '{"contents":[{"role": "user", "parts":[{"text": "hi"}]}]}'
|
|
||||||
```
|
|
||||||
|
|
||||||
Vertex API Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference#rest
|
|
||||||
it uses the vertex ai credentials on the proxy and forwards to vertex ai api
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
response = await execute_post_vertex_ai_request(
|
|
||||||
request=request,
|
|
||||||
route=f"/publishers/google/models/{model_id}:generateContent",
|
|
||||||
)
|
|
||||||
return response
|
|
||||||
except Exception as e:
|
|
||||||
raise exception_handler(e) from e
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/vertex-ai/publishers/google/models/{model_id:path}:predict",
|
|
||||||
dependencies=[Depends(user_api_key_auth)],
|
|
||||||
tags=["Vertex AI endpoints"],
|
|
||||||
)
|
|
||||||
async def vertex_predict_endpoint(
|
|
||||||
request: Request,
|
|
||||||
fastapi_response: Response,
|
|
||||||
model_id: str,
|
|
||||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
this is a pass through endpoint for the Vertex AI API. /predict endpoint
|
|
||||||
Use this for:
|
|
||||||
- Embeddings API - Text Embedding, Multi Modal Embedding
|
|
||||||
- Imagen API
|
|
||||||
- Code Completion API
|
|
||||||
|
|
||||||
Example Curl:
|
|
||||||
```
|
|
||||||
curl http://localhost:4000/vertex-ai/publishers/google/models/textembedding-gecko@001:predict \
|
|
||||||
-H "Content-Type: application/json" \
|
|
||||||
-H "Authorization: Bearer sk-1234" \
|
|
||||||
-d '{"instances":[{"content": "gm"}]}'
|
|
||||||
```
|
|
||||||
|
|
||||||
Vertex API Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings-api#generative-ai-get-text-embedding-drest
|
|
||||||
it uses the vertex ai credentials on the proxy and forwards to vertex ai api
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
response = await execute_post_vertex_ai_request(
|
|
||||||
request=request,
|
|
||||||
route=f"/publishers/google/models/{model_id}:predict",
|
|
||||||
)
|
|
||||||
return response
|
|
||||||
except Exception as e:
|
|
||||||
raise exception_handler(e) from e
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/vertex-ai/publishers/google/models/{model_id:path}:countTokens",
|
|
||||||
dependencies=[Depends(user_api_key_auth)],
|
|
||||||
tags=["Vertex AI endpoints"],
|
|
||||||
)
|
|
||||||
async def vertex_countTokens_endpoint(
|
|
||||||
request: Request,
|
|
||||||
fastapi_response: Response,
|
|
||||||
model_id: str,
|
|
||||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
this is a pass through endpoint for the Vertex AI API. /countTokens endpoint
|
|
||||||
https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/count-tokens#curl
|
|
||||||
|
|
||||||
|
|
||||||
Example Curl:
|
|
||||||
```
|
|
||||||
curl http://localhost:4000/vertex-ai/publishers/google/models/gemini-1.5-flash-001:countTokens \
|
|
||||||
-H "Content-Type: application/json" \
|
|
||||||
-H "Authorization: Bearer sk-1234" \
|
|
||||||
-d '{"contents":[{"role": "user", "parts":[{"text": "hi"}]}]}'
|
|
||||||
```
|
|
||||||
|
|
||||||
it uses the vertex ai credentials on the proxy and forwards to vertex ai api
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
response = await execute_post_vertex_ai_request(
|
|
||||||
request=request,
|
|
||||||
route=f"/publishers/google/models/{model_id}:countTokens",
|
|
||||||
)
|
|
||||||
return response
|
|
||||||
except Exception as e:
|
|
||||||
raise exception_handler(e) from e
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/vertex-ai/batchPredictionJobs",
|
|
||||||
dependencies=[Depends(user_api_key_auth)],
|
|
||||||
tags=["Vertex AI endpoints"],
|
|
||||||
)
|
|
||||||
async def vertex_create_batch_prediction_job(
|
|
||||||
request: Request,
|
|
||||||
fastapi_response: Response,
|
|
||||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
this is a pass through endpoint for the Vertex AI API. /batchPredictionJobs endpoint
|
|
||||||
|
|
||||||
Vertex API Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/batch-prediction-api#syntax
|
|
||||||
|
|
||||||
it uses the vertex ai credentials on the proxy and forwards to vertex ai api
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
response = await execute_post_vertex_ai_request(
|
|
||||||
request=request,
|
|
||||||
route="/batchPredictionJobs",
|
|
||||||
)
|
|
||||||
return response
|
|
||||||
except Exception as e:
|
|
||||||
raise exception_handler(e) from e
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/vertex-ai/tuningJobs",
|
|
||||||
dependencies=[Depends(user_api_key_auth)],
|
|
||||||
tags=["Vertex AI endpoints"],
|
|
||||||
)
|
|
||||||
async def vertex_create_fine_tuning_job(
|
|
||||||
request: Request,
|
|
||||||
fastapi_response: Response,
|
|
||||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
this is a pass through endpoint for the Vertex AI API. /tuningJobs endpoint
|
|
||||||
|
|
||||||
Vertex API Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/tuning
|
|
||||||
|
|
||||||
it uses the vertex ai credentials on the proxy and forwards to vertex ai api
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
response = await execute_post_vertex_ai_request(
|
|
||||||
request=request,
|
|
||||||
route="/tuningJobs",
|
|
||||||
)
|
|
||||||
return response
|
|
||||||
except Exception as e:
|
|
||||||
raise exception_handler(e) from e
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/vertex-ai/tuningJobs/{job_id:path}:cancel",
|
|
||||||
dependencies=[Depends(user_api_key_auth)],
|
|
||||||
tags=["Vertex AI endpoints"],
|
|
||||||
)
|
|
||||||
async def vertex_cancel_fine_tuning_job(
|
|
||||||
request: Request,
|
|
||||||
job_id: str,
|
|
||||||
fastapi_response: Response,
|
|
||||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
this is a pass through endpoint for the Vertex AI API. tuningJobs/{job_id:path}:cancel
|
|
||||||
|
|
||||||
Vertex API Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/tuning#cancel_a_tuning_job
|
|
||||||
|
|
||||||
it uses the vertex ai credentials on the proxy and forwards to vertex ai api
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
|
|
||||||
response = await execute_post_vertex_ai_request(
|
|
||||||
request=request,
|
|
||||||
route=f"/tuningJobs/{job_id}:cancel",
|
|
||||||
)
|
|
||||||
return response
|
|
||||||
except Exception as e:
|
|
||||||
raise exception_handler(e) from e
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/vertex-ai/cachedContents",
|
|
||||||
dependencies=[Depends(user_api_key_auth)],
|
|
||||||
tags=["Vertex AI endpoints"],
|
|
||||||
)
|
|
||||||
async def vertex_create_add_cached_content(
|
|
||||||
request: Request,
|
|
||||||
fastapi_response: Response,
|
|
||||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
this is a pass through endpoint for the Vertex AI API. /cachedContents endpoint
|
|
||||||
|
|
||||||
Vertex API Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/context-cache/context-cache-create#create-context-cache-sample-drest
|
|
||||||
|
|
||||||
it uses the vertex ai credentials on the proxy and forwards to vertex ai api
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
response = await execute_post_vertex_ai_request(
|
|
||||||
request=request,
|
|
||||||
route="/cachedContents",
|
|
||||||
)
|
|
||||||
return response
|
|
||||||
except Exception as e:
|
|
||||||
raise exception_handler(e) from e
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue