mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-01 16:24:44 +00:00
second try
This commit is contained in:
parent
31a15332c4
commit
1cb9d3bca2
11 changed files with 237 additions and 64 deletions
|
@ -273,6 +273,7 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper, Models
|
|||
response_format: ResponseFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
suffix: str | None = None,
|
||||
) -> CompletionResponse | AsyncIterator[CompletionResponseStreamChunk]:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
|
@ -293,6 +294,7 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper, Models
|
|||
response_format=response_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
suffix=suffix,
|
||||
),
|
||||
n=1,
|
||||
)
|
||||
|
|
|
@ -155,7 +155,8 @@ def convert_completion_request(
|
|||
|
||||
if request.logprobs:
|
||||
payload.update(logprobs=request.logprobs.top_k)
|
||||
|
||||
if request.suffix:
|
||||
payload.update(suffix=request.suffix)
|
||||
if request.sampling_params:
|
||||
nvext.update(repetition_penalty=request.sampling_params.repetition_penalty)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue