forked from phoenix-oss/llama-stack-mirror
log probs - mark pytests as xfail for unsupported providers + add support for together (#883)
# What does this PR do? 1) As per @mattf's suggestion, we want to mark the pytest as xfail for providers that do not support the functionality. In this diff, we xfail the logProbs inference tests for providers who does not support log probs. ( log probs is only supported by together, fireworks and vllm) 2) Added logProbs support for together according to their developer [doc](https://docs.together.ai/docs/logprobs). ## Test Plan 1) Together & Fireworks ``` export LLAMA_STACK_CONFIG=/Users/sxyi/llama-stack/llama_stack/templates/together/run.yaml /opt/miniconda3/envs/stack/bin/pytest -s -v /Users/sxyi/llama-stack/tests/client-sdk/inference/test_inference.py ``` ``` tests/client-sdk/inference/test_inference.py::test_text_completion_streaming[meta-llama/Llama-3.1-8B-Instruct] PASSED tests/client-sdk/inference/test_inference.py::test_completion_log_probs_non_streaming[meta-llama/Llama-3.1-8B-Instruct] PASSED tests/client-sdk/inference/test_inference.py::test_completion_log_probs_streaming[meta-llama/Llama-3.1-8B-Instruct] PASSED tests/client-sdk/inference/test_inference.py::test_text_completion_structured_output[meta-llama/Llama-3.1-8B-Instruct] PASSED tests/client-sdk/inference/test_inference.py::test_text_chat_completion_non_streaming[meta-llama/Llama-3.1-8B-Instruct-What are the names of planets in our solar system?-Earth] PASSED tests/client-sdk/inference/test_inference.py::test_text_chat_completion_non_streaming[meta-llama/Llama-3.1-8B-Instruct-What are the names of the planets that have rings around them?-Saturn] PASSED tests/client-sdk/inference/test_inference.py::test_text_chat_completion_streaming[meta-llama/Llama-3.1-8B-Instruct-What's the name of the Sun in latin?-Sol] PASSED tests/client-sdk/inference/test_inference.py::test_text_chat_completion_streaming[meta-llama/Llama-3.1-8B-Instruct-What is the name of the US captial?-Washington] PASSED tests/client-sdk/inference/test_inference.py::test_text_chat_completion_with_tool_calling_and_non_streaming[meta-llama/Llama-3.1-8B-Instruct] PASSED tests/client-sdk/inference/test_inference.py::test_text_chat_completion_with_tool_calling_and_streaming[meta-llama/Llama-3.1-8B-Instruct] PASSED tests/client-sdk/inference/test_inference.py::test_text_chat_completion_structured_output[meta-llama/Llama-3.1-8B-Instruct] PASSED tests/client-sdk/inference/test_inference.py::test_image_chat_completion_non_streaming[meta-llama/Llama-3.2-11B-Vision-Instruct] PASSED tests/client-sdk/inference/test_inference.py::test_image_chat_completion_streaming[meta-llama/Llama-3.2-11B-Vision-Instruct] PASSED tests/client-sdk/inference/test_inference.py::test_image_chat_completion_base64_url[meta-llama/Llama-3.2-11B-Vision-Instruct] PASSED ========================================================================================== 15 passed, 2 warnings in 19.46s =========================================================================================== ``` ``` export LLAMA_STACK_CONFIG=/Users/sxyi/llama-stack/llama_stack/templates/fireworks/run.yaml /opt/miniconda3/envs/stack/bin/pytest -s -v /Users/sxyi/llama-stack/tests/client-sdk/inference/test_inference.py ``` All tests passed 2) Ollama - LogProbs tests are marked as xfailed. ``` tests/client-sdk/inference/test_inference.py::test_completion_log_probs_non_streaming[meta-llama/Llama-3.1-8B-Instruct] XFAIL (remote::ollama doesn't support log probs yet) tests/client-sdk/inference/test_inference.py::test_completion_log_probs_streaming[meta-llama/Llama-3.1-8B-Instruct] XFAIL (remote::ollama doesn't support log probs yet) ``` ## Sources Please link relevant resources if necessary. ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Ran pre-commit to handle lint / formatting issues. - [ ] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [ ] Wrote necessary unit or integration tests.
This commit is contained in:
parent
6f9023d948
commit
836f47a82d
3 changed files with 68 additions and 14 deletions
|
@ -161,7 +161,10 @@ class TogetherInferenceAdapter(
|
|||
yield chunk
|
||||
|
||||
def _build_options(
|
||||
self, sampling_params: Optional[SamplingParams], fmt: ResponseFormat
|
||||
self,
|
||||
sampling_params: Optional[SamplingParams],
|
||||
logprobs: Optional[LogProbConfig],
|
||||
fmt: ResponseFormat,
|
||||
) -> dict:
|
||||
options = get_sampling_options(sampling_params)
|
||||
if fmt:
|
||||
|
@ -175,6 +178,13 @@ class TogetherInferenceAdapter(
|
|||
else:
|
||||
raise ValueError(f"Unknown response format {fmt.type}")
|
||||
|
||||
if logprobs and logprobs.top_k:
|
||||
if logprobs.top_k != 1:
|
||||
raise ValueError(
|
||||
f"Unsupported value: Together only supports logprobs top_k=1. {logprobs.top_k} was provided",
|
||||
)
|
||||
options["logprobs"] = 1
|
||||
|
||||
return options
|
||||
|
||||
async def chat_completion(
|
||||
|
@ -263,7 +273,9 @@ class TogetherInferenceAdapter(
|
|||
"model": request.model,
|
||||
**input_dict,
|
||||
"stream": request.stream,
|
||||
**self._build_options(request.sampling_params, request.response_format),
|
||||
**self._build_options(
|
||||
request.sampling_params, request.logprobs, request.response_format
|
||||
),
|
||||
}
|
||||
|
||||
async def embeddings(
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import AsyncGenerator, Dict, List, Optional
|
||||
from typing import AsyncGenerator, Dict, List, Optional, Union
|
||||
|
||||
from llama_models.datatypes import (
|
||||
GreedySamplingStrategy,
|
||||
|
@ -121,7 +121,31 @@ def convert_openai_completion_logprobs(
|
|||
) -> Optional[List[TokenLogProbs]]:
|
||||
if not logprobs:
|
||||
return None
|
||||
return [TokenLogProbs(logprobs_by_token=x) for x in logprobs.top_logprobs]
|
||||
if hasattr(logprobs, "top_logprobs"):
|
||||
return [TokenLogProbs(logprobs_by_token=x) for x in logprobs.top_logprobs]
|
||||
|
||||
# Together supports logprobs with top_k=1 only. This means for each token position,
|
||||
# they return only the logprobs for the selected token (vs. the top n most likely tokens).
|
||||
# Here we construct the response by matching the selected token with the logprobs.
|
||||
if logprobs.tokens and logprobs.token_logprobs:
|
||||
return [
|
||||
TokenLogProbs(logprobs_by_token={token: token_lp})
|
||||
for token, token_lp in zip(logprobs.tokens, logprobs.token_logprobs)
|
||||
]
|
||||
return None
|
||||
|
||||
|
||||
def convert_openai_completion_logprobs_stream(
|
||||
text: str, logprobs: Optional[Union[float, OpenAICompatLogprobs]]
|
||||
):
|
||||
if logprobs is None:
|
||||
return None
|
||||
if isinstance(logprobs, float):
|
||||
# Adapt response from Together CompletionChoicesChunk
|
||||
return [TokenLogProbs(logprobs_by_token={text: logprobs})]
|
||||
if hasattr(logprobs, "top_logprobs"):
|
||||
return [TokenLogProbs(logprobs_by_token=x) for x in logprobs.top_logprobs]
|
||||
return None
|
||||
|
||||
|
||||
def process_completion_response(
|
||||
|
@ -188,7 +212,7 @@ async def process_completion_stream_response(
|
|||
yield CompletionResponseStreamChunk(
|
||||
delta=text,
|
||||
stop_reason=stop_reason,
|
||||
logprobs=convert_openai_completion_logprobs(choice.logprobs),
|
||||
logprobs=convert_openai_completion_logprobs_stream(text, choice.logprobs),
|
||||
)
|
||||
if finish_reason:
|
||||
if finish_reason in ["stop", "eos", "eos_token"]:
|
||||
|
|
|
@ -16,6 +16,14 @@ PROVIDER_TOOL_PROMPT_FORMAT = {
|
|||
"remote::fireworks": "json",
|
||||
}
|
||||
|
||||
PROVIDER_LOGPROBS_TOP_K = set(
|
||||
{
|
||||
"remote::together",
|
||||
"remote::fireworks",
|
||||
# "remote:vllm"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def provider_tool_format(inference_provider_type):
|
||||
|
@ -83,8 +91,12 @@ def test_text_completion_streaming(llama_stack_client, text_model_id):
|
|||
assert "blue" in "".join(streamed_content).lower().strip()
|
||||
|
||||
|
||||
@pytest.mark.skip("Most inference providers don't support log probs yet")
|
||||
def test_completion_log_probs_non_streaming(llama_stack_client, text_model_id):
|
||||
def test_completion_log_probs_non_streaming(
|
||||
llama_stack_client, text_model_id, inference_provider_type
|
||||
):
|
||||
if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K:
|
||||
pytest.xfail(f"{inference_provider_type} doesn't support log probs yet")
|
||||
|
||||
response = llama_stack_client.inference.completion(
|
||||
content="Complete the sentence: Micheael Jordan is born in ",
|
||||
stream=False,
|
||||
|
@ -93,16 +105,22 @@ def test_completion_log_probs_non_streaming(llama_stack_client, text_model_id):
|
|||
"max_tokens": 5,
|
||||
},
|
||||
logprobs={
|
||||
"top_k": 3,
|
||||
"top_k": 1,
|
||||
},
|
||||
)
|
||||
assert response.logprobs, "Logprobs should not be empty"
|
||||
assert 1 <= len(response.logprobs) <= 5
|
||||
assert all(len(logprob.logprobs_by_token) == 3 for logprob in response.logprobs)
|
||||
assert (
|
||||
1 <= len(response.logprobs) <= 5
|
||||
) # each token has 1 logprob and here max_tokens=5
|
||||
assert all(len(logprob.logprobs_by_token) == 1 for logprob in response.logprobs)
|
||||
|
||||
|
||||
@pytest.mark.skip("Most inference providers don't support log probs yet")
|
||||
def test_completion_log_probs_streaming(llama_stack_client, text_model_id):
|
||||
def test_completion_log_probs_streaming(
|
||||
llama_stack_client, text_model_id, inference_provider_type
|
||||
):
|
||||
if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K:
|
||||
pytest.xfail(f"{inference_provider_type} doesn't support log probs yet")
|
||||
|
||||
response = llama_stack_client.inference.completion(
|
||||
content="Complete the sentence: Micheael Jordan is born in ",
|
||||
stream=True,
|
||||
|
@ -111,7 +129,7 @@ def test_completion_log_probs_streaming(llama_stack_client, text_model_id):
|
|||
"max_tokens": 5,
|
||||
},
|
||||
logprobs={
|
||||
"top_k": 3,
|
||||
"top_k": 1,
|
||||
},
|
||||
)
|
||||
streamed_content = [chunk for chunk in response]
|
||||
|
@ -119,7 +137,7 @@ def test_completion_log_probs_streaming(llama_stack_client, text_model_id):
|
|||
if chunk.delta: # if there's a token, we expect logprobs
|
||||
assert chunk.logprobs, "Logprobs should not be empty"
|
||||
assert all(
|
||||
len(logprob.logprobs_by_token) == 3 for logprob in chunk.logprobs
|
||||
len(logprob.logprobs_by_token) == 1 for logprob in chunk.logprobs
|
||||
)
|
||||
else: # no token, no logprobs
|
||||
assert not chunk.logprobs, "Logprobs should be empty"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue