diff --git a/llama_stack/providers/adapters/inference/vllm/vllm.py b/llama_stack/providers/adapters/inference/vllm/vllm.py index ce8ba223d..9df94d94d 100644 --- a/llama_stack/providers/adapters/inference/vllm/vllm.py +++ b/llama_stack/providers/adapters/inference/vllm/vllm.py @@ -15,12 +15,16 @@ from llama_models.sku_list import resolve_model from openai import OpenAI from llama_stack.apis.inference import * # noqa: F403 -from llama_stack.providers.utils.inference.prepare_messages import prepare_messages +from llama_stack.providers.utils.inference.augment_messages import augment_messages_for_tools from .config import VLLMImplConfig -# TODO -VLLM_SUPPORTED_MODELS = {} +# Reference: https://docs.vllm.ai/en/latest/models/supported_models.html +VLLM_SUPPORTED_MODELS = { + "Llama3.1-8B-Instruct": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "Llama3.1-70B-Instruct": "meta-llama/Meta-Llama-3.1-70B-Instruct", + "Llama3.1-405B-Instruct": "meta-llama/Meta-Llama-3.1-405B-Instruct", +} class VLLMInferenceAdapter(Inference): @@ -70,7 +74,10 @@ class VLLMInferenceAdapter(Inference): def get_vllm_chat_options(self, request: ChatCompletionRequest) -> dict: options = {} - # TODO + if request.sampling_params is not None: + for attr in {"temperature", "top_p", "top_k", "max_tokens"}: + if getattr(request.sampling_params, attr): + options[attr] = getattr(request.sampling_params, attr) return options async def chat_completion( @@ -99,7 +106,7 @@ class VLLMInferenceAdapter(Inference): # accumulate sampling params and other options to pass to vLLM options = self.get_vllm_chat_options(request) vllm_model = self.resolve_vllm_model(request.model) - messages = prepare_messages(request) + messages = augment_messages_for_tools(request) model_input = self.formatter.encode_dialog_prompt(messages) input_tokens = len(model_input.tokens)