move model_aliases, some fixes for template

This commit is contained in:
Vladislav 2025-02-20 15:06:55 +01:00
parent 7bb7597c00
commit 19f3b23d47
4 changed files with 69 additions and 41 deletions

View file

@ -158,6 +158,7 @@
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"redis",
"requests",

View file

@ -29,17 +29,10 @@ from llama_stack.apis.inference import (
ToolConfig,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.models.llama.datatypes import (
SamplingParams,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.models.llama.sku_list import CoreModelId
from llama_stack.models.llama.datatypes import SamplingParams, ToolDefinition, ToolPromptFormat
from llama_stack.providers.remote.inference.groq.config import GroqConfig
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
build_hf_repo_model_entry,
build_model_entry,
)
from .groq_utils import (
@ -47,33 +40,7 @@ from .groq_utils import (
convert_chat_completion_response,
convert_chat_completion_response_stream,
)
_MODEL_ENTRIES = [
build_hf_repo_model_entry(
"llama3-8b-8192",
CoreModelId.llama3_1_8b_instruct.value,
),
build_model_entry(
"llama-3.1-8b-instant",
CoreModelId.llama3_1_8b_instruct.value,
),
build_hf_repo_model_entry(
"llama3-70b-8192",
CoreModelId.llama3_70b_instruct.value,
),
build_hf_repo_model_entry(
"llama-3.3-70b-versatile",
CoreModelId.llama3_3_70b_instruct.value,
),
# Groq only contains a preview version for llama-3.2-3b
# Preview models aren't recommended for production use, but we include this one
# to pass the test fixture
# TODO(aidand): Replace this with a stable model once Groq supports it
build_hf_repo_model_entry(
"llama-3.2-3b-preview",
CoreModelId.llama3_2_3b_instruct.value,
),
]
from .models import _MODEL_ENTRIES
class GroqInferenceAdapter(Inference, ModelRegistryHelper, NeedsRequestProviderData):

View file

@ -0,0 +1,38 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.models.llama.sku_list import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
build_model_alias_with_just_provider_model_id,
)
_MODEL_ALIASES = [
build_model_alias(
"llama3-8b-8192",
CoreModelId.llama3_1_8b_instruct.value,
),
build_model_alias_with_just_provider_model_id(
"llama-3.1-8b-instant",
CoreModelId.llama3_1_8b_instruct.value,
),
build_model_alias(
"llama3-70b-8192",
CoreModelId.llama3_70b_instruct.value,
),
build_model_alias(
"llama-3.3-70b-versatile",
CoreModelId.llama3_3_70b_instruct.value,
),
# Groq only contains a preview version for llama-3.2-3b
# Preview models aren't recommended for production use, but we include this one
# to pass the test fixture
# TODO(aidand): Replace this with a stable model once Groq supports it
build_model_alias(
"llama-3.2-3b-preview",
CoreModelId.llama3_2_3b_instruct.value,
),
]

View file

@ -6,22 +6,26 @@
from pathlib import Path
from llama_stack.apis.models.models import ModelType
from llama_stack.distribution.datatypes import (
ModelInput,
Provider,
ShieldInput,
ToolGroupInput,
)
from llama_stack.models.llama.sku_list import all_registered_models
from llama_stack.providers.inline.inference.sentence_transformers import (
SentenceTransformersInferenceConfig,
)
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
from llama_stack.providers.remote.inference.groq import GroqConfig
from llama_stack.providers.remote.inference.groq.groq import _MODEL_ALIASES
from llama_stack.providers.remote.inference.groq.models import _MODEL_ALIASES
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
def get_distribution_template() -> DistributionTemplate:
providers = {
"inference": ["remote::groq"],
"vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
"vector_io": ["inline::faiss"],
"safety": ["inline::llama-guard"],
"agents": ["inline::meta-reference"],
"telemetry": ["inline::meta-reference"],
@ -43,6 +47,25 @@ def get_distribution_template() -> DistributionTemplate:
config=GroqConfig.sample_run_config(),
)
embedding_provider = Provider(
provider_id="sentence-transformers",
provider_type="inline::sentence-transformers",
config=SentenceTransformersInferenceConfig.sample_run_config(),
)
vector_io_provider = Provider(
provider_id="faiss",
provider_type="inline::faiss",
config=FaissVectorIOConfig.sample_run_config(f"distributions/{name}"),
)
embedding_model = ModelInput(
model_id="all-MiniLM-L6-v2",
provider_id="sentence-transformers",
model_type=ModelType.embedding,
metadata={
"embedding_dimension": 384,
},
)
core_model_to_hf_repo = {m.descriptor(): m.huggingface_repo for m in all_registered_models()}
default_models = [
ModelInput(
@ -79,10 +102,9 @@ def get_distribution_template() -> DistributionTemplate:
run_configs={
"run.yaml": RunConfigSettings(
provider_overrides={
"inference": [inference_provider],
"inference": [inference_provider, embedding_provider],
},
default_models=default_models,
default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")],
default_models=default_models + [embedding_model],
default_tool_groups=default_tool_groups,
),
},