mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 17:29:01 +00:00
Merge branch 'vllm' into vllm-merge-1
This commit is contained in:
commit
8e358ec6a8
3 changed files with 64 additions and 17 deletions
|
@ -50,7 +50,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
log.info("Initializing vLLM inference adapter")
|
log.info("Initializing vLLM inference provider.")
|
||||||
|
|
||||||
# Disable usage stats reporting. This would be a surprising thing for most
|
# Disable usage stats reporting. This would be a surprising thing for most
|
||||||
# people to find out was on by default.
|
# people to find out was on by default.
|
||||||
|
@ -79,14 +79,33 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
|
|
||||||
async def shutdown(self):
|
async def shutdown(self):
|
||||||
"""Shutdown the vLLM inference adapter."""
|
"""Shutdown the vLLM inference adapter."""
|
||||||
log.info("Shutting down vLLM inference adapter")
|
log.info("Shutting down vLLM inference provider.")
|
||||||
if self.engine:
|
if self.engine:
|
||||||
self.engine.shutdown_background_loop()
|
self.engine.shutdown_background_loop()
|
||||||
|
|
||||||
async def register_model(self, model: Model) -> None:
|
# Note that the return type of the superclass method is WRONG
|
||||||
raise ValueError(
|
async def register_model(self, model: Model) -> Model:
|
||||||
"You cannot dynamically add a model to a running vllm instance"
|
"""
|
||||||
)
|
Callback that is called when the server associates an inference endpoint
|
||||||
|
with an inference provider.
|
||||||
|
|
||||||
|
:param model: Object that encapsulates parameters necessary for identifying
|
||||||
|
a specific LLM.
|
||||||
|
|
||||||
|
:returns: The input ``Model`` object. It may or may not be permissible
|
||||||
|
to change fields before returning this object.
|
||||||
|
"""
|
||||||
|
log.info(f"Registering model {model.identifier} with vLLM inference provider.")
|
||||||
|
# The current version of this provided is hard-coded to serve only
|
||||||
|
# the model specified in the YAML config file.
|
||||||
|
configured_model = resolve_model(self.config.model)
|
||||||
|
registered_model = resolve_model(model.model_id)
|
||||||
|
|
||||||
|
if configured_model.core_model_id != registered_model.core_model_id:
|
||||||
|
raise ValueError(f"Requested model '{model.identifier}' is different from "
|
||||||
|
f"model '{self.config.model}' that this provider "
|
||||||
|
f"is configured to serve")
|
||||||
|
return model
|
||||||
|
|
||||||
def _sampling_params(self, sampling_params: SamplingParams) -> VLLMSamplingParams:
|
def _sampling_params(self, sampling_params: SamplingParams) -> VLLMSamplingParams:
|
||||||
if sampling_params is None:
|
if sampling_params is None:
|
||||||
|
@ -206,7 +225,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
stream, self.formatter
|
stream, self.formatter
|
||||||
):
|
):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self, model_id: str, contents: List[InterleavedContent]
|
self, model_id: str, contents: List[InterleavedContent]
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
|
|
|
@ -15,6 +15,7 @@ from llama_stack.distribution.datatypes import Api, Provider
|
||||||
from llama_stack.providers.inline.inference.meta_reference import (
|
from llama_stack.providers.inline.inference.meta_reference import (
|
||||||
MetaReferenceInferenceConfig,
|
MetaReferenceInferenceConfig,
|
||||||
)
|
)
|
||||||
|
from llama_stack.providers.inline.inference.vllm import VLLMConfig
|
||||||
from llama_stack.providers.remote.inference.bedrock import BedrockConfig
|
from llama_stack.providers.remote.inference.bedrock import BedrockConfig
|
||||||
|
|
||||||
from llama_stack.providers.remote.inference.cerebras import CerebrasImplConfig
|
from llama_stack.providers.remote.inference.cerebras import CerebrasImplConfig
|
||||||
|
@ -104,6 +105,26 @@ def inference_ollama(inference_model) -> ProviderFixture:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture(scope="session")
|
||||||
|
def inference_vllm(inference_model) -> ProviderFixture:
|
||||||
|
inference_model = (
|
||||||
|
[inference_model] if isinstance(inference_model, str) else inference_model
|
||||||
|
)
|
||||||
|
return ProviderFixture(
|
||||||
|
providers=[
|
||||||
|
Provider(
|
||||||
|
provider_id=f"vllm-{i}",
|
||||||
|
provider_type="inline::vllm",
|
||||||
|
config=VLLMConfig(
|
||||||
|
model=m,
|
||||||
|
enforce_eager=True, # Make test run faster
|
||||||
|
).model_dump(),
|
||||||
|
)
|
||||||
|
for i, m in enumerate(inference_model)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def inference_vllm_remote() -> ProviderFixture:
|
def inference_vllm_remote() -> ProviderFixture:
|
||||||
return ProviderFixture(
|
return ProviderFixture(
|
||||||
|
@ -236,6 +257,7 @@ INFERENCE_FIXTURES = [
|
||||||
"ollama",
|
"ollama",
|
||||||
"fireworks",
|
"fireworks",
|
||||||
"together",
|
"together",
|
||||||
|
"vllm",
|
||||||
"vllm_remote",
|
"vllm_remote",
|
||||||
"remote",
|
"remote",
|
||||||
"bedrock",
|
"bedrock",
|
||||||
|
@ -268,4 +290,8 @@ async def inference_stack(request, inference_model):
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
return test_stack.impls[Api.inference], test_stack.impls[Api.models]
|
# Pytest yield fixture; see https://docs.pytest.org/en/stable/how-to/fixtures.html#yield-fixtures-recommended
|
||||||
|
yield test_stack.impls[Api.inference], test_stack.impls[Api.models]
|
||||||
|
|
||||||
|
# Cleanup code that runs after test case completion
|
||||||
|
await test_stack.impls[Api.inference].shutdown()
|
||||||
|
|
|
@ -67,7 +67,9 @@ def sample_tool_definition():
|
||||||
|
|
||||||
|
|
||||||
class TestInference:
|
class TestInference:
|
||||||
@pytest.mark.asyncio
|
# Session scope for asyncio because the tests in this class all
|
||||||
|
# share the same provider instance.
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
async def test_model_list(self, inference_model, inference_stack):
|
async def test_model_list(self, inference_model, inference_stack):
|
||||||
_, models_impl = inference_stack
|
_, models_impl = inference_stack
|
||||||
response = await models_impl.list_models()
|
response = await models_impl.list_models()
|
||||||
|
@ -83,7 +85,7 @@ class TestInference:
|
||||||
|
|
||||||
assert model_def is not None
|
assert model_def is not None
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
async def test_completion(self, inference_model, inference_stack):
|
async def test_completion(self, inference_model, inference_stack):
|
||||||
inference_impl, _ = inference_stack
|
inference_impl, _ = inference_stack
|
||||||
|
|
||||||
|
@ -128,7 +130,7 @@ class TestInference:
|
||||||
last = chunks[-1]
|
last = chunks[-1]
|
||||||
assert last.stop_reason == StopReason.out_of_tokens
|
assert last.stop_reason == StopReason.out_of_tokens
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
async def test_completion_logprobs(self, inference_model, inference_stack):
|
async def test_completion_logprobs(self, inference_model, inference_stack):
|
||||||
inference_impl, _ = inference_stack
|
inference_impl, _ = inference_stack
|
||||||
|
|
||||||
|
@ -183,7 +185,7 @@ class TestInference:
|
||||||
else: # no token, no logprobs
|
else: # no token, no logprobs
|
||||||
assert not chunk.logprobs, "Logprobs should be empty"
|
assert not chunk.logprobs, "Logprobs should be empty"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
@pytest.mark.skip("This test is not quite robust")
|
@pytest.mark.skip("This test is not quite robust")
|
||||||
async def test_completion_structured_output(self, inference_model, inference_stack):
|
async def test_completion_structured_output(self, inference_model, inference_stack):
|
||||||
inference_impl, _ = inference_stack
|
inference_impl, _ = inference_stack
|
||||||
|
@ -227,7 +229,7 @@ class TestInference:
|
||||||
assert answer.year_born == "1963"
|
assert answer.year_born == "1963"
|
||||||
assert answer.year_retired == "2003"
|
assert answer.year_retired == "2003"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
async def test_chat_completion_non_streaming(
|
async def test_chat_completion_non_streaming(
|
||||||
self, inference_model, inference_stack, common_params, sample_messages
|
self, inference_model, inference_stack, common_params, sample_messages
|
||||||
):
|
):
|
||||||
|
@ -244,7 +246,7 @@ class TestInference:
|
||||||
assert isinstance(response.completion_message.content, str)
|
assert isinstance(response.completion_message.content, str)
|
||||||
assert len(response.completion_message.content) > 0
|
assert len(response.completion_message.content) > 0
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
async def test_structured_output(
|
async def test_structured_output(
|
||||||
self, inference_model, inference_stack, common_params
|
self, inference_model, inference_stack, common_params
|
||||||
):
|
):
|
||||||
|
@ -314,7 +316,7 @@ class TestInference:
|
||||||
with pytest.raises(ValidationError):
|
with pytest.raises(ValidationError):
|
||||||
AnswerFormat.model_validate_json(response.completion_message.content)
|
AnswerFormat.model_validate_json(response.completion_message.content)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
async def test_chat_completion_streaming(
|
async def test_chat_completion_streaming(
|
||||||
self, inference_model, inference_stack, common_params, sample_messages
|
self, inference_model, inference_stack, common_params, sample_messages
|
||||||
):
|
):
|
||||||
|
@ -341,7 +343,7 @@ class TestInference:
|
||||||
end = grouped[ChatCompletionResponseEventType.complete][0]
|
end = grouped[ChatCompletionResponseEventType.complete][0]
|
||||||
assert end.event.stop_reason == StopReason.end_of_turn
|
assert end.event.stop_reason == StopReason.end_of_turn
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
async def test_chat_completion_with_tool_calling(
|
async def test_chat_completion_with_tool_calling(
|
||||||
self,
|
self,
|
||||||
inference_model,
|
inference_model,
|
||||||
|
@ -380,7 +382,7 @@ class TestInference:
|
||||||
assert "location" in call.arguments
|
assert "location" in call.arguments
|
||||||
assert "San Francisco" in call.arguments["location"]
|
assert "San Francisco" in call.arguments["location"]
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
async def test_chat_completion_with_tool_calling_streaming(
|
async def test_chat_completion_with_tool_calling_streaming(
|
||||||
self,
|
self,
|
||||||
inference_model,
|
inference_model,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue