mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-17 13:02:36 +00:00
working fireworks and together
This commit is contained in:
parent
25d8ab0e14
commit
8de4cee373
8 changed files with 205 additions and 86 deletions
|
|
@ -105,9 +105,8 @@ class InferenceRouter(Inference):
|
|||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
model = await self.routing_table.get_model(model_id)
|
||||
params = dict(
|
||||
model_id=model.provider_resource_id,
|
||||
model_id=model_id,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
|
|
@ -132,10 +131,9 @@ class InferenceRouter(Inference):
|
|||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
model = await self.routing_table.get_model(model_id)
|
||||
provider = self.routing_table.get_provider_impl(model_id)
|
||||
params = dict(
|
||||
model_id=model.provider_resource_id,
|
||||
model_id=model_id,
|
||||
content=content,
|
||||
sampling_params=sampling_params,
|
||||
response_format=response_format,
|
||||
|
|
@ -152,9 +150,8 @@ class InferenceRouter(Inference):
|
|||
model_id: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
) -> EmbeddingsResponse:
|
||||
model = await self.routing_table.get_model(model_id)
|
||||
return await self.routing_table.get_provider_impl(model_id).embeddings(
|
||||
model_id=model.provider_resource_id,
|
||||
model_id=model_id,
|
||||
contents=contents,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -28,7 +28,9 @@ def get_impl_api(p: Any) -> Api:
|
|||
return p.__provider_spec__.api
|
||||
|
||||
|
||||
async def register_object_with_provider(obj: RoutableObject, p: Any) -> None:
|
||||
# TODO: this should return the registered object for all APIs
|
||||
async def register_object_with_provider(obj: RoutableObject, p: Any) -> RoutableObject:
|
||||
|
||||
api = get_impl_api(p)
|
||||
|
||||
if obj.provider_id == "remote":
|
||||
|
|
@ -42,7 +44,7 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> None:
|
|||
obj.provider_id = ""
|
||||
|
||||
if api == Api.inference:
|
||||
await p.register_model(obj)
|
||||
return await p.register_model(obj)
|
||||
elif api == Api.safety:
|
||||
await p.register_shield(obj)
|
||||
elif api == Api.memory:
|
||||
|
|
@ -167,7 +169,9 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
assert len(objects) == 1
|
||||
return objects[0]
|
||||
|
||||
async def register_object(self, obj: RoutableObjectWithProvider):
|
||||
async def register_object(
|
||||
self, obj: RoutableObjectWithProvider
|
||||
) -> RoutableObjectWithProvider:
|
||||
# Get existing objects from registry
|
||||
existing_objects = await self.dist_registry.get(obj.type, obj.identifier)
|
||||
|
||||
|
|
@ -177,7 +181,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
print(
|
||||
f"`{obj.identifier}` already registered with `{existing_obj.provider_id}`"
|
||||
)
|
||||
return
|
||||
return existing_obj
|
||||
|
||||
# if provider_id is not specified, pick an arbitrary one from existing entries
|
||||
if not obj.provider_id and len(self.impls_by_provider_id) > 0:
|
||||
|
|
@ -188,8 +192,15 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
|
||||
p = self.impls_by_provider_id[obj.provider_id]
|
||||
|
||||
await register_object_with_provider(obj, p)
|
||||
await self.dist_registry.register(obj)
|
||||
registered_obj = await register_object_with_provider(obj, p)
|
||||
# TODO: This needs to be fixed for all APIs once they return the registered object
|
||||
if obj.type == ResourceType.model.value:
|
||||
await self.dist_registry.register(registered_obj)
|
||||
return registered_obj
|
||||
|
||||
else:
|
||||
await self.dist_registry.register(obj)
|
||||
return obj
|
||||
|
||||
async def get_all_with_type(self, type: str) -> List[RoutableObjectWithProvider]:
|
||||
objs = await self.dist_registry.get_all()
|
||||
|
|
@ -228,8 +239,8 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
provider_id=provider_id,
|
||||
metadata=metadata,
|
||||
)
|
||||
await self.register_object(model)
|
||||
return model
|
||||
registered_model = await self.register_object(model)
|
||||
return registered_model
|
||||
|
||||
|
||||
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue