description: A Trimmer that allocates a length budget to segments in order.
A Trimmer that allocates a length budget to segments in order.
Inherits From: Trimmer
text.WaterfallTrimmer(
max_seq_length, axis=-1
)
A Trimmer that allocates a length budget to segments in order. It selects
elements to drop, according to a max sequence length budget, and then applies
this mask to actually drop the elements. See generate_mask() for more details.
>>> a = tf.ragged.constant([['a', 'b', 'c'], [], ['d']])
>>> b = tf.ragged.constant([['1', '2', '3'], [], ['4', '5', '6', '7']])
>>> trimmer = tf_text.WaterfallTrimmer(4)
>>> trimmer.trim([a, b])
[<tf.RaggedTensor [[b'a', b'b', b'c'], [], [b'd']]>,
<tf.RaggedTensor [[b'1'], [], [b'4', b'5', b'6']]>]
Here, for the first pair of elements, ['a', 'b', 'c'] and ['1', '2', '3'],
the '2' and '3' are dropped to fit the sequence within the max sequence
length budget.
generate_mask(
segments
)
Calculates a truncation mask given a per-batch budget.
Calculate a truncation mask given a budget of the max number of items for each or all batch row. The allocation of the budget is done using a 'waterfall' algorithm. This algorithm allocates quota in a left-to-right manner and fill up the buckets until we run out of budget.
For example if the budget of [5] and we have segments of size [3, 4, 2], the truncate budget will be allocated as [3, 2, 0].
The budget can be a scalar, in which case the same budget is broadcasted
and applied to all batch rows. It can also be a 1D Tensor of size
batch_size, in which each batch row i will have a budget corresponding to
per_batch_quota[i].
>>> a = tf.ragged.constant([['a', 'b', 'c'], [], ['d']])
>>> b = tf.ragged.constant([['1', '2', '3'], [], ['4', '5', '6', '7']])
>>> trimmer = tf_text.WaterfallTrimmer(4)
>>> trimmer.generate_mask([a, b])
[<tf.RaggedTensor [[True, True, True], [], [True]]>,
<tf.RaggedTensor [[True, False, False], [], [True, True, True, False]]>]
| Args | |
|---|---|
| `segments` | A list of `RaggedTensor` each w/ a shape of [num_batch, (num_items)]. |
| Returns | |
|---|---|
| a list with len(segments) of `RaggedTensor`s, see superclass for details. |
trim(
segments
)
Truncate the list of segments.
Truncate the list of segments using the truncation strategy defined by
generate_mask.
| Args | |
|---|---|
| `segments` | A list of `RaggedTensor`s w/ shape [num_batch, (num_items)]. |
