This commit is contained in:
Botao Chen 2024-12-17 13:38:19 -08:00
parent 415b8f2dbd
commit 48482ff9c3
9 changed files with 18 additions and 57 deletions

View file

@ -7,7 +7,6 @@
from typing import Any, Dict, Optional
from llama_models.datatypes import * # noqa: F403
from llama_models.sku_list import resolve_model
from llama_stack.apis.inference import * # noqa: F401, F403
from pydantic import BaseModel, field_validator
@ -16,9 +15,11 @@ from llama_stack.providers.utils.inference import supported_inference_models
class MetaReferenceInferenceConfig(BaseModel):
model: Optional[str] = (
None # this is a placeholder to indicate inference model id, not actually being used
)
# this is a placeholder to indicate inference model id
# the actual inference model id is dtermined by the moddel id in the request
# Note: you need to register the model before using it for inference
# models in the resouce list in the run.yaml config will be registered automatically
model: Optional[str] = None
torch_seed: Optional[int] = None
max_seq_len: int = 4096
max_batch_size: int = 1
@ -45,13 +46,6 @@ class MetaReferenceInferenceConfig(BaseModel):
)
return model
@property
def model_parallel_size(self) -> Optional[int]:
resolved = resolve_model(self.model)
if resolved is None:
return None
return resolved.pth_file_count
@classmethod
def sample_run_config(
cls,