mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-07 02:58:21 +00:00
log probs - mark xfail for unsupported + support for together
This commit is contained in:
parent
7de46e40f9
commit
667c71f4e7
3 changed files with 69 additions and 13 deletions
|
@ -9,6 +9,7 @@ from typing import AsyncGenerator, List, Optional, Union
|
||||||
from llama_models.datatypes import CoreModelId
|
from llama_models.datatypes import CoreModelId
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
from termcolor import cprint
|
||||||
from together import Together
|
from together import Together
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent
|
from llama_stack.apis.common.content_types import InterleavedContent
|
||||||
|
@ -161,7 +162,10 @@ class TogetherInferenceAdapter(
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
def _build_options(
|
def _build_options(
|
||||||
self, sampling_params: Optional[SamplingParams], fmt: ResponseFormat
|
self,
|
||||||
|
sampling_params: Optional[SamplingParams],
|
||||||
|
logprobs: Optional[LogProbConfig],
|
||||||
|
fmt: ResponseFormat,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
options = get_sampling_options(sampling_params)
|
options = get_sampling_options(sampling_params)
|
||||||
if fmt:
|
if fmt:
|
||||||
|
@ -175,6 +179,14 @@ class TogetherInferenceAdapter(
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown response format {fmt.type}")
|
raise ValueError(f"Unknown response format {fmt.type}")
|
||||||
|
|
||||||
|
if logprobs and logprobs.top_k:
|
||||||
|
if logprobs.top_k != 1:
|
||||||
|
cprint(
|
||||||
|
"Together only supports logprobs top_k=1. Overriding.",
|
||||||
|
"Yello",
|
||||||
|
)
|
||||||
|
options["logprobs"] = 1
|
||||||
|
|
||||||
return options
|
return options
|
||||||
|
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
|
@ -263,7 +275,9 @@ class TogetherInferenceAdapter(
|
||||||
"model": request.model,
|
"model": request.model,
|
||||||
**input_dict,
|
**input_dict,
|
||||||
"stream": request.stream,
|
"stream": request.stream,
|
||||||
**self._build_options(request.sampling_params, request.response_format),
|
**self._build_options(
|
||||||
|
request.sampling_params, request.logprobs, request.response_format
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import AsyncGenerator, Dict, List, Optional
|
from typing import AsyncGenerator, Dict, List, Optional, Union
|
||||||
|
|
||||||
from llama_models.datatypes import (
|
from llama_models.datatypes import (
|
||||||
GreedySamplingStrategy,
|
GreedySamplingStrategy,
|
||||||
|
@ -121,7 +121,29 @@ def convert_openai_completion_logprobs(
|
||||||
) -> Optional[List[TokenLogProbs]]:
|
) -> Optional[List[TokenLogProbs]]:
|
||||||
if not logprobs:
|
if not logprobs:
|
||||||
return None
|
return None
|
||||||
return [TokenLogProbs(logprobs_by_token=x) for x in logprobs.top_logprobs]
|
if hasattr(logprobs, "top_logprobs"):
|
||||||
|
return [TokenLogProbs(logprobs_by_token=x) for x in logprobs.top_logprobs]
|
||||||
|
|
||||||
|
# Together supports logprobs (top_k=1) but not top_logprobs (top_k>1).
|
||||||
|
if logprobs.tokens and logprobs.token_logprobs:
|
||||||
|
return [
|
||||||
|
TokenLogProbs(logprobs_by_token={token: token_lp})
|
||||||
|
for token, token_lp in zip(logprobs.tokens, logprobs.token_logprobs)
|
||||||
|
]
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def convert_openai_completion_logprobs_stream(
|
||||||
|
text: str, logprobs: Optional[Union[float, OpenAICompatLogprobs]]
|
||||||
|
):
|
||||||
|
if logprobs is None:
|
||||||
|
return None
|
||||||
|
if isinstance(logprobs, float):
|
||||||
|
return [TokenLogProbs(logprobs_by_token={text: logprobs})]
|
||||||
|
if hasattr(logprobs, "top_logprobs"):
|
||||||
|
# Adapt response from Together CompletionChoicesChunk
|
||||||
|
return [TokenLogProbs(logprobs_by_token=x) for x in logprobs.top_logprobs]
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def process_completion_response(
|
def process_completion_response(
|
||||||
|
@ -188,7 +210,7 @@ async def process_completion_stream_response(
|
||||||
yield CompletionResponseStreamChunk(
|
yield CompletionResponseStreamChunk(
|
||||||
delta=text,
|
delta=text,
|
||||||
stop_reason=stop_reason,
|
stop_reason=stop_reason,
|
||||||
logprobs=convert_openai_completion_logprobs(choice.logprobs),
|
logprobs=convert_openai_completion_logprobs_stream(text, choice.logprobs),
|
||||||
)
|
)
|
||||||
if finish_reason:
|
if finish_reason:
|
||||||
if finish_reason in ["stop", "eos", "eos_token"]:
|
if finish_reason in ["stop", "eos", "eos_token"]:
|
||||||
|
|
|
@ -16,6 +16,12 @@ PROVIDER_TOOL_PROMPT_FORMAT = {
|
||||||
"remote::fireworks": "json",
|
"remote::fireworks": "json",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
PROVIDER_LOGPROBS_TOP_K = {
|
||||||
|
"remote::together": 1,
|
||||||
|
"remote::fireworks": 3,
|
||||||
|
# "remote:vllm"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def provider_tool_format(inference_provider_type):
|
def provider_tool_format(inference_provider_type):
|
||||||
|
@ -83,8 +89,13 @@ def test_text_completion_streaming(llama_stack_client, text_model_id):
|
||||||
assert "blue" in "".join(streamed_content).lower().strip()
|
assert "blue" in "".join(streamed_content).lower().strip()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip("Most inference providers don't support log probs yet")
|
def test_completion_log_probs_non_streaming(
|
||||||
def test_completion_log_probs_non_streaming(llama_stack_client, text_model_id):
|
llama_stack_client, text_model_id, inference_provider_type
|
||||||
|
):
|
||||||
|
if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K:
|
||||||
|
pytest.xfail(f"{inference_provider_type} doesn't support log probs yet")
|
||||||
|
|
||||||
|
logprobs_top_k = PROVIDER_LOGPROBS_TOP_K[inference_provider_type]
|
||||||
response = llama_stack_client.inference.completion(
|
response = llama_stack_client.inference.completion(
|
||||||
content="Complete the sentence: Micheael Jordan is born in ",
|
content="Complete the sentence: Micheael Jordan is born in ",
|
||||||
stream=False,
|
stream=False,
|
||||||
|
@ -93,16 +104,24 @@ def test_completion_log_probs_non_streaming(llama_stack_client, text_model_id):
|
||||||
"max_tokens": 5,
|
"max_tokens": 5,
|
||||||
},
|
},
|
||||||
logprobs={
|
logprobs={
|
||||||
"top_k": 3,
|
"top_k": logprobs_top_k,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
assert response.logprobs, "Logprobs should not be empty"
|
assert response.logprobs, "Logprobs should not be empty"
|
||||||
assert 1 <= len(response.logprobs) <= 5
|
assert 1 <= len(response.logprobs) <= 5
|
||||||
assert all(len(logprob.logprobs_by_token) == 3 for logprob in response.logprobs)
|
assert all(
|
||||||
|
len(logprob.logprobs_by_token) == logprobs_top_k
|
||||||
|
for logprob in response.logprobs
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip("Most inference providers don't support log probs yet")
|
def test_completion_log_probs_streaming(
|
||||||
def test_completion_log_probs_streaming(llama_stack_client, text_model_id):
|
llama_stack_client, text_model_id, inference_provider_type
|
||||||
|
):
|
||||||
|
if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K:
|
||||||
|
pytest.xfail(f"{inference_provider_type} doesn't support log probs yet")
|
||||||
|
|
||||||
|
logprobs_top_k = PROVIDER_LOGPROBS_TOP_K[inference_provider_type]
|
||||||
response = llama_stack_client.inference.completion(
|
response = llama_stack_client.inference.completion(
|
||||||
content="Complete the sentence: Micheael Jordan is born in ",
|
content="Complete the sentence: Micheael Jordan is born in ",
|
||||||
stream=True,
|
stream=True,
|
||||||
|
@ -111,7 +130,7 @@ def test_completion_log_probs_streaming(llama_stack_client, text_model_id):
|
||||||
"max_tokens": 5,
|
"max_tokens": 5,
|
||||||
},
|
},
|
||||||
logprobs={
|
logprobs={
|
||||||
"top_k": 3,
|
"top_k": logprobs_top_k,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
streamed_content = [chunk for chunk in response]
|
streamed_content = [chunk for chunk in response]
|
||||||
|
@ -119,7 +138,8 @@ def test_completion_log_probs_streaming(llama_stack_client, text_model_id):
|
||||||
if chunk.delta: # if there's a token, we expect logprobs
|
if chunk.delta: # if there's a token, we expect logprobs
|
||||||
assert chunk.logprobs, "Logprobs should not be empty"
|
assert chunk.logprobs, "Logprobs should not be empty"
|
||||||
assert all(
|
assert all(
|
||||||
len(logprob.logprobs_by_token) == 3 for logprob in chunk.logprobs
|
len(logprob.logprobs_by_token) == logprobs_top_k
|
||||||
|
for logprob in chunk.logprobs
|
||||||
)
|
)
|
||||||
else: # no token, no logprobs
|
else: # no token, no logprobs
|
||||||
assert not chunk.logprobs, "Logprobs should be empty"
|
assert not chunk.logprobs, "Logprobs should be empty"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue