From bbd3a026159e95135409099c630441ca18b71f0f Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 7 Oct 2024 17:28:19 -0700 Subject: [PATCH] Make Together inference work using the raw completions API --- .../adapters/inference/together/together.py | 32 ++++++++++--------- .../tests/inference/test_inference.py | 19 +++++++++-- 2 files changed, 33 insertions(+), 18 deletions(-) diff --git a/llama_stack/providers/adapters/inference/together/together.py b/llama_stack/providers/adapters/inference/together/together.py index 2ee90d8e3..73e0edc4e 100644 --- a/llama_stack/providers/adapters/inference/together/together.py +++ b/llama_stack/providers/adapters/inference/together/together.py @@ -41,8 +41,8 @@ class TogetherInferenceAdapter( self, stack_to_provider_models_map=TOGETHER_SUPPORTED_MODELS ) self.config = config - tokenizer = Tokenizer.get_instance() - self.formatter = ChatFormat(tokenizer) + self.tokenizer = Tokenizer.get_instance() + self.formatter = ChatFormat(self.tokenizer) @property def client(self) -> Together: @@ -124,27 +124,28 @@ class TogetherInferenceAdapter( options = self.get_together_chat_options(request) together_model = self.map_to_provider_model(request.model) messages = augment_messages_for_tools(request) + model_input = self.formatter.encode_dialog_prompt(messages) + prompt = self.tokenizer.decode(model_input.tokens) if not request.stream: # TODO: might need to add back an async here - r = client.chat.completions.create( + r = client.completions.create( model=together_model, - messages=self._messages_to_together_messages(messages), + prompt=prompt, stream=False, **options, ) stop_reason = None - if r.choices[0].finish_reason: - if ( - r.choices[0].finish_reason == "stop" - or r.choices[0].finish_reason == "eos" - ): + choice = r.choices[0] + if choice.finish_reason: + if choice.finish_reason in ["stop", "eos"]: stop_reason = StopReason.end_of_turn - elif r.choices[0].finish_reason == "length": + stop_reason = StopReason.end_of_turn + elif choice.finish_reason == "length": stop_reason = StopReason.out_of_tokens completion_message = self.formatter.decode_assistant_message_from_content( - r.choices[0].message.content, stop_reason + choice.text, stop_reason ) yield ChatCompletionResponse( completion_message=completion_message, @@ -162,20 +163,21 @@ class TogetherInferenceAdapter( ipython = False stop_reason = None - for chunk in client.chat.completions.create( + for chunk in client.completions.create( model=together_model, - messages=self._messages_to_together_messages(messages), + prompt=prompt, stream=True, **options, ): - if finish_reason := chunk.choices[0].finish_reason: + choice = chunk.choices[0] + if finish_reason := choice.finish_reason: if stop_reason is None and finish_reason in ["stop", "eos"]: stop_reason = StopReason.end_of_turn elif stop_reason is None and finish_reason == "length": stop_reason = StopReason.out_of_tokens break - text = chunk.choices[0].delta.content + text = choice.delta.content if text is None: continue diff --git a/llama_stack/providers/tests/inference/test_inference.py b/llama_stack/providers/tests/inference/test_inference.py index 61989b691..794cbaa2b 100644 --- a/llama_stack/providers/tests/inference/test_inference.py +++ b/llama_stack/providers/tests/inference/test_inference.py @@ -5,6 +5,7 @@ # the root directory of this source tree. import itertools +import json import os from datetime import datetime @@ -17,6 +18,7 @@ from llama_stack.apis.inference import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.configure import parse_and_maybe_upgrade_config +from llama_stack.distribution.request_headers import set_request_provider_data from llama_stack.distribution.resolver import resolve_impls_with_routing @@ -60,9 +62,10 @@ async def stack_impls(model): provider = providers_by_id[provider_id] else: provider = list(providers_by_id.values())[0] - print(f"No provider ID specified, picking first {provider['provider_id']}") + provider_id = provider["provider_id"] + print(f"No provider ID specified, picking first `{provider_id}`") - config_dict = dict( + run_config = dict( built_at=datetime.now(), image_name="test-fixture", apis=[ @@ -84,8 +87,17 @@ async def stack_impls(model): shields=[], memory_banks=[], ) - run_config = parse_and_maybe_upgrade_config(config_dict) + run_config = parse_and_maybe_upgrade_config(run_config) impls = await resolve_impls_with_routing(run_config) + + # may need something cleaner here + if "provider_data" in config_dict: + provider_data = config_dict["provider_data"].get(provider_id, {}) + if provider_data: + set_request_provider_data( + {"X-LlamaStack-ProviderData": json.dumps(provider_data)} + ) + return impls @@ -97,6 +109,7 @@ async def stack_impls(model): {"model": Llama_8B}, {"model": Llama_3B}, ], + ids=lambda d: d["model"], ) async def inference_settings(request): model = request.param["model"]