Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Haifengl Smile InferenceModel Predict

From Leeroopedia


Overview

The prediction pipeline in Smile's serve module spans three classes: InferenceResource (REST endpoint), InferenceModel (prediction logic), and InferenceResponse (result container). The REST endpoint receives a JSON body, delegates to InferenceModel.predict() which validates input, constructs a Tuple, dispatches to the appropriate model type, and returns an InferenceResponse containing the prediction and optional posterior probabilities.

Source Location

Class Location
InferenceResource.predict() serve/src/main/java/smile/serve/InferenceResource.java (Lines 64-70)
InferenceModel serve/src/main/java/smile/serve/InferenceModel.java (Lines 83-117)
InferenceResponse serve/src/main/java/smile/serve/InferenceResponse.java (Lines 29-59)
ProbabilitySerializer serve/src/main/java/smile/serve/ProbabilitySerializer.java

Import Statements

import smile.serve.InferenceModel;
import smile.serve.InferenceResponse;

External Dependencies

Dependency Annotation/Class Purpose
JAX-RS @POST, @Path, @Consumes, @Produces REST endpoint annotations
JAX-RS @PathParam Extracts model ID from URL
Vert.x io.vertx.core.json.JsonObject JSON request body parsing (Quarkus uses Vert.x under the hood)
Jackson @JsonInclude, @JsonSerialize JSON response serialization with custom probability formatting
Smile Core smile.data.Tuple Internal data row representation
Smile Core smile.model.ClassificationModel, smile.model.RegressionModel Model type dispatching

Type

API Doc

REST Endpoint

POST /api/v1/models/{id}

Property Value
Method POST
Path /api/v1/models/{id}
Consumes application/json
Produces application/json
Path Parameter id -- model identifier (e.g., "iris-classifier-2")
Request Body JSON object with feature name-value pairs
Response InferenceResponse JSON with prediction and optional probabilities
Errors 400 Bad Request if input features are insufficient; 404 Not Found if model ID is invalid

Java implementation (REST layer):

@POST
@Path("/{id}")
@Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON)
public InferenceResponse predict(@PathParam("id") String id, JsonObject request) {
    return service.predict(id, request);
}

HTTP Examples

Classification Request

curl -X POST http://localhost:8080/api/v1/models/iris-classifier-2 \
  -H "Content-Type: application/json" \
  -d '{
    "sepal_length": 5.1,
    "sepal_width": 3.5,
    "petal_length": 1.4,
    "petal_width": 0.2
  }'

Response (soft classifier with probabilities):

{
  "prediction": 0,
  "probabilities": [0.950, 0.030, 0.020]
}

Regression Request

curl -X POST http://localhost:8080/api/v1/models/boston-regression-1 \
  -H "Content-Type: application/json" \
  -d '{
    "crim": 0.00632,
    "zn": 18.0,
    "indus": 2.31,
    "chas": 0,
    "nox": 0.538,
    "rm": 6.575,
    "age": 65.2,
    "dis": 4.09,
    "rad": 1,
    "tax": 296,
    "ptratio": 15.3,
    "lstat": 4.98
  }'

Response (regression, no probabilities):

{
  "prediction": 30.003
}

Note that the probabilities field is omitted from regression responses (not set to null in JSON) thanks to the @JsonInclude(JsonInclude.Include.NON_NULL) annotation.

Java API Signatures

InferenceModel

public class InferenceModel {

    /** Predict from JSON input (feature names matched by schema). */
    public InferenceResponse predict(JsonObject request) throws BadRequestException;

    /** Predict from CSV input (feature values matched by position). */
    public InferenceResponse predict(String request) throws BadRequestException;

    /** Core prediction method operating on a Smile Tuple. */
    public InferenceResponse predict(Tuple x);
}

InferenceResponse

public class InferenceResponse {
    /** The predicted value (Integer for classification, Double for regression). */
    public Number prediction;

    /** Posterior probabilities for soft classifiers; null for hard classifiers and regression. */
    @JsonInclude(JsonInclude.Include.NON_NULL)
    @JsonSerialize(using = ProbabilitySerializer.class)
    public double[] probabilities;
}

Implementation Details

JSON to Tuple Conversion

The json() method converts a JSON request body to a Smile Tuple:

public Tuple json(JsonObject values) throws BadRequestException {
    StructType schema = model.schema();
    if (values.size() < schema.length()) throw new BadRequestException();

    var row = new Object[schema.length()];
    for (int i = 0; i < row.length; i++) {
        row[i] = values.getValue(schema.field(i).name());
    }
    return Tuple.of(schema, row);
}

Key behaviors:

  • Validates that the JSON contains at least as many fields as the schema requires.
  • Extracts values by field name from the JSON object, matching against the model's schema.
  • Constructs a Tuple from the extracted values.

CSV to Tuple Conversion

The csv() method handles comma-separated input:

public Tuple csv(String line) throws BadRequestException {
    var values = line.split(",");
    StructType schema = model.schema();
    if (values.length < schema.length()) throw new BadRequestException();

    try {
        var row = new Object[schema.length()];
        for (int i = 0; i < row.length; i++) {
            row[i] = schema.field(i).valueOf(values[i]);
        }
        return Tuple.of(schema, row);
    } catch (Exception ex) {
        throw new BadRequestException(ex.getMessage());
    }
}

Key behaviors:

  • Splits input on commas.
  • Values are matched by position (not by name).
  • Each string value is converted to the correct type using the schema field's valueOf() method.

Type-Dispatched Prediction

The core predict(Tuple) method uses Java's pattern matching switch:

public InferenceResponse predict(Tuple x) {
    double[] probabilities = null;
    Number y = switch (model) {
        case ClassificationModel m -> {
            if (isSoft) {
                probabilities = new double[m.classifier().numClasses()];
                yield m.predict(x, probabilities);
            } else {
                yield m.predict(x);
            }
        }
        case RegressionModel m -> m.predict(x);
        default -> 0;
    };
    return new InferenceResponse(y, probabilities);
}

The isSoft flag is determined once at construction time:

if (model instanceof ClassificationModel m) {
    isSoft = m.classifier().isSoft();
} else {
    isSoft = false;
}

For soft classifiers, a pre-allocated double[] array receives the posterior probabilities. For hard classifiers or regression models, probabilities remains null.

Probability Serialization

The ProbabilitySerializer custom Jackson serializer formats each probability to 3 decimal places:

public class ProbabilitySerializer extends JsonSerializer<double[]> {
    @Override
    public void serialize(double[] probabilities, JsonGenerator gen,
                          SerializerProvider serializers) throws IOException {
        gen.writeStartArray();
        for (double prob : probabilities) {
            String value = String.format("%.3f", prob);
            gen.writeRawValue(value);
        }
        gen.writeEndArray();
    }
}

This ensures compact, consistent formatting (e.g., [0.950, 0.030, 0.020]) rather than full-precision doubles.

String Representation (for Streaming)

The InferenceResponse.toString() method produces a compact text format used in streaming responses:

@Override
public String toString() {
    String s = prediction.toString();
    if (probabilities != null) {
        s += Arrays.stream(probabilities)
            .mapToObj(p -> String.format("%.3f", p))
            .collect(Collectors.joining(" ", " ", ""));
    }
    return s;
}

Example output: "1 0.100 0.900" (prediction followed by space-separated probabilities).

Related

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment