+
Skip to content

Support for 'aten::unflatten' operator when exporting to ONNX opset version 12 #100826

@eliashossain001

Description

@eliashossain001

🐛 Describe the bug

Hello, I'm currently facing an issue while trying to export a PyTorch model (AlignTTS of coqui-ai) to ONNX format using the onnx.export() function with opset_version=12. Specifically, I'm getting an error message indicating that the 'aten::unflatten' operator is not supported for this opset version. I would like to request support for this operator when exporting to ONNX opset version 12. It would be very helpful if this operator could be added to the ONNX exporter or if there is a workaround or alternative solution to this issue. I have already tried lowering the opset version, but none of the solutions seem to have worked. I came across the issues of #98190, which suggested applying a custom symbolic function, but I don't understand how this can be done. I would appreciate any assistance you can provide.

Thank you!

Here is my code:



from typing import Dict
import torch
import torch.onnx as onnx
from torch.utils.tensorboard import SummaryWriter

import torch
import torchvision
import os
import torch.nn.functional as F
from TTS.tts.configs.shared_configs import BaseDatasetConfig,BaseAudioConfig,CharactersConfig
from TTS.tts.configs.glow_tts_config import GlowTTSConfig
from TTS.tts.configs.align_tts_config import AlignTTSConfig
from TTS.utils.audio import AudioProcessor
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.models.glow_tts import GlowTTS
from TTS.tts.models.align_tts import AlignTTS
from TTS.tts.datasets import load_tts_samples
from trainer import Trainer, TrainerArgs
from TTS.tts.models.align_tts import AlignTTSArgs
import torch.jit
import sys
from torch import nn
from torch.onnx.symbolic_helper import parse_args, _unimplemented
from torch.onnx import register_custom_op_symbolic


class AlignTTSModelClass(nn.Module):
    def __init__(self, num_graphemes, seq_len):
        super().__init__()

        # Create an instance of the AlignTTSConfig class with the desired parameters
        config = AlignTTSConfig(
            # model_name="align_tts",
            model_args=AlignTTSArgs(
                num_chars=num_graphemes,
                out_channels=80,
                hidden_channels=256,
                hidden_channels_dp=256,
                encoder_type="fftransformer",
                encoder_params={"hidden_channels_ffn": 1024, "num_heads": 2, "num_layers": 6, "dropout_p": 0.1},
                decoder_type="fftransformer",
                decoder_params={"hidden_channels_ffn": 1024, "num_heads": 2, "num_layers": 6, "dropout_p": 0.1},
                length_scale=1.0,
                num_speakers=0,
                use_speaker_embedding=False,
                use_d_vector_file=False,
                d_vector_dim=0,
            )
        )

        # Create an instance of the AlignTTS class using the config
        self.align_tts = AlignTTS(config)

        # Set the input shape of the model
        self.input_shape = (seq_len, num_graphemes)
        
   
    def forward(self, x: torch.Tensor, x_lengths: torch.Tensor, y: torch.Tensor, y_lengths: torch.Tensor) -> Dict[str, torch.Tensor]:
        # Perform the forward pass through the AlignTTS model, passing the additional arguments
        aux_input = {"d_vectors": None} # dummy input for the "aux_input" argument
        phase = None # dummy input for the "phase" argument
        return self.align_tts.forward(x, x_lengths, y, y_lengths, aux_input, phase)
    
    
if __name__ == '__main__':
    # Parameters
    batch_size = 8
    num_graphemes = 105
    seq_len = 256

    # Load the PyTorch checkpoint
    checkpoint_path = '/home/elias/male_checkpoint.pth'
    checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))

    dummy_input = torch.randint(low=0, high=num_graphemes, size=(batch_size, seq_len), dtype=torch.long)
    dummy_input_lengths = torch.randint(low=1, high=seq_len+1, size=(batch_size,))
    dummy_output_lengths = torch.randint(low=1, high=seq_len+1, size=(batch_size,))
    dummy_output = torch.randn(batch_size, seq_len, 80)

    model = AlignTTSModelClass(num_graphemes=num_graphemes, seq_len=seq_len)
    # Enable autograd profiler
    torch.autograd.profiler.profile(enabled=True, use_cuda=False)

    # Perform a forward pass of the model
    with torch.autograd.profiler.record_function("model_inference"):
        model(dummy_input, dummy_input_lengths, dummy_output, dummy_output_lengths)

    # Disable autograd profiler
    torch.autograd.profiler.profile(enabled=False)

    model = torch.nn.DataParallel(model)
    
    # Load the state dictionary from the checkpoint
    state_dict = checkpoint['model']
    model_state_dict = model.state_dict()
    for name, param in state_dict.items():
        if name not in model_state_dict:
            print(f"Ignoring parameter '{name}' in the checkpoint because it does not exist in the model")
        elif param.shape != model_state_dict[name].shape:
            print(f"Ignoring parameter '{name}' in the checkpoint because its shape does not match the corresponding model parameter")
        else:
            model_state_dict[name] = param

    # Load the state dictionary into the model
    model.load_state_dict(model_state_dict, strict=False)
    
 
    # Export the model to ONNX format
    onnx_path = '/home/elias/male_model.onnx'
    
    torch.onnx.utils.export(model.module, (dummy_input, dummy_input_lengths, dummy_output, dummy_output_lengths), onnx_path,
            verbose=True, export_params=True, opset_version=12, 
            input_names=['input', 'input_lengths', 'output_lengths'],
            output_names=['output'],
            dynamic_axes={'input': {0: 'batch_size', 1: 'seq_len'},
                          'input_lengths': {0: 'batch_size'},
                          'output_lengths': {0: 'batch_size'},
                          'output': {0: 'batch_size', 1: 'seq_len'}})
    
    print('Export to ONNX has been successful!!')

   

============= Diagnostic Run torch.onnx.export version 2.0.0+cu117 =============
verbose: False, log level: Level.ERROR
======================= 0 NONE 0 NOTE 0 WARNING 1 ERROR ========================
ERROR: missing-standard-symbolic-function
=========================================
Exporting the operator 'aten::unflatten' to ONNX opset version 12 is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub: https://github.com/pytorch/pytorch/issues.
None
<Set verbose=True to see more details>


Traceback (most recent call last):
  File "/home/elias/script/onnx_2.py", line 111, in <module>
    torch.onnx.utils.export(model.module, (dummy_input, dummy_input_lengths, dummy_output, dummy_output_lengths), onnx_path,
  File "/home/elias/miniconda3/envs/elias-dev/lib/python3.10/site-packages/torch/onnx/utils.py", line 506, in export
    _export(
  File "/home/elias/miniconda3/envs/elias-dev/lib/python3.10/site-packages/torch/onnx/utils.py", line 1548, in _export
    graph, params_dict, torch_out = _model_to_graph(
  File "/home/elias/miniconda3/envs/elias-dev/lib/python3.10/site-packages/torch/onnx/utils.py", line 1117, in _model_to_graph
    graph = _optimize_graph(
  File "/home/elias/miniconda3/envs/elias-dev/lib/python3.10/site-packages/torch/onnx/utils.py", line 665, in _optimize_graph
    graph = _C._jit_pass_onnx(graph, operator_export_type)
  File "/home/elias/miniconda3/envs/elias-dev/lib/python3.10/site-packages/torch/onnx/utils.py", line 1901, in _run_symbolic_function
    raise errors.UnsupportedOperatorError(
torch.onnx.errors.UnsupportedOperatorError: Exporting the operator 'aten::unflatten' to ONNX opset version 12 is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub: https://github.com/pytorch/pytorch/issues.

Versions

CPU:
Architecture:          x86_64
CPU op-mode(s):        32-bit, 64-bit
Byte Order:            Little Endian
CPU(s):                40
On-line CPU(s) list:   0-39
Thread(s) per core:    2
Core(s) per socket:    10
Socket(s):             2
NUMA node(s):          1
Vendor ID:             GenuineIntel
CPU family:            6
Model:                 63
Model name:            Intel(R) Xeon(R) CPU E5-2650 v3 @ 2.30GHz
Stepping:              2
CPU MHz:               1499.969
CPU max MHz:           3000.0000
CPU min MHz:           1200.0000
BogoMIPS:              4600.00
Virtualization:        VT-x
L1d cache:             32K
L1i cache:             32K
L2 cache:              256K
L3 cache:              25600K
NUMA node0 CPU(s):     0-39

Versions of relevant libraries:
[pip3] numpy==1.23.5
[pip3] torch==2.0.0
[pip3] torchaudio==2.0.1+cu117
[pip3] torchvision==0.15.1+cu117
[pip3] triton==2.0.0
[conda] numpy                     1.23.5                   pypi_0    pypi
[conda] torch                     2.0.0                    pypi_0    pypi
[conda] torchaudio                2.0.1+cu117              pypi_0    pypi
[conda] torchvision               0.15.1+cu117             pypi_0    pypi
[conda] triton                    2.0.0                    pypi_0    pypi

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: onnxRelated to torch.onnxtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    Status

    Done

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      点击 这是indexloc提供的php浏览器服务,不要输入任何密码和下载