Using TensorFlow models from the JVM using TensorFlow Serving

When it comes to using a deep learning model in production we face challenges which are quite different from those we encounter during training. In production we have much higher requirements regarding stability and monitoring. In many setups large parts of the front end are build using JVM languages (such as Scala, Clojure, and, of course, Java), as the JVM nowadays provides a battle proven environment for customer facing applications. Using a TensorFlow model from the JVM is not trivial, as currently the main APIs are either for C++ or Python. The TensorFlow team is working on a Java API which is still experimental and it is possible to access the C++ API using JavaCpp as described is this post. While these are feasible ways to access TensorFlow models from the JVM I think there is a better way using more stable means.

One tool in the TensorFlow ecosystem is TensorFlow Serving which is a dedicated system serving models, especially in production environments. It is efficient, under continuous development, and supports many neat features, such as updating model versions and serving multiple models (which can be used for A/B testing). TensorFlow Serving communicates via gRPC which a mature framework with proper Java integration. In the following I will show you a simple example model which we will access using a Java client. You can find the entire example on GitHub.

First we need a model that we are going to serve. For this example we will use a model which takes two numbers, multiplies the first one with three and adds the second one. You know, the really tough machine learning stuff…

import tensorflow as tf

x = tf.placeholder(tf.float32, shape=(None))
y = tf.placeholder(tf.float32, shape=(None))

three = tf.Variable(3, dtype= tf.float32)
z = tf.scalar_mul(three, x) + y

So far nothing unexpected. The variable we use could also have been a constant – but lets just pretend that it could have also been determined using many hours of training. The next step will be to serialize this model for serving. We need to make sure that the variable values as well as the graph architecture are serialized.

import os
from tensorflow.python.util import compat

model_version = 1
path = os.path.join("three_x_plus_y", str(model_version))
builder = tf.python.saved_model.builder.SavedModelBuilder(path)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    builder.add_meta_graph_and_variables(
        sess,
        [tf.python.saved_model.tag_constants.SERVING],
        signature_def_map= {
            "magic_model": tf.saved_model.signature_def_utils.predict_signature_def(
                inputs= {"egg": x, "bacon": y},
                outputs= {"spam": z})
        })
    builder.save()

Let's go through the steps here. First we create a SavedModelBuilder. Note that in the path we provide for the model builder the model version is encoded. TensorFlow Serving will automatically determine available versions under the models base path. Next we create a session and initialize the variable. Then we must make sure that our SavedModelBuilder knows our compute graph and the variables. Furthermore we add a signature under which the model and its inputs and outputs can be accessed. Finally we save the model.

Now that we have the exported model we can start our model server. Assume you already have installed you can now start the model server locally:

$ [path to tensorflow_serving]/bazel-bin/tensorflow_serving/model_servers/tensorflow_model_server --model_base_path="model/three_x_plus_y"

To access the model server we will create a small Scala client. The communication will be over a gRPC service which is defined in Protocol Buffer in the TensorFlow Serving model. These .proto files can be directly translated into corresponding Java classes which will give us a type safe way to talk to the model server. For the example I copy the proto files into src/main/proto. The translation into the corresponding Java classes can be automated using the Protobuf Gradle Plugin.

Putting the data into the corresponding Protobuf classes involves a small amount of boiler plate code which should be encapsulated:

def createTensorProto(values: Seq[Double]) = {
  val dim = TensorShapeProto.Dim.newBuilder()
    .setSize(values.size)

  val shape = TensorShapeProto.newBuilder()
    .addDim(dim)

  val builder = TensorProto.newBuilder()
    .setDtype(DataType.DT_FLOAT)
    .setTensorShape(shape)

  values.foreach{ value =>
    builder.addFloatVal(value.toFloat)
  }
  builder.build()
}

def createRequest(x: Seq[Double], y: Seq[Double]) = {
  val modelSpec = ModelSpec.newBuilder()
    .setName("default")
    .setSignatureName("magic_model")

  val requestBuilder = PredictRequest.newBuilder()
    .setModelSpec(modelSpec)
    .putInputs("egg", createTensorProto(x))
    .putInputs("bacon", createTensorProto(y))

  requestBuilder.build()
}

Note that in building the request we need to use the signature we defined when exporting the model. We can now create a full request object, containing two batches of values, which can then be send to the model server which in turn will reply with a corresponding response object.

val channel = ManagedChannelBuilder.forAddress("localhost", 8500)
  .usePlaintext(true) // by default SSL/TLS is used
  .build()

val stub = PredictionServiceGrpc.newBlockingStub(channel)

def applyModel(x: Seq[Double], y: Seq[Double]): Seq[Double] = {
  val response = stub.predict(createRequest(x, y))
  response.getOutputsOrThrow("outputs")
    .getFloatValList()
    .asScala
    .map(_.toDouble)
}

To make things easier for the example, I used a blocking stub which will block until the result is received. In most real applications an asynchronous stub will be more useful.

Finally we can now send some test data to the server and get back the corresponding results:

def main(args: Array[String]):Unit = {

  print("Application of the magic model yields: ")
  println(applyModel(Seq(1.0, 6.0), Seq(18.0, 107.0)).mkString(", "))

  channel.shutdown().awaitTermination(5, TimeUnit.SECONDS);
}

Running the client (while running the model server in the background) will yield

INFO: [io.grpc.internal.ManagedChannelImpl-1] Created with target localhost:8500
Application of the magic model yields: 21.0, 125.0
Mar 05, 2017 11:54:37 PM io.grpc.internal.ManagedChannelImpl maybeTerminateChannel
INFO: [io.grpc.internal.ManagedChannelImpl-1] Terminated

You can download and try the entire example from GitHub.


comments powered by Disqus