feat: register embedding models for ollama, together, fireworks (#1190)

# What does this PR do?

We have support for embeddings in our Inference providers, but so far we
haven't done the final step of actually registering the known embedding
models and making sure they are extremely easy to use. This is one step
towards that.

## Test Plan

Run existing inference tests.

```bash

$ cd llama_stack/providers/tests/inference
$ pytest -s -v -k fireworks test_embeddings.py \
   --inference-model nomic-ai/nomic-embed-text-v1.5 --env EMBEDDING_DIMENSION=784
$  pytest -s -v -k together test_embeddings.py \
   --inference-model togethercomputer/m2-bert-80M-8k-retrieval --env EMBEDDING_DIMENSION=784
$ pytest -s -v -k ollama test_embeddings.py \
   --inference-model all-minilm:latest --env EMBEDDING_DIMENSION=784
```

The value of the EMBEDDING_DIMENSION isn't actually used in these tests,
it is merely used by the test fixtures to check if the model is an LLM
or Embedding.
This commit is contained in:
Ashwin Bharambe 2025-02-20 15:39:08 -08:00 committed by GitHub
parent 736560ceba
commit 9436dd570d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 214 additions and 105 deletions

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Optional
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
@ -23,6 +23,7 @@ class ProviderModelEntry(BaseModel):
aliases: List[str] = Field(default_factory=list)
llama_model: Optional[str] = None
model_type: ModelType = ModelType.llm
metadata: Dict[str, Any] = Field(default_factory=dict)
def get_huggingface_repo(model_descriptor: str) -> Optional[str]:
@ -47,6 +48,7 @@ def build_model_entry(provider_model_id: str, model_descriptor: str) -> Provider
provider_model_id=provider_model_id,
aliases=[],
llama_model=model_descriptor,
model_type=ModelType.llm,
)
@ -54,14 +56,16 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
def __init__(self, model_entries: List[ProviderModelEntry]):
self.alias_to_provider_id_map = {}
self.provider_id_to_llama_model_map = {}
for alias_obj in model_entries:
for alias in alias_obj.aliases:
self.alias_to_provider_id_map[alias] = alias_obj.provider_model_id
for entry in model_entries:
for alias in entry.aliases:
self.alias_to_provider_id_map[alias] = entry.provider_model_id
# also add a mapping from provider model id to itself for easy lookup
self.alias_to_provider_id_map[alias_obj.provider_model_id] = alias_obj.provider_model_id
# ensure we can go from llama model to provider model id
self.alias_to_provider_id_map[alias_obj.llama_model] = alias_obj.provider_model_id
self.provider_id_to_llama_model_map[alias_obj.provider_model_id] = alias_obj.llama_model
self.alias_to_provider_id_map[entry.provider_model_id] = entry.provider_model_id
if entry.llama_model:
self.alias_to_provider_id_map[entry.llama_model] = entry.provider_model_id
self.provider_id_to_llama_model_map[entry.provider_model_id] = entry.llama_model
def get_provider_model_id(self, identifier: str) -> Optional[str]:
return self.alias_to_provider_id_map.get(identifier, None)