mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-18 10:39:48 +00:00
Revert "Revert "add model type to APIs" (#605)"
This reverts commit 47b2dc8ae3.
This commit is contained in:
parent
47b2dc8ae3
commit
310c15bada
6 changed files with 77 additions and 13 deletions
|
|
@ -88,9 +88,10 @@ class InferenceRouter(Inference):
|
|||
provider_model_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
model_type: Optional[ModelType] = None,
|
||||
) -> None:
|
||||
await self.routing_table.register_model(
|
||||
model_id, provider_model_id, provider_id, metadata
|
||||
model_id, provider_model_id, provider_id, metadata, model_type
|
||||
)
|
||||
|
||||
async def chat_completion(
|
||||
|
|
@ -105,6 +106,13 @@ class InferenceRouter(Inference):
|
|||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
model = await self.routing_table.get_model(model_id)
|
||||
if model is None:
|
||||
raise ValueError(f"Model '{model_id}' not found")
|
||||
if model.model_type == ModelType.embedding_model:
|
||||
raise ValueError(
|
||||
f"Model '{model_id}' is an embedding model and does not support chat completions"
|
||||
)
|
||||
params = dict(
|
||||
model_id=model_id,
|
||||
messages=messages,
|
||||
|
|
@ -131,6 +139,13 @@ class InferenceRouter(Inference):
|
|||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
model = await self.routing_table.get_model(model_id)
|
||||
if model is None:
|
||||
raise ValueError(f"Model '{model_id}' not found")
|
||||
if model.model_type == ModelType.embedding_model:
|
||||
raise ValueError(
|
||||
f"Model '{model_id}' is an embedding model and does not support chat completions"
|
||||
)
|
||||
provider = self.routing_table.get_provider_impl(model_id)
|
||||
params = dict(
|
||||
model_id=model_id,
|
||||
|
|
@ -150,6 +165,13 @@ class InferenceRouter(Inference):
|
|||
model_id: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
) -> EmbeddingsResponse:
|
||||
model = await self.routing_table.get_model(model_id)
|
||||
if model is None:
|
||||
raise ValueError(f"Model '{model_id}' not found")
|
||||
if model.model_type == ModelType.llm:
|
||||
raise ValueError(
|
||||
f"Model '{model_id}' is an LLM model and does not support embeddings"
|
||||
)
|
||||
return await self.routing_table.get_provider_impl(model_id).embeddings(
|
||||
model_id=model_id,
|
||||
contents=contents,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue