mirror of
				https://github.com/meta-llama/llama-stack.git
				synced 2025-10-26 01:12:59 +00:00 
			
		
		
		
	This flips #2823 and #2805 by making the Stack periodically query the providers for models rather than the providers going behind the back and calling "register" on to the registry themselves. This also adds support for model listing for all other providers via `ModelRegistryHelper`. Once this is done, we do not need to manually list or register models via `run.yaml` and it will remove both noise and annoyance (setting `INFERENCE_MODEL` environment variables, for example) from the new user experience. In addition, it adds a configuration variable `allowed_models` which can be used to optionally restrict the set of models exposed from a provider.
		
			
				
	
	
		
			637 lines
		
	
	
	
		
			24 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			637 lines
		
	
	
	
		
			24 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright (c) Meta Platforms, Inc. and affiliates.
 | |
| # All rights reserved.
 | |
| #
 | |
| # This source code is licensed under the terms described in the LICENSE file in
 | |
| # the root directory of this source tree.
 | |
| 
 | |
| import asyncio
 | |
| import os
 | |
| import sys
 | |
| from collections.abc import AsyncGenerator
 | |
| 
 | |
| from pydantic import BaseModel
 | |
| from termcolor import cprint
 | |
| 
 | |
| from llama_stack.apis.common.content_types import (
 | |
|     TextDelta,
 | |
|     ToolCallDelta,
 | |
|     ToolCallParseStatus,
 | |
| )
 | |
| from llama_stack.apis.inference import (
 | |
|     BatchChatCompletionResponse,
 | |
|     BatchCompletionResponse,
 | |
|     ChatCompletionRequest,
 | |
|     ChatCompletionResponse,
 | |
|     ChatCompletionResponseEvent,
 | |
|     ChatCompletionResponseEventType,
 | |
|     ChatCompletionResponseStreamChunk,
 | |
|     CompletionMessage,
 | |
|     CompletionRequest,
 | |
|     CompletionResponse,
 | |
|     CompletionResponseStreamChunk,
 | |
|     InferenceProvider,
 | |
|     InterleavedContent,
 | |
|     LogProbConfig,
 | |
|     Message,
 | |
|     ResponseFormat,
 | |
|     SamplingParams,
 | |
|     StopReason,
 | |
|     TokenLogProbs,
 | |
|     ToolChoice,
 | |
|     ToolConfig,
 | |
|     ToolDefinition,
 | |
|     ToolPromptFormat,
 | |
|     UserMessage,
 | |
| )
 | |
| from llama_stack.apis.models import Model, ModelType
 | |
| from llama_stack.log import get_logger
 | |
| from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
 | |
| from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
 | |
| from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
 | |
| from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
 | |
| from llama_stack.models.llama.sku_list import resolve_model
 | |
| from llama_stack.models.llama.sku_types import ModelFamily
 | |
| from llama_stack.providers.datatypes import ModelsProtocolPrivate
 | |
| from llama_stack.providers.utils.inference.embedding_mixin import (
 | |
|     SentenceTransformerEmbeddingMixin,
 | |
| )
 | |
| from llama_stack.providers.utils.inference.model_registry import (
 | |
|     ModelRegistryHelper,
 | |
|     build_hf_repo_model_entry,
 | |
| )
 | |
| from llama_stack.providers.utils.inference.openai_compat import (
 | |
|     OpenAIChatCompletionToLlamaStackMixin,
 | |
|     OpenAICompletionToLlamaStackMixin,
 | |
| )
 | |
| from llama_stack.providers.utils.inference.prompt_adapter import (
 | |
|     augment_content_with_response_format_prompt,
 | |
|     chat_completion_request_to_messages,
 | |
|     convert_request_to_raw,
 | |
| )
 | |
| 
 | |
| from .config import MetaReferenceInferenceConfig
 | |
| from .generators import LlamaGenerator
 | |
| from .model_parallel import LlamaModelParallelGenerator
 | |
| 
 | |
| log = get_logger(__name__, category="inference")
 | |
| # there's a single model parallel process running serving the model. for now,
 | |
| # we don't support multiple concurrent requests to this process.
 | |
| SEMAPHORE = asyncio.Semaphore(1)
 | |
| 
 | |
| 
 | |
| def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> LlamaGenerator:
 | |
|     return LlamaGenerator(config, model_id, llama_model)
 | |
| 
 | |
| 
 | |
| class MetaReferenceInferenceImpl(
 | |
|     OpenAICompletionToLlamaStackMixin,
 | |
|     OpenAIChatCompletionToLlamaStackMixin,
 | |
|     SentenceTransformerEmbeddingMixin,
 | |
|     InferenceProvider,
 | |
|     ModelsProtocolPrivate,
 | |
| ):
 | |
|     def __init__(self, config: MetaReferenceInferenceConfig) -> None:
 | |
|         self.config = config
 | |
|         self.model_id = None
 | |
|         self.llama_model = None
 | |
| 
 | |
|     async def initialize(self) -> None:
 | |
|         pass
 | |
| 
 | |
|     async def shutdown(self) -> None:
 | |
|         if self.config.create_distributed_process_group:
 | |
|             self.generator.stop()
 | |
| 
 | |
|     async def should_refresh_models(self) -> bool:
 | |
|         return False
 | |
| 
 | |
|     async def list_models(self) -> list[Model] | None:
 | |
|         return None
 | |
| 
 | |
|     async def unregister_model(self, model_id: str) -> None:
 | |
|         pass
 | |
| 
 | |
|     async def register_model(self, model: Model) -> Model:
 | |
|         llama_model = (
 | |
|             resolve_model(model.metadata["llama_model"])
 | |
|             if "llama_model" in model.metadata
 | |
|             else resolve_model(model.identifier)
 | |
|         )
 | |
|         if llama_model is None:
 | |
|             raise ValueError(
 | |
|                 "Please make sure your llama_model in model metadata or model identifier is in Llama SKU list"
 | |
|             )
 | |
| 
 | |
|         self.model_registry_helper = ModelRegistryHelper(
 | |
|             [
 | |
|                 build_hf_repo_model_entry(
 | |
|                     llama_model.descriptor(),
 | |
|                     llama_model.core_model_id.value,
 | |
|                 )
 | |
|             ],
 | |
|         )
 | |
|         model = await self.model_registry_helper.register_model(model)
 | |
| 
 | |
|         if model.model_type == ModelType.embedding:
 | |
|             self._load_sentence_transformer_model(model.provider_resource_id)
 | |
| 
 | |
|         # TODO: what is this?! you can't really specify skipping via model metadata
 | |
|         # kill this madness
 | |
|         if "skip_load" in model.metadata and model.metadata["skip_load"]:
 | |
|             return model
 | |
| 
 | |
|         await self.load_model(model.identifier, llama_model)
 | |
|         return model
 | |
| 
 | |
|     async def load_model(self, model_id, llama_model) -> None:
 | |
|         log.info(f"Loading model `{model_id}`")
 | |
| 
 | |
|         builder_params = [self.config, model_id, llama_model]
 | |
| 
 | |
|         if self.config.create_distributed_process_group:
 | |
|             self.generator = LlamaModelParallelGenerator(
 | |
|                 model_parallel_size=self.config.model_parallel_size or llama_model.pth_file_count,
 | |
|                 builder_fn=llama_builder_fn,
 | |
|                 builder_params=builder_params,
 | |
|                 formatter=(
 | |
|                     Llama4ChatFormat(Llama4Tokenizer.get_instance())
 | |
|                     if llama_model.model_family == ModelFamily.llama4
 | |
|                     else Llama3ChatFormat(Llama3Tokenizer.get_instance())
 | |
|                 ),
 | |
|             )
 | |
|             self.generator.start()
 | |
|         else:
 | |
|             self.generator = llama_builder_fn(*builder_params)
 | |
| 
 | |
|         self.model_id = model_id
 | |
|         self.llama_model = llama_model
 | |
| 
 | |
|         log.info("Warming up...")
 | |
|         await self.completion(
 | |
|             model_id=model_id,
 | |
|             content="Hello, world!",
 | |
|             sampling_params=SamplingParams(max_tokens=10),
 | |
|         )
 | |
|         await self.chat_completion(
 | |
|             model_id=model_id,
 | |
|             messages=[UserMessage(content="Hi how are you?")],
 | |
|             sampling_params=SamplingParams(max_tokens=20),
 | |
|         )
 | |
|         log.info("Warmed up!")
 | |
| 
 | |
|     def check_model(self, request) -> None:
 | |
|         if self.model_id is None or self.llama_model is None:
 | |
|             raise RuntimeError(
 | |
|                 "No avaible model yet, please register your requested model or add your model in the resouces first"
 | |
|             )
 | |
|         elif request.model != self.model_id:
 | |
|             raise RuntimeError(f"Model mismatch: request model: {request.model} != loaded model: {self.model_id}")
 | |
| 
 | |
|     async def completion(
 | |
|         self,
 | |
|         model_id: str,
 | |
|         content: InterleavedContent,
 | |
|         sampling_params: SamplingParams | None = None,
 | |
|         response_format: ResponseFormat | None = None,
 | |
|         stream: bool | None = False,
 | |
|         logprobs: LogProbConfig | None = None,
 | |
|     ) -> CompletionResponse | CompletionResponseStreamChunk:
 | |
|         if sampling_params is None:
 | |
|             sampling_params = SamplingParams()
 | |
|         if logprobs:
 | |
|             assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
 | |
| 
 | |
|         content = augment_content_with_response_format_prompt(response_format, content)
 | |
|         request = CompletionRequest(
 | |
|             model=model_id,
 | |
|             content=content,
 | |
|             sampling_params=sampling_params,
 | |
|             response_format=response_format,
 | |
|             stream=stream,
 | |
|             logprobs=logprobs,
 | |
|         )
 | |
|         self.check_model(request)
 | |
|         request = await convert_request_to_raw(request)
 | |
| 
 | |
|         if request.stream:
 | |
|             return self._stream_completion(request)
 | |
|         else:
 | |
|             results = await self._nonstream_completion([request])
 | |
|             return results[0]
 | |
| 
 | |
|     async def batch_completion(
 | |
|         self,
 | |
|         model_id: str,
 | |
|         content_batch: list[InterleavedContent],
 | |
|         sampling_params: SamplingParams | None = None,
 | |
|         response_format: ResponseFormat | None = None,
 | |
|         stream: bool | None = False,
 | |
|         logprobs: LogProbConfig | None = None,
 | |
|     ) -> BatchCompletionResponse:
 | |
|         if sampling_params is None:
 | |
|             sampling_params = SamplingParams()
 | |
|         if logprobs:
 | |
|             assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
 | |
| 
 | |
|         content_batch = [
 | |
|             augment_content_with_response_format_prompt(response_format, content) for content in content_batch
 | |
|         ]
 | |
| 
 | |
|         request_batch = []
 | |
|         for content in content_batch:
 | |
|             request = CompletionRequest(
 | |
|                 model=model_id,
 | |
|                 content=content,
 | |
|                 sampling_params=sampling_params,
 | |
|                 response_format=response_format,
 | |
|                 stream=stream,
 | |
|                 logprobs=logprobs,
 | |
|             )
 | |
|             self.check_model(request)
 | |
|             request = await convert_request_to_raw(request)
 | |
|             request_batch.append(request)
 | |
| 
 | |
|         results = await self._nonstream_completion(request_batch)
 | |
|         return BatchCompletionResponse(batch=results)
 | |
| 
 | |
|     async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
 | |
|         tokenizer = self.generator.formatter.tokenizer
 | |
| 
 | |
|         def impl():
 | |
|             stop_reason = None
 | |
| 
 | |
|             for token_results in self.generator.completion([request]):
 | |
|                 token_result = token_results[0]
 | |
|                 if token_result.token == tokenizer.eot_id:
 | |
|                     stop_reason = StopReason.end_of_turn
 | |
|                     text = ""
 | |
|                 elif token_result.token == tokenizer.eom_id:
 | |
|                     stop_reason = StopReason.end_of_message
 | |
|                     text = ""
 | |
|                 else:
 | |
|                     text = token_result.text
 | |
| 
 | |
|                 logprobs = None
 | |
|                 if stop_reason is None:
 | |
|                     if request.logprobs:
 | |
|                         assert len(token_result.logprobs) == 1
 | |
| 
 | |
|                         logprobs = [TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]})]
 | |
| 
 | |
|                 yield CompletionResponseStreamChunk(
 | |
|                     delta=text,
 | |
|                     stop_reason=stop_reason,
 | |
|                     logprobs=logprobs if request.logprobs else None,
 | |
|                 )
 | |
| 
 | |
|             if stop_reason is None:
 | |
|                 yield CompletionResponseStreamChunk(
 | |
|                     delta="",
 | |
|                     stop_reason=StopReason.out_of_tokens,
 | |
|                 )
 | |
| 
 | |
|         if self.config.create_distributed_process_group:
 | |
|             async with SEMAPHORE:
 | |
|                 for x in impl():
 | |
|                     yield x
 | |
|         else:
 | |
|             for x in impl():
 | |
|                 yield x
 | |
| 
 | |
|     async def _nonstream_completion(self, request_batch: list[CompletionRequest]) -> list[CompletionResponse]:
 | |
|         tokenizer = self.generator.formatter.tokenizer
 | |
| 
 | |
|         first_request = request_batch[0]
 | |
| 
 | |
|         class ItemState(BaseModel):
 | |
|             tokens: list[int] = []
 | |
|             logprobs: list[TokenLogProbs] = []
 | |
|             stop_reason: StopReason | None = None
 | |
|             finished: bool = False
 | |
| 
 | |
|         def impl():
 | |
|             states = [ItemState() for _ in request_batch]
 | |
| 
 | |
|             results = []
 | |
|             for token_results in self.generator.completion(request_batch):
 | |
|                 for result in token_results:
 | |
|                     idx = result.batch_idx
 | |
|                     state = states[idx]
 | |
|                     if state.finished or result.ignore_token:
 | |
|                         continue
 | |
| 
 | |
|                     state.finished = result.finished
 | |
|                     if first_request.logprobs:
 | |
|                         state.logprobs.append(TokenLogProbs(logprobs_by_token={result.text: result.logprobs[0]}))
 | |
| 
 | |
|                     state.tokens.append(result.token)
 | |
|                     if result.token == tokenizer.eot_id:
 | |
|                         state.stop_reason = StopReason.end_of_turn
 | |
|                     elif result.token == tokenizer.eom_id:
 | |
|                         state.stop_reason = StopReason.end_of_message
 | |
| 
 | |
|             for state in states:
 | |
|                 if state.stop_reason is None:
 | |
|                     state.stop_reason = StopReason.out_of_tokens
 | |
| 
 | |
|                 if state.tokens[-1] in self.generator.formatter.tokenizer.stop_tokens:
 | |
|                     state.tokens = state.tokens[:-1]
 | |
|                 content = self.generator.formatter.tokenizer.decode(state.tokens)
 | |
|                 results.append(
 | |
|                     CompletionResponse(
 | |
|                         content=content,
 | |
|                         stop_reason=state.stop_reason,
 | |
|                         logprobs=state.logprobs if first_request.logprobs else None,
 | |
|                     )
 | |
|                 )
 | |
| 
 | |
|             return results
 | |
| 
 | |
|         if self.config.create_distributed_process_group:
 | |
|             async with SEMAPHORE:
 | |
|                 return impl()
 | |
|         else:
 | |
|             return impl()
 | |
| 
 | |
|     async def chat_completion(
 | |
|         self,
 | |
|         model_id: str,
 | |
|         messages: list[Message],
 | |
|         sampling_params: SamplingParams | None = None,
 | |
|         response_format: ResponseFormat | None = None,
 | |
|         tools: list[ToolDefinition] | None = None,
 | |
|         tool_choice: ToolChoice | None = ToolChoice.auto,
 | |
|         tool_prompt_format: ToolPromptFormat | None = None,
 | |
|         stream: bool | None = False,
 | |
|         logprobs: LogProbConfig | None = None,
 | |
|         tool_config: ToolConfig | None = None,
 | |
|     ) -> AsyncGenerator:
 | |
|         if sampling_params is None:
 | |
|             sampling_params = SamplingParams()
 | |
|         if logprobs:
 | |
|             assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
 | |
| 
 | |
|         # wrapper request to make it easier to pass around (internal only, not exposed to API)
 | |
|         request = ChatCompletionRequest(
 | |
|             model=model_id,
 | |
|             messages=messages,
 | |
|             sampling_params=sampling_params,
 | |
|             tools=tools or [],
 | |
|             response_format=response_format,
 | |
|             stream=stream,
 | |
|             logprobs=logprobs,
 | |
|             tool_config=tool_config or ToolConfig(),
 | |
|         )
 | |
|         self.check_model(request)
 | |
| 
 | |
|         # augment and rewrite messages depending on the model
 | |
|         request.messages = chat_completion_request_to_messages(request, self.llama_model.core_model_id.value)
 | |
|         # download media and convert to raw content so we can send it to the model
 | |
|         request = await convert_request_to_raw(request)
 | |
| 
 | |
|         if self.config.create_distributed_process_group:
 | |
|             if SEMAPHORE.locked():
 | |
|                 raise RuntimeError("Only one concurrent request is supported")
 | |
| 
 | |
|         if request.stream:
 | |
|             return self._stream_chat_completion(request)
 | |
|         else:
 | |
|             results = await self._nonstream_chat_completion([request])
 | |
|             return results[0]
 | |
| 
 | |
|     async def batch_chat_completion(
 | |
|         self,
 | |
|         model_id: str,
 | |
|         messages_batch: list[list[Message]],
 | |
|         sampling_params: SamplingParams | None = None,
 | |
|         response_format: ResponseFormat | None = None,
 | |
|         tools: list[ToolDefinition] | None = None,
 | |
|         stream: bool | None = False,
 | |
|         logprobs: LogProbConfig | None = None,
 | |
|         tool_config: ToolConfig | None = None,
 | |
|     ) -> BatchChatCompletionResponse:
 | |
|         if sampling_params is None:
 | |
|             sampling_params = SamplingParams()
 | |
|         if logprobs:
 | |
|             assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
 | |
| 
 | |
|         # wrapper request to make it easier to pass around (internal only, not exposed to API)
 | |
|         request_batch = []
 | |
|         for messages in messages_batch:
 | |
|             request = ChatCompletionRequest(
 | |
|                 model=model_id,
 | |
|                 messages=messages,
 | |
|                 sampling_params=sampling_params,
 | |
|                 tools=tools or [],
 | |
|                 response_format=response_format,
 | |
|                 logprobs=logprobs,
 | |
|                 tool_config=tool_config or ToolConfig(),
 | |
|             )
 | |
|             self.check_model(request)
 | |
| 
 | |
|             # augment and rewrite messages depending on the model
 | |
|             request.messages = chat_completion_request_to_messages(request, self.llama_model.core_model_id.value)
 | |
|             # download media and convert to raw content so we can send it to the model
 | |
|             request = await convert_request_to_raw(request)
 | |
|             request_batch.append(request)
 | |
| 
 | |
|         if self.config.create_distributed_process_group:
 | |
|             if SEMAPHORE.locked():
 | |
|                 raise RuntimeError("Only one concurrent request is supported")
 | |
| 
 | |
|         results = await self._nonstream_chat_completion(request_batch)
 | |
|         return BatchChatCompletionResponse(batch=results)
 | |
| 
 | |
|     async def _nonstream_chat_completion(
 | |
|         self, request_batch: list[ChatCompletionRequest]
 | |
|     ) -> list[ChatCompletionResponse]:
 | |
|         tokenizer = self.generator.formatter.tokenizer
 | |
| 
 | |
|         first_request = request_batch[0]
 | |
| 
 | |
|         class ItemState(BaseModel):
 | |
|             tokens: list[int] = []
 | |
|             logprobs: list[TokenLogProbs] = []
 | |
|             stop_reason: StopReason | None = None
 | |
|             finished: bool = False
 | |
| 
 | |
|         def impl():
 | |
|             states = [ItemState() for _ in request_batch]
 | |
| 
 | |
|             for token_results in self.generator.chat_completion(request_batch):
 | |
|                 first = token_results[0]
 | |
|                 if not first.finished and not first.ignore_token:
 | |
|                     if os.environ.get("LLAMA_MODELS_DEBUG", "0") in ("1", "2"):
 | |
|                         cprint(first.text, color="cyan", end="", file=sys.stderr)
 | |
|                     if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2":
 | |
|                         cprint(f"<{first.token}>", color="magenta", end="", file=sys.stderr)
 | |
| 
 | |
|                 for result in token_results:
 | |
|                     idx = result.batch_idx
 | |
|                     state = states[idx]
 | |
|                     if state.finished or result.ignore_token:
 | |
|                         continue
 | |
| 
 | |
|                     state.finished = result.finished
 | |
|                     if first_request.logprobs:
 | |
|                         state.logprobs.append(TokenLogProbs(logprobs_by_token={result.text: result.logprobs[0]}))
 | |
| 
 | |
|                     state.tokens.append(result.token)
 | |
|                     if result.token == tokenizer.eot_id:
 | |
|                         state.stop_reason = StopReason.end_of_turn
 | |
|                     elif result.token == tokenizer.eom_id:
 | |
|                         state.stop_reason = StopReason.end_of_message
 | |
| 
 | |
|             results = []
 | |
|             for state in states:
 | |
|                 if state.stop_reason is None:
 | |
|                     state.stop_reason = StopReason.out_of_tokens
 | |
| 
 | |
|                 raw_message = self.generator.formatter.decode_assistant_message(state.tokens, state.stop_reason)
 | |
|                 results.append(
 | |
|                     ChatCompletionResponse(
 | |
|                         completion_message=CompletionMessage(
 | |
|                             content=raw_message.content,
 | |
|                             stop_reason=raw_message.stop_reason,
 | |
|                             tool_calls=raw_message.tool_calls,
 | |
|                         ),
 | |
|                         logprobs=state.logprobs if first_request.logprobs else None,
 | |
|                     )
 | |
|                 )
 | |
| 
 | |
|             return results
 | |
| 
 | |
|         if self.config.create_distributed_process_group:
 | |
|             async with SEMAPHORE:
 | |
|                 return impl()
 | |
|         else:
 | |
|             return impl()
 | |
| 
 | |
|     async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
 | |
|         tokenizer = self.generator.formatter.tokenizer
 | |
| 
 | |
|         def impl():
 | |
|             yield ChatCompletionResponseStreamChunk(
 | |
|                 event=ChatCompletionResponseEvent(
 | |
|                     event_type=ChatCompletionResponseEventType.start,
 | |
|                     delta=TextDelta(text=""),
 | |
|                 )
 | |
|             )
 | |
| 
 | |
|             tokens = []
 | |
|             logprobs = []
 | |
|             stop_reason = None
 | |
|             ipython = False
 | |
| 
 | |
|             for token_results in self.generator.chat_completion([request]):
 | |
|                 token_result = token_results[0]
 | |
|                 if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1":
 | |
|                     cprint(token_result.text, color="cyan", end="", file=sys.stderr)
 | |
|                 if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2":
 | |
|                     cprint(f"<{token_result.token}>", color="magenta", end="", file=sys.stderr)
 | |
| 
 | |
|                 if token_result.token == tokenizer.eot_id:
 | |
|                     stop_reason = StopReason.end_of_turn
 | |
|                     text = ""
 | |
|                 elif token_result.token == tokenizer.eom_id:
 | |
|                     stop_reason = StopReason.end_of_message
 | |
|                     text = ""
 | |
|                 else:
 | |
|                     text = token_result.text
 | |
| 
 | |
|                 if request.logprobs:
 | |
|                     assert len(token_result.logprobs) == 1
 | |
| 
 | |
|                     logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
 | |
| 
 | |
|                 tokens.append(token_result.token)
 | |
| 
 | |
|                 if not ipython and token_result.text.startswith("<|python_tag|>"):
 | |
|                     ipython = True
 | |
|                     yield ChatCompletionResponseStreamChunk(
 | |
|                         event=ChatCompletionResponseEvent(
 | |
|                             event_type=ChatCompletionResponseEventType.progress,
 | |
|                             delta=ToolCallDelta(
 | |
|                                 tool_call="",
 | |
|                                 parse_status=ToolCallParseStatus.started,
 | |
|                             ),
 | |
|                         )
 | |
|                     )
 | |
|                     continue
 | |
| 
 | |
|                 if token_result.token == tokenizer.eot_id:
 | |
|                     stop_reason = StopReason.end_of_turn
 | |
|                     text = ""
 | |
|                 elif token_result.token == tokenizer.eom_id:
 | |
|                     stop_reason = StopReason.end_of_message
 | |
|                     text = ""
 | |
|                 else:
 | |
|                     text = token_result.text
 | |
| 
 | |
|                 if ipython:
 | |
|                     delta = ToolCallDelta(
 | |
|                         tool_call=text,
 | |
|                         parse_status=ToolCallParseStatus.in_progress,
 | |
|                     )
 | |
|                 else:
 | |
|                     delta = TextDelta(text=text)
 | |
| 
 | |
|                 if stop_reason is None:
 | |
|                     if request.logprobs:
 | |
|                         assert len(token_result.logprobs) == 1
 | |
| 
 | |
|                         logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
 | |
|                     yield ChatCompletionResponseStreamChunk(
 | |
|                         event=ChatCompletionResponseEvent(
 | |
|                             event_type=ChatCompletionResponseEventType.progress,
 | |
|                             delta=delta,
 | |
|                             stop_reason=stop_reason,
 | |
|                             logprobs=logprobs if request.logprobs else None,
 | |
|                         )
 | |
|                     )
 | |
| 
 | |
|             if stop_reason is None:
 | |
|                 stop_reason = StopReason.out_of_tokens
 | |
| 
 | |
|             message = self.generator.formatter.decode_assistant_message(tokens, stop_reason)
 | |
| 
 | |
|             parsed_tool_calls = len(message.tool_calls) > 0
 | |
|             if ipython and not parsed_tool_calls:
 | |
|                 yield ChatCompletionResponseStreamChunk(
 | |
|                     event=ChatCompletionResponseEvent(
 | |
|                         event_type=ChatCompletionResponseEventType.progress,
 | |
|                         delta=ToolCallDelta(
 | |
|                             tool_call="",
 | |
|                             parse_status=ToolCallParseStatus.failed,
 | |
|                         ),
 | |
|                         stop_reason=stop_reason,
 | |
|                     )
 | |
|                 )
 | |
| 
 | |
|             for tool_call in message.tool_calls:
 | |
|                 yield ChatCompletionResponseStreamChunk(
 | |
|                     event=ChatCompletionResponseEvent(
 | |
|                         event_type=ChatCompletionResponseEventType.progress,
 | |
|                         delta=ToolCallDelta(
 | |
|                             tool_call=tool_call,
 | |
|                             parse_status=ToolCallParseStatus.succeeded,
 | |
|                         ),
 | |
|                         stop_reason=stop_reason,
 | |
|                     )
 | |
|                 )
 | |
| 
 | |
|             yield ChatCompletionResponseStreamChunk(
 | |
|                 event=ChatCompletionResponseEvent(
 | |
|                     event_type=ChatCompletionResponseEventType.complete,
 | |
|                     delta=TextDelta(text=""),
 | |
|                     stop_reason=stop_reason,
 | |
|                 )
 | |
|             )
 | |
| 
 | |
|         if self.config.create_distributed_process_group:
 | |
|             async with SEMAPHORE:
 | |
|                 for x in impl():
 | |
|                     yield x
 | |
|         else:
 | |
|             for x in impl():
 | |
|                 yield x
 |