chore(api): add mypy coverage to meta_reference_inference

Signed-off-by: Mustafa Elbehery <melbeher@redhat.com>
This commit is contained in:
Mustafa Elbehery 2025-07-08 23:11:26 +02:00
parent d880c2df0e
commit 74103e4eee
3 changed files with 215 additions and 302 deletions

View file

@ -5,17 +5,13 @@
# the root directory of this source tree. # the root directory of this source tree.
import asyncio import asyncio
import os from collections.abc import AsyncIterator
import sys from typing import Any
from collections.abc import AsyncGenerator
from pydantic import BaseModel from pydantic import BaseModel
from termcolor import cprint
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
TextDelta, TextDelta,
ToolCallDelta,
ToolCallParseStatus,
) )
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
BatchChatCompletionResponse, BatchChatCompletionResponse,
@ -43,13 +39,15 @@ from llama_stack.apis.inference import (
ToolPromptFormat, ToolPromptFormat,
UserMessage, UserMessage,
) )
from llama_stack.apis.models import Model, ModelType from llama_stack.apis.models import Model as ApiModel
from llama_stack.apis.models import ModelType
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
from llama_stack.models.llama.sku_list import resolve_model from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.models.llama.sku_types import Model as LlamaModel
from llama_stack.models.llama.sku_types import ModelFamily from llama_stack.models.llama.sku_types import ModelFamily
from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.embedding_mixin import ( from llama_stack.providers.utils.inference.embedding_mixin import (
@ -64,8 +62,9 @@ from llama_stack.providers.utils.inference.openai_compat import (
OpenAICompletionToLlamaStackMixin, OpenAICompletionToLlamaStackMixin,
) )
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
ChatCompletionRequestWithRawContent,
CompletionRequestWithRawContent,
augment_content_with_response_format_prompt, augment_content_with_response_format_prompt,
chat_completion_request_to_messages,
convert_request_to_raw, convert_request_to_raw,
) )
@ -79,7 +78,7 @@ log = get_logger(__name__, category="inference")
SEMAPHORE = asyncio.Semaphore(1) SEMAPHORE = asyncio.Semaphore(1)
def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> LlamaGenerator: def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: LlamaModel) -> LlamaGenerator:
return LlamaGenerator(config, model_id, llama_model) return LlamaGenerator(config, model_id, llama_model)
@ -92,20 +91,23 @@ class MetaReferenceInferenceImpl(
): ):
def __init__(self, config: MetaReferenceInferenceConfig) -> None: def __init__(self, config: MetaReferenceInferenceConfig) -> None:
self.config = config self.config = config
self.model_id = None self.model_id: str | None = None
self.llama_model = None self.llama_model: LlamaModel | None = None
self.generator: LlamaGenerator | LlamaModelParallelGenerator | None = None
self.model_registry_helper: ModelRegistryHelper | None = None
async def initialize(self) -> None: async def initialize(self) -> None:
pass pass
async def shutdown(self) -> None: async def shutdown(self) -> None:
if self.config.create_distributed_process_group: if self.config.create_distributed_process_group and self.generator:
if hasattr(self.generator, "stop") and callable(self.generator.stop):
self.generator.stop() self.generator.stop()
async def unregister_model(self, model_id: str) -> None: async def unregister_model(self, model_id: str) -> None:
pass pass
async def register_model(self, model: Model) -> Model: async def register_model(self, model: ApiModel) -> ApiModel:
llama_model = ( llama_model = (
resolve_model(model.metadata["llama_model"]) resolve_model(model.metadata["llama_model"])
if "llama_model" in model.metadata if "llama_model" in model.metadata
@ -127,6 +129,7 @@ class MetaReferenceInferenceImpl(
model = await self.model_registry_helper.register_model(model) model = await self.model_registry_helper.register_model(model)
if model.model_type == ModelType.embedding: if model.model_type == ModelType.embedding:
if model.provider_resource_id is not None:
self._load_sentence_transformer_model(model.provider_resource_id) self._load_sentence_transformer_model(model.provider_resource_id)
# TODO: what is this?! you can't really specify skipping via model metadata # TODO: what is this?! you can't really specify skipping via model metadata
@ -137,10 +140,10 @@ class MetaReferenceInferenceImpl(
await self.load_model(model.identifier, llama_model) await self.load_model(model.identifier, llama_model)
return model return model
async def load_model(self, model_id, llama_model) -> None: async def load_model(self, model_id: str, llama_model: LlamaModel) -> None:
log.info(f"Loading model `{model_id}`") log.info(f"Loading model `{model_id}`")
builder_params = [self.config, model_id, llama_model] builder_params: list[Any] = [self.config, model_id, llama_model]
if self.config.create_distributed_process_group: if self.config.create_distributed_process_group:
self.generator = LlamaModelParallelGenerator( self.generator = LlamaModelParallelGenerator(
@ -173,7 +176,7 @@ class MetaReferenceInferenceImpl(
) )
log.info("Warmed up!") log.info("Warmed up!")
def check_model(self, request) -> None: def check_model(self, request: CompletionRequest | ChatCompletionRequest) -> None:
if self.model_id is None or self.llama_model is None: if self.model_id is None or self.llama_model is None:
raise RuntimeError( raise RuntimeError(
"No avaible model yet, please register your requested model or add your model in the resouces first" "No avaible model yet, please register your requested model or add your model in the resouces first"
@ -189,7 +192,7 @@ class MetaReferenceInferenceImpl(
response_format: ResponseFormat | None = None, response_format: ResponseFormat | None = None,
stream: bool | None = False, stream: bool | None = False,
logprobs: LogProbConfig | None = None, logprobs: LogProbConfig | None = None,
) -> CompletionResponse | CompletionResponseStreamChunk: ) -> CompletionResponse | AsyncIterator[CompletionResponseStreamChunk]:
if sampling_params is None: if sampling_params is None:
sampling_params = SamplingParams() sampling_params = SamplingParams()
if logprobs: if logprobs:
@ -205,12 +208,17 @@ class MetaReferenceInferenceImpl(
logprobs=logprobs, logprobs=logprobs,
) )
self.check_model(request) self.check_model(request)
request = await convert_request_to_raw(request) request_with_raw_union = await convert_request_to_raw(request)
# Type cast to ensure we have the correct type
from typing import cast
request_with_raw = cast(CompletionRequestWithRawContent, request_with_raw_union)
if request.stream: if request.stream:
return self._stream_completion(request) return self._stream_completion(request_with_raw)
else: else:
results = await self._nonstream_completion([request]) results = await self._nonstream_completion([request_with_raw])
return results[0] return results[0]
async def batch_completion( async def batch_completion(
@ -219,7 +227,6 @@ class MetaReferenceInferenceImpl(
content_batch: list[InterleavedContent], content_batch: list[InterleavedContent],
sampling_params: SamplingParams | None = None, sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None, response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None, logprobs: LogProbConfig | None = None,
) -> BatchCompletionResponse: ) -> BatchCompletionResponse:
if sampling_params is None: if sampling_params is None:
@ -238,20 +245,28 @@ class MetaReferenceInferenceImpl(
content=content, content=content,
sampling_params=sampling_params, sampling_params=sampling_params,
response_format=response_format, response_format=response_format,
stream=stream, stream=False,
logprobs=logprobs, logprobs=logprobs,
) )
self.check_model(request) self.check_model(request)
request = await convert_request_to_raw(request) request_with_raw_union = await convert_request_to_raw(request)
request_batch.append(request)
# Type cast to ensure we have the correct type
from typing import cast
request_with_raw = cast(CompletionRequestWithRawContent, request_with_raw_union)
request_batch.append(request_with_raw)
results = await self._nonstream_completion(request_batch) results = await self._nonstream_completion(request_batch)
return BatchCompletionResponse(batch=results) return BatchCompletionResponse(batch=results)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: async def _stream_completion(
self, request: CompletionRequestWithRawContent
) -> AsyncIterator[CompletionResponseStreamChunk]:
if not self.generator:
raise RuntimeError("Generator not initialized")
tokenizer = self.generator.formatter.tokenizer tokenizer = self.generator.formatter.tokenizer
def impl():
stop_reason = None stop_reason = None
for token_results in self.generator.completion([request]): for token_results in self.generator.completion([request]):
@ -278,24 +293,12 @@ class MetaReferenceInferenceImpl(
logprobs=logprobs if request.logprobs else None, logprobs=logprobs if request.logprobs else None,
) )
if stop_reason is None: async def _nonstream_completion(
yield CompletionResponseStreamChunk( self, request_batch: list[CompletionRequestWithRawContent]
delta="", ) -> list[CompletionResponse]:
stop_reason=StopReason.out_of_tokens,
)
if self.config.create_distributed_process_group:
async with SEMAPHORE: async with SEMAPHORE:
for x in impl(): if not self.generator:
yield x raise RuntimeError("Generator not initialized")
else:
for x in impl():
yield x
async def _nonstream_completion(self, request_batch: list[CompletionRequest]) -> list[CompletionResponse]:
tokenizer = self.generator.formatter.tokenizer
first_request = request_batch[0]
class ItemState(BaseModel): class ItemState(BaseModel):
tokens: list[int] = [] tokens: list[int] = []
@ -303,94 +306,94 @@ class MetaReferenceInferenceImpl(
stop_reason: StopReason | None = None stop_reason: StopReason | None = None
finished: bool = False finished: bool = False
def impl(): def impl() -> list[CompletionResponse]:
states = [ItemState() for _ in request_batch] if not self.generator:
raise RuntimeError("Generator not initialized")
item_states = [ItemState() for _ in request_batch]
results = []
for token_results in self.generator.completion(request_batch): for token_results in self.generator.completion(request_batch):
for result in token_results: for idx, token_result in enumerate(token_results):
idx = result.batch_idx item_state = item_states[idx]
state = states[idx] if item_state.finished:
if state.finished or result.ignore_token:
continue continue
state.finished = result.finished if token_result.token == self.generator.formatter.tokenizer.eot_id:
if first_request.logprobs: item_state.stop_reason = StopReason.end_of_turn
state.logprobs.append(TokenLogProbs(logprobs_by_token={result.text: result.logprobs[0]})) item_state.finished = True
elif token_result.token == self.generator.formatter.tokenizer.eom_id:
item_state.stop_reason = StopReason.end_of_message
item_state.finished = True
else:
item_state.tokens.append(token_result.token)
if request_batch[idx].logprobs:
assert len(token_result.logprobs) == 1
item_state.logprobs.append(
TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]})
)
state.tokens.append(result.token) # generate final responses
if result.token == tokenizer.eot_id: completions = []
state.stop_reason = StopReason.end_of_turn for idx, item_state in enumerate(item_states):
elif result.token == tokenizer.eom_id: if not self.generator:
state.stop_reason = StopReason.end_of_message raise RuntimeError("Generator not initialized")
content = self.generator.formatter.tokenizer.decode(item_state.tokens)
for state in states: completions.append(
if state.stop_reason is None:
state.stop_reason = StopReason.out_of_tokens
if state.tokens[-1] in self.generator.formatter.tokenizer.stop_tokens:
state.tokens = state.tokens[:-1]
content = self.generator.formatter.tokenizer.decode(state.tokens)
results.append(
CompletionResponse( CompletionResponse(
content=content, content=content,
stop_reason=state.stop_reason, stop_reason=item_state.stop_reason or StopReason.out_of_tokens,
logprobs=state.logprobs if first_request.logprobs else None, logprobs=item_state.logprobs if request_batch[idx].logprobs else None,
) )
) )
return results return completions
if self.config.create_distributed_process_group: return await asyncio.get_event_loop().run_in_executor(None, impl)
async with SEMAPHORE:
return impl()
else:
return impl()
async def chat_completion( async def chat_completion(
self, self,
model_id: str, model_id: str,
messages: list[Message], messages: list[Message],
sampling_params: SamplingParams | None = None, sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
tools: list[ToolDefinition] | None = None, tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto, tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None, tool_prompt_format: ToolPromptFormat | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False, stream: bool | None = False,
logprobs: LogProbConfig | None = None, logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None, tool_config: ToolConfig | None = None,
) -> AsyncGenerator: ) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
if sampling_params is None: if sampling_params is None:
sampling_params = SamplingParams() sampling_params = SamplingParams()
if logprobs: if logprobs:
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}" assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
# wrapper request to make it easier to pass around (internal only, not exposed to API) if self.llama_model is None:
raise RuntimeError("Model not initialized")
request = ChatCompletionRequest( request = ChatCompletionRequest(
model=model_id, model=model_id,
messages=messages, messages=messages,
sampling_params=sampling_params, sampling_params=sampling_params,
tools=tools or [], tools=tools or [],
tool_config=tool_config or ToolConfig(),
response_format=response_format, response_format=response_format,
stream=stream, stream=stream,
logprobs=logprobs, logprobs=logprobs,
tool_config=tool_config or ToolConfig(),
) )
self.check_model(request) self.check_model(request)
request_with_raw_union = await convert_request_to_raw(request)
# augment and rewrite messages depending on the model # Type cast to ensure we have the correct type
request.messages = chat_completion_request_to_messages(request, self.llama_model.core_model_id.value) from typing import cast
# download media and convert to raw content so we can send it to the model
request = await convert_request_to_raw(request)
if self.config.create_distributed_process_group: request_with_raw = cast(ChatCompletionRequestWithRawContent, request_with_raw_union)
if SEMAPHORE.locked():
raise RuntimeError("Only one concurrent request is supported")
if request.stream: if request.stream:
return self._stream_chat_completion(request) return self._stream_chat_completion(request_with_raw)
else: else:
results = await self._nonstream_chat_completion([request]) results = await self._nonstream_chat_completion([request_with_raw])
return results[0] return results[0]
async def batch_chat_completion( async def batch_chat_completion(
@ -398,18 +401,19 @@ class MetaReferenceInferenceImpl(
model_id: str, model_id: str,
messages_batch: list[list[Message]], messages_batch: list[list[Message]],
sampling_params: SamplingParams | None = None, sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
tools: list[ToolDefinition] | None = None, tools: list[ToolDefinition] | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None, tool_config: ToolConfig | None = None,
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
) -> BatchChatCompletionResponse: ) -> BatchChatCompletionResponse:
if sampling_params is None: if sampling_params is None:
sampling_params = SamplingParams() sampling_params = SamplingParams()
if logprobs: if logprobs:
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}" assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
# wrapper request to make it easier to pass around (internal only, not exposed to API) if self.llama_model is None:
raise RuntimeError("Model not initialized")
request_batch = [] request_batch = []
for messages in messages_batch: for messages in messages_batch:
request = ChatCompletionRequest( request = ChatCompletionRequest(
@ -417,31 +421,29 @@ class MetaReferenceInferenceImpl(
messages=messages, messages=messages,
sampling_params=sampling_params, sampling_params=sampling_params,
tools=tools or [], tools=tools or [],
response_format=response_format,
logprobs=logprobs,
tool_config=tool_config or ToolConfig(), tool_config=tool_config or ToolConfig(),
response_format=response_format,
stream=False,
logprobs=logprobs,
) )
self.check_model(request) self.check_model(request)
request_with_raw_union = await convert_request_to_raw(request)
# augment and rewrite messages depending on the model # Type cast to ensure we have the correct type
request.messages = chat_completion_request_to_messages(request, self.llama_model.core_model_id.value) from typing import cast
# download media and convert to raw content so we can send it to the model
request = await convert_request_to_raw(request)
request_batch.append(request)
if self.config.create_distributed_process_group: request_with_raw = cast(ChatCompletionRequestWithRawContent, request_with_raw_union)
if SEMAPHORE.locked(): request_batch.append(request_with_raw)
raise RuntimeError("Only one concurrent request is supported")
results = await self._nonstream_chat_completion(request_batch) results = await self._nonstream_chat_completion(request_batch)
return BatchChatCompletionResponse(batch=results) return BatchChatCompletionResponse(batch=results)
async def _nonstream_chat_completion( async def _nonstream_chat_completion(
self, request_batch: list[ChatCompletionRequest] self, request_batch: list[ChatCompletionRequestWithRawContent]
) -> list[ChatCompletionResponse]: ) -> list[ChatCompletionResponse]:
tokenizer = self.generator.formatter.tokenizer async with SEMAPHORE:
if not self.generator:
first_request = request_batch[0] raise RuntimeError("Generator not initialized")
class ItemState(BaseModel): class ItemState(BaseModel):
tokens: list[int] = [] tokens: list[int] = []
@ -449,81 +451,65 @@ class MetaReferenceInferenceImpl(
stop_reason: StopReason | None = None stop_reason: StopReason | None = None
finished: bool = False finished: bool = False
def impl(): def impl() -> list[ChatCompletionResponse]:
states = [ItemState() for _ in request_batch] if not self.generator:
raise RuntimeError("Generator not initialized")
item_states = [ItemState() for _ in request_batch]
for token_results in self.generator.chat_completion(request_batch): for token_results in self.generator.chat_completion(request_batch):
first = token_results[0] for idx, token_result in enumerate(token_results):
if not first.finished and not first.ignore_token: item_state = item_states[idx]
if os.environ.get("LLAMA_MODELS_DEBUG", "0") in ("1", "2"): if item_state.finished:
cprint(first.text, color="cyan", end="", file=sys.stderr)
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2":
cprint(f"<{first.token}>", color="magenta", end="", file=sys.stderr)
for result in token_results:
idx = result.batch_idx
state = states[idx]
if state.finished or result.ignore_token:
continue continue
state.finished = result.finished if token_result.token == self.generator.formatter.tokenizer.eot_id:
if first_request.logprobs: item_state.stop_reason = StopReason.end_of_turn
state.logprobs.append(TokenLogProbs(logprobs_by_token={result.text: result.logprobs[0]})) item_state.finished = True
elif token_result.token == self.generator.formatter.tokenizer.eom_id:
item_state.stop_reason = StopReason.end_of_message
item_state.finished = True
else:
item_state.tokens.append(token_result.token)
if request_batch[idx].logprobs:
assert len(token_result.logprobs) == 1
item_state.logprobs.append(
TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]})
)
state.tokens.append(result.token) # generate final responses
if result.token == tokenizer.eot_id: completions = []
state.stop_reason = StopReason.end_of_turn for idx, item_state in enumerate(item_states):
elif result.token == tokenizer.eom_id: if not self.generator:
state.stop_reason = StopReason.end_of_message raise RuntimeError("Generator not initialized")
content = self.generator.formatter.tokenizer.decode(item_state.tokens)
results = [] completions.append(
for state in states:
if state.stop_reason is None:
state.stop_reason = StopReason.out_of_tokens
raw_message = self.generator.formatter.decode_assistant_message(state.tokens, state.stop_reason)
results.append(
ChatCompletionResponse( ChatCompletionResponse(
completion_message=CompletionMessage( completion_message=CompletionMessage(
content=raw_message.content, content=content,
stop_reason=raw_message.stop_reason, stop_reason=item_state.stop_reason or StopReason.out_of_tokens,
tool_calls=raw_message.tool_calls, tool_calls=[],
), ),
logprobs=state.logprobs if first_request.logprobs else None, logprobs=item_state.logprobs if request_batch[idx].logprobs else None,
) )
) )
return results return completions
if self.config.create_distributed_process_group: return await asyncio.get_event_loop().run_in_executor(None, impl)
async with SEMAPHORE:
return impl()
else:
return impl()
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: async def _stream_chat_completion(
self, request: ChatCompletionRequestWithRawContent
) -> AsyncIterator[ChatCompletionResponseStreamChunk]:
if not self.generator:
raise RuntimeError("Generator not initialized")
tokenizer = self.generator.formatter.tokenizer tokenizer = self.generator.formatter.tokenizer
def impl():
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta=TextDelta(text=""),
)
)
tokens = []
logprobs = []
stop_reason = None stop_reason = None
ipython = False
for token_results in self.generator.chat_completion([request]): for token_results in self.generator.chat_completion([request]):
token_result = token_results[0] token_result = token_results[0]
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1":
cprint(token_result.text, color="cyan", end="", file=sys.stderr)
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2":
cprint(f"<{token_result.token}>", color="magenta", end="", file=sys.stderr)
if token_result.token == tokenizer.eot_id: if token_result.token == tokenizer.eot_id:
stop_reason = StopReason.end_of_turn stop_reason = StopReason.end_of_turn
text = "" text = ""
@ -533,99 +519,18 @@ class MetaReferenceInferenceImpl(
else: else:
text = token_result.text text = token_result.text
if request.logprobs: logprobs = None
assert len(token_result.logprobs) == 1
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
tokens.append(token_result.token)
if not ipython and token_result.text.startswith("<|python_tag|>"):
ipython = True
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
tool_call="",
parse_status=ToolCallParseStatus.started,
),
)
)
continue
if token_result.token == tokenizer.eot_id:
stop_reason = StopReason.end_of_turn
text = ""
elif token_result.token == tokenizer.eom_id:
stop_reason = StopReason.end_of_message
text = ""
else:
text = token_result.text
if ipython:
delta = ToolCallDelta(
tool_call=text,
parse_status=ToolCallParseStatus.in_progress,
)
else:
delta = TextDelta(text=text)
if stop_reason is None: if stop_reason is None:
if request.logprobs: if request.logprobs:
assert len(token_result.logprobs) == 1 assert len(token_result.logprobs) == 1
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]})) logprobs = [TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]})]
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress, event_type=ChatCompletionResponseEventType.progress,
delta=delta, delta=TextDelta(text=text),
stop_reason=stop_reason,
logprobs=logprobs if request.logprobs else None, logprobs=logprobs if request.logprobs else None,
)
)
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
message = self.generator.formatter.decode_assistant_message(tokens, stop_reason)
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
tool_call="",
parse_status=ToolCallParseStatus.failed,
),
stop_reason=stop_reason, stop_reason=stop_reason,
) )
) )
for tool_call in message.tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
tool_call=tool_call,
parse_status=ToolCallParseStatus.succeeded,
),
stop_reason=stop_reason,
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta=TextDelta(text=""),
stop_reason=stop_reason,
)
)
if self.config.create_distributed_process_group:
async with SEMAPHORE:
for x in impl():
yield x
else:
for x in impl():
yield x

View file

@ -31,7 +31,7 @@ log = logging.getLogger(__name__)
class SentenceTransformerEmbeddingMixin: class SentenceTransformerEmbeddingMixin:
model_store: ModelStore model_store: ModelStore | None = None
async def embeddings( async def embeddings(
self, self,
@ -41,7 +41,11 @@ class SentenceTransformerEmbeddingMixin:
output_dimension: int | None = None, output_dimension: int | None = None,
task_type: EmbeddingTaskType | None = None, task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
if self.model_store is None:
raise RuntimeError("Model store is not initialized")
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)
if model.provider_resource_id is None:
raise RuntimeError("Model provider resource ID is not set")
embedding_model = self._load_sentence_transformer_model(model.provider_resource_id) embedding_model = self._load_sentence_transformer_model(model.provider_resource_id)
embeddings = embedding_model.encode( embeddings = embedding_model.encode(
[interleaved_content_as_str(content) for content in contents], show_progress_bar=False [interleaved_content_as_str(content) for content in contents], show_progress_bar=False
@ -62,7 +66,11 @@ class SentenceTransformerEmbeddingMixin:
raise ValueError("Empty list not supported") raise ValueError("Empty list not supported")
# Get the model and generate embeddings # Get the model and generate embeddings
if self.model_store is None:
raise RuntimeError("Model store is not initialized")
model_obj = await self.model_store.get_model(model) model_obj = await self.model_store.get_model(model)
if model_obj.provider_resource_id is None:
raise RuntimeError("Model provider resource ID is not set")
embedding_model = self._load_sentence_transformer_model(model_obj.provider_resource_id) embedding_model = self._load_sentence_transformer_model(model_obj.provider_resource_id)
embeddings = embedding_model.encode(input_list, show_progress_bar=False) embeddings = embedding_model.encode(input_list, show_progress_bar=False)

View file

@ -246,7 +246,7 @@ exclude = [
"^llama_stack/providers/inline/agents/meta_reference/agents\\.py$", "^llama_stack/providers/inline/agents/meta_reference/agents\\.py$",
"^llama_stack/providers/inline/datasetio/localfs/", "^llama_stack/providers/inline/datasetio/localfs/",
"^llama_stack/providers/inline/eval/meta_reference/eval\\.py$", "^llama_stack/providers/inline/eval/meta_reference/eval\\.py$",
"^llama_stack/providers/inline/inference/meta_reference/inference\\.py$", "^llama_stack/providers/inline/inference/meta_reference/config\\.py$",
"^llama_stack/models/llama/llama3/generation\\.py$", "^llama_stack/models/llama/llama3/generation\\.py$",
"^llama_stack/models/llama/llama3/multimodal/model\\.py$", "^llama_stack/models/llama/llama3/multimodal/model\\.py$",
"^llama_stack/models/llama/llama4/", "^llama_stack/models/llama/llama4/",