diff --git a/llama_stack/providers/adapters/inference/vllm/__init__.py b/llama_stack/providers/adapters/inference/vllm/__init__.py index f4588a307..78222d7d9 100644 --- a/llama_stack/providers/adapters/inference/vllm/__init__.py +++ b/llama_stack/providers/adapters/inference/vllm/__init__.py @@ -4,12 +4,15 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from .config import VLLMImplConfig -from .vllm import VLLMInferenceAdapter +from .config import VLLMInferenceAdapterConfig -async def get_adapter_impl(config: VLLMImplConfig, _deps): - assert isinstance(config, VLLMImplConfig), f"Unexpected config type: {type(config)}" +async def get_adapter_impl(config: VLLMInferenceAdapterConfig, _deps): + from .vllm import VLLMInferenceAdapter + + assert isinstance( + config, VLLMInferenceAdapterConfig + ), f"Unexpected config type: {type(config)}" impl = VLLMInferenceAdapter(config) await impl.initialize() return impl diff --git a/llama_stack/providers/adapters/inference/vllm/config.py b/llama_stack/providers/adapters/inference/vllm/config.py index 65815922c..50a174589 100644 --- a/llama_stack/providers/adapters/inference/vllm/config.py +++ b/llama_stack/providers/adapters/inference/vllm/config.py @@ -11,12 +11,16 @@ from pydantic import BaseModel, Field @json_schema_type -class VLLMImplConfig(BaseModel): +class VLLMInferenceAdapterConfig(BaseModel): url: Optional[str] = Field( default=None, description="The URL for the vLLM model serving endpoint", ) + max_tokens: int = Field( + default=4096, + description="Maximum number of tokens to generate.", + ) api_token: Optional[str] = Field( - default=None, + default="fake", description="The API token", ) diff --git a/llama_stack/providers/adapters/inference/vllm/vllm.py b/llama_stack/providers/adapters/inference/vllm/vllm.py index aad2fdc1f..09c17ee57 100644 --- a/llama_stack/providers/adapters/inference/vllm/vllm.py +++ b/llama_stack/providers/adapters/inference/vllm/vllm.py @@ -23,7 +23,8 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( chat_completion_request_to_prompt, ) -from .config import VLLMImplConfig +from .config import VLLMInferenceAdapterConfig + VLLM_SUPPORTED_MODELS = { "Llama3.1-8B": "meta-llama/Llama-3.1-8B", @@ -55,7 +56,7 @@ VLLM_SUPPORTED_MODELS = { class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): - def __init__(self, config: VLLMImplConfig) -> None: + def __init__(self, config: VLLMInferenceAdapterConfig) -> None: self.config = config self.formatter = ChatFormat(Tokenizer.get_instance()) self.client = None @@ -70,10 +71,22 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): pass async def list_models(self) -> List[ModelDef]: - return [ - ModelDef(identifier=model.id, llama_model=model.id) - for model in self.client.models.list() - ] + vllm_to_llama_map = {v: k for k, v in VLLM_SUPPORTED_MODELS.items()} + + models = [] + for model in self.client.models.list(): + if model.id not in vllm_to_llama_map: + print(f"Unknown model served by vllm: {model.id}") + continue + + identifier = vllm_to_llama_map[model.id] + models.append( + ModelDef( + identifier=identifier, + llama_model=identifier, + ) + ) + return models async def completion( self, @@ -118,7 +131,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): ) -> ChatCompletionResponse: params = self._get_params(request) r = client.completions.create(**params) - return process_chat_completion_response(request, r, self.formatter) + return process_chat_completion_response(r, self.formatter) async def _stream_chat_completion( self, request: ChatCompletionRequest, client: OpenAI @@ -139,11 +152,14 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): yield chunk def _get_params(self, request: ChatCompletionRequest) -> dict: + options = get_sampling_options(request.sampling_params) + if "max_tokens" not in options: + options["max_tokens"] = self.config.max_tokens return { "model": VLLM_SUPPORTED_MODELS[request.model], "prompt": chat_completion_request_to_prompt(request, self.formatter), "stream": request.stream, - **get_sampling_options(request.sampling_params), + **options, } async def embeddings( diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index 88265f1b4..717ff78a8 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -61,15 +61,15 @@ def available_providers() -> List[ProviderSpec]: module="llama_stack.providers.adapters.inference.ollama", ), ), - # remote_provider_spec( - # api=Api.inference, - # adapter=AdapterSpec( - # adapter_type="vllm", - # pip_packages=["openai"], - # module="llama_stack.providers.adapters.inference.vllm", - # config_class="llama_stack.providers.adapters.inference.vllm.VLLMImplConfig", - # ), - # ), + remote_provider_spec( + api=Api.inference, + adapter=AdapterSpec( + adapter_type="vllm", + pip_packages=["openai"], + module="llama_stack.providers.adapters.inference.vllm", + config_class="llama_stack.providers.adapters.inference.vllm.VLLMInferenceAdapterConfig", + ), + ), remote_provider_spec( api=Api.inference, adapter=AdapterSpec( diff --git a/llama_stack/providers/tests/inference/fixtures.py b/llama_stack/providers/tests/inference/fixtures.py index 896acbad8..acff151cf 100644 --- a/llama_stack/providers/tests/inference/fixtures.py +++ b/llama_stack/providers/tests/inference/fixtures.py @@ -14,6 +14,7 @@ from llama_stack.distribution.datatypes import Api, Provider from llama_stack.providers.adapters.inference.fireworks import FireworksImplConfig from llama_stack.providers.adapters.inference.ollama import OllamaImplConfig from llama_stack.providers.adapters.inference.together import TogetherImplConfig +from llama_stack.providers.adapters.inference.vllm import VLLMInferenceAdapterConfig from llama_stack.providers.impls.meta_reference.inference import ( MetaReferenceInferenceConfig, ) @@ -78,6 +79,21 @@ def inference_ollama(inference_model) -> ProviderFixture: ) +@pytest.fixture(scope="session") +def inference_vllm_remote() -> ProviderFixture: + return ProviderFixture( + providers=[ + Provider( + provider_id="remote::vllm", + provider_type="remote::vllm", + config=VLLMInferenceAdapterConfig( + url=get_env_or_fail("VLLM_URL"), + ).model_dump(), + ) + ], + ) + + @pytest.fixture(scope="session") def inference_fireworks() -> ProviderFixture: return ProviderFixture( @@ -109,7 +125,14 @@ def inference_together() -> ProviderFixture: ) -INFERENCE_FIXTURES = ["meta_reference", "ollama", "fireworks", "together", "remote"] +INFERENCE_FIXTURES = [ + "meta_reference", + "ollama", + "fireworks", + "together", + "vllm_remote", + "remote", +] @pytest_asyncio.fixture(scope="session")