feat: register embedding models for ollama, together, fireworks (#1190)

# What does this PR do?

We have support for embeddings in our Inference providers, but so far we
haven't done the final step of actually registering the known embedding
models and making sure they are extremely easy to use. This is one step
towards that.

## Test Plan

Run existing inference tests.

```bash

$ cd llama_stack/providers/tests/inference
$ pytest -s -v -k fireworks test_embeddings.py \
   --inference-model nomic-ai/nomic-embed-text-v1.5 --env EMBEDDING_DIMENSION=784
$  pytest -s -v -k together test_embeddings.py \
   --inference-model togethercomputer/m2-bert-80M-8k-retrieval --env EMBEDDING_DIMENSION=784
$ pytest -s -v -k ollama test_embeddings.py \
   --inference-model all-minilm:latest --env EMBEDDING_DIMENSION=784
```

The value of the EMBEDDING_DIMENSION isn't actually used in these tests,
it is merely used by the test fixtures to check if the model is an LLM
or Embedding.
This commit is contained in:
Ashwin Bharambe 2025-02-20 15:39:08 -08:00 committed by GitHub
parent 736560ceba
commit 9436dd570d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 214 additions and 105 deletions

View file

@ -88,6 +88,7 @@ repos:
pass_filenames: false pass_filenames: false
require_serial: true require_serial: true
files: ^llama_stack/templates/.*$ files: ^llama_stack/templates/.*$
files: ^llama_stack/providers/.*/inference/.*/models\.py$
ci: ci:
autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks

View file

@ -47,6 +47,7 @@ The following models are available by default:
- `meta-llama/Llama-3.3-70B-Instruct (accounts/fireworks/models/llama-v3p3-70b-instruct)` - `meta-llama/Llama-3.3-70B-Instruct (accounts/fireworks/models/llama-v3p3-70b-instruct)`
- `meta-llama/Llama-Guard-3-8B (accounts/fireworks/models/llama-guard-3-8b)` - `meta-llama/Llama-Guard-3-8B (accounts/fireworks/models/llama-guard-3-8b)`
- `meta-llama/Llama-Guard-3-11B-Vision (accounts/fireworks/models/llama-guard-3-11b-vision)` - `meta-llama/Llama-Guard-3-11B-Vision (accounts/fireworks/models/llama-guard-3-11b-vision)`
- `nomic-ai/nomic-embed-text-v1.5 (nomic-ai/nomic-embed-text-v1.5)`
### Prerequisite: API Keys ### Prerequisite: API Keys

View file

@ -46,6 +46,8 @@ The following models are available by default:
- `meta-llama/Llama-3.3-70B-Instruct` - `meta-llama/Llama-3.3-70B-Instruct`
- `meta-llama/Llama-Guard-3-8B` - `meta-llama/Llama-Guard-3-8B`
- `meta-llama/Llama-Guard-3-11B-Vision` - `meta-llama/Llama-Guard-3-11B-Vision`
- `togethercomputer/m2-bert-80M-8k-retrieval`
- `togethercomputer/m2-bert-80M-32k-retrieval`
### Prerequisite: API Keys ### Prerequisite: API Keys

View file

@ -4,8 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from llama_stack.apis.models.models import ModelType
from llama_stack.models.llama.datatypes import CoreModelId from llama_stack.models.llama.datatypes import CoreModelId
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry,
build_hf_repo_model_entry, build_hf_repo_model_entry,
) )
@ -50,4 +52,12 @@ MODEL_ENTRIES = [
"accounts/fireworks/models/llama-guard-3-11b-vision", "accounts/fireworks/models/llama-guard-3-11b-vision",
CoreModelId.llama_guard_3_11b_vision.value, CoreModelId.llama_guard_3_11b_vision.value,
), ),
ProviderModelEntry(
provider_model_id="nomic-ai/nomic-embed-text-v1.5",
model_type=ModelType.embedding,
metadata={
"embedding_dimensions": 768,
"context_length": 8192,
},
),
] ]

View file

@ -0,0 +1,103 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.models.models import ModelType
from llama_stack.models.llama.datatypes import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry,
build_hf_repo_model_entry,
build_model_entry,
)
model_entries = [
build_hf_repo_model_entry(
"llama3.1:8b-instruct-fp16",
CoreModelId.llama3_1_8b_instruct.value,
),
build_model_entry(
"llama3.1:8b",
CoreModelId.llama3_1_8b_instruct.value,
),
build_hf_repo_model_entry(
"llama3.1:70b-instruct-fp16",
CoreModelId.llama3_1_70b_instruct.value,
),
build_model_entry(
"llama3.1:70b",
CoreModelId.llama3_1_70b_instruct.value,
),
build_hf_repo_model_entry(
"llama3.1:405b-instruct-fp16",
CoreModelId.llama3_1_405b_instruct.value,
),
build_model_entry(
"llama3.1:405b",
CoreModelId.llama3_1_405b_instruct.value,
),
build_hf_repo_model_entry(
"llama3.2:1b-instruct-fp16",
CoreModelId.llama3_2_1b_instruct.value,
),
build_model_entry(
"llama3.2:1b",
CoreModelId.llama3_2_1b_instruct.value,
),
build_hf_repo_model_entry(
"llama3.2:3b-instruct-fp16",
CoreModelId.llama3_2_3b_instruct.value,
),
build_model_entry(
"llama3.2:3b",
CoreModelId.llama3_2_3b_instruct.value,
),
build_hf_repo_model_entry(
"llama3.2-vision:11b-instruct-fp16",
CoreModelId.llama3_2_11b_vision_instruct.value,
),
build_model_entry(
"llama3.2-vision:latest",
CoreModelId.llama3_2_11b_vision_instruct.value,
),
build_hf_repo_model_entry(
"llama3.2-vision:90b-instruct-fp16",
CoreModelId.llama3_2_90b_vision_instruct.value,
),
build_model_entry(
"llama3.2-vision:90b",
CoreModelId.llama3_2_90b_vision_instruct.value,
),
build_hf_repo_model_entry(
"llama3.3:70b",
CoreModelId.llama3_3_70b_instruct.value,
),
# The Llama Guard models don't have their full fp16 versions
# so we are going to alias their default version to the canonical SKU
build_hf_repo_model_entry(
"llama-guard3:8b",
CoreModelId.llama_guard_3_8b.value,
),
build_hf_repo_model_entry(
"llama-guard3:1b",
CoreModelId.llama_guard_3_1b.value,
),
ProviderModelEntry(
provider_model_id="all-minilm:latest",
aliases=["all-minilm"],
model_type=ModelType.embedding,
metadata={
"embedding_dimensions": 384,
"context_length": 512,
},
),
ProviderModelEntry(
provider_model_id="nomic-embed-text",
model_type=ModelType.embedding,
metadata={
"embedding_dimensions": 768,
"context_length": 8192,
},
),
]

View file

@ -31,12 +31,9 @@ from llama_stack.apis.inference import (
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.apis.models import Model, ModelType from llama_stack.apis.models import Model, ModelType
from llama_stack.models.llama.datatypes import CoreModelId
from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper, ModelRegistryHelper,
build_hf_repo_model_entry,
build_model_entry,
) )
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
OpenAICompatCompletionChoice, OpenAICompatCompletionChoice,
@ -56,80 +53,9 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
request_has_media, request_has_media,
) )
log = logging.getLogger(__name__) from .models import model_entries
model_entries = [ log = logging.getLogger(__name__)
build_hf_repo_model_entry(
"llama3.1:8b-instruct-fp16",
CoreModelId.llama3_1_8b_instruct.value,
),
build_model_entry(
"llama3.1:8b",
CoreModelId.llama3_1_8b_instruct.value,
),
build_hf_repo_model_entry(
"llama3.1:70b-instruct-fp16",
CoreModelId.llama3_1_70b_instruct.value,
),
build_model_entry(
"llama3.1:70b",
CoreModelId.llama3_1_70b_instruct.value,
),
build_hf_repo_model_entry(
"llama3.1:405b-instruct-fp16",
CoreModelId.llama3_1_405b_instruct.value,
),
build_model_entry(
"llama3.1:405b",
CoreModelId.llama3_1_405b_instruct.value,
),
build_hf_repo_model_entry(
"llama3.2:1b-instruct-fp16",
CoreModelId.llama3_2_1b_instruct.value,
),
build_model_entry(
"llama3.2:1b",
CoreModelId.llama3_2_1b_instruct.value,
),
build_hf_repo_model_entry(
"llama3.2:3b-instruct-fp16",
CoreModelId.llama3_2_3b_instruct.value,
),
build_model_entry(
"llama3.2:3b",
CoreModelId.llama3_2_3b_instruct.value,
),
build_hf_repo_model_entry(
"llama3.2-vision:11b-instruct-fp16",
CoreModelId.llama3_2_11b_vision_instruct.value,
),
build_model_entry(
"llama3.2-vision:latest",
CoreModelId.llama3_2_11b_vision_instruct.value,
),
build_hf_repo_model_entry(
"llama3.2-vision:90b-instruct-fp16",
CoreModelId.llama3_2_90b_vision_instruct.value,
),
build_model_entry(
"llama3.2-vision:90b",
CoreModelId.llama3_2_90b_vision_instruct.value,
),
build_hf_repo_model_entry(
"llama3.3:70b",
CoreModelId.llama3_3_70b_instruct.value,
),
# The Llama Guard models don't have their full fp16 versions
# so we are going to alias their default version to the canonical SKU
build_hf_repo_model_entry(
"llama-guard3:8b",
CoreModelId.llama_guard_3_8b.value,
),
build_hf_repo_model_entry(
"llama-guard3:1b",
CoreModelId.llama_guard_3_1b.value,
),
]
class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
@ -348,22 +274,17 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
return EmbeddingsResponse(embeddings=embeddings) return EmbeddingsResponse(embeddings=embeddings)
async def register_model(self, model: Model) -> Model: async def register_model(self, model: Model) -> Model:
async def check_model_availability(model_id: str):
response = await self.client.ps()
available_models = [m["model"] for m in response["models"]]
if model_id not in available_models:
raise ValueError(
f"Model '{model_id}' is not available in Ollama. Available models: {', '.join(available_models)}"
)
if model.model_type == ModelType.embedding: if model.model_type == ModelType.embedding:
await check_model_availability(model.provider_resource_id) response = await self.client.list()
return model else:
response = await self.client.ps()
available_models = [m["model"] for m in response["models"]]
if model.provider_resource_id not in available_models:
raise ValueError(
f"Model '{model.provider_resource_id}' is not available in Ollama. Available models: {', '.join(available_models)}"
)
model = await self.register_helper.register_model(model) return await self.register_helper.register_model(model)
await check_model_availability(model.provider_resource_id)
return model
async def convert_message_to_openai_dict_for_ollama(message: Message) -> List[dict]: async def convert_message_to_openai_dict_for_ollama(message: Message) -> List[dict]:

View file

@ -4,8 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from llama_stack.apis.models.models import ModelType
from llama_stack.models.llama.datatypes import CoreModelId from llama_stack.models.llama.datatypes import CoreModelId
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry,
build_hf_repo_model_entry, build_hf_repo_model_entry,
) )
@ -46,4 +48,20 @@ MODEL_ENTRIES = [
"meta-llama/Llama-Guard-3-11B-Vision-Turbo", "meta-llama/Llama-Guard-3-11B-Vision-Turbo",
CoreModelId.llama_guard_3_11b_vision.value, CoreModelId.llama_guard_3_11b_vision.value,
), ),
ProviderModelEntry(
provider_model_id="togethercomputer/m2-bert-80M-8k-retrieval",
model_type=ModelType.embedding,
metadata={
"embedding_dimensions": 768,
"context_length": 8192,
},
),
ProviderModelEntry(
provider_model_id="togethercomputer/m2-bert-80M-32k-retrieval",
model_type=ModelType.embedding,
metadata={
"embedding_dimensions": 768,
"context_length": 32768,
},
),
] ]

View file

@ -20,7 +20,7 @@ from llama_stack.providers.remote.inference.cerebras import CerebrasImplConfig
from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig
from llama_stack.providers.remote.inference.groq import GroqConfig from llama_stack.providers.remote.inference.groq import GroqConfig
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig from llama_stack.providers.remote.inference.ollama import DEFAULT_OLLAMA_URL, OllamaImplConfig
from llama_stack.providers.remote.inference.sambanova import SambaNovaImplConfig from llama_stack.providers.remote.inference.sambanova import SambaNovaImplConfig
from llama_stack.providers.remote.inference.tgi import TGIImplConfig from llama_stack.providers.remote.inference.tgi import TGIImplConfig
from llama_stack.providers.remote.inference.together import TogetherImplConfig from llama_stack.providers.remote.inference.together import TogetherImplConfig
@ -89,7 +89,7 @@ def inference_ollama() -> ProviderFixture:
Provider( Provider(
provider_id="ollama", provider_id="ollama",
provider_type="remote::ollama", provider_type="remote::ollama",
config=OllamaImplConfig(url=get_env_or_fail("OLLAMA_URL")).model_dump(), config=OllamaImplConfig(url=os.getenv("OLLAMA_URL", DEFAULT_OLLAMA_URL)).model_dump(),
) )
], ],
) )

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import List, Optional from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -23,6 +23,7 @@ class ProviderModelEntry(BaseModel):
aliases: List[str] = Field(default_factory=list) aliases: List[str] = Field(default_factory=list)
llama_model: Optional[str] = None llama_model: Optional[str] = None
model_type: ModelType = ModelType.llm model_type: ModelType = ModelType.llm
metadata: Dict[str, Any] = Field(default_factory=dict)
def get_huggingface_repo(model_descriptor: str) -> Optional[str]: def get_huggingface_repo(model_descriptor: str) -> Optional[str]:
@ -47,6 +48,7 @@ def build_model_entry(provider_model_id: str, model_descriptor: str) -> Provider
provider_model_id=provider_model_id, provider_model_id=provider_model_id,
aliases=[], aliases=[],
llama_model=model_descriptor, llama_model=model_descriptor,
model_type=ModelType.llm,
) )
@ -54,14 +56,16 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
def __init__(self, model_entries: List[ProviderModelEntry]): def __init__(self, model_entries: List[ProviderModelEntry]):
self.alias_to_provider_id_map = {} self.alias_to_provider_id_map = {}
self.provider_id_to_llama_model_map = {} self.provider_id_to_llama_model_map = {}
for alias_obj in model_entries: for entry in model_entries:
for alias in alias_obj.aliases: for alias in entry.aliases:
self.alias_to_provider_id_map[alias] = alias_obj.provider_model_id self.alias_to_provider_id_map[alias] = entry.provider_model_id
# also add a mapping from provider model id to itself for easy lookup # also add a mapping from provider model id to itself for easy lookup
self.alias_to_provider_id_map[alias_obj.provider_model_id] = alias_obj.provider_model_id self.alias_to_provider_id_map[entry.provider_model_id] = entry.provider_model_id
# ensure we can go from llama model to provider model id
self.alias_to_provider_id_map[alias_obj.llama_model] = alias_obj.provider_model_id if entry.llama_model:
self.provider_id_to_llama_model_map[alias_obj.provider_model_id] = alias_obj.llama_model self.alias_to_provider_id_map[entry.llama_model] = entry.provider_model_id
self.provider_id_to_llama_model_map[entry.provider_model_id] = entry.llama_model
def get_provider_model_id(self, identifier: str) -> Optional[str]: def get_provider_model_id(self, identifier: str) -> Optional[str]:
return self.alias_to_provider_id_map.get(identifier, None) return self.alias_to_provider_id_map.get(identifier, None)

View file

@ -63,9 +63,11 @@ def get_distribution_template() -> DistributionTemplate:
core_model_to_hf_repo = {m.descriptor(): m.huggingface_repo for m in all_registered_models()} core_model_to_hf_repo = {m.descriptor(): m.huggingface_repo for m in all_registered_models()}
default_models = [ default_models = [
ModelInput( ModelInput(
model_id=core_model_to_hf_repo[m.llama_model], model_id=core_model_to_hf_repo[m.llama_model] if m.llama_model else m.provider_model_id,
provider_model_id=m.provider_model_id, provider_model_id=m.provider_model_id,
provider_id="fireworks", provider_id="fireworks",
metadata=m.metadata,
model_type=m.model_type,
) )
for m in MODEL_ENTRIES for m in MODEL_ENTRIES
] ]

View file

@ -149,6 +149,13 @@ models:
provider_id: fireworks provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-guard-3-11b-vision provider_model_id: accounts/fireworks/models/llama-guard-3-11b-vision
model_type: llm model_type: llm
- metadata:
embedding_dimensions: 768
context_length: 8192
model_id: nomic-ai/nomic-embed-text-v1.5
provider_id: fireworks
provider_model_id: nomic-ai/nomic-embed-text-v1.5
model_type: embedding
- metadata: - metadata:
embedding_dimension: 384 embedding_dimension: 384
model_id: all-MiniLM-L6-v2 model_id: all-MiniLM-L6-v2

View file

@ -143,6 +143,13 @@ models:
provider_id: fireworks provider_id: fireworks
provider_model_id: accounts/fireworks/models/llama-guard-3-11b-vision provider_model_id: accounts/fireworks/models/llama-guard-3-11b-vision
model_type: llm model_type: llm
- metadata:
embedding_dimensions: 768
context_length: 8192
model_id: nomic-ai/nomic-embed-text-v1.5
provider_id: fireworks
provider_model_id: nomic-ai/nomic-embed-text-v1.5
model_type: embedding
- metadata: - metadata:
embedding_dimension: 384 embedding_dimension: 384
model_id: all-MiniLM-L6-v2 model_id: all-MiniLM-L6-v2

View file

@ -71,7 +71,8 @@ def get_distribution_template() -> DistributionTemplate:
) )
embedding_model = ModelInput( embedding_model = ModelInput(
model_id="all-MiniLM-L6-v2", model_id="all-MiniLM-L6-v2",
provider_id="sentence-transformers", provider_id="ollama",
provider_model_id="all-minilm:latest",
model_type=ModelType.embedding, model_type=ModelType.embedding,
metadata={ metadata={
"embedding_dimension": 384, "embedding_dimension": 384,

View file

@ -110,7 +110,8 @@ models:
- metadata: - metadata:
embedding_dimension: 384 embedding_dimension: 384
model_id: all-MiniLM-L6-v2 model_id: all-MiniLM-L6-v2
provider_id: sentence-transformers provider_id: ollama
provider_model_id: all-minilm:latest
model_type: embedding model_type: embedding
shields: shields:
- shield_id: ${env.SAFETY_MODEL} - shield_id: ${env.SAFETY_MODEL}

View file

@ -103,7 +103,8 @@ models:
- metadata: - metadata:
embedding_dimension: 384 embedding_dimension: 384
model_id: all-MiniLM-L6-v2 model_id: all-MiniLM-L6-v2
provider_id: sentence-transformers provider_id: ollama
provider_model_id: all-minilm:latest
model_type: embedding model_type: embedding
shields: [] shields: []
vector_dbs: [] vector_dbs: []

View file

@ -144,6 +144,20 @@ models:
provider_id: together provider_id: together
provider_model_id: meta-llama/Llama-Guard-3-11B-Vision-Turbo provider_model_id: meta-llama/Llama-Guard-3-11B-Vision-Turbo
model_type: llm model_type: llm
- metadata:
embedding_dimensions: 768
context_length: 8192
model_id: togethercomputer/m2-bert-80M-8k-retrieval
provider_id: together
provider_model_id: togethercomputer/m2-bert-80M-8k-retrieval
model_type: embedding
- metadata:
embedding_dimensions: 768
context_length: 32768
model_id: togethercomputer/m2-bert-80M-32k-retrieval
provider_id: together
provider_model_id: togethercomputer/m2-bert-80M-32k-retrieval
model_type: embedding
- metadata: - metadata:
embedding_dimension: 384 embedding_dimension: 384
model_id: all-MiniLM-L6-v2 model_id: all-MiniLM-L6-v2

View file

@ -138,6 +138,20 @@ models:
provider_id: together provider_id: together
provider_model_id: meta-llama/Llama-Guard-3-11B-Vision-Turbo provider_model_id: meta-llama/Llama-Guard-3-11B-Vision-Turbo
model_type: llm model_type: llm
- metadata:
embedding_dimensions: 768
context_length: 8192
model_id: togethercomputer/m2-bert-80M-8k-retrieval
provider_id: together
provider_model_id: togethercomputer/m2-bert-80M-8k-retrieval
model_type: embedding
- metadata:
embedding_dimensions: 768
context_length: 32768
model_id: togethercomputer/m2-bert-80M-32k-retrieval
provider_id: together
provider_model_id: togethercomputer/m2-bert-80M-32k-retrieval
model_type: embedding
- metadata: - metadata:
embedding_dimension: 384 embedding_dimension: 384
model_id: all-MiniLM-L6-v2 model_id: all-MiniLM-L6-v2

View file

@ -61,9 +61,11 @@ def get_distribution_template() -> DistributionTemplate:
core_model_to_hf_repo = {m.descriptor(): m.huggingface_repo for m in all_registered_models()} core_model_to_hf_repo = {m.descriptor(): m.huggingface_repo for m in all_registered_models()}
default_models = [ default_models = [
ModelInput( ModelInput(
model_id=core_model_to_hf_repo[m.llama_model], model_id=core_model_to_hf_repo[m.llama_model] if m.llama_model else m.provider_model_id,
provider_model_id=m.provider_model_id, provider_model_id=m.provider_model_id,
provider_id="together", provider_id="together",
metadata=m.metadata,
model_type=m.model_type,
) )
for m in MODEL_ENTRIES for m in MODEL_ENTRIES
] ]