diff --git a/litellm/llms/vertex_ai.py b/litellm/llms/vertex_ai.py index 44b53b13cc..8f16136efa 100644 --- a/litellm/llms/vertex_ai.py +++ b/litellm/llms/vertex_ai.py @@ -345,8 +345,7 @@ def completion( request_str += f"llm_model = CodeChatModel.from_pretrained({model})\n" elif model == "private": mode = "private" - if "model_id" in optional_params: - model = optional_params.pop("model_id") + model = optional_params.pop("model_id", None) # private endpoint requires a dict instead of JSON instances = [optional_params.copy()] instances[0]["prompt"] = prompt @@ -368,6 +367,7 @@ def completion( # Will determine the API used based on async parameter llm_model = None + # NOTE: async prediction and streaming under "private" mode isn't supported by aiplatform right now if acompletion == True: data = { "llm_model": llm_model, @@ -603,6 +603,7 @@ def completion( if "stream" in optional_params and optional_params["stream"] == True: response = TextStreamer(completion_response) return response + ## LOGGING logging_obj.post_call( input=prompt, api_key=None, original_response=completion_response @@ -968,18 +969,15 @@ async def async_streaming( "request_str": request_str, }, ) - async_client = aiplatform.gapic.PredictionServiceAsyncClient( + llm_model = aiplatform.gapic.PredictionServiceAsyncClient( client_options=client_options ) request_str += f"llm_model = aiplatform.gapic.PredictionServiceAsyncClient(client_options={client_options})\n" - llm_model = async_client.endpoint_path( - project=vertex_project, location=vertex_location, endpoint=model - ) endpoint_path = llm_model.endpoint_path( project=vertex_project, location=vertex_location, endpoint=model ) request_str += f"client.predict(endpoint={endpoint_path}, instances={instances})\n" - response_obj = await async_client.predict( + response_obj = await llm_model.predict( endpoint=endpoint_path, instances=instances, **optional_params, @@ -997,8 +995,11 @@ async def async_streaming( elif mode == "private": stream = optional_params.pop("stream", None) - request_str += f"llm_model.predict(instances={instances}, **{optional_params})\n" - response_obj = await async_client.predict( + _ = instances[0].pop("stream", None) + request_str += f"llm_model.predict_async(instances={instances}, **{optional_params})\n" + print("instances", instances) + print("optional_params", optional_params) + response_obj = await llm_model.predict_async( instances=instances, **optional_params, )