fix(test-fixes): test fixes

This commit is contained in:
Krrish Dholakia 2023-10-10 08:09:42 -07:00
parent 22ee0c6931
commit 1c9f87751d
6 changed files with 10 additions and 8 deletions

View file

@ -163,7 +163,7 @@ def completion(
choices_list.append(choice_obj) choices_list.append(choice_obj)
model_response["choices"] = choices_list model_response["choices"] = choices_list
except Exception as e: except Exception as e:
raise CohereError(message=traceback.format_exc(), status_code=response.status_code) raise CohereError(message=response.text, status_code=response.status_code)
## CALCULATING USAGE ## CALCULATING USAGE
prompt_tokens = len( prompt_tokens = len(

View file

@ -1,4 +1,4 @@
import os, types, traceback import os, types, traceback, copy
import json import json
from enum import Enum from enum import Enum
import time import time
@ -87,10 +87,12 @@ def completion(
model = model model = model
## Load Config ## Load Config
inference_params = copy.deepcopy(optional_params)
inference_params.pop("stream") # palm does not support streaming, so we handle this by fake streaming in main.py
config = litellm.PalmConfig.get_config() config = litellm.PalmConfig.get_config()
for k, v in config.items(): for k, v in config.items():
if k not in optional_params: # completion(top_k=3) > palm_config(top_k=3) <- allows for dynamic variables to be passed in if k not in inference_params: # completion(top_k=3) > palm_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v inference_params[k] = v
prompt = "" prompt = ""
for message in messages: for message in messages:
@ -110,11 +112,11 @@ def completion(
logging_obj.pre_call( logging_obj.pre_call(
input=prompt, input=prompt,
api_key="", api_key="",
additional_args={"complete_input_dict": {"optional_params": optional_params}}, additional_args={"complete_input_dict": {"inference_params": inference_params}},
) )
## COMPLETION CALL ## COMPLETION CALL
try: try:
response = palm.generate_text(prompt=prompt, **optional_params) response = palm.generate_text(prompt=prompt, **inference_params)
except Exception as e: except Exception as e:
raise PalmError( raise PalmError(
message=str(e), message=str(e),

View file

@ -45,7 +45,7 @@ def test_context_window(model):
with pytest.raises(ContextWindowExceededError): with pytest.raises(ContextWindowExceededError):
completion(model=model, messages=messages) completion(model=model, messages=messages)
test_context_window(model="gpt-3.5-turbo") # test_context_window(model="command-nightly")
# Test 2: InvalidAuth Errors # Test 2: InvalidAuth Errors
@pytest.mark.parametrize("model", models) @pytest.mark.parametrize("model", models)
def invalid_auth(model): # set the model key to an invalid key, depending on the model def invalid_auth(model): # set the model key to an invalid key, depending on the model

View file

@ -453,7 +453,7 @@ def test_completion_palm_stream():
print(f"completion_response: {complete_response}") print(f"completion_response: {complete_response}")
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# test_completion_palm_stream() test_completion_palm_stream()
# def test_completion_deep_infra_stream(): # def test_completion_deep_infra_stream():
# # deep infra currently includes role in the 2nd chunk # # deep infra currently includes role in the 2nd chunk