From 919d421bcf3bcd73123caedc08c741566e1f0d14 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Tue, 12 Nov 2024 15:37:07 -0800 Subject: [PATCH] fixes after rebase --- .../inference/meta_reference/generation.py | 14 +++++-- .../inference/meta_reference/inference.py | 8 +--- .../remote/inference/bedrock/bedrock.py | 6 +-- .../remote/inference/databricks/databricks.py | 4 +- .../remote/inference/fireworks/fireworks.py | 18 ++++---- .../remote/inference/ollama/ollama.py | 15 +++---- .../remote/inference/together/together.py | 16 ++++---- .../providers/remote/inference/vllm/vllm.py | 3 +- .../providers/tests/inference/fixtures.py | 4 +- .../tests/inference/test_text_inference.py | 41 +++++++++---------- .../utils/inference/model_registry.py | 13 +++--- 11 files changed, 72 insertions(+), 70 deletions(-) diff --git a/llama_stack/providers/inline/inference/meta_reference/generation.py b/llama_stack/providers/inline/inference/meta_reference/generation.py index 2f296c7c2..38c982473 100644 --- a/llama_stack/providers/inline/inference/meta_reference/generation.py +++ b/llama_stack/providers/inline/inference/meta_reference/generation.py @@ -86,6 +86,7 @@ class Llama: and loads the pre-trained model and tokenizer. """ model = resolve_model(config.model) + llama_model = model.core_model_id.value if not torch.distributed.is_initialized(): torch.distributed.init_process_group("nccl") @@ -186,13 +187,20 @@ class Llama: model.load_state_dict(state_dict, strict=False) print(f"Loaded in {time.time() - start_time:.2f} seconds") - return Llama(model, tokenizer, model_args) + return Llama(model, tokenizer, model_args, llama_model) - def __init__(self, model: Transformer, tokenizer: Tokenizer, args: ModelArgs): + def __init__( + self, + model: Transformer, + tokenizer: Tokenizer, + args: ModelArgs, + llama_model: str, + ): self.args = args self.model = model self.tokenizer = tokenizer self.formatter = ChatFormat(tokenizer) + self.llama_model = llama_model @torch.inference_mode() def generate( @@ -369,7 +377,7 @@ class Llama: self, request: ChatCompletionRequest, ) -> Generator: - messages = chat_completion_request_to_messages(request) + messages = chat_completion_request_to_messages(request, self.llama_model) sampling_params = request.sampling_params max_gen_len = sampling_params.max_tokens diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index 844cf6939..4f5c0c8c2 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -39,7 +39,7 @@ class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolP [ build_model_alias( model.descriptor(), - model.core_model_id, + model.core_model_id.value, ) ], ) @@ -56,12 +56,6 @@ class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolP else: self.generator = Llama.build(self.config) - async def register_model(self, model: Model) -> None: - if model.provider_resource_id != self.model.descriptor(): - raise ValueError( - f"Model mismatch: {model.identifier} != {self.model.descriptor()}" - ) - async def shutdown(self) -> None: if self.config.create_distributed_process_group: self.generator.stop() diff --git a/llama_stack/providers/remote/inference/bedrock/bedrock.py b/llama_stack/providers/remote/inference/bedrock/bedrock.py index 8762a6c95..f575d9dc3 100644 --- a/llama_stack/providers/remote/inference/bedrock/bedrock.py +++ b/llama_stack/providers/remote/inference/bedrock/bedrock.py @@ -26,15 +26,15 @@ from llama_stack.providers.utils.bedrock.client import create_bedrock_client model_aliases = [ build_model_alias( "meta.llama3-1-8b-instruct-v1:0", - CoreModelId.llama3_1_8b_instruct, + CoreModelId.llama3_1_8b_instruct.value, ), build_model_alias( "meta.llama3-1-70b-instruct-v1:0", - CoreModelId.llama3_1_70b_instruct, + CoreModelId.llama3_1_70b_instruct.value, ), build_model_alias( "meta.llama3-1-405b-instruct-v1:0", - CoreModelId.llama3_1_405b_instruct, + CoreModelId.llama3_1_405b_instruct.value, ), ] diff --git a/llama_stack/providers/remote/inference/databricks/databricks.py b/llama_stack/providers/remote/inference/databricks/databricks.py index 1337b6c09..0ebb625bc 100644 --- a/llama_stack/providers/remote/inference/databricks/databricks.py +++ b/llama_stack/providers/remote/inference/databricks/databricks.py @@ -36,11 +36,11 @@ from .config import DatabricksImplConfig model_aliases = [ build_model_alias( "databricks-meta-llama-3-1-70b-instruct", - CoreModelId.llama3_1_70b_instruct, + CoreModelId.llama3_1_70b_instruct.value, ), build_model_alias( "databricks-meta-llama-3-1-405b-instruct", - CoreModelId.llama3_1_405b_instruct, + CoreModelId.llama3_1_405b_instruct.value, ), ] diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index e0d42c721..42075eff7 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -38,39 +38,39 @@ from .config import FireworksImplConfig model_aliases = [ build_model_alias( "fireworks/llama-v3p1-8b-instruct", - CoreModelId.llama3_1_8b_instruct, + CoreModelId.llama3_1_8b_instruct.value, ), build_model_alias( "fireworks/llama-v3p1-70b-instruct", - CoreModelId.llama3_1_70b_instruct, + CoreModelId.llama3_1_70b_instruct.value, ), build_model_alias( "fireworks/llama-v3p1-405b-instruct", - CoreModelId.llama3_1_405b_instruct, + CoreModelId.llama3_1_405b_instruct.value, ), build_model_alias( "fireworks/llama-v3p2-1b-instruct", - CoreModelId.llama3_2_3b_instruct, + CoreModelId.llama3_2_3b_instruct.value, ), build_model_alias( "fireworks/llama-v3p2-3b-instruct", - CoreModelId.llama3_2_11b_vision_instruct, + CoreModelId.llama3_2_11b_vision_instruct.value, ), build_model_alias( "fireworks/llama-v3p2-11b-vision-instruct", - CoreModelId.llama3_2_11b_vision_instruct, + CoreModelId.llama3_2_11b_vision_instruct.value, ), build_model_alias( "fireworks/llama-v3p2-90b-vision-instruct", - CoreModelId.llama3_2_90b_vision_instruct, + CoreModelId.llama3_2_90b_vision_instruct.value, ), build_model_alias( "fireworks/llama-guard-3-8b", - CoreModelId.llama_guard_3_8b, + CoreModelId.llama_guard_3_8b.value, ), build_model_alias( "fireworks/llama-guard-3-11b-vision", - CoreModelId.llama_guard_3_11b_vision, + CoreModelId.llama_guard_3_11b_vision.value, ), ] diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 34af95b50..99f74572e 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -42,31 +42,31 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( model_aliases = [ build_model_alias( "llama3.1:8b-instruct-fp16", - CoreModelId.llama3_1_8b_instruct, + CoreModelId.llama3_1_8b_instruct.value, ), build_model_alias( "llama3.1:70b-instruct-fp16", - CoreModelId.llama3_1_70b_instruct, + CoreModelId.llama3_1_70b_instruct.value, ), build_model_alias( "llama3.2:1b-instruct-fp16", - CoreModelId.llama3_2_1b_instruct, + CoreModelId.llama3_2_1b_instruct.value, ), build_model_alias( "llama3.2:3b-instruct-fp16", - CoreModelId.llama3_2_3b_instruct, + CoreModelId.llama3_2_3b_instruct.value, ), build_model_alias( "llama-guard3:8b", - CoreModelId.llama_guard_3_8b, + CoreModelId.llama_guard_3_8b.value, ), build_model_alias( "llama-guard3:1b", - CoreModelId.llama_guard_3_1b, + CoreModelId.llama_guard_3_1b.value, ), build_model_alias( "x/llama3.2-vision:11b-instruct-fp16", - CoreModelId.llama3_2_11b_vision_instruct, + CoreModelId.llama3_2_11b_vision_instruct.value, ), ] @@ -164,6 +164,7 @@ class OllamaInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPriva logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: model = await self.model_store.get_model(model_id) + print(f"model={model}") request = ChatCompletionRequest( model=model.provider_resource_id, messages=messages, diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index 644302a0f..aae34bb87 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -41,35 +41,35 @@ from .config import TogetherImplConfig model_aliases = [ build_model_alias( "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", - CoreModelId.llama3_1_8b_instruct, + CoreModelId.llama3_1_8b_instruct.value, ), build_model_alias( "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", - CoreModelId.llama3_1_70b_instruct, + CoreModelId.llama3_1_70b_instruct.value, ), build_model_alias( "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo", - CoreModelId.llama3_1_405b_instruct, + CoreModelId.llama3_1_405b_instruct.value, ), build_model_alias( "meta-llama/Llama-3.2-3B-Instruct-Turbo", - CoreModelId.llama3_2_3b_instruct, + CoreModelId.llama3_2_3b_instruct.value, ), build_model_alias( "meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo", - CoreModelId.llama3_2_11b_vision_instruct, + CoreModelId.llama3_2_11b_vision_instruct.value, ), build_model_alias( "meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo", - CoreModelId.llama3_2_90b_vision_instruct, + CoreModelId.llama3_2_90b_vision_instruct.value, ), build_model_alias( "meta-llama/Meta-Llama-Guard-3-8B", - CoreModelId.llama_guard_3_8b, + CoreModelId.llama_guard_3_8b.value, ), build_model_alias( "meta-llama/Llama-Guard-3-11B-Vision-Turbo", - CoreModelId.llama_guard_3_11b_vision, + CoreModelId.llama_guard_3_11b_vision.value, ), ] diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 9bf25c5ad..2d03a9ef8 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -38,7 +38,7 @@ def build_model_aliases(): return [ build_model_alias( model.huggingface_repo, - model.core_model_id, + model.descriptor(), ) for model in all_registered_models() if model.huggingface_repo @@ -85,6 +85,7 @@ class VLLMInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPrivate logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: model = await self.model_store.get_model(model_id) + print(f"model={model}") request = ChatCompletionRequest( model=model.provider_resource_id, messages=messages, diff --git a/llama_stack/providers/tests/inference/fixtures.py b/llama_stack/providers/tests/inference/fixtures.py index 59bd492b9..f6f2a30e8 100644 --- a/llama_stack/providers/tests/inference/fixtures.py +++ b/llama_stack/providers/tests/inference/fixtures.py @@ -179,7 +179,7 @@ INFERENCE_FIXTURES = [ @pytest_asyncio.fixture(scope="session") -async def inference_stack(request, inference_model, model_id): +async def inference_stack(request, inference_model): fixture_name = request.param inference_fixture = request.getfixturevalue(f"inference_{fixture_name}") impls = await resolve_impls_for_test_v2( @@ -188,7 +188,7 @@ async def inference_stack(request, inference_model, model_id): inference_fixture.provider_data, models=[ ModelInput( - model_id=model_id, + model_id=inference_model, ) ], ) diff --git a/llama_stack/providers/tests/inference/test_text_inference.py b/llama_stack/providers/tests/inference/test_text_inference.py index 9850b328e..70047a61f 100644 --- a/llama_stack/providers/tests/inference/test_text_inference.py +++ b/llama_stack/providers/tests/inference/test_text_inference.py @@ -64,7 +64,7 @@ def sample_tool_definition(): class TestInference: @pytest.mark.asyncio - async def test_model_list(self, inference_model, inference_stack, model_id): + async def test_model_list(self, inference_model, inference_stack): _, models_impl = inference_stack response = await models_impl.list_models() assert isinstance(response, list) @@ -73,16 +73,17 @@ class TestInference: model_def = None for model in response: - if model.identifier == model_id: + if model.identifier == inference_model: model_def = model break assert model_def is not None @pytest.mark.asyncio - async def test_completion(self, inference_model, inference_stack, model_id): + async def test_completion(self, inference_model, inference_stack): inference_impl, _ = inference_stack - provider = inference_impl.routing_table.get_provider_impl(model_id) + + provider = inference_impl.routing_table.get_provider_impl(inference_model) if provider.__provider_spec__.provider_type not in ( "meta-reference", "remote::ollama", @@ -95,7 +96,7 @@ class TestInference: response = await inference_impl.completion( content="Micheael Jordan is born in ", stream=False, - model_id=model_id, + model_id=inference_model, sampling_params=SamplingParams( max_tokens=50, ), @@ -109,7 +110,7 @@ class TestInference: async for r in await inference_impl.completion( content="Roses are red,", stream=True, - model_id=model_id, + model_id=inference_model, sampling_params=SamplingParams( max_tokens=50, ), @@ -124,11 +125,11 @@ class TestInference: @pytest.mark.asyncio @pytest.mark.skip("This test is not quite robust") async def test_completions_structured_output( - self, inference_model, inference_stack, model_id + self, inference_model, inference_stack ): inference_impl, _ = inference_stack - provider = inference_impl.routing_table.get_provider_impl(model_id) + provider = inference_impl.routing_table.get_provider_impl(inference_model) if provider.__provider_spec__.provider_type not in ( "meta-reference", "remote::tgi", @@ -148,7 +149,7 @@ class TestInference: response = await inference_impl.completion( content=user_input, stream=False, - model_id=model_id, + model=inference_model, sampling_params=SamplingParams( max_tokens=50, ), @@ -166,11 +167,11 @@ class TestInference: @pytest.mark.asyncio async def test_chat_completion_non_streaming( - self, inference_model, inference_stack, common_params, sample_messages, model_id + self, inference_model, inference_stack, common_params, sample_messages ): inference_impl, _ = inference_stack response = await inference_impl.chat_completion( - model_id=model_id, + model_id=inference_model, messages=sample_messages, stream=False, **common_params, @@ -183,11 +184,11 @@ class TestInference: @pytest.mark.asyncio async def test_structured_output( - self, inference_model, inference_stack, common_params, model_id + self, inference_model, inference_stack, common_params ): inference_impl, _ = inference_stack - provider = inference_impl.routing_table.get_provider_impl(model_id) + provider = inference_impl.routing_table.get_provider_impl(inference_model) if provider.__provider_spec__.provider_type not in ( "meta-reference", "remote::fireworks", @@ -203,7 +204,7 @@ class TestInference: num_seasons_in_nba: int response = await inference_impl.chat_completion( - model_id=model_id, + model_id=inference_model, messages=[ SystemMessage(content="You are a helpful assistant."), UserMessage(content="Please give me information about Michael Jordan."), @@ -226,7 +227,7 @@ class TestInference: assert answer.num_seasons_in_nba == 15 response = await inference_impl.chat_completion( - model_id=model_id, + model_id=inference_model, messages=[ SystemMessage(content="You are a helpful assistant."), UserMessage(content="Please give me information about Michael Jordan."), @@ -243,13 +244,13 @@ class TestInference: @pytest.mark.asyncio async def test_chat_completion_streaming( - self, inference_model, inference_stack, common_params, sample_messages, model_id + self, inference_model, inference_stack, common_params, sample_messages ): inference_impl, _ = inference_stack response = [ r async for r in await inference_impl.chat_completion( - model_id=model_id, + model_id=inference_model, messages=sample_messages, stream=True, **common_params, @@ -276,7 +277,6 @@ class TestInference: common_params, sample_messages, sample_tool_definition, - model_id, ): inference_impl, _ = inference_stack messages = sample_messages + [ @@ -286,7 +286,7 @@ class TestInference: ] response = await inference_impl.chat_completion( - model_id=model_id, + model_id=inference_model, messages=messages, tools=[sample_tool_definition], stream=False, @@ -316,7 +316,6 @@ class TestInference: common_params, sample_messages, sample_tool_definition, - model_id, ): inference_impl, _ = inference_stack messages = sample_messages + [ @@ -328,7 +327,7 @@ class TestInference: response = [ r async for r in await inference_impl.chat_completion( - model_id=model_id, + model_id=inference_model, messages=messages, tools=[sample_tool_definition], stream=True, diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index 35d67a4cc..c44c641a2 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -7,7 +7,6 @@ from collections import namedtuple from typing import List, Optional -from llama_models.datatypes import CoreModelId from llama_models.sku_list import all_registered_models from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate @@ -15,22 +14,22 @@ from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate ModelAlias = namedtuple("ModelAlias", ["provider_model_id", "aliases", "llama_model"]) -def get_huggingface_repo(core_model_id: CoreModelId) -> Optional[str]: +def get_huggingface_repo(model_descriptor: str) -> Optional[str]: """Get the Hugging Face repository for a given CoreModelId.""" for model in all_registered_models(): - if model.core_model_id == core_model_id: + if model.descriptor() == model_descriptor: return model.huggingface_repo return None -def build_model_alias(provider_model_id: str, core_model_id: CoreModelId) -> ModelAlias: +def build_model_alias(provider_model_id: str, model_descriptor: str) -> ModelAlias: return ModelAlias( provider_model_id=provider_model_id, aliases=[ - core_model_id.value, - get_huggingface_repo(core_model_id), + model_descriptor, + get_huggingface_repo(model_descriptor), ], - llama_model=core_model_id.value, + llama_model=model_descriptor, )