diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 475dc802b..e076900a2 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -352,12 +352,12 @@ def test_completion_azure_deployment_id(): # Replicate API endpoints are unstable -> throw random CUDA errors -> this means our tests can fail even if our tests weren't incorrect. def test_completion_replicate_llama_2(): - litellm.set_verbose = True model_name = "replicate/llama-2-70b-chat:2796ee9483c3fd7aa2e171d38f4ca12251a30609463dcfd4cd76703f22e96cdf" try: response = completion( model=model_name, messages=messages, + max_tokens=20, custom_llm_provider="replicate" ) print(response) @@ -368,9 +368,29 @@ def test_completion_replicate_llama_2(): pytest.fail(f"Error occurred: {e}") except Exception as e: pytest.fail(f"Error occurred: {e}") - # test_completion_replicate_llama_2() +def test_completion_replicate_vicuna(): + model_name = "replicate/vicuna-13b:6282abe6a492de4145d7bb601023762212f9ddbbe78278bd6771c8b3b2f2a13b" + try: + response = completion( + model=model_name, + messages=messages, + custom_llm_provider="replicate", + temperature=0.1, + max_tokens=20, + ) + print(response) + # Add any assertions here to check the response + response_str = response["choices"][0]["message"]["content"] + print(response_str) + if type(response_str) != str: + pytest.fail(f"Error occurred: {e}") + except Exception as e: + pytest.fail(f"Error occurred: {e}") + +# test_completion_replicate_vicuna() + def test_completion_replicate_llama_stream(): model_name = "replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1" try: diff --git a/litellm/utils.py b/litellm/utils.py index 3c59f32c6..8a7d908e8 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -706,7 +706,10 @@ def get_optional_params( # use the openai defaults optional_params["stream"] = stream return optional_params if max_tokens != float("inf"): - optional_params["max_new_tokens"] = max_tokens + if "vicuna" in model: + optional_params["max_length"] = max_tokens + else: + optional_params["max_new_tokens"] = max_tokens if temperature != 1: optional_params["temperature"] = temperature if top_p != 1: