mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
Make Together inference work using the raw completions API
This commit is contained in:
parent
3ae2b712e8
commit
bbd3a02615
2 changed files with 33 additions and 18 deletions
|
@ -41,8 +41,8 @@ class TogetherInferenceAdapter(
|
||||||
self, stack_to_provider_models_map=TOGETHER_SUPPORTED_MODELS
|
self, stack_to_provider_models_map=TOGETHER_SUPPORTED_MODELS
|
||||||
)
|
)
|
||||||
self.config = config
|
self.config = config
|
||||||
tokenizer = Tokenizer.get_instance()
|
self.tokenizer = Tokenizer.get_instance()
|
||||||
self.formatter = ChatFormat(tokenizer)
|
self.formatter = ChatFormat(self.tokenizer)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def client(self) -> Together:
|
def client(self) -> Together:
|
||||||
|
@ -124,27 +124,28 @@ class TogetherInferenceAdapter(
|
||||||
options = self.get_together_chat_options(request)
|
options = self.get_together_chat_options(request)
|
||||||
together_model = self.map_to_provider_model(request.model)
|
together_model = self.map_to_provider_model(request.model)
|
||||||
messages = augment_messages_for_tools(request)
|
messages = augment_messages_for_tools(request)
|
||||||
|
model_input = self.formatter.encode_dialog_prompt(messages)
|
||||||
|
prompt = self.tokenizer.decode(model_input.tokens)
|
||||||
|
|
||||||
if not request.stream:
|
if not request.stream:
|
||||||
# TODO: might need to add back an async here
|
# TODO: might need to add back an async here
|
||||||
r = client.chat.completions.create(
|
r = client.completions.create(
|
||||||
model=together_model,
|
model=together_model,
|
||||||
messages=self._messages_to_together_messages(messages),
|
prompt=prompt,
|
||||||
stream=False,
|
stream=False,
|
||||||
**options,
|
**options,
|
||||||
)
|
)
|
||||||
stop_reason = None
|
stop_reason = None
|
||||||
if r.choices[0].finish_reason:
|
choice = r.choices[0]
|
||||||
if (
|
if choice.finish_reason:
|
||||||
r.choices[0].finish_reason == "stop"
|
if choice.finish_reason in ["stop", "eos"]:
|
||||||
or r.choices[0].finish_reason == "eos"
|
|
||||||
):
|
|
||||||
stop_reason = StopReason.end_of_turn
|
stop_reason = StopReason.end_of_turn
|
||||||
elif r.choices[0].finish_reason == "length":
|
stop_reason = StopReason.end_of_turn
|
||||||
|
elif choice.finish_reason == "length":
|
||||||
stop_reason = StopReason.out_of_tokens
|
stop_reason = StopReason.out_of_tokens
|
||||||
|
|
||||||
completion_message = self.formatter.decode_assistant_message_from_content(
|
completion_message = self.formatter.decode_assistant_message_from_content(
|
||||||
r.choices[0].message.content, stop_reason
|
choice.text, stop_reason
|
||||||
)
|
)
|
||||||
yield ChatCompletionResponse(
|
yield ChatCompletionResponse(
|
||||||
completion_message=completion_message,
|
completion_message=completion_message,
|
||||||
|
@ -162,20 +163,21 @@ class TogetherInferenceAdapter(
|
||||||
ipython = False
|
ipython = False
|
||||||
stop_reason = None
|
stop_reason = None
|
||||||
|
|
||||||
for chunk in client.chat.completions.create(
|
for chunk in client.completions.create(
|
||||||
model=together_model,
|
model=together_model,
|
||||||
messages=self._messages_to_together_messages(messages),
|
prompt=prompt,
|
||||||
stream=True,
|
stream=True,
|
||||||
**options,
|
**options,
|
||||||
):
|
):
|
||||||
if finish_reason := chunk.choices[0].finish_reason:
|
choice = chunk.choices[0]
|
||||||
|
if finish_reason := choice.finish_reason:
|
||||||
if stop_reason is None and finish_reason in ["stop", "eos"]:
|
if stop_reason is None and finish_reason in ["stop", "eos"]:
|
||||||
stop_reason = StopReason.end_of_turn
|
stop_reason = StopReason.end_of_turn
|
||||||
elif stop_reason is None and finish_reason == "length":
|
elif stop_reason is None and finish_reason == "length":
|
||||||
stop_reason = StopReason.out_of_tokens
|
stop_reason = StopReason.out_of_tokens
|
||||||
break
|
break
|
||||||
|
|
||||||
text = chunk.choices[0].delta.content
|
text = choice.delta.content
|
||||||
if text is None:
|
if text is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
@ -17,6 +18,7 @@ from llama_stack.apis.inference import * # noqa: F403
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||||
|
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||||
from llama_stack.distribution.resolver import resolve_impls_with_routing
|
from llama_stack.distribution.resolver import resolve_impls_with_routing
|
||||||
|
|
||||||
|
|
||||||
|
@ -60,9 +62,10 @@ async def stack_impls(model):
|
||||||
provider = providers_by_id[provider_id]
|
provider = providers_by_id[provider_id]
|
||||||
else:
|
else:
|
||||||
provider = list(providers_by_id.values())[0]
|
provider = list(providers_by_id.values())[0]
|
||||||
print(f"No provider ID specified, picking first {provider['provider_id']}")
|
provider_id = provider["provider_id"]
|
||||||
|
print(f"No provider ID specified, picking first `{provider_id}`")
|
||||||
|
|
||||||
config_dict = dict(
|
run_config = dict(
|
||||||
built_at=datetime.now(),
|
built_at=datetime.now(),
|
||||||
image_name="test-fixture",
|
image_name="test-fixture",
|
||||||
apis=[
|
apis=[
|
||||||
|
@ -84,8 +87,17 @@ async def stack_impls(model):
|
||||||
shields=[],
|
shields=[],
|
||||||
memory_banks=[],
|
memory_banks=[],
|
||||||
)
|
)
|
||||||
run_config = parse_and_maybe_upgrade_config(config_dict)
|
run_config = parse_and_maybe_upgrade_config(run_config)
|
||||||
impls = await resolve_impls_with_routing(run_config)
|
impls = await resolve_impls_with_routing(run_config)
|
||||||
|
|
||||||
|
# may need something cleaner here
|
||||||
|
if "provider_data" in config_dict:
|
||||||
|
provider_data = config_dict["provider_data"].get(provider_id, {})
|
||||||
|
if provider_data:
|
||||||
|
set_request_provider_data(
|
||||||
|
{"X-LlamaStack-ProviderData": json.dumps(provider_data)}
|
||||||
|
)
|
||||||
|
|
||||||
return impls
|
return impls
|
||||||
|
|
||||||
|
|
||||||
|
@ -97,6 +109,7 @@ async def stack_impls(model):
|
||||||
{"model": Llama_8B},
|
{"model": Llama_8B},
|
||||||
{"model": Llama_3B},
|
{"model": Llama_3B},
|
||||||
],
|
],
|
||||||
|
ids=lambda d: d["model"],
|
||||||
)
|
)
|
||||||
async def inference_settings(request):
|
async def inference_settings(request):
|
||||||
model = request.param["model"]
|
model = request.param["model"]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue