Inference to use provider resource id to register and validate (#428)

This PR changes the way model id gets translated to the final model name
that gets passed through the provider.
Major changes include:
1) Providers are responsible for registering an object and as part of
the registration returning the object with the correct provider specific
name of the model provider_resource_id
2) To help with the common look ups different names a new ModelLookup
class is created.



Tested all inference providers including together, fireworks, vllm,
ollama, meta reference and bedrock
This commit is contained in:
Dinesh Yeduguru 2024-11-12 20:02:00 -08:00 committed by GitHub
parent e51107e019
commit fdff24e77a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 460 additions and 290 deletions

View file

@ -86,6 +86,7 @@ class Llama:
and loads the pre-trained model and tokenizer.
"""
model = resolve_model(config.model)
llama_model = model.core_model_id.value
if not torch.distributed.is_initialized():
torch.distributed.init_process_group("nccl")
@ -186,13 +187,20 @@ class Llama:
model.load_state_dict(state_dict, strict=False)
print(f"Loaded in {time.time() - start_time:.2f} seconds")
return Llama(model, tokenizer, model_args)
return Llama(model, tokenizer, model_args, llama_model)
def __init__(self, model: Transformer, tokenizer: Tokenizer, args: ModelArgs):
def __init__(
self,
model: Transformer,
tokenizer: Tokenizer,
args: ModelArgs,
llama_model: str,
):
self.args = args
self.model = model
self.tokenizer = tokenizer
self.formatter = ChatFormat(tokenizer)
self.llama_model = llama_model
@torch.inference_mode()
def generate(
@ -369,7 +377,7 @@ class Llama:
self,
request: ChatCompletionRequest,
) -> Generator:
messages = chat_completion_request_to_messages(request)
messages = chat_completion_request_to_messages(request, self.llama_model)
sampling_params = request.sampling_params
max_gen_len = sampling_params.max_tokens

View file

@ -11,9 +11,11 @@ from typing import AsyncGenerator, List
from llama_models.sku_list import resolve_model
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import build_model_alias
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.prompt_adapter import (
convert_image_media_to_url,
request_has_media,
@ -28,10 +30,19 @@ from .model_parallel import LlamaModelParallelGenerator
SEMAPHORE = asyncio.Semaphore(1)
class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolPrivate):
def __init__(self, config: MetaReferenceInferenceConfig) -> None:
self.config = config
model = resolve_model(config.model)
ModelRegistryHelper.__init__(
self,
[
build_model_alias(
model.descriptor(),
model.core_model_id.value,
)
],
)
if model is None:
raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`")
self.model = model
@ -45,12 +56,6 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
else:
self.generator = Llama.build(self.config)
async def register_model(self, model: Model) -> None:
if model.identifier != self.model.descriptor():
raise ValueError(
f"Model mismatch: {model.identifier} != {self.model.descriptor()}"
)
async def shutdown(self) -> None:
if self.config.create_distributed_process_group:
self.generator.stop()
@ -68,7 +73,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
async def completion(
self,
model: str,
model_id: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
@ -79,7 +84,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
request = CompletionRequest(
model=model,
model=model_id,
content=content,
sampling_params=sampling_params,
response_format=response_format,
@ -186,7 +191,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
async def chat_completion(
self,
model: str,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
@ -201,7 +206,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
# wrapper request to make it easier to pass around (internal only, not exposed to API)
request = ChatCompletionRequest(
model=model,
model=model_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
@ -386,7 +391,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
async def embeddings(
self,
model: str,
model_id: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
raise NotImplementedError()