forked from apache/mxnet
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add BERT QA Scala/Java example (apache#14592)
* add BertQA major code piece * add scripts and bug fixes * add integration test * address comments * address doc comments
- Loading branch information
1 parent
2c5ee91
commit ee5e17e
Showing
8 changed files
with
545 additions
and
0 deletions.
There are no files selected for viewing
34 changes: 34 additions & 0 deletions
34
scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Layout.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package org.apache.mxnet.javaapi | ||
|
||
/** | ||
* Layout definition of DataDesc | ||
* N Batch size | ||
* C channels | ||
* H Height | ||
* W Weight | ||
* T sequence length | ||
* __undefined__ default value of Layout | ||
*/ | ||
object Layout { | ||
val UNDEFINED: String = org.apache.mxnet.Layout.UNDEFINED | ||
val NCHW: String = org.apache.mxnet.Layout.NCHW | ||
val NTC: String = org.apache.mxnet.Layout.NTC | ||
val NT: String = org.apache.mxnet.Layout.NT | ||
val N: String = org.apache.mxnet.Layout.N | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
31 changes: 31 additions & 0 deletions
31
scala-package/examples/scripts/infer/bert/get_bert_data.sh
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
#!/bin/bash | ||
|
||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
set -e | ||
|
||
MXNET_ROOT=$(cd "$(dirname $0)/../../.."; pwd) | ||
|
||
data_path=$MXNET_ROOT/scripts/infer/models/static-bert-qa/ | ||
|
||
if [ ! -d "$data_path" ]; then | ||
mkdir -p "$data_path" | ||
curl https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/BertQA/vocab.json -o $data_path/vocab.json | ||
curl https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/BertQA/static_bert_qa-0002.params -o $data_path/static_bert_qa-0002.params | ||
curl https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/BertQA/static_bert_qa-symbol.json -o $data_path/static_bert_qa-symbol.json | ||
fi |
27 changes: 27 additions & 0 deletions
27
scala-package/examples/scripts/infer/bert/run_bert_qa_example.sh
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
#!/bin/bash | ||
|
||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
set -e | ||
|
||
MXNET_ROOT=$(cd "$(dirname $0)/../../../../.."; pwd) | ||
|
||
CLASS_PATH=$MXNET_ROOT/scala-package/assembly/target/*:$MXNET_ROOT/scala-package/examples/target/*:$MXNET_ROOT/scala-package/examples/target/classes/lib/* | ||
|
||
java -Xmx8G -Dmxnet.traceLeakedObjects=true -cp $CLASS_PATH \ | ||
org.apache.mxnetexamples.javaapi.infer.bert.BertQA $@ |
126 changes: 126 additions & 0 deletions
126
...ge/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertDataParser.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.mxnetexamples.javaapi.infer.bert; | ||
|
||
import java.io.FileReader; | ||
import java.util.*; | ||
|
||
import com.google.gson.Gson; | ||
import com.google.gson.JsonArray; | ||
import com.google.gson.JsonElement; | ||
import com.google.gson.JsonObject; | ||
|
||
/** | ||
* This is the Utility for pre-processing the data for Bert Model | ||
* You can use this utility to parse Vocabulary JSON into Java Array and Dictionary, | ||
* clean and tokenize sentences and pad the text | ||
*/ | ||
public class BertDataParser { | ||
|
||
private Map<String, Integer> token2idx; | ||
private List<String> idx2token; | ||
|
||
/** | ||
* Parse the Vocabulary to JSON files | ||
* [PAD], [CLS], [SEP], [MASK], [UNK] are reserved tokens | ||
* @param jsonFile the filePath of the vocab.json | ||
* @throws Exception | ||
*/ | ||
void parseJSON(String jsonFile) throws Exception { | ||
Gson gson = new Gson(); | ||
token2idx = new HashMap<>(); | ||
idx2token = new LinkedList<>(); | ||
JsonObject jsonObject = gson.fromJson(new FileReader(jsonFile), JsonObject.class); | ||
JsonArray arr = jsonObject.getAsJsonArray("idx_to_token"); | ||
for (JsonElement element : arr) { | ||
idx2token.add(element.getAsString()); | ||
} | ||
JsonObject preMap = jsonObject.getAsJsonObject("token_to_idx"); | ||
for (String key : preMap.keySet()) { | ||
token2idx.put(key, preMap.get(key).getAsInt()); | ||
} | ||
} | ||
|
||
/** | ||
* Tokenize the input, split all kinds of whitespace and | ||
* Separate the end of sentence symbol: . , ? ! | ||
* @param input The input string | ||
* @return List of tokens | ||
*/ | ||
List<String> tokenizer(String input) { | ||
String[] step1 = input.split("\\s+"); | ||
List<String> finalResult = new LinkedList<>(); | ||
for (String item : step1) { | ||
if (item.length() != 0) { | ||
if ((item + "a").split("[.,?!]+").length > 1) { | ||
finalResult.add(item.substring(0, item.length() - 1)); | ||
finalResult.add(item.substring(item.length() -1)); | ||
} else { | ||
finalResult.add(item); | ||
} | ||
} | ||
} | ||
return finalResult; | ||
} | ||
|
||
/** | ||
* Pad the tokens to the required length | ||
* @param tokens input tokens | ||
* @param padItem things to pad at the end | ||
* @param num total length after padding | ||
* @return List of padded tokens | ||
*/ | ||
<E> List<E> pad(List<E> tokens, E padItem, int num) { | ||
if (tokens.size() >= num) return tokens; | ||
List<E> padded = new LinkedList<>(tokens); | ||
for (int i = 0; i < num - tokens.size(); i++) { | ||
padded.add(padItem); | ||
} | ||
return padded; | ||
} | ||
|
||
/** | ||
* Convert tokens to indexes | ||
* @param tokens input tokens | ||
* @return List of indexes | ||
*/ | ||
List<Integer> token2idx(List<String> tokens) { | ||
List<Integer> indexes = new ArrayList<>(); | ||
for (String token : tokens) { | ||
if (token2idx.containsKey(token)) { | ||
indexes.add(token2idx.get(token)); | ||
} else { | ||
indexes.add(token2idx.get("[UNK]")); | ||
} | ||
} | ||
return indexes; | ||
} | ||
|
||
/** | ||
* Convert indexes to tokens | ||
* @param indexes List of indexes | ||
* @return List of tokens | ||
*/ | ||
List<String> idx2token(List<Integer> indexes) { | ||
List<String> tokens = new ArrayList<>(); | ||
for (int index : indexes) { | ||
tokens.add(idx2token.get(index)); | ||
} | ||
return tokens; | ||
} | ||
} |
148 changes: 148 additions & 0 deletions
148
scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertQA.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.mxnetexamples.javaapi.infer.bert; | ||
|
||
import org.apache.mxnet.infer.javaapi.Predictor; | ||
import org.apache.mxnet.javaapi.*; | ||
import org.kohsuke.args4j.CmdLineParser; | ||
import org.kohsuke.args4j.Option; | ||
import org.slf4j.Logger; | ||
import org.slf4j.LoggerFactory; | ||
|
||
import java.util.*; | ||
|
||
/** | ||
* This is an example of using BERT to do the general Question and Answer inference jobs | ||
* Users can provide a question with a paragraph contains answer to the model and | ||
* the model will be able to find the best answer from the answer paragraph | ||
*/ | ||
public class BertQA { | ||
@Option(name = "--model-path-prefix", usage = "input model directory and prefix of the model") | ||
private String modelPathPrefix = "/model/static_bert_qa"; | ||
@Option(name = "--model-epoch", usage = "Epoch number of the model") | ||
private int epoch = 2; | ||
@Option(name = "--model-vocab", usage = "the vocabulary used in the model") | ||
private String modelVocab = "/model/vocab.json"; | ||
@Option(name = "--input-question", usage = "the input question") | ||
private String inputQ = "When did BBC Japan start broadcasting?"; | ||
@Option(name = "--input-answer", usage = "the input answer") | ||
private String inputA = | ||
"BBC Japan was a general entertainment Channel.\n" + | ||
" Which operated between December 2004 and April 2006.\n" + | ||
"It ceased operations after its Japanese distributor folded."; | ||
@Option(name = "--seq-length", usage = "the maximum length of the sequence") | ||
private int seqLength = 384; | ||
|
||
private final static Logger logger = LoggerFactory.getLogger(BertQA.class); | ||
private static NDArray$ NDArray = NDArray$.MODULE$; | ||
|
||
private static int argmax(float[] prob) { | ||
int maxIdx = 0; | ||
for (int i = 0; i < prob.length; i++) { | ||
if (prob[maxIdx] < prob[i]) maxIdx = i; | ||
} | ||
return maxIdx; | ||
} | ||
|
||
/** | ||
* Do the post processing on the output, apply softmax to get the probabilities | ||
* reshape and get the most probable index | ||
* @param result prediction result | ||
* @param tokens word tokens | ||
* @return Answers clipped from the original paragraph | ||
*/ | ||
static List<String> postProcessing(NDArray result, List<String> tokens) { | ||
NDArray[] output = NDArray.split( | ||
NDArray.new splitParam(result, 2).setAxis(2)); | ||
// Get the formatted logits result | ||
NDArray startLogits = output[0].reshape(new int[]{0, -3}); | ||
NDArray endLogits = output[1].reshape(new int[]{0, -3}); | ||
// Get Probability distribution | ||
float[] startProb = NDArray.softmax( | ||
NDArray.new softmaxParam(startLogits))[0].toArray(); | ||
float[] endProb = NDArray.softmax( | ||
NDArray.new softmaxParam(endLogits))[0].toArray(); | ||
int startIdx = argmax(startProb); | ||
int endIdx = argmax(endProb); | ||
return tokens.subList(startIdx, endIdx + 1); | ||
} | ||
|
||
public static void main(String[] args) throws Exception{ | ||
BertQA inst = new BertQA(); | ||
CmdLineParser parser = new CmdLineParser(inst); | ||
parser.parseArgument(args); | ||
BertDataParser util = new BertDataParser(); | ||
Context context = Context.cpu(); | ||
if (System.getenv().containsKey("SCALA_TEST_ON_GPU") && | ||
Integer.valueOf(System.getenv("SCALA_TEST_ON_GPU")) == 1) { | ||
context = Context.gpu(); | ||
} | ||
// pre-processing - tokenize sentence | ||
List<String> tokenQ = util.tokenizer(inst.inputQ.toLowerCase()); | ||
List<String> tokenA = util.tokenizer(inst.inputA.toLowerCase()); | ||
int validLength = tokenQ.size() + tokenA.size(); | ||
logger.info("Valid length: " + validLength); | ||
// generate token types [0000...1111....0000] | ||
List<Float> QAEmbedded = new ArrayList<>(); | ||
util.pad(QAEmbedded, 0f, tokenQ.size()).addAll( | ||
util.pad(new ArrayList<Float>(), 1f, tokenA.size()) | ||
); | ||
List<Float> tokenTypes = util.pad(QAEmbedded, 0f, inst.seqLength); | ||
// make BERT pre-processing standard | ||
tokenQ.add("[SEP]"); | ||
tokenQ.add(0, "[CLS]"); | ||
tokenA.add("[SEP]"); | ||
tokenQ.addAll(tokenA); | ||
List<String> tokens = util.pad(tokenQ, "[PAD]", inst.seqLength); | ||
logger.info("Pre-processed tokens: " + Arrays.toString(tokenQ.toArray())); | ||
// pre-processing - token to index translation | ||
util.parseJSON(inst.modelVocab); | ||
List<Integer> indexes = util.token2idx(tokens); | ||
List<Float> indexesFloat = new ArrayList<>(); | ||
for (int integer : indexes) { | ||
indexesFloat.add((float) integer); | ||
} | ||
// Preparing the input data | ||
List<NDArray> inputBatch = Arrays.asList( | ||
new NDArray(indexesFloat, | ||
new Shape(new int[]{1, inst.seqLength}), context), | ||
new NDArray(tokenTypes, | ||
new Shape(new int[]{1, inst.seqLength}), context), | ||
new NDArray(new float[] { validLength }, | ||
new Shape(new int[]{1}), context) | ||
); | ||
// Build the model | ||
List<Context> contexts = new ArrayList<>(); | ||
contexts.add(context); | ||
List<DataDesc> inputDescs = Arrays.asList( | ||
new DataDesc("data0", | ||
new Shape(new int[]{1, inst.seqLength}), DType.Float32(), Layout.NT()), | ||
new DataDesc("data1", | ||
new Shape(new int[]{1, inst.seqLength}), DType.Float32(), Layout.NT()), | ||
new DataDesc("data2", | ||
new Shape(new int[]{1}), DType.Float32(), Layout.N()) | ||
); | ||
Predictor bertQA = new Predictor(inst.modelPathPrefix, inputDescs, contexts, inst.epoch); | ||
// Start prediction | ||
NDArray result = bertQA.predictWithNDArray(inputBatch).get(0); | ||
List<String> answer = postProcessing(result, tokens); | ||
logger.info("Question: " + inst.inputQ); | ||
logger.info("Answer paragraph: " + inst.inputA); | ||
logger.info("Answer: " + Arrays.toString(answer.toArray())); | ||
} | ||
} |
Oops, something went wrong.