fix(utils.py): fix response object mapping

This commit is contained in:
Krrish Dholakia 2023-11-13 15:58:25 -08:00
parent 548933d0bb
commit 38ff412b9a
2 changed files with 32 additions and 34 deletions

View file

@ -136,35 +136,35 @@ def streaming_format_tests(idx, chunk):
print(f"extracted chunk: {extracted_chunk}") print(f"extracted chunk: {extracted_chunk}")
return extracted_chunk, finished return extracted_chunk, finished
def test_completion_cohere_stream(): # def test_completion_cohere_stream():
# this is a flaky test due to the cohere API endpoint being unstable # # this is a flaky test due to the cohere API endpoint being unstable
try: # try:
messages = [ # messages = [
{"role": "system", "content": "You are a helpful assistant."}, # {"role": "system", "content": "You are a helpful assistant."},
{ # {
"role": "user", # "role": "user",
"content": "how does a court case get to the Supreme Court?", # "content": "how does a court case get to the Supreme Court?",
}, # },
] # ]
response = completion( # response = completion(
model="command-nightly", messages=messages, stream=True, max_tokens=50, # model="command-nightly", messages=messages, stream=True, max_tokens=50,
) # )
complete_response = "" # complete_response = ""
# Add any assertions here to check the response # # Add any assertions here to check the response
has_finish_reason = False # has_finish_reason = False
for idx, chunk in enumerate(response): # for idx, chunk in enumerate(response):
chunk, finished = streaming_format_tests(idx, chunk) # chunk, finished = streaming_format_tests(idx, chunk)
has_finish_reason = finished # has_finish_reason = finished
if finished: # if finished:
break # break
complete_response += chunk # complete_response += chunk
if has_finish_reason is False: # if has_finish_reason is False:
raise Exception("Finish reason not in final chunk") # raise Exception("Finish reason not in final chunk")
if complete_response.strip() == "": # if complete_response.strip() == "":
raise Exception("Empty response received") # raise Exception("Empty response received")
print(f"completion_response: {complete_response}") # print(f"completion_response: {complete_response}")
except Exception as e: # except Exception as e:
pytest.fail(f"Error occurred: {e}") # pytest.fail(f"Error occurred: {e}")
# test_completion_cohere_stream() # test_completion_cohere_stream()
@ -493,7 +493,7 @@ def test_completion_claude_stream_bad_key():
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
test_completion_claude_stream_bad_key() # test_completion_claude_stream_bad_key()
# test_completion_replicate_stream() # test_completion_replicate_stream()
# def test_completion_vertexai_stream(): # def test_completion_vertexai_stream():
@ -767,8 +767,6 @@ def ai21_completion_call_bad_key():
if complete_response.strip() == "": if complete_response.strip() == "":
raise Exception("Empty response received") raise Exception("Empty response received")
print(f"completion_response: {complete_response}") print(f"completion_response: {complete_response}")
except Bad as e:
pass
except: except:
pytest.fail(f"error occurred: {traceback.format_exc()}") pytest.fail(f"error occurred: {traceback.format_exc()}")
@ -848,7 +846,7 @@ def test_openai_chat_completion_call():
print(f"error occurred: {traceback.format_exc()}") print(f"error occurred: {traceback.format_exc()}")
pass pass
test_openai_chat_completion_call() # test_openai_chat_completion_call()
def test_openai_chat_completion_complete_response_call(): def test_openai_chat_completion_complete_response_call():
try: try:

View file

@ -1108,7 +1108,7 @@ def client(original_function):
if cached_result != None: if cached_result != None:
print_verbose(f"Cache Hit!") print_verbose(f"Cache Hit!")
call_type = original_function.__name__ call_type = original_function.__name__
if call_type == CallTypes.completion.value: if call_type == CallTypes.completion.value and isinstance(cached_result, dict):
return convert_to_model_response_object(response_object=cached_result, model_response_object=ModelResponse()) return convert_to_model_response_object(response_object=cached_result, model_response_object=ModelResponse())
else: else:
return cached_result return cached_result