From cdc120d55fe1fe3cc0842f9dd5311a608dcbc7ef Mon Sep 17 00:00:00 2001 From: dronus <paul.geisler@web.de> Date: Sat, 1 Feb 2020 20:26:00 +0100 Subject: [PATCH] Fix for newer TF framework --- src/sample.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sample.py b/src/sample.py index c90ed28..da83bd4 100644 --- a/src/sample.py +++ b/src/sample.py @@ -25,7 +25,7 @@ def top_k_logits(logits, k): def top_p_logits(logits, p): """Nucleus sampling""" batch, _ = logits.shape.as_list() - sorted_logits = tf.sort(logits, direction='DESCENDING', axis=-1) + sorted_logits = tf.contrib.framework.sort(logits, direction='DESCENDING', axis=-1) cumulative_probs = tf.cumsum(tf.nn.softmax(sorted_logits, axis=-1), axis=-1) indices = tf.stack([ tf.range(0, batch), -- GitLab