Skip to content
Snippets Groups Projects
Commit 2cf46d99 authored by Ignacio Lopez-Francos's avatar Ignacio Lopez-Francos Committed by Jeff Wu
Browse files

fixed unconditional sampling reproducibility issue

parent 99af6d70
Branches
No related tags found
No related merge requests found
...@@ -17,9 +17,6 @@ def sample_model( ...@@ -17,9 +17,6 @@ def sample_model(
temperature=1, temperature=1,
top_k=0, top_k=0,
): ):
np.random.seed(seed)
tf.set_random_seed(seed)
enc = encoder.get_encoder(model_name) enc = encoder.get_encoder(model_name)
hparams = model.default_hparams() hparams = model.default_hparams()
with open(os.path.join('models', model_name, 'hparams.json')) as f: with open(os.path.join('models', model_name, 'hparams.json')) as f:
...@@ -31,6 +28,9 @@ def sample_model( ...@@ -31,6 +28,9 @@ def sample_model(
raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx) raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx)
with tf.Session(graph=tf.Graph()) as sess: with tf.Session(graph=tf.Graph()) as sess:
np.random.seed(seed)
tf.set_random_seed(seed)
output = sample.sample_sequence( output = sample.sample_sequence(
hparams=hparams, length=length, hparams=hparams, length=length,
start_token=enc.encoder['<|endoftext|>'], start_token=enc.encoder['<|endoftext|>'],
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment