mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
Merge pull request #1916 from RenaLu/main
Add support for Vertex AI custom models deployed on private endpoint
This commit is contained in:
commit
9b60ef9a3c
2 changed files with 128 additions and 42 deletions
|
@ -343,24 +343,31 @@ def completion(
|
|||
llm_model = CodeChatModel.from_pretrained(model)
|
||||
mode = "chat"
|
||||
request_str += f"llm_model = CodeChatModel.from_pretrained({model})\n"
|
||||
else: # assume vertex model garden
|
||||
client = aiplatform.gapic.PredictionServiceClient(
|
||||
client_options=client_options
|
||||
elif model == "private":
|
||||
mode = "private"
|
||||
model = optional_params.pop("model_id", None)
|
||||
# private endpoint requires a dict instead of JSON
|
||||
instances = [optional_params.copy()]
|
||||
instances[0]["prompt"] = prompt
|
||||
llm_model = aiplatform.PrivateEndpoint(
|
||||
endpoint_name=model,
|
||||
project=vertex_project,
|
||||
location=vertex_location,
|
||||
)
|
||||
request_str += f"llm_model = aiplatform.PrivateEndpoint(endpoint_name={model}, project={vertex_project}, location={vertex_location})\n"
|
||||
else: # assume vertex model garden on public endpoint
|
||||
mode = "custom"
|
||||
|
||||
instances = [optional_params]
|
||||
instances = [optional_params.copy()]
|
||||
instances[0]["prompt"] = prompt
|
||||
instances = [
|
||||
json_format.ParseDict(instance_dict, Value())
|
||||
for instance_dict in instances
|
||||
]
|
||||
llm_model = client.endpoint_path(
|
||||
project=vertex_project, location=vertex_location, endpoint=model
|
||||
)
|
||||
|
||||
mode = "custom"
|
||||
request_str += f"llm_model = client.endpoint_path(project={vertex_project}, location={vertex_location}, endpoint={model})\n"
|
||||
# Will determine the API used based on async parameter
|
||||
llm_model = None
|
||||
|
||||
# NOTE: async prediction and streaming under "private" mode isn't supported by aiplatform right now
|
||||
if acompletion == True:
|
||||
data = {
|
||||
"llm_model": llm_model,
|
||||
|
@ -532,9 +539,6 @@ def completion(
|
|||
"""
|
||||
Vertex AI Model Garden
|
||||
"""
|
||||
request_str += (
|
||||
f"client.predict(endpoint={llm_model}, instances={instances})\n"
|
||||
)
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
|
@ -544,11 +548,21 @@ def completion(
|
|||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
|
||||
response = client.predict(
|
||||
endpoint=llm_model,
|
||||
instances=instances,
|
||||
llm_model = aiplatform.gapic.PredictionServiceClient(
|
||||
client_options=client_options
|
||||
)
|
||||
request_str += f"llm_model = aiplatform.gapic.PredictionServiceClient(client_options={client_options})\n"
|
||||
endpoint_path = llm_model.endpoint_path(
|
||||
project=vertex_project, location=vertex_location, endpoint=model
|
||||
)
|
||||
request_str += (
|
||||
f"llm_model.predict(endpoint={endpoint_path}, instances={instances})\n"
|
||||
)
|
||||
response = llm_model.predict(
|
||||
endpoint=endpoint_path,
|
||||
instances=instances
|
||||
).predictions
|
||||
|
||||
completion_response = response[0]
|
||||
if (
|
||||
isinstance(completion_response, str)
|
||||
|
@ -558,6 +572,36 @@ def completion(
|
|||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
response = TextStreamer(completion_response)
|
||||
return response
|
||||
elif mode == "private":
|
||||
"""
|
||||
Vertex AI Model Garden deployed on private endpoint
|
||||
"""
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key=None,
|
||||
additional_args={
|
||||
"complete_input_dict": optional_params,
|
||||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
request_str += (
|
||||
f"llm_model.predict(instances={instances})\n"
|
||||
)
|
||||
response = llm_model.predict(
|
||||
instances=instances
|
||||
).predictions
|
||||
|
||||
completion_response = response[0]
|
||||
if (
|
||||
isinstance(completion_response, str)
|
||||
and "\nOutput:\n" in completion_response
|
||||
):
|
||||
completion_response = completion_response.split("\nOutput:\n", 1)[1]
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
response = TextStreamer(completion_response)
|
||||
return response
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=prompt, api_key=None, original_response=completion_response
|
||||
|
@ -722,17 +766,6 @@ async def async_completion(
|
|||
Vertex AI Model Garden
|
||||
"""
|
||||
from google.cloud import aiplatform
|
||||
|
||||
async_client = aiplatform.gapic.PredictionServiceAsyncClient(
|
||||
client_options=client_options
|
||||
)
|
||||
llm_model = async_client.endpoint_path(
|
||||
project=vertex_project, location=vertex_location, endpoint=model
|
||||
)
|
||||
|
||||
request_str += (
|
||||
f"client.predict(endpoint={llm_model}, instances={instances})\n"
|
||||
)
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
|
@ -743,8 +776,18 @@ async def async_completion(
|
|||
},
|
||||
)
|
||||
|
||||
response_obj = await async_client.predict(
|
||||
endpoint=llm_model,
|
||||
llm_model = aiplatform.gapic.PredictionServiceAsyncClient(
|
||||
client_options=client_options
|
||||
)
|
||||
request_str += f"llm_model = aiplatform.gapic.PredictionServiceAsyncClient(client_options={client_options})\n"
|
||||
endpoint_path = llm_model.endpoint_path(
|
||||
project=vertex_project, location=vertex_location, endpoint=model
|
||||
)
|
||||
request_str += (
|
||||
f"llm_model.predict(endpoint={endpoint_path}, instances={instances})\n"
|
||||
)
|
||||
response_obj = await llm_model.predict(
|
||||
endpoint=endpoint_path,
|
||||
instances=instances,
|
||||
)
|
||||
response = response_obj.predictions
|
||||
|
@ -754,6 +797,23 @@ async def async_completion(
|
|||
and "\nOutput:\n" in completion_response
|
||||
):
|
||||
completion_response = completion_response.split("\nOutput:\n", 1)[1]
|
||||
|
||||
elif mode == "private":
|
||||
request_str += (
|
||||
f"llm_model.predict_async(instances={instances})\n"
|
||||
)
|
||||
response_obj = await llm_model.predict_async(
|
||||
instances=instances,
|
||||
)
|
||||
|
||||
response = response_obj.predictions
|
||||
completion_response = response[0]
|
||||
if (
|
||||
isinstance(completion_response, str)
|
||||
and "\nOutput:\n" in completion_response
|
||||
):
|
||||
completion_response = completion_response.split("\nOutput:\n", 1)[1]
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=prompt, api_key=None, original_response=completion_response
|
||||
|
@ -894,15 +954,8 @@ async def async_streaming(
|
|||
response = llm_model.predict_streaming_async(prompt, **optional_params)
|
||||
elif mode == "custom":
|
||||
from google.cloud import aiplatform
|
||||
stream = optional_params.pop("stream", None)
|
||||
|
||||
async_client = aiplatform.gapic.PredictionServiceAsyncClient(
|
||||
client_options=client_options
|
||||
)
|
||||
llm_model = async_client.endpoint_path(
|
||||
project=vertex_project, location=vertex_location, endpoint=model
|
||||
)
|
||||
|
||||
request_str += f"client.predict(endpoint={llm_model}, instances={instances})\n"
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
|
@ -912,9 +965,34 @@ async def async_streaming(
|
|||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
llm_model = aiplatform.gapic.PredictionServiceAsyncClient(
|
||||
client_options=client_options
|
||||
)
|
||||
request_str += f"llm_model = aiplatform.gapic.PredictionServiceAsyncClient(client_options={client_options})\n"
|
||||
endpoint_path = llm_model.endpoint_path(
|
||||
project=vertex_project, location=vertex_location, endpoint=model
|
||||
)
|
||||
request_str += f"client.predict(endpoint={endpoint_path}, instances={instances})\n"
|
||||
response_obj = await llm_model.predict(
|
||||
endpoint=endpoint_path,
|
||||
instances=instances,
|
||||
)
|
||||
|
||||
response_obj = await async_client.predict(
|
||||
endpoint=llm_model,
|
||||
response = response_obj.predictions
|
||||
completion_response = response[0]
|
||||
if (
|
||||
isinstance(completion_response, str)
|
||||
and "\nOutput:\n" in completion_response
|
||||
):
|
||||
completion_response = completion_response.split("\nOutput:\n", 1)[1]
|
||||
if stream:
|
||||
response = TextStreamer(completion_response)
|
||||
|
||||
elif mode == "private":
|
||||
stream = optional_params.pop("stream", None)
|
||||
_ = instances[0].pop("stream", None)
|
||||
request_str += f"llm_model.predict_async(instances={instances})\n"
|
||||
response_obj = await llm_model.predict_async(
|
||||
instances=instances,
|
||||
)
|
||||
response = response_obj.predictions
|
||||
|
@ -924,8 +1002,9 @@ async def async_streaming(
|
|||
and "\nOutput:\n" in completion_response
|
||||
):
|
||||
completion_response = completion_response.split("\nOutput:\n", 1)[1]
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
if stream:
|
||||
response = TextStreamer(completion_response)
|
||||
|
||||
streamwrapper = CustomStreamWrapper(
|
||||
completion_stream=response,
|
||||
model=model,
|
||||
|
|
|
@ -4256,7 +4256,14 @@ def get_optional_params(
|
|||
optional_params["stop_sequences"] = stop
|
||||
if max_tokens is not None:
|
||||
optional_params["max_output_tokens"] = max_tokens
|
||||
elif custom_llm_provider == "vertex_ai":
|
||||
elif custom_llm_provider == "vertex_ai" and model in (
|
||||
litellm.vertex_chat_models
|
||||
or model in litellm.vertex_code_chat_models
|
||||
or model in litellm.vertex_text_models
|
||||
or model in litellm.vertex_code_text_models
|
||||
or model in litellm.vertex_language_models
|
||||
or model in litellm.vertex_embedding_models
|
||||
):
|
||||
## check if unsupported param passed in
|
||||
supported_params = [
|
||||
"temperature",
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue