forked from phoenix-oss/llama-stack-mirror
add test for completion logprobs (#532)
# What does this PR do? adds a test for the completion api's logprobs parameter tbd which providers pass this test ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [x] Ran pre-commit to handle lint / formatting issues. - [x] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [x] Wrote necessary unit or integration tests.
This commit is contained in:
parent
96e158eaac
commit
2a9b13dd52
1 changed files with 55 additions and 0 deletions
|
@ -128,6 +128,61 @@ class TestInference:
|
|||
last = chunks[-1]
|
||||
assert last.stop_reason == StopReason.out_of_tokens
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion_logprobs(self, inference_model, inference_stack):
|
||||
inference_impl, _ = inference_stack
|
||||
|
||||
provider = inference_impl.routing_table.get_provider_impl(inference_model)
|
||||
if provider.__provider_spec__.provider_type not in (
|
||||
# "remote::nvidia", -- provider doesn't provide all logprobs
|
||||
):
|
||||
pytest.skip("Other inference providers don't support completion() yet")
|
||||
|
||||
response = await inference_impl.completion(
|
||||
content="Micheael Jordan is born in ",
|
||||
stream=False,
|
||||
model_id=inference_model,
|
||||
sampling_params=SamplingParams(
|
||||
max_tokens=5,
|
||||
),
|
||||
logprobs=LogProbConfig(
|
||||
top_k=3,
|
||||
),
|
||||
)
|
||||
|
||||
assert isinstance(response, CompletionResponse)
|
||||
assert 1 <= len(response.logprobs) <= 5
|
||||
assert response.logprobs, "Logprobs should not be empty"
|
||||
assert all(len(logprob.logprobs_by_token) == 3 for logprob in response.logprobs)
|
||||
|
||||
chunks = [
|
||||
r
|
||||
async for r in await inference_impl.completion(
|
||||
content="Roses are red,",
|
||||
stream=True,
|
||||
model_id=inference_model,
|
||||
sampling_params=SamplingParams(
|
||||
max_tokens=5,
|
||||
),
|
||||
logprobs=LogProbConfig(
|
||||
top_k=3,
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
assert all(isinstance(chunk, CompletionResponseStreamChunk) for chunk in chunks)
|
||||
assert (
|
||||
1 <= len(chunks) <= 6
|
||||
) # why 6 and not 5? the response may have an extra closing chunk, e.g. for usage or stop_reason
|
||||
for chunk in chunks:
|
||||
if chunk.delta: # if there's a token, we expect logprobs
|
||||
assert chunk.logprobs, "Logprobs should not be empty"
|
||||
assert all(
|
||||
len(logprob.logprobs_by_token) == 3 for logprob in chunk.logprobs
|
||||
)
|
||||
else: # no token, no logprobs
|
||||
assert not chunk.logprobs, "Logprobs should be empty"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skip("This test is not quite robust")
|
||||
async def test_completion_structured_output(self, inference_model, inference_stack):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue