Skip to content
Snippets Groups Projects
Commit cdc120d5 authored by dronus's avatar dronus
Browse files

Fix for newer TF framework

parent d98291d2
Branches dev
No related tags found
No related merge requests found
......@@ -25,7 +25,7 @@ def top_k_logits(logits, k):
def top_p_logits(logits, p):
"""Nucleus sampling"""
batch, _ = logits.shape.as_list()
sorted_logits = tf.sort(logits, direction='DESCENDING', axis=-1)
sorted_logits = tf.contrib.framework.sort(logits, direction='DESCENDING', axis=-1)
cumulative_probs = tf.cumsum(tf.nn.softmax(sorted_logits, axis=-1), axis=-1)
indices = tf.stack([
tf.range(0, batch),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment