mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
add vertex ai private endpoint support
This commit is contained in:
parent
048acc8e68
commit
0e8a0aefd5
1 changed files with 104 additions and 42 deletions
|
@ -301,6 +301,11 @@ def completion(
|
||||||
gapic_content_types.SafetySetting(x) for x in safety_settings
|
gapic_content_types.SafetySetting(x) for x in safety_settings
|
||||||
]
|
]
|
||||||
|
|
||||||
|
## Custom model deployed on private endpoint
|
||||||
|
private = False
|
||||||
|
if "private" in optional_params:
|
||||||
|
private = optional_params.pop("private")
|
||||||
|
|
||||||
# vertexai does not use an API key, it looks for credentials.json in the environment
|
# vertexai does not use an API key, it looks for credentials.json in the environment
|
||||||
|
|
||||||
prompt = " ".join(
|
prompt = " ".join(
|
||||||
|
@ -344,22 +349,32 @@ def completion(
|
||||||
mode = "chat"
|
mode = "chat"
|
||||||
request_str += f"llm_model = CodeChatModel.from_pretrained({model})\n"
|
request_str += f"llm_model = CodeChatModel.from_pretrained({model})\n"
|
||||||
else: # assume vertex model garden
|
else: # assume vertex model garden
|
||||||
client = aiplatform.gapic.PredictionServiceClient(
|
mode = "custom"
|
||||||
client_options=client_options
|
|
||||||
)
|
|
||||||
|
|
||||||
instances = [optional_params]
|
instances = [optional_params]
|
||||||
instances[0]["prompt"] = prompt
|
instances[0]["prompt"] = prompt
|
||||||
instances = [
|
if not private:
|
||||||
json_format.ParseDict(instance_dict, Value())
|
# private endpoint requires a string to be passed in
|
||||||
for instance_dict in instances
|
instances = [
|
||||||
]
|
json_format.ParseDict(instance_dict, Value())
|
||||||
llm_model = client.endpoint_path(
|
for instance_dict in instances
|
||||||
|
]
|
||||||
|
client = aiplatform.gapic.PredictionServiceClient(
|
||||||
|
client_options=client_options
|
||||||
|
)
|
||||||
|
llm_model = client.endpoint_path(
|
||||||
project=vertex_project, location=vertex_location, endpoint=model
|
project=vertex_project, location=vertex_location, endpoint=model
|
||||||
)
|
)
|
||||||
|
request_str += f"llm_model = client.endpoint_path(project={vertex_project}, location={vertex_location}, endpoint={model})\n"
|
||||||
|
|
||||||
mode = "custom"
|
# private endpoint
|
||||||
request_str += f"llm_model = client.endpoint_path(project={vertex_project}, location={vertex_location}, endpoint={model})\n"
|
else:
|
||||||
|
print("private endpoint", model)
|
||||||
|
client = aiplatform.PrivateEndpoint(
|
||||||
|
endpoint_name=model,
|
||||||
|
project=vertex_project,
|
||||||
|
location=vertex_location,
|
||||||
|
)
|
||||||
|
|
||||||
if acompletion == True:
|
if acompletion == True:
|
||||||
data = {
|
data = {
|
||||||
|
@ -532,9 +547,6 @@ def completion(
|
||||||
"""
|
"""
|
||||||
Vertex AI Model Garden
|
Vertex AI Model Garden
|
||||||
"""
|
"""
|
||||||
request_str += (
|
|
||||||
f"client.predict(endpoint={llm_model}, instances={instances})\n"
|
|
||||||
)
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
input=prompt,
|
input=prompt,
|
||||||
|
@ -545,10 +557,25 @@ def completion(
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
response = client.predict(
|
# public endpoint
|
||||||
endpoint=llm_model,
|
if not private:
|
||||||
instances=instances,
|
request_str += (
|
||||||
).predictions
|
f"client.predict(endpoint={llm_model}, instances={instances})\n"
|
||||||
|
)
|
||||||
|
response = client.predict(
|
||||||
|
endpoint=llm_model,
|
||||||
|
instances=instances,
|
||||||
|
).predictions
|
||||||
|
# private endpoint
|
||||||
|
else:
|
||||||
|
request_str += (
|
||||||
|
f"client.predict(instances={instances})\n"
|
||||||
|
)
|
||||||
|
print("instances", instances)
|
||||||
|
response = client.predict(
|
||||||
|
instances=instances,
|
||||||
|
).predictions
|
||||||
|
|
||||||
completion_response = response[0]
|
completion_response = response[0]
|
||||||
if (
|
if (
|
||||||
isinstance(completion_response, str)
|
isinstance(completion_response, str)
|
||||||
|
@ -723,16 +750,10 @@ async def async_completion(
|
||||||
"""
|
"""
|
||||||
from google.cloud import aiplatform
|
from google.cloud import aiplatform
|
||||||
|
|
||||||
async_client = aiplatform.gapic.PredictionServiceAsyncClient(
|
private = False
|
||||||
client_options=client_options
|
if 'private' in optional_params:
|
||||||
)
|
private = optional_params.pop('private')
|
||||||
llm_model = async_client.endpoint_path(
|
|
||||||
project=vertex_project, location=vertex_location, endpoint=model
|
|
||||||
)
|
|
||||||
|
|
||||||
request_str += (
|
|
||||||
f"client.predict(endpoint={llm_model}, instances={instances})\n"
|
|
||||||
)
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
input=prompt,
|
input=prompt,
|
||||||
|
@ -743,10 +764,35 @@ async def async_completion(
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
response_obj = await async_client.predict(
|
# public endpoint
|
||||||
endpoint=llm_model,
|
if not private:
|
||||||
instances=instances,
|
async_client = aiplatform.gapic.PredictionServiceAsyncClient(
|
||||||
)
|
client_options=client_options
|
||||||
|
)
|
||||||
|
llm_model = async_client.endpoint_path(
|
||||||
|
project=vertex_project, location=vertex_location, endpoint=model
|
||||||
|
)
|
||||||
|
request_str += (
|
||||||
|
f"client.predict(endpoint={llm_model}, instances={instances})\n"
|
||||||
|
)
|
||||||
|
response_obj = await async_client.predict(
|
||||||
|
endpoint=llm_model,
|
||||||
|
instances=instances,
|
||||||
|
)
|
||||||
|
# private endpoint
|
||||||
|
else:
|
||||||
|
async_client = aiplatform.PrivateEndpoint(
|
||||||
|
endpoint_name=model,
|
||||||
|
project=vertex_project,
|
||||||
|
location=vertex_location,
|
||||||
|
)
|
||||||
|
request_str += (
|
||||||
|
f"client.predict(instances={instances})\n"
|
||||||
|
)
|
||||||
|
response_obj = await async_client.predict(
|
||||||
|
instances=instances,
|
||||||
|
)
|
||||||
|
|
||||||
response = response_obj.predictions
|
response = response_obj.predictions
|
||||||
completion_response = response[0]
|
completion_response = response[0]
|
||||||
if (
|
if (
|
||||||
|
@ -895,14 +941,10 @@ async def async_streaming(
|
||||||
elif mode == "custom":
|
elif mode == "custom":
|
||||||
from google.cloud import aiplatform
|
from google.cloud import aiplatform
|
||||||
|
|
||||||
async_client = aiplatform.gapic.PredictionServiceAsyncClient(
|
private = False
|
||||||
client_options=client_options
|
if 'private' in optional_params:
|
||||||
)
|
private = optional_params.pop('private')
|
||||||
llm_model = async_client.endpoint_path(
|
|
||||||
project=vertex_project, location=vertex_location, endpoint=model
|
|
||||||
)
|
|
||||||
|
|
||||||
request_str += f"client.predict(endpoint={llm_model}, instances={instances})\n"
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
input=prompt,
|
input=prompt,
|
||||||
|
@ -912,11 +954,31 @@ async def async_streaming(
|
||||||
"request_str": request_str,
|
"request_str": request_str,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
# public endpoint
|
||||||
|
if not private:
|
||||||
|
async_client = aiplatform.gapic.PredictionServiceAsyncClient(
|
||||||
|
client_options=client_options
|
||||||
|
)
|
||||||
|
llm_model = async_client.endpoint_path(
|
||||||
|
project=vertex_project, location=vertex_location, endpoint=model
|
||||||
|
)
|
||||||
|
|
||||||
|
request_str += f"client.predict(endpoint={llm_model}, instances={instances})\n"
|
||||||
|
response_obj = await async_client.predict(
|
||||||
|
endpoint=llm_model,
|
||||||
|
instances=instances,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
async_client = aiplatform.PrivateEndpoint(
|
||||||
|
endpoint_name=model,
|
||||||
|
project=vertex_project,
|
||||||
|
location=vertex_location,
|
||||||
|
)
|
||||||
|
request_str += f"client.predict(instances={instances})\n"
|
||||||
|
response_obj = await async_client.predict(
|
||||||
|
instances=instances,
|
||||||
|
)
|
||||||
|
|
||||||
response_obj = await async_client.predict(
|
|
||||||
endpoint=llm_model,
|
|
||||||
instances=instances,
|
|
||||||
)
|
|
||||||
response = response_obj.predictions
|
response = response_obj.predictions
|
||||||
completion_response = response[0]
|
completion_response = response[0]
|
||||||
if (
|
if (
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue