From 667c71f4e759afea4a07742279285f00416755cb Mon Sep 17 00:00:00 2001 From: Sixian Yi Date: Sun, 26 Jan 2025 22:37:41 -0800 Subject: [PATCH] log probs - mark xfail for unsupported + support for together --- .../remote/inference/together/together.py | 18 ++++++++-- .../utils/inference/openai_compat.py | 28 +++++++++++++-- tests/client-sdk/inference/test_inference.py | 36 ++++++++++++++----- 3 files changed, 69 insertions(+), 13 deletions(-) diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index 8f679cb56..e2bbb2220 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -9,6 +9,7 @@ from typing import AsyncGenerator, List, Optional, Union from llama_models.datatypes import CoreModelId from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.tokenizer import Tokenizer +from termcolor import cprint from together import Together from llama_stack.apis.common.content_types import InterleavedContent @@ -161,7 +162,10 @@ class TogetherInferenceAdapter( yield chunk def _build_options( - self, sampling_params: Optional[SamplingParams], fmt: ResponseFormat + self, + sampling_params: Optional[SamplingParams], + logprobs: Optional[LogProbConfig], + fmt: ResponseFormat, ) -> dict: options = get_sampling_options(sampling_params) if fmt: @@ -175,6 +179,14 @@ class TogetherInferenceAdapter( else: raise ValueError(f"Unknown response format {fmt.type}") + if logprobs and logprobs.top_k: + if logprobs.top_k != 1: + cprint( + "Together only supports logprobs top_k=1. Overriding.", + "Yello", + ) + options["logprobs"] = 1 + return options async def chat_completion( @@ -263,7 +275,9 @@ class TogetherInferenceAdapter( "model": request.model, **input_dict, "stream": request.stream, - **self._build_options(request.sampling_params, request.response_format), + **self._build_options( + request.sampling_params, request.logprobs, request.response_format + ), } async def embeddings( diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index 6c93f49c0..41411d9f0 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import AsyncGenerator, Dict, List, Optional +from typing import AsyncGenerator, Dict, List, Optional, Union from llama_models.datatypes import ( GreedySamplingStrategy, @@ -121,7 +121,29 @@ def convert_openai_completion_logprobs( ) -> Optional[List[TokenLogProbs]]: if not logprobs: return None - return [TokenLogProbs(logprobs_by_token=x) for x in logprobs.top_logprobs] + if hasattr(logprobs, "top_logprobs"): + return [TokenLogProbs(logprobs_by_token=x) for x in logprobs.top_logprobs] + + # Together supports logprobs (top_k=1) but not top_logprobs (top_k>1). + if logprobs.tokens and logprobs.token_logprobs: + return [ + TokenLogProbs(logprobs_by_token={token: token_lp}) + for token, token_lp in zip(logprobs.tokens, logprobs.token_logprobs) + ] + return None + + +def convert_openai_completion_logprobs_stream( + text: str, logprobs: Optional[Union[float, OpenAICompatLogprobs]] +): + if logprobs is None: + return None + if isinstance(logprobs, float): + return [TokenLogProbs(logprobs_by_token={text: logprobs})] + if hasattr(logprobs, "top_logprobs"): + # Adapt response from Together CompletionChoicesChunk + return [TokenLogProbs(logprobs_by_token=x) for x in logprobs.top_logprobs] + return None def process_completion_response( @@ -188,7 +210,7 @@ async def process_completion_stream_response( yield CompletionResponseStreamChunk( delta=text, stop_reason=stop_reason, - logprobs=convert_openai_completion_logprobs(choice.logprobs), + logprobs=convert_openai_completion_logprobs_stream(text, choice.logprobs), ) if finish_reason: if finish_reason in ["stop", "eos", "eos_token"]: diff --git a/tests/client-sdk/inference/test_inference.py b/tests/client-sdk/inference/test_inference.py index 8ca11521c..6e76c1339 100644 --- a/tests/client-sdk/inference/test_inference.py +++ b/tests/client-sdk/inference/test_inference.py @@ -16,6 +16,12 @@ PROVIDER_TOOL_PROMPT_FORMAT = { "remote::fireworks": "json", } +PROVIDER_LOGPROBS_TOP_K = { + "remote::together": 1, + "remote::fireworks": 3, + # "remote:vllm" +} + @pytest.fixture(scope="session") def provider_tool_format(inference_provider_type): @@ -83,8 +89,13 @@ def test_text_completion_streaming(llama_stack_client, text_model_id): assert "blue" in "".join(streamed_content).lower().strip() -@pytest.mark.skip("Most inference providers don't support log probs yet") -def test_completion_log_probs_non_streaming(llama_stack_client, text_model_id): +def test_completion_log_probs_non_streaming( + llama_stack_client, text_model_id, inference_provider_type +): + if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K: + pytest.xfail(f"{inference_provider_type} doesn't support log probs yet") + + logprobs_top_k = PROVIDER_LOGPROBS_TOP_K[inference_provider_type] response = llama_stack_client.inference.completion( content="Complete the sentence: Micheael Jordan is born in ", stream=False, @@ -93,16 +104,24 @@ def test_completion_log_probs_non_streaming(llama_stack_client, text_model_id): "max_tokens": 5, }, logprobs={ - "top_k": 3, + "top_k": logprobs_top_k, }, ) assert response.logprobs, "Logprobs should not be empty" assert 1 <= len(response.logprobs) <= 5 - assert all(len(logprob.logprobs_by_token) == 3 for logprob in response.logprobs) + assert all( + len(logprob.logprobs_by_token) == logprobs_top_k + for logprob in response.logprobs + ) -@pytest.mark.skip("Most inference providers don't support log probs yet") -def test_completion_log_probs_streaming(llama_stack_client, text_model_id): +def test_completion_log_probs_streaming( + llama_stack_client, text_model_id, inference_provider_type +): + if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K: + pytest.xfail(f"{inference_provider_type} doesn't support log probs yet") + + logprobs_top_k = PROVIDER_LOGPROBS_TOP_K[inference_provider_type] response = llama_stack_client.inference.completion( content="Complete the sentence: Micheael Jordan is born in ", stream=True, @@ -111,7 +130,7 @@ def test_completion_log_probs_streaming(llama_stack_client, text_model_id): "max_tokens": 5, }, logprobs={ - "top_k": 3, + "top_k": logprobs_top_k, }, ) streamed_content = [chunk for chunk in response] @@ -119,7 +138,8 @@ def test_completion_log_probs_streaming(llama_stack_client, text_model_id): if chunk.delta: # if there's a token, we expect logprobs assert chunk.logprobs, "Logprobs should not be empty" assert all( - len(logprob.logprobs_by_token) == 3 for logprob in chunk.logprobs + len(logprob.logprobs_by_token) == logprobs_top_k + for logprob in chunk.logprobs ) else: # no token, no logprobs assert not chunk.logprobs, "Logprobs should be empty"