Skip to content

Model Interface

openlithohub.models.base

Abstract base class for lithography optimization models.

PredictionResult dataclass

Result from a model prediction.

Source code in src/openlithohub/models/base.py
@dataclass
class PredictionResult:
    """Result from a model prediction."""

    mask: torch.Tensor
    contour: torch.Tensor | None = None
    metadata: dict[str, Any] = field(default_factory=dict)

LithographyModel

Bases: ABC

Abstract interface for lithography optimization models.

Any model (heuristic OPC, U-Net, diffusion-based ILT, curvyILT) can join the evaluation pipeline by implementing predict().

Source code in src/openlithohub/models/base.py
class LithographyModel(ABC):
    """Abstract interface for lithography optimization models.

    Any model (heuristic OPC, U-Net, diffusion-based ILT, curvyILT)
    can join the evaluation pipeline by implementing predict().
    """

    @property
    @abstractmethod
    def name(self) -> str:
        """Human-readable model name for leaderboard display."""
        ...

    @property
    @abstractmethod
    def supports_curvilinear(self) -> bool:
        """Whether this model produces curvilinear (non-Manhattan) output."""
        ...

    @abstractmethod
    def predict(self, design: torch.Tensor, **kwargs: Any) -> PredictionResult:
        """Run model inference on a design layout tensor.

        Args:
            design: Input design tensor of shape (H, W) or (B, C, H, W).
            **kwargs: Model-specific parameters (process node, dose, etc.)

        Returns:
            PredictionResult with the optimized mask and optional contour.
        """
        ...

    def setup(self) -> None:
        """Optional setup hook (load weights, initialize GPU, etc.)."""

    def teardown(self) -> None:
        """Optional cleanup hook."""

name abstractmethod property

Human-readable model name for leaderboard display.

supports_curvilinear abstractmethod property

Whether this model produces curvilinear (non-Manhattan) output.

predict(design, **kwargs) abstractmethod

Run model inference on a design layout tensor.

Parameters:

Name Type Description Default
design Tensor

Input design tensor of shape (H, W) or (B, C, H, W).

required
**kwargs Any

Model-specific parameters (process node, dose, etc.)

{}

Returns:

Type Description
PredictionResult

PredictionResult with the optimized mask and optional contour.

Source code in src/openlithohub/models/base.py
@abstractmethod
def predict(self, design: torch.Tensor, **kwargs: Any) -> PredictionResult:
    """Run model inference on a design layout tensor.

    Args:
        design: Input design tensor of shape (H, W) or (B, C, H, W).
        **kwargs: Model-specific parameters (process node, dose, etc.)

    Returns:
        PredictionResult with the optimized mask and optional contour.
    """
    ...

setup()

Optional setup hook (load weights, initialize GPU, etc.).

Source code in src/openlithohub/models/base.py
def setup(self) -> None:
    """Optional setup hook (load weights, initialize GPU, etc.)."""

teardown()

Optional cleanup hook.

Source code in src/openlithohub/models/base.py
def teardown(self) -> None:
    """Optional cleanup hook."""

openlithohub.models.registry

Model registry — discover and instantiate lithography models.

ModelRegistry

Registry for discovering and instantiating lithography models.

Source code in src/openlithohub/models/registry.py
class ModelRegistry:
    """Registry for discovering and instantiating lithography models."""

    def __init__(self) -> None:
        self._models: dict[str, type[LithographyModel]] = {}

    def register(self, model_cls: type[LithographyModel]) -> type[LithographyModel]:
        """Register a model class. Can be used as a decorator."""
        instance = model_cls.__new__(model_cls)
        name = instance.name if hasattr(instance, "name") else model_cls.__name__
        self._models[name] = model_cls
        return model_cls

    def get(self, name: str, **kwargs: Any) -> LithographyModel:
        """Instantiate a registered model by name."""
        if name not in self._models:
            available = ", ".join(sorted(self._models.keys()))
            raise KeyError(f"Model '{name}' not found. Available: [{available}]")
        return self._models[name](**kwargs)

    def list_models(self) -> list[str]:
        """Return names of all registered models."""
        return sorted(self._models.keys())

register(model_cls)

Register a model class. Can be used as a decorator.

Source code in src/openlithohub/models/registry.py
def register(self, model_cls: type[LithographyModel]) -> type[LithographyModel]:
    """Register a model class. Can be used as a decorator."""
    instance = model_cls.__new__(model_cls)
    name = instance.name if hasattr(instance, "name") else model_cls.__name__
    self._models[name] = model_cls
    return model_cls

get(name, **kwargs)

Instantiate a registered model by name.

Source code in src/openlithohub/models/registry.py
def get(self, name: str, **kwargs: Any) -> LithographyModel:
    """Instantiate a registered model by name."""
    if name not in self._models:
        available = ", ".join(sorted(self._models.keys()))
        raise KeyError(f"Model '{name}' not found. Available: [{available}]")
    return self._models[name](**kwargs)

list_models()

Return names of all registered models.

Source code in src/openlithohub/models/registry.py
def list_models(self) -> list[str]:
    """Return names of all registered models."""
    return sorted(self._models.keys())