forked from phoenix-oss/llama-stack-mirror
log probs - mark pytests as xfail for unsupported providers + add support for together (#883)
# What does this PR do? 1) As per @mattf's suggestion, we want to mark the pytest as xfail for providers that do not support the functionality. In this diff, we xfail the logProbs inference tests for providers who does not support log probs. ( log probs is only supported by together, fireworks and vllm) 2) Added logProbs support for together according to their developer [doc](https://docs.together.ai/docs/logprobs). ## Test Plan 1) Together & Fireworks ``` export LLAMA_STACK_CONFIG=/Users/sxyi/llama-stack/llama_stack/templates/together/run.yaml /opt/miniconda3/envs/stack/bin/pytest -s -v /Users/sxyi/llama-stack/tests/client-sdk/inference/test_inference.py ``` ``` tests/client-sdk/inference/test_inference.py::test_text_completion_streaming[meta-llama/Llama-3.1-8B-Instruct] PASSED tests/client-sdk/inference/test_inference.py::test_completion_log_probs_non_streaming[meta-llama/Llama-3.1-8B-Instruct] PASSED tests/client-sdk/inference/test_inference.py::test_completion_log_probs_streaming[meta-llama/Llama-3.1-8B-Instruct] PASSED tests/client-sdk/inference/test_inference.py::test_text_completion_structured_output[meta-llama/Llama-3.1-8B-Instruct] PASSED tests/client-sdk/inference/test_inference.py::test_text_chat_completion_non_streaming[meta-llama/Llama-3.1-8B-Instruct-What are the names of planets in our solar system?-Earth] PASSED tests/client-sdk/inference/test_inference.py::test_text_chat_completion_non_streaming[meta-llama/Llama-3.1-8B-Instruct-What are the names of the planets that have rings around them?-Saturn] PASSED tests/client-sdk/inference/test_inference.py::test_text_chat_completion_streaming[meta-llama/Llama-3.1-8B-Instruct-What's the name of the Sun in latin?-Sol] PASSED tests/client-sdk/inference/test_inference.py::test_text_chat_completion_streaming[meta-llama/Llama-3.1-8B-Instruct-What is the name of the US captial?-Washington] PASSED tests/client-sdk/inference/test_inference.py::test_text_chat_completion_with_tool_calling_and_non_streaming[meta-llama/Llama-3.1-8B-Instruct] PASSED tests/client-sdk/inference/test_inference.py::test_text_chat_completion_with_tool_calling_and_streaming[meta-llama/Llama-3.1-8B-Instruct] PASSED tests/client-sdk/inference/test_inference.py::test_text_chat_completion_structured_output[meta-llama/Llama-3.1-8B-Instruct] PASSED tests/client-sdk/inference/test_inference.py::test_image_chat_completion_non_streaming[meta-llama/Llama-3.2-11B-Vision-Instruct] PASSED tests/client-sdk/inference/test_inference.py::test_image_chat_completion_streaming[meta-llama/Llama-3.2-11B-Vision-Instruct] PASSED tests/client-sdk/inference/test_inference.py::test_image_chat_completion_base64_url[meta-llama/Llama-3.2-11B-Vision-Instruct] PASSED ========================================================================================== 15 passed, 2 warnings in 19.46s =========================================================================================== ``` ``` export LLAMA_STACK_CONFIG=/Users/sxyi/llama-stack/llama_stack/templates/fireworks/run.yaml /opt/miniconda3/envs/stack/bin/pytest -s -v /Users/sxyi/llama-stack/tests/client-sdk/inference/test_inference.py ``` All tests passed 2) Ollama - LogProbs tests are marked as xfailed. ``` tests/client-sdk/inference/test_inference.py::test_completion_log_probs_non_streaming[meta-llama/Llama-3.1-8B-Instruct] XFAIL (remote::ollama doesn't support log probs yet) tests/client-sdk/inference/test_inference.py::test_completion_log_probs_streaming[meta-llama/Llama-3.1-8B-Instruct] XFAIL (remote::ollama doesn't support log probs yet) ``` ## Sources Please link relevant resources if necessary. ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Ran pre-commit to handle lint / formatting issues. - [ ] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [ ] Wrote necessary unit or integration tests.
This commit is contained in:
parent
6f9023d948
commit
836f47a82d
3 changed files with 68 additions and 14 deletions
|
@ -161,7 +161,10 @@ class TogetherInferenceAdapter(
|
|||
yield chunk
|
||||
|
||||
def _build_options(
|
||||
self, sampling_params: Optional[SamplingParams], fmt: ResponseFormat
|
||||
self,
|
||||
sampling_params: Optional[SamplingParams],
|
||||
logprobs: Optional[LogProbConfig],
|
||||
fmt: ResponseFormat,
|
||||
) -> dict:
|
||||
options = get_sampling_options(sampling_params)
|
||||
if fmt:
|
||||
|
@ -175,6 +178,13 @@ class TogetherInferenceAdapter(
|
|||
else:
|
||||
raise ValueError(f"Unknown response format {fmt.type}")
|
||||
|
||||
if logprobs and logprobs.top_k:
|
||||
if logprobs.top_k != 1:
|
||||
raise ValueError(
|
||||
f"Unsupported value: Together only supports logprobs top_k=1. {logprobs.top_k} was provided",
|
||||
)
|
||||
options["logprobs"] = 1
|
||||
|
||||
return options
|
||||
|
||||
async def chat_completion(
|
||||
|
@ -263,7 +273,9 @@ class TogetherInferenceAdapter(
|
|||
"model": request.model,
|
||||
**input_dict,
|
||||
"stream": request.stream,
|
||||
**self._build_options(request.sampling_params, request.response_format),
|
||||
**self._build_options(
|
||||
request.sampling_params, request.logprobs, request.response_format
|
||||
),
|
||||
}
|
||||
|
||||
async def embeddings(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue