From cdc120d55fe1fe3cc0842f9dd5311a608dcbc7ef Mon Sep 17 00:00:00 2001
From: dronus <paul.geisler@web.de>
Date: Sat, 1 Feb 2020 20:26:00 +0100
Subject: [PATCH] Fix for newer TF framework

---
 src/sample.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/src/sample.py b/src/sample.py
index c90ed28..da83bd4 100644
--- a/src/sample.py
+++ b/src/sample.py
@@ -25,7 +25,7 @@ def top_k_logits(logits, k):
 def top_p_logits(logits, p):
     """Nucleus sampling"""
     batch, _ = logits.shape.as_list()
-    sorted_logits = tf.sort(logits, direction='DESCENDING', axis=-1)
+    sorted_logits = tf.contrib.framework.sort(logits, direction='DESCENDING', axis=-1)
     cumulative_probs = tf.cumsum(tf.nn.softmax(sorted_logits, axis=-1), axis=-1)
     indices = tf.stack([
         tf.range(0, batch),
-- 
GitLab