Update Strategy in SamplingParams to be a union

This commit is contained in:
Hardik Shah 2025-01-14 15:56:02 -08:00 committed by Ashwin Bharambe
parent 300e6e2702
commit dea575c994
28 changed files with 600 additions and 377 deletions

View file

@ -36,6 +36,7 @@ from llama_stack.apis.inference import (
from llama_stack.apis.models import Model
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse,
process_chat_completion_response,
@ -126,21 +127,12 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
if sampling_params is None:
return VLLMSamplingParams(max_tokens=self.config.max_tokens)
# TODO convert what I saw in my first test ... but surely there's more to do here
kwargs = {
"temperature": sampling_params.temperature,
"max_tokens": self.config.max_tokens,
}
if sampling_params.top_k:
kwargs["top_k"] = sampling_params.top_k
if sampling_params.top_p:
kwargs["top_p"] = sampling_params.top_p
if sampling_params.max_tokens:
kwargs["max_tokens"] = sampling_params.max_tokens
if sampling_params.repetition_penalty > 0:
kwargs["repetition_penalty"] = sampling_params.repetition_penalty
options = get_sampling_options(sampling_params)
if "repeat_penalty" in options:
options["repetition_penalty"] = options["repeat_penalty"]
del options["repeat_penalty"]
return VLLMSamplingParams(**kwargs)
return VLLMSamplingParams(**options)
async def unregister_model(self, model_id: str) -> None:
pass