log probs - mark xfail for unsupported + support for together

This commit is contained in:
Sixian Yi 2025-01-26 22:37:41 -08:00
parent 7de46e40f9
commit 667c71f4e7
3 changed files with 69 additions and 13 deletions

View file

@ -9,6 +9,7 @@ from typing import AsyncGenerator, List, Optional, Union
from llama_models.datatypes import CoreModelId
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from termcolor import cprint
from together import Together
from llama_stack.apis.common.content_types import InterleavedContent
@ -161,7 +162,10 @@ class TogetherInferenceAdapter(
yield chunk
def _build_options(
self, sampling_params: Optional[SamplingParams], fmt: ResponseFormat
self,
sampling_params: Optional[SamplingParams],
logprobs: Optional[LogProbConfig],
fmt: ResponseFormat,
) -> dict:
options = get_sampling_options(sampling_params)
if fmt:
@ -175,6 +179,14 @@ class TogetherInferenceAdapter(
else:
raise ValueError(f"Unknown response format {fmt.type}")
if logprobs and logprobs.top_k:
if logprobs.top_k != 1:
cprint(
"Together only supports logprobs top_k=1. Overriding.",
"Yello",
)
options["logprobs"] = 1
return options
async def chat_completion(
@ -263,7 +275,9 @@ class TogetherInferenceAdapter(
"model": request.model,
**input_dict,
"stream": request.stream,
**self._build_options(request.sampling_params, request.response_format),
**self._build_options(
request.sampling_params, request.logprobs, request.response_format
),
}
async def embeddings(

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import AsyncGenerator, Dict, List, Optional
from typing import AsyncGenerator, Dict, List, Optional, Union
from llama_models.datatypes import (
GreedySamplingStrategy,
@ -121,7 +121,29 @@ def convert_openai_completion_logprobs(
) -> Optional[List[TokenLogProbs]]:
if not logprobs:
return None
return [TokenLogProbs(logprobs_by_token=x) for x in logprobs.top_logprobs]
if hasattr(logprobs, "top_logprobs"):
return [TokenLogProbs(logprobs_by_token=x) for x in logprobs.top_logprobs]
# Together supports logprobs (top_k=1) but not top_logprobs (top_k>1).
if logprobs.tokens and logprobs.token_logprobs:
return [
TokenLogProbs(logprobs_by_token={token: token_lp})
for token, token_lp in zip(logprobs.tokens, logprobs.token_logprobs)
]
return None
def convert_openai_completion_logprobs_stream(
text: str, logprobs: Optional[Union[float, OpenAICompatLogprobs]]
):
if logprobs is None:
return None
if isinstance(logprobs, float):
return [TokenLogProbs(logprobs_by_token={text: logprobs})]
if hasattr(logprobs, "top_logprobs"):
# Adapt response from Together CompletionChoicesChunk
return [TokenLogProbs(logprobs_by_token=x) for x in logprobs.top_logprobs]
return None
def process_completion_response(
@ -188,7 +210,7 @@ async def process_completion_stream_response(
yield CompletionResponseStreamChunk(
delta=text,
stop_reason=stop_reason,
logprobs=convert_openai_completion_logprobs(choice.logprobs),
logprobs=convert_openai_completion_logprobs_stream(text, choice.logprobs),
)
if finish_reason:
if finish_reason in ["stop", "eos", "eos_token"]:

View file

@ -16,6 +16,12 @@ PROVIDER_TOOL_PROMPT_FORMAT = {
"remote::fireworks": "json",
}
PROVIDER_LOGPROBS_TOP_K = {
"remote::together": 1,
"remote::fireworks": 3,
# "remote:vllm"
}
@pytest.fixture(scope="session")
def provider_tool_format(inference_provider_type):
@ -83,8 +89,13 @@ def test_text_completion_streaming(llama_stack_client, text_model_id):
assert "blue" in "".join(streamed_content).lower().strip()
@pytest.mark.skip("Most inference providers don't support log probs yet")
def test_completion_log_probs_non_streaming(llama_stack_client, text_model_id):
def test_completion_log_probs_non_streaming(
llama_stack_client, text_model_id, inference_provider_type
):
if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K:
pytest.xfail(f"{inference_provider_type} doesn't support log probs yet")
logprobs_top_k = PROVIDER_LOGPROBS_TOP_K[inference_provider_type]
response = llama_stack_client.inference.completion(
content="Complete the sentence: Micheael Jordan is born in ",
stream=False,
@ -93,16 +104,24 @@ def test_completion_log_probs_non_streaming(llama_stack_client, text_model_id):
"max_tokens": 5,
},
logprobs={
"top_k": 3,
"top_k": logprobs_top_k,
},
)
assert response.logprobs, "Logprobs should not be empty"
assert 1 <= len(response.logprobs) <= 5
assert all(len(logprob.logprobs_by_token) == 3 for logprob in response.logprobs)
assert all(
len(logprob.logprobs_by_token) == logprobs_top_k
for logprob in response.logprobs
)
@pytest.mark.skip("Most inference providers don't support log probs yet")
def test_completion_log_probs_streaming(llama_stack_client, text_model_id):
def test_completion_log_probs_streaming(
llama_stack_client, text_model_id, inference_provider_type
):
if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K:
pytest.xfail(f"{inference_provider_type} doesn't support log probs yet")
logprobs_top_k = PROVIDER_LOGPROBS_TOP_K[inference_provider_type]
response = llama_stack_client.inference.completion(
content="Complete the sentence: Micheael Jordan is born in ",
stream=True,
@ -111,7 +130,7 @@ def test_completion_log_probs_streaming(llama_stack_client, text_model_id):
"max_tokens": 5,
},
logprobs={
"top_k": 3,
"top_k": logprobs_top_k,
},
)
streamed_content = [chunk for chunk in response]
@ -119,7 +138,8 @@ def test_completion_log_probs_streaming(llama_stack_client, text_model_id):
if chunk.delta: # if there's a token, we expect logprobs
assert chunk.logprobs, "Logprobs should not be empty"
assert all(
len(logprob.logprobs_by_token) == 3 for logprob in chunk.logprobs
len(logprob.logprobs_by_token) == logprobs_top_k
for logprob in chunk.logprobs
)
else: # no token, no logprobs
assert not chunk.logprobs, "Logprobs should be empty"