add completion logprobs

This commit is contained in:
Xi Yan 2025-01-15 16:48:39 -08:00
parent 965644ce68
commit 9c13a7b76b
2 changed files with 25 additions and 2 deletions

View file

@ -168,7 +168,10 @@ class FireworksInferenceAdapter(
yield chunk yield chunk
def _build_options( def _build_options(
self, sampling_params: Optional[SamplingParams], fmt: ResponseFormat self,
sampling_params: Optional[SamplingParams],
fmt: ResponseFormat,
logprobs: Optional[LogProbConfig],
) -> dict: ) -> dict:
options = get_sampling_options(sampling_params) options = get_sampling_options(sampling_params)
options.setdefault("max_tokens", 512) options.setdefault("max_tokens", 512)
@ -187,6 +190,11 @@ class FireworksInferenceAdapter(
else: else:
raise ValueError(f"Unknown response format {fmt.type}") raise ValueError(f"Unknown response format {fmt.type}")
if logprobs and logprobs.top_k:
options["logprobs"] = logprobs.top_k
if options["logprobs"] <= 0 or options["logprobs"] >= 5:
raise ValueError("Required range: 0 < top_k < 5")
return options return options
async def chat_completion( async def chat_completion(
@ -280,7 +288,9 @@ class FireworksInferenceAdapter(
"model": request.model, "model": request.model,
**input_dict, **input_dict,
"stream": request.stream, "stream": request.stream,
**self._build_options(request.sampling_params, request.response_format), **self._build_options(
request.sampling_params, request.response_format, request.logprobs
),
} }
async def embeddings( async def embeddings(

View file

@ -15,6 +15,7 @@ from llama_models.llama3.api.datatypes import (
TopKSamplingStrategy, TopKSamplingStrategy,
TopPSamplingStrategy, TopPSamplingStrategy,
) )
from openai.types.completion_choice import Logprobs as OpenAILogprobs
from pydantic import BaseModel from pydantic import BaseModel
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
@ -34,6 +35,7 @@ from llama_stack.apis.inference import (
CompletionResponse, CompletionResponse,
CompletionResponseStreamChunk, CompletionResponseStreamChunk,
Message, Message,
TokenLogProbs,
) )
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
@ -104,6 +106,14 @@ def get_stop_reason(finish_reason: str) -> StopReason:
return StopReason.out_of_tokens return StopReason.out_of_tokens
def convert_openai_completion_logprobs(
logprobs: Optional[OpenAILogprobs],
) -> Optional[List[TokenLogProbs]]:
if not logprobs:
return None
return [TokenLogProbs(logprobs_by_token=x) for x in logprobs.top_logprobs]
def process_completion_response( def process_completion_response(
response: OpenAICompatCompletionResponse, formatter: ChatFormat response: OpenAICompatCompletionResponse, formatter: ChatFormat
) -> CompletionResponse: ) -> CompletionResponse:
@ -113,16 +123,19 @@ def process_completion_response(
return CompletionResponse( return CompletionResponse(
stop_reason=StopReason.end_of_turn, stop_reason=StopReason.end_of_turn,
content=choice.text[: -len("<|eot_id|>")], content=choice.text[: -len("<|eot_id|>")],
logprobs=convert_openai_completion_logprobs(choice.logprobs),
) )
# drop suffix <eom_id> if present and return stop reason as end of message # drop suffix <eom_id> if present and return stop reason as end of message
if choice.text.endswith("<|eom_id|>"): if choice.text.endswith("<|eom_id|>"):
return CompletionResponse( return CompletionResponse(
stop_reason=StopReason.end_of_message, stop_reason=StopReason.end_of_message,
content=choice.text[: -len("<|eom_id|>")], content=choice.text[: -len("<|eom_id|>")],
logprobs=convert_openai_completion_logprobs(choice.logprobs),
) )
return CompletionResponse( return CompletionResponse(
stop_reason=get_stop_reason(choice.finish_reason), stop_reason=get_stop_reason(choice.finish_reason),
content=choice.text, content=choice.text,
logprobs=convert_openai_completion_logprobs(choice.logprobs),
) )