temp commit

This commit is contained in:
Botao Chen 2024-12-16 19:04:47 -08:00
parent b2dbb5e3fe
commit 30f6eb282f
7 changed files with 16 additions and 15 deletions

View file

@ -74,7 +74,6 @@ class InferenceRouter(Inference):
self, self,
routing_table: RoutingTable, routing_table: RoutingTable,
) -> None: ) -> None:
print("InferenceRouter init")
self.routing_table = routing_table self.routing_table = routing_table
async def initialize(self) -> None: async def initialize(self) -> None:
@ -91,7 +90,6 @@ class InferenceRouter(Inference):
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
model_type: Optional[ModelType] = None, model_type: Optional[ModelType] = None,
) -> None: ) -> None:
print("inference router")
await self.routing_table.register_model( await self.routing_table.register_model(
model_id, provider_model_id, provider_id, metadata, model_type model_id, provider_model_id, provider_id, metadata, model_type
) )

View file

@ -15,9 +15,6 @@ async def get_provider_impl(
): ):
from .inference import MetaReferenceInferenceImpl from .inference import MetaReferenceInferenceImpl
print("get_provider_impl")
impl = MetaReferenceInferenceImpl(config) impl = MetaReferenceInferenceImpl(config)
print("after MetaReferenceInferenceImpl")
return impl return impl

View file

@ -94,7 +94,6 @@ class MetaReferenceInferenceImpl(
], ],
) )
model = await self.model_registry_helper.register_model(model) model = await self.model_registry_helper.register_model(model)
print("model type", type(model))
if model.model_type == ModelType.embedding: if model.model_type == ModelType.embedding:
self._load_sentence_transformer_model(model.provider_resource_id) self._load_sentence_transformer_model(model.provider_resource_id)
@ -304,7 +303,6 @@ class MetaReferenceInferenceImpl(
if self.config.create_distributed_process_group: if self.config.create_distributed_process_group:
async with SEMAPHORE: async with SEMAPHORE:
print("after SEMAPHORE")
return impl() return impl()
else: else:
return impl() return impl()

View file

@ -58,7 +58,6 @@ class LlamaModelParallelGenerator:
config: MetaReferenceInferenceConfig, config: MetaReferenceInferenceConfig,
model_id: str, model_id: str,
): ):
print("LlamaModelParallelGenerator init")
self.config = config self.config = config
self.model_id = model_id self.model_id = model_id
self.model = resolve_model(model_id) self.model = resolve_model(model_id)
@ -76,7 +75,6 @@ class LlamaModelParallelGenerator:
self.__exit__(None, None, None) self.__exit__(None, None, None)
def __enter__(self): def __enter__(self):
print("enter LlamaModelParallelGenerator")
if self.config.model_parallel_size: if self.config.model_parallel_size:
model_parallel_size = self.config.model_parallel_size model_parallel_size = self.config.model_parallel_size
else: else:

View file

@ -69,6 +69,7 @@ def pytest_generate_tests(metafunc):
else: else:
params = MODEL_PARAMS params = MODEL_PARAMS
# print("params", params)
metafunc.parametrize( metafunc.parametrize(
"inference_model", "inference_model",
params, params,
@ -82,5 +83,7 @@ def pytest_generate_tests(metafunc):
"inference": INFERENCE_FIXTURES, "inference": INFERENCE_FIXTURES,
}, },
): ):
# print("I reach here")
fixtures = [stack.values[0]["inference"] for stack in filtered_stacks] fixtures = [stack.values[0]["inference"] for stack in filtered_stacks]
print("fixtures", fixtures)
metafunc.parametrize("inference_stack", fixtures, indirect=True) metafunc.parametrize("inference_stack", fixtures, indirect=True)

View file

@ -3,10 +3,17 @@ image_name: experimental-post-training
docker_image: null docker_image: null
conda_env: experimental-post-training conda_env: experimental-post-training
apis: apis:
- inference
- telemetry - telemetry
- datasetio - datasetio
- post_training - post_training
providers: providers:
inference:
- provider_id: meta-reference-inference
provider_type: inline::meta-reference
config:
max_seq_len: 4096
checkpoint_dir: null
datasetio: datasetio:
- provider_id: huggingface-0 - provider_id: huggingface-0
provider_type: remote::huggingface provider_type: remote::huggingface
@ -24,11 +31,11 @@ metadata_store:
namespace: null namespace: null
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/registry.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/registry.db
models: models: []
- metadata: {} # - metadata: {}
model_id: ${env.POST_TRAINING_MODEL} # model_id: ${env.POST_TRAINING_MODEL}
provider_id: meta-reference-inference # provider_id: meta-reference-inference
provider_model_id: null # provider_model_id: null
shields: [] shields: []
memory_banks: [] memory_banks: []
datasets: datasets:

View file

@ -16,7 +16,7 @@ providers:
- provider_id: meta-reference-inference - provider_id: meta-reference-inference
provider_type: inline::meta-reference-quantized provider_type: inline::meta-reference-quantized
config: config:
model: ${env.INFERENCE_MODEL} model: ${env.INFERENCE_MODEL} # please make sure your inference model here is added as resource
max_seq_len: 4096 max_seq_len: 4096
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null} checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
quantization: quantization: