mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 18:00:36 +00:00
This PR enables routing of fully qualified model IDs of the form `provider_id/model_id` even when the models are not registered with the Stack. Here's the situation: assume a remote inference provider which works only when users provide their own API keys via `X-LlamaStack-Provider-Data` header. By definition, we cannot list models and hence update our routing registry. But because we _require_ a provider ID in the models now, we can identify which provider to route to and let that provider decide. Note that we still try to look up our registry since it may have a pre-registered alias. Just that we don't outright fail when we are not able to look it up. Also, updated inference router so that the responses have the _exact_ model that the request had. ## Test Plan Added an integration test Closes #3929<hr>This is an automatic backport of pull request #3928 done by [Mergify](https://mergify.com). --------- Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com> Co-authored-by: ehhuang <ehhuang@users.noreply.github.com>
This commit is contained in:
parent
a6c3a9cadf
commit
641d5144be
6 changed files with 214 additions and 55 deletions
|
|
@ -46,8 +46,7 @@ class SentenceTransformerEmbeddingMixin:
|
|||
raise ValueError("Empty list not supported")
|
||||
|
||||
# Get the model and generate embeddings
|
||||
model_obj = await self.model_store.get_model(params.model)
|
||||
embedding_model = await self._load_sentence_transformer_model(model_obj.provider_resource_id)
|
||||
embedding_model = await self._load_sentence_transformer_model(params.model)
|
||||
embeddings = await asyncio.to_thread(embedding_model.encode, input_list, show_progress_bar=False)
|
||||
|
||||
# Convert embeddings to the requested format
|
||||
|
|
|
|||
|
|
@ -201,8 +201,11 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
|||
:param model: The registered model name/identifier
|
||||
:return: The provider-specific model ID (e.g., "gpt-4")
|
||||
"""
|
||||
# Look up the registered model to get the provider-specific model ID
|
||||
# self.model_store is injected by the distribution system at runtime
|
||||
if not await self.model_store.has_model(model): # type: ignore[attr-defined]
|
||||
return model
|
||||
|
||||
# Look up the registered model to get the provider-specific model ID
|
||||
model_obj: Model = await self.model_store.get_model(model) # type: ignore[attr-defined]
|
||||
# provider_resource_id is str | None, but we expect it to be str for OpenAI calls
|
||||
if model_obj.provider_resource_id is None:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue