make model registery a helper for ollama and vllm

This commit is contained in:
Dinesh Yeduguru 2024-11-13 12:24:18 -08:00
parent 3b68e6cbbe
commit f63d51963d
2 changed files with 13 additions and 10 deletions

View file

@ -71,9 +71,9 @@ model_aliases = [
] ]
class OllamaInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPrivate): class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
def __init__(self, url: str) -> None: def __init__(self, url: str) -> None:
ModelRegistryHelper.__init__( self.model_register_helper = ModelRegistryHelper(
self, self,
model_aliases=model_aliases, model_aliases=model_aliases,
) )
@ -203,7 +203,9 @@ class OllamaInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPriva
else: else:
input_dict["raw"] = True input_dict["raw"] = True
input_dict["prompt"] = chat_completion_request_to_prompt( input_dict["prompt"] = chat_completion_request_to_prompt(
request, self.get_llama_model(request.model), self.formatter request,
self.model_register_helper.get_llama_model(request.model),
self.formatter,
) )
else: else:
assert ( assert (
@ -283,7 +285,7 @@ class OllamaInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPriva
raise NotImplementedError() raise NotImplementedError()
async def register_model(self, model: Model) -> Model: async def register_model(self, model: Model) -> Model:
model = await super().register_model(model) model = await self.model_register_helper.register_model(model)
models = await self.client.ps() models = await self.client.ps()
available_models = [m["model"] for m in models["models"]] available_models = [m["model"] for m in models["models"]]
if model.provider_resource_id not in available_models: if model.provider_resource_id not in available_models:

View file

@ -45,9 +45,9 @@ def build_model_aliases():
] ]
class VLLMInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPrivate): class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
def __init__(self, config: VLLMInferenceAdapterConfig) -> None: def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
ModelRegistryHelper.__init__( self.model_register_helper = ModelRegistryHelper(
self, self,
model_aliases=build_model_aliases(), model_aliases=build_model_aliases(),
) )
@ -132,8 +132,7 @@ class VLLMInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPrivate
yield chunk yield chunk
async def register_model(self, model: Model) -> Model: async def register_model(self, model: Model) -> Model:
print(f"model: {model}") model = await self.model_register_helper.register_model(model)
model = await super().register_model(model)
res = self.client.models.list() res = self.client.models.list()
available_models = [m.id for m in res] available_models = [m.id for m in res]
if model.provider_resource_id not in available_models: if model.provider_resource_id not in available_models:
@ -160,7 +159,9 @@ class VLLMInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPrivate
] ]
else: else:
input_dict["prompt"] = chat_completion_request_to_prompt( input_dict["prompt"] = chat_completion_request_to_prompt(
request, self.get_llama_model(request.model), self.formatter request,
self.model_register_helper.get_llama_model(request.model),
self.formatter,
) )
else: else:
assert ( assert (
@ -168,7 +169,7 @@ class VLLMInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPrivate
), "Together does not support media for Completion requests" ), "Together does not support media for Completion requests"
input_dict["prompt"] = completion_request_to_prompt( input_dict["prompt"] = completion_request_to_prompt(
request, request,
self.get_llama_model(request.model), self.model_register_helper.get_llama_model(request.model),
self.formatter, self.formatter,
) )