update private async

This commit is contained in:
Rena Lu 2024-02-13 18:33:52 +00:00
parent 60c0bec7b3
commit e011f8022a

View file

@ -345,8 +345,7 @@ def completion(
request_str += f"llm_model = CodeChatModel.from_pretrained({model})\n"
elif model == "private":
mode = "private"
if "model_id" in optional_params:
model = optional_params.pop("model_id")
model = optional_params.pop("model_id", None)
# private endpoint requires a dict instead of JSON
instances = [optional_params.copy()]
instances[0]["prompt"] = prompt
@ -368,6 +367,7 @@ def completion(
# Will determine the API used based on async parameter
llm_model = None
# NOTE: async prediction and streaming under "private" mode isn't supported by aiplatform right now
if acompletion == True:
data = {
"llm_model": llm_model,
@ -603,6 +603,7 @@ def completion(
if "stream" in optional_params and optional_params["stream"] == True:
response = TextStreamer(completion_response)
return response
## LOGGING
logging_obj.post_call(
input=prompt, api_key=None, original_response=completion_response
@ -968,18 +969,15 @@ async def async_streaming(
"request_str": request_str,
},
)
async_client = aiplatform.gapic.PredictionServiceAsyncClient(
llm_model = aiplatform.gapic.PredictionServiceAsyncClient(
client_options=client_options
)
request_str += f"llm_model = aiplatform.gapic.PredictionServiceAsyncClient(client_options={client_options})\n"
llm_model = async_client.endpoint_path(
project=vertex_project, location=vertex_location, endpoint=model
)
endpoint_path = llm_model.endpoint_path(
project=vertex_project, location=vertex_location, endpoint=model
)
request_str += f"client.predict(endpoint={endpoint_path}, instances={instances})\n"
response_obj = await async_client.predict(
response_obj = await llm_model.predict(
endpoint=endpoint_path,
instances=instances,
**optional_params,
@ -997,8 +995,11 @@ async def async_streaming(
elif mode == "private":
stream = optional_params.pop("stream", None)
request_str += f"llm_model.predict(instances={instances}, **{optional_params})\n"
response_obj = await async_client.predict(
_ = instances[0].pop("stream", None)
request_str += f"llm_model.predict_async(instances={instances}, **{optional_params})\n"
print("instances", instances)
print("optional_params", optional_params)
response_obj = await llm_model.predict_async(
instances=instances,
**optional_params,
)