Implementation:Protectai Modelscan FormatViaExtensionMiddleware
| Knowledge Sources | |
|---|---|
| Domains | ML_Security, Software_Architecture |
| Last Updated | 2026-02-14 12:00 GMT |
Overview
Concrete tool for tagging model files with their format type based on file extension, provided by the modelscan middleware module.
Description
The FormatViaExtensionMiddleware class implements MiddlewareBase to detect model file formats by matching file extensions against a configured mapping. It sets the formats context on the Model object so downstream scanners can quickly determine format compatibility. The MiddlewarePipeline class manages the chain of middleware instances, loading them dynamically from settings and executing them in order.
Usage
This middleware runs automatically as part of the ModelScan pipeline. Interact with it when:
- Adding support for new file extensions in the format mapping
- Understanding how scanners determine which files to process
- Implementing a custom middleware for content-based format detection
Code Reference
Source Location
- Repository: modelscan
- File: modelscan/middlewares/format_via_extension.py (FormatViaExtensionMiddleware)
- Lines: L6-17
- File: modelscan/middlewares/middleware.py (MiddlewareBase, MiddlewarePipeline)
- Lines: L11-59
Signature
class MiddlewareBase(metaclass=abc.ABCMeta):
def __init__(self, settings: Dict[str, Any]):
"""
Args:
settings: Middleware-specific config from the settings dict.
"""
@abc.abstractmethod
def __call__(
self,
model: Model,
call_next: Callable[[Model], None],
) -> None:
"""Process model and call call_next to continue the chain."""
class MiddlewarePipeline:
@staticmethod
def from_settings(middleware_settings: Dict[str, Any]) -> "MiddlewarePipeline":
"""Load middleware classes from settings and build pipeline."""
def add_middleware(self, middleware: MiddlewareBase) -> "MiddlewarePipeline":
"""Add a middleware instance to the pipeline."""
def run(self, model: Model) -> None:
"""Execute the middleware chain on the model."""
class FormatViaExtensionMiddleware(MiddlewareBase):
def __call__(
self,
model: Model,
call_next: Callable[[Model], None],
) -> None:
"""
Match file extension to format, set model context, call next.
Uses self._settings["formats"] mapping: Property -> List[str].
"""
Import
from modelscan.middlewares.middleware import MiddlewareBase, MiddlewarePipeline
from modelscan.middlewares.format_via_extension import FormatViaExtensionMiddleware
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| model | Model | Yes | Model object with a source path. The extension of get_source() is checked against the format mapping. |
| call_next | Callable[[Model], None] | Yes | Callback to continue the middleware chain |
| settings["formats"] | Dict[Property, List[str]] | Yes (via config) | Mapping of SupportedModelFormats Property objects to lists of file extensions (e.g., PICKLE -> [".pkl", ".pickle", ...]) |
Outputs
| Name | Type | Description |
|---|---|---|
| model.get_context("formats") | List[Property] | List of SupportedModelFormats Property objects matching the file extension. Set on the model's context for scanners to read. |
Usage Examples
Default Pipeline Execution
from modelscan.modelscan import ModelScan
# FormatViaExtensionMiddleware runs automatically during scan
scanner = ModelScan()
scanner.scan("/path/to/model.pkl")
# Internally: middleware tags .pkl files as PICKLE format
# Pickle scanner checks model.get_context("formats") and proceeds
Adding a Custom Format Extension
import copy
from modelscan.settings import DEFAULT_SETTINGS, SupportedModelFormats
settings = copy.deepcopy(DEFAULT_SETTINGS)
# Add .safetensors extension to the PYTORCH format
middleware_key = "modelscan.middlewares.FormatViaExtensionMiddleware"
settings["middlewares"][middleware_key]["formats"][SupportedModelFormats.PYTORCH].append(
".safetensors"
)
Implementing Custom Middleware
from modelscan.middlewares.middleware import MiddlewareBase
from modelscan.model import Model
from typing import Callable
class ContentTypeMiddleware(MiddlewareBase):
"""Detect format by reading file magic bytes instead of extension."""
def __call__(
self,
model: Model,
call_next: Callable[[Model], None],
) -> None:
# Read magic bytes for content-based detection
stream = model.get_stream(offset=0)
magic = stream.read(4)
stream.seek(0)
if magic == b"PK\x03\x04":
# ZIP-based format detected
pass
call_next(model)