Fix Meta reference GPU implementation (#663)

By performing in-place mutations, we lost. Never in life do that.
This commit is contained in:
Ashwin Bharambe 2024-12-19 14:09:45 -08:00 committed by GitHub
parent f19eb8eee3
commit 540fc4d717
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 15 additions and 7 deletions

View file

@ -14,7 +14,10 @@ from llama_models.llama3.api.datatypes import Model
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model
from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest
from llama_stack.providers.utils.inference.prompt_adapter import (
ChatCompletionRequestWithRawContent,
CompletionRequestWithRawContent,
)
from .config import MetaReferenceInferenceConfig
from .generation import Llama, model_checkpoint_dir
@ -27,9 +30,9 @@ class ModelRunner:
# the `task` object is the same that is sent to `ModelParallelProcessGroup.run_inference()`
def __call__(self, req: Any):
if isinstance(req, ChatCompletionRequest):
if isinstance(req, ChatCompletionRequestWithRawContent):
return self.llama.chat_completion(req)
elif isinstance(req, CompletionRequest):
elif isinstance(req, CompletionRequestWithRawContent):
return self.llama.completion(req)
else:
raise ValueError(f"Unexpected task type {type(req)}")
@ -100,7 +103,7 @@ class LlamaModelParallelGenerator:
def completion(
self,
request: CompletionRequest,
request: CompletionRequestWithRawContent,
) -> Generator:
req_obj = deepcopy(request)
gen = self.group.run_inference(req_obj)
@ -108,7 +111,7 @@ class LlamaModelParallelGenerator:
def chat_completion(
self,
request: ChatCompletionRequest,
request: ChatCompletionRequestWithRawContent,
) -> Generator:
req_obj = deepcopy(request)
gen = self.group.run_inference(req_obj)

View file

@ -94,9 +94,14 @@ async def convert_request_to_raw(
d = m.model_dump()
d["content"] = content
messages.append(RawMessage(**d))
request.messages = messages
d = request.model_dump()
d["messages"] = messages
request = ChatCompletionRequestWithRawContent(**d)
else:
request.content = await interleaved_content_convert_to_raw(request.content)
d = request.model_dump()
d["content"] = await interleaved_content_convert_to_raw(request.content)
request = CompletionRequestWithRawContent(**d)
return request