mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 08:44:44 +00:00
fixes
This commit is contained in:
parent
a5d6ab16b2
commit
693c709c27
2 changed files with 103 additions and 34 deletions
|
@ -8,9 +8,6 @@ import asyncio
|
|||
import os
|
||||
from typing import AsyncGenerator, List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
TextDelta,
|
||||
ToolCallDelta,
|
||||
|
@ -55,8 +52,8 @@ from llama_stack.providers.utils.inference.embedding_mixin import (
|
|||
SentenceTransformerEmbeddingMixin,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
build_hf_repo_model_entry,
|
||||
ModelRegistryHelper,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
|
@ -68,6 +65,9 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
convert_request_to_raw,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel
|
||||
from termcolor import cprint
|
||||
|
||||
from .config import MetaReferenceInferenceConfig
|
||||
from .generators import LlamaGenerator
|
||||
from .model_parallel import LlamaModelParallelGenerator
|
||||
|
@ -78,7 +78,9 @@ log = get_logger(__name__, category="inference")
|
|||
SEMAPHORE = asyncio.Semaphore(1)
|
||||
|
||||
|
||||
def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> LlamaGenerator:
|
||||
def llama_builder_fn(
|
||||
config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model
|
||||
) -> LlamaGenerator:
|
||||
return LlamaGenerator(config, model_id, llama_model)
|
||||
|
||||
|
||||
|
@ -143,7 +145,8 @@ class MetaReferenceInferenceImpl(
|
|||
|
||||
if self.config.create_distributed_process_group:
|
||||
self.generator = LlamaModelParallelGenerator(
|
||||
model_parallel_size=self.config.model_parallel_size or llama_model.pth_file_count,
|
||||
model_parallel_size=self.config.model_parallel_size
|
||||
or llama_model.pth_file_count,
|
||||
builder_fn=llama_builder_fn,
|
||||
builder_params=builder_params,
|
||||
formatter=(
|
||||
|
@ -178,7 +181,9 @@ class MetaReferenceInferenceImpl(
|
|||
"No avaible model yet, please register your requested model or add your model in the resouces first"
|
||||
)
|
||||
elif request.model != self.model_id:
|
||||
raise RuntimeError(f"Model mismatch: request model: {request.model} != loaded model: {self.model_id}")
|
||||
raise RuntimeError(
|
||||
f"Model mismatch: request model: {request.model} != loaded model: {self.model_id}"
|
||||
)
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
|
@ -227,7 +232,8 @@ class MetaReferenceInferenceImpl(
|
|||
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
|
||||
|
||||
content_batch = [
|
||||
augment_content_with_response_format_prompt(response_format, content) for content in content_batch
|
||||
augment_content_with_response_format_prompt(response_format, content)
|
||||
for content in content_batch
|
||||
]
|
||||
|
||||
request_batch = []
|
||||
|
@ -253,7 +259,8 @@ class MetaReferenceInferenceImpl(
|
|||
def impl():
|
||||
stop_reason = None
|
||||
|
||||
for token_result in self.generator.completion(request):
|
||||
for token_results in self.generator.completion([request]):
|
||||
token_result = token_results[0]
|
||||
if token_result.token == tokenizer.eot_id:
|
||||
stop_reason = StopReason.end_of_turn
|
||||
text = ""
|
||||
|
@ -268,7 +275,13 @@ class MetaReferenceInferenceImpl(
|
|||
if request.logprobs:
|
||||
assert len(token_result.logprobs) == 1
|
||||
|
||||
logprobs = [TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]})]
|
||||
logprobs = [
|
||||
TokenLogProbs(
|
||||
logprobs_by_token={
|
||||
token_result.text: token_result.logprobs[0]
|
||||
}
|
||||
)
|
||||
]
|
||||
|
||||
yield CompletionResponseStreamChunk(
|
||||
delta=text,
|
||||
|
@ -290,7 +303,9 @@ class MetaReferenceInferenceImpl(
|
|||
for x in impl():
|
||||
yield x
|
||||
|
||||
async def _nonstream_completion(self, request_batch: List[CompletionRequest]) -> List[CompletionResponse]:
|
||||
async def _nonstream_completion(
|
||||
self, request_batch: List[CompletionRequest]
|
||||
) -> List[CompletionResponse]:
|
||||
tokenizer = self.generator.formatter.tokenizer
|
||||
|
||||
first_request = request_batch[0]
|
||||
|
@ -314,7 +329,11 @@ class MetaReferenceInferenceImpl(
|
|||
|
||||
state.finished = result.finished
|
||||
if first_request.logprobs:
|
||||
state.logprobs.append(TokenLogProbs(logprobs_by_token={result.text: result.logprobs[0]}))
|
||||
state.logprobs.append(
|
||||
TokenLogProbs(
|
||||
logprobs_by_token={result.text: result.logprobs[0]}
|
||||
)
|
||||
)
|
||||
|
||||
state.tokens.append(result.token)
|
||||
if result.token == tokenizer.eot_id:
|
||||
|
@ -377,7 +396,9 @@ class MetaReferenceInferenceImpl(
|
|||
self.check_model(request)
|
||||
|
||||
# augment and rewrite messages depending on the model
|
||||
request.messages = chat_completion_request_to_messages(request, self.llama_model.core_model_id.value)
|
||||
request.messages = chat_completion_request_to_messages(
|
||||
request, self.llama_model.core_model_id.value
|
||||
)
|
||||
# download media and convert to raw content so we can send it to the model
|
||||
request = await convert_request_to_raw(request)
|
||||
|
||||
|
@ -422,7 +443,9 @@ class MetaReferenceInferenceImpl(
|
|||
self.check_model(request)
|
||||
|
||||
# augment and rewrite messages depending on the model
|
||||
request.messages = chat_completion_request_to_messages(request, self.llama_model.core_model_id.value)
|
||||
request.messages = chat_completion_request_to_messages(
|
||||
request, self.llama_model.core_model_id.value
|
||||
)
|
||||
# download media and convert to raw content so we can send it to the model
|
||||
request = await convert_request_to_raw(request)
|
||||
request_batch.append(request)
|
||||
|
@ -466,7 +489,11 @@ class MetaReferenceInferenceImpl(
|
|||
|
||||
state.finished = result.finished
|
||||
if first_request.logprobs:
|
||||
state.logprobs.append(TokenLogProbs(logprobs_by_token={result.text: result.logprobs[0]}))
|
||||
state.logprobs.append(
|
||||
TokenLogProbs(
|
||||
logprobs_by_token={result.text: result.logprobs[0]}
|
||||
)
|
||||
)
|
||||
|
||||
state.tokens.append(result.token)
|
||||
if result.token == tokenizer.eot_id:
|
||||
|
@ -479,7 +506,9 @@ class MetaReferenceInferenceImpl(
|
|||
if state.stop_reason is None:
|
||||
state.stop_reason = StopReason.out_of_tokens
|
||||
|
||||
raw_message = self.generator.formatter.decode_assistant_message(state.tokens, state.stop_reason)
|
||||
raw_message = self.generator.formatter.decode_assistant_message(
|
||||
state.tokens, state.stop_reason
|
||||
)
|
||||
results.append(
|
||||
ChatCompletionResponse(
|
||||
completion_message=CompletionMessage(
|
||||
|
@ -499,7 +528,9 @@ class MetaReferenceInferenceImpl(
|
|||
else:
|
||||
return impl()
|
||||
|
||||
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
||||
async def _stream_chat_completion(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> AsyncGenerator:
|
||||
tokenizer = self.generator.formatter.tokenizer
|
||||
|
||||
def impl():
|
||||
|
@ -534,7 +565,13 @@ class MetaReferenceInferenceImpl(
|
|||
if request.logprobs:
|
||||
assert len(token_result.logprobs) == 1
|
||||
|
||||
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
|
||||
logprobs.append(
|
||||
TokenLogProbs(
|
||||
logprobs_by_token={
|
||||
token_result.text: token_result.logprobs[0]
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
tokens.append(token_result.token)
|
||||
|
||||
|
@ -572,7 +609,13 @@ class MetaReferenceInferenceImpl(
|
|||
if request.logprobs:
|
||||
assert len(token_result.logprobs) == 1
|
||||
|
||||
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
|
||||
logprobs.append(
|
||||
TokenLogProbs(
|
||||
logprobs_by_token={
|
||||
token_result.text: token_result.logprobs[0]
|
||||
}
|
||||
)
|
||||
)
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
|
@ -585,7 +628,9 @@ class MetaReferenceInferenceImpl(
|
|||
if stop_reason is None:
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
|
||||
message = self.generator.formatter.decode_assistant_message(tokens, stop_reason)
|
||||
message = self.generator.formatter.decode_assistant_message(
|
||||
tokens, stop_reason
|
||||
)
|
||||
|
||||
parsed_tool_calls = len(message.tool_calls) > 0
|
||||
if ipython and not parsed_tool_calls:
|
||||
|
|
|
@ -28,15 +28,15 @@ from fairscale.nn.model_parallel.initialize import (
|
|||
get_model_parallel_rank,
|
||||
get_model_parallel_src_rank,
|
||||
)
|
||||
from pydantic import BaseModel, Field
|
||||
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.models.llama.datatypes import GenerationResult
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
ChatCompletionRequestWithRawContent,
|
||||
CompletionRequestWithRawContent,
|
||||
)
|
||||
from pydantic import BaseModel, Field
|
||||
from torch.distributed.launcher.api import elastic_launch, LaunchConfig
|
||||
from typing_extensions import Annotated
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
@ -52,33 +52,51 @@ class ProcessingMessageName(str, Enum):
|
|||
|
||||
|
||||
class ReadyRequest(BaseModel):
|
||||
type: Literal[ProcessingMessageName.ready_request] = ProcessingMessageName.ready_request
|
||||
type: Literal[ProcessingMessageName.ready_request] = (
|
||||
ProcessingMessageName.ready_request
|
||||
)
|
||||
|
||||
|
||||
class ReadyResponse(BaseModel):
|
||||
type: Literal[ProcessingMessageName.ready_response] = ProcessingMessageName.ready_response
|
||||
type: Literal[ProcessingMessageName.ready_response] = (
|
||||
ProcessingMessageName.ready_response
|
||||
)
|
||||
|
||||
|
||||
class EndSentinel(BaseModel):
|
||||
type: Literal[ProcessingMessageName.end_sentinel] = ProcessingMessageName.end_sentinel
|
||||
type: Literal[ProcessingMessageName.end_sentinel] = (
|
||||
ProcessingMessageName.end_sentinel
|
||||
)
|
||||
|
||||
|
||||
class CancelSentinel(BaseModel):
|
||||
type: Literal[ProcessingMessageName.cancel_sentinel] = ProcessingMessageName.cancel_sentinel
|
||||
type: Literal[ProcessingMessageName.cancel_sentinel] = (
|
||||
ProcessingMessageName.cancel_sentinel
|
||||
)
|
||||
|
||||
|
||||
class TaskRequest(BaseModel):
|
||||
type: Literal[ProcessingMessageName.task_request] = ProcessingMessageName.task_request
|
||||
task: Tuple[str, List[CompletionRequestWithRawContent] | List[ChatCompletionRequestWithRawContent]]
|
||||
type: Literal[ProcessingMessageName.task_request] = (
|
||||
ProcessingMessageName.task_request
|
||||
)
|
||||
task: Tuple[
|
||||
str,
|
||||
List[CompletionRequestWithRawContent]
|
||||
| List[ChatCompletionRequestWithRawContent],
|
||||
]
|
||||
|
||||
|
||||
class TaskResponse(BaseModel):
|
||||
type: Literal[ProcessingMessageName.task_response] = ProcessingMessageName.task_response
|
||||
type: Literal[ProcessingMessageName.task_response] = (
|
||||
ProcessingMessageName.task_response
|
||||
)
|
||||
result: List[GenerationResult]
|
||||
|
||||
|
||||
class ExceptionResponse(BaseModel):
|
||||
type: Literal[ProcessingMessageName.exception_response] = ProcessingMessageName.exception_response
|
||||
type: Literal[ProcessingMessageName.exception_response] = (
|
||||
ProcessingMessageName.exception_response
|
||||
)
|
||||
error: str
|
||||
|
||||
|
||||
|
@ -172,7 +190,9 @@ def retrieve_requests(reply_socket_url: str):
|
|||
group=get_model_parallel_group(),
|
||||
)
|
||||
if isinstance(updates[0], CancelSentinel):
|
||||
log.info("quitting generation loop because request was cancelled")
|
||||
log.info(
|
||||
"quitting generation loop because request was cancelled"
|
||||
)
|
||||
break
|
||||
|
||||
if mp_rank_0():
|
||||
|
@ -234,7 +254,7 @@ def worker_process_entrypoint(
|
|||
if isinstance(task, EndSentinel):
|
||||
break
|
||||
|
||||
assert isinstance(task, TaskRequest)
|
||||
assert isinstance(task, TaskRequest), task
|
||||
result = model(task.task)
|
||||
except StopIteration:
|
||||
break
|
||||
|
@ -331,7 +351,11 @@ class ModelParallelProcessGroup:
|
|||
|
||||
def run_inference(
|
||||
self,
|
||||
req: Tuple[str, List[CompletionRequestWithRawContent] | List[ChatCompletionRequestWithRawContent]],
|
||||
req: Tuple[
|
||||
str,
|
||||
List[CompletionRequestWithRawContent]
|
||||
| List[ChatCompletionRequestWithRawContent],
|
||||
],
|
||||
) -> Generator:
|
||||
assert not self.running, "inference already running"
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue