Revert "add model type to APIs" (#605)

Reverts meta-llama/llama-stack#588
This commit is contained in:
Dinesh Yeduguru 2024-12-11 10:17:54 -08:00 committed by GitHub
parent 8e33db6015
commit 47b2dc8ae3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 13 additions and 77 deletions

View file

@ -88,10 +88,9 @@ class InferenceRouter(Inference):
provider_model_id: Optional[str] = None,
provider_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
model_type: Optional[ModelType] = None,
) -> None:
await self.routing_table.register_model(
model_id, provider_model_id, provider_id, metadata, model_type
model_id, provider_model_id, provider_id, metadata
)
async def chat_completion(
@ -106,13 +105,6 @@ class InferenceRouter(Inference):
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
model = await self.routing_table.get_model(model_id)
if model is None:
raise ValueError(f"Model '{model_id}' not found")
if model.model_type == ModelType.embedding_model:
raise ValueError(
f"Model '{model_id}' is an embedding model and does not support chat completions"
)
params = dict(
model_id=model_id,
messages=messages,
@ -139,13 +131,6 @@ class InferenceRouter(Inference):
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
model = await self.routing_table.get_model(model_id)
if model is None:
raise ValueError(f"Model '{model_id}' not found")
if model.model_type == ModelType.embedding_model:
raise ValueError(
f"Model '{model_id}' is an embedding model and does not support chat completions"
)
provider = self.routing_table.get_provider_impl(model_id)
params = dict(
model_id=model_id,
@ -165,13 +150,6 @@ class InferenceRouter(Inference):
model_id: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
model = await self.routing_table.get_model(model_id)
if model is None:
raise ValueError(f"Model '{model_id}' not found")
if model.model_type == ModelType.llm:
raise ValueError(
f"Model '{model_id}' is an LLM model and does not support embeddings"
)
return await self.routing_table.get_provider_impl(model_id).embeddings(
model_id=model_id,
contents=contents,