This commit is contained in:
ishaan-jaff 2023-09-15 13:38:26 -07:00
parent 3623370c7f
commit abb3793e50
2 changed files with 26 additions and 34 deletions

View file

@ -638,12 +638,12 @@ def completion(
)
return response
response = model_response
elif model in litellm.vertex_chat_models:
elif model in litellm.vertex_chat_models or model in litellm.vertex_code_chat_models:
try:
import vertexai
except:
raise Exception("vertexai import failed please run `pip install google-cloud-aiplatform`")
from vertexai.preview.language_models import ChatModel, InputOutputTextPair
from vertexai.preview.language_models import ChatModel, CodeChatModel, InputOutputTextPair
vertexai.init(
project=litellm.vertex_project, location=litellm.vertex_location
@ -653,8 +653,10 @@ def completion(
prompt = " ".join([message["content"] for message in messages])
## LOGGING
logging.pre_call(input=prompt, api_key=None)
chat_model = ChatModel.from_pretrained(model)
if model in litellm.vertex_chat_models:
chat_model = ChatModel.from_pretrained(model)
else: # vertex_code_chat_models
chat_model = CodeChatModel.from_pretrained(model)
chat = chat_model.start_chat()
completion_response = chat.send_message(prompt, **optional_params)
@ -669,12 +671,12 @@ def completion(
model_response["created"] = time.time()
model_response["model"] = model
response = model_response
elif model in litellm.vertex_text_models:
elif model in litellm.vertex_text_models or model in litellm.vertex_code_text_models:
try:
import vertexai
except:
raise Exception("vertexai import failed please run `pip install google-cloud-aiplatform`")
from vertexai.language_models import TextGenerationModel
from vertexai.language_models import TextGenerationModel, CodeGenerationModel
vertexai.init(
project=litellm.vertex_project, location=litellm.vertex_location
@ -684,8 +686,12 @@ def completion(
prompt = " ".join([message["content"] for message in messages])
## LOGGING
logging.pre_call(input=prompt, api_key=None)
if model in litellm.vertex_text_models:
vertex_model = TextGenerationModel.from_pretrained(model)
else:
vertex_model = CodeGenerationModel.from_pretrained(model)
vertex_model = TextGenerationModel.from_pretrained(model)
completion_response = vertex_model.predict(prompt, **optional_params)
## LOGGING