Merge pull request #1916 from RenaLu/main

Add support for Vertex AI custom models deployed on private endpoint
This commit is contained in:
Krish Dholakia 2024-02-15 22:47:36 -08:00 committed by GitHub
commit 9b60ef9a3c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 128 additions and 42 deletions

View file

@ -343,24 +343,31 @@ def completion(
llm_model = CodeChatModel.from_pretrained(model) llm_model = CodeChatModel.from_pretrained(model)
mode = "chat" mode = "chat"
request_str += f"llm_model = CodeChatModel.from_pretrained({model})\n" request_str += f"llm_model = CodeChatModel.from_pretrained({model})\n"
else: # assume vertex model garden elif model == "private":
client = aiplatform.gapic.PredictionServiceClient( mode = "private"
client_options=client_options model = optional_params.pop("model_id", None)
# private endpoint requires a dict instead of JSON
instances = [optional_params.copy()]
instances[0]["prompt"] = prompt
llm_model = aiplatform.PrivateEndpoint(
endpoint_name=model,
project=vertex_project,
location=vertex_location,
) )
request_str += f"llm_model = aiplatform.PrivateEndpoint(endpoint_name={model}, project={vertex_project}, location={vertex_location})\n"
else: # assume vertex model garden on public endpoint
mode = "custom"
instances = [optional_params] instances = [optional_params.copy()]
instances[0]["prompt"] = prompt instances[0]["prompt"] = prompt
instances = [ instances = [
json_format.ParseDict(instance_dict, Value()) json_format.ParseDict(instance_dict, Value())
for instance_dict in instances for instance_dict in instances
] ]
llm_model = client.endpoint_path( # Will determine the API used based on async parameter
project=vertex_project, location=vertex_location, endpoint=model llm_model = None
)
mode = "custom"
request_str += f"llm_model = client.endpoint_path(project={vertex_project}, location={vertex_location}, endpoint={model})\n"
# NOTE: async prediction and streaming under "private" mode isn't supported by aiplatform right now
if acompletion == True: if acompletion == True:
data = { data = {
"llm_model": llm_model, "llm_model": llm_model,
@ -532,9 +539,6 @@ def completion(
""" """
Vertex AI Model Garden Vertex AI Model Garden
""" """
request_str += (
f"client.predict(endpoint={llm_model}, instances={instances})\n"
)
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=prompt, input=prompt,
@ -544,11 +548,21 @@ def completion(
"request_str": request_str, "request_str": request_str,
}, },
) )
llm_model = aiplatform.gapic.PredictionServiceClient(
response = client.predict( client_options=client_options
endpoint=llm_model, )
instances=instances, request_str += f"llm_model = aiplatform.gapic.PredictionServiceClient(client_options={client_options})\n"
endpoint_path = llm_model.endpoint_path(
project=vertex_project, location=vertex_location, endpoint=model
)
request_str += (
f"llm_model.predict(endpoint={endpoint_path}, instances={instances})\n"
)
response = llm_model.predict(
endpoint=endpoint_path,
instances=instances
).predictions ).predictions
completion_response = response[0] completion_response = response[0]
if ( if (
isinstance(completion_response, str) isinstance(completion_response, str)
@ -558,6 +572,36 @@ def completion(
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] == True:
response = TextStreamer(completion_response) response = TextStreamer(completion_response)
return response return response
elif mode == "private":
"""
Vertex AI Model Garden deployed on private endpoint
"""
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
request_str += (
f"llm_model.predict(instances={instances})\n"
)
response = llm_model.predict(
instances=instances
).predictions
completion_response = response[0]
if (
isinstance(completion_response, str)
and "\nOutput:\n" in completion_response
):
completion_response = completion_response.split("\nOutput:\n", 1)[1]
if "stream" in optional_params and optional_params["stream"] == True:
response = TextStreamer(completion_response)
return response
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=prompt, api_key=None, original_response=completion_response input=prompt, api_key=None, original_response=completion_response
@ -722,17 +766,6 @@ async def async_completion(
Vertex AI Model Garden Vertex AI Model Garden
""" """
from google.cloud import aiplatform from google.cloud import aiplatform
async_client = aiplatform.gapic.PredictionServiceAsyncClient(
client_options=client_options
)
llm_model = async_client.endpoint_path(
project=vertex_project, location=vertex_location, endpoint=model
)
request_str += (
f"client.predict(endpoint={llm_model}, instances={instances})\n"
)
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=prompt, input=prompt,
@ -743,8 +776,18 @@ async def async_completion(
}, },
) )
response_obj = await async_client.predict( llm_model = aiplatform.gapic.PredictionServiceAsyncClient(
endpoint=llm_model, client_options=client_options
)
request_str += f"llm_model = aiplatform.gapic.PredictionServiceAsyncClient(client_options={client_options})\n"
endpoint_path = llm_model.endpoint_path(
project=vertex_project, location=vertex_location, endpoint=model
)
request_str += (
f"llm_model.predict(endpoint={endpoint_path}, instances={instances})\n"
)
response_obj = await llm_model.predict(
endpoint=endpoint_path,
instances=instances, instances=instances,
) )
response = response_obj.predictions response = response_obj.predictions
@ -754,6 +797,23 @@ async def async_completion(
and "\nOutput:\n" in completion_response and "\nOutput:\n" in completion_response
): ):
completion_response = completion_response.split("\nOutput:\n", 1)[1] completion_response = completion_response.split("\nOutput:\n", 1)[1]
elif mode == "private":
request_str += (
f"llm_model.predict_async(instances={instances})\n"
)
response_obj = await llm_model.predict_async(
instances=instances,
)
response = response_obj.predictions
completion_response = response[0]
if (
isinstance(completion_response, str)
and "\nOutput:\n" in completion_response
):
completion_response = completion_response.split("\nOutput:\n", 1)[1]
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=prompt, api_key=None, original_response=completion_response input=prompt, api_key=None, original_response=completion_response
@ -894,15 +954,8 @@ async def async_streaming(
response = llm_model.predict_streaming_async(prompt, **optional_params) response = llm_model.predict_streaming_async(prompt, **optional_params)
elif mode == "custom": elif mode == "custom":
from google.cloud import aiplatform from google.cloud import aiplatform
stream = optional_params.pop("stream", None)
async_client = aiplatform.gapic.PredictionServiceAsyncClient(
client_options=client_options
)
llm_model = async_client.endpoint_path(
project=vertex_project, location=vertex_location, endpoint=model
)
request_str += f"client.predict(endpoint={llm_model}, instances={instances})\n"
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=prompt, input=prompt,
@ -912,9 +965,34 @@ async def async_streaming(
"request_str": request_str, "request_str": request_str,
}, },
) )
llm_model = aiplatform.gapic.PredictionServiceAsyncClient(
client_options=client_options
)
request_str += f"llm_model = aiplatform.gapic.PredictionServiceAsyncClient(client_options={client_options})\n"
endpoint_path = llm_model.endpoint_path(
project=vertex_project, location=vertex_location, endpoint=model
)
request_str += f"client.predict(endpoint={endpoint_path}, instances={instances})\n"
response_obj = await llm_model.predict(
endpoint=endpoint_path,
instances=instances,
)
response_obj = await async_client.predict( response = response_obj.predictions
endpoint=llm_model, completion_response = response[0]
if (
isinstance(completion_response, str)
and "\nOutput:\n" in completion_response
):
completion_response = completion_response.split("\nOutput:\n", 1)[1]
if stream:
response = TextStreamer(completion_response)
elif mode == "private":
stream = optional_params.pop("stream", None)
_ = instances[0].pop("stream", None)
request_str += f"llm_model.predict_async(instances={instances})\n"
response_obj = await llm_model.predict_async(
instances=instances, instances=instances,
) )
response = response_obj.predictions response = response_obj.predictions
@ -924,8 +1002,9 @@ async def async_streaming(
and "\nOutput:\n" in completion_response and "\nOutput:\n" in completion_response
): ):
completion_response = completion_response.split("\nOutput:\n", 1)[1] completion_response = completion_response.split("\nOutput:\n", 1)[1]
if "stream" in optional_params and optional_params["stream"] == True: if stream:
response = TextStreamer(completion_response) response = TextStreamer(completion_response)
streamwrapper = CustomStreamWrapper( streamwrapper = CustomStreamWrapper(
completion_stream=response, completion_stream=response,
model=model, model=model,

View file

@ -4256,7 +4256,14 @@ def get_optional_params(
optional_params["stop_sequences"] = stop optional_params["stop_sequences"] = stop
if max_tokens is not None: if max_tokens is not None:
optional_params["max_output_tokens"] = max_tokens optional_params["max_output_tokens"] = max_tokens
elif custom_llm_provider == "vertex_ai": elif custom_llm_provider == "vertex_ai" and model in (
litellm.vertex_chat_models
or model in litellm.vertex_code_chat_models
or model in litellm.vertex_text_models
or model in litellm.vertex_code_text_models
or model in litellm.vertex_language_models
or model in litellm.vertex_embedding_models
):
## check if unsupported param passed in ## check if unsupported param passed in
supported_params = [ supported_params = [
"temperature", "temperature",