feat: register embedding models for ollama, together, fireworks (#1190)

# What does this PR do?

We have support for embeddings in our Inference providers, but so far we
haven't done the final step of actually registering the known embedding
models and making sure they are extremely easy to use. This is one step
towards that.

## Test Plan

Run existing inference tests.

```bash

$ cd llama_stack/providers/tests/inference
$ pytest -s -v -k fireworks test_embeddings.py \
   --inference-model nomic-ai/nomic-embed-text-v1.5 --env EMBEDDING_DIMENSION=784
$  pytest -s -v -k together test_embeddings.py \
   --inference-model togethercomputer/m2-bert-80M-8k-retrieval --env EMBEDDING_DIMENSION=784
$ pytest -s -v -k ollama test_embeddings.py \
   --inference-model all-minilm:latest --env EMBEDDING_DIMENSION=784
```

The value of the EMBEDDING_DIMENSION isn't actually used in these tests,
it is merely used by the test fixtures to check if the model is an LLM
or Embedding.
This commit is contained in:
Ashwin Bharambe 2025-02-20 15:39:08 -08:00 committed by GitHub
parent 736560ceba
commit 9436dd570d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 214 additions and 105 deletions

View file

@ -31,12 +31,9 @@ from llama_stack.apis.inference import (
ToolPromptFormat,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.models.llama.datatypes import CoreModelId
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
build_hf_repo_model_entry,
build_model_entry,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAICompatCompletionChoice,
@ -56,80 +53,9 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
request_has_media,
)
log = logging.getLogger(__name__)
from .models import model_entries
model_entries = [
build_hf_repo_model_entry(
"llama3.1:8b-instruct-fp16",
CoreModelId.llama3_1_8b_instruct.value,
),
build_model_entry(
"llama3.1:8b",
CoreModelId.llama3_1_8b_instruct.value,
),
build_hf_repo_model_entry(
"llama3.1:70b-instruct-fp16",
CoreModelId.llama3_1_70b_instruct.value,
),
build_model_entry(
"llama3.1:70b",
CoreModelId.llama3_1_70b_instruct.value,
),
build_hf_repo_model_entry(
"llama3.1:405b-instruct-fp16",
CoreModelId.llama3_1_405b_instruct.value,
),
build_model_entry(
"llama3.1:405b",
CoreModelId.llama3_1_405b_instruct.value,
),
build_hf_repo_model_entry(
"llama3.2:1b-instruct-fp16",
CoreModelId.llama3_2_1b_instruct.value,
),
build_model_entry(
"llama3.2:1b",
CoreModelId.llama3_2_1b_instruct.value,
),
build_hf_repo_model_entry(
"llama3.2:3b-instruct-fp16",
CoreModelId.llama3_2_3b_instruct.value,
),
build_model_entry(
"llama3.2:3b",
CoreModelId.llama3_2_3b_instruct.value,
),
build_hf_repo_model_entry(
"llama3.2-vision:11b-instruct-fp16",
CoreModelId.llama3_2_11b_vision_instruct.value,
),
build_model_entry(
"llama3.2-vision:latest",
CoreModelId.llama3_2_11b_vision_instruct.value,
),
build_hf_repo_model_entry(
"llama3.2-vision:90b-instruct-fp16",
CoreModelId.llama3_2_90b_vision_instruct.value,
),
build_model_entry(
"llama3.2-vision:90b",
CoreModelId.llama3_2_90b_vision_instruct.value,
),
build_hf_repo_model_entry(
"llama3.3:70b",
CoreModelId.llama3_3_70b_instruct.value,
),
# The Llama Guard models don't have their full fp16 versions
# so we are going to alias their default version to the canonical SKU
build_hf_repo_model_entry(
"llama-guard3:8b",
CoreModelId.llama_guard_3_8b.value,
),
build_hf_repo_model_entry(
"llama-guard3:1b",
CoreModelId.llama_guard_3_1b.value,
),
]
log = logging.getLogger(__name__)
class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
@ -348,22 +274,17 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
return EmbeddingsResponse(embeddings=embeddings)
async def register_model(self, model: Model) -> Model:
async def check_model_availability(model_id: str):
response = await self.client.ps()
available_models = [m["model"] for m in response["models"]]
if model_id not in available_models:
raise ValueError(
f"Model '{model_id}' is not available in Ollama. Available models: {', '.join(available_models)}"
)
if model.model_type == ModelType.embedding:
await check_model_availability(model.provider_resource_id)
return model
response = await self.client.list()
else:
response = await self.client.ps()
available_models = [m["model"] for m in response["models"]]
if model.provider_resource_id not in available_models:
raise ValueError(
f"Model '{model.provider_resource_id}' is not available in Ollama. Available models: {', '.join(available_models)}"
)
model = await self.register_helper.register_model(model)
await check_model_availability(model.provider_resource_id)
return model
return await self.register_helper.register_model(model)
async def convert_message_to_openai_dict_for_ollama(message: Message) -> List[dict]: