Skip to content

Fix bug for KerasTensor._keras_mask should be None#16689

Merged
copybara-service[bot] merged 1 commit intokeras-team:masterfrom
haifeng-jin:bug
Jun 16, 2022
Merged

Fix bug for KerasTensor._keras_mask should be None#16689
copybara-service[bot] merged 1 commit intokeras-team:masterfrom
haifeng-jin:bug

Conversation

@haifeng-jin
Copy link
Copy Markdown
Contributor

@haifeng-jin haifeng-jin commented Jun 16, 2022

When converting a Tensor to KerasTensor, if the ._keras_mask attribute is None, it converts the attribute to a KerasTensor(type_spec=NoneTensorSpec()), which confuses any code checking if ._keras_mask is None later.

This was found by a use case proposed by @ageron

encoder_inputs = tf.keras.layers.Input(shape=[], dtype=tf.string)
decoder_inputs = tf.keras.layers.Input(shape=[], dtype=tf.string)
embed_size = 128
vocab_size = 1000
text_vec_layer_en = tf.keras.layers.TextVectorization(vocab_size, ragged=True)
text_vec_layer_es = tf.keras.layers.TextVectorization(vocab_size, ragged=True)
encoder_input_ids = text_vec_layer_en(encoder_inputs)
decoder_input_ids = text_vec_layer_es(decoder_inputs)
encoder_embedding_layer = tf.keras.layers.Embedding(vocab_size, embed_size)
decoder_embedding_layer = tf.keras.layers.Embedding(vocab_size, embed_size)
encoder_embeddings = encoder_embedding_layer(encoder_input_ids)
decoder_embeddings = decoder_embedding_layer(decoder_input_ids)
encoder = tf.keras.layers.LSTM(512, return_state=True)
encoder_outputs, *encoder_state = encoder(encoder_embeddings)
decoder = tf.keras.layers.LSTM(512, name='bug1', return_sequences=True)
print(encoder_state[0]._keras_mask)
print(encoder_state[0]._keras_mask._to_placeholder())
print(encoder_state[1]._keras_mask)
decoder_outputs = decoder(keras.Input(shape=(32, 512)), initial_state=encoder_state)
decoder_outputs = decoder(decoder_embeddings, initial_state=encoder_state)
KerasTensor(type_spec=NoneTensorSpec())
None
KerasTensor(type_spec=NoneTensorSpec())
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
[<ipython-input-7-8c078ef42dc5>](https://localhost:8080/#) in <module>()
     32 print(encoder_state[1]._keras_mask)
     33 decoder_outputs = decoder(keras.Input(shape=(32, 512)), initial_state=encoder_state)
---> 34 decoder_outputs = decoder(decoder_embeddings, initial_state=encoder_state)

2 frames
[/usr/local/lib/python3.7/dist-packages/keras/layers/recurrent.py](https://localhost:8080/#) in _validate_args_if_ragged(self, is_ragged_input, mask)
    899 
    900     if mask is not None:
--> 901       raise ValueError(f'The mask that was passed in was {mask}, which '
    902                        'cannot be applied to RaggedTensor inputs. Please '
    903                        'make sure that there is no mask injected by upstream '

ValueError: Exception encountered when calling layer "bug1" (type LSTM).

The mask that was passed in was [None, None, None], which cannot be applied to RaggedTensor inputs. Please make sure that there is no mask injected by upstream layers.

Call arguments received:
  • inputs=['tf.RaggedTensor(values=Tensor("Placeholder:0", shape=(None, 128), dtype=float32), row_splits=Tensor("Placeholder_1:0", shape=(None,), dtype=int64))', 'tf.Tensor(shape=(None, 512), dtype=float32)', 'tf.Tensor(shape=(None, 512), dtype=float32)']
  • mask=['None', 'None', 'None']
  • training=None
  • initial_state=None

@haifeng-jin haifeng-jin added the ready to pull Ready to be merged into the codebase label Jun 16, 2022
@gbaned gbaned requested a review from qlzh727 June 16, 2022 11:05
@google-ml-butler google-ml-butler bot added the keras-team-review-pending Pending review by a Keras team member. label Jun 16, 2022
@mattdangerw mattdangerw self-requested a review June 16, 2022 15:34
@haifeng-jin haifeng-jin removed the keras-team-review-pending Pending review by a Keras team member. label Jun 16, 2022
Copy link
Copy Markdown
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks

@copybara-service copybara-service bot merged commit d38de48 into keras-team:master Jun 16, 2022
@haifeng-jin haifeng-jin deleted the bug branch September 7, 2022 20:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready to pull Ready to be merged into the codebase size:S

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants