Home > Data Analysis > Weka on Android: load precomputed model and predict new samples

Weka on Android: load precomputed model and predict new samples

Weka on Android

Imagine you want to use machine learning on Android to predict the value for your target variable for new samples. Usually, neither the storage capabilities (for storing the training data) nor the computational power for training such models is available on mobile devices – assuming that you use a medium or large dataset and thorough evaluation of different model types and model parametrizations with a subsequent model selection. In such a scenario, the way to go therefore is:

  1. Train, evaluate, and select the optimal model for your machine learning prediction offline, e.g. using a desktop PC or server hardware.
  2. Export the ready trained model (“model file”, which usually tends to be relatively small) and ship it with your Android application.
  3. The “application case”/”production case” on the mobile device simply involves loading and using the model to predict new, yet unseen samples. Thereby, the prediction on the mobile device requires much less computational power and storage capabilities than the previous training of the model.

Based on the previous post about loading and using Weka models with JavaSE, this example demonstrates how to load and use a model Android which was previously trained with Weka “offline”, aka on a desktop PC. The core points you have to look out for when setting up the Android project:

  1. Use the exact same version of Weka you used for offline training the model also on Android.
  2. Copy the weka.jar to libs/, then select “Add as library…”.
  3. Ensure that “compile files(’libs/weka.jar’)” is present in build.gradle.
  4. Do a clean build: at the first time, this is going to require some CPU power and take some time.
  5. Don’t do anything with Weka that is UI related: Weka uses the JavaSE UI, which is not available on Android.
  6. Follow the details in the example below:
    1. Structure the new, yet unseen instance you want to predict the target variable so that Weka can work with it.
    2. Load the classifier object.
    3. Predict the target variable for the new instance.

The complete Android example project containing the code below can be downloaded here:

  • Download (7z): Android Weka Example.
  • SHA256sum: aaf8594954d3561218dc0dd621740f3184a171efa7016a2efe4b743a19845ffb
package at.fhooe.mcm.ml.androidwekaexample;

import android.content.res.AssetManager;
import android.os.Bundle;
import android.support.v7.app.AppCompatActivity;
import android.support.v7.widget.Toolbar;
import android.util.Log;
import android.view.View;
import android.view.Menu;
import android.view.MenuItem;
import android.widget.TextView;
import android.widget.Toast;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;

import ml.mcm.fhooe.at.androidwekaexample.R;
import weka.classifiers.Classifier;
import weka.core.Attribute;
import weka.core.DenseInstance;
import weka.core.Instances;

public class MainActivity extends AppCompatActivity {

    private static final String WEKA_TEST = "WekaTest";

    private Random mRandom = new Random();

    private Sample[] mSamples = new Sample[]{
            new Sample(1, 0, new double[]{5, 3.5, 2, 0.4}), // should be in the setosa domain
            new Sample(2, 1, new double[]{5.6, 3, 3.5, 1.2}), // should be in the versicolor domain
            new Sample(3, 2, new double[]{7, 3, 6.8, 2.1}) // should be in the virginica domain
    };

    private Classifier mClassifier = null;

    TextView mTextViewSamples = null;

    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_main);
        Toolbar toolbar = (Toolbar) findViewById(R.id.toolbar);
        setSupportActionBar(toolbar);

        // show samples
        StringBuilder sb = new StringBuilder("Samples:\n");
        for(Sample s : mSamples) {
            sb.append(s.toString() + "\n");
        }
        mTextViewSamples = (TextView) findViewById(R.id.textViewSamples);
        mTextViewSamples.setText(sb.toString());

        Log.d(WEKA_TEST, "onCreate() finished.");
    }

    @Override
    public boolean onCreateOptionsMenu(Menu menu) {
        // Inflate the menu; this adds items to the action bar if it is present.
        getMenuInflater().inflate(R.menu.menu_main, menu);
        return true;
    }

    @Override
    public boolean onOptionsItemSelected(MenuItem item) {
        // Handle action bar item clicks here. The action bar will
        // automatically handle clicks on the Home/Up button, so long
        // as you specify a parent activity in AndroidManifest.xml.
        int id = item.getItemId();

        //noinspection SimplifiableIfStatement
        if (id == R.id.action_settings) {
            return true;
        }

        return super.onOptionsItemSelected(item);
    }

    public void onClickButtonLoadModel(View _v) {
        Log.d(WEKA_TEST, "onClickButtonLoadModel()");

        AssetManager assetManager = getAssets();
        try {
            mClassifier = (Classifier) weka.core.SerializationHelper.read(assetManager.open("iris_model_logistic_allfeatures.model"));

        } catch (IOException e) {
            e.printStackTrace();
        } catch (Exception e) {
            // Weka "catch'em all!"
            e.printStackTrace();
        }
        Toast.makeText(this, "Model loaded.", Toast.LENGTH_SHORT).show();
    }

    public void onClickButtonPredict(View _v) {
        Log.d(WEKA_TEST, "onClickButtonPredict()");

        if(mClassifier==null){
            Toast.makeText(this, "Model not loaded!", Toast.LENGTH_SHORT).show();
            return;
        }

        // we need those for creating new instances later
        // order of attributes/classes needs to be exactly equal to those used for training
        final Attribute attributeSepalLength = new Attribute("sepallength");
        final Attribute attributeSepalWidth = new Attribute("sepalwidth");
        final Attribute attributePetalLength = new Attribute("petallength");
        final Attribute attributePetalWidth = new Attribute("petalwidth");
        final List<String> classes = new ArrayList<String>() {
            {
                add("Iris-setosa"); // cls nr 1
                add("Iris-versicolor"); // cls nr 2
                add("Iris-virginica"); // cls nr 3
            }
        };

        // Instances(...) requires ArrayList<> instead of List<>...
        ArrayList<Attribute> attributeList = new ArrayList<Attribute>(2) {
            {
                add(attributeSepalLength);
                add(attributeSepalWidth);
                add(attributePetalLength);
                add(attributePetalWidth);
                Attribute attributeClass = new Attribute("@@class@@", classes);
                add(attributeClass);
            }
        };
        // unpredicted data sets (reference to sample structure for new instances)
        Instances dataUnpredicted = new Instances("TestInstances",
                attributeList, 1);
        // last feature is target variable
        dataUnpredicted.setClassIndex(dataUnpredicted.numAttributes() - 1);

        // create new instance: this one should fall into the setosa domain
        final Sample s = mSamples[mRandom.nextInt(mSamples.length)];
        DenseInstance newInstance = new DenseInstance(dataUnpredicted.numAttributes()) {
            {
                setValue(attributeSepalLength, s.features[0]);
                setValue(attributeSepalWidth, s.features[1]);
                setValue(attributePetalLength, s.features[2]);
                setValue(attributePetalWidth, s.features[3]);
            }
        };
        // reference to dataset
        newInstance.setDataset(dataUnpredicted);

        // predict new sample
        try {
            double result = mClassifier.classifyInstance(newInstance);
            String className = classes.get(new Double(result).intValue());
            String msg = "Nr: " + s.nr + ", predicted: " + className + ", actual: " + classes.get(s.label);
            Log.d(WEKA_TEST, msg);
            Toast.makeText(this, msg, Toast.LENGTH_SHORT).show();
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public class Sample {
        public int nr;
        public int label;
        public double [] features;

        public Sample(int _nr, int _label, double[] _features) {
            this.nr = _nr;
            this.label = _label;
            this.features = _features;
        }

        @Override
        public String toString() {
            return "Nr " +
                    nr +
                    ", cls " + label +
                    ", feat: " + Arrays.toString(features);
        }
    }
}
Categories: Data Analysis
  1. No comments yet.
  1. No trackbacks yet.

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out / Change )

Twitter picture

You are commenting using your Twitter account. Log Out / Change )

Facebook photo

You are commenting using your Facebook account. Log Out / Change )

Google+ photo

You are commenting using your Google+ account. Log Out / Change )

Connecting to %s

%d bloggers like this: