Inference to use provider resource id to register and validate (#428)

This PR changes the way model id gets translated to the final model name
that gets passed through the provider.
Major changes include:
1) Providers are responsible for registering an object and as part of
the registration returning the object with the correct provider specific
name of the model provider_resource_id
2) To help with the common look ups different names a new ModelLookup
class is created.



Tested all inference providers including together, fireworks, vllm,
ollama, meta reference and bedrock
This commit is contained in:
Dinesh Yeduguru 2024-11-12 20:02:00 -08:00 committed by GitHub
parent e51107e019
commit fdff24e77a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 460 additions and 290 deletions

View file

@ -8,13 +8,17 @@ from typing import AsyncGenerator
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import all_registered_models, resolve_model
from llama_models.sku_list import all_registered_models
from openai import OpenAI
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
process_chat_completion_response,
@ -30,44 +34,36 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import VLLMInferenceAdapterConfig
class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
def build_model_aliases():
return [
build_model_alias(
model.huggingface_repo,
model.descriptor(),
)
for model in all_registered_models()
if model.huggingface_repo
]
class VLLMInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPrivate):
def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
ModelRegistryHelper.__init__(
self,
model_aliases=build_model_aliases(),
)
self.config = config
self.formatter = ChatFormat(Tokenizer.get_instance())
self.client = None
self.huggingface_repo_to_llama_model_id = {
model.huggingface_repo: model.descriptor()
for model in all_registered_models()
if model.huggingface_repo
}
async def initialize(self) -> None:
self.client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
async def register_model(self, model: Model) -> None:
for running_model in self.client.models.list():
repo = running_model.id
if repo not in self.huggingface_repo_to_llama_model_id:
print(f"Unknown model served by vllm: {repo}")
continue
identifier = self.huggingface_repo_to_llama_model_id[repo]
if identifier == model.provider_resource_id:
print(
f"Verified that model {model.provider_resource_id} is being served by vLLM"
)
return
raise ValueError(
f"Model {model.provider_resource_id} is not being served by vLLM"
)
async def shutdown(self) -> None:
pass
async def completion(
self,
model: str,
model_id: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
@ -78,7 +74,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
async def chat_completion(
self,
model: str,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
@ -88,8 +84,9 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
model=model,
model=model.provider_resource_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
@ -141,10 +138,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
if "max_tokens" not in options:
options["max_tokens"] = self.config.max_tokens
model = resolve_model(request.model)
if model is None:
raise ValueError(f"Unknown model: {request.model}")
input_dict = {}
media_present = request_has_media(request)
if isinstance(request, ChatCompletionRequest):
@ -156,16 +149,20 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
]
else:
input_dict["prompt"] = chat_completion_request_to_prompt(
request, self.formatter
request, self.get_llama_model(request.model), self.formatter
)
else:
assert (
not media_present
), "Together does not support media for Completion requests"
input_dict["prompt"] = completion_request_to_prompt(request, self.formatter)
input_dict["prompt"] = completion_request_to_prompt(
request,
self.get_llama_model(request.model),
self.formatter,
)
return {
"model": model.huggingface_repo,
"model": request.model,
**input_dict,
"stream": request.stream,
**options,
@ -173,7 +170,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
async def embeddings(
self,
model: str,
model_id: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
raise NotImplementedError()