Mar 3, 2026 | Read time 9 min

Understanding TDT: The Mechanism Behind the Fastest Models on the Open ASR Leaderboard

Why predicting durations as well as tokens allows transducer models to skip frames and achieve up to 2.82× faster inference.
Oliver Parish Machine Learning Engineer

TL;DR

The Token-and-Duration Transducer (TDT) extends RNN-T by jointly predicting what token to emit and how many frames that token covers. This lets the model skip multiple encoder frames per step during inference instead of advancing one at a time, yielding up to 2.82x faster decoding with comparable or better accuracy.

Word Error Rate (WER) is a useful metric to try to optimise, but if your model takes 10 seconds to transcribe 1 second of audio, nobody's shipping it. The Huggingface Open ASR Leaderboard tracks both accuracy and speed. At the time of writing, in the huggingface top 10, Nvidia's Parakeet TDT models are more than 3x ahead of the nearest competition in RTFx (Inverse Real Time Factor/Throughput, i.e. how many seconds of audio the model can process per second of wall-clock time).

These models are significantly faster than the competition while maintaining competitive WERs. The mechanism? A modification to the RNN-Transducer called the Token-and-Duration Transducer (TDT). In this post, we'll first look at how RNN-T and TDT work at inference time to build intuition for why TDT is faster, then circle back to explain how each model is trained.


Part 1: Inference - Decoding with Frame Skipping

RNN-T Architecture

Without going into too much detail, there are a few ways to train a speech-to-text model: CTC, AED, Decoder-only or RNN-T/TDT. Each of these have pros and cons, for a full comparison see Desh's Analysis.

RNN-T hits a useful middle ground: it has enough modeling capacity to capture label dependencies (unlike CTC), its autoregressive component is lightweight (unlike AED/Decoder-only), and it can be trained end-to-end with a well-understood loss function. RNN-T is already fast - much faster than AED or Decoder-only models, and only somewhat slower than CTC. But there's still space to speed up.

An RNN-T consists of three components:

  1. Encoder {% katex inline %}h{% endkatex %}: Maps audio {% katex inline %}\mathbf{X}{% endkatex %} to hidden representations {% katex inline %}h_t(\mathbf{X}) \in \mathbb{R}^{H_{\text{enc}}}{% endkatex %} for each frame {% katex inline %}t{% endkatex %}. Typically, this is a large transformer.
  2. Predictor (a.k.a. decoder) {% katex inline %}g{% endkatex %}: A small autoregressive network that maps the previous non-blank tokens {% katex inline %}\mathbf{y}_{<u}{% endkatex %} to representations {% katex inline %}g_u(\mathbf{y}_{<u}) \in \mathbb{R}^{H_{\text{pred}}}{% endkatex %}. The name "RNN-T" implies an RNN here, but this can be substituted for any autoregressive network. During training, the predictor sees the ground-truth label prefix (teacher forcing); during inference, it autoregresses on its own previous predictions.
  3. Joint network {% katex inline %}J{% endkatex %}: A small network (often a single linear layer) that combines encoder and predictor outputs to produce logits over the vocabulary {% katex inline %}\mathcal{V} \cup {\varnothing}{% endkatex %} (where {% katex inline %}\varnothing{% endkatex %} is the blank symbol):

{% katex display %}\mathbf{z}_{t,u} = J(h_t(\mathbf{X}), g_u(\mathbf{y}_{<u})) \in \mathbb{R}^{|\mathcal{V}| + 1}{% endkatex %}

Here {% katex inline %}t{% endkatex %} indexes the encoder time-step, {% katex inline %}u{% endkatex %} indexes how many output labels have been emitted so far (position in the target sequence), and {% katex inline %}v{% endkatex %} denotes any candidate vocabulary symbol (including blank) when we write {% katex inline %}P(v \mid t, u){% endkatex %}.

RNN-T Inference

At inference time, the model decodes greedily by stepping through encoder frames:

# RNN-T Greedy Decoding (simplified)
t = 0
u = 0
output = []
while t < T:
    logits = joint(encoder[t], predictor(output))
    token = argmax(logits)
    if token == BLANK:
        t += 1          # advance ONE frame
    else:
        output.append(token)
        # stay at same t, advance u

At each step, the model either emits blank (advance one frame) or emits a token (stay at the same frame, advance the label index). The key observation in speeding this up will be to allow frame-skipping.

For a 10-second utterance at 80ms frame rate (after subsampling), that's ~125 sequential joint network calls at minimum. Most of those will be blanks - in typical speech, tokens are sparse relative to frames. The model spends most of its time predicting "nothing is happening" one frame at a time. The joint network is cheap per call, but the sequential one-frame-at-a-time structure leaves performance on the table.

TDT addresses this.


TDT: The Key Modification

The core idea of TDT (Xu et al., 2023): instead of predicting just a token at each step, jointly predict the token and how many frames it covers.

In standard RNN-T, the joint network outputs a single distribution over {% katex inline %}|\mathcal{V}| + 1{% endkatex %} symbols (vocabulary + blank). In TDT, the joint network outputs two independent distributions:

  1. Token distribution: {% katex inline %}P(v \mid t, u) \in \Delta^{|\mathcal{V}|+1}{% endkatex %} - same as RNN-T
  2. Duration distribution: {% katex inline %}P(d \mid t, u) \in \Delta^{|\mathcal{D}|}{% endkatex %} - probability over a set of allowed durations

where {% katex inline %}\mathcal{D}{% endkatex %} is a predefined set of durations. A typical choice is {% katex inline %}\mathcal{D} = \lbrace 0, 1, 2, 3, 4\rbrace {% endkatex %}, though the set can be configured - for example, {% katex inline %} \lbrace 1, 2, 3, 4\rbrace {% endkatex %} (omitting 0) is also valid.

The two heads share the same encoder and predictor representations but are independently normalized (separate softmax operations):

# TDT Joint Network Output
logits = joint(encoder[t], predictor(output))  # shape: [V + 1 + |D|]

# Split into token and duration logits
token_logits = logits[:V+1]                    # shape: [V + 1]
duration_logits = logits[V+1:]                 # shape: [|D|]

# Independent softmax
token_probs = softmax(token_logits)
duration_probs = softmax(duration_logits)

TDT Inference

The inference speedup is immediate. Compare with the RNN-T loop above:

# TDT Greedy Decoding (simplified)
t = 0
output = []
while t < T:
    logits = joint(encoder[t], predictor(output))
    token = argmax(token_logits)
    duration = argmax(duration_logits)

    if token == BLANK:
        t += max(1, duration)    # skip MULTIPLE frames!
    else:
        output.append(token)
        t += duration            # can also skip frames on token emission

Instead of advancing one frame at a time, the model can skip over stretches of silence or steady-state audio. If the model predicts blank with duration 4, it skips 4 frames in one step - reducing joint network calls for that stretch proportionally.

Let's trace through a concrete example. Suppose we have 8 encoder frames ({% katex inline %}T = 8{% endkatex %}), target "hi" → tokens [h, i], and durations {% katex inline %}\mathcal{D} = \lbrace 0, 1, 2, 3\rbrace {% endkatex %}:

t=0: joint(enc[0], pred([]))
     → token=h (p=0.8), duration=0 (p=0.7)
     → emit 'h', stay at t=0

t=0: joint(enc[0], pred([h]))
     → token=i (p=0.6), duration=2 (p=0.5)
     → emit 'i', jump to t=2

t=2: joint(enc[2], pred([h, i]))
     → token=blank (p=0.9), duration=3 (p=0.6)
     → skip to t=5

t=5: joint(enc[5], pred([h, i]))
     → token=blank (p=0.95), duration=3 (p=0.8)
     → skip to t=8 → DONE!

4 joint network calls instead of 8+ for standard RNN-T. That's the speedup.

The TDT paper reports up to 2.82x faster inference than standard RNN-T on speech recognition tasks, with comparable or better accuracy. The speedup is more pronounced on longer utterances with more silence.


Part 2: Training - Mechanics of Forward-Backward

Now that we've seen what these models do at inference time, let's understand how they're trained. This requires a bit more machinery.

The Lattice and Alignments (Standard RNN-T)

During training, have the audio {% katex inline %}\mathbf{X}{% endkatex %} and we know the correct transcription {% katex inline %}\mathbf{y}{% endkatex %}, but typically don't know the correct frame-word alignment.

Suppose we have {% katex inline %}T = 8{% endkatex %} encoder frames and the target transcription {% katex inline %}\mathbf{y} = [\text{“the"}, \text{“quick"}, \text{“brown"}, \text{“fox"}]{% endkatex %} ({% katex inline %}U = 4{% endkatex %}), we have many potential ways to get the exact same transcript, for example:

Path A (early speech):    ∅, the, quick, brown, fox, ∅, ∅, ∅, ∅, ∅, ∅, ∅
       (orange)           → all tokens emitted by t=4, rest is silence

Path B (spread out):      the, ∅, ∅, quick, ∅, brown, ∅, ∅, fox, ∅, ∅, ∅
       (pink)             → tokens spread across the utterance

Path C (late speech):     ∅, ∅, ∅, ∅, ∅, the, quick, brown, ∅, ∅, fox, ∅
       (blue)             → speech starts late, around t=5

Remember that each time we output a blank symbol {% katex inline %}\varnothing{% endkatex %}, we increment {% katex inline %}t{% endkatex %} (the time-frame of the encoder) and each time we output a token, we feed that back into the predictor to get the next predictor output (increment {% katex inline %}u{% endkatex %} by one).

Our goal now is to maximise the chance of the correct transcript (irrespective of the alignment - which we don't yet know). RNN-T's solution to this is maximise the probability over all possible alignments. The way we visualise this is by constructing a lattice, which will encode any possible frame-word alignment.

IMAGE 1 (1)

The joint network produces a probability distribution {% katex inline %}P(v \mid t, u){% endkatex %} at every node {% katex inline %}(t, u){% endkatex %} in a {% katex inline %}T \times (U+1){% endkatex %} grid (the lattice), where {% katex inline %}T{% endkatex %} is the number of encoder frames and {% katex inline %}U{% endkatex %} is the number of target tokens. In the above example, for {% katex inline %}t=3{% endkatex %} and {% katex inline %}u=2{% endkatex %}, we evaluate: {% katex display %}\mathbf{z}_{3,2} = J(h_3(\mathbf{X}), g(\text{[“the", “quick"]})){% endkatex %}

the joiner called on the 3rd frame of the encoder output, and the predictor called on the first 2 model outputs. This gives us a probability distribution over the entire vocab, {% katex inline %}\mathcal{V}{% endkatex %}, plus blank, {% katex inline %}\varnothing{% endkatex %}. For training, we only care about the probability of the next correct token (in this case "brown") or blank, {% katex inline %}\varnothing{% endkatex %} - so we just show these two transitions in the lattice:

  • Emit blank {% katex inline %}\varnothing{% endkatex %}: moves from {% katex inline %}(t, u) \to (t+1, u){% endkatex %} - a step right along the time axis.
  • Emit the next token {% katex inline %}y_{u+1}{% endkatex %}: moves from {% katex inline %}(t, u) \to (t, u+1){% endkatex %} - a step up along the label axis.

Every valid path from bottom-left [start] to top-right [end] emits exactly the target sequence and is a different valid alignment. Different paths through this lattice correspond to different timings of the same transcription.

To get the probability of a given path/alignment we use the product of all token/blank probabilities along that path. e.g.

{% katex display %} P(\text{Path A} \mid \mathbf{X}) = P(\varnothing \mid t=0, u=0)\cdot P(\text{“the"}\mid t=1,u=0)... {% endkatex %}

Where (as descibed above), {% katex inline %}t{% endkatex %} is the time-frame index of the encoder output, and {% katex inline %}u{% endkatex %} is the amount of the transcript that the predictor has seen so far.

The total probability of {% katex inline %}\mathbf{y}{% endkatex %} (the correct transcription) is defined as the sum over all such paths:

{% katex display %}P(\mathbf{y} \mid \mathbf{X}) = \sum_{\mathbf{a} \in \mathcal{A}(\mathbf{y})} P(\mathbf{a} \mid \mathbf{X}){% endkatex %}

where {% katex inline %}\mathcal{A}(\mathbf{y}){% endkatex %} is the set of all valid alignments for {% katex inline %}\mathbf{y}{% endkatex %}.

This probability {% katex inline %}P(\mathbf{y} \mid \mathbf{X}){% endkatex %} is the objective we will try to maximise in training. Or more accurately, we will try to minimise the negative log-likelihood:

{% katex display %}\mathcal{L}_{\mathrm{RNNT}} = - \log P(\mathbf{y} \mid \mathbf{X}){% endkatex %}

So, our loss is completely agnostic to the alignment the model wants to use, we just want to maximise the total probability mass running through this lattice.

Now we need to efficiently calculate this loss, {% katex inline %}\mathcal{L}_{\mathrm{RNNT}}{% endkatex %} - and the relevant gradients.

RNN-T Training: The Forward-Backward Algorithm

As long as we stay on this training lattice, we will produce the correct transcript. The probability of staying on this lattice is the thing we will try to maximise - so we want to boost the chance of any transitions on this lattice (scaled by the impact they have on the final probability).

It's worth noting here that the output of the joiner is normalised, so increasing the chance of e.g. the token {% katex inline %}y_1=\text{“the"}{% endkatex %}, will implicitly decrease the chance of all other tokens here e.g. {% katex inline %}y_1=\text{“then"}{% endkatex %}.

So how do we get this probability?

If we start from no transcript - with probability 1 (we must start with nothing yet transcribed) - we can get the chance of moving in either valid direction:

IMAGE 2 (1)

So, the chance of going up in the lattice - emitting {% katex inline %}z_{0,0}=\text{“the"}{% endkatex %} - is say {% katex inline %}0.4{% endkatex %}. The chance of emitting {% katex inline %}z_{0,0}=\varnothing{% endkatex %} is say {% katex inline %}0.5{% endkatex %}. This is quite good, it means the chance of emitting any other random token is only {% katex inline %}0.1{% endkatex %} - indicating a well trained model.

Here we'll keep {% katex inline %}\alpha(t,u){% endkatex %} as the probability of getting to a node (from the start). So, what's the chance of progressing any further through this lattice:

IMAGE 3 (1)

To get to node {% katex inline %}(t=2, u=1){% endkatex %} i.e. two blanks and one correct {% katex inline %}\text{“the"}{% endkatex %} token. We have three possible paths:

Path 1: "the", ∅, ∅
		(↑, →, →)
		P_1 = 0.4 * 0.1 * 0.1 = 0.004

Path 2: ∅, "the", ∅
		(→, ↑, →)
		P_2 = 0.5 * 0.5 * 0.1 = 0.025

Path 3: ∅, ∅, "the"
		(→, →, ↑)
		P_3 = 0.5 * 0.4 * 0.7 = 0.14

So the sum over all paths to node {% katex inline %}(t=2, u=1){% endkatex %} is {% katex inline %}\alpha(t=2,u=1) = 0.004 + 0.025 + 0.14 = 0.169{% endkatex %}. This means that the rest of the time: {% katex inline %}100%-16.9%=83.1%{% endkatex %} of the time, we've already gone wrong at this stage - left the training lattice - e.g. output ["then", {% katex inline %}\varnothing{% endkatex %}, {% katex inline %}\varnothing{% endkatex %}] or [{% katex inline %}\varnothing{% endkatex %}, "apple", {% katex inline %}\varnothing{% endkatex %}].

More generally, we define {% katex inline %}\alpha{% endkatex %}, a.k.a. the forward variable, as the sum of all correct the paths to a given node:

{% katex display %}\alpha(t, u) = \sum_{\text{paths from } (0,0) \text{ to } (t,u)} P(\text{path}){% endkatex %}

It's also useful to think of this as the total amount of probability mass that flows through the lattice to a given node.

Now if we enumerate all paths to the [end] node and sum the probabilities we will get the full transcript probability:

{% katex display %}P(\mathbf{y} \mid \mathbf{X}) = \alpha(T, U){% endkatex %}

The problem with this is that we will have way too many paths to enumerate. Even for the above small example, with {% katex inline %}T=8{% endkatex %} and {% katex inline %}U=4{% endkatex %} we have 330 potential paths through the lattice.

To solve this issue, we notice that to get the probability mass that gets to a given node, we only care about the mass that gets to the previous adjacent nodes (i.e. one blank token backwards, or one correct token backwards):

IMAGE 4 (1)

We don't care about individual paths leading up to these predecessor nodes, just the total sum over all possible paths to them - the total probability mass that arrives there. This means we get the following:

{% katex display %}\alpha(t, u) = \alpha(t-1, u) \cdot P(\varnothing \mid t-1, u) + \alpha(t, u-1) \cdot P(y_u \mid t, u-1){% endkatex %}

with {% katex inline %}\alpha(0, 0) = 1{% endkatex %} (as the chance of starting at {% katex inline %}(0,0){% endkatex %} is {% katex inline %}100%{% endkatex %}). Each term above says: the mass arriving at {% katex inline %}(t, u){% endkatex %} is the mass at the predecessor node, times the probability of the transition from the predecessor to {% katex inline %}(t,u){% endkatex %}. This means that we get to skip enumerating every possible path and just run through each node in the lattice with this sum - all the way to the [end] node.

Now that we've efficiently calculated the total probability {% katex inline %}P(\mathbf{y}\mid \mathbf{X})=\alpha(T,U){% endkatex %}, we need to calculate the amount that each transition effects this final sum - the gradient of {% katex inline %}P(\mathbf{y}\mid \mathbf{X}){% endkatex %} with respect to {% katex inline %}z_{t,u}{% endkatex %}. This will tell us how much to update the model weights each step. Specifically, if we make a small change in transition probabilities {% katex inline %}\partial z_{t,u}{% endkatex %}, what will be the effect on the total probability {% katex inline %}\partial P(\mathbf{y}\mid\mathbf{X}){% endkatex %}.

Gradient Calculation:

So let's work this out for a given node; e.g. probability of a blank transition at {% katex inline %}(t=3,u=1){% endkatex %}: {% katex inline %}P(\varnothing\mid t=3, u=1){% endkatex %}. This means that the model has already output the correct first token - e.g. "the" as well as 3 blank tokens {% katex inline %}\varnothing{% endkatex %} - in some order.

For some path "Path k" through the lattice - that goes through our transition {% katex inline %}z_{3,1}=\varnothing{% endkatex %}, we have:

{% katex display %}P(\text{Path k}) = P(v_0) ... \cdot P(\varnothing\mid t=3, u=1)...\cdot P(v_n){% endkatex %}

where {% katex inline %}P(v_i){% endkatex %} is some transition that exists along this path. This means that a small change in our transition {% katex inline %}\partial P(\varnothing\mid t=3, u=1){% endkatex %} will affect the total path probability:

{% katex display %}\partial P(\text{Path k}) = P(v_0) ... \cdot \partial P(\varnothing\mid t=3, u=1)...\cdot P(v_n){% endkatex %}

{% katex display %}\frac{\partial P(\text{Path k})}{\partial P(\varnothing\mid t=3, u=1)} = P(v_0) ...\cdot P(v_3) \cdot P(v_5) ...\cdot P(v_n){% endkatex %}

So, this is the amount that changing our transition probability will affect a path that uses it - just the total probability of the path (not including the given transition). We also observe that changing this probability won't affect paths that don't use this transition. We also know that the total probability of the correct transcript is the sum over all possible correct paths:

{% katex display %} P(\mathbf{y} \mid \mathbf{X}) = \sum_{\text{paths from } (0,0) \text{ to } (T,U)} P(\text{path}) {% endkatex %}

so naturally, the effect of changing this transition on {% katex inline %}P(\mathbf{y} \mid \mathbf{X}){% endkatex %}, is the sum of the effects it has on each relevant path:

{% katex display %} \partial P(\mathbf{y} \mid \mathbf{X}) = \sum_{\text{paths through } z_{3,1} = \varnothing} P(v_0)...\cdot \partial P(\varnothing\mid t=3,u=1)...\cdot P(v_n) {% endkatex %}

which we can split into the sum over paths that get to the transition {% katex inline %}z_{3,1}=\varnothing{% endkatex %}; the transition itself; and the sum over all the paths that leave the transition and get to the [end] state:

{% katex display %} \partial P(\mathbf{y} \mid \mathbf{X}) = \left[\sum_{\text{paths from } (0,0) \text{ to } (t=3,u=1)} P(\text{path})\right] \cdot \partial P(\varnothing\mid t=3,u=1) \cdot \left[\sum_{\text{paths from } (t=4,u=1) \text{ to } (T,U)} P(\text{path})\right] {% endkatex %}

All this is saying is: the amount that the total probability changes is the amount of probability mass that gets to a given transition {% katex inline %}\times{% endkatex %} the amount of probability mass that gets from the transition to the [end] state.

IMAGE 5 (1)

But we've already done the maths for the first part! The sum over paths that get to a given node is just {% katex inline %}\alpha(t,u){% endkatex %}, and the second part looks very similar - we'll call this the backward variable {% katex inline %}\beta(t',u'){% endkatex %}.

The backward variable {% katex inline %}\beta(t, u){% endkatex %} is the mirror image of {% katex inline %}\alpha(t,u){% endkatex %}: it represents the total probability mass that flows from node {% katex inline %}(t, u){% endkatex %} to the final state - "how much probability mass will still reach the target from here."

{% katex display %}\beta(t, u) = \sum_{\text{paths from } (t,u) \text{ to } (T,U)} P(\text{path}){% endkatex %}

Here, in a symmetric way, we set {% katex inline %}\beta(T,U)=1{% endkatex %}, but now walk backwards through the lattice:

{% katex display %}\beta(t, u) = \beta(t+1, u) \cdot P(\varnothing \mid t, u) + \beta(t, u+1) \cdot P(y_{u+1} \mid t, u){% endkatex %}

Effectively, "the amount of probability mass that will reach the target from each of the next nodes" {% katex inline %}\times{% endkatex %} "the probability of getting there from the current node".

IMAGE 6 (1)

This simplifies life a lot for our example:

{% katex display %} \partial P(\mathbf{y} \mid \mathbf{X}) = \alpha(t=3,u=1) \cdot \partial P(\varnothing\mid t=3,u=1) \cdot \beta(t=4,u=1) {% endkatex %}

{% katex display %} \frac{\partial P(\mathbf{y} \mid \mathbf{X})}{\partial P(\varnothing\mid t=3,u=1)} = \alpha(t=3,u=1) \cdot \beta(t=4,u=1) {% endkatex %}

Or more generally:

{% katex display %} \frac{\partial P(\mathbf{y} \mid \mathbf{X})}{\partial P(v \mid t,u)} = \alpha(t,u) \cdot \beta(t',u') {% endkatex %}

Where {% katex inline %}v{% endkatex %} is a transition at a given node {% katex inline %}(t,u){% endkatex %} pointing to another node {% katex inline %}(t',u'){% endkatex %}.

The forward variable gets the mass to the transition, and the backward variable represents the mass that will eventually arrive at the target from that point. The full loss gradient normalizes by the total likelihood (see the original Graves 2012 paper for the complete derivation):

{% katex display %}\mathcal{L}_{\mathrm{RNNT}} = - \log P(\mathbf{y} \mid \mathbf{X}){% endkatex %}

{% katex display %}\frac{\partial \mathcal{L}_{\mathrm{RNNT}}}{\partial P(v \mid t, u)} = \frac{\partial \mathcal{L}_{\mathrm{RNNT}}}{\partial P(\mathbf{y} \mid \mathbf{X})} \cdot \frac{\partial P(\mathbf{y} \mid \mathbf{X})}{\partial P(v \mid t, u)} = -\frac{\alpha(t,u) \cdot \beta(t', u')}{P(\mathbf{y} \mid \mathbf{X})}{% endkatex %}

This gives a nice result! The gradient with respect to any given transition probability is just the proportion of the total probability mass that flows through that transition. Early in training, when {% katex inline %}P(\mathbf{y} \mid \mathbf{X}){% endkatex %} is small, the gradient is still significant for any correct transitions. This also explains why lattice paths tend to collapse to a small number of dominant alignments later in training - the highest-probability paths receive the largest gradients, incentivizing further path concentration.

n.b. This "path collapse" is a key insight of the K2's RNN-T pruned loss which simplifies the gradient computation significantly by only considering paths near (in time) to the high-probability alignments.

The forward-backward algorithm computes all of this in {% katex inline %}O(T \cdot U){% endkatex %} time.


TDT Training: The Modified Forward-Backward

Training TDT requires modifying the forward-backward algorithm to account for the duration variable. The loss is still the negative log-likelihood {% katex inline %}-\log P(\mathbf{y} \mid \mathbf{X}){% endkatex %}, but the lattice transitions are now richer. Recall that in TDT, transitions can skip multiple frames:

  • Blank with duration {% katex inline %}d{% endkatex %}: {% katex inline %}(t, u) \to (t+d, u){% endkatex %} - advances by {% katex inline %}d \geq 1{% endkatex %} frames
  • Token with duration {% katex inline %}d{% endkatex %}: {% katex inline %}(t, u) \to (t+d, u+1){% endkatex %} - advances by {% katex inline %}d \geq 0{% endkatex %} frames and emits a token

Note the asymmetry: blanks must have {% katex inline %}d \geq 1{% endkatex %} (you must advance at least one frame when emitting nothing), but tokens can have {% katex inline %}d = 0{% endkatex %} if 0 is in {% katex inline %}\mathcal{D}{% endkatex %} (emitting a token without advancing - useful for fast speech or multi-token emissions at a single frame). If {% katex inline %}\mathcal{D}{% endkatex %} doesn't include 0, every emission also advances at least one frame.

IMAGE 7 (1)

We now have two independent distributions predicted from each node:

{% katex display %} P(y_u, d \mid t, u) = P_T(y_u \mid t, u) \cdot P_D(d \mid t, u) {% endkatex %}

Modified Forward Variable

The forward variable {% katex inline %}\alpha(t, u){% endkatex %} now has a more complex recurrence. At each position {% katex inline %}(t, u){% endkatex %}, we must sum over all durations that could have led here:

{% katex display %}\alpha(t, u) = \underbrace{\sum_{d \in \mathcal{D}, d \geq 1} \alpha(t-d, u) \cdot P(\varnothing, d \mid t-d, u)}_{\text{blank transitions}} + \underbrace{\sum_{d \in \mathcal{D}} \alpha(t-d, u-1) \cdot P(y_u, d \mid t-d, u-1)}_{\text{token transitions}}{% endkatex %}

The key difference from standard RNN-T: instead of looking back exactly 1 step, we look back {% katex inline %}d{% endkatex %} steps for each duration in {% katex inline %}\mathcal{D}{% endkatex %}. This makes the forward pass {% katex inline %}O(T \cdot U \cdot |\mathcal{D}|){% endkatex %} instead of {% katex inline %}O(T \cdot U){% endkatex %} - a constant factor increase since {% katex inline %}|\mathcal{D}|{% endkatex %} is typically small (4–5 elements).

IMAGE 8 (1)

Backward Variable and Gradients

The backward variable {% katex inline %}\beta(t, u){% endkatex %} follows the same pattern but in reverse:

{% katex display %}\beta(t, u) = \sum_{d \in \mathcal{D}, d \geq 1} P(\varnothing, d \mid t, u) \cdot \beta(t+d, u) + \sum_{d \in \mathcal{D}} P(y_{u+1}, d \mid t, u) \cdot \beta(t+d, u+1){% endkatex %}

The gradient computation uses both {% katex inline %}\alpha{% endkatex %} and {% katex inline %}\beta{% endkatex %} in the standard way, summing over each possible duration for a given token prediction and scaling by the duration probabilities ({% katex inline %}P_D{% endkatex %}). For the token logit, the gradient at position {% katex inline %}(t, u, v){% endkatex %} is:

{% katex display %}\frac{\partial \mathcal{L}}{\partial P_T(v \mid t, u)} = -\sum_{(t', u') \in C_{t,u}} \frac{\alpha(t,u) \cdot \beta(t',u') \cdot P_D(t'-t \mid t, u)}{P(\mathbf{y} \mid \mathbf{X})}{% endkatex %}

where {% katex inline %}C_{t,u}{% endkatex %} is the set of reachable states from {% katex inline %}(t, u){% endkatex %}:

{% katex display %}C = \underbrace{{(t + d, u+1) \mid d \in {0 \ldots D}}}_{\text{next token } y_u} ;\cup; \underbrace{{(t + d, u) \mid d \in {1 \ldots D}}}_{\text{blank } \varnothing}{% endkatex %}

IMAGE 9 (1)

In this case, to count the paths affected by the chance of predicting e.g. "fox", we have 4 possible lattice transitions to count, and 3 possible transitions for the blank token {% katex inline %}\varnothing{% endkatex %}.

For the duration logits, the gradient at position {% katex inline %}(t, u, d){% endkatex %} accounts for all transitions that use duration {% katex inline %}d{% endkatex %}, either the correct token or a blank transition:

{% katex display %}\frac{\partial \mathcal{L}}{\partial P_D(d_{>0} \mid t, u)} = - \frac{\alpha(t,u) \cdot \beta(t+d, u+1) \cdot P_T(y_{u+1} \mid t, u)}{P(\mathbf{y} \mid \mathbf{X})} - \frac{\alpha(t,u) \cdot \beta(t+d, u) \cdot P_T(\varnothing \mid t, u)}{P(\mathbf{y} \mid \mathbf{X})}{% endkatex %}

or for {% katex inline %}d=0{% endkatex %} (blank not allowed at zero duration):

{% katex display %}\frac{\partial \mathcal{L}}{\partial P_D(d=0 \mid t, u)} = - \frac{\alpha(t,u) \cdot \beta(t, u+1) \cdot P_T(y_{u+1} \mid t, u)}{P(\mathbf{y} \mid \mathbf{X})}{% endkatex %}

This too is somewhat intuitive. It represents the sum over all valid paths that use this duration. Now we're done! This is all the maths required to understand the efficient TDT training mechanics. For the full derivation see the TDT paper.

Some More Training Tricks

Working in Log-space: As is usual in machine learning when working with probabilities, we use log-space. Big summations of log probabilities are much more stable than big products of raw probabilities.

The Sigma Trick - Logit Under-Normalization: Every transition in the lattice, whether blank or token, gets penalized by {% katex inline %}\sigma{% endkatex %} (typically 0.05) in log-space. Since this penalty is applied per transition, paths with more steps accumulate a larger total penalty. This biases the model toward using fewer, larger-duration steps rather than many duration-1 steps.

The Omega Trick - Sampled RNN-T Loss: with probability {% katex inline %}\omega{% endkatex %}, the loss falls back to the standard RNN-T loss (ignoring durations entirely). This acts as a regularizer, ensuring the token predictions remain well-calibrated even without duration information. This is important for the batched inference case, where we will have to increment the entire batch encoder-frame by the same amount (e.g. the minimum predicted token duration).


Practical Considerations and Pitfalls

Training Memory

TDT has the same memory footprint challenge as standard RNN-T: the joint network output is a 4D tensor of shape {% katex inline %}(B, T, U, V + |\mathcal{D}|){% endkatex %}. For large vocabularies and long sequences, this can be enormous. The standard mitigation is fused loss computation - instead of materializing the full joint tensor, compute the loss and gradients in a fused kernel that only materializes one {% katex inline %}(t, u){% endkatex %} slice at a time. Also, it's typically important to keep the vocab-size small - the above example uses full words, but a smaller vocab of sub-words is usually preferable.

Duration Set Design

The choice of duration set {% katex inline %}\mathcal{D}{% endkatex %} matters. The paper uses {% katex inline %}{0, 1, 2, 3, 4}{% endkatex %} as the default. Some considerations:

  • Must include 1: Duration 1 is needed to recover standard single-frame advancement. Duration 0 is optional - it allows token emission without frame advancement (useful for fast speech), but some configurations omit it.
  • Larger durations = more skipping: The model learns when to use large skips vs. small ones. In practice, the model is conservative enough that large durations don't cause problems.
  • More durations = slightly slower training: The forward-backward complexity scales linearly with {% katex inline %}|\mathcal{D}|{% endkatex %}, though with typical set sizes (4–5 elements) this is a small constant factor.

Comparison with Multi-Blank Transducer

TDT is related to but distinct from the Multi-Blank Transducer, which adds multiple blank symbols (big-blank-2, big-blank-3, etc.) that skip different numbers of frames. The key difference:

Multi-BlankTDT
Duration predictionImplicit (via blank type)Explicit (separate head)
Token durationsAlways 0 (no frame skip on token)Variable (tokens can skip frames too)
Vocab size increase{% katex inline %}\lvert \mathcal{D} \rvert{% endkatex %} blank symbolsNo vocab increase; separate duration head
IndependenceToken and duration coupledToken and duration independently normalized

TDT's independent normalization means the model doesn't need to use vocabulary capacity on multiple blank symbols, and the duration prediction can be more fine-grained.


Summary

TDT extends RNN-T by jointly predicting tokens and their durations. The key ideas are:

  1. Two-headed joint network: independently predict token and duration distributions
  2. Variable-stride lattice: transitions can skip multiple frames, not just one
  3. Modified forward-backward: same algorithm structure, just summing over durations at each step
  4. Training tricks: logit under-normalization ({% katex inline %}\sigma{% endkatex %}) and sampled RNN-T loss ({% katex inline %}\omega{% endkatex %}) for stable training

The result: models that are up to 2.82x faster at inference with comparable or better accuracy than standard transducers - and RNN-T was already fast to begin with. This is how Nvidia's Parakeet-TDT models dominate the RTFx column at the top of the HuggingFace leaderboard.

The NeMo toolkit has a full implementation, and pretrained Parakeet-TDT checkpoints are available on HuggingFace.


References: