log probs - mark pytests as xfail for unsupported providers + add support for together (#883)

# What does this PR do?

1) As per @mattf's suggestion, we want to mark the pytest as xfail for
providers that do not support the functionality. In this diff, we xfail
the logProbs inference tests for providers who does not support log
probs.
( log probs is only supported by together, fireworks and vllm)

2) Added logProbs support for together according to their developer
[doc](https://docs.together.ai/docs/logprobs).

## Test Plan
1) Together & Fireworks
```
export LLAMA_STACK_CONFIG=/Users/sxyi/llama-stack/llama_stack/templates/together/run.yaml  
/opt/miniconda3/envs/stack/bin/pytest -s -v /Users/sxyi/llama-stack/tests/client-sdk/inference/test_inference.py
```
```
tests/client-sdk/inference/test_inference.py::test_text_completion_streaming[meta-llama/Llama-3.1-8B-Instruct] PASSED
tests/client-sdk/inference/test_inference.py::test_completion_log_probs_non_streaming[meta-llama/Llama-3.1-8B-Instruct] PASSED
tests/client-sdk/inference/test_inference.py::test_completion_log_probs_streaming[meta-llama/Llama-3.1-8B-Instruct] PASSED
tests/client-sdk/inference/test_inference.py::test_text_completion_structured_output[meta-llama/Llama-3.1-8B-Instruct] PASSED
tests/client-sdk/inference/test_inference.py::test_text_chat_completion_non_streaming[meta-llama/Llama-3.1-8B-Instruct-What are the names of planets in our solar system?-Earth] PASSED
tests/client-sdk/inference/test_inference.py::test_text_chat_completion_non_streaming[meta-llama/Llama-3.1-8B-Instruct-What are the names of the planets that have rings around them?-Saturn] PASSED
tests/client-sdk/inference/test_inference.py::test_text_chat_completion_streaming[meta-llama/Llama-3.1-8B-Instruct-What's the name of the Sun in latin?-Sol] PASSED
tests/client-sdk/inference/test_inference.py::test_text_chat_completion_streaming[meta-llama/Llama-3.1-8B-Instruct-What is the name of the US captial?-Washington] PASSED
tests/client-sdk/inference/test_inference.py::test_text_chat_completion_with_tool_calling_and_non_streaming[meta-llama/Llama-3.1-8B-Instruct] PASSED
tests/client-sdk/inference/test_inference.py::test_text_chat_completion_with_tool_calling_and_streaming[meta-llama/Llama-3.1-8B-Instruct] PASSED
tests/client-sdk/inference/test_inference.py::test_text_chat_completion_structured_output[meta-llama/Llama-3.1-8B-Instruct] PASSED
tests/client-sdk/inference/test_inference.py::test_image_chat_completion_non_streaming[meta-llama/Llama-3.2-11B-Vision-Instruct] PASSED
tests/client-sdk/inference/test_inference.py::test_image_chat_completion_streaming[meta-llama/Llama-3.2-11B-Vision-Instruct] PASSED
tests/client-sdk/inference/test_inference.py::test_image_chat_completion_base64_url[meta-llama/Llama-3.2-11B-Vision-Instruct] PASSED

========================================================================================== 15 passed, 2 warnings in 19.46s ===========================================================================================
```

```
export LLAMA_STACK_CONFIG=/Users/sxyi/llama-stack/llama_stack/templates/fireworks/run.yaml   
/opt/miniconda3/envs/stack/bin/pytest -s -v /Users/sxyi/llama-stack/tests/client-sdk/inference/test_inference.py
```
All tests passed 

2) Ollama - LogProbs tests are marked as xfailed. 
```
tests/client-sdk/inference/test_inference.py::test_completion_log_probs_non_streaming[meta-llama/Llama-3.1-8B-Instruct] XFAIL (remote::ollama doesn't support log probs yet)
tests/client-sdk/inference/test_inference.py::test_completion_log_probs_streaming[meta-llama/Llama-3.1-8B-Instruct] XFAIL (remote::ollama doesn't support log probs yet)
```
## Sources

Please link relevant resources if necessary.


## Before submitting

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Ran pre-commit to handle lint / formatting issues.
- [ ] Read the [contributor
guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md),
      Pull Request section?
- [ ] Updated relevant documentation.
- [ ] Wrote necessary unit or integration tests.
This commit is contained in:
Sixian Yi 2025-01-29 23:41:25 -08:00 committed by GitHub
parent 6f9023d948
commit 836f47a82d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 68 additions and 14 deletions

View file

@ -16,6 +16,14 @@ PROVIDER_TOOL_PROMPT_FORMAT = {
"remote::fireworks": "json",
}
PROVIDER_LOGPROBS_TOP_K = set(
{
"remote::together",
"remote::fireworks",
# "remote:vllm"
}
)
@pytest.fixture(scope="session")
def provider_tool_format(inference_provider_type):
@ -83,8 +91,12 @@ def test_text_completion_streaming(llama_stack_client, text_model_id):
assert "blue" in "".join(streamed_content).lower().strip()
@pytest.mark.skip("Most inference providers don't support log probs yet")
def test_completion_log_probs_non_streaming(llama_stack_client, text_model_id):
def test_completion_log_probs_non_streaming(
llama_stack_client, text_model_id, inference_provider_type
):
if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K:
pytest.xfail(f"{inference_provider_type} doesn't support log probs yet")
response = llama_stack_client.inference.completion(
content="Complete the sentence: Micheael Jordan is born in ",
stream=False,
@ -93,16 +105,22 @@ def test_completion_log_probs_non_streaming(llama_stack_client, text_model_id):
"max_tokens": 5,
},
logprobs={
"top_k": 3,
"top_k": 1,
},
)
assert response.logprobs, "Logprobs should not be empty"
assert 1 <= len(response.logprobs) <= 5
assert all(len(logprob.logprobs_by_token) == 3 for logprob in response.logprobs)
assert (
1 <= len(response.logprobs) <= 5
) # each token has 1 logprob and here max_tokens=5
assert all(len(logprob.logprobs_by_token) == 1 for logprob in response.logprobs)
@pytest.mark.skip("Most inference providers don't support log probs yet")
def test_completion_log_probs_streaming(llama_stack_client, text_model_id):
def test_completion_log_probs_streaming(
llama_stack_client, text_model_id, inference_provider_type
):
if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K:
pytest.xfail(f"{inference_provider_type} doesn't support log probs yet")
response = llama_stack_client.inference.completion(
content="Complete the sentence: Micheael Jordan is born in ",
stream=True,
@ -111,7 +129,7 @@ def test_completion_log_probs_streaming(llama_stack_client, text_model_id):
"max_tokens": 5,
},
logprobs={
"top_k": 3,
"top_k": 1,
},
)
streamed_content = [chunk for chunk in response]
@ -119,7 +137,7 @@ def test_completion_log_probs_streaming(llama_stack_client, text_model_id):
if chunk.delta: # if there's a token, we expect logprobs
assert chunk.logprobs, "Logprobs should not be empty"
assert all(
len(logprob.logprobs_by_token) == 3 for logprob in chunk.logprobs
len(logprob.logprobs_by_token) == 1 for logprob in chunk.logprobs
)
else: # no token, no logprobs
assert not chunk.logprobs, "Logprobs should be empty"