Implementation:Haifengl Smile InferenceModel Predict
Overview
The prediction pipeline in Smile's serve module spans three classes: InferenceResource (REST endpoint), InferenceModel (prediction logic), and InferenceResponse (result container). The REST endpoint receives a JSON body, delegates to InferenceModel.predict() which validates input, constructs a Tuple, dispatches to the appropriate model type, and returns an InferenceResponse containing the prediction and optional posterior probabilities.
Source Location
| Class | Location |
|---|---|
InferenceResource.predict() |
serve/src/main/java/smile/serve/InferenceResource.java (Lines 64-70) |
InferenceModel |
serve/src/main/java/smile/serve/InferenceModel.java (Lines 83-117) |
InferenceResponse |
serve/src/main/java/smile/serve/InferenceResponse.java (Lines 29-59) |
ProbabilitySerializer |
serve/src/main/java/smile/serve/ProbabilitySerializer.java |
Import Statements
import smile.serve.InferenceModel;
import smile.serve.InferenceResponse;
External Dependencies
| Dependency | Annotation/Class | Purpose |
|---|---|---|
| JAX-RS | @POST, @Path, @Consumes, @Produces |
REST endpoint annotations |
| JAX-RS | @PathParam |
Extracts model ID from URL |
| Vert.x | io.vertx.core.json.JsonObject |
JSON request body parsing (Quarkus uses Vert.x under the hood) |
| Jackson | @JsonInclude, @JsonSerialize |
JSON response serialization with custom probability formatting |
| Smile Core | smile.data.Tuple |
Internal data row representation |
| Smile Core | smile.model.ClassificationModel, smile.model.RegressionModel |
Model type dispatching |
Type
API Doc
REST Endpoint
POST /api/v1/models/{id}
| Property | Value |
|---|---|
| Method | POST
|
| Path | /api/v1/models/{id}
|
| Consumes | application/json
|
| Produces | application/json
|
| Path Parameter | id -- model identifier (e.g., "iris-classifier-2")
|
| Request Body | JSON object with feature name-value pairs |
| Response | InferenceResponse JSON with prediction and optional probabilities
|
| Errors | 400 Bad Request if input features are insufficient; 404 Not Found if model ID is invalid
|
Java implementation (REST layer):
@POST
@Path("/{id}")
@Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON)
public InferenceResponse predict(@PathParam("id") String id, JsonObject request) {
return service.predict(id, request);
}
HTTP Examples
Classification Request
curl -X POST http://localhost:8080/api/v1/models/iris-classifier-2 \
-H "Content-Type: application/json" \
-d '{
"sepal_length": 5.1,
"sepal_width": 3.5,
"petal_length": 1.4,
"petal_width": 0.2
}'
Response (soft classifier with probabilities):
{
"prediction": 0,
"probabilities": [0.950, 0.030, 0.020]
}
Regression Request
curl -X POST http://localhost:8080/api/v1/models/boston-regression-1 \
-H "Content-Type: application/json" \
-d '{
"crim": 0.00632,
"zn": 18.0,
"indus": 2.31,
"chas": 0,
"nox": 0.538,
"rm": 6.575,
"age": 65.2,
"dis": 4.09,
"rad": 1,
"tax": 296,
"ptratio": 15.3,
"lstat": 4.98
}'
Response (regression, no probabilities):
{
"prediction": 30.003
}
Note that the probabilities field is omitted from regression responses (not set to null in JSON) thanks to the @JsonInclude(JsonInclude.Include.NON_NULL) annotation.
Java API Signatures
InferenceModel
public class InferenceModel {
/** Predict from JSON input (feature names matched by schema). */
public InferenceResponse predict(JsonObject request) throws BadRequestException;
/** Predict from CSV input (feature values matched by position). */
public InferenceResponse predict(String request) throws BadRequestException;
/** Core prediction method operating on a Smile Tuple. */
public InferenceResponse predict(Tuple x);
}
InferenceResponse
public class InferenceResponse {
/** The predicted value (Integer for classification, Double for regression). */
public Number prediction;
/** Posterior probabilities for soft classifiers; null for hard classifiers and regression. */
@JsonInclude(JsonInclude.Include.NON_NULL)
@JsonSerialize(using = ProbabilitySerializer.class)
public double[] probabilities;
}
Implementation Details
JSON to Tuple Conversion
The json() method converts a JSON request body to a Smile Tuple:
public Tuple json(JsonObject values) throws BadRequestException {
StructType schema = model.schema();
if (values.size() < schema.length()) throw new BadRequestException();
var row = new Object[schema.length()];
for (int i = 0; i < row.length; i++) {
row[i] = values.getValue(schema.field(i).name());
}
return Tuple.of(schema, row);
}
Key behaviors:
- Validates that the JSON contains at least as many fields as the schema requires.
- Extracts values by field name from the JSON object, matching against the model's schema.
- Constructs a
Tuplefrom the extracted values.
CSV to Tuple Conversion
The csv() method handles comma-separated input:
public Tuple csv(String line) throws BadRequestException {
var values = line.split(",");
StructType schema = model.schema();
if (values.length < schema.length()) throw new BadRequestException();
try {
var row = new Object[schema.length()];
for (int i = 0; i < row.length; i++) {
row[i] = schema.field(i).valueOf(values[i]);
}
return Tuple.of(schema, row);
} catch (Exception ex) {
throw new BadRequestException(ex.getMessage());
}
}
Key behaviors:
- Splits input on commas.
- Values are matched by position (not by name).
- Each string value is converted to the correct type using the schema field's
valueOf()method.
Type-Dispatched Prediction
The core predict(Tuple) method uses Java's pattern matching switch:
public InferenceResponse predict(Tuple x) {
double[] probabilities = null;
Number y = switch (model) {
case ClassificationModel m -> {
if (isSoft) {
probabilities = new double[m.classifier().numClasses()];
yield m.predict(x, probabilities);
} else {
yield m.predict(x);
}
}
case RegressionModel m -> m.predict(x);
default -> 0;
};
return new InferenceResponse(y, probabilities);
}
The isSoft flag is determined once at construction time:
if (model instanceof ClassificationModel m) {
isSoft = m.classifier().isSoft();
} else {
isSoft = false;
}
For soft classifiers, a pre-allocated double[] array receives the posterior probabilities. For hard classifiers or regression models, probabilities remains null.
Probability Serialization
The ProbabilitySerializer custom Jackson serializer formats each probability to 3 decimal places:
public class ProbabilitySerializer extends JsonSerializer<double[]> {
@Override
public void serialize(double[] probabilities, JsonGenerator gen,
SerializerProvider serializers) throws IOException {
gen.writeStartArray();
for (double prob : probabilities) {
String value = String.format("%.3f", prob);
gen.writeRawValue(value);
}
gen.writeEndArray();
}
}
This ensures compact, consistent formatting (e.g., [0.950, 0.030, 0.020]) rather than full-precision doubles.
String Representation (for Streaming)
The InferenceResponse.toString() method produces a compact text format used in streaming responses:
@Override
public String toString() {
String s = prediction.toString();
if (probabilities != null) {
s += Arrays.stream(probabilities)
.mapToObj(p -> String.format("%.3f", p))
.collect(Collectors.joining(" ", " ", ""));
}
return s;
}
Example output: "1 0.100 0.900" (prediction followed by space-separated probabilities).