diff --git a/litellm/llms/vertex_ai.py b/litellm/llms/vertex_ai.py index d42bd003f8..a3597271ad 100644 --- a/litellm/llms/vertex_ai.py +++ b/litellm/llms/vertex_ai.py @@ -301,6 +301,11 @@ def completion( gapic_content_types.SafetySetting(x) for x in safety_settings ] + ## Custom model deployed on private endpoint + private = False + if "private" in optional_params: + private = optional_params.pop("private") + # vertexai does not use an API key, it looks for credentials.json in the environment prompt = " ".join( @@ -344,22 +349,32 @@ def completion( mode = "chat" request_str += f"llm_model = CodeChatModel.from_pretrained({model})\n" else: # assume vertex model garden - client = aiplatform.gapic.PredictionServiceClient( - client_options=client_options - ) + mode = "custom" instances = [optional_params] instances[0]["prompt"] = prompt - instances = [ - json_format.ParseDict(instance_dict, Value()) - for instance_dict in instances - ] - llm_model = client.endpoint_path( + if not private: + # private endpoint requires a string to be passed in + instances = [ + json_format.ParseDict(instance_dict, Value()) + for instance_dict in instances + ] + client = aiplatform.gapic.PredictionServiceClient( + client_options=client_options + ) + llm_model = client.endpoint_path( project=vertex_project, location=vertex_location, endpoint=model - ) + ) + request_str += f"llm_model = client.endpoint_path(project={vertex_project}, location={vertex_location}, endpoint={model})\n" - mode = "custom" - request_str += f"llm_model = client.endpoint_path(project={vertex_project}, location={vertex_location}, endpoint={model})\n" + # private endpoint + else: + print("private endpoint", model) + client = aiplatform.PrivateEndpoint( + endpoint_name=model, + project=vertex_project, + location=vertex_location, + ) if acompletion == True: data = { @@ -532,9 +547,6 @@ def completion( """ Vertex AI Model Garden """ - request_str += ( - f"client.predict(endpoint={llm_model}, instances={instances})\n" - ) ## LOGGING logging_obj.pre_call( input=prompt, @@ -545,10 +557,25 @@ def completion( }, ) - response = client.predict( - endpoint=llm_model, - instances=instances, - ).predictions + # public endpoint + if not private: + request_str += ( + f"client.predict(endpoint={llm_model}, instances={instances})\n" + ) + response = client.predict( + endpoint=llm_model, + instances=instances, + ).predictions + # private endpoint + else: + request_str += ( + f"client.predict(instances={instances})\n" + ) + print("instances", instances) + response = client.predict( + instances=instances, + ).predictions + completion_response = response[0] if ( isinstance(completion_response, str) @@ -723,16 +750,10 @@ async def async_completion( """ from google.cloud import aiplatform - async_client = aiplatform.gapic.PredictionServiceAsyncClient( - client_options=client_options - ) - llm_model = async_client.endpoint_path( - project=vertex_project, location=vertex_location, endpoint=model - ) + private = False + if 'private' in optional_params: + private = optional_params.pop('private') - request_str += ( - f"client.predict(endpoint={llm_model}, instances={instances})\n" - ) ## LOGGING logging_obj.pre_call( input=prompt, @@ -743,10 +764,35 @@ async def async_completion( }, ) - response_obj = await async_client.predict( - endpoint=llm_model, - instances=instances, - ) + # public endpoint + if not private: + async_client = aiplatform.gapic.PredictionServiceAsyncClient( + client_options=client_options + ) + llm_model = async_client.endpoint_path( + project=vertex_project, location=vertex_location, endpoint=model + ) + request_str += ( + f"client.predict(endpoint={llm_model}, instances={instances})\n" + ) + response_obj = await async_client.predict( + endpoint=llm_model, + instances=instances, + ) + # private endpoint + else: + async_client = aiplatform.PrivateEndpoint( + endpoint_name=model, + project=vertex_project, + location=vertex_location, + ) + request_str += ( + f"client.predict(instances={instances})\n" + ) + response_obj = await async_client.predict( + instances=instances, + ) + response = response_obj.predictions completion_response = response[0] if ( @@ -895,14 +941,10 @@ async def async_streaming( elif mode == "custom": from google.cloud import aiplatform - async_client = aiplatform.gapic.PredictionServiceAsyncClient( - client_options=client_options - ) - llm_model = async_client.endpoint_path( - project=vertex_project, location=vertex_location, endpoint=model - ) + private = False + if 'private' in optional_params: + private = optional_params.pop('private') - request_str += f"client.predict(endpoint={llm_model}, instances={instances})\n" ## LOGGING logging_obj.pre_call( input=prompt, @@ -912,11 +954,31 @@ async def async_streaming( "request_str": request_str, }, ) + # public endpoint + if not private: + async_client = aiplatform.gapic.PredictionServiceAsyncClient( + client_options=client_options + ) + llm_model = async_client.endpoint_path( + project=vertex_project, location=vertex_location, endpoint=model + ) + + request_str += f"client.predict(endpoint={llm_model}, instances={instances})\n" + response_obj = await async_client.predict( + endpoint=llm_model, + instances=instances, + ) + else: + async_client = aiplatform.PrivateEndpoint( + endpoint_name=model, + project=vertex_project, + location=vertex_location, + ) + request_str += f"client.predict(instances={instances})\n" + response_obj = await async_client.predict( + instances=instances, + ) - response_obj = await async_client.predict( - endpoint=llm_model, - instances=instances, - ) response = response_obj.predictions completion_response = response[0] if (