mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-08 19:10:56 +00:00
Make Together inference work using the raw completions API
This commit is contained in:
parent
3ae2b712e8
commit
bbd3a02615
2 changed files with 33 additions and 18 deletions
|
|
@ -5,6 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import itertools
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
|
|
@ -17,6 +18,7 @@ from llama_stack.apis.inference import * # noqa: F403
|
|||
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||
from llama_stack.distribution.resolver import resolve_impls_with_routing
|
||||
|
||||
|
||||
|
|
@ -60,9 +62,10 @@ async def stack_impls(model):
|
|||
provider = providers_by_id[provider_id]
|
||||
else:
|
||||
provider = list(providers_by_id.values())[0]
|
||||
print(f"No provider ID specified, picking first {provider['provider_id']}")
|
||||
provider_id = provider["provider_id"]
|
||||
print(f"No provider ID specified, picking first `{provider_id}`")
|
||||
|
||||
config_dict = dict(
|
||||
run_config = dict(
|
||||
built_at=datetime.now(),
|
||||
image_name="test-fixture",
|
||||
apis=[
|
||||
|
|
@ -84,8 +87,17 @@ async def stack_impls(model):
|
|||
shields=[],
|
||||
memory_banks=[],
|
||||
)
|
||||
run_config = parse_and_maybe_upgrade_config(config_dict)
|
||||
run_config = parse_and_maybe_upgrade_config(run_config)
|
||||
impls = await resolve_impls_with_routing(run_config)
|
||||
|
||||
# may need something cleaner here
|
||||
if "provider_data" in config_dict:
|
||||
provider_data = config_dict["provider_data"].get(provider_id, {})
|
||||
if provider_data:
|
||||
set_request_provider_data(
|
||||
{"X-LlamaStack-ProviderData": json.dumps(provider_data)}
|
||||
)
|
||||
|
||||
return impls
|
||||
|
||||
|
||||
|
|
@ -97,6 +109,7 @@ async def stack_impls(model):
|
|||
{"model": Llama_8B},
|
||||
{"model": Llama_3B},
|
||||
],
|
||||
ids=lambda d: d["model"],
|
||||
)
|
||||
async def inference_settings(request):
|
||||
model = request.param["model"]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue