-
Notifications
You must be signed in to change notification settings - Fork 25.5k
Description
🐛 Describe the bug
Hello, I'm currently facing an issue while trying to export a PyTorch model (AlignTTS of coqui-ai) to ONNX format using the onnx.export() function with opset_version=12. Specifically, I'm getting an error message indicating that the 'aten::unflatten' operator is not supported for this opset version. I would like to request support for this operator when exporting to ONNX opset version 12. It would be very helpful if this operator could be added to the ONNX exporter or if there is a workaround or alternative solution to this issue. I have already tried lowering the opset version, but none of the solutions seem to have worked. I came across the issues of #98190, which suggested applying a custom symbolic function, but I don't understand how this can be done. I would appreciate any assistance you can provide.
Thank you!
Here is my code:
from typing import Dict
import torch
import torch.onnx as onnx
from torch.utils.tensorboard import SummaryWriter
import torch
import torchvision
import os
import torch.nn.functional as F
from TTS.tts.configs.shared_configs import BaseDatasetConfig,BaseAudioConfig,CharactersConfig
from TTS.tts.configs.glow_tts_config import GlowTTSConfig
from TTS.tts.configs.align_tts_config import AlignTTSConfig
from TTS.utils.audio import AudioProcessor
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.models.glow_tts import GlowTTS
from TTS.tts.models.align_tts import AlignTTS
from TTS.tts.datasets import load_tts_samples
from trainer import Trainer, TrainerArgs
from TTS.tts.models.align_tts import AlignTTSArgs
import torch.jit
import sys
from torch import nn
from torch.onnx.symbolic_helper import parse_args, _unimplemented
from torch.onnx import register_custom_op_symbolic
class AlignTTSModelClass(nn.Module):
def __init__(self, num_graphemes, seq_len):
super().__init__()
# Create an instance of the AlignTTSConfig class with the desired parameters
config = AlignTTSConfig(
# model_name="align_tts",
model_args=AlignTTSArgs(
num_chars=num_graphemes,
out_channels=80,
hidden_channels=256,
hidden_channels_dp=256,
encoder_type="fftransformer",
encoder_params={"hidden_channels_ffn": 1024, "num_heads": 2, "num_layers": 6, "dropout_p": 0.1},
decoder_type="fftransformer",
decoder_params={"hidden_channels_ffn": 1024, "num_heads": 2, "num_layers": 6, "dropout_p": 0.1},
length_scale=1.0,
num_speakers=0,
use_speaker_embedding=False,
use_d_vector_file=False,
d_vector_dim=0,
)
)
# Create an instance of the AlignTTS class using the config
self.align_tts = AlignTTS(config)
# Set the input shape of the model
self.input_shape = (seq_len, num_graphemes)
def forward(self, x: torch.Tensor, x_lengths: torch.Tensor, y: torch.Tensor, y_lengths: torch.Tensor) -> Dict[str, torch.Tensor]:
# Perform the forward pass through the AlignTTS model, passing the additional arguments
aux_input = {"d_vectors": None} # dummy input for the "aux_input" argument
phase = None # dummy input for the "phase" argument
return self.align_tts.forward(x, x_lengths, y, y_lengths, aux_input, phase)
if __name__ == '__main__':
# Parameters
batch_size = 8
num_graphemes = 105
seq_len = 256
# Load the PyTorch checkpoint
checkpoint_path = '/home/elias/male_checkpoint.pth'
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
dummy_input = torch.randint(low=0, high=num_graphemes, size=(batch_size, seq_len), dtype=torch.long)
dummy_input_lengths = torch.randint(low=1, high=seq_len+1, size=(batch_size,))
dummy_output_lengths = torch.randint(low=1, high=seq_len+1, size=(batch_size,))
dummy_output = torch.randn(batch_size, seq_len, 80)
model = AlignTTSModelClass(num_graphemes=num_graphemes, seq_len=seq_len)
# Enable autograd profiler
torch.autograd.profiler.profile(enabled=True, use_cuda=False)
# Perform a forward pass of the model
with torch.autograd.profiler.record_function("model_inference"):
model(dummy_input, dummy_input_lengths, dummy_output, dummy_output_lengths)
# Disable autograd profiler
torch.autograd.profiler.profile(enabled=False)
model = torch.nn.DataParallel(model)
# Load the state dictionary from the checkpoint
state_dict = checkpoint['model']
model_state_dict = model.state_dict()
for name, param in state_dict.items():
if name not in model_state_dict:
print(f"Ignoring parameter '{name}' in the checkpoint because it does not exist in the model")
elif param.shape != model_state_dict[name].shape:
print(f"Ignoring parameter '{name}' in the checkpoint because its shape does not match the corresponding model parameter")
else:
model_state_dict[name] = param
# Load the state dictionary into the model
model.load_state_dict(model_state_dict, strict=False)
# Export the model to ONNX format
onnx_path = '/home/elias/male_model.onnx'
torch.onnx.utils.export(model.module, (dummy_input, dummy_input_lengths, dummy_output, dummy_output_lengths), onnx_path,
verbose=True, export_params=True, opset_version=12,
input_names=['input', 'input_lengths', 'output_lengths'],
output_names=['output'],
dynamic_axes={'input': {0: 'batch_size', 1: 'seq_len'},
'input_lengths': {0: 'batch_size'},
'output_lengths': {0: 'batch_size'},
'output': {0: 'batch_size', 1: 'seq_len'}})
print('Export to ONNX has been successful!!')
============= Diagnostic Run torch.onnx.export version 2.0.0+cu117 =============
verbose: False, log level: Level.ERROR
======================= 0 NONE 0 NOTE 0 WARNING 1 ERROR ========================
ERROR: missing-standard-symbolic-function
=========================================
Exporting the operator 'aten::unflatten' to ONNX opset version 12 is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub: https://github.com/pytorch/pytorch/issues.
None
<Set verbose=True to see more details>
Traceback (most recent call last):
File "/home/elias/script/onnx_2.py", line 111, in <module>
torch.onnx.utils.export(model.module, (dummy_input, dummy_input_lengths, dummy_output, dummy_output_lengths), onnx_path,
File "/home/elias/miniconda3/envs/elias-dev/lib/python3.10/site-packages/torch/onnx/utils.py", line 506, in export
_export(
File "/home/elias/miniconda3/envs/elias-dev/lib/python3.10/site-packages/torch/onnx/utils.py", line 1548, in _export
graph, params_dict, torch_out = _model_to_graph(
File "/home/elias/miniconda3/envs/elias-dev/lib/python3.10/site-packages/torch/onnx/utils.py", line 1117, in _model_to_graph
graph = _optimize_graph(
File "/home/elias/miniconda3/envs/elias-dev/lib/python3.10/site-packages/torch/onnx/utils.py", line 665, in _optimize_graph
graph = _C._jit_pass_onnx(graph, operator_export_type)
File "/home/elias/miniconda3/envs/elias-dev/lib/python3.10/site-packages/torch/onnx/utils.py", line 1901, in _run_symbolic_function
raise errors.UnsupportedOperatorError(
torch.onnx.errors.UnsupportedOperatorError: Exporting the operator 'aten::unflatten' to ONNX opset version 12 is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub: https://github.com/pytorch/pytorch/issues.
Versions
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
CPU(s): 40
On-line CPU(s) list: 0-39
Thread(s) per core: 2
Core(s) per socket: 10
Socket(s): 2
NUMA node(s): 1
Vendor ID: GenuineIntel
CPU family: 6
Model: 63
Model name: Intel(R) Xeon(R) CPU E5-2650 v3 @ 2.30GHz
Stepping: 2
CPU MHz: 1499.969
CPU max MHz: 3000.0000
CPU min MHz: 1200.0000
BogoMIPS: 4600.00
Virtualization: VT-x
L1d cache: 32K
L1i cache: 32K
L2 cache: 256K
L3 cache: 25600K
NUMA node0 CPU(s): 0-39
Versions of relevant libraries:
[pip3] numpy==1.23.5
[pip3] torch==2.0.0
[pip3] torchaudio==2.0.1+cu117
[pip3] torchvision==0.15.1+cu117
[pip3] triton==2.0.0
[conda] numpy 1.23.5 pypi_0 pypi
[conda] torch 2.0.0 pypi_0 pypi
[conda] torchaudio 2.0.1+cu117 pypi_0 pypi
[conda] torchvision 0.15.1+cu117 pypi_0 pypi
[conda] triton 2.0.0 pypi_0 pypi
Metadata
Metadata
Assignees
Labels
Type
Projects
Status