diff --git a/llama_stack/providers/impls/vllm/config.py b/llama_stack/providers/impls/vllm/config.py index df2526f2e..a7469ebde 100644 --- a/llama_stack/providers/impls/vllm/config.py +++ b/llama_stack/providers/impls/vllm/config.py @@ -15,13 +15,24 @@ class VLLMConfig(BaseModel): """Configuration for the vLLM inference provider.""" model: str = Field( - default="Llama3.1-8B-Instruct", + default="Llama3.2-3B-Instruct", description="Model descriptor from `llama model list`", ) tensor_parallel_size: int = Field( default=1, description="Number of tensor parallel replicas (number of GPUs to use).", ) + max_tokens: int = Field( + default=4096, + description="Maximum number of tokens to generate.", + ) + enforce_eager: bool = Field( + default=False, + description="Whether to use eager mode for inference (otherwise cuda graphs are used).", + ) + gpu_memory_utilization: float = Field( + default=0.3, + ) @field_validator("model") @classmethod diff --git a/llama_stack/providers/impls/vllm/vllm.py b/llama_stack/providers/impls/vllm/vllm.py index ad3ad8fb7..cf5b0572b 100644 --- a/llama_stack/providers/impls/vllm/vllm.py +++ b/llama_stack/providers/impls/vllm/vllm.py @@ -7,11 +7,12 @@ import logging import os import uuid -from typing import Any, AsyncGenerator +from typing import AsyncGenerator, Optional from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.tokenizer import Tokenizer +from llama_models.sku_list import resolve_model from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine @@ -19,7 +20,7 @@ from vllm.sampling_params import SamplingParams as VLLMSamplingParams from llama_stack.apis.inference import * # noqa: F403 -from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper +from llama_stack.providers.datatypes import ModelDef, ModelsProtocolPrivate from llama_stack.providers.utils.inference.openai_compat import ( OpenAICompatCompletionChoice, OpenAICompatCompletionResponse, @@ -40,74 +41,15 @@ def _random_uuid() -> str: return str(uuid.uuid4().hex) -def _vllm_sampling_params(sampling_params: Any) -> VLLMSamplingParams: - """Convert sampling params to vLLM sampling params.""" - if sampling_params is None: - return VLLMSamplingParams() - - # TODO convert what I saw in my first test ... but surely there's more to do here - kwargs = { - "temperature": sampling_params.temperature, - } - if sampling_params.top_k >= 1: - kwargs["top_k"] = sampling_params.top_k - if sampling_params.top_p: - kwargs["top_p"] = sampling_params.top_p - if sampling_params.max_tokens >= 1: - kwargs["max_tokens"] = sampling_params.max_tokens - if sampling_params.repetition_penalty > 0: - kwargs["repetition_penalty"] = sampling_params.repetition_penalty - - return VLLMSamplingParams(**kwargs) - - -class VLLMInferenceImpl(ModelRegistryHelper, Inference): +class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): """Inference implementation for vLLM.""" - HF_MODEL_MAPPINGS = { - # TODO: seems like we should be able to build this table dynamically ... - "Llama3.1-8B": "meta-llama/Llama-3.1-8B", - "Llama3.1-70B": "meta-llama/Llama-3.1-70B", - "Llama3.1-405B:bf16-mp8": "meta-llama/Llama-3.1-405B", - "Llama3.1-405B": "meta-llama/Llama-3.1-405B-FP8", - "Llama3.1-405B:bf16-mp16": "meta-llama/Llama-3.1-405B", - "Llama3.1-8B-Instruct": "meta-llama/Llama-3.1-8B-Instruct", - "Llama3.1-70B-Instruct": "meta-llama/Llama-3.1-70B-Instruct", - "Llama3.1-405B-Instruct:bf16-mp8": "meta-llama/Llama-3.1-405B-Instruct", - "Llama3.1-405B-Instruct": "meta-llama/Llama-3.1-405B-Instruct-FP8", - "Llama3.1-405B-Instruct:bf16-mp16": "meta-llama/Llama-3.1-405B-Instruct", - "Llama3.2-1B": "meta-llama/Llama-3.2-1B", - "Llama3.2-3B": "meta-llama/Llama-3.2-3B", - "Llama3.2-11B-Vision": "meta-llama/Llama-3.2-11B-Vision", - "Llama3.2-90B-Vision": "meta-llama/Llama-3.2-90B-Vision", - "Llama3.2-1B-Instruct": "meta-llama/Llama-3.2-1B-Instruct", - "Llama3.2-3B-Instruct": "meta-llama/Llama-3.2-3B-Instruct", - "Llama3.2-11B-Vision-Instruct": "meta-llama/Llama-3.2-11B-Vision-Instruct", - "Llama3.2-90B-Vision-Instruct": "meta-llama/Llama-3.2-90B-Vision-Instruct", - "Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision", - "Llama-Guard-3-1B:int4-mp1": "meta-llama/Llama-Guard-3-1B-INT4", - "Llama-Guard-3-1B": "meta-llama/Llama-Guard-3-1B", - "Llama-Guard-3-8B": "meta-llama/Llama-Guard-3-8B", - "Llama-Guard-3-8B:int8-mp1": "meta-llama/Llama-Guard-3-8B-INT8", - "Prompt-Guard-86M": "meta-llama/Prompt-Guard-86M", - "Llama-Guard-2-8B": "meta-llama/Llama-Guard-2-8B", - } - def __init__(self, config: VLLMConfig): - Inference.__init__(self) - ModelRegistryHelper.__init__( - self, - stack_to_provider_models_map=self.HF_MODEL_MAPPINGS, - ) self.config = config self.engine = None - - tokenizer = Tokenizer.get_instance() - self.formatter = ChatFormat(tokenizer) + self.formatter = ChatFormat(Tokenizer.get_instance()) async def initialize(self): - """Initialize the vLLM inference adapter.""" - log.info("Initializing vLLM inference adapter") # Disable usage stats reporting. This would be a surprising thing for most @@ -116,15 +58,22 @@ class VLLMInferenceImpl(ModelRegistryHelper, Inference): if "VLLM_NO_USAGE_STATS" not in os.environ: os.environ["VLLM_NO_USAGE_STATS"] = "1" - hf_model = self.HF_MODEL_MAPPINGS.get(self.config.model) + model = resolve_model(self.config.model) + if model is None: + raise ValueError(f"Unknown model {self.config.model}") + + if model.huggingface_repo is None: + raise ValueError(f"Model {self.config.model} needs a huggingface repo") # TODO -- there are a ton of options supported here ... - engine_args = AsyncEngineArgs() - engine_args.model = hf_model - # We will need a new config item for this in the future if model support is more broad - # than it is today (llama only) - engine_args.tokenizer = hf_model - engine_args.tensor_parallel_size = self.config.tensor_parallel_size + engine_args = AsyncEngineArgs( + model=model.huggingface_repo, + tokenizer=model.huggingface_repo, + tensor_parallel_size=self.config.tensor_parallel_size, + enforce_eager=self.config.enforce_eager, + gpu_memory_utilization=self.config.gpu_memory_utilization, + guided_decoding_backend="lm-format-enforcer", + ) self.engine = AsyncLLMEngine.from_engine_args(engine_args) @@ -134,13 +83,47 @@ class VLLMInferenceImpl(ModelRegistryHelper, Inference): if self.engine: self.engine.shutdown_background_loop() + async def register_model(self, model: ModelDef) -> None: + raise ValueError( + "You cannot dynamically add a model to a running vllm instance" + ) + + async def list_models(self) -> List[ModelDef]: + return [ + ModelDef( + identifier=self.config.model, + llama_model=self.config.model, + ) + ] + + def _sampling_params(self, sampling_params: SamplingParams) -> VLLMSamplingParams: + if sampling_params is None: + return VLLMSamplingParams(max_tokens=self.config.max_tokens) + + # TODO convert what I saw in my first test ... but surely there's more to do here + kwargs = { + "temperature": sampling_params.temperature, + "max_tokens": self.config.max_tokens, + } + if sampling_params.top_k: + kwargs["top_k"] = sampling_params.top_k + if sampling_params.top_p: + kwargs["top_p"] = sampling_params.top_p + if sampling_params.max_tokens: + kwargs["max_tokens"] = sampling_params.max_tokens + if sampling_params.repetition_penalty > 0: + kwargs["repetition_penalty"] = sampling_params.repetition_penalty + + return VLLMSamplingParams(**kwargs) + async def completion( self, model: str, content: InterleavedTextMedia, - sampling_params: Any | None = ..., - stream: bool | None = False, - logprobs: LogProbConfig | None = None, + sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, ) -> CompletionResponse | CompletionResponseStreamChunk: log.info("vLLM completion") messages = [UserMessage(content=content)] @@ -155,13 +138,14 @@ class VLLMInferenceImpl(ModelRegistryHelper, Inference): async def chat_completion( self, model: str, - messages: list[Message], - sampling_params: Any | None = ..., - tools: list[ToolDefinition] | None = ..., - tool_choice: ToolChoice | None = ..., - tool_prompt_format: ToolPromptFormat | None = ..., - stream: bool | None = False, - logprobs: LogProbConfig | None = None, + messages: List[Message], + sampling_params: Optional[SamplingParams] = SamplingParams(), + tools: Optional[List[ToolDefinition]] = None, + tool_choice: Optional[ToolChoice] = ToolChoice.auto, + tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, + response_format: Optional[ResponseFormat] = None, + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, ) -> ChatCompletionResponse | ChatCompletionResponseStreamChunk: log.info("vLLM chat completion") @@ -182,7 +166,7 @@ class VLLMInferenceImpl(ModelRegistryHelper, Inference): request_id = _random_uuid() prompt = chat_completion_request_to_prompt(request, self.formatter) - vllm_sampling_params = _vllm_sampling_params(request.sampling_params) + vllm_sampling_params = self._sampling_params(request.sampling_params) results_generator = self.engine.generate( prompt, vllm_sampling_params, request_id ) @@ -213,14 +197,19 @@ class VLLMInferenceImpl(ModelRegistryHelper, Inference): self, request: ChatCompletionRequest, results_generator: AsyncGenerator ) -> AsyncGenerator: async def _generate_and_convert_to_openai_compat(): + cur = [] async for chunk in results_generator: if not chunk.outputs: log.warning("Empty chunk received") continue - text = "".join([output.text for output in chunk.outputs]) + output = chunk.outputs[-1] + + new_tokens = output.token_ids[len(cur) :] + text = self.formatter.tokenizer.decode(new_tokens) + cur.extend(new_tokens) choice = OpenAICompatCompletionChoice( - finish_reason=chunk.outputs[-1].stop_reason, + finish_reason=output.finish_reason, text=text, ) yield OpenAICompatCompletionResponse(