mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-27 14:38:49 +00:00
add support embedding models and keeping provider models separate
This commit is contained in:
parent
cf629f81fe
commit
8fb4feeba1
6 changed files with 264 additions and 18 deletions
|
@ -20,7 +20,7 @@ from llama_stack.apis.inference import (
|
|||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
||||
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate, ModelType
|
||||
from llama_stack.providers.utils.inference.embedding_mixin import (
|
||||
SentenceTransformerEmbeddingMixin,
|
||||
)
|
||||
|
@ -41,6 +41,8 @@ class SentenceTransformersInferenceImpl(
|
|||
InferenceProvider,
|
||||
ModelsProtocolPrivate,
|
||||
):
|
||||
__provider_id__: str
|
||||
|
||||
def __init__(self, config: SentenceTransformersInferenceConfig) -> None:
|
||||
self.config = config
|
||||
|
||||
|
@ -54,8 +56,17 @@ class SentenceTransformersInferenceImpl(
|
|||
return False
|
||||
|
||||
async def list_models(self) -> list[Model] | None:
|
||||
# TODO: add all-mini-lm models
|
||||
return None
|
||||
return [
|
||||
Model(
|
||||
identifier="all-MiniLM-L6-v2",
|
||||
provider_resource_id="all-MiniLM-L6-v2",
|
||||
provider_id=self.__provider_id__,
|
||||
metadata={
|
||||
"embedding_dimension": 384,
|
||||
},
|
||||
model_type=ModelType.embedding,
|
||||
),
|
||||
]
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
return model
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue