Initial commit
This commit is contained in:
62
internal/kcpolicy/config.go
Normal file
62
internal/kcpolicy/config.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package kcpolicy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Keycloak struct {
|
||||
BaseURL string `yaml:"base_url"`
|
||||
Realm string `yaml:"realm"`
|
||||
ClientID string `yaml:"client_id"`
|
||||
ClientSecret string `yaml:"client_secret"`
|
||||
} `yaml:"keycloak"`
|
||||
|
||||
SQLite struct {
|
||||
Path string `yaml:"path"`
|
||||
} `yaml:"sqlite"`
|
||||
|
||||
Policy struct {
|
||||
Domain string `yaml:"domain"`
|
||||
CacheTTLSeconds int `yaml:"cache_ttl_seconds"`
|
||||
KeycloakFailureMode string `yaml:"keycloak_failure_mode"` // "tempfail" or "dunno"
|
||||
} `yaml:"policy"`
|
||||
|
||||
Sockets struct {
|
||||
PolicySocket string `yaml:"policy_socket"`
|
||||
SocketmapSocket string `yaml:"socketmap_socket"`
|
||||
SocketOwnerUser string `yaml:"socket_owner_user"`
|
||||
SocketOwnerGroup string `yaml:"socket_owner_group"`
|
||||
SocketMode string `yaml:"socket_mode"`
|
||||
} `yaml:"sockets"`
|
||||
}
|
||||
|
||||
func LoadConfig(path string) (*Config, error) {
|
||||
b, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var cfg Config
|
||||
if err := yaml.Unmarshal(b, &cfg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if cfg.Keycloak.BaseURL == "" || cfg.Keycloak.Realm == "" {
|
||||
return nil, fmt.Errorf("missing keycloak.base_url or keycloak.realm")
|
||||
}
|
||||
if cfg.SQLite.Path == "" {
|
||||
return nil, fmt.Errorf("missing sqlite.path")
|
||||
}
|
||||
if cfg.Policy.Domain == "" {
|
||||
return nil, fmt.Errorf("missing policy.domain")
|
||||
}
|
||||
if cfg.Policy.CacheTTLSeconds <= 0 {
|
||||
cfg.Policy.CacheTTLSeconds = 120
|
||||
}
|
||||
if cfg.Policy.KeycloakFailureMode == "" {
|
||||
cfg.Policy.KeycloakFailureMode = "tempfail"
|
||||
}
|
||||
return &cfg, nil
|
||||
}
|
||||
153
internal/kcpolicy/keycloak.go
Normal file
153
internal/kcpolicy/keycloak.go
Normal file
@@ -0,0 +1,153 @@
|
||||
package kcpolicy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Keycloak struct {
|
||||
cfg *Config
|
||||
hc *http.Client
|
||||
}
|
||||
|
||||
func NewKeycloak(cfg *Config) *Keycloak {
|
||||
return &Keycloak{
|
||||
cfg: cfg,
|
||||
hc: &http.Client{Timeout: 5 * time.Second},
|
||||
}
|
||||
}
|
||||
|
||||
type tokenResp struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
}
|
||||
|
||||
func (k *Keycloak) token(ctx context.Context) (string, error) {
|
||||
form := url.Values{}
|
||||
form.Set("grant_type", "client_credentials")
|
||||
form.Set("client_id", k.cfg.Keycloak.ClientID)
|
||||
form.Set("client_secret", k.cfg.Keycloak.ClientSecret)
|
||||
|
||||
u := strings.TrimRight(k.cfg.Keycloak.BaseURL, "/") +
|
||||
"/realms/" + url.PathEscape(k.cfg.Keycloak.Realm) +
|
||||
"/protocol/openid-connect/token"
|
||||
|
||||
req, _ := http.NewRequestWithContext(ctx, "POST", u, strings.NewReader(form.Encode()))
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
resp, err := k.hc.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode/100 != 2 {
|
||||
b, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
|
||||
return "", fmt.Errorf("token http %d: %s", resp.StatusCode, string(b))
|
||||
}
|
||||
var tr tokenResp
|
||||
if err := json.NewDecoder(resp.Body).Decode(&tr); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if tr.AccessToken == "" {
|
||||
return "", fmt.Errorf("empty access_token")
|
||||
}
|
||||
return tr.AccessToken, nil
|
||||
}
|
||||
|
||||
type kcUser struct {
|
||||
ID string `json:"id"`
|
||||
Username string `json:"username"`
|
||||
Email string `json:"email"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Attrs map[string][]string `json:"attributes"`
|
||||
}
|
||||
|
||||
func (k *Keycloak) adminGet(ctx context.Context, bearer, path string, q url.Values) ([]kcUser, error) {
|
||||
base := strings.TrimRight(k.cfg.Keycloak.BaseURL, "/") +
|
||||
"/admin/realms/" + url.PathEscape(k.cfg.Keycloak.Realm) + path
|
||||
|
||||
u := base
|
||||
if q != nil {
|
||||
u += "?" + q.Encode()
|
||||
}
|
||||
|
||||
req, _ := http.NewRequestWithContext(ctx, "GET", u, nil)
|
||||
req.Header.Set("Authorization", "Bearer "+bearer)
|
||||
|
||||
resp, err := k.hc.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode/100 != 2 {
|
||||
b, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
|
||||
return nil, fmt.Errorf("admin http %d: %s", resp.StatusCode, string(b))
|
||||
}
|
||||
var users []kcUser
|
||||
if err := json.NewDecoder(resp.Body).Decode(&users); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// Find primary email by username (exact if supported)
|
||||
func (k *Keycloak) EmailByUsername(ctx context.Context, username string) (string, bool, error) {
|
||||
bearer, err := k.token(ctx)
|
||||
if err != nil {
|
||||
return "", false, err
|
||||
}
|
||||
|
||||
q := url.Values{}
|
||||
q.Set("username", username)
|
||||
q.Set("exact", "true")
|
||||
users, err := k.adminGet(ctx, bearer, "/users", q)
|
||||
if err != nil {
|
||||
// fallback: search
|
||||
q2 := url.Values{}
|
||||
q2.Set("search", username)
|
||||
users, err = k.adminGet(ctx, bearer, "/users", q2)
|
||||
if err != nil {
|
||||
return "", false, err
|
||||
}
|
||||
}
|
||||
|
||||
for _, u := range users {
|
||||
if strings.EqualFold(u.Username, username) && u.Enabled && u.Email != "" {
|
||||
return strings.ToLower(u.Email), true, nil
|
||||
}
|
||||
}
|
||||
return "", false, nil
|
||||
}
|
||||
|
||||
// Check if an email exists as primary user email
|
||||
func (k *Keycloak) EmailExists(ctx context.Context, email string) (bool, error) {
|
||||
bearer, err := k.token(ctx)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
q := url.Values{}
|
||||
q.Set("email", email)
|
||||
q.Set("exact", "true")
|
||||
users, err := k.adminGet(ctx, bearer, "/users", q)
|
||||
if err != nil {
|
||||
// fallback: search
|
||||
q2 := url.Values{}
|
||||
q2.Set("search", email)
|
||||
users, err = k.adminGet(ctx, bearer, "/users", q2)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
for _, u := range users {
|
||||
if u.Enabled && strings.EqualFold(u.Email, email) {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
198
internal/kcpolicy/policy.go
Normal file
198
internal/kcpolicy/policy.go
Normal file
@@ -0,0 +1,198 @@
|
||||
package kcpolicy
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Cache struct {
|
||||
ttl time.Duration
|
||||
m map[string]cacheItem
|
||||
}
|
||||
|
||||
type cacheItem struct {
|
||||
val string
|
||||
expires time.Time
|
||||
ok bool
|
||||
}
|
||||
|
||||
func NewCache(ttl time.Duration) *Cache {
|
||||
return &Cache{ttl: ttl, m: make(map[string]cacheItem)}
|
||||
}
|
||||
|
||||
func (c *Cache) Get(key string) (string, bool, bool) {
|
||||
it, ok := c.m[key]
|
||||
if !ok || time.Now().After(it.expires) {
|
||||
return "", false, false
|
||||
}
|
||||
return it.val, it.ok, true
|
||||
}
|
||||
|
||||
func (c *Cache) Put(key, val string, ok bool) {
|
||||
c.m[key] = cacheItem{val: val, ok: ok, expires: time.Now().Add(c.ttl)}
|
||||
}
|
||||
|
||||
func RunPolicy(ctx context.Context, cfg *Config, db *AliasDB, kc *Keycloak, cache *Cache) error {
|
||||
sock := cfg.Sockets.PolicySocket
|
||||
_ = os.Remove(sock)
|
||||
|
||||
l, err := net.Listen("unix", sock)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer l.Close()
|
||||
|
||||
if err := ChownChmodSocket(sock, cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for {
|
||||
conn, err := l.Accept()
|
||||
if err != nil {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
default:
|
||||
return err
|
||||
}
|
||||
}
|
||||
go handlePolicyConn(conn, cfg, db, kc, cache)
|
||||
}
|
||||
}
|
||||
|
||||
func handlePolicyConn(conn net.Conn, cfg *Config, db *AliasDB, kc *Keycloak, cache *Cache) {
|
||||
defer conn.Close()
|
||||
r := bufio.NewReader(conn)
|
||||
|
||||
req := map[string]string{}
|
||||
for {
|
||||
line, err := r.ReadString('\n')
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
line = strings.TrimRight(line, "\r\n")
|
||||
if line == "" {
|
||||
break
|
||||
}
|
||||
if i := strings.IndexByte(line, '='); i > 0 {
|
||||
req[line[:i]] = line[i+1:]
|
||||
}
|
||||
}
|
||||
|
||||
// Decide based on protocol_state
|
||||
state := req["protocol_state"] // e.g. RCPT, MAIL
|
||||
saslUser := req["sasl_username"]
|
||||
sender := strings.ToLower(req["sender"])
|
||||
rcpt := strings.ToLower(req["recipient"])
|
||||
|
||||
action := "DUNNO"
|
||||
|
||||
switch state {
|
||||
case "RCPT":
|
||||
action = policyRCPT(cfg, db, kc, cache, rcpt)
|
||||
case "MAIL":
|
||||
// On MAIL stage we can validate sender if authenticated (submission)
|
||||
if saslUser != "" && sender != "" {
|
||||
action = policyMAIL(cfg, db, kc, cache, saslUser, sender)
|
||||
}
|
||||
default:
|
||||
action = "DUNNO"
|
||||
}
|
||||
|
||||
fmt.Fprintf(conn, "action=%s\n\n", action)
|
||||
}
|
||||
|
||||
func policyRCPT(cfg *Config, db *AliasDB, kc *Keycloak, cache *Cache, rcpt string) string {
|
||||
if rcpt == "" {
|
||||
return "DUNNO"
|
||||
}
|
||||
// Only enforce for our domain
|
||||
if !strings.HasSuffix(rcpt, "@"+strings.ToLower(cfg.Policy.Domain)) {
|
||||
return "DUNNO"
|
||||
}
|
||||
|
||||
// 1) exists in keycloak primary email?
|
||||
key := "email_exists:" + rcpt
|
||||
if _, ok, hit := cache.Get(key); hit {
|
||||
if ok {
|
||||
return "DUNNO"
|
||||
}
|
||||
} else {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
exists, err := kc.EmailExists(ctx, rcpt)
|
||||
if err != nil {
|
||||
if cfg.Policy.KeycloakFailureMode == "dunno" {
|
||||
return "DUNNO"
|
||||
}
|
||||
return "451 4.3.0 Temporary authentication/lookup failure"
|
||||
}
|
||||
cache.Put(key, "", exists)
|
||||
if exists {
|
||||
return "DUNNO"
|
||||
}
|
||||
}
|
||||
|
||||
// 2) exists as sqlite alias?
|
||||
_, ok, err := db.AliasOwner(rcpt)
|
||||
if err != nil {
|
||||
log.Printf("sqlite rcpt lookup error: %v", err)
|
||||
return "451 4.3.0 Temporary internal error"
|
||||
}
|
||||
if ok {
|
||||
return "DUNNO"
|
||||
}
|
||||
|
||||
return "550 5.1.1 No such user"
|
||||
}
|
||||
|
||||
func policyMAIL(cfg *Config, db *AliasDB, kc *Keycloak, cache *Cache, saslUser, sender string) string {
|
||||
// Allow empty sender (bounce)
|
||||
if sender == "" || sender == "<>" {
|
||||
return "DUNNO"
|
||||
}
|
||||
// Only enforce our domain senders (optional)
|
||||
if !strings.HasSuffix(sender, "@"+strings.ToLower(cfg.Policy.Domain)) {
|
||||
return "DUNNO"
|
||||
}
|
||||
|
||||
// primary email from keycloak (cached)
|
||||
key := "email_by_username:" + strings.ToLower(saslUser)
|
||||
email, ok, hit := cache.Get(key)
|
||||
if !hit {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
e, exists, err := kc.EmailByUsername(ctx, saslUser)
|
||||
if err != nil {
|
||||
if cfg.Policy.KeycloakFailureMode == "dunno" {
|
||||
return "DUNNO"
|
||||
}
|
||||
return "451 4.3.0 Temporary authentication/lookup failure"
|
||||
}
|
||||
cache.Put(key, e, exists)
|
||||
email, ok = e, exists
|
||||
}
|
||||
|
||||
// 1) sender == primary email
|
||||
if ok && strings.EqualFold(sender, email) {
|
||||
return "DUNNO"
|
||||
}
|
||||
|
||||
// 2) sender is sqlite alias belonging to this user
|
||||
belongs, err := db.AliasBelongsTo(sender, saslUser)
|
||||
if err != nil {
|
||||
log.Printf("sqlite sender lookup error: %v", err)
|
||||
return "451 4.3.0 Temporary internal error"
|
||||
}
|
||||
if belongs {
|
||||
return "DUNNO"
|
||||
}
|
||||
|
||||
return "553 5.7.1 Sender not owned by authenticated user"
|
||||
}
|
||||
31
internal/kcpolicy/socket_perms.go
Normal file
31
internal/kcpolicy/socket_perms.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package kcpolicy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/user"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
func ChownChmodSocket(path string, cfg *Config) error {
|
||||
u, err := user.Lookup(cfg.Sockets.SocketOwnerUser)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
g, err := user.LookupGroup(cfg.Sockets.SocketOwnerGroup)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
uid, _ := strconv.Atoi(u.Uid)
|
||||
gid, _ := strconv.Atoi(g.Gid)
|
||||
|
||||
if err := os.Chown(path, uid, gid); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
mode, err := strconv.ParseUint(cfg.Sockets.SocketMode, 8, 32)
|
||||
if err != nil {
|
||||
return fmt.Errorf("bad socket_mode: %w", err)
|
||||
}
|
||||
return os.Chmod(path, os.FileMode(mode))
|
||||
}
|
||||
85
internal/kcpolicy/socketmap.go
Normal file
85
internal/kcpolicy/socketmap.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package kcpolicy
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func RunSocketmap(ctx context.Context, cfg *Config, db *AliasDB) error {
|
||||
sock := cfg.Sockets.SocketmapSocket
|
||||
_ = os.Remove(sock)
|
||||
|
||||
l, err := net.Listen("unix", sock)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer l.Close()
|
||||
|
||||
if err := ChownChmodSocket(sock, cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for {
|
||||
conn, err := l.Accept()
|
||||
if err != nil {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
default:
|
||||
return err
|
||||
}
|
||||
}
|
||||
go handleSocketmapConn(conn, cfg, db)
|
||||
}
|
||||
}
|
||||
|
||||
// Socketmap protocol: "mapname key\n" -> "OK value\n" or "NOTFOUND\n" or "TEMP\n"
|
||||
func handleSocketmapConn(conn net.Conn, cfg *Config, db *AliasDB) {
|
||||
defer conn.Close()
|
||||
r := bufio.NewReader(conn)
|
||||
|
||||
for {
|
||||
line, err := r.ReadString('\n')
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
parts := strings.SplitN(line, " ", 2)
|
||||
if len(parts) != 2 {
|
||||
fmt.Fprint(conn, "TEMP\n")
|
||||
continue
|
||||
}
|
||||
mapName := parts[0]
|
||||
key := strings.ToLower(strings.TrimSpace(parts[1]))
|
||||
|
||||
if mapName != "alias" {
|
||||
fmt.Fprint(conn, "NOTFOUND\n")
|
||||
continue
|
||||
}
|
||||
|
||||
// Only handle our domain
|
||||
if !strings.HasSuffix(key, "@"+strings.ToLower(cfg.Policy.Domain)) {
|
||||
fmt.Fprint(conn, "NOTFOUND\n")
|
||||
continue
|
||||
}
|
||||
|
||||
username, ok, err := db.AliasOwner(key)
|
||||
if err != nil {
|
||||
fmt.Fprint(conn, "TEMP\n")
|
||||
continue
|
||||
}
|
||||
if !ok {
|
||||
fmt.Fprint(conn, "NOTFOUND\n")
|
||||
continue
|
||||
}
|
||||
// rewrite alias -> primary rcpt (username@domain)
|
||||
fmt.Fprintf(conn, "OK %s@%s\n", username, strings.ToLower(cfg.Policy.Domain))
|
||||
}
|
||||
}
|
||||
62
internal/kcpolicy/sqlite.go
Normal file
62
internal/kcpolicy/sqlite.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package kcpolicy
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
type AliasDB struct{ DB *sql.DB }
|
||||
|
||||
func OpenAliasDB(path string) (*AliasDB, error) {
|
||||
db, err := sql.Open("sqlite", path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, err := db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS aliases (
|
||||
alias_email TEXT PRIMARY KEY,
|
||||
username TEXT NOT NULL,
|
||||
enabled INTEGER NOT NULL DEFAULT 1,
|
||||
updated_at INTEGER NOT NULL DEFAULT (strftime('%s','now'))
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_aliases_username ON aliases(username);
|
||||
`); err != nil {
|
||||
_ = db.Close()
|
||||
return nil, fmt.Errorf("init schema: %w", err)
|
||||
}
|
||||
return &AliasDB{DB: db}, nil
|
||||
}
|
||||
|
||||
func (a *AliasDB) Close() error { return a.DB.Close() }
|
||||
|
||||
// Returns username owning alias, ok
|
||||
func (a *AliasDB) AliasOwner(aliasEmail string) (string, bool, error) {
|
||||
var username string
|
||||
var enabled int
|
||||
err := a.DB.QueryRow(`SELECT username, enabled FROM aliases WHERE alias_email=?`, aliasEmail).Scan(&username, &enabled)
|
||||
if err == sql.ErrNoRows {
|
||||
return "", false, nil
|
||||
}
|
||||
if err != nil {
|
||||
return "", false, err
|
||||
}
|
||||
if enabled != 1 {
|
||||
return "", false, nil
|
||||
}
|
||||
return username, true, nil
|
||||
}
|
||||
|
||||
// Returns true if alias belongs to username
|
||||
func (a *AliasDB) AliasBelongsTo(aliasEmail, username string) (bool, error) {
|
||||
var enabled int
|
||||
err := a.DB.QueryRow(`SELECT enabled FROM aliases WHERE alias_email=? AND username=?`, aliasEmail, username).Scan(&enabled)
|
||||
if err == sql.ErrNoRows {
|
||||
return false, nil
|
||||
}
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return enabled == 1, nil
|
||||
}
|
||||
Reference in New Issue
Block a user