mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-01 16:24:44 +00:00
vllm
This commit is contained in:
parent
71219b4937
commit
92ee627e89
1 changed files with 32 additions and 36 deletions
|
@ -8,13 +8,17 @@ from typing import AsyncGenerator
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
from llama_models.llama3.api.datatypes import Message
|
from llama_models.llama3.api.datatypes import Message
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
from llama_models.sku_list import all_registered_models, resolve_model
|
from llama_models.sku_list import all_registered_models
|
||||||
|
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||||
|
|
||||||
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
|
ModelAlias,
|
||||||
|
ModelRegistryHelper,
|
||||||
|
)
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
get_sampling_options,
|
get_sampling_options,
|
||||||
process_chat_completion_response,
|
process_chat_completion_response,
|
||||||
|
@ -30,8 +34,24 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
from .config import VLLMInferenceAdapterConfig
|
from .config import VLLMInferenceAdapterConfig
|
||||||
|
|
||||||
|
|
||||||
class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
def build_model_aliases():
|
||||||
|
return [
|
||||||
|
ModelAlias(
|
||||||
|
provider_model_id=model.huggingface_repo,
|
||||||
|
aliases=[model.descriptor()],
|
||||||
|
llama_model=model.descriptor(),
|
||||||
|
)
|
||||||
|
for model in all_registered_models()
|
||||||
|
if model.huggingface_repo
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class VLLMInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPrivate):
|
||||||
def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
|
def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
|
||||||
|
ModelRegistryHelper.__init__(
|
||||||
|
self,
|
||||||
|
model_aliases=build_model_aliases(),
|
||||||
|
)
|
||||||
self.config = config
|
self.config = config
|
||||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||||
self.client = None
|
self.client = None
|
||||||
|
@ -44,31 +64,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
self.client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
|
self.client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
|
||||||
|
|
||||||
async def register_model(self, model: Model) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def list_models(self) -> List[Model]:
|
|
||||||
models = []
|
|
||||||
for model in self.client.models.list():
|
|
||||||
repo = model.id
|
|
||||||
if repo not in self.huggingface_repo_to_llama_model_id:
|
|
||||||
print(f"Unknown model served by vllm: {repo}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
identifier = self.huggingface_repo_to_llama_model_id[repo]
|
|
||||||
if identifier == model.provider_resource_id:
|
|
||||||
print(
|
|
||||||
f"Verified that model {model.provider_resource_id} is being served by vLLM"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
raise ValueError(
|
|
||||||
f"Model {model.provider_resource_id} is not being served by vLLM"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -95,8 +90,9 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
|
model = await self.model_store.get_model(model_id)
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
model=model_id,
|
model=model.provider_resource_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
tools=tools or [],
|
tools=tools or [],
|
||||||
|
@ -148,10 +144,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
if "max_tokens" not in options:
|
if "max_tokens" not in options:
|
||||||
options["max_tokens"] = self.config.max_tokens
|
options["max_tokens"] = self.config.max_tokens
|
||||||
|
|
||||||
model = resolve_model(request.model)
|
|
||||||
if model is None:
|
|
||||||
raise ValueError(f"Unknown model: {request.model}")
|
|
||||||
|
|
||||||
input_dict = {}
|
input_dict = {}
|
||||||
media_present = request_has_media(request)
|
media_present = request_has_media(request)
|
||||||
if isinstance(request, ChatCompletionRequest):
|
if isinstance(request, ChatCompletionRequest):
|
||||||
|
@ -163,16 +155,20 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
input_dict["prompt"] = chat_completion_request_to_prompt(
|
input_dict["prompt"] = chat_completion_request_to_prompt(
|
||||||
request, self.formatter
|
request, self.get_llama_model(request.model), self.formatter
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert (
|
assert (
|
||||||
not media_present
|
not media_present
|
||||||
), "Together does not support media for Completion requests"
|
), "Together does not support media for Completion requests"
|
||||||
input_dict["prompt"] = completion_request_to_prompt(request, self.formatter)
|
input_dict["prompt"] = completion_request_to_prompt(
|
||||||
|
request,
|
||||||
|
self.get_llama_model(request.model),
|
||||||
|
self.formatter,
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"model": model.huggingface_repo,
|
"model": request.model,
|
||||||
**input_dict,
|
**input_dict,
|
||||||
"stream": request.stream,
|
"stream": request.stream,
|
||||||
**options,
|
**options,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue