Make Together inference work using the raw completions API

This commit is contained in:
Ashwin Bharambe 2024-10-07 17:28:19 -07:00 committed by Ashwin Bharambe
parent 3ae2b712e8
commit bbd3a02615
2 changed files with 33 additions and 18 deletions

View file

@ -41,8 +41,8 @@ class TogetherInferenceAdapter(
self, stack_to_provider_models_map=TOGETHER_SUPPORTED_MODELS
)
self.config = config
tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(tokenizer)
self.tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(self.tokenizer)
@property
def client(self) -> Together:
@ -124,27 +124,28 @@ class TogetherInferenceAdapter(
options = self.get_together_chat_options(request)
together_model = self.map_to_provider_model(request.model)
messages = augment_messages_for_tools(request)
model_input = self.formatter.encode_dialog_prompt(messages)
prompt = self.tokenizer.decode(model_input.tokens)
if not request.stream:
# TODO: might need to add back an async here
r = client.chat.completions.create(
r = client.completions.create(
model=together_model,
messages=self._messages_to_together_messages(messages),
prompt=prompt,
stream=False,
**options,
)
stop_reason = None
if r.choices[0].finish_reason:
if (
r.choices[0].finish_reason == "stop"
or r.choices[0].finish_reason == "eos"
):
choice = r.choices[0]
if choice.finish_reason:
if choice.finish_reason in ["stop", "eos"]:
stop_reason = StopReason.end_of_turn
elif r.choices[0].finish_reason == "length":
stop_reason = StopReason.end_of_turn
elif choice.finish_reason == "length":
stop_reason = StopReason.out_of_tokens
completion_message = self.formatter.decode_assistant_message_from_content(
r.choices[0].message.content, stop_reason
choice.text, stop_reason
)
yield ChatCompletionResponse(
completion_message=completion_message,
@ -162,20 +163,21 @@ class TogetherInferenceAdapter(
ipython = False
stop_reason = None
for chunk in client.chat.completions.create(
for chunk in client.completions.create(
model=together_model,
messages=self._messages_to_together_messages(messages),
prompt=prompt,
stream=True,
**options,
):
if finish_reason := chunk.choices[0].finish_reason:
choice = chunk.choices[0]
if finish_reason := choice.finish_reason:
if stop_reason is None and finish_reason in ["stop", "eos"]:
stop_reason = StopReason.end_of_turn
elif stop_reason is None and finish_reason == "length":
stop_reason = StopReason.out_of_tokens
break
text = chunk.choices[0].delta.content
text = choice.delta.content
if text is None:
continue

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
import itertools
import json
import os
from datetime import datetime
@ -17,6 +18,7 @@ from llama_stack.apis.inference import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.resolver import resolve_impls_with_routing
@ -60,9 +62,10 @@ async def stack_impls(model):
provider = providers_by_id[provider_id]
else:
provider = list(providers_by_id.values())[0]
print(f"No provider ID specified, picking first {provider['provider_id']}")
provider_id = provider["provider_id"]
print(f"No provider ID specified, picking first `{provider_id}`")
config_dict = dict(
run_config = dict(
built_at=datetime.now(),
image_name="test-fixture",
apis=[
@ -84,8 +87,17 @@ async def stack_impls(model):
shields=[],
memory_banks=[],
)
run_config = parse_and_maybe_upgrade_config(config_dict)
run_config = parse_and_maybe_upgrade_config(run_config)
impls = await resolve_impls_with_routing(run_config)
# may need something cleaner here
if "provider_data" in config_dict:
provider_data = config_dict["provider_data"].get(provider_id, {})
if provider_data:
set_request_provider_data(
{"X-LlamaStack-ProviderData": json.dumps(provider_data)}
)
return impls
@ -97,6 +109,7 @@ async def stack_impls(model):
{"model": Llama_8B},
{"model": Llama_3B},
],
ids=lambda d: d["model"],
)
async def inference_settings(request):
model = request.param["model"]