diff --git a/tests/integration/inference/test_openai_completion.py b/tests/integration/inference/test_openai_completion.py index 62185e470..bb447b3c1 100644 --- a/tests/integration/inference/test_openai_completion.py +++ b/tests/integration/inference/test_openai_completion.py @@ -58,6 +58,15 @@ def skip_if_model_doesnt_support_suffix(client_with_models, model_id): pytest.skip(f"Provider {provider.provider_type} doesn't support suffix.") +def skip_if_doesnt_support_n(client_with_models, model_id): + provider = provider_from_model(client_with_models, model_id) + if provider.provider_type in ( + "remote::sambanova", + "remote::ollama", + ): + pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support n param.") + + def skip_if_model_doesnt_support_openai_chat_completion(client_with_models, model_id): provider = provider_from_model(client_with_models, model_id) if provider.provider_type in ( @@ -262,10 +271,7 @@ def test_openai_chat_completion_streaming(compat_client, client_with_models, tex ) def test_openai_chat_completion_streaming_with_n(compat_client, client_with_models, text_model_id, test_case): skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id) - - provider = provider_from_model(client_with_models, text_model_id) - if provider.provider_type == "remote::ollama": - pytest.skip(f"Model {text_model_id} hosted by {provider.provider_type} doesn't support n > 1.") + skip_if_doesnt_support_n(client_with_models, text_model_id) tc = TestCase(test_case) question = tc["question"]