introduce openai_compat with the completions (not chat-completions) API

This keeps the prompt encoding layer in our control (see
`chat_completion_request_to_prompt()` method)
This commit is contained in:
Ashwin Bharambe 2024-10-08 12:15:55 -07:00 committed by Ashwin Bharambe
parent 0c9eb3341c
commit 05e73d12b3
6 changed files with 354 additions and 513 deletions

View file

@ -55,7 +55,7 @@ def get_expected_stop_reason(model: str):
@pytest_asyncio.fixture(
scope="session",
params=[
{"model": Llama_8B},
# {"model": Llama_8B},
{"model": Llama_3B},
],
ids=lambda d: d["model"],
@ -112,20 +112,16 @@ def sample_tool_definition():
@pytest.mark.asyncio
async def test_chat_completion_non_streaming(inference_settings, sample_messages):
inference_impl = inference_settings["impl"]
response = [
r
async for r in inference_impl.chat_completion(
messages=sample_messages,
stream=False,
**inference_settings["common_params"],
)
]
response = await inference_impl.chat_completion(
messages=sample_messages,
stream=False,
**inference_settings["common_params"],
)
assert len(response) == 1
assert isinstance(response[0], ChatCompletionResponse)
assert response[0].completion_message.role == "assistant"
assert isinstance(response[0].completion_message.content, str)
assert len(response[0].completion_message.content) > 0
assert isinstance(response, ChatCompletionResponse)
assert response.completion_message.role == "assistant"
assert isinstance(response.completion_message.content, str)
assert len(response.completion_message.content) > 0
@pytest.mark.asyncio
@ -166,20 +162,16 @@ async def test_chat_completion_with_tool_calling(
)
]
response = [
r
async for r in inference_impl.chat_completion(
messages=messages,
tools=[sample_tool_definition],
stream=False,
**inference_settings["common_params"],
)
]
response = await inference_impl.chat_completion(
messages=messages,
tools=[sample_tool_definition],
stream=False,
**inference_settings["common_params"],
)
assert len(response) == 1
assert isinstance(response[0], ChatCompletionResponse)
assert isinstance(response, ChatCompletionResponse)
message = response[0].completion_message
message = response.completion_message
# This is not supported in most providers :/ they don't return eom_id / eot_id
# stop_reason = get_expected_stop_reason(inference_settings["common_params"]["model"])