Merge branch 'main' into evals_new

This commit is contained in:
Xi Yan 2024-10-15 10:20:03 -07:00 committed by GitHub
commit 2c23a66300
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
24 changed files with 112 additions and 120 deletions

10
.flake8
View file

@ -21,11 +21,11 @@ ignore =
optional-ascii-coding = True
exclude =
./.git,
./docs
./build
./docs/*,
./build,
./scripts,
./venv,
*.pyi
.pre-commit-config.yaml
*.md
*.pyi,
.pre-commit-config.yaml,
*.md,
.flake8

25
.github/workflows/pre-commit.yml vendored Normal file
View file

@ -0,0 +1,25 @@
name: Pre-commit
on:
pull_request:
push:
branches: [main]
jobs:
pre-commit:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@d632683dd7b4114ad314bca15554477dd762a938 # v4.2.0
- name: Set up Python
uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0
with:
python-version: '3.11'
cache: pip
cache-dependency-path: |
**/requirements*.txt
.pre-commit-config.yaml
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd #v3.0.1

View file

@ -1,3 +1,5 @@
<img src="https://github.com/user-attachments/assets/2fedfe0f-6df7-4441-98b2-87a1fd95ee1c" width="300" title="Llama Stack Logo" alt="Llama Stack Logo"/>
# Llama Stack
[![PyPI version](https://img.shields.io/pypi/v/llama_stack.svg)](https://pypi.org/project/llama_stack/)
@ -97,7 +99,7 @@ The `llama` CLI makes it easy to work with the Llama Stack set of tools. Please
| **Language** | **Client SDK** | **Package** |
| :----: | :----: | :----: |
| Python | [llama-stack-client-python](https://github.com/meta-llama/llama-stack-client-python) | [![PyPI version](https://img.shields.io/pypi/v/llama_stack_client.svg)](https://pypi.org/project/llama_stack_client/)
| Swift | [llama-stack-client-swift](https://github.com/meta-llama/llama-stack-client-swift) |
| Swift | [llama-stack-client-swift](https://github.com/meta-llama/llama-stack-client-swift) | [![Swift Package Index](https://img.shields.io/endpoint?url=https%3A%2F%2Fswiftpackageindex.com%2Fapi%2Fpackages%2Fmeta-llama%2Fllama-stack-client-swift%2Fbadge%3Ftype%3Dswift-versions)](https://swiftpackageindex.com/meta-llama/llama-stack-client-swift)
| Node | [llama-stack-client-node](https://github.com/meta-llama/llama-stack-client-node) | [![NPM version](https://img.shields.io/npm/v/llama-stack-client.svg)](https://npmjs.org/package/llama-stack-client)
| Kotlin | [llama-stack-client-kotlin](https://github.com/meta-llama/llama-stack-client-kotlin) |

View file

@ -73,7 +73,7 @@ docker run -it -p 5000:5000 -v ~/.llama:/root/.llama --gpus=all llamastack-local
```
> [!NOTE]
> `~/.llama` should be the path containing downloaded weights of Llama models.
> `~/.llama` should be the path containing downloaded weights of Llama models.
#### Via conda

View file

@ -21,7 +21,7 @@
"info": {
"title": "[DRAFT] Llama Stack Specification",
"version": "0.0.1",
"description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-10-15 10:15:15.195382"
"description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-10-10 15:29:56.831109"
},
"servers": [
{
@ -6228,49 +6228,7 @@
],
"tags": [
{
"name": "Inference"
},
{
"name": "PostTraining"
},
{
"name": "Agents"
},
{
"name": "MemoryBanks"
},
{
"name": "Inspect"
},
{
"name": "Models"
},
{
"name": "Safety"
},
{
"name": "Evals"
},
{
"name": "BatchInference"
},
{
"name": "Shields"
},
{
"name": "SyntheticDataGeneration"
},
{
"name": "Telemetry"
},
{
"name": "RewardScoring"
},
{
"name": "Datasets"
},
{
"name": "Memory"
},
{
"name": "BuiltinTool",

View file

@ -2679,7 +2679,6 @@ info:
description: "This is the specification of the llama stack that provides\n \
\ a set of endpoints and their corresponding interfaces that are tailored\
\ to\n best leverage Llama Models. The specification is still in\
\ draft and subject to change.\n Generated at 2024-10-15 10:15:15.195382"
title: '[DRAFT] Llama Stack Specification'
version: 0.0.1
jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema

View file

@ -149,6 +149,7 @@ class StackBuild(Subcommand):
def _run_template_list_cmd(self, args: argparse.Namespace) -> None:
import json
from llama_stack.cli.table import print_table
# eventually, this should query a registry at llama.meta.com/llamastack/distributions
@ -175,6 +176,7 @@ class StackBuild(Subcommand):
def _run_stack_build_command(self, args: argparse.Namespace) -> None:
import textwrap
import yaml
from llama_stack.distribution.distribution import get_provider_registry
from prompt_toolkit import prompt
@ -256,7 +258,9 @@ class StackBuild(Subcommand):
providers = dict()
for api, providers_for_api in get_provider_registry().items():
available_providers = [
x for x in providers_for_api.keys() if x != "remote"
x
for x in providers_for_api.keys()
if x not in ("remote", "remote::sample")
]
api_provider = prompt(
"> Enter provider for API {}: ".format(api.value),

View file

@ -7,10 +7,11 @@
from .config import DatabricksImplConfig
from .databricks import DatabricksInferenceAdapter
async def get_adapter_impl(config: DatabricksImplConfig, _deps):
assert isinstance(
config, DatabricksImplConfig
), f"Unexpected config type: {type(config)}"
impl = DatabricksInferenceAdapter(config)
await impl.initialize()
return impl
return impl

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
@ -19,4 +18,4 @@ class DatabricksImplConfig(BaseModel):
api_token: str = Field(
default=None,
description="The Databricks API token",
)
)

View file

@ -48,7 +48,14 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
async def shutdown(self) -> None:
pass
def completion(self, request: CompletionRequest) -> AsyncGenerator:
def completion(
self,
model: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
raise NotImplementedError()
def chat_completion(

View file

@ -63,13 +63,12 @@ async def llm_rag_query_generator(
model = config.model
message = UserMessage(content=content)
response = inference_api.chat_completion(
response = await inference_api.chat_completion(
model=model,
messages=[message],
stream=False,
)
async for chunk in response:
query = chunk.completion_message.content
query = response.completion_message.content
return query

View file

@ -16,7 +16,7 @@ from llama_stack.apis.agents import * # noqa: F403
from ..agents import (
AGENT_INSTANCES_BY_ID,
MetaReferenceAgentsImpl,
MetaReferenceImplConfig,
MetaReferenceInferenceConfig,
)
@ -166,7 +166,7 @@ def mock_memory_api():
@pytest.fixture
async def chat_agent(mock_inference_api, mock_safety_api, mock_memory_api):
impl = MetaReferenceAgentsImpl(
config=MetaReferenceImplConfig(),
config=MetaReferenceInferenceConfig(),
inference_api=mock_inference_api,
safety_api=mock_safety_api,
memory_api=mock_memory_api,

View file

@ -4,16 +4,17 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import MetaReferenceImplConfig # noqa
from typing import Union
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
async def get_provider_impl(config: MetaReferenceImplConfig, _deps):
async def get_provider_impl(
config: Union[MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig],
_deps,
):
from .inference import MetaReferenceInferenceImpl
assert isinstance(
config, MetaReferenceImplConfig
), f"Unexpected config type: {type(config)}"
impl = MetaReferenceInferenceImpl(config)
await impl.initialize()
return impl

View file

@ -15,12 +15,11 @@ from pydantic import BaseModel, Field, field_validator
from llama_stack.providers.utils.inference import supported_inference_models
class MetaReferenceImplConfig(BaseModel):
class MetaReferenceInferenceConfig(BaseModel):
model: str = Field(
default="Llama3.1-8B-Instruct",
description="Model descriptor from `llama model list`",
)
quantization: Optional[QuantizationConfig] = None
torch_seed: Optional[int] = None
max_seq_len: int = 4096
max_batch_size: int = 1
@ -38,9 +37,9 @@ class MetaReferenceImplConfig(BaseModel):
@property
def model_parallel_size(self) -> int:
# HACK ALERT: this will be fixed when we move inference configuration
# to ModelsRegistry and we can explicitly ask for `model_parallel_size`
# as configuration there
resolved = resolve_model(self.model)
assert resolved is not None
return resolved.pth_file_count
class MetaReferenceQuantizedInferenceConfig(MetaReferenceInferenceConfig):
quantization: QuantizationConfig

View file

@ -11,9 +11,8 @@ import json
import os
import sys
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Generator, List, Optional
from typing import Generator, List, Optional, Union
import torch
import torch.nn.functional as F
@ -36,14 +35,12 @@ from llama_models.llama3.reference_impl.multimodal.model import (
)
from llama_models.sku_list import resolve_model
from llama_stack.apis.inference import QuantizationType
from llama_stack.distribution.utils.model_utils import model_local_dir
from pydantic import BaseModel
from termcolor import cprint
from .config import MetaReferenceImplConfig
from llama_stack.distribution.utils.model_utils import model_local_dir
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
def model_checkpoint_dir(model) -> str:
@ -68,7 +65,11 @@ class TokenResult(BaseModel):
class Llama:
@staticmethod
def build(config: MetaReferenceImplConfig):
def build(
config: Union[
MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
]
):
"""
Build a Llama instance by initializing and loading a model checkpoint.
@ -78,15 +79,6 @@ class Llama:
"""
model = resolve_model(config.model)
if (
config.quantization
and config.quantization.type == QuantizationType.fp8.value
):
from .quantization.loader import is_fbgemm_available
if not is_fbgemm_available():
raise ImportError("fbgemm-gpu is required for FP8 quantization")
if not torch.distributed.is_initialized():
torch.distributed.init_process_group("nccl")
@ -134,12 +126,7 @@ class Llama:
model_args.vocab_size == tokenizer.n_words
), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
fp8 = (
config.quantization
and config.quantization.type == QuantizationType.fp8.value
)
if fp8:
if isinstance(config, MetaReferenceQuantizedInferenceConfig):
from .quantization.loader import convert_to_quantized_model
# load on CPU in bf16 so that fp8 conversion does not find an

View file

@ -17,7 +17,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_messages,
)
from .config import MetaReferenceImplConfig
from .config import MetaReferenceInferenceConfig
from .model_parallel import LlamaModelParallelGenerator
# there's a single model parallel process running serving the model. for now,
@ -26,7 +26,7 @@ SEMAPHORE = asyncio.Semaphore(1)
class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
def __init__(self, config: MetaReferenceImplConfig) -> None:
def __init__(self, config: MetaReferenceInferenceConfig) -> None:
self.config = config
model = resolve_model(config.model)
if model is None:

View file

@ -6,7 +6,6 @@
import os
from copy import deepcopy
from dataclasses import dataclass
from functools import partial
from typing import Generator, List, Optional
@ -15,7 +14,7 @@ from llama_models.llama3.api.datatypes import Message, ToolPromptFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model
from .config import MetaReferenceImplConfig
from .config import MetaReferenceInferenceConfig
from .generation import Llama, model_checkpoint_dir
from .parallel_utils import InferenceArgs, ModelParallelProcessGroup
@ -36,7 +35,7 @@ class ModelRunner:
)
def init_model_cb(config: MetaReferenceImplConfig):
def init_model_cb(config: MetaReferenceInferenceConfig):
llama = Llama.build(config)
return ModelRunner(llama)
@ -52,7 +51,7 @@ class LlamaModelParallelGenerator:
clear at the callsite why we need to use a context manager.
"""
def __init__(self, config: MetaReferenceImplConfig):
def __init__(self, config: MetaReferenceInferenceConfig):
self.config = config
self.model = resolve_model(self.config.model)
# this is a hack because Agent's loop uses this to tokenize and check if input is too long

View file

@ -11,7 +11,7 @@ import tempfile
import time
import uuid
from enum import Enum
from typing import Any, Callable, Generator, List, Literal, Optional, Union
from typing import Callable, Generator, List, Literal, Optional, Union
import torch
@ -317,7 +317,7 @@ def start_model_parallel_process(
request_socket.send(encode_msg(ReadyRequest()))
response = request_socket.recv()
print(f"Finished model load {response}")
print("Loaded model...")
return request_socket, process

View file

@ -22,19 +22,10 @@ from torch import Tensor
from llama_stack.apis.inference import QuantizationType
from llama_stack.providers.impls.meta_reference.inference.config import (
MetaReferenceImplConfig,
MetaReferenceQuantizedInferenceConfig,
)
def is_fbgemm_available() -> bool:
try:
import fbgemm_gpu.experimental.gen_ai # noqa: F401
return True
except ImportError:
return False
def swiglu_wrapper(
self,
x: Tensor,
@ -47,7 +38,7 @@ def swiglu_wrapper(
def convert_to_quantized_model(
model: Transformer,
config: MetaReferenceImplConfig,
config: MetaReferenceQuantizedInferenceConfig,
fp8_activation_scale_ub: Optional[float] = 1200.0,
) -> Transformer:
if config.quantization.type == QuantizationType.bf16.value:

View file

@ -170,7 +170,7 @@ class LlamaGuardShield(ShieldBase):
for i in range(1, len(messages)):
if messages[i].role == messages[i - 1].role:
raise ValueError(
f"Messages must alternate between user and assistant. Message {i} has the same role as message {i-1}"
f"Messages must alternate between user and assistant. Message {i} has the same role as message {i - 1}"
)
return messages

View file

@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from .config import VLLMConfig

View file

@ -14,6 +14,21 @@ def available_providers() -> List[ProviderSpec]:
InlineProviderSpec(
api=Api.inference,
provider_type="meta-reference",
pip_packages=[
"accelerate",
"blobfile",
"fairscale",
"torch",
"torchvision",
"transformers",
"zmq",
],
module="llama_stack.providers.impls.meta_reference.inference",
config_class="llama_stack.providers.impls.meta_reference.inference.MetaReferenceInferenceConfig",
),
InlineProviderSpec(
api=Api.inference,
provider_type="meta-reference-quantized",
pip_packages=[
"accelerate",
"blobfile",
@ -25,7 +40,7 @@ def available_providers() -> List[ProviderSpec]:
"zmq",
],
module="llama_stack.providers.impls.meta_reference.inference",
config_class="llama_stack.providers.impls.meta_reference.inference.MetaReferenceImplConfig",
config_class="llama_stack.providers.impls.meta_reference.inference.MetaReferenceQuantizedInferenceConfig",
),
remote_provider_spec(
api=Api.inference,

View file

@ -2,7 +2,7 @@ blobfile
fire
httpx
huggingface-hub
llama-models>=0.0.41
llama-models>=0.0.42
prompt-toolkit
python-dotenv
pydantic>=2

View file

@ -16,7 +16,7 @@ def read_requirements():
setup(
name="llama_stack",
version="0.0.41",
version="0.0.42",
author="Meta Llama",
author_email="llama-oss@meta.com",
description="Llama Stack",