diff --git a/flake.nix b/flake.nix index de5752d..31e8a02 100644 --- a/flake.nix +++ b/flake.nix @@ -113,7 +113,7 @@ packages.default = pkgs.buildGoModule { inherit pname version; src = gitignore.lib.gitignoreSource ./server; - vendorHash = "sha256-PE9ns1W+7/ZBBxb7+96aXqBTzpDo5tGcfnCXAV8vp8E="; + vendorHash = "sha256-uaHWj0u71hGoOGRwH6rEZxvYXoeoyN6/FZeQ5/7zRfg="; preBuild = '' cp -r ${client} client diff --git a/server/go.mod b/server/go.mod index 96d41ad..5aad989 100644 --- a/server/go.mod +++ b/server/go.mod @@ -11,6 +11,7 @@ require ( github.com/rs/cors v1.11.1 golang.org/x/crypto v0.36.0 golang.org/x/net v0.37.0 + golang.org/x/time v0.11.0 google.golang.org/protobuf v1.36.5 gorm.io/driver/postgres v1.5.11 gorm.io/gorm v1.25.12 diff --git a/server/go.sum b/server/go.sum index 48f8e81..c6e5819 100644 --- a/server/go.sum +++ b/server/go.sum @@ -58,6 +58,8 @@ golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY= golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4= +golang.org/x/time v0.11.0 h1:/bpjEDfN9tkoN/ryeYHnv5hcMlc8ncjMcM4XBk5NWV0= +golang.org/x/time v0.11.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= google.golang.org/protobuf v1.36.5 h1:tPhr+woSbjfYvY6/GPufUoYizxw1cF/yFoxJ2fmpwlM= google.golang.org/protobuf v1.36.5/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/server/internal/handlers/auth.go b/server/internal/handlers/auth.go index 81a6a54..4cb11f2 100644 --- a/server/internal/handlers/auth.go +++ b/server/internal/handlers/auth.go @@ -9,6 +9,7 @@ import ( "connectrpc.com/connect" "github.com/golang-jwt/jwt/v5" + "github.com/spotdemo4/trevstack/server/internal/interceptors" "github.com/spotdemo4/trevstack/server/internal/models" userv1 "github.com/spotdemo4/trevstack/server/internal/services/user/v1" "github.com/spotdemo4/trevstack/server/internal/services/user/v1/userv1connect" @@ -118,8 +119,13 @@ func (s *AuthHandler) Logout(ctx context.Context, req *connect.Request[userv1.Lo } func NewAuthHandler(db *gorm.DB, key string) (string, http.Handler) { - return userv1connect.NewAuthServiceHandler(&AuthHandler{ - db: db, - key: []byte(key), - }) + interceptors := connect.WithInterceptors(interceptors.NewRateLimitInterceptor(key)) + + return userv1connect.NewAuthServiceHandler( + &AuthHandler{ + db: db, + key: []byte(key), + }, + interceptors, + ) } diff --git a/server/internal/interceptors/ratelimit.go b/server/internal/interceptors/ratelimit.go new file mode 100644 index 0000000..74138f6 --- /dev/null +++ b/server/internal/interceptors/ratelimit.go @@ -0,0 +1,106 @@ +package interceptors + +import ( + "context" + "log" + "sync" + "time" + + "connectrpc.com/connect" + "golang.org/x/time/rate" +) + +type visitor struct { + limiter *rate.Limiter + lastSeen time.Time +} + +type ratelimitInterceptor struct { + key string + visitors map[string]*visitor + mu sync.Mutex +} + +func NewRateLimitInterceptor(key string) *ratelimitInterceptor { + rl := &ratelimitInterceptor{ + key: key, + visitors: make(map[string]*visitor), + mu: sync.Mutex{}, + } + + go rl.cleanupVisitors() + + return rl +} + +func (i *ratelimitInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc { + // Same as previous UnaryInterceptorFunc. + return connect.UnaryFunc(func( + ctx context.Context, + req connect.AnyRequest, + ) (connect.AnyResponse, error) { + // Check if the request is from a client + if req.Spec().IsClient { + return next(ctx, req) + } + + // Get ip + log.Println(req.Peer().Addr) + + return next(ctx, req) + }) +} + +func (*ratelimitInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc { + return connect.StreamingClientFunc(func( + ctx context.Context, + spec connect.Spec, + ) connect.StreamingClientConn { + return next(ctx, spec) + }) +} + +func (i *ratelimitInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc { + return connect.StreamingHandlerFunc(func( + ctx context.Context, + conn connect.StreamingHandlerConn, + ) error { + // Get ip + log.Println(conn.Peer().Query) + + return next(ctx, conn) + }) +} + +func (i *ratelimitInterceptor) getVisitor(ip string) *rate.Limiter { + i.mu.Lock() + defer i.mu.Unlock() + + v, exists := i.visitors[ip] + if !exists { + limiter := rate.NewLimiter(1, 3) + // Include the current time when creating a new visitor. + i.visitors[ip] = &visitor{limiter, time.Now()} + return limiter + } + + // Update the last seen time for the visitor. + v.lastSeen = time.Now() + return v.limiter +} + +// Every minute check the map for visitors that haven't been seen for +// more than 3 minutes and delete the entries. +func (i *ratelimitInterceptor) cleanupVisitors() { + for { + time.Sleep(time.Minute) + + i.mu.Lock() + for ip, v := range i.visitors { + if time.Since(v.lastSeen) > 3*time.Minute { + delete(i.visitors, ip) + } + } + i.mu.Unlock() + } +}