Title here
Summary here
// CircuitBreaker is a middleware that wraps the handler in a circuit breaker.
// Based on the configuration, the circuit breaker will fail fast if the handler keeps returning errors.
// This is useful for preventing cascading failures.
type CircuitBreaker struct {
cb *gobreaker.CircuitBreaker
}
// NewCircuitBreaker returns a new CircuitBreaker middleware.
// Refer to the gobreaker documentation for the available settings.
func NewCircuitBreaker(settings gobreaker.Settings) CircuitBreaker {
return CircuitBreaker{
cb: gobreaker.NewCircuitBreaker(settings),
}
}
// Middleware returns the CircuitBreaker middleware.
func (c CircuitBreaker) Middleware(h message.HandlerFunc) message.HandlerFunc {
return func(msg *message.Message) ([]*message.Message, error) {
out, err := c.cb.Execute(func() (interface{}, error) {
return h(msg)
})
var result []*message.Message
if out != nil {
result = out.([]*message.Message)
}
return result, err
}
}
// SetCorrelationID sets a correlation ID for the message.
//
// SetCorrelationID should be called when the message enters the system.
// When message is produced in a request (for example HTTP),
// message correlation ID should be the same as the request's correlation ID.
func SetCorrelationID(id string, msg *message.Message) {
if MessageCorrelationID(msg) != "" {
return
}
msg.Metadata.Set(CorrelationIDMetadataKey, id)
}
// MessageCorrelationID returns correlation ID from the message.
func MessageCorrelationID(message *message.Message) string {
return message.Metadata.Get(CorrelationIDMetadataKey)
}
// CorrelationID adds correlation ID to all messages produced by the handler.
// ID is based on ID from message received by handler.
//
// To make CorrelationID working correctly, SetCorrelationID must be called to first message entering the system.
func CorrelationID(h message.HandlerFunc) message.HandlerFunc {
return func(message *message.Message) ([]*message.Message, error) {
producedMessages, err := h(message)
correlationID := MessageCorrelationID(message)
for _, msg := range producedMessages {
SetCorrelationID(correlationID, msg)
}
return producedMessages, err
}
}
// Deduplicator drops similar messages if they are present
// in a [ExpiringKeyRepository]. The similarity is determined
// by a [MessageHasher]. Time out is applied to repository
// operations using [context.WithTimeout].
//
// Call [Deduplicator.Middleware] for a new middleware
// or [Deduplicator.Decorator] for a [message.PublisherDecorator].
//
// KeyFactory defaults to [NewMessageHasherAdler32] with read
// limit set to [math.MaxInt64] for fast tagging.
// Use [NewMessageHasherSHA256] for minimal collisions.
//
// Repository defaults to [NewMapExpiringKeyRepository] with one
// minute retention window. This default setting is performant
// but **does not support distributed operations**. If you
// implement a [ExpiringKeyRepository] backed by Redis,
// please submit a pull request.
//
// Timeout defaults to one minute. If lower than
// five milliseconds, it is set to five milliseconds.
//
// [ExpiringKeyRepository] must expire values
// in a certain time window. If there is no expiration, only one
// unique message will be ever delivered as long as the repository
// keeps its state.
type Deduplicator struct {
KeyFactory MessageHasher
Repository ExpiringKeyRepository
Timeout time.Duration
}
// IsDuplicate returns true if the message hash tag calculated
// using a [MessageHasher] was seen in deduplication time window.
func (d *Deduplicator) IsDuplicate(m *message.Message) (bool, error) {
key, err := d.KeyFactory(m)
if err != nil {
return false, err
}
ctx, cancel := context.WithTimeout(m.Context(), d.Timeout)
defer cancel()
return d.Repository.IsDuplicate(ctx, key)
}
// Middleware returns the [message.HandlerMiddleware]
// that drops similar messages in a given time window.
func (d *Deduplicator) Middleware(h message.HandlerFunc) message.HandlerFunc {
d = applyDefaultsToDeduplicator(d)
return func(msg *message.Message) ([]*message.Message, error) {
isDuplicate, err := d.IsDuplicate(msg)
if err != nil {
return nil, err
}
if isDuplicate {
return nil, nil
}
return h(msg)
}
}
// NewMapExpiringKeyRepository returns a memory store
// backed by a regular hash map protected by
// a [sync.Mutex]. The state **cannot be shared or synchronized
// between instances** by design for performance.
//
// If you need to drop duplicate messages by orchestration,
// implement [ExpiringKeyRepository] interface backed by Redis
// or similar.
//
// Window specifies the minimum duration of how long the
// duplicate tags are remembered for. Real duration can
// extend up to 50% longer because it depends on the
// clean up cycle.
func NewMapExpiringKeyRepository(window time.Duration) (ExpiringKeyRepository, error) {
if window < time.Millisecond {
return nil, errors.New("deduplication window of less than a millisecond is impractical")
}
kr := &mapExpiringKeyRepository{
window: window,
mu: &sync.Mutex{},
tags: make(map[string]time.Time),
}
ticker := time.NewTicker(window / 2)
go kr.cleanOutLoop(context.Background(), ticker)
return kr, nil
}
// Len returns the number of known tags that have not been
// cleaned out yet.
func (kr *mapExpiringKeyRepository) Len() (count int) {
kr.mu.Lock()
count = len(kr.tags)
kr.mu.Unlock()
return
}
// NewMessageHasherAdler32 generates message hashes using a fast
// Adler-32 checksum of the [message.Message] body. Read
// limit specifies how many bytes of the message are
// used for calculating the hash.
//
// Lower limit improves performance but results in more false
// positives. Read limit must be greater than
// [MessageHasherReadLimitMinimum].
func NewMessageHasherAdler32(readLimit int64) MessageHasher {
if readLimit < MessageHasherReadLimitMinimum {
readLimit = MessageHasherReadLimitMinimum
}
return func(m *message.Message) (string, error) {
h := adler32.New()
_, err := io.CopyN(h, bytes.NewReader(m.Payload), readLimit)
if err != nil && err != io.EOF {
return "", err
}
return string(h.Sum(nil)), nil
}
}
// NewMessageHasherSHA256 generates message hashes using a slower
// but more resilient hashing of the [message.Message] body. Read
// limit specifies how many bytes of the message are
// used for calculating the hash.
//
// Lower limit improves performance but results in more false
// positives. Read limit must be greater than
// [MessageHasherReadLimitMinimum].
func NewMessageHasherSHA256(readLimit int64) MessageHasher {
if readLimit < MessageHasherReadLimitMinimum {
readLimit = MessageHasherReadLimitMinimum
}
return func(m *message.Message) (string, error) {
h := sha256.New()
_, err := io.CopyN(h, bytes.NewReader(m.Payload), readLimit)
if err != nil && err != io.EOF {
return "", err
}
return string(h.Sum(nil)), nil
}
}
// NewMessageHasherFromMetadataField looks for a hash value
// inside message metadata instead of calculating a new one.
// Useful if a [MessageHasher] was applied in a previous
// [message.HandlerFunc].
func NewMessageHasherFromMetadataField(field string) MessageHasher {
return func(m *message.Message) (string, error) {
fromMetadata, ok := m.Metadata[field]
if ok {
return fromMetadata, nil
}
return "", fmt.Errorf("cannot recover hash value from metadata of message #%s: field %q is absent", m.UUID, field)
}
}
// PublisherDecorator returns a decorator that
// acknowledges and drops every [message.Message] that
// was recognized by a [Deduplicator].
//
// The returned decorator provides the same functionality
// to a [message.Publisher] as [Deduplicator.Middleware]
// to a [message.Router].
func (d *Deduplicator) PublisherDecorator() message.PublisherDecorator {
return func(pub message.Publisher) (message.Publisher, error) {
if pub == nil {
return nil, errors.New("cannot decorate a <nil> publisher")
}
return &deduplicatingPublisherDecorator{
Publisher: pub,
deduplicator: applyDefaultsToDeduplicator(d),
}, nil
}
}
// DelayOnError is a middleware that adds the delay metadata to the message if an error occurs.
//
// IMPORTANT: The delay metadata doesn't cause delays with all Pub/Subs! Using it won't have any effect on Pub/Subs that don't support it.
// See the list of supported Pub/Subs in the documentation: https://watermill.io/advanced/delayed-messages/
type DelayOnError struct {
// InitialInterval is the first interval between retries. Subsequent intervals will be scaled by Multiplier.
InitialInterval time.Duration
// MaxInterval sets the limit for the exponential backoff of retries. The interval will not be increased beyond MaxInterval.
MaxInterval time.Duration
// Multiplier is the factor by which the waiting interval will be multiplied between retries.
Multiplier float64
}
// Duplicator is processing messages twice, to ensure that the endpoint is idempotent.
func Duplicator(h message.HandlerFunc) message.HandlerFunc {
return func(msg *message.Message) ([]*message.Message, error) {
firstProducedMessages, firstErr := h(msg)
if firstErr != nil {
return nil, firstErr
}
secondProducedMessages, secondErr := h(msg)
if secondErr != nil {
return nil, secondErr
}
return append(firstProducedMessages, secondProducedMessages...), nil
}
}
// IgnoreErrors provides a middleware that makes the handler ignore some explicitly whitelisted errors.
type IgnoreErrors struct {
ignoredErrors map[string]struct{}
}
// NewIgnoreErrors creates a new IgnoreErrors middleware.
func NewIgnoreErrors(errs []error) IgnoreErrors {
errsMap := make(map[string]struct{}, len(errs))
for _, err := range errs {
errsMap[err.Error()] = struct{}{}
}
return IgnoreErrors{errsMap}
}
// Middleware returns the IgnoreErrors middleware.
func (i IgnoreErrors) Middleware(h message.HandlerFunc) message.HandlerFunc {
return func(msg *message.Message) ([]*message.Message, error) {
events, err := h(msg)
if err != nil {
if _, ok := i.ignoredErrors[errors.Cause(err).Error()]; ok {
return events, nil
}
return events, err
}
return events, nil
}
}
// InstantAck makes the handler instantly acknowledge the incoming message, regardless of any errors.
// It may be used to gain throughput, but at a cost:
// If you had exactly-once delivery, you may expect at-least-once instead.
// If you had ordered messages, the ordering might be broken.
func InstantAck(h message.HandlerFunc) message.HandlerFunc {
return func(message *message.Message) ([]*message.Message, error) {
message.Ack()
return h(message)
}
}
// PoisonQueue provides a middleware that salvages unprocessable messages and published them on a separate topic.
// The main middleware chain then continues on, business as usual.
func PoisonQueue(pub message.Publisher, topic string) (message.HandlerMiddleware, error) {
if topic == "" {
return nil, ErrInvalidPoisonQueueTopic
}
pq := poisonQueue{
topic: topic,
pub: pub,
shouldGoToPoisonQueue: func(err error) bool {
return true
},
}
return pq.Middleware, nil
}
// PoisonQueueWithFilter is just like PoisonQueue, but accepts a function that decides which errors qualify for the poison queue.
func PoisonQueueWithFilter(pub message.Publisher, topic string, shouldGoToPoisonQueue func(err error) bool) (message.HandlerMiddleware, error) {
if topic == "" {
return nil, ErrInvalidPoisonQueueTopic
}
pq := poisonQueue{
topic: topic,
pub: pub,
shouldGoToPoisonQueue: shouldGoToPoisonQueue,
}
return pq.Middleware, nil
}
// RandomFail makes the handler fail with an error based on random chance. Error probability should be in the range (0,1).
func RandomFail(errorProbability float32) message.HandlerMiddleware {
return func(h message.HandlerFunc) message.HandlerFunc {
return func(message *message.Message) ([]*message.Message, error) {
if shouldFail(errorProbability) {
return nil, errors.New("random fail occurred")
}
return h(message)
}
}
}
// RandomPanic makes the handler panic based on random chance. Panic probability should be in the range (0,1).
func RandomPanic(panicProbability float32) message.HandlerMiddleware {
return func(h message.HandlerFunc) message.HandlerFunc {
return func(message *message.Message) ([]*message.Message, error) {
if shouldFail(panicProbability) {
panic("random panic occurred")
}
return h(message)
}
}
}
// RecoveredPanicError holds the recovered panic's error along with the stacktrace.
type RecoveredPanicError struct {
V interface{}
Stacktrace string
}
// Recoverer recovers from any panic in the handler and appends RecoveredPanicError with the stacktrace
// to any error returned from the handler.
func Recoverer(h message.HandlerFunc) message.HandlerFunc {
return func(event *message.Message) (events []*message.Message, err error) {
panicked := true
defer func() {
if r := recover(); r != nil || panicked {
err = errors.WithStack(RecoveredPanicError{V: r, Stacktrace: string(debug.Stack())})
}
}()
events, err = h(event)
panicked = false
return events, err
}
}
// Retry provides a middleware that retries the handler if errors are returned.
// The retry behaviour is configurable, with exponential backoff and maximum elapsed time.
type Retry struct {
// MaxRetries is maximum number of times a retry will be attempted.
MaxRetries int
// InitialInterval is the first interval between retries. Subsequent intervals will be scaled by Multiplier.
InitialInterval time.Duration
// MaxInterval sets the limit for the exponential backoff of retries. The interval will not be increased beyond MaxInterval.
MaxInterval time.Duration
// Multiplier is the factor by which the waiting interval will be multiplied between retries.
Multiplier float64
// MaxElapsedTime sets the time limit of how long retries will be attempted. Disabled if 0.
MaxElapsedTime time.Duration
// RandomizationFactor randomizes the spread of the backoff times within the interval of:
// [currentInterval * (1 - randomization_factor), currentInterval * (1 + randomization_factor)].
RandomizationFactor float64
// OnRetryHook is an optional function that will be executed on each retry attempt.
// The number of the current retry is passed as retryNum,
OnRetryHook func(retryNum int, delay time.Duration)
Logger watermill.LoggerAdapter
}
// Middleware returns the Retry middleware.
func (r Retry) Middleware(h message.HandlerFunc) message.HandlerFunc {
return func(msg *message.Message) ([]*message.Message, error) {
producedMessages, err := h(msg)
if err == nil {
return producedMessages, nil
}
expBackoff := backoff.NewExponentialBackOff()
expBackoff.InitialInterval = r.InitialInterval
expBackoff.MaxInterval = r.MaxInterval
expBackoff.Multiplier = r.Multiplier
expBackoff.MaxElapsedTime = r.MaxElapsedTime
expBackoff.RandomizationFactor = r.RandomizationFactor
ctx := msg.Context()
if r.MaxElapsedTime > 0 {
var cancel func()
ctx, cancel = context.WithTimeout(ctx, r.MaxElapsedTime)
defer cancel()
}
retryNum := 1
expBackoff.Reset()
retryLoop:
for {
waitTime := expBackoff.NextBackOff()
select {
case <-ctx.Done():
return producedMessages, err
case <-time.After(waitTime):
// go on
}
producedMessages, err = h(msg)
if err == nil {
return producedMessages, nil
}
if r.Logger != nil {
r.Logger.Error("Error occurred, retrying", err, watermill.LogFields{
"retry_no": retryNum,
"max_retries": r.MaxRetries,
"wait_time": waitTime,
"elapsed_time": expBackoff.GetElapsedTime(),
})
}
if r.OnRetryHook != nil {
r.OnRetryHook(retryNum, waitTime)
}
retryNum++
if retryNum > r.MaxRetries {
break retryLoop
}
}
return nil, err
}
}
// Throttle provides a middleware that limits the amount of messages processed per unit of time.
// This may be done e.g. to prevent excessive load caused by running a handler on a long queue of unprocessed messages.
type Throttle struct {
ticker *time.Ticker
}
// NewThrottle creates a new Throttle middleware.
// Example duration and count: NewThrottle(10, time.Second) for 10 messages per second
func NewThrottle(count int64, duration time.Duration) *Throttle {
return &Throttle{
ticker: time.NewTicker(duration / time.Duration(count)),
}
}
// Middleware returns the Throttle middleware.
func (t Throttle) Middleware(h message.HandlerFunc) message.HandlerFunc {
return func(message *message.Message) ([]*message.Message, error) {
// throttle is shared by multiple handlers, which will wait for their "tick"
<-t.ticker.C
return h(message)
}
}
// Timeout makes the handler cancel the incoming message's context after a specified time.
// Any timeout-sensitive functionality of the handler should listen on msg.Context().Done() to know when to fail.
func Timeout(timeout time.Duration) func(message.HandlerFunc) message.HandlerFunc {
return func(h message.HandlerFunc) message.HandlerFunc {
return func(msg *message.Message) ([]*message.Message, error) {
ctx, cancel := context.WithTimeout(msg.Context(), timeout)
defer func() {
cancel()
}()
msg.SetContext(ctx)
return h(msg)
}
}
}