chore(api): add mypy coverage to meta_reference_inference

Signed-off-by: Mustafa Elbehery <melbeher@redhat.com>
This commit is contained in:
Mustafa Elbehery 2025-07-08 23:11:26 +02:00
parent d880c2df0e
commit 74103e4eee
3 changed files with 215 additions and 302 deletions

View file

@ -5,17 +5,13 @@
# the root directory of this source tree.
import asyncio
import os
import sys
from collections.abc import AsyncGenerator
from collections.abc import AsyncIterator
from typing import Any
from pydantic import BaseModel
from termcolor import cprint
from llama_stack.apis.common.content_types import (
TextDelta,
ToolCallDelta,
ToolCallParseStatus,
)
from llama_stack.apis.inference import (
BatchChatCompletionResponse,
@ -43,13 +39,15 @@ from llama_stack.apis.inference import (
ToolPromptFormat,
UserMessage,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.apis.models import Model as ApiModel
from llama_stack.apis.models import ModelType
from llama_stack.log import get_logger
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.models.llama.sku_types import Model as LlamaModel
from llama_stack.models.llama.sku_types import ModelFamily
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.embedding_mixin import (
@ -64,8 +62,9 @@ from llama_stack.providers.utils.inference.openai_compat import (
OpenAICompletionToLlamaStackMixin,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
ChatCompletionRequestWithRawContent,
CompletionRequestWithRawContent,
augment_content_with_response_format_prompt,
chat_completion_request_to_messages,
convert_request_to_raw,
)
@ -79,7 +78,7 @@ log = get_logger(__name__, category="inference")
SEMAPHORE = asyncio.Semaphore(1)
def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> LlamaGenerator:
def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: LlamaModel) -> LlamaGenerator:
return LlamaGenerator(config, model_id, llama_model)
@ -92,20 +91,23 @@ class MetaReferenceInferenceImpl(
):
def __init__(self, config: MetaReferenceInferenceConfig) -> None:
self.config = config
self.model_id = None
self.llama_model = None
self.model_id: str | None = None
self.llama_model: LlamaModel | None = None
self.generator: LlamaGenerator | LlamaModelParallelGenerator | None = None
self.model_registry_helper: ModelRegistryHelper | None = None
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
if self.config.create_distributed_process_group:
self.generator.stop()
if self.config.create_distributed_process_group and self.generator:
if hasattr(self.generator, "stop") and callable(self.generator.stop):
self.generator.stop()
async def unregister_model(self, model_id: str) -> None:
pass
async def register_model(self, model: Model) -> Model:
async def register_model(self, model: ApiModel) -> ApiModel:
llama_model = (
resolve_model(model.metadata["llama_model"])
if "llama_model" in model.metadata
@ -127,7 +129,8 @@ class MetaReferenceInferenceImpl(
model = await self.model_registry_helper.register_model(model)
if model.model_type == ModelType.embedding:
self._load_sentence_transformer_model(model.provider_resource_id)
if model.provider_resource_id is not None:
self._load_sentence_transformer_model(model.provider_resource_id)
# TODO: what is this?! you can't really specify skipping via model metadata
# kill this madness
@ -137,10 +140,10 @@ class MetaReferenceInferenceImpl(
await self.load_model(model.identifier, llama_model)
return model
async def load_model(self, model_id, llama_model) -> None:
async def load_model(self, model_id: str, llama_model: LlamaModel) -> None:
log.info(f"Loading model `{model_id}`")
builder_params = [self.config, model_id, llama_model]
builder_params: list[Any] = [self.config, model_id, llama_model]
if self.config.create_distributed_process_group:
self.generator = LlamaModelParallelGenerator(
@ -173,7 +176,7 @@ class MetaReferenceInferenceImpl(
)
log.info("Warmed up!")
def check_model(self, request) -> None:
def check_model(self, request: CompletionRequest | ChatCompletionRequest) -> None:
if self.model_id is None or self.llama_model is None:
raise RuntimeError(
"No avaible model yet, please register your requested model or add your model in the resouces first"
@ -189,7 +192,7 @@ class MetaReferenceInferenceImpl(
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
) -> CompletionResponse | CompletionResponseStreamChunk:
) -> CompletionResponse | AsyncIterator[CompletionResponseStreamChunk]:
if sampling_params is None:
sampling_params = SamplingParams()
if logprobs:
@ -205,12 +208,17 @@ class MetaReferenceInferenceImpl(
logprobs=logprobs,
)
self.check_model(request)
request = await convert_request_to_raw(request)
request_with_raw_union = await convert_request_to_raw(request)
# Type cast to ensure we have the correct type
from typing import cast
request_with_raw = cast(CompletionRequestWithRawContent, request_with_raw_union)
if request.stream:
return self._stream_completion(request)
return self._stream_completion(request_with_raw)
else:
results = await self._nonstream_completion([request])
results = await self._nonstream_completion([request_with_raw])
return results[0]
async def batch_completion(
@ -219,7 +227,6 @@ class MetaReferenceInferenceImpl(
content_batch: list[InterleavedContent],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
) -> BatchCompletionResponse:
if sampling_params is None:
@ -238,159 +245,155 @@ class MetaReferenceInferenceImpl(
content=content,
sampling_params=sampling_params,
response_format=response_format,
stream=stream,
stream=False,
logprobs=logprobs,
)
self.check_model(request)
request = await convert_request_to_raw(request)
request_batch.append(request)
request_with_raw_union = await convert_request_to_raw(request)
# Type cast to ensure we have the correct type
from typing import cast
request_with_raw = cast(CompletionRequestWithRawContent, request_with_raw_union)
request_batch.append(request_with_raw)
results = await self._nonstream_completion(request_batch)
return BatchCompletionResponse(batch=results)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
async def _stream_completion(
self, request: CompletionRequestWithRawContent
) -> AsyncIterator[CompletionResponseStreamChunk]:
if not self.generator:
raise RuntimeError("Generator not initialized")
tokenizer = self.generator.formatter.tokenizer
def impl():
stop_reason = None
stop_reason = None
for token_results in self.generator.completion([request]):
token_result = token_results[0]
if token_result.token == tokenizer.eot_id:
stop_reason = StopReason.end_of_turn
text = ""
elif token_result.token == tokenizer.eom_id:
stop_reason = StopReason.end_of_message
text = ""
else:
text = token_result.text
logprobs = None
if stop_reason is None:
if request.logprobs:
assert len(token_result.logprobs) == 1
logprobs = [TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]})]
yield CompletionResponseStreamChunk(
delta=text,
stop_reason=stop_reason,
logprobs=logprobs if request.logprobs else None,
)
for token_results in self.generator.completion([request]):
token_result = token_results[0]
if token_result.token == tokenizer.eot_id:
stop_reason = StopReason.end_of_turn
text = ""
elif token_result.token == tokenizer.eom_id:
stop_reason = StopReason.end_of_message
text = ""
else:
text = token_result.text
logprobs = None
if stop_reason is None:
yield CompletionResponseStreamChunk(
delta="",
stop_reason=StopReason.out_of_tokens,
)
if request.logprobs:
assert len(token_result.logprobs) == 1
if self.config.create_distributed_process_group:
async with SEMAPHORE:
for x in impl():
yield x
else:
for x in impl():
yield x
logprobs = [TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]})]
async def _nonstream_completion(self, request_batch: list[CompletionRequest]) -> list[CompletionResponse]:
tokenizer = self.generator.formatter.tokenizer
yield CompletionResponseStreamChunk(
delta=text,
stop_reason=stop_reason,
logprobs=logprobs if request.logprobs else None,
)
first_request = request_batch[0]
async def _nonstream_completion(
self, request_batch: list[CompletionRequestWithRawContent]
) -> list[CompletionResponse]:
async with SEMAPHORE:
if not self.generator:
raise RuntimeError("Generator not initialized")
class ItemState(BaseModel):
tokens: list[int] = []
logprobs: list[TokenLogProbs] = []
stop_reason: StopReason | None = None
finished: bool = False
class ItemState(BaseModel):
tokens: list[int] = []
logprobs: list[TokenLogProbs] = []
stop_reason: StopReason | None = None
finished: bool = False
def impl():
states = [ItemState() for _ in request_batch]
def impl() -> list[CompletionResponse]:
if not self.generator:
raise RuntimeError("Generator not initialized")
results = []
for token_results in self.generator.completion(request_batch):
for result in token_results:
idx = result.batch_idx
state = states[idx]
if state.finished or result.ignore_token:
continue
item_states = [ItemState() for _ in request_batch]
state.finished = result.finished
if first_request.logprobs:
state.logprobs.append(TokenLogProbs(logprobs_by_token={result.text: result.logprobs[0]}))
for token_results in self.generator.completion(request_batch):
for idx, token_result in enumerate(token_results):
item_state = item_states[idx]
if item_state.finished:
continue
state.tokens.append(result.token)
if result.token == tokenizer.eot_id:
state.stop_reason = StopReason.end_of_turn
elif result.token == tokenizer.eom_id:
state.stop_reason = StopReason.end_of_message
if token_result.token == self.generator.formatter.tokenizer.eot_id:
item_state.stop_reason = StopReason.end_of_turn
item_state.finished = True
elif token_result.token == self.generator.formatter.tokenizer.eom_id:
item_state.stop_reason = StopReason.end_of_message
item_state.finished = True
else:
item_state.tokens.append(token_result.token)
if request_batch[idx].logprobs:
assert len(token_result.logprobs) == 1
item_state.logprobs.append(
TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]})
)
for state in states:
if state.stop_reason is None:
state.stop_reason = StopReason.out_of_tokens
# generate final responses
completions = []
for idx, item_state in enumerate(item_states):
if not self.generator:
raise RuntimeError("Generator not initialized")
content = self.generator.formatter.tokenizer.decode(item_state.tokens)
if state.tokens[-1] in self.generator.formatter.tokenizer.stop_tokens:
state.tokens = state.tokens[:-1]
content = self.generator.formatter.tokenizer.decode(state.tokens)
results.append(
CompletionResponse(
content=content,
stop_reason=state.stop_reason,
logprobs=state.logprobs if first_request.logprobs else None,
completions.append(
CompletionResponse(
content=content,
stop_reason=item_state.stop_reason or StopReason.out_of_tokens,
logprobs=item_state.logprobs if request_batch[idx].logprobs else None,
)
)
)
return results
return completions
if self.config.create_distributed_process_group:
async with SEMAPHORE:
return impl()
else:
return impl()
return await asyncio.get_event_loop().run_in_executor(None, impl)
async def chat_completion(
self,
model_id: str,
messages: list[Message],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> AsyncGenerator:
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
if sampling_params is None:
sampling_params = SamplingParams()
if logprobs:
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
# wrapper request to make it easier to pass around (internal only, not exposed to API)
if self.llama_model is None:
raise RuntimeError("Model not initialized")
request = ChatCompletionRequest(
model=model_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_config=tool_config or ToolConfig(),
response_format=response_format,
stream=stream,
logprobs=logprobs,
tool_config=tool_config or ToolConfig(),
)
self.check_model(request)
request_with_raw_union = await convert_request_to_raw(request)
# augment and rewrite messages depending on the model
request.messages = chat_completion_request_to_messages(request, self.llama_model.core_model_id.value)
# download media and convert to raw content so we can send it to the model
request = await convert_request_to_raw(request)
# Type cast to ensure we have the correct type
from typing import cast
if self.config.create_distributed_process_group:
if SEMAPHORE.locked():
raise RuntimeError("Only one concurrent request is supported")
request_with_raw = cast(ChatCompletionRequestWithRawContent, request_with_raw_union)
if request.stream:
return self._stream_chat_completion(request)
return self._stream_chat_completion(request_with_raw)
else:
results = await self._nonstream_chat_completion([request])
results = await self._nonstream_chat_completion([request_with_raw])
return results[0]
async def batch_chat_completion(
@ -398,18 +401,19 @@ class MetaReferenceInferenceImpl(
model_id: str,
messages_batch: list[list[Message]],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
tools: list[ToolDefinition] | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
) -> BatchChatCompletionResponse:
if sampling_params is None:
sampling_params = SamplingParams()
if logprobs:
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
# wrapper request to make it easier to pass around (internal only, not exposed to API)
if self.llama_model is None:
raise RuntimeError("Model not initialized")
request_batch = []
for messages in messages_batch:
request = ChatCompletionRequest(
@ -417,215 +421,116 @@ class MetaReferenceInferenceImpl(
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
response_format=response_format,
logprobs=logprobs,
tool_config=tool_config or ToolConfig(),
response_format=response_format,
stream=False,
logprobs=logprobs,
)
self.check_model(request)
request_with_raw_union = await convert_request_to_raw(request)
# augment and rewrite messages depending on the model
request.messages = chat_completion_request_to_messages(request, self.llama_model.core_model_id.value)
# download media and convert to raw content so we can send it to the model
request = await convert_request_to_raw(request)
request_batch.append(request)
# Type cast to ensure we have the correct type
from typing import cast
if self.config.create_distributed_process_group:
if SEMAPHORE.locked():
raise RuntimeError("Only one concurrent request is supported")
request_with_raw = cast(ChatCompletionRequestWithRawContent, request_with_raw_union)
request_batch.append(request_with_raw)
results = await self._nonstream_chat_completion(request_batch)
return BatchChatCompletionResponse(batch=results)
async def _nonstream_chat_completion(
self, request_batch: list[ChatCompletionRequest]
self, request_batch: list[ChatCompletionRequestWithRawContent]
) -> list[ChatCompletionResponse]:
tokenizer = self.generator.formatter.tokenizer
async with SEMAPHORE:
if not self.generator:
raise RuntimeError("Generator not initialized")
first_request = request_batch[0]
class ItemState(BaseModel):
tokens: list[int] = []
logprobs: list[TokenLogProbs] = []
stop_reason: StopReason | None = None
finished: bool = False
class ItemState(BaseModel):
tokens: list[int] = []
logprobs: list[TokenLogProbs] = []
stop_reason: StopReason | None = None
finished: bool = False
def impl() -> list[ChatCompletionResponse]:
if not self.generator:
raise RuntimeError("Generator not initialized")
def impl():
states = [ItemState() for _ in request_batch]
item_states = [ItemState() for _ in request_batch]
for token_results in self.generator.chat_completion(request_batch):
first = token_results[0]
if not first.finished and not first.ignore_token:
if os.environ.get("LLAMA_MODELS_DEBUG", "0") in ("1", "2"):
cprint(first.text, color="cyan", end="", file=sys.stderr)
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2":
cprint(f"<{first.token}>", color="magenta", end="", file=sys.stderr)
for token_results in self.generator.chat_completion(request_batch):
for idx, token_result in enumerate(token_results):
item_state = item_states[idx]
if item_state.finished:
continue
for result in token_results:
idx = result.batch_idx
state = states[idx]
if state.finished or result.ignore_token:
continue
if token_result.token == self.generator.formatter.tokenizer.eot_id:
item_state.stop_reason = StopReason.end_of_turn
item_state.finished = True
elif token_result.token == self.generator.formatter.tokenizer.eom_id:
item_state.stop_reason = StopReason.end_of_message
item_state.finished = True
else:
item_state.tokens.append(token_result.token)
if request_batch[idx].logprobs:
assert len(token_result.logprobs) == 1
item_state.logprobs.append(
TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]})
)
state.finished = result.finished
if first_request.logprobs:
state.logprobs.append(TokenLogProbs(logprobs_by_token={result.text: result.logprobs[0]}))
# generate final responses
completions = []
for idx, item_state in enumerate(item_states):
if not self.generator:
raise RuntimeError("Generator not initialized")
content = self.generator.formatter.tokenizer.decode(item_state.tokens)
state.tokens.append(result.token)
if result.token == tokenizer.eot_id:
state.stop_reason = StopReason.end_of_turn
elif result.token == tokenizer.eom_id:
state.stop_reason = StopReason.end_of_message
results = []
for state in states:
if state.stop_reason is None:
state.stop_reason = StopReason.out_of_tokens
raw_message = self.generator.formatter.decode_assistant_message(state.tokens, state.stop_reason)
results.append(
ChatCompletionResponse(
completion_message=CompletionMessage(
content=raw_message.content,
stop_reason=raw_message.stop_reason,
tool_calls=raw_message.tool_calls,
),
logprobs=state.logprobs if first_request.logprobs else None,
completions.append(
ChatCompletionResponse(
completion_message=CompletionMessage(
content=content,
stop_reason=item_state.stop_reason or StopReason.out_of_tokens,
tool_calls=[],
),
logprobs=item_state.logprobs if request_batch[idx].logprobs else None,
)
)
)
return results
return completions
if self.config.create_distributed_process_group:
async with SEMAPHORE:
return impl()
else:
return impl()
return await asyncio.get_event_loop().run_in_executor(None, impl)
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
async def _stream_chat_completion(
self, request: ChatCompletionRequestWithRawContent
) -> AsyncIterator[ChatCompletionResponseStreamChunk]:
if not self.generator:
raise RuntimeError("Generator not initialized")
tokenizer = self.generator.formatter.tokenizer
def impl():
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta=TextDelta(text=""),
)
)
stop_reason = None
tokens = []
logprobs = []
stop_reason = None
ipython = False
for token_results in self.generator.chat_completion([request]):
token_result = token_results[0]
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1":
cprint(token_result.text, color="cyan", end="", file=sys.stderr)
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2":
cprint(f"<{token_result.token}>", color="magenta", end="", file=sys.stderr)
if token_result.token == tokenizer.eot_id:
stop_reason = StopReason.end_of_turn
text = ""
elif token_result.token == tokenizer.eom_id:
stop_reason = StopReason.end_of_message
text = ""
else:
text = token_result.text
for token_results in self.generator.chat_completion([request]):
token_result = token_results[0]
if token_result.token == tokenizer.eot_id:
stop_reason = StopReason.end_of_turn
text = ""
elif token_result.token == tokenizer.eom_id:
stop_reason = StopReason.end_of_message
text = ""
else:
text = token_result.text
logprobs = None
if stop_reason is None:
if request.logprobs:
assert len(token_result.logprobs) == 1
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
tokens.append(token_result.token)
if not ipython and token_result.text.startswith("<|python_tag|>"):
ipython = True
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
tool_call="",
parse_status=ToolCallParseStatus.started,
),
)
)
continue
if token_result.token == tokenizer.eot_id:
stop_reason = StopReason.end_of_turn
text = ""
elif token_result.token == tokenizer.eom_id:
stop_reason = StopReason.end_of_message
text = ""
else:
text = token_result.text
if ipython:
delta = ToolCallDelta(
tool_call=text,
parse_status=ToolCallParseStatus.in_progress,
)
else:
delta = TextDelta(text=text)
if stop_reason is None:
if request.logprobs:
assert len(token_result.logprobs) == 1
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=delta,
stop_reason=stop_reason,
logprobs=logprobs if request.logprobs else None,
)
)
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
message = self.generator.formatter.decode_assistant_message(tokens, stop_reason)
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
tool_call="",
parse_status=ToolCallParseStatus.failed,
),
stop_reason=stop_reason,
)
)
for tool_call in message.tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
tool_call=tool_call,
parse_status=ToolCallParseStatus.succeeded,
),
stop_reason=stop_reason,
)
)
logprobs = [TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]})]
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta=TextDelta(text=""),
event_type=ChatCompletionResponseEventType.progress,
delta=TextDelta(text=text),
logprobs=logprobs if request.logprobs else None,
stop_reason=stop_reason,
)
)
if self.config.create_distributed_process_group:
async with SEMAPHORE:
for x in impl():
yield x
else:
for x in impl():
yield x

View file

@ -31,7 +31,7 @@ log = logging.getLogger(__name__)
class SentenceTransformerEmbeddingMixin:
model_store: ModelStore
model_store: ModelStore | None = None
async def embeddings(
self,
@ -41,7 +41,11 @@ class SentenceTransformerEmbeddingMixin:
output_dimension: int | None = None,
task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse:
if self.model_store is None:
raise RuntimeError("Model store is not initialized")
model = await self.model_store.get_model(model_id)
if model.provider_resource_id is None:
raise RuntimeError("Model provider resource ID is not set")
embedding_model = self._load_sentence_transformer_model(model.provider_resource_id)
embeddings = embedding_model.encode(
[interleaved_content_as_str(content) for content in contents], show_progress_bar=False
@ -62,7 +66,11 @@ class SentenceTransformerEmbeddingMixin:
raise ValueError("Empty list not supported")
# Get the model and generate embeddings
if self.model_store is None:
raise RuntimeError("Model store is not initialized")
model_obj = await self.model_store.get_model(model)
if model_obj.provider_resource_id is None:
raise RuntimeError("Model provider resource ID is not set")
embedding_model = self._load_sentence_transformer_model(model_obj.provider_resource_id)
embeddings = embedding_model.encode(input_list, show_progress_bar=False)

View file

@ -246,7 +246,7 @@ exclude = [
"^llama_stack/providers/inline/agents/meta_reference/agents\\.py$",
"^llama_stack/providers/inline/datasetio/localfs/",
"^llama_stack/providers/inline/eval/meta_reference/eval\\.py$",
"^llama_stack/providers/inline/inference/meta_reference/inference\\.py$",
"^llama_stack/providers/inline/inference/meta_reference/config\\.py$",
"^llama_stack/models/llama/llama3/generation\\.py$",
"^llama_stack/models/llama/llama3/multimodal/model\\.py$",
"^llama_stack/models/llama/llama4/",