Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Protectai Modelscan FormatViaExtensionMiddleware

From Leeroopedia
Knowledge Sources
Domains ML_Security, Software_Architecture
Last Updated 2026-02-14 12:00 GMT

Overview

Concrete tool for tagging model files with their format type based on file extension, provided by the modelscan middleware module.

Description

The FormatViaExtensionMiddleware class implements MiddlewareBase to detect model file formats by matching file extensions against a configured mapping. It sets the formats context on the Model object so downstream scanners can quickly determine format compatibility. The MiddlewarePipeline class manages the chain of middleware instances, loading them dynamically from settings and executing them in order.

Usage

This middleware runs automatically as part of the ModelScan pipeline. Interact with it when:

  • Adding support for new file extensions in the format mapping
  • Understanding how scanners determine which files to process
  • Implementing a custom middleware for content-based format detection

Code Reference

Source Location

  • Repository: modelscan
  • File: modelscan/middlewares/format_via_extension.py (FormatViaExtensionMiddleware)
  • Lines: L6-17
  • File: modelscan/middlewares/middleware.py (MiddlewareBase, MiddlewarePipeline)
  • Lines: L11-59

Signature

class MiddlewareBase(metaclass=abc.ABCMeta):
    def __init__(self, settings: Dict[str, Any]):
        """
        Args:
            settings: Middleware-specific config from the settings dict.
        """

    @abc.abstractmethod
    def __call__(
        self,
        model: Model,
        call_next: Callable[[Model], None],
    ) -> None:
        """Process model and call call_next to continue the chain."""

class MiddlewarePipeline:
    @staticmethod
    def from_settings(middleware_settings: Dict[str, Any]) -> "MiddlewarePipeline":
        """Load middleware classes from settings and build pipeline."""

    def add_middleware(self, middleware: MiddlewareBase) -> "MiddlewarePipeline":
        """Add a middleware instance to the pipeline."""

    def run(self, model: Model) -> None:
        """Execute the middleware chain on the model."""

class FormatViaExtensionMiddleware(MiddlewareBase):
    def __call__(
        self,
        model: Model,
        call_next: Callable[[Model], None],
    ) -> None:
        """
        Match file extension to format, set model context, call next.
        Uses self._settings["formats"] mapping: Property -> List[str].
        """

Import

from modelscan.middlewares.middleware import MiddlewareBase, MiddlewarePipeline
from modelscan.middlewares.format_via_extension import FormatViaExtensionMiddleware

I/O Contract

Inputs

Name Type Required Description
model Model Yes Model object with a source path. The extension of get_source() is checked against the format mapping.
call_next Callable[[Model], None] Yes Callback to continue the middleware chain
settings["formats"] Dict[Property, List[str]] Yes (via config) Mapping of SupportedModelFormats Property objects to lists of file extensions (e.g., PICKLE -> [".pkl", ".pickle", ...])

Outputs

Name Type Description
model.get_context("formats") List[Property] List of SupportedModelFormats Property objects matching the file extension. Set on the model's context for scanners to read.

Usage Examples

Default Pipeline Execution

from modelscan.modelscan import ModelScan

# FormatViaExtensionMiddleware runs automatically during scan
scanner = ModelScan()
scanner.scan("/path/to/model.pkl")
# Internally: middleware tags .pkl files as PICKLE format
# Pickle scanner checks model.get_context("formats") and proceeds

Adding a Custom Format Extension

import copy
from modelscan.settings import DEFAULT_SETTINGS, SupportedModelFormats

settings = copy.deepcopy(DEFAULT_SETTINGS)

# Add .safetensors extension to the PYTORCH format
middleware_key = "modelscan.middlewares.FormatViaExtensionMiddleware"
settings["middlewares"][middleware_key]["formats"][SupportedModelFormats.PYTORCH].append(
    ".safetensors"
)

Implementing Custom Middleware

from modelscan.middlewares.middleware import MiddlewareBase
from modelscan.model import Model
from typing import Callable

class ContentTypeMiddleware(MiddlewareBase):
    """Detect format by reading file magic bytes instead of extension."""

    def __call__(
        self,
        model: Model,
        call_next: Callable[[Model], None],
    ) -> None:
        # Read magic bytes for content-based detection
        stream = model.get_stream(offset=0)
        magic = stream.read(4)
        stream.seek(0)

        if magic == b"PK\x03\x04":
            # ZIP-based format detected
            pass

        call_next(model)

Related Pages

Implements Principle

Requires Environment

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment