diff --git a/src/sample.py b/src/sample.py
index c90ed28dc404ff14b4dcd284f747d2682ee3fd65..da83bd4ce7e2fb7694bac8c1038f4fd506d96808 100644
--- a/src/sample.py
+++ b/src/sample.py
@@ -25,7 +25,7 @@ def top_k_logits(logits, k):
 def top_p_logits(logits, p):
     """Nucleus sampling"""
     batch, _ = logits.shape.as_list()
-    sorted_logits = tf.sort(logits, direction='DESCENDING', axis=-1)
+    sorted_logits = tf.contrib.framework.sort(logits, direction='DESCENDING', axis=-1)
     cumulative_probs = tf.cumsum(tf.nn.softmax(sorted_logits, axis=-1), axis=-1)
     indices = tf.stack([
         tf.range(0, batch),