forked from phoenix-oss/llama-stack-mirror
		
	Miscellaneous fixes for meta reference inference Tests for log probs dont pass because meta reference does not support top_k > 1
		
			
				
	
	
		
			470 lines
		
	
	
	
		
			16 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			470 lines
		
	
	
	
		
			16 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright (c) Meta Platforms, Inc. and affiliates.
 | |
| # All rights reserved.
 | |
| #
 | |
| # This source code is licensed under the terms described in the LICENSE file in
 | |
| # the root directory of this source tree.
 | |
| 
 | |
| import asyncio
 | |
| import logging
 | |
| from typing import AsyncGenerator, List, Optional, Union
 | |
| 
 | |
| from llama_models.llama3.api.datatypes import (
 | |
|     SamplingParams,
 | |
|     StopReason,
 | |
|     ToolDefinition,
 | |
|     ToolPromptFormat,
 | |
| )
 | |
| from llama_models.sku_list import resolve_model
 | |
| 
 | |
| from llama_stack.apis.common.content_types import (
 | |
|     TextDelta,
 | |
|     ToolCallDelta,
 | |
|     ToolCallParseStatus,
 | |
| )
 | |
| from llama_stack.apis.inference import (
 | |
|     ChatCompletionRequest,
 | |
|     ChatCompletionResponse,
 | |
|     ChatCompletionResponseEvent,
 | |
|     ChatCompletionResponseEventType,
 | |
|     ChatCompletionResponseStreamChunk,
 | |
|     CompletionMessage,
 | |
|     CompletionRequest,
 | |
|     CompletionResponse,
 | |
|     CompletionResponseStreamChunk,
 | |
|     Inference,
 | |
|     InterleavedContent,
 | |
|     LogProbConfig,
 | |
|     Message,
 | |
|     ResponseFormat,
 | |
|     TokenLogProbs,
 | |
|     ToolChoice,
 | |
| )
 | |
| from llama_stack.apis.models import Model, ModelType
 | |
| from llama_stack.providers.datatypes import ModelsProtocolPrivate
 | |
| from llama_stack.providers.utils.inference.embedding_mixin import (
 | |
|     SentenceTransformerEmbeddingMixin,
 | |
| )
 | |
| from llama_stack.providers.utils.inference.model_registry import (
 | |
|     build_model_alias,
 | |
|     ModelRegistryHelper,
 | |
| )
 | |
| from llama_stack.providers.utils.inference.prompt_adapter import (
 | |
|     augment_content_with_response_format_prompt,
 | |
|     chat_completion_request_to_messages,
 | |
|     convert_request_to_raw,
 | |
| )
 | |
| 
 | |
| from .config import MetaReferenceInferenceConfig
 | |
| from .generation import Llama
 | |
| from .model_parallel import LlamaModelParallelGenerator
 | |
| 
 | |
| log = logging.getLogger(__name__)
 | |
| # there's a single model parallel process running serving the model. for now,
 | |
| # we don't support multiple concurrent requests to this process.
 | |
| SEMAPHORE = asyncio.Semaphore(1)
 | |
| 
 | |
| 
 | |
| class MetaReferenceInferenceImpl(
 | |
|     SentenceTransformerEmbeddingMixin,
 | |
|     Inference,
 | |
|     ModelsProtocolPrivate,
 | |
| ):
 | |
|     def __init__(self, config: MetaReferenceInferenceConfig) -> None:
 | |
|         self.config = config
 | |
|         self.model_id = None
 | |
|         self.llama_model = None
 | |
| 
 | |
|     async def initialize(self) -> None:
 | |
|         pass
 | |
| 
 | |
|     async def load_model(self, model_id, llama_model) -> None:
 | |
|         log.info(f"Loading model `{model_id}`")
 | |
|         if self.config.create_distributed_process_group:
 | |
|             self.generator = LlamaModelParallelGenerator(
 | |
|                 self.config, model_id, llama_model
 | |
|             )
 | |
|             self.generator.start()
 | |
|         else:
 | |
|             self.generator = Llama.build(self.config, model_id, llama_model)
 | |
| 
 | |
|         self.model_id = model_id
 | |
|         self.llama_model = llama_model
 | |
| 
 | |
|     async def shutdown(self) -> None:
 | |
|         if self.config.create_distributed_process_group:
 | |
|             self.generator.stop()
 | |
| 
 | |
|     def check_model(self, request) -> None:
 | |
|         if self.model_id is None or self.llama_model is None:
 | |
|             raise RuntimeError(
 | |
|                 "No avaible model yet, please register your requested model or add your model in the resouces first"
 | |
|             )
 | |
|         elif request.model != self.model_id:
 | |
|             raise RuntimeError(
 | |
|                 f"Model mismatch: request model: {request.model} != loaded model: {self.model_id}"
 | |
|             )
 | |
| 
 | |
|     async def unregister_model(self, model_id: str) -> None:
 | |
|         pass
 | |
| 
 | |
|     async def register_model(self, model: Model) -> Model:
 | |
|         llama_model = (
 | |
|             resolve_model(model.metadata["llama_model"])
 | |
|             if "llama_model" in model.metadata
 | |
|             else resolve_model(model.identifier)
 | |
|         )
 | |
|         if llama_model is None:
 | |
|             raise ValueError(
 | |
|                 "Please make sure your llama_model in model metadata or model identifier is in llama-models SKU list"
 | |
|             )
 | |
| 
 | |
|         self.model_registry_helper = ModelRegistryHelper(
 | |
|             [
 | |
|                 build_model_alias(
 | |
|                     llama_model.descriptor(),
 | |
|                     llama_model.core_model_id.value,
 | |
|                 )
 | |
|             ],
 | |
|         )
 | |
|         model = await self.model_registry_helper.register_model(model)
 | |
| 
 | |
|         if model.model_type == ModelType.embedding:
 | |
|             self._load_sentence_transformer_model(model.provider_resource_id)
 | |
| 
 | |
|         if "skip_load" in model.metadata and model.metadata["skip_load"]:
 | |
|             return model
 | |
|         await self.load_model(model.identifier, llama_model)
 | |
|         return model
 | |
| 
 | |
|     async def completion(
 | |
|         self,
 | |
|         model_id: str,
 | |
|         content: InterleavedContent,
 | |
|         sampling_params: Optional[SamplingParams] = SamplingParams(),
 | |
|         response_format: Optional[ResponseFormat] = None,
 | |
|         stream: Optional[bool] = False,
 | |
|         logprobs: Optional[LogProbConfig] = None,
 | |
|     ) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
 | |
|         if logprobs:
 | |
|             assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
 | |
| 
 | |
|         content = augment_content_with_response_format_prompt(response_format, content)
 | |
|         request = CompletionRequest(
 | |
|             model=model_id,
 | |
|             content=content,
 | |
|             sampling_params=sampling_params,
 | |
|             response_format=response_format,
 | |
|             stream=stream,
 | |
|             logprobs=logprobs,
 | |
|         )
 | |
|         self.check_model(request)
 | |
|         request = await convert_request_to_raw(request)
 | |
| 
 | |
|         if request.stream:
 | |
|             return self._stream_completion(request)
 | |
|         else:
 | |
|             return await self._nonstream_completion(request)
 | |
| 
 | |
|     async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
 | |
|         def impl():
 | |
|             stop_reason = None
 | |
| 
 | |
|             for token_result in self.generator.completion(request):
 | |
|                 if token_result.text == "<|eot_id|>":
 | |
|                     stop_reason = StopReason.end_of_turn
 | |
|                     text = ""
 | |
|                 elif token_result.text == "<|eom_id|>":
 | |
|                     stop_reason = StopReason.end_of_message
 | |
|                     text = ""
 | |
|                 else:
 | |
|                     text = token_result.text
 | |
| 
 | |
|                 logprobs = None
 | |
|                 if stop_reason is None:
 | |
|                     if request.logprobs:
 | |
|                         assert len(token_result.logprobs) == 1
 | |
| 
 | |
|                         logprobs = [
 | |
|                             TokenLogProbs(
 | |
|                                 logprobs_by_token={
 | |
|                                     token_result.text: token_result.logprobs[0]
 | |
|                                 }
 | |
|                             )
 | |
|                         ]
 | |
| 
 | |
|                 yield CompletionResponseStreamChunk(
 | |
|                     delta=text,
 | |
|                     stop_reason=stop_reason,
 | |
|                     logprobs=logprobs if request.logprobs else None,
 | |
|                 )
 | |
| 
 | |
|             if stop_reason is None:
 | |
|                 yield CompletionResponseStreamChunk(
 | |
|                     delta="",
 | |
|                     stop_reason=StopReason.out_of_tokens,
 | |
|                 )
 | |
| 
 | |
|         if self.config.create_distributed_process_group:
 | |
|             async with SEMAPHORE:
 | |
|                 for x in impl():
 | |
|                     yield x
 | |
|         else:
 | |
|             for x in impl():
 | |
|                 yield x
 | |
| 
 | |
|     async def _nonstream_completion(
 | |
|         self, request: CompletionRequest
 | |
|     ) -> CompletionResponse:
 | |
|         def impl():
 | |
|             tokens = []
 | |
|             logprobs = []
 | |
|             stop_reason = None
 | |
| 
 | |
|             tokenizer = self.generator.formatter.tokenizer
 | |
|             for token_result in self.generator.completion(request):
 | |
|                 tokens.append(token_result.token)
 | |
|                 if token_result.text == "<|eot_id|>":
 | |
|                     stop_reason = StopReason.end_of_turn
 | |
|                 elif token_result.text == "<|eom_id|>":
 | |
|                     stop_reason = StopReason.end_of_message
 | |
| 
 | |
|                 if request.logprobs:
 | |
|                     assert len(token_result.logprobs) == 1
 | |
| 
 | |
|                     logprobs.append(
 | |
|                         TokenLogProbs(
 | |
|                             logprobs_by_token={
 | |
|                                 token_result.text: token_result.logprobs[0]
 | |
|                             }
 | |
|                         )
 | |
|                     )
 | |
| 
 | |
|             if stop_reason is None:
 | |
|                 stop_reason = StopReason.out_of_tokens
 | |
| 
 | |
|             content = self.generator.formatter.tokenizer.decode(tokens)
 | |
|             if content.endswith("<|eot_id|>"):
 | |
|                 content = content[: -len("<|eot_id|>")]
 | |
|             elif content.endswith("<|eom_id|>"):
 | |
|                 content = content[: -len("<|eom_id|>")]
 | |
|             return CompletionResponse(
 | |
|                 content=content,
 | |
|                 stop_reason=stop_reason,
 | |
|                 logprobs=logprobs if request.logprobs else None,
 | |
|             )
 | |
| 
 | |
|         if self.config.create_distributed_process_group:
 | |
|             async with SEMAPHORE:
 | |
|                 return impl()
 | |
|         else:
 | |
|             return impl()
 | |
| 
 | |
|     async def chat_completion(
 | |
|         self,
 | |
|         model_id: str,
 | |
|         messages: List[Message],
 | |
|         sampling_params: Optional[SamplingParams] = SamplingParams(),
 | |
|         response_format: Optional[ResponseFormat] = None,
 | |
|         tools: Optional[List[ToolDefinition]] = None,
 | |
|         tool_choice: Optional[ToolChoice] = ToolChoice.auto,
 | |
|         tool_prompt_format: Optional[ToolPromptFormat] = None,
 | |
|         stream: Optional[bool] = False,
 | |
|         logprobs: Optional[LogProbConfig] = None,
 | |
|     ) -> AsyncGenerator:
 | |
|         if logprobs:
 | |
|             assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
 | |
| 
 | |
|         # wrapper request to make it easier to pass around (internal only, not exposed to API)
 | |
|         request = ChatCompletionRequest(
 | |
|             model=model_id,
 | |
|             messages=messages,
 | |
|             sampling_params=sampling_params,
 | |
|             tools=tools or [],
 | |
|             tool_choice=tool_choice,
 | |
|             tool_prompt_format=tool_prompt_format,
 | |
|             response_format=response_format,
 | |
|             stream=stream,
 | |
|             logprobs=logprobs,
 | |
|         )
 | |
|         self.check_model(request)
 | |
| 
 | |
|         # augment and rewrite messages depending on the model
 | |
|         request.messages = chat_completion_request_to_messages(
 | |
|             request, self.llama_model.core_model_id.value
 | |
|         )
 | |
|         # download media and convert to raw content so we can send it to the model
 | |
|         request = await convert_request_to_raw(request)
 | |
| 
 | |
|         if self.config.create_distributed_process_group:
 | |
|             if SEMAPHORE.locked():
 | |
|                 raise RuntimeError("Only one concurrent request is supported")
 | |
| 
 | |
|         if request.stream:
 | |
|             return self._stream_chat_completion(request)
 | |
|         else:
 | |
|             return await self._nonstream_chat_completion(request)
 | |
| 
 | |
|     async def _nonstream_chat_completion(
 | |
|         self, request: ChatCompletionRequest
 | |
|     ) -> ChatCompletionResponse:
 | |
|         def impl():
 | |
|             tokens = []
 | |
|             logprobs = []
 | |
|             stop_reason = None
 | |
| 
 | |
|             for token_result in self.generator.chat_completion(request):
 | |
|                 tokens.append(token_result.token)
 | |
| 
 | |
|                 if token_result.text == "<|eot_id|>":
 | |
|                     stop_reason = StopReason.end_of_turn
 | |
|                 elif token_result.text == "<|eom_id|>":
 | |
|                     stop_reason = StopReason.end_of_message
 | |
| 
 | |
|                 if request.logprobs:
 | |
|                     assert len(token_result.logprobs) == 1
 | |
| 
 | |
|                     logprobs.append(
 | |
|                         TokenLogProbs(
 | |
|                             logprobs_by_token={
 | |
|                                 token_result.text: token_result.logprobs[0]
 | |
|                             }
 | |
|                         )
 | |
|                     )
 | |
| 
 | |
|             if stop_reason is None:
 | |
|                 stop_reason = StopReason.out_of_tokens
 | |
| 
 | |
|             raw_message = self.generator.formatter.decode_assistant_message(
 | |
|                 tokens, stop_reason
 | |
|             )
 | |
|             return ChatCompletionResponse(
 | |
|                 completion_message=CompletionMessage(
 | |
|                     content=raw_message.content,
 | |
|                     stop_reason=raw_message.stop_reason,
 | |
|                     tool_calls=raw_message.tool_calls,
 | |
|                 ),
 | |
|                 logprobs=logprobs if request.logprobs else None,
 | |
|             )
 | |
| 
 | |
|         if self.config.create_distributed_process_group:
 | |
|             async with SEMAPHORE:
 | |
|                 return impl()
 | |
|         else:
 | |
|             return impl()
 | |
| 
 | |
|     async def _stream_chat_completion(
 | |
|         self, request: ChatCompletionRequest
 | |
|     ) -> AsyncGenerator:
 | |
|         def impl():
 | |
|             yield ChatCompletionResponseStreamChunk(
 | |
|                 event=ChatCompletionResponseEvent(
 | |
|                     event_type=ChatCompletionResponseEventType.start,
 | |
|                     delta=TextDelta(text=""),
 | |
|                 )
 | |
|             )
 | |
| 
 | |
|             tokens = []
 | |
|             logprobs = []
 | |
|             stop_reason = None
 | |
|             ipython = False
 | |
| 
 | |
|             for token_result in self.generator.chat_completion(request):
 | |
|                 tokens.append(token_result.token)
 | |
| 
 | |
|                 if not ipython and token_result.text.startswith("<|python_tag|>"):
 | |
|                     ipython = True
 | |
|                     yield ChatCompletionResponseStreamChunk(
 | |
|                         event=ChatCompletionResponseEvent(
 | |
|                             event_type=ChatCompletionResponseEventType.progress,
 | |
|                             delta=ToolCallDelta(
 | |
|                                 content="",
 | |
|                                 parse_status=ToolCallParseStatus.started,
 | |
|                             ),
 | |
|                         )
 | |
|                     )
 | |
|                     continue
 | |
| 
 | |
|                 if token_result.text == "<|eot_id|>":
 | |
|                     stop_reason = StopReason.end_of_turn
 | |
|                     text = ""
 | |
|                 elif token_result.text == "<|eom_id|>":
 | |
|                     stop_reason = StopReason.end_of_message
 | |
|                     text = ""
 | |
|                 else:
 | |
|                     text = token_result.text
 | |
| 
 | |
|                 if ipython:
 | |
|                     delta = ToolCallDelta(
 | |
|                         content=text,
 | |
|                         parse_status=ToolCallParseStatus.in_progress,
 | |
|                     )
 | |
|                 else:
 | |
|                     delta = TextDelta(text=text)
 | |
| 
 | |
|                 if stop_reason is None:
 | |
|                     if request.logprobs:
 | |
|                         assert len(token_result.logprobs) == 1
 | |
| 
 | |
|                         logprobs.append(
 | |
|                             TokenLogProbs(
 | |
|                                 logprobs_by_token={
 | |
|                                     token_result.text: token_result.logprobs[0]
 | |
|                                 }
 | |
|                             )
 | |
|                         )
 | |
|                     yield ChatCompletionResponseStreamChunk(
 | |
|                         event=ChatCompletionResponseEvent(
 | |
|                             event_type=ChatCompletionResponseEventType.progress,
 | |
|                             delta=delta,
 | |
|                             stop_reason=stop_reason,
 | |
|                             logprobs=logprobs if request.logprobs else None,
 | |
|                         )
 | |
|                     )
 | |
| 
 | |
|             if stop_reason is None:
 | |
|                 stop_reason = StopReason.out_of_tokens
 | |
| 
 | |
|             message = self.generator.formatter.decode_assistant_message(
 | |
|                 tokens, stop_reason
 | |
|             )
 | |
| 
 | |
|             parsed_tool_calls = len(message.tool_calls) > 0
 | |
|             if ipython and not parsed_tool_calls:
 | |
|                 yield ChatCompletionResponseStreamChunk(
 | |
|                     event=ChatCompletionResponseEvent(
 | |
|                         event_type=ChatCompletionResponseEventType.progress,
 | |
|                         delta=ToolCallDelta(
 | |
|                             content="",
 | |
|                             parse_status=ToolCallParseStatus.failed,
 | |
|                         ),
 | |
|                         stop_reason=stop_reason,
 | |
|                     )
 | |
|                 )
 | |
| 
 | |
|             for tool_call in message.tool_calls:
 | |
|                 yield ChatCompletionResponseStreamChunk(
 | |
|                     event=ChatCompletionResponseEvent(
 | |
|                         event_type=ChatCompletionResponseEventType.progress,
 | |
|                         delta=ToolCallDelta(
 | |
|                             content=tool_call,
 | |
|                             parse_status=ToolCallParseStatus.succeeded,
 | |
|                         ),
 | |
|                         stop_reason=stop_reason,
 | |
|                     )
 | |
|                 )
 | |
| 
 | |
|             yield ChatCompletionResponseStreamChunk(
 | |
|                 event=ChatCompletionResponseEvent(
 | |
|                     event_type=ChatCompletionResponseEventType.complete,
 | |
|                     delta=TextDelta(text=""),
 | |
|                     stop_reason=stop_reason,
 | |
|                 )
 | |
|             )
 | |
| 
 | |
|         if self.config.create_distributed_process_group:
 | |
|             async with SEMAPHORE:
 | |
|                 for x in impl():
 | |
|                     yield x
 | |
|         else:
 | |
|             for x in impl():
 | |
|                 yield x
 |