diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index 86dc3207a..32dfba30c 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -21,7 +21,7 @@ from llama_stack.schema_utils import json_schema_type class ModelsProtocolPrivate(Protocol): - async def register_model(self, model: Model) -> None: ... + async def register_model(self, model: Model) -> Model: ... async def unregister_model(self, model_id: str) -> None: ... diff --git a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py index b583896ad..39847e085 100644 --- a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py +++ b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py @@ -43,7 +43,7 @@ class SentenceTransformersInferenceImpl( async def shutdown(self) -> None: pass - async def register_model(self, model: Model) -> None: + async def register_model(self, model: Model) -> Model: return model async def unregister_model(self, model_id: str) -> None: diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index b28cb2016..4cfd81ead 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -300,7 +300,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): return EmbeddingsResponse(embeddings=embeddings) - async def register_model(self, model: Model): + async def register_model(self, model: Model) -> Model: model = await self.register_helper.register_model(model) if model.model_type == ModelType.embedding: logger.info(f"Pulling embedding model `{model.provider_resource_id}` if necessary...") @@ -314,6 +314,8 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): f"Model '{model.provider_resource_id}' is not available in Ollama. Available models: {', '.join(available_models)}" ) + return model + async def convert_message_to_openai_dict_for_ollama(message: Message) -> List[dict]: async def _convert_content(content) -> dict: diff --git a/llama_stack/providers/remote/inference/tgi/tgi.py b/llama_stack/providers/remote/inference/tgi/tgi.py index 757085fb1..fe99fafe1 100644 --- a/llama_stack/providers/remote/inference/tgi/tgi.py +++ b/llama_stack/providers/remote/inference/tgi/tgi.py @@ -83,7 +83,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): async def shutdown(self) -> None: pass - async def register_model(self, model: Model) -> None: + async def register_model(self, model: Model) -> Model: model = await self.register_helper.register_model(model) if model.provider_resource_id != self.model_id: raise ValueError( diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 0fefda7b0..6f99bf007 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -346,7 +346,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): async for chunk in process_completion_stream_response(stream): yield chunk - async def register_model(self, model: Model) -> None: + async def register_model(self, model: Model) -> Model: assert self.client is not None model = await self.register_helper.register_model(model) res = await self.client.models.list() @@ -356,6 +356,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): f"Model {model.provider_resource_id} is not being served by vLLM. " f"Available models: {', '.join(available_models)}" ) + return model async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict: options = get_sampling_options(request.sampling_params)