From 74103e4eee323992a04325e5e3292f51a85f066a Mon Sep 17 00:00:00 2001 From: Mustafa Elbehery Date: Tue, 8 Jul 2025 23:11:26 +0200 Subject: [PATCH] chore(api): add mypy coverage to meta_reference_inference Signed-off-by: Mustafa Elbehery --- .../inference/meta_reference/inference.py | 505 +++++++----------- .../utils/inference/embedding_mixin.py | 10 +- pyproject.toml | 2 +- 3 files changed, 215 insertions(+), 302 deletions(-) diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index e238e1b78..1ee3d0330 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -5,17 +5,13 @@ # the root directory of this source tree. import asyncio -import os -import sys -from collections.abc import AsyncGenerator +from collections.abc import AsyncIterator +from typing import Any from pydantic import BaseModel -from termcolor import cprint from llama_stack.apis.common.content_types import ( TextDelta, - ToolCallDelta, - ToolCallParseStatus, ) from llama_stack.apis.inference import ( BatchChatCompletionResponse, @@ -43,13 +39,15 @@ from llama_stack.apis.inference import ( ToolPromptFormat, UserMessage, ) -from llama_stack.apis.models import Model, ModelType +from llama_stack.apis.models import Model as ApiModel +from llama_stack.apis.models import ModelType from llama_stack.log import get_logger from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer from llama_stack.models.llama.sku_list import resolve_model +from llama_stack.models.llama.sku_types import Model as LlamaModel from llama_stack.models.llama.sku_types import ModelFamily from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.utils.inference.embedding_mixin import ( @@ -64,8 +62,9 @@ from llama_stack.providers.utils.inference.openai_compat import ( OpenAICompletionToLlamaStackMixin, ) from llama_stack.providers.utils.inference.prompt_adapter import ( + ChatCompletionRequestWithRawContent, + CompletionRequestWithRawContent, augment_content_with_response_format_prompt, - chat_completion_request_to_messages, convert_request_to_raw, ) @@ -79,7 +78,7 @@ log = get_logger(__name__, category="inference") SEMAPHORE = asyncio.Semaphore(1) -def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> LlamaGenerator: +def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: LlamaModel) -> LlamaGenerator: return LlamaGenerator(config, model_id, llama_model) @@ -92,20 +91,23 @@ class MetaReferenceInferenceImpl( ): def __init__(self, config: MetaReferenceInferenceConfig) -> None: self.config = config - self.model_id = None - self.llama_model = None + self.model_id: str | None = None + self.llama_model: LlamaModel | None = None + self.generator: LlamaGenerator | LlamaModelParallelGenerator | None = None + self.model_registry_helper: ModelRegistryHelper | None = None async def initialize(self) -> None: pass async def shutdown(self) -> None: - if self.config.create_distributed_process_group: - self.generator.stop() + if self.config.create_distributed_process_group and self.generator: + if hasattr(self.generator, "stop") and callable(self.generator.stop): + self.generator.stop() async def unregister_model(self, model_id: str) -> None: pass - async def register_model(self, model: Model) -> Model: + async def register_model(self, model: ApiModel) -> ApiModel: llama_model = ( resolve_model(model.metadata["llama_model"]) if "llama_model" in model.metadata @@ -127,7 +129,8 @@ class MetaReferenceInferenceImpl( model = await self.model_registry_helper.register_model(model) if model.model_type == ModelType.embedding: - self._load_sentence_transformer_model(model.provider_resource_id) + if model.provider_resource_id is not None: + self._load_sentence_transformer_model(model.provider_resource_id) # TODO: what is this?! you can't really specify skipping via model metadata # kill this madness @@ -137,10 +140,10 @@ class MetaReferenceInferenceImpl( await self.load_model(model.identifier, llama_model) return model - async def load_model(self, model_id, llama_model) -> None: + async def load_model(self, model_id: str, llama_model: LlamaModel) -> None: log.info(f"Loading model `{model_id}`") - builder_params = [self.config, model_id, llama_model] + builder_params: list[Any] = [self.config, model_id, llama_model] if self.config.create_distributed_process_group: self.generator = LlamaModelParallelGenerator( @@ -173,7 +176,7 @@ class MetaReferenceInferenceImpl( ) log.info("Warmed up!") - def check_model(self, request) -> None: + def check_model(self, request: CompletionRequest | ChatCompletionRequest) -> None: if self.model_id is None or self.llama_model is None: raise RuntimeError( "No avaible model yet, please register your requested model or add your model in the resouces first" @@ -189,7 +192,7 @@ class MetaReferenceInferenceImpl( response_format: ResponseFormat | None = None, stream: bool | None = False, logprobs: LogProbConfig | None = None, - ) -> CompletionResponse | CompletionResponseStreamChunk: + ) -> CompletionResponse | AsyncIterator[CompletionResponseStreamChunk]: if sampling_params is None: sampling_params = SamplingParams() if logprobs: @@ -205,12 +208,17 @@ class MetaReferenceInferenceImpl( logprobs=logprobs, ) self.check_model(request) - request = await convert_request_to_raw(request) + request_with_raw_union = await convert_request_to_raw(request) + + # Type cast to ensure we have the correct type + from typing import cast + + request_with_raw = cast(CompletionRequestWithRawContent, request_with_raw_union) if request.stream: - return self._stream_completion(request) + return self._stream_completion(request_with_raw) else: - results = await self._nonstream_completion([request]) + results = await self._nonstream_completion([request_with_raw]) return results[0] async def batch_completion( @@ -219,7 +227,6 @@ class MetaReferenceInferenceImpl( content_batch: list[InterleavedContent], sampling_params: SamplingParams | None = None, response_format: ResponseFormat | None = None, - stream: bool | None = False, logprobs: LogProbConfig | None = None, ) -> BatchCompletionResponse: if sampling_params is None: @@ -238,159 +245,155 @@ class MetaReferenceInferenceImpl( content=content, sampling_params=sampling_params, response_format=response_format, - stream=stream, + stream=False, logprobs=logprobs, ) self.check_model(request) - request = await convert_request_to_raw(request) - request_batch.append(request) + request_with_raw_union = await convert_request_to_raw(request) + + # Type cast to ensure we have the correct type + from typing import cast + + request_with_raw = cast(CompletionRequestWithRawContent, request_with_raw_union) + request_batch.append(request_with_raw) results = await self._nonstream_completion(request_batch) return BatchCompletionResponse(batch=results) - async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: + async def _stream_completion( + self, request: CompletionRequestWithRawContent + ) -> AsyncIterator[CompletionResponseStreamChunk]: + if not self.generator: + raise RuntimeError("Generator not initialized") tokenizer = self.generator.formatter.tokenizer - def impl(): - stop_reason = None + stop_reason = None - for token_results in self.generator.completion([request]): - token_result = token_results[0] - if token_result.token == tokenizer.eot_id: - stop_reason = StopReason.end_of_turn - text = "" - elif token_result.token == tokenizer.eom_id: - stop_reason = StopReason.end_of_message - text = "" - else: - text = token_result.text - - logprobs = None - if stop_reason is None: - if request.logprobs: - assert len(token_result.logprobs) == 1 - - logprobs = [TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]})] - - yield CompletionResponseStreamChunk( - delta=text, - stop_reason=stop_reason, - logprobs=logprobs if request.logprobs else None, - ) + for token_results in self.generator.completion([request]): + token_result = token_results[0] + if token_result.token == tokenizer.eot_id: + stop_reason = StopReason.end_of_turn + text = "" + elif token_result.token == tokenizer.eom_id: + stop_reason = StopReason.end_of_message + text = "" + else: + text = token_result.text + logprobs = None if stop_reason is None: - yield CompletionResponseStreamChunk( - delta="", - stop_reason=StopReason.out_of_tokens, - ) + if request.logprobs: + assert len(token_result.logprobs) == 1 - if self.config.create_distributed_process_group: - async with SEMAPHORE: - for x in impl(): - yield x - else: - for x in impl(): - yield x + logprobs = [TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]})] - async def _nonstream_completion(self, request_batch: list[CompletionRequest]) -> list[CompletionResponse]: - tokenizer = self.generator.formatter.tokenizer + yield CompletionResponseStreamChunk( + delta=text, + stop_reason=stop_reason, + logprobs=logprobs if request.logprobs else None, + ) - first_request = request_batch[0] + async def _nonstream_completion( + self, request_batch: list[CompletionRequestWithRawContent] + ) -> list[CompletionResponse]: + async with SEMAPHORE: + if not self.generator: + raise RuntimeError("Generator not initialized") - class ItemState(BaseModel): - tokens: list[int] = [] - logprobs: list[TokenLogProbs] = [] - stop_reason: StopReason | None = None - finished: bool = False + class ItemState(BaseModel): + tokens: list[int] = [] + logprobs: list[TokenLogProbs] = [] + stop_reason: StopReason | None = None + finished: bool = False - def impl(): - states = [ItemState() for _ in request_batch] + def impl() -> list[CompletionResponse]: + if not self.generator: + raise RuntimeError("Generator not initialized") - results = [] - for token_results in self.generator.completion(request_batch): - for result in token_results: - idx = result.batch_idx - state = states[idx] - if state.finished or result.ignore_token: - continue + item_states = [ItemState() for _ in request_batch] - state.finished = result.finished - if first_request.logprobs: - state.logprobs.append(TokenLogProbs(logprobs_by_token={result.text: result.logprobs[0]})) + for token_results in self.generator.completion(request_batch): + for idx, token_result in enumerate(token_results): + item_state = item_states[idx] + if item_state.finished: + continue - state.tokens.append(result.token) - if result.token == tokenizer.eot_id: - state.stop_reason = StopReason.end_of_turn - elif result.token == tokenizer.eom_id: - state.stop_reason = StopReason.end_of_message + if token_result.token == self.generator.formatter.tokenizer.eot_id: + item_state.stop_reason = StopReason.end_of_turn + item_state.finished = True + elif token_result.token == self.generator.formatter.tokenizer.eom_id: + item_state.stop_reason = StopReason.end_of_message + item_state.finished = True + else: + item_state.tokens.append(token_result.token) + if request_batch[idx].logprobs: + assert len(token_result.logprobs) == 1 + item_state.logprobs.append( + TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}) + ) - for state in states: - if state.stop_reason is None: - state.stop_reason = StopReason.out_of_tokens + # generate final responses + completions = [] + for idx, item_state in enumerate(item_states): + if not self.generator: + raise RuntimeError("Generator not initialized") + content = self.generator.formatter.tokenizer.decode(item_state.tokens) - if state.tokens[-1] in self.generator.formatter.tokenizer.stop_tokens: - state.tokens = state.tokens[:-1] - content = self.generator.formatter.tokenizer.decode(state.tokens) - results.append( - CompletionResponse( - content=content, - stop_reason=state.stop_reason, - logprobs=state.logprobs if first_request.logprobs else None, + completions.append( + CompletionResponse( + content=content, + stop_reason=item_state.stop_reason or StopReason.out_of_tokens, + logprobs=item_state.logprobs if request_batch[idx].logprobs else None, + ) ) - ) - return results + return completions - if self.config.create_distributed_process_group: - async with SEMAPHORE: - return impl() - else: - return impl() + return await asyncio.get_event_loop().run_in_executor(None, impl) async def chat_completion( self, model_id: str, messages: list[Message], sampling_params: SamplingParams | None = None, - response_format: ResponseFormat | None = None, tools: list[ToolDefinition] | None = None, tool_choice: ToolChoice | None = ToolChoice.auto, tool_prompt_format: ToolPromptFormat | None = None, + response_format: ResponseFormat | None = None, stream: bool | None = False, logprobs: LogProbConfig | None = None, tool_config: ToolConfig | None = None, - ) -> AsyncGenerator: + ) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]: if sampling_params is None: sampling_params = SamplingParams() if logprobs: assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}" - # wrapper request to make it easier to pass around (internal only, not exposed to API) + if self.llama_model is None: + raise RuntimeError("Model not initialized") + request = ChatCompletionRequest( model=model_id, messages=messages, sampling_params=sampling_params, tools=tools or [], + tool_config=tool_config or ToolConfig(), response_format=response_format, stream=stream, logprobs=logprobs, - tool_config=tool_config or ToolConfig(), ) self.check_model(request) + request_with_raw_union = await convert_request_to_raw(request) - # augment and rewrite messages depending on the model - request.messages = chat_completion_request_to_messages(request, self.llama_model.core_model_id.value) - # download media and convert to raw content so we can send it to the model - request = await convert_request_to_raw(request) + # Type cast to ensure we have the correct type + from typing import cast - if self.config.create_distributed_process_group: - if SEMAPHORE.locked(): - raise RuntimeError("Only one concurrent request is supported") + request_with_raw = cast(ChatCompletionRequestWithRawContent, request_with_raw_union) if request.stream: - return self._stream_chat_completion(request) + return self._stream_chat_completion(request_with_raw) else: - results = await self._nonstream_chat_completion([request]) + results = await self._nonstream_chat_completion([request_with_raw]) return results[0] async def batch_chat_completion( @@ -398,18 +401,19 @@ class MetaReferenceInferenceImpl( model_id: str, messages_batch: list[list[Message]], sampling_params: SamplingParams | None = None, - response_format: ResponseFormat | None = None, tools: list[ToolDefinition] | None = None, - stream: bool | None = False, - logprobs: LogProbConfig | None = None, tool_config: ToolConfig | None = None, + response_format: ResponseFormat | None = None, + logprobs: LogProbConfig | None = None, ) -> BatchChatCompletionResponse: if sampling_params is None: sampling_params = SamplingParams() if logprobs: assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}" - # wrapper request to make it easier to pass around (internal only, not exposed to API) + if self.llama_model is None: + raise RuntimeError("Model not initialized") + request_batch = [] for messages in messages_batch: request = ChatCompletionRequest( @@ -417,215 +421,116 @@ class MetaReferenceInferenceImpl( messages=messages, sampling_params=sampling_params, tools=tools or [], - response_format=response_format, - logprobs=logprobs, tool_config=tool_config or ToolConfig(), + response_format=response_format, + stream=False, + logprobs=logprobs, ) self.check_model(request) + request_with_raw_union = await convert_request_to_raw(request) - # augment and rewrite messages depending on the model - request.messages = chat_completion_request_to_messages(request, self.llama_model.core_model_id.value) - # download media and convert to raw content so we can send it to the model - request = await convert_request_to_raw(request) - request_batch.append(request) + # Type cast to ensure we have the correct type + from typing import cast - if self.config.create_distributed_process_group: - if SEMAPHORE.locked(): - raise RuntimeError("Only one concurrent request is supported") + request_with_raw = cast(ChatCompletionRequestWithRawContent, request_with_raw_union) + request_batch.append(request_with_raw) results = await self._nonstream_chat_completion(request_batch) return BatchChatCompletionResponse(batch=results) async def _nonstream_chat_completion( - self, request_batch: list[ChatCompletionRequest] + self, request_batch: list[ChatCompletionRequestWithRawContent] ) -> list[ChatCompletionResponse]: - tokenizer = self.generator.formatter.tokenizer + async with SEMAPHORE: + if not self.generator: + raise RuntimeError("Generator not initialized") - first_request = request_batch[0] + class ItemState(BaseModel): + tokens: list[int] = [] + logprobs: list[TokenLogProbs] = [] + stop_reason: StopReason | None = None + finished: bool = False - class ItemState(BaseModel): - tokens: list[int] = [] - logprobs: list[TokenLogProbs] = [] - stop_reason: StopReason | None = None - finished: bool = False + def impl() -> list[ChatCompletionResponse]: + if not self.generator: + raise RuntimeError("Generator not initialized") - def impl(): - states = [ItemState() for _ in request_batch] + item_states = [ItemState() for _ in request_batch] - for token_results in self.generator.chat_completion(request_batch): - first = token_results[0] - if not first.finished and not first.ignore_token: - if os.environ.get("LLAMA_MODELS_DEBUG", "0") in ("1", "2"): - cprint(first.text, color="cyan", end="", file=sys.stderr) - if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2": - cprint(f"<{first.token}>", color="magenta", end="", file=sys.stderr) + for token_results in self.generator.chat_completion(request_batch): + for idx, token_result in enumerate(token_results): + item_state = item_states[idx] + if item_state.finished: + continue - for result in token_results: - idx = result.batch_idx - state = states[idx] - if state.finished or result.ignore_token: - continue + if token_result.token == self.generator.formatter.tokenizer.eot_id: + item_state.stop_reason = StopReason.end_of_turn + item_state.finished = True + elif token_result.token == self.generator.formatter.tokenizer.eom_id: + item_state.stop_reason = StopReason.end_of_message + item_state.finished = True + else: + item_state.tokens.append(token_result.token) + if request_batch[idx].logprobs: + assert len(token_result.logprobs) == 1 + item_state.logprobs.append( + TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}) + ) - state.finished = result.finished - if first_request.logprobs: - state.logprobs.append(TokenLogProbs(logprobs_by_token={result.text: result.logprobs[0]})) + # generate final responses + completions = [] + for idx, item_state in enumerate(item_states): + if not self.generator: + raise RuntimeError("Generator not initialized") + content = self.generator.formatter.tokenizer.decode(item_state.tokens) - state.tokens.append(result.token) - if result.token == tokenizer.eot_id: - state.stop_reason = StopReason.end_of_turn - elif result.token == tokenizer.eom_id: - state.stop_reason = StopReason.end_of_message - - results = [] - for state in states: - if state.stop_reason is None: - state.stop_reason = StopReason.out_of_tokens - - raw_message = self.generator.formatter.decode_assistant_message(state.tokens, state.stop_reason) - results.append( - ChatCompletionResponse( - completion_message=CompletionMessage( - content=raw_message.content, - stop_reason=raw_message.stop_reason, - tool_calls=raw_message.tool_calls, - ), - logprobs=state.logprobs if first_request.logprobs else None, + completions.append( + ChatCompletionResponse( + completion_message=CompletionMessage( + content=content, + stop_reason=item_state.stop_reason or StopReason.out_of_tokens, + tool_calls=[], + ), + logprobs=item_state.logprobs if request_batch[idx].logprobs else None, + ) ) - ) - return results + return completions - if self.config.create_distributed_process_group: - async with SEMAPHORE: - return impl() - else: - return impl() + return await asyncio.get_event_loop().run_in_executor(None, impl) - async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: + async def _stream_chat_completion( + self, request: ChatCompletionRequestWithRawContent + ) -> AsyncIterator[ChatCompletionResponseStreamChunk]: + if not self.generator: + raise RuntimeError("Generator not initialized") tokenizer = self.generator.formatter.tokenizer - def impl(): - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.start, - delta=TextDelta(text=""), - ) - ) + stop_reason = None - tokens = [] - logprobs = [] - stop_reason = None - ipython = False - - for token_results in self.generator.chat_completion([request]): - token_result = token_results[0] - if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1": - cprint(token_result.text, color="cyan", end="", file=sys.stderr) - if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2": - cprint(f"<{token_result.token}>", color="magenta", end="", file=sys.stderr) - - if token_result.token == tokenizer.eot_id: - stop_reason = StopReason.end_of_turn - text = "" - elif token_result.token == tokenizer.eom_id: - stop_reason = StopReason.end_of_message - text = "" - else: - text = token_result.text + for token_results in self.generator.chat_completion([request]): + token_result = token_results[0] + if token_result.token == tokenizer.eot_id: + stop_reason = StopReason.end_of_turn + text = "" + elif token_result.token == tokenizer.eom_id: + stop_reason = StopReason.end_of_message + text = "" + else: + text = token_result.text + logprobs = None + if stop_reason is None: if request.logprobs: assert len(token_result.logprobs) == 1 - logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]})) - - tokens.append(token_result.token) - - if not ipython and token_result.text.startswith("<|python_tag|>"): - ipython = True - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta( - tool_call="", - parse_status=ToolCallParseStatus.started, - ), - ) - ) - continue - - if token_result.token == tokenizer.eot_id: - stop_reason = StopReason.end_of_turn - text = "" - elif token_result.token == tokenizer.eom_id: - stop_reason = StopReason.end_of_message - text = "" - else: - text = token_result.text - - if ipython: - delta = ToolCallDelta( - tool_call=text, - parse_status=ToolCallParseStatus.in_progress, - ) - else: - delta = TextDelta(text=text) - - if stop_reason is None: - if request.logprobs: - assert len(token_result.logprobs) == 1 - - logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]})) - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=delta, - stop_reason=stop_reason, - logprobs=logprobs if request.logprobs else None, - ) - ) - - if stop_reason is None: - stop_reason = StopReason.out_of_tokens - - message = self.generator.formatter.decode_assistant_message(tokens, stop_reason) - - parsed_tool_calls = len(message.tool_calls) > 0 - if ipython and not parsed_tool_calls: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta( - tool_call="", - parse_status=ToolCallParseStatus.failed, - ), - stop_reason=stop_reason, - ) - ) - - for tool_call in message.tool_calls: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta( - tool_call=tool_call, - parse_status=ToolCallParseStatus.succeeded, - ), - stop_reason=stop_reason, - ) - ) + logprobs = [TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]})] yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.complete, - delta=TextDelta(text=""), + event_type=ChatCompletionResponseEventType.progress, + delta=TextDelta(text=text), + logprobs=logprobs if request.logprobs else None, stop_reason=stop_reason, ) ) - - if self.config.create_distributed_process_group: - async with SEMAPHORE: - for x in impl(): - yield x - else: - for x in impl(): - yield x diff --git a/llama_stack/providers/utils/inference/embedding_mixin.py b/llama_stack/providers/utils/inference/embedding_mixin.py index 97cf87360..84d7221ee 100644 --- a/llama_stack/providers/utils/inference/embedding_mixin.py +++ b/llama_stack/providers/utils/inference/embedding_mixin.py @@ -31,7 +31,7 @@ log = logging.getLogger(__name__) class SentenceTransformerEmbeddingMixin: - model_store: ModelStore + model_store: ModelStore | None = None async def embeddings( self, @@ -41,7 +41,11 @@ class SentenceTransformerEmbeddingMixin: output_dimension: int | None = None, task_type: EmbeddingTaskType | None = None, ) -> EmbeddingsResponse: + if self.model_store is None: + raise RuntimeError("Model store is not initialized") model = await self.model_store.get_model(model_id) + if model.provider_resource_id is None: + raise RuntimeError("Model provider resource ID is not set") embedding_model = self._load_sentence_transformer_model(model.provider_resource_id) embeddings = embedding_model.encode( [interleaved_content_as_str(content) for content in contents], show_progress_bar=False @@ -62,7 +66,11 @@ class SentenceTransformerEmbeddingMixin: raise ValueError("Empty list not supported") # Get the model and generate embeddings + if self.model_store is None: + raise RuntimeError("Model store is not initialized") model_obj = await self.model_store.get_model(model) + if model_obj.provider_resource_id is None: + raise RuntimeError("Model provider resource ID is not set") embedding_model = self._load_sentence_transformer_model(model_obj.provider_resource_id) embeddings = embedding_model.encode(input_list, show_progress_bar=False) diff --git a/pyproject.toml b/pyproject.toml index d84a823a3..ab2ae4b1e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -246,7 +246,7 @@ exclude = [ "^llama_stack/providers/inline/agents/meta_reference/agents\\.py$", "^llama_stack/providers/inline/datasetio/localfs/", "^llama_stack/providers/inline/eval/meta_reference/eval\\.py$", - "^llama_stack/providers/inline/inference/meta_reference/inference\\.py$", + "^llama_stack/providers/inline/inference/meta_reference/config\\.py$", "^llama_stack/models/llama/llama3/generation\\.py$", "^llama_stack/models/llama/llama3/multimodal/model\\.py$", "^llama_stack/models/llama/llama4/",