temp commit

This commit is contained in:
Botao Chen 2024-12-16 19:04:47 -08:00
parent b2dbb5e3fe
commit 30f6eb282f
7 changed files with 16 additions and 15 deletions

View file

@ -74,7 +74,6 @@ class InferenceRouter(Inference):
self,
routing_table: RoutingTable,
) -> None:
print("InferenceRouter init")
self.routing_table = routing_table
async def initialize(self) -> None:
@ -91,7 +90,6 @@ class InferenceRouter(Inference):
metadata: Optional[Dict[str, Any]] = None,
model_type: Optional[ModelType] = None,
) -> None:
print("inference router")
await self.routing_table.register_model(
model_id, provider_model_id, provider_id, metadata, model_type
)

View file

@ -15,9 +15,6 @@ async def get_provider_impl(
):
from .inference import MetaReferenceInferenceImpl
print("get_provider_impl")
impl = MetaReferenceInferenceImpl(config)
print("after MetaReferenceInferenceImpl")
return impl

View file

@ -94,7 +94,6 @@ class MetaReferenceInferenceImpl(
],
)
model = await self.model_registry_helper.register_model(model)
print("model type", type(model))
if model.model_type == ModelType.embedding:
self._load_sentence_transformer_model(model.provider_resource_id)
@ -304,7 +303,6 @@ class MetaReferenceInferenceImpl(
if self.config.create_distributed_process_group:
async with SEMAPHORE:
print("after SEMAPHORE")
return impl()
else:
return impl()

View file

@ -58,7 +58,6 @@ class LlamaModelParallelGenerator:
config: MetaReferenceInferenceConfig,
model_id: str,
):
print("LlamaModelParallelGenerator init")
self.config = config
self.model_id = model_id
self.model = resolve_model(model_id)
@ -76,7 +75,6 @@ class LlamaModelParallelGenerator:
self.__exit__(None, None, None)
def __enter__(self):
print("enter LlamaModelParallelGenerator")
if self.config.model_parallel_size:
model_parallel_size = self.config.model_parallel_size
else:

View file

@ -69,6 +69,7 @@ def pytest_generate_tests(metafunc):
else:
params = MODEL_PARAMS
# print("params", params)
metafunc.parametrize(
"inference_model",
params,
@ -82,5 +83,7 @@ def pytest_generate_tests(metafunc):
"inference": INFERENCE_FIXTURES,
},
):
# print("I reach here")
fixtures = [stack.values[0]["inference"] for stack in filtered_stacks]
print("fixtures", fixtures)
metafunc.parametrize("inference_stack", fixtures, indirect=True)

View file

@ -3,10 +3,17 @@ image_name: experimental-post-training
docker_image: null
conda_env: experimental-post-training
apis:
- inference
- telemetry
- datasetio
- post_training
providers:
inference:
- provider_id: meta-reference-inference
provider_type: inline::meta-reference
config:
max_seq_len: 4096
checkpoint_dir: null
datasetio:
- provider_id: huggingface-0
provider_type: remote::huggingface
@ -24,11 +31,11 @@ metadata_store:
namespace: null
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/registry.db
models:
- metadata: {}
model_id: ${env.POST_TRAINING_MODEL}
provider_id: meta-reference-inference
provider_model_id: null
models: []
# - metadata: {}
# model_id: ${env.POST_TRAINING_MODEL}
# provider_id: meta-reference-inference
# provider_model_id: null
shields: []
memory_banks: []
datasets:

View file

@ -16,7 +16,7 @@ providers:
- provider_id: meta-reference-inference
provider_type: inline::meta-reference-quantized
config:
model: ${env.INFERENCE_MODEL}
model: ${env.INFERENCE_MODEL} # please make sure your inference model here is added as resource
max_seq_len: 4096
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
quantization: