Update ollama to handle embeddings properly

This commit is contained in:
Ashwin Bharambe 2025-02-20 15:35:44 -08:00
parent e16ec20a4a
commit def23d465a
6 changed files with 23 additions and 25 deletions

View file

@ -84,7 +84,8 @@ model_entries = [
CoreModelId.llama_guard_3_1b.value,
),
ProviderModelEntry(
provider_model_id="all-minilm",
provider_model_id="all-minilm:latest",
aliases=["all-minilm"],
model_type=ModelType.embedding,
metadata={
"embedding_dimensions": 384,

View file

@ -274,22 +274,17 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
return EmbeddingsResponse(embeddings=embeddings)
async def register_model(self, model: Model) -> Model:
async def check_model_availability(model_id: str):
response = await self.client.ps()
available_models = [m["model"] for m in response["models"]]
if model_id not in available_models:
raise ValueError(
f"Model '{model_id}' is not available in Ollama. Available models: {', '.join(available_models)}"
)
if model.model_type == ModelType.embedding:
await check_model_availability(model.provider_resource_id)
return model
response = await self.client.list()
else:
response = await self.client.ps()
available_models = [m["model"] for m in response["models"]]
if model.provider_resource_id not in available_models:
raise ValueError(
f"Model '{model.provider_resource_id}' is not available in Ollama. Available models: {', '.join(available_models)}"
)
model = await self.register_helper.register_model(model)
await check_model_availability(model.provider_resource_id)
return model
return await self.register_helper.register_model(model)
async def convert_message_to_openai_dict_for_ollama(message: Message) -> List[dict]:

View file

@ -56,14 +56,16 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
def __init__(self, model_entries: List[ProviderModelEntry]):
self.alias_to_provider_id_map = {}
self.provider_id_to_llama_model_map = {}
for alias_obj in model_entries:
for alias in alias_obj.aliases:
self.alias_to_provider_id_map[alias] = alias_obj.provider_model_id
for entry in model_entries:
for alias in entry.aliases:
self.alias_to_provider_id_map[alias] = entry.provider_model_id
# also add a mapping from provider model id to itself for easy lookup
self.alias_to_provider_id_map[alias_obj.provider_model_id] = alias_obj.provider_model_id
# ensure we can go from llama model to provider model id
self.alias_to_provider_id_map[alias_obj.llama_model] = alias_obj.provider_model_id
self.provider_id_to_llama_model_map[alias_obj.provider_model_id] = alias_obj.llama_model
self.alias_to_provider_id_map[entry.provider_model_id] = entry.provider_model_id
if entry.llama_model:
self.alias_to_provider_id_map[entry.llama_model] = entry.provider_model_id
self.provider_id_to_llama_model_map[entry.provider_model_id] = entry.llama_model
def get_provider_model_id(self, identifier: str) -> Optional[str]:
return self.alias_to_provider_id_map.get(identifier, None)

View file

@ -72,7 +72,7 @@ def get_distribution_template() -> DistributionTemplate:
embedding_model = ModelInput(
model_id="all-MiniLM-L6-v2",
provider_id="ollama",
provider_model_id="all-minilm",
provider_model_id="all-minilm:latest",
model_type=ModelType.embedding,
metadata={
"embedding_dimension": 384,

View file

@ -111,7 +111,7 @@ models:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2
provider_id: ollama
provider_model_id: all-minilm
provider_model_id: all-minilm:latest
model_type: embedding
shields:
- shield_id: ${env.SAFETY_MODEL}

View file

@ -104,7 +104,7 @@ models:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2
provider_id: ollama
provider_model_id: all-minilm
provider_model_id: all-minilm:latest
model_type: embedding
shields: []
vector_dbs: []