mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 19:54:13 +00:00
add vertex ai private endpoint support
This commit is contained in:
parent
ee91257c48
commit
ea29baf3b8
1 changed files with 104 additions and 42 deletions
|
@ -301,6 +301,11 @@ def completion(
|
|||
gapic_content_types.SafetySetting(x) for x in safety_settings
|
||||
]
|
||||
|
||||
## Custom model deployed on private endpoint
|
||||
private = False
|
||||
if "private" in optional_params:
|
||||
private = optional_params.pop("private")
|
||||
|
||||
# vertexai does not use an API key, it looks for credentials.json in the environment
|
||||
|
||||
prompt = " ".join(
|
||||
|
@ -344,23 +349,33 @@ def completion(
|
|||
mode = "chat"
|
||||
request_str += f"llm_model = CodeChatModel.from_pretrained({model})\n"
|
||||
else: # assume vertex model garden
|
||||
client = aiplatform.gapic.PredictionServiceClient(
|
||||
client_options=client_options
|
||||
)
|
||||
mode = "custom"
|
||||
|
||||
instances = [optional_params]
|
||||
instances[0]["prompt"] = prompt
|
||||
if not private:
|
||||
# private endpoint requires a string to be passed in
|
||||
instances = [
|
||||
json_format.ParseDict(instance_dict, Value())
|
||||
for instance_dict in instances
|
||||
]
|
||||
client = aiplatform.gapic.PredictionServiceClient(
|
||||
client_options=client_options
|
||||
)
|
||||
llm_model = client.endpoint_path(
|
||||
project=vertex_project, location=vertex_location, endpoint=model
|
||||
)
|
||||
|
||||
mode = "custom"
|
||||
request_str += f"llm_model = client.endpoint_path(project={vertex_project}, location={vertex_location}, endpoint={model})\n"
|
||||
|
||||
# private endpoint
|
||||
else:
|
||||
print("private endpoint", model)
|
||||
client = aiplatform.PrivateEndpoint(
|
||||
endpoint_name=model,
|
||||
project=vertex_project,
|
||||
location=vertex_location,
|
||||
)
|
||||
|
||||
if acompletion == True:
|
||||
data = {
|
||||
"llm_model": llm_model,
|
||||
|
@ -532,9 +547,6 @@ def completion(
|
|||
"""
|
||||
Vertex AI Model Garden
|
||||
"""
|
||||
request_str += (
|
||||
f"client.predict(endpoint={llm_model}, instances={instances})\n"
|
||||
)
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
|
@ -545,10 +557,25 @@ def completion(
|
|||
},
|
||||
)
|
||||
|
||||
# public endpoint
|
||||
if not private:
|
||||
request_str += (
|
||||
f"client.predict(endpoint={llm_model}, instances={instances})\n"
|
||||
)
|
||||
response = client.predict(
|
||||
endpoint=llm_model,
|
||||
instances=instances,
|
||||
).predictions
|
||||
# private endpoint
|
||||
else:
|
||||
request_str += (
|
||||
f"client.predict(instances={instances})\n"
|
||||
)
|
||||
print("instances", instances)
|
||||
response = client.predict(
|
||||
instances=instances,
|
||||
).predictions
|
||||
|
||||
completion_response = response[0]
|
||||
if (
|
||||
isinstance(completion_response, str)
|
||||
|
@ -723,16 +750,10 @@ async def async_completion(
|
|||
"""
|
||||
from google.cloud import aiplatform
|
||||
|
||||
async_client = aiplatform.gapic.PredictionServiceAsyncClient(
|
||||
client_options=client_options
|
||||
)
|
||||
llm_model = async_client.endpoint_path(
|
||||
project=vertex_project, location=vertex_location, endpoint=model
|
||||
)
|
||||
private = False
|
||||
if 'private' in optional_params:
|
||||
private = optional_params.pop('private')
|
||||
|
||||
request_str += (
|
||||
f"client.predict(endpoint={llm_model}, instances={instances})\n"
|
||||
)
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
|
@ -743,10 +764,35 @@ async def async_completion(
|
|||
},
|
||||
)
|
||||
|
||||
# public endpoint
|
||||
if not private:
|
||||
async_client = aiplatform.gapic.PredictionServiceAsyncClient(
|
||||
client_options=client_options
|
||||
)
|
||||
llm_model = async_client.endpoint_path(
|
||||
project=vertex_project, location=vertex_location, endpoint=model
|
||||
)
|
||||
request_str += (
|
||||
f"client.predict(endpoint={llm_model}, instances={instances})\n"
|
||||
)
|
||||
response_obj = await async_client.predict(
|
||||
endpoint=llm_model,
|
||||
instances=instances,
|
||||
)
|
||||
# private endpoint
|
||||
else:
|
||||
async_client = aiplatform.PrivateEndpoint(
|
||||
endpoint_name=model,
|
||||
project=vertex_project,
|
||||
location=vertex_location,
|
||||
)
|
||||
request_str += (
|
||||
f"client.predict(instances={instances})\n"
|
||||
)
|
||||
response_obj = await async_client.predict(
|
||||
instances=instances,
|
||||
)
|
||||
|
||||
response = response_obj.predictions
|
||||
completion_response = response[0]
|
||||
if (
|
||||
|
@ -895,14 +941,10 @@ async def async_streaming(
|
|||
elif mode == "custom":
|
||||
from google.cloud import aiplatform
|
||||
|
||||
async_client = aiplatform.gapic.PredictionServiceAsyncClient(
|
||||
client_options=client_options
|
||||
)
|
||||
llm_model = async_client.endpoint_path(
|
||||
project=vertex_project, location=vertex_location, endpoint=model
|
||||
)
|
||||
private = False
|
||||
if 'private' in optional_params:
|
||||
private = optional_params.pop('private')
|
||||
|
||||
request_str += f"client.predict(endpoint={llm_model}, instances={instances})\n"
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
|
@ -912,11 +954,31 @@ async def async_streaming(
|
|||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
# public endpoint
|
||||
if not private:
|
||||
async_client = aiplatform.gapic.PredictionServiceAsyncClient(
|
||||
client_options=client_options
|
||||
)
|
||||
llm_model = async_client.endpoint_path(
|
||||
project=vertex_project, location=vertex_location, endpoint=model
|
||||
)
|
||||
|
||||
request_str += f"client.predict(endpoint={llm_model}, instances={instances})\n"
|
||||
response_obj = await async_client.predict(
|
||||
endpoint=llm_model,
|
||||
instances=instances,
|
||||
)
|
||||
else:
|
||||
async_client = aiplatform.PrivateEndpoint(
|
||||
endpoint_name=model,
|
||||
project=vertex_project,
|
||||
location=vertex_location,
|
||||
)
|
||||
request_str += f"client.predict(instances={instances})\n"
|
||||
response_obj = await async_client.predict(
|
||||
instances=instances,
|
||||
)
|
||||
|
||||
response = response_obj.predictions
|
||||
completion_response = response[0]
|
||||
if (
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue