mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-18 15:27:16 +00:00
Kill the notion of a "remote" / "passthrough" provider
This commit is contained in:
parent
59a65e34d3
commit
743da9690b
6 changed files with 95 additions and 87 deletions
|
@ -33,28 +33,83 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable
|
|||
|
||||
api = get_impl_api(p)
|
||||
|
||||
if obj.provider_id == "remote":
|
||||
# TODO: this is broken right now because we use the generic
|
||||
# { identifier, provider_id, provider_resource_id } tuple here
|
||||
# but the APIs expect things like ModelInput, ShieldInput, etc.
|
||||
|
||||
# if this is just a passthrough, we want to let the remote
|
||||
# end actually do the registration with the correct provider
|
||||
obj = obj.model_copy(deep=True)
|
||||
obj.provider_id = ""
|
||||
is_remote = obj.provider_id == "remote"
|
||||
if is_remote:
|
||||
# TODO: these are incomplete fixes since (a) they are kind of adhoc and likely to break
|
||||
# and (b) MemoryBankInput is missing BankParams
|
||||
if isinstance(obj, Model):
|
||||
obj = ModelInput(
|
||||
model_id=obj.identifier,
|
||||
metadata=obj.metadata,
|
||||
provider_model_id=obj.provider_resource_id,
|
||||
)
|
||||
elif isinstance(obj, Shield):
|
||||
obj = ShieldInput(
|
||||
shield_id=obj.identifier,
|
||||
params=obj.params,
|
||||
provider_shield_id=obj.provider_resource_id,
|
||||
)
|
||||
elif isinstance(obj, MemoryBank):
|
||||
# need to calculate params here
|
||||
obj = MemoryBankInput(
|
||||
memory_bank_id=obj.identifier,
|
||||
provider_memory_bank_id=obj.provider_resource_id,
|
||||
)
|
||||
elif isinstance(obj, ScoringFn):
|
||||
obj = ScoringFnInput(
|
||||
scoring_fn_id=obj.identifier,
|
||||
provider_scoring_fn_id=obj.provider_resource_id,
|
||||
description=obj.description,
|
||||
metadata=obj.metadata,
|
||||
return_type=obj.return_type,
|
||||
params=obj.params,
|
||||
)
|
||||
elif isinstance(obj, EvalTask):
|
||||
obj = EvalTaskInput(
|
||||
eval_task_id=obj.identifier,
|
||||
provider_eval_task_id=obj.provider_resource_id,
|
||||
dataset_id=obj.dataset_id,
|
||||
scoring_function_id=obj.scoring_functions,
|
||||
metadata=obj.metadata,
|
||||
)
|
||||
elif isinstance(obj, Dataset):
|
||||
obj = DatasetInput(
|
||||
dataset_id=obj.identifier,
|
||||
provider_dataset_id=obj.provider_resource_id,
|
||||
schema=obj.schema,
|
||||
url=obj.url,
|
||||
metadata=obj.metadata,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown object type {type(obj)}")
|
||||
|
||||
if api == Api.inference:
|
||||
return await p.register_model(obj)
|
||||
elif api == Api.safety:
|
||||
await p.register_shield(obj)
|
||||
if is_remote:
|
||||
await p.register_shield(**obj.model_dump())
|
||||
else:
|
||||
await p.register_shield(obj)
|
||||
elif api == Api.memory:
|
||||
await p.register_memory_bank(obj)
|
||||
if is_remote:
|
||||
await p.register_memory_bank(**obj.model_dump())
|
||||
else:
|
||||
await p.register_memory_bank(obj)
|
||||
elif api == Api.datasetio:
|
||||
await p.register_dataset(obj)
|
||||
if is_remote:
|
||||
await p.register_dataset(**obj.model_dump())
|
||||
else:
|
||||
await p.register_dataset(obj)
|
||||
elif api == Api.scoring:
|
||||
await p.register_scoring_function(obj)
|
||||
if is_remote:
|
||||
await p.register_scoring_function(**obj.model_dump())
|
||||
else:
|
||||
await p.register_scoring_function(obj)
|
||||
elif api == Api.eval:
|
||||
await p.register_eval_task(obj)
|
||||
if is_remote:
|
||||
await p.register_eval_task(**obj.model_dump())
|
||||
else:
|
||||
await p.register_eval_task(obj)
|
||||
else:
|
||||
raise ValueError(f"Unknown API {api} for registering object with provider")
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue