diff --git a/llama_stack/providers/tests/inference/test_text_inference.py b/llama_stack/providers/tests/inference/test_text_inference.py index 741b61c5c..99a62ac08 100644 --- a/llama_stack/providers/tests/inference/test_text_inference.py +++ b/llama_stack/providers/tests/inference/test_text_inference.py @@ -128,6 +128,61 @@ class TestInference: last = chunks[-1] assert last.stop_reason == StopReason.out_of_tokens + @pytest.mark.asyncio + async def test_completion_logprobs(self, inference_model, inference_stack): + inference_impl, _ = inference_stack + + provider = inference_impl.routing_table.get_provider_impl(inference_model) + if provider.__provider_spec__.provider_type not in ( + # "remote::nvidia", -- provider doesn't provide all logprobs + ): + pytest.skip("Other inference providers don't support completion() yet") + + response = await inference_impl.completion( + content="Micheael Jordan is born in ", + stream=False, + model_id=inference_model, + sampling_params=SamplingParams( + max_tokens=5, + ), + logprobs=LogProbConfig( + top_k=3, + ), + ) + + assert isinstance(response, CompletionResponse) + assert 1 <= len(response.logprobs) <= 5 + assert response.logprobs, "Logprobs should not be empty" + assert all(len(logprob.logprobs_by_token) == 3 for logprob in response.logprobs) + + chunks = [ + r + async for r in await inference_impl.completion( + content="Roses are red,", + stream=True, + model_id=inference_model, + sampling_params=SamplingParams( + max_tokens=5, + ), + logprobs=LogProbConfig( + top_k=3, + ), + ) + ] + + assert all(isinstance(chunk, CompletionResponseStreamChunk) for chunk in chunks) + assert ( + 1 <= len(chunks) <= 6 + ) # why 6 and not 5? the response may have an extra closing chunk, e.g. for usage or stop_reason + for chunk in chunks: + if chunk.delta: # if there's a token, we expect logprobs + assert chunk.logprobs, "Logprobs should not be empty" + assert all( + len(logprob.logprobs_by_token) == 3 for logprob in chunk.logprobs + ) + else: # no token, no logprobs + assert not chunk.logprobs, "Logprobs should be empty" + @pytest.mark.asyncio @pytest.mark.skip("This test is not quite robust") async def test_completion_structured_output(self, inference_model, inference_stack):