fixes after rebase

This commit is contained in:
Dinesh Yeduguru 2024-11-12 15:37:07 -08:00
parent 948f6ece6e
commit 919d421bcf
11 changed files with 72 additions and 70 deletions

View file

@ -86,6 +86,7 @@ class Llama:
and loads the pre-trained model and tokenizer. and loads the pre-trained model and tokenizer.
""" """
model = resolve_model(config.model) model = resolve_model(config.model)
llama_model = model.core_model_id.value
if not torch.distributed.is_initialized(): if not torch.distributed.is_initialized():
torch.distributed.init_process_group("nccl") torch.distributed.init_process_group("nccl")
@ -186,13 +187,20 @@ class Llama:
model.load_state_dict(state_dict, strict=False) model.load_state_dict(state_dict, strict=False)
print(f"Loaded in {time.time() - start_time:.2f} seconds") print(f"Loaded in {time.time() - start_time:.2f} seconds")
return Llama(model, tokenizer, model_args) return Llama(model, tokenizer, model_args, llama_model)
def __init__(self, model: Transformer, tokenizer: Tokenizer, args: ModelArgs): def __init__(
self,
model: Transformer,
tokenizer: Tokenizer,
args: ModelArgs,
llama_model: str,
):
self.args = args self.args = args
self.model = model self.model = model
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.formatter = ChatFormat(tokenizer) self.formatter = ChatFormat(tokenizer)
self.llama_model = llama_model
@torch.inference_mode() @torch.inference_mode()
def generate( def generate(
@ -369,7 +377,7 @@ class Llama:
self, self,
request: ChatCompletionRequest, request: ChatCompletionRequest,
) -> Generator: ) -> Generator:
messages = chat_completion_request_to_messages(request) messages = chat_completion_request_to_messages(request, self.llama_model)
sampling_params = request.sampling_params sampling_params = request.sampling_params
max_gen_len = sampling_params.max_tokens max_gen_len = sampling_params.max_tokens

View file

@ -39,7 +39,7 @@ class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolP
[ [
build_model_alias( build_model_alias(
model.descriptor(), model.descriptor(),
model.core_model_id, model.core_model_id.value,
) )
], ],
) )
@ -56,12 +56,6 @@ class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolP
else: else:
self.generator = Llama.build(self.config) self.generator = Llama.build(self.config)
async def register_model(self, model: Model) -> None:
if model.provider_resource_id != self.model.descriptor():
raise ValueError(
f"Model mismatch: {model.identifier} != {self.model.descriptor()}"
)
async def shutdown(self) -> None: async def shutdown(self) -> None:
if self.config.create_distributed_process_group: if self.config.create_distributed_process_group:
self.generator.stop() self.generator.stop()

View file

@ -26,15 +26,15 @@ from llama_stack.providers.utils.bedrock.client import create_bedrock_client
model_aliases = [ model_aliases = [
build_model_alias( build_model_alias(
"meta.llama3-1-8b-instruct-v1:0", "meta.llama3-1-8b-instruct-v1:0",
CoreModelId.llama3_1_8b_instruct, CoreModelId.llama3_1_8b_instruct.value,
), ),
build_model_alias( build_model_alias(
"meta.llama3-1-70b-instruct-v1:0", "meta.llama3-1-70b-instruct-v1:0",
CoreModelId.llama3_1_70b_instruct, CoreModelId.llama3_1_70b_instruct.value,
), ),
build_model_alias( build_model_alias(
"meta.llama3-1-405b-instruct-v1:0", "meta.llama3-1-405b-instruct-v1:0",
CoreModelId.llama3_1_405b_instruct, CoreModelId.llama3_1_405b_instruct.value,
), ),
] ]

View file

@ -36,11 +36,11 @@ from .config import DatabricksImplConfig
model_aliases = [ model_aliases = [
build_model_alias( build_model_alias(
"databricks-meta-llama-3-1-70b-instruct", "databricks-meta-llama-3-1-70b-instruct",
CoreModelId.llama3_1_70b_instruct, CoreModelId.llama3_1_70b_instruct.value,
), ),
build_model_alias( build_model_alias(
"databricks-meta-llama-3-1-405b-instruct", "databricks-meta-llama-3-1-405b-instruct",
CoreModelId.llama3_1_405b_instruct, CoreModelId.llama3_1_405b_instruct.value,
), ),
] ]

View file

@ -38,39 +38,39 @@ from .config import FireworksImplConfig
model_aliases = [ model_aliases = [
build_model_alias( build_model_alias(
"fireworks/llama-v3p1-8b-instruct", "fireworks/llama-v3p1-8b-instruct",
CoreModelId.llama3_1_8b_instruct, CoreModelId.llama3_1_8b_instruct.value,
), ),
build_model_alias( build_model_alias(
"fireworks/llama-v3p1-70b-instruct", "fireworks/llama-v3p1-70b-instruct",
CoreModelId.llama3_1_70b_instruct, CoreModelId.llama3_1_70b_instruct.value,
), ),
build_model_alias( build_model_alias(
"fireworks/llama-v3p1-405b-instruct", "fireworks/llama-v3p1-405b-instruct",
CoreModelId.llama3_1_405b_instruct, CoreModelId.llama3_1_405b_instruct.value,
), ),
build_model_alias( build_model_alias(
"fireworks/llama-v3p2-1b-instruct", "fireworks/llama-v3p2-1b-instruct",
CoreModelId.llama3_2_3b_instruct, CoreModelId.llama3_2_3b_instruct.value,
), ),
build_model_alias( build_model_alias(
"fireworks/llama-v3p2-3b-instruct", "fireworks/llama-v3p2-3b-instruct",
CoreModelId.llama3_2_11b_vision_instruct, CoreModelId.llama3_2_11b_vision_instruct.value,
), ),
build_model_alias( build_model_alias(
"fireworks/llama-v3p2-11b-vision-instruct", "fireworks/llama-v3p2-11b-vision-instruct",
CoreModelId.llama3_2_11b_vision_instruct, CoreModelId.llama3_2_11b_vision_instruct.value,
), ),
build_model_alias( build_model_alias(
"fireworks/llama-v3p2-90b-vision-instruct", "fireworks/llama-v3p2-90b-vision-instruct",
CoreModelId.llama3_2_90b_vision_instruct, CoreModelId.llama3_2_90b_vision_instruct.value,
), ),
build_model_alias( build_model_alias(
"fireworks/llama-guard-3-8b", "fireworks/llama-guard-3-8b",
CoreModelId.llama_guard_3_8b, CoreModelId.llama_guard_3_8b.value,
), ),
build_model_alias( build_model_alias(
"fireworks/llama-guard-3-11b-vision", "fireworks/llama-guard-3-11b-vision",
CoreModelId.llama_guard_3_11b_vision, CoreModelId.llama_guard_3_11b_vision.value,
), ),
] ]

View file

@ -42,31 +42,31 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
model_aliases = [ model_aliases = [
build_model_alias( build_model_alias(
"llama3.1:8b-instruct-fp16", "llama3.1:8b-instruct-fp16",
CoreModelId.llama3_1_8b_instruct, CoreModelId.llama3_1_8b_instruct.value,
), ),
build_model_alias( build_model_alias(
"llama3.1:70b-instruct-fp16", "llama3.1:70b-instruct-fp16",
CoreModelId.llama3_1_70b_instruct, CoreModelId.llama3_1_70b_instruct.value,
), ),
build_model_alias( build_model_alias(
"llama3.2:1b-instruct-fp16", "llama3.2:1b-instruct-fp16",
CoreModelId.llama3_2_1b_instruct, CoreModelId.llama3_2_1b_instruct.value,
), ),
build_model_alias( build_model_alias(
"llama3.2:3b-instruct-fp16", "llama3.2:3b-instruct-fp16",
CoreModelId.llama3_2_3b_instruct, CoreModelId.llama3_2_3b_instruct.value,
), ),
build_model_alias( build_model_alias(
"llama-guard3:8b", "llama-guard3:8b",
CoreModelId.llama_guard_3_8b, CoreModelId.llama_guard_3_8b.value,
), ),
build_model_alias( build_model_alias(
"llama-guard3:1b", "llama-guard3:1b",
CoreModelId.llama_guard_3_1b, CoreModelId.llama_guard_3_1b.value,
), ),
build_model_alias( build_model_alias(
"x/llama3.2-vision:11b-instruct-fp16", "x/llama3.2-vision:11b-instruct-fp16",
CoreModelId.llama3_2_11b_vision_instruct, CoreModelId.llama3_2_11b_vision_instruct.value,
), ),
] ]
@ -164,6 +164,7 @@ class OllamaInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPriva
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)
print(f"model={model}")
request = ChatCompletionRequest( request = ChatCompletionRequest(
model=model.provider_resource_id, model=model.provider_resource_id,
messages=messages, messages=messages,

View file

@ -41,35 +41,35 @@ from .config import TogetherImplConfig
model_aliases = [ model_aliases = [
build_model_alias( build_model_alias(
"meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
CoreModelId.llama3_1_8b_instruct, CoreModelId.llama3_1_8b_instruct.value,
), ),
build_model_alias( build_model_alias(
"meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
CoreModelId.llama3_1_70b_instruct, CoreModelId.llama3_1_70b_instruct.value,
), ),
build_model_alias( build_model_alias(
"meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo", "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
CoreModelId.llama3_1_405b_instruct, CoreModelId.llama3_1_405b_instruct.value,
), ),
build_model_alias( build_model_alias(
"meta-llama/Llama-3.2-3B-Instruct-Turbo", "meta-llama/Llama-3.2-3B-Instruct-Turbo",
CoreModelId.llama3_2_3b_instruct, CoreModelId.llama3_2_3b_instruct.value,
), ),
build_model_alias( build_model_alias(
"meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo", "meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo",
CoreModelId.llama3_2_11b_vision_instruct, CoreModelId.llama3_2_11b_vision_instruct.value,
), ),
build_model_alias( build_model_alias(
"meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo", "meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo",
CoreModelId.llama3_2_90b_vision_instruct, CoreModelId.llama3_2_90b_vision_instruct.value,
), ),
build_model_alias( build_model_alias(
"meta-llama/Meta-Llama-Guard-3-8B", "meta-llama/Meta-Llama-Guard-3-8B",
CoreModelId.llama_guard_3_8b, CoreModelId.llama_guard_3_8b.value,
), ),
build_model_alias( build_model_alias(
"meta-llama/Llama-Guard-3-11B-Vision-Turbo", "meta-llama/Llama-Guard-3-11B-Vision-Turbo",
CoreModelId.llama_guard_3_11b_vision, CoreModelId.llama_guard_3_11b_vision.value,
), ),
] ]

View file

@ -38,7 +38,7 @@ def build_model_aliases():
return [ return [
build_model_alias( build_model_alias(
model.huggingface_repo, model.huggingface_repo,
model.core_model_id, model.descriptor(),
) )
for model in all_registered_models() for model in all_registered_models()
if model.huggingface_repo if model.huggingface_repo
@ -85,6 +85,7 @@ class VLLMInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPrivate
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)
print(f"model={model}")
request = ChatCompletionRequest( request = ChatCompletionRequest(
model=model.provider_resource_id, model=model.provider_resource_id,
messages=messages, messages=messages,

View file

@ -179,7 +179,7 @@ INFERENCE_FIXTURES = [
@pytest_asyncio.fixture(scope="session") @pytest_asyncio.fixture(scope="session")
async def inference_stack(request, inference_model, model_id): async def inference_stack(request, inference_model):
fixture_name = request.param fixture_name = request.param
inference_fixture = request.getfixturevalue(f"inference_{fixture_name}") inference_fixture = request.getfixturevalue(f"inference_{fixture_name}")
impls = await resolve_impls_for_test_v2( impls = await resolve_impls_for_test_v2(
@ -188,7 +188,7 @@ async def inference_stack(request, inference_model, model_id):
inference_fixture.provider_data, inference_fixture.provider_data,
models=[ models=[
ModelInput( ModelInput(
model_id=model_id, model_id=inference_model,
) )
], ],
) )

View file

@ -64,7 +64,7 @@ def sample_tool_definition():
class TestInference: class TestInference:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_model_list(self, inference_model, inference_stack, model_id): async def test_model_list(self, inference_model, inference_stack):
_, models_impl = inference_stack _, models_impl = inference_stack
response = await models_impl.list_models() response = await models_impl.list_models()
assert isinstance(response, list) assert isinstance(response, list)
@ -73,16 +73,17 @@ class TestInference:
model_def = None model_def = None
for model in response: for model in response:
if model.identifier == model_id: if model.identifier == inference_model:
model_def = model model_def = model
break break
assert model_def is not None assert model_def is not None
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_completion(self, inference_model, inference_stack, model_id): async def test_completion(self, inference_model, inference_stack):
inference_impl, _ = inference_stack inference_impl, _ = inference_stack
provider = inference_impl.routing_table.get_provider_impl(model_id)
provider = inference_impl.routing_table.get_provider_impl(inference_model)
if provider.__provider_spec__.provider_type not in ( if provider.__provider_spec__.provider_type not in (
"meta-reference", "meta-reference",
"remote::ollama", "remote::ollama",
@ -95,7 +96,7 @@ class TestInference:
response = await inference_impl.completion( response = await inference_impl.completion(
content="Micheael Jordan is born in ", content="Micheael Jordan is born in ",
stream=False, stream=False,
model_id=model_id, model_id=inference_model,
sampling_params=SamplingParams( sampling_params=SamplingParams(
max_tokens=50, max_tokens=50,
), ),
@ -109,7 +110,7 @@ class TestInference:
async for r in await inference_impl.completion( async for r in await inference_impl.completion(
content="Roses are red,", content="Roses are red,",
stream=True, stream=True,
model_id=model_id, model_id=inference_model,
sampling_params=SamplingParams( sampling_params=SamplingParams(
max_tokens=50, max_tokens=50,
), ),
@ -124,11 +125,11 @@ class TestInference:
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.skip("This test is not quite robust") @pytest.mark.skip("This test is not quite robust")
async def test_completions_structured_output( async def test_completions_structured_output(
self, inference_model, inference_stack, model_id self, inference_model, inference_stack
): ):
inference_impl, _ = inference_stack inference_impl, _ = inference_stack
provider = inference_impl.routing_table.get_provider_impl(model_id) provider = inference_impl.routing_table.get_provider_impl(inference_model)
if provider.__provider_spec__.provider_type not in ( if provider.__provider_spec__.provider_type not in (
"meta-reference", "meta-reference",
"remote::tgi", "remote::tgi",
@ -148,7 +149,7 @@ class TestInference:
response = await inference_impl.completion( response = await inference_impl.completion(
content=user_input, content=user_input,
stream=False, stream=False,
model_id=model_id, model=inference_model,
sampling_params=SamplingParams( sampling_params=SamplingParams(
max_tokens=50, max_tokens=50,
), ),
@ -166,11 +167,11 @@ class TestInference:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_chat_completion_non_streaming( async def test_chat_completion_non_streaming(
self, inference_model, inference_stack, common_params, sample_messages, model_id self, inference_model, inference_stack, common_params, sample_messages
): ):
inference_impl, _ = inference_stack inference_impl, _ = inference_stack
response = await inference_impl.chat_completion( response = await inference_impl.chat_completion(
model_id=model_id, model_id=inference_model,
messages=sample_messages, messages=sample_messages,
stream=False, stream=False,
**common_params, **common_params,
@ -183,11 +184,11 @@ class TestInference:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_structured_output( async def test_structured_output(
self, inference_model, inference_stack, common_params, model_id self, inference_model, inference_stack, common_params
): ):
inference_impl, _ = inference_stack inference_impl, _ = inference_stack
provider = inference_impl.routing_table.get_provider_impl(model_id) provider = inference_impl.routing_table.get_provider_impl(inference_model)
if provider.__provider_spec__.provider_type not in ( if provider.__provider_spec__.provider_type not in (
"meta-reference", "meta-reference",
"remote::fireworks", "remote::fireworks",
@ -203,7 +204,7 @@ class TestInference:
num_seasons_in_nba: int num_seasons_in_nba: int
response = await inference_impl.chat_completion( response = await inference_impl.chat_completion(
model_id=model_id, model_id=inference_model,
messages=[ messages=[
SystemMessage(content="You are a helpful assistant."), SystemMessage(content="You are a helpful assistant."),
UserMessage(content="Please give me information about Michael Jordan."), UserMessage(content="Please give me information about Michael Jordan."),
@ -226,7 +227,7 @@ class TestInference:
assert answer.num_seasons_in_nba == 15 assert answer.num_seasons_in_nba == 15
response = await inference_impl.chat_completion( response = await inference_impl.chat_completion(
model_id=model_id, model_id=inference_model,
messages=[ messages=[
SystemMessage(content="You are a helpful assistant."), SystemMessage(content="You are a helpful assistant."),
UserMessage(content="Please give me information about Michael Jordan."), UserMessage(content="Please give me information about Michael Jordan."),
@ -243,13 +244,13 @@ class TestInference:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_chat_completion_streaming( async def test_chat_completion_streaming(
self, inference_model, inference_stack, common_params, sample_messages, model_id self, inference_model, inference_stack, common_params, sample_messages
): ):
inference_impl, _ = inference_stack inference_impl, _ = inference_stack
response = [ response = [
r r
async for r in await inference_impl.chat_completion( async for r in await inference_impl.chat_completion(
model_id=model_id, model_id=inference_model,
messages=sample_messages, messages=sample_messages,
stream=True, stream=True,
**common_params, **common_params,
@ -276,7 +277,6 @@ class TestInference:
common_params, common_params,
sample_messages, sample_messages,
sample_tool_definition, sample_tool_definition,
model_id,
): ):
inference_impl, _ = inference_stack inference_impl, _ = inference_stack
messages = sample_messages + [ messages = sample_messages + [
@ -286,7 +286,7 @@ class TestInference:
] ]
response = await inference_impl.chat_completion( response = await inference_impl.chat_completion(
model_id=model_id, model_id=inference_model,
messages=messages, messages=messages,
tools=[sample_tool_definition], tools=[sample_tool_definition],
stream=False, stream=False,
@ -316,7 +316,6 @@ class TestInference:
common_params, common_params,
sample_messages, sample_messages,
sample_tool_definition, sample_tool_definition,
model_id,
): ):
inference_impl, _ = inference_stack inference_impl, _ = inference_stack
messages = sample_messages + [ messages = sample_messages + [
@ -328,7 +327,7 @@ class TestInference:
response = [ response = [
r r
async for r in await inference_impl.chat_completion( async for r in await inference_impl.chat_completion(
model_id=model_id, model_id=inference_model,
messages=messages, messages=messages,
tools=[sample_tool_definition], tools=[sample_tool_definition],
stream=True, stream=True,

View file

@ -7,7 +7,6 @@
from collections import namedtuple from collections import namedtuple
from typing import List, Optional from typing import List, Optional
from llama_models.datatypes import CoreModelId
from llama_models.sku_list import all_registered_models from llama_models.sku_list import all_registered_models
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
@ -15,22 +14,22 @@ from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
ModelAlias = namedtuple("ModelAlias", ["provider_model_id", "aliases", "llama_model"]) ModelAlias = namedtuple("ModelAlias", ["provider_model_id", "aliases", "llama_model"])
def get_huggingface_repo(core_model_id: CoreModelId) -> Optional[str]: def get_huggingface_repo(model_descriptor: str) -> Optional[str]:
"""Get the Hugging Face repository for a given CoreModelId.""" """Get the Hugging Face repository for a given CoreModelId."""
for model in all_registered_models(): for model in all_registered_models():
if model.core_model_id == core_model_id: if model.descriptor() == model_descriptor:
return model.huggingface_repo return model.huggingface_repo
return None return None
def build_model_alias(provider_model_id: str, core_model_id: CoreModelId) -> ModelAlias: def build_model_alias(provider_model_id: str, model_descriptor: str) -> ModelAlias:
return ModelAlias( return ModelAlias(
provider_model_id=provider_model_id, provider_model_id=provider_model_id,
aliases=[ aliases=[
core_model_id.value, model_descriptor,
get_huggingface_repo(core_model_id), get_huggingface_repo(model_descriptor),
], ],
llama_model=core_model_id.value, llama_model=model_descriptor,
) )