Merge branch 'main' into inference_refactor

This commit is contained in:
Botao Chen 2024-12-17 20:10:23 -08:00
commit fadb7deae5
79 changed files with 1547 additions and 2026 deletions

View file

@ -7,24 +7,50 @@
import asyncio
import logging
from typing import AsyncGenerator, List
from typing import AsyncGenerator, List, Optional, Union
from llama_models.llama3.api.datatypes import (
SamplingParams,
StopReason,
ToolDefinition,
ToolPromptFormat,
)
from llama_models.sku_list import resolve_model
from llama_stack.apis.models import Model
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionMessage,
CompletionRequest,
CompletionResponse,
CompletionResponseStreamChunk,
Inference,
InterleavedContent,
LogProbConfig,
Message,
ResponseFormat,
TokenLogProbs,
ToolCallDelta,
ToolCallParseStatus,
ToolChoice,
)
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.providers.utils.inference.model_registry import build_model_alias
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.models import Model, ModelType
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.embedding_mixin import (
SentenceTransformerEmbeddingMixin,
)
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
convert_image_media_to_url,
request_has_media,
augment_content_with_response_format_prompt,
chat_completion_request_to_messages,
convert_request_to_raw,
)
from .config import MetaReferenceInferenceConfig
@ -44,7 +70,8 @@ class MetaReferenceInferenceImpl(
):
def __init__(self, config: MetaReferenceInferenceConfig) -> None:
self.config = config
self.model = None
self.model_id = None
self.llama_model = None
async def initialize(self, model_id, llama_model) -> None:
log.info(f"Loading model `{model_id}`")
@ -56,20 +83,21 @@ class MetaReferenceInferenceImpl(
else:
self.generator = Llama.build(self.config, model_id, llama_model)
self.model = model_id
self.model_id = model_id
self.llama_model = llama_model
async def shutdown(self) -> None:
if self.config.create_distributed_process_group:
self.generator.stop()
def check_model(self, request) -> None:
if self.model is None:
if self.model_id or self.llama_model is None:
raise RuntimeError(
"No avaible model yet, please register your requested model or add your model in the resouces first"
)
elif request.model != self.model:
elif request.model != self.model_id:
raise RuntimeError(
f"Model mismatch: request model: {request.model} != loaded model: {self.model}"
f"Model mismatch: request model: {request.model} != loaded model: {self.model_id}"
)
async def unregister_model(self, model_id: str) -> None:
@ -107,7 +135,7 @@ class MetaReferenceInferenceImpl(
async def completion(
self,
model_id: str,
content: InterleavedTextMedia,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
@ -116,6 +144,7 @@ class MetaReferenceInferenceImpl(
if logprobs:
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
content = augment_content_with_response_format_prompt(response_format, content)
request = CompletionRequest(
model=model_id,
content=content,
@ -125,7 +154,7 @@ class MetaReferenceInferenceImpl(
logprobs=logprobs,
)
self.check_model(request)
request = await request_with_localized_media(request)
request = await convert_request_to_raw(request)
if request.stream:
return self._stream_completion(request)
@ -250,7 +279,13 @@ class MetaReferenceInferenceImpl(
logprobs=logprobs,
)
self.check_model(request)
request = await request_with_localized_media(request)
# augment and rewrite messages depending on the model
request.messages = chat_completion_request_to_messages(
request, self.llama_model.core_model_id.value
)
# download media and convert to raw content so we can send it to the model
request = await convert_request_to_raw(request)
if self.config.create_distributed_process_group:
if SEMAPHORE.locked():
@ -291,11 +326,15 @@ class MetaReferenceInferenceImpl(
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
message = self.generator.formatter.decode_assistant_message(
raw_message = self.generator.formatter.decode_assistant_message(
tokens, stop_reason
)
return ChatCompletionResponse(
completion_message=message,
completion_message=CompletionMessage(
content=raw_message.content,
stop_reason=raw_message.stop_reason,
tool_calls=raw_message.tool_calls,
),
logprobs=logprobs if request.logprobs else None,
)
@ -421,31 +460,3 @@ class MetaReferenceInferenceImpl(
else:
for x in impl():
yield x
async def request_with_localized_media(
request: Union[ChatCompletionRequest, CompletionRequest],
) -> Union[ChatCompletionRequest, CompletionRequest]:
if not request_has_media(request):
return request
async def _convert_single_content(content):
if isinstance(content, ImageMedia):
url = await convert_image_media_to_url(content, download=True)
return ImageMedia(image=URL(uri=url))
else:
return content
async def _convert_content(content):
if isinstance(content, list):
return [await _convert_single_content(c) for c in content]
else:
return await _convert_single_content(content)
if isinstance(request, ChatCompletionRequest):
for m in request.messages:
m.content = await _convert_content(m.content)
else:
request.content = await _convert_content(request.content)
return request