forked from phoenix-oss/llama-stack-mirror
Fix Meta reference GPU implementation (#663)
By performing in-place mutations, we lost. Never in life do that.
This commit is contained in:
parent
f19eb8eee3
commit
540fc4d717
2 changed files with 15 additions and 7 deletions
|
@ -14,7 +14,10 @@ from llama_models.llama3.api.datatypes import Model
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
|
|
||||||
from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
ChatCompletionRequestWithRawContent,
|
||||||
|
CompletionRequestWithRawContent,
|
||||||
|
)
|
||||||
|
|
||||||
from .config import MetaReferenceInferenceConfig
|
from .config import MetaReferenceInferenceConfig
|
||||||
from .generation import Llama, model_checkpoint_dir
|
from .generation import Llama, model_checkpoint_dir
|
||||||
|
@ -27,9 +30,9 @@ class ModelRunner:
|
||||||
|
|
||||||
# the `task` object is the same that is sent to `ModelParallelProcessGroup.run_inference()`
|
# the `task` object is the same that is sent to `ModelParallelProcessGroup.run_inference()`
|
||||||
def __call__(self, req: Any):
|
def __call__(self, req: Any):
|
||||||
if isinstance(req, ChatCompletionRequest):
|
if isinstance(req, ChatCompletionRequestWithRawContent):
|
||||||
return self.llama.chat_completion(req)
|
return self.llama.chat_completion(req)
|
||||||
elif isinstance(req, CompletionRequest):
|
elif isinstance(req, CompletionRequestWithRawContent):
|
||||||
return self.llama.completion(req)
|
return self.llama.completion(req)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unexpected task type {type(req)}")
|
raise ValueError(f"Unexpected task type {type(req)}")
|
||||||
|
@ -100,7 +103,7 @@ class LlamaModelParallelGenerator:
|
||||||
|
|
||||||
def completion(
|
def completion(
|
||||||
self,
|
self,
|
||||||
request: CompletionRequest,
|
request: CompletionRequestWithRawContent,
|
||||||
) -> Generator:
|
) -> Generator:
|
||||||
req_obj = deepcopy(request)
|
req_obj = deepcopy(request)
|
||||||
gen = self.group.run_inference(req_obj)
|
gen = self.group.run_inference(req_obj)
|
||||||
|
@ -108,7 +111,7 @@ class LlamaModelParallelGenerator:
|
||||||
|
|
||||||
def chat_completion(
|
def chat_completion(
|
||||||
self,
|
self,
|
||||||
request: ChatCompletionRequest,
|
request: ChatCompletionRequestWithRawContent,
|
||||||
) -> Generator:
|
) -> Generator:
|
||||||
req_obj = deepcopy(request)
|
req_obj = deepcopy(request)
|
||||||
gen = self.group.run_inference(req_obj)
|
gen = self.group.run_inference(req_obj)
|
||||||
|
|
|
@ -94,9 +94,14 @@ async def convert_request_to_raw(
|
||||||
d = m.model_dump()
|
d = m.model_dump()
|
||||||
d["content"] = content
|
d["content"] = content
|
||||||
messages.append(RawMessage(**d))
|
messages.append(RawMessage(**d))
|
||||||
request.messages = messages
|
|
||||||
|
d = request.model_dump()
|
||||||
|
d["messages"] = messages
|
||||||
|
request = ChatCompletionRequestWithRawContent(**d)
|
||||||
else:
|
else:
|
||||||
request.content = await interleaved_content_convert_to_raw(request.content)
|
d = request.model_dump()
|
||||||
|
d["content"] = await interleaved_content_convert_to_raw(request.content)
|
||||||
|
request = CompletionRequestWithRawContent(**d)
|
||||||
|
|
||||||
return request
|
return request
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue