From f5dcc03742d2bc0cba90a0874630e58e948acfef Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Wed, 23 Oct 2024 20:22:00 -0700 Subject: [PATCH 01/17] use pytorch/pytorch as base --- distributions/meta-reference-gpu/build.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/distributions/meta-reference-gpu/build.yaml b/distributions/meta-reference-gpu/build.yaml index e76197330..5b1521a92 100644 --- a/distributions/meta-reference-gpu/build.yaml +++ b/distributions/meta-reference-gpu/build.yaml @@ -1,5 +1,6 @@ name: meta-reference-gpu distribution_spec: + docker_image: pytorch/pytorch description: Use code from `llama_stack` itself to serve all llama stack APIs providers: inference: meta-reference From 05a8d47b98953612e340666968be03eaa6513a32 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Wed, 23 Oct 2024 19:33:14 -0700 Subject: [PATCH 02/17] Add a meta-reference-quantized-gpu distribution --- distributions/README.md | 1 + .../meta-reference-quantized-gpu/README.md | 34 +++++++++++++ .../meta-reference-quantized-gpu/build.yaml | 14 +++++ .../meta-reference-quantized-gpu/run.yaml | 51 +++++++++++++++++++ llama_stack/distribution/build_container.sh | 6 +-- llama_stack/providers/registry/inference.py | 2 +- 6 files changed, 104 insertions(+), 4 deletions(-) create mode 100644 distributions/meta-reference-quantized-gpu/README.md create mode 100644 distributions/meta-reference-quantized-gpu/build.yaml create mode 100644 distributions/meta-reference-quantized-gpu/run.yaml diff --git a/distributions/README.md b/distributions/README.md index dc1e3cc25..4dc2b9d03 100644 --- a/distributions/README.md +++ b/distributions/README.md @@ -7,6 +7,7 @@ A Distribution is where APIs and Providers are assembled together to provide a c | **Distribution** | **Llama Stack Docker** | Start This Distribution | **Inference** | **Agents** | **Memory** | **Safety** | **Telemetry** | |:----------------: |:------------------------------------------: |:-----------------------: |:------------------: |:------------------: |:------------------: |:------------------: |:------------------: | | Meta Reference | [llamastack/distribution-meta-reference-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-gpu/general) | [Guide](./meta-reference-gpu/) | meta-reference | meta-reference | meta-reference; remote::pgvector; remote::chromadb | meta-reference | meta-reference | +| Meta Reference Quantized | [llamastack/distribution-meta-reference-quantized-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-quantized-gpu/general) | [Guide](./meta-reference-quantized-gpu/) | meta-reference-quantized | meta-reference | meta-reference; remote::pgvector; remote::chromadb | meta-reference | meta-reference | | Ollama | [llamastack/distribution-ollama](https://hub.docker.com/repository/docker/llamastack/distribution-ollama/general) | [Guide](./ollama/) | remote::ollama | meta-reference | remote::pgvector; remote::chromadb | remote::ollama | meta-reference | | TGI | [llamastack/distribution-tgi](https://hub.docker.com/repository/docker/llamastack/distribution-tgi/general) | [Guide](./tgi/) | remote::tgi | meta-reference | meta-reference; remote::pgvector; remote::chromadb | meta-reference | meta-reference | | Together | [llamastack/distribution-together](https://hub.docker.com/repository/docker/llamastack/distribution-together/general) | [Guide](./together/) | remote::together | meta-reference | remote::weaviate | meta-reference | meta-reference | diff --git a/distributions/meta-reference-quantized-gpu/README.md b/distributions/meta-reference-quantized-gpu/README.md new file mode 100644 index 000000000..0c05a13c1 --- /dev/null +++ b/distributions/meta-reference-quantized-gpu/README.md @@ -0,0 +1,34 @@ +# Meta Reference Quantized Distribution + +The `llamastack/distribution-meta-reference-quantized-gpu` distribution consists of the following provider configurations. + + +| **API** | **Inference** | **Agents** | **Memory** | **Safety** | **Telemetry** | +|----------------- |------------------------ |---------------- |-------------------------------------------------- |---------------- |---------------- | +| **Provider(s)** | meta-reference-quantized | meta-reference | meta-reference, remote::pgvector, remote::chroma | meta-reference | meta-reference | + +The only difference vs. the `meta-reference-gpu` distribution is that it has support for more efficient inference -- with fp8, int4 quantization, etc. + +### Start the Distribution (Single Node GPU) + +> [!NOTE] +> This assumes you have access to GPU to start a local server with access to your GPU. + + +> [!NOTE] +> `~/.llama` should be the path containing downloaded weights of Llama models. + + +To download and start running a pre-built docker container, you may use the following commands: + +``` +docker run -it -p 5000:5000 -v ~/.llama:/root/.llama \ + -v ./run.yaml:/root/my-run.yaml \ + --gpus=all \ + distribution-meta-reference-quantized-gpu \ + --yaml_config /root/my-run.yaml +``` + +### Alternative (Build and start distribution locally via conda) + +- You may checkout the [Getting Started](../../docs/getting_started.md) for more details on building locally via conda and starting up the distribution. diff --git a/distributions/meta-reference-quantized-gpu/build.yaml b/distributions/meta-reference-quantized-gpu/build.yaml new file mode 100644 index 000000000..e9ddb4aad --- /dev/null +++ b/distributions/meta-reference-quantized-gpu/build.yaml @@ -0,0 +1,14 @@ +name: meta-reference-quantized-gpu +distribution_spec: + docker_image: pytorch/pytorch:2.5.0-cuda12.4-cudnn9-runtime + description: Use code from `llama_stack` itself to serve all llama stack APIs + providers: + inference: meta-reference-quantized + memory: + - meta-reference + - remote::chromadb + - remote::pgvector + safety: meta-reference + agents: meta-reference + telemetry: meta-reference +image_type: docker diff --git a/distributions/meta-reference-quantized-gpu/run.yaml b/distributions/meta-reference-quantized-gpu/run.yaml new file mode 100644 index 000000000..6e8be2b6d --- /dev/null +++ b/distributions/meta-reference-quantized-gpu/run.yaml @@ -0,0 +1,51 @@ +version: '2' +built_at: '2024-10-08T17:40:45.325529' +image_name: local +docker_image: null +conda_env: local +apis: +- shields +- agents +- models +- memory +- memory_banks +- inference +- safety +providers: + inference: + - provider_id: meta0 + provider_type: meta-reference-quantized + config: + model: Llama3.2-3B-Instruct + quantization: + type: fp8 + torch_seed: null + max_seq_len: 2048 + max_batch_size: 1 + safety: + - provider_id: meta0 + provider_type: meta-reference + config: + llama_guard_shield: + model: Llama-Guard-3-1B + excluded_categories: [] + disable_input_check: false + disable_output_check: false + prompt_guard_shield: + model: Prompt-Guard-86M + memory: + - provider_id: meta0 + provider_type: meta-reference + config: {} + agents: + - provider_id: meta0 + provider_type: meta-reference + config: + persistence_store: + namespace: null + type: sqlite + db_path: ~/.llama/runtime/kvstore.db + telemetry: + - provider_id: meta0 + provider_type: meta-reference + config: {} diff --git a/llama_stack/distribution/build_container.sh b/llama_stack/distribution/build_container.sh index 19f3df1e3..3bf74edcf 100755 --- a/llama_stack/distribution/build_container.sh +++ b/llama_stack/distribution/build_container.sh @@ -97,7 +97,7 @@ if [ -n "$pip_dependencies" ]; then fi if [ -n "$special_pip_deps" ]; then - IFS='#' read -ra parts <<< "$special_pip_deps" + IFS='#' read -ra parts <<<"$special_pip_deps" for part in "${parts[@]}"; do add_to_docker "RUN pip install $part" done @@ -127,7 +127,7 @@ if [ -n "$LLAMA_MODELS_DIR" ]; then mounts="$mounts -v $(readlink -f $LLAMA_MODELS_DIR):$models_mount" fi -if command -v selinuxenabled &> /dev/null && selinuxenabled; then +if command -v selinuxenabled &>/dev/null && selinuxenabled; then # Disable SELinux labels -- we don't want to relabel the llama-stack source dir DOCKER_OPTS="$DOCKER_OPTS --security-opt label=disable" fi @@ -139,4 +139,4 @@ $DOCKER_BINARY build $DOCKER_OPTS -t $image_name -f "$TEMP_DIR/Dockerfile" "$REP rm -rf $REPO_CONFIGS_DIR set +x -echo "Success! You can run it with: $DOCKER_BINARY $DOCKER_OPTS run -p 5000:5000 $image_name" +echo "Success!" diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index 6f8bc2c6e..28555755b 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -36,7 +36,7 @@ def available_providers() -> List[ProviderSpec]: pip_packages=( META_REFERENCE_DEPS + [ - "fbgemm-gpu==0.8.0", + "fbgemm-gpu", ] ), module="llama_stack.providers.impls.meta_reference.inference", From 7afe51c84d7d8738fb0b29dc8c765ba73b595455 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Thu, 24 Oct 2024 08:38:56 -0700 Subject: [PATCH 03/17] New quantized models (#301) --- .gitignore | 1 + llama_stack/apis/inference/client.py | 2 +- llama_stack/apis/inference/inference.py | 11 +- .../meta_reference/inference/generation.py | 44 ++- .../inference/quantization/loader.py | 254 +++++++++++++++++- llama_stack/providers/registry/inference.py | 1 + 6 files changed, 292 insertions(+), 21 deletions(-) diff --git a/.gitignore b/.gitignore index a6c204131..897494f21 100644 --- a/.gitignore +++ b/.gitignore @@ -13,6 +13,7 @@ xcuserdata/ Package.resolved *.pte *.ipynb_checkpoints* +.idea .venv/ .idea _build diff --git a/llama_stack/apis/inference/client.py b/llama_stack/apis/inference/client.py index 7359c6057..892da13ad 100644 --- a/llama_stack/apis/inference/client.py +++ b/llama_stack/apis/inference/client.py @@ -172,7 +172,7 @@ async def run_mm_main( ], ) cprint(f"User>{message.content}", "green") - iterator = client.chat_completion( + iterator = await client.chat_completion( model=model, messages=[message], stream=stream, diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 4ee01acae..d1ff047b0 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -25,6 +25,7 @@ class LogProbConfig(BaseModel): class QuantizationType(Enum): bf16 = "bf16" fp8 = "fp8" + int4 = "int4" @json_schema_type @@ -37,8 +38,14 @@ class Bf16QuantizationConfig(BaseModel): type: Literal[QuantizationType.bf16.value] = QuantizationType.bf16.value +@json_schema_type +class Int4QuantizationConfig(BaseModel): + type: Literal[QuantizationType.int4.value] = QuantizationType.int4.value + scheme: Optional[str] = None + + QuantizationConfig = Annotated[ - Union[Bf16QuantizationConfig, Fp8QuantizationConfig], + Union[Bf16QuantizationConfig, Fp8QuantizationConfig, Int4QuantizationConfig], Field(discriminator="type"), ] @@ -219,8 +226,6 @@ class Inference(Protocol): logprobs: Optional[LogProbConfig] = None, ) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ... - # This method is not `async def` because it can result in either an - # `AsyncGenerator` or a `ChatCompletionResponse` depending on the value of `stream`. @webmethod(route="/inference/chat_completion") async def chat_completion( self, diff --git a/llama_stack/providers/impls/meta_reference/inference/generation.py b/llama_stack/providers/impls/meta_reference/inference/generation.py index b424a9347..ebce1024b 100644 --- a/llama_stack/providers/impls/meta_reference/inference/generation.py +++ b/llama_stack/providers/impls/meta_reference/inference/generation.py @@ -30,7 +30,6 @@ from llama_models.llama3.reference_impl.multimodal.model import ( CrossAttentionTransformer, ) from llama_models.sku_list import resolve_model - from pydantic import BaseModel from termcolor import cprint @@ -43,7 +42,12 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( chat_completion_request_to_messages, ) -from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig +from .config import ( + Fp8QuantizationConfig, + Int4QuantizationConfig, + MetaReferenceInferenceConfig, + MetaReferenceQuantizedInferenceConfig, +) def model_checkpoint_dir(model) -> str: @@ -131,18 +135,34 @@ class Llama: ), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}" if isinstance(config, MetaReferenceQuantizedInferenceConfig): - from .quantization.loader import convert_to_quantized_model - # load on CPU in bf16 so that fp8 conversion does not find an - # unexpected (fp32, e.g.) datatype - torch.set_default_tensor_type(torch.BFloat16Tensor) - if model_args.vision_chunk_size > 0: - model = CrossAttentionTransformer(model_args) - model.setup_cache(model_args.max_batch_size, torch.bfloat16) - else: + if isinstance(config.quantization, Fp8QuantizationConfig): + from .quantization.loader import convert_to_fp8_quantized_model + + # load on CPU in bf16 so that fp8 conversion does not find an + # unexpected (fp32, e.g.) datatype + torch.set_default_tensor_type(torch.BFloat16Tensor) + if model_args.vision_chunk_size > 0: + model = CrossAttentionTransformer(model_args) + model.setup_cache(model_args.max_batch_size, torch.bfloat16) + else: + model = Transformer(model_args) + model.load_state_dict(state_dict, strict=False) + model = convert_to_fp8_quantized_model(model, config, ckpt_dir) + elif isinstance(config.quantization, Int4QuantizationConfig): + from .quantization.loader import convert_to_int4_quantized_model + + assert ( + config.quantization.scheme is not None + ), "Please specify a quantization scheme." + model = Transformer(model_args) - model.load_state_dict(state_dict, strict=False) - model = convert_to_quantized_model(model, config, ckpt_dir) + model = convert_to_int4_quantized_model(model, model_args, config) + model.load_state_dict(state_dict, strict=True) + else: + raise NotImplementedError( + "Currently int4 and fp8 are the only supported quantization methods." + ) else: if torch.cuda.is_bf16_supported(): torch.set_default_tensor_type(torch.cuda.BFloat16Tensor) diff --git a/llama_stack/providers/impls/meta_reference/inference/quantization/loader.py b/llama_stack/providers/impls/meta_reference/inference/quantization/loader.py index bd59fe618..e07c9fa3b 100644 --- a/llama_stack/providers/impls/meta_reference/inference/quantization/loader.py +++ b/llama_stack/providers/impls/meta_reference/inference/quantization/loader.py @@ -8,19 +8,25 @@ # This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. import os -from typing import Optional +from typing import Any, Dict, List, Optional import torch +from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region -from llama_models.datatypes import CheckpointQuantizationFormat -from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock +from llama_models.datatypes import CheckpointQuantizationFormat + +from llama_models.llama3.api.args import ModelArgs +from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock from llama_models.sku_list import resolve_model from termcolor import cprint -from torch import Tensor +from torch import nn, Tensor + +from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear from llama_stack.apis.inference import QuantizationType +from llama_stack.apis.inference.inference import Int4QuantizationConfig from llama_stack.providers.impls.meta_reference.inference.config import ( MetaReferenceQuantizedInferenceConfig, @@ -37,7 +43,7 @@ def swiglu_wrapper( return reduce_from_model_parallel_region(out) -def convert_to_quantized_model( +def convert_to_fp8_quantized_model( model: Transformer, config: MetaReferenceQuantizedInferenceConfig, checkpoint_dir: str, @@ -99,3 +105,241 @@ def convert_to_quantized_model( if not isinstance(parameter, Fp8ScaledWeights): parameter.data = parameter.to(device="cuda") return model + + +class Int8DynActInt4WeightLinearLoRA(Int8DynActInt4WeightLinear): + """ + Int8DynActInt4WeightLinear with LoRA adaptor. + + Args: + in_features: Number of input features. + out_features: Number of output features. + bias: Whether to use bias. + device: Device to use. + group_size: Group size for quantization. + precision: Precision of quantization. + scales_precision: Precision of scales. + lora_rank: Rank of LoRA adaptor. + lora_scale: Scale of LoRA adaptor. + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias=False, + device=None, + # quantization parameters + group_size: int = 256, + precision: torch.dtype = torch.float32, + scales_precision: torch.dtype = torch.float32, + # LoRA parameters + lora_rank: Optional[int] = None, + lora_scale: Optional[float] = None, + ) -> None: + super().__init__( + in_features, + out_features, + bias=bias, + device=device, + groupsize=group_size, + precision=precision, + scales_precision=scales_precision, + ) + if lora_rank is not None: + assert lora_scale is not None, "Please specify lora scale for LoRA." + # Low-rank adaptation. See paper for more details: https://arxiv.org/abs/2106.09685 + self.adaptor = nn.Sequential() + self.adaptor.add_module("A", nn.Linear(in_features, lora_rank, bias=False)) + self.adaptor.add_module("B", nn.Linear(lora_rank, out_features, bias=False)) + self.lora_scale = lora_scale + else: + self.adaptor = None + self.lora_scale = None + self._register_load_state_dict_pre_hook(self.load_hook) + + def load_hook( + self, + state_dict: Dict[str, Any], + prefix: str, + local_metadata: Dict[str, Any], + strict: bool, + missing_keys: List[str], + unexpected_keys: List[str], + error_msgs: List[str], + ) -> None: + """A hook to load the quantized weights from the state dict.""" + if prefix + "zeros" not in state_dict: + # Zero-point may not be saved in the state dict. In this case, we assume it's zero. + assert prefix + "scales" in state_dict + state_dict[prefix + "zeros"] = torch.zeros_like( + state_dict[prefix + "scales"] + ) + + def forward(self, input_: torch.Tensor) -> torch.Tensor: + module_out = super().forward(input_) + if self.adaptor is not None: + adaptor_out = self.adaptor(input_) * self.lora_scale + return module_out + adaptor_out + return module_out + + +class Int8WeightEmbedding(torch.nn.Embedding): + """An embedding layer to load int8 weights. + + Args: + num_embeddings: Number of embeddings. + embedding_dim: Embedding dimension. + padding_idx: Padding index. + """ + + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int, + device=None, + ) -> None: + super().__init__(num_embeddings, embedding_dim, padding_idx, device=device) + + self._register_load_state_dict_pre_hook(self.load_hook) + + def load_hook( + self, + state_dict: Dict[str, Any], + prefix: str, + local_metadata: Dict[str, Any], + strict: bool, + missing_keys: List[str], + unexpected_keys: List[str], + error_msgs: List[str], + ) -> None: + """A hook to load the quantized embedding weight and scales from the state dict.""" + weights = state_dict.pop(prefix + "weight") + scales = state_dict.pop(prefix + "scales") + state_dict[prefix + "weight"] = weights * scales + + +class Int8WeightLinear(torch.nn.Linear): + """A linear layer to load int8 weights. + + Args: + in_features: Number of input features. + out_features: Number of output features. + bias: Whether to use bias. + """ + + def __init__( + self, in_features: int, out_features: int, bias: bool = True, device=None + ) -> None: + super().__init__(in_features, out_features, bias, device=device) + + self._register_load_state_dict_pre_hook(self.load_hook) + + def load_hook( + self, + state_dict: Dict[str, Any], + prefix: str, + local_metadata: Dict[str, Any], + strict: bool, + missing_keys: List[str], + unexpected_keys: List[str], + error_msgs: List[str], + ) -> None: + """A hook to load the quantized linear weight and scales from the state dict.""" + weights = state_dict.pop(prefix + "weight") + scales = state_dict.pop(prefix + "scales") + state_dict[prefix + "weight"] = weights * scales + + +def _prepare_model_int4_weight_int8_dynamic_activation( + model: torch.nn.Module, + group_size: int, + lora_rank: Optional[int], + lora_scale: Optional[float], +): + """Prepare the model for int4 weight and int8 dynamic activation quantization. + + Note that the weights of embedding and output layers are quantized to int8. + """ + device = None + for module_name, module in model.named_children(): + if module_name == "output": + quantized_module = Int8WeightLinear( + in_features=module.in_features, + out_features=module.out_features, + bias=module.bias, + device=device, + ) + del module + setattr(model, module_name, quantized_module) + elif module_name == "tok_embeddings": + quantized_module = Int8WeightEmbedding( + num_embeddings=module.num_embeddings, + embedding_dim=module.embedding_dim, + padding_idx=module.padding_idx, + device=device, + ) + del module + setattr(model, module_name, quantized_module) + elif isinstance(module, (ColumnParallelLinear, RowParallelLinear, nn.Linear)): + quantized_module = Int8DynActInt4WeightLinearLoRA( + in_features=module.in_features, + out_features=module.out_features, + bias=False, + group_size=group_size, + lora_rank=lora_rank, + lora_scale=lora_scale, + device=device, + ) + del module + setattr(model, module_name, quantized_module) + else: + _prepare_model_int4_weight_int8_dynamic_activation( + module, group_size, lora_rank, lora_scale + ) + + return model + + +def convert_to_int4_quantized_model( + model: Transformer, + model_args: ModelArgs, + config: MetaReferenceQuantizedInferenceConfig, +) -> Transformer: + """Convert the model to int4 quantized model.""" + + quant_config = config.quantization + if not isinstance(quant_config, Int4QuantizationConfig): + raise ValueError("Only int4 quantization is supported") + + if quant_config.type != QuantizationType.int4.value: + raise ValueError("Only int4 quantization is supported") + + if quant_config.scheme != "int4_weight_int8_dynamic_activation": + raise NotImplementedError( + "Only int4 quantization with 'int4_weight_int8_dynamic_activation' scheme is supported." + ) + + if model_args.quantization_args is None: + raise ValueError("'quantization_args' cannot be None. Please specify it.") + + group_size = model_args.quantization_args.group_size + if group_size is None: + raise ValueError( + "'group_size' cannot be None in 'quantization_args'. Please specify it." + ) + + if model_args.lora_args is None: + # Certain quantized models (e.g., SpinQuant) may not have LoRA. + lora_rank = None + lora_scale = None + else: + lora_rank = model_args.lora_args.rank + lora_scale = model_args.lora_args.scale + + _prepare_model_int4_weight_int8_dynamic_activation( + model, group_size, lora_rank, lora_scale + ) + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + return model.to(device) diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index 28555755b..88265f1b4 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -37,6 +37,7 @@ def available_providers() -> List[ProviderSpec]: META_REFERENCE_DEPS + [ "fbgemm-gpu", + "torchao==0.5.0", ] ), module="llama_stack.providers.impls.meta_reference.inference", From 8aa8847b4acf061e4bb6789876a1b86491e31375 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Thu, 24 Oct 2024 08:41:39 -0700 Subject: [PATCH 04/17] Bump version to 0.0.44 --- requirements.txt | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index 513642500..05016cceb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ blobfile fire httpx huggingface-hub -llama-models>=0.0.43 +llama-models>=0.0.44 prompt-toolkit python-dotenv pydantic>=2 diff --git a/setup.py b/setup.py index df2c2d18e..ba44b7d53 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ def read_requirements(): setup( name="llama_stack", - version="0.0.43", + version="0.0.44", author="Meta Llama", author_email="llama-oss@meta.com", description="Llama Stack", From 8eceebec98dcc9b303165226ee65ab34ab318070 Mon Sep 17 00:00:00 2001 From: Dalton Flanagan <6599399+dltn@users.noreply.github.com> Date: Thu, 24 Oct 2024 14:47:27 -0400 Subject: [PATCH 05/17] Update iOS inference instructions for new quantization --- llama_stack/providers/impls/ios/inference/README.md | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/llama_stack/providers/impls/ios/inference/README.md b/llama_stack/providers/impls/ios/inference/README.md index d6ce42382..160980759 100644 --- a/llama_stack/providers/impls/ios/inference/README.md +++ b/llama_stack/providers/impls/ios/inference/README.md @@ -56,9 +56,20 @@ We're working on making LocalInference easier to set up.Β For now, you'll need t ## Preparing a model -1. Prepare a `.pte` file [following the executorch docs](https://github.com/pytorch/executorch/blob/main/examples/models/llama2/README.md#step-2-prepare-model) +1. Prepare a `.pte` file [following the executorch docs](https://github.com/pytorch/executorch/blob/main/examples/models/llama/README.md#step-2-prepare-model) 2. Bundle the `.pte` and `tokenizer.model` file into your app +We now support models quantized using SpinQuant and QAT-LoRA which offer a significant performance boost (demo app on iPhone 13 Pro): + + +| Llama 3.2 1B | Tokens / Second (total) | | Time-to-First-Token (sec) | | +| :---- | :---- | :---- | :---- | :---- | +| | Haiku | Paragraph | Haiku | Paragraph | +| BF16 | 2.2 | 2.5 | 2.3 | 1.9 | +| QAT+LoRA | 7.1 | 3.3 | 0.37 | 0.24 | +| SpinQuant | 10.1 | 5.2 | 0.2 | 0.2 | + + ## Using LocalInference 1. Instantiate LocalInference with a DispatchQueue. Optionally, pass it into your agents service: From 161aef0aae96571fe049af334da6161d2d0fdc0c Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Thu, 24 Oct 2024 12:08:43 -0700 Subject: [PATCH 06/17] Small updates to quantization config --- distributions/meta-reference-quantized-gpu/run.yaml | 4 ++-- llama_stack/apis/inference/inference.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/distributions/meta-reference-quantized-gpu/run.yaml b/distributions/meta-reference-quantized-gpu/run.yaml index 6e8be2b6d..f162502c5 100644 --- a/distributions/meta-reference-quantized-gpu/run.yaml +++ b/distributions/meta-reference-quantized-gpu/run.yaml @@ -16,9 +16,9 @@ providers: - provider_id: meta0 provider_type: meta-reference-quantized config: - model: Llama3.2-3B-Instruct + model: Llama3.2-3B-Instruct:int4-qlora-eo8 quantization: - type: fp8 + type: int4 torch_seed: null max_seq_len: 2048 max_batch_size: 1 diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index d1ff047b0..24b7bdc33 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -41,7 +41,7 @@ class Bf16QuantizationConfig(BaseModel): @json_schema_type class Int4QuantizationConfig(BaseModel): type: Literal[QuantizationType.int4.value] = QuantizationType.int4.value - scheme: Optional[str] = None + scheme: Optional[str] = "int4_weight_int8_dynamic_activation" QuantizationConfig = Annotated[ From 205bcfdd4eef68ea01c0a1d19de9b3a84d64b7fb Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Thu, 24 Oct 2024 12:11:58 -0700 Subject: [PATCH 07/17] Fix score threshold in faiss --- llama_stack/providers/impls/meta_reference/memory/faiss.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/llama_stack/providers/impls/meta_reference/memory/faiss.py b/llama_stack/providers/impls/meta_reference/memory/faiss.py index 8ead96302..02829f7be 100644 --- a/llama_stack/providers/impls/meta_reference/memory/faiss.py +++ b/llama_stack/providers/impls/meta_reference/memory/faiss.py @@ -47,7 +47,9 @@ class FaissIndex(EmbeddingIndex): self.index.add(np.array(embeddings).astype(np.float32)) - async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse: + async def query( + self, embedding: NDArray, k: int, score_threshold: float + ) -> QueryDocumentsResponse: distances, indices = self.index.search( embedding.reshape(1, -1).astype(np.float32), k ) From 0538cc297e6d4e749f1c40fbb36aa53212c202c0 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Thu, 24 Oct 2024 12:14:18 -0700 Subject: [PATCH 08/17] Bump version to 0.0.45 --- requirements.txt | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index 05016cceb..621c17c28 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ blobfile fire httpx huggingface-hub -llama-models>=0.0.44 +llama-models>=0.0.45 prompt-toolkit python-dotenv pydantic>=2 diff --git a/setup.py b/setup.py index ba44b7d53..78ceb145e 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ def read_requirements(): setup( name="llama_stack", - version="0.0.44", + version="0.0.45", author="Meta Llama", author_email="llama-oss@meta.com", description="Llama Stack", From 94728d6983bfba9c6d010758340e451beb8ac33c Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Thu, 24 Oct 2024 13:36:41 -0700 Subject: [PATCH 09/17] Handle both ipv6 and ipv4 interfaces together --- llama_stack/distribution/server/server.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 185c89e7e..e3d621fd6 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -337,7 +337,8 @@ def main( import uvicorn # FYI this does not do hot-reloads - listen_host = "::" if not disable_ipv6 else "0.0.0.0" + + listen_host = ["::", "0.0.0.0"] if not disable_ipv6 else "0.0.0.0" print(f"Listening on {listen_host}:{port}") uvicorn.run(app, host=listen_host, port=port) From 8615bc9e08bd93c79a1dae9c4c65142ad4829ef7 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 24 Oct 2024 14:04:13 -0700 Subject: [PATCH 10/17] update manifest for build templates --- MANIFEST.in | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/MANIFEST.in b/MANIFEST.in index 52ab42950..7426a3abd 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,4 +1,4 @@ include requirements.txt include llama_stack/distribution/*.sh include llama_stack/cli/scripts/*.sh -include llama_stack/distribution/templates/*.yaml +include distributions/*/build.yaml From e70420a06ee671437483f5c9707a06b27386390d Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 24 Oct 2024 14:19:35 -0700 Subject: [PATCH 11/17] Update getting_started.md --- docs/getting_started.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/getting_started.md b/docs/getting_started.md index 6b9510e00..e08885a72 100644 --- a/docs/getting_started.md +++ b/docs/getting_started.md @@ -41,6 +41,10 @@ export LLAMA_CHECKPOINT_DIR=~/.llama > [!NOTE] > `~/.llama` should be the path containing downloaded weights of Llama models. +To download llama models, use +``` +llama download --model-id Llama3.1-8B-Instruct +``` To download and start running a pre-built docker container, you may use the following commands: From cb8403456748a17c93beb04ce9f2870a10ec9800 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 24 Oct 2024 14:52:30 -0700 Subject: [PATCH 12/17] [Evals API][3/n] scoring_functions / scoring meta-reference implementations (#296) * wip * dataset validation * test_scoring * cleanup * clean up test * comments * error checking * dataset client * test client: * datasetio client * clean up * basic scoring function works * scorer wip * equality scorer * score batch impl * score batch * update scoring test * refactor * validate scorer input * address comments * add all rows scores to ScoringResult * bugfix * scoring function def rename --- llama_stack/apis/datasetio/client.py | 103 ++++++++++++++ llama_stack/apis/datasetio/datasetio.py | 2 +- llama_stack/apis/datasets/client.py | 116 +++++++++++++++ llama_stack/apis/datasets/datasets.py | 2 +- llama_stack/apis/scoring/client.py | 132 ++++++++++++++++++ llama_stack/apis/scoring/scoring.py | 20 ++- .../scoring_functions/scoring_functions.py | 46 ++---- llama_stack/distribution/datatypes.py | 5 + llama_stack/distribution/distribution.py | 4 + llama_stack/distribution/resolver.py | 4 + llama_stack/distribution/routers/__init__.py | 12 +- llama_stack/distribution/routers/routers.py | 54 +++++++ .../distribution/routers/routing_tables.py | 36 ++++- llama_stack/providers/datatypes.py | 13 +- .../meta_reference/datasetio/datasetio.py | 21 ++- .../impls/meta_reference/scoring/__init__.py | 21 +++ .../impls/meta_reference/scoring/config.py | 9 ++ .../meta_reference/scoring/scorer/__init__.py | 5 + .../scoring/scorer/base_scorer.py | 37 +++++ .../scoring/scorer/equality_scorer.py | 49 +++++++ .../impls/meta_reference/scoring/scoring.py | 109 +++++++++++++++ llama_stack/providers/registry/scoring.py | 25 ++++ .../tests/datasetio/test_dataset.csv | 6 + .../tests/datasetio/test_datasetio.py | 36 ++++- .../providers/tests/scoring/__init__.py | 5 + .../scoring/provider_config_example.yaml | 9 ++ .../providers/tests/scoring/test_scoring.py | 69 +++++++++ tests/examples/evals-tgi-run.yaml | 5 + 28 files changed, 904 insertions(+), 51 deletions(-) create mode 100644 llama_stack/apis/datasetio/client.py create mode 100644 llama_stack/apis/datasets/client.py create mode 100644 llama_stack/apis/scoring/client.py create mode 100644 llama_stack/providers/impls/meta_reference/scoring/__init__.py create mode 100644 llama_stack/providers/impls/meta_reference/scoring/config.py create mode 100644 llama_stack/providers/impls/meta_reference/scoring/scorer/__init__.py create mode 100644 llama_stack/providers/impls/meta_reference/scoring/scorer/base_scorer.py create mode 100644 llama_stack/providers/impls/meta_reference/scoring/scorer/equality_scorer.py create mode 100644 llama_stack/providers/impls/meta_reference/scoring/scoring.py create mode 100644 llama_stack/providers/registry/scoring.py create mode 100644 llama_stack/providers/tests/datasetio/test_dataset.csv create mode 100644 llama_stack/providers/tests/scoring/__init__.py create mode 100644 llama_stack/providers/tests/scoring/provider_config_example.yaml create mode 100644 llama_stack/providers/tests/scoring/test_scoring.py diff --git a/llama_stack/apis/datasetio/client.py b/llama_stack/apis/datasetio/client.py new file mode 100644 index 000000000..b62db9085 --- /dev/null +++ b/llama_stack/apis/datasetio/client.py @@ -0,0 +1,103 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import asyncio +import os +from pathlib import Path +from typing import Optional + +import fire +import httpx +from termcolor import cprint + +from llama_stack.apis.datasets import * # noqa: F403 +from llama_stack.apis.datasetio import * # noqa: F403 +from llama_stack.apis.common.type_system import * # noqa: F403 +from llama_stack.apis.datasets.client import DatasetsClient +from llama_stack.providers.tests.datasetio.test_datasetio import data_url_from_file + + +class DatasetIOClient(DatasetIO): + def __init__(self, base_url: str): + self.base_url = base_url + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + pass + + async def get_rows_paginated( + self, + dataset_id: str, + rows_in_page: int, + page_token: Optional[str] = None, + filter_condition: Optional[str] = None, + ) -> PaginatedRowsResult: + async with httpx.AsyncClient() as client: + response = await client.get( + f"{self.base_url}/datasetio/get_rows_paginated", + params={ + "dataset_id": dataset_id, + "rows_in_page": rows_in_page, + "page_token": page_token, + "filter_condition": filter_condition, + }, + headers={"Content-Type": "application/json"}, + timeout=60, + ) + response.raise_for_status() + if not response.json(): + return + + return PaginatedRowsResult(**response.json()) + + +async def run_main(host: str, port: int): + client = DatasetsClient(f"http://{host}:{port}") + + # register dataset + test_file = ( + Path(os.path.abspath(__file__)).parent.parent.parent + / "providers/tests/datasetio/test_dataset.csv" + ) + test_url = data_url_from_file(str(test_file)) + response = await client.register_dataset( + DatasetDefWithProvider( + identifier="test-dataset", + provider_id="meta0", + url=URL( + uri=test_url, + ), + dataset_schema={ + "generated_answer": StringType(), + "expected_answer": StringType(), + "input_query": StringType(), + }, + ) + ) + + # list datasets + list_dataset = await client.list_datasets() + cprint(list_dataset, "blue") + + # datsetio client to get the rows + datasetio_client = DatasetIOClient(f"http://{host}:{port}") + response = await datasetio_client.get_rows_paginated( + dataset_id="test-dataset", + rows_in_page=4, + page_token=None, + filter_condition=None, + ) + cprint(f"Returned {len(response.rows)} rows \n {response}", "green") + + +def main(host: str, port: int): + asyncio.run(run_main(host, port)) + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/llama_stack/apis/datasetio/datasetio.py b/llama_stack/apis/datasetio/datasetio.py index e8811d233..b321b260e 100644 --- a/llama_stack/apis/datasetio/datasetio.py +++ b/llama_stack/apis/datasetio/datasetio.py @@ -29,7 +29,7 @@ class DatasetIO(Protocol): # keeping for aligning with inference/safety, but this is not used dataset_store: DatasetStore - @webmethod(route="/dataio/get_rows_paginated") + @webmethod(route="/datasetio/get_rows_paginated", method="GET") async def get_rows_paginated( self, dataset_id: str, diff --git a/llama_stack/apis/datasets/client.py b/llama_stack/apis/datasets/client.py new file mode 100644 index 000000000..9e5891e74 --- /dev/null +++ b/llama_stack/apis/datasets/client.py @@ -0,0 +1,116 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import asyncio +import json +import os +from pathlib import Path +from typing import Optional + +import fire +import httpx +from termcolor import cprint + +from .datasets import * # noqa: F403 +from llama_stack.apis.datasets import * # noqa: F403 +from llama_stack.apis.common.type_system import * # noqa: F403 +from llama_stack.providers.tests.datasetio.test_datasetio import data_url_from_file + + +class DatasetsClient(Datasets): + def __init__(self, base_url: str): + self.base_url = base_url + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + pass + + async def register_dataset( + self, + dataset_def: DatasetDefWithProvider, + ) -> None: + async with httpx.AsyncClient() as client: + response = await client.post( + f"{self.base_url}/datasets/register", + json={ + "dataset_def": json.loads(dataset_def.json()), + }, + headers={"Content-Type": "application/json"}, + timeout=60, + ) + response.raise_for_status() + return + + async def get_dataset( + self, + dataset_identifier: str, + ) -> Optional[DatasetDefWithProvider]: + async with httpx.AsyncClient() as client: + response = await client.get( + f"{self.base_url}/datasets/get", + params={ + "dataset_identifier": dataset_identifier, + }, + headers={"Content-Type": "application/json"}, + timeout=60, + ) + response.raise_for_status() + if not response.json(): + return + + return DatasetDefWithProvider(**response.json()) + + async def list_datasets(self) -> List[DatasetDefWithProvider]: + async with httpx.AsyncClient() as client: + response = await client.get( + f"{self.base_url}/datasets/list", + headers={"Content-Type": "application/json"}, + timeout=60, + ) + response.raise_for_status() + if not response.json(): + return + + return [DatasetDefWithProvider(**x) for x in response.json()] + + +async def run_main(host: str, port: int): + client = DatasetsClient(f"http://{host}:{port}") + + # register dataset + test_file = ( + Path(os.path.abspath(__file__)).parent.parent.parent + / "providers/tests/datasetio/test_dataset.csv" + ) + test_url = data_url_from_file(str(test_file)) + response = await client.register_dataset( + DatasetDefWithProvider( + identifier="test-dataset", + provider_id="meta0", + url=URL( + uri=test_url, + ), + dataset_schema={ + "generated_answer": StringType(), + "expected_answer": StringType(), + "input_query": StringType(), + }, + ) + ) + + # list datasets + list_dataset = await client.list_datasets() + cprint(list_dataset, "blue") + + +def main(host: str, port: int): + asyncio.run(run_main(host, port)) + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/llama_stack/apis/datasets/datasets.py b/llama_stack/apis/datasets/datasets.py index e2b764d7f..7a56049bf 100644 --- a/llama_stack/apis/datasets/datasets.py +++ b/llama_stack/apis/datasets/datasets.py @@ -20,7 +20,7 @@ class DatasetDef(BaseModel): identifier: str = Field( description="A unique name for the dataset", ) - columns_schema: Dict[str, ParamType] = Field( + dataset_schema: Dict[str, ParamType] = Field( description="The schema definition for this dataset", ) url: URL diff --git a/llama_stack/apis/scoring/client.py b/llama_stack/apis/scoring/client.py new file mode 100644 index 000000000..f08fa4bc0 --- /dev/null +++ b/llama_stack/apis/scoring/client.py @@ -0,0 +1,132 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import asyncio +import os +from pathlib import Path + +import fire +import httpx +from termcolor import cprint + +from llama_stack.apis.datasets import * # noqa: F403 +from llama_stack.apis.scoring import * # noqa: F403 +from llama_stack.apis.common.type_system import * # noqa: F403 +from llama_stack.apis.datasetio.client import DatasetIOClient +from llama_stack.apis.datasets.client import DatasetsClient +from llama_stack.providers.tests.datasetio.test_datasetio import data_url_from_file + + +class ScoringClient(Scoring): + def __init__(self, base_url: str): + self.base_url = base_url + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + pass + + async def score_batch( + self, dataset_id: str, scoring_functions: List[str] + ) -> ScoreBatchResponse: + async with httpx.AsyncClient() as client: + response = await client.post( + f"{self.base_url}/scoring/score_batch", + json={ + "dataset_id": dataset_id, + "scoring_functions": scoring_functions, + }, + headers={"Content-Type": "application/json"}, + timeout=60, + ) + response.raise_for_status() + if not response.json(): + return + + return ScoreBatchResponse(**response.json()) + + async def score( + self, input_rows: List[Dict[str, Any]], scoring_functions: List[str] + ) -> ScoreResponse: + async with httpx.AsyncClient() as client: + response = await client.post( + f"{self.base_url}/scoring/score", + json={ + "input_rows": input_rows, + "scoring_functions": scoring_functions, + }, + headers={"Content-Type": "application/json"}, + timeout=60, + ) + response.raise_for_status() + if not response.json(): + return + + return ScoreResponse(**response.json()) + + +async def run_main(host: str, port: int): + client = DatasetsClient(f"http://{host}:{port}") + + # register dataset + test_file = ( + Path(os.path.abspath(__file__)).parent.parent.parent + / "providers/tests/datasetio/test_dataset.csv" + ) + test_url = data_url_from_file(str(test_file)) + response = await client.register_dataset( + DatasetDefWithProvider( + identifier="test-dataset", + provider_id="meta0", + url=URL( + uri=test_url, + ), + dataset_schema={ + "generated_answer": StringType(), + "expected_answer": StringType(), + "input_query": StringType(), + }, + ) + ) + + # list datasets + list_dataset = await client.list_datasets() + cprint(list_dataset, "blue") + + # datsetio client to get the rows + datasetio_client = DatasetIOClient(f"http://{host}:{port}") + response = await datasetio_client.get_rows_paginated( + dataset_id="test-dataset", + rows_in_page=4, + page_token=None, + filter_condition=None, + ) + cprint(f"Returned {len(response.rows)} rows \n {response}", "green") + + # scoring client to score the rows + scoring_client = ScoringClient(f"http://{host}:{port}") + response = await scoring_client.score( + input_rows=response.rows, + scoring_functions=["equality"], + ) + cprint(f"score response={response}", "blue") + + # test scoring batch using datasetio api + scoring_client = ScoringClient(f"http://{host}:{port}") + response = await scoring_client.score_batch( + dataset_id="test-dataset", + scoring_functions=["equality"], + ) + cprint(f"score_batch response={response}", "cyan") + + +def main(host: str, port: int): + asyncio.run(run_main(host, port)) + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/llama_stack/apis/scoring/scoring.py b/llama_stack/apis/scoring/scoring.py index ec50ecab1..adac34d55 100644 --- a/llama_stack/apis/scoring/scoring.py +++ b/llama_stack/apis/scoring/scoring.py @@ -13,18 +13,27 @@ from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.scoring_functions import * # noqa: F403 -ScoringResult = Dict[str, Any] +# mapping of metric to value +ScoringResultRow = Dict[str, Any] + + +@json_schema_type +class ScoringResult(BaseModel): + score_rows: List[ScoringResultRow] + # aggregated metrics to value + aggregated_results: Dict[str, Any] @json_schema_type class ScoreBatchResponse(BaseModel): - dataset_id: str + dataset_id: Optional[str] = None + results: Dict[str, ScoringResult] @json_schema_type class ScoreResponse(BaseModel): # each key in the dict is a scoring function name - results: List[Dict[str, ScoringResult]] + results: Dict[str, ScoringResult] class ScoringFunctionStore(Protocol): @@ -37,7 +46,10 @@ class Scoring(Protocol): @webmethod(route="/scoring/score_batch") async def score_batch( - self, dataset_id: str, scoring_functions: List[str] + self, + dataset_id: str, + scoring_functions: List[str], + save_results_dataset: bool = False, ) -> ScoreBatchResponse: ... @webmethod(route="/scoring/score") diff --git a/llama_stack/apis/scoring_functions/scoring_functions.py b/llama_stack/apis/scoring_functions/scoring_functions.py index 1d71c51f3..a242215c6 100644 --- a/llama_stack/apis/scoring_functions/scoring_functions.py +++ b/llama_stack/apis/scoring_functions/scoring_functions.py @@ -4,20 +4,10 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import ( - Any, - Dict, - List, - Literal, - Optional, - Protocol, - runtime_checkable, - Union, -) +from typing import Any, Dict, List, Optional, Protocol, runtime_checkable from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, Field -from typing_extensions import Annotated from llama_stack.apis.common.type_system import ParamType @@ -33,45 +23,37 @@ class Parameter(BaseModel): # with standard metrics so they can be rolled up? +class LLMAsJudgeContext(BaseModel): + judge_model: str + prompt_template: Optional[str] = None + + @json_schema_type -class CommonDef(BaseModel): - name: str +class ScoringFunctionDef(BaseModel): + identifier: str description: Optional[str] = None metadata: Dict[str, Any] = Field( default_factory=dict, description="Any additional metadata for this definition", ) - # Hack: same with memory_banks for union defs - provider_id: str = "" - - -@json_schema_type -class DeterministicFunctionDef(CommonDef): - type: Literal["deterministic"] = "deterministic" parameters: List[Parameter] = Field( description="List of parameters for the deterministic function", + default_factory=list, ) return_type: ParamType = Field( description="The return type of the deterministic function", ) + context: Optional[LLMAsJudgeContext] = None # We can optionally add information here to support packaging of code, etc. @json_schema_type -class LLMJudgeFunctionDef(CommonDef): - type: Literal["judge"] = "judge" - model: str = Field( - description="The LLM model to use for the judge function", +class ScoringFunctionDefWithProvider(ScoringFunctionDef): + provider_id: str = Field( + description="ID of the provider which serves this dataset", ) -ScoringFunctionDef = Annotated[ - Union[DeterministicFunctionDef, LLMJudgeFunctionDef], Field(discriminator="type") -] - -ScoringFunctionDefWithProvider = ScoringFunctionDef - - @runtime_checkable class ScoringFunctions(Protocol): @webmethod(route="/scoring_functions/list", method="GET") @@ -84,5 +66,5 @@ class ScoringFunctions(Protocol): @webmethod(route="/scoring_functions/register", method="POST") async def register_scoring_function( - self, function: ScoringFunctionDefWithProvider + self, function_def: ScoringFunctionDefWithProvider ) -> None: ... diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index 10f78b78f..318809baf 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -15,10 +15,12 @@ from llama_stack.apis.models import * # noqa: F403 from llama_stack.apis.shields import * # noqa: F403 from llama_stack.apis.memory_banks import * # noqa: F403 from llama_stack.apis.datasets import * # noqa: F403 +from llama_stack.apis.scoring_functions import * # noqa: F403 from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.inference import Inference from llama_stack.apis.memory import Memory from llama_stack.apis.safety import Safety +from llama_stack.apis.scoring import Scoring LLAMA_STACK_BUILD_CONFIG_VERSION = "2" LLAMA_STACK_RUN_CONFIG_VERSION = "2" @@ -32,6 +34,7 @@ RoutableObject = Union[ ShieldDef, MemoryBankDef, DatasetDef, + ScoringFunctionDef, ] RoutableObjectWithProvider = Union[ @@ -39,6 +42,7 @@ RoutableObjectWithProvider = Union[ ShieldDefWithProvider, MemoryBankDefWithProvider, DatasetDefWithProvider, + ScoringFunctionDefWithProvider, ] RoutedProtocol = Union[ @@ -46,6 +50,7 @@ RoutedProtocol = Union[ Safety, Memory, DatasetIO, + Scoring, ] diff --git a/llama_stack/distribution/distribution.py b/llama_stack/distribution/distribution.py index 53d544471..2149162a6 100644 --- a/llama_stack/distribution/distribution.py +++ b/llama_stack/distribution/distribution.py @@ -39,6 +39,10 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]: routing_table_api=Api.datasets, router_api=Api.datasetio, ), + AutoRoutedApiInfo( + routing_table_api=Api.scoring_functions, + router_api=Api.scoring, + ), ] diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index 2e6b64a53..b9b9fb229 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -20,6 +20,8 @@ from llama_stack.apis.memory import Memory from llama_stack.apis.memory_banks import MemoryBanks from llama_stack.apis.models import Models from llama_stack.apis.safety import Safety +from llama_stack.apis.scoring import Scoring +from llama_stack.apis.scoring_functions import ScoringFunctions from llama_stack.apis.shields import Shields from llama_stack.apis.telemetry import Telemetry from llama_stack.distribution.distribution import ( @@ -42,6 +44,8 @@ def api_protocol_map() -> Dict[Api, Any]: Api.telemetry: Telemetry, Api.datasets: Datasets, Api.datasetio: DatasetIO, + Api.scoring_functions: ScoringFunctions, + Api.scoring: Scoring, } diff --git a/llama_stack/distribution/routers/__init__.py b/llama_stack/distribution/routers/__init__.py index 4970e93e1..2cc89848e 100644 --- a/llama_stack/distribution/routers/__init__.py +++ b/llama_stack/distribution/routers/__init__.py @@ -11,6 +11,7 @@ from .routing_tables import ( DatasetsRoutingTable, MemoryBanksRoutingTable, ModelsRoutingTable, + ScoringFunctionsRoutingTable, ShieldsRoutingTable, ) @@ -25,7 +26,9 @@ async def get_routing_table_impl( "models": ModelsRoutingTable, "shields": ShieldsRoutingTable, "datasets": DatasetsRoutingTable, + "scoring_functions": ScoringFunctionsRoutingTable, } + if api.value not in api_to_tables: raise ValueError(f"API {api.value} not found in router map") @@ -35,13 +38,20 @@ async def get_routing_table_impl( async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> Any: - from .routers import DatasetIORouter, InferenceRouter, MemoryRouter, SafetyRouter + from .routers import ( + DatasetIORouter, + InferenceRouter, + MemoryRouter, + SafetyRouter, + ScoringRouter, + ) api_to_routers = { "memory": MemoryRouter, "inference": InferenceRouter, "safety": SafetyRouter, "datasetio": DatasetIORouter, + "scoring": ScoringRouter, } if api.value not in api_to_routers: raise ValueError(f"API {api.value} not found in router map") diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 31b8efa48..348d8449d 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -13,6 +13,7 @@ from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.safety import * # noqa: F403 from llama_stack.apis.datasetio import * # noqa: F403 +from llama_stack.apis.scoring import * # noqa: F403 class MemoryRouter(Memory): @@ -192,3 +193,56 @@ class DatasetIORouter(DatasetIO): page_token=page_token, filter_condition=filter_condition, ) + + +class ScoringRouter(Scoring): + def __init__( + self, + routing_table: RoutingTable, + ) -> None: + self.routing_table = routing_table + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + pass + + async def score_batch( + self, + dataset_id: str, + scoring_functions: List[str], + save_results_dataset: bool = False, + ) -> ScoreBatchResponse: + res = {} + for fn_identifier in scoring_functions: + score_response = await self.routing_table.get_provider_impl( + fn_identifier + ).score_batch( + dataset_id=dataset_id, + scoring_functions=[fn_identifier], + ) + res.update(score_response.results) + + if save_results_dataset: + raise NotImplementedError("Save results dataset not implemented yet") + + return ScoreBatchResponse( + results=res, + ) + + async def score( + self, input_rows: List[Dict[str, Any]], scoring_functions: List[str] + ) -> ScoreResponse: + res = {} + # look up and map each scoring function to its provider impl + for fn_identifier in scoring_functions: + score_response = await self.routing_table.get_provider_impl( + fn_identifier + ).score( + input_rows=input_rows, + scoring_functions=[fn_identifier], + ) + res.update(score_response.results) + + return ScoreResponse(results=res) diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index db0946d81..dcd588a9e 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -30,6 +30,8 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> None: await p.register_memory_bank(obj) elif api == Api.datasetio: await p.register_dataset(obj) + elif api == Api.scoring: + await p.register_scoring_function(obj) else: raise ValueError(f"Unknown API {api} for registering object with provider") @@ -93,7 +95,15 @@ class CommonRoutingTableImpl(RoutingTable): for d in datasets: d.provider_id = pid - add_objects(datasets) + elif api == Api.scoring: + p.scoring_function_store = self + scoring_functions = await p.list_scoring_functions() + add_objects( + [ + ScoringFunctionDefWithProvider(**s.dict(), provider_id=pid) + for s in scoring_functions + ] + ) async def shutdown(self) -> None: for p in self.impls_by_provider_id.values(): @@ -109,6 +119,10 @@ class CommonRoutingTableImpl(RoutingTable): return ("Safety", "shield") elif isinstance(self, MemoryBanksRoutingTable): return ("Memory", "memory_bank") + elif isinstance(self, DatasetsRoutingTable): + return ("DatasetIO", "dataset") + elif isinstance(self, ScoringFunctionsRoutingTable): + return ("Scoring", "scoring_function") else: raise ValueError("Unknown routing table type") @@ -218,7 +232,25 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): async def get_dataset( self, dataset_identifier: str ) -> Optional[DatasetDefWithProvider]: - return self.get_object_by_identifier(identifier) + return self.get_object_by_identifier(dataset_identifier) async def register_dataset(self, dataset_def: DatasetDefWithProvider) -> None: await self.register_object(dataset_def) + + +class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, Scoring): + async def list_scoring_functions(self) -> List[ScoringFunctionDefWithProvider]: + objects = [] + for objs in self.registry.values(): + objects.extend(objs) + return objects + + async def get_scoring_function( + self, name: str + ) -> Optional[ScoringFunctionDefWithProvider]: + return self.get_object_by_identifier(name) + + async def register_scoring_function( + self, function_def: ScoringFunctionDefWithProvider + ) -> None: + await self.register_object(function_def) diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index d7e2d4d0c..903ff5438 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -11,10 +11,9 @@ from llama_models.schema_utils import json_schema_type from pydantic import BaseModel, Field from llama_stack.apis.datasets import DatasetDef - from llama_stack.apis.memory_banks import MemoryBankDef - from llama_stack.apis.models import ModelDef +from llama_stack.apis.scoring_functions import ScoringFunctionDef from llama_stack.apis.shields import ShieldDef @@ -25,6 +24,7 @@ class Api(Enum): agents = "agents" memory = "memory" datasetio = "datasetio" + scoring = "scoring" telemetry = "telemetry" @@ -32,6 +32,7 @@ class Api(Enum): shields = "shields" memory_banks = "memory_banks" datasets = "datasets" + scoring_functions = "scoring_functions" # built-in API inspect = "inspect" @@ -61,6 +62,14 @@ class DatasetsProtocolPrivate(Protocol): async def register_datasets(self, dataset_def: DatasetDef) -> None: ... +class ScoringFunctionsProtocolPrivate(Protocol): + async def list_scoring_functions(self) -> List[ScoringFunctionDef]: ... + + async def register_scoring_function( + self, function_def: ScoringFunctionDef + ) -> None: ... + + @json_schema_type class ProviderSpec(BaseModel): api: Api diff --git a/llama_stack/providers/impls/meta_reference/datasetio/datasetio.py b/llama_stack/providers/impls/meta_reference/datasetio/datasetio.py index a8e648e46..43664f394 100644 --- a/llama_stack/providers/impls/meta_reference/datasetio/datasetio.py +++ b/llama_stack/providers/impls/meta_reference/datasetio/datasetio.py @@ -3,17 +3,20 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import io from typing import List, Optional import pandas - from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.datasetio import * # noqa: F403 +import base64 from abc import ABC, abstractmethod from dataclasses import dataclass +from urllib.parse import unquote from llama_stack.providers.datatypes import DatasetsProtocolPrivate +from llama_stack.providers.utils.memory.vector_store import parse_data_url from .config import MetaReferenceDatasetIOConfig @@ -52,11 +55,20 @@ class PandasDataframeDataset(BaseDataset): return len(self.df) def __getitem__(self, idx): + assert self.df is not None, "Dataset not loaded. Please call .load() first" if isinstance(idx, slice): return self.df.iloc[idx].to_dict(orient="records") else: return self.df.iloc[idx].to_dict() + def _validate_dataset_schema(self, df) -> pandas.DataFrame: + # note that we will drop any columns in dataset that are not in the schema + df = df[self.dataset_def.dataset_schema.keys()] + # check all columns in dataset schema are present + assert len(df.columns) == len(self.dataset_def.dataset_schema) + # TODO: type checking against column types in dataset schema + return df + def load(self) -> None: if self.df is not None: return @@ -87,7 +99,7 @@ class PandasDataframeDataset(BaseDataset): else: raise ValueError(f"Unsupported file type: {self.dataset_def.url}") - self.df = df + self.df = self._validate_dataset_schema(df) class MetaReferenceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): @@ -123,7 +135,10 @@ class MetaReferenceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): dataset_info = self.dataset_infos.get(dataset_id) dataset_info.dataset_impl.load() - if page_token is None: + if page_token and not page_token.isnumeric(): + raise ValueError("Invalid page_token") + + if page_token is None or len(page_token) == 0: next_page_token = 0 else: next_page_token = int(page_token) diff --git a/llama_stack/providers/impls/meta_reference/scoring/__init__.py b/llama_stack/providers/impls/meta_reference/scoring/__init__.py new file mode 100644 index 000000000..69d9b543a --- /dev/null +++ b/llama_stack/providers/impls/meta_reference/scoring/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from typing import Dict + +from llama_stack.distribution.datatypes import Api, ProviderSpec + +from .config import MetaReferenceScoringConfig + + +async def get_provider_impl( + config: MetaReferenceScoringConfig, + deps: Dict[Api, ProviderSpec], +): + from .scoring import MetaReferenceScoringImpl + + impl = MetaReferenceScoringImpl(config, deps[Api.datasetio], deps[Api.datasets]) + await impl.initialize() + return impl diff --git a/llama_stack/providers/impls/meta_reference/scoring/config.py b/llama_stack/providers/impls/meta_reference/scoring/config.py new file mode 100644 index 000000000..bd4dcb9f0 --- /dev/null +++ b/llama_stack/providers/impls/meta_reference/scoring/config.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from llama_stack.apis.scoring import * # noqa: F401, F403 + + +class MetaReferenceScoringConfig(BaseModel): ... diff --git a/llama_stack/providers/impls/meta_reference/scoring/scorer/__init__.py b/llama_stack/providers/impls/meta_reference/scoring/scorer/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/providers/impls/meta_reference/scoring/scorer/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/providers/impls/meta_reference/scoring/scorer/base_scorer.py b/llama_stack/providers/impls/meta_reference/scoring/scorer/base_scorer.py new file mode 100644 index 000000000..ea8a3f063 --- /dev/null +++ b/llama_stack/providers/impls/meta_reference/scoring/scorer/base_scorer.py @@ -0,0 +1,37 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from abc import ABC, abstractmethod +from typing import Any, Dict, List +from llama_stack.apis.scoring_functions import * # noqa: F401, F403 +from llama_stack.apis.scoring import * # noqa: F401, F403 + + +class BaseScorer(ABC): + """ + Base interface class for all meta-reference scorers. + Each scorer needs to implement the following methods: + - score_row(self, row) + - aggregate(self, scorer_results) + """ + + scoring_function_def: ScoringFunctionDef + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def __str__(self) -> str: + return self.__class__.__name__ + + @abstractmethod + def score_row(self, input_row: Dict[str, Any]) -> ScoringResultRow: + raise NotImplementedError() + + @abstractmethod + def aggregate(self, scoring_results: List[ScoringResultRow]) -> Dict[str, Any]: + raise NotImplementedError() + + def score(self, input_rows: List[Dict[str, Any]]) -> List[ScoringResultRow]: + return [self.score_row(input_row) for input_row in input_rows] diff --git a/llama_stack/providers/impls/meta_reference/scoring/scorer/equality_scorer.py b/llama_stack/providers/impls/meta_reference/scoring/scorer/equality_scorer.py new file mode 100644 index 000000000..ce765bfb5 --- /dev/null +++ b/llama_stack/providers/impls/meta_reference/scoring/scorer/equality_scorer.py @@ -0,0 +1,49 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.providers.impls.meta_reference.scoring.scorer.base_scorer import ( + BaseScorer, +) +from llama_stack.apis.scoring_functions import * # noqa: F401, F403 +from llama_stack.apis.scoring import * # noqa: F401, F403 +from llama_stack.apis.common.type_system import * # noqa: F403 + + +class EqualityScorer(BaseScorer): + """ + A scorer that assigns a score of 1.0 if the input string matches the target string, and 0.0 otherwise. + """ + + scoring_function_def = ScoringFunctionDef( + identifier="equality", + description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.", + parameters=[], + return_type=NumberType(), + ) + + def score_row(self, input_row: Dict[str, Any]) -> ScoringResultRow: + assert "expected_answer" in input_row, "Expected answer not found in input row." + assert ( + "generated_answer" in input_row + ), "Generated answer not found in input row." + + expected_answer = input_row["expected_answer"] + generated_answer = input_row["generated_answer"] + score = 1.0 if expected_answer == generated_answer else 0.0 + return { + "score": score, + } + + def aggregate(self, scoring_results: List[ScoringResultRow]) -> Dict[str, Any]: + assert len(scoring_results) > 0, "Empty scoring results provided." + num_correct = sum(result["score"] for result in scoring_results) + avg_score = num_correct / len(scoring_results) + + return { + "accuracy": avg_score, + "num_correct": num_correct, + "num_total": len(scoring_results), + } diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring.py b/llama_stack/providers/impls/meta_reference/scoring/scoring.py new file mode 100644 index 000000000..0d32c8195 --- /dev/null +++ b/llama_stack/providers/impls/meta_reference/scoring/scoring.py @@ -0,0 +1,109 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from typing import List + +from llama_models.llama3.api.datatypes import * # noqa: F403 +from llama_stack.apis.scoring import * # noqa: F403 +from llama_stack.apis.scoring_functions import * # noqa: F403 +from llama_stack.apis.common.type_system import * # noqa: F403 +from llama_stack.apis.datasetio import * # noqa: F403 +from llama_stack.apis.datasets import * # noqa: F403 + +from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate +from llama_stack.providers.impls.meta_reference.scoring.scorer.equality_scorer import ( + EqualityScorer, +) + +from .config import MetaReferenceScoringConfig + +SUPPORTED_SCORERS = [ + EqualityScorer, +] + +SCORER_REGISTRY = {x.scoring_function_def.identifier: x for x in SUPPORTED_SCORERS} + + +class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): + def __init__( + self, + config: MetaReferenceScoringConfig, + datasetio_api: DatasetIO, + datasets_api: Datasets, + ) -> None: + self.config = config + self.datasetio_api = datasetio_api + self.datasets_api = datasets_api + + async def initialize(self) -> None: ... + + async def shutdown(self) -> None: ... + + async def list_scoring_functions(self) -> List[ScoringFunctionDef]: + return [x.scoring_function_def for x in SUPPORTED_SCORERS] + + async def register_scoring_function(self, function_def: ScoringFunctionDef) -> None: + raise NotImplementedError( + "Dynamically registering scoring functions is not supported" + ) + + async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None: + dataset_def = await self.datasets_api.get_dataset(dataset_identifier=dataset_id) + if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0: + raise ValueError( + f"Dataset {dataset_id} does not have a schema defined. Please define a schema for the dataset." + ) + + for required_column in ["generated_answer", "expected_answer", "input_query"]: + if required_column not in dataset_def.dataset_schema: + raise ValueError( + f"Dataset {dataset_id} does not have a '{required_column}' column." + ) + if dataset_def.dataset_schema[required_column].type != "string": + raise ValueError( + f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'." + ) + + async def score_batch( + self, + dataset_id: str, + scoring_functions: List[str], + save_results_dataset: bool = False, + ) -> ScoreBatchResponse: + await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id) + all_rows = await self.datasetio_api.get_rows_paginated( + dataset_id=dataset_id, + rows_in_page=-1, + ) + res = await self.score( + input_rows=all_rows.rows, scoring_functions=scoring_functions + ) + if save_results_dataset: + # TODO: persist and register dataset on to server for reading + # self.datasets_api.register_dataset() + raise NotImplementedError("Save results dataset not implemented yet") + + return ScoreBatchResponse( + results=res.results, + ) + + async def score( + self, input_rows: List[Dict[str, Any]], scoring_functions: List[str] + ) -> ScoreResponse: + res = {} + for scoring_fn_id in scoring_functions: + if scoring_fn_id not in SCORER_REGISTRY: + raise ValueError(f"Scoring function {scoring_fn_id} is not supported.") + scorer = SCORER_REGISTRY[scoring_fn_id]() + score_results = scorer.score(input_rows) + agg_results = scorer.aggregate(score_results) + res[scoring_fn_id] = ScoringResult( + score_rows=score_results, + aggregated_results=agg_results, + ) + + return ScoreResponse( + results=res, + ) diff --git a/llama_stack/providers/registry/scoring.py b/llama_stack/providers/registry/scoring.py new file mode 100644 index 000000000..4543449b4 --- /dev/null +++ b/llama_stack/providers/registry/scoring.py @@ -0,0 +1,25 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import List + +from llama_stack.distribution.datatypes import * # noqa: F403 + + +def available_providers() -> List[ProviderSpec]: + return [ + InlineProviderSpec( + api=Api.scoring, + provider_type="meta-reference", + pip_packages=[], + module="llama_stack.providers.impls.meta_reference.scoring", + config_class="llama_stack.providers.impls.meta_reference.scoring.MetaReferenceScoringConfig", + api_dependencies=[ + Api.datasetio, + Api.datasets, + ], + ), + ] diff --git a/llama_stack/providers/tests/datasetio/test_dataset.csv b/llama_stack/providers/tests/datasetio/test_dataset.csv new file mode 100644 index 000000000..a1a250753 --- /dev/null +++ b/llama_stack/providers/tests/datasetio/test_dataset.csv @@ -0,0 +1,6 @@ +input_query,generated_answer,expected_answer +What is the capital of France?,London,Paris +Who is the CEO of Meta?,Mark Zuckerberg,Mark Zuckerberg +What is the largest planet in our solar system?,Jupiter,Jupiter +What is the smallest country in the world?,China,Vatican City +What is the currency of Japan?,Yen,Yen diff --git a/llama_stack/providers/tests/datasetio/test_datasetio.py b/llama_stack/providers/tests/datasetio/test_datasetio.py index 85235a64b..9a351ba30 100644 --- a/llama_stack/providers/tests/datasetio/test_datasetio.py +++ b/llama_stack/providers/tests/datasetio/test_datasetio.py @@ -8,8 +8,13 @@ import os import pytest import pytest_asyncio +from llama_stack.apis.common.type_system import * # noqa: F403 from llama_stack.apis.datasetio import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403 +import base64 +import mimetypes +from pathlib import Path + from llama_stack.providers.tests.resolver import resolve_impls_for_test # How to run this test: @@ -41,14 +46,35 @@ async def datasetio_settings(): } +def data_url_from_file(file_path: str) -> str: + if not os.path.exists(file_path): + raise FileNotFoundError(f"File not found: {file_path}") + + with open(file_path, "rb") as file: + file_content = file.read() + + base64_content = base64.b64encode(file_content).decode("utf-8") + mime_type, _ = mimetypes.guess_type(file_path) + + data_url = f"data:{mime_type};base64,{base64_content}" + + return data_url + + async def register_dataset(datasets_impl: Datasets): + test_file = Path(os.path.abspath(__file__)).parent / "test_dataset.csv" + test_url = data_url_from_file(str(test_file)) dataset = DatasetDefWithProvider( identifier="test_dataset", provider_id=os.environ["PROVIDER_ID"], url=URL( - uri="https://openaipublic.blob.core.windows.net/simple-evals/mmlu.csv", + uri=test_url, ), - columns_schema={}, + dataset_schema={ + "generated_answer": StringType(), + "expected_answer": StringType(), + "input_query": StringType(), + }, ) await datasets_impl.register_dataset(dataset) @@ -100,10 +126,10 @@ async def test_get_rows_paginated(datasetio_settings): # iterate over all rows response = await datasetio_impl.get_rows_paginated( dataset_id="test_dataset", - rows_in_page=10, + rows_in_page=2, page_token=response.next_page_token, ) assert isinstance(response.rows, list) - assert len(response.rows) == 10 - assert response.next_page_token == "13" + assert len(response.rows) == 2 + assert response.next_page_token == "5" diff --git a/llama_stack/providers/tests/scoring/__init__.py b/llama_stack/providers/tests/scoring/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/providers/tests/scoring/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/providers/tests/scoring/provider_config_example.yaml b/llama_stack/providers/tests/scoring/provider_config_example.yaml new file mode 100644 index 000000000..9a8895149 --- /dev/null +++ b/llama_stack/providers/tests/scoring/provider_config_example.yaml @@ -0,0 +1,9 @@ +providers: + datasetio: + - provider_id: test-meta + provider_type: meta-reference + config: {} + scoring: + - provider_id: test-meta + provider_type: meta-reference + config: {} diff --git a/llama_stack/providers/tests/scoring/test_scoring.py b/llama_stack/providers/tests/scoring/test_scoring.py new file mode 100644 index 000000000..2218faa54 --- /dev/null +++ b/llama_stack/providers/tests/scoring/test_scoring.py @@ -0,0 +1,69 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +import pytest +import pytest_asyncio + +from llama_stack.apis.common.type_system import * # noqa: F403 +from llama_stack.apis.datasetio import * # noqa: F403 +from llama_stack.distribution.datatypes import * # noqa: F403 + +from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset +from llama_stack.providers.tests.resolver import resolve_impls_for_test + +# How to run this test: +# +# 1. Ensure you have a conda with the right dependencies installed. This is a bit tricky +# since it depends on the provider you are testing. On top of that you need +# `pytest` and `pytest-asyncio` installed. +# +# 2. Copy and modify the provider_config_example.yaml depending on the provider you are testing. +# +# 3. Run: +# +# ```bash +# PROVIDER_ID= \ +# PROVIDER_CONFIG=provider_config.yaml \ +# pytest -s llama_stack/providers/tests/scoring/test_scoring.py \ +# --tb=short --disable-warnings +# ``` + + +@pytest_asyncio.fixture(scope="session") +async def scoring_settings(): + impls = await resolve_impls_for_test(Api.scoring, deps=[Api.datasetio]) + return { + "scoring_impl": impls[Api.scoring], + "scoring_functions_impl": impls[Api.scoring_functions], + "datasets_impl": impls[Api.datasets], + } + + +@pytest.mark.asyncio +async def test_scoring_functions_list(scoring_settings): + scoring_functions_impl = scoring_settings["scoring_functions_impl"] + scoring_functions = await scoring_functions_impl.list_scoring_functions() + assert isinstance(scoring_functions, list) + assert len(scoring_functions) > 0 + function_ids = [f.identifier for f in scoring_functions] + assert "equality" in function_ids + + +@pytest.mark.asyncio +async def test_scoring_score(scoring_settings): + scoring_impl = scoring_settings["scoring_impl"] + datasets_impl = scoring_settings["datasets_impl"] + await register_dataset(datasets_impl) + + response = await datasets_impl.list_datasets() + assert len(response) == 1 + + response = await scoring_impl.score_batch( + dataset_id=response[0].identifier, + scoring_functions=["equality"], + ) + + assert len(response.results) == 1 + assert "equality" in response.results diff --git a/tests/examples/evals-tgi-run.yaml b/tests/examples/evals-tgi-run.yaml index 8edb050cc..e56c43420 100644 --- a/tests/examples/evals-tgi-run.yaml +++ b/tests/examples/evals-tgi-run.yaml @@ -13,7 +13,12 @@ apis: - inference - datasets - datasetio +- scoring providers: + scoring: + - provider_id: meta0 + provider_type: meta-reference + config: {} datasetio: - provider_id: meta0 provider_type: meta-reference From 3e1c3fdb3fab895393ad248d7a8ae54a137838b4 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Thu, 24 Oct 2024 16:02:41 -0700 Subject: [PATCH 13/17] completion() for tgi (#295) --- .../inference/databricks/databricks.py | 2 +- .../adapters/inference/fireworks/fireworks.py | 2 +- .../adapters/inference/ollama/ollama.py | 4 +- .../providers/adapters/inference/tgi/tgi.py | 118 +++++++++++++++--- .../adapters/inference/together/together.py | 2 +- .../providers/adapters/inference/vllm/vllm.py | 2 +- .../tests/inference/test_inference.py | 41 ++++++ .../utils/inference/openai_compat.py | 30 +++-- .../utils/inference/prompt_adapter.py | 7 ++ 9 files changed, 173 insertions(+), 35 deletions(-) diff --git a/llama_stack/providers/adapters/inference/databricks/databricks.py b/llama_stack/providers/adapters/inference/databricks/databricks.py index 4752e3fe4..f12ecb7f5 100644 --- a/llama_stack/providers/adapters/inference/databricks/databricks.py +++ b/llama_stack/providers/adapters/inference/databricks/databricks.py @@ -116,7 +116,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference): "model": self.map_to_provider_model(request.model), "prompt": chat_completion_request_to_prompt(request, self.formatter), "stream": request.stream, - **get_sampling_options(request), + **get_sampling_options(request.sampling_params), } async def embeddings( diff --git a/llama_stack/providers/adapters/inference/fireworks/fireworks.py b/llama_stack/providers/adapters/inference/fireworks/fireworks.py index 441f32166..69535cd3c 100644 --- a/llama_stack/providers/adapters/inference/fireworks/fireworks.py +++ b/llama_stack/providers/adapters/inference/fireworks/fireworks.py @@ -116,7 +116,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference): if prompt.startswith("<|begin_of_text|>"): prompt = prompt[len("<|begin_of_text|>") :] - options = get_sampling_options(request) + options = get_sampling_options(request.sampling_params) options.setdefault("max_tokens", 512) if fmt := request.response_format: diff --git a/llama_stack/providers/adapters/inference/ollama/ollama.py b/llama_stack/providers/adapters/inference/ollama/ollama.py index d4fe75cfa..916241a7c 100644 --- a/llama_stack/providers/adapters/inference/ollama/ollama.py +++ b/llama_stack/providers/adapters/inference/ollama/ollama.py @@ -110,7 +110,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): return await self._nonstream_completion(request) def _get_params_for_completion(self, request: CompletionRequest) -> dict: - sampling_options = get_sampling_options(request) + sampling_options = get_sampling_options(request.sampling_params) # This is needed since the Ollama API expects num_predict to be set # for early truncation instead of max_tokens. if sampling_options["max_tokens"] is not None: @@ -187,7 +187,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): return { "model": OLLAMA_SUPPORTED_MODELS[request.model], "prompt": chat_completion_request_to_prompt(request, self.formatter), - "options": get_sampling_options(request), + "options": get_sampling_options(request.sampling_params), "raw": True, "stream": request.stream, } diff --git a/llama_stack/providers/adapters/inference/tgi/tgi.py b/llama_stack/providers/adapters/inference/tgi/tgi.py index f19181320..a7fa6ba00 100644 --- a/llama_stack/providers/adapters/inference/tgi/tgi.py +++ b/llama_stack/providers/adapters/inference/tgi/tgi.py @@ -24,9 +24,12 @@ from llama_stack.providers.utils.inference.openai_compat import ( OpenAICompatCompletionResponse, process_chat_completion_response, process_chat_completion_stream_response, + process_completion_response, + process_completion_stream_response, ) from llama_stack.providers.utils.inference.prompt_adapter import ( chat_completion_request_to_model_input_info, + completion_request_to_prompt_model_input_info, ) from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig @@ -75,7 +78,98 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: - raise NotImplementedError() + request = CompletionRequest( + model=model, + content=content, + sampling_params=sampling_params, + response_format=response_format, + stream=stream, + logprobs=logprobs, + ) + if stream: + return self._stream_completion(request) + else: + return await self._nonstream_completion(request) + + def _get_max_new_tokens(self, sampling_params, input_tokens): + return min( + sampling_params.max_tokens or (self.max_tokens - input_tokens), + self.max_tokens - input_tokens - 1, + ) + + def _build_options( + self, + sampling_params: Optional[SamplingParams] = None, + fmt: ResponseFormat = None, + ): + options = get_sampling_options(sampling_params) + # delete key "max_tokens" from options since its not supported by the API + options.pop("max_tokens", None) + if fmt: + if fmt.type == ResponseFormatType.json_schema.value: + options["grammar"] = { + "type": "json", + "value": fmt.schema, + } + elif fmt.type == ResponseFormatType.grammar.value: + raise ValueError("Grammar response format not supported yet") + else: + raise ValueError(f"Unexpected response format: {fmt.type}") + + return options + + def _get_params_for_completion(self, request: CompletionRequest) -> dict: + prompt, input_tokens = completion_request_to_prompt_model_input_info( + request, self.formatter + ) + + return dict( + prompt=prompt, + stream=request.stream, + details=True, + max_new_tokens=self._get_max_new_tokens( + request.sampling_params, input_tokens + ), + stop_sequences=["<|eom_id|>", "<|eot_id|>"], + **self._build_options(request.sampling_params, request.response_format), + ) + + async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: + params = self._get_params_for_completion(request) + + async def _generate_and_convert_to_openai_compat(): + s = await self.client.text_generation(**params) + async for chunk in s: + token_result = chunk.token + finish_reason = None + if chunk.details: + finish_reason = chunk.details.finish_reason + + choice = OpenAICompatCompletionChoice( + text=token_result.text, finish_reason=finish_reason + ) + yield OpenAICompatCompletionResponse( + choices=[choice], + ) + + stream = _generate_and_convert_to_openai_compat() + async for chunk in process_completion_stream_response(stream, self.formatter): + yield chunk + + async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator: + params = self._get_params_for_completion(request) + r = await self.client.text_generation(**params) + + choice = OpenAICompatCompletionChoice( + finish_reason=r.details.finish_reason, + text="".join(t.text for t in r.details.tokens), + ) + + response = OpenAICompatCompletionResponse( + choices=[choice], + ) + + return process_completion_response(response, self.formatter) async def chat_completion( self, @@ -146,29 +240,15 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): prompt, input_tokens = chat_completion_request_to_model_input_info( request, self.formatter ) - max_new_tokens = min( - request.sampling_params.max_tokens or (self.max_tokens - input_tokens), - self.max_tokens - input_tokens - 1, - ) - options = get_sampling_options(request) - if fmt := request.response_format: - if fmt.type == ResponseFormatType.json_schema.value: - options["grammar"] = { - "type": "json", - "value": fmt.schema, - } - elif fmt.type == ResponseFormatType.grammar.value: - raise ValueError("Grammar response format not supported yet") - else: - raise ValueError(f"Unexpected response format: {fmt.type}") - return dict( prompt=prompt, stream=request.stream, details=True, - max_new_tokens=max_new_tokens, + max_new_tokens=self._get_max_new_tokens( + request.sampling_params, input_tokens + ), stop_sequences=["<|eom_id|>", "<|eot_id|>"], - **options, + **self._build_options(request.sampling_params, request.response_format), ) async def embeddings( diff --git a/llama_stack/providers/adapters/inference/together/together.py b/llama_stack/providers/adapters/inference/together/together.py index 2f258e620..daf57497a 100644 --- a/llama_stack/providers/adapters/inference/together/together.py +++ b/llama_stack/providers/adapters/inference/together/together.py @@ -131,7 +131,7 @@ class TogetherInferenceAdapter( yield chunk def _get_params(self, request: ChatCompletionRequest) -> dict: - options = get_sampling_options(request) + options = get_sampling_options(request.sampling_params) if fmt := request.response_format: if fmt.type == ResponseFormatType.json_schema.value: options["response_format"] = { diff --git a/llama_stack/providers/adapters/inference/vllm/vllm.py b/llama_stack/providers/adapters/inference/vllm/vllm.py index dacf646b0..4687618fa 100644 --- a/llama_stack/providers/adapters/inference/vllm/vllm.py +++ b/llama_stack/providers/adapters/inference/vllm/vllm.py @@ -143,7 +143,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): "model": VLLM_SUPPORTED_MODELS[request.model], "prompt": chat_completion_request_to_prompt(request, self.formatter), "stream": request.stream, - **get_sampling_options(request), + **get_sampling_options(request.sampling_params), } async def embeddings( diff --git a/llama_stack/providers/tests/inference/test_inference.py b/llama_stack/providers/tests/inference/test_inference.py index ad49448e2..c7cbdd592 100644 --- a/llama_stack/providers/tests/inference/test_inference.py +++ b/llama_stack/providers/tests/inference/test_inference.py @@ -137,6 +137,7 @@ async def test_completion(inference_settings): if provider.__provider_spec__.provider_type not in ( "meta-reference", "remote::ollama", + "remote::tgi", ): pytest.skip("Other inference providers don't support completion() yet") @@ -170,6 +171,46 @@ async def test_completion(inference_settings): assert last.stop_reason == StopReason.out_of_tokens +@pytest.mark.asyncio +async def test_completions_structured_output(inference_settings): + inference_impl = inference_settings["impl"] + params = inference_settings["common_params"] + + provider = inference_impl.routing_table.get_provider_impl(params["model"]) + if provider.__provider_spec__.provider_type not in ( + "meta-reference", + "remote::tgi", + ): + pytest.skip( + "Other inference providers don't support structured output in completions yet" + ) + + class Output(BaseModel): + name: str + year_born: str + year_retired: str + + user_input = "Michael Jordan was born in 1963. He played basketball for the Chicago Bulls. He retired in 2003." + response = await inference_impl.completion( + content=f"input: '{user_input}'. the schema for json: {Output.schema()}, the json is: ", + stream=False, + model=params["model"], + sampling_params=SamplingParams( + max_tokens=50, + ), + response_format=JsonResponseFormat( + schema=Output.model_json_schema(), + ), + ) + assert isinstance(response, CompletionResponse) + assert isinstance(response.content, str) + + answer = Output.parse_raw(response.content) + assert answer.name == "Michael Jordan" + assert answer.year_born == "1963" + assert answer.year_retired == "2003" + + @pytest.mark.asyncio async def test_chat_completion_non_streaming(inference_settings, sample_messages): inference_impl = inference_settings["impl"] diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index 22ae8a717..086227c73 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -29,9 +29,9 @@ class OpenAICompatCompletionResponse(BaseModel): choices: List[OpenAICompatCompletionChoice] -def get_sampling_options(request: ChatCompletionRequest) -> dict: +def get_sampling_options(params: SamplingParams) -> dict: options = {} - if params := request.sampling_params: + if params: for attr in {"temperature", "top_p", "top_k", "max_tokens"}: if getattr(params, attr): options[attr] = getattr(params, attr) @@ -64,7 +64,18 @@ def process_completion_response( response: OpenAICompatCompletionResponse, formatter: ChatFormat ) -> CompletionResponse: choice = response.choices[0] - + # drop suffix if present and return stop reason as end of turn + if choice.text.endswith("<|eot_id|>"): + return CompletionResponse( + stop_reason=StopReason.end_of_turn, + content=choice.text[: -len("<|eot_id|>")], + ) + # drop suffix if present and return stop reason as end of message + if choice.text.endswith("<|eom_id|>"): + return CompletionResponse( + stop_reason=StopReason.end_of_message, + content=choice.text[: -len("<|eom_id|>")], + ) return CompletionResponse( stop_reason=get_stop_reason(choice.finish_reason), content=choice.text, @@ -95,13 +106,6 @@ async def process_completion_stream_response( choice = chunk.choices[0] finish_reason = choice.finish_reason - if finish_reason: - if finish_reason in ["stop", "eos", "eos_token"]: - stop_reason = StopReason.end_of_turn - elif finish_reason == "length": - stop_reason = StopReason.out_of_tokens - break - text = text_from_choice(choice) if text == "<|eot_id|>": stop_reason = StopReason.end_of_turn @@ -115,6 +119,12 @@ async def process_completion_stream_response( delta=text, stop_reason=stop_reason, ) + if finish_reason: + if finish_reason in ["stop", "eos", "eos_token"]: + stop_reason = StopReason.end_of_turn + elif finish_reason == "length": + stop_reason = StopReason.out_of_tokens + break yield CompletionResponseStreamChunk( delta="", diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index 48f1df02f..d204ab728 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -31,6 +31,13 @@ def completion_request_to_prompt( return formatter.tokenizer.decode(model_input.tokens) +def completion_request_to_prompt_model_input_info( + request: CompletionRequest, formatter: ChatFormat +) -> Tuple[str, int]: + model_input = formatter.encode_content(request.content) + return (formatter.tokenizer.decode(model_input.tokens), len(model_input.tokens)) + + def chat_completion_request_to_prompt( request: ChatCompletionRequest, formatter: ChatFormat ) -> str: From b6d8246b82dd7ae6e9f6a698b6520f684fabe71d Mon Sep 17 00:00:00 2001 From: Justin Lee Date: Thu, 24 Oct 2024 17:07:06 -0700 Subject: [PATCH 14/17] added templates and enhanced readme (#307) Co-authored-by: Justin Lee --- .github/ISSUE_TEMPLATE/bug.yml | 77 ++++++ .github/ISSUE_TEMPLATE/feature-request.yml | 31 +++ .github/PULL_REQUEST_TEMPLATE.md | 31 +++ README.md | 29 ++- docs/getting_started.md | 261 +++++++++++---------- 5 files changed, 293 insertions(+), 136 deletions(-) create mode 100644 .github/ISSUE_TEMPLATE/bug.yml create mode 100644 .github/ISSUE_TEMPLATE/feature-request.yml create mode 100644 .github/PULL_REQUEST_TEMPLATE.md diff --git a/.github/ISSUE_TEMPLATE/bug.yml b/.github/ISSUE_TEMPLATE/bug.yml new file mode 100644 index 000000000..1f7dabb9f --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug.yml @@ -0,0 +1,77 @@ +name: πŸ› Bug Report +description: Create a report to help us reproduce and fix the bug + +body: + - type: markdown + attributes: + value: > + #### Before submitting a bug, please make sure the issue hasn't been already addressed by searching through [the + existing and past issues](https://github.com/meta-llama/llama-stack/issues). + + - type: textarea + id: system-info + attributes: + label: System Info + description: | + Please share your system info with us. You can use the following command to capture your environment information + python -m "torch.utils.collect_env" + + placeholder: | + PyTorch version, CUDA version, GPU type, #num of GPUs... + validations: + required: true + + - type: checkboxes + id: information-scripts-examples + attributes: + label: Information + description: 'The problem arises when using:' + options: + - label: "The official example scripts" + - label: "My own modified scripts" + + - type: textarea + id: bug-description + attributes: + label: πŸ› Describe the bug + description: | + Please provide a clear and concise description of what the bug is. + + Please also paste or describe the results you observe instead of the expected results. + placeholder: | + A clear and concise description of what the bug is. + + ```llama stack + # Command that you used for running the examples + ``` + Description of the results + validations: + required: true + + - type: textarea + attributes: + label: Error logs + description: | + If you observe an error, please paste the error message including the **full** traceback of the exception. It may be relevant to wrap error messages in ```` ```triple quotes blocks``` ````. + + placeholder: | + ``` + The error message you got, with the full traceback. + ``` + + validations: + required: true + + + - type: textarea + id: expected-behavior + validations: + required: true + attributes: + label: Expected behavior + description: "A clear and concise description of what you would expect to happen." + + - type: markdown + attributes: + value: > + Thanks for contributing πŸŽ‰! diff --git a/.github/ISSUE_TEMPLATE/feature-request.yml b/.github/ISSUE_TEMPLATE/feature-request.yml new file mode 100644 index 000000000..db1a43139 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature-request.yml @@ -0,0 +1,31 @@ +name: πŸš€ Feature request +description: Submit a proposal/request for a new llama-stack feature + +body: +- type: textarea + id: feature-pitch + attributes: + label: πŸš€ The feature, motivation and pitch + description: > + A clear and concise description of the feature proposal. Please outline the motivation for the proposal. Is your feature request related to a specific problem? e.g., *"I'm working on X and would like Y to be possible"*. If this is related to another GitHub issue, please link here too. + validations: + required: true + +- type: textarea + id: alternatives + attributes: + label: Alternatives + description: > + A description of any alternative solutions or features you've considered, if any. + +- type: textarea + id: additional-context + attributes: + label: Additional context + description: > + Add any other context or screenshots about the feature request. + +- type: markdown + attributes: + value: > + Thanks for contributing πŸŽ‰! diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 000000000..a92442dc1 --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,31 @@ +# What does this PR do? + +Closes # (issue) + +## Feature/Issue validation/testing/test plan + +Please describe the tests that you ran to verify your changes and relevant result summary. Provide instructions so it can be reproduced. +Please also list any relevant details for your test configuration or test plan. + +- [ ] Test A +Logs for Test A + +- [ ] Test B +Logs for Test B + + +## Sources + +Please link relevant resources if necessary. + + +## Before submitting +- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). +- [ ] Did you read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), + Pull Request section? +- [ ] Was this discussed/approved via a Github issue? Please add a link + to it if that's the case. +- [ ] Did you make sure to update the documentation with your changes? +- [ ] Did you write any new necessary tests? + +Thanks for contributing πŸŽ‰! diff --git a/README.md b/README.md index 973a9a396..251b81513 100644 --- a/README.md +++ b/README.md @@ -65,23 +65,30 @@ A Distribution is where APIs and Providers are assembled together to provide a c | Dell-TGI | [Local TGI + Chroma](https://hub.docker.com/repository/docker/llamastack/llamastack-local-tgi-chroma/general) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | + ## Installation -You can install this repository as a [package](https://pypi.org/project/llama-stack/) with `pip install llama-stack` +You have two ways to install this repository: -If you want to install from source: +1. **Install as a package**: + You can install the repository directly from [PyPI](https://pypi.org/project/llama-stack/) by running the following command: + ```bash + pip install llama-stack + ``` -```bash -mkdir -p ~/local -cd ~/local -git clone git@github.com:meta-llama/llama-stack.git +2. **Install from source**: + If you prefer to install from the source code, follow these steps: + ```bash + mkdir -p ~/local + cd ~/local + git clone git@github.com:meta-llama/llama-stack.git -conda create -n stack python=3.10 -conda activate stack + conda create -n stack python=3.10 + conda activate stack -cd llama-stack -$CONDA_PREFIX/bin/pip install -e . -``` + cd llama-stack + $CONDA_PREFIX/bin/pip install -e . + ``` ## Documentations diff --git a/docs/getting_started.md b/docs/getting_started.md index e08885a72..4f06f5d47 100644 --- a/docs/getting_started.md +++ b/docs/getting_started.md @@ -5,163 +5,174 @@ This guide will walk you though the steps to get started on end-to-end flow for ## Installation The `llama` CLI tool helps you setup and use the Llama toolchain & agentic systems. It should be available on your path after installing the `llama-stack` package. -You can install this repository as a [package](https://pypi.org/project/llama-stack/) with `pip install llama-stack` +You have two ways to install this repository: -If you want to install from source: +1. **Install as a package**: + You can install the repository directly from [PyPI](https://pypi.org/project/llama-stack/) by running the following command: + ```bash + pip install llama-stack + ``` -```bash -mkdir -p ~/local -cd ~/local -git clone git@github.com:meta-llama/llama-stack.git +2. **Install from source**: + If you prefer to install from the source code, follow these steps: + ```bash + mkdir -p ~/local + cd ~/local + git clone git@github.com:meta-llama/llama-stack.git -conda create -n stack python=3.10 -conda activate stack + conda create -n stack python=3.10 + conda activate stack -cd llama-stack -$CONDA_PREFIX/bin/pip install -e . -``` + cd llama-stack + $CONDA_PREFIX/bin/pip install -e . + ``` For what you can do with the Llama CLI, please refer to [CLI Reference](./cli_reference.md). ## Starting Up Llama Stack Server -#### Starting up server via docker -We provide 2 pre-built Docker image of Llama Stack distribution, which can be found in the following links. -- [llamastack-local-gpu](https://hub.docker.com/repository/docker/llamastack/llamastack-local-gpu/general) - - This is a packaged version with our local meta-reference implementations, where you will be running inference locally with downloaded Llama model checkpoints. -- [llamastack-local-cpu](https://hub.docker.com/repository/docker/llamastack/llamastack-local-cpu/general) - - This is a lite version with remote inference where you can hook up to your favourite remote inference framework (e.g. ollama, fireworks, together, tgi) for running inference without GPU. +You have two ways to start up Llama stack server: -> [!NOTE] -> For GPU inference, you need to set these environment variables for specifying local directory containing your model checkpoints, and enable GPU inference to start running docker container. -``` -export LLAMA_CHECKPOINT_DIR=~/.llama -``` +1. **Starting up server via docker**: -> [!NOTE] -> `~/.llama` should be the path containing downloaded weights of Llama models. + We provide 2 pre-built Docker image of Llama Stack distribution, which can be found in the following links. + - [llamastack-local-gpu](https://hub.docker.com/repository/docker/llamastack/llamastack-local-gpu/general) + - This is a packaged version with our local meta-reference implementations, where you will be running inference locally with downloaded Llama model checkpoints. + - [llamastack-local-cpu](https://hub.docker.com/repository/docker/llamastack/llamastack-local-cpu/general) + - This is a lite version with remote inference where you can hook up to your favourite remote inference framework (e.g. ollama, fireworks, together, tgi) for running inference without GPU. -To download llama models, use -``` -llama download --model-id Llama3.1-8B-Instruct -``` + > [!NOTE] + > For GPU inference, you need to set these environment variables for specifying local directory containing your model checkpoints, and enable GPU inference to start running docker container. + ``` + export LLAMA_CHECKPOINT_DIR=~/.llama + ``` -To download and start running a pre-built docker container, you may use the following commands: + > [!NOTE] + > `~/.llama` should be the path containing downloaded weights of Llama models. -``` -docker run -it -p 5000:5000 -v ~/.llama:/root/.llama --gpus=all llamastack/llamastack-local-gpu -``` + To download llama models, use + ``` + llama download --model-id Llama3.1-8B-Instruct + ``` -> [!TIP] -> Pro Tip: We may use `docker compose up` for starting up a distribution with remote providers (e.g. TGI) using [llamastack-local-cpu](https://hub.docker.com/repository/docker/llamastack/llamastack-local-cpu/general). You can checkout [these scripts](../distributions/) to help you get started. + To download and start running a pre-built docker container, you may use the following commands: -#### Build->Configure->Run Llama Stack server via conda -You may also build a LlamaStack distribution from scratch, configure it, and start running the distribution. This is useful for developing on LlamaStack. + ``` + docker run -it -p 5000:5000 -v ~/.llama:/root/.llama --gpus=all llamastack/llamastack-local-gpu + ``` -**`llama stack build`** -- You'll be prompted to enter build information interactively. -``` -llama stack build + > [!TIP] + > Pro Tip: We may use `docker compose up` for starting up a distribution with remote providers (e.g. TGI) using [llamastack-local-cpu](https://hub.docker.com/repository/docker/llamastack/llamastack-local-cpu/general). You can checkout [these scripts](../distributions/) to help you get started. -> Enter an unique name for identifying your Llama Stack build distribution (e.g. my-local-stack): my-local-stack -> Enter the image type you want your distribution to be built with (docker or conda): conda - Llama Stack is composed of several APIs working together. Let's configure the providers (implementations) you want to use for these APIs. -> Enter the API provider for the inference API: (default=meta-reference): meta-reference -> Enter the API provider for the safety API: (default=meta-reference): meta-reference -> Enter the API provider for the agents API: (default=meta-reference): meta-reference -> Enter the API provider for the memory API: (default=meta-reference): meta-reference -> Enter the API provider for the telemetry API: (default=meta-reference): meta-reference +2. **Build->Configure->Run Llama Stack server via conda**: - > (Optional) Enter a short description for your Llama Stack distribution: + You may also build a LlamaStack distribution from scratch, configure it, and start running the distribution. This is useful for developing on LlamaStack. -Build spec configuration saved at ~/.conda/envs/llamastack-my-local-stack/my-local-stack-build.yaml -You can now run `llama stack configure my-local-stack` -``` + **`llama stack build`** + - You'll be prompted to enter build information interactively. + ``` + llama stack build -**`llama stack configure`** -- Run `llama stack configure ` with the name you have previously defined in `build` step. -``` -llama stack configure -``` -- You will be prompted to enter configurations for your Llama Stack + > Enter an unique name for identifying your Llama Stack build distribution (e.g. my-local-stack): my-local-stack + > Enter the image type you want your distribution to be built with (docker or conda): conda -``` -$ llama stack configure my-local-stack + Llama Stack is composed of several APIs working together. Let's configure the providers (implementations) you want to use for these APIs. + > Enter the API provider for the inference API: (default=meta-reference): meta-reference + > Enter the API provider for the safety API: (default=meta-reference): meta-reference + > Enter the API provider for the agents API: (default=meta-reference): meta-reference + > Enter the API provider for the memory API: (default=meta-reference): meta-reference + > Enter the API provider for the telemetry API: (default=meta-reference): meta-reference -Could not find my-local-stack. Trying conda build name instead... -Configuring API `inference`... -=== Configuring provider `meta-reference` for API inference... -Enter value for model (default: Llama3.1-8B-Instruct) (required): -Do you want to configure quantization? (y/n): n -Enter value for torch_seed (optional): -Enter value for max_seq_len (default: 4096) (required): -Enter value for max_batch_size (default: 1) (required): + > (Optional) Enter a short description for your Llama Stack distribution: -Configuring API `safety`... -=== Configuring provider `meta-reference` for API safety... -Do you want to configure llama_guard_shield? (y/n): n -Do you want to configure prompt_guard_shield? (y/n): n + Build spec configuration saved at ~/.conda/envs/llamastack-my-local-stack/my-local-stack-build.yaml + You can now run `llama stack configure my-local-stack` + ``` -Configuring API `agents`... -=== Configuring provider `meta-reference` for API agents... -Enter `type` for persistence_store (options: redis, sqlite, postgres) (default: sqlite): + **`llama stack configure`** + - Run `llama stack configure ` with the name you have previously defined in `build` step. + ``` + llama stack configure + ``` + - You will be prompted to enter configurations for your Llama Stack -Configuring SqliteKVStoreConfig: -Enter value for namespace (optional): -Enter value for db_path (default: /home/xiyan/.llama/runtime/kvstore.db) (required): + ``` + $ llama stack configure my-local-stack -Configuring API `memory`... -=== Configuring provider `meta-reference` for API memory... -> Please enter the supported memory bank type your provider has for memory: vector + Could not find my-local-stack. Trying conda build name instead... + Configuring API `inference`... + === Configuring provider `meta-reference` for API inference... + Enter value for model (default: Llama3.1-8B-Instruct) (required): + Do you want to configure quantization? (y/n): n + Enter value for torch_seed (optional): + Enter value for max_seq_len (default: 4096) (required): + Enter value for max_batch_size (default: 1) (required): -Configuring API `telemetry`... -=== Configuring provider `meta-reference` for API telemetry... + Configuring API `safety`... + === Configuring provider `meta-reference` for API safety... + Do you want to configure llama_guard_shield? (y/n): n + Do you want to configure prompt_guard_shield? (y/n): n -> YAML configuration has been written to ~/.llama/builds/conda/my-local-stack-run.yaml. -You can now run `llama stack run my-local-stack --port PORT` -``` + Configuring API `agents`... + === Configuring provider `meta-reference` for API agents... + Enter `type` for persistence_store (options: redis, sqlite, postgres) (default: sqlite): -**`llama stack run`** -- Run `llama stack run ` with the name you have previously defined. -``` -llama stack run my-local-stack + Configuring SqliteKVStoreConfig: + Enter value for namespace (optional): + Enter value for db_path (default: /home/xiyan/.llama/runtime/kvstore.db) (required): -... -> initializing model parallel with size 1 -> initializing ddp with size 1 -> initializing pipeline with size 1 -... -Finished model load YES READY -Serving POST /inference/chat_completion -Serving POST /inference/completion -Serving POST /inference/embeddings -Serving POST /memory_banks/create -Serving DELETE /memory_bank/documents/delete -Serving DELETE /memory_banks/drop -Serving GET /memory_bank/documents/get -Serving GET /memory_banks/get -Serving POST /memory_bank/insert -Serving GET /memory_banks/list -Serving POST /memory_bank/query -Serving POST /memory_bank/update -Serving POST /safety/run_shield -Serving POST /agentic_system/create -Serving POST /agentic_system/session/create -Serving POST /agentic_system/turn/create -Serving POST /agentic_system/delete -Serving POST /agentic_system/session/delete -Serving POST /agentic_system/session/get -Serving POST /agentic_system/step/get -Serving POST /agentic_system/turn/get -Serving GET /telemetry/get_trace -Serving POST /telemetry/log_event -Listening on :::5000 -INFO: Started server process [587053] -INFO: Waiting for application startup. -INFO: Application startup complete. -INFO: Uvicorn running on http://[::]:5000 (Press CTRL+C to quit) -``` + Configuring API `memory`... + === Configuring provider `meta-reference` for API memory... + > Please enter the supported memory bank type your provider has for memory: vector + + Configuring API `telemetry`... + === Configuring provider `meta-reference` for API telemetry... + + > YAML configuration has been written to ~/.llama/builds/conda/my-local-stack-run.yaml. + You can now run `llama stack run my-local-stack --port PORT` + ``` + + **`llama stack run`** + - Run `llama stack run ` with the name you have previously defined. + ``` + llama stack run my-local-stack + + ... + > initializing model parallel with size 1 + > initializing ddp with size 1 + > initializing pipeline with size 1 + ... + Finished model load YES READY + Serving POST /inference/chat_completion + Serving POST /inference/completion + Serving POST /inference/embeddings + Serving POST /memory_banks/create + Serving DELETE /memory_bank/documents/delete + Serving DELETE /memory_banks/drop + Serving GET /memory_bank/documents/get + Serving GET /memory_banks/get + Serving POST /memory_bank/insert + Serving GET /memory_banks/list + Serving POST /memory_bank/query + Serving POST /memory_bank/update + Serving POST /safety/run_shield + Serving POST /agentic_system/create + Serving POST /agentic_system/session/create + Serving POST /agentic_system/turn/create + Serving POST /agentic_system/delete + Serving POST /agentic_system/session/delete + Serving POST /agentic_system/session/get + Serving POST /agentic_system/step/get + Serving POST /agentic_system/turn/get + Serving GET /telemetry/get_trace + Serving POST /telemetry/log_event + Listening on :::5000 + INFO: Started server process [587053] + INFO: Waiting for application startup. + INFO: Application startup complete. + INFO: Uvicorn running on http://[::]:5000 (Press CTRL+C to quit) + ``` ## Testing with client From df141b6ef3b864da586e0c32a1fcd8f6b2ce6351 Mon Sep 17 00:00:00 2001 From: Sarthak Deshpande <60317842+cheesecake100201@users.noreply.github.com> Date: Fri, 25 Oct 2024 07:06:27 +0530 Subject: [PATCH 15/17] Fix for get_agents_session (#300) --- llama_stack/providers/impls/meta_reference/agents/agents.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama_stack/providers/impls/meta_reference/agents/agents.py b/llama_stack/providers/impls/meta_reference/agents/agents.py index ca5a00359..13d9044fd 100644 --- a/llama_stack/providers/impls/meta_reference/agents/agents.py +++ b/llama_stack/providers/impls/meta_reference/agents/agents.py @@ -169,7 +169,7 @@ class MetaReferenceAgentsImpl(Agents): turn_ids: Optional[List[str]] = None, ) -> Session: session = await self.persistence_store.get(f"session:{agent_id}:{session_id}") - session = Session(**json.loads(session)) + session = Session(**json.loads(session), turns=[]) turns = [] if turn_ids: for turn_id in turn_ids: From cb43caa2c3cb3ff9d23eca281b6fda2c14e73ec1 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 24 Oct 2024 21:29:07 -0700 Subject: [PATCH 16/17] start_container.sh prefix llamastack->distribution name --- llama_stack/distribution/start_container.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama_stack/distribution/start_container.sh b/llama_stack/distribution/start_container.sh index 8533da7d1..fe1b5051f 100755 --- a/llama_stack/distribution/start_container.sh +++ b/llama_stack/distribution/start_container.sh @@ -29,7 +29,7 @@ if [ $# -lt 3 ]; then fi build_name="$1" -docker_image="llamastack-$build_name" +docker_image="distribution-$build_name" shift yaml_config="$1" From 70d59b0f5de914bea011ae5256f3a24330dbb83b Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Thu, 24 Oct 2024 22:30:49 -0700 Subject: [PATCH 17/17] Make vllm inference better Tests still don't pass completely (some hang) so I think there are some potential threading issues maybe --- llama_stack/providers/impls/vllm/config.py | 13 +- llama_stack/providers/impls/vllm/vllm.py | 155 ++++++++++----------- 2 files changed, 84 insertions(+), 84 deletions(-) diff --git a/llama_stack/providers/impls/vllm/config.py b/llama_stack/providers/impls/vllm/config.py index df2526f2e..a7469ebde 100644 --- a/llama_stack/providers/impls/vllm/config.py +++ b/llama_stack/providers/impls/vllm/config.py @@ -15,13 +15,24 @@ class VLLMConfig(BaseModel): """Configuration for the vLLM inference provider.""" model: str = Field( - default="Llama3.1-8B-Instruct", + default="Llama3.2-3B-Instruct", description="Model descriptor from `llama model list`", ) tensor_parallel_size: int = Field( default=1, description="Number of tensor parallel replicas (number of GPUs to use).", ) + max_tokens: int = Field( + default=4096, + description="Maximum number of tokens to generate.", + ) + enforce_eager: bool = Field( + default=False, + description="Whether to use eager mode for inference (otherwise cuda graphs are used).", + ) + gpu_memory_utilization: float = Field( + default=0.3, + ) @field_validator("model") @classmethod diff --git a/llama_stack/providers/impls/vllm/vllm.py b/llama_stack/providers/impls/vllm/vllm.py index ad3ad8fb7..cf5b0572b 100644 --- a/llama_stack/providers/impls/vllm/vllm.py +++ b/llama_stack/providers/impls/vllm/vllm.py @@ -7,11 +7,12 @@ import logging import os import uuid -from typing import Any, AsyncGenerator +from typing import AsyncGenerator, Optional from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.tokenizer import Tokenizer +from llama_models.sku_list import resolve_model from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine @@ -19,7 +20,7 @@ from vllm.sampling_params import SamplingParams as VLLMSamplingParams from llama_stack.apis.inference import * # noqa: F403 -from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper +from llama_stack.providers.datatypes import ModelDef, ModelsProtocolPrivate from llama_stack.providers.utils.inference.openai_compat import ( OpenAICompatCompletionChoice, OpenAICompatCompletionResponse, @@ -40,74 +41,15 @@ def _random_uuid() -> str: return str(uuid.uuid4().hex) -def _vllm_sampling_params(sampling_params: Any) -> VLLMSamplingParams: - """Convert sampling params to vLLM sampling params.""" - if sampling_params is None: - return VLLMSamplingParams() - - # TODO convert what I saw in my first test ... but surely there's more to do here - kwargs = { - "temperature": sampling_params.temperature, - } - if sampling_params.top_k >= 1: - kwargs["top_k"] = sampling_params.top_k - if sampling_params.top_p: - kwargs["top_p"] = sampling_params.top_p - if sampling_params.max_tokens >= 1: - kwargs["max_tokens"] = sampling_params.max_tokens - if sampling_params.repetition_penalty > 0: - kwargs["repetition_penalty"] = sampling_params.repetition_penalty - - return VLLMSamplingParams(**kwargs) - - -class VLLMInferenceImpl(ModelRegistryHelper, Inference): +class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): """Inference implementation for vLLM.""" - HF_MODEL_MAPPINGS = { - # TODO: seems like we should be able to build this table dynamically ... - "Llama3.1-8B": "meta-llama/Llama-3.1-8B", - "Llama3.1-70B": "meta-llama/Llama-3.1-70B", - "Llama3.1-405B:bf16-mp8": "meta-llama/Llama-3.1-405B", - "Llama3.1-405B": "meta-llama/Llama-3.1-405B-FP8", - "Llama3.1-405B:bf16-mp16": "meta-llama/Llama-3.1-405B", - "Llama3.1-8B-Instruct": "meta-llama/Llama-3.1-8B-Instruct", - "Llama3.1-70B-Instruct": "meta-llama/Llama-3.1-70B-Instruct", - "Llama3.1-405B-Instruct:bf16-mp8": "meta-llama/Llama-3.1-405B-Instruct", - "Llama3.1-405B-Instruct": "meta-llama/Llama-3.1-405B-Instruct-FP8", - "Llama3.1-405B-Instruct:bf16-mp16": "meta-llama/Llama-3.1-405B-Instruct", - "Llama3.2-1B": "meta-llama/Llama-3.2-1B", - "Llama3.2-3B": "meta-llama/Llama-3.2-3B", - "Llama3.2-11B-Vision": "meta-llama/Llama-3.2-11B-Vision", - "Llama3.2-90B-Vision": "meta-llama/Llama-3.2-90B-Vision", - "Llama3.2-1B-Instruct": "meta-llama/Llama-3.2-1B-Instruct", - "Llama3.2-3B-Instruct": "meta-llama/Llama-3.2-3B-Instruct", - "Llama3.2-11B-Vision-Instruct": "meta-llama/Llama-3.2-11B-Vision-Instruct", - "Llama3.2-90B-Vision-Instruct": "meta-llama/Llama-3.2-90B-Vision-Instruct", - "Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision", - "Llama-Guard-3-1B:int4-mp1": "meta-llama/Llama-Guard-3-1B-INT4", - "Llama-Guard-3-1B": "meta-llama/Llama-Guard-3-1B", - "Llama-Guard-3-8B": "meta-llama/Llama-Guard-3-8B", - "Llama-Guard-3-8B:int8-mp1": "meta-llama/Llama-Guard-3-8B-INT8", - "Prompt-Guard-86M": "meta-llama/Prompt-Guard-86M", - "Llama-Guard-2-8B": "meta-llama/Llama-Guard-2-8B", - } - def __init__(self, config: VLLMConfig): - Inference.__init__(self) - ModelRegistryHelper.__init__( - self, - stack_to_provider_models_map=self.HF_MODEL_MAPPINGS, - ) self.config = config self.engine = None - - tokenizer = Tokenizer.get_instance() - self.formatter = ChatFormat(tokenizer) + self.formatter = ChatFormat(Tokenizer.get_instance()) async def initialize(self): - """Initialize the vLLM inference adapter.""" - log.info("Initializing vLLM inference adapter") # Disable usage stats reporting. This would be a surprising thing for most @@ -116,15 +58,22 @@ class VLLMInferenceImpl(ModelRegistryHelper, Inference): if "VLLM_NO_USAGE_STATS" not in os.environ: os.environ["VLLM_NO_USAGE_STATS"] = "1" - hf_model = self.HF_MODEL_MAPPINGS.get(self.config.model) + model = resolve_model(self.config.model) + if model is None: + raise ValueError(f"Unknown model {self.config.model}") + + if model.huggingface_repo is None: + raise ValueError(f"Model {self.config.model} needs a huggingface repo") # TODO -- there are a ton of options supported here ... - engine_args = AsyncEngineArgs() - engine_args.model = hf_model - # We will need a new config item for this in the future if model support is more broad - # than it is today (llama only) - engine_args.tokenizer = hf_model - engine_args.tensor_parallel_size = self.config.tensor_parallel_size + engine_args = AsyncEngineArgs( + model=model.huggingface_repo, + tokenizer=model.huggingface_repo, + tensor_parallel_size=self.config.tensor_parallel_size, + enforce_eager=self.config.enforce_eager, + gpu_memory_utilization=self.config.gpu_memory_utilization, + guided_decoding_backend="lm-format-enforcer", + ) self.engine = AsyncLLMEngine.from_engine_args(engine_args) @@ -134,13 +83,47 @@ class VLLMInferenceImpl(ModelRegistryHelper, Inference): if self.engine: self.engine.shutdown_background_loop() + async def register_model(self, model: ModelDef) -> None: + raise ValueError( + "You cannot dynamically add a model to a running vllm instance" + ) + + async def list_models(self) -> List[ModelDef]: + return [ + ModelDef( + identifier=self.config.model, + llama_model=self.config.model, + ) + ] + + def _sampling_params(self, sampling_params: SamplingParams) -> VLLMSamplingParams: + if sampling_params is None: + return VLLMSamplingParams(max_tokens=self.config.max_tokens) + + # TODO convert what I saw in my first test ... but surely there's more to do here + kwargs = { + "temperature": sampling_params.temperature, + "max_tokens": self.config.max_tokens, + } + if sampling_params.top_k: + kwargs["top_k"] = sampling_params.top_k + if sampling_params.top_p: + kwargs["top_p"] = sampling_params.top_p + if sampling_params.max_tokens: + kwargs["max_tokens"] = sampling_params.max_tokens + if sampling_params.repetition_penalty > 0: + kwargs["repetition_penalty"] = sampling_params.repetition_penalty + + return VLLMSamplingParams(**kwargs) + async def completion( self, model: str, content: InterleavedTextMedia, - sampling_params: Any | None = ..., - stream: bool | None = False, - logprobs: LogProbConfig | None = None, + sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, ) -> CompletionResponse | CompletionResponseStreamChunk: log.info("vLLM completion") messages = [UserMessage(content=content)] @@ -155,13 +138,14 @@ class VLLMInferenceImpl(ModelRegistryHelper, Inference): async def chat_completion( self, model: str, - messages: list[Message], - sampling_params: Any | None = ..., - tools: list[ToolDefinition] | None = ..., - tool_choice: ToolChoice | None = ..., - tool_prompt_format: ToolPromptFormat | None = ..., - stream: bool | None = False, - logprobs: LogProbConfig | None = None, + messages: List[Message], + sampling_params: Optional[SamplingParams] = SamplingParams(), + tools: Optional[List[ToolDefinition]] = None, + tool_choice: Optional[ToolChoice] = ToolChoice.auto, + tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, + response_format: Optional[ResponseFormat] = None, + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, ) -> ChatCompletionResponse | ChatCompletionResponseStreamChunk: log.info("vLLM chat completion") @@ -182,7 +166,7 @@ class VLLMInferenceImpl(ModelRegistryHelper, Inference): request_id = _random_uuid() prompt = chat_completion_request_to_prompt(request, self.formatter) - vllm_sampling_params = _vllm_sampling_params(request.sampling_params) + vllm_sampling_params = self._sampling_params(request.sampling_params) results_generator = self.engine.generate( prompt, vllm_sampling_params, request_id ) @@ -213,14 +197,19 @@ class VLLMInferenceImpl(ModelRegistryHelper, Inference): self, request: ChatCompletionRequest, results_generator: AsyncGenerator ) -> AsyncGenerator: async def _generate_and_convert_to_openai_compat(): + cur = [] async for chunk in results_generator: if not chunk.outputs: log.warning("Empty chunk received") continue - text = "".join([output.text for output in chunk.outputs]) + output = chunk.outputs[-1] + + new_tokens = output.token_ids[len(cur) :] + text = self.formatter.tokenizer.decode(new_tokens) + cur.extend(new_tokens) choice = OpenAICompatCompletionChoice( - finish_reason=chunk.outputs[-1].stop_reason, + finish_reason=output.finish_reason, text=text, ) yield OpenAICompatCompletionResponse(