This commit is contained in:
Sixian Yi 2025-01-28 14:18:18 -08:00
parent 0052089ab8
commit 3816355e1d

View file

@ -16,11 +16,13 @@ PROVIDER_TOOL_PROMPT_FORMAT = {
"remote::fireworks": "json",
}
PROVIDER_LOGPROBS_TOP_K = {
"remote::together": 1,
"remote::fireworks": 3,
# "remote:vllm"
}
PROVIDER_LOGPROBS_TOP_K = set(
{
"remote::together",
"remote::fireworks",
# "remote:vllm"
}
)
@pytest.fixture(scope="session")
@ -95,7 +97,6 @@ def test_completion_log_probs_non_streaming(
if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K:
pytest.xfail(f"{inference_provider_type} doesn't support log probs yet")
logprobs_top_k = PROVIDER_LOGPROBS_TOP_K[inference_provider_type]
response = llama_stack_client.inference.completion(
content="Complete the sentence: Micheael Jordan is born in ",
stream=False,
@ -104,15 +105,14 @@ def test_completion_log_probs_non_streaming(
"max_tokens": 5,
},
logprobs={
"top_k": logprobs_top_k,
"top_k": 1,
},
)
assert response.logprobs, "Logprobs should not be empty"
assert 1 <= len(response.logprobs) <= 5
assert all(
len(logprob.logprobs_by_token) == logprobs_top_k
for logprob in response.logprobs
)
assert (
1 <= len(response.logprobs) <= 5
) # each token has 1 logprob and here max_tokens=5
assert all(len(logprob.logprobs_by_token) == 1 for logprob in response.logprobs)
def test_completion_log_probs_streaming(
@ -121,7 +121,6 @@ def test_completion_log_probs_streaming(
if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K:
pytest.xfail(f"{inference_provider_type} doesn't support log probs yet")
logprobs_top_k = PROVIDER_LOGPROBS_TOP_K[inference_provider_type]
response = llama_stack_client.inference.completion(
content="Complete the sentence: Micheael Jordan is born in ",
stream=True,
@ -130,7 +129,7 @@ def test_completion_log_probs_streaming(
"max_tokens": 5,
},
logprobs={
"top_k": logprobs_top_k,
"top_k": 1,
},
)
streamed_content = [chunk for chunk in response]
@ -138,8 +137,7 @@ def test_completion_log_probs_streaming(
if chunk.delta: # if there's a token, we expect logprobs
assert chunk.logprobs, "Logprobs should not be empty"
assert all(
len(logprob.logprobs_by_token) == logprobs_top_k
for logprob in chunk.logprobs
len(logprob.logprobs_by_token) == 1 for logprob in chunk.logprobs
)
else: # no token, no logprobs
assert not chunk.logprobs, "Logprobs should be empty"