From 6845060d000dba860b5a22e2d427a61d39a43772 Mon Sep 17 00:00:00 2001 From: peio Date: Mon, 19 Jan 2026 17:18:37 +0000 Subject: [PATCH] fix: postfix socketmap protocol --- internal/kcpolicy/socketmap.go | 83 ++++++++++++++++++++++++++++------ 1 file changed, 70 insertions(+), 13 deletions(-) diff --git a/internal/kcpolicy/socketmap.go b/internal/kcpolicy/socketmap.go index 3b1a4fb..21af2b0 100644 --- a/internal/kcpolicy/socketmap.go +++ b/internal/kcpolicy/socketmap.go @@ -4,8 +4,10 @@ import ( "bufio" "context" "fmt" + "io" "net" "os" + "strconv" "strings" ) @@ -37,49 +39,104 @@ func RunSocketmap(ctx context.Context, cfg *Config, db *AliasDB) error { } } -// Socketmap protocol: "mapname key\n" -> "OK value\n" or "NOTFOUND\n" or "TEMP\n" +// Postfix socketmap framing: ":," func handleSocketmapConn(conn net.Conn, cfg *Config, db *AliasDB) { defer conn.Close() r := bufio.NewReader(conn) for { - line, err := r.ReadString('\n') + payload, err := readSocketmapFrame(r) if err != nil { + // normal close return } - line = strings.TrimSpace(line) - if line == "" { + + payload = strings.TrimSpace(payload) + if payload == "" { + _ = writeSocketmapFrame(conn, "NOTFOUND") continue } - parts := strings.SplitN(line, " ", 2) + + parts := strings.SplitN(payload, " ", 2) if len(parts) != 2 { - fmt.Fprint(conn, "TEMP\n") + _ = writeSocketmapFrame(conn, "TEMP") continue } + mapName := parts[0] key := strings.ToLower(strings.TrimSpace(parts[1])) if mapName != "alias" { - fmt.Fprint(conn, "NOTFOUND\n") + _ = writeSocketmapFrame(conn, "NOTFOUND") continue } // Only handle our domain - if !strings.HasSuffix(key, "@"+strings.ToLower(cfg.Policy.Domain)) { - fmt.Fprint(conn, "NOTFOUND\n") + domain := strings.ToLower(cfg.Policy.Domain) + if !strings.HasSuffix(key, "@"+domain) { + _ = writeSocketmapFrame(conn, "NOTFOUND") continue } username, ok, err := db.AliasOwner(key) if err != nil { - fmt.Fprint(conn, "TEMP\n") + _ = writeSocketmapFrame(conn, "TEMP") continue } if !ok { - fmt.Fprint(conn, "NOTFOUND\n") + _ = writeSocketmapFrame(conn, "NOTFOUND") continue } - // rewrite alias -> primary rcpt (username@domain) - fmt.Fprintf(conn, "OK %s@%s\n", username, strings.ToLower(cfg.Policy.Domain)) + + // rewrite alias -> username@domain + _ = writeSocketmapFrame(conn, fmt.Sprintf("OK %s@%s", username, domain)) } } + +func readSocketmapFrame(r *bufio.Reader) (string, error) { + // read decimal length until ':' + var lenBuf strings.Builder + for { + b, err := r.ReadByte() + if err != nil { + return "", err + } + if b == ':' { + break + } + if b < '0' || b > '9' { + return "", io.ErrUnexpectedEOF + } + lenBuf.WriteByte(b) + if lenBuf.Len() > 10 { + return "", io.ErrUnexpectedEOF + } + } + + n, err := strconv.Atoi(lenBuf.String()) + if err != nil || n < 0 || n > 1024*1024 { + return "", io.ErrUnexpectedEOF + } + + buf := make([]byte, n) + if _, err := io.ReadFull(r, buf); err != nil { + return "", err + } + + // expect trailing comma + b, err := r.ReadByte() + if err != nil { + return "", err + } + if b != ',' { + return "", io.ErrUnexpectedEOF + } + + return string(buf), nil +} + +func writeSocketmapFrame(w io.Writer, payload string) error { + // ":," + _, err := fmt.Fprintf(w, "%d:%s,", len(payload), payload) + return err +}