mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
Merge pull request #1916 from RenaLu/main
Add support for Vertex AI custom models deployed on private endpoint
This commit is contained in:
commit
9b60ef9a3c
2 changed files with 128 additions and 42 deletions
|
@ -343,24 +343,31 @@ def completion(
|
||||||
llm_model = CodeChatModel.from_pretrained(model)
|
llm_model = CodeChatModel.from_pretrained(model)
|
||||||
mode = "chat"
|
mode = "chat"
|
||||||
request_str += f"llm_model = CodeChatModel.from_pretrained({model})\n"
|
request_str += f"llm_model = CodeChatModel.from_pretrained({model})\n"
|
||||||
else: # assume vertex model garden
|
elif model == "private":
|
||||||
client = aiplatform.gapic.PredictionServiceClient(
|
mode = "private"
|
||||||
client_options=client_options
|
model = optional_params.pop("model_id", None)
|
||||||
|
# private endpoint requires a dict instead of JSON
|
||||||
|
instances = [optional_params.copy()]
|
||||||
|
instances[0]["prompt"] = prompt
|
||||||
|
llm_model = aiplatform.PrivateEndpoint(
|
||||||
|
endpoint_name=model,
|
||||||
|
project=vertex_project,
|
||||||
|
location=vertex_location,
|
||||||
)
|
)
|
||||||
|
request_str += f"llm_model = aiplatform.PrivateEndpoint(endpoint_name={model}, project={vertex_project}, location={vertex_location})\n"
|
||||||
|
else: # assume vertex model garden on public endpoint
|
||||||
|
mode = "custom"
|
||||||
|
|
||||||
instances = [optional_params]
|
instances = [optional_params.copy()]
|
||||||
instances[0]["prompt"] = prompt
|
instances[0]["prompt"] = prompt
|
||||||
instances = [
|
instances = [
|
||||||
json_format.ParseDict(instance_dict, Value())
|
json_format.ParseDict(instance_dict, Value())
|
||||||
for instance_dict in instances
|
for instance_dict in instances
|
||||||
]
|
]
|
||||||
llm_model = client.endpoint_path(
|
# Will determine the API used based on async parameter
|
||||||
project=vertex_project, location=vertex_location, endpoint=model
|
llm_model = None
|
||||||
)
|
|
||||||
|
|
||||||
mode = "custom"
|
|
||||||
request_str += f"llm_model = client.endpoint_path(project={vertex_project}, location={vertex_location}, endpoint={model})\n"
|
|
||||||
|
|
||||||
|
# NOTE: async prediction and streaming under "private" mode isn't supported by aiplatform right now
|
||||||
if acompletion == True:
|
if acompletion == True:
|
||||||
data = {
|
data = {
|
||||||
"llm_model": llm_model,
|
"llm_model": llm_model,
|
||||||
|
@ -532,9 +539,6 @@ def completion(
|
||||||
"""
|
"""
|
||||||
Vertex AI Model Garden
|
Vertex AI Model Garden
|
||||||
"""
|
"""
|
||||||
request_str += (
|
|
||||||
f"client.predict(endpoint={llm_model}, instances={instances})\n"
|
|
||||||
)
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
input=prompt,
|
input=prompt,
|
||||||
|
@ -544,11 +548,21 @@ def completion(
|
||||||
"request_str": request_str,
|
"request_str": request_str,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
llm_model = aiplatform.gapic.PredictionServiceClient(
|
||||||
response = client.predict(
|
client_options=client_options
|
||||||
endpoint=llm_model,
|
)
|
||||||
instances=instances,
|
request_str += f"llm_model = aiplatform.gapic.PredictionServiceClient(client_options={client_options})\n"
|
||||||
|
endpoint_path = llm_model.endpoint_path(
|
||||||
|
project=vertex_project, location=vertex_location, endpoint=model
|
||||||
|
)
|
||||||
|
request_str += (
|
||||||
|
f"llm_model.predict(endpoint={endpoint_path}, instances={instances})\n"
|
||||||
|
)
|
||||||
|
response = llm_model.predict(
|
||||||
|
endpoint=endpoint_path,
|
||||||
|
instances=instances
|
||||||
).predictions
|
).predictions
|
||||||
|
|
||||||
completion_response = response[0]
|
completion_response = response[0]
|
||||||
if (
|
if (
|
||||||
isinstance(completion_response, str)
|
isinstance(completion_response, str)
|
||||||
|
@ -558,6 +572,36 @@ def completion(
|
||||||
if "stream" in optional_params and optional_params["stream"] == True:
|
if "stream" in optional_params and optional_params["stream"] == True:
|
||||||
response = TextStreamer(completion_response)
|
response = TextStreamer(completion_response)
|
||||||
return response
|
return response
|
||||||
|
elif mode == "private":
|
||||||
|
"""
|
||||||
|
Vertex AI Model Garden deployed on private endpoint
|
||||||
|
"""
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.pre_call(
|
||||||
|
input=prompt,
|
||||||
|
api_key=None,
|
||||||
|
additional_args={
|
||||||
|
"complete_input_dict": optional_params,
|
||||||
|
"request_str": request_str,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
request_str += (
|
||||||
|
f"llm_model.predict(instances={instances})\n"
|
||||||
|
)
|
||||||
|
response = llm_model.predict(
|
||||||
|
instances=instances
|
||||||
|
).predictions
|
||||||
|
|
||||||
|
completion_response = response[0]
|
||||||
|
if (
|
||||||
|
isinstance(completion_response, str)
|
||||||
|
and "\nOutput:\n" in completion_response
|
||||||
|
):
|
||||||
|
completion_response = completion_response.split("\nOutput:\n", 1)[1]
|
||||||
|
if "stream" in optional_params and optional_params["stream"] == True:
|
||||||
|
response = TextStreamer(completion_response)
|
||||||
|
return response
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
input=prompt, api_key=None, original_response=completion_response
|
input=prompt, api_key=None, original_response=completion_response
|
||||||
|
@ -722,17 +766,6 @@ async def async_completion(
|
||||||
Vertex AI Model Garden
|
Vertex AI Model Garden
|
||||||
"""
|
"""
|
||||||
from google.cloud import aiplatform
|
from google.cloud import aiplatform
|
||||||
|
|
||||||
async_client = aiplatform.gapic.PredictionServiceAsyncClient(
|
|
||||||
client_options=client_options
|
|
||||||
)
|
|
||||||
llm_model = async_client.endpoint_path(
|
|
||||||
project=vertex_project, location=vertex_location, endpoint=model
|
|
||||||
)
|
|
||||||
|
|
||||||
request_str += (
|
|
||||||
f"client.predict(endpoint={llm_model}, instances={instances})\n"
|
|
||||||
)
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
input=prompt,
|
input=prompt,
|
||||||
|
@ -743,8 +776,18 @@ async def async_completion(
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
response_obj = await async_client.predict(
|
llm_model = aiplatform.gapic.PredictionServiceAsyncClient(
|
||||||
endpoint=llm_model,
|
client_options=client_options
|
||||||
|
)
|
||||||
|
request_str += f"llm_model = aiplatform.gapic.PredictionServiceAsyncClient(client_options={client_options})\n"
|
||||||
|
endpoint_path = llm_model.endpoint_path(
|
||||||
|
project=vertex_project, location=vertex_location, endpoint=model
|
||||||
|
)
|
||||||
|
request_str += (
|
||||||
|
f"llm_model.predict(endpoint={endpoint_path}, instances={instances})\n"
|
||||||
|
)
|
||||||
|
response_obj = await llm_model.predict(
|
||||||
|
endpoint=endpoint_path,
|
||||||
instances=instances,
|
instances=instances,
|
||||||
)
|
)
|
||||||
response = response_obj.predictions
|
response = response_obj.predictions
|
||||||
|
@ -754,6 +797,23 @@ async def async_completion(
|
||||||
and "\nOutput:\n" in completion_response
|
and "\nOutput:\n" in completion_response
|
||||||
):
|
):
|
||||||
completion_response = completion_response.split("\nOutput:\n", 1)[1]
|
completion_response = completion_response.split("\nOutput:\n", 1)[1]
|
||||||
|
|
||||||
|
elif mode == "private":
|
||||||
|
request_str += (
|
||||||
|
f"llm_model.predict_async(instances={instances})\n"
|
||||||
|
)
|
||||||
|
response_obj = await llm_model.predict_async(
|
||||||
|
instances=instances,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = response_obj.predictions
|
||||||
|
completion_response = response[0]
|
||||||
|
if (
|
||||||
|
isinstance(completion_response, str)
|
||||||
|
and "\nOutput:\n" in completion_response
|
||||||
|
):
|
||||||
|
completion_response = completion_response.split("\nOutput:\n", 1)[1]
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
input=prompt, api_key=None, original_response=completion_response
|
input=prompt, api_key=None, original_response=completion_response
|
||||||
|
@ -894,15 +954,8 @@ async def async_streaming(
|
||||||
response = llm_model.predict_streaming_async(prompt, **optional_params)
|
response = llm_model.predict_streaming_async(prompt, **optional_params)
|
||||||
elif mode == "custom":
|
elif mode == "custom":
|
||||||
from google.cloud import aiplatform
|
from google.cloud import aiplatform
|
||||||
|
stream = optional_params.pop("stream", None)
|
||||||
|
|
||||||
async_client = aiplatform.gapic.PredictionServiceAsyncClient(
|
|
||||||
client_options=client_options
|
|
||||||
)
|
|
||||||
llm_model = async_client.endpoint_path(
|
|
||||||
project=vertex_project, location=vertex_location, endpoint=model
|
|
||||||
)
|
|
||||||
|
|
||||||
request_str += f"client.predict(endpoint={llm_model}, instances={instances})\n"
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
input=prompt,
|
input=prompt,
|
||||||
|
@ -912,9 +965,34 @@ async def async_streaming(
|
||||||
"request_str": request_str,
|
"request_str": request_str,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
llm_model = aiplatform.gapic.PredictionServiceAsyncClient(
|
||||||
|
client_options=client_options
|
||||||
|
)
|
||||||
|
request_str += f"llm_model = aiplatform.gapic.PredictionServiceAsyncClient(client_options={client_options})\n"
|
||||||
|
endpoint_path = llm_model.endpoint_path(
|
||||||
|
project=vertex_project, location=vertex_location, endpoint=model
|
||||||
|
)
|
||||||
|
request_str += f"client.predict(endpoint={endpoint_path}, instances={instances})\n"
|
||||||
|
response_obj = await llm_model.predict(
|
||||||
|
endpoint=endpoint_path,
|
||||||
|
instances=instances,
|
||||||
|
)
|
||||||
|
|
||||||
response_obj = await async_client.predict(
|
response = response_obj.predictions
|
||||||
endpoint=llm_model,
|
completion_response = response[0]
|
||||||
|
if (
|
||||||
|
isinstance(completion_response, str)
|
||||||
|
and "\nOutput:\n" in completion_response
|
||||||
|
):
|
||||||
|
completion_response = completion_response.split("\nOutput:\n", 1)[1]
|
||||||
|
if stream:
|
||||||
|
response = TextStreamer(completion_response)
|
||||||
|
|
||||||
|
elif mode == "private":
|
||||||
|
stream = optional_params.pop("stream", None)
|
||||||
|
_ = instances[0].pop("stream", None)
|
||||||
|
request_str += f"llm_model.predict_async(instances={instances})\n"
|
||||||
|
response_obj = await llm_model.predict_async(
|
||||||
instances=instances,
|
instances=instances,
|
||||||
)
|
)
|
||||||
response = response_obj.predictions
|
response = response_obj.predictions
|
||||||
|
@ -924,8 +1002,9 @@ async def async_streaming(
|
||||||
and "\nOutput:\n" in completion_response
|
and "\nOutput:\n" in completion_response
|
||||||
):
|
):
|
||||||
completion_response = completion_response.split("\nOutput:\n", 1)[1]
|
completion_response = completion_response.split("\nOutput:\n", 1)[1]
|
||||||
if "stream" in optional_params and optional_params["stream"] == True:
|
if stream:
|
||||||
response = TextStreamer(completion_response)
|
response = TextStreamer(completion_response)
|
||||||
|
|
||||||
streamwrapper = CustomStreamWrapper(
|
streamwrapper = CustomStreamWrapper(
|
||||||
completion_stream=response,
|
completion_stream=response,
|
||||||
model=model,
|
model=model,
|
||||||
|
|
|
@ -4256,7 +4256,14 @@ def get_optional_params(
|
||||||
optional_params["stop_sequences"] = stop
|
optional_params["stop_sequences"] = stop
|
||||||
if max_tokens is not None:
|
if max_tokens is not None:
|
||||||
optional_params["max_output_tokens"] = max_tokens
|
optional_params["max_output_tokens"] = max_tokens
|
||||||
elif custom_llm_provider == "vertex_ai":
|
elif custom_llm_provider == "vertex_ai" and model in (
|
||||||
|
litellm.vertex_chat_models
|
||||||
|
or model in litellm.vertex_code_chat_models
|
||||||
|
or model in litellm.vertex_text_models
|
||||||
|
or model in litellm.vertex_code_text_models
|
||||||
|
or model in litellm.vertex_language_models
|
||||||
|
or model in litellm.vertex_embedding_models
|
||||||
|
):
|
||||||
## check if unsupported param passed in
|
## check if unsupported param passed in
|
||||||
supported_params = [
|
supported_params = [
|
||||||
"temperature",
|
"temperature",
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue