mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-28 15:02:37 +00:00
chore(api): add mypy coverage to meta_reference_inference
Signed-off-by: Mustafa Elbehery <melbeher@redhat.com>
This commit is contained in:
parent
d880c2df0e
commit
74103e4eee
3 changed files with 215 additions and 302 deletions
|
@ -5,17 +5,13 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from collections.abc import AsyncGenerator
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
TextDelta,
|
||||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
BatchChatCompletionResponse,
|
||||
|
@ -43,13 +39,15 @@ from llama_stack.apis.inference import (
|
|||
ToolPromptFormat,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.apis.models import Model as ApiModel
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
|
||||
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
|
||||
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
|
||||
from llama_stack.models.llama.sku_list import resolve_model
|
||||
from llama_stack.models.llama.sku_types import Model as LlamaModel
|
||||
from llama_stack.models.llama.sku_types import ModelFamily
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.embedding_mixin import (
|
||||
|
@ -64,8 +62,9 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
|||
OpenAICompletionToLlamaStackMixin,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
ChatCompletionRequestWithRawContent,
|
||||
CompletionRequestWithRawContent,
|
||||
augment_content_with_response_format_prompt,
|
||||
chat_completion_request_to_messages,
|
||||
convert_request_to_raw,
|
||||
)
|
||||
|
||||
|
@ -79,7 +78,7 @@ log = get_logger(__name__, category="inference")
|
|||
SEMAPHORE = asyncio.Semaphore(1)
|
||||
|
||||
|
||||
def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> LlamaGenerator:
|
||||
def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: LlamaModel) -> LlamaGenerator:
|
||||
return LlamaGenerator(config, model_id, llama_model)
|
||||
|
||||
|
||||
|
@ -92,20 +91,23 @@ class MetaReferenceInferenceImpl(
|
|||
):
|
||||
def __init__(self, config: MetaReferenceInferenceConfig) -> None:
|
||||
self.config = config
|
||||
self.model_id = None
|
||||
self.llama_model = None
|
||||
self.model_id: str | None = None
|
||||
self.llama_model: LlamaModel | None = None
|
||||
self.generator: LlamaGenerator | LlamaModelParallelGenerator | None = None
|
||||
self.model_registry_helper: ModelRegistryHelper | None = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
if self.config.create_distributed_process_group:
|
||||
self.generator.stop()
|
||||
if self.config.create_distributed_process_group and self.generator:
|
||||
if hasattr(self.generator, "stop") and callable(self.generator.stop):
|
||||
self.generator.stop()
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
pass
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
async def register_model(self, model: ApiModel) -> ApiModel:
|
||||
llama_model = (
|
||||
resolve_model(model.metadata["llama_model"])
|
||||
if "llama_model" in model.metadata
|
||||
|
@ -127,7 +129,8 @@ class MetaReferenceInferenceImpl(
|
|||
model = await self.model_registry_helper.register_model(model)
|
||||
|
||||
if model.model_type == ModelType.embedding:
|
||||
self._load_sentence_transformer_model(model.provider_resource_id)
|
||||
if model.provider_resource_id is not None:
|
||||
self._load_sentence_transformer_model(model.provider_resource_id)
|
||||
|
||||
# TODO: what is this?! you can't really specify skipping via model metadata
|
||||
# kill this madness
|
||||
|
@ -137,10 +140,10 @@ class MetaReferenceInferenceImpl(
|
|||
await self.load_model(model.identifier, llama_model)
|
||||
return model
|
||||
|
||||
async def load_model(self, model_id, llama_model) -> None:
|
||||
async def load_model(self, model_id: str, llama_model: LlamaModel) -> None:
|
||||
log.info(f"Loading model `{model_id}`")
|
||||
|
||||
builder_params = [self.config, model_id, llama_model]
|
||||
builder_params: list[Any] = [self.config, model_id, llama_model]
|
||||
|
||||
if self.config.create_distributed_process_group:
|
||||
self.generator = LlamaModelParallelGenerator(
|
||||
|
@ -173,7 +176,7 @@ class MetaReferenceInferenceImpl(
|
|||
)
|
||||
log.info("Warmed up!")
|
||||
|
||||
def check_model(self, request) -> None:
|
||||
def check_model(self, request: CompletionRequest | ChatCompletionRequest) -> None:
|
||||
if self.model_id is None or self.llama_model is None:
|
||||
raise RuntimeError(
|
||||
"No avaible model yet, please register your requested model or add your model in the resouces first"
|
||||
|
@ -189,7 +192,7 @@ class MetaReferenceInferenceImpl(
|
|||
response_format: ResponseFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> CompletionResponse | CompletionResponseStreamChunk:
|
||||
) -> CompletionResponse | AsyncIterator[CompletionResponseStreamChunk]:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
if logprobs:
|
||||
|
@ -205,12 +208,17 @@ class MetaReferenceInferenceImpl(
|
|||
logprobs=logprobs,
|
||||
)
|
||||
self.check_model(request)
|
||||
request = await convert_request_to_raw(request)
|
||||
request_with_raw_union = await convert_request_to_raw(request)
|
||||
|
||||
# Type cast to ensure we have the correct type
|
||||
from typing import cast
|
||||
|
||||
request_with_raw = cast(CompletionRequestWithRawContent, request_with_raw_union)
|
||||
|
||||
if request.stream:
|
||||
return self._stream_completion(request)
|
||||
return self._stream_completion(request_with_raw)
|
||||
else:
|
||||
results = await self._nonstream_completion([request])
|
||||
results = await self._nonstream_completion([request_with_raw])
|
||||
return results[0]
|
||||
|
||||
async def batch_completion(
|
||||
|
@ -219,7 +227,6 @@ class MetaReferenceInferenceImpl(
|
|||
content_batch: list[InterleavedContent],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> BatchCompletionResponse:
|
||||
if sampling_params is None:
|
||||
|
@ -238,159 +245,155 @@ class MetaReferenceInferenceImpl(
|
|||
content=content,
|
||||
sampling_params=sampling_params,
|
||||
response_format=response_format,
|
||||
stream=stream,
|
||||
stream=False,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
self.check_model(request)
|
||||
request = await convert_request_to_raw(request)
|
||||
request_batch.append(request)
|
||||
request_with_raw_union = await convert_request_to_raw(request)
|
||||
|
||||
# Type cast to ensure we have the correct type
|
||||
from typing import cast
|
||||
|
||||
request_with_raw = cast(CompletionRequestWithRawContent, request_with_raw_union)
|
||||
request_batch.append(request_with_raw)
|
||||
|
||||
results = await self._nonstream_completion(request_batch)
|
||||
return BatchCompletionResponse(batch=results)
|
||||
|
||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
async def _stream_completion(
|
||||
self, request: CompletionRequestWithRawContent
|
||||
) -> AsyncIterator[CompletionResponseStreamChunk]:
|
||||
if not self.generator:
|
||||
raise RuntimeError("Generator not initialized")
|
||||
tokenizer = self.generator.formatter.tokenizer
|
||||
|
||||
def impl():
|
||||
stop_reason = None
|
||||
stop_reason = None
|
||||
|
||||
for token_results in self.generator.completion([request]):
|
||||
token_result = token_results[0]
|
||||
if token_result.token == tokenizer.eot_id:
|
||||
stop_reason = StopReason.end_of_turn
|
||||
text = ""
|
||||
elif token_result.token == tokenizer.eom_id:
|
||||
stop_reason = StopReason.end_of_message
|
||||
text = ""
|
||||
else:
|
||||
text = token_result.text
|
||||
|
||||
logprobs = None
|
||||
if stop_reason is None:
|
||||
if request.logprobs:
|
||||
assert len(token_result.logprobs) == 1
|
||||
|
||||
logprobs = [TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]})]
|
||||
|
||||
yield CompletionResponseStreamChunk(
|
||||
delta=text,
|
||||
stop_reason=stop_reason,
|
||||
logprobs=logprobs if request.logprobs else None,
|
||||
)
|
||||
for token_results in self.generator.completion([request]):
|
||||
token_result = token_results[0]
|
||||
if token_result.token == tokenizer.eot_id:
|
||||
stop_reason = StopReason.end_of_turn
|
||||
text = ""
|
||||
elif token_result.token == tokenizer.eom_id:
|
||||
stop_reason = StopReason.end_of_message
|
||||
text = ""
|
||||
else:
|
||||
text = token_result.text
|
||||
|
||||
logprobs = None
|
||||
if stop_reason is None:
|
||||
yield CompletionResponseStreamChunk(
|
||||
delta="",
|
||||
stop_reason=StopReason.out_of_tokens,
|
||||
)
|
||||
if request.logprobs:
|
||||
assert len(token_result.logprobs) == 1
|
||||
|
||||
if self.config.create_distributed_process_group:
|
||||
async with SEMAPHORE:
|
||||
for x in impl():
|
||||
yield x
|
||||
else:
|
||||
for x in impl():
|
||||
yield x
|
||||
logprobs = [TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]})]
|
||||
|
||||
async def _nonstream_completion(self, request_batch: list[CompletionRequest]) -> list[CompletionResponse]:
|
||||
tokenizer = self.generator.formatter.tokenizer
|
||||
yield CompletionResponseStreamChunk(
|
||||
delta=text,
|
||||
stop_reason=stop_reason,
|
||||
logprobs=logprobs if request.logprobs else None,
|
||||
)
|
||||
|
||||
first_request = request_batch[0]
|
||||
async def _nonstream_completion(
|
||||
self, request_batch: list[CompletionRequestWithRawContent]
|
||||
) -> list[CompletionResponse]:
|
||||
async with SEMAPHORE:
|
||||
if not self.generator:
|
||||
raise RuntimeError("Generator not initialized")
|
||||
|
||||
class ItemState(BaseModel):
|
||||
tokens: list[int] = []
|
||||
logprobs: list[TokenLogProbs] = []
|
||||
stop_reason: StopReason | None = None
|
||||
finished: bool = False
|
||||
class ItemState(BaseModel):
|
||||
tokens: list[int] = []
|
||||
logprobs: list[TokenLogProbs] = []
|
||||
stop_reason: StopReason | None = None
|
||||
finished: bool = False
|
||||
|
||||
def impl():
|
||||
states = [ItemState() for _ in request_batch]
|
||||
def impl() -> list[CompletionResponse]:
|
||||
if not self.generator:
|
||||
raise RuntimeError("Generator not initialized")
|
||||
|
||||
results = []
|
||||
for token_results in self.generator.completion(request_batch):
|
||||
for result in token_results:
|
||||
idx = result.batch_idx
|
||||
state = states[idx]
|
||||
if state.finished or result.ignore_token:
|
||||
continue
|
||||
item_states = [ItemState() for _ in request_batch]
|
||||
|
||||
state.finished = result.finished
|
||||
if first_request.logprobs:
|
||||
state.logprobs.append(TokenLogProbs(logprobs_by_token={result.text: result.logprobs[0]}))
|
||||
for token_results in self.generator.completion(request_batch):
|
||||
for idx, token_result in enumerate(token_results):
|
||||
item_state = item_states[idx]
|
||||
if item_state.finished:
|
||||
continue
|
||||
|
||||
state.tokens.append(result.token)
|
||||
if result.token == tokenizer.eot_id:
|
||||
state.stop_reason = StopReason.end_of_turn
|
||||
elif result.token == tokenizer.eom_id:
|
||||
state.stop_reason = StopReason.end_of_message
|
||||
if token_result.token == self.generator.formatter.tokenizer.eot_id:
|
||||
item_state.stop_reason = StopReason.end_of_turn
|
||||
item_state.finished = True
|
||||
elif token_result.token == self.generator.formatter.tokenizer.eom_id:
|
||||
item_state.stop_reason = StopReason.end_of_message
|
||||
item_state.finished = True
|
||||
else:
|
||||
item_state.tokens.append(token_result.token)
|
||||
if request_batch[idx].logprobs:
|
||||
assert len(token_result.logprobs) == 1
|
||||
item_state.logprobs.append(
|
||||
TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]})
|
||||
)
|
||||
|
||||
for state in states:
|
||||
if state.stop_reason is None:
|
||||
state.stop_reason = StopReason.out_of_tokens
|
||||
# generate final responses
|
||||
completions = []
|
||||
for idx, item_state in enumerate(item_states):
|
||||
if not self.generator:
|
||||
raise RuntimeError("Generator not initialized")
|
||||
content = self.generator.formatter.tokenizer.decode(item_state.tokens)
|
||||
|
||||
if state.tokens[-1] in self.generator.formatter.tokenizer.stop_tokens:
|
||||
state.tokens = state.tokens[:-1]
|
||||
content = self.generator.formatter.tokenizer.decode(state.tokens)
|
||||
results.append(
|
||||
CompletionResponse(
|
||||
content=content,
|
||||
stop_reason=state.stop_reason,
|
||||
logprobs=state.logprobs if first_request.logprobs else None,
|
||||
completions.append(
|
||||
CompletionResponse(
|
||||
content=content,
|
||||
stop_reason=item_state.stop_reason or StopReason.out_of_tokens,
|
||||
logprobs=item_state.logprobs if request_batch[idx].logprobs else None,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
return results
|
||||
return completions
|
||||
|
||||
if self.config.create_distributed_process_group:
|
||||
async with SEMAPHORE:
|
||||
return impl()
|
||||
else:
|
||||
return impl()
|
||||
return await asyncio.get_event_loop().run_in_executor(None, impl)
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: list[Message],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_choice: ToolChoice | None = ToolChoice.auto,
|
||||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
) -> AsyncGenerator:
|
||||
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
if logprobs:
|
||||
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
|
||||
|
||||
# wrapper request to make it easier to pass around (internal only, not exposed to API)
|
||||
if self.llama_model is None:
|
||||
raise RuntimeError("Model not initialized")
|
||||
|
||||
request = ChatCompletionRequest(
|
||||
model=model_id,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
tool_config=tool_config or ToolConfig(),
|
||||
response_format=response_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
tool_config=tool_config or ToolConfig(),
|
||||
)
|
||||
self.check_model(request)
|
||||
request_with_raw_union = await convert_request_to_raw(request)
|
||||
|
||||
# augment and rewrite messages depending on the model
|
||||
request.messages = chat_completion_request_to_messages(request, self.llama_model.core_model_id.value)
|
||||
# download media and convert to raw content so we can send it to the model
|
||||
request = await convert_request_to_raw(request)
|
||||
# Type cast to ensure we have the correct type
|
||||
from typing import cast
|
||||
|
||||
if self.config.create_distributed_process_group:
|
||||
if SEMAPHORE.locked():
|
||||
raise RuntimeError("Only one concurrent request is supported")
|
||||
request_with_raw = cast(ChatCompletionRequestWithRawContent, request_with_raw_union)
|
||||
|
||||
if request.stream:
|
||||
return self._stream_chat_completion(request)
|
||||
return self._stream_chat_completion(request_with_raw)
|
||||
else:
|
||||
results = await self._nonstream_chat_completion([request])
|
||||
results = await self._nonstream_chat_completion([request_with_raw])
|
||||
return results[0]
|
||||
|
||||
async def batch_chat_completion(
|
||||
|
@ -398,18 +401,19 @@ class MetaReferenceInferenceImpl(
|
|||
model_id: str,
|
||||
messages_batch: list[list[Message]],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> BatchChatCompletionResponse:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
if logprobs:
|
||||
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
|
||||
|
||||
# wrapper request to make it easier to pass around (internal only, not exposed to API)
|
||||
if self.llama_model is None:
|
||||
raise RuntimeError("Model not initialized")
|
||||
|
||||
request_batch = []
|
||||
for messages in messages_batch:
|
||||
request = ChatCompletionRequest(
|
||||
|
@ -417,215 +421,116 @@ class MetaReferenceInferenceImpl(
|
|||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
response_format=response_format,
|
||||
logprobs=logprobs,
|
||||
tool_config=tool_config or ToolConfig(),
|
||||
response_format=response_format,
|
||||
stream=False,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
self.check_model(request)
|
||||
request_with_raw_union = await convert_request_to_raw(request)
|
||||
|
||||
# augment and rewrite messages depending on the model
|
||||
request.messages = chat_completion_request_to_messages(request, self.llama_model.core_model_id.value)
|
||||
# download media and convert to raw content so we can send it to the model
|
||||
request = await convert_request_to_raw(request)
|
||||
request_batch.append(request)
|
||||
# Type cast to ensure we have the correct type
|
||||
from typing import cast
|
||||
|
||||
if self.config.create_distributed_process_group:
|
||||
if SEMAPHORE.locked():
|
||||
raise RuntimeError("Only one concurrent request is supported")
|
||||
request_with_raw = cast(ChatCompletionRequestWithRawContent, request_with_raw_union)
|
||||
request_batch.append(request_with_raw)
|
||||
|
||||
results = await self._nonstream_chat_completion(request_batch)
|
||||
return BatchChatCompletionResponse(batch=results)
|
||||
|
||||
async def _nonstream_chat_completion(
|
||||
self, request_batch: list[ChatCompletionRequest]
|
||||
self, request_batch: list[ChatCompletionRequestWithRawContent]
|
||||
) -> list[ChatCompletionResponse]:
|
||||
tokenizer = self.generator.formatter.tokenizer
|
||||
async with SEMAPHORE:
|
||||
if not self.generator:
|
||||
raise RuntimeError("Generator not initialized")
|
||||
|
||||
first_request = request_batch[0]
|
||||
class ItemState(BaseModel):
|
||||
tokens: list[int] = []
|
||||
logprobs: list[TokenLogProbs] = []
|
||||
stop_reason: StopReason | None = None
|
||||
finished: bool = False
|
||||
|
||||
class ItemState(BaseModel):
|
||||
tokens: list[int] = []
|
||||
logprobs: list[TokenLogProbs] = []
|
||||
stop_reason: StopReason | None = None
|
||||
finished: bool = False
|
||||
def impl() -> list[ChatCompletionResponse]:
|
||||
if not self.generator:
|
||||
raise RuntimeError("Generator not initialized")
|
||||
|
||||
def impl():
|
||||
states = [ItemState() for _ in request_batch]
|
||||
item_states = [ItemState() for _ in request_batch]
|
||||
|
||||
for token_results in self.generator.chat_completion(request_batch):
|
||||
first = token_results[0]
|
||||
if not first.finished and not first.ignore_token:
|
||||
if os.environ.get("LLAMA_MODELS_DEBUG", "0") in ("1", "2"):
|
||||
cprint(first.text, color="cyan", end="", file=sys.stderr)
|
||||
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2":
|
||||
cprint(f"<{first.token}>", color="magenta", end="", file=sys.stderr)
|
||||
for token_results in self.generator.chat_completion(request_batch):
|
||||
for idx, token_result in enumerate(token_results):
|
||||
item_state = item_states[idx]
|
||||
if item_state.finished:
|
||||
continue
|
||||
|
||||
for result in token_results:
|
||||
idx = result.batch_idx
|
||||
state = states[idx]
|
||||
if state.finished or result.ignore_token:
|
||||
continue
|
||||
if token_result.token == self.generator.formatter.tokenizer.eot_id:
|
||||
item_state.stop_reason = StopReason.end_of_turn
|
||||
item_state.finished = True
|
||||
elif token_result.token == self.generator.formatter.tokenizer.eom_id:
|
||||
item_state.stop_reason = StopReason.end_of_message
|
||||
item_state.finished = True
|
||||
else:
|
||||
item_state.tokens.append(token_result.token)
|
||||
if request_batch[idx].logprobs:
|
||||
assert len(token_result.logprobs) == 1
|
||||
item_state.logprobs.append(
|
||||
TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]})
|
||||
)
|
||||
|
||||
state.finished = result.finished
|
||||
if first_request.logprobs:
|
||||
state.logprobs.append(TokenLogProbs(logprobs_by_token={result.text: result.logprobs[0]}))
|
||||
# generate final responses
|
||||
completions = []
|
||||
for idx, item_state in enumerate(item_states):
|
||||
if not self.generator:
|
||||
raise RuntimeError("Generator not initialized")
|
||||
content = self.generator.formatter.tokenizer.decode(item_state.tokens)
|
||||
|
||||
state.tokens.append(result.token)
|
||||
if result.token == tokenizer.eot_id:
|
||||
state.stop_reason = StopReason.end_of_turn
|
||||
elif result.token == tokenizer.eom_id:
|
||||
state.stop_reason = StopReason.end_of_message
|
||||
|
||||
results = []
|
||||
for state in states:
|
||||
if state.stop_reason is None:
|
||||
state.stop_reason = StopReason.out_of_tokens
|
||||
|
||||
raw_message = self.generator.formatter.decode_assistant_message(state.tokens, state.stop_reason)
|
||||
results.append(
|
||||
ChatCompletionResponse(
|
||||
completion_message=CompletionMessage(
|
||||
content=raw_message.content,
|
||||
stop_reason=raw_message.stop_reason,
|
||||
tool_calls=raw_message.tool_calls,
|
||||
),
|
||||
logprobs=state.logprobs if first_request.logprobs else None,
|
||||
completions.append(
|
||||
ChatCompletionResponse(
|
||||
completion_message=CompletionMessage(
|
||||
content=content,
|
||||
stop_reason=item_state.stop_reason or StopReason.out_of_tokens,
|
||||
tool_calls=[],
|
||||
),
|
||||
logprobs=item_state.logprobs if request_batch[idx].logprobs else None,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
return results
|
||||
return completions
|
||||
|
||||
if self.config.create_distributed_process_group:
|
||||
async with SEMAPHORE:
|
||||
return impl()
|
||||
else:
|
||||
return impl()
|
||||
return await asyncio.get_event_loop().run_in_executor(None, impl)
|
||||
|
||||
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
||||
async def _stream_chat_completion(
|
||||
self, request: ChatCompletionRequestWithRawContent
|
||||
) -> AsyncIterator[ChatCompletionResponseStreamChunk]:
|
||||
if not self.generator:
|
||||
raise RuntimeError("Generator not initialized")
|
||||
tokenizer = self.generator.formatter.tokenizer
|
||||
|
||||
def impl():
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.start,
|
||||
delta=TextDelta(text=""),
|
||||
)
|
||||
)
|
||||
stop_reason = None
|
||||
|
||||
tokens = []
|
||||
logprobs = []
|
||||
stop_reason = None
|
||||
ipython = False
|
||||
|
||||
for token_results in self.generator.chat_completion([request]):
|
||||
token_result = token_results[0]
|
||||
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1":
|
||||
cprint(token_result.text, color="cyan", end="", file=sys.stderr)
|
||||
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2":
|
||||
cprint(f"<{token_result.token}>", color="magenta", end="", file=sys.stderr)
|
||||
|
||||
if token_result.token == tokenizer.eot_id:
|
||||
stop_reason = StopReason.end_of_turn
|
||||
text = ""
|
||||
elif token_result.token == tokenizer.eom_id:
|
||||
stop_reason = StopReason.end_of_message
|
||||
text = ""
|
||||
else:
|
||||
text = token_result.text
|
||||
for token_results in self.generator.chat_completion([request]):
|
||||
token_result = token_results[0]
|
||||
if token_result.token == tokenizer.eot_id:
|
||||
stop_reason = StopReason.end_of_turn
|
||||
text = ""
|
||||
elif token_result.token == tokenizer.eom_id:
|
||||
stop_reason = StopReason.end_of_message
|
||||
text = ""
|
||||
else:
|
||||
text = token_result.text
|
||||
|
||||
logprobs = None
|
||||
if stop_reason is None:
|
||||
if request.logprobs:
|
||||
assert len(token_result.logprobs) == 1
|
||||
|
||||
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
|
||||
|
||||
tokens.append(token_result.token)
|
||||
|
||||
if not ipython and token_result.text.startswith("<|python_tag|>"):
|
||||
ipython = True
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
tool_call="",
|
||||
parse_status=ToolCallParseStatus.started,
|
||||
),
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
if token_result.token == tokenizer.eot_id:
|
||||
stop_reason = StopReason.end_of_turn
|
||||
text = ""
|
||||
elif token_result.token == tokenizer.eom_id:
|
||||
stop_reason = StopReason.end_of_message
|
||||
text = ""
|
||||
else:
|
||||
text = token_result.text
|
||||
|
||||
if ipython:
|
||||
delta = ToolCallDelta(
|
||||
tool_call=text,
|
||||
parse_status=ToolCallParseStatus.in_progress,
|
||||
)
|
||||
else:
|
||||
delta = TextDelta(text=text)
|
||||
|
||||
if stop_reason is None:
|
||||
if request.logprobs:
|
||||
assert len(token_result.logprobs) == 1
|
||||
|
||||
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=delta,
|
||||
stop_reason=stop_reason,
|
||||
logprobs=logprobs if request.logprobs else None,
|
||||
)
|
||||
)
|
||||
|
||||
if stop_reason is None:
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
|
||||
message = self.generator.formatter.decode_assistant_message(tokens, stop_reason)
|
||||
|
||||
parsed_tool_calls = len(message.tool_calls) > 0
|
||||
if ipython and not parsed_tool_calls:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
tool_call="",
|
||||
parse_status=ToolCallParseStatus.failed,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
for tool_call in message.tool_calls:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
tool_call=tool_call,
|
||||
parse_status=ToolCallParseStatus.succeeded,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
logprobs = [TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]})]
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta=TextDelta(text=""),
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=TextDelta(text=text),
|
||||
logprobs=logprobs if request.logprobs else None,
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
if self.config.create_distributed_process_group:
|
||||
async with SEMAPHORE:
|
||||
for x in impl():
|
||||
yield x
|
||||
else:
|
||||
for x in impl():
|
||||
yield x
|
||||
|
|
|
@ -31,7 +31,7 @@ log = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class SentenceTransformerEmbeddingMixin:
|
||||
model_store: ModelStore
|
||||
model_store: ModelStore | None = None
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
|
@ -41,7 +41,11 @@ class SentenceTransformerEmbeddingMixin:
|
|||
output_dimension: int | None = None,
|
||||
task_type: EmbeddingTaskType | None = None,
|
||||
) -> EmbeddingsResponse:
|
||||
if self.model_store is None:
|
||||
raise RuntimeError("Model store is not initialized")
|
||||
model = await self.model_store.get_model(model_id)
|
||||
if model.provider_resource_id is None:
|
||||
raise RuntimeError("Model provider resource ID is not set")
|
||||
embedding_model = self._load_sentence_transformer_model(model.provider_resource_id)
|
||||
embeddings = embedding_model.encode(
|
||||
[interleaved_content_as_str(content) for content in contents], show_progress_bar=False
|
||||
|
@ -62,7 +66,11 @@ class SentenceTransformerEmbeddingMixin:
|
|||
raise ValueError("Empty list not supported")
|
||||
|
||||
# Get the model and generate embeddings
|
||||
if self.model_store is None:
|
||||
raise RuntimeError("Model store is not initialized")
|
||||
model_obj = await self.model_store.get_model(model)
|
||||
if model_obj.provider_resource_id is None:
|
||||
raise RuntimeError("Model provider resource ID is not set")
|
||||
embedding_model = self._load_sentence_transformer_model(model_obj.provider_resource_id)
|
||||
embeddings = embedding_model.encode(input_list, show_progress_bar=False)
|
||||
|
||||
|
|
|
@ -246,7 +246,7 @@ exclude = [
|
|||
"^llama_stack/providers/inline/agents/meta_reference/agents\\.py$",
|
||||
"^llama_stack/providers/inline/datasetio/localfs/",
|
||||
"^llama_stack/providers/inline/eval/meta_reference/eval\\.py$",
|
||||
"^llama_stack/providers/inline/inference/meta_reference/inference\\.py$",
|
||||
"^llama_stack/providers/inline/inference/meta_reference/config\\.py$",
|
||||
"^llama_stack/models/llama/llama3/generation\\.py$",
|
||||
"^llama_stack/models/llama/llama3/multimodal/model\\.py$",
|
||||
"^llama_stack/models/llama/llama4/",
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue