add vertex ai private endpoint support

This commit is contained in:
Rena Lu 2024-02-09 16:19:26 -05:00
parent ee91257c48
commit ea29baf3b8

View file

@ -301,6 +301,11 @@ def completion(
gapic_content_types.SafetySetting(x) for x in safety_settings
]
## Custom model deployed on private endpoint
private = False
if "private" in optional_params:
private = optional_params.pop("private")
# vertexai does not use an API key, it looks for credentials.json in the environment
prompt = " ".join(
@ -344,22 +349,32 @@ def completion(
mode = "chat"
request_str += f"llm_model = CodeChatModel.from_pretrained({model})\n"
else: # assume vertex model garden
client = aiplatform.gapic.PredictionServiceClient(
client_options=client_options
)
mode = "custom"
instances = [optional_params]
instances[0]["prompt"] = prompt
instances = [
json_format.ParseDict(instance_dict, Value())
for instance_dict in instances
]
llm_model = client.endpoint_path(
if not private:
# private endpoint requires a string to be passed in
instances = [
json_format.ParseDict(instance_dict, Value())
for instance_dict in instances
]
client = aiplatform.gapic.PredictionServiceClient(
client_options=client_options
)
llm_model = client.endpoint_path(
project=vertex_project, location=vertex_location, endpoint=model
)
)
request_str += f"llm_model = client.endpoint_path(project={vertex_project}, location={vertex_location}, endpoint={model})\n"
mode = "custom"
request_str += f"llm_model = client.endpoint_path(project={vertex_project}, location={vertex_location}, endpoint={model})\n"
# private endpoint
else:
print("private endpoint", model)
client = aiplatform.PrivateEndpoint(
endpoint_name=model,
project=vertex_project,
location=vertex_location,
)
if acompletion == True:
data = {
@ -532,9 +547,6 @@ def completion(
"""
Vertex AI Model Garden
"""
request_str += (
f"client.predict(endpoint={llm_model}, instances={instances})\n"
)
## LOGGING
logging_obj.pre_call(
input=prompt,
@ -545,10 +557,25 @@ def completion(
},
)
response = client.predict(
endpoint=llm_model,
instances=instances,
).predictions
# public endpoint
if not private:
request_str += (
f"client.predict(endpoint={llm_model}, instances={instances})\n"
)
response = client.predict(
endpoint=llm_model,
instances=instances,
).predictions
# private endpoint
else:
request_str += (
f"client.predict(instances={instances})\n"
)
print("instances", instances)
response = client.predict(
instances=instances,
).predictions
completion_response = response[0]
if (
isinstance(completion_response, str)
@ -723,16 +750,10 @@ async def async_completion(
"""
from google.cloud import aiplatform
async_client = aiplatform.gapic.PredictionServiceAsyncClient(
client_options=client_options
)
llm_model = async_client.endpoint_path(
project=vertex_project, location=vertex_location, endpoint=model
)
private = False
if 'private' in optional_params:
private = optional_params.pop('private')
request_str += (
f"client.predict(endpoint={llm_model}, instances={instances})\n"
)
## LOGGING
logging_obj.pre_call(
input=prompt,
@ -743,10 +764,35 @@ async def async_completion(
},
)
response_obj = await async_client.predict(
endpoint=llm_model,
instances=instances,
)
# public endpoint
if not private:
async_client = aiplatform.gapic.PredictionServiceAsyncClient(
client_options=client_options
)
llm_model = async_client.endpoint_path(
project=vertex_project, location=vertex_location, endpoint=model
)
request_str += (
f"client.predict(endpoint={llm_model}, instances={instances})\n"
)
response_obj = await async_client.predict(
endpoint=llm_model,
instances=instances,
)
# private endpoint
else:
async_client = aiplatform.PrivateEndpoint(
endpoint_name=model,
project=vertex_project,
location=vertex_location,
)
request_str += (
f"client.predict(instances={instances})\n"
)
response_obj = await async_client.predict(
instances=instances,
)
response = response_obj.predictions
completion_response = response[0]
if (
@ -895,14 +941,10 @@ async def async_streaming(
elif mode == "custom":
from google.cloud import aiplatform
async_client = aiplatform.gapic.PredictionServiceAsyncClient(
client_options=client_options
)
llm_model = async_client.endpoint_path(
project=vertex_project, location=vertex_location, endpoint=model
)
private = False
if 'private' in optional_params:
private = optional_params.pop('private')
request_str += f"client.predict(endpoint={llm_model}, instances={instances})\n"
## LOGGING
logging_obj.pre_call(
input=prompt,
@ -912,11 +954,31 @@ async def async_streaming(
"request_str": request_str,
},
)
# public endpoint
if not private:
async_client = aiplatform.gapic.PredictionServiceAsyncClient(
client_options=client_options
)
llm_model = async_client.endpoint_path(
project=vertex_project, location=vertex_location, endpoint=model
)
request_str += f"client.predict(endpoint={llm_model}, instances={instances})\n"
response_obj = await async_client.predict(
endpoint=llm_model,
instances=instances,
)
else:
async_client = aiplatform.PrivateEndpoint(
endpoint_name=model,
project=vertex_project,
location=vertex_location,
)
request_str += f"client.predict(instances={instances})\n"
response_obj = await async_client.predict(
instances=instances,
)
response_obj = await async_client.predict(
endpoint=llm_model,
instances=instances,
)
response = response_obj.predictions
completion_response = response[0]
if (