several fixes

This commit is contained in:
Ashwin Bharambe 2025-04-07 10:31:20 -07:00
parent e2e2820c9a
commit 53a8086e37
60 changed files with 1006 additions and 1078 deletions

View file

@ -12,20 +12,19 @@ from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerToken
from llama_stack.apis.inference import (
Fp8QuantizationConfig,
GreedySamplingStrategy,
Int4QuantizationConfig,
JsonSchemaResponseFormat,
ResponseFormat,
)
from llama_stack.models.llama.datatypes import (
GreedySamplingStrategy,
Model,
SamplingParams,
TopPSamplingStrategy,
)
from llama_stack.models.llama.datatypes import QuantizationMode
from llama_stack.models.llama.llama3.generation import Llama3
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
from llama_stack.models.llama.llama4.generation import Llama4
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
from llama_stack.models.llama.sku_types import Model
from llama_stack.providers.utils.inference.prompt_adapter import (
ChatCompletionRequestWithRawContent,
CompletionRequestWithRawContent,
@ -136,9 +135,9 @@ class Llama4Generator:
if isinstance(config, MetaReferenceQuantizedInferenceConfig):
if isinstance(config.quantization, Fp8QuantizationConfig):
quantization_mode = "fp8_mixed"
quantization_mode = QuantizationMode.fp8_mixed
elif isinstance(config.quantization, Int4QuantizationConfig):
quantization_mode = "int4_mixed"
quantization_mode = QuantizationMode.int4_mixed
else:
raise ValueError(f"Unsupported quantization mode {config.quantization}")
else:
@ -225,9 +224,9 @@ class Llama3Generator:
if isinstance(config, MetaReferenceQuantizedInferenceConfig):
if isinstance(config.quantization, Fp8QuantizationConfig):
quantization_mode = "fp8_mixed"
quantization_mode = QuantizationMode.fp8_mixed
elif isinstance(config.quantization, Int4QuantizationConfig):
quantization_mode = "int4_mixed"
quantization_mode = QuantizationMode.int4_mixed
else:
raise ValueError(f"Unsupported quantization mode {config.quantization}")
else:
@ -240,6 +239,9 @@ class Llama3Generator:
world_size=llama_model.pth_file_count,
quantization_mode=quantization_mode,
)
self.tokenizer = self.inner_generator.tokenizer
self.args = self.inner_generator.args
self.formatter = self.inner_generator.formatter
def completion(
self,

View file

@ -31,23 +31,21 @@ from llama_stack.apis.inference import (
LogProbConfig,
Message,
ResponseFormat,
SamplingParams,
StopReason,
TokenLogProbs,
ToolChoice,
ToolConfig,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.models.llama.datatypes import (
ModelFamily,
SamplingParams,
StopReason,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.models.llama.sku_types import ModelFamily
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.embedding_mixin import (
SentenceTransformerEmbeddingMixin,