Make Together inference work using the raw completions API

This commit is contained in:
Ashwin Bharambe 2024-10-07 17:28:19 -07:00 committed by Ashwin Bharambe
parent 3ae2b712e8
commit bbd3a02615
2 changed files with 33 additions and 18 deletions

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
import itertools
import json
import os
from datetime import datetime
@ -17,6 +18,7 @@ from llama_stack.apis.inference import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.resolver import resolve_impls_with_routing
@ -60,9 +62,10 @@ async def stack_impls(model):
provider = providers_by_id[provider_id]
else:
provider = list(providers_by_id.values())[0]
print(f"No provider ID specified, picking first {provider['provider_id']}")
provider_id = provider["provider_id"]
print(f"No provider ID specified, picking first `{provider_id}`")
config_dict = dict(
run_config = dict(
built_at=datetime.now(),
image_name="test-fixture",
apis=[
@ -84,8 +87,17 @@ async def stack_impls(model):
shields=[],
memory_banks=[],
)
run_config = parse_and_maybe_upgrade_config(config_dict)
run_config = parse_and_maybe_upgrade_config(run_config)
impls = await resolve_impls_with_routing(run_config)
# may need something cleaner here
if "provider_data" in config_dict:
provider_data = config_dict["provider_data"].get(provider_id, {})
if provider_data:
set_request_provider_data(
{"X-LlamaStack-ProviderData": json.dumps(provider_data)}
)
return impls
@ -97,6 +109,7 @@ async def stack_impls(model):
{"model": Llama_8B},
{"model": Llama_3B},
],
ids=lambda d: d["model"],
)
async def inference_settings(request):
model = request.param["model"]