Make Together inference work using the raw completions API

This commit is contained in:
Ashwin Bharambe 2024-10-07 17:28:19 -07:00 committed by Ashwin Bharambe
parent 3ae2b712e8
commit bbd3a02615
2 changed files with 33 additions and 18 deletions

View file

@ -41,8 +41,8 @@ class TogetherInferenceAdapter(
self, stack_to_provider_models_map=TOGETHER_SUPPORTED_MODELS self, stack_to_provider_models_map=TOGETHER_SUPPORTED_MODELS
) )
self.config = config self.config = config
tokenizer = Tokenizer.get_instance() self.tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(tokenizer) self.formatter = ChatFormat(self.tokenizer)
@property @property
def client(self) -> Together: def client(self) -> Together:
@ -124,27 +124,28 @@ class TogetherInferenceAdapter(
options = self.get_together_chat_options(request) options = self.get_together_chat_options(request)
together_model = self.map_to_provider_model(request.model) together_model = self.map_to_provider_model(request.model)
messages = augment_messages_for_tools(request) messages = augment_messages_for_tools(request)
model_input = self.formatter.encode_dialog_prompt(messages)
prompt = self.tokenizer.decode(model_input.tokens)
if not request.stream: if not request.stream:
# TODO: might need to add back an async here # TODO: might need to add back an async here
r = client.chat.completions.create( r = client.completions.create(
model=together_model, model=together_model,
messages=self._messages_to_together_messages(messages), prompt=prompt,
stream=False, stream=False,
**options, **options,
) )
stop_reason = None stop_reason = None
if r.choices[0].finish_reason: choice = r.choices[0]
if ( if choice.finish_reason:
r.choices[0].finish_reason == "stop" if choice.finish_reason in ["stop", "eos"]:
or r.choices[0].finish_reason == "eos"
):
stop_reason = StopReason.end_of_turn stop_reason = StopReason.end_of_turn
elif r.choices[0].finish_reason == "length": stop_reason = StopReason.end_of_turn
elif choice.finish_reason == "length":
stop_reason = StopReason.out_of_tokens stop_reason = StopReason.out_of_tokens
completion_message = self.formatter.decode_assistant_message_from_content( completion_message = self.formatter.decode_assistant_message_from_content(
r.choices[0].message.content, stop_reason choice.text, stop_reason
) )
yield ChatCompletionResponse( yield ChatCompletionResponse(
completion_message=completion_message, completion_message=completion_message,
@ -162,20 +163,21 @@ class TogetherInferenceAdapter(
ipython = False ipython = False
stop_reason = None stop_reason = None
for chunk in client.chat.completions.create( for chunk in client.completions.create(
model=together_model, model=together_model,
messages=self._messages_to_together_messages(messages), prompt=prompt,
stream=True, stream=True,
**options, **options,
): ):
if finish_reason := chunk.choices[0].finish_reason: choice = chunk.choices[0]
if finish_reason := choice.finish_reason:
if stop_reason is None and finish_reason in ["stop", "eos"]: if stop_reason is None and finish_reason in ["stop", "eos"]:
stop_reason = StopReason.end_of_turn stop_reason = StopReason.end_of_turn
elif stop_reason is None and finish_reason == "length": elif stop_reason is None and finish_reason == "length":
stop_reason = StopReason.out_of_tokens stop_reason = StopReason.out_of_tokens
break break
text = chunk.choices[0].delta.content text = choice.delta.content
if text is None: if text is None:
continue continue

View file

@ -5,6 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
import itertools import itertools
import json
import os import os
from datetime import datetime from datetime import datetime
@ -17,6 +18,7 @@ from llama_stack.apis.inference import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.resolver import resolve_impls_with_routing from llama_stack.distribution.resolver import resolve_impls_with_routing
@ -60,9 +62,10 @@ async def stack_impls(model):
provider = providers_by_id[provider_id] provider = providers_by_id[provider_id]
else: else:
provider = list(providers_by_id.values())[0] provider = list(providers_by_id.values())[0]
print(f"No provider ID specified, picking first {provider['provider_id']}") provider_id = provider["provider_id"]
print(f"No provider ID specified, picking first `{provider_id}`")
config_dict = dict( run_config = dict(
built_at=datetime.now(), built_at=datetime.now(),
image_name="test-fixture", image_name="test-fixture",
apis=[ apis=[
@ -84,8 +87,17 @@ async def stack_impls(model):
shields=[], shields=[],
memory_banks=[], memory_banks=[],
) )
run_config = parse_and_maybe_upgrade_config(config_dict) run_config = parse_and_maybe_upgrade_config(run_config)
impls = await resolve_impls_with_routing(run_config) impls = await resolve_impls_with_routing(run_config)
# may need something cleaner here
if "provider_data" in config_dict:
provider_data = config_dict["provider_data"].get(provider_id, {})
if provider_data:
set_request_provider_data(
{"X-LlamaStack-ProviderData": json.dumps(provider_data)}
)
return impls return impls
@ -97,6 +109,7 @@ async def stack_impls(model):
{"model": Llama_8B}, {"model": Llama_8B},
{"model": Llama_3B}, {"model": Llama_3B},
], ],
ids=lambda d: d["model"],
) )
async def inference_settings(request): async def inference_settings(request):
model = request.param["model"] model = request.param["model"]