斯坦纳树
给一个联通图,求 $k$ 个关键点联通的最小生成树权值
设 $f[o][i]$ 表示当前关键点选择状态为 $o$ ,以点 $i$ 为根的树的最小权值
初始 $f[1<<(i-1)][i]=val[i]$ ,$val[i]$ 表示点 $i$ 的权值
那么从小到大枚举状态 $o$
对于每一个状态枚举 $o$ 的真子集 $op$,
则 $f[o][i]=min(f[o][i],f[o-op][i]+f[op][i]-val[i])$ 注意代价要减去 $val[i]$ ,因为两个状态合并时点 $i$ 的代价会算两次
这样转移还不够,还要考虑一个树自己扩展出去
所以枚举与根 $i$ 相连的点 $v$
则 $f[o][v]=min(f[o][v],f[o][i]+val[v])$ ,这样dp的顺序不好确定,但是发现这个很像 SPFA 的式子,所以用 SPFA 来进行转移
总结一下,对于每个状态,先考虑树的合并,再考虑树的扩展
至于为什么这样做是对的呢:
感性理解一下,这样显然会考虑到所有的情况,所以是对的2333....
SPFA时以 $f[o][i]!=INF$ 为起点
因为此题要输出路径,所以维护一个 $fa[o][i]$ 存状态 $o,i$ 是从哪两个子树合并的,对于扩展的子树就特殊处理一下
骚操作:枚举一个状态的真子集 : $for(int op=(o-1)&o;op;op=(op-1)&o)$
#include#include #include #include #include #include using namespace std;typedef long long ll;inline int read(){ int x=0,f=1; char ch=getchar(); while(ch<'0'||ch>'9') { if(ch=='-') f=-1; ch=getchar(); } while(ch>='0'&&ch<='9') { x=(x<<1)+(x<<3)+(ch^48); ch=getchar(); } return x*f;}const int N=27,M=2027,INF=1e9+7;int fir[M],from[M<<3],to[M<<3],cntt;void add(int &a,int &b){ from[++cntt]=fir[a]; fir[a]=cntt; to[cntt]=b;}int n,m,K,tot;int val[M],pos[N],id[N][N];int f[M][M];bool inq[M],mp[N][N];struct path { int o1,x1,o2,x2;}fa[M][M];void SPFA(int p){ queue q; for(int i=1;i<=tot;i++) if(f[p][i] f[p][x]+val[v]) { f[p][v]=f[p][x]+val[v]; fa[p][v]=(path){p,x,-1,v};//扩展的子树特殊处理成-1 if(!inq[v]) q.push(v),inq[v]=1; } } }}void dfs(int o,int x){ int o1=fa[o][x].o1,x1=fa[o][x].x1,o2=fa[o][x].o2,x2=fa[o][x].x2; if(!(o1|x1|o2|x2)) return; dfs(o1,x1); if(o2==-1) mp[(x2-1)/m+1][(x2-1)%m+1]=1;//如果是-1则说明此点有志愿者 else dfs(o2,x2);//否则向下一个子树转移}int main(){ //freopen("data.in","r",stdin); //freopen("data.out","w",stdout); memset(f,0x3f,sizeof(f)); n=read(),m=read(); tot=n*m; for(int i=1;i<=n;i++) for(int j=1;j<=m;j++) { id[i][j]=(i-1)*m+j;//把点缩起来 val[id[i][j]]=read(); if(!val[id[i][j]]) pos[++K]=id[i][j]; if(j>1) add(id[i][j-1],id[i][j]),add(id[i][j],id[i][j-1]); if(i>1) add(id[i-1][j],id[i][j]),add(id[i][j],id[i-1][j]); } int mx=(1< f[o^op][j]+f[op][j]-val[j]) { f[o][j]=f[o^op][j]+f[op][j]-val[j]; fa[o][j]=(path){o-op,j,op,j}; } SPFA(o); } int ans=INF,rt=0; for(int i=1;i<=tot;i++) if(f[mx][i]