diff options
Diffstat (limited to 'internal/embedder')
| -rw-r--r-- | internal/embedder/embedder.go | 222 |
1 files changed, 222 insertions, 0 deletions
diff --git a/internal/embedder/embedder.go b/internal/embedder/embedder.go new file mode 100644 index 0000000..42f8518 --- /dev/null +++ b/internal/embedder/embedder.go | |||
| @@ -0,0 +1,222 @@ | |||
| 1 | package embedder | ||
| 2 | |||
| 3 | import ( | ||
| 4 | "bytes" | ||
| 5 | "context" | ||
| 6 | "encoding/json" | ||
| 7 | "fmt" | ||
| 8 | "net/http" | ||
| 9 | "os" | ||
| 10 | ) | ||
| 11 | |||
| 12 | // Embedder generates embeddings for text | ||
| 13 | type Embedder interface { | ||
| 14 | Embed(ctx context.Context, texts []string) ([][]float32, error) | ||
| 15 | Dimensions() int | ||
| 16 | } | ||
| 17 | |||
| 18 | // OllamaEmbedder uses Ollama's embedding API | ||
| 19 | type OllamaEmbedder struct { | ||
| 20 | baseURL string | ||
| 21 | model string | ||
| 22 | dims int | ||
| 23 | } | ||
| 24 | |||
| 25 | // NewOllamaEmbedder creates an Ollama embedder | ||
| 26 | func NewOllamaEmbedder(model string) *OllamaEmbedder { | ||
| 27 | baseURL := os.Getenv("CODEVEC_BASE_URL") | ||
| 28 | if baseURL == "" { | ||
| 29 | baseURL = "http://localhost:11434" | ||
| 30 | } | ||
| 31 | if model == "" { | ||
| 32 | model = "nomic-embed-text" | ||
| 33 | } | ||
| 34 | |||
| 35 | // Model dimensions | ||
| 36 | dims := 768 // nomic-embed-text default | ||
| 37 | switch model { | ||
| 38 | case "mxbai-embed-large": | ||
| 39 | dims = 1024 | ||
| 40 | case "all-minilm": | ||
| 41 | dims = 384 | ||
| 42 | } | ||
| 43 | |||
| 44 | return &OllamaEmbedder{ | ||
| 45 | baseURL: baseURL, | ||
| 46 | model: model, | ||
| 47 | dims: dims, | ||
| 48 | } | ||
| 49 | } | ||
| 50 | |||
| 51 | func (e *OllamaEmbedder) Dimensions() int { | ||
| 52 | return e.dims | ||
| 53 | } | ||
| 54 | |||
| 55 | type ollamaRequest struct { | ||
| 56 | Model string `json:"model"` | ||
| 57 | Prompt string `json:"prompt"` | ||
| 58 | } | ||
| 59 | |||
| 60 | type ollamaResponse struct { | ||
| 61 | Embedding []float32 `json:"embedding"` | ||
| 62 | } | ||
| 63 | |||
| 64 | func (e *OllamaEmbedder) Embed(ctx context.Context, texts []string) ([][]float32, error) { | ||
| 65 | embeddings := make([][]float32, len(texts)) | ||
| 66 | |||
| 67 | // Ollama's /api/embeddings takes one prompt at a time | ||
| 68 | for i, text := range texts { | ||
| 69 | req := ollamaRequest{ | ||
| 70 | Model: e.model, | ||
| 71 | Prompt: text, | ||
| 72 | } | ||
| 73 | |||
| 74 | body, err := json.Marshal(req) | ||
| 75 | if err != nil { | ||
| 76 | return nil, err | ||
| 77 | } | ||
| 78 | |||
| 79 | httpReq, err := http.NewRequestWithContext(ctx, "POST", e.baseURL+"/api/embeddings", bytes.NewReader(body)) | ||
| 80 | if err != nil { | ||
| 81 | return nil, err | ||
| 82 | } | ||
| 83 | httpReq.Header.Set("Content-Type", "application/json") | ||
| 84 | |||
| 85 | resp, err := http.DefaultClient.Do(httpReq) | ||
| 86 | if err != nil { | ||
| 87 | return nil, fmt.Errorf("ollama request failed: %w", err) | ||
| 88 | } | ||
| 89 | defer resp.Body.Close() | ||
| 90 | |||
| 91 | if resp.StatusCode != http.StatusOK { | ||
| 92 | return nil, fmt.Errorf("ollama returned status %d", resp.StatusCode) | ||
| 93 | } | ||
| 94 | |||
| 95 | var result ollamaResponse | ||
| 96 | if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { | ||
| 97 | return nil, err | ||
| 98 | } | ||
| 99 | |||
| 100 | embeddings[i] = result.Embedding | ||
| 101 | } | ||
| 102 | |||
| 103 | return embeddings, nil | ||
| 104 | } | ||
| 105 | |||
| 106 | // OpenAIEmbedder uses OpenAI-compatible embedding API | ||
| 107 | type OpenAIEmbedder struct { | ||
| 108 | baseURL string | ||
| 109 | apiKey string | ||
| 110 | model string | ||
| 111 | dims int | ||
| 112 | } | ||
| 113 | |||
| 114 | // NewOpenAIEmbedder creates an OpenAI-compatible embedder | ||
| 115 | func NewOpenAIEmbedder(model string) *OpenAIEmbedder { | ||
| 116 | baseURL := os.Getenv("CODEVEC_BASE_URL") | ||
| 117 | if baseURL == "" { | ||
| 118 | baseURL = "https://api.openai.com" | ||
| 119 | } | ||
| 120 | apiKey := os.Getenv("CODEVEC_API_KEY") | ||
| 121 | if model == "" { | ||
| 122 | model = "text-embedding-3-small" | ||
| 123 | } | ||
| 124 | |||
| 125 | dims := 1536 // text-embedding-3-small default | ||
| 126 | switch model { | ||
| 127 | case "text-embedding-3-large": | ||
| 128 | dims = 3072 | ||
| 129 | case "text-embedding-ada-002": | ||
| 130 | dims = 1536 | ||
| 131 | } | ||
| 132 | |||
| 133 | return &OpenAIEmbedder{ | ||
| 134 | baseURL: baseURL, | ||
| 135 | apiKey: apiKey, | ||
| 136 | model: model, | ||
| 137 | dims: dims, | ||
| 138 | } | ||
| 139 | } | ||
| 140 | |||
| 141 | func (e *OpenAIEmbedder) Dimensions() int { | ||
| 142 | return e.dims | ||
| 143 | } | ||
| 144 | |||
| 145 | type openaiRequest struct { | ||
| 146 | Model string `json:"model"` | ||
| 147 | Input []string `json:"input"` | ||
| 148 | } | ||
| 149 | |||
| 150 | type openaiResponse struct { | ||
| 151 | Data []struct { | ||
| 152 | Embedding []float32 `json:"embedding"` | ||
| 153 | } `json:"data"` | ||
| 154 | } | ||
| 155 | |||
| 156 | func (e *OpenAIEmbedder) Embed(ctx context.Context, texts []string) ([][]float32, error) { | ||
| 157 | if e.apiKey == "" { | ||
| 158 | return nil, fmt.Errorf("CODEVEC_API_KEY not set") | ||
| 159 | } | ||
| 160 | |||
| 161 | // Batch in groups of 100 | ||
| 162 | const batchSize = 100 | ||
| 163 | embeddings := make([][]float32, len(texts)) | ||
| 164 | |||
| 165 | for start := 0; start < len(texts); start += batchSize { | ||
| 166 | end := start + batchSize | ||
| 167 | if end > len(texts) { | ||
| 168 | end = len(texts) | ||
| 169 | } | ||
| 170 | batch := texts[start:end] | ||
| 171 | |||
| 172 | req := openaiRequest{ | ||
| 173 | Model: e.model, | ||
| 174 | Input: batch, | ||
| 175 | } | ||
| 176 | |||
| 177 | body, err := json.Marshal(req) | ||
| 178 | if err != nil { | ||
| 179 | return nil, err | ||
| 180 | } | ||
| 181 | |||
| 182 | httpReq, err := http.NewRequestWithContext(ctx, "POST", e.baseURL+"/v1/embeddings", bytes.NewReader(body)) | ||
| 183 | if err != nil { | ||
| 184 | return nil, err | ||
| 185 | } | ||
| 186 | httpReq.Header.Set("Content-Type", "application/json") | ||
| 187 | httpReq.Header.Set("Authorization", "Bearer "+e.apiKey) | ||
| 188 | |||
| 189 | resp, err := http.DefaultClient.Do(httpReq) | ||
| 190 | if err != nil { | ||
| 191 | return nil, fmt.Errorf("openai request failed: %w", err) | ||
| 192 | } | ||
| 193 | defer resp.Body.Close() | ||
| 194 | |||
| 195 | if resp.StatusCode != http.StatusOK { | ||
| 196 | return nil, fmt.Errorf("openai returned status %d", resp.StatusCode) | ||
| 197 | } | ||
| 198 | |||
| 199 | var result openaiResponse | ||
| 200 | if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { | ||
| 201 | return nil, err | ||
| 202 | } | ||
| 203 | |||
| 204 | for i, d := range result.Data { | ||
| 205 | embeddings[start+i] = d.Embedding | ||
| 206 | } | ||
| 207 | } | ||
| 208 | |||
| 209 | return embeddings, nil | ||
| 210 | } | ||
| 211 | |||
| 212 | // New creates an embedder based on provider name | ||
| 213 | func New(provider, model string) (Embedder, error) { | ||
| 214 | switch provider { | ||
| 215 | case "ollama": | ||
| 216 | return NewOllamaEmbedder(model), nil | ||
| 217 | case "openai": | ||
| 218 | return NewOpenAIEmbedder(model), nil | ||
| 219 | default: | ||
| 220 | return nil, fmt.Errorf("unknown provider: %s", provider) | ||
| 221 | } | ||
| 222 | } | ||
