address feedback

This commit is contained in:
Dinesh Yeduguru 2024-11-13 13:02:45 -08:00
parent 96b1bafcde
commit 7e4765c45b
3 changed files with 44 additions and 8 deletions

View file

@ -73,7 +73,7 @@ model_aliases = [
class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
def __init__(self, url: str) -> None:
self.model_register_helper = ModelRegistryHelper(model_aliases)
self.register_helper = ModelRegistryHelper(model_aliases)
self.url = url
self.formatter = ChatFormat(Tokenizer.get_instance())
@ -201,7 +201,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
input_dict["raw"] = True
input_dict["prompt"] = chat_completion_request_to_prompt(
request,
self.model_register_helper.get_llama_model(request.model),
self.register_helper.get_llama_model(request.model),
self.formatter,
)
else:
@ -282,7 +282,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
raise NotImplementedError()
async def register_model(self, model: Model) -> Model:
model = await self.model_register_helper.register_model(model)
model = await self.register_helper.register_model(model)
models = await self.client.ps()
available_models = [m["model"] for m in models["models"]]
if model.provider_resource_id not in available_models: