Tacotron CBHG module on MXNet

Hello,

I’m new on MXNet and in DL field in general. For a project of mine I’m trying to implement Tacotron on Python MXNet.

For a newbie like me, is a kind of difficult task but it would be very useful to have some hints on the C(onvolutional 1-D filters)B(ank)H(ighway networks)G(ated recurrent unit bidirectional).

Right now I’m trying to use the CBHG for predict linear scale spectrograms from mel spectrograms (this is the last part of the Tacotron system). This module is used on the encoder part to process embedding of size 256

My data shapes are like: (batch_size, num_bands, num_time_frames)
So linear spectrograms got num_bands = 1025 and Mel spectrograms got num_bands=80
num_time_frames is fixed to the maximum audio file length among all my audio data

I’m stuck on make shapes compatibles during all steps of the CBHG module. Those steps consist of:

  • Create a bank of K stacked convolutional filter which each k filter got a kernel width of k (1° filter: kernel width=1, 2° filter: kernel width=2…)

  • Maxpooling

  • other 2 convolution for projection: 1° with kernel_size=3, num_filter = 256, 2° with kernel_size=3, num_filter = 80

  • Highway net, 4 layers fully-connected, num_hidden=128

  • Bidirectional GRU, 128 cells

This is my code:

Convolution bank of K filter, emb_size=256

def conv1dBank(conv_input, K):
    conv=mx.sym.Convolution(data=conv_input, kernel=(1,), num_filter=emb_size//2)
    (conv, mean, var) = mx.sym.BatchNorm(data=conv, output_mean_var=True)
    conv = mx.sym.Activation(data=conv, act_type='relu')
    for k in range(2, K+1):
        convi = mx.sym.Convolution(data=conv_input, kernel=(k,), num_filter=emb_size//2)
        (convi, mean, var) = mx.sym.BatchNorm(data=convi, output_mean_var=True)
        convi = mx.sym.Activation(data=convi, act_type='relu')
        conv = mx.symbol.concat(conv,convi,dim=2) #TODO: Need to concat on the num_filter dimension!
    return conv

highway

def highway_layer(data):
    H= mx.symbol.Activation(
        data=mx.symbol.FullyConnected(data=data, num_hidden=emb_size//2, name="highway_fcH"),
        act_type="relu"
    )
    T= mx.symbol.Activation(
        data=mx.symbol.FullyConnected(data=data, num_hidden=emb_size//2, bias=mx.sym.Variable('bias'), name="highway_fcT"),
        act_type="sigmoid"
    )
    return  H * T + data * (1.0 - T)

CBHG

def CBHG(data,K,proj1_size,proj2_size):
    bank = conv1dBank(data,K)    
    poold_bank = mx.sym.Pooling(data=bank, pool_type='max', kernel=(2,), stride=(1,), name="CBHG_pool")

    proj1 = mx.sym.Convolution(data=poold_bank, kernel=(3,), num_filter=proj1_size, name='CBHG_conv1')
    (proj1, proj1_mean, proj1_var) = mx.sym.BatchNorm(data=proj1, output_mean_var=True, name='CBHG_batch1')
    proj1 = mx.sym.Activation(data=proj1, act_type='relu', name='CBHG_act1')

    proj2 = mx.sym.Convolution(proj1, kernel=(3,), num_filter=proj2_size, name='CBHG_conv2')
    (proj2, proj2_mean, proj2_var) = mx.sym.BatchNorm(data=proj2, output_mean_var=True, name='CBHG_batch2')
    
    residual= proj2 + data #Error here: incompatible shapes 

    for i in range(4):
        residual = highway_layer(residual)
    highway_pass = residual
   
    bidirectional_gru_cell = mx.rnn.BidirectionalCell(
        mx.rnn.GRUCell(num_hidden=emb_size//2, prefix='CBHG_gru1'),
        mx.rnn.GRUCell(num_hidden=emb_size//2, prefix='CBHG_gru2'),
        output_prefix='CBHG_bi_'
    )
    outputs, states = bidirectional_gru_cell.unroll(1, inputs=highway_pass, merge_outputs=True)
    return outputs

So, if I infer the shape with a dummy ndarray with shape (batch_size, num_bands, time_frames) I got error on incompatible shapes during the residual sum

in_cbhg = mx.sym.Variable("in_cbhg")
in_cbhg_shape = (2,80,100)
CBHG(in_cbhg,hp.decoder_num_banks,hp.embed_size,hp.n_mels).infer_shape(in_cbhg=in_cbhg_shape)

infer_shape error. Arguments:
  in_cbhg: (2, 80, 100)
Incompatible attr in node _plus70 at 1-th input: expected (2,80,767), got (2,80,100)

How can I solve it? What’s wrong with my code? Are my input data shaped right?
Here some CBHG implementation on TFlow:

I’ve reproduced the issue and am looking into it. Don’t think it matters, but I’m using MXNet 0.12.1 on my side and these parameters:
decoder_num_banks = 8
embed_size = 256
n_mels = 80

Hi @safrooze thank you for helping me.

I have managed to make the CBHG work. I think that I have forced it to make it work using some manual padding, swap axes and things like that.

I would like to have some feedback of what I’ve done because I don’t think it is the correct way.

Right now I got some result using a very simple dataset. From this free digit dataset here I have taken 50 audio files of one speaker saying “seven” and then I fed the to the CBHG using simple custom iterator that explicitly set the data layout to NTC. I got as output a predicted spectrogram that I converted back to waveform and it actually says “seven” :slight_smile:

The custom iterator need to be parallized when I will tackle more data (I will use the Bible recordings founded here ) . This is will be a subject of other thread

So this is my code updated:

CONV1D BANK

# Convolution bank of K filter
def conv1dBank(conv_input, K): 
    #from prenet: conv_input.shape = (batch_size,128, time_frames)
    
    #The k-th filter got a kernel width of k, with 0<k<=K
    conv=mx.sym.Convolution(data=conv_input, kernel=(1,), num_filter=hp.emb_size//2,name="convBank_1")
    #conv.shape = (batch_size, 128 , time_frames)

    conv = mx.sym.Activation(data=conv, act_type='relu')
    if hp.use_convBank_batchNorm:
        conv = mx.sym.BatchNorm(data=conv, name="batchN_bank_1")

    for k in range(2, K+1):
        #pad the last dimension to force original time_frames length
        in_i = mx.sym.concat(conv_input,mx.sym.zeros((hp.batch_size,hp.emb_size//2,k-1)),dim=2)
        convi = mx.sym.Convolution(data=in_i, kernel=(k,), num_filter=hp.emb_size//2,name="convBank_"+str(k))
        #convi.shape = (batch_size,128,time_frames)
        
        convi = mx.sym.Activation(data=convi,act_type='relu')
        if hp.use_convBank_batchNorm:
            convi = mx.sym.BatchNorm(data=convi, name='batchN_bank_'+str(k))
        conv = mx.symbol.concat(conv,convi,dim=1)
        #conv.shape = (batch_size,k*128,time_frames)
    return conv

and the

CBHG

# CBHG
def CBHG(data,K,proj1_size,proj2_size,num_unroll):
    bank = conv1dBank(data,K)

    #After the convolutional bank, a max pooling is applied
    #Again here. To obtain always the same dimension I'm padding the input of each operation
    conv_padded = mx.sym.concat(bank,mx.sym.zeros((hp.batch_size,K*(hp.emb_size//2),1)),dim=2)
    poold_bank = mx.sym.Pooling(data=conv_padded, pool_type='max', kernel=(2,), stride=(1,), name="CBHG_pool")
    #shape here: (batch_szie,K*128,time_frames)
    #Now two other projections (convolutions) are done. Same padding thing
    poold_bank_padded = mx.sym.concat(poold_bank,mx.sym.zeros((hp.batch_size,K*(hp.emb_size//2),2)),dim=2)

    proj1 = mx.sym.Convolution(data=poold_bank_padded, kernel=(3,), num_filter=proj1_size, name='CBHG_conv1')
    proj1 = mx.sym.Activation(data=proj1, act_type='relu', name='CBHG_act1')

    if hp.use_proj1_batchNorm:
        proj1 = mx.sym.BatchNorm(data=proj1, name="batchNorm_proj1")

    proj1_padded = mx.sym.concat(proj1,mx.sym.zeros((hp.batch_size,hp.emb_size,2)),dim=2)
    proj2 = mx.sym.Convolution(proj1_padded, kernel=(3,), num_filter=proj2_size, name='CBHG_conv2')

    if hp.use_proj2_batchNorm:
        proj2=mx.sym.BatchNorm(data=proj2, name="batchNorm_proj2")

    #Adding residual connection. The output of the prenet pass is added to proj2
    residual= proj2 + data

    residual = mx.sym.swapaxes(residual,1,2)
    #shape here: (batch_size, time_frames, proj2_size). note: proj2_size=128 and not 80 as paper states
    #A 4 highway layers is created
    for i in range(4):
        residual = highway_layer(residual,i)
    highway_pass = residual

    #The highway output is passed to the bidirectional gru cell
    bidirectional_gru_cell = mx.rnn.BidirectionalCell(
        mx.rnn.GRUCell(num_hidden=hp.emb_size//2, prefix='CBHG_gru1'),
        mx.rnn.GRUCell(num_hidden=hp.emb_size//2, prefix='CBHG_gru2'),
        output_prefix='CBHG_bi_'
    )

    bi_gru_outputs, bi_gru_states = bidirectional_gru_cell.unroll(num_unroll, inputs=highway_pass, merge_outputs=True)

    return bi_gru_outputs

I’m sharing the most relevant code to me.

Those manual pad, swap axes things looks like ugly to me, am I wrong?
Looking at the same implementation on Tensorflow I see that code doesn’t not care too much of shape consistency. This is the output of the Tensorflow code of the repository that I shared in my first post.

I would like to note that time_frames (here 64) is not padded, so for each new spectrogram a different time frame length is given (resulting from the actual audio file length)

input of CBHG:
Mel batch shape: [32 64 80] 

used for loss estimation:
Linear batch shape: [  32   64 1025]

Shapes of post processing CBHG

1° step: prenet_out [ 32  64 128]

2° step: conv1d_bank_shape [  32   64 1024]

3° step: max_pool_shape [  32   64 1024]

4° step: proj1_shape [ 32  64 256]

5° step: proj2_shape [ 32  64 128]

6° step: residual_shape [ 32  64 128]

7° step: highway_shape [ 32  64 128]

8° step: gru_shape [ 32  64 256]

9° step: final_shape [  32   64 1025]

I’m going to address a number of items in no particular order:
Convolution
Convolution in MXNet is different from TF in two main ways:

  1. the default dim order in MXNet is CHW, whereas in TF it is HWC (you already know this).
  2. TF has two padding options, ‘valid’ and ‘same’. ‘valid’ is simply no-padding. ‘same’ is enough padding so that the output HW dimension is equal to the input HW dimension divided by stride. Given how convolution works, this means that if the kernel dimension is an even value (2, 4, etc.), the padding is non-symmetric (see explanation here). However with MXNet, padding is specified as a number and it is applied symmetrically to the input array before convolving it with the kernel. In most applications, kernel dimensions are odd number, but in your case, with even number kernels, you either need to do padding yourself (which is what you’re trying to do) or you need to slice the output array after convolution because the output will be off by 1. For example if your input array width is 100 and kernel width is 2 and you specify pad=1, then the output array will have a width of 101 and you need to use sym.slice to slice off one element from the beginning of the array. Example:

.

# Using padding with sym.Convolution to get the same effect as padding='same' in TF
convi = mx.sym.Convolution(data=conv_input, kernel=(k,), num_filter=emb_size//2, pad=k//2)
if k/2 == k//2:
    convi = mx.sym.slice_axis(convi, axis=2, begin=1)

n-gram convolution
According to the paper, the intention of the algorithm for using 1 to K size kernels is to capture n-grams of 1 to K. In such case, perhaps symmetrically padding the input to the convolution (as done in the TF example) is the wrong thing to do (although I don’t know how much the results may be effected). Potentially padding the output of convolution at the end of the array is a more correct approach. On the other hand, the conv1d projections probably need to be symmetrical.

Padding
There is an MXNet function specifically for padding (sym.pad). Unfortunately it is only currently implemented for 4-D and 5-D arrays. The advantage of using this function is that you do not need to specify exact dimension the way you do with sym.concat. You can use sym.expand_dims() followed by sym.pad(), followed by sym.reshape() to pad your 3D array. You can write your own convolution wrapper to implement the behavior that you want. Example:

# padding output after convolution
in_i = mx.sym.reshape(
    mx.sym.pad(
        mx.sym.expand_dims(conv, axis=3),
        mode='constant',
        pad_width=(0, 0, 0, 0, 0, k - 1, 0, 0)),
    shape=(0, 0, 0)
    )

swapaxis
In TF they decided to keep the default dimension order as HWC and as a result, when you do a dense (or fully connected) layer, no swapaxis is required. In MXNet, the default is CHW, so you do need to swap axis before the fully connected layer because fully connected always works on the last dimension(s). An alternative could have been to change the layout in convolution operator. However, unfortunately at the moment MXNet only supports ‘NCW’ layout for 1D convolution (‘NWC’ layout is what TF uses by default for 1D convolution, but that’s not currently supported by MXNet).

FullyConnected
There is a big difference between the behavior of mx.sym.FullyConnected and tf.layers.dense. TF creates the FC layer from the last dimension only. MXNet, however, by default, flattens all by first dimension and feeds the flattened dimension to FC layer. In CBHG, FC must be operated on the last dimension only. MXNet can be configured to do so by passing the Flatten=False flag. By default, Flatten is True. Example:

H = mx.sym.Activation(
    data=mx.sym.FullyConnected(
        data=data,
        num_hidden=emb_size//2,
        flatten=False,
        name="highway_fcH"),
    act_type="relu"
)

Highway Layer Input
I don’t know if you’ve already fixed this, but the dimension of input to the highway layer must match the dimension of the FC output in highway layer. So you need to run the input through an FC layer before passing it to the highway layer. This is done in the TF code as well. No activation should be necessary.

residual = mx.sym.swapaxes(residual, 1, 2) # This is from your code
residual = mx.sym.FullyConnected(
    residual,
    num_hidden=emb_size//2,
    flatten=False,
    name='highway_input_FC')   # FC to match input and output of highway

RNN
Everything looks OK with the RNN section. Just remember to reset the RNN cell when switching to a new sequence in training or inference (see here for more details).

Feel free to report other issues you may be facing with training a dynamic sequence dataset.

1 Like

By the way, since you’re just getting started, I highly recommend switching to gluon API instead of the symbolic API which is much easier to debug and has almost the same performance as the symbolic. Feel free to post questions regarding gluon if you decide to take advantage of it.

@safrooze thank you. Just, thank you. Finally everything is more clear. I will test everything you told me as soon as I can.

i would like to point out that when I started coding I started with gluon and create the pre-net of Tacotron. It was pretty easy but I suddenly get lost during the conv1DBank implementation 'cause I got no idea how to do that. Than looking at the symbolic api, I thought that they were more straightforward.
I will try to use Gluon again for sure.

Thank you so much again and I will update you with my progresses