mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 13:00:39 +00:00
Update ollama to handle embeddings properly
This commit is contained in:
parent
e16ec20a4a
commit
def23d465a
6 changed files with 23 additions and 25 deletions
|
@ -84,7 +84,8 @@ model_entries = [
|
|||
CoreModelId.llama_guard_3_1b.value,
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="all-minilm",
|
||||
provider_model_id="all-minilm:latest",
|
||||
aliases=["all-minilm"],
|
||||
model_type=ModelType.embedding,
|
||||
metadata={
|
||||
"embedding_dimensions": 384,
|
||||
|
|
|
@ -274,22 +274,17 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
return EmbeddingsResponse(embeddings=embeddings)
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
async def check_model_availability(model_id: str):
|
||||
response = await self.client.ps()
|
||||
available_models = [m["model"] for m in response["models"]]
|
||||
if model_id not in available_models:
|
||||
raise ValueError(
|
||||
f"Model '{model_id}' is not available in Ollama. Available models: {', '.join(available_models)}"
|
||||
)
|
||||
|
||||
if model.model_type == ModelType.embedding:
|
||||
await check_model_availability(model.provider_resource_id)
|
||||
return model
|
||||
response = await self.client.list()
|
||||
else:
|
||||
response = await self.client.ps()
|
||||
available_models = [m["model"] for m in response["models"]]
|
||||
if model.provider_resource_id not in available_models:
|
||||
raise ValueError(
|
||||
f"Model '{model.provider_resource_id}' is not available in Ollama. Available models: {', '.join(available_models)}"
|
||||
)
|
||||
|
||||
model = await self.register_helper.register_model(model)
|
||||
await check_model_availability(model.provider_resource_id)
|
||||
|
||||
return model
|
||||
return await self.register_helper.register_model(model)
|
||||
|
||||
|
||||
async def convert_message_to_openai_dict_for_ollama(message: Message) -> List[dict]:
|
||||
|
|
|
@ -56,14 +56,16 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
|||
def __init__(self, model_entries: List[ProviderModelEntry]):
|
||||
self.alias_to_provider_id_map = {}
|
||||
self.provider_id_to_llama_model_map = {}
|
||||
for alias_obj in model_entries:
|
||||
for alias in alias_obj.aliases:
|
||||
self.alias_to_provider_id_map[alias] = alias_obj.provider_model_id
|
||||
for entry in model_entries:
|
||||
for alias in entry.aliases:
|
||||
self.alias_to_provider_id_map[alias] = entry.provider_model_id
|
||||
|
||||
# also add a mapping from provider model id to itself for easy lookup
|
||||
self.alias_to_provider_id_map[alias_obj.provider_model_id] = alias_obj.provider_model_id
|
||||
# ensure we can go from llama model to provider model id
|
||||
self.alias_to_provider_id_map[alias_obj.llama_model] = alias_obj.provider_model_id
|
||||
self.provider_id_to_llama_model_map[alias_obj.provider_model_id] = alias_obj.llama_model
|
||||
self.alias_to_provider_id_map[entry.provider_model_id] = entry.provider_model_id
|
||||
|
||||
if entry.llama_model:
|
||||
self.alias_to_provider_id_map[entry.llama_model] = entry.provider_model_id
|
||||
self.provider_id_to_llama_model_map[entry.provider_model_id] = entry.llama_model
|
||||
|
||||
def get_provider_model_id(self, identifier: str) -> Optional[str]:
|
||||
return self.alias_to_provider_id_map.get(identifier, None)
|
||||
|
|
|
@ -72,7 +72,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
embedding_model = ModelInput(
|
||||
model_id="all-MiniLM-L6-v2",
|
||||
provider_id="ollama",
|
||||
provider_model_id="all-minilm",
|
||||
provider_model_id="all-minilm:latest",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={
|
||||
"embedding_dimension": 384,
|
||||
|
|
|
@ -111,7 +111,7 @@ models:
|
|||
embedding_dimension: 384
|
||||
model_id: all-MiniLM-L6-v2
|
||||
provider_id: ollama
|
||||
provider_model_id: all-minilm
|
||||
provider_model_id: all-minilm:latest
|
||||
model_type: embedding
|
||||
shields:
|
||||
- shield_id: ${env.SAFETY_MODEL}
|
||||
|
|
|
@ -104,7 +104,7 @@ models:
|
|||
embedding_dimension: 384
|
||||
model_id: all-MiniLM-L6-v2
|
||||
provider_id: ollama
|
||||
provider_model_id: all-minilm
|
||||
provider_model_id: all-minilm:latest
|
||||
model_type: embedding
|
||||
shields: []
|
||||
vector_dbs: []
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue