mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-06 02:32:40 +00:00
add completion logprobs
This commit is contained in:
parent
965644ce68
commit
9c13a7b76b
2 changed files with 25 additions and 2 deletions
|
@ -168,7 +168,10 @@ class FireworksInferenceAdapter(
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
def _build_options(
|
def _build_options(
|
||||||
self, sampling_params: Optional[SamplingParams], fmt: ResponseFormat
|
self,
|
||||||
|
sampling_params: Optional[SamplingParams],
|
||||||
|
fmt: ResponseFormat,
|
||||||
|
logprobs: Optional[LogProbConfig],
|
||||||
) -> dict:
|
) -> dict:
|
||||||
options = get_sampling_options(sampling_params)
|
options = get_sampling_options(sampling_params)
|
||||||
options.setdefault("max_tokens", 512)
|
options.setdefault("max_tokens", 512)
|
||||||
|
@ -187,6 +190,11 @@ class FireworksInferenceAdapter(
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown response format {fmt.type}")
|
raise ValueError(f"Unknown response format {fmt.type}")
|
||||||
|
|
||||||
|
if logprobs and logprobs.top_k:
|
||||||
|
options["logprobs"] = logprobs.top_k
|
||||||
|
if options["logprobs"] <= 0 or options["logprobs"] >= 5:
|
||||||
|
raise ValueError("Required range: 0 < top_k < 5")
|
||||||
|
|
||||||
return options
|
return options
|
||||||
|
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
|
@ -280,7 +288,9 @@ class FireworksInferenceAdapter(
|
||||||
"model": request.model,
|
"model": request.model,
|
||||||
**input_dict,
|
**input_dict,
|
||||||
"stream": request.stream,
|
"stream": request.stream,
|
||||||
**self._build_options(request.sampling_params, request.response_format),
|
**self._build_options(
|
||||||
|
request.sampling_params, request.response_format, request.logprobs
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
|
|
|
@ -15,6 +15,7 @@ from llama_models.llama3.api.datatypes import (
|
||||||
TopKSamplingStrategy,
|
TopKSamplingStrategy,
|
||||||
TopPSamplingStrategy,
|
TopPSamplingStrategy,
|
||||||
)
|
)
|
||||||
|
from openai.types.completion_choice import Logprobs as OpenAILogprobs
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
|
@ -34,6 +35,7 @@ from llama_stack.apis.inference import (
|
||||||
CompletionResponse,
|
CompletionResponse,
|
||||||
CompletionResponseStreamChunk,
|
CompletionResponseStreamChunk,
|
||||||
Message,
|
Message,
|
||||||
|
TokenLogProbs,
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
@ -104,6 +106,14 @@ def get_stop_reason(finish_reason: str) -> StopReason:
|
||||||
return StopReason.out_of_tokens
|
return StopReason.out_of_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def convert_openai_completion_logprobs(
|
||||||
|
logprobs: Optional[OpenAILogprobs],
|
||||||
|
) -> Optional[List[TokenLogProbs]]:
|
||||||
|
if not logprobs:
|
||||||
|
return None
|
||||||
|
return [TokenLogProbs(logprobs_by_token=x) for x in logprobs.top_logprobs]
|
||||||
|
|
||||||
|
|
||||||
def process_completion_response(
|
def process_completion_response(
|
||||||
response: OpenAICompatCompletionResponse, formatter: ChatFormat
|
response: OpenAICompatCompletionResponse, formatter: ChatFormat
|
||||||
) -> CompletionResponse:
|
) -> CompletionResponse:
|
||||||
|
@ -113,16 +123,19 @@ def process_completion_response(
|
||||||
return CompletionResponse(
|
return CompletionResponse(
|
||||||
stop_reason=StopReason.end_of_turn,
|
stop_reason=StopReason.end_of_turn,
|
||||||
content=choice.text[: -len("<|eot_id|>")],
|
content=choice.text[: -len("<|eot_id|>")],
|
||||||
|
logprobs=convert_openai_completion_logprobs(choice.logprobs),
|
||||||
)
|
)
|
||||||
# drop suffix <eom_id> if present and return stop reason as end of message
|
# drop suffix <eom_id> if present and return stop reason as end of message
|
||||||
if choice.text.endswith("<|eom_id|>"):
|
if choice.text.endswith("<|eom_id|>"):
|
||||||
return CompletionResponse(
|
return CompletionResponse(
|
||||||
stop_reason=StopReason.end_of_message,
|
stop_reason=StopReason.end_of_message,
|
||||||
content=choice.text[: -len("<|eom_id|>")],
|
content=choice.text[: -len("<|eom_id|>")],
|
||||||
|
logprobs=convert_openai_completion_logprobs(choice.logprobs),
|
||||||
)
|
)
|
||||||
return CompletionResponse(
|
return CompletionResponse(
|
||||||
stop_reason=get_stop_reason(choice.finish_reason),
|
stop_reason=get_stop_reason(choice.finish_reason),
|
||||||
content=choice.text,
|
content=choice.text,
|
||||||
|
logprobs=convert_openai_completion_logprobs(choice.logprobs),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue