Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 116 additions & 1 deletion pkg/ocihook/ocihook.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"os"
"os/exec"
"path/filepath"
"strconv"
"strings"
"time"

Expand Down Expand Up @@ -420,11 +421,94 @@ func getIP6AddressOpts(opts *handlerOpts) ([]cni.NamespaceOpts, error) {
return nil, nil
}

func reserveSocket(protocol, hostAddr string) (*os.File, error) {
type filer interface {
File() (*os.File, error)
}
var f filer
switch {
case strings.HasPrefix(protocol, "tcp"):
l, err := net.Listen(protocol, hostAddr)
if err != nil {
return nil, err
}
defer l.Close()
var ok bool
f, ok = l.(filer)
if !ok {
return nil, fmt.Errorf("cannot get file descriptor from the listener of type %T", l)
}
case strings.HasPrefix(protocol, "udp"):
l, err := net.ListenPacket(protocol, hostAddr)
if err != nil {
return nil, err
}
defer l.Close()
var ok bool
f, ok = l.(filer)
if !ok {
return nil, fmt.Errorf("cannot get file descriptor from the listener of type %T", l)
}
default:
return nil, fmt.Errorf("unsupported protocol %q", protocol)
}
return f.File()
}

// portReserverPidFilePath returns /run/nerdctl/<namespace>/<id>/port-reserver.pid
func portReserverPidFilePath(opts *handlerOpts) string {
return filepath.Join("/run/nerdctl/", opts.state.Annotations[labels.Namespace], opts.state.ID, "port-reserver.pid")
}

func applyNetworkSettings(opts *handlerOpts) (err error) {
portMapOpts, err := getPortMapOpts(opts)
if err != nil {
return err
}
if !rootlessutil.IsRootlessChild() && len(opts.ports) > 0 {
// When running in rootful mode, reserve the ports on the host
// so that the ports appears on /proc/net/tcp.
//
// This also prevents other processes from binding to the same ports.
//
// Note that in rootless mode this is not necessary because
// RootlessKit's port driver already reserves the ports.
//
// See https://github.com/lima-vm/lima/issues/4085
//
// Similar patterns are used in Docker and Podman.
// - https://github.com/moby/moby/pull/48132
// - https://github.com/containers/podman/pull/23446
reserverCmd := exec.Command("sleep", "infinity")
for _, p := range opts.ports {
protocol := p.Protocol
if !strings.HasSuffix(protocol, "4") && !strings.HasSuffix(protocol, "6") {
// e.g. "tcp" -> "tcp4"
protocol += "4"
}
hostAddr := net.JoinHostPort(p.HostIP, strconv.Itoa(int(p.HostPort)))
f, err := reserveSocket(protocol, hostAddr)
if err != nil {
log.L.WithError(err).Warnf("cannot reserve the port %s/%s", hostAddr, protocol)
continue
}
reserverCmd.ExtraFiles = append(reserverCmd.ExtraFiles, f)
}
if err := reserverCmd.Start(); err != nil {
return fmt.Errorf("cannot start the port reserver process: %w", err)
}
reserverCmdPid := reserverCmd.Process.Pid
log.L.Debugf("started the port reserver process (pid=%d)", reserverCmdPid)
defer func() {
if err != nil {
log.L.Debugf("killing the port reserver process (pid=%d)", reserverCmdPid)
_ = reserverCmd.Process.Kill()
}
}()
if err := writePidFile(portReserverPidFilePath(opts), reserverCmdPid); err != nil {
return fmt.Errorf("cannot write the pid file of the port reserver process: %w", err)
}
}
nsPath, err := getNetNSPath(opts.state)
if err != nil {
return err
Expand Down Expand Up @@ -659,6 +743,11 @@ func onPostStop(opts *handlerOpts) error {
if err := namst.Release(name, opts.state.ID); err != nil && !errors.Is(err, store.ErrNotFound) {
return fmt.Errorf("failed to release container name %s: %w", name, err)
}
// Kill port-reserver process if any
portReserverPidFile := portReserverPidFilePath(opts)
if err = killProcessByPidFile(portReserverPidFile); err != nil {
log.L.WithError(err).Errorf("failed to kill the port-reserver process")
}
return nil
}

Expand Down Expand Up @@ -706,7 +795,11 @@ func writePidFile(path string, pid int) error {
if err != nil {
return err
}
tempPath := filepath.Join(filepath.Dir(path), fmt.Sprintf(".%s", filepath.Base(path)))
dir := filepath.Dir(path)
if err := os.MkdirAll(dir, 0755); err != nil {
return err
}
tempPath := filepath.Join(dir, fmt.Sprintf(".%s", filepath.Base(path)))
f, err := os.OpenFile(tempPath, os.O_RDWR|os.O_CREATE|os.O_EXCL|os.O_SYNC, 0666)
if err != nil {
return err
Expand All @@ -718,3 +811,25 @@ func writePidFile(path string, pid int) error {
}
return os.Rename(tempPath, path)
}

func killProcessByPidFile(pidFile string) error {
pidData, err := os.ReadFile(pidFile)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
err = nil
}
return err
}
pid, err := strconv.Atoi(strings.TrimSpace(string(pidData)))
if err != nil {
return fmt.Errorf("failed to parse pid %q from %q: %w", string(pidData), pidFile, err)
}
proc, err := os.FindProcess(pid)
if err != nil {
return fmt.Errorf("failed to find process %d: %w", pid, err)
}
if err := proc.Kill(); err != nil {
return fmt.Errorf("failed to kill process %d: %w", pid, err)
}
return nil
}
Loading