(test) replicate with params

This commit is contained in:
ishaan-jaff 2023-11-08 14:28:06 -08:00
parent b102489f49
commit 01a7660a12

View file

@ -123,7 +123,7 @@ def test_completion_gpt4_turbo():
pass pass
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
test_completion_gpt4_turbo() # test_completion_gpt4_turbo()
def test_completion_perplexity_api(): def test_completion_perplexity_api():
try: try:
@ -756,37 +756,19 @@ def test_completion_azure_deployment_id():
# test_completion_anthropic_openai_proxy() # test_completion_anthropic_openai_proxy()
# Replicate API endpoints are unstable -> throw random CUDA errors -> this means our tests can fail even if our tests weren't incorrect.
# def test_completion_replicate_llama_2():
# model_name = "replicate/meta/llama-2-70b-chat:02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3"
# litellm.replicate_config(max_new_tokens=200)
# try:
# response = completion(
# model=model_name,
# messages=messages,
# )
# print(response)
# cost = completion_cost(completion_response=response)
# print("Cost for completion call with llama-2: ", f"${float(cost):.10f}")
# # Add any assertions here to check the response
# response_str = response["choices"][0]["message"]["content"]
# print(response_str)
# if type(response_str) != str:
# pytest.fail(f"Error occurred: {e}")
# except Exception as e:
# pytest.fail(f"Error occurred: {e}")
# test_completion_replicate_llama_2()
def test_completion_replicate_vicuna(): def test_completion_replicate_vicuna():
print("TESTING REPLICATE") print("TESTING REPLICATE")
litellm.set_verbose=False
model_name = "replicate/vicuna-13b:6282abe6a492de4145d7bb601023762212f9ddbbe78278bd6771c8b3b2f2a13b" model_name = "replicate/vicuna-13b:6282abe6a492de4145d7bb601023762212f9ddbbe78278bd6771c8b3b2f2a13b"
try: try:
response = completion( response = completion(
model=model_name, model=model_name,
messages=messages, messages=messages,
custom_llm_provider="replicate", temperature=0.5,
temperature=0.1, top_k=20,
repetition_penalty=1,
min_tokens=1,
seed=-1,
max_tokens=20, max_tokens=20,
) )
print(response) print(response)