mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-06 10:42:39 +00:00
top_k=1
This commit is contained in:
parent
0052089ab8
commit
3816355e1d
1 changed files with 14 additions and 16 deletions
|
@ -16,11 +16,13 @@ PROVIDER_TOOL_PROMPT_FORMAT = {
|
|||
"remote::fireworks": "json",
|
||||
}
|
||||
|
||||
PROVIDER_LOGPROBS_TOP_K = {
|
||||
"remote::together": 1,
|
||||
"remote::fireworks": 3,
|
||||
# "remote:vllm"
|
||||
}
|
||||
PROVIDER_LOGPROBS_TOP_K = set(
|
||||
{
|
||||
"remote::together",
|
||||
"remote::fireworks",
|
||||
# "remote:vllm"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
|
@ -95,7 +97,6 @@ def test_completion_log_probs_non_streaming(
|
|||
if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K:
|
||||
pytest.xfail(f"{inference_provider_type} doesn't support log probs yet")
|
||||
|
||||
logprobs_top_k = PROVIDER_LOGPROBS_TOP_K[inference_provider_type]
|
||||
response = llama_stack_client.inference.completion(
|
||||
content="Complete the sentence: Micheael Jordan is born in ",
|
||||
stream=False,
|
||||
|
@ -104,15 +105,14 @@ def test_completion_log_probs_non_streaming(
|
|||
"max_tokens": 5,
|
||||
},
|
||||
logprobs={
|
||||
"top_k": logprobs_top_k,
|
||||
"top_k": 1,
|
||||
},
|
||||
)
|
||||
assert response.logprobs, "Logprobs should not be empty"
|
||||
assert 1 <= len(response.logprobs) <= 5
|
||||
assert all(
|
||||
len(logprob.logprobs_by_token) == logprobs_top_k
|
||||
for logprob in response.logprobs
|
||||
)
|
||||
assert (
|
||||
1 <= len(response.logprobs) <= 5
|
||||
) # each token has 1 logprob and here max_tokens=5
|
||||
assert all(len(logprob.logprobs_by_token) == 1 for logprob in response.logprobs)
|
||||
|
||||
|
||||
def test_completion_log_probs_streaming(
|
||||
|
@ -121,7 +121,6 @@ def test_completion_log_probs_streaming(
|
|||
if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K:
|
||||
pytest.xfail(f"{inference_provider_type} doesn't support log probs yet")
|
||||
|
||||
logprobs_top_k = PROVIDER_LOGPROBS_TOP_K[inference_provider_type]
|
||||
response = llama_stack_client.inference.completion(
|
||||
content="Complete the sentence: Micheael Jordan is born in ",
|
||||
stream=True,
|
||||
|
@ -130,7 +129,7 @@ def test_completion_log_probs_streaming(
|
|||
"max_tokens": 5,
|
||||
},
|
||||
logprobs={
|
||||
"top_k": logprobs_top_k,
|
||||
"top_k": 1,
|
||||
},
|
||||
)
|
||||
streamed_content = [chunk for chunk in response]
|
||||
|
@ -138,8 +137,7 @@ def test_completion_log_probs_streaming(
|
|||
if chunk.delta: # if there's a token, we expect logprobs
|
||||
assert chunk.logprobs, "Logprobs should not be empty"
|
||||
assert all(
|
||||
len(logprob.logprobs_by_token) == logprobs_top_k
|
||||
for logprob in chunk.logprobs
|
||||
len(logprob.logprobs_by_token) == 1 for logprob in chunk.logprobs
|
||||
)
|
||||
else: # no token, no logprobs
|
||||
assert not chunk.logprobs, "Logprobs should be empty"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue