chore(api): add mypy coverage to meta_reference_inference

Signed-off-by: Mustafa Elbehery <melbeher@redhat.com>
This commit is contained in:
Mustafa Elbehery 2025-07-08 23:11:26 +02:00
parent d880c2df0e
commit 74103e4eee
3 changed files with 215 additions and 302 deletions

View file

@ -5,17 +5,13 @@
# the root directory of this source tree. # the root directory of this source tree.
import asyncio import asyncio
import os from collections.abc import AsyncIterator
import sys from typing import Any
from collections.abc import AsyncGenerator
from pydantic import BaseModel from pydantic import BaseModel
from termcolor import cprint
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
TextDelta, TextDelta,
ToolCallDelta,
ToolCallParseStatus,
) )
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
BatchChatCompletionResponse, BatchChatCompletionResponse,
@ -43,13 +39,15 @@ from llama_stack.apis.inference import (
ToolPromptFormat, ToolPromptFormat,
UserMessage, UserMessage,
) )
from llama_stack.apis.models import Model, ModelType from llama_stack.apis.models import Model as ApiModel
from llama_stack.apis.models import ModelType
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
from llama_stack.models.llama.sku_list import resolve_model from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.models.llama.sku_types import Model as LlamaModel
from llama_stack.models.llama.sku_types import ModelFamily from llama_stack.models.llama.sku_types import ModelFamily
from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.embedding_mixin import ( from llama_stack.providers.utils.inference.embedding_mixin import (
@ -64,8 +62,9 @@ from llama_stack.providers.utils.inference.openai_compat import (
OpenAICompletionToLlamaStackMixin, OpenAICompletionToLlamaStackMixin,
) )
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
ChatCompletionRequestWithRawContent,
CompletionRequestWithRawContent,
augment_content_with_response_format_prompt, augment_content_with_response_format_prompt,
chat_completion_request_to_messages,
convert_request_to_raw, convert_request_to_raw,
) )
@ -79,7 +78,7 @@ log = get_logger(__name__, category="inference")
SEMAPHORE = asyncio.Semaphore(1) SEMAPHORE = asyncio.Semaphore(1)
def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> LlamaGenerator: def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: LlamaModel) -> LlamaGenerator:
return LlamaGenerator(config, model_id, llama_model) return LlamaGenerator(config, model_id, llama_model)
@ -92,20 +91,23 @@ class MetaReferenceInferenceImpl(
): ):
def __init__(self, config: MetaReferenceInferenceConfig) -> None: def __init__(self, config: MetaReferenceInferenceConfig) -> None:
self.config = config self.config = config
self.model_id = None self.model_id: str | None = None
self.llama_model = None self.llama_model: LlamaModel | None = None
self.generator: LlamaGenerator | LlamaModelParallelGenerator | None = None
self.model_registry_helper: ModelRegistryHelper | None = None
async def initialize(self) -> None: async def initialize(self) -> None:
pass pass
async def shutdown(self) -> None: async def shutdown(self) -> None:
if self.config.create_distributed_process_group: if self.config.create_distributed_process_group and self.generator:
self.generator.stop() if hasattr(self.generator, "stop") and callable(self.generator.stop):
self.generator.stop()
async def unregister_model(self, model_id: str) -> None: async def unregister_model(self, model_id: str) -> None:
pass pass
async def register_model(self, model: Model) -> Model: async def register_model(self, model: ApiModel) -> ApiModel:
llama_model = ( llama_model = (
resolve_model(model.metadata["llama_model"]) resolve_model(model.metadata["llama_model"])
if "llama_model" in model.metadata if "llama_model" in model.metadata
@ -127,7 +129,8 @@ class MetaReferenceInferenceImpl(
model = await self.model_registry_helper.register_model(model) model = await self.model_registry_helper.register_model(model)
if model.model_type == ModelType.embedding: if model.model_type == ModelType.embedding:
self._load_sentence_transformer_model(model.provider_resource_id) if model.provider_resource_id is not None:
self._load_sentence_transformer_model(model.provider_resource_id)
# TODO: what is this?! you can't really specify skipping via model metadata # TODO: what is this?! you can't really specify skipping via model metadata
# kill this madness # kill this madness
@ -137,10 +140,10 @@ class MetaReferenceInferenceImpl(
await self.load_model(model.identifier, llama_model) await self.load_model(model.identifier, llama_model)
return model return model
async def load_model(self, model_id, llama_model) -> None: async def load_model(self, model_id: str, llama_model: LlamaModel) -> None:
log.info(f"Loading model `{model_id}`") log.info(f"Loading model `{model_id}`")
builder_params = [self.config, model_id, llama_model] builder_params: list[Any] = [self.config, model_id, llama_model]
if self.config.create_distributed_process_group: if self.config.create_distributed_process_group:
self.generator = LlamaModelParallelGenerator( self.generator = LlamaModelParallelGenerator(
@ -173,7 +176,7 @@ class MetaReferenceInferenceImpl(
) )
log.info("Warmed up!") log.info("Warmed up!")
def check_model(self, request) -> None: def check_model(self, request: CompletionRequest | ChatCompletionRequest) -> None:
if self.model_id is None or self.llama_model is None: if self.model_id is None or self.llama_model is None:
raise RuntimeError( raise RuntimeError(
"No avaible model yet, please register your requested model or add your model in the resouces first" "No avaible model yet, please register your requested model or add your model in the resouces first"
@ -189,7 +192,7 @@ class MetaReferenceInferenceImpl(
response_format: ResponseFormat | None = None, response_format: ResponseFormat | None = None,
stream: bool | None = False, stream: bool | None = False,
logprobs: LogProbConfig | None = None, logprobs: LogProbConfig | None = None,
) -> CompletionResponse | CompletionResponseStreamChunk: ) -> CompletionResponse | AsyncIterator[CompletionResponseStreamChunk]:
if sampling_params is None: if sampling_params is None:
sampling_params = SamplingParams() sampling_params = SamplingParams()
if logprobs: if logprobs:
@ -205,12 +208,17 @@ class MetaReferenceInferenceImpl(
logprobs=logprobs, logprobs=logprobs,
) )
self.check_model(request) self.check_model(request)
request = await convert_request_to_raw(request) request_with_raw_union = await convert_request_to_raw(request)
# Type cast to ensure we have the correct type
from typing import cast
request_with_raw = cast(CompletionRequestWithRawContent, request_with_raw_union)
if request.stream: if request.stream:
return self._stream_completion(request) return self._stream_completion(request_with_raw)
else: else:
results = await self._nonstream_completion([request]) results = await self._nonstream_completion([request_with_raw])
return results[0] return results[0]
async def batch_completion( async def batch_completion(
@ -219,7 +227,6 @@ class MetaReferenceInferenceImpl(
content_batch: list[InterleavedContent], content_batch: list[InterleavedContent],
sampling_params: SamplingParams | None = None, sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None, response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None, logprobs: LogProbConfig | None = None,
) -> BatchCompletionResponse: ) -> BatchCompletionResponse:
if sampling_params is None: if sampling_params is None:
@ -238,159 +245,155 @@ class MetaReferenceInferenceImpl(
content=content, content=content,
sampling_params=sampling_params, sampling_params=sampling_params,
response_format=response_format, response_format=response_format,
stream=stream, stream=False,
logprobs=logprobs, logprobs=logprobs,
) )
self.check_model(request) self.check_model(request)
request = await convert_request_to_raw(request) request_with_raw_union = await convert_request_to_raw(request)
request_batch.append(request)
# Type cast to ensure we have the correct type
from typing import cast
request_with_raw = cast(CompletionRequestWithRawContent, request_with_raw_union)
request_batch.append(request_with_raw)
results = await self._nonstream_completion(request_batch) results = await self._nonstream_completion(request_batch)
return BatchCompletionResponse(batch=results) return BatchCompletionResponse(batch=results)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: async def _stream_completion(
self, request: CompletionRequestWithRawContent
) -> AsyncIterator[CompletionResponseStreamChunk]:
if not self.generator:
raise RuntimeError("Generator not initialized")
tokenizer = self.generator.formatter.tokenizer tokenizer = self.generator.formatter.tokenizer
def impl(): stop_reason = None
stop_reason = None
for token_results in self.generator.completion([request]): for token_results in self.generator.completion([request]):
token_result = token_results[0] token_result = token_results[0]
if token_result.token == tokenizer.eot_id: if token_result.token == tokenizer.eot_id:
stop_reason = StopReason.end_of_turn stop_reason = StopReason.end_of_turn
text = "" text = ""
elif token_result.token == tokenizer.eom_id: elif token_result.token == tokenizer.eom_id:
stop_reason = StopReason.end_of_message stop_reason = StopReason.end_of_message
text = "" text = ""
else: else:
text = token_result.text text = token_result.text
logprobs = None
if stop_reason is None:
if request.logprobs:
assert len(token_result.logprobs) == 1
logprobs = [TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]})]
yield CompletionResponseStreamChunk(
delta=text,
stop_reason=stop_reason,
logprobs=logprobs if request.logprobs else None,
)
logprobs = None
if stop_reason is None: if stop_reason is None:
yield CompletionResponseStreamChunk( if request.logprobs:
delta="", assert len(token_result.logprobs) == 1
stop_reason=StopReason.out_of_tokens,
)
if self.config.create_distributed_process_group: logprobs = [TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]})]
async with SEMAPHORE:
for x in impl():
yield x
else:
for x in impl():
yield x
async def _nonstream_completion(self, request_batch: list[CompletionRequest]) -> list[CompletionResponse]: yield CompletionResponseStreamChunk(
tokenizer = self.generator.formatter.tokenizer delta=text,
stop_reason=stop_reason,
logprobs=logprobs if request.logprobs else None,
)
first_request = request_batch[0] async def _nonstream_completion(
self, request_batch: list[CompletionRequestWithRawContent]
) -> list[CompletionResponse]:
async with SEMAPHORE:
if not self.generator:
raise RuntimeError("Generator not initialized")
class ItemState(BaseModel): class ItemState(BaseModel):
tokens: list[int] = [] tokens: list[int] = []
logprobs: list[TokenLogProbs] = [] logprobs: list[TokenLogProbs] = []
stop_reason: StopReason | None = None stop_reason: StopReason | None = None
finished: bool = False finished: bool = False
def impl(): def impl() -> list[CompletionResponse]:
states = [ItemState() for _ in request_batch] if not self.generator:
raise RuntimeError("Generator not initialized")
results = [] item_states = [ItemState() for _ in request_batch]
for token_results in self.generator.completion(request_batch):
for result in token_results:
idx = result.batch_idx
state = states[idx]
if state.finished or result.ignore_token:
continue
state.finished = result.finished for token_results in self.generator.completion(request_batch):
if first_request.logprobs: for idx, token_result in enumerate(token_results):
state.logprobs.append(TokenLogProbs(logprobs_by_token={result.text: result.logprobs[0]})) item_state = item_states[idx]
if item_state.finished:
continue
state.tokens.append(result.token) if token_result.token == self.generator.formatter.tokenizer.eot_id:
if result.token == tokenizer.eot_id: item_state.stop_reason = StopReason.end_of_turn
state.stop_reason = StopReason.end_of_turn item_state.finished = True
elif result.token == tokenizer.eom_id: elif token_result.token == self.generator.formatter.tokenizer.eom_id:
state.stop_reason = StopReason.end_of_message item_state.stop_reason = StopReason.end_of_message
item_state.finished = True
else:
item_state.tokens.append(token_result.token)
if request_batch[idx].logprobs:
assert len(token_result.logprobs) == 1
item_state.logprobs.append(
TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]})
)
for state in states: # generate final responses
if state.stop_reason is None: completions = []
state.stop_reason = StopReason.out_of_tokens for idx, item_state in enumerate(item_states):
if not self.generator:
raise RuntimeError("Generator not initialized")
content = self.generator.formatter.tokenizer.decode(item_state.tokens)
if state.tokens[-1] in self.generator.formatter.tokenizer.stop_tokens: completions.append(
state.tokens = state.tokens[:-1] CompletionResponse(
content = self.generator.formatter.tokenizer.decode(state.tokens) content=content,
results.append( stop_reason=item_state.stop_reason or StopReason.out_of_tokens,
CompletionResponse( logprobs=item_state.logprobs if request_batch[idx].logprobs else None,
content=content, )
stop_reason=state.stop_reason,
logprobs=state.logprobs if first_request.logprobs else None,
) )
)
return results return completions
if self.config.create_distributed_process_group: return await asyncio.get_event_loop().run_in_executor(None, impl)
async with SEMAPHORE:
return impl()
else:
return impl()
async def chat_completion( async def chat_completion(
self, self,
model_id: str, model_id: str,
messages: list[Message], messages: list[Message],
sampling_params: SamplingParams | None = None, sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
tools: list[ToolDefinition] | None = None, tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto, tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None, tool_prompt_format: ToolPromptFormat | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False, stream: bool | None = False,
logprobs: LogProbConfig | None = None, logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None, tool_config: ToolConfig | None = None,
) -> AsyncGenerator: ) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
if sampling_params is None: if sampling_params is None:
sampling_params = SamplingParams() sampling_params = SamplingParams()
if logprobs: if logprobs:
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}" assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
# wrapper request to make it easier to pass around (internal only, not exposed to API) if self.llama_model is None:
raise RuntimeError("Model not initialized")
request = ChatCompletionRequest( request = ChatCompletionRequest(
model=model_id, model=model_id,
messages=messages, messages=messages,
sampling_params=sampling_params, sampling_params=sampling_params,
tools=tools or [], tools=tools or [],
tool_config=tool_config or ToolConfig(),
response_format=response_format, response_format=response_format,
stream=stream, stream=stream,
logprobs=logprobs, logprobs=logprobs,
tool_config=tool_config or ToolConfig(),
) )
self.check_model(request) self.check_model(request)
request_with_raw_union = await convert_request_to_raw(request)
# augment and rewrite messages depending on the model # Type cast to ensure we have the correct type
request.messages = chat_completion_request_to_messages(request, self.llama_model.core_model_id.value) from typing import cast
# download media and convert to raw content so we can send it to the model
request = await convert_request_to_raw(request)
if self.config.create_distributed_process_group: request_with_raw = cast(ChatCompletionRequestWithRawContent, request_with_raw_union)
if SEMAPHORE.locked():
raise RuntimeError("Only one concurrent request is supported")
if request.stream: if request.stream:
return self._stream_chat_completion(request) return self._stream_chat_completion(request_with_raw)
else: else:
results = await self._nonstream_chat_completion([request]) results = await self._nonstream_chat_completion([request_with_raw])
return results[0] return results[0]
async def batch_chat_completion( async def batch_chat_completion(
@ -398,18 +401,19 @@ class MetaReferenceInferenceImpl(
model_id: str, model_id: str,
messages_batch: list[list[Message]], messages_batch: list[list[Message]],
sampling_params: SamplingParams | None = None, sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
tools: list[ToolDefinition] | None = None, tools: list[ToolDefinition] | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None, tool_config: ToolConfig | None = None,
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
) -> BatchChatCompletionResponse: ) -> BatchChatCompletionResponse:
if sampling_params is None: if sampling_params is None:
sampling_params = SamplingParams() sampling_params = SamplingParams()
if logprobs: if logprobs:
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}" assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
# wrapper request to make it easier to pass around (internal only, not exposed to API) if self.llama_model is None:
raise RuntimeError("Model not initialized")
request_batch = [] request_batch = []
for messages in messages_batch: for messages in messages_batch:
request = ChatCompletionRequest( request = ChatCompletionRequest(
@ -417,215 +421,116 @@ class MetaReferenceInferenceImpl(
messages=messages, messages=messages,
sampling_params=sampling_params, sampling_params=sampling_params,
tools=tools or [], tools=tools or [],
response_format=response_format,
logprobs=logprobs,
tool_config=tool_config or ToolConfig(), tool_config=tool_config or ToolConfig(),
response_format=response_format,
stream=False,
logprobs=logprobs,
) )
self.check_model(request) self.check_model(request)
request_with_raw_union = await convert_request_to_raw(request)
# augment and rewrite messages depending on the model # Type cast to ensure we have the correct type
request.messages = chat_completion_request_to_messages(request, self.llama_model.core_model_id.value) from typing import cast
# download media and convert to raw content so we can send it to the model
request = await convert_request_to_raw(request)
request_batch.append(request)
if self.config.create_distributed_process_group: request_with_raw = cast(ChatCompletionRequestWithRawContent, request_with_raw_union)
if SEMAPHORE.locked(): request_batch.append(request_with_raw)
raise RuntimeError("Only one concurrent request is supported")
results = await self._nonstream_chat_completion(request_batch) results = await self._nonstream_chat_completion(request_batch)
return BatchChatCompletionResponse(batch=results) return BatchChatCompletionResponse(batch=results)
async def _nonstream_chat_completion( async def _nonstream_chat_completion(
self, request_batch: list[ChatCompletionRequest] self, request_batch: list[ChatCompletionRequestWithRawContent]
) -> list[ChatCompletionResponse]: ) -> list[ChatCompletionResponse]:
tokenizer = self.generator.formatter.tokenizer async with SEMAPHORE:
if not self.generator:
raise RuntimeError("Generator not initialized")
first_request = request_batch[0] class ItemState(BaseModel):
tokens: list[int] = []
logprobs: list[TokenLogProbs] = []
stop_reason: StopReason | None = None
finished: bool = False
class ItemState(BaseModel): def impl() -> list[ChatCompletionResponse]:
tokens: list[int] = [] if not self.generator:
logprobs: list[TokenLogProbs] = [] raise RuntimeError("Generator not initialized")
stop_reason: StopReason | None = None
finished: bool = False
def impl(): item_states = [ItemState() for _ in request_batch]
states = [ItemState() for _ in request_batch]
for token_results in self.generator.chat_completion(request_batch): for token_results in self.generator.chat_completion(request_batch):
first = token_results[0] for idx, token_result in enumerate(token_results):
if not first.finished and not first.ignore_token: item_state = item_states[idx]
if os.environ.get("LLAMA_MODELS_DEBUG", "0") in ("1", "2"): if item_state.finished:
cprint(first.text, color="cyan", end="", file=sys.stderr) continue
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2":
cprint(f"<{first.token}>", color="magenta", end="", file=sys.stderr)
for result in token_results: if token_result.token == self.generator.formatter.tokenizer.eot_id:
idx = result.batch_idx item_state.stop_reason = StopReason.end_of_turn
state = states[idx] item_state.finished = True
if state.finished or result.ignore_token: elif token_result.token == self.generator.formatter.tokenizer.eom_id:
continue item_state.stop_reason = StopReason.end_of_message
item_state.finished = True
else:
item_state.tokens.append(token_result.token)
if request_batch[idx].logprobs:
assert len(token_result.logprobs) == 1
item_state.logprobs.append(
TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]})
)
state.finished = result.finished # generate final responses
if first_request.logprobs: completions = []
state.logprobs.append(TokenLogProbs(logprobs_by_token={result.text: result.logprobs[0]})) for idx, item_state in enumerate(item_states):
if not self.generator:
raise RuntimeError("Generator not initialized")
content = self.generator.formatter.tokenizer.decode(item_state.tokens)
state.tokens.append(result.token) completions.append(
if result.token == tokenizer.eot_id: ChatCompletionResponse(
state.stop_reason = StopReason.end_of_turn completion_message=CompletionMessage(
elif result.token == tokenizer.eom_id: content=content,
state.stop_reason = StopReason.end_of_message stop_reason=item_state.stop_reason or StopReason.out_of_tokens,
tool_calls=[],
results = [] ),
for state in states: logprobs=item_state.logprobs if request_batch[idx].logprobs else None,
if state.stop_reason is None: )
state.stop_reason = StopReason.out_of_tokens
raw_message = self.generator.formatter.decode_assistant_message(state.tokens, state.stop_reason)
results.append(
ChatCompletionResponse(
completion_message=CompletionMessage(
content=raw_message.content,
stop_reason=raw_message.stop_reason,
tool_calls=raw_message.tool_calls,
),
logprobs=state.logprobs if first_request.logprobs else None,
) )
)
return results return completions
if self.config.create_distributed_process_group: return await asyncio.get_event_loop().run_in_executor(None, impl)
async with SEMAPHORE:
return impl()
else:
return impl()
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: async def _stream_chat_completion(
self, request: ChatCompletionRequestWithRawContent
) -> AsyncIterator[ChatCompletionResponseStreamChunk]:
if not self.generator:
raise RuntimeError("Generator not initialized")
tokenizer = self.generator.formatter.tokenizer tokenizer = self.generator.formatter.tokenizer
def impl(): stop_reason = None
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta=TextDelta(text=""),
)
)
tokens = [] for token_results in self.generator.chat_completion([request]):
logprobs = [] token_result = token_results[0]
stop_reason = None if token_result.token == tokenizer.eot_id:
ipython = False stop_reason = StopReason.end_of_turn
text = ""
for token_results in self.generator.chat_completion([request]): elif token_result.token == tokenizer.eom_id:
token_result = token_results[0] stop_reason = StopReason.end_of_message
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1": text = ""
cprint(token_result.text, color="cyan", end="", file=sys.stderr) else:
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2": text = token_result.text
cprint(f"<{token_result.token}>", color="magenta", end="", file=sys.stderr)
if token_result.token == tokenizer.eot_id:
stop_reason = StopReason.end_of_turn
text = ""
elif token_result.token == tokenizer.eom_id:
stop_reason = StopReason.end_of_message
text = ""
else:
text = token_result.text
logprobs = None
if stop_reason is None:
if request.logprobs: if request.logprobs:
assert len(token_result.logprobs) == 1 assert len(token_result.logprobs) == 1
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]})) logprobs = [TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]})]
tokens.append(token_result.token)
if not ipython and token_result.text.startswith("<|python_tag|>"):
ipython = True
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
tool_call="",
parse_status=ToolCallParseStatus.started,
),
)
)
continue
if token_result.token == tokenizer.eot_id:
stop_reason = StopReason.end_of_turn
text = ""
elif token_result.token == tokenizer.eom_id:
stop_reason = StopReason.end_of_message
text = ""
else:
text = token_result.text
if ipython:
delta = ToolCallDelta(
tool_call=text,
parse_status=ToolCallParseStatus.in_progress,
)
else:
delta = TextDelta(text=text)
if stop_reason is None:
if request.logprobs:
assert len(token_result.logprobs) == 1
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=delta,
stop_reason=stop_reason,
logprobs=logprobs if request.logprobs else None,
)
)
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
message = self.generator.formatter.decode_assistant_message(tokens, stop_reason)
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
tool_call="",
parse_status=ToolCallParseStatus.failed,
),
stop_reason=stop_reason,
)
)
for tool_call in message.tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
tool_call=tool_call,
parse_status=ToolCallParseStatus.succeeded,
),
stop_reason=stop_reason,
)
)
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete, event_type=ChatCompletionResponseEventType.progress,
delta=TextDelta(text=""), delta=TextDelta(text=text),
logprobs=logprobs if request.logprobs else None,
stop_reason=stop_reason, stop_reason=stop_reason,
) )
) )
if self.config.create_distributed_process_group:
async with SEMAPHORE:
for x in impl():
yield x
else:
for x in impl():
yield x

View file

@ -31,7 +31,7 @@ log = logging.getLogger(__name__)
class SentenceTransformerEmbeddingMixin: class SentenceTransformerEmbeddingMixin:
model_store: ModelStore model_store: ModelStore | None = None
async def embeddings( async def embeddings(
self, self,
@ -41,7 +41,11 @@ class SentenceTransformerEmbeddingMixin:
output_dimension: int | None = None, output_dimension: int | None = None,
task_type: EmbeddingTaskType | None = None, task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
if self.model_store is None:
raise RuntimeError("Model store is not initialized")
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)
if model.provider_resource_id is None:
raise RuntimeError("Model provider resource ID is not set")
embedding_model = self._load_sentence_transformer_model(model.provider_resource_id) embedding_model = self._load_sentence_transformer_model(model.provider_resource_id)
embeddings = embedding_model.encode( embeddings = embedding_model.encode(
[interleaved_content_as_str(content) for content in contents], show_progress_bar=False [interleaved_content_as_str(content) for content in contents], show_progress_bar=False
@ -62,7 +66,11 @@ class SentenceTransformerEmbeddingMixin:
raise ValueError("Empty list not supported") raise ValueError("Empty list not supported")
# Get the model and generate embeddings # Get the model and generate embeddings
if self.model_store is None:
raise RuntimeError("Model store is not initialized")
model_obj = await self.model_store.get_model(model) model_obj = await self.model_store.get_model(model)
if model_obj.provider_resource_id is None:
raise RuntimeError("Model provider resource ID is not set")
embedding_model = self._load_sentence_transformer_model(model_obj.provider_resource_id) embedding_model = self._load_sentence_transformer_model(model_obj.provider_resource_id)
embeddings = embedding_model.encode(input_list, show_progress_bar=False) embeddings = embedding_model.encode(input_list, show_progress_bar=False)

View file

@ -246,7 +246,7 @@ exclude = [
"^llama_stack/providers/inline/agents/meta_reference/agents\\.py$", "^llama_stack/providers/inline/agents/meta_reference/agents\\.py$",
"^llama_stack/providers/inline/datasetio/localfs/", "^llama_stack/providers/inline/datasetio/localfs/",
"^llama_stack/providers/inline/eval/meta_reference/eval\\.py$", "^llama_stack/providers/inline/eval/meta_reference/eval\\.py$",
"^llama_stack/providers/inline/inference/meta_reference/inference\\.py$", "^llama_stack/providers/inline/inference/meta_reference/config\\.py$",
"^llama_stack/models/llama/llama3/generation\\.py$", "^llama_stack/models/llama/llama3/generation\\.py$",
"^llama_stack/models/llama/llama3/multimodal/model\\.py$", "^llama_stack/models/llama/llama3/multimodal/model\\.py$",
"^llama_stack/models/llama/llama4/", "^llama_stack/models/llama/llama4/",