fixes after rebase

This commit is contained in:
Dinesh Yeduguru 2024-11-12 15:37:07 -08:00
parent 948f6ece6e
commit 919d421bcf
11 changed files with 72 additions and 70 deletions

View file

@ -179,7 +179,7 @@ INFERENCE_FIXTURES = [
@pytest_asyncio.fixture(scope="session")
async def inference_stack(request, inference_model, model_id):
async def inference_stack(request, inference_model):
fixture_name = request.param
inference_fixture = request.getfixturevalue(f"inference_{fixture_name}")
impls = await resolve_impls_for_test_v2(
@ -188,7 +188,7 @@ async def inference_stack(request, inference_model, model_id):
inference_fixture.provider_data,
models=[
ModelInput(
model_id=model_id,
model_id=inference_model,
)
],
)

View file

@ -64,7 +64,7 @@ def sample_tool_definition():
class TestInference:
@pytest.mark.asyncio
async def test_model_list(self, inference_model, inference_stack, model_id):
async def test_model_list(self, inference_model, inference_stack):
_, models_impl = inference_stack
response = await models_impl.list_models()
assert isinstance(response, list)
@ -73,16 +73,17 @@ class TestInference:
model_def = None
for model in response:
if model.identifier == model_id:
if model.identifier == inference_model:
model_def = model
break
assert model_def is not None
@pytest.mark.asyncio
async def test_completion(self, inference_model, inference_stack, model_id):
async def test_completion(self, inference_model, inference_stack):
inference_impl, _ = inference_stack
provider = inference_impl.routing_table.get_provider_impl(model_id)
provider = inference_impl.routing_table.get_provider_impl(inference_model)
if provider.__provider_spec__.provider_type not in (
"meta-reference",
"remote::ollama",
@ -95,7 +96,7 @@ class TestInference:
response = await inference_impl.completion(
content="Micheael Jordan is born in ",
stream=False,
model_id=model_id,
model_id=inference_model,
sampling_params=SamplingParams(
max_tokens=50,
),
@ -109,7 +110,7 @@ class TestInference:
async for r in await inference_impl.completion(
content="Roses are red,",
stream=True,
model_id=model_id,
model_id=inference_model,
sampling_params=SamplingParams(
max_tokens=50,
),
@ -124,11 +125,11 @@ class TestInference:
@pytest.mark.asyncio
@pytest.mark.skip("This test is not quite robust")
async def test_completions_structured_output(
self, inference_model, inference_stack, model_id
self, inference_model, inference_stack
):
inference_impl, _ = inference_stack
provider = inference_impl.routing_table.get_provider_impl(model_id)
provider = inference_impl.routing_table.get_provider_impl(inference_model)
if provider.__provider_spec__.provider_type not in (
"meta-reference",
"remote::tgi",
@ -148,7 +149,7 @@ class TestInference:
response = await inference_impl.completion(
content=user_input,
stream=False,
model_id=model_id,
model=inference_model,
sampling_params=SamplingParams(
max_tokens=50,
),
@ -166,11 +167,11 @@ class TestInference:
@pytest.mark.asyncio
async def test_chat_completion_non_streaming(
self, inference_model, inference_stack, common_params, sample_messages, model_id
self, inference_model, inference_stack, common_params, sample_messages
):
inference_impl, _ = inference_stack
response = await inference_impl.chat_completion(
model_id=model_id,
model_id=inference_model,
messages=sample_messages,
stream=False,
**common_params,
@ -183,11 +184,11 @@ class TestInference:
@pytest.mark.asyncio
async def test_structured_output(
self, inference_model, inference_stack, common_params, model_id
self, inference_model, inference_stack, common_params
):
inference_impl, _ = inference_stack
provider = inference_impl.routing_table.get_provider_impl(model_id)
provider = inference_impl.routing_table.get_provider_impl(inference_model)
if provider.__provider_spec__.provider_type not in (
"meta-reference",
"remote::fireworks",
@ -203,7 +204,7 @@ class TestInference:
num_seasons_in_nba: int
response = await inference_impl.chat_completion(
model_id=model_id,
model_id=inference_model,
messages=[
SystemMessage(content="You are a helpful assistant."),
UserMessage(content="Please give me information about Michael Jordan."),
@ -226,7 +227,7 @@ class TestInference:
assert answer.num_seasons_in_nba == 15
response = await inference_impl.chat_completion(
model_id=model_id,
model_id=inference_model,
messages=[
SystemMessage(content="You are a helpful assistant."),
UserMessage(content="Please give me information about Michael Jordan."),
@ -243,13 +244,13 @@ class TestInference:
@pytest.mark.asyncio
async def test_chat_completion_streaming(
self, inference_model, inference_stack, common_params, sample_messages, model_id
self, inference_model, inference_stack, common_params, sample_messages
):
inference_impl, _ = inference_stack
response = [
r
async for r in await inference_impl.chat_completion(
model_id=model_id,
model_id=inference_model,
messages=sample_messages,
stream=True,
**common_params,
@ -276,7 +277,6 @@ class TestInference:
common_params,
sample_messages,
sample_tool_definition,
model_id,
):
inference_impl, _ = inference_stack
messages = sample_messages + [
@ -286,7 +286,7 @@ class TestInference:
]
response = await inference_impl.chat_completion(
model_id=model_id,
model_id=inference_model,
messages=messages,
tools=[sample_tool_definition],
stream=False,
@ -316,7 +316,6 @@ class TestInference:
common_params,
sample_messages,
sample_tool_definition,
model_id,
):
inference_impl, _ = inference_stack
messages = sample_messages + [
@ -328,7 +327,7 @@ class TestInference:
response = [
r
async for r in await inference_impl.chat_completion(
model_id=model_id,
model_id=inference_model,
messages=messages,
tools=[sample_tool_definition],
stream=True,