adding coverage for ai21

This commit is contained in:
Krrish Dholakia 2023-08-29 13:32:20 -07:00
parent f11599e50c
commit 436e8eadb2
4 changed files with 39 additions and 3 deletions

View file

@ -102,7 +102,7 @@ class AI21LLM:
try: try:
model_response["choices"][0]["message"]["content"] = completion_response["completions"][0]["data"]["text"] model_response["choices"][0]["message"]["content"] = completion_response["completions"][0]["data"]["text"]
except: except:
raise ValueError(f"Unable to parse response. Original response: {response.text}") raise AI21Error(message=json.dumps(completion_response), status_code=response.status_code)
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
prompt_tokens = len( prompt_tokens = len(

View file

@ -33,8 +33,8 @@ litellm.failure_callback = ["sentry"]
# Approach: Run each model through the test -> assert if the correct error (always the same one) is triggered # Approach: Run each model through the test -> assert if the correct error (always the same one) is triggered
# models = ["gpt-3.5-turbo", "chatgpt-test", "claude-instant-1", "command-nightly"] # models = ["gpt-3.5-turbo", "chatgpt-test", "claude-instant-1", "command-nightly"]
test_model = "togethercomputer/CodeLlama-34b-Python" test_model = "j2-light"
models = ["togethercomputer/CodeLlama-34b-Python"] models = ["j2-light"]
def logging_fn(model_call_dict): def logging_fn(model_call_dict):
@ -98,6 +98,9 @@ def invalid_auth(model): # set the model key to an invalid key, depending on th
elif model == "command-nightly": elif model == "command-nightly":
temporary_key = os.environ["COHERE_API_KEY"] temporary_key = os.environ["COHERE_API_KEY"]
os.environ["COHERE_API_KEY"] = "bad-key" os.environ["COHERE_API_KEY"] = "bad-key"
elif "j2" in model:
temporary_key = os.environ["AI21_API_KEY"]
os.environ["AI21_API_KEY"] = "bad-key"
elif "togethercomputer" in model: elif "togethercomputer" in model:
temporary_key = os.environ["TOGETHERAI_API_KEY"] temporary_key = os.environ["TOGETHERAI_API_KEY"]
os.environ["TOGETHERAI_API_KEY"] = "84060c79880fc49df126d3e87b53f8a463ff6e1c6d27fe64207cde25cdfcd1f24a" os.environ["TOGETHERAI_API_KEY"] = "84060c79880fc49df126d3e87b53f8a463ff6e1c6d27fe64207cde25cdfcd1f24a"
@ -138,6 +141,8 @@ def invalid_auth(model): # set the model key to an invalid key, depending on th
== "replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1" == "replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1"
): ):
os.environ["REPLICATE_API_KEY"] = temporary_key os.environ["REPLICATE_API_KEY"] = temporary_key
elif "j2" in model:
os.environ["AI21_API_KEY"] = temporary_key
elif ("togethercomputer" in model): elif ("togethercomputer" in model):
os.environ["TOGETHERAI_API_KEY"] = temporary_key os.environ["TOGETHERAI_API_KEY"] = temporary_key
return return

View file

@ -1446,6 +1446,37 @@ def exception_type(model, original_exception, custom_llm_provider):
message=f"HuggingfaceException - {original_exception.message}", message=f"HuggingfaceException - {original_exception.message}",
llm_provider="huggingface", llm_provider="huggingface",
) )
elif custom_llm_provider == "ai21":
print(f"e: {original_exception}")
if hasattr(original_exception, "message"):
if "Prompt has too many tokens" in original_exception.message:
exception_mapping_worked = True
raise ContextWindowExceededError(
message=f"AI21Exception - {original_exception.message}",
model=model,
llm_provider="ai21"
)
if hasattr(original_exception, "status_code"):
print(f"status code: {original_exception.status_code}")
if original_exception.status_code == 401:
exception_mapping_worked = True
raise AuthenticationError(
message=f"AI21Exception - {original_exception.message}",
llm_provider="ai21",
)
if original_exception.status_code == 422 or "Prompt has too many tokens" in original_exception.message:
exception_mapping_worked = True
raise InvalidRequestError(
message=f"AI21Exception - {original_exception.message}",
model=model,
llm_provider="ai21",
)
elif original_exception.status_code == 429:
exception_mapping_worked = True
raise RateLimitError(
message=f"AI21Exception - {original_exception.message}",
llm_provider="ai21",
)
elif custom_llm_provider == "together_ai": elif custom_llm_provider == "together_ai":
error_response = json.loads(error_str) error_response = json.loads(error_str)
if "error" in error_response and "`inputs` tokens + `max_new_tokens` must be <=" in error_response["error"]: if "error" in error_response and "`inputs` tokens + `max_new_tokens` must be <=" in error_response["error"]: