forked from phoenix-oss/llama-stack-mirror
This is just like `local` using `meta-reference` for everything except it uses `vllm` for inference. Docker works, but So far, `conda` is a bit easier to use with the vllm provider. The default container base image does not include all the necessary libraries for all vllm features. More cuda dependencies are necessary. I started changing this base image used in this template, but it also required changes to the Dockerfile, so it was getting too involved to include in the first PR. Working so far: * `python -m llama_stack.apis.inference.client localhost 5000 --model Llama3.2-1B-Instruct --stream True` * `python -m llama_stack.apis.inference.client localhost 5000 --model Llama3.2-1B-Instruct --stream False` Example: ``` $ python -m llama_stack.apis.inference.client localhost 5000 --model Llama3.2-1B-Instruct --stream False User>hello world, write me a 2 sentence poem about the moon Assistant> The moon glows bright in the midnight sky A beacon of light, ``` I have only tested these models: * `Llama3.1-8B-Instruct` - across 4 GPUs (tensor_parallel_size = 4) * `Llama3.2-1B-Instruct` - on a single GPU (tensor_parallel_size = 1)
35 lines
1.1 KiB
Python
35 lines
1.1 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
from llama_models.schema_utils import json_schema_type
|
|
from pydantic import BaseModel, Field, field_validator
|
|
|
|
from llama_stack.providers.utils.inference import supported_inference_models
|
|
|
|
|
|
@json_schema_type
|
|
class VLLMConfig(BaseModel):
|
|
"""Configuration for the vLLM inference provider."""
|
|
|
|
model: str = Field(
|
|
default="Llama3.1-8B-Instruct",
|
|
description="Model descriptor from `llama model list`",
|
|
)
|
|
tensor_parallel_size: int = Field(
|
|
default=1,
|
|
description="Number of tensor parallel replicas (number of GPUs to use).",
|
|
)
|
|
|
|
@field_validator("model")
|
|
@classmethod
|
|
def validate_model(cls, model: str) -> str:
|
|
permitted_models = supported_inference_models()
|
|
if model not in permitted_models:
|
|
model_list = "\n\t".join(permitted_models)
|
|
raise ValueError(
|
|
f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]"
|
|
)
|
|
return model
|