fix(vertex_ai.py): add exception mapping for acompletion calls

This commit is contained in:
Krrish Dholakia 2023-12-13 16:35:50 -08:00
parent 16b7b8097b
commit c673a23769
5 changed files with 99 additions and 70 deletions

View file

@ -220,64 +220,67 @@ async def async_completion(llm_model, mode: str, prompt: str, model: str, model_
"""
Add support for acompletion calls for gemini-pro
"""
from vertexai.preview.generative_models import GenerationConfig
try:
from vertexai.preview.generative_models import GenerationConfig
if mode == "":
# gemini-pro
chat = llm_model.start_chat()
## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
response_obj = await chat.send_message_async(prompt, generation_config=GenerationConfig(**optional_params))
completion_response = response_obj.text
response_obj = response_obj._raw_response
elif mode == "chat":
# chat-bison etc.
chat = llm_model.start_chat()
## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
response_obj = await chat.send_message_async(prompt, **optional_params)
completion_response = response_obj.text
elif mode == "text":
# gecko etc.
request_str += f"llm_model.predict({prompt}, **{optional_params}).text\n"
## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
response_obj = await llm_model.predict_async(prompt, **optional_params)
completion_response = response_obj.text
if mode == "":
# gemini-pro
chat = llm_model.start_chat()
## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
response_obj = await chat.send_message_async(prompt, generation_config=GenerationConfig(**optional_params))
completion_response = response_obj.text
response_obj = response_obj._raw_response
elif mode == "chat":
# chat-bison etc.
chat = llm_model.start_chat()
## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
response_obj = await chat.send_message_async(prompt, **optional_params)
completion_response = response_obj.text
elif mode == "text":
# gecko etc.
request_str += f"llm_model.predict({prompt}, **{optional_params}).text\n"
## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
response_obj = await llm_model.predict_async(prompt, **optional_params)
completion_response = response_obj.text
## LOGGING
logging_obj.post_call(
input=prompt, api_key=None, original_response=completion_response
)
## RESPONSE OBJECT
if len(str(completion_response)) > 0:
model_response["choices"][0]["message"][
"content"
] = str(completion_response)
model_response["choices"][0]["message"]["content"] = str(completion_response)
model_response["created"] = int(time.time())
model_response["model"] = model
## CALCULATING USAGE
if model in litellm.vertex_language_models and response_obj is not None:
model_response["choices"][0].finish_reason = response_obj.candidates[0].finish_reason.name
usage = Usage(prompt_tokens=response_obj.usage_metadata.prompt_token_count,
completion_tokens=response_obj.usage_metadata.candidates_token_count,
total_tokens=response_obj.usage_metadata.total_token_count)
else:
prompt_tokens = len(
encoding.encode(prompt)
)
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
## LOGGING
logging_obj.post_call(
input=prompt, api_key=None, original_response=completion_response
)
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
## RESPONSE OBJECT
if len(str(completion_response)) > 0:
model_response["choices"][0]["message"][
"content"
] = str(completion_response)
model_response["choices"][0]["message"]["content"] = str(completion_response)
model_response["created"] = int(time.time())
model_response["model"] = model
## CALCULATING USAGE
if model in litellm.vertex_language_models and response_obj is not None:
model_response["choices"][0].finish_reason = response_obj.candidates[0].finish_reason.name
usage = Usage(prompt_tokens=response_obj.usage_metadata.prompt_token_count,
completion_tokens=response_obj.usage_metadata.candidates_token_count,
total_tokens=response_obj.usage_metadata.total_token_count)
else:
prompt_tokens = len(
encoding.encode(prompt)
)
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
)
model_response.usage = usage
return model_response
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
)
model_response.usage = usage
return model_response
except Exception as e:
raise VertexAIError(status_code=500, message=str(e))
async def async_streaming(llm_model, mode: str, prompt: str, model: str, model_response: ModelResponse, logging_obj=None, request_str=None, **optional_params):
"""