mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-01 16:24:44 +00:00
fixes after rebase
This commit is contained in:
parent
948f6ece6e
commit
919d421bcf
11 changed files with 72 additions and 70 deletions
|
@ -86,6 +86,7 @@ class Llama:
|
||||||
and loads the pre-trained model and tokenizer.
|
and loads the pre-trained model and tokenizer.
|
||||||
"""
|
"""
|
||||||
model = resolve_model(config.model)
|
model = resolve_model(config.model)
|
||||||
|
llama_model = model.core_model_id.value
|
||||||
|
|
||||||
if not torch.distributed.is_initialized():
|
if not torch.distributed.is_initialized():
|
||||||
torch.distributed.init_process_group("nccl")
|
torch.distributed.init_process_group("nccl")
|
||||||
|
@ -186,13 +187,20 @@ class Llama:
|
||||||
model.load_state_dict(state_dict, strict=False)
|
model.load_state_dict(state_dict, strict=False)
|
||||||
|
|
||||||
print(f"Loaded in {time.time() - start_time:.2f} seconds")
|
print(f"Loaded in {time.time() - start_time:.2f} seconds")
|
||||||
return Llama(model, tokenizer, model_args)
|
return Llama(model, tokenizer, model_args, llama_model)
|
||||||
|
|
||||||
def __init__(self, model: Transformer, tokenizer: Tokenizer, args: ModelArgs):
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: Transformer,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
args: ModelArgs,
|
||||||
|
llama_model: str,
|
||||||
|
):
|
||||||
self.args = args
|
self.args = args
|
||||||
self.model = model
|
self.model = model
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.formatter = ChatFormat(tokenizer)
|
self.formatter = ChatFormat(tokenizer)
|
||||||
|
self.llama_model = llama_model
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def generate(
|
def generate(
|
||||||
|
@ -369,7 +377,7 @@ class Llama:
|
||||||
self,
|
self,
|
||||||
request: ChatCompletionRequest,
|
request: ChatCompletionRequest,
|
||||||
) -> Generator:
|
) -> Generator:
|
||||||
messages = chat_completion_request_to_messages(request)
|
messages = chat_completion_request_to_messages(request, self.llama_model)
|
||||||
|
|
||||||
sampling_params = request.sampling_params
|
sampling_params = request.sampling_params
|
||||||
max_gen_len = sampling_params.max_tokens
|
max_gen_len = sampling_params.max_tokens
|
||||||
|
|
|
@ -39,7 +39,7 @@ class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolP
|
||||||
[
|
[
|
||||||
build_model_alias(
|
build_model_alias(
|
||||||
model.descriptor(),
|
model.descriptor(),
|
||||||
model.core_model_id,
|
model.core_model_id.value,
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -56,12 +56,6 @@ class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolP
|
||||||
else:
|
else:
|
||||||
self.generator = Llama.build(self.config)
|
self.generator = Llama.build(self.config)
|
||||||
|
|
||||||
async def register_model(self, model: Model) -> None:
|
|
||||||
if model.provider_resource_id != self.model.descriptor():
|
|
||||||
raise ValueError(
|
|
||||||
f"Model mismatch: {model.identifier} != {self.model.descriptor()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
if self.config.create_distributed_process_group:
|
if self.config.create_distributed_process_group:
|
||||||
self.generator.stop()
|
self.generator.stop()
|
||||||
|
|
|
@ -26,15 +26,15 @@ from llama_stack.providers.utils.bedrock.client import create_bedrock_client
|
||||||
model_aliases = [
|
model_aliases = [
|
||||||
build_model_alias(
|
build_model_alias(
|
||||||
"meta.llama3-1-8b-instruct-v1:0",
|
"meta.llama3-1-8b-instruct-v1:0",
|
||||||
CoreModelId.llama3_1_8b_instruct,
|
CoreModelId.llama3_1_8b_instruct.value,
|
||||||
),
|
),
|
||||||
build_model_alias(
|
build_model_alias(
|
||||||
"meta.llama3-1-70b-instruct-v1:0",
|
"meta.llama3-1-70b-instruct-v1:0",
|
||||||
CoreModelId.llama3_1_70b_instruct,
|
CoreModelId.llama3_1_70b_instruct.value,
|
||||||
),
|
),
|
||||||
build_model_alias(
|
build_model_alias(
|
||||||
"meta.llama3-1-405b-instruct-v1:0",
|
"meta.llama3-1-405b-instruct-v1:0",
|
||||||
CoreModelId.llama3_1_405b_instruct,
|
CoreModelId.llama3_1_405b_instruct.value,
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
@ -36,11 +36,11 @@ from .config import DatabricksImplConfig
|
||||||
model_aliases = [
|
model_aliases = [
|
||||||
build_model_alias(
|
build_model_alias(
|
||||||
"databricks-meta-llama-3-1-70b-instruct",
|
"databricks-meta-llama-3-1-70b-instruct",
|
||||||
CoreModelId.llama3_1_70b_instruct,
|
CoreModelId.llama3_1_70b_instruct.value,
|
||||||
),
|
),
|
||||||
build_model_alias(
|
build_model_alias(
|
||||||
"databricks-meta-llama-3-1-405b-instruct",
|
"databricks-meta-llama-3-1-405b-instruct",
|
||||||
CoreModelId.llama3_1_405b_instruct,
|
CoreModelId.llama3_1_405b_instruct.value,
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
@ -38,39 +38,39 @@ from .config import FireworksImplConfig
|
||||||
model_aliases = [
|
model_aliases = [
|
||||||
build_model_alias(
|
build_model_alias(
|
||||||
"fireworks/llama-v3p1-8b-instruct",
|
"fireworks/llama-v3p1-8b-instruct",
|
||||||
CoreModelId.llama3_1_8b_instruct,
|
CoreModelId.llama3_1_8b_instruct.value,
|
||||||
),
|
),
|
||||||
build_model_alias(
|
build_model_alias(
|
||||||
"fireworks/llama-v3p1-70b-instruct",
|
"fireworks/llama-v3p1-70b-instruct",
|
||||||
CoreModelId.llama3_1_70b_instruct,
|
CoreModelId.llama3_1_70b_instruct.value,
|
||||||
),
|
),
|
||||||
build_model_alias(
|
build_model_alias(
|
||||||
"fireworks/llama-v3p1-405b-instruct",
|
"fireworks/llama-v3p1-405b-instruct",
|
||||||
CoreModelId.llama3_1_405b_instruct,
|
CoreModelId.llama3_1_405b_instruct.value,
|
||||||
),
|
),
|
||||||
build_model_alias(
|
build_model_alias(
|
||||||
"fireworks/llama-v3p2-1b-instruct",
|
"fireworks/llama-v3p2-1b-instruct",
|
||||||
CoreModelId.llama3_2_3b_instruct,
|
CoreModelId.llama3_2_3b_instruct.value,
|
||||||
),
|
),
|
||||||
build_model_alias(
|
build_model_alias(
|
||||||
"fireworks/llama-v3p2-3b-instruct",
|
"fireworks/llama-v3p2-3b-instruct",
|
||||||
CoreModelId.llama3_2_11b_vision_instruct,
|
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||||
),
|
),
|
||||||
build_model_alias(
|
build_model_alias(
|
||||||
"fireworks/llama-v3p2-11b-vision-instruct",
|
"fireworks/llama-v3p2-11b-vision-instruct",
|
||||||
CoreModelId.llama3_2_11b_vision_instruct,
|
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||||
),
|
),
|
||||||
build_model_alias(
|
build_model_alias(
|
||||||
"fireworks/llama-v3p2-90b-vision-instruct",
|
"fireworks/llama-v3p2-90b-vision-instruct",
|
||||||
CoreModelId.llama3_2_90b_vision_instruct,
|
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||||
),
|
),
|
||||||
build_model_alias(
|
build_model_alias(
|
||||||
"fireworks/llama-guard-3-8b",
|
"fireworks/llama-guard-3-8b",
|
||||||
CoreModelId.llama_guard_3_8b,
|
CoreModelId.llama_guard_3_8b.value,
|
||||||
),
|
),
|
||||||
build_model_alias(
|
build_model_alias(
|
||||||
"fireworks/llama-guard-3-11b-vision",
|
"fireworks/llama-guard-3-11b-vision",
|
||||||
CoreModelId.llama_guard_3_11b_vision,
|
CoreModelId.llama_guard_3_11b_vision.value,
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
@ -42,31 +42,31 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
model_aliases = [
|
model_aliases = [
|
||||||
build_model_alias(
|
build_model_alias(
|
||||||
"llama3.1:8b-instruct-fp16",
|
"llama3.1:8b-instruct-fp16",
|
||||||
CoreModelId.llama3_1_8b_instruct,
|
CoreModelId.llama3_1_8b_instruct.value,
|
||||||
),
|
),
|
||||||
build_model_alias(
|
build_model_alias(
|
||||||
"llama3.1:70b-instruct-fp16",
|
"llama3.1:70b-instruct-fp16",
|
||||||
CoreModelId.llama3_1_70b_instruct,
|
CoreModelId.llama3_1_70b_instruct.value,
|
||||||
),
|
),
|
||||||
build_model_alias(
|
build_model_alias(
|
||||||
"llama3.2:1b-instruct-fp16",
|
"llama3.2:1b-instruct-fp16",
|
||||||
CoreModelId.llama3_2_1b_instruct,
|
CoreModelId.llama3_2_1b_instruct.value,
|
||||||
),
|
),
|
||||||
build_model_alias(
|
build_model_alias(
|
||||||
"llama3.2:3b-instruct-fp16",
|
"llama3.2:3b-instruct-fp16",
|
||||||
CoreModelId.llama3_2_3b_instruct,
|
CoreModelId.llama3_2_3b_instruct.value,
|
||||||
),
|
),
|
||||||
build_model_alias(
|
build_model_alias(
|
||||||
"llama-guard3:8b",
|
"llama-guard3:8b",
|
||||||
CoreModelId.llama_guard_3_8b,
|
CoreModelId.llama_guard_3_8b.value,
|
||||||
),
|
),
|
||||||
build_model_alias(
|
build_model_alias(
|
||||||
"llama-guard3:1b",
|
"llama-guard3:1b",
|
||||||
CoreModelId.llama_guard_3_1b,
|
CoreModelId.llama_guard_3_1b.value,
|
||||||
),
|
),
|
||||||
build_model_alias(
|
build_model_alias(
|
||||||
"x/llama3.2-vision:11b-instruct-fp16",
|
"x/llama3.2-vision:11b-instruct-fp16",
|
||||||
CoreModelId.llama3_2_11b_vision_instruct,
|
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -164,6 +164,7 @@ class OllamaInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPriva
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self.model_store.get_model(model_id)
|
||||||
|
print(f"model={model}")
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
model=model.provider_resource_id,
|
model=model.provider_resource_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
|
|
@ -41,35 +41,35 @@ from .config import TogetherImplConfig
|
||||||
model_aliases = [
|
model_aliases = [
|
||||||
build_model_alias(
|
build_model_alias(
|
||||||
"meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
|
"meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
|
||||||
CoreModelId.llama3_1_8b_instruct,
|
CoreModelId.llama3_1_8b_instruct.value,
|
||||||
),
|
),
|
||||||
build_model_alias(
|
build_model_alias(
|
||||||
"meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
|
"meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
|
||||||
CoreModelId.llama3_1_70b_instruct,
|
CoreModelId.llama3_1_70b_instruct.value,
|
||||||
),
|
),
|
||||||
build_model_alias(
|
build_model_alias(
|
||||||
"meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
|
"meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
|
||||||
CoreModelId.llama3_1_405b_instruct,
|
CoreModelId.llama3_1_405b_instruct.value,
|
||||||
),
|
),
|
||||||
build_model_alias(
|
build_model_alias(
|
||||||
"meta-llama/Llama-3.2-3B-Instruct-Turbo",
|
"meta-llama/Llama-3.2-3B-Instruct-Turbo",
|
||||||
CoreModelId.llama3_2_3b_instruct,
|
CoreModelId.llama3_2_3b_instruct.value,
|
||||||
),
|
),
|
||||||
build_model_alias(
|
build_model_alias(
|
||||||
"meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo",
|
"meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo",
|
||||||
CoreModelId.llama3_2_11b_vision_instruct,
|
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||||
),
|
),
|
||||||
build_model_alias(
|
build_model_alias(
|
||||||
"meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo",
|
"meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo",
|
||||||
CoreModelId.llama3_2_90b_vision_instruct,
|
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||||
),
|
),
|
||||||
build_model_alias(
|
build_model_alias(
|
||||||
"meta-llama/Meta-Llama-Guard-3-8B",
|
"meta-llama/Meta-Llama-Guard-3-8B",
|
||||||
CoreModelId.llama_guard_3_8b,
|
CoreModelId.llama_guard_3_8b.value,
|
||||||
),
|
),
|
||||||
build_model_alias(
|
build_model_alias(
|
||||||
"meta-llama/Llama-Guard-3-11B-Vision-Turbo",
|
"meta-llama/Llama-Guard-3-11B-Vision-Turbo",
|
||||||
CoreModelId.llama_guard_3_11b_vision,
|
CoreModelId.llama_guard_3_11b_vision.value,
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
@ -38,7 +38,7 @@ def build_model_aliases():
|
||||||
return [
|
return [
|
||||||
build_model_alias(
|
build_model_alias(
|
||||||
model.huggingface_repo,
|
model.huggingface_repo,
|
||||||
model.core_model_id,
|
model.descriptor(),
|
||||||
)
|
)
|
||||||
for model in all_registered_models()
|
for model in all_registered_models()
|
||||||
if model.huggingface_repo
|
if model.huggingface_repo
|
||||||
|
@ -85,6 +85,7 @@ class VLLMInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPrivate
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self.model_store.get_model(model_id)
|
||||||
|
print(f"model={model}")
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
model=model.provider_resource_id,
|
model=model.provider_resource_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
|
|
@ -179,7 +179,7 @@ INFERENCE_FIXTURES = [
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="session")
|
@pytest_asyncio.fixture(scope="session")
|
||||||
async def inference_stack(request, inference_model, model_id):
|
async def inference_stack(request, inference_model):
|
||||||
fixture_name = request.param
|
fixture_name = request.param
|
||||||
inference_fixture = request.getfixturevalue(f"inference_{fixture_name}")
|
inference_fixture = request.getfixturevalue(f"inference_{fixture_name}")
|
||||||
impls = await resolve_impls_for_test_v2(
|
impls = await resolve_impls_for_test_v2(
|
||||||
|
@ -188,7 +188,7 @@ async def inference_stack(request, inference_model, model_id):
|
||||||
inference_fixture.provider_data,
|
inference_fixture.provider_data,
|
||||||
models=[
|
models=[
|
||||||
ModelInput(
|
ModelInput(
|
||||||
model_id=model_id,
|
model_id=inference_model,
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -64,7 +64,7 @@ def sample_tool_definition():
|
||||||
|
|
||||||
class TestInference:
|
class TestInference:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_model_list(self, inference_model, inference_stack, model_id):
|
async def test_model_list(self, inference_model, inference_stack):
|
||||||
_, models_impl = inference_stack
|
_, models_impl = inference_stack
|
||||||
response = await models_impl.list_models()
|
response = await models_impl.list_models()
|
||||||
assert isinstance(response, list)
|
assert isinstance(response, list)
|
||||||
|
@ -73,16 +73,17 @@ class TestInference:
|
||||||
|
|
||||||
model_def = None
|
model_def = None
|
||||||
for model in response:
|
for model in response:
|
||||||
if model.identifier == model_id:
|
if model.identifier == inference_model:
|
||||||
model_def = model
|
model_def = model
|
||||||
break
|
break
|
||||||
|
|
||||||
assert model_def is not None
|
assert model_def is not None
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_completion(self, inference_model, inference_stack, model_id):
|
async def test_completion(self, inference_model, inference_stack):
|
||||||
inference_impl, _ = inference_stack
|
inference_impl, _ = inference_stack
|
||||||
provider = inference_impl.routing_table.get_provider_impl(model_id)
|
|
||||||
|
provider = inference_impl.routing_table.get_provider_impl(inference_model)
|
||||||
if provider.__provider_spec__.provider_type not in (
|
if provider.__provider_spec__.provider_type not in (
|
||||||
"meta-reference",
|
"meta-reference",
|
||||||
"remote::ollama",
|
"remote::ollama",
|
||||||
|
@ -95,7 +96,7 @@ class TestInference:
|
||||||
response = await inference_impl.completion(
|
response = await inference_impl.completion(
|
||||||
content="Micheael Jordan is born in ",
|
content="Micheael Jordan is born in ",
|
||||||
stream=False,
|
stream=False,
|
||||||
model_id=model_id,
|
model_id=inference_model,
|
||||||
sampling_params=SamplingParams(
|
sampling_params=SamplingParams(
|
||||||
max_tokens=50,
|
max_tokens=50,
|
||||||
),
|
),
|
||||||
|
@ -109,7 +110,7 @@ class TestInference:
|
||||||
async for r in await inference_impl.completion(
|
async for r in await inference_impl.completion(
|
||||||
content="Roses are red,",
|
content="Roses are red,",
|
||||||
stream=True,
|
stream=True,
|
||||||
model_id=model_id,
|
model_id=inference_model,
|
||||||
sampling_params=SamplingParams(
|
sampling_params=SamplingParams(
|
||||||
max_tokens=50,
|
max_tokens=50,
|
||||||
),
|
),
|
||||||
|
@ -124,11 +125,11 @@ class TestInference:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.skip("This test is not quite robust")
|
@pytest.mark.skip("This test is not quite robust")
|
||||||
async def test_completions_structured_output(
|
async def test_completions_structured_output(
|
||||||
self, inference_model, inference_stack, model_id
|
self, inference_model, inference_stack
|
||||||
):
|
):
|
||||||
inference_impl, _ = inference_stack
|
inference_impl, _ = inference_stack
|
||||||
|
|
||||||
provider = inference_impl.routing_table.get_provider_impl(model_id)
|
provider = inference_impl.routing_table.get_provider_impl(inference_model)
|
||||||
if provider.__provider_spec__.provider_type not in (
|
if provider.__provider_spec__.provider_type not in (
|
||||||
"meta-reference",
|
"meta-reference",
|
||||||
"remote::tgi",
|
"remote::tgi",
|
||||||
|
@ -148,7 +149,7 @@ class TestInference:
|
||||||
response = await inference_impl.completion(
|
response = await inference_impl.completion(
|
||||||
content=user_input,
|
content=user_input,
|
||||||
stream=False,
|
stream=False,
|
||||||
model_id=model_id,
|
model=inference_model,
|
||||||
sampling_params=SamplingParams(
|
sampling_params=SamplingParams(
|
||||||
max_tokens=50,
|
max_tokens=50,
|
||||||
),
|
),
|
||||||
|
@ -166,11 +167,11 @@ class TestInference:
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_chat_completion_non_streaming(
|
async def test_chat_completion_non_streaming(
|
||||||
self, inference_model, inference_stack, common_params, sample_messages, model_id
|
self, inference_model, inference_stack, common_params, sample_messages
|
||||||
):
|
):
|
||||||
inference_impl, _ = inference_stack
|
inference_impl, _ = inference_stack
|
||||||
response = await inference_impl.chat_completion(
|
response = await inference_impl.chat_completion(
|
||||||
model_id=model_id,
|
model_id=inference_model,
|
||||||
messages=sample_messages,
|
messages=sample_messages,
|
||||||
stream=False,
|
stream=False,
|
||||||
**common_params,
|
**common_params,
|
||||||
|
@ -183,11 +184,11 @@ class TestInference:
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_structured_output(
|
async def test_structured_output(
|
||||||
self, inference_model, inference_stack, common_params, model_id
|
self, inference_model, inference_stack, common_params
|
||||||
):
|
):
|
||||||
inference_impl, _ = inference_stack
|
inference_impl, _ = inference_stack
|
||||||
|
|
||||||
provider = inference_impl.routing_table.get_provider_impl(model_id)
|
provider = inference_impl.routing_table.get_provider_impl(inference_model)
|
||||||
if provider.__provider_spec__.provider_type not in (
|
if provider.__provider_spec__.provider_type not in (
|
||||||
"meta-reference",
|
"meta-reference",
|
||||||
"remote::fireworks",
|
"remote::fireworks",
|
||||||
|
@ -203,7 +204,7 @@ class TestInference:
|
||||||
num_seasons_in_nba: int
|
num_seasons_in_nba: int
|
||||||
|
|
||||||
response = await inference_impl.chat_completion(
|
response = await inference_impl.chat_completion(
|
||||||
model_id=model_id,
|
model_id=inference_model,
|
||||||
messages=[
|
messages=[
|
||||||
SystemMessage(content="You are a helpful assistant."),
|
SystemMessage(content="You are a helpful assistant."),
|
||||||
UserMessage(content="Please give me information about Michael Jordan."),
|
UserMessage(content="Please give me information about Michael Jordan."),
|
||||||
|
@ -226,7 +227,7 @@ class TestInference:
|
||||||
assert answer.num_seasons_in_nba == 15
|
assert answer.num_seasons_in_nba == 15
|
||||||
|
|
||||||
response = await inference_impl.chat_completion(
|
response = await inference_impl.chat_completion(
|
||||||
model_id=model_id,
|
model_id=inference_model,
|
||||||
messages=[
|
messages=[
|
||||||
SystemMessage(content="You are a helpful assistant."),
|
SystemMessage(content="You are a helpful assistant."),
|
||||||
UserMessage(content="Please give me information about Michael Jordan."),
|
UserMessage(content="Please give me information about Michael Jordan."),
|
||||||
|
@ -243,13 +244,13 @@ class TestInference:
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_chat_completion_streaming(
|
async def test_chat_completion_streaming(
|
||||||
self, inference_model, inference_stack, common_params, sample_messages, model_id
|
self, inference_model, inference_stack, common_params, sample_messages
|
||||||
):
|
):
|
||||||
inference_impl, _ = inference_stack
|
inference_impl, _ = inference_stack
|
||||||
response = [
|
response = [
|
||||||
r
|
r
|
||||||
async for r in await inference_impl.chat_completion(
|
async for r in await inference_impl.chat_completion(
|
||||||
model_id=model_id,
|
model_id=inference_model,
|
||||||
messages=sample_messages,
|
messages=sample_messages,
|
||||||
stream=True,
|
stream=True,
|
||||||
**common_params,
|
**common_params,
|
||||||
|
@ -276,7 +277,6 @@ class TestInference:
|
||||||
common_params,
|
common_params,
|
||||||
sample_messages,
|
sample_messages,
|
||||||
sample_tool_definition,
|
sample_tool_definition,
|
||||||
model_id,
|
|
||||||
):
|
):
|
||||||
inference_impl, _ = inference_stack
|
inference_impl, _ = inference_stack
|
||||||
messages = sample_messages + [
|
messages = sample_messages + [
|
||||||
|
@ -286,7 +286,7 @@ class TestInference:
|
||||||
]
|
]
|
||||||
|
|
||||||
response = await inference_impl.chat_completion(
|
response = await inference_impl.chat_completion(
|
||||||
model_id=model_id,
|
model_id=inference_model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
tools=[sample_tool_definition],
|
tools=[sample_tool_definition],
|
||||||
stream=False,
|
stream=False,
|
||||||
|
@ -316,7 +316,6 @@ class TestInference:
|
||||||
common_params,
|
common_params,
|
||||||
sample_messages,
|
sample_messages,
|
||||||
sample_tool_definition,
|
sample_tool_definition,
|
||||||
model_id,
|
|
||||||
):
|
):
|
||||||
inference_impl, _ = inference_stack
|
inference_impl, _ = inference_stack
|
||||||
messages = sample_messages + [
|
messages = sample_messages + [
|
||||||
|
@ -328,7 +327,7 @@ class TestInference:
|
||||||
response = [
|
response = [
|
||||||
r
|
r
|
||||||
async for r in await inference_impl.chat_completion(
|
async for r in await inference_impl.chat_completion(
|
||||||
model_id=model_id,
|
model_id=inference_model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
tools=[sample_tool_definition],
|
tools=[sample_tool_definition],
|
||||||
stream=True,
|
stream=True,
|
||||||
|
|
|
@ -7,7 +7,6 @@
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from llama_models.datatypes import CoreModelId
|
|
||||||
from llama_models.sku_list import all_registered_models
|
from llama_models.sku_list import all_registered_models
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
||||||
|
@ -15,22 +14,22 @@ from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
||||||
ModelAlias = namedtuple("ModelAlias", ["provider_model_id", "aliases", "llama_model"])
|
ModelAlias = namedtuple("ModelAlias", ["provider_model_id", "aliases", "llama_model"])
|
||||||
|
|
||||||
|
|
||||||
def get_huggingface_repo(core_model_id: CoreModelId) -> Optional[str]:
|
def get_huggingface_repo(model_descriptor: str) -> Optional[str]:
|
||||||
"""Get the Hugging Face repository for a given CoreModelId."""
|
"""Get the Hugging Face repository for a given CoreModelId."""
|
||||||
for model in all_registered_models():
|
for model in all_registered_models():
|
||||||
if model.core_model_id == core_model_id:
|
if model.descriptor() == model_descriptor:
|
||||||
return model.huggingface_repo
|
return model.huggingface_repo
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def build_model_alias(provider_model_id: str, core_model_id: CoreModelId) -> ModelAlias:
|
def build_model_alias(provider_model_id: str, model_descriptor: str) -> ModelAlias:
|
||||||
return ModelAlias(
|
return ModelAlias(
|
||||||
provider_model_id=provider_model_id,
|
provider_model_id=provider_model_id,
|
||||||
aliases=[
|
aliases=[
|
||||||
core_model_id.value,
|
model_descriptor,
|
||||||
get_huggingface_repo(core_model_id),
|
get_huggingface_repo(model_descriptor),
|
||||||
],
|
],
|
||||||
llama_model=core_model_id.value,
|
llama_model=model_descriptor,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue