mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-30 07:14:19 +00:00
fixes
This commit is contained in:
parent
a5d6ab16b2
commit
693c709c27
2 changed files with 103 additions and 34 deletions
|
|
@ -8,9 +8,6 @@ import asyncio
|
|||
import os
|
||||
from typing import AsyncGenerator, List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
TextDelta,
|
||||
ToolCallDelta,
|
||||
|
|
@ -55,8 +52,8 @@ from llama_stack.providers.utils.inference.embedding_mixin import (
|
|||
SentenceTransformerEmbeddingMixin,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
build_hf_repo_model_entry,
|
||||
ModelRegistryHelper,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
|
|
@ -68,6 +65,9 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
convert_request_to_raw,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel
|
||||
from termcolor import cprint
|
||||
|
||||
from .config import MetaReferenceInferenceConfig
|
||||
from .generators import LlamaGenerator
|
||||
from .model_parallel import LlamaModelParallelGenerator
|
||||
|
|
@ -78,7 +78,9 @@ log = get_logger(__name__, category="inference")
|
|||
SEMAPHORE = asyncio.Semaphore(1)
|
||||
|
||||
|
||||
def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> LlamaGenerator:
|
||||
def llama_builder_fn(
|
||||
config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model
|
||||
) -> LlamaGenerator:
|
||||
return LlamaGenerator(config, model_id, llama_model)
|
||||
|
||||
|
||||
|
|
@ -143,7 +145,8 @@ class MetaReferenceInferenceImpl(
|
|||
|
||||
if self.config.create_distributed_process_group:
|
||||
self.generator = LlamaModelParallelGenerator(
|
||||
model_parallel_size=self.config.model_parallel_size or llama_model.pth_file_count,
|
||||
model_parallel_size=self.config.model_parallel_size
|
||||
or llama_model.pth_file_count,
|
||||
builder_fn=llama_builder_fn,
|
||||
builder_params=builder_params,
|
||||
formatter=(
|
||||
|
|
@ -178,7 +181,9 @@ class MetaReferenceInferenceImpl(
|
|||
"No avaible model yet, please register your requested model or add your model in the resouces first"
|
||||
)
|
||||
elif request.model != self.model_id:
|
||||
raise RuntimeError(f"Model mismatch: request model: {request.model} != loaded model: {self.model_id}")
|
||||
raise RuntimeError(
|
||||
f"Model mismatch: request model: {request.model} != loaded model: {self.model_id}"
|
||||
)
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
|
|
@ -227,7 +232,8 @@ class MetaReferenceInferenceImpl(
|
|||
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
|
||||
|
||||
content_batch = [
|
||||
augment_content_with_response_format_prompt(response_format, content) for content in content_batch
|
||||
augment_content_with_response_format_prompt(response_format, content)
|
||||
for content in content_batch
|
||||
]
|
||||
|
||||
request_batch = []
|
||||
|
|
@ -253,7 +259,8 @@ class MetaReferenceInferenceImpl(
|
|||
def impl():
|
||||
stop_reason = None
|
||||
|
||||
for token_result in self.generator.completion(request):
|
||||
for token_results in self.generator.completion([request]):
|
||||
token_result = token_results[0]
|
||||
if token_result.token == tokenizer.eot_id:
|
||||
stop_reason = StopReason.end_of_turn
|
||||
text = ""
|
||||
|
|
@ -268,7 +275,13 @@ class MetaReferenceInferenceImpl(
|
|||
if request.logprobs:
|
||||
assert len(token_result.logprobs) == 1
|
||||
|
||||
logprobs = [TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]})]
|
||||
logprobs = [
|
||||
TokenLogProbs(
|
||||
logprobs_by_token={
|
||||
token_result.text: token_result.logprobs[0]
|
||||
}
|
||||
)
|
||||
]
|
||||
|
||||
yield CompletionResponseStreamChunk(
|
||||
delta=text,
|
||||
|
|
@ -290,7 +303,9 @@ class MetaReferenceInferenceImpl(
|
|||
for x in impl():
|
||||
yield x
|
||||
|
||||
async def _nonstream_completion(self, request_batch: List[CompletionRequest]) -> List[CompletionResponse]:
|
||||
async def _nonstream_completion(
|
||||
self, request_batch: List[CompletionRequest]
|
||||
) -> List[CompletionResponse]:
|
||||
tokenizer = self.generator.formatter.tokenizer
|
||||
|
||||
first_request = request_batch[0]
|
||||
|
|
@ -314,7 +329,11 @@ class MetaReferenceInferenceImpl(
|
|||
|
||||
state.finished = result.finished
|
||||
if first_request.logprobs:
|
||||
state.logprobs.append(TokenLogProbs(logprobs_by_token={result.text: result.logprobs[0]}))
|
||||
state.logprobs.append(
|
||||
TokenLogProbs(
|
||||
logprobs_by_token={result.text: result.logprobs[0]}
|
||||
)
|
||||
)
|
||||
|
||||
state.tokens.append(result.token)
|
||||
if result.token == tokenizer.eot_id:
|
||||
|
|
@ -377,7 +396,9 @@ class MetaReferenceInferenceImpl(
|
|||
self.check_model(request)
|
||||
|
||||
# augment and rewrite messages depending on the model
|
||||
request.messages = chat_completion_request_to_messages(request, self.llama_model.core_model_id.value)
|
||||
request.messages = chat_completion_request_to_messages(
|
||||
request, self.llama_model.core_model_id.value
|
||||
)
|
||||
# download media and convert to raw content so we can send it to the model
|
||||
request = await convert_request_to_raw(request)
|
||||
|
||||
|
|
@ -422,7 +443,9 @@ class MetaReferenceInferenceImpl(
|
|||
self.check_model(request)
|
||||
|
||||
# augment and rewrite messages depending on the model
|
||||
request.messages = chat_completion_request_to_messages(request, self.llama_model.core_model_id.value)
|
||||
request.messages = chat_completion_request_to_messages(
|
||||
request, self.llama_model.core_model_id.value
|
||||
)
|
||||
# download media and convert to raw content so we can send it to the model
|
||||
request = await convert_request_to_raw(request)
|
||||
request_batch.append(request)
|
||||
|
|
@ -466,7 +489,11 @@ class MetaReferenceInferenceImpl(
|
|||
|
||||
state.finished = result.finished
|
||||
if first_request.logprobs:
|
||||
state.logprobs.append(TokenLogProbs(logprobs_by_token={result.text: result.logprobs[0]}))
|
||||
state.logprobs.append(
|
||||
TokenLogProbs(
|
||||
logprobs_by_token={result.text: result.logprobs[0]}
|
||||
)
|
||||
)
|
||||
|
||||
state.tokens.append(result.token)
|
||||
if result.token == tokenizer.eot_id:
|
||||
|
|
@ -479,7 +506,9 @@ class MetaReferenceInferenceImpl(
|
|||
if state.stop_reason is None:
|
||||
state.stop_reason = StopReason.out_of_tokens
|
||||
|
||||
raw_message = self.generator.formatter.decode_assistant_message(state.tokens, state.stop_reason)
|
||||
raw_message = self.generator.formatter.decode_assistant_message(
|
||||
state.tokens, state.stop_reason
|
||||
)
|
||||
results.append(
|
||||
ChatCompletionResponse(
|
||||
completion_message=CompletionMessage(
|
||||
|
|
@ -499,7 +528,9 @@ class MetaReferenceInferenceImpl(
|
|||
else:
|
||||
return impl()
|
||||
|
||||
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
||||
async def _stream_chat_completion(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> AsyncGenerator:
|
||||
tokenizer = self.generator.formatter.tokenizer
|
||||
|
||||
def impl():
|
||||
|
|
@ -534,7 +565,13 @@ class MetaReferenceInferenceImpl(
|
|||
if request.logprobs:
|
||||
assert len(token_result.logprobs) == 1
|
||||
|
||||
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
|
||||
logprobs.append(
|
||||
TokenLogProbs(
|
||||
logprobs_by_token={
|
||||
token_result.text: token_result.logprobs[0]
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
tokens.append(token_result.token)
|
||||
|
||||
|
|
@ -572,7 +609,13 @@ class MetaReferenceInferenceImpl(
|
|||
if request.logprobs:
|
||||
assert len(token_result.logprobs) == 1
|
||||
|
||||
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
|
||||
logprobs.append(
|
||||
TokenLogProbs(
|
||||
logprobs_by_token={
|
||||
token_result.text: token_result.logprobs[0]
|
||||
}
|
||||
)
|
||||
)
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
|
|
@ -585,7 +628,9 @@ class MetaReferenceInferenceImpl(
|
|||
if stop_reason is None:
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
|
||||
message = self.generator.formatter.decode_assistant_message(tokens, stop_reason)
|
||||
message = self.generator.formatter.decode_assistant_message(
|
||||
tokens, stop_reason
|
||||
)
|
||||
|
||||
parsed_tool_calls = len(message.tool_calls) > 0
|
||||
if ipython and not parsed_tool_calls:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue