diff --git a/distributions/dependencies.json b/distributions/dependencies.json index 60ccc10e5..da43c019c 100644 --- a/distributions/dependencies.json +++ b/distributions/dependencies.json @@ -158,6 +158,7 @@ "pandas", "pillow", "psycopg2-binary", + "pymongo", "pypdf", "redis", "requests", diff --git a/llama_stack/providers/remote/inference/groq/groq.py b/llama_stack/providers/remote/inference/groq/groq.py index 45c15a467..2c9fab614 100644 --- a/llama_stack/providers/remote/inference/groq/groq.py +++ b/llama_stack/providers/remote/inference/groq/groq.py @@ -29,17 +29,10 @@ from llama_stack.apis.inference import ( ToolConfig, ) from llama_stack.distribution.request_headers import NeedsRequestProviderData -from llama_stack.models.llama.datatypes import ( - SamplingParams, - ToolDefinition, - ToolPromptFormat, -) -from llama_stack.models.llama.sku_list import CoreModelId +from llama_stack.models.llama.datatypes import SamplingParams, ToolDefinition, ToolPromptFormat from llama_stack.providers.remote.inference.groq.config import GroqConfig from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, - build_hf_repo_model_entry, - build_model_entry, ) from .groq_utils import ( @@ -47,33 +40,7 @@ from .groq_utils import ( convert_chat_completion_response, convert_chat_completion_response_stream, ) - -_MODEL_ENTRIES = [ - build_hf_repo_model_entry( - "llama3-8b-8192", - CoreModelId.llama3_1_8b_instruct.value, - ), - build_model_entry( - "llama-3.1-8b-instant", - CoreModelId.llama3_1_8b_instruct.value, - ), - build_hf_repo_model_entry( - "llama3-70b-8192", - CoreModelId.llama3_70b_instruct.value, - ), - build_hf_repo_model_entry( - "llama-3.3-70b-versatile", - CoreModelId.llama3_3_70b_instruct.value, - ), - # Groq only contains a preview version for llama-3.2-3b - # Preview models aren't recommended for production use, but we include this one - # to pass the test fixture - # TODO(aidand): Replace this with a stable model once Groq supports it - build_hf_repo_model_entry( - "llama-3.2-3b-preview", - CoreModelId.llama3_2_3b_instruct.value, - ), -] +from .models import _MODEL_ENTRIES class GroqInferenceAdapter(Inference, ModelRegistryHelper, NeedsRequestProviderData): diff --git a/llama_stack/providers/remote/inference/groq/models.py b/llama_stack/providers/remote/inference/groq/models.py new file mode 100644 index 000000000..fd73e42e5 --- /dev/null +++ b/llama_stack/providers/remote/inference/groq/models.py @@ -0,0 +1,38 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.models.llama.sku_list import CoreModelId +from llama_stack.providers.utils.inference.model_registry import ( + build_model_alias, + build_model_alias_with_just_provider_model_id, +) + +_MODEL_ALIASES = [ + build_model_alias( + "llama3-8b-8192", + CoreModelId.llama3_1_8b_instruct.value, + ), + build_model_alias_with_just_provider_model_id( + "llama-3.1-8b-instant", + CoreModelId.llama3_1_8b_instruct.value, + ), + build_model_alias( + "llama3-70b-8192", + CoreModelId.llama3_70b_instruct.value, + ), + build_model_alias( + "llama-3.3-70b-versatile", + CoreModelId.llama3_3_70b_instruct.value, + ), + # Groq only contains a preview version for llama-3.2-3b + # Preview models aren't recommended for production use, but we include this one + # to pass the test fixture + # TODO(aidand): Replace this with a stable model once Groq supports it + build_model_alias( + "llama-3.2-3b-preview", + CoreModelId.llama3_2_3b_instruct.value, + ), +] diff --git a/llama_stack/templates/groq/groq.py b/llama_stack/templates/groq/groq.py index 9a82cb916..b81fc5e78 100644 --- a/llama_stack/templates/groq/groq.py +++ b/llama_stack/templates/groq/groq.py @@ -6,22 +6,26 @@ from pathlib import Path +from llama_stack.apis.models.models import ModelType from llama_stack.distribution.datatypes import ( ModelInput, Provider, - ShieldInput, ToolGroupInput, ) from llama_stack.models.llama.sku_list import all_registered_models +from llama_stack.providers.inline.inference.sentence_transformers import ( + SentenceTransformersInferenceConfig, +) +from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig from llama_stack.providers.remote.inference.groq import GroqConfig -from llama_stack.providers.remote.inference.groq.groq import _MODEL_ALIASES +from llama_stack.providers.remote.inference.groq.models import _MODEL_ALIASES from llama_stack.templates.template import DistributionTemplate, RunConfigSettings def get_distribution_template() -> DistributionTemplate: providers = { "inference": ["remote::groq"], - "vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"], + "vector_io": ["inline::faiss"], "safety": ["inline::llama-guard"], "agents": ["inline::meta-reference"], "telemetry": ["inline::meta-reference"], @@ -43,6 +47,25 @@ def get_distribution_template() -> DistributionTemplate: config=GroqConfig.sample_run_config(), ) + embedding_provider = Provider( + provider_id="sentence-transformers", + provider_type="inline::sentence-transformers", + config=SentenceTransformersInferenceConfig.sample_run_config(), + ) + vector_io_provider = Provider( + provider_id="faiss", + provider_type="inline::faiss", + config=FaissVectorIOConfig.sample_run_config(f"distributions/{name}"), + ) + embedding_model = ModelInput( + model_id="all-MiniLM-L6-v2", + provider_id="sentence-transformers", + model_type=ModelType.embedding, + metadata={ + "embedding_dimension": 384, + }, + ) + core_model_to_hf_repo = {m.descriptor(): m.huggingface_repo for m in all_registered_models()} default_models = [ ModelInput( @@ -79,10 +102,9 @@ def get_distribution_template() -> DistributionTemplate: run_configs={ "run.yaml": RunConfigSettings( provider_overrides={ - "inference": [inference_provider], + "inference": [inference_provider, embedding_provider], }, - default_models=default_models, - default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")], + default_models=default_models + [embedding_model], default_tool_groups=default_tool_groups, ), },