From e23928093233c46fab071f4445bf59487c05cd43 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 16 Jan 2025 10:37:07 -0800 Subject: [PATCH] fireworks add completion logprobs adapter (#778) # What does this PR do? - add completion log probs for fireworks ## Test Plan image ## Sources Please link relevant resources if necessary. ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Ran pre-commit to handle lint / formatting issues. - [ ] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [ ] Wrote necessary unit or integration tests. --- .../remote/inference/fireworks/fireworks.py | 14 ++++++++-- .../utils/inference/openai_compat.py | 26 ++++++++++++++++++- 2 files changed, 37 insertions(+), 3 deletions(-) diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index b451f0264..e22144326 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -168,7 +168,10 @@ class FireworksInferenceAdapter( yield chunk def _build_options( - self, sampling_params: Optional[SamplingParams], fmt: ResponseFormat + self, + sampling_params: Optional[SamplingParams], + fmt: ResponseFormat, + logprobs: Optional[LogProbConfig], ) -> dict: options = get_sampling_options(sampling_params) options.setdefault("max_tokens", 512) @@ -187,6 +190,11 @@ class FireworksInferenceAdapter( else: raise ValueError(f"Unknown response format {fmt.type}") + if logprobs and logprobs.top_k: + options["logprobs"] = logprobs.top_k + if options["logprobs"] <= 0 or options["logprobs"] >= 5: + raise ValueError("Required range: 0 < top_k < 5") + return options async def chat_completion( @@ -280,7 +288,9 @@ class FireworksInferenceAdapter( "model": request.model, **input_dict, "stream": request.stream, - **self._build_options(request.sampling_params, request.response_format), + **self._build_options( + request.sampling_params, request.response_format, request.logprobs + ), } async def embeddings( diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index 694212a02..f6350ed51 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import AsyncGenerator, List, Optional +from typing import AsyncGenerator, Dict, List, Optional from llama_models.llama3.api.chat_format import ChatFormat @@ -34,6 +34,7 @@ from llama_stack.apis.inference import ( CompletionResponse, CompletionResponseStreamChunk, Message, + TokenLogProbs, ) from llama_stack.providers.utils.inference.prompt_adapter import ( @@ -45,10 +46,21 @@ class OpenAICompatCompletionChoiceDelta(BaseModel): content: str +class OpenAICompatLogprobs(BaseModel): + text_offset: Optional[List[int]] = None + + token_logprobs: Optional[List[float]] = None + + tokens: Optional[List[str]] = None + + top_logprobs: Optional[List[Dict[str, float]]] = None + + class OpenAICompatCompletionChoice(BaseModel): finish_reason: Optional[str] = None text: Optional[str] = None delta: Optional[OpenAICompatCompletionChoiceDelta] = None + logprobs: Optional[OpenAICompatLogprobs] = None class OpenAICompatCompletionResponse(BaseModel): @@ -104,6 +116,14 @@ def get_stop_reason(finish_reason: str) -> StopReason: return StopReason.out_of_tokens +def convert_openai_completion_logprobs( + logprobs: Optional[OpenAICompatLogprobs], +) -> Optional[List[TokenLogProbs]]: + if not logprobs: + return None + return [TokenLogProbs(logprobs_by_token=x) for x in logprobs.top_logprobs] + + def process_completion_response( response: OpenAICompatCompletionResponse, formatter: ChatFormat ) -> CompletionResponse: @@ -113,16 +133,19 @@ def process_completion_response( return CompletionResponse( stop_reason=StopReason.end_of_turn, content=choice.text[: -len("<|eot_id|>")], + logprobs=convert_openai_completion_logprobs(choice.logprobs), ) # drop suffix if present and return stop reason as end of message if choice.text.endswith("<|eom_id|>"): return CompletionResponse( stop_reason=StopReason.end_of_message, content=choice.text[: -len("<|eom_id|>")], + logprobs=convert_openai_completion_logprobs(choice.logprobs), ) return CompletionResponse( stop_reason=get_stop_reason(choice.finish_reason), content=choice.text, + logprobs=convert_openai_completion_logprobs(choice.logprobs), ) @@ -165,6 +188,7 @@ async def process_completion_stream_response( yield CompletionResponseStreamChunk( delta=text, stop_reason=stop_reason, + logprobs=convert_openai_completion_logprobs(choice.logprobs), ) if finish_reason: if finish_reason in ["stop", "eos", "eos_token"]: