Rework InterleavedContentMedia datatype so URL downloading is in llama-stack

This commit is contained in:
Ashwin Bharambe 2024-12-15 13:23:30 -08:00
parent c2f7905fa4
commit a9a041a1de
10 changed files with 368 additions and 146 deletions

View file

@ -24,7 +24,8 @@ from fairscale.nn.model_parallel.initialize import (
model_parallel_is_initialized,
)
from llama_models.llama3.api.args import ModelArgs
from llama_models.llama3.api.chat_format import ChatFormat, ModelInput
from llama_models.llama3.api.chat_format import ChatFormat, LLMInput
from llama_models.llama3.api.datatypes import RawContent, RawMessage
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.llama3.reference_impl.model import Transformer
from llama_models.llama3.reference_impl.multimodal.model import (
@ -38,10 +39,6 @@ from llama_stack.apis.inference import * # noqa: F403
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.providers.utils.inference.prompt_adapter import (
augment_content_with_response_format_prompt,
chat_completion_request_to_messages,
)
from .config import (
Fp8QuantizationConfig,
@ -53,6 +50,14 @@ from .config import (
log = logging.getLogger(__name__)
class ChatCompletionRequestWithRawContent(ChatCompletionRequest):
messages: List[RawMessage]
class CompletionRequestWithRawContent(CompletionRequest):
content: RawContent
def model_checkpoint_dir(model) -> str:
checkpoint_dir = Path(model_local_dir(model.descriptor()))
@ -206,7 +211,7 @@ class Llama:
@torch.inference_mode()
def generate(
self,
model_input: ModelInput,
model_input: LLMInput,
max_gen_len: int,
temperature: float = 0.6,
top_p: float = 0.9,
@ -343,7 +348,7 @@ class Llama:
def completion(
self,
request: CompletionRequest,
request: CompletionRequestWithRawContent,
) -> Generator:
sampling_params = request.sampling_params
max_gen_len = sampling_params.max_tokens
@ -354,10 +359,7 @@ class Llama:
):
max_gen_len = self.model.params.max_seq_len - 1
content = augment_content_with_response_format_prompt(
request.response_format, request.content
)
model_input = self.formatter.encode_content(content)
model_input = self.formatter.encode_content(request.content)
yield from self.generate(
model_input=model_input,
max_gen_len=max_gen_len,
@ -374,10 +376,8 @@ class Llama:
def chat_completion(
self,
request: ChatCompletionRequest,
request: ChatCompletionRequestWithRawContent,
) -> Generator:
messages = chat_completion_request_to_messages(request, self.llama_model)
sampling_params = request.sampling_params
max_gen_len = sampling_params.max_tokens
if (
@ -389,7 +389,7 @@ class Llama:
yield from self.generate(
model_input=self.formatter.encode_dialog_prompt(
messages,
request.messages,
request.tool_prompt_format,
),
max_gen_len=max_gen_len,