Implementation:Tensorflow Serving Tfrt Predict Util Test
| Knowledge Sources | |
|---|---|
| Domains | Testing, Prediction, TFRT |
| Last Updated | 2026-02-13 00:00 GMT |
Overview
Test suite validating the TFRT-based predict utility functions which execute prediction using the TFRT runtime.
Description
This test file validates the TFRT predict utility implementation that executes predictions using the TFRT (TensorFlow RunTime) engine. The PredictImplTest fixture initializes a TFRT runtime with 4 inter-op threads, sets up a ServerCore with the saved_model_half_plus_two_tf2_cpu model using TfrtSavedModelSourceAdapterConfig, and uses the Servable::Predict interface to execute predictions.
Key areas tested include:
- Successful prediction with explicit and default inputs
- Invalid tensor type handling
- Missing function/signature handling
- Missing required inputs
- Run errors propagation
- Output tensor number mismatch
- Output filter support (subset, full set, with default inputs)
- Unmatched output filter handling
- Prediction timeout/deadline support
Usage
Run these tests to validate changes to the TFRT predict path, including the TFRT servable predict interface and run options.
Code Reference
Source Location
- Repository: Tensorflow_Serving
- File: tensorflow_serving/servables/tensorflow/tfrt_predict_util_test.cc
- Lines: 1-582
Test Fixture
class PredictImplTest : public ::testing::Test {
public:
static void SetUpTestSuite() {
tfrt_stub::SetGlobalRuntime(
tfrt_stub::Runtime::Create(/*num_inter_op_threads=*/4));
// Sets up ServerCore with TfrtSavedModelSourceAdapterConfig
// and saved_model_half_plus_two_tf2_cpu model
}
protected:
absl::Status CallPredict(ServerCore* server_core,
const PredictRequest& request,
PredictResponse* response,
absl::Duration timeout = absl::ZeroDuration()) {
ServableHandle<Servable> servable;
TF_RETURN_IF_ERROR(GetSavedModelServableHandle(server_core, &servable));
Servable::RunOptions run_options;
if (timeout != absl::ZeroDuration())
run_options.deadline = absl::Now() + timeout;
return servable->Predict(run_options, request, response);
}
};
Build Target
bazel test //tensorflow_serving/servables/tensorflow:tfrt_predict_util_test
Test Coverage
Key Test Cases
| Test Name | Category | Description |
|---|---|---|
| PredictionSuccess | Integration | Tests successful prediction with half_plus_two model |
| PredictionSuccessWithDefaultInputs | Integration | Tests prediction using default input values |
| PredictionInvalidTensor | Validation | Tests error on invalid input tensor type |
| PredictionMissingFunction | Validation | Tests error on missing function/signature |
| PredictionMissingInput | Validation | Tests error when required input is missing |
| PredictionRunError | Error Handling | Tests propagation of TFRT run errors |
| PredictionUnmatchedOutputNumber | Validation | Tests error on unexpected output count |
| OutputFilters | Integration | Tests output filter subset selection |
| OutputFiltersFullSet | Integration | Tests output filter with all outputs |
| OutputFiltersWithDefaultInputs | Integration | Tests output filters with default inputs |
| UnmatchedOutputFilters | Validation | Tests error on non-existent output filters |
| PredictionTimeout | Integration | Tests prediction with deadline/timeout |
Usage Examples
Test Pattern
TEST_F(PredictImplTest, PredictionSuccess) {
PredictRequest request;
PredictResponse response;
ModelSpec* model_spec = request.mutable_model_spec();
model_spec->set_name(kTestModelName);
model_spec->mutable_version()->set_value(kTestModelVersion);
TensorProto tensor_proto;
tensor_proto.add_float_val(2.0);
tensor_proto.set_dtype(tensorflow::DT_FLOAT);
(*request.mutable_inputs())[kInputTensorKey] = tensor_proto;
TF_EXPECT_OK(CallPredict(GetServerCore(), request, &response));
// Expected: 0.5 * 2 + 2 = 3
}