forked from phoenix-oss/llama-stack-mirror
fireworks add completion logprobs adapter (#778)
# What does this PR do? - add completion log probs for fireworks ## Test Plan <img width="849" alt="image" src="https://github.com/user-attachments/assets/5aa1f27f-02a6-422c-8478-94dd1e345342" /> ## Sources Please link relevant resources if necessary. ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Ran pre-commit to handle lint / formatting issues. - [ ] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [ ] Wrote necessary unit or integration tests.
This commit is contained in:
parent
05f6b44da7
commit
e239280932
2 changed files with 37 additions and 3 deletions
|
@ -168,7 +168,10 @@ class FireworksInferenceAdapter(
|
|||
yield chunk
|
||||
|
||||
def _build_options(
|
||||
self, sampling_params: Optional[SamplingParams], fmt: ResponseFormat
|
||||
self,
|
||||
sampling_params: Optional[SamplingParams],
|
||||
fmt: ResponseFormat,
|
||||
logprobs: Optional[LogProbConfig],
|
||||
) -> dict:
|
||||
options = get_sampling_options(sampling_params)
|
||||
options.setdefault("max_tokens", 512)
|
||||
|
@ -187,6 +190,11 @@ class FireworksInferenceAdapter(
|
|||
else:
|
||||
raise ValueError(f"Unknown response format {fmt.type}")
|
||||
|
||||
if logprobs and logprobs.top_k:
|
||||
options["logprobs"] = logprobs.top_k
|
||||
if options["logprobs"] <= 0 or options["logprobs"] >= 5:
|
||||
raise ValueError("Required range: 0 < top_k < 5")
|
||||
|
||||
return options
|
||||
|
||||
async def chat_completion(
|
||||
|
@ -280,7 +288,9 @@ class FireworksInferenceAdapter(
|
|||
"model": request.model,
|
||||
**input_dict,
|
||||
"stream": request.stream,
|
||||
**self._build_options(request.sampling_params, request.response_format),
|
||||
**self._build_options(
|
||||
request.sampling_params, request.response_format, request.logprobs
|
||||
),
|
||||
}
|
||||
|
||||
async def embeddings(
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import AsyncGenerator, List, Optional
|
||||
from typing import AsyncGenerator, Dict, List, Optional
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
|
||||
|
@ -34,6 +34,7 @@ from llama_stack.apis.inference import (
|
|||
CompletionResponse,
|
||||
CompletionResponseStreamChunk,
|
||||
Message,
|
||||
TokenLogProbs,
|
||||
)
|
||||
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
|
@ -45,10 +46,21 @@ class OpenAICompatCompletionChoiceDelta(BaseModel):
|
|||
content: str
|
||||
|
||||
|
||||
class OpenAICompatLogprobs(BaseModel):
|
||||
text_offset: Optional[List[int]] = None
|
||||
|
||||
token_logprobs: Optional[List[float]] = None
|
||||
|
||||
tokens: Optional[List[str]] = None
|
||||
|
||||
top_logprobs: Optional[List[Dict[str, float]]] = None
|
||||
|
||||
|
||||
class OpenAICompatCompletionChoice(BaseModel):
|
||||
finish_reason: Optional[str] = None
|
||||
text: Optional[str] = None
|
||||
delta: Optional[OpenAICompatCompletionChoiceDelta] = None
|
||||
logprobs: Optional[OpenAICompatLogprobs] = None
|
||||
|
||||
|
||||
class OpenAICompatCompletionResponse(BaseModel):
|
||||
|
@ -104,6 +116,14 @@ def get_stop_reason(finish_reason: str) -> StopReason:
|
|||
return StopReason.out_of_tokens
|
||||
|
||||
|
||||
def convert_openai_completion_logprobs(
|
||||
logprobs: Optional[OpenAICompatLogprobs],
|
||||
) -> Optional[List[TokenLogProbs]]:
|
||||
if not logprobs:
|
||||
return None
|
||||
return [TokenLogProbs(logprobs_by_token=x) for x in logprobs.top_logprobs]
|
||||
|
||||
|
||||
def process_completion_response(
|
||||
response: OpenAICompatCompletionResponse, formatter: ChatFormat
|
||||
) -> CompletionResponse:
|
||||
|
@ -113,16 +133,19 @@ def process_completion_response(
|
|||
return CompletionResponse(
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
content=choice.text[: -len("<|eot_id|>")],
|
||||
logprobs=convert_openai_completion_logprobs(choice.logprobs),
|
||||
)
|
||||
# drop suffix <eom_id> if present and return stop reason as end of message
|
||||
if choice.text.endswith("<|eom_id|>"):
|
||||
return CompletionResponse(
|
||||
stop_reason=StopReason.end_of_message,
|
||||
content=choice.text[: -len("<|eom_id|>")],
|
||||
logprobs=convert_openai_completion_logprobs(choice.logprobs),
|
||||
)
|
||||
return CompletionResponse(
|
||||
stop_reason=get_stop_reason(choice.finish_reason),
|
||||
content=choice.text,
|
||||
logprobs=convert_openai_completion_logprobs(choice.logprobs),
|
||||
)
|
||||
|
||||
|
||||
|
@ -165,6 +188,7 @@ async def process_completion_stream_response(
|
|||
yield CompletionResponseStreamChunk(
|
||||
delta=text,
|
||||
stop_reason=stop_reason,
|
||||
logprobs=convert_openai_completion_logprobs(choice.logprobs),
|
||||
)
|
||||
if finish_reason:
|
||||
if finish_reason in ["stop", "eos", "eos_token"]:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue