mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 01:03:59 +00:00
fixes
This commit is contained in:
parent
a5d6ab16b2
commit
693c709c27
2 changed files with 103 additions and 34 deletions
|
@ -8,9 +8,6 @@ import asyncio
|
||||||
import os
|
import os
|
||||||
from typing import AsyncGenerator, List, Optional, Union
|
from typing import AsyncGenerator, List, Optional, Union
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from termcolor import cprint
|
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
TextDelta,
|
TextDelta,
|
||||||
ToolCallDelta,
|
ToolCallDelta,
|
||||||
|
@ -55,8 +52,8 @@ from llama_stack.providers.utils.inference.embedding_mixin import (
|
||||||
SentenceTransformerEmbeddingMixin,
|
SentenceTransformerEmbeddingMixin,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.utils.inference.model_registry import (
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
ModelRegistryHelper,
|
|
||||||
build_hf_repo_model_entry,
|
build_hf_repo_model_entry,
|
||||||
|
ModelRegistryHelper,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
OpenAIChatCompletionToLlamaStackMixin,
|
OpenAIChatCompletionToLlamaStackMixin,
|
||||||
|
@ -68,6 +65,9 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
convert_request_to_raw,
|
convert_request_to_raw,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from termcolor import cprint
|
||||||
|
|
||||||
from .config import MetaReferenceInferenceConfig
|
from .config import MetaReferenceInferenceConfig
|
||||||
from .generators import LlamaGenerator
|
from .generators import LlamaGenerator
|
||||||
from .model_parallel import LlamaModelParallelGenerator
|
from .model_parallel import LlamaModelParallelGenerator
|
||||||
|
@ -78,7 +78,9 @@ log = get_logger(__name__, category="inference")
|
||||||
SEMAPHORE = asyncio.Semaphore(1)
|
SEMAPHORE = asyncio.Semaphore(1)
|
||||||
|
|
||||||
|
|
||||||
def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> LlamaGenerator:
|
def llama_builder_fn(
|
||||||
|
config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model
|
||||||
|
) -> LlamaGenerator:
|
||||||
return LlamaGenerator(config, model_id, llama_model)
|
return LlamaGenerator(config, model_id, llama_model)
|
||||||
|
|
||||||
|
|
||||||
|
@ -143,7 +145,8 @@ class MetaReferenceInferenceImpl(
|
||||||
|
|
||||||
if self.config.create_distributed_process_group:
|
if self.config.create_distributed_process_group:
|
||||||
self.generator = LlamaModelParallelGenerator(
|
self.generator = LlamaModelParallelGenerator(
|
||||||
model_parallel_size=self.config.model_parallel_size or llama_model.pth_file_count,
|
model_parallel_size=self.config.model_parallel_size
|
||||||
|
or llama_model.pth_file_count,
|
||||||
builder_fn=llama_builder_fn,
|
builder_fn=llama_builder_fn,
|
||||||
builder_params=builder_params,
|
builder_params=builder_params,
|
||||||
formatter=(
|
formatter=(
|
||||||
|
@ -178,7 +181,9 @@ class MetaReferenceInferenceImpl(
|
||||||
"No avaible model yet, please register your requested model or add your model in the resouces first"
|
"No avaible model yet, please register your requested model or add your model in the resouces first"
|
||||||
)
|
)
|
||||||
elif request.model != self.model_id:
|
elif request.model != self.model_id:
|
||||||
raise RuntimeError(f"Model mismatch: request model: {request.model} != loaded model: {self.model_id}")
|
raise RuntimeError(
|
||||||
|
f"Model mismatch: request model: {request.model} != loaded model: {self.model_id}"
|
||||||
|
)
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
|
@ -227,7 +232,8 @@ class MetaReferenceInferenceImpl(
|
||||||
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
|
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
|
||||||
|
|
||||||
content_batch = [
|
content_batch = [
|
||||||
augment_content_with_response_format_prompt(response_format, content) for content in content_batch
|
augment_content_with_response_format_prompt(response_format, content)
|
||||||
|
for content in content_batch
|
||||||
]
|
]
|
||||||
|
|
||||||
request_batch = []
|
request_batch = []
|
||||||
|
@ -253,7 +259,8 @@ class MetaReferenceInferenceImpl(
|
||||||
def impl():
|
def impl():
|
||||||
stop_reason = None
|
stop_reason = None
|
||||||
|
|
||||||
for token_result in self.generator.completion(request):
|
for token_results in self.generator.completion([request]):
|
||||||
|
token_result = token_results[0]
|
||||||
if token_result.token == tokenizer.eot_id:
|
if token_result.token == tokenizer.eot_id:
|
||||||
stop_reason = StopReason.end_of_turn
|
stop_reason = StopReason.end_of_turn
|
||||||
text = ""
|
text = ""
|
||||||
|
@ -268,7 +275,13 @@ class MetaReferenceInferenceImpl(
|
||||||
if request.logprobs:
|
if request.logprobs:
|
||||||
assert len(token_result.logprobs) == 1
|
assert len(token_result.logprobs) == 1
|
||||||
|
|
||||||
logprobs = [TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]})]
|
logprobs = [
|
||||||
|
TokenLogProbs(
|
||||||
|
logprobs_by_token={
|
||||||
|
token_result.text: token_result.logprobs[0]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
yield CompletionResponseStreamChunk(
|
yield CompletionResponseStreamChunk(
|
||||||
delta=text,
|
delta=text,
|
||||||
|
@ -290,7 +303,9 @@ class MetaReferenceInferenceImpl(
|
||||||
for x in impl():
|
for x in impl():
|
||||||
yield x
|
yield x
|
||||||
|
|
||||||
async def _nonstream_completion(self, request_batch: List[CompletionRequest]) -> List[CompletionResponse]:
|
async def _nonstream_completion(
|
||||||
|
self, request_batch: List[CompletionRequest]
|
||||||
|
) -> List[CompletionResponse]:
|
||||||
tokenizer = self.generator.formatter.tokenizer
|
tokenizer = self.generator.formatter.tokenizer
|
||||||
|
|
||||||
first_request = request_batch[0]
|
first_request = request_batch[0]
|
||||||
|
@ -314,7 +329,11 @@ class MetaReferenceInferenceImpl(
|
||||||
|
|
||||||
state.finished = result.finished
|
state.finished = result.finished
|
||||||
if first_request.logprobs:
|
if first_request.logprobs:
|
||||||
state.logprobs.append(TokenLogProbs(logprobs_by_token={result.text: result.logprobs[0]}))
|
state.logprobs.append(
|
||||||
|
TokenLogProbs(
|
||||||
|
logprobs_by_token={result.text: result.logprobs[0]}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
state.tokens.append(result.token)
|
state.tokens.append(result.token)
|
||||||
if result.token == tokenizer.eot_id:
|
if result.token == tokenizer.eot_id:
|
||||||
|
@ -377,7 +396,9 @@ class MetaReferenceInferenceImpl(
|
||||||
self.check_model(request)
|
self.check_model(request)
|
||||||
|
|
||||||
# augment and rewrite messages depending on the model
|
# augment and rewrite messages depending on the model
|
||||||
request.messages = chat_completion_request_to_messages(request, self.llama_model.core_model_id.value)
|
request.messages = chat_completion_request_to_messages(
|
||||||
|
request, self.llama_model.core_model_id.value
|
||||||
|
)
|
||||||
# download media and convert to raw content so we can send it to the model
|
# download media and convert to raw content so we can send it to the model
|
||||||
request = await convert_request_to_raw(request)
|
request = await convert_request_to_raw(request)
|
||||||
|
|
||||||
|
@ -422,7 +443,9 @@ class MetaReferenceInferenceImpl(
|
||||||
self.check_model(request)
|
self.check_model(request)
|
||||||
|
|
||||||
# augment and rewrite messages depending on the model
|
# augment and rewrite messages depending on the model
|
||||||
request.messages = chat_completion_request_to_messages(request, self.llama_model.core_model_id.value)
|
request.messages = chat_completion_request_to_messages(
|
||||||
|
request, self.llama_model.core_model_id.value
|
||||||
|
)
|
||||||
# download media and convert to raw content so we can send it to the model
|
# download media and convert to raw content so we can send it to the model
|
||||||
request = await convert_request_to_raw(request)
|
request = await convert_request_to_raw(request)
|
||||||
request_batch.append(request)
|
request_batch.append(request)
|
||||||
|
@ -466,7 +489,11 @@ class MetaReferenceInferenceImpl(
|
||||||
|
|
||||||
state.finished = result.finished
|
state.finished = result.finished
|
||||||
if first_request.logprobs:
|
if first_request.logprobs:
|
||||||
state.logprobs.append(TokenLogProbs(logprobs_by_token={result.text: result.logprobs[0]}))
|
state.logprobs.append(
|
||||||
|
TokenLogProbs(
|
||||||
|
logprobs_by_token={result.text: result.logprobs[0]}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
state.tokens.append(result.token)
|
state.tokens.append(result.token)
|
||||||
if result.token == tokenizer.eot_id:
|
if result.token == tokenizer.eot_id:
|
||||||
|
@ -479,7 +506,9 @@ class MetaReferenceInferenceImpl(
|
||||||
if state.stop_reason is None:
|
if state.stop_reason is None:
|
||||||
state.stop_reason = StopReason.out_of_tokens
|
state.stop_reason = StopReason.out_of_tokens
|
||||||
|
|
||||||
raw_message = self.generator.formatter.decode_assistant_message(state.tokens, state.stop_reason)
|
raw_message = self.generator.formatter.decode_assistant_message(
|
||||||
|
state.tokens, state.stop_reason
|
||||||
|
)
|
||||||
results.append(
|
results.append(
|
||||||
ChatCompletionResponse(
|
ChatCompletionResponse(
|
||||||
completion_message=CompletionMessage(
|
completion_message=CompletionMessage(
|
||||||
|
@ -499,7 +528,9 @@ class MetaReferenceInferenceImpl(
|
||||||
else:
|
else:
|
||||||
return impl()
|
return impl()
|
||||||
|
|
||||||
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
async def _stream_chat_completion(
|
||||||
|
self, request: ChatCompletionRequest
|
||||||
|
) -> AsyncGenerator:
|
||||||
tokenizer = self.generator.formatter.tokenizer
|
tokenizer = self.generator.formatter.tokenizer
|
||||||
|
|
||||||
def impl():
|
def impl():
|
||||||
|
@ -534,7 +565,13 @@ class MetaReferenceInferenceImpl(
|
||||||
if request.logprobs:
|
if request.logprobs:
|
||||||
assert len(token_result.logprobs) == 1
|
assert len(token_result.logprobs) == 1
|
||||||
|
|
||||||
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
|
logprobs.append(
|
||||||
|
TokenLogProbs(
|
||||||
|
logprobs_by_token={
|
||||||
|
token_result.text: token_result.logprobs[0]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
tokens.append(token_result.token)
|
tokens.append(token_result.token)
|
||||||
|
|
||||||
|
@ -572,7 +609,13 @@ class MetaReferenceInferenceImpl(
|
||||||
if request.logprobs:
|
if request.logprobs:
|
||||||
assert len(token_result.logprobs) == 1
|
assert len(token_result.logprobs) == 1
|
||||||
|
|
||||||
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
|
logprobs.append(
|
||||||
|
TokenLogProbs(
|
||||||
|
logprobs_by_token={
|
||||||
|
token_result.text: token_result.logprobs[0]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type=ChatCompletionResponseEventType.progress,
|
event_type=ChatCompletionResponseEventType.progress,
|
||||||
|
@ -585,7 +628,9 @@ class MetaReferenceInferenceImpl(
|
||||||
if stop_reason is None:
|
if stop_reason is None:
|
||||||
stop_reason = StopReason.out_of_tokens
|
stop_reason = StopReason.out_of_tokens
|
||||||
|
|
||||||
message = self.generator.formatter.decode_assistant_message(tokens, stop_reason)
|
message = self.generator.formatter.decode_assistant_message(
|
||||||
|
tokens, stop_reason
|
||||||
|
)
|
||||||
|
|
||||||
parsed_tool_calls = len(message.tool_calls) > 0
|
parsed_tool_calls = len(message.tool_calls) > 0
|
||||||
if ipython and not parsed_tool_calls:
|
if ipython and not parsed_tool_calls:
|
||||||
|
|
|
@ -28,15 +28,15 @@ from fairscale.nn.model_parallel.initialize import (
|
||||||
get_model_parallel_rank,
|
get_model_parallel_rank,
|
||||||
get_model_parallel_src_rank,
|
get_model_parallel_src_rank,
|
||||||
)
|
)
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
|
|
||||||
from typing_extensions import Annotated
|
|
||||||
|
|
||||||
from llama_stack.models.llama.datatypes import GenerationResult
|
from llama_stack.models.llama.datatypes import GenerationResult
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
ChatCompletionRequestWithRawContent,
|
ChatCompletionRequestWithRawContent,
|
||||||
CompletionRequestWithRawContent,
|
CompletionRequestWithRawContent,
|
||||||
)
|
)
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from torch.distributed.launcher.api import elastic_launch, LaunchConfig
|
||||||
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -52,33 +52,51 @@ class ProcessingMessageName(str, Enum):
|
||||||
|
|
||||||
|
|
||||||
class ReadyRequest(BaseModel):
|
class ReadyRequest(BaseModel):
|
||||||
type: Literal[ProcessingMessageName.ready_request] = ProcessingMessageName.ready_request
|
type: Literal[ProcessingMessageName.ready_request] = (
|
||||||
|
ProcessingMessageName.ready_request
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ReadyResponse(BaseModel):
|
class ReadyResponse(BaseModel):
|
||||||
type: Literal[ProcessingMessageName.ready_response] = ProcessingMessageName.ready_response
|
type: Literal[ProcessingMessageName.ready_response] = (
|
||||||
|
ProcessingMessageName.ready_response
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class EndSentinel(BaseModel):
|
class EndSentinel(BaseModel):
|
||||||
type: Literal[ProcessingMessageName.end_sentinel] = ProcessingMessageName.end_sentinel
|
type: Literal[ProcessingMessageName.end_sentinel] = (
|
||||||
|
ProcessingMessageName.end_sentinel
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class CancelSentinel(BaseModel):
|
class CancelSentinel(BaseModel):
|
||||||
type: Literal[ProcessingMessageName.cancel_sentinel] = ProcessingMessageName.cancel_sentinel
|
type: Literal[ProcessingMessageName.cancel_sentinel] = (
|
||||||
|
ProcessingMessageName.cancel_sentinel
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TaskRequest(BaseModel):
|
class TaskRequest(BaseModel):
|
||||||
type: Literal[ProcessingMessageName.task_request] = ProcessingMessageName.task_request
|
type: Literal[ProcessingMessageName.task_request] = (
|
||||||
task: Tuple[str, List[CompletionRequestWithRawContent] | List[ChatCompletionRequestWithRawContent]]
|
ProcessingMessageName.task_request
|
||||||
|
)
|
||||||
|
task: Tuple[
|
||||||
|
str,
|
||||||
|
List[CompletionRequestWithRawContent]
|
||||||
|
| List[ChatCompletionRequestWithRawContent],
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class TaskResponse(BaseModel):
|
class TaskResponse(BaseModel):
|
||||||
type: Literal[ProcessingMessageName.task_response] = ProcessingMessageName.task_response
|
type: Literal[ProcessingMessageName.task_response] = (
|
||||||
|
ProcessingMessageName.task_response
|
||||||
|
)
|
||||||
result: List[GenerationResult]
|
result: List[GenerationResult]
|
||||||
|
|
||||||
|
|
||||||
class ExceptionResponse(BaseModel):
|
class ExceptionResponse(BaseModel):
|
||||||
type: Literal[ProcessingMessageName.exception_response] = ProcessingMessageName.exception_response
|
type: Literal[ProcessingMessageName.exception_response] = (
|
||||||
|
ProcessingMessageName.exception_response
|
||||||
|
)
|
||||||
error: str
|
error: str
|
||||||
|
|
||||||
|
|
||||||
|
@ -172,7 +190,9 @@ def retrieve_requests(reply_socket_url: str):
|
||||||
group=get_model_parallel_group(),
|
group=get_model_parallel_group(),
|
||||||
)
|
)
|
||||||
if isinstance(updates[0], CancelSentinel):
|
if isinstance(updates[0], CancelSentinel):
|
||||||
log.info("quitting generation loop because request was cancelled")
|
log.info(
|
||||||
|
"quitting generation loop because request was cancelled"
|
||||||
|
)
|
||||||
break
|
break
|
||||||
|
|
||||||
if mp_rank_0():
|
if mp_rank_0():
|
||||||
|
@ -234,7 +254,7 @@ def worker_process_entrypoint(
|
||||||
if isinstance(task, EndSentinel):
|
if isinstance(task, EndSentinel):
|
||||||
break
|
break
|
||||||
|
|
||||||
assert isinstance(task, TaskRequest)
|
assert isinstance(task, TaskRequest), task
|
||||||
result = model(task.task)
|
result = model(task.task)
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
break
|
break
|
||||||
|
@ -331,7 +351,11 @@ class ModelParallelProcessGroup:
|
||||||
|
|
||||||
def run_inference(
|
def run_inference(
|
||||||
self,
|
self,
|
||||||
req: Tuple[str, List[CompletionRequestWithRawContent] | List[ChatCompletionRequestWithRawContent]],
|
req: Tuple[
|
||||||
|
str,
|
||||||
|
List[CompletionRequestWithRawContent]
|
||||||
|
| List[ChatCompletionRequestWithRawContent],
|
||||||
|
],
|
||||||
) -> Generator:
|
) -> Generator:
|
||||||
assert not self.running, "inference already running"
|
assert not self.running, "inference already running"
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue