Merge branch 'main' into evals_new

This commit is contained in:
Xi Yan 2024-10-15 10:20:03 -07:00 committed by GitHub
commit 2c23a66300
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
24 changed files with 112 additions and 120 deletions

10
.flake8
View file

@ -21,11 +21,11 @@ ignore =
optional-ascii-coding = True optional-ascii-coding = True
exclude = exclude =
./.git, ./.git,
./docs ./docs/*,
./build ./build,
./scripts, ./scripts,
./venv, ./venv,
*.pyi *.pyi,
.pre-commit-config.yaml .pre-commit-config.yaml,
*.md *.md,
.flake8 .flake8

25
.github/workflows/pre-commit.yml vendored Normal file
View file

@ -0,0 +1,25 @@
name: Pre-commit
on:
pull_request:
push:
branches: [main]
jobs:
pre-commit:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@d632683dd7b4114ad314bca15554477dd762a938 # v4.2.0
- name: Set up Python
uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0
with:
python-version: '3.11'
cache: pip
cache-dependency-path: |
**/requirements*.txt
.pre-commit-config.yaml
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd #v3.0.1

View file

@ -1,3 +1,5 @@
<img src="https://github.com/user-attachments/assets/2fedfe0f-6df7-4441-98b2-87a1fd95ee1c" width="300" title="Llama Stack Logo" alt="Llama Stack Logo"/>
# Llama Stack # Llama Stack
[![PyPI version](https://img.shields.io/pypi/v/llama_stack.svg)](https://pypi.org/project/llama_stack/) [![PyPI version](https://img.shields.io/pypi/v/llama_stack.svg)](https://pypi.org/project/llama_stack/)
@ -97,7 +99,7 @@ The `llama` CLI makes it easy to work with the Llama Stack set of tools. Please
| **Language** | **Client SDK** | **Package** | | **Language** | **Client SDK** | **Package** |
| :----: | :----: | :----: | | :----: | :----: | :----: |
| Python | [llama-stack-client-python](https://github.com/meta-llama/llama-stack-client-python) | [![PyPI version](https://img.shields.io/pypi/v/llama_stack_client.svg)](https://pypi.org/project/llama_stack_client/) | Python | [llama-stack-client-python](https://github.com/meta-llama/llama-stack-client-python) | [![PyPI version](https://img.shields.io/pypi/v/llama_stack_client.svg)](https://pypi.org/project/llama_stack_client/)
| Swift | [llama-stack-client-swift](https://github.com/meta-llama/llama-stack-client-swift) | | Swift | [llama-stack-client-swift](https://github.com/meta-llama/llama-stack-client-swift) | [![Swift Package Index](https://img.shields.io/endpoint?url=https%3A%2F%2Fswiftpackageindex.com%2Fapi%2Fpackages%2Fmeta-llama%2Fllama-stack-client-swift%2Fbadge%3Ftype%3Dswift-versions)](https://swiftpackageindex.com/meta-llama/llama-stack-client-swift)
| Node | [llama-stack-client-node](https://github.com/meta-llama/llama-stack-client-node) | [![NPM version](https://img.shields.io/npm/v/llama-stack-client.svg)](https://npmjs.org/package/llama-stack-client) | Node | [llama-stack-client-node](https://github.com/meta-llama/llama-stack-client-node) | [![NPM version](https://img.shields.io/npm/v/llama-stack-client.svg)](https://npmjs.org/package/llama-stack-client)
| Kotlin | [llama-stack-client-kotlin](https://github.com/meta-llama/llama-stack-client-kotlin) | | Kotlin | [llama-stack-client-kotlin](https://github.com/meta-llama/llama-stack-client-kotlin) |

View file

@ -73,7 +73,7 @@ docker run -it -p 5000:5000 -v ~/.llama:/root/.llama --gpus=all llamastack-local
``` ```
> [!NOTE] > [!NOTE]
> `~/.llama` should be the path containing downloaded weights of Llama models. > `~/.llama` should be the path containing downloaded weights of Llama models.
#### Via conda #### Via conda

View file

@ -21,7 +21,7 @@
"info": { "info": {
"title": "[DRAFT] Llama Stack Specification", "title": "[DRAFT] Llama Stack Specification",
"version": "0.0.1", "version": "0.0.1",
"description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-10-15 10:15:15.195382" "description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-10-10 15:29:56.831109"
}, },
"servers": [ "servers": [
{ {
@ -6228,49 +6228,7 @@
], ],
"tags": [ "tags": [
{ {
"name": "Inference"
},
{
"name": "PostTraining"
},
{
"name": "Agents"
},
{
"name": "MemoryBanks"
},
{
"name": "Inspect"
},
{
"name": "Models"
},
{
"name": "Safety"
},
{
"name": "Evals"
},
{
"name": "BatchInference"
},
{
"name": "Shields"
},
{
"name": "SyntheticDataGeneration"
},
{
"name": "Telemetry"
},
{
"name": "RewardScoring"
},
{
"name": "Datasets"
},
{
"name": "Memory"
}, },
{ {
"name": "BuiltinTool", "name": "BuiltinTool",

View file

@ -2679,7 +2679,6 @@ info:
description: "This is the specification of the llama stack that provides\n \ description: "This is the specification of the llama stack that provides\n \
\ a set of endpoints and their corresponding interfaces that are tailored\ \ a set of endpoints and their corresponding interfaces that are tailored\
\ to\n best leverage Llama Models. The specification is still in\ \ to\n best leverage Llama Models. The specification is still in\
\ draft and subject to change.\n Generated at 2024-10-15 10:15:15.195382"
title: '[DRAFT] Llama Stack Specification' title: '[DRAFT] Llama Stack Specification'
version: 0.0.1 version: 0.0.1
jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema

View file

@ -149,6 +149,7 @@ class StackBuild(Subcommand):
def _run_template_list_cmd(self, args: argparse.Namespace) -> None: def _run_template_list_cmd(self, args: argparse.Namespace) -> None:
import json import json
from llama_stack.cli.table import print_table from llama_stack.cli.table import print_table
# eventually, this should query a registry at llama.meta.com/llamastack/distributions # eventually, this should query a registry at llama.meta.com/llamastack/distributions
@ -175,6 +176,7 @@ class StackBuild(Subcommand):
def _run_stack_build_command(self, args: argparse.Namespace) -> None: def _run_stack_build_command(self, args: argparse.Namespace) -> None:
import textwrap import textwrap
import yaml import yaml
from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.distribution import get_provider_registry
from prompt_toolkit import prompt from prompt_toolkit import prompt
@ -256,7 +258,9 @@ class StackBuild(Subcommand):
providers = dict() providers = dict()
for api, providers_for_api in get_provider_registry().items(): for api, providers_for_api in get_provider_registry().items():
available_providers = [ available_providers = [
x for x in providers_for_api.keys() if x != "remote" x
for x in providers_for_api.keys()
if x not in ("remote", "remote::sample")
] ]
api_provider = prompt( api_provider = prompt(
"> Enter provider for API {}: ".format(api.value), "> Enter provider for API {}: ".format(api.value),

View file

@ -7,10 +7,11 @@
from .config import DatabricksImplConfig from .config import DatabricksImplConfig
from .databricks import DatabricksInferenceAdapter from .databricks import DatabricksInferenceAdapter
async def get_adapter_impl(config: DatabricksImplConfig, _deps): async def get_adapter_impl(config: DatabricksImplConfig, _deps):
assert isinstance( assert isinstance(
config, DatabricksImplConfig config, DatabricksImplConfig
), f"Unexpected config type: {type(config)}" ), f"Unexpected config type: {type(config)}"
impl = DatabricksInferenceAdapter(config) impl = DatabricksInferenceAdapter(config)
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Optional
from llama_models.schema_utils import json_schema_type from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -19,4 +18,4 @@ class DatabricksImplConfig(BaseModel):
api_token: str = Field( api_token: str = Field(
default=None, default=None,
description="The Databricks API token", description="The Databricks API token",
) )

View file

@ -48,7 +48,14 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
def completion(self, request: CompletionRequest) -> AsyncGenerator: def completion(
self,
model: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
raise NotImplementedError() raise NotImplementedError()
def chat_completion( def chat_completion(

View file

@ -63,13 +63,12 @@ async def llm_rag_query_generator(
model = config.model model = config.model
message = UserMessage(content=content) message = UserMessage(content=content)
response = inference_api.chat_completion( response = await inference_api.chat_completion(
model=model, model=model,
messages=[message], messages=[message],
stream=False, stream=False,
) )
async for chunk in response: query = response.completion_message.content
query = chunk.completion_message.content
return query return query

View file

@ -16,7 +16,7 @@ from llama_stack.apis.agents import * # noqa: F403
from ..agents import ( from ..agents import (
AGENT_INSTANCES_BY_ID, AGENT_INSTANCES_BY_ID,
MetaReferenceAgentsImpl, MetaReferenceAgentsImpl,
MetaReferenceImplConfig, MetaReferenceInferenceConfig,
) )
@ -166,7 +166,7 @@ def mock_memory_api():
@pytest.fixture @pytest.fixture
async def chat_agent(mock_inference_api, mock_safety_api, mock_memory_api): async def chat_agent(mock_inference_api, mock_safety_api, mock_memory_api):
impl = MetaReferenceAgentsImpl( impl = MetaReferenceAgentsImpl(
config=MetaReferenceImplConfig(), config=MetaReferenceInferenceConfig(),
inference_api=mock_inference_api, inference_api=mock_inference_api,
safety_api=mock_safety_api, safety_api=mock_safety_api,
memory_api=mock_memory_api, memory_api=mock_memory_api,

View file

@ -4,16 +4,17 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from .config import MetaReferenceImplConfig # noqa from typing import Union
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
async def get_provider_impl(config: MetaReferenceImplConfig, _deps): async def get_provider_impl(
config: Union[MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig],
_deps,
):
from .inference import MetaReferenceInferenceImpl from .inference import MetaReferenceInferenceImpl
assert isinstance(
config, MetaReferenceImplConfig
), f"Unexpected config type: {type(config)}"
impl = MetaReferenceInferenceImpl(config) impl = MetaReferenceInferenceImpl(config)
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -15,12 +15,11 @@ from pydantic import BaseModel, Field, field_validator
from llama_stack.providers.utils.inference import supported_inference_models from llama_stack.providers.utils.inference import supported_inference_models
class MetaReferenceImplConfig(BaseModel): class MetaReferenceInferenceConfig(BaseModel):
model: str = Field( model: str = Field(
default="Llama3.1-8B-Instruct", default="Llama3.1-8B-Instruct",
description="Model descriptor from `llama model list`", description="Model descriptor from `llama model list`",
) )
quantization: Optional[QuantizationConfig] = None
torch_seed: Optional[int] = None torch_seed: Optional[int] = None
max_seq_len: int = 4096 max_seq_len: int = 4096
max_batch_size: int = 1 max_batch_size: int = 1
@ -38,9 +37,9 @@ class MetaReferenceImplConfig(BaseModel):
@property @property
def model_parallel_size(self) -> int: def model_parallel_size(self) -> int:
# HACK ALERT: this will be fixed when we move inference configuration
# to ModelsRegistry and we can explicitly ask for `model_parallel_size`
# as configuration there
resolved = resolve_model(self.model) resolved = resolve_model(self.model)
assert resolved is not None
return resolved.pth_file_count return resolved.pth_file_count
class MetaReferenceQuantizedInferenceConfig(MetaReferenceInferenceConfig):
quantization: QuantizationConfig

View file

@ -11,9 +11,8 @@ import json
import os import os
import sys import sys
import time import time
from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Generator, List, Optional from typing import Generator, List, Optional, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@ -36,14 +35,12 @@ from llama_models.llama3.reference_impl.multimodal.model import (
) )
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from llama_stack.apis.inference import QuantizationType
from llama_stack.distribution.utils.model_utils import model_local_dir
from pydantic import BaseModel from pydantic import BaseModel
from termcolor import cprint from termcolor import cprint
from .config import MetaReferenceImplConfig from llama_stack.distribution.utils.model_utils import model_local_dir
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
def model_checkpoint_dir(model) -> str: def model_checkpoint_dir(model) -> str:
@ -68,7 +65,11 @@ class TokenResult(BaseModel):
class Llama: class Llama:
@staticmethod @staticmethod
def build(config: MetaReferenceImplConfig): def build(
config: Union[
MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
]
):
""" """
Build a Llama instance by initializing and loading a model checkpoint. Build a Llama instance by initializing and loading a model checkpoint.
@ -78,15 +79,6 @@ class Llama:
""" """
model = resolve_model(config.model) model = resolve_model(config.model)
if (
config.quantization
and config.quantization.type == QuantizationType.fp8.value
):
from .quantization.loader import is_fbgemm_available
if not is_fbgemm_available():
raise ImportError("fbgemm-gpu is required for FP8 quantization")
if not torch.distributed.is_initialized(): if not torch.distributed.is_initialized():
torch.distributed.init_process_group("nccl") torch.distributed.init_process_group("nccl")
@ -134,12 +126,7 @@ class Llama:
model_args.vocab_size == tokenizer.n_words model_args.vocab_size == tokenizer.n_words
), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}" ), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
fp8 = ( if isinstance(config, MetaReferenceQuantizedInferenceConfig):
config.quantization
and config.quantization.type == QuantizationType.fp8.value
)
if fp8:
from .quantization.loader import convert_to_quantized_model from .quantization.loader import convert_to_quantized_model
# load on CPU in bf16 so that fp8 conversion does not find an # load on CPU in bf16 so that fp8 conversion does not find an

View file

@ -17,7 +17,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_messages, chat_completion_request_to_messages,
) )
from .config import MetaReferenceImplConfig from .config import MetaReferenceInferenceConfig
from .model_parallel import LlamaModelParallelGenerator from .model_parallel import LlamaModelParallelGenerator
# there's a single model parallel process running serving the model. for now, # there's a single model parallel process running serving the model. for now,
@ -26,7 +26,7 @@ SEMAPHORE = asyncio.Semaphore(1)
class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate): class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
def __init__(self, config: MetaReferenceImplConfig) -> None: def __init__(self, config: MetaReferenceInferenceConfig) -> None:
self.config = config self.config = config
model = resolve_model(config.model) model = resolve_model(config.model)
if model is None: if model is None:

View file

@ -6,7 +6,6 @@
import os import os
from copy import deepcopy from copy import deepcopy
from dataclasses import dataclass
from functools import partial from functools import partial
from typing import Generator, List, Optional from typing import Generator, List, Optional
@ -15,7 +14,7 @@ from llama_models.llama3.api.datatypes import Message, ToolPromptFormat
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from .config import MetaReferenceImplConfig from .config import MetaReferenceInferenceConfig
from .generation import Llama, model_checkpoint_dir from .generation import Llama, model_checkpoint_dir
from .parallel_utils import InferenceArgs, ModelParallelProcessGroup from .parallel_utils import InferenceArgs, ModelParallelProcessGroup
@ -36,7 +35,7 @@ class ModelRunner:
) )
def init_model_cb(config: MetaReferenceImplConfig): def init_model_cb(config: MetaReferenceInferenceConfig):
llama = Llama.build(config) llama = Llama.build(config)
return ModelRunner(llama) return ModelRunner(llama)
@ -52,7 +51,7 @@ class LlamaModelParallelGenerator:
clear at the callsite why we need to use a context manager. clear at the callsite why we need to use a context manager.
""" """
def __init__(self, config: MetaReferenceImplConfig): def __init__(self, config: MetaReferenceInferenceConfig):
self.config = config self.config = config
self.model = resolve_model(self.config.model) self.model = resolve_model(self.config.model)
# this is a hack because Agent's loop uses this to tokenize and check if input is too long # this is a hack because Agent's loop uses this to tokenize and check if input is too long

View file

@ -11,7 +11,7 @@ import tempfile
import time import time
import uuid import uuid
from enum import Enum from enum import Enum
from typing import Any, Callable, Generator, List, Literal, Optional, Union from typing import Callable, Generator, List, Literal, Optional, Union
import torch import torch
@ -317,7 +317,7 @@ def start_model_parallel_process(
request_socket.send(encode_msg(ReadyRequest())) request_socket.send(encode_msg(ReadyRequest()))
response = request_socket.recv() response = request_socket.recv()
print(f"Finished model load {response}") print("Loaded model...")
return request_socket, process return request_socket, process

View file

@ -22,19 +22,10 @@ from torch import Tensor
from llama_stack.apis.inference import QuantizationType from llama_stack.apis.inference import QuantizationType
from llama_stack.providers.impls.meta_reference.inference.config import ( from llama_stack.providers.impls.meta_reference.inference.config import (
MetaReferenceImplConfig, MetaReferenceQuantizedInferenceConfig,
) )
def is_fbgemm_available() -> bool:
try:
import fbgemm_gpu.experimental.gen_ai # noqa: F401
return True
except ImportError:
return False
def swiglu_wrapper( def swiglu_wrapper(
self, self,
x: Tensor, x: Tensor,
@ -47,7 +38,7 @@ def swiglu_wrapper(
def convert_to_quantized_model( def convert_to_quantized_model(
model: Transformer, model: Transformer,
config: MetaReferenceImplConfig, config: MetaReferenceQuantizedInferenceConfig,
fp8_activation_scale_ub: Optional[float] = 1200.0, fp8_activation_scale_ub: Optional[float] = 1200.0,
) -> Transformer: ) -> Transformer:
if config.quantization.type == QuantizationType.bf16.value: if config.quantization.type == QuantizationType.bf16.value:

View file

@ -170,7 +170,7 @@ class LlamaGuardShield(ShieldBase):
for i in range(1, len(messages)): for i in range(1, len(messages)):
if messages[i].role == messages[i - 1].role: if messages[i].role == messages[i - 1].role:
raise ValueError( raise ValueError(
f"Messages must alternate between user and assistant. Message {i} has the same role as message {i-1}" f"Messages must alternate between user and assistant. Message {i} has the same role as message {i - 1}"
) )
return messages return messages

View file

@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any from typing import Any
from .config import VLLMConfig from .config import VLLMConfig

View file

@ -14,6 +14,21 @@ def available_providers() -> List[ProviderSpec]:
InlineProviderSpec( InlineProviderSpec(
api=Api.inference, api=Api.inference,
provider_type="meta-reference", provider_type="meta-reference",
pip_packages=[
"accelerate",
"blobfile",
"fairscale",
"torch",
"torchvision",
"transformers",
"zmq",
],
module="llama_stack.providers.impls.meta_reference.inference",
config_class="llama_stack.providers.impls.meta_reference.inference.MetaReferenceInferenceConfig",
),
InlineProviderSpec(
api=Api.inference,
provider_type="meta-reference-quantized",
pip_packages=[ pip_packages=[
"accelerate", "accelerate",
"blobfile", "blobfile",
@ -25,7 +40,7 @@ def available_providers() -> List[ProviderSpec]:
"zmq", "zmq",
], ],
module="llama_stack.providers.impls.meta_reference.inference", module="llama_stack.providers.impls.meta_reference.inference",
config_class="llama_stack.providers.impls.meta_reference.inference.MetaReferenceImplConfig", config_class="llama_stack.providers.impls.meta_reference.inference.MetaReferenceQuantizedInferenceConfig",
), ),
remote_provider_spec( remote_provider_spec(
api=Api.inference, api=Api.inference,

View file

@ -2,7 +2,7 @@ blobfile
fire fire
httpx httpx
huggingface-hub huggingface-hub
llama-models>=0.0.41 llama-models>=0.0.42
prompt-toolkit prompt-toolkit
python-dotenv python-dotenv
pydantic>=2 pydantic>=2

View file

@ -16,7 +16,7 @@ def read_requirements():
setup( setup(
name="llama_stack", name="llama_stack",
version="0.0.41", version="0.0.42",
author="Meta Llama", author="Meta Llama",
author_email="llama-oss@meta.com", author_email="llama-oss@meta.com",
description="Llama Stack", description="Llama Stack",