Enable remote::vllm

This commit is contained in:
Ashwin Bharambe 2024-11-06 14:11:31 -08:00
parent 6ebd553da5
commit 6deeee9b87
5 changed files with 70 additions and 24 deletions

View file

@ -4,12 +4,15 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from .config import VLLMImplConfig from .config import VLLMInferenceAdapterConfig
from .vllm import VLLMInferenceAdapter
async def get_adapter_impl(config: VLLMImplConfig, _deps): async def get_adapter_impl(config: VLLMInferenceAdapterConfig, _deps):
assert isinstance(config, VLLMImplConfig), f"Unexpected config type: {type(config)}" from .vllm import VLLMInferenceAdapter
assert isinstance(
config, VLLMInferenceAdapterConfig
), f"Unexpected config type: {type(config)}"
impl = VLLMInferenceAdapter(config) impl = VLLMInferenceAdapter(config)
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -11,12 +11,16 @@ from pydantic import BaseModel, Field
@json_schema_type @json_schema_type
class VLLMImplConfig(BaseModel): class VLLMInferenceAdapterConfig(BaseModel):
url: Optional[str] = Field( url: Optional[str] = Field(
default=None, default=None,
description="The URL for the vLLM model serving endpoint", description="The URL for the vLLM model serving endpoint",
) )
max_tokens: int = Field(
default=4096,
description="Maximum number of tokens to generate.",
)
api_token: Optional[str] = Field( api_token: Optional[str] = Field(
default=None, default="fake",
description="The API token", description="The API token",
) )

View file

@ -23,7 +23,8 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt, chat_completion_request_to_prompt,
) )
from .config import VLLMImplConfig from .config import VLLMInferenceAdapterConfig
VLLM_SUPPORTED_MODELS = { VLLM_SUPPORTED_MODELS = {
"Llama3.1-8B": "meta-llama/Llama-3.1-8B", "Llama3.1-8B": "meta-llama/Llama-3.1-8B",
@ -55,7 +56,7 @@ VLLM_SUPPORTED_MODELS = {
class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
def __init__(self, config: VLLMImplConfig) -> None: def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
self.config = config self.config = config
self.formatter = ChatFormat(Tokenizer.get_instance()) self.formatter = ChatFormat(Tokenizer.get_instance())
self.client = None self.client = None
@ -70,10 +71,22 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
pass pass
async def list_models(self) -> List[ModelDef]: async def list_models(self) -> List[ModelDef]:
return [ vllm_to_llama_map = {v: k for k, v in VLLM_SUPPORTED_MODELS.items()}
ModelDef(identifier=model.id, llama_model=model.id)
for model in self.client.models.list() models = []
] for model in self.client.models.list():
if model.id not in vllm_to_llama_map:
print(f"Unknown model served by vllm: {model.id}")
continue
identifier = vllm_to_llama_map[model.id]
models.append(
ModelDef(
identifier=identifier,
llama_model=identifier,
)
)
return models
async def completion( async def completion(
self, self,
@ -118,7 +131,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
) -> ChatCompletionResponse: ) -> ChatCompletionResponse:
params = self._get_params(request) params = self._get_params(request)
r = client.completions.create(**params) r = client.completions.create(**params)
return process_chat_completion_response(request, r, self.formatter) return process_chat_completion_response(r, self.formatter)
async def _stream_chat_completion( async def _stream_chat_completion(
self, request: ChatCompletionRequest, client: OpenAI self, request: ChatCompletionRequest, client: OpenAI
@ -139,11 +152,14 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
yield chunk yield chunk
def _get_params(self, request: ChatCompletionRequest) -> dict: def _get_params(self, request: ChatCompletionRequest) -> dict:
options = get_sampling_options(request.sampling_params)
if "max_tokens" not in options:
options["max_tokens"] = self.config.max_tokens
return { return {
"model": VLLM_SUPPORTED_MODELS[request.model], "model": VLLM_SUPPORTED_MODELS[request.model],
"prompt": chat_completion_request_to_prompt(request, self.formatter), "prompt": chat_completion_request_to_prompt(request, self.formatter),
"stream": request.stream, "stream": request.stream,
**get_sampling_options(request.sampling_params), **options,
} }
async def embeddings( async def embeddings(

View file

@ -61,15 +61,15 @@ def available_providers() -> List[ProviderSpec]:
module="llama_stack.providers.adapters.inference.ollama", module="llama_stack.providers.adapters.inference.ollama",
), ),
), ),
# remote_provider_spec( remote_provider_spec(
# api=Api.inference, api=Api.inference,
# adapter=AdapterSpec( adapter=AdapterSpec(
# adapter_type="vllm", adapter_type="vllm",
# pip_packages=["openai"], pip_packages=["openai"],
# module="llama_stack.providers.adapters.inference.vllm", module="llama_stack.providers.adapters.inference.vllm",
# config_class="llama_stack.providers.adapters.inference.vllm.VLLMImplConfig", config_class="llama_stack.providers.adapters.inference.vllm.VLLMInferenceAdapterConfig",
# ), ),
# ), ),
remote_provider_spec( remote_provider_spec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec( adapter=AdapterSpec(

View file

@ -14,6 +14,7 @@ from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.adapters.inference.fireworks import FireworksImplConfig from llama_stack.providers.adapters.inference.fireworks import FireworksImplConfig
from llama_stack.providers.adapters.inference.ollama import OllamaImplConfig from llama_stack.providers.adapters.inference.ollama import OllamaImplConfig
from llama_stack.providers.adapters.inference.together import TogetherImplConfig from llama_stack.providers.adapters.inference.together import TogetherImplConfig
from llama_stack.providers.adapters.inference.vllm import VLLMInferenceAdapterConfig
from llama_stack.providers.impls.meta_reference.inference import ( from llama_stack.providers.impls.meta_reference.inference import (
MetaReferenceInferenceConfig, MetaReferenceInferenceConfig,
) )
@ -78,6 +79,21 @@ def inference_ollama(inference_model) -> ProviderFixture:
) )
@pytest.fixture(scope="session")
def inference_vllm_remote() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="remote::vllm",
provider_type="remote::vllm",
config=VLLMInferenceAdapterConfig(
url=get_env_or_fail("VLLM_URL"),
).model_dump(),
)
],
)
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def inference_fireworks() -> ProviderFixture: def inference_fireworks() -> ProviderFixture:
return ProviderFixture( return ProviderFixture(
@ -109,7 +125,14 @@ def inference_together() -> ProviderFixture:
) )
INFERENCE_FIXTURES = ["meta_reference", "ollama", "fireworks", "together", "remote"] INFERENCE_FIXTURES = [
"meta_reference",
"ollama",
"fireworks",
"together",
"vllm_remote",
"remote",
]
@pytest_asyncio.fixture(scope="session") @pytest_asyncio.fixture(scope="session")