fireworks add completion logprobs adapter (#778)

# What does this PR do?

- add completion log probs for fireworks

## Test Plan

<img width="849" alt="image"
src="https://github.com/user-attachments/assets/5aa1f27f-02a6-422c-8478-94dd1e345342"
/>


## Sources

Please link relevant resources if necessary.


## Before submitting

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Ran pre-commit to handle lint / formatting issues.
- [ ] Read the [contributor
guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md),
      Pull Request section?
- [ ] Updated relevant documentation.
- [ ] Wrote necessary unit or integration tests.
This commit is contained in:
Xi Yan 2025-01-16 10:37:07 -08:00 committed by GitHub
parent 05f6b44da7
commit e239280932
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 37 additions and 3 deletions

View file

@ -168,7 +168,10 @@ class FireworksInferenceAdapter(
yield chunk
def _build_options(
self, sampling_params: Optional[SamplingParams], fmt: ResponseFormat
self,
sampling_params: Optional[SamplingParams],
fmt: ResponseFormat,
logprobs: Optional[LogProbConfig],
) -> dict:
options = get_sampling_options(sampling_params)
options.setdefault("max_tokens", 512)
@ -187,6 +190,11 @@ class FireworksInferenceAdapter(
else:
raise ValueError(f"Unknown response format {fmt.type}")
if logprobs and logprobs.top_k:
options["logprobs"] = logprobs.top_k
if options["logprobs"] <= 0 or options["logprobs"] >= 5:
raise ValueError("Required range: 0 < top_k < 5")
return options
async def chat_completion(
@ -280,7 +288,9 @@ class FireworksInferenceAdapter(
"model": request.model,
**input_dict,
"stream": request.stream,
**self._build_options(request.sampling_params, request.response_format),
**self._build_options(
request.sampling_params, request.response_format, request.logprobs
),
}
async def embeddings(