style(test_completion.py): fix merge conflict

This commit is contained in:
Krrish Dholakia 2023-10-05 22:09:38 -07:00
parent 396d9d8e38
commit dd7e397650
22 changed files with 1535 additions and 250 deletions

View file

@ -45,7 +45,8 @@ from .llms import (
cohere,
petals,
oobabooga,
palm)
palm,
vertex_ai)
from .llms.prompt_templates.factory import prompt_factory, custom_prompt
import tiktoken
from concurrent.futures import ThreadPoolExecutor
@ -810,134 +811,32 @@ def completion(
)
return response
response = model_response
elif model in litellm.vertex_chat_models or model in litellm.vertex_code_chat_models:
try:
import vertexai
except:
raise Exception("vertexai import failed please run `pip install google-cloud-aiplatform`")
from vertexai.preview.language_models import ChatModel, CodeChatModel, InputOutputTextPair
elif model in litellm.vertex_chat_models or model in litellm.vertex_code_chat_models or model in litellm.vertex_text_models or model in litellm.vertex_code_text_models:
vertex_ai_project = (litellm.vertex_project
or get_secret("VERTEXAI_PROJECT"))
vertex_ai_location = (litellm.vertex_location
or get_secret("VERTEXAI_LOCATION"))
vertex_project = (litellm.vertex_project or get_secret("VERTEXAI_PROJECT"))
vertex_location = (litellm.vertex_location or get_secret("VERTEXAI_LOCATION"))
vertexai.init(
project=vertex_project, location=vertex_location
# palm does not support streaming as yet :(
model_response = vertex_ai.completion(
model=model,
messages=messages,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
vertex_location=vertex_ai_location,
vertex_project=vertex_ai_project,
logging_obj=logging
)
# vertexai does not use an API key, it looks for credentials.json in the environment
prompt = " ".join([message["content"] for message in messages])
# contains any default values we need to pass to the provider
VertexAIConfig = {
"top_k": 40 # override by setting kwarg in completion() - e.g. completion(..., top_k=20)
}
if model in litellm.vertex_chat_models:
chat_model = ChatModel.from_pretrained(model)
else: # vertex_code_chat_models
chat_model = CodeChatModel.from_pretrained(model)
chat = chat_model.start_chat()
## Load Config
for k, v in VertexAIConfig.items():
if k not in optional_params:
optional_params[k] = v
## LOGGING
logging.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params})
if "stream" in optional_params and optional_params["stream"] == True:
model_response = chat.send_message_streaming(prompt, **optional_params)
response = CustomStreamWrapper(
model_response, model, custom_llm_provider="vertex_ai", logging_obj=logging
)
)
return response
completion_response = chat.send_message(prompt, **optional_params)
## LOGGING
logging.post_call(
input=prompt, api_key=None, original_response=completion_response
)
## RESPONSE OBJECT
model_response["choices"][0]["message"]["content"] = str(completion_response)
model_response["created"] = time.time()
model_response["model"] = model
## CALCULATING USAGE
prompt_tokens = len(
encoding.encode(prompt)
)
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"]["content"])
)
model_response["usage"] = {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
}
response = model_response
elif model in litellm.vertex_text_models or model in litellm.vertex_code_text_models:
try:
import vertexai
except:
raise Exception("vertexai import failed please run `pip install google-cloud-aiplatform`")
from vertexai.language_models import TextGenerationModel, CodeGenerationModel
vertexai.init(
project=litellm.vertex_project, location=litellm.vertex_location
)
# vertexai does not use an API key, it looks for credentials.json in the environment
# contains any default values we need to pass to the provider
VertexAIConfig = {
"top_k": 40 # override by setting kwarg in completion() - e.g. completion(..., top_k=20)
}
prompt = " ".join([message["content"] for message in messages])
if model in litellm.vertex_text_models:
vertex_model = TextGenerationModel.from_pretrained(model)
else:
vertex_model = CodeGenerationModel.from_pretrained(model)
## Load Config
for k, v in VertexAIConfig.items():
if k not in optional_params:
optional_params[k] = v
## LOGGING
logging.pre_call(input=prompt, api_key=None)
if "stream" in optional_params and optional_params["stream"] == True:
model_response = vertex_model.predict_streaming(prompt, **optional_params)
response = CustomStreamWrapper(
model_response, model, custom_llm_provider="vertexai", logging_obj=logging
)
return response
completion_response = vertex_model.predict(prompt, **optional_params)
## LOGGING
logging.post_call(
input=prompt, api_key=None, original_response=completion_response
)
## RESPONSE OBJECT
model_response["choices"][0]["message"]["content"] = str(completion_response)
model_response["created"] = time.time()
model_response["model"] = model
## CALCULATING USAGE
prompt_tokens = len(
encoding.encode(prompt)
)
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"]["content"])
)
model_response["usage"] = {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
}
response = model_response
elif model in litellm.ai21_models:
custom_llm_provider = "ai21"
@ -1122,10 +1021,16 @@ def completion(
custom_llm_provider == "petals"
or model in litellm.petals_models
):
api_base = (
litellm.api_base or
api_base
)
custom_llm_provider = "petals"
model_response = petals.completion(
model=model,
messages=messages,
api_base=api_base,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,