Register embedding models for ollama, together, fireworks

This commit is contained in:
Ashwin Bharambe 2025-02-20 13:15:41 -08:00
parent 736560ceba
commit e337600954
17 changed files with 194 additions and 82 deletions

View file

@ -88,6 +88,7 @@ repos:
pass_filenames: false
require_serial: true
files: ^llama_stack/templates/.*$
files: ^llama_stack/providers/.*/inference/.*/models\.py$
ci:
autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks

View file

@ -47,6 +47,7 @@ The following models are available by default:
- `meta-llama/Llama-3.3-70B-Instruct (accounts/fireworks/models/llama-v3p3-70b-instruct)`
- `meta-llama/Llama-Guard-3-8B (accounts/fireworks/models/llama-guard-3-8b)`
- `meta-llama/Llama-Guard-3-11B-Vision (accounts/fireworks/models/llama-guard-3-11b-vision)`
- `nomic-ai/nomic-embed-text-v1.5 (nomic-ai/nomic-embed-text-v1.5)`
### Prerequisite: API Keys

View file

@ -46,6 +46,8 @@ The following models are available by default:
- `meta-llama/Llama-3.3-70B-Instruct`
- `meta-llama/Llama-Guard-3-8B`
- `meta-llama/Llama-Guard-3-11B-Vision`
- `togethercomputer/m2-bert-80M-8k-retrieval`
- `togethercomputer/m2-bert-80M-32k-retrieval`
### Prerequisite: API Keys

View file

@ -4,8 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.models.models import ModelType
from llama_stack.models.llama.datatypes import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry,
build_hf_repo_model_entry,
)
@ -50,4 +52,12 @@ MODEL_ENTRIES = [
"accounts/fireworks/models/llama-guard-3-11b-vision",
CoreModelId.llama_guard_3_11b_vision.value,
),
ProviderModelEntry(
provider_model_id="nomic-ai/nomic-embed-text-v1.5",
model_type=ModelType.embedding,
metadata={
"embedding_dimensions": 768,
"context_length": 8192,
},
),
]

View file

@ -0,0 +1,103 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.models.models import ModelType
from llama_stack.models.llama.datatypes import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry,
build_hf_repo_model_entry,
build_model_entry,
)
model_entries = [
build_hf_repo_model_entry(
"llama3.1:8b-instruct-fp16",
CoreModelId.llama3_1_8b_instruct.value,
),
build_model_entry(
"llama3.1:8b",
CoreModelId.llama3_1_8b_instruct.value,
),
build_hf_repo_model_entry(
"llama3.1:70b-instruct-fp16",
CoreModelId.llama3_1_70b_instruct.value,
),
build_model_entry(
"llama3.1:70b",
CoreModelId.llama3_1_70b_instruct.value,
),
build_hf_repo_model_entry(
"llama3.1:405b-instruct-fp16",
CoreModelId.llama3_1_405b_instruct.value,
),
build_model_entry(
"llama3.1:405b",
CoreModelId.llama3_1_405b_instruct.value,
),
build_hf_repo_model_entry(
"llama3.2:1b-instruct-fp16",
CoreModelId.llama3_2_1b_instruct.value,
),
build_model_entry(
"llama3.2:1b",
CoreModelId.llama3_2_1b_instruct.value,
),
build_hf_repo_model_entry(
"llama3.2:3b-instruct-fp16",
CoreModelId.llama3_2_3b_instruct.value,
),
build_model_entry(
"llama3.2:3b",
CoreModelId.llama3_2_3b_instruct.value,
),
build_hf_repo_model_entry(
"llama3.2-vision:11b-instruct-fp16",
CoreModelId.llama3_2_11b_vision_instruct.value,
),
build_model_entry(
"llama3.2-vision:latest",
CoreModelId.llama3_2_11b_vision_instruct.value,
),
build_hf_repo_model_entry(
"llama3.2-vision:90b-instruct-fp16",
CoreModelId.llama3_2_90b_vision_instruct.value,
),
build_model_entry(
"llama3.2-vision:90b",
CoreModelId.llama3_2_90b_vision_instruct.value,
),
build_hf_repo_model_entry(
"llama3.3:70b",
CoreModelId.llama3_3_70b_instruct.value,
),
# The Llama Guard models don't have their full fp16 versions
# so we are going to alias their default version to the canonical SKU
build_hf_repo_model_entry(
"llama-guard3:8b",
CoreModelId.llama_guard_3_8b.value,
),
build_hf_repo_model_entry(
"llama-guard3:1b",
CoreModelId.llama_guard_3_1b.value,
),
] + [
ProviderModelEntry(
provider_model_id="all-minilm",
model_type=ModelType.embedding,
metadata={
"embedding_dimensions": 384,
"context_length": 512,
},
),
ProviderModelEntry(
provider_model_id="nomic-embed-text",
model_type=ModelType.embedding,
metadata={
"embedding_dimensions": 768,
"context_length": 8192,
},
),
]

View file

@ -31,12 +31,9 @@ from llama_stack.apis.inference import (
ToolPromptFormat,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.models.llama.datatypes import CoreModelId
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
build_hf_repo_model_entry,
build_model_entry,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAICompatCompletionChoice,
@ -56,80 +53,9 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
request_has_media,
)
log = logging.getLogger(__name__)
from .models import model_entries
model_entries = [
build_hf_repo_model_entry(
"llama3.1:8b-instruct-fp16",
CoreModelId.llama3_1_8b_instruct.value,
),
build_model_entry(
"llama3.1:8b",
CoreModelId.llama3_1_8b_instruct.value,
),
build_hf_repo_model_entry(
"llama3.1:70b-instruct-fp16",
CoreModelId.llama3_1_70b_instruct.value,
),
build_model_entry(
"llama3.1:70b",
CoreModelId.llama3_1_70b_instruct.value,
),
build_hf_repo_model_entry(
"llama3.1:405b-instruct-fp16",
CoreModelId.llama3_1_405b_instruct.value,
),
build_model_entry(
"llama3.1:405b",
CoreModelId.llama3_1_405b_instruct.value,
),
build_hf_repo_model_entry(
"llama3.2:1b-instruct-fp16",
CoreModelId.llama3_2_1b_instruct.value,
),
build_model_entry(
"llama3.2:1b",
CoreModelId.llama3_2_1b_instruct.value,
),
build_hf_repo_model_entry(
"llama3.2:3b-instruct-fp16",
CoreModelId.llama3_2_3b_instruct.value,
),
build_model_entry(
"llama3.2:3b",
CoreModelId.llama3_2_3b_instruct.value,
),
build_hf_repo_model_entry(
"llama3.2-vision:11b-instruct-fp16",
CoreModelId.llama3_2_11b_vision_instruct.value,
),
build_model_entry(
"llama3.2-vision:latest",
CoreModelId.llama3_2_11b_vision_instruct.value,
),
build_hf_repo_model_entry(
"llama3.2-vision:90b-instruct-fp16",
CoreModelId.llama3_2_90b_vision_instruct.value,
),
build_model_entry(
"llama3.2-vision:90b",
CoreModelId.llama3_2_90b_vision_instruct.value,
),
build_hf_repo_model_entry(
"llama3.3:70b",
CoreModelId.llama3_3_70b_instruct.value,
),
# The Llama Guard models don't have their full fp16 versions
# so we are going to alias their default version to the canonical SKU
build_hf_repo_model_entry(
"llama-guard3:8b",
CoreModelId.llama_guard_3_8b.value,
),
build_hf_repo_model_entry(
"llama-guard3:1b",
CoreModelId.llama_guard_3_1b.value,
),
]
log = logging.getLogger(__name__)
class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):

View file

@ -4,8 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.models.models import ModelType
from llama_stack.models.llama.datatypes import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry,
build_hf_repo_model_entry,
)
@ -46,4 +48,20 @@ MODEL_ENTRIES = [
"meta-llama/Llama-Guard-3-11B-Vision-Turbo",
CoreModelId.llama_guard_3_11b_vision.value,
),
ProviderModelEntry(
provider_model_id="togethercomputer/m2-bert-80M-8k-retrieval",
model_type=ModelType.embedding,
metadata={
"embedding_dimensions": 768,
"context_length": 8192,
},
),
ProviderModelEntry(
provider_model_id="togethercomputer/m2-bert-80M-32k-retrieval",
model_type=ModelType.embedding,
metadata={
"embedding_dimensions": 768,
"context_length": 32768,
},
),
]

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Optional
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
@ -23,6 +23,7 @@ class ProviderModelEntry(BaseModel):
aliases: List[str] = Field(default_factory=list)
llama_model: Optional[str] = None
model_type: ModelType = ModelType.llm
metadata: Dict[str, Any] = Field(default_factory=dict)
def get_huggingface_repo(model_descriptor: str) -> Optional[str]:
@ -47,6 +48,7 @@ def build_model_entry(provider_model_id: str, model_descriptor: str) -> Provider
provider_model_id=provider_model_id,
aliases=[],
llama_model=model_descriptor,
model_type=ModelType.llm,
)

View file

@ -63,9 +63,11 @@ def get_distribution_template() -> DistributionTemplate:
core_model_to_hf_repo = {m.descriptor(): m.huggingface_repo for m in all_registered_models()}
default_models = [
ModelInput(
model_id=core_model_to_hf_repo[m.llama_model],
model_id=core_model_to_hf_repo[m.llama_model] if m.llama_model else m.provider_model_id,
provider_model_id=m.provider_model_id,
provider_id="fireworks",
metadata=m.metadata,
model_type=m.model_type,
)
for m in MODEL_ENTRIES
]

View file

@ -149,6 +149,13 @@ models:
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-guard-3-11b-vision
model_type: llm
- metadata:
embedding_dimensions: 768
context_length: 8192
model_id: nomic-ai/nomic-embed-text-v1.5
provider_id: fireworks
provider_model_id: nomic-ai/nomic-embed-text-v1.5
model_type: embedding
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2

View file

@ -143,6 +143,13 @@ models:
provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-guard-3-11b-vision
model_type: llm
- metadata:
embedding_dimensions: 768
context_length: 8192
model_id: nomic-ai/nomic-embed-text-v1.5
provider_id: fireworks
provider_model_id: nomic-ai/nomic-embed-text-v1.5
model_type: embedding
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2

View file

@ -71,7 +71,8 @@ def get_distribution_template() -> DistributionTemplate:
)
embedding_model = ModelInput(
model_id="all-MiniLM-L6-v2",
provider_id="sentence-transformers",
provider_id="ollama",
provider_model_id="all-minilm",
model_type=ModelType.embedding,
metadata={
"embedding_dimension": 384,

View file

@ -110,7 +110,8 @@ models:
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2
provider_id: sentence-transformers
provider_id: ollama
provider_model_id: all-minilm
model_type: embedding
shields:
- shield_id: ${env.SAFETY_MODEL}

View file

@ -103,7 +103,8 @@ models:
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2
provider_id: sentence-transformers
provider_id: ollama
provider_model_id: all-minilm
model_type: embedding
shields: []
vector_dbs: []

View file

@ -144,6 +144,20 @@ models:
provider_id: together
provider_model_id: meta-llama/Llama-Guard-3-11B-Vision-Turbo
model_type: llm
- metadata:
embedding_dimensions: 768
context_length: 8192
model_id: togethercomputer/m2-bert-80M-8k-retrieval
provider_id: together
provider_model_id: togethercomputer/m2-bert-80M-8k-retrieval
model_type: embedding
- metadata:
embedding_dimensions: 768
context_length: 32768
model_id: togethercomputer/m2-bert-80M-32k-retrieval
provider_id: together
provider_model_id: togethercomputer/m2-bert-80M-32k-retrieval
model_type: embedding
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2

View file

@ -138,6 +138,20 @@ models:
provider_id: together
provider_model_id: meta-llama/Llama-Guard-3-11B-Vision-Turbo
model_type: llm
- metadata:
embedding_dimensions: 768
context_length: 8192
model_id: togethercomputer/m2-bert-80M-8k-retrieval
provider_id: together
provider_model_id: togethercomputer/m2-bert-80M-8k-retrieval
model_type: embedding
- metadata:
embedding_dimensions: 768
context_length: 32768
model_id: togethercomputer/m2-bert-80M-32k-retrieval
provider_id: together
provider_model_id: togethercomputer/m2-bert-80M-32k-retrieval
model_type: embedding
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2

View file

@ -61,9 +61,11 @@ def get_distribution_template() -> DistributionTemplate:
core_model_to_hf_repo = {m.descriptor(): m.huggingface_repo for m in all_registered_models()}
default_models = [
ModelInput(
model_id=core_model_to_hf_repo[m.llama_model],
model_id=core_model_to_hf_repo[m.llama_model] if m.llama_model else m.provider_model_id,
provider_model_id=m.provider_model_id,
provider_id="together",
metadata=m.metadata,
model_type=m.model_type,
)
for m in MODEL_ENTRIES
]