completion() for tgi (#295)

This commit is contained in:
Dinesh Yeduguru 2024-10-24 16:02:41 -07:00 committed by GitHub
parent cb84034567
commit 3e1c3fdb3f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 173 additions and 35 deletions

View file

@ -110,7 +110,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
return await self._nonstream_completion(request)
def _get_params_for_completion(self, request: CompletionRequest) -> dict:
sampling_options = get_sampling_options(request)
sampling_options = get_sampling_options(request.sampling_params)
# This is needed since the Ollama API expects num_predict to be set
# for early truncation instead of max_tokens.
if sampling_options["max_tokens"] is not None:
@ -187,7 +187,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
return {
"model": OLLAMA_SUPPORTED_MODELS[request.model],
"prompt": chat_completion_request_to_prompt(request, self.formatter),
"options": get_sampling_options(request),
"options": get_sampling_options(request.sampling_params),
"raw": True,
"stream": request.stream,
}