common/pkg/proxy/proxy.go
2024-12-04 19:30:22 +08:00

109 lines
2.3 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package proxy
import (
"context"
"fmt"
"log/slog"
"math/rand"
"os"
"sync"
"time"
"github.com/golang/glog"
"github.com/timerzz/proxypool/pkg/getter"
"github.com/timerzz/proxypool/pkg/proxy"
"github.com/timerzz/proxypool/pkg/tool"
)
type ProxyPool struct {
m sync.Mutex
proxies proxy.ProxyList
cfg *Option
updated time.Time
}
func InitDefaultProxyPool() (*ProxyPool, error) {
path := os.Getenv(ProxyConfigEnv)
if path == "" {
path = DefaultProxyConfigPath
}
cfg, err := LoadProxyConfig(path)
if err != nil {
return nil, fmt.Errorf("获取代理池配置失败:%v", err)
}
return NewProxyPool(cfg), nil
}
func NewProxyPool(cfg *Option) *ProxyPool {
var p = &ProxyPool{}
p.cfg = cfg
p.Update()
return p
}
func (p *ProxyPool) Status() (proxy.ProxyList, time.Time) {
return p.proxies, p.updated
}
// Update 更新代理池
func (p *ProxyPool) Update() {
var list = make(proxy.ProxyList, 0, len(p.proxies))
var getters = make([]getter.Getter, 0, len(p.cfg.Clash)+len(p.cfg.Subscribes))
for _, url := range p.cfg.Subscribes {
gtr, err := getter.NewSubscribe(tool.Options{"url": url})
if err != nil {
slog.Warn(fmt.Sprintf("创建Subscribe Getter失败%v", err))
continue
}
getters = append(getters, gtr)
}
for _, url := range p.cfg.Clash {
gtr, err := getter.NewClashGetter(tool.Options{"url": url})
if err != nil {
slog.Warn(fmt.Sprintf("创建Clash Getter失败%v", err))
continue
}
getters = append(getters, gtr)
}
for _, gtr := range getters {
list = list.UniqAppendProxyList(gtr.Get())
}
glog.Infof("代理源共 %d 个: %v", len(p.cfg.Subscribes), p.cfg.Subscribes)
glog.Infof("获取代理共 %d 个", len(list))
p.m.Lock()
p.proxies = list
p.m.Unlock()
p.updated = time.Now()
}
// CronUpdate 定时更新
func (p *ProxyPool) CronUpdate(ctx context.Context, interval time.Duration) {
if interval == 0 {
interval = time.Minute * 30
}
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
p.Update()
}
}
}
// RandomIterator 获取随机代理的迭代器
func (p *ProxyPool) RandomIterator() func() proxy.Proxy {
return func() (proxy proxy.Proxy) {
if len(p.proxies) == 0 {
return nil
}
p.m.Lock()
defer p.m.Unlock()
curIndex := rand.Intn(len(p.proxies))
return p.proxies[curIndex]
}
}