This commit is contained in:
binwiederhier 2025-07-04 10:16:49 +02:00
parent d8c8f31846
commit 54514454bf
6 changed files with 125 additions and 20 deletions

View file

@ -61,6 +61,8 @@ const (
DefaultVisitorAuthFailureLimitReplenish = time.Minute
DefaultVisitorAttachmentTotalSizeLimit = 100 * 1024 * 1024 // 100 MB
DefaultVisitorAttachmentDailyBandwidthLimit = 500 * 1024 * 1024 // 500 MB
DefaultVisitorPrefixBitsIPv4 = 32 // Use the entire IPv4 address for rate limiting
DefaultVisitorPrefixBitsIPv6 = 64 // Use /64 for IPv6 rate limiting
)
var (
@ -143,6 +145,8 @@ type Config struct {
VisitorAuthFailureLimitReplenish time.Duration
VisitorStatsResetTime time.Time // Time of the day at which to reset visitor stats
VisitorSubscriberRateLimiting bool // Enable subscriber-based rate limiting for UnifiedPush topics
VisitorPrefixBitsIPv4 int // Number of bits for IPv4 rate limiting (default: 32)
VisitorPrefixBitsIPv6 int // Number of bits for IPv6 rate limiting (default: 64)
BehindProxy bool // If true, the server will trust the proxy client IP header to determine the client IP address (IPv4 and IPv6 supported)
ProxyForwardedHeader string // The header field to read the real/client IP address from, if BehindProxy is true, defaults to "X-Forwarded-For" (IPv4 and IPv6 supported)
ProxyTrustedAddresses []string // List of trusted proxy addresses (IPv4 or IPv6) that will be stripped from the Forwarded header if BehindProxy is true
@ -234,6 +238,8 @@ func NewConfig() *Config {
VisitorAuthFailureLimitReplenish: DefaultVisitorAuthFailureLimitReplenish,
VisitorStatsResetTime: DefaultVisitorStatsResetTime,
VisitorSubscriberRateLimiting: false,
VisitorPrefixBitsIPv4: 32, // Default: use full IPv4 address
VisitorPrefixBitsIPv6: 64, // Default: use /64 for IPv6
BehindProxy: false, // If true, the server will trust the proxy client IP header to determine the client IP address
ProxyForwardedHeader: "X-Forwarded-For", // Default header for reverse proxy client IPs
StripeSecretKey: "",

View file

@ -2023,7 +2023,7 @@ func (s *Server) authenticateBearerAuth(r *http.Request, token string) (*user.Us
func (s *Server) visitor(ip netip.Addr, user *user.User) *visitor {
s.mu.Lock()
defer s.mu.Unlock()
id := visitorID(ip, user)
id := visitorID(ip, user, s.config)
v, exists := s.visitors[id]
if !exists {
s.visitors[id] = newVisitor(s.config, s.messageCache, s.userManager, ip, user)

View file

@ -1169,7 +1169,7 @@ func (t *testMailer) Count() int {
return t.count
}
func TestServer_PublishTooRequests_Defaults(t *testing.T) {
func TestServer_PublishTooManyRequests_Defaults(t *testing.T) {
s := newTestServer(t, newTestConfig(t))
for i := 0; i < 60; i++ {
response := request(t, s, "PUT", "/mytopic", fmt.Sprintf("message %d", i), nil)
@ -1179,7 +1179,50 @@ func TestServer_PublishTooRequests_Defaults(t *testing.T) {
require.Equal(t, 429, response.Code)
}
func TestServer_PublishTooRequests_Defaults_ExemptHosts(t *testing.T) {
func TestServer_PublishTooManyRequests_Defaults_IPv6(t *testing.T) {
s := newTestServer(t, newTestConfig(t))
overrideRemoteAddr1 := func(r *http.Request) {
r.RemoteAddr = "[2001:db8:9999:8888:1::1]:1234"
}
overrideRemoteAddr2 := func(r *http.Request) {
r.RemoteAddr = "[2001:db8:9999:8888:2::1]:1234" // Same /64
}
for i := 0; i < 30; i++ {
response := request(t, s, "PUT", "/mytopic", fmt.Sprintf("message %d", i), nil, overrideRemoteAddr1)
require.Equal(t, 200, response.Code)
}
for i := 0; i < 30; i++ {
response := request(t, s, "PUT", "/mytopic", fmt.Sprintf("message %d", i), nil, overrideRemoteAddr2)
require.Equal(t, 200, response.Code)
}
response := request(t, s, "PUT", "/mytopic", "message", nil, overrideRemoteAddr1)
require.Equal(t, 429, response.Code)
}
func TestServer_PublishTooManyRequests_IPv6_Slash48(t *testing.T) {
c := newTestConfig(t)
c.VisitorRequestLimitBurst = 6
c.VisitorPrefixBitsIPv6 = 48 // Use /48 for IPv6 prefixes
s := newTestServer(t, c)
overrideRemoteAddr1 := func(r *http.Request) {
r.RemoteAddr = "[2001:db8:9999::1]:1234"
}
overrideRemoteAddr2 := func(r *http.Request) {
r.RemoteAddr = "[2001:db8:9999::2]:1234" // Same /48
}
for i := 0; i < 3; i++ {
response := request(t, s, "PUT", "/mytopic", fmt.Sprintf("message %d", i), nil, overrideRemoteAddr1)
require.Equal(t, 200, response.Code)
}
for i := 0; i < 3; i++ {
response := request(t, s, "PUT", "/mytopic", fmt.Sprintf("message %d", i), nil, overrideRemoteAddr2)
require.Equal(t, 200, response.Code)
}
response := request(t, s, "PUT", "/mytopic", "message", nil, overrideRemoteAddr1)
require.Equal(t, 429, response.Code)
}
func TestServer_PublishTooManyRequests_Defaults_ExemptHosts(t *testing.T) {
c := newTestConfig(t)
c.VisitorRequestLimitBurst = 3
c.VisitorRequestExemptIPAddrs = []netip.Prefix{netip.MustParsePrefix("9.9.9.9/32")} // see request()
@ -1190,7 +1233,21 @@ func TestServer_PublishTooRequests_Defaults_ExemptHosts(t *testing.T) {
}
}
func TestServer_PublishTooRequests_Defaults_ExemptHosts_MessageDailyLimit(t *testing.T) {
func TestServer_PublishTooManyRequests_Defaults_ExemptHosts_IPv6(t *testing.T) {
c := newTestConfig(t)
c.VisitorRequestLimitBurst = 3
c.VisitorRequestExemptIPAddrs = []netip.Prefix{netip.MustParsePrefix("2001:db8:9999::/48")}
s := newTestServer(t, c)
overrideRemoteAddr := func(r *http.Request) {
r.RemoteAddr = "[2001:db8:9999::1]:1234"
}
for i := 0; i < 5; i++ { // > 3
response := request(t, s, "PUT", "/mytopic", fmt.Sprintf("message %d", i), nil, overrideRemoteAddr)
require.Equal(t, 200, response.Code)
}
}
func TestServer_PublishTooManyRequests_Defaults_ExemptHosts_MessageDailyLimit(t *testing.T) {
c := newTestConfig(t)
c.VisitorRequestLimitBurst = 10
c.VisitorMessageDailyLimit = 4
@ -1202,7 +1259,7 @@ func TestServer_PublishTooRequests_Defaults_ExemptHosts_MessageDailyLimit(t *tes
}
}
func TestServer_PublishTooRequests_ShortReplenish(t *testing.T) {
func TestServer_PublishTooManyRequests_ShortReplenish(t *testing.T) {
t.Parallel()
c := newTestConfig(t)
c.VisitorRequestLimitBurst = 60
@ -2244,6 +2301,19 @@ func TestServer_Visitor_Custom_ClientIP_Header(t *testing.T) {
require.Equal(t, "1.2.3.4", v.ip.String())
}
func TestServer_Visitor_Custom_ClientIP_Header_IPv6(t *testing.T) {
c := newTestConfig(t)
c.BehindProxy = true
c.ProxyForwardedHeader = "X-Client-IP"
s := newTestServer(t, c)
r, _ := http.NewRequest("GET", "/bla", nil)
r.RemoteAddr = "[2001:db8:9999::1]:1234"
r.Header.Set("X-Client-IP", "2001:db8:7777::1")
v, err := s.maybeAuthenticate(r)
require.Nil(t, err)
require.Equal(t, "2001:db8:7777::1", v.ip.String())
}
func TestServer_Visitor_Custom_Forwarded_Header(t *testing.T) {
c := newTestConfig(t)
c.BehindProxy = true
@ -2258,6 +2328,20 @@ func TestServer_Visitor_Custom_Forwarded_Header(t *testing.T) {
require.Equal(t, "5.6.7.8", v.ip.String())
}
func TestServer_Visitor_Custom_Forwarded_Header_IPv6(t *testing.T) {
c := newTestConfig(t)
c.BehindProxy = true
c.ProxyForwardedHeader = "Forwarded"
c.ProxyTrustedAddresses = []string{"2001:db8:1111::1"}
s := newTestServer(t, c)
r, _ := http.NewRequest("GET", "/bla", nil)
r.RemoteAddr = "[2001:db8:2222::1]:1234"
r.Header.Set("Forwarded", " for=[2001:db8:1111::1], by=example.com;for=[2001:db8:3333::1]")
v, err := s.maybeAuthenticate(r)
require.Nil(t, err)
require.Equal(t, "2001:db8:3333::1", v.ip.String())
}
func TestServer_PublishWhileUpdatingStatsWithLotsOfMessages(t *testing.T) {
t.Parallel()
count := 50000

View file

@ -22,8 +22,13 @@ var (
priorityHeaderIgnoreRegex = regexp.MustCompile(`^u=\d,\s*(i|\d)$|^u=\d$`)
// forwardedHeaderRegex parses IPv4 and IPv6 addresses from the "Forwarded" header (RFC 7239)
// IPv6 addresses in Forwarded header are enclosed in square brackets, e.g. for="[2001:db8::1]"
forwardedHeaderRegex = regexp.MustCompile(`(?i)\\bfor=\"?((?:[0-9]{1,3}\.){3}[0-9]{1,3}|\[[0-9a-fA-F:]+\])\"?`)
// IPv6 addresses in Forwarded header are enclosed in square brackets. The port is optional.
//
// Examples:
// for="1.2.3.4"
// for="[2001:db8::1]"; for=1.2.3.4:8080, by=phil
// for="1.2.3.4:8080"
forwardedHeaderRegex = regexp.MustCompile(`(?i)\bfor="?(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}|\[[0-9a-f:]+])(?::\d+)?"?`)
)
func readBoolParam(r *http.Request, defaultValue bool, names ...string) bool {
@ -105,7 +110,7 @@ func extractIPAddress(r *http.Request, behindProxy bool, proxyForwardedHeader st
// then take the right-most address in the list (as this is the one added by our proxy server).
// See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For for details.
func extractIPAddressFromHeader(r *http.Request, forwardedHeader string, trustedAddresses []string) (netip.Addr, error) {
value := strings.TrimSpace(r.Header.Get(forwardedHeader))
value := strings.TrimSpace(strings.ToLower(r.Header.Get(forwardedHeader)))
if value == "" {
return netip.IPv4Unspecified(), fmt.Errorf("no %s header found", forwardedHeader)
}

View file

@ -2,13 +2,13 @@ package server
import (
"fmt"
"heckel.io/ntfy/v2/log"
"heckel.io/ntfy/v2/user"
"net/netip"
"sync"
"time"
"golang.org/x/time/rate"
"heckel.io/ntfy/v2/log"
"heckel.io/ntfy/v2/user"
"heckel.io/ntfy/v2/util"
)
@ -151,7 +151,7 @@ func (v *visitor) Context() log.Context {
func (v *visitor) contextNoLock() log.Context {
info := v.infoLightNoLock()
fields := log.Context{
"visitor_id": visitorID(v.ip, v.user),
"visitor_id": visitorID(v.ip, v.user, v.config),
"visitor_ip": v.ip.String(),
"visitor_seen": util.FormatTime(v.seen),
"visitor_messages": info.Stats.Messages,
@ -524,15 +524,15 @@ func dailyLimitToRate(limit int64) rate.Limit {
return rate.Limit(limit) * rate.Every(oneDay)
}
func visitorID(ip netip.Addr, u *user.User) string {
// visitorID returns a unique identifier for a visitor based on user or IP, using configurable prefix bits for IPv4/IPv6
func visitorID(ip netip.Addr, u *user.User, conf *Config) string {
if u != nil && u.Tier != nil {
return fmt.Sprintf("user:%s", u.ID)
}
if ip.Is6() {
// IPv6 addresses are too long to be used as visitor IDs, so we use the first 8 bytes
ip = netip.PrefixFrom(ip, 64).Masked().Addr()
} else if ip.Is4() {
ip = netip.PrefixFrom(ip, 20).Masked().Addr()
if ip.Is4() {
ip = netip.PrefixFrom(ip, conf.VisitorPrefixBitsIPv4).Masked().Addr()
} else if ip.Is6() {
ip = netip.PrefixFrom(ip, conf.VisitorPrefixBitsIPv6).Masked().Addr()
}
return fmt.Sprintf("ip:%s", ip.String())
}