diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index ed30b7016..a5ea8e0db 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -71,9 +71,9 @@ model_aliases = [ ] -class OllamaInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPrivate): +class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): def __init__(self, url: str) -> None: - ModelRegistryHelper.__init__( + self.model_register_helper = ModelRegistryHelper( self, model_aliases=model_aliases, ) @@ -203,7 +203,9 @@ class OllamaInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPriva else: input_dict["raw"] = True input_dict["prompt"] = chat_completion_request_to_prompt( - request, self.get_llama_model(request.model), self.formatter + request, + self.model_register_helper.get_llama_model(request.model), + self.formatter, ) else: assert ( @@ -283,7 +285,7 @@ class OllamaInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPriva raise NotImplementedError() async def register_model(self, model: Model) -> Model: - model = await super().register_model(model) + model = await self.model_register_helper.register_model(model) models = await self.client.ps() available_models = [m["model"] for m in models["models"]] if model.provider_resource_id not in available_models: diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 21ff05f4d..7fad85e8c 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -45,9 +45,9 @@ def build_model_aliases(): ] -class VLLMInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPrivate): +class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): def __init__(self, config: VLLMInferenceAdapterConfig) -> None: - ModelRegistryHelper.__init__( + self.model_register_helper = ModelRegistryHelper( self, model_aliases=build_model_aliases(), ) @@ -132,8 +132,7 @@ class VLLMInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPrivate yield chunk async def register_model(self, model: Model) -> Model: - print(f"model: {model}") - model = await super().register_model(model) + model = await self.model_register_helper.register_model(model) res = self.client.models.list() available_models = [m.id for m in res] if model.provider_resource_id not in available_models: @@ -160,7 +159,9 @@ class VLLMInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPrivate ] else: input_dict["prompt"] = chat_completion_request_to_prompt( - request, self.get_llama_model(request.model), self.formatter + request, + self.model_register_helper.get_llama_model(request.model), + self.formatter, ) else: assert ( @@ -168,7 +169,7 @@ class VLLMInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPrivate ), "Together does not support media for Completion requests" input_dict["prompt"] = completion_request_to_prompt( request, - self.get_llama_model(request.model), + self.model_register_helper.get_llama_model(request.model), self.formatter, )