refactor get_max_tokens and build_options

This commit is contained in:
Dinesh Yeduguru 2024-10-23 19:11:04 -07:00
parent 5965ef3979
commit 4a073fcee5
7 changed files with 33 additions and 38 deletions

View file

@ -110,7 +110,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
return await self._nonstream_completion(request)
def _get_params_for_completion(self, request: CompletionRequest) -> dict:
sampling_options = get_sampling_options(request)
sampling_options = get_sampling_options(request.sampling_params)
# This is needed since the Ollama API expects num_predict to be set
# for early truncation instead of max_tokens.
if sampling_options["max_tokens"] is not None:
@ -187,7 +187,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
return {
"model": OLLAMA_SUPPORTED_MODELS[request.model],
"prompt": chat_completion_request_to_prompt(request, self.formatter),
"options": get_sampling_options(request),
"options": get_sampling_options(request.sampling_params),
"raw": True,
"stream": request.stream,
}