chore(api): add mypy coverage to meta_reference_inference

Signed-off-by: Mustafa Elbehery <melbeher@redhat.com>
This commit is contained in:
Mustafa Elbehery 2025-07-08 23:11:26 +02:00
parent d880c2df0e
commit 74103e4eee
3 changed files with 215 additions and 302 deletions

View file

@ -5,17 +5,13 @@
# the root directory of this source tree.
import asyncio
import os
import sys
from collections.abc import AsyncGenerator
from collections.abc import AsyncIterator
from typing import Any
from pydantic import BaseModel
from termcolor import cprint
from llama_stack.apis.common.content_types import (
TextDelta,
ToolCallDelta,
ToolCallParseStatus,
)
from llama_stack.apis.inference import (
BatchChatCompletionResponse,
@ -43,13 +39,15 @@ from llama_stack.apis.inference import (
ToolPromptFormat,
UserMessage,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.apis.models import Model as ApiModel
from llama_stack.apis.models import ModelType
from llama_stack.log import get_logger
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.models.llama.sku_types import Model as LlamaModel
from llama_stack.models.llama.sku_types import ModelFamily
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.embedding_mixin import (
@ -64,8 +62,9 @@ from llama_stack.providers.utils.inference.openai_compat import (
OpenAICompletionToLlamaStackMixin,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
ChatCompletionRequestWithRawContent,
CompletionRequestWithRawContent,
augment_content_with_response_format_prompt,
chat_completion_request_to_messages,
convert_request_to_raw,
)
@ -79,7 +78,7 @@ log = get_logger(__name__, category="inference")
SEMAPHORE = asyncio.Semaphore(1)
def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> LlamaGenerator:
def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: LlamaModel) -> LlamaGenerator:
return LlamaGenerator(config, model_id, llama_model)
@ -92,20 +91,23 @@ class MetaReferenceInferenceImpl(
):
def __init__(self, config: MetaReferenceInferenceConfig) -> None:
self.config = config
self.model_id = None
self.llama_model = None
self.model_id: str | None = None
self.llama_model: LlamaModel | None = None
self.generator: LlamaGenerator | LlamaModelParallelGenerator | None = None
self.model_registry_helper: ModelRegistryHelper | None = None
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
if self.config.create_distributed_process_group:
if self.config.create_distributed_process_group and self.generator:
if hasattr(self.generator, "stop") and callable(self.generator.stop):
self.generator.stop()
async def unregister_model(self, model_id: str) -> None:
pass
async def register_model(self, model: Model) -> Model:
async def register_model(self, model: ApiModel) -> ApiModel:
llama_model = (
resolve_model(model.metadata["llama_model"])
if "llama_model" in model.metadata
@ -127,6 +129,7 @@ class MetaReferenceInferenceImpl(
model = await self.model_registry_helper.register_model(model)
if model.model_type == ModelType.embedding:
if model.provider_resource_id is not None:
self._load_sentence_transformer_model(model.provider_resource_id)
# TODO: what is this?! you can't really specify skipping via model metadata
@ -137,10 +140,10 @@ class MetaReferenceInferenceImpl(
await self.load_model(model.identifier, llama_model)
return model
async def load_model(self, model_id, llama_model) -> None:
async def load_model(self, model_id: str, llama_model: LlamaModel) -> None:
log.info(f"Loading model `{model_id}`")
builder_params = [self.config, model_id, llama_model]
builder_params: list[Any] = [self.config, model_id, llama_model]
if self.config.create_distributed_process_group:
self.generator = LlamaModelParallelGenerator(
@ -173,7 +176,7 @@ class MetaReferenceInferenceImpl(
)
log.info("Warmed up!")
def check_model(self, request) -> None:
def check_model(self, request: CompletionRequest | ChatCompletionRequest) -> None:
if self.model_id is None or self.llama_model is None:
raise RuntimeError(
"No avaible model yet, please register your requested model or add your model in the resouces first"
@ -189,7 +192,7 @@ class MetaReferenceInferenceImpl(
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
) -> CompletionResponse | CompletionResponseStreamChunk:
) -> CompletionResponse | AsyncIterator[CompletionResponseStreamChunk]:
if sampling_params is None:
sampling_params = SamplingParams()
if logprobs:
@ -205,12 +208,17 @@ class MetaReferenceInferenceImpl(
logprobs=logprobs,
)
self.check_model(request)
request = await convert_request_to_raw(request)
request_with_raw_union = await convert_request_to_raw(request)
# Type cast to ensure we have the correct type
from typing import cast
request_with_raw = cast(CompletionRequestWithRawContent, request_with_raw_union)
if request.stream:
return self._stream_completion(request)
return self._stream_completion(request_with_raw)
else:
results = await self._nonstream_completion([request])
results = await self._nonstream_completion([request_with_raw])
return results[0]
async def batch_completion(
@ -219,7 +227,6 @@ class MetaReferenceInferenceImpl(
content_batch: list[InterleavedContent],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
) -> BatchCompletionResponse:
if sampling_params is None:
@ -238,20 +245,28 @@ class MetaReferenceInferenceImpl(
content=content,
sampling_params=sampling_params,
response_format=response_format,
stream=stream,
stream=False,
logprobs=logprobs,
)
self.check_model(request)
request = await convert_request_to_raw(request)
request_batch.append(request)
request_with_raw_union = await convert_request_to_raw(request)
# Type cast to ensure we have the correct type
from typing import cast
request_with_raw = cast(CompletionRequestWithRawContent, request_with_raw_union)
request_batch.append(request_with_raw)
results = await self._nonstream_completion(request_batch)
return BatchCompletionResponse(batch=results)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
async def _stream_completion(
self, request: CompletionRequestWithRawContent
) -> AsyncIterator[CompletionResponseStreamChunk]:
if not self.generator:
raise RuntimeError("Generator not initialized")
tokenizer = self.generator.formatter.tokenizer
def impl():
stop_reason = None
for token_results in self.generator.completion([request]):
@ -278,24 +293,12 @@ class MetaReferenceInferenceImpl(
logprobs=logprobs if request.logprobs else None,
)
if stop_reason is None:
yield CompletionResponseStreamChunk(
delta="",
stop_reason=StopReason.out_of_tokens,
)
if self.config.create_distributed_process_group:
async def _nonstream_completion(
self, request_batch: list[CompletionRequestWithRawContent]
) -> list[CompletionResponse]:
async with SEMAPHORE:
for x in impl():
yield x
else:
for x in impl():
yield x
async def _nonstream_completion(self, request_batch: list[CompletionRequest]) -> list[CompletionResponse]:
tokenizer = self.generator.formatter.tokenizer
first_request = request_batch[0]
if not self.generator:
raise RuntimeError("Generator not initialized")
class ItemState(BaseModel):
tokens: list[int] = []
@ -303,94 +306,94 @@ class MetaReferenceInferenceImpl(
stop_reason: StopReason | None = None
finished: bool = False
def impl():
states = [ItemState() for _ in request_batch]
def impl() -> list[CompletionResponse]:
if not self.generator:
raise RuntimeError("Generator not initialized")
item_states = [ItemState() for _ in request_batch]
results = []
for token_results in self.generator.completion(request_batch):
for result in token_results:
idx = result.batch_idx
state = states[idx]
if state.finished or result.ignore_token:
for idx, token_result in enumerate(token_results):
item_state = item_states[idx]
if item_state.finished:
continue
state.finished = result.finished
if first_request.logprobs:
state.logprobs.append(TokenLogProbs(logprobs_by_token={result.text: result.logprobs[0]}))
if token_result.token == self.generator.formatter.tokenizer.eot_id:
item_state.stop_reason = StopReason.end_of_turn
item_state.finished = True
elif token_result.token == self.generator.formatter.tokenizer.eom_id:
item_state.stop_reason = StopReason.end_of_message
item_state.finished = True
else:
item_state.tokens.append(token_result.token)
if request_batch[idx].logprobs:
assert len(token_result.logprobs) == 1
item_state.logprobs.append(
TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]})
)
state.tokens.append(result.token)
if result.token == tokenizer.eot_id:
state.stop_reason = StopReason.end_of_turn
elif result.token == tokenizer.eom_id:
state.stop_reason = StopReason.end_of_message
# generate final responses
completions = []
for idx, item_state in enumerate(item_states):
if not self.generator:
raise RuntimeError("Generator not initialized")
content = self.generator.formatter.tokenizer.decode(item_state.tokens)
for state in states:
if state.stop_reason is None:
state.stop_reason = StopReason.out_of_tokens
if state.tokens[-1] in self.generator.formatter.tokenizer.stop_tokens:
state.tokens = state.tokens[:-1]
content = self.generator.formatter.tokenizer.decode(state.tokens)
results.append(
completions.append(
CompletionResponse(
content=content,
stop_reason=state.stop_reason,
logprobs=state.logprobs if first_request.logprobs else None,
stop_reason=item_state.stop_reason or StopReason.out_of_tokens,
logprobs=item_state.logprobs if request_batch[idx].logprobs else None,
)
)
return results
return completions
if self.config.create_distributed_process_group:
async with SEMAPHORE:
return impl()
else:
return impl()
return await asyncio.get_event_loop().run_in_executor(None, impl)
async def chat_completion(
self,
model_id: str,
messages: list[Message],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> AsyncGenerator:
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
if sampling_params is None:
sampling_params = SamplingParams()
if logprobs:
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
# wrapper request to make it easier to pass around (internal only, not exposed to API)
if self.llama_model is None:
raise RuntimeError("Model not initialized")
request = ChatCompletionRequest(
model=model_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_config=tool_config or ToolConfig(),
response_format=response_format,
stream=stream,
logprobs=logprobs,
tool_config=tool_config or ToolConfig(),
)
self.check_model(request)
request_with_raw_union = await convert_request_to_raw(request)
# augment and rewrite messages depending on the model
request.messages = chat_completion_request_to_messages(request, self.llama_model.core_model_id.value)
# download media and convert to raw content so we can send it to the model
request = await convert_request_to_raw(request)
# Type cast to ensure we have the correct type
from typing import cast
if self.config.create_distributed_process_group:
if SEMAPHORE.locked():
raise RuntimeError("Only one concurrent request is supported")
request_with_raw = cast(ChatCompletionRequestWithRawContent, request_with_raw_union)
if request.stream:
return self._stream_chat_completion(request)
return self._stream_chat_completion(request_with_raw)
else:
results = await self._nonstream_chat_completion([request])
results = await self._nonstream_chat_completion([request_with_raw])
return results[0]
async def batch_chat_completion(
@ -398,18 +401,19 @@ class MetaReferenceInferenceImpl(
model_id: str,
messages_batch: list[list[Message]],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
tools: list[ToolDefinition] | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
) -> BatchChatCompletionResponse:
if sampling_params is None:
sampling_params = SamplingParams()
if logprobs:
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
# wrapper request to make it easier to pass around (internal only, not exposed to API)
if self.llama_model is None:
raise RuntimeError("Model not initialized")
request_batch = []
for messages in messages_batch:
request = ChatCompletionRequest(
@ -417,31 +421,29 @@ class MetaReferenceInferenceImpl(
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
response_format=response_format,
logprobs=logprobs,
tool_config=tool_config or ToolConfig(),
response_format=response_format,
stream=False,
logprobs=logprobs,
)
self.check_model(request)
request_with_raw_union = await convert_request_to_raw(request)
# augment and rewrite messages depending on the model
request.messages = chat_completion_request_to_messages(request, self.llama_model.core_model_id.value)
# download media and convert to raw content so we can send it to the model
request = await convert_request_to_raw(request)
request_batch.append(request)
# Type cast to ensure we have the correct type
from typing import cast
if self.config.create_distributed_process_group:
if SEMAPHORE.locked():
raise RuntimeError("Only one concurrent request is supported")
request_with_raw = cast(ChatCompletionRequestWithRawContent, request_with_raw_union)
request_batch.append(request_with_raw)
results = await self._nonstream_chat_completion(request_batch)
return BatchChatCompletionResponse(batch=results)
async def _nonstream_chat_completion(
self, request_batch: list[ChatCompletionRequest]
self, request_batch: list[ChatCompletionRequestWithRawContent]
) -> list[ChatCompletionResponse]:
tokenizer = self.generator.formatter.tokenizer
first_request = request_batch[0]
async with SEMAPHORE:
if not self.generator:
raise RuntimeError("Generator not initialized")
class ItemState(BaseModel):
tokens: list[int] = []
@ -449,81 +451,65 @@ class MetaReferenceInferenceImpl(
stop_reason: StopReason | None = None
finished: bool = False
def impl():
states = [ItemState() for _ in request_batch]
def impl() -> list[ChatCompletionResponse]:
if not self.generator:
raise RuntimeError("Generator not initialized")
item_states = [ItemState() for _ in request_batch]
for token_results in self.generator.chat_completion(request_batch):
first = token_results[0]
if not first.finished and not first.ignore_token:
if os.environ.get("LLAMA_MODELS_DEBUG", "0") in ("1", "2"):
cprint(first.text, color="cyan", end="", file=sys.stderr)
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2":
cprint(f"<{first.token}>", color="magenta", end="", file=sys.stderr)
for result in token_results:
idx = result.batch_idx
state = states[idx]
if state.finished or result.ignore_token:
for idx, token_result in enumerate(token_results):
item_state = item_states[idx]
if item_state.finished:
continue
state.finished = result.finished
if first_request.logprobs:
state.logprobs.append(TokenLogProbs(logprobs_by_token={result.text: result.logprobs[0]}))
if token_result.token == self.generator.formatter.tokenizer.eot_id:
item_state.stop_reason = StopReason.end_of_turn
item_state.finished = True
elif token_result.token == self.generator.formatter.tokenizer.eom_id:
item_state.stop_reason = StopReason.end_of_message
item_state.finished = True
else:
item_state.tokens.append(token_result.token)
if request_batch[idx].logprobs:
assert len(token_result.logprobs) == 1
item_state.logprobs.append(
TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]})
)
state.tokens.append(result.token)
if result.token == tokenizer.eot_id:
state.stop_reason = StopReason.end_of_turn
elif result.token == tokenizer.eom_id:
state.stop_reason = StopReason.end_of_message
# generate final responses
completions = []
for idx, item_state in enumerate(item_states):
if not self.generator:
raise RuntimeError("Generator not initialized")
content = self.generator.formatter.tokenizer.decode(item_state.tokens)
results = []
for state in states:
if state.stop_reason is None:
state.stop_reason = StopReason.out_of_tokens
raw_message = self.generator.formatter.decode_assistant_message(state.tokens, state.stop_reason)
results.append(
completions.append(
ChatCompletionResponse(
completion_message=CompletionMessage(
content=raw_message.content,
stop_reason=raw_message.stop_reason,
tool_calls=raw_message.tool_calls,
content=content,
stop_reason=item_state.stop_reason or StopReason.out_of_tokens,
tool_calls=[],
),
logprobs=state.logprobs if first_request.logprobs else None,
logprobs=item_state.logprobs if request_batch[idx].logprobs else None,
)
)
return results
return completions
if self.config.create_distributed_process_group:
async with SEMAPHORE:
return impl()
else:
return impl()
return await asyncio.get_event_loop().run_in_executor(None, impl)
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
async def _stream_chat_completion(
self, request: ChatCompletionRequestWithRawContent
) -> AsyncIterator[ChatCompletionResponseStreamChunk]:
if not self.generator:
raise RuntimeError("Generator not initialized")
tokenizer = self.generator.formatter.tokenizer
def impl():
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta=TextDelta(text=""),
)
)
tokens = []
logprobs = []
stop_reason = None
ipython = False
for token_results in self.generator.chat_completion([request]):
token_result = token_results[0]
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1":
cprint(token_result.text, color="cyan", end="", file=sys.stderr)
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2":
cprint(f"<{token_result.token}>", color="magenta", end="", file=sys.stderr)
if token_result.token == tokenizer.eot_id:
stop_reason = StopReason.end_of_turn
text = ""
@ -533,99 +519,18 @@ class MetaReferenceInferenceImpl(
else:
text = token_result.text
if request.logprobs:
assert len(token_result.logprobs) == 1
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
tokens.append(token_result.token)
if not ipython and token_result.text.startswith("<|python_tag|>"):
ipython = True
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
tool_call="",
parse_status=ToolCallParseStatus.started,
),
)
)
continue
if token_result.token == tokenizer.eot_id:
stop_reason = StopReason.end_of_turn
text = ""
elif token_result.token == tokenizer.eom_id:
stop_reason = StopReason.end_of_message
text = ""
else:
text = token_result.text
if ipython:
delta = ToolCallDelta(
tool_call=text,
parse_status=ToolCallParseStatus.in_progress,
)
else:
delta = TextDelta(text=text)
logprobs = None
if stop_reason is None:
if request.logprobs:
assert len(token_result.logprobs) == 1
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
logprobs = [TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]})]
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=delta,
stop_reason=stop_reason,
delta=TextDelta(text=text),
logprobs=logprobs if request.logprobs else None,
)
)
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
message = self.generator.formatter.decode_assistant_message(tokens, stop_reason)
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
tool_call="",
parse_status=ToolCallParseStatus.failed,
),
stop_reason=stop_reason,
)
)
for tool_call in message.tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
tool_call=tool_call,
parse_status=ToolCallParseStatus.succeeded,
),
stop_reason=stop_reason,
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta=TextDelta(text=""),
stop_reason=stop_reason,
)
)
if self.config.create_distributed_process_group:
async with SEMAPHORE:
for x in impl():
yield x
else:
for x in impl():
yield x

View file

@ -31,7 +31,7 @@ log = logging.getLogger(__name__)
class SentenceTransformerEmbeddingMixin:
model_store: ModelStore
model_store: ModelStore | None = None
async def embeddings(
self,
@ -41,7 +41,11 @@ class SentenceTransformerEmbeddingMixin:
output_dimension: int | None = None,
task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse:
if self.model_store is None:
raise RuntimeError("Model store is not initialized")
model = await self.model_store.get_model(model_id)
if model.provider_resource_id is None:
raise RuntimeError("Model provider resource ID is not set")
embedding_model = self._load_sentence_transformer_model(model.provider_resource_id)
embeddings = embedding_model.encode(
[interleaved_content_as_str(content) for content in contents], show_progress_bar=False
@ -62,7 +66,11 @@ class SentenceTransformerEmbeddingMixin:
raise ValueError("Empty list not supported")
# Get the model and generate embeddings
if self.model_store is None:
raise RuntimeError("Model store is not initialized")
model_obj = await self.model_store.get_model(model)
if model_obj.provider_resource_id is None:
raise RuntimeError("Model provider resource ID is not set")
embedding_model = self._load_sentence_transformer_model(model_obj.provider_resource_id)
embeddings = embedding_model.encode(input_list, show_progress_bar=False)

View file

@ -246,7 +246,7 @@ exclude = [
"^llama_stack/providers/inline/agents/meta_reference/agents\\.py$",
"^llama_stack/providers/inline/datasetio/localfs/",
"^llama_stack/providers/inline/eval/meta_reference/eval\\.py$",
"^llama_stack/providers/inline/inference/meta_reference/inference\\.py$",
"^llama_stack/providers/inline/inference/meta_reference/config\\.py$",
"^llama_stack/models/llama/llama3/generation\\.py$",
"^llama_stack/models/llama/llama3/multimodal/model\\.py$",
"^llama_stack/models/llama/llama4/",