llama-stack/llama_stack/providers/registry/inference.py
Russell Bryant f73e247ba1
Inline vLLM inference provider (#181)
This is just like `local` using `meta-reference` for everything except
it uses `vllm` for inference.

Docker works, but So far, `conda` is a bit easier to use with the vllm
provider. The default container base image does not include all the
necessary libraries for all vllm features. More cuda dependencies are
necessary.

I started changing this base image used in this template, but it also
required changes to the Dockerfile, so it was getting too involved to
include in the first PR.

Working so far:

* `python -m llama_stack.apis.inference.client localhost 5000 --model Llama3.2-1B-Instruct --stream True`
* `python -m llama_stack.apis.inference.client localhost 5000 --model Llama3.2-1B-Instruct --stream False`

Example:

```
$ python -m llama_stack.apis.inference.client localhost 5000 --model Llama3.2-1B-Instruct --stream False
User>hello world, write me a 2 sentence poem about the moon
Assistant>
The moon glows bright in the midnight sky
A beacon of light,
```

I have only tested these models:

* `Llama3.1-8B-Instruct` - across 4 GPUs (tensor_parallel_size = 4)
* `Llama3.2-1B-Instruct` - on a single GPU (tensor_parallel_size = 1)
2024-10-05 23:34:16 -07:00

116 lines
4.4 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from llama_stack.distribution.datatypes import * # noqa: F403
def available_providers() -> List[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.inference,
provider_type="meta-reference",
pip_packages=[
"accelerate",
"blobfile",
"fairscale",
"fbgemm-gpu==0.8.0",
"torch",
"torchvision",
"transformers",
"zmq",
],
module="llama_stack.providers.impls.meta_reference.inference",
config_class="llama_stack.providers.impls.meta_reference.inference.MetaReferenceImplConfig",
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="sample",
pip_packages=[],
module="llama_stack.providers.adapters.inference.sample",
config_class="llama_stack.providers.adapters.inference.sample.SampleConfig",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="ollama",
pip_packages=["ollama"],
config_class="llama_stack.providers.adapters.inference.ollama.OllamaImplConfig",
module="llama_stack.providers.adapters.inference.ollama",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="tgi",
pip_packages=["huggingface_hub", "aiohttp"],
module="llama_stack.providers.adapters.inference.tgi",
config_class="llama_stack.providers.adapters.inference.tgi.TGIImplConfig",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="hf::serverless",
pip_packages=["huggingface_hub", "aiohttp"],
module="llama_stack.providers.adapters.inference.tgi",
config_class="llama_stack.providers.adapters.inference.tgi.InferenceAPIImplConfig",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="hf::endpoint",
pip_packages=["huggingface_hub", "aiohttp"],
module="llama_stack.providers.adapters.inference.tgi",
config_class="llama_stack.providers.adapters.inference.tgi.InferenceEndpointImplConfig",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="fireworks",
pip_packages=[
"fireworks-ai",
],
module="llama_stack.providers.adapters.inference.fireworks",
config_class="llama_stack.providers.adapters.inference.fireworks.FireworksImplConfig",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="together",
pip_packages=[
"together",
],
module="llama_stack.providers.adapters.inference.together",
config_class="llama_stack.providers.adapters.inference.together.TogetherImplConfig",
provider_data_validator="llama_stack.providers.adapters.safety.together.TogetherProviderDataValidator",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="bedrock",
pip_packages=["boto3"],
module="llama_stack.providers.adapters.inference.bedrock",
config_class="llama_stack.providers.adapters.inference.bedrock.BedrockConfig",
),
),
InlineProviderSpec(
api=Api.inference,
provider_type="vllm",
pip_packages=[
"vllm",
],
module="llama_stack.providers.impls.vllm",
config_class="llama_stack.providers.impls.vllm.VLLMConfig",
),
]