diff --git a/src/sample.py b/src/sample.py index c90ed28dc404ff14b4dcd284f747d2682ee3fd65..da83bd4ce7e2fb7694bac8c1038f4fd506d96808 100644 --- a/src/sample.py +++ b/src/sample.py @@ -25,7 +25,7 @@ def top_k_logits(logits, k): def top_p_logits(logits, p): """Nucleus sampling""" batch, _ = logits.shape.as_list() - sorted_logits = tf.sort(logits, direction='DESCENDING', axis=-1) + sorted_logits = tf.contrib.framework.sort(logits, direction='DESCENDING', axis=-1) cumulative_probs = tf.cumsum(tf.nn.softmax(sorted_logits, axis=-1), axis=-1) indices = tf.stack([ tf.range(0, batch),