Add
This commit is contained in:
parent
a0080d16af
commit
58a53148d4
116
.gitignore
vendored
Normal file
116
.gitignore
vendored
Normal file
@ -0,0 +1,116 @@
|
|||||||
|
# Initially taken from Github's Python gitignore file
|
||||||
|
|
||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
.python-version
|
||||||
|
|
||||||
|
# celery beat schedule file
|
||||||
|
celerybeat-schedule
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
5
.idea/.gitignore
generated
vendored
Normal file
5
.idea/.gitignore
generated
vendored
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
|
||||||
|
# Default ignored files
|
||||||
|
/workspace.xml
|
||||||
|
# Datasource local storage ignored files
|
||||||
|
/dataSources.local.xml
|
11
.idea/bert.iml
generated
Normal file
11
.idea/bert.iml
generated
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<module type="PYTHON_MODULE" version="4">
|
||||||
|
<component name="NewModuleRootManager">
|
||||||
|
<content url="file://$MODULE_DIR$" />
|
||||||
|
<orderEntry type="jdk" jdkName="Python 3.6 (tensorflow-gpu-bertenv)" jdkType="Python SDK" />
|
||||||
|
<orderEntry type="sourceFolder" forTests="false" />
|
||||||
|
</component>
|
||||||
|
<component name="TestRunnerService">
|
||||||
|
<option name="PROJECT_TEST_RUNNER" value="Unittests" />
|
||||||
|
</component>
|
||||||
|
</module>
|
11
.idea/dataSources.xml
generated
Normal file
11
.idea/dataSources.xml
generated
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="DataSourceManagerImpl" format="xml" multifile-model="true">
|
||||||
|
<data-source source="LOCAL" name="bptdata" uuid="3fe6f6a7-5596-43ed-98b3-5f9a30a24452">
|
||||||
|
<driver-ref>sqlite.xerial</driver-ref>
|
||||||
|
<synchronize>true</synchronize>
|
||||||
|
<jdbc-driver>org.sqlite.JDBC</jdbc-driver>
|
||||||
|
<jdbc-url>jdbc:sqlite:C:\Users\Administrator\Documents\GitHub\bert\bptdata.db</jdbc-url>
|
||||||
|
</data-source>
|
||||||
|
</component>
|
||||||
|
</project>
|
22
.idea/dictionaries/Administrator.xml
generated
Normal file
22
.idea/dictionaries/Administrator.xml
generated
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
<component name="ProjectDictionaryState">
|
||||||
|
<dictionary name="Administrator">
|
||||||
|
<words>
|
||||||
|
<w>amki</w>
|
||||||
|
<w>asctime</w>
|
||||||
|
<w>badrequest</w>
|
||||||
|
<w>bptdata</w>
|
||||||
|
<w>codedream</w>
|
||||||
|
<w>epaper</w>
|
||||||
|
<w>epout</w>
|
||||||
|
<w>eppdt</w>
|
||||||
|
<w>eppdtout</w>
|
||||||
|
<w>eppredict</w>
|
||||||
|
<w>idcode</w>
|
||||||
|
<w>levelname</w>
|
||||||
|
<w>nlpdata</w>
|
||||||
|
<w>sckstn</w>
|
||||||
|
<w>stnid</w>
|
||||||
|
<w>stns</w>
|
||||||
|
</words>
|
||||||
|
</dictionary>
|
||||||
|
</component>
|
6
.idea/inspectionProfiles/profiles_settings.xml
generated
Normal file
6
.idea/inspectionProfiles/profiles_settings.xml
generated
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
<component name="InspectionProjectProfileManager">
|
||||||
|
<settings>
|
||||||
|
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||||
|
<version value="1.0" />
|
||||||
|
</settings>
|
||||||
|
</component>
|
13
.idea/misc.xml
generated
Normal file
13
.idea/misc.xml
generated
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="JavaScriptSettings">
|
||||||
|
<option name="languageLevel" value="ES6" />
|
||||||
|
</component>
|
||||||
|
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.6 (tensorflow-gpu-bertenv)" project-jdk-type="Python SDK" />
|
||||||
|
<component name="PyPackaging">
|
||||||
|
<option name="earlyReleasesAsUpgrades" value="true" />
|
||||||
|
</component>
|
||||||
|
<component name="PythonCompatibilityInspectionAdvertiser">
|
||||||
|
<option name="version" value="3" />
|
||||||
|
</component>
|
||||||
|
</project>
|
8
.idea/modules.xml
generated
Normal file
8
.idea/modules.xml
generated
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="ProjectModuleManager">
|
||||||
|
<modules>
|
||||||
|
<module fileurl="file://$PROJECT_DIR$/.idea/bert.iml" filepath="$PROJECT_DIR$/.idea/bert.iml" />
|
||||||
|
</modules>
|
||||||
|
</component>
|
||||||
|
</project>
|
7
.idea/other.xml
generated
Normal file
7
.idea/other.xml
generated
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="PySciProjectComponent">
|
||||||
|
<option name="PY_SCI_VIEW" value="true" />
|
||||||
|
<option name="PY_SCI_VIEW_SUGGESTED" value="true" />
|
||||||
|
</component>
|
||||||
|
</project>
|
7
.idea/sqldialects.xml
generated
Normal file
7
.idea/sqldialects.xml
generated
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="SqlDialectMappings">
|
||||||
|
<file url="file://$PROJECT_DIR$/server.py" dialect="GenericSQL" />
|
||||||
|
<file url="PROJECT" dialect="SQLite" />
|
||||||
|
</component>
|
||||||
|
</project>
|
6
.idea/vcs.xml
generated
Normal file
6
.idea/vcs.xml
generated
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="VcsDirectoryMappings">
|
||||||
|
<mapping directory="$PROJECT_DIR$" vcs="Git" />
|
||||||
|
</component>
|
||||||
|
</project>
|
15
__init__.py
Normal file
15
__init__.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2018 The Google AI Language Team Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
BIN
bptdata.db
Normal file
BIN
bptdata.db
Normal file
Binary file not shown.
19
chinese_wwm_ext_L-12_H-768_A-12/bert_config.json
Normal file
19
chinese_wwm_ext_L-12_H-768_A-12/bert_config.json
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
{
|
||||||
|
"attention_probs_dropout_prob": 0.1,
|
||||||
|
"directionality": "bidi",
|
||||||
|
"hidden_act": "gelu",
|
||||||
|
"hidden_dropout_prob": 0.1,
|
||||||
|
"hidden_size": 768,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"intermediate_size": 3072,
|
||||||
|
"max_position_embeddings": 512,
|
||||||
|
"num_attention_heads": 12,
|
||||||
|
"num_hidden_layers": 12,
|
||||||
|
"pooler_fc_size": 768,
|
||||||
|
"pooler_num_attention_heads": 12,
|
||||||
|
"pooler_num_fc_layers": 3,
|
||||||
|
"pooler_size_per_head": 128,
|
||||||
|
"pooler_type": "first_token_transform",
|
||||||
|
"type_vocab_size": 2,
|
||||||
|
"vocab_size": 21128
|
||||||
|
}
|
BIN
chinese_wwm_ext_L-12_H-768_A-12/bert_model.ckpt.index
Normal file
BIN
chinese_wwm_ext_L-12_H-768_A-12/bert_model.ckpt.index
Normal file
Binary file not shown.
BIN
chinese_wwm_ext_L-12_H-768_A-12/bert_model.ckpt.meta
Normal file
BIN
chinese_wwm_ext_L-12_H-768_A-12/bert_model.ckpt.meta
Normal file
Binary file not shown.
21128
chinese_wwm_ext_L-12_H-768_A-12/vocab.txt
Normal file
21128
chinese_wwm_ext_L-12_H-768_A-12/vocab.txt
Normal file
File diff suppressed because it is too large
Load Diff
469
create_pretraining_data.py
Normal file
469
create_pretraining_data.py
Normal file
@ -0,0 +1,469 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2018 The Google AI Language Team Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Create masked LM/next sentence masked_lm TF examples for BERT."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import collections
|
||||||
|
import random
|
||||||
|
import tokenization
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
flags = tf.flags
|
||||||
|
|
||||||
|
FLAGS = flags.FLAGS
|
||||||
|
|
||||||
|
flags.DEFINE_string("input_file", None,
|
||||||
|
"Input raw text file (or comma-separated list of files).")
|
||||||
|
|
||||||
|
flags.DEFINE_string(
|
||||||
|
"output_file", None,
|
||||||
|
"Output TF example file (or comma-separated list of files).")
|
||||||
|
|
||||||
|
flags.DEFINE_string("vocab_file", None,
|
||||||
|
"The vocabulary file that the BERT model was trained on.")
|
||||||
|
|
||||||
|
flags.DEFINE_bool(
|
||||||
|
"do_lower_case", True,
|
||||||
|
"Whether to lower case the input text. Should be True for uncased "
|
||||||
|
"models and False for cased models.")
|
||||||
|
|
||||||
|
flags.DEFINE_bool(
|
||||||
|
"do_whole_word_mask", False,
|
||||||
|
"Whether to use whole word masking rather than per-WordPiece masking.")
|
||||||
|
|
||||||
|
flags.DEFINE_integer("max_seq_length", 128, "Maximum sequence length.")
|
||||||
|
|
||||||
|
flags.DEFINE_integer("max_predictions_per_seq", 20,
|
||||||
|
"Maximum number of masked LM predictions per sequence.")
|
||||||
|
|
||||||
|
flags.DEFINE_integer("random_seed", 12345, "Random seed for data generation.")
|
||||||
|
|
||||||
|
flags.DEFINE_integer(
|
||||||
|
"dupe_factor", 10,
|
||||||
|
"Number of times to duplicate the input data (with different masks).")
|
||||||
|
|
||||||
|
flags.DEFINE_float("masked_lm_prob", 0.15, "Masked LM probability.")
|
||||||
|
|
||||||
|
flags.DEFINE_float(
|
||||||
|
"short_seq_prob", 0.1,
|
||||||
|
"Probability of creating sequences which are shorter than the "
|
||||||
|
"maximum length.")
|
||||||
|
|
||||||
|
|
||||||
|
class TrainingInstance(object):
|
||||||
|
"""A single training instance (sentence pair)."""
|
||||||
|
|
||||||
|
def __init__(self, tokens, segment_ids, masked_lm_positions, masked_lm_labels,
|
||||||
|
is_random_next):
|
||||||
|
self.tokens = tokens
|
||||||
|
self.segment_ids = segment_ids
|
||||||
|
self.is_random_next = is_random_next
|
||||||
|
self.masked_lm_positions = masked_lm_positions
|
||||||
|
self.masked_lm_labels = masked_lm_labels
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
s = ""
|
||||||
|
s += "tokens: %s\n" % (" ".join(
|
||||||
|
[tokenization.printable_text(x) for x in self.tokens]))
|
||||||
|
s += "segment_ids: %s\n" % (" ".join([str(x) for x in self.segment_ids]))
|
||||||
|
s += "is_random_next: %s\n" % self.is_random_next
|
||||||
|
s += "masked_lm_positions: %s\n" % (" ".join(
|
||||||
|
[str(x) for x in self.masked_lm_positions]))
|
||||||
|
s += "masked_lm_labels: %s\n" % (" ".join(
|
||||||
|
[tokenization.printable_text(x) for x in self.masked_lm_labels]))
|
||||||
|
s += "\n"
|
||||||
|
return s
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return self.__str__()
|
||||||
|
|
||||||
|
|
||||||
|
def write_instance_to_example_files(instances, tokenizer, max_seq_length,
|
||||||
|
max_predictions_per_seq, output_files):
|
||||||
|
"""Create TF example files from `TrainingInstance`s."""
|
||||||
|
writers = []
|
||||||
|
for output_file in output_files:
|
||||||
|
writers.append(tf.python_io.TFRecordWriter(output_file))
|
||||||
|
|
||||||
|
writer_index = 0
|
||||||
|
|
||||||
|
total_written = 0
|
||||||
|
for (inst_index, instance) in enumerate(instances):
|
||||||
|
input_ids = tokenizer.convert_tokens_to_ids(instance.tokens)
|
||||||
|
input_mask = [1] * len(input_ids)
|
||||||
|
segment_ids = list(instance.segment_ids)
|
||||||
|
assert len(input_ids) <= max_seq_length
|
||||||
|
|
||||||
|
while len(input_ids) < max_seq_length:
|
||||||
|
input_ids.append(0)
|
||||||
|
input_mask.append(0)
|
||||||
|
segment_ids.append(0)
|
||||||
|
|
||||||
|
assert len(input_ids) == max_seq_length
|
||||||
|
assert len(input_mask) == max_seq_length
|
||||||
|
assert len(segment_ids) == max_seq_length
|
||||||
|
|
||||||
|
masked_lm_positions = list(instance.masked_lm_positions)
|
||||||
|
masked_lm_ids = tokenizer.convert_tokens_to_ids(instance.masked_lm_labels)
|
||||||
|
masked_lm_weights = [1.0] * len(masked_lm_ids)
|
||||||
|
|
||||||
|
while len(masked_lm_positions) < max_predictions_per_seq:
|
||||||
|
masked_lm_positions.append(0)
|
||||||
|
masked_lm_ids.append(0)
|
||||||
|
masked_lm_weights.append(0.0)
|
||||||
|
|
||||||
|
next_sentence_label = 1 if instance.is_random_next else 0
|
||||||
|
|
||||||
|
features = collections.OrderedDict()
|
||||||
|
features["input_ids"] = create_int_feature(input_ids)
|
||||||
|
features["input_mask"] = create_int_feature(input_mask)
|
||||||
|
features["segment_ids"] = create_int_feature(segment_ids)
|
||||||
|
features["masked_lm_positions"] = create_int_feature(masked_lm_positions)
|
||||||
|
features["masked_lm_ids"] = create_int_feature(masked_lm_ids)
|
||||||
|
features["masked_lm_weights"] = create_float_feature(masked_lm_weights)
|
||||||
|
features["next_sentence_labels"] = create_int_feature([next_sentence_label])
|
||||||
|
|
||||||
|
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
|
||||||
|
|
||||||
|
writers[writer_index].write(tf_example.SerializeToString())
|
||||||
|
writer_index = (writer_index + 1) % len(writers)
|
||||||
|
|
||||||
|
total_written += 1
|
||||||
|
|
||||||
|
if inst_index < 20:
|
||||||
|
tf.logging.info("*** Example ***")
|
||||||
|
tf.logging.info("tokens: %s" % " ".join(
|
||||||
|
[tokenization.printable_text(x) for x in instance.tokens]))
|
||||||
|
|
||||||
|
for feature_name in features.keys():
|
||||||
|
feature = features[feature_name]
|
||||||
|
values = []
|
||||||
|
if feature.int64_list.value:
|
||||||
|
values = feature.int64_list.value
|
||||||
|
elif feature.float_list.value:
|
||||||
|
values = feature.float_list.value
|
||||||
|
tf.logging.info(
|
||||||
|
"%s: %s" % (feature_name, " ".join([str(x) for x in values])))
|
||||||
|
|
||||||
|
for writer in writers:
|
||||||
|
writer.close()
|
||||||
|
|
||||||
|
tf.logging.info("Wrote %d total instances", total_written)
|
||||||
|
|
||||||
|
|
||||||
|
def create_int_feature(values):
|
||||||
|
feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
|
||||||
|
return feature
|
||||||
|
|
||||||
|
|
||||||
|
def create_float_feature(values):
|
||||||
|
feature = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
|
||||||
|
return feature
|
||||||
|
|
||||||
|
|
||||||
|
def create_training_instances(input_files, tokenizer, max_seq_length,
|
||||||
|
dupe_factor, short_seq_prob, masked_lm_prob,
|
||||||
|
max_predictions_per_seq, rng):
|
||||||
|
"""Create `TrainingInstance`s from raw text."""
|
||||||
|
all_documents = [[]]
|
||||||
|
|
||||||
|
# Input file format:
|
||||||
|
# (1) One sentence per line. These should ideally be actual sentences, not
|
||||||
|
# entire paragraphs or arbitrary spans of text. (Because we use the
|
||||||
|
# sentence boundaries for the "next sentence prediction" task).
|
||||||
|
# (2) Blank lines between documents. Document boundaries are needed so
|
||||||
|
# that the "next sentence prediction" task doesn't span between documents.
|
||||||
|
for input_file in input_files:
|
||||||
|
with tf.gfile.GFile(input_file, "r") as reader:
|
||||||
|
while True:
|
||||||
|
line = tokenization.convert_to_unicode(reader.readline())
|
||||||
|
if not line:
|
||||||
|
break
|
||||||
|
line = line.strip()
|
||||||
|
|
||||||
|
# Empty lines are used as document delimiters
|
||||||
|
if not line:
|
||||||
|
all_documents.append([])
|
||||||
|
tokens = tokenizer.tokenize(line)
|
||||||
|
if tokens:
|
||||||
|
all_documents[-1].append(tokens)
|
||||||
|
|
||||||
|
# Remove empty documents
|
||||||
|
all_documents = [x for x in all_documents if x]
|
||||||
|
rng.shuffle(all_documents)
|
||||||
|
|
||||||
|
vocab_words = list(tokenizer.vocab.keys())
|
||||||
|
instances = []
|
||||||
|
for _ in range(dupe_factor):
|
||||||
|
for document_index in range(len(all_documents)):
|
||||||
|
instances.extend(
|
||||||
|
create_instances_from_document(
|
||||||
|
all_documents, document_index, max_seq_length, short_seq_prob,
|
||||||
|
masked_lm_prob, max_predictions_per_seq, vocab_words, rng))
|
||||||
|
|
||||||
|
rng.shuffle(instances)
|
||||||
|
return instances
|
||||||
|
|
||||||
|
|
||||||
|
def create_instances_from_document(
|
||||||
|
all_documents, document_index, max_seq_length, short_seq_prob,
|
||||||
|
masked_lm_prob, max_predictions_per_seq, vocab_words, rng):
|
||||||
|
"""Creates `TrainingInstance`s for a single document."""
|
||||||
|
document = all_documents[document_index]
|
||||||
|
|
||||||
|
# Account for [CLS], [SEP], [SEP]
|
||||||
|
max_num_tokens = max_seq_length - 3
|
||||||
|
|
||||||
|
# We *usually* want to fill up the entire sequence since we are padding
|
||||||
|
# to `max_seq_length` anyways, so short sequences are generally wasted
|
||||||
|
# computation. However, we *sometimes*
|
||||||
|
# (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter
|
||||||
|
# sequences to minimize the mismatch between pre-training and fine-tuning.
|
||||||
|
# The `target_seq_length` is just a rough target however, whereas
|
||||||
|
# `max_seq_length` is a hard limit.
|
||||||
|
target_seq_length = max_num_tokens
|
||||||
|
if rng.random() < short_seq_prob:
|
||||||
|
target_seq_length = rng.randint(2, max_num_tokens)
|
||||||
|
|
||||||
|
# We DON'T just concatenate all of the tokens from a document into a long
|
||||||
|
# sequence and choose an arbitrary split point because this would make the
|
||||||
|
# next sentence prediction task too easy. Instead, we split the input into
|
||||||
|
# segments "A" and "B" based on the actual "sentences" provided by the user
|
||||||
|
# input.
|
||||||
|
instances = []
|
||||||
|
current_chunk = []
|
||||||
|
current_length = 0
|
||||||
|
i = 0
|
||||||
|
while i < len(document):
|
||||||
|
segment = document[i]
|
||||||
|
current_chunk.append(segment)
|
||||||
|
current_length += len(segment)
|
||||||
|
if i == len(document) - 1 or current_length >= target_seq_length:
|
||||||
|
if current_chunk:
|
||||||
|
# `a_end` is how many segments from `current_chunk` go into the `A`
|
||||||
|
# (first) sentence.
|
||||||
|
a_end = 1
|
||||||
|
if len(current_chunk) >= 2:
|
||||||
|
a_end = rng.randint(1, len(current_chunk) - 1)
|
||||||
|
|
||||||
|
tokens_a = []
|
||||||
|
for j in range(a_end):
|
||||||
|
tokens_a.extend(current_chunk[j])
|
||||||
|
|
||||||
|
tokens_b = []
|
||||||
|
# Random next
|
||||||
|
is_random_next = False
|
||||||
|
if len(current_chunk) == 1 or rng.random() < 0.5:
|
||||||
|
is_random_next = True
|
||||||
|
target_b_length = target_seq_length - len(tokens_a)
|
||||||
|
|
||||||
|
# This should rarely go for more than one iteration for large
|
||||||
|
# corpora. However, just to be careful, we try to make sure that
|
||||||
|
# the random document is not the same as the document
|
||||||
|
# we're processing.
|
||||||
|
for _ in range(10):
|
||||||
|
random_document_index = rng.randint(0, len(all_documents) - 1)
|
||||||
|
if random_document_index != document_index:
|
||||||
|
break
|
||||||
|
|
||||||
|
random_document = all_documents[random_document_index]
|
||||||
|
random_start = rng.randint(0, len(random_document) - 1)
|
||||||
|
for j in range(random_start, len(random_document)):
|
||||||
|
tokens_b.extend(random_document[j])
|
||||||
|
if len(tokens_b) >= target_b_length:
|
||||||
|
break
|
||||||
|
# We didn't actually use these segments so we "put them back" so
|
||||||
|
# they don't go to waste.
|
||||||
|
num_unused_segments = len(current_chunk) - a_end
|
||||||
|
i -= num_unused_segments
|
||||||
|
# Actual next
|
||||||
|
else:
|
||||||
|
is_random_next = False
|
||||||
|
for j in range(a_end, len(current_chunk)):
|
||||||
|
tokens_b.extend(current_chunk[j])
|
||||||
|
truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng)
|
||||||
|
|
||||||
|
assert len(tokens_a) >= 1
|
||||||
|
assert len(tokens_b) >= 1
|
||||||
|
|
||||||
|
tokens = []
|
||||||
|
segment_ids = []
|
||||||
|
tokens.append("[CLS]")
|
||||||
|
segment_ids.append(0)
|
||||||
|
for token in tokens_a:
|
||||||
|
tokens.append(token)
|
||||||
|
segment_ids.append(0)
|
||||||
|
|
||||||
|
tokens.append("[SEP]")
|
||||||
|
segment_ids.append(0)
|
||||||
|
|
||||||
|
for token in tokens_b:
|
||||||
|
tokens.append(token)
|
||||||
|
segment_ids.append(1)
|
||||||
|
tokens.append("[SEP]")
|
||||||
|
segment_ids.append(1)
|
||||||
|
|
||||||
|
(tokens, masked_lm_positions,
|
||||||
|
masked_lm_labels) = create_masked_lm_predictions(
|
||||||
|
tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng)
|
||||||
|
instance = TrainingInstance(
|
||||||
|
tokens=tokens,
|
||||||
|
segment_ids=segment_ids,
|
||||||
|
is_random_next=is_random_next,
|
||||||
|
masked_lm_positions=masked_lm_positions,
|
||||||
|
masked_lm_labels=masked_lm_labels)
|
||||||
|
instances.append(instance)
|
||||||
|
current_chunk = []
|
||||||
|
current_length = 0
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
return instances
|
||||||
|
|
||||||
|
|
||||||
|
MaskedLmInstance = collections.namedtuple("MaskedLmInstance",
|
||||||
|
["index", "label"])
|
||||||
|
|
||||||
|
|
||||||
|
def create_masked_lm_predictions(tokens, masked_lm_prob,
|
||||||
|
max_predictions_per_seq, vocab_words, rng):
|
||||||
|
"""Creates the predictions for the masked LM objective."""
|
||||||
|
|
||||||
|
cand_indexes = []
|
||||||
|
for (i, token) in enumerate(tokens):
|
||||||
|
if token == "[CLS]" or token == "[SEP]":
|
||||||
|
continue
|
||||||
|
# Whole Word Masking means that if we mask all of the wordpieces
|
||||||
|
# corresponding to an original word. When a word has been split into
|
||||||
|
# WordPieces, the first token does not have any marker and any subsequence
|
||||||
|
# tokens are prefixed with ##. So whenever we see the ## token, we
|
||||||
|
# append it to the previous set of word indexes.
|
||||||
|
#
|
||||||
|
# Note that Whole Word Masking does *not* change the training code
|
||||||
|
# at all -- we still predict each WordPiece independently, softmaxed
|
||||||
|
# over the entire vocabulary.
|
||||||
|
if (FLAGS.do_whole_word_mask and len(cand_indexes) >= 1 and
|
||||||
|
token.startswith("##")):
|
||||||
|
cand_indexes[-1].append(i)
|
||||||
|
else:
|
||||||
|
cand_indexes.append([i])
|
||||||
|
|
||||||
|
rng.shuffle(cand_indexes)
|
||||||
|
|
||||||
|
output_tokens = list(tokens)
|
||||||
|
|
||||||
|
num_to_predict = min(max_predictions_per_seq,
|
||||||
|
max(1, int(round(len(tokens) * masked_lm_prob))))
|
||||||
|
|
||||||
|
masked_lms = []
|
||||||
|
covered_indexes = set()
|
||||||
|
for index_set in cand_indexes:
|
||||||
|
if len(masked_lms) >= num_to_predict:
|
||||||
|
break
|
||||||
|
# If adding a whole-word mask would exceed the maximum number of
|
||||||
|
# predictions, then just skip this candidate.
|
||||||
|
if len(masked_lms) + len(index_set) > num_to_predict:
|
||||||
|
continue
|
||||||
|
is_any_index_covered = False
|
||||||
|
for index in index_set:
|
||||||
|
if index in covered_indexes:
|
||||||
|
is_any_index_covered = True
|
||||||
|
break
|
||||||
|
if is_any_index_covered:
|
||||||
|
continue
|
||||||
|
for index in index_set:
|
||||||
|
covered_indexes.add(index)
|
||||||
|
|
||||||
|
masked_token = None
|
||||||
|
# 80% of the time, replace with [MASK]
|
||||||
|
if rng.random() < 0.8:
|
||||||
|
masked_token = "[MASK]"
|
||||||
|
else:
|
||||||
|
# 10% of the time, keep original
|
||||||
|
if rng.random() < 0.5:
|
||||||
|
masked_token = tokens[index]
|
||||||
|
# 10% of the time, replace with random word
|
||||||
|
else:
|
||||||
|
masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)]
|
||||||
|
|
||||||
|
output_tokens[index] = masked_token
|
||||||
|
|
||||||
|
masked_lms.append(MaskedLmInstance(index=index, label=tokens[index]))
|
||||||
|
assert len(masked_lms) <= num_to_predict
|
||||||
|
masked_lms = sorted(masked_lms, key=lambda x: x.index)
|
||||||
|
|
||||||
|
masked_lm_positions = []
|
||||||
|
masked_lm_labels = []
|
||||||
|
for p in masked_lms:
|
||||||
|
masked_lm_positions.append(p.index)
|
||||||
|
masked_lm_labels.append(p.label)
|
||||||
|
|
||||||
|
return (output_tokens, masked_lm_positions, masked_lm_labels)
|
||||||
|
|
||||||
|
|
||||||
|
def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng):
|
||||||
|
"""Truncates a pair of sequences to a maximum sequence length."""
|
||||||
|
while True:
|
||||||
|
total_length = len(tokens_a) + len(tokens_b)
|
||||||
|
if total_length <= max_num_tokens:
|
||||||
|
break
|
||||||
|
|
||||||
|
trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b
|
||||||
|
assert len(trunc_tokens) >= 1
|
||||||
|
|
||||||
|
# We want to sometimes truncate from the front and sometimes from the
|
||||||
|
# back to add more randomness and avoid biases.
|
||||||
|
if rng.random() < 0.5:
|
||||||
|
del trunc_tokens[0]
|
||||||
|
else:
|
||||||
|
trunc_tokens.pop()
|
||||||
|
|
||||||
|
|
||||||
|
def main(_):
|
||||||
|
tf.logging.set_verbosity(tf.logging.INFO)
|
||||||
|
|
||||||
|
tokenizer = tokenization.FullTokenizer(
|
||||||
|
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
|
||||||
|
|
||||||
|
input_files = []
|
||||||
|
for input_pattern in FLAGS.input_file.split(","):
|
||||||
|
input_files.extend(tf.gfile.Glob(input_pattern))
|
||||||
|
|
||||||
|
tf.logging.info("*** Reading from input files ***")
|
||||||
|
for input_file in input_files:
|
||||||
|
tf.logging.info(" %s", input_file)
|
||||||
|
|
||||||
|
rng = random.Random(FLAGS.random_seed)
|
||||||
|
instances = create_training_instances(
|
||||||
|
input_files, tokenizer, FLAGS.max_seq_length, FLAGS.dupe_factor,
|
||||||
|
FLAGS.short_seq_prob, FLAGS.masked_lm_prob, FLAGS.max_predictions_per_seq,
|
||||||
|
rng)
|
||||||
|
|
||||||
|
output_files = FLAGS.output_file.split(",")
|
||||||
|
tf.logging.info("*** Writing to output files ***")
|
||||||
|
for output_file in output_files:
|
||||||
|
tf.logging.info(" %s", output_file)
|
||||||
|
|
||||||
|
write_instance_to_example_files(instances, tokenizer, FLAGS.max_seq_length,
|
||||||
|
FLAGS.max_predictions_per_seq, output_files)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
flags.mark_flag_as_required("input_file")
|
||||||
|
flags.mark_flag_as_required("output_file")
|
||||||
|
flags.mark_flag_as_required("vocab_file")
|
||||||
|
tf.app.run()
|
49
dealing_dataset.py
Normal file
49
dealing_dataset.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
import sqlite3
|
||||||
|
|
||||||
|
conn = sqlite3.connect(r"nlpdata.db")\
|
||||||
|
|
||||||
|
|
||||||
|
def create_dataset_ep(table):
|
||||||
|
cursor = conn.cursor()
|
||||||
|
sql = "select * from " + table + " LIMIT 20"
|
||||||
|
cursor.execute(sql)
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
dataset = []
|
||||||
|
|
||||||
|
for row in cursor:
|
||||||
|
eid = row[0]
|
||||||
|
tag = row[1]
|
||||||
|
content = row[2]
|
||||||
|
if tag == "5" or tag == "4":
|
||||||
|
dataset.append([eid, 2, content])
|
||||||
|
print(eid, 2, content)
|
||||||
|
elif tag == "1" or tag == "2":
|
||||||
|
dataset.append([eid, 0, content])
|
||||||
|
print(eid, 0, content)
|
||||||
|
else:
|
||||||
|
dataset.append([eid, 1, content])
|
||||||
|
print(eid, 1, content)
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
def create_dataset_pdt():
|
||||||
|
conn_pdt = sqlite3.connect(r".\bptdata.db")
|
||||||
|
cursor = conn_pdt.cursor()
|
||||||
|
sql = "select * from " + "predict_data"
|
||||||
|
cursor.execute(sql)
|
||||||
|
conn_pdt.commit()
|
||||||
|
|
||||||
|
dataset = []
|
||||||
|
|
||||||
|
for row in cursor:
|
||||||
|
stnid = row[0]
|
||||||
|
text = row[1]
|
||||||
|
dataset.append([stnid, 0, text])
|
||||||
|
print(stnid, 0, text)
|
||||||
|
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
print(create_dataset_ep("amki_test"))
|
419
extract_features.py
Normal file
419
extract_features.py
Normal file
@ -0,0 +1,419 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2018 The Google AI Language Team Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Extract pre-computed feature vectors from BERT."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import codecs
|
||||||
|
import collections
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
|
||||||
|
import modeling
|
||||||
|
import tokenization
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
flags = tf.flags
|
||||||
|
|
||||||
|
FLAGS = flags.FLAGS
|
||||||
|
|
||||||
|
flags.DEFINE_string("input_file", None, "")
|
||||||
|
|
||||||
|
flags.DEFINE_string("output_file", None, "")
|
||||||
|
|
||||||
|
flags.DEFINE_string("layers", "-1,-2,-3,-4", "")
|
||||||
|
|
||||||
|
flags.DEFINE_string(
|
||||||
|
"bert_config_file", None,
|
||||||
|
"The config json file corresponding to the pre-trained BERT model. "
|
||||||
|
"This specifies the model architecture.")
|
||||||
|
|
||||||
|
flags.DEFINE_integer(
|
||||||
|
"max_seq_length", 128,
|
||||||
|
"The maximum total input sequence length after WordPiece tokenization. "
|
||||||
|
"Sequences longer than this will be truncated, and sequences shorter "
|
||||||
|
"than this will be padded.")
|
||||||
|
|
||||||
|
flags.DEFINE_string(
|
||||||
|
"init_checkpoint", None,
|
||||||
|
"Initial checkpoint (usually from a pre-trained BERT model).")
|
||||||
|
|
||||||
|
flags.DEFINE_string("vocab_file", None,
|
||||||
|
"The vocabulary file that the BERT model was trained on.")
|
||||||
|
|
||||||
|
flags.DEFINE_bool(
|
||||||
|
"do_lower_case", True,
|
||||||
|
"Whether to lower case the input text. Should be True for uncased "
|
||||||
|
"models and False for cased models.")
|
||||||
|
|
||||||
|
flags.DEFINE_integer("batch_size", 32, "Batch size for predictions.")
|
||||||
|
|
||||||
|
flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.")
|
||||||
|
|
||||||
|
flags.DEFINE_string("master", None,
|
||||||
|
"If using a TPU, the address of the master.")
|
||||||
|
|
||||||
|
flags.DEFINE_integer(
|
||||||
|
"num_tpu_cores", 8,
|
||||||
|
"Only used if `use_tpu` is True. Total number of TPU cores to use.")
|
||||||
|
|
||||||
|
flags.DEFINE_bool(
|
||||||
|
"use_one_hot_embeddings", False,
|
||||||
|
"If True, tf.one_hot will be used for embedding lookups, otherwise "
|
||||||
|
"tf.nn.embedding_lookup will be used. On TPUs, this should be True "
|
||||||
|
"since it is much faster.")
|
||||||
|
|
||||||
|
|
||||||
|
class InputExample(object):
|
||||||
|
|
||||||
|
def __init__(self, unique_id, text_a, text_b):
|
||||||
|
self.unique_id = unique_id
|
||||||
|
self.text_a = text_a
|
||||||
|
self.text_b = text_b
|
||||||
|
|
||||||
|
|
||||||
|
class InputFeatures(object):
|
||||||
|
"""A single set of features of data."""
|
||||||
|
|
||||||
|
def __init__(self, unique_id, tokens, input_ids, input_mask, input_type_ids):
|
||||||
|
self.unique_id = unique_id
|
||||||
|
self.tokens = tokens
|
||||||
|
self.input_ids = input_ids
|
||||||
|
self.input_mask = input_mask
|
||||||
|
self.input_type_ids = input_type_ids
|
||||||
|
|
||||||
|
|
||||||
|
def input_fn_builder(features, seq_length):
|
||||||
|
"""Creates an `input_fn` closure to be passed to TPUEstimator."""
|
||||||
|
|
||||||
|
all_unique_ids = []
|
||||||
|
all_input_ids = []
|
||||||
|
all_input_mask = []
|
||||||
|
all_input_type_ids = []
|
||||||
|
|
||||||
|
for feature in features:
|
||||||
|
all_unique_ids.append(feature.unique_id)
|
||||||
|
all_input_ids.append(feature.input_ids)
|
||||||
|
all_input_mask.append(feature.input_mask)
|
||||||
|
all_input_type_ids.append(feature.input_type_ids)
|
||||||
|
|
||||||
|
def input_fn(params):
|
||||||
|
"""The actual input function."""
|
||||||
|
batch_size = params["batch_size"]
|
||||||
|
|
||||||
|
num_examples = len(features)
|
||||||
|
|
||||||
|
# This is for demo purposes and does NOT scale to large data sets. We do
|
||||||
|
# not use Dataset.from_generator() because that uses tf.py_func which is
|
||||||
|
# not TPU compatible. The right way to load data is with TFRecordReader.
|
||||||
|
d = tf.data.Dataset.from_tensor_slices({
|
||||||
|
"unique_ids":
|
||||||
|
tf.constant(all_unique_ids, shape=[num_examples], dtype=tf.int32),
|
||||||
|
"input_ids":
|
||||||
|
tf.constant(
|
||||||
|
all_input_ids, shape=[num_examples, seq_length],
|
||||||
|
dtype=tf.int32),
|
||||||
|
"input_mask":
|
||||||
|
tf.constant(
|
||||||
|
all_input_mask,
|
||||||
|
shape=[num_examples, seq_length],
|
||||||
|
dtype=tf.int32),
|
||||||
|
"input_type_ids":
|
||||||
|
tf.constant(
|
||||||
|
all_input_type_ids,
|
||||||
|
shape=[num_examples, seq_length],
|
||||||
|
dtype=tf.int32),
|
||||||
|
})
|
||||||
|
|
||||||
|
d = d.batch(batch_size=batch_size, drop_remainder=False)
|
||||||
|
return d
|
||||||
|
|
||||||
|
return input_fn
|
||||||
|
|
||||||
|
|
||||||
|
def model_fn_builder(bert_config, init_checkpoint, layer_indexes, use_tpu,
|
||||||
|
use_one_hot_embeddings):
|
||||||
|
"""Returns `model_fn` closure for TPUEstimator."""
|
||||||
|
|
||||||
|
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
|
||||||
|
"""The `model_fn` for TPUEstimator."""
|
||||||
|
|
||||||
|
unique_ids = features["unique_ids"]
|
||||||
|
input_ids = features["input_ids"]
|
||||||
|
input_mask = features["input_mask"]
|
||||||
|
input_type_ids = features["input_type_ids"]
|
||||||
|
|
||||||
|
model = modeling.BertModel(
|
||||||
|
config=bert_config,
|
||||||
|
is_training=False,
|
||||||
|
input_ids=input_ids,
|
||||||
|
input_mask=input_mask,
|
||||||
|
token_type_ids=input_type_ids,
|
||||||
|
use_one_hot_embeddings=use_one_hot_embeddings)
|
||||||
|
|
||||||
|
if mode != tf.estimator.ModeKeys.PREDICT:
|
||||||
|
raise ValueError("Only PREDICT modes are supported: %s" % (mode))
|
||||||
|
|
||||||
|
tvars = tf.trainable_variables()
|
||||||
|
scaffold_fn = None
|
||||||
|
(assignment_map,
|
||||||
|
initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(
|
||||||
|
tvars, init_checkpoint)
|
||||||
|
if use_tpu:
|
||||||
|
|
||||||
|
def tpu_scaffold():
|
||||||
|
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
|
||||||
|
return tf.train.Scaffold()
|
||||||
|
|
||||||
|
scaffold_fn = tpu_scaffold
|
||||||
|
else:
|
||||||
|
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
|
||||||
|
|
||||||
|
tf.logging.info("**** Trainable Variables ****")
|
||||||
|
for var in tvars:
|
||||||
|
init_string = ""
|
||||||
|
if var.name in initialized_variable_names:
|
||||||
|
init_string = ", *INIT_FROM_CKPT*"
|
||||||
|
tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape,
|
||||||
|
init_string)
|
||||||
|
|
||||||
|
all_layers = model.get_all_encoder_layers()
|
||||||
|
|
||||||
|
predictions = {
|
||||||
|
"unique_id": unique_ids,
|
||||||
|
}
|
||||||
|
|
||||||
|
for (i, layer_index) in enumerate(layer_indexes):
|
||||||
|
predictions["layer_output_%d" % i] = all_layers[layer_index]
|
||||||
|
|
||||||
|
output_spec = tf.contrib.tpu.TPUEstimatorSpec(
|
||||||
|
mode=mode, predictions=predictions, scaffold_fn=scaffold_fn)
|
||||||
|
return output_spec
|
||||||
|
|
||||||
|
return model_fn
|
||||||
|
|
||||||
|
|
||||||
|
def convert_examples_to_features(examples, seq_length, tokenizer):
|
||||||
|
"""Loads a data file into a list of `InputBatch`s."""
|
||||||
|
|
||||||
|
features = []
|
||||||
|
for (ex_index, example) in enumerate(examples):
|
||||||
|
tokens_a = tokenizer.tokenize(example.text_a)
|
||||||
|
|
||||||
|
tokens_b = None
|
||||||
|
if example.text_b:
|
||||||
|
tokens_b = tokenizer.tokenize(example.text_b)
|
||||||
|
|
||||||
|
if tokens_b:
|
||||||
|
# Modifies `tokens_a` and `tokens_b` in place so that the total
|
||||||
|
# length is less than the specified length.
|
||||||
|
# Account for [CLS], [SEP], [SEP] with "- 3"
|
||||||
|
_truncate_seq_pair(tokens_a, tokens_b, seq_length - 3)
|
||||||
|
else:
|
||||||
|
# Account for [CLS] and [SEP] with "- 2"
|
||||||
|
if len(tokens_a) > seq_length - 2:
|
||||||
|
tokens_a = tokens_a[0:(seq_length - 2)]
|
||||||
|
|
||||||
|
# The convention in BERT is:
|
||||||
|
# (a) For sequence pairs:
|
||||||
|
# tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
|
||||||
|
# type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
|
||||||
|
# (b) For single sequences:
|
||||||
|
# tokens: [CLS] the dog is hairy . [SEP]
|
||||||
|
# type_ids: 0 0 0 0 0 0 0
|
||||||
|
#
|
||||||
|
# Where "type_ids" are used to indicate whether this is the first
|
||||||
|
# sequence or the second sequence. The embedding vectors for `type=0` and
|
||||||
|
# `type=1` were learned during pre-training and are added to the wordpiece
|
||||||
|
# embedding vector (and position vector). This is not *strictly* necessary
|
||||||
|
# since the [SEP] token unambiguously separates the sequences, but it makes
|
||||||
|
# it easier for the model to learn the concept of sequences.
|
||||||
|
#
|
||||||
|
# For classification tasks, the first vector (corresponding to [CLS]) is
|
||||||
|
# used as as the "sentence vector". Note that this only makes sense because
|
||||||
|
# the entire model is fine-tuned.
|
||||||
|
tokens = []
|
||||||
|
input_type_ids = []
|
||||||
|
tokens.append("[CLS]")
|
||||||
|
input_type_ids.append(0)
|
||||||
|
for token in tokens_a:
|
||||||
|
tokens.append(token)
|
||||||
|
input_type_ids.append(0)
|
||||||
|
tokens.append("[SEP]")
|
||||||
|
input_type_ids.append(0)
|
||||||
|
|
||||||
|
if tokens_b:
|
||||||
|
for token in tokens_b:
|
||||||
|
tokens.append(token)
|
||||||
|
input_type_ids.append(1)
|
||||||
|
tokens.append("[SEP]")
|
||||||
|
input_type_ids.append(1)
|
||||||
|
|
||||||
|
input_ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||||
|
|
||||||
|
# The mask has 1 for real tokens and 0 for padding tokens. Only real
|
||||||
|
# tokens are attended to.
|
||||||
|
input_mask = [1] * len(input_ids)
|
||||||
|
|
||||||
|
# Zero-pad up to the sequence length.
|
||||||
|
while len(input_ids) < seq_length:
|
||||||
|
input_ids.append(0)
|
||||||
|
input_mask.append(0)
|
||||||
|
input_type_ids.append(0)
|
||||||
|
|
||||||
|
assert len(input_ids) == seq_length
|
||||||
|
assert len(input_mask) == seq_length
|
||||||
|
assert len(input_type_ids) == seq_length
|
||||||
|
|
||||||
|
if ex_index < 5:
|
||||||
|
tf.logging.info("*** Example ***")
|
||||||
|
tf.logging.info("unique_id: %s" % (example.unique_id))
|
||||||
|
tf.logging.info("tokens: %s" % " ".join(
|
||||||
|
[tokenization.printable_text(x) for x in tokens]))
|
||||||
|
tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
|
||||||
|
tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
|
||||||
|
tf.logging.info(
|
||||||
|
"input_type_ids: %s" % " ".join([str(x) for x in input_type_ids]))
|
||||||
|
|
||||||
|
features.append(
|
||||||
|
InputFeatures(
|
||||||
|
unique_id=example.unique_id,
|
||||||
|
tokens=tokens,
|
||||||
|
input_ids=input_ids,
|
||||||
|
input_mask=input_mask,
|
||||||
|
input_type_ids=input_type_ids))
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
|
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
|
||||||
|
"""Truncates a sequence pair in place to the maximum length."""
|
||||||
|
|
||||||
|
# This is a simple heuristic which will always truncate the longer sequence
|
||||||
|
# one token at a time. This makes more sense than truncating an equal percent
|
||||||
|
# of tokens from each, since if one sequence is very short then each token
|
||||||
|
# that's truncated likely contains more information than a longer sequence.
|
||||||
|
while True:
|
||||||
|
total_length = len(tokens_a) + len(tokens_b)
|
||||||
|
if total_length <= max_length:
|
||||||
|
break
|
||||||
|
if len(tokens_a) > len(tokens_b):
|
||||||
|
tokens_a.pop()
|
||||||
|
else:
|
||||||
|
tokens_b.pop()
|
||||||
|
|
||||||
|
|
||||||
|
def read_examples(input_file):
|
||||||
|
"""Read a list of `InputExample`s from an input file."""
|
||||||
|
examples = []
|
||||||
|
unique_id = 0
|
||||||
|
with tf.gfile.GFile(input_file, "r") as reader:
|
||||||
|
while True:
|
||||||
|
line = tokenization.convert_to_unicode(reader.readline())
|
||||||
|
if not line:
|
||||||
|
break
|
||||||
|
line = line.strip()
|
||||||
|
text_a = None
|
||||||
|
text_b = None
|
||||||
|
m = re.match(r"^(.*) \|\|\| (.*)$", line)
|
||||||
|
if m is None:
|
||||||
|
text_a = line
|
||||||
|
else:
|
||||||
|
text_a = m.group(1)
|
||||||
|
text_b = m.group(2)
|
||||||
|
examples.append(
|
||||||
|
InputExample(unique_id=unique_id, text_a=text_a, text_b=text_b))
|
||||||
|
unique_id += 1
|
||||||
|
return examples
|
||||||
|
|
||||||
|
|
||||||
|
def main(_):
|
||||||
|
tf.logging.set_verbosity(tf.logging.INFO)
|
||||||
|
|
||||||
|
layer_indexes = [int(x) for x in FLAGS.layers.split(",")]
|
||||||
|
|
||||||
|
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
|
||||||
|
|
||||||
|
tokenizer = tokenization.FullTokenizer(
|
||||||
|
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
|
||||||
|
|
||||||
|
is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
|
||||||
|
run_config = tf.contrib.tpu.RunConfig(
|
||||||
|
master=FLAGS.master,
|
||||||
|
tpu_config=tf.contrib.tpu.TPUConfig(
|
||||||
|
num_shards=FLAGS.num_tpu_cores,
|
||||||
|
per_host_input_for_training=is_per_host))
|
||||||
|
|
||||||
|
examples = read_examples(FLAGS.input_file)
|
||||||
|
|
||||||
|
features = convert_examples_to_features(
|
||||||
|
examples=examples, seq_length=FLAGS.max_seq_length, tokenizer=tokenizer)
|
||||||
|
|
||||||
|
unique_id_to_feature = {}
|
||||||
|
for feature in features:
|
||||||
|
unique_id_to_feature[feature.unique_id] = feature
|
||||||
|
|
||||||
|
model_fn = model_fn_builder(
|
||||||
|
bert_config=bert_config,
|
||||||
|
init_checkpoint=FLAGS.init_checkpoint,
|
||||||
|
layer_indexes=layer_indexes,
|
||||||
|
use_tpu=FLAGS.use_tpu,
|
||||||
|
use_one_hot_embeddings=FLAGS.use_one_hot_embeddings)
|
||||||
|
|
||||||
|
# If TPU is not available, this will fall back to normal Estimator on CPU
|
||||||
|
# or GPU.
|
||||||
|
estimator = tf.contrib.tpu.TPUEstimator(
|
||||||
|
use_tpu=FLAGS.use_tpu,
|
||||||
|
model_fn=model_fn,
|
||||||
|
config=run_config,
|
||||||
|
predict_batch_size=FLAGS.batch_size)
|
||||||
|
|
||||||
|
input_fn = input_fn_builder(
|
||||||
|
features=features, seq_length=FLAGS.max_seq_length)
|
||||||
|
|
||||||
|
with codecs.getwriter("utf-8")(tf.gfile.Open(FLAGS.output_file,
|
||||||
|
"w")) as writer:
|
||||||
|
for result in estimator.predict(input_fn, yield_single_examples=True):
|
||||||
|
unique_id = int(result["unique_id"])
|
||||||
|
feature = unique_id_to_feature[unique_id]
|
||||||
|
output_json = collections.OrderedDict()
|
||||||
|
output_json["linex_index"] = unique_id
|
||||||
|
all_features = []
|
||||||
|
for (i, token) in enumerate(feature.tokens):
|
||||||
|
all_layers = []
|
||||||
|
for (j, layer_index) in enumerate(layer_indexes):
|
||||||
|
layer_output = result["layer_output_%d" % j]
|
||||||
|
layers = collections.OrderedDict()
|
||||||
|
layers["index"] = layer_index
|
||||||
|
layers["values"] = [
|
||||||
|
round(float(x), 6) for x in layer_output[i:(i + 1)].flat
|
||||||
|
]
|
||||||
|
all_layers.append(layers)
|
||||||
|
features = collections.OrderedDict()
|
||||||
|
features["token"] = token
|
||||||
|
features["layers"] = all_layers
|
||||||
|
all_features.append(features)
|
||||||
|
output_json["features"] = all_features
|
||||||
|
writer.write(json.dumps(output_json) + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
flags.mark_flag_as_required("input_file")
|
||||||
|
flags.mark_flag_as_required("vocab_file")
|
||||||
|
flags.mark_flag_as_required("bert_config_file")
|
||||||
|
flags.mark_flag_as_required("init_checkpoint")
|
||||||
|
flags.mark_flag_as_required("output_file")
|
||||||
|
tf.app.run()
|
986
modeling.py
Normal file
986
modeling.py
Normal file
@ -0,0 +1,986 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2018 The Google AI Language Team Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""The main BERT model and related functions."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import collections
|
||||||
|
import copy
|
||||||
|
import json
|
||||||
|
import math
|
||||||
|
import re
|
||||||
|
import numpy as np
|
||||||
|
import six
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
|
class BertConfig(object):
|
||||||
|
"""Configuration for `BertModel`."""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
vocab_size,
|
||||||
|
hidden_size=768,
|
||||||
|
num_hidden_layers=12,
|
||||||
|
num_attention_heads=12,
|
||||||
|
intermediate_size=3072,
|
||||||
|
hidden_act="gelu",
|
||||||
|
hidden_dropout_prob=0.1,
|
||||||
|
attention_probs_dropout_prob=0.1,
|
||||||
|
max_position_embeddings=512,
|
||||||
|
type_vocab_size=16,
|
||||||
|
initializer_range=0.02):
|
||||||
|
"""Constructs BertConfig.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vocab_size: Vocabulary size of `inputs_ids` in `BertModel`.
|
||||||
|
hidden_size: Size of the encoder layers and the pooler layer.
|
||||||
|
num_hidden_layers: Number of hidden layers in the Transformer encoder.
|
||||||
|
num_attention_heads: Number of attention heads for each attention layer in
|
||||||
|
the Transformer encoder.
|
||||||
|
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
|
||||||
|
layer in the Transformer encoder.
|
||||||
|
hidden_act: The non-linear activation function (function or string) in the
|
||||||
|
encoder and pooler.
|
||||||
|
hidden_dropout_prob: The dropout probability for all fully connected
|
||||||
|
layers in the embeddings, encoder, and pooler.
|
||||||
|
attention_probs_dropout_prob: The dropout ratio for the attention
|
||||||
|
probabilities.
|
||||||
|
max_position_embeddings: The maximum sequence length that this model might
|
||||||
|
ever be used with. Typically set this to something large just in case
|
||||||
|
(e.g., 512 or 1024 or 2048).
|
||||||
|
type_vocab_size: The vocabulary size of the `token_type_ids` passed into
|
||||||
|
`BertModel`.
|
||||||
|
initializer_range: The stdev of the truncated_normal_initializer for
|
||||||
|
initializing all weight matrices.
|
||||||
|
"""
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.hidden_dropout_prob = hidden_dropout_prob
|
||||||
|
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.type_vocab_size = type_vocab_size
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, json_object):
|
||||||
|
"""Constructs a `BertConfig` from a Python dictionary of parameters."""
|
||||||
|
config = BertConfig(vocab_size=None)
|
||||||
|
for (key, value) in six.iteritems(json_object):
|
||||||
|
config.__dict__[key] = value
|
||||||
|
return config
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_json_file(cls, json_file):
|
||||||
|
"""Constructs a `BertConfig` from a json file of parameters."""
|
||||||
|
with tf.gfile.GFile(json_file, "r") as reader:
|
||||||
|
text = reader.read()
|
||||||
|
return cls.from_dict(json.loads(text))
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
"""Serializes this instance to a Python dictionary."""
|
||||||
|
output = copy.deepcopy(self.__dict__)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def to_json_string(self):
|
||||||
|
"""Serializes this instance to a JSON string."""
|
||||||
|
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
|
||||||
|
|
||||||
|
|
||||||
|
class BertModel(object):
|
||||||
|
"""BERT model ("Bidirectional Encoder Representations from Transformers").
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Already been converted into WordPiece token ids
|
||||||
|
input_ids = tf.constant([[31, 51, 99], [15, 5, 0]])
|
||||||
|
input_mask = tf.constant([[1, 1, 1], [1, 1, 0]])
|
||||||
|
token_type_ids = tf.constant([[0, 0, 1], [0, 2, 0]])
|
||||||
|
|
||||||
|
config = modeling.BertConfig(vocab_size=32000, hidden_size=512,
|
||||||
|
num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
|
||||||
|
|
||||||
|
model = modeling.BertModel(config=config, is_training=True,
|
||||||
|
input_ids=input_ids, input_mask=input_mask, token_type_ids=token_type_ids)
|
||||||
|
|
||||||
|
label_embeddings = tf.get_variable(...)
|
||||||
|
pooled_output = model.get_pooled_output()
|
||||||
|
logits = tf.matmul(pooled_output, label_embeddings)
|
||||||
|
...
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
config,
|
||||||
|
is_training,
|
||||||
|
input_ids,
|
||||||
|
input_mask=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
use_one_hot_embeddings=False,
|
||||||
|
scope=None):
|
||||||
|
"""Constructor for BertModel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: `BertConfig` instance.
|
||||||
|
is_training: bool. true for training model, false for eval model. Controls
|
||||||
|
whether dropout will be applied.
|
||||||
|
input_ids: int32 Tensor of shape [batch_size, seq_length].
|
||||||
|
input_mask: (optional) int32 Tensor of shape [batch_size, seq_length].
|
||||||
|
token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length].
|
||||||
|
use_one_hot_embeddings: (optional) bool. Whether to use one-hot word
|
||||||
|
embeddings or tf.embedding_lookup() for the word embeddings.
|
||||||
|
scope: (optional) variable scope. Defaults to "bert".
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: The config is invalid or one of the input tensor shapes
|
||||||
|
is invalid.
|
||||||
|
"""
|
||||||
|
config = copy.deepcopy(config)
|
||||||
|
if not is_training:
|
||||||
|
config.hidden_dropout_prob = 0.0
|
||||||
|
config.attention_probs_dropout_prob = 0.0
|
||||||
|
|
||||||
|
input_shape = get_shape_list(input_ids, expected_rank=2)
|
||||||
|
batch_size = input_shape[0]
|
||||||
|
seq_length = input_shape[1]
|
||||||
|
|
||||||
|
if input_mask is None:
|
||||||
|
input_mask = tf.ones(shape=[batch_size, seq_length], dtype=tf.int32)
|
||||||
|
|
||||||
|
if token_type_ids is None:
|
||||||
|
token_type_ids = tf.zeros(shape=[batch_size, seq_length], dtype=tf.int32)
|
||||||
|
|
||||||
|
with tf.variable_scope(scope, default_name="bert"):
|
||||||
|
with tf.variable_scope("embeddings"):
|
||||||
|
# Perform embedding lookup on the word ids.
|
||||||
|
(self.embedding_output, self.embedding_table) = embedding_lookup(
|
||||||
|
input_ids=input_ids,
|
||||||
|
vocab_size=config.vocab_size,
|
||||||
|
embedding_size=config.hidden_size,
|
||||||
|
initializer_range=config.initializer_range,
|
||||||
|
word_embedding_name="word_embeddings",
|
||||||
|
use_one_hot_embeddings=use_one_hot_embeddings)
|
||||||
|
|
||||||
|
# Add positional embeddings and token type embeddings, then layer
|
||||||
|
# normalize and perform dropout.
|
||||||
|
self.embedding_output = embedding_postprocessor(
|
||||||
|
input_tensor=self.embedding_output,
|
||||||
|
use_token_type=True,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
token_type_vocab_size=config.type_vocab_size,
|
||||||
|
token_type_embedding_name="token_type_embeddings",
|
||||||
|
use_position_embeddings=True,
|
||||||
|
position_embedding_name="position_embeddings",
|
||||||
|
initializer_range=config.initializer_range,
|
||||||
|
max_position_embeddings=config.max_position_embeddings,
|
||||||
|
dropout_prob=config.hidden_dropout_prob)
|
||||||
|
|
||||||
|
with tf.variable_scope("encoder"):
|
||||||
|
# This converts a 2D mask of shape [batch_size, seq_length] to a 3D
|
||||||
|
# mask of shape [batch_size, seq_length, seq_length] which is used
|
||||||
|
# for the attention scores.
|
||||||
|
attention_mask = create_attention_mask_from_input_mask(
|
||||||
|
input_ids, input_mask)
|
||||||
|
|
||||||
|
# Run the stacked transformer.
|
||||||
|
# `sequence_output` shape = [batch_size, seq_length, hidden_size].
|
||||||
|
self.all_encoder_layers = transformer_model(
|
||||||
|
input_tensor=self.embedding_output,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
hidden_size=config.hidden_size,
|
||||||
|
num_hidden_layers=config.num_hidden_layers,
|
||||||
|
num_attention_heads=config.num_attention_heads,
|
||||||
|
intermediate_size=config.intermediate_size,
|
||||||
|
intermediate_act_fn=get_activation(config.hidden_act),
|
||||||
|
hidden_dropout_prob=config.hidden_dropout_prob,
|
||||||
|
attention_probs_dropout_prob=config.attention_probs_dropout_prob,
|
||||||
|
initializer_range=config.initializer_range,
|
||||||
|
do_return_all_layers=True)
|
||||||
|
|
||||||
|
self.sequence_output = self.all_encoder_layers[-1]
|
||||||
|
# The "pooler" converts the encoded sequence tensor of shape
|
||||||
|
# [batch_size, seq_length, hidden_size] to a tensor of shape
|
||||||
|
# [batch_size, hidden_size]. This is necessary for segment-level
|
||||||
|
# (or segment-pair-level) classification tasks where we need a fixed
|
||||||
|
# dimensional representation of the segment.
|
||||||
|
with tf.variable_scope("pooler"):
|
||||||
|
# We "pool" the model by simply taking the hidden state corresponding
|
||||||
|
# to the first token. We assume that this has been pre-trained
|
||||||
|
first_token_tensor = tf.squeeze(self.sequence_output[:, 0:1, :], axis=1)
|
||||||
|
self.pooled_output = tf.layers.dense(
|
||||||
|
first_token_tensor,
|
||||||
|
config.hidden_size,
|
||||||
|
activation=tf.tanh,
|
||||||
|
kernel_initializer=create_initializer(config.initializer_range))
|
||||||
|
|
||||||
|
def get_pooled_output(self):
|
||||||
|
return self.pooled_output
|
||||||
|
|
||||||
|
def get_sequence_output(self):
|
||||||
|
"""Gets final hidden layer of encoder.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float Tensor of shape [batch_size, seq_length, hidden_size] corresponding
|
||||||
|
to the final hidden of the transformer encoder.
|
||||||
|
"""
|
||||||
|
return self.sequence_output
|
||||||
|
|
||||||
|
def get_all_encoder_layers(self):
|
||||||
|
return self.all_encoder_layers
|
||||||
|
|
||||||
|
def get_embedding_output(self):
|
||||||
|
"""Gets output of the embedding lookup (i.e., input to the transformer).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float Tensor of shape [batch_size, seq_length, hidden_size] corresponding
|
||||||
|
to the output of the embedding layer, after summing the word
|
||||||
|
embeddings with the positional embeddings and the token type embeddings,
|
||||||
|
then performing layer normalization. This is the input to the transformer.
|
||||||
|
"""
|
||||||
|
return self.embedding_output
|
||||||
|
|
||||||
|
def get_embedding_table(self):
|
||||||
|
return self.embedding_table
|
||||||
|
|
||||||
|
|
||||||
|
def gelu(x):
|
||||||
|
"""Gaussian Error Linear Unit.
|
||||||
|
|
||||||
|
This is a smoother version of the RELU.
|
||||||
|
Original paper: https://arxiv.org/abs/1606.08415
|
||||||
|
Args:
|
||||||
|
x: float Tensor to perform activation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`x` with the GELU activation applied.
|
||||||
|
"""
|
||||||
|
cdf = 0.5 * (1.0 + tf.tanh(
|
||||||
|
(np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
|
||||||
|
return x * cdf
|
||||||
|
|
||||||
|
|
||||||
|
def get_activation(activation_string):
|
||||||
|
"""Maps a string to a Python function, e.g., "relu" => `tf.nn.relu`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
activation_string: String name of the activation function.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Python function corresponding to the activation function. If
|
||||||
|
`activation_string` is None, empty, or "linear", this will return None.
|
||||||
|
If `activation_string` is not a string, it will return `activation_string`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: The `activation_string` does not correspond to a known
|
||||||
|
activation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# We assume that anything that"s not a string is already an activation
|
||||||
|
# function, so we just return it.
|
||||||
|
if not isinstance(activation_string, six.string_types):
|
||||||
|
return activation_string
|
||||||
|
|
||||||
|
if not activation_string:
|
||||||
|
return None
|
||||||
|
|
||||||
|
act = activation_string.lower()
|
||||||
|
if act == "linear":
|
||||||
|
return None
|
||||||
|
elif act == "relu":
|
||||||
|
return tf.nn.relu
|
||||||
|
elif act == "gelu":
|
||||||
|
return gelu
|
||||||
|
elif act == "tanh":
|
||||||
|
return tf.tanh
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported activation: %s" % act)
|
||||||
|
|
||||||
|
|
||||||
|
def get_assignment_map_from_checkpoint(tvars, init_checkpoint):
|
||||||
|
"""Compute the union of the current variables and checkpoint variables."""
|
||||||
|
assignment_map = {}
|
||||||
|
initialized_variable_names = {}
|
||||||
|
|
||||||
|
name_to_variable = collections.OrderedDict()
|
||||||
|
for var in tvars:
|
||||||
|
name = var.name
|
||||||
|
m = re.match("^(.*):\\d+$", name)
|
||||||
|
if m is not None:
|
||||||
|
name = m.group(1)
|
||||||
|
name_to_variable[name] = var
|
||||||
|
|
||||||
|
init_vars = tf.train.list_variables(init_checkpoint)
|
||||||
|
|
||||||
|
assignment_map = collections.OrderedDict()
|
||||||
|
for x in init_vars:
|
||||||
|
(name, var) = (x[0], x[1])
|
||||||
|
if name not in name_to_variable:
|
||||||
|
continue
|
||||||
|
assignment_map[name] = name
|
||||||
|
initialized_variable_names[name] = 1
|
||||||
|
initialized_variable_names[name + ":0"] = 1
|
||||||
|
|
||||||
|
return (assignment_map, initialized_variable_names)
|
||||||
|
|
||||||
|
|
||||||
|
def dropout(input_tensor, dropout_prob):
|
||||||
|
"""Perform dropout.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_tensor: float Tensor.
|
||||||
|
dropout_prob: Python float. The probability of dropping out a value (NOT of
|
||||||
|
*keeping* a dimension as in `tf.nn.dropout`).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A version of `input_tensor` with dropout applied.
|
||||||
|
"""
|
||||||
|
if dropout_prob is None or dropout_prob == 0.0:
|
||||||
|
return input_tensor
|
||||||
|
|
||||||
|
output = tf.nn.dropout(input_tensor, 1.0 - dropout_prob)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def layer_norm(input_tensor, name=None):
|
||||||
|
"""Run layer normalization on the last dimension of the tensor."""
|
||||||
|
return tf.contrib.layers.layer_norm(
|
||||||
|
inputs=input_tensor, begin_norm_axis=-1, begin_params_axis=-1, scope=name)
|
||||||
|
|
||||||
|
|
||||||
|
def layer_norm_and_dropout(input_tensor, dropout_prob, name=None):
|
||||||
|
"""Runs layer normalization followed by dropout."""
|
||||||
|
output_tensor = layer_norm(input_tensor, name)
|
||||||
|
output_tensor = dropout(output_tensor, dropout_prob)
|
||||||
|
return output_tensor
|
||||||
|
|
||||||
|
|
||||||
|
def create_initializer(initializer_range=0.02):
|
||||||
|
"""Creates a `truncated_normal_initializer` with the given range."""
|
||||||
|
return tf.truncated_normal_initializer(stddev=initializer_range)
|
||||||
|
|
||||||
|
|
||||||
|
def embedding_lookup(input_ids,
|
||||||
|
vocab_size,
|
||||||
|
embedding_size=128,
|
||||||
|
initializer_range=0.02,
|
||||||
|
word_embedding_name="word_embeddings",
|
||||||
|
use_one_hot_embeddings=False):
|
||||||
|
"""Looks up words embeddings for id tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_ids: int32 Tensor of shape [batch_size, seq_length] containing word
|
||||||
|
ids.
|
||||||
|
vocab_size: int. Size of the embedding vocabulary.
|
||||||
|
embedding_size: int. Width of the word embeddings.
|
||||||
|
initializer_range: float. Embedding initialization range.
|
||||||
|
word_embedding_name: string. Name of the embedding table.
|
||||||
|
use_one_hot_embeddings: bool. If True, use one-hot method for word
|
||||||
|
embeddings. If False, use `tf.gather()`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float Tensor of shape [batch_size, seq_length, embedding_size].
|
||||||
|
"""
|
||||||
|
# This function assumes that the input is of shape [batch_size, seq_length,
|
||||||
|
# num_inputs].
|
||||||
|
#
|
||||||
|
# If the input is a 2D tensor of shape [batch_size, seq_length], we
|
||||||
|
# reshape to [batch_size, seq_length, 1].
|
||||||
|
if input_ids.shape.ndims == 2:
|
||||||
|
input_ids = tf.expand_dims(input_ids, axis=[-1])
|
||||||
|
|
||||||
|
embedding_table = tf.get_variable(
|
||||||
|
name=word_embedding_name,
|
||||||
|
shape=[vocab_size, embedding_size],
|
||||||
|
initializer=create_initializer(initializer_range))
|
||||||
|
|
||||||
|
flat_input_ids = tf.reshape(input_ids, [-1])
|
||||||
|
if use_one_hot_embeddings:
|
||||||
|
one_hot_input_ids = tf.one_hot(flat_input_ids, depth=vocab_size)
|
||||||
|
output = tf.matmul(one_hot_input_ids, embedding_table)
|
||||||
|
else:
|
||||||
|
output = tf.gather(embedding_table, flat_input_ids)
|
||||||
|
|
||||||
|
input_shape = get_shape_list(input_ids)
|
||||||
|
|
||||||
|
output = tf.reshape(output,
|
||||||
|
input_shape[0:-1] + [input_shape[-1] * embedding_size])
|
||||||
|
return (output, embedding_table)
|
||||||
|
|
||||||
|
|
||||||
|
def embedding_postprocessor(input_tensor,
|
||||||
|
use_token_type=False,
|
||||||
|
token_type_ids=None,
|
||||||
|
token_type_vocab_size=16,
|
||||||
|
token_type_embedding_name="token_type_embeddings",
|
||||||
|
use_position_embeddings=True,
|
||||||
|
position_embedding_name="position_embeddings",
|
||||||
|
initializer_range=0.02,
|
||||||
|
max_position_embeddings=512,
|
||||||
|
dropout_prob=0.1):
|
||||||
|
"""Performs various post-processing on a word embedding tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_tensor: float Tensor of shape [batch_size, seq_length,
|
||||||
|
embedding_size].
|
||||||
|
use_token_type: bool. Whether to add embeddings for `token_type_ids`.
|
||||||
|
token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length].
|
||||||
|
Must be specified if `use_token_type` is True.
|
||||||
|
token_type_vocab_size: int. The vocabulary size of `token_type_ids`.
|
||||||
|
token_type_embedding_name: string. The name of the embedding table variable
|
||||||
|
for token type ids.
|
||||||
|
use_position_embeddings: bool. Whether to add position embeddings for the
|
||||||
|
position of each token in the sequence.
|
||||||
|
position_embedding_name: string. The name of the embedding table variable
|
||||||
|
for positional embeddings.
|
||||||
|
initializer_range: float. Range of the weight initialization.
|
||||||
|
max_position_embeddings: int. Maximum sequence length that might ever be
|
||||||
|
used with this model. This can be longer than the sequence length of
|
||||||
|
input_tensor, but cannot be shorter.
|
||||||
|
dropout_prob: float. Dropout probability applied to the final output tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float tensor with same shape as `input_tensor`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: One of the tensor shapes or input values is invalid.
|
||||||
|
"""
|
||||||
|
input_shape = get_shape_list(input_tensor, expected_rank=3)
|
||||||
|
batch_size = input_shape[0]
|
||||||
|
seq_length = input_shape[1]
|
||||||
|
width = input_shape[2]
|
||||||
|
|
||||||
|
output = input_tensor
|
||||||
|
|
||||||
|
if use_token_type:
|
||||||
|
if token_type_ids is None:
|
||||||
|
raise ValueError("`token_type_ids` must be specified if"
|
||||||
|
"`use_token_type` is True.")
|
||||||
|
token_type_table = tf.get_variable(
|
||||||
|
name=token_type_embedding_name,
|
||||||
|
shape=[token_type_vocab_size, width],
|
||||||
|
initializer=create_initializer(initializer_range))
|
||||||
|
# This vocab will be small so we always do one-hot here, since it is always
|
||||||
|
# faster for a small vocabulary.
|
||||||
|
flat_token_type_ids = tf.reshape(token_type_ids, [-1])
|
||||||
|
one_hot_ids = tf.one_hot(flat_token_type_ids, depth=token_type_vocab_size)
|
||||||
|
token_type_embeddings = tf.matmul(one_hot_ids, token_type_table)
|
||||||
|
token_type_embeddings = tf.reshape(token_type_embeddings,
|
||||||
|
[batch_size, seq_length, width])
|
||||||
|
output += token_type_embeddings
|
||||||
|
|
||||||
|
if use_position_embeddings:
|
||||||
|
assert_op = tf.assert_less_equal(seq_length, max_position_embeddings)
|
||||||
|
with tf.control_dependencies([assert_op]):
|
||||||
|
full_position_embeddings = tf.get_variable(
|
||||||
|
name=position_embedding_name,
|
||||||
|
shape=[max_position_embeddings, width],
|
||||||
|
initializer=create_initializer(initializer_range))
|
||||||
|
# Since the position embedding table is a learned variable, we create it
|
||||||
|
# using a (long) sequence length `max_position_embeddings`. The actual
|
||||||
|
# sequence length might be shorter than this, for faster training of
|
||||||
|
# tasks that do not have long sequences.
|
||||||
|
#
|
||||||
|
# So `full_position_embeddings` is effectively an embedding table
|
||||||
|
# for position [0, 1, 2, ..., max_position_embeddings-1], and the current
|
||||||
|
# sequence has positions [0, 1, 2, ... seq_length-1], so we can just
|
||||||
|
# perform a slice.
|
||||||
|
position_embeddings = tf.slice(full_position_embeddings, [0, 0],
|
||||||
|
[seq_length, -1])
|
||||||
|
num_dims = len(output.shape.as_list())
|
||||||
|
|
||||||
|
# Only the last two dimensions are relevant (`seq_length` and `width`), so
|
||||||
|
# we broadcast among the first dimensions, which is typically just
|
||||||
|
# the batch size.
|
||||||
|
position_broadcast_shape = []
|
||||||
|
for _ in range(num_dims - 2):
|
||||||
|
position_broadcast_shape.append(1)
|
||||||
|
position_broadcast_shape.extend([seq_length, width])
|
||||||
|
position_embeddings = tf.reshape(position_embeddings,
|
||||||
|
position_broadcast_shape)
|
||||||
|
output += position_embeddings
|
||||||
|
|
||||||
|
output = layer_norm_and_dropout(output, dropout_prob)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def create_attention_mask_from_input_mask(from_tensor, to_mask):
|
||||||
|
"""Create 3D attention mask from a 2D tensor mask.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
from_tensor: 2D or 3D Tensor of shape [batch_size, from_seq_length, ...].
|
||||||
|
to_mask: int32 Tensor of shape [batch_size, to_seq_length].
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float Tensor of shape [batch_size, from_seq_length, to_seq_length].
|
||||||
|
"""
|
||||||
|
from_shape = get_shape_list(from_tensor, expected_rank=[2, 3])
|
||||||
|
batch_size = from_shape[0]
|
||||||
|
from_seq_length = from_shape[1]
|
||||||
|
|
||||||
|
to_shape = get_shape_list(to_mask, expected_rank=2)
|
||||||
|
to_seq_length = to_shape[1]
|
||||||
|
|
||||||
|
to_mask = tf.cast(
|
||||||
|
tf.reshape(to_mask, [batch_size, 1, to_seq_length]), tf.float32)
|
||||||
|
|
||||||
|
# We don't assume that `from_tensor` is a mask (although it could be). We
|
||||||
|
# don't actually care if we attend *from* padding tokens (only *to* padding)
|
||||||
|
# tokens so we create a tensor of all ones.
|
||||||
|
#
|
||||||
|
# `broadcast_ones` = [batch_size, from_seq_length, 1]
|
||||||
|
broadcast_ones = tf.ones(
|
||||||
|
shape=[batch_size, from_seq_length, 1], dtype=tf.float32)
|
||||||
|
|
||||||
|
# Here we broadcast along two dimensions to create the mask.
|
||||||
|
mask = broadcast_ones * to_mask
|
||||||
|
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
def attention_layer(from_tensor,
|
||||||
|
to_tensor,
|
||||||
|
attention_mask=None,
|
||||||
|
num_attention_heads=1,
|
||||||
|
size_per_head=512,
|
||||||
|
query_act=None,
|
||||||
|
key_act=None,
|
||||||
|
value_act=None,
|
||||||
|
attention_probs_dropout_prob=0.0,
|
||||||
|
initializer_range=0.02,
|
||||||
|
do_return_2d_tensor=False,
|
||||||
|
batch_size=None,
|
||||||
|
from_seq_length=None,
|
||||||
|
to_seq_length=None):
|
||||||
|
"""Performs multi-headed attention from `from_tensor` to `to_tensor`.
|
||||||
|
|
||||||
|
This is an implementation of multi-headed attention based on "Attention
|
||||||
|
is all you Need". If `from_tensor` and `to_tensor` are the same, then
|
||||||
|
this is self-attention. Each timestep in `from_tensor` attends to the
|
||||||
|
corresponding sequence in `to_tensor`, and returns a fixed-with vector.
|
||||||
|
|
||||||
|
This function first projects `from_tensor` into a "query" tensor and
|
||||||
|
`to_tensor` into "key" and "value" tensors. These are (effectively) a list
|
||||||
|
of tensors of length `num_attention_heads`, where each tensor is of shape
|
||||||
|
[batch_size, seq_length, size_per_head].
|
||||||
|
|
||||||
|
Then, the query and key tensors are dot-producted and scaled. These are
|
||||||
|
softmaxed to obtain attention probabilities. The value tensors are then
|
||||||
|
interpolated by these probabilities, then concatenated back to a single
|
||||||
|
tensor and returned.
|
||||||
|
|
||||||
|
In practice, the multi-headed attention are done with transposes and
|
||||||
|
reshapes rather than actual separate tensors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
from_tensor: float Tensor of shape [batch_size, from_seq_length,
|
||||||
|
from_width].
|
||||||
|
to_tensor: float Tensor of shape [batch_size, to_seq_length, to_width].
|
||||||
|
attention_mask: (optional) int32 Tensor of shape [batch_size,
|
||||||
|
from_seq_length, to_seq_length]. The values should be 1 or 0. The
|
||||||
|
attention scores will effectively be set to -infinity for any positions in
|
||||||
|
the mask that are 0, and will be unchanged for positions that are 1.
|
||||||
|
num_attention_heads: int. Number of attention heads.
|
||||||
|
size_per_head: int. Size of each attention head.
|
||||||
|
query_act: (optional) Activation function for the query transform.
|
||||||
|
key_act: (optional) Activation function for the key transform.
|
||||||
|
value_act: (optional) Activation function for the value transform.
|
||||||
|
attention_probs_dropout_prob: (optional) float. Dropout probability of the
|
||||||
|
attention probabilities.
|
||||||
|
initializer_range: float. Range of the weight initializer.
|
||||||
|
do_return_2d_tensor: bool. If True, the output will be of shape [batch_size
|
||||||
|
* from_seq_length, num_attention_heads * size_per_head]. If False, the
|
||||||
|
output will be of shape [batch_size, from_seq_length, num_attention_heads
|
||||||
|
* size_per_head].
|
||||||
|
batch_size: (Optional) int. If the input is 2D, this might be the batch size
|
||||||
|
of the 3D version of the `from_tensor` and `to_tensor`.
|
||||||
|
from_seq_length: (Optional) If the input is 2D, this might be the seq length
|
||||||
|
of the 3D version of the `from_tensor`.
|
||||||
|
to_seq_length: (Optional) If the input is 2D, this might be the seq length
|
||||||
|
of the 3D version of the `to_tensor`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float Tensor of shape [batch_size, from_seq_length,
|
||||||
|
num_attention_heads * size_per_head]. (If `do_return_2d_tensor` is
|
||||||
|
true, this will be of shape [batch_size * from_seq_length,
|
||||||
|
num_attention_heads * size_per_head]).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: Any of the arguments or tensor shapes are invalid.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def transpose_for_scores(input_tensor, batch_size, num_attention_heads,
|
||||||
|
seq_length, width):
|
||||||
|
output_tensor = tf.reshape(
|
||||||
|
input_tensor, [batch_size, seq_length, num_attention_heads, width])
|
||||||
|
|
||||||
|
output_tensor = tf.transpose(output_tensor, [0, 2, 1, 3])
|
||||||
|
return output_tensor
|
||||||
|
|
||||||
|
from_shape = get_shape_list(from_tensor, expected_rank=[2, 3])
|
||||||
|
to_shape = get_shape_list(to_tensor, expected_rank=[2, 3])
|
||||||
|
|
||||||
|
if len(from_shape) != len(to_shape):
|
||||||
|
raise ValueError(
|
||||||
|
"The rank of `from_tensor` must match the rank of `to_tensor`.")
|
||||||
|
|
||||||
|
if len(from_shape) == 3:
|
||||||
|
batch_size = from_shape[0]
|
||||||
|
from_seq_length = from_shape[1]
|
||||||
|
to_seq_length = to_shape[1]
|
||||||
|
elif len(from_shape) == 2:
|
||||||
|
if (batch_size is None or from_seq_length is None or to_seq_length is None):
|
||||||
|
raise ValueError(
|
||||||
|
"When passing in rank 2 tensors to attention_layer, the values "
|
||||||
|
"for `batch_size`, `from_seq_length`, and `to_seq_length` "
|
||||||
|
"must all be specified.")
|
||||||
|
|
||||||
|
# Scalar dimensions referenced here:
|
||||||
|
# B = batch size (number of sequences)
|
||||||
|
# F = `from_tensor` sequence length
|
||||||
|
# T = `to_tensor` sequence length
|
||||||
|
# N = `num_attention_heads`
|
||||||
|
# H = `size_per_head`
|
||||||
|
|
||||||
|
from_tensor_2d = reshape_to_matrix(from_tensor)
|
||||||
|
to_tensor_2d = reshape_to_matrix(to_tensor)
|
||||||
|
|
||||||
|
# `query_layer` = [B*F, N*H]
|
||||||
|
query_layer = tf.layers.dense(
|
||||||
|
from_tensor_2d,
|
||||||
|
num_attention_heads * size_per_head,
|
||||||
|
activation=query_act,
|
||||||
|
name="query",
|
||||||
|
kernel_initializer=create_initializer(initializer_range))
|
||||||
|
|
||||||
|
# `key_layer` = [B*T, N*H]
|
||||||
|
key_layer = tf.layers.dense(
|
||||||
|
to_tensor_2d,
|
||||||
|
num_attention_heads * size_per_head,
|
||||||
|
activation=key_act,
|
||||||
|
name="key",
|
||||||
|
kernel_initializer=create_initializer(initializer_range))
|
||||||
|
|
||||||
|
# `value_layer` = [B*T, N*H]
|
||||||
|
value_layer = tf.layers.dense(
|
||||||
|
to_tensor_2d,
|
||||||
|
num_attention_heads * size_per_head,
|
||||||
|
activation=value_act,
|
||||||
|
name="value",
|
||||||
|
kernel_initializer=create_initializer(initializer_range))
|
||||||
|
|
||||||
|
# `query_layer` = [B, N, F, H]
|
||||||
|
query_layer = transpose_for_scores(query_layer, batch_size,
|
||||||
|
num_attention_heads, from_seq_length,
|
||||||
|
size_per_head)
|
||||||
|
|
||||||
|
# `key_layer` = [B, N, T, H]
|
||||||
|
key_layer = transpose_for_scores(key_layer, batch_size, num_attention_heads,
|
||||||
|
to_seq_length, size_per_head)
|
||||||
|
|
||||||
|
# Take the dot product between "query" and "key" to get the raw
|
||||||
|
# attention scores.
|
||||||
|
# `attention_scores` = [B, N, F, T]
|
||||||
|
attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
|
||||||
|
attention_scores = tf.multiply(attention_scores,
|
||||||
|
1.0 / math.sqrt(float(size_per_head)))
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
# `attention_mask` = [B, 1, F, T]
|
||||||
|
attention_mask = tf.expand_dims(attention_mask, axis=[1])
|
||||||
|
|
||||||
|
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||||
|
# masked positions, this operation will create a tensor which is 0.0 for
|
||||||
|
# positions we want to attend and -10000.0 for masked positions.
|
||||||
|
adder = (1.0 - tf.cast(attention_mask, tf.float32)) * -10000.0
|
||||||
|
|
||||||
|
# Since we are adding it to the raw scores before the softmax, this is
|
||||||
|
# effectively the same as removing these entirely.
|
||||||
|
attention_scores += adder
|
||||||
|
|
||||||
|
# Normalize the attention scores to probabilities.
|
||||||
|
# `attention_probs` = [B, N, F, T]
|
||||||
|
attention_probs = tf.nn.softmax(attention_scores)
|
||||||
|
|
||||||
|
# This is actually dropping out entire tokens to attend to, which might
|
||||||
|
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||||
|
attention_probs = dropout(attention_probs, attention_probs_dropout_prob)
|
||||||
|
|
||||||
|
# `value_layer` = [B, T, N, H]
|
||||||
|
value_layer = tf.reshape(
|
||||||
|
value_layer,
|
||||||
|
[batch_size, to_seq_length, num_attention_heads, size_per_head])
|
||||||
|
|
||||||
|
# `value_layer` = [B, N, T, H]
|
||||||
|
value_layer = tf.transpose(value_layer, [0, 2, 1, 3])
|
||||||
|
|
||||||
|
# `context_layer` = [B, N, F, H]
|
||||||
|
context_layer = tf.matmul(attention_probs, value_layer)
|
||||||
|
|
||||||
|
# `context_layer` = [B, F, N, H]
|
||||||
|
context_layer = tf.transpose(context_layer, [0, 2, 1, 3])
|
||||||
|
|
||||||
|
if do_return_2d_tensor:
|
||||||
|
# `context_layer` = [B*F, N*H]
|
||||||
|
context_layer = tf.reshape(
|
||||||
|
context_layer,
|
||||||
|
[batch_size * from_seq_length, num_attention_heads * size_per_head])
|
||||||
|
else:
|
||||||
|
# `context_layer` = [B, F, N*H]
|
||||||
|
context_layer = tf.reshape(
|
||||||
|
context_layer,
|
||||||
|
[batch_size, from_seq_length, num_attention_heads * size_per_head])
|
||||||
|
|
||||||
|
return context_layer
|
||||||
|
|
||||||
|
|
||||||
|
def transformer_model(input_tensor,
|
||||||
|
attention_mask=None,
|
||||||
|
hidden_size=768,
|
||||||
|
num_hidden_layers=12,
|
||||||
|
num_attention_heads=12,
|
||||||
|
intermediate_size=3072,
|
||||||
|
intermediate_act_fn=gelu,
|
||||||
|
hidden_dropout_prob=0.1,
|
||||||
|
attention_probs_dropout_prob=0.1,
|
||||||
|
initializer_range=0.02,
|
||||||
|
do_return_all_layers=False):
|
||||||
|
"""Multi-headed, multi-layer Transformer from "Attention is All You Need".
|
||||||
|
|
||||||
|
This is almost an exact implementation of the original Transformer encoder.
|
||||||
|
|
||||||
|
See the original paper:
|
||||||
|
https://arxiv.org/abs/1706.03762
|
||||||
|
|
||||||
|
Also see:
|
||||||
|
https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_tensor: float Tensor of shape [batch_size, seq_length, hidden_size].
|
||||||
|
attention_mask: (optional) int32 Tensor of shape [batch_size, seq_length,
|
||||||
|
seq_length], with 1 for positions that can be attended to and 0 in
|
||||||
|
positions that should not be.
|
||||||
|
hidden_size: int. Hidden size of the Transformer.
|
||||||
|
num_hidden_layers: int. Number of layers (blocks) in the Transformer.
|
||||||
|
num_attention_heads: int. Number of attention heads in the Transformer.
|
||||||
|
intermediate_size: int. The size of the "intermediate" (a.k.a., feed
|
||||||
|
forward) layer.
|
||||||
|
intermediate_act_fn: function. The non-linear activation function to apply
|
||||||
|
to the output of the intermediate/feed-forward layer.
|
||||||
|
hidden_dropout_prob: float. Dropout probability for the hidden layers.
|
||||||
|
attention_probs_dropout_prob: float. Dropout probability of the attention
|
||||||
|
probabilities.
|
||||||
|
initializer_range: float. Range of the initializer (stddev of truncated
|
||||||
|
normal).
|
||||||
|
do_return_all_layers: Whether to also return all layers or just the final
|
||||||
|
layer.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float Tensor of shape [batch_size, seq_length, hidden_size], the final
|
||||||
|
hidden layer of the Transformer.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: A Tensor shape or parameter is invalid.
|
||||||
|
"""
|
||||||
|
if hidden_size % num_attention_heads != 0:
|
||||||
|
raise ValueError(
|
||||||
|
"The hidden size (%d) is not a multiple of the number of attention "
|
||||||
|
"heads (%d)" % (hidden_size, num_attention_heads))
|
||||||
|
|
||||||
|
attention_head_size = int(hidden_size / num_attention_heads)
|
||||||
|
input_shape = get_shape_list(input_tensor, expected_rank=3)
|
||||||
|
batch_size = input_shape[0]
|
||||||
|
seq_length = input_shape[1]
|
||||||
|
input_width = input_shape[2]
|
||||||
|
|
||||||
|
# The Transformer performs sum residuals on all layers so the input needs
|
||||||
|
# to be the same as the hidden size.
|
||||||
|
if input_width != hidden_size:
|
||||||
|
raise ValueError("The width of the input tensor (%d) != hidden size (%d)" %
|
||||||
|
(input_width, hidden_size))
|
||||||
|
|
||||||
|
# We keep the representation as a 2D tensor to avoid re-shaping it back and
|
||||||
|
# forth from a 3D tensor to a 2D tensor. Re-shapes are normally free on
|
||||||
|
# the GPU/CPU but may not be free on the TPU, so we want to minimize them to
|
||||||
|
# help the optimizer.
|
||||||
|
prev_output = reshape_to_matrix(input_tensor)
|
||||||
|
|
||||||
|
all_layer_outputs = []
|
||||||
|
for layer_idx in range(num_hidden_layers):
|
||||||
|
with tf.variable_scope("layer_%d" % layer_idx):
|
||||||
|
layer_input = prev_output
|
||||||
|
|
||||||
|
with tf.variable_scope("attention"):
|
||||||
|
attention_heads = []
|
||||||
|
with tf.variable_scope("self"):
|
||||||
|
attention_head = attention_layer(
|
||||||
|
from_tensor=layer_input,
|
||||||
|
to_tensor=layer_input,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
num_attention_heads=num_attention_heads,
|
||||||
|
size_per_head=attention_head_size,
|
||||||
|
attention_probs_dropout_prob=attention_probs_dropout_prob,
|
||||||
|
initializer_range=initializer_range,
|
||||||
|
do_return_2d_tensor=True,
|
||||||
|
batch_size=batch_size,
|
||||||
|
from_seq_length=seq_length,
|
||||||
|
to_seq_length=seq_length)
|
||||||
|
attention_heads.append(attention_head)
|
||||||
|
|
||||||
|
attention_output = None
|
||||||
|
if len(attention_heads) == 1:
|
||||||
|
attention_output = attention_heads[0]
|
||||||
|
else:
|
||||||
|
# In the case where we have other sequences, we just concatenate
|
||||||
|
# them to the self-attention head before the projection.
|
||||||
|
attention_output = tf.concat(attention_heads, axis=-1)
|
||||||
|
|
||||||
|
# Run a linear projection of `hidden_size` then add a residual
|
||||||
|
# with `layer_input`.
|
||||||
|
with tf.variable_scope("output"):
|
||||||
|
attention_output = tf.layers.dense(
|
||||||
|
attention_output,
|
||||||
|
hidden_size,
|
||||||
|
kernel_initializer=create_initializer(initializer_range))
|
||||||
|
attention_output = dropout(attention_output, hidden_dropout_prob)
|
||||||
|
attention_output = layer_norm(attention_output + layer_input)
|
||||||
|
|
||||||
|
# The activation is only applied to the "intermediate" hidden layer.
|
||||||
|
with tf.variable_scope("intermediate"):
|
||||||
|
intermediate_output = tf.layers.dense(
|
||||||
|
attention_output,
|
||||||
|
intermediate_size,
|
||||||
|
activation=intermediate_act_fn,
|
||||||
|
kernel_initializer=create_initializer(initializer_range))
|
||||||
|
|
||||||
|
# Down-project back to `hidden_size` then add the residual.
|
||||||
|
with tf.variable_scope("output"):
|
||||||
|
layer_output = tf.layers.dense(
|
||||||
|
intermediate_output,
|
||||||
|
hidden_size,
|
||||||
|
kernel_initializer=create_initializer(initializer_range))
|
||||||
|
layer_output = dropout(layer_output, hidden_dropout_prob)
|
||||||
|
layer_output = layer_norm(layer_output + attention_output)
|
||||||
|
prev_output = layer_output
|
||||||
|
all_layer_outputs.append(layer_output)
|
||||||
|
|
||||||
|
if do_return_all_layers:
|
||||||
|
final_outputs = []
|
||||||
|
for layer_output in all_layer_outputs:
|
||||||
|
final_output = reshape_from_matrix(layer_output, input_shape)
|
||||||
|
final_outputs.append(final_output)
|
||||||
|
return final_outputs
|
||||||
|
else:
|
||||||
|
final_output = reshape_from_matrix(prev_output, input_shape)
|
||||||
|
return final_output
|
||||||
|
|
||||||
|
|
||||||
|
def get_shape_list(tensor, expected_rank=None, name=None):
|
||||||
|
"""Returns a list of the shape of tensor, preferring static dimensions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor: A tf.Tensor object to find the shape of.
|
||||||
|
expected_rank: (optional) int. The expected rank of `tensor`. If this is
|
||||||
|
specified and the `tensor` has a different rank, and exception will be
|
||||||
|
thrown.
|
||||||
|
name: Optional name of the tensor for the error message.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of dimensions of the shape of tensor. All static dimensions will
|
||||||
|
be returned as python integers, and dynamic dimensions will be returned
|
||||||
|
as tf.Tensor scalars.
|
||||||
|
"""
|
||||||
|
if name is None:
|
||||||
|
name = tensor.name
|
||||||
|
|
||||||
|
if expected_rank is not None:
|
||||||
|
assert_rank(tensor, expected_rank, name)
|
||||||
|
|
||||||
|
shape = tensor.shape.as_list()
|
||||||
|
|
||||||
|
non_static_indexes = []
|
||||||
|
for (index, dim) in enumerate(shape):
|
||||||
|
if dim is None:
|
||||||
|
non_static_indexes.append(index)
|
||||||
|
|
||||||
|
if not non_static_indexes:
|
||||||
|
return shape
|
||||||
|
|
||||||
|
dyn_shape = tf.shape(tensor)
|
||||||
|
for index in non_static_indexes:
|
||||||
|
shape[index] = dyn_shape[index]
|
||||||
|
return shape
|
||||||
|
|
||||||
|
|
||||||
|
def reshape_to_matrix(input_tensor):
|
||||||
|
"""Reshapes a >= rank 2 tensor to a rank 2 tensor (i.e., a matrix)."""
|
||||||
|
ndims = input_tensor.shape.ndims
|
||||||
|
if ndims < 2:
|
||||||
|
raise ValueError("Input tensor must have at least rank 2. Shape = %s" %
|
||||||
|
(input_tensor.shape))
|
||||||
|
if ndims == 2:
|
||||||
|
return input_tensor
|
||||||
|
|
||||||
|
width = input_tensor.shape[-1]
|
||||||
|
output_tensor = tf.reshape(input_tensor, [-1, width])
|
||||||
|
return output_tensor
|
||||||
|
|
||||||
|
|
||||||
|
def reshape_from_matrix(output_tensor, orig_shape_list):
|
||||||
|
"""Reshapes a rank 2 tensor back to its original rank >= 2 tensor."""
|
||||||
|
if len(orig_shape_list) == 2:
|
||||||
|
return output_tensor
|
||||||
|
|
||||||
|
output_shape = get_shape_list(output_tensor)
|
||||||
|
|
||||||
|
orig_dims = orig_shape_list[0:-1]
|
||||||
|
width = output_shape[-1]
|
||||||
|
|
||||||
|
return tf.reshape(output_tensor, orig_dims + [width])
|
||||||
|
|
||||||
|
|
||||||
|
def assert_rank(tensor, expected_rank, name=None):
|
||||||
|
"""Raises an exception if the tensor rank is not of the expected rank.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor: A tf.Tensor to check the rank of.
|
||||||
|
expected_rank: Python integer or list of integers, expected rank.
|
||||||
|
name: Optional name of the tensor for the error message.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the expected shape doesn't match the actual shape.
|
||||||
|
"""
|
||||||
|
if name is None:
|
||||||
|
name = tensor.name
|
||||||
|
|
||||||
|
expected_rank_dict = {}
|
||||||
|
if isinstance(expected_rank, six.integer_types):
|
||||||
|
expected_rank_dict[expected_rank] = True
|
||||||
|
else:
|
||||||
|
for x in expected_rank:
|
||||||
|
expected_rank_dict[x] = True
|
||||||
|
|
||||||
|
actual_rank = tensor.shape.ndims
|
||||||
|
if actual_rank not in expected_rank_dict:
|
||||||
|
scope_name = tf.get_variable_scope().name
|
||||||
|
raise ValueError(
|
||||||
|
"For the tensor `%s` in scope `%s`, the actual rank "
|
||||||
|
"`%d` (shape = %s) is not equal to the expected rank `%s`" %
|
||||||
|
(name, scope_name, actual_rank, str(tensor.shape), str(expected_rank)))
|
277
modeling_test.py
Normal file
277
modeling_test.py
Normal file
@ -0,0 +1,277 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2018 The Google AI Language Team Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import collections
|
||||||
|
import json
|
||||||
|
import random
|
||||||
|
import re
|
||||||
|
|
||||||
|
import modeling
|
||||||
|
import six
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
|
class BertModelTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
class BertModelTester(object):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
parent,
|
||||||
|
batch_size=13,
|
||||||
|
seq_length=7,
|
||||||
|
is_training=True,
|
||||||
|
use_input_mask=True,
|
||||||
|
use_token_type_ids=True,
|
||||||
|
vocab_size=99,
|
||||||
|
hidden_size=32,
|
||||||
|
num_hidden_layers=5,
|
||||||
|
num_attention_heads=4,
|
||||||
|
intermediate_size=37,
|
||||||
|
hidden_act="gelu",
|
||||||
|
hidden_dropout_prob=0.1,
|
||||||
|
attention_probs_dropout_prob=0.1,
|
||||||
|
max_position_embeddings=512,
|
||||||
|
type_vocab_size=16,
|
||||||
|
initializer_range=0.02,
|
||||||
|
scope=None):
|
||||||
|
self.parent = parent
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.seq_length = seq_length
|
||||||
|
self.is_training = is_training
|
||||||
|
self.use_input_mask = use_input_mask
|
||||||
|
self.use_token_type_ids = use_token_type_ids
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.hidden_dropout_prob = hidden_dropout_prob
|
||||||
|
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.type_vocab_size = type_vocab_size
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.scope = scope
|
||||||
|
|
||||||
|
def create_model(self):
|
||||||
|
input_ids = BertModelTest.ids_tensor([self.batch_size, self.seq_length],
|
||||||
|
self.vocab_size)
|
||||||
|
|
||||||
|
input_mask = None
|
||||||
|
if self.use_input_mask:
|
||||||
|
input_mask = BertModelTest.ids_tensor(
|
||||||
|
[self.batch_size, self.seq_length], vocab_size=2)
|
||||||
|
|
||||||
|
token_type_ids = None
|
||||||
|
if self.use_token_type_ids:
|
||||||
|
token_type_ids = BertModelTest.ids_tensor(
|
||||||
|
[self.batch_size, self.seq_length], self.type_vocab_size)
|
||||||
|
|
||||||
|
config = modeling.BertConfig(
|
||||||
|
vocab_size=self.vocab_size,
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
|
num_hidden_layers=self.num_hidden_layers,
|
||||||
|
num_attention_heads=self.num_attention_heads,
|
||||||
|
intermediate_size=self.intermediate_size,
|
||||||
|
hidden_act=self.hidden_act,
|
||||||
|
hidden_dropout_prob=self.hidden_dropout_prob,
|
||||||
|
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||||
|
max_position_embeddings=self.max_position_embeddings,
|
||||||
|
type_vocab_size=self.type_vocab_size,
|
||||||
|
initializer_range=self.initializer_range)
|
||||||
|
|
||||||
|
model = modeling.BertModel(
|
||||||
|
config=config,
|
||||||
|
is_training=self.is_training,
|
||||||
|
input_ids=input_ids,
|
||||||
|
input_mask=input_mask,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
scope=self.scope)
|
||||||
|
|
||||||
|
outputs = {
|
||||||
|
"embedding_output": model.get_embedding_output(),
|
||||||
|
"sequence_output": model.get_sequence_output(),
|
||||||
|
"pooled_output": model.get_pooled_output(),
|
||||||
|
"all_encoder_layers": model.get_all_encoder_layers(),
|
||||||
|
}
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def check_output(self, result):
|
||||||
|
self.parent.assertAllEqual(
|
||||||
|
result["embedding_output"].shape,
|
||||||
|
[self.batch_size, self.seq_length, self.hidden_size])
|
||||||
|
|
||||||
|
self.parent.assertAllEqual(
|
||||||
|
result["sequence_output"].shape,
|
||||||
|
[self.batch_size, self.seq_length, self.hidden_size])
|
||||||
|
|
||||||
|
self.parent.assertAllEqual(result["pooled_output"].shape,
|
||||||
|
[self.batch_size, self.hidden_size])
|
||||||
|
|
||||||
|
def test_default(self):
|
||||||
|
self.run_tester(BertModelTest.BertModelTester(self))
|
||||||
|
|
||||||
|
def test_config_to_json_string(self):
|
||||||
|
config = modeling.BertConfig(vocab_size=99, hidden_size=37)
|
||||||
|
obj = json.loads(config.to_json_string())
|
||||||
|
self.assertEqual(obj["vocab_size"], 99)
|
||||||
|
self.assertEqual(obj["hidden_size"], 37)
|
||||||
|
|
||||||
|
def run_tester(self, tester):
|
||||||
|
with self.test_session() as sess:
|
||||||
|
ops = tester.create_model()
|
||||||
|
init_op = tf.group(tf.global_variables_initializer(),
|
||||||
|
tf.local_variables_initializer())
|
||||||
|
sess.run(init_op)
|
||||||
|
output_result = sess.run(ops)
|
||||||
|
tester.check_output(output_result)
|
||||||
|
|
||||||
|
self.assert_all_tensors_reachable(sess, [init_op, ops])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def ids_tensor(cls, shape, vocab_size, rng=None, name=None):
|
||||||
|
"""Creates a random int32 tensor of the shape within the vocab size."""
|
||||||
|
if rng is None:
|
||||||
|
rng = random.Random()
|
||||||
|
|
||||||
|
total_dims = 1
|
||||||
|
for dim in shape:
|
||||||
|
total_dims *= dim
|
||||||
|
|
||||||
|
values = []
|
||||||
|
for _ in range(total_dims):
|
||||||
|
values.append(rng.randint(0, vocab_size - 1))
|
||||||
|
|
||||||
|
return tf.constant(value=values, dtype=tf.int32, shape=shape, name=name)
|
||||||
|
|
||||||
|
def assert_all_tensors_reachable(self, sess, outputs):
|
||||||
|
"""Checks that all the tensors in the graph are reachable from outputs."""
|
||||||
|
graph = sess.graph
|
||||||
|
|
||||||
|
ignore_strings = [
|
||||||
|
"^.*/assert_less_equal/.*$",
|
||||||
|
"^.*/dilation_rate$",
|
||||||
|
"^.*/Tensordot/concat$",
|
||||||
|
"^.*/Tensordot/concat/axis$",
|
||||||
|
"^testing/.*$",
|
||||||
|
]
|
||||||
|
|
||||||
|
ignore_regexes = [re.compile(x) for x in ignore_strings]
|
||||||
|
|
||||||
|
unreachable = self.get_unreachable_ops(graph, outputs)
|
||||||
|
filtered_unreachable = []
|
||||||
|
for x in unreachable:
|
||||||
|
do_ignore = False
|
||||||
|
for r in ignore_regexes:
|
||||||
|
m = r.match(x.name)
|
||||||
|
if m is not None:
|
||||||
|
do_ignore = True
|
||||||
|
if do_ignore:
|
||||||
|
continue
|
||||||
|
filtered_unreachable.append(x)
|
||||||
|
unreachable = filtered_unreachable
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
len(unreachable), 0, "The following ops are unreachable: %s" %
|
||||||
|
(" ".join([x.name for x in unreachable])))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_unreachable_ops(cls, graph, outputs):
|
||||||
|
"""Finds all of the tensors in graph that are unreachable from outputs."""
|
||||||
|
outputs = cls.flatten_recursive(outputs)
|
||||||
|
output_to_op = collections.defaultdict(list)
|
||||||
|
op_to_all = collections.defaultdict(list)
|
||||||
|
assign_out_to_in = collections.defaultdict(list)
|
||||||
|
|
||||||
|
for op in graph.get_operations():
|
||||||
|
for x in op.inputs:
|
||||||
|
op_to_all[op.name].append(x.name)
|
||||||
|
for y in op.outputs:
|
||||||
|
output_to_op[y.name].append(op.name)
|
||||||
|
op_to_all[op.name].append(y.name)
|
||||||
|
if str(op.type) == "Assign":
|
||||||
|
for y in op.outputs:
|
||||||
|
for x in op.inputs:
|
||||||
|
assign_out_to_in[y.name].append(x.name)
|
||||||
|
|
||||||
|
assign_groups = collections.defaultdict(list)
|
||||||
|
for out_name in assign_out_to_in.keys():
|
||||||
|
name_group = assign_out_to_in[out_name]
|
||||||
|
for n1 in name_group:
|
||||||
|
assign_groups[n1].append(out_name)
|
||||||
|
for n2 in name_group:
|
||||||
|
if n1 != n2:
|
||||||
|
assign_groups[n1].append(n2)
|
||||||
|
|
||||||
|
seen_tensors = {}
|
||||||
|
stack = [x.name for x in outputs]
|
||||||
|
while stack:
|
||||||
|
name = stack.pop()
|
||||||
|
if name in seen_tensors:
|
||||||
|
continue
|
||||||
|
seen_tensors[name] = True
|
||||||
|
|
||||||
|
if name in output_to_op:
|
||||||
|
for op_name in output_to_op[name]:
|
||||||
|
if op_name in op_to_all:
|
||||||
|
for input_name in op_to_all[op_name]:
|
||||||
|
if input_name not in stack:
|
||||||
|
stack.append(input_name)
|
||||||
|
|
||||||
|
expanded_names = []
|
||||||
|
if name in assign_groups:
|
||||||
|
for assign_name in assign_groups[name]:
|
||||||
|
expanded_names.append(assign_name)
|
||||||
|
|
||||||
|
for expanded_name in expanded_names:
|
||||||
|
if expanded_name not in stack:
|
||||||
|
stack.append(expanded_name)
|
||||||
|
|
||||||
|
unreachable_ops = []
|
||||||
|
for op in graph.get_operations():
|
||||||
|
is_unreachable = False
|
||||||
|
all_names = [x.name for x in op.inputs] + [x.name for x in op.outputs]
|
||||||
|
for name in all_names:
|
||||||
|
if name not in seen_tensors:
|
||||||
|
is_unreachable = True
|
||||||
|
if is_unreachable:
|
||||||
|
unreachable_ops.append(op)
|
||||||
|
return unreachable_ops
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def flatten_recursive(cls, item):
|
||||||
|
"""Flattens (potentially nested) a tuple/dictionary/list to a list."""
|
||||||
|
output = []
|
||||||
|
if isinstance(item, list):
|
||||||
|
output.extend(item)
|
||||||
|
elif isinstance(item, tuple):
|
||||||
|
output.extend(list(item))
|
||||||
|
elif isinstance(item, dict):
|
||||||
|
for (_, v) in six.iteritems(item):
|
||||||
|
output.append(v)
|
||||||
|
else:
|
||||||
|
return [item]
|
||||||
|
|
||||||
|
flat_output = []
|
||||||
|
for x in output:
|
||||||
|
flat_output.extend(cls.flatten_recursive(x))
|
||||||
|
return flat_output
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
tf.test.main()
|
174
optimization.py
Normal file
174
optimization.py
Normal file
@ -0,0 +1,174 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2018 The Google AI Language Team Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Functions and classes related to optimization (weight updates)."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import re
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
|
def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu):
|
||||||
|
"""Creates an optimizer training op."""
|
||||||
|
global_step = tf.train.get_or_create_global_step()
|
||||||
|
|
||||||
|
learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32)
|
||||||
|
|
||||||
|
# Implements linear decay of the learning rate.
|
||||||
|
learning_rate = tf.train.polynomial_decay(
|
||||||
|
learning_rate,
|
||||||
|
global_step,
|
||||||
|
num_train_steps,
|
||||||
|
end_learning_rate=0.0,
|
||||||
|
power=1.0,
|
||||||
|
cycle=False)
|
||||||
|
|
||||||
|
# Implements linear warmup. I.e., if global_step < num_warmup_steps, the
|
||||||
|
# learning rate will be `global_step/num_warmup_steps * init_lr`.
|
||||||
|
if num_warmup_steps:
|
||||||
|
global_steps_int = tf.cast(global_step, tf.int32)
|
||||||
|
warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32)
|
||||||
|
|
||||||
|
global_steps_float = tf.cast(global_steps_int, tf.float32)
|
||||||
|
warmup_steps_float = tf.cast(warmup_steps_int, tf.float32)
|
||||||
|
|
||||||
|
warmup_percent_done = global_steps_float / warmup_steps_float
|
||||||
|
warmup_learning_rate = init_lr * warmup_percent_done
|
||||||
|
|
||||||
|
is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32)
|
||||||
|
learning_rate = (
|
||||||
|
(1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate)
|
||||||
|
|
||||||
|
# It is recommended that you use this optimizer for fine tuning, since this
|
||||||
|
# is how the model was trained (note that the Adam m/v variables are NOT
|
||||||
|
# loaded from init_checkpoint.)
|
||||||
|
optimizer = AdamWeightDecayOptimizer(
|
||||||
|
learning_rate=learning_rate,
|
||||||
|
weight_decay_rate=0.01,
|
||||||
|
beta_1=0.9,
|
||||||
|
beta_2=0.999,
|
||||||
|
epsilon=1e-6,
|
||||||
|
exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"])
|
||||||
|
|
||||||
|
if use_tpu:
|
||||||
|
optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)
|
||||||
|
|
||||||
|
tvars = tf.trainable_variables()
|
||||||
|
grads = tf.gradients(loss, tvars)
|
||||||
|
|
||||||
|
# This is how the model was pre-trained.
|
||||||
|
(grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)
|
||||||
|
|
||||||
|
train_op = optimizer.apply_gradients(
|
||||||
|
zip(grads, tvars), global_step=global_step)
|
||||||
|
|
||||||
|
# Normally the global step update is done inside of `apply_gradients`.
|
||||||
|
# However, `AdamWeightDecayOptimizer` doesn't do this. But if you use
|
||||||
|
# a different optimizer, you should probably take this line out.
|
||||||
|
new_global_step = global_step + 1
|
||||||
|
train_op = tf.group(train_op, [global_step.assign(new_global_step)])
|
||||||
|
return train_op
|
||||||
|
|
||||||
|
|
||||||
|
class AdamWeightDecayOptimizer(tf.train.Optimizer):
|
||||||
|
"""A basic Adam optimizer that includes "correct" L2 weight decay."""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
learning_rate,
|
||||||
|
weight_decay_rate=0.0,
|
||||||
|
beta_1=0.9,
|
||||||
|
beta_2=0.999,
|
||||||
|
epsilon=1e-6,
|
||||||
|
exclude_from_weight_decay=None,
|
||||||
|
name="AdamWeightDecayOptimizer"):
|
||||||
|
"""Constructs a AdamWeightDecayOptimizer."""
|
||||||
|
super(AdamWeightDecayOptimizer, self).__init__(False, name)
|
||||||
|
|
||||||
|
self.learning_rate = learning_rate
|
||||||
|
self.weight_decay_rate = weight_decay_rate
|
||||||
|
self.beta_1 = beta_1
|
||||||
|
self.beta_2 = beta_2
|
||||||
|
self.epsilon = epsilon
|
||||||
|
self.exclude_from_weight_decay = exclude_from_weight_decay
|
||||||
|
|
||||||
|
def apply_gradients(self, grads_and_vars, global_step=None, name=None):
|
||||||
|
"""See base class."""
|
||||||
|
assignments = []
|
||||||
|
for (grad, param) in grads_and_vars:
|
||||||
|
if grad is None or param is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
param_name = self._get_variable_name(param.name)
|
||||||
|
|
||||||
|
m = tf.get_variable(
|
||||||
|
name=param_name + "/adam_m",
|
||||||
|
shape=param.shape.as_list(),
|
||||||
|
dtype=tf.float32,
|
||||||
|
trainable=False,
|
||||||
|
initializer=tf.zeros_initializer())
|
||||||
|
v = tf.get_variable(
|
||||||
|
name=param_name + "/adam_v",
|
||||||
|
shape=param.shape.as_list(),
|
||||||
|
dtype=tf.float32,
|
||||||
|
trainable=False,
|
||||||
|
initializer=tf.zeros_initializer())
|
||||||
|
|
||||||
|
# Standard Adam update.
|
||||||
|
next_m = (
|
||||||
|
tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad))
|
||||||
|
next_v = (
|
||||||
|
tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2,
|
||||||
|
tf.square(grad)))
|
||||||
|
|
||||||
|
update = next_m / (tf.sqrt(next_v) + self.epsilon)
|
||||||
|
|
||||||
|
# Just adding the square of the weights to the loss function is *not*
|
||||||
|
# the correct way of using L2 regularization/weight decay with Adam,
|
||||||
|
# since that will interact with the m and v parameters in strange ways.
|
||||||
|
#
|
||||||
|
# Instead we want ot decay the weights in a manner that doesn't interact
|
||||||
|
# with the m/v parameters. This is equivalent to adding the square
|
||||||
|
# of the weights to the loss with plain (non-momentum) SGD.
|
||||||
|
if self._do_use_weight_decay(param_name):
|
||||||
|
update += self.weight_decay_rate * param
|
||||||
|
|
||||||
|
update_with_lr = self.learning_rate * update
|
||||||
|
|
||||||
|
next_param = param - update_with_lr
|
||||||
|
|
||||||
|
assignments.extend(
|
||||||
|
[param.assign(next_param),
|
||||||
|
m.assign(next_m),
|
||||||
|
v.assign(next_v)])
|
||||||
|
return tf.group(*assignments, name=name)
|
||||||
|
|
||||||
|
def _do_use_weight_decay(self, param_name):
|
||||||
|
"""Whether to use L2 weight decay for `param_name`."""
|
||||||
|
if not self.weight_decay_rate:
|
||||||
|
return False
|
||||||
|
if self.exclude_from_weight_decay:
|
||||||
|
for r in self.exclude_from_weight_decay:
|
||||||
|
if re.search(r, param_name) is not None:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _get_variable_name(self, param_name):
|
||||||
|
"""Get the variable name from the tensor name."""
|
||||||
|
m = re.match("^(.*):\\d+$", param_name)
|
||||||
|
if m is not None:
|
||||||
|
param_name = m.group(1)
|
||||||
|
return param_name
|
48
optimization_test.py
Normal file
48
optimization_test.py
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2018 The Google AI Language Team Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import optimization
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
|
class OptimizationTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def test_adam(self):
|
||||||
|
with self.test_session() as sess:
|
||||||
|
w = tf.get_variable(
|
||||||
|
"w",
|
||||||
|
shape=[3],
|
||||||
|
initializer=tf.constant_initializer([0.1, -0.2, -0.1]))
|
||||||
|
x = tf.constant([0.4, 0.2, -0.5])
|
||||||
|
loss = tf.reduce_mean(tf.square(x - w))
|
||||||
|
tvars = tf.trainable_variables()
|
||||||
|
grads = tf.gradients(loss, tvars)
|
||||||
|
global_step = tf.train.get_or_create_global_step()
|
||||||
|
optimizer = optimization.AdamWeightDecayOptimizer(learning_rate=0.2)
|
||||||
|
train_op = optimizer.apply_gradients(zip(grads, tvars), global_step)
|
||||||
|
init_op = tf.group(tf.global_variables_initializer(),
|
||||||
|
tf.local_variables_initializer())
|
||||||
|
sess.run(init_op)
|
||||||
|
for _ in range(100):
|
||||||
|
sess.run(train_op)
|
||||||
|
w_np = sess.run(w)
|
||||||
|
self.assertAllClose(w_np.flat, [0.4, 0.2, -0.5], rtol=1e-2, atol=1e-2)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
tf.test.main()
|
1231
predicting_movie_reviews_with_bert_on_tf_hub.ipynb
Normal file
1231
predicting_movie_reviews_with_bert_on_tf_hub.ipynb
Normal file
File diff suppressed because it is too large
Load Diff
1056
run_classifier.py
Normal file
1056
run_classifier.py
Normal file
File diff suppressed because it is too large
Load Diff
314
run_classifier_with_tfhub.py
Normal file
314
run_classifier_with_tfhub.py
Normal file
@ -0,0 +1,314 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2018 The Google AI Language Team Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""BERT finetuning runner with TF-Hub."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os
|
||||||
|
import optimization
|
||||||
|
import run_classifier
|
||||||
|
import tokenization
|
||||||
|
import tensorflow as tf
|
||||||
|
import tensorflow_hub as hub
|
||||||
|
|
||||||
|
flags = tf.flags
|
||||||
|
|
||||||
|
FLAGS = flags.FLAGS
|
||||||
|
|
||||||
|
flags.DEFINE_string(
|
||||||
|
"bert_hub_module_handle", None,
|
||||||
|
"Handle for the BERT TF-Hub module.")
|
||||||
|
|
||||||
|
|
||||||
|
def create_model(is_training, input_ids, input_mask, segment_ids, labels,
|
||||||
|
num_labels, bert_hub_module_handle):
|
||||||
|
"""Creates a classification model."""
|
||||||
|
tags = set()
|
||||||
|
if is_training:
|
||||||
|
tags.add("train")
|
||||||
|
bert_module = hub.Module(bert_hub_module_handle, tags=tags, trainable=True)
|
||||||
|
bert_inputs = dict(
|
||||||
|
input_ids=input_ids,
|
||||||
|
input_mask=input_mask,
|
||||||
|
segment_ids=segment_ids)
|
||||||
|
bert_outputs = bert_module(
|
||||||
|
inputs=bert_inputs,
|
||||||
|
signature="tokens",
|
||||||
|
as_dict=True)
|
||||||
|
|
||||||
|
# In the demo, we are doing a simple classification task on the entire
|
||||||
|
# segment.
|
||||||
|
#
|
||||||
|
# If you want to use the token-level output, use
|
||||||
|
# bert_outputs["sequence_output"] instead.
|
||||||
|
output_layer = bert_outputs["pooled_output"]
|
||||||
|
|
||||||
|
hidden_size = output_layer.shape[-1].value
|
||||||
|
|
||||||
|
output_weights = tf.get_variable(
|
||||||
|
"output_weights", [num_labels, hidden_size],
|
||||||
|
initializer=tf.truncated_normal_initializer(stddev=0.02))
|
||||||
|
|
||||||
|
output_bias = tf.get_variable(
|
||||||
|
"output_bias", [num_labels], initializer=tf.zeros_initializer())
|
||||||
|
|
||||||
|
with tf.variable_scope("loss"):
|
||||||
|
if is_training:
|
||||||
|
# I.e., 0.1 dropout
|
||||||
|
output_layer = tf.nn.dropout(output_layer, keep_prob=0.9)
|
||||||
|
|
||||||
|
logits = tf.matmul(output_layer, output_weights, transpose_b=True)
|
||||||
|
logits = tf.nn.bias_add(logits, output_bias)
|
||||||
|
probabilities = tf.nn.softmax(logits, axis=-1)
|
||||||
|
log_probs = tf.nn.log_softmax(logits, axis=-1)
|
||||||
|
|
||||||
|
one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32)
|
||||||
|
|
||||||
|
per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)
|
||||||
|
loss = tf.reduce_mean(per_example_loss)
|
||||||
|
|
||||||
|
return (loss, per_example_loss, logits, probabilities)
|
||||||
|
|
||||||
|
|
||||||
|
def model_fn_builder(num_labels, learning_rate, num_train_steps,
|
||||||
|
num_warmup_steps, use_tpu, bert_hub_module_handle):
|
||||||
|
"""Returns `model_fn` closure for TPUEstimator."""
|
||||||
|
|
||||||
|
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
|
||||||
|
"""The `model_fn` for TPUEstimator."""
|
||||||
|
|
||||||
|
tf.logging.info("*** Features ***")
|
||||||
|
for name in sorted(features.keys()):
|
||||||
|
tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape))
|
||||||
|
|
||||||
|
input_ids = features["input_ids"]
|
||||||
|
input_mask = features["input_mask"]
|
||||||
|
segment_ids = features["segment_ids"]
|
||||||
|
label_ids = features["label_ids"]
|
||||||
|
|
||||||
|
is_training = (mode == tf.estimator.ModeKeys.TRAIN)
|
||||||
|
|
||||||
|
(total_loss, per_example_loss, logits, probabilities) = create_model(
|
||||||
|
is_training, input_ids, input_mask, segment_ids, label_ids, num_labels,
|
||||||
|
bert_hub_module_handle)
|
||||||
|
|
||||||
|
output_spec = None
|
||||||
|
if mode == tf.estimator.ModeKeys.TRAIN:
|
||||||
|
train_op = optimization.create_optimizer(
|
||||||
|
total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu)
|
||||||
|
|
||||||
|
output_spec = tf.contrib.tpu.TPUEstimatorSpec(
|
||||||
|
mode=mode,
|
||||||
|
loss=total_loss,
|
||||||
|
train_op=train_op)
|
||||||
|
elif mode == tf.estimator.ModeKeys.EVAL:
|
||||||
|
|
||||||
|
def metric_fn(per_example_loss, label_ids, logits):
|
||||||
|
predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
|
||||||
|
accuracy = tf.metrics.accuracy(label_ids, predictions)
|
||||||
|
loss = tf.metrics.mean(per_example_loss)
|
||||||
|
return {
|
||||||
|
"eval_accuracy": accuracy,
|
||||||
|
"eval_loss": loss,
|
||||||
|
}
|
||||||
|
|
||||||
|
eval_metrics = (metric_fn, [per_example_loss, label_ids, logits])
|
||||||
|
output_spec = tf.contrib.tpu.TPUEstimatorSpec(
|
||||||
|
mode=mode,
|
||||||
|
loss=total_loss,
|
||||||
|
eval_metrics=eval_metrics)
|
||||||
|
elif mode == tf.estimator.ModeKeys.PREDICT:
|
||||||
|
output_spec = tf.contrib.tpu.TPUEstimatorSpec(
|
||||||
|
mode=mode, predictions={"probabilities": probabilities})
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Only TRAIN, EVAL and PREDICT modes are supported: %s" % (mode))
|
||||||
|
|
||||||
|
return output_spec
|
||||||
|
|
||||||
|
return model_fn
|
||||||
|
|
||||||
|
|
||||||
|
def create_tokenizer_from_hub_module(bert_hub_module_handle):
|
||||||
|
"""Get the vocab file and casing info from the Hub module."""
|
||||||
|
with tf.Graph().as_default():
|
||||||
|
bert_module = hub.Module(bert_hub_module_handle)
|
||||||
|
tokenization_info = bert_module(signature="tokenization_info", as_dict=True)
|
||||||
|
with tf.Session() as sess:
|
||||||
|
vocab_file, do_lower_case = sess.run([tokenization_info["vocab_file"],
|
||||||
|
tokenization_info["do_lower_case"]])
|
||||||
|
return tokenization.FullTokenizer(
|
||||||
|
vocab_file=vocab_file, do_lower_case=do_lower_case)
|
||||||
|
|
||||||
|
|
||||||
|
def main(_):
|
||||||
|
tf.logging.set_verbosity(tf.logging.INFO)
|
||||||
|
|
||||||
|
processors = {
|
||||||
|
"cola": run_classifier.ColaProcessor,
|
||||||
|
"mnli": run_classifier.MnliProcessor,
|
||||||
|
"mrpc": run_classifier.MrpcProcessor,
|
||||||
|
}
|
||||||
|
|
||||||
|
if not FLAGS.do_train and not FLAGS.do_eval:
|
||||||
|
raise ValueError("At least one of `do_train` or `do_eval` must be True.")
|
||||||
|
|
||||||
|
tf.gfile.MakeDirs(FLAGS.output_dir)
|
||||||
|
|
||||||
|
task_name = FLAGS.task_name.lower()
|
||||||
|
|
||||||
|
if task_name not in processors:
|
||||||
|
raise ValueError("Task not found: %s" % (task_name))
|
||||||
|
|
||||||
|
processor = processors[task_name]()
|
||||||
|
|
||||||
|
label_list = processor.get_labels()
|
||||||
|
|
||||||
|
tokenizer = create_tokenizer_from_hub_module(FLAGS.bert_hub_module_handle)
|
||||||
|
|
||||||
|
tpu_cluster_resolver = None
|
||||||
|
if FLAGS.use_tpu and FLAGS.tpu_name:
|
||||||
|
tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
|
||||||
|
FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
|
||||||
|
|
||||||
|
is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
|
||||||
|
run_config = tf.contrib.tpu.RunConfig(
|
||||||
|
cluster=tpu_cluster_resolver,
|
||||||
|
master=FLAGS.master,
|
||||||
|
model_dir=FLAGS.output_dir,
|
||||||
|
save_checkpoints_steps=FLAGS.save_checkpoints_steps,
|
||||||
|
tpu_config=tf.contrib.tpu.TPUConfig(
|
||||||
|
iterations_per_loop=FLAGS.iterations_per_loop,
|
||||||
|
num_shards=FLAGS.num_tpu_cores,
|
||||||
|
per_host_input_for_training=is_per_host))
|
||||||
|
|
||||||
|
train_examples = None
|
||||||
|
num_train_steps = None
|
||||||
|
num_warmup_steps = None
|
||||||
|
if FLAGS.do_train:
|
||||||
|
train_examples = processor.get_train_examples(FLAGS.data_dir)
|
||||||
|
num_train_steps = int(
|
||||||
|
len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs)
|
||||||
|
num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)
|
||||||
|
|
||||||
|
model_fn = model_fn_builder(
|
||||||
|
num_labels=len(label_list),
|
||||||
|
learning_rate=FLAGS.learning_rate,
|
||||||
|
num_train_steps=num_train_steps,
|
||||||
|
num_warmup_steps=num_warmup_steps,
|
||||||
|
use_tpu=FLAGS.use_tpu,
|
||||||
|
bert_hub_module_handle=FLAGS.bert_hub_module_handle)
|
||||||
|
|
||||||
|
# If TPU is not available, this will fall back to normal Estimator on CPU
|
||||||
|
# or GPU.
|
||||||
|
estimator = tf.contrib.tpu.TPUEstimator(
|
||||||
|
use_tpu=FLAGS.use_tpu,
|
||||||
|
model_fn=model_fn,
|
||||||
|
config=run_config,
|
||||||
|
train_batch_size=FLAGS.train_batch_size,
|
||||||
|
eval_batch_size=FLAGS.eval_batch_size,
|
||||||
|
predict_batch_size=FLAGS.predict_batch_size)
|
||||||
|
|
||||||
|
if FLAGS.do_train:
|
||||||
|
train_features = run_classifier.convert_examples_to_features(
|
||||||
|
train_examples, label_list, FLAGS.max_seq_length, tokenizer)
|
||||||
|
tf.logging.info("***** Running training *****")
|
||||||
|
tf.logging.info(" Num examples = %d", len(train_examples))
|
||||||
|
tf.logging.info(" Batch size = %d", FLAGS.train_batch_size)
|
||||||
|
tf.logging.info(" Num steps = %d", num_train_steps)
|
||||||
|
train_input_fn = run_classifier.input_fn_builder(
|
||||||
|
features=train_features,
|
||||||
|
seq_length=FLAGS.max_seq_length,
|
||||||
|
is_training=True,
|
||||||
|
drop_remainder=True)
|
||||||
|
estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)
|
||||||
|
|
||||||
|
if FLAGS.do_eval:
|
||||||
|
eval_examples = processor.get_dev_examples(FLAGS.data_dir)
|
||||||
|
eval_features = run_classifier.convert_examples_to_features(
|
||||||
|
eval_examples, label_list, FLAGS.max_seq_length, tokenizer)
|
||||||
|
|
||||||
|
tf.logging.info("***** Running evaluation *****")
|
||||||
|
tf.logging.info(" Num examples = %d", len(eval_examples))
|
||||||
|
tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size)
|
||||||
|
|
||||||
|
# This tells the estimator to run through the entire set.
|
||||||
|
eval_steps = None
|
||||||
|
# However, if running eval on the TPU, you will need to specify the
|
||||||
|
# number of steps.
|
||||||
|
if FLAGS.use_tpu:
|
||||||
|
# Eval will be slightly WRONG on the TPU because it will truncate
|
||||||
|
# the last batch.
|
||||||
|
eval_steps = int(len(eval_examples) / FLAGS.eval_batch_size)
|
||||||
|
|
||||||
|
eval_drop_remainder = True if FLAGS.use_tpu else False
|
||||||
|
eval_input_fn = run_classifier.input_fn_builder(
|
||||||
|
features=eval_features,
|
||||||
|
seq_length=FLAGS.max_seq_length,
|
||||||
|
is_training=False,
|
||||||
|
drop_remainder=eval_drop_remainder)
|
||||||
|
|
||||||
|
result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps)
|
||||||
|
|
||||||
|
output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
|
||||||
|
with tf.gfile.GFile(output_eval_file, "w") as writer:
|
||||||
|
tf.logging.info("***** Eval results *****")
|
||||||
|
for key in sorted(result.keys()):
|
||||||
|
tf.logging.info(" %s = %s", key, str(result[key]))
|
||||||
|
writer.write("%s = %s\n" % (key, str(result[key])))
|
||||||
|
|
||||||
|
if FLAGS.do_predict:
|
||||||
|
predict_examples = processor.get_test_examples(FLAGS.data_dir)
|
||||||
|
if FLAGS.use_tpu:
|
||||||
|
# Discard batch remainder if running on TPU
|
||||||
|
n = len(predict_examples)
|
||||||
|
predict_examples = predict_examples[:(n - n % FLAGS.predict_batch_size)]
|
||||||
|
|
||||||
|
predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record")
|
||||||
|
run_classifier.file_based_convert_examples_to_features(
|
||||||
|
predict_examples, label_list, FLAGS.max_seq_length, tokenizer,
|
||||||
|
predict_file)
|
||||||
|
|
||||||
|
tf.logging.info("***** Running prediction*****")
|
||||||
|
tf.logging.info(" Num examples = %d", len(predict_examples))
|
||||||
|
tf.logging.info(" Batch size = %d", FLAGS.predict_batch_size)
|
||||||
|
|
||||||
|
predict_input_fn = run_classifier.file_based_input_fn_builder(
|
||||||
|
input_file=predict_file,
|
||||||
|
seq_length=FLAGS.max_seq_length,
|
||||||
|
is_training=False,
|
||||||
|
drop_remainder=FLAGS.use_tpu)
|
||||||
|
|
||||||
|
result = estimator.predict(input_fn=predict_input_fn)
|
||||||
|
|
||||||
|
output_predict_file = os.path.join(FLAGS.output_dir, "test_results.tsv")
|
||||||
|
with tf.gfile.GFile(output_predict_file, "w") as writer:
|
||||||
|
tf.logging.info("***** Predict results *****")
|
||||||
|
for prediction in result:
|
||||||
|
probabilities = prediction["probabilities"]
|
||||||
|
output_line = "\t".join(
|
||||||
|
str(class_probability)
|
||||||
|
for class_probability in probabilities) + "\n"
|
||||||
|
writer.write(output_line)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
flags.mark_flag_as_required("data_dir")
|
||||||
|
flags.mark_flag_as_required("task_name")
|
||||||
|
flags.mark_flag_as_required("bert_hub_module_handle")
|
||||||
|
flags.mark_flag_as_required("output_dir")
|
||||||
|
tf.app.run()
|
493
run_pretraining.py
Normal file
493
run_pretraining.py
Normal file
@ -0,0 +1,493 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2018 The Google AI Language Team Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Run masked LM/next sentence masked_lm pre-training for BERT."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os
|
||||||
|
import modeling
|
||||||
|
import optimization
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
flags = tf.flags
|
||||||
|
|
||||||
|
FLAGS = flags.FLAGS
|
||||||
|
|
||||||
|
## Required parameters
|
||||||
|
flags.DEFINE_string(
|
||||||
|
"bert_config_file", None,
|
||||||
|
"The config json file corresponding to the pre-trained BERT model. "
|
||||||
|
"This specifies the model architecture.")
|
||||||
|
|
||||||
|
flags.DEFINE_string(
|
||||||
|
"input_file", None,
|
||||||
|
"Input TF example files (can be a glob or comma separated).")
|
||||||
|
|
||||||
|
flags.DEFINE_string(
|
||||||
|
"output_dir", None,
|
||||||
|
"The output directory where the model checkpoints will be written.")
|
||||||
|
|
||||||
|
## Other parameters
|
||||||
|
flags.DEFINE_string(
|
||||||
|
"init_checkpoint", None,
|
||||||
|
"Initial checkpoint (usually from a pre-trained BERT model).")
|
||||||
|
|
||||||
|
flags.DEFINE_integer(
|
||||||
|
"max_seq_length", 128,
|
||||||
|
"The maximum total input sequence length after WordPiece tokenization. "
|
||||||
|
"Sequences longer than this will be truncated, and sequences shorter "
|
||||||
|
"than this will be padded. Must match data generation.")
|
||||||
|
|
||||||
|
flags.DEFINE_integer(
|
||||||
|
"max_predictions_per_seq", 20,
|
||||||
|
"Maximum number of masked LM predictions per sequence. "
|
||||||
|
"Must match data generation.")
|
||||||
|
|
||||||
|
flags.DEFINE_bool("do_train", False, "Whether to run training.")
|
||||||
|
|
||||||
|
flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.")
|
||||||
|
|
||||||
|
flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.")
|
||||||
|
|
||||||
|
flags.DEFINE_integer("eval_batch_size", 8, "Total batch size for eval.")
|
||||||
|
|
||||||
|
flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.")
|
||||||
|
|
||||||
|
flags.DEFINE_integer("num_train_steps", 100000, "Number of training steps.")
|
||||||
|
|
||||||
|
flags.DEFINE_integer("num_warmup_steps", 10000, "Number of warmup steps.")
|
||||||
|
|
||||||
|
flags.DEFINE_integer("save_checkpoints_steps", 1000,
|
||||||
|
"How often to save the model checkpoint.")
|
||||||
|
|
||||||
|
flags.DEFINE_integer("iterations_per_loop", 1000,
|
||||||
|
"How many steps to make in each estimator call.")
|
||||||
|
|
||||||
|
flags.DEFINE_integer("max_eval_steps", 100, "Maximum number of eval steps.")
|
||||||
|
|
||||||
|
flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.")
|
||||||
|
|
||||||
|
tf.flags.DEFINE_string(
|
||||||
|
"tpu_name", None,
|
||||||
|
"The Cloud TPU to use for training. This should be either the name "
|
||||||
|
"used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 "
|
||||||
|
"url.")
|
||||||
|
|
||||||
|
tf.flags.DEFINE_string(
|
||||||
|
"tpu_zone", None,
|
||||||
|
"[Optional] GCE zone where the Cloud TPU is located in. If not "
|
||||||
|
"specified, we will attempt to automatically detect the GCE project from "
|
||||||
|
"metadata.")
|
||||||
|
|
||||||
|
tf.flags.DEFINE_string(
|
||||||
|
"gcp_project", None,
|
||||||
|
"[Optional] Project name for the Cloud TPU-enabled project. If not "
|
||||||
|
"specified, we will attempt to automatically detect the GCE project from "
|
||||||
|
"metadata.")
|
||||||
|
|
||||||
|
tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.")
|
||||||
|
|
||||||
|
flags.DEFINE_integer(
|
||||||
|
"num_tpu_cores", 8,
|
||||||
|
"Only used if `use_tpu` is True. Total number of TPU cores to use.")
|
||||||
|
|
||||||
|
|
||||||
|
def model_fn_builder(bert_config, init_checkpoint, learning_rate,
|
||||||
|
num_train_steps, num_warmup_steps, use_tpu,
|
||||||
|
use_one_hot_embeddings):
|
||||||
|
"""Returns `model_fn` closure for TPUEstimator."""
|
||||||
|
|
||||||
|
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
|
||||||
|
"""The `model_fn` for TPUEstimator."""
|
||||||
|
|
||||||
|
tf.logging.info("*** Features ***")
|
||||||
|
for name in sorted(features.keys()):
|
||||||
|
tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape))
|
||||||
|
|
||||||
|
input_ids = features["input_ids"]
|
||||||
|
input_mask = features["input_mask"]
|
||||||
|
segment_ids = features["segment_ids"]
|
||||||
|
masked_lm_positions = features["masked_lm_positions"]
|
||||||
|
masked_lm_ids = features["masked_lm_ids"]
|
||||||
|
masked_lm_weights = features["masked_lm_weights"]
|
||||||
|
next_sentence_labels = features["next_sentence_labels"]
|
||||||
|
|
||||||
|
is_training = (mode == tf.estimator.ModeKeys.TRAIN)
|
||||||
|
|
||||||
|
model = modeling.BertModel(
|
||||||
|
config=bert_config,
|
||||||
|
is_training=is_training,
|
||||||
|
input_ids=input_ids,
|
||||||
|
input_mask=input_mask,
|
||||||
|
token_type_ids=segment_ids,
|
||||||
|
use_one_hot_embeddings=use_one_hot_embeddings)
|
||||||
|
|
||||||
|
(masked_lm_loss,
|
||||||
|
masked_lm_example_loss, masked_lm_log_probs) = get_masked_lm_output(
|
||||||
|
bert_config, model.get_sequence_output(), model.get_embedding_table(),
|
||||||
|
masked_lm_positions, masked_lm_ids, masked_lm_weights)
|
||||||
|
|
||||||
|
(next_sentence_loss, next_sentence_example_loss,
|
||||||
|
next_sentence_log_probs) = get_next_sentence_output(
|
||||||
|
bert_config, model.get_pooled_output(), next_sentence_labels)
|
||||||
|
|
||||||
|
total_loss = masked_lm_loss + next_sentence_loss
|
||||||
|
|
||||||
|
tvars = tf.trainable_variables()
|
||||||
|
|
||||||
|
initialized_variable_names = {}
|
||||||
|
scaffold_fn = None
|
||||||
|
if init_checkpoint:
|
||||||
|
(assignment_map, initialized_variable_names
|
||||||
|
) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
|
||||||
|
if use_tpu:
|
||||||
|
|
||||||
|
def tpu_scaffold():
|
||||||
|
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
|
||||||
|
return tf.train.Scaffold()
|
||||||
|
|
||||||
|
scaffold_fn = tpu_scaffold
|
||||||
|
else:
|
||||||
|
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
|
||||||
|
|
||||||
|
tf.logging.info("**** Trainable Variables ****")
|
||||||
|
for var in tvars:
|
||||||
|
init_string = ""
|
||||||
|
if var.name in initialized_variable_names:
|
||||||
|
init_string = ", *INIT_FROM_CKPT*"
|
||||||
|
tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape,
|
||||||
|
init_string)
|
||||||
|
|
||||||
|
output_spec = None
|
||||||
|
if mode == tf.estimator.ModeKeys.TRAIN:
|
||||||
|
train_op = optimization.create_optimizer(
|
||||||
|
total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu)
|
||||||
|
|
||||||
|
output_spec = tf.contrib.tpu.TPUEstimatorSpec(
|
||||||
|
mode=mode,
|
||||||
|
loss=total_loss,
|
||||||
|
train_op=train_op,
|
||||||
|
scaffold_fn=scaffold_fn)
|
||||||
|
elif mode == tf.estimator.ModeKeys.EVAL:
|
||||||
|
|
||||||
|
def metric_fn(masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids,
|
||||||
|
masked_lm_weights, next_sentence_example_loss,
|
||||||
|
next_sentence_log_probs, next_sentence_labels):
|
||||||
|
"""Computes the loss and accuracy of the model."""
|
||||||
|
masked_lm_log_probs = tf.reshape(masked_lm_log_probs,
|
||||||
|
[-1, masked_lm_log_probs.shape[-1]])
|
||||||
|
masked_lm_predictions = tf.argmax(
|
||||||
|
masked_lm_log_probs, axis=-1, output_type=tf.int32)
|
||||||
|
masked_lm_example_loss = tf.reshape(masked_lm_example_loss, [-1])
|
||||||
|
masked_lm_ids = tf.reshape(masked_lm_ids, [-1])
|
||||||
|
masked_lm_weights = tf.reshape(masked_lm_weights, [-1])
|
||||||
|
masked_lm_accuracy = tf.metrics.accuracy(
|
||||||
|
labels=masked_lm_ids,
|
||||||
|
predictions=masked_lm_predictions,
|
||||||
|
weights=masked_lm_weights)
|
||||||
|
masked_lm_mean_loss = tf.metrics.mean(
|
||||||
|
values=masked_lm_example_loss, weights=masked_lm_weights)
|
||||||
|
|
||||||
|
next_sentence_log_probs = tf.reshape(
|
||||||
|
next_sentence_log_probs, [-1, next_sentence_log_probs.shape[-1]])
|
||||||
|
next_sentence_predictions = tf.argmax(
|
||||||
|
next_sentence_log_probs, axis=-1, output_type=tf.int32)
|
||||||
|
next_sentence_labels = tf.reshape(next_sentence_labels, [-1])
|
||||||
|
next_sentence_accuracy = tf.metrics.accuracy(
|
||||||
|
labels=next_sentence_labels, predictions=next_sentence_predictions)
|
||||||
|
next_sentence_mean_loss = tf.metrics.mean(
|
||||||
|
values=next_sentence_example_loss)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"masked_lm_accuracy": masked_lm_accuracy,
|
||||||
|
"masked_lm_loss": masked_lm_mean_loss,
|
||||||
|
"next_sentence_accuracy": next_sentence_accuracy,
|
||||||
|
"next_sentence_loss": next_sentence_mean_loss,
|
||||||
|
}
|
||||||
|
|
||||||
|
eval_metrics = (metric_fn, [
|
||||||
|
masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids,
|
||||||
|
masked_lm_weights, next_sentence_example_loss,
|
||||||
|
next_sentence_log_probs, next_sentence_labels
|
||||||
|
])
|
||||||
|
output_spec = tf.contrib.tpu.TPUEstimatorSpec(
|
||||||
|
mode=mode,
|
||||||
|
loss=total_loss,
|
||||||
|
eval_metrics=eval_metrics,
|
||||||
|
scaffold_fn=scaffold_fn)
|
||||||
|
else:
|
||||||
|
raise ValueError("Only TRAIN and EVAL modes are supported: %s" % (mode))
|
||||||
|
|
||||||
|
return output_spec
|
||||||
|
|
||||||
|
return model_fn
|
||||||
|
|
||||||
|
|
||||||
|
def get_masked_lm_output(bert_config, input_tensor, output_weights, positions,
|
||||||
|
label_ids, label_weights):
|
||||||
|
"""Get loss and log probs for the masked LM."""
|
||||||
|
input_tensor = gather_indexes(input_tensor, positions)
|
||||||
|
|
||||||
|
with tf.variable_scope("cls/predictions"):
|
||||||
|
# We apply one more non-linear transformation before the output layer.
|
||||||
|
# This matrix is not used after pre-training.
|
||||||
|
with tf.variable_scope("transform"):
|
||||||
|
input_tensor = tf.layers.dense(
|
||||||
|
input_tensor,
|
||||||
|
units=bert_config.hidden_size,
|
||||||
|
activation=modeling.get_activation(bert_config.hidden_act),
|
||||||
|
kernel_initializer=modeling.create_initializer(
|
||||||
|
bert_config.initializer_range))
|
||||||
|
input_tensor = modeling.layer_norm(input_tensor)
|
||||||
|
|
||||||
|
# The output weights are the same as the input embeddings, but there is
|
||||||
|
# an output-only bias for each token.
|
||||||
|
output_bias = tf.get_variable(
|
||||||
|
"output_bias",
|
||||||
|
shape=[bert_config.vocab_size],
|
||||||
|
initializer=tf.zeros_initializer())
|
||||||
|
logits = tf.matmul(input_tensor, output_weights, transpose_b=True)
|
||||||
|
logits = tf.nn.bias_add(logits, output_bias)
|
||||||
|
log_probs = tf.nn.log_softmax(logits, axis=-1)
|
||||||
|
|
||||||
|
label_ids = tf.reshape(label_ids, [-1])
|
||||||
|
label_weights = tf.reshape(label_weights, [-1])
|
||||||
|
|
||||||
|
one_hot_labels = tf.one_hot(
|
||||||
|
label_ids, depth=bert_config.vocab_size, dtype=tf.float32)
|
||||||
|
|
||||||
|
# The `positions` tensor might be zero-padded (if the sequence is too
|
||||||
|
# short to have the maximum number of predictions). The `label_weights`
|
||||||
|
# tensor has a value of 1.0 for every real prediction and 0.0 for the
|
||||||
|
# padding predictions.
|
||||||
|
per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1])
|
||||||
|
numerator = tf.reduce_sum(label_weights * per_example_loss)
|
||||||
|
denominator = tf.reduce_sum(label_weights) + 1e-5
|
||||||
|
loss = numerator / denominator
|
||||||
|
|
||||||
|
return (loss, per_example_loss, log_probs)
|
||||||
|
|
||||||
|
|
||||||
|
def get_next_sentence_output(bert_config, input_tensor, labels):
|
||||||
|
"""Get loss and log probs for the next sentence prediction."""
|
||||||
|
|
||||||
|
# Simple binary classification. Note that 0 is "next sentence" and 1 is
|
||||||
|
# "random sentence". This weight matrix is not used after pre-training.
|
||||||
|
with tf.variable_scope("cls/seq_relationship"):
|
||||||
|
output_weights = tf.get_variable(
|
||||||
|
"output_weights",
|
||||||
|
shape=[2, bert_config.hidden_size],
|
||||||
|
initializer=modeling.create_initializer(bert_config.initializer_range))
|
||||||
|
output_bias = tf.get_variable(
|
||||||
|
"output_bias", shape=[2], initializer=tf.zeros_initializer())
|
||||||
|
|
||||||
|
logits = tf.matmul(input_tensor, output_weights, transpose_b=True)
|
||||||
|
logits = tf.nn.bias_add(logits, output_bias)
|
||||||
|
log_probs = tf.nn.log_softmax(logits, axis=-1)
|
||||||
|
labels = tf.reshape(labels, [-1])
|
||||||
|
one_hot_labels = tf.one_hot(labels, depth=2, dtype=tf.float32)
|
||||||
|
per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)
|
||||||
|
loss = tf.reduce_mean(per_example_loss)
|
||||||
|
return (loss, per_example_loss, log_probs)
|
||||||
|
|
||||||
|
|
||||||
|
def gather_indexes(sequence_tensor, positions):
|
||||||
|
"""Gathers the vectors at the specific positions over a minibatch."""
|
||||||
|
sequence_shape = modeling.get_shape_list(sequence_tensor, expected_rank=3)
|
||||||
|
batch_size = sequence_shape[0]
|
||||||
|
seq_length = sequence_shape[1]
|
||||||
|
width = sequence_shape[2]
|
||||||
|
|
||||||
|
flat_offsets = tf.reshape(
|
||||||
|
tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1])
|
||||||
|
flat_positions = tf.reshape(positions + flat_offsets, [-1])
|
||||||
|
flat_sequence_tensor = tf.reshape(sequence_tensor,
|
||||||
|
[batch_size * seq_length, width])
|
||||||
|
output_tensor = tf.gather(flat_sequence_tensor, flat_positions)
|
||||||
|
return output_tensor
|
||||||
|
|
||||||
|
|
||||||
|
def input_fn_builder(input_files,
|
||||||
|
max_seq_length,
|
||||||
|
max_predictions_per_seq,
|
||||||
|
is_training,
|
||||||
|
num_cpu_threads=4):
|
||||||
|
"""Creates an `input_fn` closure to be passed to TPUEstimator."""
|
||||||
|
|
||||||
|
def input_fn(params):
|
||||||
|
"""The actual input function."""
|
||||||
|
batch_size = params["batch_size"]
|
||||||
|
|
||||||
|
name_to_features = {
|
||||||
|
"input_ids":
|
||||||
|
tf.FixedLenFeature([max_seq_length], tf.int64),
|
||||||
|
"input_mask":
|
||||||
|
tf.FixedLenFeature([max_seq_length], tf.int64),
|
||||||
|
"segment_ids":
|
||||||
|
tf.FixedLenFeature([max_seq_length], tf.int64),
|
||||||
|
"masked_lm_positions":
|
||||||
|
tf.FixedLenFeature([max_predictions_per_seq], tf.int64),
|
||||||
|
"masked_lm_ids":
|
||||||
|
tf.FixedLenFeature([max_predictions_per_seq], tf.int64),
|
||||||
|
"masked_lm_weights":
|
||||||
|
tf.FixedLenFeature([max_predictions_per_seq], tf.float32),
|
||||||
|
"next_sentence_labels":
|
||||||
|
tf.FixedLenFeature([1], tf.int64),
|
||||||
|
}
|
||||||
|
|
||||||
|
# For training, we want a lot of parallel reading and shuffling.
|
||||||
|
# For eval, we want no shuffling and parallel reading doesn't matter.
|
||||||
|
if is_training:
|
||||||
|
d = tf.data.Dataset.from_tensor_slices(tf.constant(input_files))
|
||||||
|
d = d.repeat()
|
||||||
|
d = d.shuffle(buffer_size=len(input_files))
|
||||||
|
|
||||||
|
# `cycle_length` is the number of parallel files that get read.
|
||||||
|
cycle_length = min(num_cpu_threads, len(input_files))
|
||||||
|
|
||||||
|
# `sloppy` mode means that the interleaving is not exact. This adds
|
||||||
|
# even more randomness to the training pipeline.
|
||||||
|
d = d.apply(
|
||||||
|
tf.contrib.data.parallel_interleave(
|
||||||
|
tf.data.TFRecordDataset,
|
||||||
|
sloppy=is_training,
|
||||||
|
cycle_length=cycle_length))
|
||||||
|
d = d.shuffle(buffer_size=100)
|
||||||
|
else:
|
||||||
|
d = tf.data.TFRecordDataset(input_files)
|
||||||
|
# Since we evaluate for a fixed number of steps we don't want to encounter
|
||||||
|
# out-of-range exceptions.
|
||||||
|
d = d.repeat()
|
||||||
|
|
||||||
|
# We must `drop_remainder` on training because the TPU requires fixed
|
||||||
|
# size dimensions. For eval, we assume we are evaluating on the CPU or GPU
|
||||||
|
# and we *don't* want to drop the remainder, otherwise we wont cover
|
||||||
|
# every sample.
|
||||||
|
d = d.apply(
|
||||||
|
tf.contrib.data.map_and_batch(
|
||||||
|
lambda record: _decode_record(record, name_to_features),
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_parallel_batches=num_cpu_threads,
|
||||||
|
drop_remainder=True))
|
||||||
|
return d
|
||||||
|
|
||||||
|
return input_fn
|
||||||
|
|
||||||
|
|
||||||
|
def _decode_record(record, name_to_features):
|
||||||
|
"""Decodes a record to a TensorFlow example."""
|
||||||
|
example = tf.parse_single_example(record, name_to_features)
|
||||||
|
|
||||||
|
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
|
||||||
|
# So cast all int64 to int32.
|
||||||
|
for name in list(example.keys()):
|
||||||
|
t = example[name]
|
||||||
|
if t.dtype == tf.int64:
|
||||||
|
t = tf.to_int32(t)
|
||||||
|
example[name] = t
|
||||||
|
|
||||||
|
return example
|
||||||
|
|
||||||
|
|
||||||
|
def main(_):
|
||||||
|
tf.logging.set_verbosity(tf.logging.INFO)
|
||||||
|
|
||||||
|
if not FLAGS.do_train and not FLAGS.do_eval:
|
||||||
|
raise ValueError("At least one of `do_train` or `do_eval` must be True.")
|
||||||
|
|
||||||
|
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
|
||||||
|
|
||||||
|
tf.gfile.MakeDirs(FLAGS.output_dir)
|
||||||
|
|
||||||
|
input_files = []
|
||||||
|
for input_pattern in FLAGS.input_file.split(","):
|
||||||
|
input_files.extend(tf.gfile.Glob(input_pattern))
|
||||||
|
|
||||||
|
tf.logging.info("*** Input Files ***")
|
||||||
|
for input_file in input_files:
|
||||||
|
tf.logging.info(" %s" % input_file)
|
||||||
|
|
||||||
|
tpu_cluster_resolver = None
|
||||||
|
if FLAGS.use_tpu and FLAGS.tpu_name:
|
||||||
|
tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
|
||||||
|
FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
|
||||||
|
|
||||||
|
is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
|
||||||
|
run_config = tf.contrib.tpu.RunConfig(
|
||||||
|
cluster=tpu_cluster_resolver,
|
||||||
|
master=FLAGS.master,
|
||||||
|
model_dir=FLAGS.output_dir,
|
||||||
|
save_checkpoints_steps=FLAGS.save_checkpoints_steps,
|
||||||
|
tpu_config=tf.contrib.tpu.TPUConfig(
|
||||||
|
iterations_per_loop=FLAGS.iterations_per_loop,
|
||||||
|
num_shards=FLAGS.num_tpu_cores,
|
||||||
|
per_host_input_for_training=is_per_host))
|
||||||
|
|
||||||
|
model_fn = model_fn_builder(
|
||||||
|
bert_config=bert_config,
|
||||||
|
init_checkpoint=FLAGS.init_checkpoint,
|
||||||
|
learning_rate=FLAGS.learning_rate,
|
||||||
|
num_train_steps=FLAGS.num_train_steps,
|
||||||
|
num_warmup_steps=FLAGS.num_warmup_steps,
|
||||||
|
use_tpu=FLAGS.use_tpu,
|
||||||
|
use_one_hot_embeddings=FLAGS.use_tpu)
|
||||||
|
|
||||||
|
# If TPU is not available, this will fall back to normal Estimator on CPU
|
||||||
|
# or GPU.
|
||||||
|
estimator = tf.contrib.tpu.TPUEstimator(
|
||||||
|
use_tpu=FLAGS.use_tpu,
|
||||||
|
model_fn=model_fn,
|
||||||
|
config=run_config,
|
||||||
|
train_batch_size=FLAGS.train_batch_size,
|
||||||
|
eval_batch_size=FLAGS.eval_batch_size)
|
||||||
|
|
||||||
|
if FLAGS.do_train:
|
||||||
|
tf.logging.info("***** Running training *****")
|
||||||
|
tf.logging.info(" Batch size = %d", FLAGS.train_batch_size)
|
||||||
|
train_input_fn = input_fn_builder(
|
||||||
|
input_files=input_files,
|
||||||
|
max_seq_length=FLAGS.max_seq_length,
|
||||||
|
max_predictions_per_seq=FLAGS.max_predictions_per_seq,
|
||||||
|
is_training=True)
|
||||||
|
estimator.train(input_fn=train_input_fn, max_steps=FLAGS.num_train_steps)
|
||||||
|
|
||||||
|
if FLAGS.do_eval:
|
||||||
|
tf.logging.info("***** Running evaluation *****")
|
||||||
|
tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size)
|
||||||
|
|
||||||
|
eval_input_fn = input_fn_builder(
|
||||||
|
input_files=input_files,
|
||||||
|
max_seq_length=FLAGS.max_seq_length,
|
||||||
|
max_predictions_per_seq=FLAGS.max_predictions_per_seq,
|
||||||
|
is_training=False)
|
||||||
|
|
||||||
|
result = estimator.evaluate(
|
||||||
|
input_fn=eval_input_fn, steps=FLAGS.max_eval_steps)
|
||||||
|
|
||||||
|
output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
|
||||||
|
with tf.gfile.GFile(output_eval_file, "w") as writer:
|
||||||
|
tf.logging.info("***** Eval results *****")
|
||||||
|
for key in sorted(result.keys()):
|
||||||
|
tf.logging.info(" %s = %s", key, str(result[key]))
|
||||||
|
writer.write("%s = %s\n" % (key, str(result[key])))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
flags.mark_flag_as_required("input_file")
|
||||||
|
flags.mark_flag_as_required("bert_config_file")
|
||||||
|
flags.mark_flag_as_required("output_dir")
|
||||||
|
tf.app.run()
|
1283
run_squad.py
Normal file
1283
run_squad.py
Normal file
File diff suppressed because it is too large
Load Diff
417
server.py
Normal file
417
server.py
Normal file
@ -0,0 +1,417 @@
|
|||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
import requests
|
||||||
|
import datetime
|
||||||
|
import time
|
||||||
|
import hashlib
|
||||||
|
import sqlite3
|
||||||
|
import pandas
|
||||||
|
import threading
|
||||||
|
import logging as log
|
||||||
|
|
||||||
|
server_url = "http://39.100.94.111:8083"
|
||||||
|
openid = "gpu-server-test1"
|
||||||
|
password = "1e327b070ab43fd071768a4d474f016adbbf3ea475577fe66a505d9e33b24f2f"
|
||||||
|
token = None
|
||||||
|
# 客户端代码
|
||||||
|
client_code = "dc9fbb4f4f0b84fa903058991af60e73556494af8a02ef69fb6a93217729f04b"
|
||||||
|
# 护照认证码
|
||||||
|
idcode = None
|
||||||
|
# 时间戳
|
||||||
|
timestamp = ""
|
||||||
|
# 单次最大处理句数
|
||||||
|
max_stn_num = 20000
|
||||||
|
# 当前处理的bpt的序号
|
||||||
|
bpt_id = 0
|
||||||
|
# STNS
|
||||||
|
stn_list = []
|
||||||
|
# 输入数据存储表
|
||||||
|
predict_table = "predict_data"
|
||||||
|
# 模型处理结果输出文件夹
|
||||||
|
result_out_dir = "./tmp/eppredict"
|
||||||
|
# 初始化标志位
|
||||||
|
base_init = False
|
||||||
|
|
||||||
|
log.basicConfig(filename=None, format="%(asctime)s %(levelname)s [%(funcName)s] : %(message)s", level=log.INFO)
|
||||||
|
|
||||||
|
|
||||||
|
def get_timestamp():
|
||||||
|
return str(int(time.mktime(datetime.datetime.now().timetuple())) * 1000)
|
||||||
|
|
||||||
|
|
||||||
|
base_headers = {"timestamp": get_timestamp(), "X-Requested-With": ""}
|
||||||
|
token_headers = {"timestamp": get_timestamp(), "X-Requested-With": "", "signed": "", "openid": openid}
|
||||||
|
|
||||||
|
|
||||||
|
# url对象
|
||||||
|
def url_parser(url):
|
||||||
|
return server_url + "/" + url
|
||||||
|
|
||||||
|
|
||||||
|
# 计算随机特征值
|
||||||
|
def calculate_random_code():
|
||||||
|
return hashlib.sha1("RandomCode [{0}][{1}][{2}]".format(openid, get_timestamp(), client_code).encode("utf-8")) \
|
||||||
|
.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
# 计算客户端签名
|
||||||
|
def calculate_signed():
|
||||||
|
return hashlib.sha1("SIGN [{0}][{1}][{2}]".format(openid, calculate_random_code(), token).encode("utf-8")) \
|
||||||
|
.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
# 检查用户是否存在
|
||||||
|
def user_checker():
|
||||||
|
log.info("Check User Existence: openid" + str(openid))
|
||||||
|
checker_param = {"openid": openid}
|
||||||
|
base_headers["timestamp"] = get_timestamp()
|
||||||
|
res = requests.get(url=url_parser("user"), headers=base_headers, params=checker_param)
|
||||||
|
if res.status_code == 404:
|
||||||
|
log.warning("User Not Exist: openid" + str(openid))
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
log.info("User Exist: openid " + str(openid))
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
# 注册用户
|
||||||
|
def user_register():
|
||||||
|
if not user_checker():
|
||||||
|
log.info("Try Creating New User: openid " + str(openid))
|
||||||
|
register_json = {"openid": openid, "password": password}
|
||||||
|
register_param = {"clientCode": client_code}
|
||||||
|
base_headers["timestamp"] = get_timestamp()
|
||||||
|
res = requests.post(url=url_parser("user/cs"), headers=base_headers, json=register_json, params=register_param)
|
||||||
|
respond_json = res.json()
|
||||||
|
if res.status_code == 201 and respond_json["openid"] == openid:
|
||||||
|
log.info("User Creation Success: openid " + str(openid))
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
log.error("User Creation Failed: openid " + str(openid))
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
# 获得token
|
||||||
|
def get_token():
|
||||||
|
if user_checker():
|
||||||
|
log.info("Try Getting New Token")
|
||||||
|
login_json = {"openid": openid, "password": password, "clientCode": client_code}
|
||||||
|
res = requests.post(url=url_parser("user/login"), headers=base_headers, json=login_json)
|
||||||
|
respond_json = res.json()
|
||||||
|
if res.status_code == 200 and respond_json["info"] == "Authentication Success":
|
||||||
|
global token
|
||||||
|
token = respond_json["data"]["token"]
|
||||||
|
log.info("Succeed In Getting New Token" + str(token))
|
||||||
|
else:
|
||||||
|
if base_init is True:
|
||||||
|
user_register()
|
||||||
|
log.error("Fail To Get New Token")
|
||||||
|
|
||||||
|
|
||||||
|
# 获得子服务器护照
|
||||||
|
def get_csp():
|
||||||
|
global idcode
|
||||||
|
if token is not None:
|
||||||
|
log.info("Try Getting New CSP")
|
||||||
|
# 计算客户端签名
|
||||||
|
token_headers["signed"] = calculate_signed()
|
||||||
|
token_headers["timestamp"] = get_timestamp()
|
||||||
|
res = requests.post(url=url_parser("cs"), headers=token_headers)
|
||||||
|
respond_json = res.json()
|
||||||
|
log.debug(respond_json)
|
||||||
|
# 正常返回
|
||||||
|
if res.status_code == 200:
|
||||||
|
# 无权限检查
|
||||||
|
try:
|
||||||
|
idcode = respond_json["identityCode"]
|
||||||
|
log.info("Succeed In Getting CSP: idcode " + str(idcode))
|
||||||
|
except KeyError:
|
||||||
|
if respond_json["status"] == 401:
|
||||||
|
log.warning("Token OUT OF DATE: token " + str(token))
|
||||||
|
get_token()
|
||||||
|
return
|
||||||
|
|
||||||
|
# 无权限返回
|
||||||
|
elif res.status_code == 401:
|
||||||
|
# 重新获取token
|
||||||
|
log.warning("Token Maybe OUT OF DATE: token " + str(token))
|
||||||
|
log.info("Try to Get New Token")
|
||||||
|
get_token()
|
||||||
|
else:
|
||||||
|
log.error("Failed to get New CSP")
|
||||||
|
else:
|
||||||
|
get_token()
|
||||||
|
|
||||||
|
|
||||||
|
# 更新签证
|
||||||
|
def update_csp():
|
||||||
|
if idcode is not None:
|
||||||
|
token_headers["signed"] = calculate_signed()
|
||||||
|
token_headers["timestamp"] = get_timestamp()
|
||||||
|
res = requests.put(url=url_parser("cs"), headers=token_headers, params={"idcode": idcode})
|
||||||
|
respond_json = res.json()
|
||||||
|
log.debug(respond_json)
|
||||||
|
# 成功返回
|
||||||
|
if res.status_code == 200 and respond_json["expired"] is False:
|
||||||
|
log.info("Succeed IN Updating CSP: idcode " + str(idcode))
|
||||||
|
log.info("CSP Last Update Time: " + str(respond_json["lastUpdateTime"]))
|
||||||
|
elif res.status_code == 401:
|
||||||
|
# 尝试获得新的token
|
||||||
|
log.warning("Unauthorized Status Code: Try to Get New Token")
|
||||||
|
get_token()
|
||||||
|
else:
|
||||||
|
# 重新获得护照
|
||||||
|
log.warning("CSP Maybe OUT OF DATE: idcode " + str(idcode))
|
||||||
|
get_csp()
|
||||||
|
|
||||||
|
|
||||||
|
# 放弃批处理任务
|
||||||
|
def giving_up_bpt():
|
||||||
|
global bpt_id
|
||||||
|
global stn_list
|
||||||
|
try_count = 3
|
||||||
|
while try_count < 3:
|
||||||
|
try_count += 1
|
||||||
|
# 标记任务执行失败
|
||||||
|
res = requests.put(url=url_parser("cs/bpt"),
|
||||||
|
headers=token_headers,
|
||||||
|
params={"idcode": idcode, "bptId": bpt_id, "status": False},
|
||||||
|
json=[])
|
||||||
|
|
||||||
|
if res.status_code == 201:
|
||||||
|
log.info("Marking Task Failed Successful: bertId ", bpt_id)
|
||||||
|
return True
|
||||||
|
elif res.status_code == 401:
|
||||||
|
# 尝试获得新的token
|
||||||
|
log.warning("Unauthorized Status Code: Try to Get New Token")
|
||||||
|
get_token()
|
||||||
|
else:
|
||||||
|
if try_count >= 3:
|
||||||
|
log.error("Marking Task Failed Eventually Failed: bertId ", bpt_id)
|
||||||
|
log.warning("Connection Maybe Unstable")
|
||||||
|
return False
|
||||||
|
log.warning("Failed and Try: count " + str(try_count))
|
||||||
|
|
||||||
|
# 清空计算数据
|
||||||
|
bpt_id = None
|
||||||
|
stn_list = []
|
||||||
|
|
||||||
|
|
||||||
|
# 从主服务器获得批处理任务
|
||||||
|
def get_bpt_from_server():
|
||||||
|
global max_stn_num
|
||||||
|
global idcode
|
||||||
|
if idcode is not None:
|
||||||
|
log.info("Try Getting BPT From Server...")
|
||||||
|
token_headers["signed"] = calculate_signed()
|
||||||
|
token_headers["timestamp"] = get_timestamp()
|
||||||
|
res = requests.get(url=url_parser("cs/bpt"),
|
||||||
|
headers=token_headers,
|
||||||
|
params={"idcode": idcode, "maxStnNum": int(max_stn_num)})
|
||||||
|
respond_json = res.json()
|
||||||
|
print(res.json())
|
||||||
|
if res.status_code == 200:
|
||||||
|
global bpt_id
|
||||||
|
try:
|
||||||
|
bpt_id = respond_json["id"]
|
||||||
|
except KeyError:
|
||||||
|
if respond_json["status"] == 401:
|
||||||
|
get_token()
|
||||||
|
return
|
||||||
|
|
||||||
|
# 如果没有批处理任务
|
||||||
|
if bpt_id is None:
|
||||||
|
log.info("No BPT Task For Now")
|
||||||
|
return
|
||||||
|
|
||||||
|
stns = respond_json["stns"]
|
||||||
|
if len(stns) == 0:
|
||||||
|
|
||||||
|
log.info("STNS IS EMPTY, Giving UP")
|
||||||
|
giving_up_bpt()
|
||||||
|
return
|
||||||
|
|
||||||
|
log.info("Get BPT Task: bptId " + str(bpt_id))
|
||||||
|
global stn_list
|
||||||
|
stn_list = stns
|
||||||
|
conn = sqlite3.connect(r".\bptdata.db")
|
||||||
|
# 处理数据
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("DELETE FROM {0}".format(predict_table))
|
||||||
|
|
||||||
|
log.info("Processing Bert Predict Data...")
|
||||||
|
for stn in stns:
|
||||||
|
sql = "INSERT INTO {0} (id, text) values (?, ?)".format(predict_table)
|
||||||
|
cursor.execute(sql, [stn["stnId"], stn["text"]])
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
log.info("Finished in Processing Bert Predict Data")
|
||||||
|
|
||||||
|
result = execute_bert_predict()
|
||||||
|
|
||||||
|
if result is True:
|
||||||
|
if processing_bert_result() is True:
|
||||||
|
log.info("BPT Execution Success: bptId " + str(bpt_id))
|
||||||
|
else:
|
||||||
|
log.info("BPT Execution Eventually Failed: bptId " + str(bpt_id))
|
||||||
|
else:
|
||||||
|
log.error("Bert Model Execution Failed")
|
||||||
|
|
||||||
|
log.info("Try Giving Up BPT Task: bptId " + str(bpt_id))
|
||||||
|
giving_up_bpt()
|
||||||
|
|
||||||
|
log.info("Get Status Code: " + str(res.status_code))
|
||||||
|
|
||||||
|
# 清空计算数据
|
||||||
|
bpt_id = None
|
||||||
|
stn_list = []
|
||||||
|
|
||||||
|
elif res.status_code == 400:
|
||||||
|
if respond_json["data"]["exception"] == "org.codedream.epaper.exception.badrequest.AuthExpiredException":
|
||||||
|
print("Auth Expired Exception: Try to Get New CSP")
|
||||||
|
get_csp()
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
print("Unknown Exception")
|
||||||
|
|
||||||
|
elif res.status_code == 401:
|
||||||
|
# 尝试获得新的token
|
||||||
|
log.warning("Unauthorized Status Code: Try to Get New Token")
|
||||||
|
get_token()
|
||||||
|
elif res.status_code == 500:
|
||||||
|
log.warning("Remote Server Error: Inner Server Error")
|
||||||
|
print(res.json())
|
||||||
|
else:
|
||||||
|
# 尝试获得护照
|
||||||
|
get_csp()
|
||||||
|
|
||||||
|
|
||||||
|
# 初始化数据库环境
|
||||||
|
def sqlite_create_table():
|
||||||
|
conn = sqlite3.connect(r".\bptdata.db")
|
||||||
|
cursor = conn.cursor()
|
||||||
|
create_tb_cmd = "CREATE TABLE IF NOT EXISTS {0}" \
|
||||||
|
"(id INT PRIMARY KEY," \
|
||||||
|
"text INT)".format(predict_table)
|
||||||
|
cursor.execute(create_tb_cmd)
|
||||||
|
cursor.execute("DELETE FROM {0}".format(predict_table))
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
# 启动BERT神经网络模型
|
||||||
|
def execute_bert_predict():
|
||||||
|
if os.path.exists(result_out_dir):
|
||||||
|
shutil.rmtree(result_out_dir)
|
||||||
|
log.info("BERT Model Executing...")
|
||||||
|
os.system("python run_classifier.py "
|
||||||
|
"--task_name=eppdt "
|
||||||
|
"--do_predict=true "
|
||||||
|
"--data_dir=./tmp "
|
||||||
|
"--vocab_file=./chinese_wwm_ext_L-12_H-768_A-12/vocab.txt "
|
||||||
|
"--bert_config_file=./chinese_wwm_ext_L-12_H-768_A-12/bert_config.json "
|
||||||
|
"--init_checkpoint=./tmp/epout/model.ckpt-14062 "
|
||||||
|
"--max_seq_length=64 "
|
||||||
|
"--output_dir=./tmp/eppredict/ > bert_out.log 2>&1")
|
||||||
|
result_list = os.listdir(result_out_dir)
|
||||||
|
log.info("BERT Model Execution Finished.")
|
||||||
|
if "test_results.tsv" not in result_list:
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
# 处理模型计算结果
|
||||||
|
def processing_bert_result():
|
||||||
|
result = pandas.read_csv(result_out_dir + '/test_results.tsv', sep='\t', header=None)
|
||||||
|
token_headers["timestamp"] = get_timestamp()
|
||||||
|
token_headers["signed"] = calculate_signed()
|
||||||
|
bpt_result_json = []
|
||||||
|
idx = 0
|
||||||
|
|
||||||
|
for i, row in result.iterrows():
|
||||||
|
bpt_result_json.append({"stnid": stn_list[idx]["stnId"], "tagPossible": [row[0], row[1], row[2]]})
|
||||||
|
idx += 1
|
||||||
|
|
||||||
|
log.debug("Bert Result Json")
|
||||||
|
log.debug(bpt_result_json)
|
||||||
|
log.info("Processing BERT Model Result Successful")
|
||||||
|
|
||||||
|
# 尝试3次
|
||||||
|
try_count = 0
|
||||||
|
while try_count < 3:
|
||||||
|
try_count += 1
|
||||||
|
log.info("Uploading BERT Model Result...")
|
||||||
|
res = requests.put(url=url_parser("cs/bpt"),
|
||||||
|
headers=token_headers,
|
||||||
|
params={"idcode": idcode, "bptId": bpt_id, "status": True},
|
||||||
|
json=bpt_result_json)
|
||||||
|
if res.status_code == 201:
|
||||||
|
log.info("Uploading Successful: bertId " + str(bpt_id))
|
||||||
|
return True
|
||||||
|
elif res.status_code == 401:
|
||||||
|
# 尝试获得新的token
|
||||||
|
log.warning("Unauthorized Status Code: Try to Get New Token")
|
||||||
|
get_token()
|
||||||
|
else:
|
||||||
|
if try_count >= 3:
|
||||||
|
log.error("Uploading Eventually Failed: bertId " + str(bpt_id))
|
||||||
|
log.warning("Connection Maybe Unstable")
|
||||||
|
return False
|
||||||
|
log.warning("Failed and Try: count " + str(try_count))
|
||||||
|
|
||||||
|
|
||||||
|
# 签证更新多线程定时器
|
||||||
|
def update_csp_timer():
|
||||||
|
log.info("UPDATE CSP TIMER STARTED")
|
||||||
|
try:
|
||||||
|
update_csp()
|
||||||
|
except:
|
||||||
|
log.error("Exception Thrown, Restarting Timer...")
|
||||||
|
finally:
|
||||||
|
t = threading.Timer(60, update_csp_timer)
|
||||||
|
t.start()
|
||||||
|
|
||||||
|
|
||||||
|
# 批处理任务多线程定时器
|
||||||
|
def get_bpt_timer():
|
||||||
|
log.info("GET BPT TIMER STARTED")
|
||||||
|
try:
|
||||||
|
get_bpt_from_server()
|
||||||
|
except:
|
||||||
|
log.error("Exception Thrown, Restarting Timer...")
|
||||||
|
finally:
|
||||||
|
t = threading.Timer(15, get_bpt_timer)
|
||||||
|
t.start()
|
||||||
|
|
||||||
|
|
||||||
|
# 初始化工作
|
||||||
|
def init():
|
||||||
|
global base_init
|
||||||
|
sqlite_create_table()
|
||||||
|
user_register()
|
||||||
|
get_token()
|
||||||
|
get_csp()
|
||||||
|
base_init = True
|
||||||
|
|
||||||
|
|
||||||
|
# 初始化定时器
|
||||||
|
def init_timer():
|
||||||
|
update_csp_timer()
|
||||||
|
get_bpt_timer()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
try_time = 0
|
||||||
|
while try_time < 3:
|
||||||
|
try:
|
||||||
|
init()
|
||||||
|
try_time = 3
|
||||||
|
except:
|
||||||
|
try_time += 1
|
||||||
|
time.sleep(5)
|
||||||
|
|
||||||
|
init_timer()
|
||||||
|
while True:
|
||||||
|
time.sleep(5)
|
6
tmp/epout/checkpoint
Normal file
6
tmp/epout/checkpoint
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
model_checkpoint_path: "model.ckpt-14062"
|
||||||
|
all_model_checkpoint_paths: "model.ckpt-11000"
|
||||||
|
all_model_checkpoint_paths: "model.ckpt-12000"
|
||||||
|
all_model_checkpoint_paths: "model.ckpt-13000"
|
||||||
|
all_model_checkpoint_paths: "model.ckpt-14000"
|
||||||
|
all_model_checkpoint_paths: "model.ckpt-14062"
|
BIN
tmp/epout/eval.tf_record
Normal file
BIN
tmp/epout/eval.tf_record
Normal file
Binary file not shown.
Binary file not shown.
4
tmp/epout/eval_results.txt
Normal file
4
tmp/epout/eval_results.txt
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
eval_accuracy = 0.98253334
|
||||||
|
eval_loss = 0.06590833
|
||||||
|
global_step = 14062
|
||||||
|
loss = 0.06590833
|
BIN
tmp/epout/events.out.tfevents.1586536204.iZ8vbescrakld4m4drzcktZ
Normal file
BIN
tmp/epout/events.out.tfevents.1586536204.iZ8vbescrakld4m4drzcktZ
Normal file
Binary file not shown.
592992
tmp/epout/graph.pbtxt
Normal file
592992
tmp/epout/graph.pbtxt
Normal file
File diff suppressed because it is too large
Load Diff
BIN
tmp/epout/model.ckpt-14062.index
Normal file
BIN
tmp/epout/model.ckpt-14062.index
Normal file
Binary file not shown.
BIN
tmp/epout/model.ckpt-14062.meta
Normal file
BIN
tmp/epout/model.ckpt-14062.meta
Normal file
Binary file not shown.
BIN
tmp/epout/train.tf_record
Normal file
BIN
tmp/epout/train.tf_record
Normal file
Binary file not shown.
BIN
tmp/eppredict/predict.tf_record
Normal file
BIN
tmp/eppredict/predict.tf_record
Normal file
Binary file not shown.
134
tmp/eppredict/test_results.tsv
Normal file
134
tmp/eppredict/test_results.tsv
Normal file
@ -0,0 +1,134 @@
|
|||||||
|
4.643959e-06 0.99999154 3.8252842e-06
|
||||||
|
4.8007923e-06 0.99999166 3.5236512e-06
|
||||||
|
4.6748496e-06 0.9999913 4.0859113e-06
|
||||||
|
4.4524345e-06 0.99999154 4.0169652e-06
|
||||||
|
4.0334426e-06 0.999992 3.9886727e-06
|
||||||
|
4.3244304e-06 0.9999919 3.781365e-06
|
||||||
|
4.126574e-06 0.999992 3.8155436e-06
|
||||||
|
4.2378488e-06 0.9999919 3.8473977e-06
|
||||||
|
0.00011727525 0.99848527 0.0013974697
|
||||||
|
4.361115e-06 0.99999154 4.017187e-06
|
||||||
|
5.1768693e-06 0.9999908 4.06028e-06
|
||||||
|
4.288634e-06 0.999992 3.709147e-06
|
||||||
|
4.9716205e-06 0.99999106 3.9358397e-06
|
||||||
|
4.182195e-06 0.99999213 3.6933725e-06
|
||||||
|
4.549165e-06 0.9999913 4.1921107e-06
|
||||||
|
6.6755088e-06 0.99998903 4.339319e-06
|
||||||
|
4.595618e-06 0.9999914 3.9184333e-06
|
||||||
|
4.607402e-06 0.99999166 3.752449e-06
|
||||||
|
4.598755e-06 0.9999919 3.4854709e-06
|
||||||
|
4.8619454e-06 0.9999912 3.882037e-06
|
||||||
|
4.1419257e-06 0.9999918 4.000119e-06
|
||||||
|
4.784566e-06 0.99999154 3.654619e-06
|
||||||
|
4.388862e-06 0.9999919 3.6818228e-06
|
||||||
|
5.644322e-06 0.9999908 3.5793255e-06
|
||||||
|
3.823311e-06 0.99999154 4.630431e-06
|
||||||
|
4.244102e-06 0.999992 3.6986664e-06
|
||||||
|
4.2734914e-06 0.999992 3.6995482e-06
|
||||||
|
4.3241253e-06 0.9999919 3.7916818e-06
|
||||||
|
4.4547583e-06 0.9999907 4.839659e-06
|
||||||
|
4.5243414e-06 0.9999918 3.6832657e-06
|
||||||
|
8.419241e-06 0.99998474 6.8440236e-06
|
||||||
|
7.966646e-06 0.9999751 1.6936197e-05
|
||||||
|
5.1216794e-06 0.9999901 4.733681e-06
|
||||||
|
5.034731e-06 0.9999912 3.794812e-06
|
||||||
|
0.0021447523 0.9975885 0.00026675928
|
||||||
|
5.646126e-06 0.9999908 3.6186827e-06
|
||||||
|
1.690044e-05 0.99997103 1.2073235e-05
|
||||||
|
4.9650066e-06 0.9999912 3.8529765e-06
|
||||||
|
4.904027e-06 0.99999094 4.1881613e-06
|
||||||
|
0.9271971 0.008783599 0.064019315
|
||||||
|
4.868973e-06 0.9999918 3.3974088e-06
|
||||||
|
4.9225746e-06 0.9999913 3.7550608e-06
|
||||||
|
3.858802e-06 0.9999924 3.8034468e-06
|
||||||
|
3.494936e-05 0.9999577 7.3692418e-06
|
||||||
|
4.923359e-06 0.9999913 3.849029e-06
|
||||||
|
5.178022e-06 0.9999914 3.4967088e-06
|
||||||
|
4.121945e-06 0.99999213 3.7508728e-06
|
||||||
|
4.451608e-06 0.9999919 3.651948e-06
|
||||||
|
0.00792893 0.98557657 0.0064944625
|
||||||
|
3.5600256e-06 0.9999896 6.819469e-06
|
||||||
|
4.801607e-06 0.99999154 3.753082e-06
|
||||||
|
0.23659243 0.7619511 0.001456472
|
||||||
|
4.4562576e-06 0.99999154 4.027041e-06
|
||||||
|
0.0015988095 0.99809843 0.00030284424
|
||||||
|
9.155851e-06 0.99998546 5.311816e-06
|
||||||
|
6.3670245e-06 0.9999906 3.0769254e-06
|
||||||
|
4.1778785e-06 0.9999914 4.376392e-06
|
||||||
|
4.648281e-06 0.999992 3.3429906e-06
|
||||||
|
5.194813e-06 0.99999106 3.7433254e-06
|
||||||
|
9.076348e-06 0.9999436 4.7352063e-05
|
||||||
|
4.3432983e-06 0.999992 3.7271493e-06
|
||||||
|
4.302407e-06 0.9999902 5.4830507e-06
|
||||||
|
5.4339334e-06 0.9999907 3.8040862e-06
|
||||||
|
4.391311e-06 0.9999918 3.8109024e-06
|
||||||
|
6.945087e-06 0.9999863 6.775152e-06
|
||||||
|
5.1417023e-06 0.9999895 5.3861663e-06
|
||||||
|
0.0011567149 0.9985071 0.00033610596
|
||||||
|
5.3787658e-06 0.9999907 3.8785456e-06
|
||||||
|
1.8892406e-05 0.9999585 2.2630838e-05
|
||||||
|
0.0015190784 0.997843 0.00063789665
|
||||||
|
6.7695796e-06 0.9999875 5.740492e-06
|
||||||
|
5.1006527e-06 0.99999034 4.5368133e-06
|
||||||
|
5.47516e-06 0.99998915 5.3493272e-06
|
||||||
|
4.8562415e-06 0.99999034 4.7732487e-06
|
||||||
|
0.060555745 0.0001441841 0.93930006
|
||||||
|
0.052763145 0.94372463 0.0035121434
|
||||||
|
4.3671207e-06 0.99999166 3.9508695e-06
|
||||||
|
4.778654e-06 0.9999901 5.128561e-06
|
||||||
|
4.7153376e-06 0.9999908 4.362817e-06
|
||||||
|
4.2666793e-06 0.9999914 4.266248e-06
|
||||||
|
3.838838e-06 0.99999225 3.990184e-06
|
||||||
|
4.461001e-06 0.9999912 4.44564e-06
|
||||||
|
4.0283635e-06 0.9999918 4.1213025e-06
|
||||||
|
1.4158776e-05 0.99996805 1.7746825e-05
|
||||||
|
8.674982e-05 0.9998728 4.0382252e-05
|
||||||
|
4.290552e-06 0.9999919 3.838057e-06
|
||||||
|
5.187617e-06 0.9999914 3.374771e-06
|
||||||
|
6.3959133e-06 0.9999894 4.2049714e-06
|
||||||
|
6.8617037e-06 0.9999896 3.4103866e-06
|
||||||
|
4.4409358e-06 0.999992 3.5424264e-06
|
||||||
|
5.4345987e-06 0.9999901 4.447051e-06
|
||||||
|
4.135196e-06 0.99999166 4.2157744e-06
|
||||||
|
4.7487447e-06 0.99999154 3.7494221e-06
|
||||||
|
8.4601015e-06 0.9999864 5.146833e-06
|
||||||
|
0.0010207603 0.9987452 0.00023391622
|
||||||
|
4.5771494e-06 0.9999919 3.5248079e-06
|
||||||
|
4.798046e-06 0.99999166 3.5586033e-06
|
||||||
|
5.8361684e-06 0.99998987 4.2727647e-06
|
||||||
|
5.2285122e-06 0.9999913 3.479859e-06
|
||||||
|
4.372247e-06 0.9999918 3.8096887e-06
|
||||||
|
4.5528377e-06 0.99999154 3.955717e-06
|
||||||
|
4.7401645e-06 0.9999913 3.9730817e-06
|
||||||
|
4.4522612e-06 0.9999918 3.8139272e-06
|
||||||
|
4.5628153e-06 0.99999094 4.564952e-06
|
||||||
|
5.596948e-06 0.99998975 4.607477e-06
|
||||||
|
4.4438884e-06 0.9999919 3.6515894e-06
|
||||||
|
4.682183e-06 0.9999919 3.397442e-06
|
||||||
|
4.7578187e-06 0.9999913 3.9699153e-06
|
||||||
|
6.961098e-06 0.9999745 1.8637946e-05
|
||||||
|
4.2590627e-06 0.9999919 3.834863e-06
|
||||||
|
5.1346065e-06 0.9999894 5.45891e-06
|
||||||
|
4.871587e-06 0.9999896 5.451893e-06
|
||||||
|
4.0097016e-06 0.9999918 4.1663734e-06
|
||||||
|
5.319837e-06 0.99999046 4.156601e-06
|
||||||
|
4.407603e-06 0.9999919 3.69892e-06
|
||||||
|
4.4321364e-06 0.9999914 4.1590397e-06
|
||||||
|
4.812539e-06 0.9999912 3.9999254e-06
|
||||||
|
5.329538e-06 0.9999888 5.881711e-06
|
||||||
|
4.2385377e-06 0.99999225 3.5127584e-06
|
||||||
|
4.7709664e-06 0.99999094 4.2666857e-06
|
||||||
|
9.780118e-06 0.99998 1.0236932e-05
|
||||||
|
8.97482e-06 0.99997663 1.44237465e-05
|
||||||
|
4.4326803e-06 0.9999919 3.6516205e-06
|
||||||
|
4.6600417e-06 0.99999166 3.7069035e-06
|
||||||
|
0.00041292913 0.99843234 0.001154748
|
||||||
|
0.008509361 0.9900877 0.0014030206
|
||||||
|
7.7520845e-06 0.9999875 4.7950984e-06
|
||||||
|
4.8316547e-06 0.99999094 4.2039433e-06
|
||||||
|
4.4381522e-06 0.99999166 3.9086926e-06
|
||||||
|
5.5704777e-06 0.9999908 3.5597884e-06
|
||||||
|
4.117504e-06 0.9999918 4.029943e-06
|
||||||
|
5.205461e-06 0.9999893 5.5215824e-06
|
||||||
|
4.6852315e-06 0.9999914 3.9398033e-06
|
||||||
|
4.80286e-06 0.9999913 3.89835e-06
|
|
399
tokenization.py
Normal file
399
tokenization.py
Normal file
@ -0,0 +1,399 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2018 The Google AI Language Team Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Tokenization classes."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import collections
|
||||||
|
import re
|
||||||
|
import unicodedata
|
||||||
|
import six
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
|
def validate_case_matches_checkpoint(do_lower_case, init_checkpoint):
|
||||||
|
"""Checks whether the casing config is consistent with the checkpoint name."""
|
||||||
|
|
||||||
|
# The casing has to be passed in by the user and there is no explicit check
|
||||||
|
# as to whether it matches the checkpoint. The casing information probably
|
||||||
|
# should have been stored in the bert_config.json file, but it's not, so
|
||||||
|
# we have to heuristically detect it to validate.
|
||||||
|
|
||||||
|
if not init_checkpoint:
|
||||||
|
return
|
||||||
|
|
||||||
|
m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint)
|
||||||
|
if m is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
model_name = m.group(1)
|
||||||
|
|
||||||
|
lower_models = [
|
||||||
|
"uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12",
|
||||||
|
"multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12"
|
||||||
|
]
|
||||||
|
|
||||||
|
cased_models = [
|
||||||
|
"cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16",
|
||||||
|
"multi_cased_L-12_H-768_A-12"
|
||||||
|
]
|
||||||
|
|
||||||
|
is_bad_config = False
|
||||||
|
if model_name in lower_models and not do_lower_case:
|
||||||
|
is_bad_config = True
|
||||||
|
actual_flag = "False"
|
||||||
|
case_name = "lowercased"
|
||||||
|
opposite_flag = "True"
|
||||||
|
|
||||||
|
if model_name in cased_models and do_lower_case:
|
||||||
|
is_bad_config = True
|
||||||
|
actual_flag = "True"
|
||||||
|
case_name = "cased"
|
||||||
|
opposite_flag = "False"
|
||||||
|
|
||||||
|
if is_bad_config:
|
||||||
|
raise ValueError(
|
||||||
|
"You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. "
|
||||||
|
"However, `%s` seems to be a %s model, so you "
|
||||||
|
"should pass in `--do_lower_case=%s` so that the fine-tuning matches "
|
||||||
|
"how the model was pre-training. If this error is wrong, please "
|
||||||
|
"just comment out this check." % (actual_flag, init_checkpoint,
|
||||||
|
model_name, case_name, opposite_flag))
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_unicode(text):
|
||||||
|
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
|
||||||
|
if six.PY3:
|
||||||
|
if isinstance(text, str):
|
||||||
|
return text
|
||||||
|
elif isinstance(text, bytes):
|
||||||
|
return text.decode("utf-8", "ignore")
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported string type: %s" % (type(text)))
|
||||||
|
elif six.PY2:
|
||||||
|
if isinstance(text, str):
|
||||||
|
return text.decode("utf-8", "ignore")
|
||||||
|
elif isinstance(text, unicode):
|
||||||
|
return text
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported string type: %s" % (type(text)))
|
||||||
|
else:
|
||||||
|
raise ValueError("Not running on Python2 or Python 3?")
|
||||||
|
|
||||||
|
|
||||||
|
def printable_text(text):
|
||||||
|
"""Returns text encoded in a way suitable for print or `tf.logging`."""
|
||||||
|
|
||||||
|
# These functions want `str` for both Python2 and Python3, but in one case
|
||||||
|
# it's a Unicode string and in the other it's a byte string.
|
||||||
|
if six.PY3:
|
||||||
|
if isinstance(text, str):
|
||||||
|
return text
|
||||||
|
elif isinstance(text, bytes):
|
||||||
|
return text.decode("utf-8", "ignore")
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported string type: %s" % (type(text)))
|
||||||
|
elif six.PY2:
|
||||||
|
if isinstance(text, str):
|
||||||
|
return text
|
||||||
|
elif isinstance(text, unicode):
|
||||||
|
return text.encode("utf-8")
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported string type: %s" % (type(text)))
|
||||||
|
else:
|
||||||
|
raise ValueError("Not running on Python2 or Python 3?")
|
||||||
|
|
||||||
|
|
||||||
|
def load_vocab(vocab_file):
|
||||||
|
"""Loads a vocabulary file into a dictionary."""
|
||||||
|
vocab = collections.OrderedDict()
|
||||||
|
index = 0
|
||||||
|
with tf.gfile.GFile(vocab_file, "r") as reader:
|
||||||
|
while True:
|
||||||
|
token = convert_to_unicode(reader.readline())
|
||||||
|
if not token:
|
||||||
|
break
|
||||||
|
token = token.strip()
|
||||||
|
vocab[token] = index
|
||||||
|
index += 1
|
||||||
|
return vocab
|
||||||
|
|
||||||
|
|
||||||
|
def convert_by_vocab(vocab, items):
|
||||||
|
"""Converts a sequence of [tokens|ids] using the vocab."""
|
||||||
|
output = []
|
||||||
|
for item in items:
|
||||||
|
output.append(vocab[item])
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def convert_tokens_to_ids(vocab, tokens):
|
||||||
|
return convert_by_vocab(vocab, tokens)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_ids_to_tokens(inv_vocab, ids):
|
||||||
|
return convert_by_vocab(inv_vocab, ids)
|
||||||
|
|
||||||
|
|
||||||
|
def whitespace_tokenize(text):
|
||||||
|
"""Runs basic whitespace cleaning and splitting on a piece of text."""
|
||||||
|
text = text.strip()
|
||||||
|
if not text:
|
||||||
|
return []
|
||||||
|
tokens = text.split()
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
|
||||||
|
class FullTokenizer(object):
|
||||||
|
"""Runs end-to-end tokenziation."""
|
||||||
|
|
||||||
|
def __init__(self, vocab_file, do_lower_case=True):
|
||||||
|
self.vocab = load_vocab(vocab_file)
|
||||||
|
self.inv_vocab = {v: k for k, v in self.vocab.items()}
|
||||||
|
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
|
||||||
|
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
|
||||||
|
|
||||||
|
def tokenize(self, text):
|
||||||
|
split_tokens = []
|
||||||
|
for token in self.basic_tokenizer.tokenize(text):
|
||||||
|
for sub_token in self.wordpiece_tokenizer.tokenize(token):
|
||||||
|
split_tokens.append(sub_token)
|
||||||
|
|
||||||
|
return split_tokens
|
||||||
|
|
||||||
|
def convert_tokens_to_ids(self, tokens):
|
||||||
|
return convert_by_vocab(self.vocab, tokens)
|
||||||
|
|
||||||
|
def convert_ids_to_tokens(self, ids):
|
||||||
|
return convert_by_vocab(self.inv_vocab, ids)
|
||||||
|
|
||||||
|
|
||||||
|
class BasicTokenizer(object):
|
||||||
|
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
|
||||||
|
|
||||||
|
def __init__(self, do_lower_case=True):
|
||||||
|
"""Constructs a BasicTokenizer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
do_lower_case: Whether to lower case the input.
|
||||||
|
"""
|
||||||
|
self.do_lower_case = do_lower_case
|
||||||
|
|
||||||
|
def tokenize(self, text):
|
||||||
|
"""Tokenizes a piece of text."""
|
||||||
|
text = convert_to_unicode(text)
|
||||||
|
text = self._clean_text(text)
|
||||||
|
|
||||||
|
# This was added on November 1st, 2018 for the multilingual and Chinese
|
||||||
|
# models. This is also applied to the English models now, but it doesn't
|
||||||
|
# matter since the English models were not trained on any Chinese data
|
||||||
|
# and generally don't have any Chinese data in them (there are Chinese
|
||||||
|
# characters in the vocabulary because Wikipedia does have some Chinese
|
||||||
|
# words in the English Wikipedia.).
|
||||||
|
text = self._tokenize_chinese_chars(text)
|
||||||
|
|
||||||
|
orig_tokens = whitespace_tokenize(text)
|
||||||
|
split_tokens = []
|
||||||
|
for token in orig_tokens:
|
||||||
|
if self.do_lower_case:
|
||||||
|
token = token.lower()
|
||||||
|
token = self._run_strip_accents(token)
|
||||||
|
split_tokens.extend(self._run_split_on_punc(token))
|
||||||
|
|
||||||
|
output_tokens = whitespace_tokenize(" ".join(split_tokens))
|
||||||
|
return output_tokens
|
||||||
|
|
||||||
|
def _run_strip_accents(self, text):
|
||||||
|
"""Strips accents from a piece of text."""
|
||||||
|
text = unicodedata.normalize("NFD", text)
|
||||||
|
output = []
|
||||||
|
for char in text:
|
||||||
|
cat = unicodedata.category(char)
|
||||||
|
if cat == "Mn":
|
||||||
|
continue
|
||||||
|
output.append(char)
|
||||||
|
return "".join(output)
|
||||||
|
|
||||||
|
def _run_split_on_punc(self, text):
|
||||||
|
"""Splits punctuation on a piece of text."""
|
||||||
|
chars = list(text)
|
||||||
|
i = 0
|
||||||
|
start_new_word = True
|
||||||
|
output = []
|
||||||
|
while i < len(chars):
|
||||||
|
char = chars[i]
|
||||||
|
if _is_punctuation(char):
|
||||||
|
output.append([char])
|
||||||
|
start_new_word = True
|
||||||
|
else:
|
||||||
|
if start_new_word:
|
||||||
|
output.append([])
|
||||||
|
start_new_word = False
|
||||||
|
output[-1].append(char)
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
return ["".join(x) for x in output]
|
||||||
|
|
||||||
|
def _tokenize_chinese_chars(self, text):
|
||||||
|
"""Adds whitespace around any CJK character."""
|
||||||
|
output = []
|
||||||
|
for char in text:
|
||||||
|
cp = ord(char)
|
||||||
|
if self._is_chinese_char(cp):
|
||||||
|
output.append(" ")
|
||||||
|
output.append(char)
|
||||||
|
output.append(" ")
|
||||||
|
else:
|
||||||
|
output.append(char)
|
||||||
|
return "".join(output)
|
||||||
|
|
||||||
|
def _is_chinese_char(self, cp):
|
||||||
|
"""Checks whether CP is the codepoint of a CJK character."""
|
||||||
|
# This defines a "chinese character" as anything in the CJK Unicode block:
|
||||||
|
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
|
||||||
|
#
|
||||||
|
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
|
||||||
|
# despite its name. The modern Korean Hangul alphabet is a different block,
|
||||||
|
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
|
||||||
|
# space-separated words, so they are not treated specially and handled
|
||||||
|
# like the all of the other languages.
|
||||||
|
if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
|
||||||
|
(cp >= 0x3400 and cp <= 0x4DBF) or #
|
||||||
|
(cp >= 0x20000 and cp <= 0x2A6DF) or #
|
||||||
|
(cp >= 0x2A700 and cp <= 0x2B73F) or #
|
||||||
|
(cp >= 0x2B740 and cp <= 0x2B81F) or #
|
||||||
|
(cp >= 0x2B820 and cp <= 0x2CEAF) or
|
||||||
|
(cp >= 0xF900 and cp <= 0xFAFF) or #
|
||||||
|
(cp >= 0x2F800 and cp <= 0x2FA1F)): #
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _clean_text(self, text):
|
||||||
|
"""Performs invalid character removal and whitespace cleanup on text."""
|
||||||
|
output = []
|
||||||
|
for char in text:
|
||||||
|
cp = ord(char)
|
||||||
|
if cp == 0 or cp == 0xfffd or _is_control(char):
|
||||||
|
continue
|
||||||
|
if _is_whitespace(char):
|
||||||
|
output.append(" ")
|
||||||
|
else:
|
||||||
|
output.append(char)
|
||||||
|
return "".join(output)
|
||||||
|
|
||||||
|
|
||||||
|
class WordpieceTokenizer(object):
|
||||||
|
"""Runs WordPiece tokenziation."""
|
||||||
|
|
||||||
|
def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200):
|
||||||
|
self.vocab = vocab
|
||||||
|
self.unk_token = unk_token
|
||||||
|
self.max_input_chars_per_word = max_input_chars_per_word
|
||||||
|
|
||||||
|
def tokenize(self, text):
|
||||||
|
"""Tokenizes a piece of text into its word pieces.
|
||||||
|
|
||||||
|
This uses a greedy longest-match-first algorithm to perform tokenization
|
||||||
|
using the given vocabulary.
|
||||||
|
|
||||||
|
For example:
|
||||||
|
input = "unaffable"
|
||||||
|
output = ["un", "##aff", "##able"]
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: A single token or whitespace separated tokens. This should have
|
||||||
|
already been passed through `BasicTokenizer.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of wordpiece tokens.
|
||||||
|
"""
|
||||||
|
|
||||||
|
text = convert_to_unicode(text)
|
||||||
|
|
||||||
|
output_tokens = []
|
||||||
|
for token in whitespace_tokenize(text):
|
||||||
|
chars = list(token)
|
||||||
|
if len(chars) > self.max_input_chars_per_word:
|
||||||
|
output_tokens.append(self.unk_token)
|
||||||
|
continue
|
||||||
|
|
||||||
|
is_bad = False
|
||||||
|
start = 0
|
||||||
|
sub_tokens = []
|
||||||
|
while start < len(chars):
|
||||||
|
end = len(chars)
|
||||||
|
cur_substr = None
|
||||||
|
while start < end:
|
||||||
|
substr = "".join(chars[start:end])
|
||||||
|
if start > 0:
|
||||||
|
substr = "##" + substr
|
||||||
|
if substr in self.vocab:
|
||||||
|
cur_substr = substr
|
||||||
|
break
|
||||||
|
end -= 1
|
||||||
|
if cur_substr is None:
|
||||||
|
is_bad = True
|
||||||
|
break
|
||||||
|
sub_tokens.append(cur_substr)
|
||||||
|
start = end
|
||||||
|
|
||||||
|
if is_bad:
|
||||||
|
output_tokens.append(self.unk_token)
|
||||||
|
else:
|
||||||
|
output_tokens.extend(sub_tokens)
|
||||||
|
return output_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def _is_whitespace(char):
|
||||||
|
"""Checks whether `chars` is a whitespace character."""
|
||||||
|
# \t, \n, and \r are technically contorl characters but we treat them
|
||||||
|
# as whitespace since they are generally considered as such.
|
||||||
|
if char == " " or char == "\t" or char == "\n" or char == "\r":
|
||||||
|
return True
|
||||||
|
cat = unicodedata.category(char)
|
||||||
|
if cat == "Zs":
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _is_control(char):
|
||||||
|
"""Checks whether `chars` is a control character."""
|
||||||
|
# These are technically control characters but we count them as whitespace
|
||||||
|
# characters.
|
||||||
|
if char == "\t" or char == "\n" or char == "\r":
|
||||||
|
return False
|
||||||
|
cat = unicodedata.category(char)
|
||||||
|
if cat in ("Cc", "Cf"):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _is_punctuation(char):
|
||||||
|
"""Checks whether `chars` is a punctuation character."""
|
||||||
|
cp = ord(char)
|
||||||
|
# We treat all non-letter/number ASCII as punctuation.
|
||||||
|
# Characters such as "^", "$", and "`" are not in the Unicode
|
||||||
|
# Punctuation class but we treat them as punctuation anyways, for
|
||||||
|
# consistency.
|
||||||
|
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
|
||||||
|
(cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
|
||||||
|
return True
|
||||||
|
cat = unicodedata.category(char)
|
||||||
|
if cat.startswith("P"):
|
||||||
|
return True
|
||||||
|
return False
|
137
tokenization_test.py
Normal file
137
tokenization_test.py
Normal file
@ -0,0 +1,137 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2018 The Google AI Language Team Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
import tokenization
|
||||||
|
import six
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
|
class TokenizationTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def test_full_tokenizer(self):
|
||||||
|
vocab_tokens = [
|
||||||
|
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
|
||||||
|
"##ing", ","
|
||||||
|
]
|
||||||
|
with tempfile.NamedTemporaryFile(delete=False) as vocab_writer:
|
||||||
|
if six.PY2:
|
||||||
|
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
||||||
|
else:
|
||||||
|
vocab_writer.write("".join(
|
||||||
|
[x + "\n" for x in vocab_tokens]).encode("utf-8"))
|
||||||
|
|
||||||
|
vocab_file = vocab_writer.name
|
||||||
|
|
||||||
|
tokenizer = tokenization.FullTokenizer(vocab_file)
|
||||||
|
os.unlink(vocab_file)
|
||||||
|
|
||||||
|
tokens = tokenizer.tokenize(u"UNwant\u00E9d,running")
|
||||||
|
self.assertAllEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
|
||||||
|
|
||||||
|
self.assertAllEqual(
|
||||||
|
tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
|
||||||
|
|
||||||
|
def test_chinese(self):
|
||||||
|
tokenizer = tokenization.BasicTokenizer()
|
||||||
|
|
||||||
|
self.assertAllEqual(
|
||||||
|
tokenizer.tokenize(u"ah\u535A\u63A8zz"),
|
||||||
|
[u"ah", u"\u535A", u"\u63A8", u"zz"])
|
||||||
|
|
||||||
|
def test_basic_tokenizer_lower(self):
|
||||||
|
tokenizer = tokenization.BasicTokenizer(do_lower_case=True)
|
||||||
|
|
||||||
|
self.assertAllEqual(
|
||||||
|
tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
|
||||||
|
["hello", "!", "how", "are", "you", "?"])
|
||||||
|
self.assertAllEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"])
|
||||||
|
|
||||||
|
def test_basic_tokenizer_no_lower(self):
|
||||||
|
tokenizer = tokenization.BasicTokenizer(do_lower_case=False)
|
||||||
|
|
||||||
|
self.assertAllEqual(
|
||||||
|
tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
|
||||||
|
["HeLLo", "!", "how", "Are", "yoU", "?"])
|
||||||
|
|
||||||
|
def test_wordpiece_tokenizer(self):
|
||||||
|
vocab_tokens = [
|
||||||
|
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
|
||||||
|
"##ing"
|
||||||
|
]
|
||||||
|
|
||||||
|
vocab = {}
|
||||||
|
for (i, token) in enumerate(vocab_tokens):
|
||||||
|
vocab[token] = i
|
||||||
|
tokenizer = tokenization.WordpieceTokenizer(vocab=vocab)
|
||||||
|
|
||||||
|
self.assertAllEqual(tokenizer.tokenize(""), [])
|
||||||
|
|
||||||
|
self.assertAllEqual(
|
||||||
|
tokenizer.tokenize("unwanted running"),
|
||||||
|
["un", "##want", "##ed", "runn", "##ing"])
|
||||||
|
|
||||||
|
self.assertAllEqual(
|
||||||
|
tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"])
|
||||||
|
|
||||||
|
def test_convert_tokens_to_ids(self):
|
||||||
|
vocab_tokens = [
|
||||||
|
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
|
||||||
|
"##ing"
|
||||||
|
]
|
||||||
|
|
||||||
|
vocab = {}
|
||||||
|
for (i, token) in enumerate(vocab_tokens):
|
||||||
|
vocab[token] = i
|
||||||
|
|
||||||
|
self.assertAllEqual(
|
||||||
|
tokenization.convert_tokens_to_ids(
|
||||||
|
vocab, ["un", "##want", "##ed", "runn", "##ing"]), [7, 4, 5, 8, 9])
|
||||||
|
|
||||||
|
def test_is_whitespace(self):
|
||||||
|
self.assertTrue(tokenization._is_whitespace(u" "))
|
||||||
|
self.assertTrue(tokenization._is_whitespace(u"\t"))
|
||||||
|
self.assertTrue(tokenization._is_whitespace(u"\r"))
|
||||||
|
self.assertTrue(tokenization._is_whitespace(u"\n"))
|
||||||
|
self.assertTrue(tokenization._is_whitespace(u"\u00A0"))
|
||||||
|
|
||||||
|
self.assertFalse(tokenization._is_whitespace(u"A"))
|
||||||
|
self.assertFalse(tokenization._is_whitespace(u"-"))
|
||||||
|
|
||||||
|
def test_is_control(self):
|
||||||
|
self.assertTrue(tokenization._is_control(u"\u0005"))
|
||||||
|
|
||||||
|
self.assertFalse(tokenization._is_control(u"A"))
|
||||||
|
self.assertFalse(tokenization._is_control(u" "))
|
||||||
|
self.assertFalse(tokenization._is_control(u"\t"))
|
||||||
|
self.assertFalse(tokenization._is_control(u"\r"))
|
||||||
|
self.assertFalse(tokenization._is_control(u"\U0001F4A9"))
|
||||||
|
|
||||||
|
def test_is_punctuation(self):
|
||||||
|
self.assertTrue(tokenization._is_punctuation(u"-"))
|
||||||
|
self.assertTrue(tokenization._is_punctuation(u"$"))
|
||||||
|
self.assertTrue(tokenization._is_punctuation(u"`"))
|
||||||
|
self.assertTrue(tokenization._is_punctuation(u"."))
|
||||||
|
|
||||||
|
self.assertFalse(tokenization._is_punctuation(u"A"))
|
||||||
|
self.assertFalse(tokenization._is_punctuation(u" "))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
tf.test.main()
|
Loading…
Reference in New Issue
Block a user