Skip to content Skip to sidebar Skip to footer

How To Access Tensor Shape Within .map Function?

I have a dataset of audios in multiple lengths, and I want to crop all of them in 5 second windows (which means 240000 elements with 48000 sample rate). So, after loading the .tfre

Solution 1:

You can use a function like this:

import tensorflow as tf

def enforce_length(audio):
    # Target shape
    AUDIO_LEN = 240_000
    # Current shape
    current_len = tf.shape(audio)[0]
    # Compute number of necessary repetitions
    num_reps = AUDIO_LEN // current_len
    num_reps += tf.dtypes.cast((AUDIO_LEN % current_len) > 0, num_reps.dtype)
    # Do repetitions
    audio_rep = tf.tile(audio, [num_reps])
    # Trim to required size
    return audio_rep[:AUDIO_LEN]

# Test
examples = tf.data.Dataset.from_generator(lambda: iter([
    tf.zeros([100_000], tf.float32),
    tf.zeros([300_000], tf.float32),
    tf.zeros([123_456], tf.float32),
]), output_types=tf.float32, output_shapes=[None])
result = examples.map(enforce_length)
for item in result:
    print(item.shape)

Output:

(240000,)
(240000,)
(240000,)

Post a Comment for "How To Access Tensor Shape Within .map Function?"