completion() for tgi (#295)

This commit is contained in:
Dinesh Yeduguru 2024-10-24 16:02:41 -07:00 committed by GitHub
parent cb84034567
commit 3e1c3fdb3f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 173 additions and 35 deletions

View file

@ -137,6 +137,7 @@ async def test_completion(inference_settings):
if provider.__provider_spec__.provider_type not in (
"meta-reference",
"remote::ollama",
"remote::tgi",
):
pytest.skip("Other inference providers don't support completion() yet")
@ -170,6 +171,46 @@ async def test_completion(inference_settings):
assert last.stop_reason == StopReason.out_of_tokens
@pytest.mark.asyncio
async def test_completions_structured_output(inference_settings):
inference_impl = inference_settings["impl"]
params = inference_settings["common_params"]
provider = inference_impl.routing_table.get_provider_impl(params["model"])
if provider.__provider_spec__.provider_type not in (
"meta-reference",
"remote::tgi",
):
pytest.skip(
"Other inference providers don't support structured output in completions yet"
)
class Output(BaseModel):
name: str
year_born: str
year_retired: str
user_input = "Michael Jordan was born in 1963. He played basketball for the Chicago Bulls. He retired in 2003."
response = await inference_impl.completion(
content=f"input: '{user_input}'. the schema for json: {Output.schema()}, the json is: ",
stream=False,
model=params["model"],
sampling_params=SamplingParams(
max_tokens=50,
),
response_format=JsonResponseFormat(
schema=Output.model_json_schema(),
),
)
assert isinstance(response, CompletionResponse)
assert isinstance(response.content, str)
answer = Output.parse_raw(response.content)
assert answer.name == "Michael Jordan"
assert answer.year_born == "1963"
assert answer.year_retired == "2003"
@pytest.mark.asyncio
async def test_chat_completion_non_streaming(inference_settings, sample_messages):
inference_impl = inference_settings["impl"]