add vertex ai private endpoint support

This commit is contained in:
Rena Lu 2024-02-09 16:19:26 -05:00
parent 048acc8e68
commit 0e8a0aefd5

View file

@ -301,6 +301,11 @@ def completion(
gapic_content_types.SafetySetting(x) for x in safety_settings gapic_content_types.SafetySetting(x) for x in safety_settings
] ]
## Custom model deployed on private endpoint
private = False
if "private" in optional_params:
private = optional_params.pop("private")
# vertexai does not use an API key, it looks for credentials.json in the environment # vertexai does not use an API key, it looks for credentials.json in the environment
prompt = " ".join( prompt = " ".join(
@ -344,22 +349,32 @@ def completion(
mode = "chat" mode = "chat"
request_str += f"llm_model = CodeChatModel.from_pretrained({model})\n" request_str += f"llm_model = CodeChatModel.from_pretrained({model})\n"
else: # assume vertex model garden else: # assume vertex model garden
client = aiplatform.gapic.PredictionServiceClient( mode = "custom"
client_options=client_options
)
instances = [optional_params] instances = [optional_params]
instances[0]["prompt"] = prompt instances[0]["prompt"] = prompt
instances = [ if not private:
json_format.ParseDict(instance_dict, Value()) # private endpoint requires a string to be passed in
for instance_dict in instances instances = [
] json_format.ParseDict(instance_dict, Value())
llm_model = client.endpoint_path( for instance_dict in instances
]
client = aiplatform.gapic.PredictionServiceClient(
client_options=client_options
)
llm_model = client.endpoint_path(
project=vertex_project, location=vertex_location, endpoint=model project=vertex_project, location=vertex_location, endpoint=model
) )
request_str += f"llm_model = client.endpoint_path(project={vertex_project}, location={vertex_location}, endpoint={model})\n"
mode = "custom" # private endpoint
request_str += f"llm_model = client.endpoint_path(project={vertex_project}, location={vertex_location}, endpoint={model})\n" else:
print("private endpoint", model)
client = aiplatform.PrivateEndpoint(
endpoint_name=model,
project=vertex_project,
location=vertex_location,
)
if acompletion == True: if acompletion == True:
data = { data = {
@ -532,9 +547,6 @@ def completion(
""" """
Vertex AI Model Garden Vertex AI Model Garden
""" """
request_str += (
f"client.predict(endpoint={llm_model}, instances={instances})\n"
)
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=prompt, input=prompt,
@ -545,10 +557,25 @@ def completion(
}, },
) )
response = client.predict( # public endpoint
endpoint=llm_model, if not private:
instances=instances, request_str += (
).predictions f"client.predict(endpoint={llm_model}, instances={instances})\n"
)
response = client.predict(
endpoint=llm_model,
instances=instances,
).predictions
# private endpoint
else:
request_str += (
f"client.predict(instances={instances})\n"
)
print("instances", instances)
response = client.predict(
instances=instances,
).predictions
completion_response = response[0] completion_response = response[0]
if ( if (
isinstance(completion_response, str) isinstance(completion_response, str)
@ -723,16 +750,10 @@ async def async_completion(
""" """
from google.cloud import aiplatform from google.cloud import aiplatform
async_client = aiplatform.gapic.PredictionServiceAsyncClient( private = False
client_options=client_options if 'private' in optional_params:
) private = optional_params.pop('private')
llm_model = async_client.endpoint_path(
project=vertex_project, location=vertex_location, endpoint=model
)
request_str += (
f"client.predict(endpoint={llm_model}, instances={instances})\n"
)
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=prompt, input=prompt,
@ -743,10 +764,35 @@ async def async_completion(
}, },
) )
response_obj = await async_client.predict( # public endpoint
endpoint=llm_model, if not private:
instances=instances, async_client = aiplatform.gapic.PredictionServiceAsyncClient(
) client_options=client_options
)
llm_model = async_client.endpoint_path(
project=vertex_project, location=vertex_location, endpoint=model
)
request_str += (
f"client.predict(endpoint={llm_model}, instances={instances})\n"
)
response_obj = await async_client.predict(
endpoint=llm_model,
instances=instances,
)
# private endpoint
else:
async_client = aiplatform.PrivateEndpoint(
endpoint_name=model,
project=vertex_project,
location=vertex_location,
)
request_str += (
f"client.predict(instances={instances})\n"
)
response_obj = await async_client.predict(
instances=instances,
)
response = response_obj.predictions response = response_obj.predictions
completion_response = response[0] completion_response = response[0]
if ( if (
@ -895,14 +941,10 @@ async def async_streaming(
elif mode == "custom": elif mode == "custom":
from google.cloud import aiplatform from google.cloud import aiplatform
async_client = aiplatform.gapic.PredictionServiceAsyncClient( private = False
client_options=client_options if 'private' in optional_params:
) private = optional_params.pop('private')
llm_model = async_client.endpoint_path(
project=vertex_project, location=vertex_location, endpoint=model
)
request_str += f"client.predict(endpoint={llm_model}, instances={instances})\n"
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=prompt, input=prompt,
@ -912,11 +954,31 @@ async def async_streaming(
"request_str": request_str, "request_str": request_str,
}, },
) )
# public endpoint
if not private:
async_client = aiplatform.gapic.PredictionServiceAsyncClient(
client_options=client_options
)
llm_model = async_client.endpoint_path(
project=vertex_project, location=vertex_location, endpoint=model
)
request_str += f"client.predict(endpoint={llm_model}, instances={instances})\n"
response_obj = await async_client.predict(
endpoint=llm_model,
instances=instances,
)
else:
async_client = aiplatform.PrivateEndpoint(
endpoint_name=model,
project=vertex_project,
location=vertex_location,
)
request_str += f"client.predict(instances={instances})\n"
response_obj = await async_client.predict(
instances=instances,
)
response_obj = await async_client.predict(
endpoint=llm_model,
instances=instances,
)
response = response_obj.predictions response = response_obj.predictions
completion_response = response[0] completion_response = response[0]
if ( if (