diff --git a/Mastodon/Scene/Thread/ThreadViewController+DataSourceProvider.swift b/Mastodon/Scene/Thread/ThreadViewController+DataSourceProvider.swift index 187885175..14b5f18a2 100644 --- a/Mastodon/Scene/Thread/ThreadViewController+DataSourceProvider.swift +++ b/Mastodon/Scene/Thread/ThreadViewController+DataSourceProvider.swift @@ -30,7 +30,33 @@ extension ThreadViewController: DataSourceProvider { } func update(status: MastodonStatus) { - viewModel.root = .root(context: .init(status: status)) + switch viewModel.root { + case let .root(context): + if context.status.id == status.id { + viewModel.root = .root(context: .init(status: status)) + } else { + handle(status: status) + } + case let .reply(context): + if context.status.id == status.id { + viewModel.root = .reply(context: .init(status: status)) + } else { + handle(status: status) + } + case let .leaf(context): + if context.status.id == status.id { + viewModel.root = .leaf(context: .init(status: status)) + } else { + handle(status: status) + } + case .none: + assertionFailure("This should not have happened") + } + } + + private func handle(status: MastodonStatus) { + viewModel.mastodonStatusThreadViewModel.ancestors.handle(status: status, for: viewModel) + viewModel.mastodonStatusThreadViewModel.descendants.handle(status: status, for: viewModel) } func delete(status: MastodonStatus) { @@ -42,3 +68,41 @@ extension ThreadViewController: DataSourceProvider { return tableView.indexPath(for: cell) } } + +private extension [StatusItem] { + mutating func handle(status: MastodonStatus, for viewModel: ThreadViewModel) { + for (index, ancestor) in enumerated() { + switch ancestor { + case let .feed(record): + if record.status?.id == status.id { + self[index] = .feed(record: .fromStatus(status, kind: record.kind)) + } + case let.feedLoader(record): + if record.status?.id == status.id { + self[index] = .feedLoader(record: .fromStatus(status, kind: record.kind)) + } + case let .status(record): + if record.id == status.id { + self[index] = .status(record: status) + } + case let .thread(thread): + switch thread { + case let .root(context): + if context.status.id == status.id { + self[index] = .thread(.root(context: .init(status: status))) + } + case let .reply(context): + if context.status.id == status.id { + self[index] = .thread(.reply(context: .init(status: status))) + } + case let .leaf(context): + if context.status.id == status.id { + self[index] = .thread(.leaf(context: .init(status: status))) + } + } + case .bottomLoader, .topLoader: + break + } + } + } +}