refactor(bedrock.py-+-cohere.py): making bedrock and cohere compatible with openai v1 sdk

This commit is contained in:
Krrish Dholakia 2023-11-11 17:33:19 -08:00
parent 39c2597c33
commit 547598a134
7 changed files with 82 additions and 74 deletions

View file

@ -6,11 +6,14 @@ from typing import Callable, Optional
import litellm import litellm
from litellm.utils import ModelResponse, get_secret from litellm.utils import ModelResponse, get_secret
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt
import httpx
class BedrockError(Exception): class BedrockError(Exception):
def __init__(self, status_code, message): def __init__(self, status_code, message):
self.status_code = status_code self.status_code = status_code
self.message = message self.message = message
self.request = httpx.Request(method="POST", url="https://us-west-2.console.aws.amazon.com/bedrock")
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__( super().__init__(
self.message self.message
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs

View file

@ -6,11 +6,14 @@ import time, traceback
from typing import Callable, Optional from typing import Callable, Optional
from litellm.utils import ModelResponse, Choices, Message from litellm.utils import ModelResponse, Choices, Message
import litellm import litellm
import httpx
class CohereError(Exception): class CohereError(Exception):
def __init__(self, status_code, message): def __init__(self, status_code, message):
self.status_code = status_code self.status_code = status_code
self.message = message self.message = message
self.request = httpx.Request(method="POST", url="https://api.cohere.ai/v1/generate")
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__( super().__init__(
self.message self.message
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs

View file

@ -2,7 +2,7 @@
import os, copy, types import os, copy, types
import json import json
from enum import Enum from enum import Enum
import requests import httpx, requests
import time import time
import litellm import litellm
from typing import Callable, Dict, List, Any from typing import Callable, Dict, List, Any
@ -14,6 +14,8 @@ class HuggingfaceError(Exception):
def __init__(self, status_code, message): def __init__(self, status_code, message):
self.status_code = status_code self.status_code = status_code
self.message = message self.message = message
self.response = httpx.Response(status_code=status_code)
self.request = self.response.request
super().__init__( super().__init__(
self.message self.message
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs

View file

@ -28,7 +28,7 @@ def test_async_response():
user_message = "Hello, how are you?" user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}] messages = [{"content": user_message, "role": "user"}]
try: try:
response = await acompletion(model="azure/chatgpt-v-2", messages=messages) response = await acompletion(model="command-nightly", messages=messages)
print(f"response: {response}") print(f"response: {response}")
except Exception as e: except Exception as e:
pytest.fail(f"An exception occurred: {e}") pytest.fail(f"An exception occurred: {e}")
@ -42,7 +42,7 @@ def test_get_response_streaming():
user_message = "write a short poem in one sentence" user_message = "write a short poem in one sentence"
messages = [{"content": user_message, "role": "user"}] messages = [{"content": user_message, "role": "user"}]
try: try:
response = await acompletion(model="azure/chatgpt-v-2", messages=messages, stream=True) response = await acompletion(model="command-nightly", messages=messages, stream=True)
print(type(response)) print(type(response))
import inspect import inspect
@ -65,7 +65,7 @@ def test_get_response_streaming():
asyncio.run(test_async_call()) asyncio.run(test_async_call())
test_get_response_streaming() # test_get_response_streaming()
def test_get_response_non_openai_streaming(): def test_get_response_non_openai_streaming():
import asyncio import asyncio

View file

@ -395,7 +395,7 @@ def test_completion_cohere(): # commenting for now as the cohere endpoint is bei
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# test_completion_cohere() # # test_completion_cohere()
def test_completion_openai(): def test_completion_openai():
@ -634,7 +634,7 @@ def test_completion_azure():
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
test_completion_azure() # test_completion_azure()
def test_completion_azure2(): def test_completion_azure2():
# test if we can pass api_base, api_version and api_key in compleition() # test if we can pass api_base, api_version and api_key in compleition()
try: try:
@ -941,7 +941,7 @@ def test_completion_bedrock_claude():
pass pass
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# test_completion_bedrock_claude() test_completion_bedrock_claude()
def test_completion_bedrock_cohere(): def test_completion_bedrock_cohere():
print("calling bedrock cohere") print("calling bedrock cohere")

View file

@ -42,9 +42,10 @@ models = ["command-nightly"]
# Test 1: Context Window Errors # Test 1: Context Window Errors
@pytest.mark.parametrize("model", models) @pytest.mark.parametrize("model", models)
def test_context_window(model): def test_context_window(model):
sample_text = "Say error 50 times" * 10000 sample_text = "Say error 50 times" * 1000000
messages = [{"content": sample_text, "role": "user"}] messages = [{"content": sample_text, "role": "user"}]
try: try:
litellm.set_verbose = False
response = completion(model=model, messages=messages) response = completion(model=model, messages=messages)
print(f"response: {response}") print(f"response: {response}")
print("FAILED!") print("FAILED!")
@ -67,8 +68,8 @@ def test_context_window_with_fallbacks(model):
# for model in litellm.models_by_provider["bedrock"]: # for model in litellm.models_by_provider["bedrock"]:
# test_context_window(model=model) # test_context_window(model=model)
# test_context_window(model="azure/chatgpt-v-2") # test_context_window(model="command-nightly")
# test_context_window_with_fallbacks(model="azure/chatgpt-v-2") # test_context_window_with_fallbacks(model="command-nightly")
# Test 2: InvalidAuth Errors # Test 2: InvalidAuth Errors
@pytest.mark.parametrize("model", models) @pytest.mark.parametrize("model", models)
def invalid_auth(model): # set the model key to an invalid key, depending on the model def invalid_auth(model): # set the model key to an invalid key, depending on the model
@ -78,7 +79,7 @@ def invalid_auth(model): # set the model key to an invalid key, depending on th
if model == "gpt-3.5-turbo" or model == "gpt-3.5-turbo-instruct": if model == "gpt-3.5-turbo" or model == "gpt-3.5-turbo-instruct":
temporary_key = os.environ["OPENAI_API_KEY"] temporary_key = os.environ["OPENAI_API_KEY"]
os.environ["OPENAI_API_KEY"] = "bad-key" os.environ["OPENAI_API_KEY"] = "bad-key"
elif model == "bedrock/anthropic.claude-v2": elif "bedrock" in model:
temporary_aws_access_key = os.environ["AWS_ACCESS_KEY_ID"] temporary_aws_access_key = os.environ["AWS_ACCESS_KEY_ID"]
os.environ["AWS_ACCESS_KEY_ID"] = "bad-key" os.environ["AWS_ACCESS_KEY_ID"] = "bad-key"
temporary_aws_region_name = os.environ["AWS_REGION_NAME"] temporary_aws_region_name = os.environ["AWS_REGION_NAME"]
@ -163,7 +164,7 @@ def invalid_auth(model): # set the model key to an invalid key, depending on th
# for model in litellm.models_by_provider["bedrock"]: # for model in litellm.models_by_provider["bedrock"]:
# invalid_auth(model=model) # invalid_auth(model=model)
# invalid_auth(model="azure/chatgpt-v-2") # invalid_auth(model="command-nightly")
# Test 3: Invalid Request Error # Test 3: Invalid Request Error
@pytest.mark.parametrize("model", models) @pytest.mark.parametrize("model", models)
@ -173,7 +174,7 @@ def test_invalid_request_error(model):
with pytest.raises(BadRequestError): with pytest.raises(BadRequestError):
completion(model=model, messages=messages, max_tokens="hello world") completion(model=model, messages=messages, max_tokens="hello world")
# test_invalid_request_error(model="azure/chatgpt-v-2") # test_invalid_request_error(model="command-nightly")
# Test 3: Rate Limit Errors # Test 3: Rate Limit Errors
# def test_model_call(model): # def test_model_call(model):
# try: # try:

View file

@ -136,37 +136,37 @@ def streaming_format_tests(idx, chunk):
print(f"extracted chunk: {extracted_chunk}") print(f"extracted chunk: {extracted_chunk}")
return extracted_chunk, finished return extracted_chunk, finished
# def test_completion_cohere_stream(): def test_completion_cohere_stream():
# this is a flaky test due to the cohere API endpoint being unstable # this is a flaky test due to the cohere API endpoint being unstable
# try: try:
# messages = [ messages = [
# {"role": "system", "content": "You are a helpful assistant."}, {"role": "system", "content": "You are a helpful assistant."},
# { {
# "role": "user", "role": "user",
# "content": "how does a court case get to the Supreme Court?", "content": "how does a court case get to the Supreme Court?",
# }, },
# ] ]
# response = completion( response = completion(
# model="command-nightly", messages=messages, stream=True, max_tokens=50, model="command-nightly", messages=messages, stream=True, max_tokens=50,
# ) )
# complete_response = "" complete_response = ""
# # Add any assertions here to check the response # Add any assertions here to check the response
# has_finish_reason = False has_finish_reason = False
# for idx, chunk in enumerate(response): for idx, chunk in enumerate(response):
# chunk, finished = streaming_format_tests(idx, chunk) chunk, finished = streaming_format_tests(idx, chunk)
# has_finish_reason = finished has_finish_reason = finished
# if finished: if finished:
# break break
# complete_response += chunk complete_response += chunk
# if has_finish_reason is False: if has_finish_reason is False:
# raise Exception("Finish reason not in final chunk") raise Exception("Finish reason not in final chunk")
# if complete_response.strip() == "": if complete_response.strip() == "":
# raise Exception("Empty response received") raise Exception("Empty response received")
# print(f"completion_response: {complete_response}") print(f"completion_response: {complete_response}")
# except Exception as e: except Exception as e:
# pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# test_completion_cohere_stream() test_completion_cohere_stream()
def test_completion_cohere_stream_bad_key(): def test_completion_cohere_stream_bad_key():
try: try:
@ -372,7 +372,7 @@ def test_completion_azure_stream():
print(f"completion_response: {complete_response}") print(f"completion_response: {complete_response}")
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
test_completion_azure_stream() # test_completion_azure_stream()
def test_completion_claude_stream(): def test_completion_claude_stream():
try: try:
@ -634,40 +634,39 @@ def test_completion_replicate_stream_bad_key():
# test_completion_replicate_stream_bad_key() # test_completion_replicate_stream_bad_key()
# def test_completion_bedrock_claude_stream(): def test_completion_bedrock_claude_stream():
# try: try:
# litellm.set_verbose=False litellm.set_verbose=False
# response = completion( response = completion(
# model="bedrock/anthropic.claude-instant-v1", model="bedrock/anthropic.claude-instant-v1",
# messages=[{"role": "user", "content": "Be as verbose as possible and give as many details as possible, how does a court case get to the Supreme Court?"}], messages=[{"role": "user", "content": "Be as verbose as possible and give as many details as possible, how does a court case get to the Supreme Court?"}],
# temperature=1, temperature=1,
# max_tokens=20, max_tokens=20,
# stream=True, stream=True,
# ) )
# print(response) print(response)
# complete_response = "" complete_response = ""
# has_finish_reason = False has_finish_reason = False
# # Add any assertions here to check the response # Add any assertions here to check the response
# for idx, chunk in enumerate(response): for idx, chunk in enumerate(response):
# # print # print
# chunk, finished = streaming_format_tests(idx, chunk) chunk, finished = streaming_format_tests(idx, chunk)
# has_finish_reason = finished has_finish_reason = finished
# complete_response += chunk complete_response += chunk
# if finished: if finished:
# break break
# if has_finish_reason is False: if has_finish_reason is False:
# raise Exception("finish reason not set for last chunk") raise Exception("finish reason not set for last chunk")
# if complete_response.strip() == "": if complete_response.strip() == "":
# raise Exception("Empty response received") raise Exception("Empty response received")
# print(f"completion_response: {complete_response}") print(f"completion_response: {complete_response}")
# except RateLimitError: except RateLimitError:
# pass pass
# except Exception as e: except Exception as e:
# pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# test_completion_bedrock_claude_stream() # test_completion_bedrock_claude_stream()
# def test_completion_sagemaker_stream(): # def test_completion_sagemaker_stream():
# try: # try:
# response = completion( # response = completion(