with text-bison

This commit is contained in:
ishaan-jaff 2023-08-14 10:33:59 -07:00
parent 15944eb0f3
commit 1b4aadbb25
4 changed files with 85 additions and 29 deletions

View file

@ -100,12 +100,18 @@ openrouter_models = [
'meta-llama/llama-2-70b-chat' 'meta-llama/llama-2-70b-chat'
] ]
vertex_models = [ vertex_chat_models = [
"chat-bison", "chat-bison",
"chat-bison@001" "chat-bison@001"
] ]
model_list = open_ai_chat_completion_models + open_ai_text_completion_models + cohere_models + anthropic_models + replicate_models + openrouter_models + vertex_models
vertex_text_models = [
"text-bison",
"text-bison@001"
]
model_list = open_ai_chat_completion_models + open_ai_text_completion_models + cohere_models + anthropic_models + replicate_models + openrouter_models + vertex_chat_models + vertex_text_models
####### EMBEDDING MODELS ################### ####### EMBEDDING MODELS ###################
open_ai_embedding_models = [ open_ai_embedding_models = [

View file

@ -47,7 +47,10 @@ def completion(
temperature=1, top_p=1, n=1, stream=False, stop=None, max_tokens=float('inf'), temperature=1, top_p=1, n=1, stream=False, stop=None, max_tokens=float('inf'),
presence_penalty=0, frequency_penalty=0, logit_bias={}, user="", deployment_id=None, presence_penalty=0, frequency_penalty=0, logit_bias={}, user="", deployment_id=None,
# Optional liteLLM function params # Optional liteLLM function params
*, return_async=False, api_key=None, force_timeout=600, logger_fn=None, verbose=False, azure=False, custom_llm_provider=None, custom_api_base=None *, return_async=False, api_key=None, force_timeout=600, logger_fn=None, verbose=False, azure=False, custom_llm_provider=None, custom_api_base=None,
# model specific optional params
# used by text-bison only
top_k=40,
): ):
try: try:
global new_response global new_response
@ -61,7 +64,7 @@ def completion(
temperature=temperature, top_p=top_p, n=n, stream=stream, stop=stop, max_tokens=max_tokens, temperature=temperature, top_p=top_p, n=n, stream=stream, stop=stop, max_tokens=max_tokens,
presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, logit_bias=logit_bias, user=user, deployment_id=deployment_id, presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, logit_bias=logit_bias, user=user, deployment_id=deployment_id,
# params to identify the model # params to identify the model
model=model, custom_llm_provider=custom_llm_provider model=model, custom_llm_provider=custom_llm_provider, top_k=top_k,
) )
# For logging - save the values of the litellm-specific params passed in # For logging - save the values of the litellm-specific params passed in
litellm_params = get_litellm_params( litellm_params = get_litellm_params(
@ -366,7 +369,7 @@ def completion(
"total_tokens": prompt_tokens + completion_tokens "total_tokens": prompt_tokens + completion_tokens
} }
response = model_response response = model_response
elif model in litellm.vertex_models: elif model in litellm.vertex_chat_models:
# import vertexai/if it fails then pip install vertexai# import cohere/if it fails then pip install cohere # import vertexai/if it fails then pip install vertexai# import cohere/if it fails then pip install cohere
install_and_import("vertexai") install_and_import("vertexai")
import vertexai import vertexai
@ -387,6 +390,28 @@ def completion(
## LOGGING ## LOGGING
logging(model=model, input=prompt, custom_llm_provider=custom_llm_provider, additional_args={"max_tokens": max_tokens, "original_response": completion_response}, logger_fn=logger_fn) logging(model=model, input=prompt, custom_llm_provider=custom_llm_provider, additional_args={"max_tokens": max_tokens, "original_response": completion_response}, logger_fn=logger_fn)
## RESPONSE OBJECT
model_response["choices"][0]["message"]["content"] = completion_response
model_response["created"] = time.time()
model_response["model"] = model
elif model in litellm.vertex_text_models:
# import vertexai/if it fails then pip install vertexai# import cohere/if it fails then pip install cohere
install_and_import("vertexai")
import vertexai
from vertexai.language_models import TextGenerationModel
vertexai.init(project=litellm.vertex_project, location=litellm.vertex_location)
# vertexai does not use an API key, it looks for credentials.json in the environment
prompt = " ".join([message["content"] for message in messages])
## LOGGING
logging(model=model, input=prompt, custom_llm_provider=custom_llm_provider, logger_fn=logger_fn)
vertex_model = TextGenerationModel.from_pretrained(model)
completion_response= vertex_model.predict(prompt, **optional_params)
## LOGGING
logging(model=model, input=prompt, custom_llm_provider=custom_llm_provider, additional_args={"max_tokens": max_tokens, "original_response": completion_response}, logger_fn=logger_fn)
## RESPONSE OBJECT ## RESPONSE OBJECT
model_response["choices"][0]["message"]["content"] = completion_response model_response["choices"][0]["message"]["content"] = completion_response
model_response["created"] = time.time() model_response["created"] = time.time()

View file

@ -1,32 +1,49 @@
# import sys, os import sys, os
# import traceback import traceback
# from dotenv import load_dotenv from dotenv import load_dotenv
# load_dotenv() load_dotenv()
# import os import os
# sys.path.insert(0, os.path.abspath('../..')) # Adds the parent directory to the system path sys.path.insert(0, os.path.abspath('../..')) # Adds the parent directory to the system path
# import pytest import pytest
# import litellm import litellm
# from litellm import embedding, completion from litellm import embedding, completion
# litellm.vertex_project = "hardy-device-386718" litellm.vertex_project = "hardy-device-386718"
# litellm.vertex_location = "us-central1" litellm.vertex_location = "us-central1"
litellm.set_verbose = True
# user_message = "what's the weather in SF " user_message = "what's the weather in SF "
# messages = [{ "content": user_message,"role": "user"}] messages = [{ "content": user_message,"role": "user"}]
# def logger_fn(user_model_dict):
# print(f"user_model_dict: {user_model_dict}")
# chat-bison
# response = completion(model="chat-bison", messages=messages, temperature=0.5, top_p=0.1) # response = completion(model="chat-bison", messages=messages, temperature=0.5, top_p=0.1)
# print(response) # print(response)
# text-bison
# # chat_model = ChatModel.from_pretrained("chat-bison@001") # response = completion(model="text-bison@001", messages=messages)
# # parameters = { # print(response)
# # "temperature": 0.2,
# # "max_output_tokens": 256,
# # "top_p": 0.8,
# # "top_k": 40
# # }
# # chat = chat_model.start_chat() # response = completion(model="text-bison@001", messages=messages, temperature=0.1, logger_fn=logger_fn)
# # response = chat.send_message("who are u? write a sentence", **parameters) # print(response)
# # print(f"Response from Model: {response.text}")
# response = completion(model="text-bison@001", messages=messages, temperature=0.4, top_p=0.1, logger_fn=logger_fn)
# print(response)
# response = completion(model="text-bison@001", messages=messages, temperature=0.8, top_p=0.4, top_k=30, logger_fn=logger_fn)
# print(response)
# chat_model = ChatModel.from_pretrained("chat-bison@001")
# parameters = {
# "temperature": 0.2,
# "max_output_tokens": 256,
# "top_p": 0.8,
# "top_k": 40
# }
# chat = chat_model.start_chat()
# response = chat.send_message("who are u? write a sentence", **parameters)
# print(f"Response from Model: {response.text}")

View file

@ -246,7 +246,8 @@ def get_optional_params(
user = "", user = "",
deployment_id = None, deployment_id = None,
model = None, model = None,
custom_llm_provider = "" custom_llm_provider = "",
top_k = 40,
): ):
optional_params = {} optional_params = {}
if model in litellm.anthropic_models: if model in litellm.anthropic_models:
@ -293,6 +294,13 @@ def get_optional_params(
optional_params["top_p"] = top_p optional_params["top_p"] = top_p
if max_tokens != float('inf'): if max_tokens != float('inf'):
optional_params["max_output_tokens"] = max_tokens optional_params["max_output_tokens"] = max_tokens
elif model in litellm.vertex_text_models:
# required params for all text vertex calls
# temperature=0.2, top_p=0.1, top_k=20
# always set temperature, top_p, top_k else, text bison fails
optional_params["temperature"] = temperature
optional_params["top_p"] = top_p
optional_params["top_k"] = top_k
else:# assume passing in params for openai/azure openai else:# assume passing in params for openai/azure openai
if functions != []: if functions != []: