diff --git a/tests/client-sdk/inference/test_inference.py b/tests/client-sdk/inference/test_inference.py index 6e76c1339..6dff1be24 100644 --- a/tests/client-sdk/inference/test_inference.py +++ b/tests/client-sdk/inference/test_inference.py @@ -16,11 +16,13 @@ PROVIDER_TOOL_PROMPT_FORMAT = { "remote::fireworks": "json", } -PROVIDER_LOGPROBS_TOP_K = { - "remote::together": 1, - "remote::fireworks": 3, - # "remote:vllm" -} +PROVIDER_LOGPROBS_TOP_K = set( + { + "remote::together", + "remote::fireworks", + # "remote:vllm" + } +) @pytest.fixture(scope="session") @@ -95,7 +97,6 @@ def test_completion_log_probs_non_streaming( if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K: pytest.xfail(f"{inference_provider_type} doesn't support log probs yet") - logprobs_top_k = PROVIDER_LOGPROBS_TOP_K[inference_provider_type] response = llama_stack_client.inference.completion( content="Complete the sentence: Micheael Jordan is born in ", stream=False, @@ -104,15 +105,14 @@ def test_completion_log_probs_non_streaming( "max_tokens": 5, }, logprobs={ - "top_k": logprobs_top_k, + "top_k": 1, }, ) assert response.logprobs, "Logprobs should not be empty" - assert 1 <= len(response.logprobs) <= 5 - assert all( - len(logprob.logprobs_by_token) == logprobs_top_k - for logprob in response.logprobs - ) + assert ( + 1 <= len(response.logprobs) <= 5 + ) # each token has 1 logprob and here max_tokens=5 + assert all(len(logprob.logprobs_by_token) == 1 for logprob in response.logprobs) def test_completion_log_probs_streaming( @@ -121,7 +121,6 @@ def test_completion_log_probs_streaming( if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K: pytest.xfail(f"{inference_provider_type} doesn't support log probs yet") - logprobs_top_k = PROVIDER_LOGPROBS_TOP_K[inference_provider_type] response = llama_stack_client.inference.completion( content="Complete the sentence: Micheael Jordan is born in ", stream=True, @@ -130,7 +129,7 @@ def test_completion_log_probs_streaming( "max_tokens": 5, }, logprobs={ - "top_k": logprobs_top_k, + "top_k": 1, }, ) streamed_content = [chunk for chunk in response] @@ -138,8 +137,7 @@ def test_completion_log_probs_streaming( if chunk.delta: # if there's a token, we expect logprobs assert chunk.logprobs, "Logprobs should not be empty" assert all( - len(logprob.logprobs_by_token) == logprobs_top_k - for logprob in chunk.logprobs + len(logprob.logprobs_by_token) == 1 for logprob in chunk.logprobs ) else: # no token, no logprobs assert not chunk.logprobs, "Logprobs should be empty"