aboutsummaryrefslogtreecommitdiffstats
path: root/internal/embedder/embedder.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/embedder/embedder.go')
-rw-r--r--internal/embedder/embedder.go222
1 files changed, 222 insertions, 0 deletions
diff --git a/internal/embedder/embedder.go b/internal/embedder/embedder.go
new file mode 100644
index 0000000..42f8518
--- /dev/null
+++ b/internal/embedder/embedder.go
@@ -0,0 +1,222 @@
1package embedder
2
3import (
4 "bytes"
5 "context"
6 "encoding/json"
7 "fmt"
8 "net/http"
9 "os"
10)
11
12// Embedder generates embeddings for text
13type Embedder interface {
14 Embed(ctx context.Context, texts []string) ([][]float32, error)
15 Dimensions() int
16}
17
18// OllamaEmbedder uses Ollama's embedding API
19type OllamaEmbedder struct {
20 baseURL string
21 model string
22 dims int
23}
24
25// NewOllamaEmbedder creates an Ollama embedder
26func NewOllamaEmbedder(model string) *OllamaEmbedder {
27 baseURL := os.Getenv("CODEVEC_BASE_URL")
28 if baseURL == "" {
29 baseURL = "http://localhost:11434"
30 }
31 if model == "" {
32 model = "nomic-embed-text"
33 }
34
35 // Model dimensions
36 dims := 768 // nomic-embed-text default
37 switch model {
38 case "mxbai-embed-large":
39 dims = 1024
40 case "all-minilm":
41 dims = 384
42 }
43
44 return &OllamaEmbedder{
45 baseURL: baseURL,
46 model: model,
47 dims: dims,
48 }
49}
50
51func (e *OllamaEmbedder) Dimensions() int {
52 return e.dims
53}
54
55type ollamaRequest struct {
56 Model string `json:"model"`
57 Prompt string `json:"prompt"`
58}
59
60type ollamaResponse struct {
61 Embedding []float32 `json:"embedding"`
62}
63
64func (e *OllamaEmbedder) Embed(ctx context.Context, texts []string) ([][]float32, error) {
65 embeddings := make([][]float32, len(texts))
66
67 // Ollama's /api/embeddings takes one prompt at a time
68 for i, text := range texts {
69 req := ollamaRequest{
70 Model: e.model,
71 Prompt: text,
72 }
73
74 body, err := json.Marshal(req)
75 if err != nil {
76 return nil, err
77 }
78
79 httpReq, err := http.NewRequestWithContext(ctx, "POST", e.baseURL+"/api/embeddings", bytes.NewReader(body))
80 if err != nil {
81 return nil, err
82 }
83 httpReq.Header.Set("Content-Type", "application/json")
84
85 resp, err := http.DefaultClient.Do(httpReq)
86 if err != nil {
87 return nil, fmt.Errorf("ollama request failed: %w", err)
88 }
89 defer resp.Body.Close()
90
91 if resp.StatusCode != http.StatusOK {
92 return nil, fmt.Errorf("ollama returned status %d", resp.StatusCode)
93 }
94
95 var result ollamaResponse
96 if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
97 return nil, err
98 }
99
100 embeddings[i] = result.Embedding
101 }
102
103 return embeddings, nil
104}
105
106// OpenAIEmbedder uses OpenAI-compatible embedding API
107type OpenAIEmbedder struct {
108 baseURL string
109 apiKey string
110 model string
111 dims int
112}
113
114// NewOpenAIEmbedder creates an OpenAI-compatible embedder
115func NewOpenAIEmbedder(model string) *OpenAIEmbedder {
116 baseURL := os.Getenv("CODEVEC_BASE_URL")
117 if baseURL == "" {
118 baseURL = "https://api.openai.com"
119 }
120 apiKey := os.Getenv("CODEVEC_API_KEY")
121 if model == "" {
122 model = "text-embedding-3-small"
123 }
124
125 dims := 1536 // text-embedding-3-small default
126 switch model {
127 case "text-embedding-3-large":
128 dims = 3072
129 case "text-embedding-ada-002":
130 dims = 1536
131 }
132
133 return &OpenAIEmbedder{
134 baseURL: baseURL,
135 apiKey: apiKey,
136 model: model,
137 dims: dims,
138 }
139}
140
141func (e *OpenAIEmbedder) Dimensions() int {
142 return e.dims
143}
144
145type openaiRequest struct {
146 Model string `json:"model"`
147 Input []string `json:"input"`
148}
149
150type openaiResponse struct {
151 Data []struct {
152 Embedding []float32 `json:"embedding"`
153 } `json:"data"`
154}
155
156func (e *OpenAIEmbedder) Embed(ctx context.Context, texts []string) ([][]float32, error) {
157 if e.apiKey == "" {
158 return nil, fmt.Errorf("CODEVEC_API_KEY not set")
159 }
160
161 // Batch in groups of 100
162 const batchSize = 100
163 embeddings := make([][]float32, len(texts))
164
165 for start := 0; start < len(texts); start += batchSize {
166 end := start + batchSize
167 if end > len(texts) {
168 end = len(texts)
169 }
170 batch := texts[start:end]
171
172 req := openaiRequest{
173 Model: e.model,
174 Input: batch,
175 }
176
177 body, err := json.Marshal(req)
178 if err != nil {
179 return nil, err
180 }
181
182 httpReq, err := http.NewRequestWithContext(ctx, "POST", e.baseURL+"/v1/embeddings", bytes.NewReader(body))
183 if err != nil {
184 return nil, err
185 }
186 httpReq.Header.Set("Content-Type", "application/json")
187 httpReq.Header.Set("Authorization", "Bearer "+e.apiKey)
188
189 resp, err := http.DefaultClient.Do(httpReq)
190 if err != nil {
191 return nil, fmt.Errorf("openai request failed: %w", err)
192 }
193 defer resp.Body.Close()
194
195 if resp.StatusCode != http.StatusOK {
196 return nil, fmt.Errorf("openai returned status %d", resp.StatusCode)
197 }
198
199 var result openaiResponse
200 if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
201 return nil, err
202 }
203
204 for i, d := range result.Data {
205 embeddings[start+i] = d.Embedding
206 }
207 }
208
209 return embeddings, nil
210}
211
212// New creates an embedder based on provider name
213func New(provider, model string) (Embedder, error) {
214 switch provider {
215 case "ollama":
216 return NewOllamaEmbedder(model), nil
217 case "openai":
218 return NewOpenAIEmbedder(model), nil
219 default:
220 return nil, fmt.Errorf("unknown provider: %s", provider)
221 }
222}