Enable remote::vllm

This commit is contained in:
Ashwin Bharambe 2024-11-06 14:11:31 -08:00
parent 6ebd553da5
commit 6deeee9b87
5 changed files with 70 additions and 24 deletions

View file

@ -4,12 +4,15 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import VLLMImplConfig
from .vllm import VLLMInferenceAdapter
from .config import VLLMInferenceAdapterConfig
async def get_adapter_impl(config: VLLMImplConfig, _deps):
assert isinstance(config, VLLMImplConfig), f"Unexpected config type: {type(config)}"
async def get_adapter_impl(config: VLLMInferenceAdapterConfig, _deps):
from .vllm import VLLMInferenceAdapter
assert isinstance(
config, VLLMInferenceAdapterConfig
), f"Unexpected config type: {type(config)}"
impl = VLLMInferenceAdapter(config)
await impl.initialize()
return impl

View file

@ -11,12 +11,16 @@ from pydantic import BaseModel, Field
@json_schema_type
class VLLMImplConfig(BaseModel):
class VLLMInferenceAdapterConfig(BaseModel):
url: Optional[str] = Field(
default=None,
description="The URL for the vLLM model serving endpoint",
)
max_tokens: int = Field(
default=4096,
description="Maximum number of tokens to generate.",
)
api_token: Optional[str] = Field(
default=None,
default="fake",
description="The API token",
)

View file

@ -23,7 +23,8 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
)
from .config import VLLMImplConfig
from .config import VLLMInferenceAdapterConfig
VLLM_SUPPORTED_MODELS = {
"Llama3.1-8B": "meta-llama/Llama-3.1-8B",
@ -55,7 +56,7 @@ VLLM_SUPPORTED_MODELS = {
class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
def __init__(self, config: VLLMImplConfig) -> None:
def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
self.config = config
self.formatter = ChatFormat(Tokenizer.get_instance())
self.client = None
@ -70,10 +71,22 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
pass
async def list_models(self) -> List[ModelDef]:
return [
ModelDef(identifier=model.id, llama_model=model.id)
for model in self.client.models.list()
]
vllm_to_llama_map = {v: k for k, v in VLLM_SUPPORTED_MODELS.items()}
models = []
for model in self.client.models.list():
if model.id not in vllm_to_llama_map:
print(f"Unknown model served by vllm: {model.id}")
continue
identifier = vllm_to_llama_map[model.id]
models.append(
ModelDef(
identifier=identifier,
llama_model=identifier,
)
)
return models
async def completion(
self,
@ -118,7 +131,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
) -> ChatCompletionResponse:
params = self._get_params(request)
r = client.completions.create(**params)
return process_chat_completion_response(request, r, self.formatter)
return process_chat_completion_response(r, self.formatter)
async def _stream_chat_completion(
self, request: ChatCompletionRequest, client: OpenAI
@ -139,11 +152,14 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
yield chunk
def _get_params(self, request: ChatCompletionRequest) -> dict:
options = get_sampling_options(request.sampling_params)
if "max_tokens" not in options:
options["max_tokens"] = self.config.max_tokens
return {
"model": VLLM_SUPPORTED_MODELS[request.model],
"prompt": chat_completion_request_to_prompt(request, self.formatter),
"stream": request.stream,
**get_sampling_options(request.sampling_params),
**options,
}
async def embeddings(

View file

@ -61,15 +61,15 @@ def available_providers() -> List[ProviderSpec]:
module="llama_stack.providers.adapters.inference.ollama",
),
),
# remote_provider_spec(
# api=Api.inference,
# adapter=AdapterSpec(
# adapter_type="vllm",
# pip_packages=["openai"],
# module="llama_stack.providers.adapters.inference.vllm",
# config_class="llama_stack.providers.adapters.inference.vllm.VLLMImplConfig",
# ),
# ),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="vllm",
pip_packages=["openai"],
module="llama_stack.providers.adapters.inference.vllm",
config_class="llama_stack.providers.adapters.inference.vllm.VLLMInferenceAdapterConfig",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(

View file

@ -14,6 +14,7 @@ from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.adapters.inference.fireworks import FireworksImplConfig
from llama_stack.providers.adapters.inference.ollama import OllamaImplConfig
from llama_stack.providers.adapters.inference.together import TogetherImplConfig
from llama_stack.providers.adapters.inference.vllm import VLLMInferenceAdapterConfig
from llama_stack.providers.impls.meta_reference.inference import (
MetaReferenceInferenceConfig,
)
@ -78,6 +79,21 @@ def inference_ollama(inference_model) -> ProviderFixture:
)
@pytest.fixture(scope="session")
def inference_vllm_remote() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="remote::vllm",
provider_type="remote::vllm",
config=VLLMInferenceAdapterConfig(
url=get_env_or_fail("VLLM_URL"),
).model_dump(),
)
],
)
@pytest.fixture(scope="session")
def inference_fireworks() -> ProviderFixture:
return ProviderFixture(
@ -109,7 +125,14 @@ def inference_together() -> ProviderFixture:
)
INFERENCE_FIXTURES = ["meta_reference", "ollama", "fireworks", "together", "remote"]
INFERENCE_FIXTURES = [
"meta_reference",
"ollama",
"fireworks",
"together",
"vllm_remote",
"remote",
]
@pytest_asyncio.fixture(scope="session")