mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
Merge branch 'main' into evals_new
This commit is contained in:
commit
2c23a66300
24 changed files with 112 additions and 120 deletions
10
.flake8
10
.flake8
|
@ -21,11 +21,11 @@ ignore =
|
||||||
optional-ascii-coding = True
|
optional-ascii-coding = True
|
||||||
exclude =
|
exclude =
|
||||||
./.git,
|
./.git,
|
||||||
./docs
|
./docs/*,
|
||||||
./build
|
./build,
|
||||||
./scripts,
|
./scripts,
|
||||||
./venv,
|
./venv,
|
||||||
*.pyi
|
*.pyi,
|
||||||
.pre-commit-config.yaml
|
.pre-commit-config.yaml,
|
||||||
*.md
|
*.md,
|
||||||
.flake8
|
.flake8
|
||||||
|
|
25
.github/workflows/pre-commit.yml
vendored
Normal file
25
.github/workflows/pre-commit.yml
vendored
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
name: Pre-commit
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
push:
|
||||||
|
branches: [main]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
pre-commit:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@d632683dd7b4114ad314bca15554477dd762a938 # v4.2.0
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0
|
||||||
|
with:
|
||||||
|
python-version: '3.11'
|
||||||
|
cache: pip
|
||||||
|
cache-dependency-path: |
|
||||||
|
**/requirements*.txt
|
||||||
|
.pre-commit-config.yaml
|
||||||
|
|
||||||
|
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd #v3.0.1
|
|
@ -1,3 +1,5 @@
|
||||||
|
<img src="https://github.com/user-attachments/assets/2fedfe0f-6df7-4441-98b2-87a1fd95ee1c" width="300" title="Llama Stack Logo" alt="Llama Stack Logo"/>
|
||||||
|
|
||||||
# Llama Stack
|
# Llama Stack
|
||||||
|
|
||||||
[](https://pypi.org/project/llama_stack/)
|
[](https://pypi.org/project/llama_stack/)
|
||||||
|
@ -97,7 +99,7 @@ The `llama` CLI makes it easy to work with the Llama Stack set of tools. Please
|
||||||
| **Language** | **Client SDK** | **Package** |
|
| **Language** | **Client SDK** | **Package** |
|
||||||
| :----: | :----: | :----: |
|
| :----: | :----: | :----: |
|
||||||
| Python | [llama-stack-client-python](https://github.com/meta-llama/llama-stack-client-python) | [](https://pypi.org/project/llama_stack_client/)
|
| Python | [llama-stack-client-python](https://github.com/meta-llama/llama-stack-client-python) | [](https://pypi.org/project/llama_stack_client/)
|
||||||
| Swift | [llama-stack-client-swift](https://github.com/meta-llama/llama-stack-client-swift) |
|
| Swift | [llama-stack-client-swift](https://github.com/meta-llama/llama-stack-client-swift) | [](https://swiftpackageindex.com/meta-llama/llama-stack-client-swift)
|
||||||
| Node | [llama-stack-client-node](https://github.com/meta-llama/llama-stack-client-node) | [](https://npmjs.org/package/llama-stack-client)
|
| Node | [llama-stack-client-node](https://github.com/meta-llama/llama-stack-client-node) | [](https://npmjs.org/package/llama-stack-client)
|
||||||
| Kotlin | [llama-stack-client-kotlin](https://github.com/meta-llama/llama-stack-client-kotlin) |
|
| Kotlin | [llama-stack-client-kotlin](https://github.com/meta-llama/llama-stack-client-kotlin) |
|
||||||
|
|
||||||
|
|
|
@ -21,7 +21,7 @@
|
||||||
"info": {
|
"info": {
|
||||||
"title": "[DRAFT] Llama Stack Specification",
|
"title": "[DRAFT] Llama Stack Specification",
|
||||||
"version": "0.0.1",
|
"version": "0.0.1",
|
||||||
"description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-10-15 10:15:15.195382"
|
"description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-10-10 15:29:56.831109"
|
||||||
},
|
},
|
||||||
"servers": [
|
"servers": [
|
||||||
{
|
{
|
||||||
|
@ -6228,49 +6228,7 @@
|
||||||
],
|
],
|
||||||
"tags": [
|
"tags": [
|
||||||
{
|
{
|
||||||
"name": "Inference"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "PostTraining"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Agents"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "MemoryBanks"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Inspect"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Models"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Safety"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Evals"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "BatchInference"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Shields"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "SyntheticDataGeneration"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Telemetry"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "RewardScoring"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Datasets"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Memory"
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "BuiltinTool",
|
"name": "BuiltinTool",
|
||||||
|
|
|
@ -2679,7 +2679,6 @@ info:
|
||||||
description: "This is the specification of the llama stack that provides\n \
|
description: "This is the specification of the llama stack that provides\n \
|
||||||
\ a set of endpoints and their corresponding interfaces that are tailored\
|
\ a set of endpoints and their corresponding interfaces that are tailored\
|
||||||
\ to\n best leverage Llama Models. The specification is still in\
|
\ to\n best leverage Llama Models. The specification is still in\
|
||||||
\ draft and subject to change.\n Generated at 2024-10-15 10:15:15.195382"
|
|
||||||
title: '[DRAFT] Llama Stack Specification'
|
title: '[DRAFT] Llama Stack Specification'
|
||||||
version: 0.0.1
|
version: 0.0.1
|
||||||
jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema
|
jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema
|
||||||
|
|
|
@ -149,6 +149,7 @@ class StackBuild(Subcommand):
|
||||||
|
|
||||||
def _run_template_list_cmd(self, args: argparse.Namespace) -> None:
|
def _run_template_list_cmd(self, args: argparse.Namespace) -> None:
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from llama_stack.cli.table import print_table
|
from llama_stack.cli.table import print_table
|
||||||
|
|
||||||
# eventually, this should query a registry at llama.meta.com/llamastack/distributions
|
# eventually, this should query a registry at llama.meta.com/llamastack/distributions
|
||||||
|
@ -175,6 +176,7 @@ class StackBuild(Subcommand):
|
||||||
|
|
||||||
def _run_stack_build_command(self, args: argparse.Namespace) -> None:
|
def _run_stack_build_command(self, args: argparse.Namespace) -> None:
|
||||||
import textwrap
|
import textwrap
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from llama_stack.distribution.distribution import get_provider_registry
|
from llama_stack.distribution.distribution import get_provider_registry
|
||||||
from prompt_toolkit import prompt
|
from prompt_toolkit import prompt
|
||||||
|
@ -256,7 +258,9 @@ class StackBuild(Subcommand):
|
||||||
providers = dict()
|
providers = dict()
|
||||||
for api, providers_for_api in get_provider_registry().items():
|
for api, providers_for_api in get_provider_registry().items():
|
||||||
available_providers = [
|
available_providers = [
|
||||||
x for x in providers_for_api.keys() if x != "remote"
|
x
|
||||||
|
for x in providers_for_api.keys()
|
||||||
|
if x not in ("remote", "remote::sample")
|
||||||
]
|
]
|
||||||
api_provider = prompt(
|
api_provider = prompt(
|
||||||
"> Enter provider for API {}: ".format(api.value),
|
"> Enter provider for API {}: ".format(api.value),
|
||||||
|
|
|
@ -7,6 +7,7 @@
|
||||||
from .config import DatabricksImplConfig
|
from .config import DatabricksImplConfig
|
||||||
from .databricks import DatabricksInferenceAdapter
|
from .databricks import DatabricksInferenceAdapter
|
||||||
|
|
||||||
|
|
||||||
async def get_adapter_impl(config: DatabricksImplConfig, _deps):
|
async def get_adapter_impl(config: DatabricksImplConfig, _deps):
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
config, DatabricksImplConfig
|
config, DatabricksImplConfig
|
||||||
|
|
|
@ -4,7 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
from llama_models.schema_utils import json_schema_type
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
|
@ -48,7 +48,14 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def completion(self, request: CompletionRequest) -> AsyncGenerator:
|
def completion(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
content: InterleavedTextMedia,
|
||||||
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
|
stream: Optional[bool] = False,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
) -> AsyncGenerator:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def chat_completion(
|
def chat_completion(
|
||||||
|
|
|
@ -63,13 +63,12 @@ async def llm_rag_query_generator(
|
||||||
|
|
||||||
model = config.model
|
model = config.model
|
||||||
message = UserMessage(content=content)
|
message = UserMessage(content=content)
|
||||||
response = inference_api.chat_completion(
|
response = await inference_api.chat_completion(
|
||||||
model=model,
|
model=model,
|
||||||
messages=[message],
|
messages=[message],
|
||||||
stream=False,
|
stream=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
async for chunk in response:
|
query = response.completion_message.content
|
||||||
query = chunk.completion_message.content
|
|
||||||
|
|
||||||
return query
|
return query
|
||||||
|
|
|
@ -16,7 +16,7 @@ from llama_stack.apis.agents import * # noqa: F403
|
||||||
from ..agents import (
|
from ..agents import (
|
||||||
AGENT_INSTANCES_BY_ID,
|
AGENT_INSTANCES_BY_ID,
|
||||||
MetaReferenceAgentsImpl,
|
MetaReferenceAgentsImpl,
|
||||||
MetaReferenceImplConfig,
|
MetaReferenceInferenceConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -166,7 +166,7 @@ def mock_memory_api():
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def chat_agent(mock_inference_api, mock_safety_api, mock_memory_api):
|
async def chat_agent(mock_inference_api, mock_safety_api, mock_memory_api):
|
||||||
impl = MetaReferenceAgentsImpl(
|
impl = MetaReferenceAgentsImpl(
|
||||||
config=MetaReferenceImplConfig(),
|
config=MetaReferenceInferenceConfig(),
|
||||||
inference_api=mock_inference_api,
|
inference_api=mock_inference_api,
|
||||||
safety_api=mock_safety_api,
|
safety_api=mock_safety_api,
|
||||||
memory_api=mock_memory_api,
|
memory_api=mock_memory_api,
|
||||||
|
|
|
@ -4,16 +4,17 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .config import MetaReferenceImplConfig # noqa
|
from typing import Union
|
||||||
|
|
||||||
|
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(config: MetaReferenceImplConfig, _deps):
|
async def get_provider_impl(
|
||||||
|
config: Union[MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig],
|
||||||
|
_deps,
|
||||||
|
):
|
||||||
from .inference import MetaReferenceInferenceImpl
|
from .inference import MetaReferenceInferenceImpl
|
||||||
|
|
||||||
assert isinstance(
|
|
||||||
config, MetaReferenceImplConfig
|
|
||||||
), f"Unexpected config type: {type(config)}"
|
|
||||||
|
|
||||||
impl = MetaReferenceInferenceImpl(config)
|
impl = MetaReferenceInferenceImpl(config)
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -15,12 +15,11 @@ from pydantic import BaseModel, Field, field_validator
|
||||||
from llama_stack.providers.utils.inference import supported_inference_models
|
from llama_stack.providers.utils.inference import supported_inference_models
|
||||||
|
|
||||||
|
|
||||||
class MetaReferenceImplConfig(BaseModel):
|
class MetaReferenceInferenceConfig(BaseModel):
|
||||||
model: str = Field(
|
model: str = Field(
|
||||||
default="Llama3.1-8B-Instruct",
|
default="Llama3.1-8B-Instruct",
|
||||||
description="Model descriptor from `llama model list`",
|
description="Model descriptor from `llama model list`",
|
||||||
)
|
)
|
||||||
quantization: Optional[QuantizationConfig] = None
|
|
||||||
torch_seed: Optional[int] = None
|
torch_seed: Optional[int] = None
|
||||||
max_seq_len: int = 4096
|
max_seq_len: int = 4096
|
||||||
max_batch_size: int = 1
|
max_batch_size: int = 1
|
||||||
|
@ -38,9 +37,9 @@ class MetaReferenceImplConfig(BaseModel):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def model_parallel_size(self) -> int:
|
def model_parallel_size(self) -> int:
|
||||||
# HACK ALERT: this will be fixed when we move inference configuration
|
|
||||||
# to ModelsRegistry and we can explicitly ask for `model_parallel_size`
|
|
||||||
# as configuration there
|
|
||||||
resolved = resolve_model(self.model)
|
resolved = resolve_model(self.model)
|
||||||
assert resolved is not None
|
|
||||||
return resolved.pth_file_count
|
return resolved.pth_file_count
|
||||||
|
|
||||||
|
|
||||||
|
class MetaReferenceQuantizedInferenceConfig(MetaReferenceInferenceConfig):
|
||||||
|
quantization: QuantizationConfig
|
||||||
|
|
|
@ -11,9 +11,8 @@ import json
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Generator, List, Optional
|
from typing import Generator, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
@ -36,14 +35,12 @@ from llama_models.llama3.reference_impl.multimodal.model import (
|
||||||
)
|
)
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
|
|
||||||
from llama_stack.apis.inference import QuantizationType
|
|
||||||
|
|
||||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
from .config import MetaReferenceImplConfig
|
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||||
|
|
||||||
|
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
|
||||||
|
|
||||||
|
|
||||||
def model_checkpoint_dir(model) -> str:
|
def model_checkpoint_dir(model) -> str:
|
||||||
|
@ -68,7 +65,11 @@ class TokenResult(BaseModel):
|
||||||
|
|
||||||
class Llama:
|
class Llama:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def build(config: MetaReferenceImplConfig):
|
def build(
|
||||||
|
config: Union[
|
||||||
|
MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
|
||||||
|
]
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Build a Llama instance by initializing and loading a model checkpoint.
|
Build a Llama instance by initializing and loading a model checkpoint.
|
||||||
|
|
||||||
|
@ -78,15 +79,6 @@ class Llama:
|
||||||
"""
|
"""
|
||||||
model = resolve_model(config.model)
|
model = resolve_model(config.model)
|
||||||
|
|
||||||
if (
|
|
||||||
config.quantization
|
|
||||||
and config.quantization.type == QuantizationType.fp8.value
|
|
||||||
):
|
|
||||||
from .quantization.loader import is_fbgemm_available
|
|
||||||
|
|
||||||
if not is_fbgemm_available():
|
|
||||||
raise ImportError("fbgemm-gpu is required for FP8 quantization")
|
|
||||||
|
|
||||||
if not torch.distributed.is_initialized():
|
if not torch.distributed.is_initialized():
|
||||||
torch.distributed.init_process_group("nccl")
|
torch.distributed.init_process_group("nccl")
|
||||||
|
|
||||||
|
@ -134,12 +126,7 @@ class Llama:
|
||||||
model_args.vocab_size == tokenizer.n_words
|
model_args.vocab_size == tokenizer.n_words
|
||||||
), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
|
), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
|
||||||
|
|
||||||
fp8 = (
|
if isinstance(config, MetaReferenceQuantizedInferenceConfig):
|
||||||
config.quantization
|
|
||||||
and config.quantization.type == QuantizationType.fp8.value
|
|
||||||
)
|
|
||||||
|
|
||||||
if fp8:
|
|
||||||
from .quantization.loader import convert_to_quantized_model
|
from .quantization.loader import convert_to_quantized_model
|
||||||
|
|
||||||
# load on CPU in bf16 so that fp8 conversion does not find an
|
# load on CPU in bf16 so that fp8 conversion does not find an
|
||||||
|
|
|
@ -17,7 +17,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
chat_completion_request_to_messages,
|
chat_completion_request_to_messages,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .config import MetaReferenceImplConfig
|
from .config import MetaReferenceInferenceConfig
|
||||||
from .model_parallel import LlamaModelParallelGenerator
|
from .model_parallel import LlamaModelParallelGenerator
|
||||||
|
|
||||||
# there's a single model parallel process running serving the model. for now,
|
# there's a single model parallel process running serving the model. for now,
|
||||||
|
@ -26,7 +26,7 @@ SEMAPHORE = asyncio.Semaphore(1)
|
||||||
|
|
||||||
|
|
||||||
class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
def __init__(self, config: MetaReferenceImplConfig) -> None:
|
def __init__(self, config: MetaReferenceInferenceConfig) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
model = resolve_model(config.model)
|
model = resolve_model(config.model)
|
||||||
if model is None:
|
if model is None:
|
||||||
|
|
|
@ -6,7 +6,6 @@
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from dataclasses import dataclass
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Generator, List, Optional
|
from typing import Generator, List, Optional
|
||||||
|
|
||||||
|
@ -15,7 +14,7 @@ from llama_models.llama3.api.datatypes import Message, ToolPromptFormat
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
|
|
||||||
from .config import MetaReferenceImplConfig
|
from .config import MetaReferenceInferenceConfig
|
||||||
from .generation import Llama, model_checkpoint_dir
|
from .generation import Llama, model_checkpoint_dir
|
||||||
from .parallel_utils import InferenceArgs, ModelParallelProcessGroup
|
from .parallel_utils import InferenceArgs, ModelParallelProcessGroup
|
||||||
|
|
||||||
|
@ -36,7 +35,7 @@ class ModelRunner:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def init_model_cb(config: MetaReferenceImplConfig):
|
def init_model_cb(config: MetaReferenceInferenceConfig):
|
||||||
llama = Llama.build(config)
|
llama = Llama.build(config)
|
||||||
return ModelRunner(llama)
|
return ModelRunner(llama)
|
||||||
|
|
||||||
|
@ -52,7 +51,7 @@ class LlamaModelParallelGenerator:
|
||||||
clear at the callsite why we need to use a context manager.
|
clear at the callsite why we need to use a context manager.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config: MetaReferenceImplConfig):
|
def __init__(self, config: MetaReferenceInferenceConfig):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.model = resolve_model(self.config.model)
|
self.model = resolve_model(self.config.model)
|
||||||
# this is a hack because Agent's loop uses this to tokenize and check if input is too long
|
# this is a hack because Agent's loop uses this to tokenize and check if input is too long
|
||||||
|
|
|
@ -11,7 +11,7 @@ import tempfile
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Callable, Generator, List, Literal, Optional, Union
|
from typing import Callable, Generator, List, Literal, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
@ -317,7 +317,7 @@ def start_model_parallel_process(
|
||||||
|
|
||||||
request_socket.send(encode_msg(ReadyRequest()))
|
request_socket.send(encode_msg(ReadyRequest()))
|
||||||
response = request_socket.recv()
|
response = request_socket.recv()
|
||||||
print(f"Finished model load {response}")
|
print("Loaded model...")
|
||||||
|
|
||||||
return request_socket, process
|
return request_socket, process
|
||||||
|
|
||||||
|
|
|
@ -22,19 +22,10 @@ from torch import Tensor
|
||||||
from llama_stack.apis.inference import QuantizationType
|
from llama_stack.apis.inference import QuantizationType
|
||||||
|
|
||||||
from llama_stack.providers.impls.meta_reference.inference.config import (
|
from llama_stack.providers.impls.meta_reference.inference.config import (
|
||||||
MetaReferenceImplConfig,
|
MetaReferenceQuantizedInferenceConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def is_fbgemm_available() -> bool:
|
|
||||||
try:
|
|
||||||
import fbgemm_gpu.experimental.gen_ai # noqa: F401
|
|
||||||
|
|
||||||
return True
|
|
||||||
except ImportError:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def swiglu_wrapper(
|
def swiglu_wrapper(
|
||||||
self,
|
self,
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
|
@ -47,7 +38,7 @@ def swiglu_wrapper(
|
||||||
|
|
||||||
def convert_to_quantized_model(
|
def convert_to_quantized_model(
|
||||||
model: Transformer,
|
model: Transformer,
|
||||||
config: MetaReferenceImplConfig,
|
config: MetaReferenceQuantizedInferenceConfig,
|
||||||
fp8_activation_scale_ub: Optional[float] = 1200.0,
|
fp8_activation_scale_ub: Optional[float] = 1200.0,
|
||||||
) -> Transformer:
|
) -> Transformer:
|
||||||
if config.quantization.type == QuantizationType.bf16.value:
|
if config.quantization.type == QuantizationType.bf16.value:
|
||||||
|
|
|
@ -170,7 +170,7 @@ class LlamaGuardShield(ShieldBase):
|
||||||
for i in range(1, len(messages)):
|
for i in range(1, len(messages)):
|
||||||
if messages[i].role == messages[i - 1].role:
|
if messages[i].role == messages[i - 1].role:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Messages must alternate between user and assistant. Message {i} has the same role as message {i-1}"
|
f"Messages must alternate between user and assistant. Message {i} has the same role as message {i - 1}"
|
||||||
)
|
)
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,9 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from .config import VLLMConfig
|
from .config import VLLMConfig
|
||||||
|
|
|
@ -14,6 +14,21 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
provider_type="meta-reference",
|
provider_type="meta-reference",
|
||||||
|
pip_packages=[
|
||||||
|
"accelerate",
|
||||||
|
"blobfile",
|
||||||
|
"fairscale",
|
||||||
|
"torch",
|
||||||
|
"torchvision",
|
||||||
|
"transformers",
|
||||||
|
"zmq",
|
||||||
|
],
|
||||||
|
module="llama_stack.providers.impls.meta_reference.inference",
|
||||||
|
config_class="llama_stack.providers.impls.meta_reference.inference.MetaReferenceInferenceConfig",
|
||||||
|
),
|
||||||
|
InlineProviderSpec(
|
||||||
|
api=Api.inference,
|
||||||
|
provider_type="meta-reference-quantized",
|
||||||
pip_packages=[
|
pip_packages=[
|
||||||
"accelerate",
|
"accelerate",
|
||||||
"blobfile",
|
"blobfile",
|
||||||
|
@ -25,7 +40,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
"zmq",
|
"zmq",
|
||||||
],
|
],
|
||||||
module="llama_stack.providers.impls.meta_reference.inference",
|
module="llama_stack.providers.impls.meta_reference.inference",
|
||||||
config_class="llama_stack.providers.impls.meta_reference.inference.MetaReferenceImplConfig",
|
config_class="llama_stack.providers.impls.meta_reference.inference.MetaReferenceQuantizedInferenceConfig",
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
|
|
|
@ -2,7 +2,7 @@ blobfile
|
||||||
fire
|
fire
|
||||||
httpx
|
httpx
|
||||||
huggingface-hub
|
huggingface-hub
|
||||||
llama-models>=0.0.41
|
llama-models>=0.0.42
|
||||||
prompt-toolkit
|
prompt-toolkit
|
||||||
python-dotenv
|
python-dotenv
|
||||||
pydantic>=2
|
pydantic>=2
|
||||||
|
|
2
setup.py
2
setup.py
|
@ -16,7 +16,7 @@ def read_requirements():
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="llama_stack",
|
name="llama_stack",
|
||||||
version="0.0.41",
|
version="0.0.42",
|
||||||
author="Meta Llama",
|
author="Meta Llama",
|
||||||
author_email="llama-oss@meta.com",
|
author_email="llama-oss@meta.com",
|
||||||
description="Llama Stack",
|
description="Llama Stack",
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue