fix(vertex_ai.py): fix exception mapping for vertex ai

This commit is contained in:
Krrish Dholakia 2023-11-23 17:35:26 -08:00
parent 7fda40e2be
commit 78d13ea6eb
4 changed files with 107 additions and 87 deletions

View file

@ -73,93 +73,96 @@ def completion(
try:
import vertexai
except:
raise Exception("vertexai import failed please run `pip install google-cloud-aiplatform`")
from vertexai.preview.language_models import ChatModel, CodeChatModel, InputOutputTextPair
from vertexai.language_models import TextGenerationModel, CodeGenerationModel
raise VertexAIError(status_code=400,message="vertexai import failed please run `pip install google-cloud-aiplatform`")
try:
from vertexai.preview.language_models import ChatModel, CodeChatModel, InputOutputTextPair
from vertexai.language_models import TextGenerationModel, CodeGenerationModel
vertexai.init(
project=vertex_project, location=vertex_location
)
## Load Config
config = litellm.VertexAIConfig.get_config()
for k, v in config.items():
if k not in optional_params:
optional_params[k] = v
# vertexai does not use an API key, it looks for credentials.json in the environment
prompt = " ".join([message["content"] for message in messages])
mode = ""
if model in litellm.vertex_chat_models:
chat_model = ChatModel.from_pretrained(model)
mode = "chat"
elif model in litellm.vertex_text_models:
text_model = TextGenerationModel.from_pretrained(model)
mode = "text"
elif model in litellm.vertex_code_text_models:
text_model = CodeGenerationModel.from_pretrained(model)
mode = "text"
else: # vertex_code_chat_models
chat_model = CodeChatModel.from_pretrained(model)
mode = "chat"
if mode == "chat":
chat = chat_model.start_chat()
## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params})
if "stream" in optional_params and optional_params["stream"] == True:
# NOTE: VertexAI does not accept stream=True as a param and raises an error,
# we handle this by removing 'stream' from optional params and sending the request
# after we get the response we add optional_params["stream"] = True, since main.py needs to know it's a streaming response to then transform it for the OpenAI format
optional_params.pop("stream", None) # vertex ai raises an error when passing stream in optional params
model_response = chat.send_message_streaming(prompt, **optional_params)
optional_params["stream"] = True
return model_response
completion_response = chat.send_message(prompt, **optional_params).text
elif mode == "text":
## LOGGING
logging_obj.pre_call(input=prompt, api_key=None)
if "stream" in optional_params and optional_params["stream"] == True:
optional_params.pop("stream", None) # See note above on handling streaming for vertex ai
model_response = text_model.predict_streaming(prompt, **optional_params)
optional_params["stream"] = True
return model_response
completion_response = text_model.predict(prompt, **optional_params).text
## LOGGING
logging_obj.post_call(
input=prompt, api_key=None, original_response=completion_response
)
## RESPONSE OBJECT
if len(str(completion_response)) > 0:
model_response["choices"][0]["message"][
"content"
] = str(completion_response)
model_response["choices"][0]["message"]["content"] = str(completion_response)
model_response["created"] = int(time.time())
model_response["model"] = model
## CALCULATING USAGE
prompt_tokens = len(
encoding.encode(prompt)
)
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
)
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
vertexai.init(
project=vertex_project, location=vertex_location
)
model_response.usage = usage
return model_response
## Load Config
config = litellm.VertexAIConfig.get_config()
for k, v in config.items():
if k not in optional_params:
optional_params[k] = v
# vertexai does not use an API key, it looks for credentials.json in the environment
prompt = " ".join([message["content"] for message in messages])
mode = ""
if model in litellm.vertex_chat_models:
chat_model = ChatModel.from_pretrained(model)
mode = "chat"
elif model in litellm.vertex_text_models:
text_model = TextGenerationModel.from_pretrained(model)
mode = "text"
elif model in litellm.vertex_code_text_models:
text_model = CodeGenerationModel.from_pretrained(model)
mode = "text"
else: # vertex_code_chat_models
chat_model = CodeChatModel.from_pretrained(model)
mode = "chat"
if mode == "chat":
chat = chat_model.start_chat()
## LOGGING
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params})
if "stream" in optional_params and optional_params["stream"] == True:
# NOTE: VertexAI does not accept stream=True as a param and raises an error,
# we handle this by removing 'stream' from optional params and sending the request
# after we get the response we add optional_params["stream"] = True, since main.py needs to know it's a streaming response to then transform it for the OpenAI format
optional_params.pop("stream", None) # vertex ai raises an error when passing stream in optional params
model_response = chat.send_message_streaming(prompt, **optional_params)
optional_params["stream"] = True
return model_response
completion_response = chat.send_message(prompt, **optional_params).text
elif mode == "text":
## LOGGING
logging_obj.pre_call(input=prompt, api_key=None)
if "stream" in optional_params and optional_params["stream"] == True:
optional_params.pop("stream", None) # See note above on handling streaming for vertex ai
model_response = text_model.predict_streaming(prompt, **optional_params)
optional_params["stream"] = True
return model_response
completion_response = text_model.predict(prompt, **optional_params).text
## LOGGING
logging_obj.post_call(
input=prompt, api_key=None, original_response=completion_response
)
## RESPONSE OBJECT
if len(str(completion_response)) > 0:
model_response["choices"][0]["message"][
"content"
] = str(completion_response)
model_response["choices"][0]["message"]["content"] = str(completion_response)
model_response["created"] = int(time.time())
model_response["model"] = model
## CALCULATING USAGE
prompt_tokens = len(
encoding.encode(prompt)
)
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
)
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
)
model_response.usage = usage
return model_response
except Exception as e:
raise VertexAIError(status_code=500, message=str(e))
def embedding():