log probs - mark pytests as xfail for unsupported providers + add support for together (#883)

# What does this PR do?

1) As per @mattf's suggestion, we want to mark the pytest as xfail for
providers that do not support the functionality. In this diff, we xfail
the logProbs inference tests for providers who does not support log
probs.
( log probs is only supported by together, fireworks and vllm)

2) Added logProbs support for together according to their developer
[doc](https://docs.together.ai/docs/logprobs).

## Test Plan
1) Together & Fireworks
```
export LLAMA_STACK_CONFIG=/Users/sxyi/llama-stack/llama_stack/templates/together/run.yaml  
/opt/miniconda3/envs/stack/bin/pytest -s -v /Users/sxyi/llama-stack/tests/client-sdk/inference/test_inference.py
```
```
tests/client-sdk/inference/test_inference.py::test_text_completion_streaming[meta-llama/Llama-3.1-8B-Instruct] PASSED
tests/client-sdk/inference/test_inference.py::test_completion_log_probs_non_streaming[meta-llama/Llama-3.1-8B-Instruct] PASSED
tests/client-sdk/inference/test_inference.py::test_completion_log_probs_streaming[meta-llama/Llama-3.1-8B-Instruct] PASSED
tests/client-sdk/inference/test_inference.py::test_text_completion_structured_output[meta-llama/Llama-3.1-8B-Instruct] PASSED
tests/client-sdk/inference/test_inference.py::test_text_chat_completion_non_streaming[meta-llama/Llama-3.1-8B-Instruct-What are the names of planets in our solar system?-Earth] PASSED
tests/client-sdk/inference/test_inference.py::test_text_chat_completion_non_streaming[meta-llama/Llama-3.1-8B-Instruct-What are the names of the planets that have rings around them?-Saturn] PASSED
tests/client-sdk/inference/test_inference.py::test_text_chat_completion_streaming[meta-llama/Llama-3.1-8B-Instruct-What's the name of the Sun in latin?-Sol] PASSED
tests/client-sdk/inference/test_inference.py::test_text_chat_completion_streaming[meta-llama/Llama-3.1-8B-Instruct-What is the name of the US captial?-Washington] PASSED
tests/client-sdk/inference/test_inference.py::test_text_chat_completion_with_tool_calling_and_non_streaming[meta-llama/Llama-3.1-8B-Instruct] PASSED
tests/client-sdk/inference/test_inference.py::test_text_chat_completion_with_tool_calling_and_streaming[meta-llama/Llama-3.1-8B-Instruct] PASSED
tests/client-sdk/inference/test_inference.py::test_text_chat_completion_structured_output[meta-llama/Llama-3.1-8B-Instruct] PASSED
tests/client-sdk/inference/test_inference.py::test_image_chat_completion_non_streaming[meta-llama/Llama-3.2-11B-Vision-Instruct] PASSED
tests/client-sdk/inference/test_inference.py::test_image_chat_completion_streaming[meta-llama/Llama-3.2-11B-Vision-Instruct] PASSED
tests/client-sdk/inference/test_inference.py::test_image_chat_completion_base64_url[meta-llama/Llama-3.2-11B-Vision-Instruct] PASSED

========================================================================================== 15 passed, 2 warnings in 19.46s ===========================================================================================
```

```
export LLAMA_STACK_CONFIG=/Users/sxyi/llama-stack/llama_stack/templates/fireworks/run.yaml   
/opt/miniconda3/envs/stack/bin/pytest -s -v /Users/sxyi/llama-stack/tests/client-sdk/inference/test_inference.py
```
All tests passed 

2) Ollama - LogProbs tests are marked as xfailed. 
```
tests/client-sdk/inference/test_inference.py::test_completion_log_probs_non_streaming[meta-llama/Llama-3.1-8B-Instruct] XFAIL (remote::ollama doesn't support log probs yet)
tests/client-sdk/inference/test_inference.py::test_completion_log_probs_streaming[meta-llama/Llama-3.1-8B-Instruct] XFAIL (remote::ollama doesn't support log probs yet)
```
## Sources

Please link relevant resources if necessary.


## Before submitting

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Ran pre-commit to handle lint / formatting issues.
- [ ] Read the [contributor
guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md),
      Pull Request section?
- [ ] Updated relevant documentation.
- [ ] Wrote necessary unit or integration tests.
This commit is contained in:
Sixian Yi 2025-01-29 23:41:25 -08:00 committed by GitHub
parent 6f9023d948
commit 836f47a82d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 68 additions and 14 deletions

View file

@ -161,7 +161,10 @@ class TogetherInferenceAdapter(
yield chunk
def _build_options(
self, sampling_params: Optional[SamplingParams], fmt: ResponseFormat
self,
sampling_params: Optional[SamplingParams],
logprobs: Optional[LogProbConfig],
fmt: ResponseFormat,
) -> dict:
options = get_sampling_options(sampling_params)
if fmt:
@ -175,6 +178,13 @@ class TogetherInferenceAdapter(
else:
raise ValueError(f"Unknown response format {fmt.type}")
if logprobs and logprobs.top_k:
if logprobs.top_k != 1:
raise ValueError(
f"Unsupported value: Together only supports logprobs top_k=1. {logprobs.top_k} was provided",
)
options["logprobs"] = 1
return options
async def chat_completion(
@ -263,7 +273,9 @@ class TogetherInferenceAdapter(
"model": request.model,
**input_dict,
"stream": request.stream,
**self._build_options(request.sampling_params, request.response_format),
**self._build_options(
request.sampling_params, request.logprobs, request.response_format
),
}
async def embeddings(

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import AsyncGenerator, Dict, List, Optional
from typing import AsyncGenerator, Dict, List, Optional, Union
from llama_models.datatypes import (
GreedySamplingStrategy,
@ -121,7 +121,31 @@ def convert_openai_completion_logprobs(
) -> Optional[List[TokenLogProbs]]:
if not logprobs:
return None
return [TokenLogProbs(logprobs_by_token=x) for x in logprobs.top_logprobs]
if hasattr(logprobs, "top_logprobs"):
return [TokenLogProbs(logprobs_by_token=x) for x in logprobs.top_logprobs]
# Together supports logprobs with top_k=1 only. This means for each token position,
# they return only the logprobs for the selected token (vs. the top n most likely tokens).
# Here we construct the response by matching the selected token with the logprobs.
if logprobs.tokens and logprobs.token_logprobs:
return [
TokenLogProbs(logprobs_by_token={token: token_lp})
for token, token_lp in zip(logprobs.tokens, logprobs.token_logprobs)
]
return None
def convert_openai_completion_logprobs_stream(
text: str, logprobs: Optional[Union[float, OpenAICompatLogprobs]]
):
if logprobs is None:
return None
if isinstance(logprobs, float):
# Adapt response from Together CompletionChoicesChunk
return [TokenLogProbs(logprobs_by_token={text: logprobs})]
if hasattr(logprobs, "top_logprobs"):
return [TokenLogProbs(logprobs_by_token=x) for x in logprobs.top_logprobs]
return None
def process_completion_response(
@ -188,7 +212,7 @@ async def process_completion_stream_response(
yield CompletionResponseStreamChunk(
delta=text,
stop_reason=stop_reason,
logprobs=convert_openai_completion_logprobs(choice.logprobs),
logprobs=convert_openai_completion_logprobs_stream(text, choice.logprobs),
)
if finish_reason:
if finish_reason in ["stop", "eos", "eos_token"]:

View file

@ -16,6 +16,14 @@ PROVIDER_TOOL_PROMPT_FORMAT = {
"remote::fireworks": "json",
}
PROVIDER_LOGPROBS_TOP_K = set(
{
"remote::together",
"remote::fireworks",
# "remote:vllm"
}
)
@pytest.fixture(scope="session")
def provider_tool_format(inference_provider_type):
@ -83,8 +91,12 @@ def test_text_completion_streaming(llama_stack_client, text_model_id):
assert "blue" in "".join(streamed_content).lower().strip()
@pytest.mark.skip("Most inference providers don't support log probs yet")
def test_completion_log_probs_non_streaming(llama_stack_client, text_model_id):
def test_completion_log_probs_non_streaming(
llama_stack_client, text_model_id, inference_provider_type
):
if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K:
pytest.xfail(f"{inference_provider_type} doesn't support log probs yet")
response = llama_stack_client.inference.completion(
content="Complete the sentence: Micheael Jordan is born in ",
stream=False,
@ -93,16 +105,22 @@ def test_completion_log_probs_non_streaming(llama_stack_client, text_model_id):
"max_tokens": 5,
},
logprobs={
"top_k": 3,
"top_k": 1,
},
)
assert response.logprobs, "Logprobs should not be empty"
assert 1 <= len(response.logprobs) <= 5
assert all(len(logprob.logprobs_by_token) == 3 for logprob in response.logprobs)
assert (
1 <= len(response.logprobs) <= 5
) # each token has 1 logprob and here max_tokens=5
assert all(len(logprob.logprobs_by_token) == 1 for logprob in response.logprobs)
@pytest.mark.skip("Most inference providers don't support log probs yet")
def test_completion_log_probs_streaming(llama_stack_client, text_model_id):
def test_completion_log_probs_streaming(
llama_stack_client, text_model_id, inference_provider_type
):
if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K:
pytest.xfail(f"{inference_provider_type} doesn't support log probs yet")
response = llama_stack_client.inference.completion(
content="Complete the sentence: Micheael Jordan is born in ",
stream=True,
@ -111,7 +129,7 @@ def test_completion_log_probs_streaming(llama_stack_client, text_model_id):
"max_tokens": 5,
},
logprobs={
"top_k": 3,
"top_k": 1,
},
)
streamed_content = [chunk for chunk in response]
@ -119,7 +137,7 @@ def test_completion_log_probs_streaming(llama_stack_client, text_model_id):
if chunk.delta: # if there's a token, we expect logprobs
assert chunk.logprobs, "Logprobs should not be empty"
assert all(
len(logprob.logprobs_by_token) == 3 for logprob in chunk.logprobs
len(logprob.logprobs_by_token) == 1 for logprob in chunk.logprobs
)
else: # no token, no logprobs
assert not chunk.logprobs, "Logprobs should be empty"