test, rerecord

# What does this PR do?


## Test Plan
# What does this PR do?


## Test Plan
This commit is contained in:
Eric Huang 2025-08-01 15:23:16 -07:00
parent 140ee7d337
commit 204c2717ce
28 changed files with 4224 additions and 1003 deletions

View file

@ -217,55 +217,21 @@ class ResponseStorage:
return cast(dict[str, Any], data)
async def _patched_inference_method(original_method, self, client_type, method_name=None, *args, **kwargs):
async def _patched_inference_method(original_method, self, client_type, endpoint, *args, **kwargs):
global _current_mode, _current_storage
if _current_mode == InferenceMode.LIVE or _current_storage is None:
# Normal operation
return await original_method(self, *args, **kwargs)
# Get base URL and endpoint based on client type
# Get base URL based on client type
if client_type == "openai":
base_url = str(self._client.base_url)
# Determine endpoint based on the method's module/class path
method_str = str(original_method)
if "chat.completions" in method_str:
endpoint = "/v1/chat/completions"
elif "embeddings" in method_str:
endpoint = "/v1/embeddings"
elif "completions" in method_str:
endpoint = "/v1/completions"
else:
# Fallback - try to guess from the self object
if hasattr(self, "_resource") and hasattr(self._resource, "_resource"):
resource_name = getattr(self._resource._resource, "_resource", "unknown")
if "chat" in str(resource_name):
endpoint = "/v1/chat/completions"
elif "embeddings" in str(resource_name):
endpoint = "/v1/embeddings"
else:
endpoint = "/v1/completions"
else:
endpoint = "/v1/completions"
elif client_type == "ollama":
# Get base URL from the client (Ollama client uses host attribute)
base_url = getattr(self, "host", "http://localhost:11434")
if not base_url.startswith("http"):
base_url = f"http://{base_url}"
# Determine endpoint based on method name
if method_name == "generate":
endpoint = "/api/generate"
elif method_name == "chat":
endpoint = "/api/chat"
elif method_name == "embed":
endpoint = "/api/embeddings"
elif method_name == "list":
endpoint = "/api/tags"
else:
endpoint = f"/api/{method_name}"
else:
raise ValueError(f"Unknown client type: {client_type}")
@ -366,14 +332,18 @@ def patch_inference_clients():
# Create patched methods for OpenAI client
async def patched_chat_completions_create(self, *args, **kwargs):
return await _patched_inference_method(
_original_methods["chat_completions_create"], self, "openai", *args, **kwargs
_original_methods["chat_completions_create"], self, "openai", "/v1/chat/completions", *args, **kwargs
)
async def patched_completions_create(self, *args, **kwargs):
return await _patched_inference_method(_original_methods["completions_create"], self, "openai", *args, **kwargs)
return await _patched_inference_method(
_original_methods["completions_create"], self, "openai", "/v1/completions", *args, **kwargs
)
async def patched_embeddings_create(self, *args, **kwargs):
return await _patched_inference_method(_original_methods["embeddings_create"], self, "openai", *args, **kwargs)
return await _patched_inference_method(
_original_methods["embeddings_create"], self, "openai", "/v1/embeddings", *args, **kwargs
)
# Apply OpenAI patches
AsyncChatCompletions.create = patched_chat_completions_create
@ -383,30 +353,32 @@ def patch_inference_clients():
# Create patched methods for Ollama client
async def patched_ollama_generate(self, *args, **kwargs):
return await _patched_inference_method(
_original_methods["ollama_generate"], self, "ollama", "generate", *args, **kwargs
_original_methods["ollama_generate"], self, "ollama", "/api/generate", *args, **kwargs
)
async def patched_ollama_chat(self, *args, **kwargs):
return await _patched_inference_method(
_original_methods["ollama_chat"], self, "ollama", "chat", *args, **kwargs
_original_methods["ollama_chat"], self, "ollama", "/api/chat", *args, **kwargs
)
async def patched_ollama_embed(self, *args, **kwargs):
return await _patched_inference_method(
_original_methods["ollama_embed"], self, "ollama", "embed", *args, **kwargs
_original_methods["ollama_embed"], self, "ollama", "/api/embeddings", *args, **kwargs
)
async def patched_ollama_ps(self, *args, **kwargs):
return await _patched_inference_method(_original_methods["ollama_ps"], self, "ollama", "ps", *args, **kwargs)
return await _patched_inference_method(
_original_methods["ollama_ps"], self, "ollama", "/api/ps", *args, **kwargs
)
async def patched_ollama_pull(self, *args, **kwargs):
return await _patched_inference_method(
_original_methods["ollama_pull"], self, "ollama", "pull", *args, **kwargs
_original_methods["ollama_pull"], self, "ollama", "/api/pull", *args, **kwargs
)
async def patched_ollama_list(self, *args, **kwargs):
return await _patched_inference_method(
_original_methods["ollama_list"], self, "ollama", "list", *args, **kwargs
_original_methods["ollama_list"], self, "ollama", "/api/tags", *args, **kwargs
)
# Apply Ollama patches