mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-01 16:24:44 +00:00
make model registery a helper for ollama and vllm
This commit is contained in:
parent
3b68e6cbbe
commit
f63d51963d
2 changed files with 13 additions and 10 deletions
|
@ -71,9 +71,9 @@ model_aliases = [
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class OllamaInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPrivate):
|
class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
def __init__(self, url: str) -> None:
|
def __init__(self, url: str) -> None:
|
||||||
ModelRegistryHelper.__init__(
|
self.model_register_helper = ModelRegistryHelper(
|
||||||
self,
|
self,
|
||||||
model_aliases=model_aliases,
|
model_aliases=model_aliases,
|
||||||
)
|
)
|
||||||
|
@ -203,7 +203,9 @@ class OllamaInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPriva
|
||||||
else:
|
else:
|
||||||
input_dict["raw"] = True
|
input_dict["raw"] = True
|
||||||
input_dict["prompt"] = chat_completion_request_to_prompt(
|
input_dict["prompt"] = chat_completion_request_to_prompt(
|
||||||
request, self.get_llama_model(request.model), self.formatter
|
request,
|
||||||
|
self.model_register_helper.get_llama_model(request.model),
|
||||||
|
self.formatter,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert (
|
assert (
|
||||||
|
@ -283,7 +285,7 @@ class OllamaInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPriva
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
async def register_model(self, model: Model) -> Model:
|
async def register_model(self, model: Model) -> Model:
|
||||||
model = await super().register_model(model)
|
model = await self.model_register_helper.register_model(model)
|
||||||
models = await self.client.ps()
|
models = await self.client.ps()
|
||||||
available_models = [m["model"] for m in models["models"]]
|
available_models = [m["model"] for m in models["models"]]
|
||||||
if model.provider_resource_id not in available_models:
|
if model.provider_resource_id not in available_models:
|
||||||
|
|
|
@ -45,9 +45,9 @@ def build_model_aliases():
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class VLLMInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPrivate):
|
class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
|
def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
|
||||||
ModelRegistryHelper.__init__(
|
self.model_register_helper = ModelRegistryHelper(
|
||||||
self,
|
self,
|
||||||
model_aliases=build_model_aliases(),
|
model_aliases=build_model_aliases(),
|
||||||
)
|
)
|
||||||
|
@ -132,8 +132,7 @@ class VLLMInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPrivate
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
async def register_model(self, model: Model) -> Model:
|
async def register_model(self, model: Model) -> Model:
|
||||||
print(f"model: {model}")
|
model = await self.model_register_helper.register_model(model)
|
||||||
model = await super().register_model(model)
|
|
||||||
res = self.client.models.list()
|
res = self.client.models.list()
|
||||||
available_models = [m.id for m in res]
|
available_models = [m.id for m in res]
|
||||||
if model.provider_resource_id not in available_models:
|
if model.provider_resource_id not in available_models:
|
||||||
|
@ -160,7 +159,9 @@ class VLLMInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPrivate
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
input_dict["prompt"] = chat_completion_request_to_prompt(
|
input_dict["prompt"] = chat_completion_request_to_prompt(
|
||||||
request, self.get_llama_model(request.model), self.formatter
|
request,
|
||||||
|
self.model_register_helper.get_llama_model(request.model),
|
||||||
|
self.formatter,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert (
|
assert (
|
||||||
|
@ -168,7 +169,7 @@ class VLLMInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPrivate
|
||||||
), "Together does not support media for Completion requests"
|
), "Together does not support media for Completion requests"
|
||||||
input_dict["prompt"] = completion_request_to_prompt(
|
input_dict["prompt"] = completion_request_to_prompt(
|
||||||
request,
|
request,
|
||||||
self.get_llama_model(request.model),
|
self.model_register_helper.get_llama_model(request.model),
|
||||||
self.formatter,
|
self.formatter,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue