ai21 models

This commit is contained in:
ishaan-jaff 2023-08-14 16:40:26 -07:00
parent ccec725be4
commit 06b5579ba6
4 changed files with 51 additions and 1 deletions

View file

@ -69,6 +69,16 @@ Here are some examples of supported models:
| [google/flan-t5-xxl](https://huggingface.co/google/flan-t5-xxl) | `completion(model="google/flan-t5-xxl", messages=messages, custom_llm_provider="huggingface")` | `os.environ['HF_TOKEN']` | | [google/flan-t5-xxl](https://huggingface.co/google/flan-t5-xxl) | `completion(model="google/flan-t5-xxl", messages=messages, custom_llm_provider="huggingface")` | `os.environ['HF_TOKEN']` |
| [google/flan-t5-large](https://huggingface.co/google/flan-t5-large) | `completion(model="google/flan-t5-large", messages=messages, custom_llm_provider="huggingface")` | `os.environ['HF_TOKEN']` | | [google/flan-t5-large](https://huggingface.co/google/flan-t5-large) | `completion(model="google/flan-t5-large", messages=messages, custom_llm_provider="huggingface")` | `os.environ['HF_TOKEN']` |
### AI21 Models
| Model Name | Function Call | Required OS Variables |
|------------------|--------------------------------------------|--------------------------------------|
| j2-light | `completion('j2-light', messages)` | `os.environ['AI21_API_KEY']` |
| j2-mid | `completion('j2-mid', messages)` | `os.environ['AI21_API_KEY']` |
| j2-ultra | `completion('j2-ultra', messages)` | `os.environ['AI21_API_KEY']` |
### Cohere Models ### Cohere Models
| Model Name | Function Call | Required OS Variables | | Model Name | Function Call | Required OS Variables |

View file

@ -124,7 +124,14 @@ huggingface_models = [
"meta-llama/Llama-2-70b-chat", "meta-llama/Llama-2-70b-chat",
] # these have been tested on extensively. But by default all text2text-generation and text-generation models are supported by liteLLM. - https://docs.litellm.ai/docs/completion/supported ] # these have been tested on extensively. But by default all text2text-generation and text-generation models are supported by liteLLM. - https://docs.litellm.ai/docs/completion/supported
model_list = open_ai_chat_completion_models + open_ai_text_completion_models + cohere_models + anthropic_models + replicate_models + openrouter_models + huggingface_models + vertex_chat_models + vertex_text_models ai21_models = [
"j2-ultra",
"j2-mid",
"j2-light"
]
model_list = open_ai_chat_completion_models + open_ai_text_completion_models + cohere_models + anthropic_models + replicate_models + openrouter_models + huggingface_models + vertex_chat_models + vertex_text_models + ai21_models
####### EMBEDDING MODELS ################### ####### EMBEDDING MODELS ###################
open_ai_embedding_models = [ open_ai_embedding_models = [

View file

@ -394,6 +394,29 @@ def completion(
## LOGGING ## LOGGING
logging(model=model, input=prompt, custom_llm_provider=custom_llm_provider, additional_args={"max_tokens": max_tokens, "original_response": completion_response}, logger_fn=logger_fn) logging(model=model, input=prompt, custom_llm_provider=custom_llm_provider, additional_args={"max_tokens": max_tokens, "original_response": completion_response}, logger_fn=logger_fn)
## RESPONSE OBJECT
model_response["choices"][0]["message"]["content"] = completion_response
model_response["created"] = time.time()
model_response["model"] = model
response = model_response
elif model in litellm.ai21_models:
install_and_import("ai21")
import ai21
ai21.api_key = get_secret("AI21_API_KEY")
prompt = " ".join([message["content"] for message in messages])
## LOGGING
logging(model=model, input=prompt, custom_llm_provider=custom_llm_provider, logger_fn=logger_fn)
ai21_response = ai21.Completion.execute(
model=model,
prompt=prompt,
)
completion_response = ai21_response['completions'][0]['data']['text']
## LOGGING
logging(model=model, input=prompt, custom_llm_provider=custom_llm_provider, additional_args={"max_tokens": max_tokens, "original_response": completion_response}, logger_fn=logger_fn)
## RESPONSE OBJECT ## RESPONSE OBJECT
model_response["choices"][0]["message"]["content"] = completion_response model_response["choices"][0]["message"]["content"] = completion_response
model_response["created"] = time.time() model_response["created"] = time.time()

View file

@ -213,6 +213,16 @@ def test_completion_together_ai_stream():
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
#### Test A121 ###################
def test_completion_ai21():
model_name = "j2-light"
try:
response = completion(model=model_name, messages=messages)
# Add any assertions here to check the response
print(response)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test config file with completion # # test config file with completion #
# def test_completion_openai_config(): # def test_completion_openai_config():
# try: # try: