adding exception handling for together ai

This commit is contained in:
Krrish Dholakia 2023-08-29 12:29:56 -07:00
parent 35ecc91a71
commit 88bd1df3e0
8 changed files with 96 additions and 43 deletions

View file

@ -286,10 +286,11 @@ from .utils import (
) )
from .main import * # type: ignore from .main import * # type: ignore
from .integrations import * from .integrations import *
from openai.error import ( from .exceptions import (
AuthenticationError, AuthenticationError,
InvalidRequestError, InvalidRequestError,
RateLimitError, RateLimitError,
ServiceUnavailableError, ServiceUnavailableError,
OpenAIError, OpenAIError,
ContextWindowExceededError
) )

View file

@ -28,6 +28,17 @@ class InvalidRequestError(InvalidRequestError): # type: ignore
self.message, f"{self.model}" self.message, f"{self.model}"
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
# sub class of invalid request error - meant to give more granularity for error handling context window exceeded errors
class ContextWindowExceededError(InvalidRequestError): # type: ignore
def __init__(self, message, model, llm_provider):
self.status_code = 400
self.message = message
self.model = model
self.llm_provider = llm_provider
super().__init__(
self.message, self.model, self.llm_provider
) # Call the base class constructor with the parameters it needs
class RateLimitError(RateLimitError): # type: ignore class RateLimitError(RateLimitError): # type: ignore
def __init__(self, message, llm_provider): def __init__(self, message, llm_provider):

View file

@ -1,4 +1,4 @@
import os, openai, sys import os, openai, sys, json
from typing import Any from typing import Any
from functools import partial from functools import partial
import dotenv, traceback, random, asyncio, time, contextvars import dotenv, traceback, random, asyncio, time, contextvars
@ -539,6 +539,7 @@ def completion(
return response return response
response = model_response response = model_response
elif custom_llm_provider == "together_ai" or ("togethercomputer" in model): elif custom_llm_provider == "together_ai" or ("togethercomputer" in model):
custom_llm_provider = "together_ai"
import requests import requests
TOGETHER_AI_TOKEN = ( TOGETHER_AI_TOKEN = (
@ -594,10 +595,10 @@ def completion(
) )
# make this safe for reading, if output does not exist raise an error # make this safe for reading, if output does not exist raise an error
json_response = res.json() json_response = res.json()
if "output" not in json_response: if "error" in json_response:
raise Exception( raise Exception(json.dumps(json_response))
f"liteLLM: Error Making TogetherAI request, JSON Response {json_response}" elif "error" in json_response["output"]:
) raise Exception(json.dumps(json_response["output"]))
completion_response = json_response["output"]["choices"][0]["text"] completion_response = json_response["output"]["choices"][0]["text"]
prompt_tokens = len(encoding.encode(prompt)) prompt_tokens = len(encoding.encode(prompt))
completion_tokens = len(encoding.encode(completion_response)) completion_tokens = len(encoding.encode(completion_response))

View file

@ -12,6 +12,7 @@ from litellm import (
completion, completion,
AuthenticationError, AuthenticationError,
InvalidRequestError, InvalidRequestError,
ContextWindowExceededError,
RateLimitError, RateLimitError,
ServiceUnavailableError, ServiceUnavailableError,
OpenAIError, OpenAIError,
@ -32,11 +33,12 @@ litellm.failure_callback = ["sentry"]
# Approach: Run each model through the test -> assert if the correct error (always the same one) is triggered # Approach: Run each model through the test -> assert if the correct error (always the same one) is triggered
# models = ["gpt-3.5-turbo", "chatgpt-test", "claude-instant-1", "command-nightly"] # models = ["gpt-3.5-turbo", "chatgpt-test", "claude-instant-1", "command-nightly"]
test_model = "claude-instant-1" test_model = "togethercomputer/CodeLlama-34b-Python"
models = ["claude-instant-1"] models = ["togethercomputer/CodeLlama-34b-Python"]
def logging_fn(model_call_dict): def logging_fn(model_call_dict):
return
if "model" in model_call_dict: if "model" in model_call_dict:
print(f"model_call_dict: {model_call_dict['model']}") print(f"model_call_dict: {model_call_dict['model']}")
else: else:
@ -49,15 +51,16 @@ def test_context_window(model):
sample_text = "how does a court case get to the Supreme Court?" * 5000 sample_text = "how does a court case get to the Supreme Court?" * 5000
messages = [{"content": sample_text, "role": "user"}] messages = [{"content": sample_text, "role": "user"}]
try: try:
model = "chatgpt-test"
print(f"model: {model}") print(f"model: {model}")
response = completion( response = completion(
model=model, model=model,
messages=messages, messages=messages,
custom_llm_provider="azure",
logger_fn=logging_fn, logger_fn=logging_fn,
) )
print(f"response: {response}") print(f"response: {response}")
except ContextWindowExceededError as e:
print(f"ContextWindowExceededError: {e.llm_provider}")
return
except InvalidRequestError as e: except InvalidRequestError as e:
print(f"InvalidRequestError: {e.llm_provider}") print(f"InvalidRequestError: {e.llm_provider}")
return return
@ -95,6 +98,9 @@ def invalid_auth(model): # set the model key to an invalid key, depending on th
elif model == "command-nightly": elif model == "command-nightly":
temporary_key = os.environ["COHERE_API_KEY"] temporary_key = os.environ["COHERE_API_KEY"]
os.environ["COHERE_API_KEY"] = "bad-key" os.environ["COHERE_API_KEY"] = "bad-key"
elif "togethercomputer" in model:
temporary_key = os.environ["TOGETHERAI_API_KEY"]
os.environ["TOGETHERAI_API_KEY"] = "84060c79880fc49df126d3e87b53f8a463ff6e1c6d27fe64207cde25cdfcd1f24a"
elif ( elif (
model model
== "replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1" == "replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1"
@ -132,46 +138,48 @@ def invalid_auth(model): # set the model key to an invalid key, depending on th
== "replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1" == "replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1"
): ):
os.environ["REPLICATE_API_KEY"] = temporary_key os.environ["REPLICATE_API_KEY"] = temporary_key
elif ("togethercomputer" in model):
os.environ["TOGETHERAI_API_KEY"] = temporary_key
return return
invalid_auth(test_model) invalid_auth(test_model)
# # Test 3: Rate Limit Errors # Test 3: Rate Limit Errors
# def test_model(model): def test_model(model):
# try: try:
# sample_text = "how does a court case get to the Supreme Court?" * 50000 sample_text = "how does a court case get to the Supreme Court?" * 50000
# messages = [{ "content": sample_text,"role": "user"}] messages = [{ "content": sample_text,"role": "user"}]
# custom_llm_provider = None custom_llm_provider = None
# if model == "chatgpt-test": if model == "chatgpt-test":
# custom_llm_provider = "azure" custom_llm_provider = "azure"
# print(f"model: {model}") print(f"model: {model}")
# response = completion(model=model, messages=messages, custom_llm_provider=custom_llm_provider) response = completion(model=model, messages=messages, custom_llm_provider=custom_llm_provider)
# except RateLimitError: except RateLimitError:
# return True return True
# except OpenAIError: # is at least an openai error -> in case of random model errors - e.g. overloaded server except OpenAIError: # is at least an openai error -> in case of random model errors - e.g. overloaded server
# return True return True
# except Exception as e: except Exception as e:
# print(f"Uncaught Exception {model}: {type(e).__name__} - {e}") print(f"Uncaught Exception {model}: {type(e).__name__} - {e}")
# pass pass
# return False return False
# # Repeat each model 500 times # Repeat each model 500 times
# extended_models = [model for model in models for _ in range(250)] extended_models = [model for model in models for _ in range(250)]
# def worker(model): def worker(model):
# return test_model(model) return test_model(model)
# # Create a dictionary to store the results # Create a dictionary to store the results
# counts = {True: 0, False: 0} counts = {True: 0, False: 0}
# # Use Thread Pool Executor # Use Thread Pool Executor
# with ThreadPoolExecutor(max_workers=500) as executor: with ThreadPoolExecutor(max_workers=500) as executor:
# # Use map to start the operation in thread pool # Use map to start the operation in thread pool
# results = executor.map(worker, extended_models) results = executor.map(worker, extended_models)
# # Iterate over results and count True/False # Iterate over results and count True/False
# for result in results: for result in results:
# counts[result] += 1 counts[result] += 1
# accuracy_score = counts[True]/(counts[True] + counts[False]) accuracy_score = counts[True]/(counts[True] + counts[False])
# print(f"accuracy_score: {accuracy_score}") print(f"accuracy_score: {accuracy_score}")

View file

@ -25,6 +25,7 @@ from .exceptions import (
RateLimitError, RateLimitError,
ServiceUnavailableError, ServiceUnavailableError,
OpenAIError, OpenAIError,
ContextWindowExceededError
) )
from typing import List, Dict, Union, Optional from typing import List, Dict, Union, Optional
from .caching import Cache from .caching import Cache
@ -1445,6 +1446,37 @@ def exception_type(model, original_exception, custom_llm_provider):
message=f"HuggingfaceException - {original_exception.message}", message=f"HuggingfaceException - {original_exception.message}",
llm_provider="huggingface", llm_provider="huggingface",
) )
elif custom_llm_provider == "together_ai":
error_response = json.loads(error_str)
if "error" in error_response and "`inputs` tokens + `max_new_tokens` must be <=" in error_response["error"]:
exception_mapping_worked = True
raise ContextWindowExceededError(
message=error_response["error"],
model=model,
llm_provider="together_ai"
)
elif "error" in error_response and "invalid private key" in error_response["error"]:
exception_mapping_worked = True
raise AuthenticationError(
message=error_response["error"],
llm_provider="together_ai"
)
elif "error" in error_response and "INVALID_ARGUMENT" in error_response["error"]:
exception_mapping_worked = True
raise InvalidRequestError(
message=error_response["error"],
model=model,
llm_provider="together_ai"
)
elif "error_type" in error_response and error_response["error_type"] == "validation":
exception_mapping_worked = True
raise InvalidRequestError(
message=error_response["error"],
model=model,
llm_provider="together_ai"
)
print(f"error: {error_response}")
print(f"e: {original_exception}")
raise original_exception # base case - return the original exception raise original_exception # base case - return the original exception
else: else:
raise original_exception raise original_exception