mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-17 01:42:36 +00:00
fixes after rebase
This commit is contained in:
parent
948f6ece6e
commit
919d421bcf
11 changed files with 72 additions and 70 deletions
|
|
@ -64,7 +64,7 @@ def sample_tool_definition():
|
|||
|
||||
class TestInference:
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_list(self, inference_model, inference_stack, model_id):
|
||||
async def test_model_list(self, inference_model, inference_stack):
|
||||
_, models_impl = inference_stack
|
||||
response = await models_impl.list_models()
|
||||
assert isinstance(response, list)
|
||||
|
|
@ -73,16 +73,17 @@ class TestInference:
|
|||
|
||||
model_def = None
|
||||
for model in response:
|
||||
if model.identifier == model_id:
|
||||
if model.identifier == inference_model:
|
||||
model_def = model
|
||||
break
|
||||
|
||||
assert model_def is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion(self, inference_model, inference_stack, model_id):
|
||||
async def test_completion(self, inference_model, inference_stack):
|
||||
inference_impl, _ = inference_stack
|
||||
provider = inference_impl.routing_table.get_provider_impl(model_id)
|
||||
|
||||
provider = inference_impl.routing_table.get_provider_impl(inference_model)
|
||||
if provider.__provider_spec__.provider_type not in (
|
||||
"meta-reference",
|
||||
"remote::ollama",
|
||||
|
|
@ -95,7 +96,7 @@ class TestInference:
|
|||
response = await inference_impl.completion(
|
||||
content="Micheael Jordan is born in ",
|
||||
stream=False,
|
||||
model_id=model_id,
|
||||
model_id=inference_model,
|
||||
sampling_params=SamplingParams(
|
||||
max_tokens=50,
|
||||
),
|
||||
|
|
@ -109,7 +110,7 @@ class TestInference:
|
|||
async for r in await inference_impl.completion(
|
||||
content="Roses are red,",
|
||||
stream=True,
|
||||
model_id=model_id,
|
||||
model_id=inference_model,
|
||||
sampling_params=SamplingParams(
|
||||
max_tokens=50,
|
||||
),
|
||||
|
|
@ -124,11 +125,11 @@ class TestInference:
|
|||
@pytest.mark.asyncio
|
||||
@pytest.mark.skip("This test is not quite robust")
|
||||
async def test_completions_structured_output(
|
||||
self, inference_model, inference_stack, model_id
|
||||
self, inference_model, inference_stack
|
||||
):
|
||||
inference_impl, _ = inference_stack
|
||||
|
||||
provider = inference_impl.routing_table.get_provider_impl(model_id)
|
||||
provider = inference_impl.routing_table.get_provider_impl(inference_model)
|
||||
if provider.__provider_spec__.provider_type not in (
|
||||
"meta-reference",
|
||||
"remote::tgi",
|
||||
|
|
@ -148,7 +149,7 @@ class TestInference:
|
|||
response = await inference_impl.completion(
|
||||
content=user_input,
|
||||
stream=False,
|
||||
model_id=model_id,
|
||||
model=inference_model,
|
||||
sampling_params=SamplingParams(
|
||||
max_tokens=50,
|
||||
),
|
||||
|
|
@ -166,11 +167,11 @@ class TestInference:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_non_streaming(
|
||||
self, inference_model, inference_stack, common_params, sample_messages, model_id
|
||||
self, inference_model, inference_stack, common_params, sample_messages
|
||||
):
|
||||
inference_impl, _ = inference_stack
|
||||
response = await inference_impl.chat_completion(
|
||||
model_id=model_id,
|
||||
model_id=inference_model,
|
||||
messages=sample_messages,
|
||||
stream=False,
|
||||
**common_params,
|
||||
|
|
@ -183,11 +184,11 @@ class TestInference:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_structured_output(
|
||||
self, inference_model, inference_stack, common_params, model_id
|
||||
self, inference_model, inference_stack, common_params
|
||||
):
|
||||
inference_impl, _ = inference_stack
|
||||
|
||||
provider = inference_impl.routing_table.get_provider_impl(model_id)
|
||||
provider = inference_impl.routing_table.get_provider_impl(inference_model)
|
||||
if provider.__provider_spec__.provider_type not in (
|
||||
"meta-reference",
|
||||
"remote::fireworks",
|
||||
|
|
@ -203,7 +204,7 @@ class TestInference:
|
|||
num_seasons_in_nba: int
|
||||
|
||||
response = await inference_impl.chat_completion(
|
||||
model_id=model_id,
|
||||
model_id=inference_model,
|
||||
messages=[
|
||||
SystemMessage(content="You are a helpful assistant."),
|
||||
UserMessage(content="Please give me information about Michael Jordan."),
|
||||
|
|
@ -226,7 +227,7 @@ class TestInference:
|
|||
assert answer.num_seasons_in_nba == 15
|
||||
|
||||
response = await inference_impl.chat_completion(
|
||||
model_id=model_id,
|
||||
model_id=inference_model,
|
||||
messages=[
|
||||
SystemMessage(content="You are a helpful assistant."),
|
||||
UserMessage(content="Please give me information about Michael Jordan."),
|
||||
|
|
@ -243,13 +244,13 @@ class TestInference:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_streaming(
|
||||
self, inference_model, inference_stack, common_params, sample_messages, model_id
|
||||
self, inference_model, inference_stack, common_params, sample_messages
|
||||
):
|
||||
inference_impl, _ = inference_stack
|
||||
response = [
|
||||
r
|
||||
async for r in await inference_impl.chat_completion(
|
||||
model_id=model_id,
|
||||
model_id=inference_model,
|
||||
messages=sample_messages,
|
||||
stream=True,
|
||||
**common_params,
|
||||
|
|
@ -276,7 +277,6 @@ class TestInference:
|
|||
common_params,
|
||||
sample_messages,
|
||||
sample_tool_definition,
|
||||
model_id,
|
||||
):
|
||||
inference_impl, _ = inference_stack
|
||||
messages = sample_messages + [
|
||||
|
|
@ -286,7 +286,7 @@ class TestInference:
|
|||
]
|
||||
|
||||
response = await inference_impl.chat_completion(
|
||||
model_id=model_id,
|
||||
model_id=inference_model,
|
||||
messages=messages,
|
||||
tools=[sample_tool_definition],
|
||||
stream=False,
|
||||
|
|
@ -316,7 +316,6 @@ class TestInference:
|
|||
common_params,
|
||||
sample_messages,
|
||||
sample_tool_definition,
|
||||
model_id,
|
||||
):
|
||||
inference_impl, _ = inference_stack
|
||||
messages = sample_messages + [
|
||||
|
|
@ -328,7 +327,7 @@ class TestInference:
|
|||
response = [
|
||||
r
|
||||
async for r in await inference_impl.chat_completion(
|
||||
model_id=model_id,
|
||||
model_id=inference_model,
|
||||
messages=messages,
|
||||
tools=[sample_tool_definition],
|
||||
stream=True,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue