mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-06 10:42:39 +00:00
log probs - mark xfail for unsupported + support for together
This commit is contained in:
parent
7de46e40f9
commit
667c71f4e7
3 changed files with 69 additions and 13 deletions
|
@ -9,6 +9,7 @@ from typing import AsyncGenerator, List, Optional, Union
|
|||
from llama_models.datatypes import CoreModelId
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from termcolor import cprint
|
||||
from together import Together
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
|
@ -161,7 +162,10 @@ class TogetherInferenceAdapter(
|
|||
yield chunk
|
||||
|
||||
def _build_options(
|
||||
self, sampling_params: Optional[SamplingParams], fmt: ResponseFormat
|
||||
self,
|
||||
sampling_params: Optional[SamplingParams],
|
||||
logprobs: Optional[LogProbConfig],
|
||||
fmt: ResponseFormat,
|
||||
) -> dict:
|
||||
options = get_sampling_options(sampling_params)
|
||||
if fmt:
|
||||
|
@ -175,6 +179,14 @@ class TogetherInferenceAdapter(
|
|||
else:
|
||||
raise ValueError(f"Unknown response format {fmt.type}")
|
||||
|
||||
if logprobs and logprobs.top_k:
|
||||
if logprobs.top_k != 1:
|
||||
cprint(
|
||||
"Together only supports logprobs top_k=1. Overriding.",
|
||||
"Yello",
|
||||
)
|
||||
options["logprobs"] = 1
|
||||
|
||||
return options
|
||||
|
||||
async def chat_completion(
|
||||
|
@ -263,7 +275,9 @@ class TogetherInferenceAdapter(
|
|||
"model": request.model,
|
||||
**input_dict,
|
||||
"stream": request.stream,
|
||||
**self._build_options(request.sampling_params, request.response_format),
|
||||
**self._build_options(
|
||||
request.sampling_params, request.logprobs, request.response_format
|
||||
),
|
||||
}
|
||||
|
||||
async def embeddings(
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import AsyncGenerator, Dict, List, Optional
|
||||
from typing import AsyncGenerator, Dict, List, Optional, Union
|
||||
|
||||
from llama_models.datatypes import (
|
||||
GreedySamplingStrategy,
|
||||
|
@ -121,7 +121,29 @@ def convert_openai_completion_logprobs(
|
|||
) -> Optional[List[TokenLogProbs]]:
|
||||
if not logprobs:
|
||||
return None
|
||||
return [TokenLogProbs(logprobs_by_token=x) for x in logprobs.top_logprobs]
|
||||
if hasattr(logprobs, "top_logprobs"):
|
||||
return [TokenLogProbs(logprobs_by_token=x) for x in logprobs.top_logprobs]
|
||||
|
||||
# Together supports logprobs (top_k=1) but not top_logprobs (top_k>1).
|
||||
if logprobs.tokens and logprobs.token_logprobs:
|
||||
return [
|
||||
TokenLogProbs(logprobs_by_token={token: token_lp})
|
||||
for token, token_lp in zip(logprobs.tokens, logprobs.token_logprobs)
|
||||
]
|
||||
return None
|
||||
|
||||
|
||||
def convert_openai_completion_logprobs_stream(
|
||||
text: str, logprobs: Optional[Union[float, OpenAICompatLogprobs]]
|
||||
):
|
||||
if logprobs is None:
|
||||
return None
|
||||
if isinstance(logprobs, float):
|
||||
return [TokenLogProbs(logprobs_by_token={text: logprobs})]
|
||||
if hasattr(logprobs, "top_logprobs"):
|
||||
# Adapt response from Together CompletionChoicesChunk
|
||||
return [TokenLogProbs(logprobs_by_token=x) for x in logprobs.top_logprobs]
|
||||
return None
|
||||
|
||||
|
||||
def process_completion_response(
|
||||
|
@ -188,7 +210,7 @@ async def process_completion_stream_response(
|
|||
yield CompletionResponseStreamChunk(
|
||||
delta=text,
|
||||
stop_reason=stop_reason,
|
||||
logprobs=convert_openai_completion_logprobs(choice.logprobs),
|
||||
logprobs=convert_openai_completion_logprobs_stream(text, choice.logprobs),
|
||||
)
|
||||
if finish_reason:
|
||||
if finish_reason in ["stop", "eos", "eos_token"]:
|
||||
|
|
|
@ -16,6 +16,12 @@ PROVIDER_TOOL_PROMPT_FORMAT = {
|
|||
"remote::fireworks": "json",
|
||||
}
|
||||
|
||||
PROVIDER_LOGPROBS_TOP_K = {
|
||||
"remote::together": 1,
|
||||
"remote::fireworks": 3,
|
||||
# "remote:vllm"
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def provider_tool_format(inference_provider_type):
|
||||
|
@ -83,8 +89,13 @@ def test_text_completion_streaming(llama_stack_client, text_model_id):
|
|||
assert "blue" in "".join(streamed_content).lower().strip()
|
||||
|
||||
|
||||
@pytest.mark.skip("Most inference providers don't support log probs yet")
|
||||
def test_completion_log_probs_non_streaming(llama_stack_client, text_model_id):
|
||||
def test_completion_log_probs_non_streaming(
|
||||
llama_stack_client, text_model_id, inference_provider_type
|
||||
):
|
||||
if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K:
|
||||
pytest.xfail(f"{inference_provider_type} doesn't support log probs yet")
|
||||
|
||||
logprobs_top_k = PROVIDER_LOGPROBS_TOP_K[inference_provider_type]
|
||||
response = llama_stack_client.inference.completion(
|
||||
content="Complete the sentence: Micheael Jordan is born in ",
|
||||
stream=False,
|
||||
|
@ -93,16 +104,24 @@ def test_completion_log_probs_non_streaming(llama_stack_client, text_model_id):
|
|||
"max_tokens": 5,
|
||||
},
|
||||
logprobs={
|
||||
"top_k": 3,
|
||||
"top_k": logprobs_top_k,
|
||||
},
|
||||
)
|
||||
assert response.logprobs, "Logprobs should not be empty"
|
||||
assert 1 <= len(response.logprobs) <= 5
|
||||
assert all(len(logprob.logprobs_by_token) == 3 for logprob in response.logprobs)
|
||||
assert all(
|
||||
len(logprob.logprobs_by_token) == logprobs_top_k
|
||||
for logprob in response.logprobs
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip("Most inference providers don't support log probs yet")
|
||||
def test_completion_log_probs_streaming(llama_stack_client, text_model_id):
|
||||
def test_completion_log_probs_streaming(
|
||||
llama_stack_client, text_model_id, inference_provider_type
|
||||
):
|
||||
if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K:
|
||||
pytest.xfail(f"{inference_provider_type} doesn't support log probs yet")
|
||||
|
||||
logprobs_top_k = PROVIDER_LOGPROBS_TOP_K[inference_provider_type]
|
||||
response = llama_stack_client.inference.completion(
|
||||
content="Complete the sentence: Micheael Jordan is born in ",
|
||||
stream=True,
|
||||
|
@ -111,7 +130,7 @@ def test_completion_log_probs_streaming(llama_stack_client, text_model_id):
|
|||
"max_tokens": 5,
|
||||
},
|
||||
logprobs={
|
||||
"top_k": 3,
|
||||
"top_k": logprobs_top_k,
|
||||
},
|
||||
)
|
||||
streamed_content = [chunk for chunk in response]
|
||||
|
@ -119,7 +138,8 @@ def test_completion_log_probs_streaming(llama_stack_client, text_model_id):
|
|||
if chunk.delta: # if there's a token, we expect logprobs
|
||||
assert chunk.logprobs, "Logprobs should not be empty"
|
||||
assert all(
|
||||
len(logprob.logprobs_by_token) == 3 for logprob in chunk.logprobs
|
||||
len(logprob.logprobs_by_token) == logprobs_top_k
|
||||
for logprob in chunk.logprobs
|
||||
)
|
||||
else: # no token, no logprobs
|
||||
assert not chunk.logprobs, "Logprobs should be empty"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue