From 9436dd570db4dd3f244707e169d65d570d152e1c Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Thu, 20 Feb 2025 15:39:08 -0800 Subject: [PATCH] feat: register embedding models for ollama, together, fireworks (#1190) # What does this PR do? We have support for embeddings in our Inference providers, but so far we haven't done the final step of actually registering the known embedding models and making sure they are extremely easy to use. This is one step towards that. ## Test Plan Run existing inference tests. ```bash $ cd llama_stack/providers/tests/inference $ pytest -s -v -k fireworks test_embeddings.py \ --inference-model nomic-ai/nomic-embed-text-v1.5 --env EMBEDDING_DIMENSION=784 $ pytest -s -v -k together test_embeddings.py \ --inference-model togethercomputer/m2-bert-80M-8k-retrieval --env EMBEDDING_DIMENSION=784 $ pytest -s -v -k ollama test_embeddings.py \ --inference-model all-minilm:latest --env EMBEDDING_DIMENSION=784 ``` The value of the EMBEDDING_DIMENSION isn't actually used in these tests, it is merely used by the test fixtures to check if the model is an LLM or Embedding. --- .pre-commit-config.yaml | 1 + .../self_hosted_distro/fireworks.md | 1 + .../self_hosted_distro/together.md | 2 + .../remote/inference/fireworks/models.py | 10 ++ .../remote/inference/ollama/models.py | 103 ++++++++++++++++++ .../remote/inference/ollama/ollama.py | 101 ++--------------- .../remote/inference/together/models.py | 18 +++ .../providers/tests/inference/fixtures.py | 4 +- .../utils/inference/model_registry.py | 20 ++-- llama_stack/templates/fireworks/fireworks.py | 4 +- .../templates/fireworks/run-with-safety.yaml | 7 ++ llama_stack/templates/fireworks/run.yaml | 7 ++ llama_stack/templates/ollama/ollama.py | 3 +- .../templates/ollama/run-with-safety.yaml | 3 +- llama_stack/templates/ollama/run.yaml | 3 +- .../templates/together/run-with-safety.yaml | 14 +++ llama_stack/templates/together/run.yaml | 14 +++ llama_stack/templates/together/together.py | 4 +- 18 files changed, 214 insertions(+), 105 deletions(-) create mode 100644 llama_stack/providers/remote/inference/ollama/models.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8c5510b27..56e35aa6e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -88,6 +88,7 @@ repos: pass_filenames: false require_serial: true files: ^llama_stack/templates/.*$ + files: ^llama_stack/providers/.*/inference/.*/models\.py$ ci: autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks diff --git a/docs/source/distributions/self_hosted_distro/fireworks.md b/docs/source/distributions/self_hosted_distro/fireworks.md index f77d9f656..7951e148e 100644 --- a/docs/source/distributions/self_hosted_distro/fireworks.md +++ b/docs/source/distributions/self_hosted_distro/fireworks.md @@ -47,6 +47,7 @@ The following models are available by default: - `meta-llama/Llama-3.3-70B-Instruct (accounts/fireworks/models/llama-v3p3-70b-instruct)` - `meta-llama/Llama-Guard-3-8B (accounts/fireworks/models/llama-guard-3-8b)` - `meta-llama/Llama-Guard-3-11B-Vision (accounts/fireworks/models/llama-guard-3-11b-vision)` +- `nomic-ai/nomic-embed-text-v1.5 (nomic-ai/nomic-embed-text-v1.5)` ### Prerequisite: API Keys diff --git a/docs/source/distributions/self_hosted_distro/together.md b/docs/source/distributions/self_hosted_distro/together.md index 8e36c1eb0..936ae58f5 100644 --- a/docs/source/distributions/self_hosted_distro/together.md +++ b/docs/source/distributions/self_hosted_distro/together.md @@ -46,6 +46,8 @@ The following models are available by default: - `meta-llama/Llama-3.3-70B-Instruct` - `meta-llama/Llama-Guard-3-8B` - `meta-llama/Llama-Guard-3-11B-Vision` +- `togethercomputer/m2-bert-80M-8k-retrieval` +- `togethercomputer/m2-bert-80M-32k-retrieval` ### Prerequisite: API Keys diff --git a/llama_stack/providers/remote/inference/fireworks/models.py b/llama_stack/providers/remote/inference/fireworks/models.py index b44f89853..e71979eae 100644 --- a/llama_stack/providers/remote/inference/fireworks/models.py +++ b/llama_stack/providers/remote/inference/fireworks/models.py @@ -4,8 +4,10 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from llama_stack.apis.models.models import ModelType from llama_stack.models.llama.datatypes import CoreModelId from llama_stack.providers.utils.inference.model_registry import ( + ProviderModelEntry, build_hf_repo_model_entry, ) @@ -50,4 +52,12 @@ MODEL_ENTRIES = [ "accounts/fireworks/models/llama-guard-3-11b-vision", CoreModelId.llama_guard_3_11b_vision.value, ), + ProviderModelEntry( + provider_model_id="nomic-ai/nomic-embed-text-v1.5", + model_type=ModelType.embedding, + metadata={ + "embedding_dimensions": 768, + "context_length": 8192, + }, + ), ] diff --git a/llama_stack/providers/remote/inference/ollama/models.py b/llama_stack/providers/remote/inference/ollama/models.py new file mode 100644 index 000000000..e0bf269db --- /dev/null +++ b/llama_stack/providers/remote/inference/ollama/models.py @@ -0,0 +1,103 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.apis.models.models import ModelType +from llama_stack.models.llama.datatypes import CoreModelId +from llama_stack.providers.utils.inference.model_registry import ( + ProviderModelEntry, + build_hf_repo_model_entry, + build_model_entry, +) + +model_entries = [ + build_hf_repo_model_entry( + "llama3.1:8b-instruct-fp16", + CoreModelId.llama3_1_8b_instruct.value, + ), + build_model_entry( + "llama3.1:8b", + CoreModelId.llama3_1_8b_instruct.value, + ), + build_hf_repo_model_entry( + "llama3.1:70b-instruct-fp16", + CoreModelId.llama3_1_70b_instruct.value, + ), + build_model_entry( + "llama3.1:70b", + CoreModelId.llama3_1_70b_instruct.value, + ), + build_hf_repo_model_entry( + "llama3.1:405b-instruct-fp16", + CoreModelId.llama3_1_405b_instruct.value, + ), + build_model_entry( + "llama3.1:405b", + CoreModelId.llama3_1_405b_instruct.value, + ), + build_hf_repo_model_entry( + "llama3.2:1b-instruct-fp16", + CoreModelId.llama3_2_1b_instruct.value, + ), + build_model_entry( + "llama3.2:1b", + CoreModelId.llama3_2_1b_instruct.value, + ), + build_hf_repo_model_entry( + "llama3.2:3b-instruct-fp16", + CoreModelId.llama3_2_3b_instruct.value, + ), + build_model_entry( + "llama3.2:3b", + CoreModelId.llama3_2_3b_instruct.value, + ), + build_hf_repo_model_entry( + "llama3.2-vision:11b-instruct-fp16", + CoreModelId.llama3_2_11b_vision_instruct.value, + ), + build_model_entry( + "llama3.2-vision:latest", + CoreModelId.llama3_2_11b_vision_instruct.value, + ), + build_hf_repo_model_entry( + "llama3.2-vision:90b-instruct-fp16", + CoreModelId.llama3_2_90b_vision_instruct.value, + ), + build_model_entry( + "llama3.2-vision:90b", + CoreModelId.llama3_2_90b_vision_instruct.value, + ), + build_hf_repo_model_entry( + "llama3.3:70b", + CoreModelId.llama3_3_70b_instruct.value, + ), + # The Llama Guard models don't have their full fp16 versions + # so we are going to alias their default version to the canonical SKU + build_hf_repo_model_entry( + "llama-guard3:8b", + CoreModelId.llama_guard_3_8b.value, + ), + build_hf_repo_model_entry( + "llama-guard3:1b", + CoreModelId.llama_guard_3_1b.value, + ), + ProviderModelEntry( + provider_model_id="all-minilm:latest", + aliases=["all-minilm"], + model_type=ModelType.embedding, + metadata={ + "embedding_dimensions": 384, + "context_length": 512, + }, + ), + ProviderModelEntry( + provider_model_id="nomic-embed-text", + model_type=ModelType.embedding, + metadata={ + "embedding_dimensions": 768, + "context_length": 8192, + }, + ), +] diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index e16c02003..1dbcbc294 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -31,12 +31,9 @@ from llama_stack.apis.inference import ( ToolPromptFormat, ) from llama_stack.apis.models import Model, ModelType -from llama_stack.models.llama.datatypes import CoreModelId from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, - build_hf_repo_model_entry, - build_model_entry, ) from llama_stack.providers.utils.inference.openai_compat import ( OpenAICompatCompletionChoice, @@ -56,80 +53,9 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( request_has_media, ) -log = logging.getLogger(__name__) +from .models import model_entries -model_entries = [ - build_hf_repo_model_entry( - "llama3.1:8b-instruct-fp16", - CoreModelId.llama3_1_8b_instruct.value, - ), - build_model_entry( - "llama3.1:8b", - CoreModelId.llama3_1_8b_instruct.value, - ), - build_hf_repo_model_entry( - "llama3.1:70b-instruct-fp16", - CoreModelId.llama3_1_70b_instruct.value, - ), - build_model_entry( - "llama3.1:70b", - CoreModelId.llama3_1_70b_instruct.value, - ), - build_hf_repo_model_entry( - "llama3.1:405b-instruct-fp16", - CoreModelId.llama3_1_405b_instruct.value, - ), - build_model_entry( - "llama3.1:405b", - CoreModelId.llama3_1_405b_instruct.value, - ), - build_hf_repo_model_entry( - "llama3.2:1b-instruct-fp16", - CoreModelId.llama3_2_1b_instruct.value, - ), - build_model_entry( - "llama3.2:1b", - CoreModelId.llama3_2_1b_instruct.value, - ), - build_hf_repo_model_entry( - "llama3.2:3b-instruct-fp16", - CoreModelId.llama3_2_3b_instruct.value, - ), - build_model_entry( - "llama3.2:3b", - CoreModelId.llama3_2_3b_instruct.value, - ), - build_hf_repo_model_entry( - "llama3.2-vision:11b-instruct-fp16", - CoreModelId.llama3_2_11b_vision_instruct.value, - ), - build_model_entry( - "llama3.2-vision:latest", - CoreModelId.llama3_2_11b_vision_instruct.value, - ), - build_hf_repo_model_entry( - "llama3.2-vision:90b-instruct-fp16", - CoreModelId.llama3_2_90b_vision_instruct.value, - ), - build_model_entry( - "llama3.2-vision:90b", - CoreModelId.llama3_2_90b_vision_instruct.value, - ), - build_hf_repo_model_entry( - "llama3.3:70b", - CoreModelId.llama3_3_70b_instruct.value, - ), - # The Llama Guard models don't have their full fp16 versions - # so we are going to alias their default version to the canonical SKU - build_hf_repo_model_entry( - "llama-guard3:8b", - CoreModelId.llama_guard_3_8b.value, - ), - build_hf_repo_model_entry( - "llama-guard3:1b", - CoreModelId.llama_guard_3_1b.value, - ), -] +log = logging.getLogger(__name__) class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): @@ -348,22 +274,17 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): return EmbeddingsResponse(embeddings=embeddings) async def register_model(self, model: Model) -> Model: - async def check_model_availability(model_id: str): - response = await self.client.ps() - available_models = [m["model"] for m in response["models"]] - if model_id not in available_models: - raise ValueError( - f"Model '{model_id}' is not available in Ollama. Available models: {', '.join(available_models)}" - ) - if model.model_type == ModelType.embedding: - await check_model_availability(model.provider_resource_id) - return model + response = await self.client.list() + else: + response = await self.client.ps() + available_models = [m["model"] for m in response["models"]] + if model.provider_resource_id not in available_models: + raise ValueError( + f"Model '{model.provider_resource_id}' is not available in Ollama. Available models: {', '.join(available_models)}" + ) - model = await self.register_helper.register_model(model) - await check_model_availability(model.provider_resource_id) - - return model + return await self.register_helper.register_model(model) async def convert_message_to_openai_dict_for_ollama(message: Message) -> List[dict]: diff --git a/llama_stack/providers/remote/inference/together/models.py b/llama_stack/providers/remote/inference/together/models.py index 90fb60508..6ee31fa78 100644 --- a/llama_stack/providers/remote/inference/together/models.py +++ b/llama_stack/providers/remote/inference/together/models.py @@ -4,8 +4,10 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from llama_stack.apis.models.models import ModelType from llama_stack.models.llama.datatypes import CoreModelId from llama_stack.providers.utils.inference.model_registry import ( + ProviderModelEntry, build_hf_repo_model_entry, ) @@ -46,4 +48,20 @@ MODEL_ENTRIES = [ "meta-llama/Llama-Guard-3-11B-Vision-Turbo", CoreModelId.llama_guard_3_11b_vision.value, ), + ProviderModelEntry( + provider_model_id="togethercomputer/m2-bert-80M-8k-retrieval", + model_type=ModelType.embedding, + metadata={ + "embedding_dimensions": 768, + "context_length": 8192, + }, + ), + ProviderModelEntry( + provider_model_id="togethercomputer/m2-bert-80M-32k-retrieval", + model_type=ModelType.embedding, + metadata={ + "embedding_dimensions": 768, + "context_length": 32768, + }, + ), ] diff --git a/llama_stack/providers/tests/inference/fixtures.py b/llama_stack/providers/tests/inference/fixtures.py index ec4e094c9..b553b6b02 100644 --- a/llama_stack/providers/tests/inference/fixtures.py +++ b/llama_stack/providers/tests/inference/fixtures.py @@ -20,7 +20,7 @@ from llama_stack.providers.remote.inference.cerebras import CerebrasImplConfig from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig from llama_stack.providers.remote.inference.groq import GroqConfig from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig -from llama_stack.providers.remote.inference.ollama import OllamaImplConfig +from llama_stack.providers.remote.inference.ollama import DEFAULT_OLLAMA_URL, OllamaImplConfig from llama_stack.providers.remote.inference.sambanova import SambaNovaImplConfig from llama_stack.providers.remote.inference.tgi import TGIImplConfig from llama_stack.providers.remote.inference.together import TogetherImplConfig @@ -89,7 +89,7 @@ def inference_ollama() -> ProviderFixture: Provider( provider_id="ollama", provider_type="remote::ollama", - config=OllamaImplConfig(url=get_env_or_fail("OLLAMA_URL")).model_dump(), + config=OllamaImplConfig(url=os.getenv("OLLAMA_URL", DEFAULT_OLLAMA_URL)).model_dump(), ) ], ) diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index 288f27449..0882019e3 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import List, Optional +from typing import Any, Dict, List, Optional from pydantic import BaseModel, Field @@ -23,6 +23,7 @@ class ProviderModelEntry(BaseModel): aliases: List[str] = Field(default_factory=list) llama_model: Optional[str] = None model_type: ModelType = ModelType.llm + metadata: Dict[str, Any] = Field(default_factory=dict) def get_huggingface_repo(model_descriptor: str) -> Optional[str]: @@ -47,6 +48,7 @@ def build_model_entry(provider_model_id: str, model_descriptor: str) -> Provider provider_model_id=provider_model_id, aliases=[], llama_model=model_descriptor, + model_type=ModelType.llm, ) @@ -54,14 +56,16 @@ class ModelRegistryHelper(ModelsProtocolPrivate): def __init__(self, model_entries: List[ProviderModelEntry]): self.alias_to_provider_id_map = {} self.provider_id_to_llama_model_map = {} - for alias_obj in model_entries: - for alias in alias_obj.aliases: - self.alias_to_provider_id_map[alias] = alias_obj.provider_model_id + for entry in model_entries: + for alias in entry.aliases: + self.alias_to_provider_id_map[alias] = entry.provider_model_id + # also add a mapping from provider model id to itself for easy lookup - self.alias_to_provider_id_map[alias_obj.provider_model_id] = alias_obj.provider_model_id - # ensure we can go from llama model to provider model id - self.alias_to_provider_id_map[alias_obj.llama_model] = alias_obj.provider_model_id - self.provider_id_to_llama_model_map[alias_obj.provider_model_id] = alias_obj.llama_model + self.alias_to_provider_id_map[entry.provider_model_id] = entry.provider_model_id + + if entry.llama_model: + self.alias_to_provider_id_map[entry.llama_model] = entry.provider_model_id + self.provider_id_to_llama_model_map[entry.provider_model_id] = entry.llama_model def get_provider_model_id(self, identifier: str) -> Optional[str]: return self.alias_to_provider_id_map.get(identifier, None) diff --git a/llama_stack/templates/fireworks/fireworks.py b/llama_stack/templates/fireworks/fireworks.py index 5cde01e81..06b851551 100644 --- a/llama_stack/templates/fireworks/fireworks.py +++ b/llama_stack/templates/fireworks/fireworks.py @@ -63,9 +63,11 @@ def get_distribution_template() -> DistributionTemplate: core_model_to_hf_repo = {m.descriptor(): m.huggingface_repo for m in all_registered_models()} default_models = [ ModelInput( - model_id=core_model_to_hf_repo[m.llama_model], + model_id=core_model_to_hf_repo[m.llama_model] if m.llama_model else m.provider_model_id, provider_model_id=m.provider_model_id, provider_id="fireworks", + metadata=m.metadata, + model_type=m.model_type, ) for m in MODEL_ENTRIES ] diff --git a/llama_stack/templates/fireworks/run-with-safety.yaml b/llama_stack/templates/fireworks/run-with-safety.yaml index 8f95e9d59..1ed5540d8 100644 --- a/llama_stack/templates/fireworks/run-with-safety.yaml +++ b/llama_stack/templates/fireworks/run-with-safety.yaml @@ -149,6 +149,13 @@ models: provider_id: fireworks provider_model_id: accounts/fireworks/models/llama-guard-3-11b-vision model_type: llm +- metadata: + embedding_dimensions: 768 + context_length: 8192 + model_id: nomic-ai/nomic-embed-text-v1.5 + provider_id: fireworks + provider_model_id: nomic-ai/nomic-embed-text-v1.5 + model_type: embedding - metadata: embedding_dimension: 384 model_id: all-MiniLM-L6-v2 diff --git a/llama_stack/templates/fireworks/run.yaml b/llama_stack/templates/fireworks/run.yaml index 64229a5d8..04d55eba8 100644 --- a/llama_stack/templates/fireworks/run.yaml +++ b/llama_stack/templates/fireworks/run.yaml @@ -143,6 +143,13 @@ models: provider_id: fireworks provider_model_id: accounts/fireworks/models/llama-guard-3-11b-vision model_type: llm +- metadata: + embedding_dimensions: 768 + context_length: 8192 + model_id: nomic-ai/nomic-embed-text-v1.5 + provider_id: fireworks + provider_model_id: nomic-ai/nomic-embed-text-v1.5 + model_type: embedding - metadata: embedding_dimension: 384 model_id: all-MiniLM-L6-v2 diff --git a/llama_stack/templates/ollama/ollama.py b/llama_stack/templates/ollama/ollama.py index f3383cd5a..31119e040 100644 --- a/llama_stack/templates/ollama/ollama.py +++ b/llama_stack/templates/ollama/ollama.py @@ -71,7 +71,8 @@ def get_distribution_template() -> DistributionTemplate: ) embedding_model = ModelInput( model_id="all-MiniLM-L6-v2", - provider_id="sentence-transformers", + provider_id="ollama", + provider_model_id="all-minilm:latest", model_type=ModelType.embedding, metadata={ "embedding_dimension": 384, diff --git a/llama_stack/templates/ollama/run-with-safety.yaml b/llama_stack/templates/ollama/run-with-safety.yaml index 4ce64cf59..7cf527c04 100644 --- a/llama_stack/templates/ollama/run-with-safety.yaml +++ b/llama_stack/templates/ollama/run-with-safety.yaml @@ -110,7 +110,8 @@ models: - metadata: embedding_dimension: 384 model_id: all-MiniLM-L6-v2 - provider_id: sentence-transformers + provider_id: ollama + provider_model_id: all-minilm:latest model_type: embedding shields: - shield_id: ${env.SAFETY_MODEL} diff --git a/llama_stack/templates/ollama/run.yaml b/llama_stack/templates/ollama/run.yaml index b4982f8e2..ab292c5e0 100644 --- a/llama_stack/templates/ollama/run.yaml +++ b/llama_stack/templates/ollama/run.yaml @@ -103,7 +103,8 @@ models: - metadata: embedding_dimension: 384 model_id: all-MiniLM-L6-v2 - provider_id: sentence-transformers + provider_id: ollama + provider_model_id: all-minilm:latest model_type: embedding shields: [] vector_dbs: [] diff --git a/llama_stack/templates/together/run-with-safety.yaml b/llama_stack/templates/together/run-with-safety.yaml index f101a5d60..837709579 100644 --- a/llama_stack/templates/together/run-with-safety.yaml +++ b/llama_stack/templates/together/run-with-safety.yaml @@ -144,6 +144,20 @@ models: provider_id: together provider_model_id: meta-llama/Llama-Guard-3-11B-Vision-Turbo model_type: llm +- metadata: + embedding_dimensions: 768 + context_length: 8192 + model_id: togethercomputer/m2-bert-80M-8k-retrieval + provider_id: together + provider_model_id: togethercomputer/m2-bert-80M-8k-retrieval + model_type: embedding +- metadata: + embedding_dimensions: 768 + context_length: 32768 + model_id: togethercomputer/m2-bert-80M-32k-retrieval + provider_id: together + provider_model_id: togethercomputer/m2-bert-80M-32k-retrieval + model_type: embedding - metadata: embedding_dimension: 384 model_id: all-MiniLM-L6-v2 diff --git a/llama_stack/templates/together/run.yaml b/llama_stack/templates/together/run.yaml index 8af85979d..28ff36cff 100644 --- a/llama_stack/templates/together/run.yaml +++ b/llama_stack/templates/together/run.yaml @@ -138,6 +138,20 @@ models: provider_id: together provider_model_id: meta-llama/Llama-Guard-3-11B-Vision-Turbo model_type: llm +- metadata: + embedding_dimensions: 768 + context_length: 8192 + model_id: togethercomputer/m2-bert-80M-8k-retrieval + provider_id: together + provider_model_id: togethercomputer/m2-bert-80M-8k-retrieval + model_type: embedding +- metadata: + embedding_dimensions: 768 + context_length: 32768 + model_id: togethercomputer/m2-bert-80M-32k-retrieval + provider_id: together + provider_model_id: togethercomputer/m2-bert-80M-32k-retrieval + model_type: embedding - metadata: embedding_dimension: 384 model_id: all-MiniLM-L6-v2 diff --git a/llama_stack/templates/together/together.py b/llama_stack/templates/together/together.py index d46dd9d27..d275b7238 100644 --- a/llama_stack/templates/together/together.py +++ b/llama_stack/templates/together/together.py @@ -61,9 +61,11 @@ def get_distribution_template() -> DistributionTemplate: core_model_to_hf_repo = {m.descriptor(): m.huggingface_repo for m in all_registered_models()} default_models = [ ModelInput( - model_id=core_model_to_hf_repo[m.llama_model], + model_id=core_model_to_hf_repo[m.llama_model] if m.llama_model else m.provider_model_id, provider_model_id=m.provider_model_id, provider_id="together", + metadata=m.metadata, + model_type=m.model_type, ) for m in MODEL_ENTRIES ]