forked from phoenix-oss/llama-stack-mirror
Fix TGI register_model() issue
This commit is contained in:
parent
4b94cd313c
commit
707da55c23
1 changed files with 24 additions and 16 deletions
|
@ -17,6 +17,10 @@ from llama_stack.apis.inference import * # noqa: F403
|
||||||
from llama_stack.apis.models import * # noqa: F403
|
from llama_stack.apis.models import * # noqa: F403
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
||||||
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
|
build_model_alias,
|
||||||
|
ModelRegistryHelper,
|
||||||
|
)
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
get_sampling_options,
|
get_sampling_options,
|
||||||
|
@ -37,6 +41,17 @@ from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImpl
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def build_model_aliases():
|
||||||
|
return [
|
||||||
|
build_model_alias(
|
||||||
|
model.huggingface_repo,
|
||||||
|
model.descriptor(),
|
||||||
|
)
|
||||||
|
for model in all_registered_models()
|
||||||
|
if model.huggingface_repo
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class _HfAdapter(Inference, ModelsProtocolPrivate):
|
class _HfAdapter(Inference, ModelsProtocolPrivate):
|
||||||
client: AsyncInferenceClient
|
client: AsyncInferenceClient
|
||||||
max_tokens: int
|
max_tokens: int
|
||||||
|
@ -44,31 +59,24 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||||
|
self.register_helper = ModelRegistryHelper(build_model_aliases())
|
||||||
self.huggingface_repo_to_llama_model_id = {
|
self.huggingface_repo_to_llama_model_id = {
|
||||||
model.huggingface_repo: model.descriptor()
|
model.huggingface_repo: model.descriptor()
|
||||||
for model in all_registered_models()
|
for model in all_registered_models()
|
||||||
if model.huggingface_repo
|
if model.huggingface_repo
|
||||||
}
|
}
|
||||||
|
|
||||||
async def register_model(self, model: Model) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def list_models(self) -> List[Model]:
|
|
||||||
repo = self.model_id
|
|
||||||
identifier = self.huggingface_repo_to_llama_model_id[repo]
|
|
||||||
return [
|
|
||||||
Model(
|
|
||||||
identifier=identifier,
|
|
||||||
llama_model=identifier,
|
|
||||||
metadata={
|
|
||||||
"huggingface_repo": repo,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def register_model(self, model: Model) -> None:
|
||||||
|
model = await self.register_helper.register_model(model)
|
||||||
|
if model.provider_resource_id != self.model_id:
|
||||||
|
raise ValueError(
|
||||||
|
f"Model {model.provider_resource_id} does not match the model {self.model_id} served by TGI."
|
||||||
|
)
|
||||||
|
return model
|
||||||
|
|
||||||
async def unregister_model(self, model_id: str) -> None:
|
async def unregister_model(self, model_id: str) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue