forked from phoenix-oss/llama-stack-mirror
Inference to use provider resource id to register and validate (#428)
This PR changes the way model id gets translated to the final model name that gets passed through the provider. Major changes include: 1) Providers are responsible for registering an object and as part of the registration returning the object with the correct provider specific name of the model provider_resource_id 2) To help with the common look ups different names a new ModelLookup class is created. Tested all inference providers including together, fireworks, vllm, ollama, meta reference and bedrock
This commit is contained in:
parent
e51107e019
commit
fdff24e77a
21 changed files with 460 additions and 290 deletions
|
@ -86,6 +86,7 @@ class Llama:
|
|||
and loads the pre-trained model and tokenizer.
|
||||
"""
|
||||
model = resolve_model(config.model)
|
||||
llama_model = model.core_model_id.value
|
||||
|
||||
if not torch.distributed.is_initialized():
|
||||
torch.distributed.init_process_group("nccl")
|
||||
|
@ -186,13 +187,20 @@ class Llama:
|
|||
model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
print(f"Loaded in {time.time() - start_time:.2f} seconds")
|
||||
return Llama(model, tokenizer, model_args)
|
||||
return Llama(model, tokenizer, model_args, llama_model)
|
||||
|
||||
def __init__(self, model: Transformer, tokenizer: Tokenizer, args: ModelArgs):
|
||||
def __init__(
|
||||
self,
|
||||
model: Transformer,
|
||||
tokenizer: Tokenizer,
|
||||
args: ModelArgs,
|
||||
llama_model: str,
|
||||
):
|
||||
self.args = args
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.formatter = ChatFormat(tokenizer)
|
||||
self.llama_model = llama_model
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(
|
||||
|
@ -369,7 +377,7 @@ class Llama:
|
|||
self,
|
||||
request: ChatCompletionRequest,
|
||||
) -> Generator:
|
||||
messages = chat_completion_request_to_messages(request)
|
||||
messages = chat_completion_request_to_messages(request, self.llama_model)
|
||||
|
||||
sampling_params = request.sampling_params
|
||||
max_gen_len = sampling_params.max_tokens
|
||||
|
|
|
@ -11,9 +11,11 @@ from typing import AsyncGenerator, List
|
|||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import build_model_alias
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
convert_image_media_to_url,
|
||||
request_has_media,
|
||||
|
@ -28,10 +30,19 @@ from .model_parallel import LlamaModelParallelGenerator
|
|||
SEMAPHORE = asyncio.Semaphore(1)
|
||||
|
||||
|
||||
class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||
class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolPrivate):
|
||||
def __init__(self, config: MetaReferenceInferenceConfig) -> None:
|
||||
self.config = config
|
||||
model = resolve_model(config.model)
|
||||
ModelRegistryHelper.__init__(
|
||||
self,
|
||||
[
|
||||
build_model_alias(
|
||||
model.descriptor(),
|
||||
model.core_model_id.value,
|
||||
)
|
||||
],
|
||||
)
|
||||
if model is None:
|
||||
raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`")
|
||||
self.model = model
|
||||
|
@ -45,12 +56,6 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
else:
|
||||
self.generator = Llama.build(self.config)
|
||||
|
||||
async def register_model(self, model: Model) -> None:
|
||||
if model.identifier != self.model.descriptor():
|
||||
raise ValueError(
|
||||
f"Model mismatch: {model.identifier} != {self.model.descriptor()}"
|
||||
)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
if self.config.create_distributed_process_group:
|
||||
self.generator.stop()
|
||||
|
@ -68,7 +73,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
|
||||
async def completion(
|
||||
self,
|
||||
model: str,
|
||||
model_id: str,
|
||||
content: InterleavedTextMedia,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
|
@ -79,7 +84,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
|
||||
|
||||
request = CompletionRequest(
|
||||
model=model,
|
||||
model=model_id,
|
||||
content=content,
|
||||
sampling_params=sampling_params,
|
||||
response_format=response_format,
|
||||
|
@ -186,7 +191,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
|
@ -201,7 +206,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
|
||||
# wrapper request to make it easier to pass around (internal only, not exposed to API)
|
||||
request = ChatCompletionRequest(
|
||||
model=model,
|
||||
model=model_id,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
|
@ -386,7 +391,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
|
||||
async def embeddings(
|
||||
self,
|
||||
model: str,
|
||||
model_id: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
|
|
@ -110,7 +110,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
|
||||
async def completion(
|
||||
self,
|
||||
model: str,
|
||||
model_id: str,
|
||||
content: InterleavedTextMedia,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
|
@ -120,7 +120,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
log.info("vLLM completion")
|
||||
messages = [UserMessage(content=content)]
|
||||
return self.chat_completion(
|
||||
model=model,
|
||||
model=model_id,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
stream=stream,
|
||||
|
@ -129,7 +129,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
|
@ -144,7 +144,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
assert self.engine is not None
|
||||
|
||||
request = ChatCompletionRequest(
|
||||
model=model,
|
||||
model=model_id,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
|
@ -215,7 +215,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
yield chunk
|
||||
|
||||
async def embeddings(
|
||||
self, model: str, contents: list[InterleavedTextMedia]
|
||||
self, model_id: str, contents: list[InterleavedTextMedia]
|
||||
) -> EmbeddingsResponse:
|
||||
log.info("vLLM embeddings")
|
||||
# TODO
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue