diff --git a/llama_stack/providers/inline/inference/meta_reference/model_parallel.py b/llama_stack/providers/inline/inference/meta_reference/model_parallel.py index cb422b9b6..97384f4bb 100644 --- a/llama_stack/providers/inline/inference/meta_reference/model_parallel.py +++ b/llama_stack/providers/inline/inference/meta_reference/model_parallel.py @@ -14,7 +14,10 @@ from llama_models.llama3.api.datatypes import Model from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.sku_list import resolve_model -from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest +from llama_stack.providers.utils.inference.prompt_adapter import ( + ChatCompletionRequestWithRawContent, + CompletionRequestWithRawContent, +) from .config import MetaReferenceInferenceConfig from .generation import Llama, model_checkpoint_dir @@ -27,9 +30,9 @@ class ModelRunner: # the `task` object is the same that is sent to `ModelParallelProcessGroup.run_inference()` def __call__(self, req: Any): - if isinstance(req, ChatCompletionRequest): + if isinstance(req, ChatCompletionRequestWithRawContent): return self.llama.chat_completion(req) - elif isinstance(req, CompletionRequest): + elif isinstance(req, CompletionRequestWithRawContent): return self.llama.completion(req) else: raise ValueError(f"Unexpected task type {type(req)}") @@ -100,7 +103,7 @@ class LlamaModelParallelGenerator: def completion( self, - request: CompletionRequest, + request: CompletionRequestWithRawContent, ) -> Generator: req_obj = deepcopy(request) gen = self.group.run_inference(req_obj) @@ -108,7 +111,7 @@ class LlamaModelParallelGenerator: def chat_completion( self, - request: ChatCompletionRequest, + request: ChatCompletionRequestWithRawContent, ) -> Generator: req_obj = deepcopy(request) gen = self.group.run_inference(req_obj) diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index 9f034e801..82fcefe54 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -94,9 +94,14 @@ async def convert_request_to_raw( d = m.model_dump() d["content"] = content messages.append(RawMessage(**d)) - request.messages = messages + + d = request.model_dump() + d["messages"] = messages + request = ChatCompletionRequestWithRawContent(**d) else: - request.content = await interleaved_content_convert_to_raw(request.content) + d = request.model_dump() + d["content"] = await interleaved_content_convert_to_raw(request.content) + request = CompletionRequestWithRawContent(**d) return request