mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-15 09:36:10 +00:00
migrate model to Resource and new registration signature (#410)
* resource oriented object design for models * add back llama_model field * working tests * register singature fix * address feedback --------- Co-authored-by: Dinesh Yeduguru <dineshyv@fb.com>
This commit is contained in:
parent
bd0622ef10
commit
ec644d3418
17 changed files with 99 additions and 90 deletions
|
@ -15,7 +15,7 @@ from llama_models.llama3.api.tokenizer import Tokenizer
|
|||
from ollama import AsyncClient
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
||||
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_options,
|
||||
|
@ -65,10 +65,11 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def register_model(self, model: ModelDef) -> None:
|
||||
raise ValueError("Dynamic model registration is not supported")
|
||||
async def register_model(self, model: Model) -> None:
|
||||
if model.identifier not in OLLAMA_SUPPORTED_MODELS:
|
||||
raise ValueError(f"Model {model.identifier} is not supported by Ollama")
|
||||
|
||||
async def list_models(self) -> List[ModelDef]:
|
||||
async def list_models(self) -> List[Model]:
|
||||
ollama_to_llama = {v: k for k, v in OLLAMA_SUPPORTED_MODELS.items()}
|
||||
|
||||
ret = []
|
||||
|
@ -80,9 +81,8 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
|
||||
llama_model = ollama_to_llama[r["model"]]
|
||||
ret.append(
|
||||
ModelDef(
|
||||
Model(
|
||||
identifier=llama_model,
|
||||
llama_model=llama_model,
|
||||
metadata={
|
||||
"ollama_model": r["model"],
|
||||
},
|
||||
|
|
|
@ -14,7 +14,7 @@ class SampleInferenceImpl(Inference):
|
|||
def __init__(self, config: SampleConfig):
|
||||
self.config = config
|
||||
|
||||
async def register_model(self, model: ModelDef) -> None:
|
||||
async def register_model(self, model: Model) -> None:
|
||||
# these are the model names the Llama Stack will use to route requests to this provider
|
||||
# perform validation here if necessary
|
||||
pass
|
||||
|
|
|
@ -16,7 +16,7 @@ from llama_models.sku_list import all_registered_models
|
|||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.models import * # noqa: F403
|
||||
|
||||
from llama_stack.providers.datatypes import ModelDef, ModelsProtocolPrivate
|
||||
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
||||
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_options,
|
||||
|
@ -50,14 +50,14 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
|||
if model.huggingface_repo
|
||||
}
|
||||
|
||||
async def register_model(self, model: ModelDef) -> None:
|
||||
raise ValueError("Model registration is not supported for HuggingFace models")
|
||||
async def register_model(self, model: Model) -> None:
|
||||
pass
|
||||
|
||||
async def list_models(self) -> List[ModelDef]:
|
||||
async def list_models(self) -> List[Model]:
|
||||
repo = self.model_id
|
||||
identifier = self.huggingface_repo_to_llama_model_id[repo]
|
||||
return [
|
||||
ModelDef(
|
||||
Model(
|
||||
identifier=identifier,
|
||||
llama_model=identifier,
|
||||
metadata={
|
||||
|
|
|
@ -13,7 +13,7 @@ from llama_models.sku_list import all_registered_models, resolve_model
|
|||
from openai import OpenAI
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
||||
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_options,
|
||||
|
@ -44,13 +44,13 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
async def initialize(self) -> None:
|
||||
self.client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
|
||||
|
||||
async def register_model(self, model: ModelDef) -> None:
|
||||
async def register_model(self, model: Model) -> None:
|
||||
raise ValueError("Model registration is not supported for vLLM models")
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def list_models(self) -> List[ModelDef]:
|
||||
async def list_models(self) -> List[Model]:
|
||||
models = []
|
||||
for model in self.client.models.list():
|
||||
repo = model.id
|
||||
|
@ -60,7 +60,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
|
||||
identifier = self.huggingface_repo_to_llama_model_id[repo]
|
||||
models.append(
|
||||
ModelDef(
|
||||
Model(
|
||||
identifier=identifier,
|
||||
llama_model=identifier,
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue