diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index 8f679cb56..605b3ce97 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -161,7 +161,10 @@ class TogetherInferenceAdapter( yield chunk def _build_options( - self, sampling_params: Optional[SamplingParams], fmt: ResponseFormat + self, + sampling_params: Optional[SamplingParams], + logprobs: Optional[LogProbConfig], + fmt: ResponseFormat, ) -> dict: options = get_sampling_options(sampling_params) if fmt: @@ -175,6 +178,13 @@ class TogetherInferenceAdapter( else: raise ValueError(f"Unknown response format {fmt.type}") + if logprobs and logprobs.top_k: + if logprobs.top_k != 1: + raise ValueError( + f"Unsupported value: Together only supports logprobs top_k=1. {logprobs.top_k} was provided", + ) + options["logprobs"] = 1 + return options async def chat_completion( @@ -263,7 +273,9 @@ class TogetherInferenceAdapter( "model": request.model, **input_dict, "stream": request.stream, - **self._build_options(request.sampling_params, request.response_format), + **self._build_options( + request.sampling_params, request.logprobs, request.response_format + ), } async def embeddings( diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index 6c93f49c0..a0fb23c97 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import AsyncGenerator, Dict, List, Optional +from typing import AsyncGenerator, Dict, List, Optional, Union from llama_models.datatypes import ( GreedySamplingStrategy, @@ -121,7 +121,31 @@ def convert_openai_completion_logprobs( ) -> Optional[List[TokenLogProbs]]: if not logprobs: return None - return [TokenLogProbs(logprobs_by_token=x) for x in logprobs.top_logprobs] + if hasattr(logprobs, "top_logprobs"): + return [TokenLogProbs(logprobs_by_token=x) for x in logprobs.top_logprobs] + + # Together supports logprobs with top_k=1 only. This means for each token position, + # they return only the logprobs for the selected token (vs. the top n most likely tokens). + # Here we construct the response by matching the selected token with the logprobs. + if logprobs.tokens and logprobs.token_logprobs: + return [ + TokenLogProbs(logprobs_by_token={token: token_lp}) + for token, token_lp in zip(logprobs.tokens, logprobs.token_logprobs) + ] + return None + + +def convert_openai_completion_logprobs_stream( + text: str, logprobs: Optional[Union[float, OpenAICompatLogprobs]] +): + if logprobs is None: + return None + if isinstance(logprobs, float): + # Adapt response from Together CompletionChoicesChunk + return [TokenLogProbs(logprobs_by_token={text: logprobs})] + if hasattr(logprobs, "top_logprobs"): + return [TokenLogProbs(logprobs_by_token=x) for x in logprobs.top_logprobs] + return None def process_completion_response( @@ -188,7 +212,7 @@ async def process_completion_stream_response( yield CompletionResponseStreamChunk( delta=text, stop_reason=stop_reason, - logprobs=convert_openai_completion_logprobs(choice.logprobs), + logprobs=convert_openai_completion_logprobs_stream(text, choice.logprobs), ) if finish_reason: if finish_reason in ["stop", "eos", "eos_token"]: diff --git a/tests/client-sdk/inference/test_inference.py b/tests/client-sdk/inference/test_inference.py index 8ca11521c..6dff1be24 100644 --- a/tests/client-sdk/inference/test_inference.py +++ b/tests/client-sdk/inference/test_inference.py @@ -16,6 +16,14 @@ PROVIDER_TOOL_PROMPT_FORMAT = { "remote::fireworks": "json", } +PROVIDER_LOGPROBS_TOP_K = set( + { + "remote::together", + "remote::fireworks", + # "remote:vllm" + } +) + @pytest.fixture(scope="session") def provider_tool_format(inference_provider_type): @@ -83,8 +91,12 @@ def test_text_completion_streaming(llama_stack_client, text_model_id): assert "blue" in "".join(streamed_content).lower().strip() -@pytest.mark.skip("Most inference providers don't support log probs yet") -def test_completion_log_probs_non_streaming(llama_stack_client, text_model_id): +def test_completion_log_probs_non_streaming( + llama_stack_client, text_model_id, inference_provider_type +): + if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K: + pytest.xfail(f"{inference_provider_type} doesn't support log probs yet") + response = llama_stack_client.inference.completion( content="Complete the sentence: Micheael Jordan is born in ", stream=False, @@ -93,16 +105,22 @@ def test_completion_log_probs_non_streaming(llama_stack_client, text_model_id): "max_tokens": 5, }, logprobs={ - "top_k": 3, + "top_k": 1, }, ) assert response.logprobs, "Logprobs should not be empty" - assert 1 <= len(response.logprobs) <= 5 - assert all(len(logprob.logprobs_by_token) == 3 for logprob in response.logprobs) + assert ( + 1 <= len(response.logprobs) <= 5 + ) # each token has 1 logprob and here max_tokens=5 + assert all(len(logprob.logprobs_by_token) == 1 for logprob in response.logprobs) -@pytest.mark.skip("Most inference providers don't support log probs yet") -def test_completion_log_probs_streaming(llama_stack_client, text_model_id): +def test_completion_log_probs_streaming( + llama_stack_client, text_model_id, inference_provider_type +): + if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K: + pytest.xfail(f"{inference_provider_type} doesn't support log probs yet") + response = llama_stack_client.inference.completion( content="Complete the sentence: Micheael Jordan is born in ", stream=True, @@ -111,7 +129,7 @@ def test_completion_log_probs_streaming(llama_stack_client, text_model_id): "max_tokens": 5, }, logprobs={ - "top_k": 3, + "top_k": 1, }, ) streamed_content = [chunk for chunk in response] @@ -119,7 +137,7 @@ def test_completion_log_probs_streaming(llama_stack_client, text_model_id): if chunk.delta: # if there's a token, we expect logprobs assert chunk.logprobs, "Logprobs should not be empty" assert all( - len(logprob.logprobs_by_token) == 3 for logprob in chunk.logprobs + len(logprob.logprobs_by_token) == 1 for logprob in chunk.logprobs ) else: # no token, no logprobs assert not chunk.logprobs, "Logprobs should be empty"