diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index b451f0264..e22144326 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -168,7 +168,10 @@ class FireworksInferenceAdapter( yield chunk def _build_options( - self, sampling_params: Optional[SamplingParams], fmt: ResponseFormat + self, + sampling_params: Optional[SamplingParams], + fmt: ResponseFormat, + logprobs: Optional[LogProbConfig], ) -> dict: options = get_sampling_options(sampling_params) options.setdefault("max_tokens", 512) @@ -187,6 +190,11 @@ class FireworksInferenceAdapter( else: raise ValueError(f"Unknown response format {fmt.type}") + if logprobs and logprobs.top_k: + options["logprobs"] = logprobs.top_k + if options["logprobs"] <= 0 or options["logprobs"] >= 5: + raise ValueError("Required range: 0 < top_k < 5") + return options async def chat_completion( @@ -280,7 +288,9 @@ class FireworksInferenceAdapter( "model": request.model, **input_dict, "stream": request.stream, - **self._build_options(request.sampling_params, request.response_format), + **self._build_options( + request.sampling_params, request.response_format, request.logprobs + ), } async def embeddings( diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index 694212a02..88aca1a1c 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -15,6 +15,7 @@ from llama_models.llama3.api.datatypes import ( TopKSamplingStrategy, TopPSamplingStrategy, ) +from openai.types.completion_choice import Logprobs as OpenAILogprobs from pydantic import BaseModel from llama_stack.apis.common.content_types import ( @@ -34,6 +35,7 @@ from llama_stack.apis.inference import ( CompletionResponse, CompletionResponseStreamChunk, Message, + TokenLogProbs, ) from llama_stack.providers.utils.inference.prompt_adapter import ( @@ -104,6 +106,14 @@ def get_stop_reason(finish_reason: str) -> StopReason: return StopReason.out_of_tokens +def convert_openai_completion_logprobs( + logprobs: Optional[OpenAILogprobs], +) -> Optional[List[TokenLogProbs]]: + if not logprobs: + return None + return [TokenLogProbs(logprobs_by_token=x) for x in logprobs.top_logprobs] + + def process_completion_response( response: OpenAICompatCompletionResponse, formatter: ChatFormat ) -> CompletionResponse: @@ -113,16 +123,19 @@ def process_completion_response( return CompletionResponse( stop_reason=StopReason.end_of_turn, content=choice.text[: -len("<|eot_id|>")], + logprobs=convert_openai_completion_logprobs(choice.logprobs), ) # drop suffix if present and return stop reason as end of message if choice.text.endswith("<|eom_id|>"): return CompletionResponse( stop_reason=StopReason.end_of_message, content=choice.text[: -len("<|eom_id|>")], + logprobs=convert_openai_completion_logprobs(choice.logprobs), ) return CompletionResponse( stop_reason=get_stop_reason(choice.finish_reason), content=choice.text, + logprobs=convert_openai_completion_logprobs(choice.logprobs), )