diff --git a/litellm/llms/vertex_ai.py b/litellm/llms/vertex_ai.py index 816ded3941..2e5fb3dd34 100644 --- a/litellm/llms/vertex_ai.py +++ b/litellm/llms/vertex_ai.py @@ -343,24 +343,31 @@ def completion( llm_model = CodeChatModel.from_pretrained(model) mode = "chat" request_str += f"llm_model = CodeChatModel.from_pretrained({model})\n" - else: # assume vertex model garden - client = aiplatform.gapic.PredictionServiceClient( - client_options=client_options + elif model == "private": + mode = "private" + model = optional_params.pop("model_id", None) + # private endpoint requires a dict instead of JSON + instances = [optional_params.copy()] + instances[0]["prompt"] = prompt + llm_model = aiplatform.PrivateEndpoint( + endpoint_name=model, + project=vertex_project, + location=vertex_location, ) + request_str += f"llm_model = aiplatform.PrivateEndpoint(endpoint_name={model}, project={vertex_project}, location={vertex_location})\n" + else: # assume vertex model garden on public endpoint + mode = "custom" - instances = [optional_params] + instances = [optional_params.copy()] instances[0]["prompt"] = prompt instances = [ json_format.ParseDict(instance_dict, Value()) for instance_dict in instances ] - llm_model = client.endpoint_path( - project=vertex_project, location=vertex_location, endpoint=model - ) - - mode = "custom" - request_str += f"llm_model = client.endpoint_path(project={vertex_project}, location={vertex_location}, endpoint={model})\n" + # Will determine the API used based on async parameter + llm_model = None + # NOTE: async prediction and streaming under "private" mode isn't supported by aiplatform right now if acompletion == True: data = { "llm_model": llm_model, @@ -532,9 +539,6 @@ def completion( """ Vertex AI Model Garden """ - request_str += ( - f"client.predict(endpoint={llm_model}, instances={instances})\n" - ) ## LOGGING logging_obj.pre_call( input=prompt, @@ -544,11 +548,21 @@ def completion( "request_str": request_str, }, ) - - response = client.predict( - endpoint=llm_model, - instances=instances, + llm_model = aiplatform.gapic.PredictionServiceClient( + client_options=client_options + ) + request_str += f"llm_model = aiplatform.gapic.PredictionServiceClient(client_options={client_options})\n" + endpoint_path = llm_model.endpoint_path( + project=vertex_project, location=vertex_location, endpoint=model + ) + request_str += ( + f"llm_model.predict(endpoint={endpoint_path}, instances={instances})\n" + ) + response = llm_model.predict( + endpoint=endpoint_path, + instances=instances ).predictions + completion_response = response[0] if ( isinstance(completion_response, str) @@ -558,6 +572,36 @@ def completion( if "stream" in optional_params and optional_params["stream"] == True: response = TextStreamer(completion_response) return response + elif mode == "private": + """ + Vertex AI Model Garden deployed on private endpoint + """ + ## LOGGING + logging_obj.pre_call( + input=prompt, + api_key=None, + additional_args={ + "complete_input_dict": optional_params, + "request_str": request_str, + }, + ) + request_str += ( + f"llm_model.predict(instances={instances})\n" + ) + response = llm_model.predict( + instances=instances + ).predictions + + completion_response = response[0] + if ( + isinstance(completion_response, str) + and "\nOutput:\n" in completion_response + ): + completion_response = completion_response.split("\nOutput:\n", 1)[1] + if "stream" in optional_params and optional_params["stream"] == True: + response = TextStreamer(completion_response) + return response + ## LOGGING logging_obj.post_call( input=prompt, api_key=None, original_response=completion_response @@ -722,17 +766,6 @@ async def async_completion( Vertex AI Model Garden """ from google.cloud import aiplatform - - async_client = aiplatform.gapic.PredictionServiceAsyncClient( - client_options=client_options - ) - llm_model = async_client.endpoint_path( - project=vertex_project, location=vertex_location, endpoint=model - ) - - request_str += ( - f"client.predict(endpoint={llm_model}, instances={instances})\n" - ) ## LOGGING logging_obj.pre_call( input=prompt, @@ -743,8 +776,18 @@ async def async_completion( }, ) - response_obj = await async_client.predict( - endpoint=llm_model, + llm_model = aiplatform.gapic.PredictionServiceAsyncClient( + client_options=client_options + ) + request_str += f"llm_model = aiplatform.gapic.PredictionServiceAsyncClient(client_options={client_options})\n" + endpoint_path = llm_model.endpoint_path( + project=vertex_project, location=vertex_location, endpoint=model + ) + request_str += ( + f"llm_model.predict(endpoint={endpoint_path}, instances={instances})\n" + ) + response_obj = await llm_model.predict( + endpoint=endpoint_path, instances=instances, ) response = response_obj.predictions @@ -754,6 +797,23 @@ async def async_completion( and "\nOutput:\n" in completion_response ): completion_response = completion_response.split("\nOutput:\n", 1)[1] + + elif mode == "private": + request_str += ( + f"llm_model.predict_async(instances={instances})\n" + ) + response_obj = await llm_model.predict_async( + instances=instances, + ) + + response = response_obj.predictions + completion_response = response[0] + if ( + isinstance(completion_response, str) + and "\nOutput:\n" in completion_response + ): + completion_response = completion_response.split("\nOutput:\n", 1)[1] + ## LOGGING logging_obj.post_call( input=prompt, api_key=None, original_response=completion_response @@ -894,15 +954,8 @@ async def async_streaming( response = llm_model.predict_streaming_async(prompt, **optional_params) elif mode == "custom": from google.cloud import aiplatform + stream = optional_params.pop("stream", None) - async_client = aiplatform.gapic.PredictionServiceAsyncClient( - client_options=client_options - ) - llm_model = async_client.endpoint_path( - project=vertex_project, location=vertex_location, endpoint=model - ) - - request_str += f"client.predict(endpoint={llm_model}, instances={instances})\n" ## LOGGING logging_obj.pre_call( input=prompt, @@ -912,9 +965,34 @@ async def async_streaming( "request_str": request_str, }, ) + llm_model = aiplatform.gapic.PredictionServiceAsyncClient( + client_options=client_options + ) + request_str += f"llm_model = aiplatform.gapic.PredictionServiceAsyncClient(client_options={client_options})\n" + endpoint_path = llm_model.endpoint_path( + project=vertex_project, location=vertex_location, endpoint=model + ) + request_str += f"client.predict(endpoint={endpoint_path}, instances={instances})\n" + response_obj = await llm_model.predict( + endpoint=endpoint_path, + instances=instances, + ) - response_obj = await async_client.predict( - endpoint=llm_model, + response = response_obj.predictions + completion_response = response[0] + if ( + isinstance(completion_response, str) + and "\nOutput:\n" in completion_response + ): + completion_response = completion_response.split("\nOutput:\n", 1)[1] + if stream: + response = TextStreamer(completion_response) + + elif mode == "private": + stream = optional_params.pop("stream", None) + _ = instances[0].pop("stream", None) + request_str += f"llm_model.predict_async(instances={instances})\n" + response_obj = await llm_model.predict_async( instances=instances, ) response = response_obj.predictions @@ -924,8 +1002,9 @@ async def async_streaming( and "\nOutput:\n" in completion_response ): completion_response = completion_response.split("\nOutput:\n", 1)[1] - if "stream" in optional_params and optional_params["stream"] == True: + if stream: response = TextStreamer(completion_response) + streamwrapper = CustomStreamWrapper( completion_stream=response, model=model, diff --git a/litellm/utils.py b/litellm/utils.py index 01a7b37b59..aad2cfa53c 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -4256,7 +4256,14 @@ def get_optional_params( optional_params["stop_sequences"] = stop if max_tokens is not None: optional_params["max_output_tokens"] = max_tokens - elif custom_llm_provider == "vertex_ai": + elif custom_llm_provider == "vertex_ai" and model in ( + litellm.vertex_chat_models + or model in litellm.vertex_code_chat_models + or model in litellm.vertex_text_models + or model in litellm.vertex_code_text_models + or model in litellm.vertex_language_models + or model in litellm.vertex_embedding_models + ): ## check if unsupported param passed in supported_params = [ "temperature",