Skip to content
Snippets Groups Projects
Commit ac5d5229 authored by Jeff Wu's avatar Jeff Wu
Browse files

nucleus sampling

parent f35fa1d9
No related branches found
No related tags found
No related merge requests found
......@@ -16,6 +16,7 @@ def sample_model(
length=None,
temperature=1,
top_k=0,
top_p=1,
models_dir='models',
):
"""
......@@ -58,7 +59,7 @@ def sample_model(
hparams=hparams, length=length,
start_token=enc.encoder['<|endoftext|>'],
batch_size=batch_size,
temperature=temperature, top_k=top_k
temperature=temperature, top_k=top_k, top_p=top_p
)[:, 1:]
saver = tf.train.Saver()
......
......@@ -16,6 +16,7 @@ def interact_model(
length=None,
temperature=1,
top_k=0,
top_p=1,
models_dir='models',
):
"""
......@@ -61,7 +62,7 @@ def interact_model(
hparams=hparams, length=length,
context=context,
batch_size=batch_size,
temperature=temperature, top_k=top_k
temperature=temperature, top_k=top_k, top_p=top_p
)
saver = tf.train.Saver()
......
......@@ -22,7 +22,25 @@ def top_k_logits(logits, k):
)
def sample_sequence(*, hparams, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0):
def top_p_logits(logits, p):
"""Nucleus sampling"""
batch, _ = logits.shape.as_list()
sorted_logits = tf.sort(logits, direction='DESCENDING', axis=-1)
cumulative_probs = tf.cumsum(tf.nn.softmax(sorted_logits, axis=-1), axis=-1)
indices = tf.stack([
tf.range(0, batch),
# number of indices to include
tf.maximum(tf.reduce_sum(tf.cast(cumulative_probs <= p, tf.int32), axis=-1) - 1, 0),
], axis=-1)
min_values = tf.gather_nd(sorted_logits, indices)
return tf.where(
logits < min_values,
tf.ones_like(logits) * -1e10,
logits,
)
def sample_sequence(*, hparams, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0, top_p=1):
if start_token is None:
assert context is not None, 'Specify exactly one of start_token and context!'
else:
......@@ -45,6 +63,7 @@ def sample_sequence(*, hparams, length, start_token=None, batch_size=None, conte
next_outputs = step(hparams, prev, past=past)
logits = next_outputs['logits'][:, -1, :] / tf.to_float(temperature)
logits = top_k_logits(logits, k=top_k)
logits = top_p_logits(logits, p=top_p)
samples = tf.multinomial(logits, num_samples=1, output_dtype=tf.int32)
return [
next_outputs['presents'] if past is None else tf.concat([past, next_outputs['presents']], axis=-2),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment