mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
update private async
This commit is contained in:
parent
60c0bec7b3
commit
e011f8022a
1 changed files with 10 additions and 9 deletions
|
@ -345,8 +345,7 @@ def completion(
|
|||
request_str += f"llm_model = CodeChatModel.from_pretrained({model})\n"
|
||||
elif model == "private":
|
||||
mode = "private"
|
||||
if "model_id" in optional_params:
|
||||
model = optional_params.pop("model_id")
|
||||
model = optional_params.pop("model_id", None)
|
||||
# private endpoint requires a dict instead of JSON
|
||||
instances = [optional_params.copy()]
|
||||
instances[0]["prompt"] = prompt
|
||||
|
@ -368,6 +367,7 @@ def completion(
|
|||
# Will determine the API used based on async parameter
|
||||
llm_model = None
|
||||
|
||||
# NOTE: async prediction and streaming under "private" mode isn't supported by aiplatform right now
|
||||
if acompletion == True:
|
||||
data = {
|
||||
"llm_model": llm_model,
|
||||
|
@ -603,6 +603,7 @@ def completion(
|
|||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
response = TextStreamer(completion_response)
|
||||
return response
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=prompt, api_key=None, original_response=completion_response
|
||||
|
@ -968,18 +969,15 @@ async def async_streaming(
|
|||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
async_client = aiplatform.gapic.PredictionServiceAsyncClient(
|
||||
llm_model = aiplatform.gapic.PredictionServiceAsyncClient(
|
||||
client_options=client_options
|
||||
)
|
||||
request_str += f"llm_model = aiplatform.gapic.PredictionServiceAsyncClient(client_options={client_options})\n"
|
||||
llm_model = async_client.endpoint_path(
|
||||
project=vertex_project, location=vertex_location, endpoint=model
|
||||
)
|
||||
endpoint_path = llm_model.endpoint_path(
|
||||
project=vertex_project, location=vertex_location, endpoint=model
|
||||
)
|
||||
request_str += f"client.predict(endpoint={endpoint_path}, instances={instances})\n"
|
||||
response_obj = await async_client.predict(
|
||||
response_obj = await llm_model.predict(
|
||||
endpoint=endpoint_path,
|
||||
instances=instances,
|
||||
**optional_params,
|
||||
|
@ -997,8 +995,11 @@ async def async_streaming(
|
|||
|
||||
elif mode == "private":
|
||||
stream = optional_params.pop("stream", None)
|
||||
request_str += f"llm_model.predict(instances={instances}, **{optional_params})\n"
|
||||
response_obj = await async_client.predict(
|
||||
_ = instances[0].pop("stream", None)
|
||||
request_str += f"llm_model.predict_async(instances={instances}, **{optional_params})\n"
|
||||
print("instances", instances)
|
||||
print("optional_params", optional_params)
|
||||
response_obj = await llm_model.predict_async(
|
||||
instances=instances,
|
||||
**optional_params,
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue