diff --git a/errors.go b/errors.go index 365250f9..728d9627 100755 --- a/errors.go +++ b/errors.go @@ -271,6 +271,7 @@ var errorMessages = map[string]string{ "EDIT_BOT_INVITE_FORBIDDEN": "Normal users can't edit invites that were created by bots.", "EMAIL_HASH_EXPIRED": "Email hash expired.", "EMAIL_INVALID": "The specified email is invalid.", + "EMAIL_INSTALL_MISSING": "No email was set up for this account.", "EMAIL_NOT_SETUP": "Login email not set up.", "EMAIL_UNCONFIRMED": "Email unconfirmed.", "EMAIL_VERIFY_EXPIRED": "The verification email has expired.", diff --git a/examples/contextstore/contextstore.go b/examples/contextstore/contextstore.go new file mode 100644 index 00000000..c2e5954b --- /dev/null +++ b/examples/contextstore/contextstore.go @@ -0,0 +1,63 @@ +package main + +import ( + "fmt" + "time" + + tg "github.com/amarnathcjd/gogram/telegram" +) + +func main() { + client, _ := tg.NewClient(tg.ClientConfig{ + AppID: 6, + AppHash: "", + }) + + // --- Basic Get/Set --- + client.Data.Set("bot_status", "running") + fmt.Println(client.Data.GetString("bot_status", "unknown")) + + // --- TTL-based expiration --- + client.Data.SetWithTTL("temp_token", "abc123", 5*time.Minute) + + // --- Scoped by chat/user ID --- + chatID := int64(123456789) + client.Data.SetScoped(chatID, "step", 1) + client.Data.SetScoped(chatID, "awaiting_reply", true) + + step := client.Data.GetScoped(chatID, "step") + fmt.Printf("Chat %d is on step: %v\n", chatID, step) + + // --- Atomic counter --- + count := client.Data.IncrementScoped(chatID, "message_count") + fmt.Printf("Message count: %d\n", count) + + // --- Type-safe generic access --- + client.Data.Set("user_settings", map[string]any{"theme": "dark", "lang": "en"}) + if settings, ok := tg.GetTyped[map[string]any](client.Data, "user_settings"); ok { + fmt.Printf("Theme: %s\n", settings["theme"]) + } + + // --- Clean up all data for a chat --- + client.Data.DeleteByScope(chatID) + + // --- Example in an update handler --- + client.On(tg.OnMessage, func(m *tg.NewMessage) error { + // Track conversation state per chat + if m.Text() == "/start" { + client.Data.SetScoped(m.ChatID(), "state", "awaiting_name") + m.Reply("What's your name?") + return nil + } + + state := client.Data.GetString(fmt.Sprintf("%d:state", m.ChatID()), "") + if state == "awaiting_name" { + client.Data.SetScoped(m.ChatID(), "name", m.Text()) + client.Data.DeleteScoped(m.ChatID(), "state") + m.Reply("Nice to meet you, " + m.Text() + "!") + } + return nil + }) + + client.Idle() +} diff --git a/internal/transport/mtproxy.go b/internal/transport/mtproxy.go index edea679c..be8d0a5d 100644 --- a/internal/transport/mtproxy.go +++ b/internal/transport/mtproxy.go @@ -114,9 +114,10 @@ func DialMTProxy(ctx context.Context, proxy *utils.Proxy, targetHost string, dcI m := &mtproxyConn{conn: conn, config: config, useFakeTls: config.FakeTlsDomain != nil, isFirstWrite: true} if config.FakeTlsDomain != nil { + conn.SetDeadline(time.Now().Add(15 * time.Second)) if err := m.fakeTlsHandshake(); err != nil { conn.Close() - return nil, err + return nil, fmt.Errorf("TLS handshake failed: %w", err) } } @@ -134,10 +135,14 @@ func DialMTProxy(ctx context.Context, proxy *utils.Proxy, targetHost string, dcI if config.FakeTlsDomain != nil { m.obfTag = obfTag - } else if _, err = conn.Write(obfTag); err != nil { - conn.Close() - return nil, err + } else { + conn.SetDeadline(time.Now().Add(10 * time.Second)) + if _, err = conn.Write(obfTag); err != nil { + conn.Close() + return nil, fmt.Errorf("writing obfuscation tag: %w", err) + } } + conn.SetDeadline(time.Time{}) return m, nil } diff --git a/mtproto.go b/mtproto.go index 7d93f9f9..408503cb 100755 --- a/mtproto.go +++ b/mtproto.go @@ -971,8 +971,8 @@ func (m *MTProto) Disconnect() error { select { case <-done: m.Logger.Trace("all routines stopped gracefully") - case <-time.After(3 * time.Second): - m.Logger.Warn("timeout waiting for routines to stop (possible goroutine leak)") + case <-time.After(10 * time.Second): + m.Logger.Debug("timeout waiting for routines to stop on disconnect") } return nil diff --git a/telegram/buttons.go b/telegram/buttons.go index 6de39808..73651e0e 100755 --- a/telegram/buttons.go +++ b/telegram/buttons.go @@ -150,7 +150,11 @@ func (ButtonBuilder) SwitchInline(text string, samePeer bool, query string) *Key return &KeyboardButtonSwitchInline{Text: text, SamePeer: samePeer, Query: query} } -func (ButtonBuilder) WebView(text, url string) *KeyboardButtonSimpleWebView { +func (ButtonBuilder) WebView(text, url string) *KeyboardButtonWebView { + return &KeyboardButtonWebView{Text: text, URL: url} +} + +func (ButtonBuilder) SimpleWebView(text, url string) *KeyboardButtonSimpleWebView { return &KeyboardButtonSimpleWebView{Text: text, URL: url} } diff --git a/telegram/channels.go b/telegram/channels.go index 0cfddb9b..ff0dc826 100755 --- a/telegram/channels.go +++ b/telegram/channels.go @@ -56,59 +56,64 @@ func (c *Client) GetChatPhoto(chatID any) (Photo, error) { // // Params: // - Channel: the username or id of the channel or chat -func (c *Client) JoinChannel(Channel any) (bool, error) { - switch p := Channel.(type) { +func (c *Client) JoinChannel(channel any) (*Channel, error) { + switch p := channel.(type) { case string: if TgJoinRe.MatchString(p) { result, err := c.MessagesImportChatInvite(TgJoinRe.FindStringSubmatch(p)[1]) if err != nil { - return false, err + return nil, err } switch result := result.(type) { case *UpdatesObj: c.Cache.UpdatePeersToCache(result.Users, result.Chats) + switch result.Chats[0].(type) { + case *Channel: + return result.Chats[0].(*Channel), nil + } } - return true, nil + return nil, nil } else if UsernameRe.MatchString(p) { return c.joinChannelByPeer(UsernameRe.FindStringSubmatch(p)[1]) } - return false, errors.New("invalid channel or chat") + return nil, errors.New("invalid channel or chat") case *InputPeerChannel, *InputPeerChat, int, int32, int64: return c.joinChannelByPeer(p) case *ChatInviteExported: _, err := c.MessagesImportChatInvite(p.Link) if err != nil { - return false, err + return nil, err } - return true, nil + return nil, nil default: - return c.joinChannelByPeer(Channel) + return c.joinChannelByPeer(channel) } } -func (c *Client) joinChannelByPeer(Channel any) (bool, error) { - channel, err := c.ResolvePeer(Channel) +func (c *Client) joinChannelByPeer(channel any) (*Channel, error) { + channel, err := c.ResolvePeer(channel) if err != nil { - return false, err + return nil, err } if chat, ok := channel.(*InputPeerChannel); ok { _, err = c.ChannelsJoinChannel(&InputChannelObj{ChannelID: chat.ChannelID, AccessHash: chat.AccessHash}) if err != nil { - return false, err + return nil, err } + return c.GetChannel(chat.ChannelID) } else if chat, ok := channel.(*InputPeerChat); ok { _, err = c.MessagesAddChatUser(chat.ChatID, &InputUserEmpty{}, 0) if err != nil { - return false, err + return nil, err } } else { - return false, errors.New("peer is not a channel or chat") + return nil, errors.New("peer is not a channel or chat") } - return true, nil + return nil, nil } // LeaveChannel leaves a channel or chat @@ -1050,3 +1055,28 @@ func (c *Client) GetLinkedChannel(channel any) (*Channel, error) { return nil, errors.New("could not get full channel info") } } + +func (c *Client) ExportInvite(channel any) (ExportedChatInvite, error) { + peer, err := c.ResolvePeer(channel) + if err != nil { + return nil, err + } + + return c.MessagesExportChatInvite(&MessagesExportChatInviteParams{ + Peer: peer, + }) +} + +func (c *Client) RevokeInvite(channel any, invite string) error { + peer, err := c.ResolvePeer(channel) + if err != nil { + return err + } + + _, err = c.MessagesEditExportedChatInvite(&MessagesEditExportedChatInviteParams{ + Peer: peer, + Link: invite, + Revoked: true, + }) + return err +} diff --git a/telegram/client.go b/telegram/client.go index e1425a59..b4ec5ccd 100644 --- a/telegram/client.go +++ b/telegram/client.go @@ -63,6 +63,7 @@ type Client struct { secretChats *e2e.SecretChatManager exportedKeys map[int]*AuthExportedAuthorization Log Logger + Data *ContextStore } type DeviceConfig struct { @@ -160,6 +161,7 @@ func NewClient(config ClientConfig) (*Client, error) { } client.exSenders = NewExSenders() + client.Data = NewContextStore() return client, nil } diff --git a/telegram/conversation.go b/telegram/conversation.go index 1397051e..4db627ef 100755 --- a/telegram/conversation.go +++ b/telegram/conversation.go @@ -5,8 +5,8 @@ import ( "errors" "fmt" "regexp" - "slices" "strings" + "sync" "time" ) @@ -17,13 +17,12 @@ var ( ErrValidationFailed = errors.New("validation failed after max retries") ) -// State Machine for conversation with users and in groups +// Conversation is a state machine for interactive conversations with users type Conversation struct { Client *Client Peer InputPeer isPrivate bool timeout int32 - openHandlers []Handle lastMsg *NewMessage stopPropagation bool ctx context.Context @@ -31,6 +30,7 @@ type Conversation struct { closed bool abortKeywords []string fromUser int64 + mu sync.RWMutex } // ConversationOptions for configuring a conversation @@ -73,7 +73,6 @@ func (c *Client) NewConversation(peer any, options ...*ConversationOptions) (*Co }, nil } -// NewConversation creates a new conversation with user (standalone function) func NewConversation(client *Client, peer InputPeer, options ...*ConversationOptions) *Conversation { opts := getVariadic(options, &ConversationOptions{ Timeout: 60, @@ -123,21 +122,23 @@ func (c *Conversation) SetTimeout(timeout int32) *Conversation { return c } -// when stopPropagation is set to true, the event handler blocks all other handlers func (c *Conversation) SetStopPropagation(stop bool) *Conversation { c.stopPropagation = stop return c } func (c *Conversation) LastMessage() *NewMessage { + c.mu.RLock() + defer c.mu.RUnlock() return c.lastMsg } func (c *Conversation) IsClosed() bool { + c.mu.RLock() + defer c.mu.RUnlock() return c.closed } -// SetAbortKeywords sets keywords that will abort the conversation func (c *Conversation) SetAbortKeywords(keywords ...string) *Conversation { c.abortKeywords = keywords return c @@ -153,7 +154,6 @@ func (c *Conversation) WithFromUser(userID int64) *Conversation { return c } -// checkAbort checks if the message contains an abort keyword func (c *Conversation) checkAbort(msg *NewMessage) bool { if len(c.abortKeywords) == 0 { return false @@ -178,9 +178,11 @@ func (c *Conversation) RespondMedia(media InputMedia, opts ...*MediaOptions) (*N func (c *Conversation) Reply(text any, opts ...*SendOptions) (*NewMessage, error) { var options = getVariadic(opts, &SendOptions{}) if options.ReplyID == 0 { + c.mu.RLock() if c.lastMsg != nil { options.ReplyID = c.lastMsg.ID } + c.mu.RUnlock() } return c.Client.SendMessage(c.Peer, text, opts...) @@ -189,153 +191,202 @@ func (c *Conversation) Reply(text any, opts ...*SendOptions) (*NewMessage, error func (c *Conversation) ReplyMedia(media InputMedia, opts ...*MediaOptions) (*NewMessage, error) { var options = getVariadic(opts, &MediaOptions{}) if options.ReplyID == 0 { + c.mu.RLock() if c.lastMsg != nil { options.ReplyID = c.lastMsg.ID } + c.mu.RUnlock() } return c.Client.SendMedia(c.Peer, media, opts...) } func (c *Conversation) GetResponse() (*NewMessage, error) { + return c.waitForMessage(nil) +} + +func (c *Conversation) waitForMessage(check func(*NewMessage) bool) (*NewMessage, error) { + c.mu.RLock() + if c.closed { + c.mu.RUnlock() + return nil, ErrConversationClosed + } + c.mu.RUnlock() + resp := make(chan *NewMessage, 1) + done := make(chan struct{}) + waitFunc := func(m *NewMessage) error { + if check != nil && !check(m) { + return nil + } select { case resp <- m: - c.lastMsg = m - default: + case <-done: } - if c.stopPropagation { return ErrEndGroup } return nil } - var filters []Filter - switch c.Peer.(type) { - case *InputPeerChannel, *InputPeerChat: - filters = append(filters, InChat(c.Client.GetPeerID(c.Peer))) - case *InputPeerUser, *InputPeerSelf: - filters = append(filters, FromUser(c.Client.GetPeerID(c.Peer))) - } - - if c.isPrivate { - filters = append(filters, FilterPrivate) + filters := c.buildFilters() + args := make([]any, 0, 2+len(filters)) + args = append(args, OnMessage, waitFunc) + for _, f := range filters { + args = append(args, f) } + h := c.Client.On(args...) + h.SetGroup(ConversationGroup) - h := c.Client.On(OnMessage, waitFunc, filters) - h.SetGroup(-1) + timeout := time.Duration(c.timeout) * time.Second + timer := time.NewTimer(timeout) + defer timer.Stop() - c.openHandlers = append(c.openHandlers, h) select { - case <-time.After(time.Duration(c.timeout) * time.Second): - go c.removeHandle(h) - return nil, fmt.Errorf("conversation timeout: %d", c.timeout) + case <-c.ctx.Done(): + close(done) + c.Client.RemoveHandle(h) + return nil, ErrConversationClosed + case <-timer.C: + close(done) + c.Client.RemoveHandle(h) + return nil, ErrConversationTimeout case m := <-resp: - go c.removeHandle(h) + close(done) + c.Client.RemoveHandle(h) + c.mu.Lock() + c.lastMsg = m + c.mu.Unlock() return m, nil } } func (c *Conversation) GetEdit() (*NewMessage, error) { - resp := make(chan *NewMessage) + c.mu.RLock() + if c.closed { + c.mu.RUnlock() + return nil, ErrConversationClosed + } + c.mu.RUnlock() + + resp := make(chan *NewMessage, 1) + done := make(chan struct{}) + waitFunc := func(m *NewMessage) error { select { case resp <- m: - c.lastMsg = m - default: + case <-done: } - if c.stopPropagation { return ErrEndGroup } return nil } - var filters []Filter - switch c.Peer.(type) { - case *InputPeerChannel, *InputPeerChat: - filters = append(filters, InChat(c.Client.GetPeerID(c.Peer))) - case *InputPeerUser, *InputPeerSelf: - filters = append(filters, FromUser(c.Client.GetPeerID(c.Peer))) - } + filters := c.buildFilters() + h := c.Client.On(OnEdit, waitFunc, filters) + h.SetGroup(ConversationGroup) - if c.isPrivate { - filters = append(filters, FilterPrivate) - } + timeout := time.Duration(c.timeout) * time.Second + timer := time.NewTimer(timeout) + defer timer.Stop() - h := c.Client.On(OnEdit, waitFunc, filters) - h.SetGroup(-1) - c.openHandlers = append(c.openHandlers, h) select { - case <-time.After(time.Duration(c.timeout) * time.Second): - go c.removeHandle(h) - return nil, fmt.Errorf("conversation timeout: %d", c.timeout) + case <-c.ctx.Done(): + close(done) + c.Client.RemoveHandle(h) + return nil, ErrConversationClosed + case <-timer.C: + close(done) + c.Client.RemoveHandle(h) + return nil, ErrConversationTimeout case m := <-resp: - go c.removeHandle(h) + close(done) + c.Client.RemoveHandle(h) + c.mu.Lock() + c.lastMsg = m + c.mu.Unlock() return m, nil } } func (c *Conversation) GetReply() (*NewMessage, error) { - resp := make(chan *NewMessage) + c.mu.RLock() + if c.closed { + c.mu.RUnlock() + return nil, ErrConversationClosed + } + c.mu.RUnlock() + + resp := make(chan *NewMessage, 1) + done := make(chan struct{}) + waitFunc := func(m *NewMessage) error { select { case resp <- m: - c.lastMsg = m - default: + case <-done: } - if c.stopPropagation { return ErrEndGroup } return nil } - var filters []Filter - switch c.Peer.(type) { - case *InputPeerChannel, *InputPeerChat: - filters = append(filters, InChat(c.Client.GetPeerID(c.Peer))) - case *InputPeerUser, *InputPeerSelf: - filters = append(filters, FromUser(c.Client.GetPeerID(c.Peer))) - } - - if c.isPrivate { - filters = append(filters, FilterPrivate) - } - + filters := c.buildFilters() filters = append(filters, FilterReply) - h := c.Client.On(OnMessage, waitFunc, filters) - h.SetGroup(-1) - c.openHandlers = append(c.openHandlers, h) + h.SetGroup(ConversationGroup) + + timeout := time.Duration(c.timeout) * time.Second + timer := time.NewTimer(timeout) + defer timer.Stop() + select { - case <-time.After(time.Duration(c.timeout) * time.Second): - go c.removeHandle(h) - return nil, fmt.Errorf("conversation timeout: %d", c.timeout) + case <-c.ctx.Done(): + close(done) + c.Client.RemoveHandle(h) + return nil, ErrConversationClosed + case <-timer.C: + close(done) + c.Client.RemoveHandle(h) + return nil, ErrConversationTimeout case m := <-resp: - go c.removeHandle(h) + close(done) + c.Client.RemoveHandle(h) + c.mu.Lock() + c.lastMsg = m + c.mu.Unlock() return m, nil } } func (c *Conversation) MarkRead() (*MessagesAffectedMessages, error) { - if c.lastMsg != nil { - return c.Client.SendReadAck(c.Peer, c.lastMsg.ID) - } else { - return c.Client.SendReadAck(c.Peer) + c.mu.RLock() + lastMsg := c.lastMsg + c.mu.RUnlock() + if lastMsg != nil { + return c.Client.SendReadAck(c.Peer, lastMsg.ID) } + return c.Client.SendReadAck(c.Peer) } func (c *Conversation) WaitClick() (*CallbackQuery, error) { - resp := make(chan *CallbackQuery) + c.mu.RLock() + if c.closed { + c.mu.RUnlock() + return nil, ErrConversationClosed + } + c.mu.RUnlock() + + resp := make(chan *CallbackQuery, 1) + done := make(chan struct{}) + waitFunc := func(b *CallbackQuery) error { select { case resp <- b: - default: + case <-done: } - if c.stopPropagation { return ErrEndGroup } @@ -345,87 +396,128 @@ func (c *Conversation) WaitClick() (*CallbackQuery, error) { h := c.Client.On(OnCallbackQuery, waitFunc, CustomCallback(func(b *CallbackQuery) bool { return c.Client.PeerEquals(b.Peer, c.Peer) })) - c.openHandlers = append(c.openHandlers, h) + h.SetGroup(ConversationGroup) + + timeout := time.Duration(c.timeout) * time.Second + timer := time.NewTimer(timeout) + defer timer.Stop() + select { - case <-time.After(time.Duration(c.timeout) * time.Second): - go c.removeHandle(h) - return nil, fmt.Errorf("conversation timeout: %d", c.timeout) + case <-c.ctx.Done(): + close(done) + c.Client.RemoveHandle(h) + return nil, ErrConversationClosed + case <-timer.C: + close(done) + c.Client.RemoveHandle(h) + return nil, ErrConversationTimeout case b := <-resp: - go c.removeHandle(h) + close(done) + c.Client.RemoveHandle(h) return b, nil } } func (c *Conversation) WaitEvent(ev Update) (Update, error) { - resp := make(chan Update) - waitFunc := func(u Update, c *Client) error { + c.mu.RLock() + if c.closed { + c.mu.RUnlock() + return nil, ErrConversationClosed + } + c.mu.RUnlock() + + resp := make(chan Update, 1) + done := make(chan struct{}) + + waitFunc := func(u Update, _ *Client) error { select { case resp <- u: - default: + case <-done: } - return nil } h := c.Client.On(ev, waitFunc) - c.openHandlers = append(c.openHandlers, h) + h.SetGroup(ConversationGroup) + + timeout := time.Duration(c.timeout) * time.Second + timer := time.NewTimer(timeout) + defer timer.Stop() + select { - case <-time.After(time.Duration(c.timeout) * time.Second): - go c.removeHandle(h) - return nil, fmt.Errorf("conversation timeout: %d", c.timeout) + case <-c.ctx.Done(): + close(done) + c.Client.RemoveHandle(h) + return nil, ErrConversationClosed + case <-timer.C: + close(done) + c.Client.RemoveHandle(h) + return nil, ErrConversationTimeout case u := <-resp: - go c.removeHandle(h) + close(done) + c.Client.RemoveHandle(h) return u, nil } } func (c *Conversation) WaitRead() (*UpdateReadChannelInbox, error) { - resp := make(chan *UpdateReadChannelInbox) + c.mu.RLock() + if c.closed { + c.mu.RUnlock() + return nil, ErrConversationClosed + } + c.mu.RUnlock() + + resp := make(chan *UpdateReadChannelInbox, 1) + done := make(chan struct{}) + waitFunc := func(u Update) error { switch v := u.(type) { case *UpdateReadChannelInbox: select { case resp <- v: - default: + case <-done: } } - return nil } h := c.Client.On(&UpdateReadChannelInbox{}, waitFunc) - c.openHandlers = append(c.openHandlers, h) + h.SetGroup(ConversationGroup) + + timeout := time.Duration(c.timeout) * time.Second + timer := time.NewTimer(timeout) + defer timer.Stop() select { - case <-time.After(time.Duration(c.timeout) * time.Second): - go c.removeHandle(h) - return nil, fmt.Errorf("conversation timeout: %d", c.timeout) + case <-c.ctx.Done(): + close(done) + c.Client.RemoveHandle(h) + return nil, ErrConversationClosed + case <-timer.C: + close(done) + c.Client.RemoveHandle(h) + return nil, ErrConversationTimeout case u := <-resp: - go c.removeHandle(h) + close(done) + c.Client.RemoveHandle(h) return u, nil } } -func (c *Conversation) removeHandle(h Handle) { - for i, v := range c.openHandlers { - if v == h { - c.openHandlers = slices.Delete(c.openHandlers, i, i+1) - break - } - } - c.Client.removeHandle(h) -} - -// Close closes the conversation, removing all open event handlers +// Close closes the conversation and cancels any pending operations func (c *Conversation) Close() { + c.mu.Lock() + if c.closed { + c.mu.Unlock() + return + } c.closed = true + c.mu.Unlock() + if c.cancel != nil { c.cancel() } - for _, h := range c.openHandlers { - c.Client.removeHandle(h) - } - c.openHandlers = nil } func (c *Conversation) Ask(text any, opts ...*SendOptions) (*NewMessage, error) { @@ -478,13 +570,13 @@ func (c *Conversation) AskVoice(text any, opts ...*SendOptions) (*NewMessage, er } func (c *Conversation) GetResponseMatching(pattern *regexp.Regexp) (*NewMessage, error) { - return c.getResponseWithFilter(func(m *NewMessage) bool { + return c.waitForMessage(func(m *NewMessage) bool { return pattern.MatchString(m.Text()) }) } func (c *Conversation) GetResponseContaining(words ...string) (*NewMessage, error) { - return c.getResponseWithFilter(func(m *NewMessage) bool { + return c.waitForMessage(func(m *NewMessage) bool { text := strings.ToLower(m.Text()) for _, word := range words { if strings.Contains(text, strings.ToLower(word)) { @@ -497,7 +589,7 @@ func (c *Conversation) GetResponseContaining(words ...string) (*NewMessage, erro // GetResponseExact waits for a message with exact text match (case-insensitive) func (c *Conversation) GetResponseExact(options ...string) (*NewMessage, error) { - return c.getResponseWithFilter(func(m *NewMessage) bool { + return c.waitForMessage(func(m *NewMessage) bool { text := strings.ToLower(strings.TrimSpace(m.Text())) for _, opt := range options { if text == strings.ToLower(opt) { @@ -508,54 +600,20 @@ func (c *Conversation) GetResponseExact(options ...string) (*NewMessage, error) }) } -func (c *Conversation) getResponseWithFilter(check func(*NewMessage) bool) (*NewMessage, error) { - resp := make(chan *NewMessage, 1) - waitFunc := func(m *NewMessage) error { - if check(m) { - select { - case resp <- m: - c.lastMsg = m - default: - } - } - if c.stopPropagation { - return ErrEndGroup - } - return nil - } - - filters := c.buildFilters() - h := c.Client.On(OnMessage, waitFunc, filters) - h.SetGroup(-1) - c.openHandlers = append(c.openHandlers, h) - - select { - case <-c.ctx.Done(): - go c.removeHandle(h) - return nil, ErrConversationClosed - case <-time.After(time.Duration(c.timeout) * time.Second): - go c.removeHandle(h) - return nil, ErrConversationTimeout - case m := <-resp: - go c.removeHandle(h) - return m, nil - } -} - func (c *Conversation) WaitForPhoto() (*NewMessage, error) { - return c.getResponseWithFilter(func(m *NewMessage) bool { + return c.waitForMessage(func(m *NewMessage) bool { return m.Photo() != nil }) } func (c *Conversation) WaitForDocument() (*NewMessage, error) { - return c.getResponseWithFilter(func(m *NewMessage) bool { + return c.waitForMessage(func(m *NewMessage) bool { return m.Document() != nil }) } func (c *Conversation) WaitForVoice() (*NewMessage, error) { - return c.getResponseWithFilter(func(m *NewMessage) bool { + return c.waitForMessage(func(m *NewMessage) bool { if doc := m.Document(); doc != nil { for _, attr := range doc.Attributes { if _, ok := attr.(*DocumentAttributeAudio); ok { @@ -568,19 +626,19 @@ func (c *Conversation) WaitForVoice() (*NewMessage, error) { } func (c *Conversation) WaitForVideo() (*NewMessage, error) { - return c.getResponseWithFilter(func(m *NewMessage) bool { + return c.waitForMessage(func(m *NewMessage) bool { return m.Video() != nil }) } func (c *Conversation) WaitForSticker() (*NewMessage, error) { - return c.getResponseWithFilter(func(m *NewMessage) bool { + return c.waitForMessage(func(m *NewMessage) bool { return m.Sticker() != nil }) } func (c *Conversation) WaitForMedia() (*NewMessage, error) { - return c.getResponseWithFilter(func(m *NewMessage) bool { + return c.waitForMessage(func(m *NewMessage) bool { return m.Media() != nil }) } @@ -644,7 +702,7 @@ func (c *Conversation) AskUntil(question string, validator func(*NewMessage) boo } if attempt < maxRetries-1 { - question = retry // Use retry message for subsequent attempts + question = retry } } @@ -774,7 +832,6 @@ func WithAskFunc(fn func(*Conversation, string, ...*SendOptions) (*NewMessage, e } } -// Convenience helpers for common media types func ExpectPhoto() func(*ConversationStep) { return WithMediaType("photo") } @@ -884,11 +941,16 @@ func (w *ConversationWizard) Run() (map[string]*NewMessage, error) { if step.Skippable && msg != nil { text := strings.ToLower(strings.TrimSpace(msg.Text())) + skipped := false for _, skipWord := range step.SkipWords { if text == strings.ToLower(skipWord) { - continue + skipped = true + break } } + if skipped { + continue + } } w.answers[step.Name] = msg diff --git a/telegram/helpers.go b/telegram/helpers.go index 23ef359a..352c2686 100644 --- a/telegram/helpers.go +++ b/telegram/helpers.go @@ -15,6 +15,8 @@ import ( "reflect" "strconv" "strings" + "sync" + "time" ige "github.com/amarnathcjd/gogram/internal/aes_ige" "github.com/amarnathcjd/gogram/internal/session" @@ -1724,3 +1726,238 @@ func (c *Client) Stringify(object any) string { func (c *Client) JSON(object any, noindent ...any) string { return MarshalWithTypeName(object, true) } + +// GetTyped retrieves a typed value from the ContextStore +func GetTyped[T any](cs *ContextStore, key string) (T, bool) { + var zero T + val, ok := cs.GetOk(key) + if !ok { + return zero, false + } + typed, ok := val.(T) + return typed, ok +} + +// GetScopedTyped retrieves a typed scoped value from the ContextStore +func GetScopedTyped[T any](cs *ContextStore, id int64, key string) (T, bool) { + return GetTyped[T](cs, cs.scopedKey(id, key)) +} + +type ContextStore struct { + mu sync.RWMutex + data map[string]*contextEntry + stopC chan struct{} +} + +type contextEntry struct { + value any + expiresAt time.Time + hasTTL bool +} + +func NewContextStore() *ContextStore { + cs := &ContextStore{ + data: make(map[string]*contextEntry), + stopC: make(chan struct{}), + } + go cs.cleanupLoop() + return cs +} + +func (cs *ContextStore) cleanupLoop() { + ticker := time.NewTicker(5 * time.Minute) + defer ticker.Stop() + for { + select { + case <-ticker.C: + cs.cleanup() + case <-cs.stopC: + return + } + } +} + +func (cs *ContextStore) cleanup() { + cs.mu.Lock() + defer cs.mu.Unlock() + now := time.Now() + for key, entry := range cs.data { + if entry.hasTTL && now.After(entry.expiresAt) { + delete(cs.data, key) + } + } +} + +func (cs *ContextStore) Close() { close(cs.stopC) } +func (cs *ContextStore) Has(key string) bool { _, ok := cs.GetOk(key); return ok } +func (cs *ContextStore) Len() int { return len(cs.Keys()) } + +func (cs *ContextStore) Set(key string, value any) { + cs.mu.Lock() + defer cs.mu.Unlock() + cs.data[key] = &contextEntry{value: value, hasTTL: false} +} + +func (cs *ContextStore) SetWithTTL(key string, value any, ttl time.Duration) { + cs.mu.Lock() + defer cs.mu.Unlock() + cs.data[key] = &contextEntry{value: value, expiresAt: time.Now().Add(ttl), hasTTL: true} +} + +func (cs *ContextStore) Get(key string) any { + cs.mu.RLock() + defer cs.mu.RUnlock() + entry, ok := cs.data[key] + if !ok || (entry.hasTTL && time.Now().After(entry.expiresAt)) { + return nil + } + return entry.value +} + +func (cs *ContextStore) GetOk(key string) (any, bool) { + cs.mu.RLock() + defer cs.mu.RUnlock() + entry, ok := cs.data[key] + if !ok || (entry.hasTTL && time.Now().After(entry.expiresAt)) { + return nil, false + } + return entry.value, true +} + +func (cs *ContextStore) Delete(key string) { + cs.mu.Lock() + defer cs.mu.Unlock() + delete(cs.data, key) +} + +func (cs *ContextStore) scopedKey(id int64, key string) string { + return strconv.FormatInt(id, 10) + ":" + key +} + +func (cs *ContextStore) SetScoped(id int64, key string, value any) { + cs.Set(cs.scopedKey(id, key), value) +} +func (cs *ContextStore) GetScoped(id int64, key string) any { return cs.Get(cs.scopedKey(id, key)) } +func (cs *ContextStore) GetScopedOk(id int64, key string) (any, bool) { + return cs.GetOk(cs.scopedKey(id, key)) +} +func (cs *ContextStore) DeleteScoped(id int64, key string) { cs.Delete(cs.scopedKey(id, key)) } +func (cs *ContextStore) HasScoped(id int64, key string) bool { return cs.Has(cs.scopedKey(id, key)) } +func (cs *ContextStore) IncrementScoped(id int64, key string) int { + return cs.Increment(cs.scopedKey(id, key)) +} + +func (cs *ContextStore) SetScopedWithTTL(id int64, key string, value any, ttl time.Duration) { + cs.SetWithTTL(cs.scopedKey(id, key), value, ttl) +} + +func (cs *ContextStore) DeleteByPrefix(prefix string) int { + cs.mu.Lock() + defer cs.mu.Unlock() + count := 0 + for key := range cs.data { + if strings.HasPrefix(key, prefix) { + delete(cs.data, key) + count++ + } + } + return count +} + +func (cs *ContextStore) DeleteByScope(id int64) int { + return cs.DeleteByPrefix(strconv.FormatInt(id, 10) + ":") +} + +func (cs *ContextStore) GetInt(key string, defaultVal int) int { + if val := cs.Get(key); val != nil { + switch v := val.(type) { + case int: + return v + case int32: + return int(v) + case int64: + return int(v) + } + } + return defaultVal +} + +func (cs *ContextStore) GetString(key string, defaultVal string) string { + if val := cs.Get(key); val != nil { + if s, ok := val.(string); ok { + return s + } + } + return defaultVal +} + +func (cs *ContextStore) GetBool(key string, defaultVal bool) bool { + if val := cs.Get(key); val != nil { + if b, ok := val.(bool); ok { + return b + } + } + return defaultVal +} + +func (cs *ContextStore) Increment(key string) int { + cs.mu.Lock() + defer cs.mu.Unlock() + entry, ok := cs.data[key] + if !ok || (entry.hasTTL && time.Now().After(entry.expiresAt)) { + cs.data[key] = &contextEntry{value: 1, hasTTL: false} + return 1 + } + switch v := entry.value.(type) { + case int: + entry.value = v + 1 + return v + 1 + case int32: + entry.value = v + 1 + return int(v + 1) + case int64: + entry.value = v + 1 + return int(v + 1) + default: + cs.data[key] = &contextEntry{value: 1, hasTTL: false} + return 1 + } +} + +func (cs *ContextStore) Keys() []string { + cs.mu.RLock() + defer cs.mu.RUnlock() + now := time.Now() + keys := make([]string, 0, len(cs.data)) + for key, entry := range cs.data { + if !entry.hasTTL || now.Before(entry.expiresAt) { + keys = append(keys, key) + } + } + return keys +} + +func (cs *ContextStore) Clear() { + cs.mu.Lock() + defer cs.mu.Unlock() + cs.data = make(map[string]*contextEntry) +} + +func (cs *ContextStore) SetOrUpdate(key string, defaultVal any, updateFn func(current any) any) any { + cs.mu.Lock() + defer cs.mu.Unlock() + entry, ok := cs.data[key] + if !ok || (entry.hasTTL && time.Now().After(entry.expiresAt)) { + cs.data[key] = &contextEntry{value: defaultVal, hasTTL: false} + return defaultVal + } + newVal := updateFn(entry.value) + entry.value = newVal + return newVal +} + +func (cs *ContextStore) String() string { + cs.mu.RLock() + defer cs.mu.RUnlock() + return fmt.Sprintf("ContextStore{entries: %d}", len(cs.data)) +} diff --git a/telegram/newmessage.go b/telegram/newmessage.go index 5e5a9868..b890a58d 100755 --- a/telegram/newmessage.go +++ b/telegram/newmessage.go @@ -34,7 +34,10 @@ type CustomFile struct { } func (m *NewMessage) MessageText() string { - return m.Message.Message + if m.Message != nil { + return m.Message.Message + } + return "" } func (m *NewMessage) ReplyToMsgID() int32 { @@ -593,19 +596,20 @@ func (m *NewMessage) GetCommand() string { return "" } -// Conv starts a new conversation with the user func (m *NewMessage) Conv(timeout ...int32) (*Conversation, error) { return m.Client.NewConversation(m.Peer, &ConversationOptions{ - Timeout: getVariadic(timeout, 60), - Private: m.IsPrivate(), + Timeout: getVariadic(timeout, 60), + Private: m.IsPrivate(), + StopPropagation: true, }) } // Wizard starts a new conversation wizard with the user func (m *NewMessage) Wizard(timeout ...int32) (*ConversationWizard, error) { conv, err := m.Client.NewConversation(m.Peer, &ConversationOptions{ - Private: m.IsPrivate(), - Timeout: getVariadic(timeout, 60), + Private: m.IsPrivate(), + Timeout: getVariadic(timeout, 60), + StopPropagation: true, }) if err != nil { return nil, err diff --git a/telegram/participant.go b/telegram/participant.go index 1fd24d87..6e9cef91 100755 --- a/telegram/participant.go +++ b/telegram/participant.go @@ -312,6 +312,26 @@ func (jru *JoinRequestUpdate) Approve(userID int64) (bool, error) { return err == nil, err } +func (jru *JoinRequestUpdate) ApproveAll() (bool, error) { + if jru.Channel == nil && jru.Chat == nil { + return false, fmt.Errorf("channel/chat is nil") + } + peer, err := jru.GetInputPeer() + if err != nil { + return false, err + } + link := "" + + if jru.BotOriginalUpdate != nil { + switch jru.BotOriginalUpdate.Invite.(type) { + case *ChatInviteExported: + link = jru.BotOriginalUpdate.Invite.(*ChatInviteExported).Link + } + } + _, err = jru.Client.MessagesHideAllChatJoinRequests(true, peer, link) + return err == nil, err +} + func (jru *JoinRequestUpdate) Decline(userID int64) (bool, error) { if jru.Channel == nil && jru.Chat == nil { return false, fmt.Errorf("channel/chat is nil") @@ -328,6 +348,25 @@ func (jru *JoinRequestUpdate) Decline(userID int64) (bool, error) { return err == nil, err } +func (jru *JoinRequestUpdate) DeclineAll() (bool, error) { + if jru.Channel == nil && jru.Chat == nil { + return false, fmt.Errorf("channel/chat is nil") + } + peer, err := jru.GetInputPeer() + if err != nil { + return false, err + } + link := "" + if jru.BotOriginalUpdate != nil { + switch jru.BotOriginalUpdate.Invite.(type) { + case *ChatInviteExported: + link = jru.BotOriginalUpdate.Invite.(*ChatInviteExported).Link + } + } + _, err = jru.Client.MessagesHideAllChatJoinRequests(false, peer, link) + return err == nil, err +} + func (jru *JoinRequestUpdate) Marshal(noindent ...bool) string { if jru.OriginalUpdate != nil { return MarshalWithTypeName(jru.OriginalUpdate, noindent...) diff --git a/telegram/updates.go b/telegram/updates.go index 7c29a46b..49f2b080 100644 --- a/telegram/updates.go +++ b/telegram/updates.go @@ -8,7 +8,6 @@ import ( "errors" "fmt" "maps" - "reflect" "regexp" "slices" "sort" @@ -106,6 +105,28 @@ func (c *lruCache) Contains(key int64) bool { return exists } +func (c *lruCache) TryAdd(key int64) bool { + c.Lock() + defer c.Unlock() + + if _, exists := c.items[key]; exists { + return false + } + + entry := &lruEntry{key: key, timestamp: time.Now()} + elem := c.list.PushFront(entry) + c.items[key] = elem + + if c.list.Len() > c.maxSize { + oldest := c.list.Back() + if oldest != nil { + c.list.Remove(oldest) + delete(c.items, oldest.Value.(*lruEntry).key) + } + } + return true +} + type patternCache struct { sync.RWMutex patterns map[string]*regexp.Regexp @@ -206,7 +227,14 @@ type Handle interface { GetPriority() int } +var handleIDCounter atomic.Uint64 + +func nextHandleID() uint64 { + return handleIDCounter.Add(1) +} + type baseHandle struct { + id uint64 Group int priority int name string @@ -308,8 +336,9 @@ type joinRequestHandle struct { type rawHandle struct { baseHandle - updateType Update - Handler RawHandler + updateType Update + updateTypeID uint32 + Handler RawHandler } type e2eHandle struct { @@ -504,17 +533,11 @@ func (u *UpdateDispatcher) getLastUpdateTime() time.Time { return time.Unix(0, u.lastUpdateTimeNano.Load()) } -// TryMarkUpdateProcessed atomically checks if an update was processed and marks it if not. -// Returns true if this call marked it (first processor), false if already processed. func (d *UpdateDispatcher) TryMarkUpdateProcessed(updateID int64) bool { if d.processedUpdatesLRU == nil { return true } - if d.processedUpdatesLRU.Contains(updateID) { - return false - } - d.processedUpdatesLRU.Add(updateID) - return true + return d.processedUpdatesLRU.TryAdd(updateID) } func (c *Client) NewUpdateDispatcher(sessionName ...string) { @@ -596,11 +619,25 @@ func (c *Client) removeHandle(handle Handle) error { return nil } -func removeHandleFromMap[T any](handle T, handlesMap map[int][]T) { +type handleWithID interface { + getID() uint64 + getPriority() int +} + +func (h *baseHandle) getID() uint64 { + return h.id +} + +func (h *baseHandle) getPriority() int { + return h.priority +} + +func removeHandleFromMap[T handleWithID](handle T, handlesMap map[int][]T) { + targetID := handle.getID() for key := range handlesMap { handles := handlesMap[key] for i := len(handles) - 1; i >= 0; i-- { - if reflect.DeepEqual(handles[i], handle) { + if handles[i].getID() == targetID { handlesMap[key] = slices.Delete(handles, i, i+1) return } @@ -608,6 +645,35 @@ func removeHandleFromMap[T any](handle T, handlesMap map[int][]T) { } } +var ( + updateTypeIDs = make(map[string]uint32) + updateTypeIDMu sync.RWMutex + nextTypeIDValue uint32 = 1 +) + +func getUpdateTypeID(update Update) uint32 { + if update == nil { + return 0 + } + typeName := fmt.Sprintf("%T", update) + updateTypeIDMu.RLock() + if id, ok := updateTypeIDs[typeName]; ok { + updateTypeIDMu.RUnlock() + return id + } + updateTypeIDMu.RUnlock() + + updateTypeIDMu.Lock() + defer updateTypeIDMu.Unlock() + if id, ok := updateTypeIDs[typeName]; ok { + return id + } + id := nextTypeIDValue + nextTypeIDValue++ + updateTypeIDs[typeName] = id + return id +} + // ---------------------------- Handle Functions ---------------------------- func (c *Client) handleMessageUpdate(update Message) { @@ -622,6 +688,10 @@ func (c *Client) handleMessageUpdate(update Message) { updateID = (peerID << 32) | int64(msg.ID) } + if msg.Out { + msg.FromID = &PeerUser{UserID: c.Me().ID} + } + if !c.dispatcher.TryMarkUpdateProcessed(updateID) { c.dispatcher.logger.Trace("duplicate message update skipped: %d", updateID) return @@ -633,6 +703,9 @@ func (c *Client) handleMessageUpdate(update Message) { packed := packMessage(c, msg) handle := func(h *messageHandle) error { + if msg.Out && !h.hasOutgoingFilter() { + return nil + } if h.runFilterChain(packed, h.Filters) { defer c.NewRecovery()() start := time.Now() @@ -735,7 +808,24 @@ func (c *Client) handleMessageUpdate(update Message) { } case *MessageService: + updateID := int64(msg.ID) + peerID := c.GetPeerID(msg.PeerID) + if peerID == 0 { + peerID = c.GetPeerID(msg.FromID) + } + if peerID != 0 { + updateID = (peerID << 32) | int64(msg.ID) + } + + if !c.dispatcher.TryMarkUpdateProcessed(updateID) { + c.dispatcher.logger.Trace("duplicate message update skipped: %d", updateID) + return + } + packed := packMessage(c, msg) + if msg.Out { + return + } c.dispatcher.RLock() actionHandles := make(map[int][]*chatActionHandle) @@ -958,6 +1048,12 @@ func (c *Client) handleInlineCallbackUpdate(update *UpdateInlineBotCallbackQuery } func (c *Client) handleParticipantUpdate(update *UpdateChannelParticipant) { + updateID := (update.ChannelID << 32) | (update.UserID << 16) | int64(update.Date&0xFFFF) + + if !c.dispatcher.TryMarkUpdateProcessed(updateID) { + return + } + packed := packChannelParticipant(c, update) c.dispatcher.RLock() @@ -1141,12 +1237,14 @@ func (c *Client) handleRawUpdate(update Update) { maps.Copy(rawHandles, c.dispatcher.rawHandles) c.dispatcher.RUnlock() + updateTypeID := getUpdateTypeID(update) + for group, handlers := range rawHandles { for _, handler := range handlers { if handler == nil || handler.Handler == nil { continue } - if reflect.TypeOf(update) == reflect.TypeOf(handler.updateType) || handler.updateType == nil { + if handler.updateTypeID == updateTypeID || handler.updateTypeID == 0 { handle := func(h *rawHandle) error { defer c.NewRecovery()() return h.Handler(update, c) @@ -1275,6 +1373,24 @@ func (h *messageHandle) runFilterChain(m *NewMessage, filters []Filter) bool { return true } +func (h *messageHandle) hasOutgoingFilter() bool { + for _, f := range h.Filters { + if f.flags.Has(FOutgoing) { + return true + } + + if f.Func != nil { + return true + } + for _, of := range f.orFilters { + if of.flags.Has(FOutgoing) { + return true + } + } + } + return false +} + func (e *messageEditHandle) runFilterChain(m *NewMessage, filters []Filter) bool { for _, f := range filters { if !f.check(m) { @@ -1678,15 +1794,15 @@ func addHandleToMap[T Handle](handleMap map[int][]T, handle T) T { return handleMap[group][len(handleMap[group])-1] } -func makePriorityChangeCallback[T Handle](handleMap map[int][]T, handle T, mu *sync.RWMutex) func() { +func makePriorityChangeCallback[T handleWithID](handleMap map[int][]T, handle T, handleID uint64, getGroup func() int, getPriority func() int, mu *sync.RWMutex) func() { return func() { mu.Lock() defer mu.Unlock() - group := handle.GetGroup() + group := getGroup() handlers := handleMap[group] for i := range handlers { - if reflect.DeepEqual(handlers[i], handle) { + if handlers[i].getID() == handleID { handlers = append(handlers[:i], handlers[i+1:]...) handleMap[group] = handlers break @@ -1695,8 +1811,9 @@ func makePriorityChangeCallback[T Handle](handleMap map[int][]T, handle T, mu *s handlers = handleMap[group] inserted := false - for i, h := range handlers { - if handle.GetPriority() > h.GetPriority() { + myPriority := getPriority() + for i := range handlers { + if myPriority > handlers[i].getPriority() { handleMap[group] = append(handlers[:i], append([]T{handle}, handlers[i:]...)...) inserted = true break @@ -1709,13 +1826,13 @@ func makePriorityChangeCallback[T Handle](handleMap map[int][]T, handle T, mu *s } } -func makeGroupChangeCallback[T Handle](handleMap map[int][]T, handle T, mu *sync.RWMutex) func(int, int) { +func makeGroupChangeCallback[T handleWithID](handleMap map[int][]T, handle T, handleID uint64, mu *sync.RWMutex) func(int, int) { return func(oldGroup, newGroup int) { mu.Lock() defer mu.Unlock() if old, ok := handleMap[oldGroup]; ok { for i := range old { - if reflect.DeepEqual(old[i], handle) { + if old[i].getID() == handleID { handleMap[oldGroup] = append(old[:i], old[i+1:]...) break } @@ -1733,19 +1850,21 @@ func (c *Client) AddMessageHandler(pattern any, handler MessageHandler, filters messageFilters = filters } + handleID := nextHandleID() handle := &messageHandle{ Pattern: pattern, Handler: handler, Filters: messageFilters, baseHandle: baseHandle{ + id: handleID, Group: DefaultGroup, enabled: true, metrics: &HandlerMetrics{}, }, } - handle.onGroupChanged = makeGroupChangeCallback(c.dispatcher.messageHandles, handle, &c.dispatcher.RWMutex) - handle.onPriorityChanged = makePriorityChangeCallback(c.dispatcher.messageHandles, handle, &c.dispatcher.RWMutex) + handle.onGroupChanged = makeGroupChangeCallback(c.dispatcher.messageHandles, handle, handleID, &c.dispatcher.RWMutex) + handle.onPriorityChanged = makePriorityChangeCallback(c.dispatcher.messageHandles, handle, handleID, handle.GetGroup, handle.GetPriority, &c.dispatcher.RWMutex) return addHandleToMap(c.dispatcher.messageHandles, handle) } @@ -1760,37 +1879,40 @@ func (c *Client) AddCommandHandler(pattern string, handler MessageHandler, filte func (c *Client) AddDeleteHandler(pattern any, handler func(d *DeleteMessage) error) Handle { c.dispatcher.Lock() defer c.dispatcher.Unlock() + handleID := nextHandleID() h := &messageDeleteHandle{ Pattern: pattern, Handler: handler, - baseHandle: baseHandle{Group: DefaultGroup}, + baseHandle: baseHandle{id: handleID, Group: DefaultGroup}, } - h.onGroupChanged = makeGroupChangeCallback(c.dispatcher.messageDeleteHandles, h, &c.dispatcher.RWMutex) - h.onPriorityChanged = makePriorityChangeCallback(c.dispatcher.messageDeleteHandles, h, &c.dispatcher.RWMutex) + h.onGroupChanged = makeGroupChangeCallback(c.dispatcher.messageDeleteHandles, h, handleID, &c.dispatcher.RWMutex) + h.onPriorityChanged = makePriorityChangeCallback(c.dispatcher.messageDeleteHandles, h, handleID, h.GetGroup, h.GetPriority, &c.dispatcher.RWMutex) return addHandleToMap(c.dispatcher.messageDeleteHandles, h) } func (c *Client) AddAlbumHandler(handler func(m *Album) error) Handle { c.dispatcher.Lock() defer c.dispatcher.Unlock() + handleID := nextHandleID() h := &albumHandle{ Handler: handler, - baseHandle: baseHandle{Group: DefaultGroup}, + baseHandle: baseHandle{id: handleID, Group: DefaultGroup}, } - h.onGroupChanged = makeGroupChangeCallback(c.dispatcher.albumHandles, h, &c.dispatcher.RWMutex) - h.onPriorityChanged = makePriorityChangeCallback(c.dispatcher.albumHandles, h, &c.dispatcher.RWMutex) + h.onGroupChanged = makeGroupChangeCallback(c.dispatcher.albumHandles, h, handleID, &c.dispatcher.RWMutex) + h.onPriorityChanged = makePriorityChangeCallback(c.dispatcher.albumHandles, h, handleID, h.GetGroup, h.GetPriority, &c.dispatcher.RWMutex) return addHandleToMap(c.dispatcher.albumHandles, h) } func (c *Client) AddActionHandler(handler MessageHandler) Handle { c.dispatcher.Lock() defer c.dispatcher.Unlock() + handleID := nextHandleID() h := &chatActionHandle{ Handler: handler, - baseHandle: baseHandle{Group: DefaultGroup}, + baseHandle: baseHandle{id: handleID, Group: DefaultGroup}, } - h.onGroupChanged = makeGroupChangeCallback(c.dispatcher.actionHandles, h, &c.dispatcher.RWMutex) - h.onPriorityChanged = makePriorityChangeCallback(c.dispatcher.actionHandles, h, &c.dispatcher.RWMutex) + h.onGroupChanged = makeGroupChangeCallback(c.dispatcher.actionHandles, h, handleID, &c.dispatcher.RWMutex) + h.onPriorityChanged = makePriorityChangeCallback(c.dispatcher.actionHandles, h, handleID, h.GetGroup, h.GetPriority, &c.dispatcher.RWMutex) return addHandleToMap(c.dispatcher.actionHandles, h) } @@ -1801,39 +1923,42 @@ func (c *Client) AddEditHandler(pattern any, handler MessageHandler, filters ... if len(filters) > 0 { messageFilters = filters } + handleID := nextHandleID() h := &messageEditHandle{ Pattern: pattern, Handler: handler, Filters: messageFilters, - baseHandle: baseHandle{Group: DefaultGroup}, + baseHandle: baseHandle{id: handleID, Group: DefaultGroup}, } - h.onGroupChanged = makeGroupChangeCallback(c.dispatcher.messageEditHandles, h, &c.dispatcher.RWMutex) - h.onPriorityChanged = makePriorityChangeCallback(c.dispatcher.messageEditHandles, h, &c.dispatcher.RWMutex) + h.onGroupChanged = makeGroupChangeCallback(c.dispatcher.messageEditHandles, h, handleID, &c.dispatcher.RWMutex) + h.onPriorityChanged = makePriorityChangeCallback(c.dispatcher.messageEditHandles, h, handleID, h.GetGroup, h.GetPriority, &c.dispatcher.RWMutex) return addHandleToMap(c.dispatcher.messageEditHandles, h) } func (c *Client) AddInlineHandler(pattern any, handler InlineHandler) Handle { c.dispatcher.Lock() defer c.dispatcher.Unlock() + handleID := nextHandleID() h := &inlineHandle{ Pattern: pattern, Handler: handler, - baseHandle: baseHandle{Group: DefaultGroup}, + baseHandle: baseHandle{id: handleID, Group: DefaultGroup}, } - h.onGroupChanged = makeGroupChangeCallback(c.dispatcher.inlineHandles, h, &c.dispatcher.RWMutex) - h.onPriorityChanged = makePriorityChangeCallback(c.dispatcher.inlineHandles, h, &c.dispatcher.RWMutex) + h.onGroupChanged = makeGroupChangeCallback(c.dispatcher.inlineHandles, h, handleID, &c.dispatcher.RWMutex) + h.onPriorityChanged = makePriorityChangeCallback(c.dispatcher.inlineHandles, h, handleID, h.GetGroup, h.GetPriority, &c.dispatcher.RWMutex) return addHandleToMap(c.dispatcher.inlineHandles, h) } func (c *Client) AddInlineSendHandler(handler InlineSendHandler) Handle { c.dispatcher.Lock() defer c.dispatcher.Unlock() + handleID := nextHandleID() h := &inlineSendHandle{ Handler: handler, - baseHandle: baseHandle{Group: DefaultGroup}, + baseHandle: baseHandle{id: handleID, Group: DefaultGroup}, } - h.onGroupChanged = makeGroupChangeCallback(c.dispatcher.inlineSendHandles, h, &c.dispatcher.RWMutex) - h.onPriorityChanged = makePriorityChangeCallback(c.dispatcher.inlineSendHandles, h, &c.dispatcher.RWMutex) + h.onGroupChanged = makeGroupChangeCallback(c.dispatcher.inlineSendHandles, h, handleID, &c.dispatcher.RWMutex) + h.onPriorityChanged = makePriorityChangeCallback(c.dispatcher.inlineSendHandles, h, handleID, h.GetGroup, h.GetPriority, &c.dispatcher.RWMutex) return addHandleToMap(c.dispatcher.inlineSendHandles, h) } @@ -1844,76 +1969,87 @@ func (c *Client) AddCallbackHandler(pattern any, handler CallbackHandler, filter if len(filters) > 0 { messageFilters = filters } + handleID := nextHandleID() h := &callbackHandle{ Pattern: pattern, Handler: handler, Filters: messageFilters, - baseHandle: baseHandle{Group: DefaultGroup}, + baseHandle: baseHandle{id: handleID, Group: DefaultGroup}, } - h.onGroupChanged = makeGroupChangeCallback(c.dispatcher.callbackHandles, h, &c.dispatcher.RWMutex) - h.onPriorityChanged = makePriorityChangeCallback(c.dispatcher.callbackHandles, h, &c.dispatcher.RWMutex) + h.onGroupChanged = makeGroupChangeCallback(c.dispatcher.callbackHandles, h, handleID, &c.dispatcher.RWMutex) + h.onPriorityChanged = makePriorityChangeCallback(c.dispatcher.callbackHandles, h, handleID, h.GetGroup, h.GetPriority, &c.dispatcher.RWMutex) return addHandleToMap(c.dispatcher.callbackHandles, h) } func (c *Client) AddInlineCallbackHandler(pattern any, handler InlineCallbackHandler) Handle { c.dispatcher.Lock() defer c.dispatcher.Unlock() + handleID := nextHandleID() h := &inlineCallbackHandle{ Pattern: pattern, Handler: handler, - baseHandle: baseHandle{Group: DefaultGroup}, + baseHandle: baseHandle{id: handleID, Group: DefaultGroup}, } - h.onGroupChanged = makeGroupChangeCallback(c.dispatcher.inlineCallbackHandles, h, &c.dispatcher.RWMutex) - h.onPriorityChanged = makePriorityChangeCallback(c.dispatcher.inlineCallbackHandles, h, &c.dispatcher.RWMutex) + h.onGroupChanged = makeGroupChangeCallback(c.dispatcher.inlineCallbackHandles, h, handleID, &c.dispatcher.RWMutex) + h.onPriorityChanged = makePriorityChangeCallback(c.dispatcher.inlineCallbackHandles, h, handleID, h.GetGroup, h.GetPriority, &c.dispatcher.RWMutex) return addHandleToMap(c.dispatcher.inlineCallbackHandles, h) } func (c *Client) AddJoinRequestHandler(handler PendingJoinHandler) Handle { c.dispatcher.Lock() defer c.dispatcher.Unlock() + handleID := nextHandleID() h := &joinRequestHandle{ Handler: handler, - baseHandle: baseHandle{Group: DefaultGroup}, + baseHandle: baseHandle{id: handleID, Group: DefaultGroup}, } - h.onGroupChanged = makeGroupChangeCallback(c.dispatcher.joinRequestHandles, h, &c.dispatcher.RWMutex) - h.onPriorityChanged = makePriorityChangeCallback(c.dispatcher.joinRequestHandles, h, &c.dispatcher.RWMutex) + h.onGroupChanged = makeGroupChangeCallback(c.dispatcher.joinRequestHandles, h, handleID, &c.dispatcher.RWMutex) + h.onPriorityChanged = makePriorityChangeCallback(c.dispatcher.joinRequestHandles, h, handleID, h.GetGroup, h.GetPriority, &c.dispatcher.RWMutex) return addHandleToMap(c.dispatcher.joinRequestHandles, h) } func (c *Client) AddParticipantHandler(handler ParticipantHandler) Handle { c.dispatcher.Lock() defer c.dispatcher.Unlock() + handleID := nextHandleID() h := &participantHandle{ Handler: handler, - baseHandle: baseHandle{Group: DefaultGroup}, + baseHandle: baseHandle{id: handleID, Group: DefaultGroup}, } - h.onGroupChanged = makeGroupChangeCallback(c.dispatcher.participantHandles, h, &c.dispatcher.RWMutex) - h.onPriorityChanged = makePriorityChangeCallback(c.dispatcher.participantHandles, h, &c.dispatcher.RWMutex) + h.onGroupChanged = makeGroupChangeCallback(c.dispatcher.participantHandles, h, handleID, &c.dispatcher.RWMutex) + h.onPriorityChanged = makePriorityChangeCallback(c.dispatcher.participantHandles, h, handleID, h.GetGroup, h.GetPriority, &c.dispatcher.RWMutex) return addHandleToMap(c.dispatcher.participantHandles, h) } func (c *Client) AddRawHandler(updateType Update, handler RawHandler) Handle { c.dispatcher.Lock() defer c.dispatcher.Unlock() + handleID := nextHandleID() + var typeID uint32 + if updateType != nil { + typeID = getUpdateTypeID(updateType) + } h := &rawHandle{ - updateType: updateType, - Handler: handler, - baseHandle: baseHandle{Group: DefaultGroup}, + updateType: updateType, + updateTypeID: typeID, + Handler: handler, + baseHandle: baseHandle{id: handleID, Group: DefaultGroup}, } - h.onGroupChanged = makeGroupChangeCallback(c.dispatcher.rawHandles, h, &c.dispatcher.RWMutex) - h.onPriorityChanged = makePriorityChangeCallback(c.dispatcher.rawHandles, h, &c.dispatcher.RWMutex) + h.onGroupChanged = makeGroupChangeCallback(c.dispatcher.rawHandles, h, handleID, &c.dispatcher.RWMutex) + h.onPriorityChanged = makePriorityChangeCallback(c.dispatcher.rawHandles, h, handleID, h.GetGroup, h.GetPriority, &c.dispatcher.RWMutex) return addHandleToMap(c.dispatcher.rawHandles, h) } func (c *Client) AddE2EHandler(handler func(update Update, c *Client) error) Handle { c.dispatcher.Lock() defer c.dispatcher.Unlock() + handleID := nextHandleID() h := &e2eHandle{ Handler: handler, - baseHandle: baseHandle{Group: DefaultGroup}, + baseHandle: baseHandle{id: handleID, Group: DefaultGroup}, } - h.onGroupChanged = makeGroupChangeCallback(c.dispatcher.e2eHandles, h, &c.dispatcher.RWMutex) - h.onPriorityChanged = makePriorityChangeCallback(c.dispatcher.e2eHandles, h, &c.dispatcher.RWMutex) + h.onGroupChanged = makeGroupChangeCallback(c.dispatcher.e2eHandles, h, handleID, &c.dispatcher.RWMutex) + h.onPriorityChanged = makePriorityChangeCallback(c.dispatcher.e2eHandles, h, handleID, h.GetGroup, h.GetPriority, &c.dispatcher.RWMutex) return addHandleToMap(c.dispatcher.e2eHandles, h) } @@ -2190,7 +2326,7 @@ func (c *Client) FetchDifference(fromPts int32, limit int32) { return default: - c.Log.Debug("unhandled difference type: %v", reflect.TypeOf(updates)) + c.Log.Debug("unhandled difference type: %T", updates) return } } @@ -2499,7 +2635,7 @@ func (c *Client) FetchChannelDifference(channelID int64, fromPts int32, limit in return default: - c.Log.Debug("unhandled channel difference type: %v (channel=%d)", reflect.TypeOf(diff), channelID) + c.Log.Debug("unhandled channel difference type: %T (channel=%d)", diff, channelID) return } } diff --git a/telegram/users.go b/telegram/users.go index 4b20d3e1..1c575f65 100644 --- a/telegram/users.go +++ b/telegram/users.go @@ -172,6 +172,26 @@ func (d *TLDialog) GetID() int64 { return 0 } +func (d *TLDialog) GetChannelID() int64 { + if d.Peer != nil { + switch peer := d.Peer.(type) { + case *PeerChannel: + if peer != nil { + return -100_000_000_0000 - peer.ChannelID + } + case *PeerChat: + if peer != nil { + return -peer.ChatID + } + case *PeerUser: + if peer != nil { + return peer.UserID + } + } + } + return 0 +} + func (d *TLDialog) GetInputPeer(c *Client) (InputPeer, error) { return c.GetSendablePeer(d.Peer) }