diff --git a/llama_stack/providers/remote/inference/openai/openai.py b/llama_stack/providers/remote/inference/openai/openai.py index 9a1ec7ee0..c3c25edd3 100644 --- a/llama_stack/providers/remote/inference/openai/openai.py +++ b/llama_stack/providers/remote/inference/openai/openai.py @@ -92,8 +92,11 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin): if prompt_logprobs is not None: logging.warning("prompt_logprobs is not supported by the OpenAI API. Ignoring.") + model_id = (await self.model_store.get_model(model)).provider_resource_id + if model_id.startswith("openai/"): + model_id = model_id[len("openai/") :] params = await prepare_openai_completion_params( - model=(await self.model_store.get_model(model)).provider_resource_id, + model=model_id, prompt=prompt, best_of=best_of, echo=echo, @@ -139,8 +142,11 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin): top_p: float | None = None, user: str | None = None, ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: + model_id = (await self.model_store.get_model(model)).provider_resource_id + if model_id.startswith("openai/"): + model_id = model_id[len("openai/") :] params = await prepare_openai_completion_params( - model=(await self.model_store.get_model(model)).provider_resource_id, + model=model_id, messages=messages, frequency_penalty=frequency_penalty, function_call=function_call,